1use {
2 crate::jsonrpc::{
3 helpers::{
4 get_x_bigtable_disabled, get_x_subscription_id, response_200, response_400,
5 response_500, to_vec, RpcResponse,
6 },
7 metrics::{
8 RPC_REQUESTS_DURATION_SECONDS, RPC_REQUESTS_GENERATED_BYTES_TOTAL, RPC_REQUESTS_TOTAL,
9 },
10 },
11 futures::{
12 future::BoxFuture,
13 stream::{FuturesOrdered, StreamExt},
14 },
15 http_body_util::{BodyExt, Limited},
16 hyper::{
17 body::{Bytes, Incoming as BodyIncoming},
18 http::Result as HttpResult,
19 },
20 jsonrpsee_types::{error::ErrorCode, Request, Response, ResponsePayload, TwoPointZero},
21 metrics::{counter, histogram},
22 quanta::Instant,
23 richat_metrics::duration_to_seconds,
24 std::{collections::HashMap, fmt, sync::Arc},
25};
26
27pub type RpcRequestResult = anyhow::Result<Vec<u8>>;
28
29pub type RpcRequestHandler<S> =
30 Box<dyn Fn(S, Arc<str>, bool, Request<'_>) -> BoxFuture<'_, RpcRequestResult> + Send + Sync>;
31
32#[derive(Debug)]
33enum RpcRequests<'a> {
34 Single(Request<'a>),
35 Batch(Vec<Request<'a>>),
36}
37
38impl<'a> RpcRequests<'a> {
39 fn parse(bytes: &'a Bytes) -> serde_json::Result<Self> {
40 for i in 0..bytes.len() {
41 if bytes[i] == b'[' {
42 return serde_json::from_slice::<Vec<Request<'_>>>(bytes).map(Self::Batch);
43 } else if bytes[i] == b'{' {
44 break;
45 }
46 }
47 serde_json::from_slice::<Request<'_>>(bytes).map(Self::Single)
48 }
49}
50
51pub struct RpcRequestsProcessor<S> {
52 body_limit: usize,
53 state: S,
54 methods: HashMap<&'static str, RpcRequestHandler<S>>,
55}
56
57impl<S> fmt::Debug for RpcRequestsProcessor<S> {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 f.debug_struct("RpcRequestsProcessor").finish()
60 }
61}
62
63impl<S: Clone> RpcRequestsProcessor<S> {
64 pub fn new(body_limit: usize, state: S) -> Self {
65 Self {
66 body_limit,
67 state,
68 methods: HashMap::new(),
69 }
70 }
71
72 pub fn add_handler(
73 &mut self,
74 method: &'static str,
75 handler: RpcRequestHandler<S>,
76 ) -> &mut Self {
77 self.methods.insert(method, handler);
78 self
79 }
80
81 pub async fn on_request(&self, req: hyper::Request<BodyIncoming>) -> HttpResult<RpcResponse> {
82 let (parts, body) = req.into_parts();
83
84 let x_subscription_id = get_x_subscription_id(&parts.headers);
85 let upstream_disabled = get_x_bigtable_disabled(&parts.headers);
86
87 let bytes = match Limited::new(body, self.body_limit).collect().await {
88 Ok(body) => body.to_bytes(),
89 Err(error) => return response_400(error),
90 };
91 let requests = match RpcRequests::parse(&bytes) {
92 Ok(requests) => requests,
93 Err(error) => return response_400(error),
94 };
95
96 let mut buffer = match requests {
97 RpcRequests::Single(request) => {
98 match self
99 .process(Arc::clone(&x_subscription_id), upstream_disabled, request)
100 .await
101 {
102 Ok(response) => response,
103 Err(error) => return response_500(error),
104 }
105 }
106 RpcRequests::Batch(requests) => {
107 let mut futures = FuturesOrdered::new();
108 for request in requests {
109 let x_subscription_id = Arc::clone(&x_subscription_id);
110 futures.push_back(self.process(
111 Arc::clone(&x_subscription_id),
112 upstream_disabled,
113 request,
114 ));
115 }
116
117 let mut buffer = Vec::new();
118 buffer.push(b'[');
119 while let Some(result) = futures.next().await {
120 match result {
121 Ok(mut response) => {
122 buffer.append(&mut response);
123 }
124 Err(error) => return response_500(error),
125 }
126 if !futures.is_empty() {
127 buffer.push(b',');
128 }
129 }
130 buffer.push(b']');
131 buffer
132 }
133 };
134 buffer.push(b'\n');
135 counter!(
136 RPC_REQUESTS_GENERATED_BYTES_TOTAL,
137 "x_subscription_id" => x_subscription_id,
138 )
139 .increment(buffer.len() as u64);
140 response_200(buffer)
141 }
142
143 async fn process<'a>(
144 &'a self,
145 x_subscription_id: Arc<str>,
146 upstream_disabled: bool,
147 request: Request<'a>,
148 ) -> anyhow::Result<Vec<u8>> {
149 let Some((method, handle)) = self.methods.get_key_value(request.method.as_ref()) else {
150 return Ok(to_vec(&Response {
151 jsonrpc: Some(TwoPointZero),
152 payload: ResponsePayload::<()>::error(ErrorCode::MethodNotFound),
153 id: request.id.into_owned(),
154 }));
155 };
156
157 let ts = Instant::now();
158 let result = handle(
159 self.state.clone(),
160 Arc::clone(&x_subscription_id),
161 upstream_disabled,
162 request,
163 )
164 .await;
165 counter!(
166 RPC_REQUESTS_TOTAL,
167 "x_subscription_id" => Arc::clone(&x_subscription_id),
168 "method" => *method,
169 )
170 .increment(1);
171 histogram!(
172 RPC_REQUESTS_DURATION_SECONDS,
173 "x_subscription_id" => x_subscription_id,
174 "method" => *method,
175 )
176 .record(duration_to_seconds(ts.elapsed()));
177 result
178 }
179}