richat_shared/jsonrpc/
requests.rs

1use {
2    crate::{
3        jsonrpc::{
4            helpers::{
5                get_x_bigtable_disabled, get_x_subscription_id, response_200, response_400,
6                response_500, RpcResponse,
7            },
8            metrics::{
9                RPC_REQUESTS_DURATION_SECONDS, RPC_REQUESTS_GENERATED_BYTES_TOTAL,
10                RPC_REQUESTS_TOTAL,
11            },
12        },
13        metrics::duration_to_seconds,
14    },
15    futures::{
16        future::BoxFuture,
17        stream::{FuturesOrdered, StreamExt},
18    },
19    http_body_util::{BodyExt, Limited},
20    hyper::{
21        body::{Bytes, Incoming as BodyIncoming},
22        http::Result as HttpResult,
23    },
24    jsonrpsee_types::{error::ErrorCode, Request, Response, ResponsePayload, TwoPointZero},
25    metrics::{counter, histogram},
26    quanta::Instant,
27    std::{collections::HashMap, fmt, sync::Arc},
28};
29
30pub type RpcRequestResult<'a> = anyhow::Result<Response<'a, serde_json::Value>>;
31
32pub type RpcRequestHandler<S> = Box<
33    dyn Fn(S, Arc<str>, bool, Request<'_>) -> BoxFuture<'_, RpcRequestResult<'_>> + Send + Sync,
34>;
35
36#[derive(Debug)]
37enum RpcRequests<'a> {
38    Single(Request<'a>),
39    Batch(Vec<Request<'a>>),
40}
41
42impl<'a> RpcRequests<'a> {
43    fn parse(bytes: &'a Bytes) -> serde_json::Result<Self> {
44        for i in 0..bytes.len() {
45            if bytes[i] == b'[' {
46                return serde_json::from_slice::<Vec<Request<'_>>>(bytes).map(Self::Batch);
47            } else if bytes[i] == b'{' {
48                break;
49            }
50        }
51        serde_json::from_slice::<Request<'_>>(bytes).map(Self::Single)
52    }
53}
54
55pub struct RpcRequestsProcessor<S> {
56    body_limit: usize,
57    state: S,
58    methods: HashMap<&'static str, RpcRequestHandler<S>>,
59}
60
61impl<S> fmt::Debug for RpcRequestsProcessor<S> {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        f.debug_struct("RpcRequestsProcessor").finish()
64    }
65}
66
67impl<S: Clone> RpcRequestsProcessor<S> {
68    pub fn new(body_limit: usize, state: S) -> Self {
69        Self {
70            body_limit,
71            state,
72            methods: HashMap::new(),
73        }
74    }
75
76    pub fn add_handler(
77        &mut self,
78        method: &'static str,
79        handler: RpcRequestHandler<S>,
80    ) -> &mut Self {
81        self.methods.insert(method, handler);
82        self
83    }
84
85    pub async fn on_request(&self, req: hyper::Request<BodyIncoming>) -> HttpResult<RpcResponse> {
86        let (parts, body) = req.into_parts();
87
88        let x_subscription_id = get_x_subscription_id(&parts.headers);
89        let upstream_disabled = get_x_bigtable_disabled(&parts.headers);
90
91        let bytes = match Limited::new(body, self.body_limit).collect().await {
92            Ok(body) => body.to_bytes(),
93            Err(error) => return response_400(error),
94        };
95        let requests = match RpcRequests::parse(&bytes) {
96            Ok(requests) => requests,
97            Err(error) => return response_400(error),
98        };
99
100        let mut buffer = match requests {
101            RpcRequests::Single(request) => {
102                match self
103                    .process(Arc::clone(&x_subscription_id), upstream_disabled, request)
104                    .await
105                {
106                    Ok(response) => {
107                        serde_json::to_vec(&response).expect("json serialization never fail")
108                    }
109                    Err(error) => return response_500(error),
110                }
111            }
112            RpcRequests::Batch(requests) => {
113                let mut futures = FuturesOrdered::new();
114                for request in requests {
115                    let x_subscription_id = Arc::clone(&x_subscription_id);
116                    futures.push_back(self.process(
117                        Arc::clone(&x_subscription_id),
118                        upstream_disabled,
119                        request,
120                    ));
121                }
122
123                let mut buffer = Vec::new();
124                buffer.push(b'[');
125                while let Some(result) = futures.next().await {
126                    match result {
127                        Ok(response) => serde_json::to_writer(&mut buffer, &response)
128                            .expect("json serialization never fail"),
129                        Err(error) => return response_500(error),
130                    }
131                    if !futures.is_empty() {
132                        buffer.push(b',');
133                    }
134                }
135                buffer.push(b']');
136                buffer
137            }
138        };
139        buffer.push(b'\n');
140        counter!(
141            RPC_REQUESTS_GENERATED_BYTES_TOTAL,
142            "x_subscription_id" => x_subscription_id,
143        )
144        .increment(buffer.len() as u64);
145        response_200(buffer)
146    }
147
148    async fn process<'a>(
149        &'a self,
150        x_subscription_id: Arc<str>,
151        upstream_disabled: bool,
152        request: Request<'a>,
153    ) -> anyhow::Result<Response<'a, serde_json::Value>> {
154        let Some((method, handle)) = self.methods.get_key_value(request.method.as_ref()) else {
155            return Ok(Response {
156                jsonrpc: Some(TwoPointZero),
157                payload: ResponsePayload::error(ErrorCode::MethodNotFound),
158                id: request.id.into_owned(),
159            });
160        };
161
162        let ts = Instant::now();
163        let result = handle(
164            self.state.clone(),
165            Arc::clone(&x_subscription_id),
166            upstream_disabled,
167            request,
168        )
169        .await;
170        counter!(
171            RPC_REQUESTS_TOTAL,
172            "x_subscription_id" => Arc::clone(&x_subscription_id),
173            "method" => *method,
174        )
175        .increment(1);
176        histogram!(
177            RPC_REQUESTS_DURATION_SECONDS,
178            "x_subscription_id" => x_subscription_id,
179            "method" => *method,
180        )
181        .record(duration_to_seconds(ts.elapsed()));
182        result
183    }
184}