Skip to main content

richat_shared/jsonrpc/
requests.rs

1use {
2    crate::jsonrpc::{
3        helpers::{
4            RpcResponse, get_x_bigtable_disabled, get_x_subscription_id, response_200,
5            response_400, response_500, to_vec,
6        },
7        metrics::{
8            RPC_REQUESTS_DURATION_SECONDS, RPC_REQUESTS_GENERATED_BYTES_TOTAL, RPC_REQUESTS_TOTAL,
9        },
10    },
11    futures::{
12        future::BoxFuture,
13        stream::{FuturesOrdered, StreamExt},
14    },
15    http_body_util::{BodyExt, Limited},
16    hyper::{
17        HeaderMap,
18        body::{Bytes, Incoming as BodyIncoming},
19        http::Result as HttpResult,
20    },
21    jsonrpsee_types::{
22        Extensions, Request, Response, ResponsePayload, TwoPointZero, error::ErrorCode,
23    },
24    metrics::{counter, histogram},
25    quanta::Instant,
26    richat_metrics::duration_to_seconds,
27    std::{collections::HashMap, fmt, sync::Arc},
28};
29
30pub type RpcRequestResult = anyhow::Result<Vec<u8>>;
31
32pub type RpcRequestHandler<S> =
33    Box<dyn Fn(S, Arc<str>, bool, Request<'_>) -> BoxFuture<'_, RpcRequestResult> + Send + Sync>;
34
35#[derive(Debug)]
36enum RpcRequests<'a> {
37    Single(Request<'a>),
38    Batch(Vec<Request<'a>>),
39}
40
41impl<'a> RpcRequests<'a> {
42    fn parse(bytes: &'a Bytes) -> serde_json::Result<Self> {
43        for i in 0..bytes.len() {
44            if bytes[i] == b'[' {
45                return serde_json::from_slice::<Vec<Request<'_>>>(bytes).map(Self::Batch);
46            } else if bytes[i] == b'{' {
47                break;
48            }
49        }
50        serde_json::from_slice::<Request<'_>>(bytes).map(Self::Single)
51    }
52}
53
54pub struct RpcRequestsProcessor<S> {
55    body_limit: usize,
56    state: S,
57    extra_headers: HeaderMap,
58    methods: HashMap<&'static str, RpcRequestHandler<S>>,
59}
60
61impl<S> fmt::Debug for RpcRequestsProcessor<S> {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        f.debug_struct("RpcRequestsProcessor").finish()
64    }
65}
66
67impl<S: Clone> RpcRequestsProcessor<S> {
68    pub fn new(body_limit: usize, state: S, extra_headers: HeaderMap) -> Self {
69        Self {
70            body_limit,
71            state,
72            extra_headers,
73            methods: HashMap::new(),
74        }
75    }
76
77    pub fn add_handler(
78        &mut self,
79        method: &'static str,
80        handler: RpcRequestHandler<S>,
81    ) -> &mut Self {
82        self.methods.insert(method, handler);
83        self
84    }
85
86    pub async fn on_request(&self, req: hyper::Request<BodyIncoming>) -> HttpResult<RpcResponse> {
87        let (parts, body) = req.into_parts();
88
89        let x_subscription_id = get_x_subscription_id(&parts.headers);
90        let upstream_disabled = get_x_bigtable_disabled(&parts.headers);
91
92        let bytes = match Limited::new(body, self.body_limit).collect().await {
93            Ok(body) => body.to_bytes(),
94            Err(error) => return response_400(error),
95        };
96        let requests = match RpcRequests::parse(&bytes) {
97            Ok(requests) => requests,
98            Err(error) => return response_400(error),
99        };
100
101        let mut buffer = match requests {
102            RpcRequests::Single(request) => {
103                match self
104                    .process(Arc::clone(&x_subscription_id), upstream_disabled, request)
105                    .await
106                {
107                    Ok(response) => response,
108                    Err(error) => return response_500(error),
109                }
110            }
111            RpcRequests::Batch(requests) => {
112                let mut futures = FuturesOrdered::new();
113                for request in requests {
114                    let x_subscription_id = Arc::clone(&x_subscription_id);
115                    futures.push_back(self.process(
116                        Arc::clone(&x_subscription_id),
117                        upstream_disabled,
118                        request,
119                    ));
120                }
121
122                let mut buffer = Vec::new();
123                buffer.push(b'[');
124                while let Some(result) = futures.next().await {
125                    match result {
126                        Ok(mut response) => {
127                            buffer.append(&mut response);
128                        }
129                        Err(error) => return response_500(error),
130                    }
131                    if !futures.is_empty() {
132                        buffer.push(b',');
133                    }
134                }
135                buffer.push(b']');
136                buffer
137            }
138        };
139        buffer.push(b'\n');
140        counter!(
141            RPC_REQUESTS_GENERATED_BYTES_TOTAL,
142            "x_subscription_id" => x_subscription_id,
143        )
144        .increment(buffer.len() as u64);
145        response_200(buffer, &self.extra_headers)
146    }
147
148    async fn process<'a>(
149        &'a self,
150        x_subscription_id: Arc<str>,
151        upstream_disabled: bool,
152        request: Request<'a>,
153    ) -> anyhow::Result<Vec<u8>> {
154        let Some((method, handle)) = self.methods.get_key_value(request.method.as_ref()) else {
155            return Ok(to_vec(&Response {
156                jsonrpc: Some(TwoPointZero),
157                payload: ResponsePayload::<()>::error(ErrorCode::MethodNotFound),
158                id: request.id.into_owned(),
159                extensions: Extensions::default(), // doesn't matter, as it is not used in serialize
160            }));
161        };
162
163        let ts = Instant::now();
164        let result = handle(
165            self.state.clone(),
166            Arc::clone(&x_subscription_id),
167            upstream_disabled,
168            request,
169        )
170        .await;
171        counter!(
172            RPC_REQUESTS_TOTAL,
173            "x_subscription_id" => Arc::clone(&x_subscription_id),
174            "method" => *method,
175        )
176        .increment(1);
177        histogram!(
178            RPC_REQUESTS_DURATION_SECONDS,
179            "x_subscription_id" => x_subscription_id,
180            "method" => *method,
181        )
182        .record(duration_to_seconds(ts.elapsed()));
183        result
184    }
185}