richat_shared/jsonrpc/
requests.rs

1use {
2    crate::jsonrpc::{
3        helpers::{
4            get_x_bigtable_disabled, get_x_subscription_id, response_200, response_400,
5            response_500, to_vec, RpcResponse,
6        },
7        metrics::{
8            RPC_REQUESTS_DURATION_SECONDS, RPC_REQUESTS_GENERATED_BYTES_TOTAL, RPC_REQUESTS_TOTAL,
9        },
10    },
11    futures::{
12        future::BoxFuture,
13        stream::{FuturesOrdered, StreamExt},
14    },
15    http_body_util::{BodyExt, Limited},
16    hyper::{
17        body::{Bytes, Incoming as BodyIncoming},
18        http::Result as HttpResult,
19        HeaderMap,
20    },
21    jsonrpsee_types::{error::ErrorCode, Request, Response, ResponsePayload, TwoPointZero},
22    metrics::{counter, histogram},
23    quanta::Instant,
24    richat_metrics::duration_to_seconds,
25    std::{collections::HashMap, fmt, sync::Arc},
26};
27
28pub type RpcRequestResult = anyhow::Result<Vec<u8>>;
29
30pub type RpcRequestHandler<S> =
31    Box<dyn Fn(S, Arc<str>, bool, Request<'_>) -> BoxFuture<'_, RpcRequestResult> + Send + Sync>;
32
33#[derive(Debug)]
34enum RpcRequests<'a> {
35    Single(Request<'a>),
36    Batch(Vec<Request<'a>>),
37}
38
39impl<'a> RpcRequests<'a> {
40    fn parse(bytes: &'a Bytes) -> serde_json::Result<Self> {
41        for i in 0..bytes.len() {
42            if bytes[i] == b'[' {
43                return serde_json::from_slice::<Vec<Request<'_>>>(bytes).map(Self::Batch);
44            } else if bytes[i] == b'{' {
45                break;
46            }
47        }
48        serde_json::from_slice::<Request<'_>>(bytes).map(Self::Single)
49    }
50}
51
52pub struct RpcRequestsProcessor<S> {
53    body_limit: usize,
54    state: S,
55    extra_headers: HeaderMap,
56    methods: HashMap<&'static str, RpcRequestHandler<S>>,
57}
58
59impl<S> fmt::Debug for RpcRequestsProcessor<S> {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        f.debug_struct("RpcRequestsProcessor").finish()
62    }
63}
64
65impl<S: Clone> RpcRequestsProcessor<S> {
66    pub fn new(body_limit: usize, state: S, extra_headers: HeaderMap) -> Self {
67        Self {
68            body_limit,
69            state,
70            extra_headers,
71            methods: HashMap::new(),
72        }
73    }
74
75    pub fn add_handler(
76        &mut self,
77        method: &'static str,
78        handler: RpcRequestHandler<S>,
79    ) -> &mut Self {
80        self.methods.insert(method, handler);
81        self
82    }
83
84    pub async fn on_request(&self, req: hyper::Request<BodyIncoming>) -> HttpResult<RpcResponse> {
85        let (parts, body) = req.into_parts();
86
87        let x_subscription_id = get_x_subscription_id(&parts.headers);
88        let upstream_disabled = get_x_bigtable_disabled(&parts.headers);
89
90        let bytes = match Limited::new(body, self.body_limit).collect().await {
91            Ok(body) => body.to_bytes(),
92            Err(error) => return response_400(error),
93        };
94        let requests = match RpcRequests::parse(&bytes) {
95            Ok(requests) => requests,
96            Err(error) => return response_400(error),
97        };
98
99        let mut buffer = match requests {
100            RpcRequests::Single(request) => {
101                match self
102                    .process(Arc::clone(&x_subscription_id), upstream_disabled, request)
103                    .await
104                {
105                    Ok(response) => response,
106                    Err(error) => return response_500(error),
107                }
108            }
109            RpcRequests::Batch(requests) => {
110                let mut futures = FuturesOrdered::new();
111                for request in requests {
112                    let x_subscription_id = Arc::clone(&x_subscription_id);
113                    futures.push_back(self.process(
114                        Arc::clone(&x_subscription_id),
115                        upstream_disabled,
116                        request,
117                    ));
118                }
119
120                let mut buffer = Vec::new();
121                buffer.push(b'[');
122                while let Some(result) = futures.next().await {
123                    match result {
124                        Ok(mut response) => {
125                            buffer.append(&mut response);
126                        }
127                        Err(error) => return response_500(error),
128                    }
129                    if !futures.is_empty() {
130                        buffer.push(b',');
131                    }
132                }
133                buffer.push(b']');
134                buffer
135            }
136        };
137        buffer.push(b'\n');
138        counter!(
139            RPC_REQUESTS_GENERATED_BYTES_TOTAL,
140            "x_subscription_id" => x_subscription_id,
141        )
142        .increment(buffer.len() as u64);
143        response_200(buffer, &self.extra_headers)
144    }
145
146    async fn process<'a>(
147        &'a self,
148        x_subscription_id: Arc<str>,
149        upstream_disabled: bool,
150        request: Request<'a>,
151    ) -> anyhow::Result<Vec<u8>> {
152        let Some((method, handle)) = self.methods.get_key_value(request.method.as_ref()) else {
153            return Ok(to_vec(&Response {
154                jsonrpc: Some(TwoPointZero),
155                payload: ResponsePayload::<()>::error(ErrorCode::MethodNotFound),
156                id: request.id.into_owned(),
157            }));
158        };
159
160        let ts = Instant::now();
161        let result = handle(
162            self.state.clone(),
163            Arc::clone(&x_subscription_id),
164            upstream_disabled,
165            request,
166        )
167        .await;
168        counter!(
169            RPC_REQUESTS_TOTAL,
170            "x_subscription_id" => Arc::clone(&x_subscription_id),
171            "method" => *method,
172        )
173        .increment(1);
174        histogram!(
175            RPC_REQUESTS_DURATION_SECONDS,
176            "x_subscription_id" => x_subscription_id,
177            "method" => *method,
178        )
179        .record(duration_to_seconds(ts.elapsed()));
180        result
181    }
182}