1use {
2 crate::jsonrpc::{
3 helpers::{
4 RpcResponse, get_x_bigtable_disabled, get_x_subscription_id, response_200,
5 response_400, response_500, to_vec,
6 },
7 metrics::{
8 RPC_REQUESTS_DURATION_SECONDS, RPC_REQUESTS_GENERATED_BYTES_TOTAL, RPC_REQUESTS_TOTAL,
9 },
10 },
11 futures::{
12 future::BoxFuture,
13 stream::{FuturesOrdered, StreamExt},
14 },
15 http_body_util::{BodyExt, Limited},
16 hyper::{
17 HeaderMap,
18 body::{Bytes, Incoming as BodyIncoming},
19 http::Result as HttpResult,
20 },
21 jsonrpsee_types::{
22 Extensions, Request, Response, ResponsePayload, TwoPointZero, error::ErrorCode,
23 },
24 metrics::{counter, histogram},
25 quanta::Instant,
26 richat_metrics::duration_to_seconds,
27 std::{collections::HashMap, fmt, sync::Arc},
28};
29
30pub type RpcRequestResult = anyhow::Result<Vec<u8>>;
31
32pub type RpcRequestHandler<S> =
33 Box<dyn Fn(S, Arc<str>, bool, Request<'_>) -> BoxFuture<'_, RpcRequestResult> + Send + Sync>;
34
35#[derive(Debug)]
36enum RpcRequests<'a> {
37 Single(Request<'a>),
38 Batch(Vec<Request<'a>>),
39}
40
41impl<'a> RpcRequests<'a> {
42 fn parse(bytes: &'a Bytes) -> serde_json::Result<Self> {
43 for i in 0..bytes.len() {
44 if bytes[i] == b'[' {
45 return serde_json::from_slice::<Vec<Request<'_>>>(bytes).map(Self::Batch);
46 } else if bytes[i] == b'{' {
47 break;
48 }
49 }
50 serde_json::from_slice::<Request<'_>>(bytes).map(Self::Single)
51 }
52}
53
54pub struct RpcRequestsProcessor<S> {
55 body_limit: usize,
56 state: S,
57 extra_headers: HeaderMap,
58 methods: HashMap<&'static str, RpcRequestHandler<S>>,
59}
60
61impl<S> fmt::Debug for RpcRequestsProcessor<S> {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 f.debug_struct("RpcRequestsProcessor").finish()
64 }
65}
66
67impl<S: Clone> RpcRequestsProcessor<S> {
68 pub fn new(body_limit: usize, state: S, extra_headers: HeaderMap) -> Self {
69 Self {
70 body_limit,
71 state,
72 extra_headers,
73 methods: HashMap::new(),
74 }
75 }
76
77 pub fn add_handler(
78 &mut self,
79 method: &'static str,
80 handler: RpcRequestHandler<S>,
81 ) -> &mut Self {
82 self.methods.insert(method, handler);
83 self
84 }
85
86 pub async fn on_request(&self, req: hyper::Request<BodyIncoming>) -> HttpResult<RpcResponse> {
87 let (parts, body) = req.into_parts();
88
89 let x_subscription_id = get_x_subscription_id(&parts.headers);
90 let upstream_disabled = get_x_bigtable_disabled(&parts.headers);
91
92 let bytes = match Limited::new(body, self.body_limit).collect().await {
93 Ok(body) => body.to_bytes(),
94 Err(error) => return response_400(error),
95 };
96 let requests = match RpcRequests::parse(&bytes) {
97 Ok(requests) => requests,
98 Err(error) => return response_400(error),
99 };
100
101 let mut buffer = match requests {
102 RpcRequests::Single(request) => {
103 match self
104 .process(Arc::clone(&x_subscription_id), upstream_disabled, request)
105 .await
106 {
107 Ok(response) => response,
108 Err(error) => return response_500(error),
109 }
110 }
111 RpcRequests::Batch(requests) => {
112 let mut futures = FuturesOrdered::new();
113 for request in requests {
114 let x_subscription_id = Arc::clone(&x_subscription_id);
115 futures.push_back(self.process(
116 Arc::clone(&x_subscription_id),
117 upstream_disabled,
118 request,
119 ));
120 }
121
122 let mut buffer = Vec::new();
123 buffer.push(b'[');
124 while let Some(result) = futures.next().await {
125 match result {
126 Ok(mut response) => {
127 buffer.append(&mut response);
128 }
129 Err(error) => return response_500(error),
130 }
131 if !futures.is_empty() {
132 buffer.push(b',');
133 }
134 }
135 buffer.push(b']');
136 buffer
137 }
138 };
139 buffer.push(b'\n');
140 counter!(
141 RPC_REQUESTS_GENERATED_BYTES_TOTAL,
142 "x_subscription_id" => x_subscription_id,
143 )
144 .increment(buffer.len() as u64);
145 response_200(buffer, &self.extra_headers)
146 }
147
148 async fn process<'a>(
149 &'a self,
150 x_subscription_id: Arc<str>,
151 upstream_disabled: bool,
152 request: Request<'a>,
153 ) -> anyhow::Result<Vec<u8>> {
154 let Some((method, handle)) = self.methods.get_key_value(request.method.as_ref()) else {
155 return Ok(to_vec(&Response {
156 jsonrpc: Some(TwoPointZero),
157 payload: ResponsePayload::<()>::error(ErrorCode::MethodNotFound),
158 id: request.id.into_owned(),
159 extensions: Extensions::default(), }));
161 };
162
163 let ts = Instant::now();
164 let result = handle(
165 self.state.clone(),
166 Arc::clone(&x_subscription_id),
167 upstream_disabled,
168 request,
169 )
170 .await;
171 counter!(
172 RPC_REQUESTS_TOTAL,
173 "x_subscription_id" => Arc::clone(&x_subscription_id),
174 "method" => *method,
175 )
176 .increment(1);
177 histogram!(
178 RPC_REQUESTS_DURATION_SECONDS,
179 "x_subscription_id" => x_subscription_id,
180 "method" => *method,
181 )
182 .record(duration_to_seconds(ts.elapsed()));
183 result
184 }
185}