1use {
2 crate::jsonrpc::{
3 helpers::{
4 get_x_bigtable_disabled, get_x_subscription_id, response_200, response_400,
5 response_500, to_vec, RpcResponse,
6 },
7 metrics::{
8 RPC_REQUESTS_DURATION_SECONDS, RPC_REQUESTS_GENERATED_BYTES_TOTAL, RPC_REQUESTS_TOTAL,
9 },
10 },
11 futures::{
12 future::BoxFuture,
13 stream::{FuturesOrdered, StreamExt},
14 },
15 http_body_util::{BodyExt, Limited},
16 hyper::{
17 body::{Bytes, Incoming as BodyIncoming},
18 http::Result as HttpResult,
19 HeaderMap,
20 },
21 jsonrpsee_types::{error::ErrorCode, Request, Response, ResponsePayload, TwoPointZero},
22 metrics::{counter, histogram},
23 quanta::Instant,
24 richat_metrics::duration_to_seconds,
25 std::{collections::HashMap, fmt, sync::Arc},
26};
27
28pub type RpcRequestResult = anyhow::Result<Vec<u8>>;
29
30pub type RpcRequestHandler<S> =
31 Box<dyn Fn(S, Arc<str>, bool, Request<'_>) -> BoxFuture<'_, RpcRequestResult> + Send + Sync>;
32
33#[derive(Debug)]
34enum RpcRequests<'a> {
35 Single(Request<'a>),
36 Batch(Vec<Request<'a>>),
37}
38
39impl<'a> RpcRequests<'a> {
40 fn parse(bytes: &'a Bytes) -> serde_json::Result<Self> {
41 for i in 0..bytes.len() {
42 if bytes[i] == b'[' {
43 return serde_json::from_slice::<Vec<Request<'_>>>(bytes).map(Self::Batch);
44 } else if bytes[i] == b'{' {
45 break;
46 }
47 }
48 serde_json::from_slice::<Request<'_>>(bytes).map(Self::Single)
49 }
50}
51
52pub struct RpcRequestsProcessor<S> {
53 body_limit: usize,
54 state: S,
55 extra_headers: HeaderMap,
56 methods: HashMap<&'static str, RpcRequestHandler<S>>,
57}
58
59impl<S> fmt::Debug for RpcRequestsProcessor<S> {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 f.debug_struct("RpcRequestsProcessor").finish()
62 }
63}
64
65impl<S: Clone> RpcRequestsProcessor<S> {
66 pub fn new(body_limit: usize, state: S, extra_headers: HeaderMap) -> Self {
67 Self {
68 body_limit,
69 state,
70 extra_headers,
71 methods: HashMap::new(),
72 }
73 }
74
75 pub fn add_handler(
76 &mut self,
77 method: &'static str,
78 handler: RpcRequestHandler<S>,
79 ) -> &mut Self {
80 self.methods.insert(method, handler);
81 self
82 }
83
84 pub async fn on_request(&self, req: hyper::Request<BodyIncoming>) -> HttpResult<RpcResponse> {
85 let (parts, body) = req.into_parts();
86
87 let x_subscription_id = get_x_subscription_id(&parts.headers);
88 let upstream_disabled = get_x_bigtable_disabled(&parts.headers);
89
90 let bytes = match Limited::new(body, self.body_limit).collect().await {
91 Ok(body) => body.to_bytes(),
92 Err(error) => return response_400(error),
93 };
94 let requests = match RpcRequests::parse(&bytes) {
95 Ok(requests) => requests,
96 Err(error) => return response_400(error),
97 };
98
99 let mut buffer = match requests {
100 RpcRequests::Single(request) => {
101 match self
102 .process(Arc::clone(&x_subscription_id), upstream_disabled, request)
103 .await
104 {
105 Ok(response) => response,
106 Err(error) => return response_500(error),
107 }
108 }
109 RpcRequests::Batch(requests) => {
110 let mut futures = FuturesOrdered::new();
111 for request in requests {
112 let x_subscription_id = Arc::clone(&x_subscription_id);
113 futures.push_back(self.process(
114 Arc::clone(&x_subscription_id),
115 upstream_disabled,
116 request,
117 ));
118 }
119
120 let mut buffer = Vec::new();
121 buffer.push(b'[');
122 while let Some(result) = futures.next().await {
123 match result {
124 Ok(mut response) => {
125 buffer.append(&mut response);
126 }
127 Err(error) => return response_500(error),
128 }
129 if !futures.is_empty() {
130 buffer.push(b',');
131 }
132 }
133 buffer.push(b']');
134 buffer
135 }
136 };
137 buffer.push(b'\n');
138 counter!(
139 RPC_REQUESTS_GENERATED_BYTES_TOTAL,
140 "x_subscription_id" => x_subscription_id,
141 )
142 .increment(buffer.len() as u64);
143 response_200(buffer, &self.extra_headers)
144 }
145
146 async fn process<'a>(
147 &'a self,
148 x_subscription_id: Arc<str>,
149 upstream_disabled: bool,
150 request: Request<'a>,
151 ) -> anyhow::Result<Vec<u8>> {
152 let Some((method, handle)) = self.methods.get_key_value(request.method.as_ref()) else {
153 return Ok(to_vec(&Response {
154 jsonrpc: Some(TwoPointZero),
155 payload: ResponsePayload::<()>::error(ErrorCode::MethodNotFound),
156 id: request.id.into_owned(),
157 }));
158 };
159
160 let ts = Instant::now();
161 let result = handle(
162 self.state.clone(),
163 Arc::clone(&x_subscription_id),
164 upstream_disabled,
165 request,
166 )
167 .await;
168 counter!(
169 RPC_REQUESTS_TOTAL,
170 "x_subscription_id" => Arc::clone(&x_subscription_id),
171 "method" => *method,
172 )
173 .increment(1);
174 histogram!(
175 RPC_REQUESTS_DURATION_SECONDS,
176 "x_subscription_id" => x_subscription_id,
177 "method" => *method,
178 )
179 .record(duration_to_seconds(ts.elapsed()));
180 result
181 }
182}