1use {
2 crate::{
3 jsonrpc::{
4 helpers::{
5 get_x_bigtable_disabled, get_x_subscription_id, response_200, response_400,
6 response_500, RpcResponse,
7 },
8 metrics::{
9 RPC_REQUESTS_DURATION_SECONDS, RPC_REQUESTS_GENERATED_BYTES_TOTAL,
10 RPC_REQUESTS_TOTAL,
11 },
12 },
13 metrics::duration_to_seconds,
14 },
15 futures::{
16 future::BoxFuture,
17 stream::{FuturesOrdered, StreamExt},
18 },
19 http_body_util::{BodyExt, Limited},
20 hyper::{
21 body::{Bytes, Incoming as BodyIncoming},
22 http::Result as HttpResult,
23 },
24 jsonrpsee_types::{error::ErrorCode, Request, Response, ResponsePayload, TwoPointZero},
25 metrics::{counter, histogram},
26 quanta::Instant,
27 std::{collections::HashMap, fmt, sync::Arc},
28};
29
30pub type RpcRequestResult<'a> = anyhow::Result<Response<'a, serde_json::Value>>;
31
32pub type RpcRequestHandler<S> = Box<
33 dyn Fn(S, Arc<str>, bool, Request<'_>) -> BoxFuture<'_, RpcRequestResult<'_>> + Send + Sync,
34>;
35
36#[derive(Debug)]
37enum RpcRequests<'a> {
38 Single(Request<'a>),
39 Batch(Vec<Request<'a>>),
40}
41
42impl<'a> RpcRequests<'a> {
43 fn parse(bytes: &'a Bytes) -> serde_json::Result<Self> {
44 for i in 0..bytes.len() {
45 if bytes[i] == b'[' {
46 return serde_json::from_slice::<Vec<Request<'_>>>(bytes).map(Self::Batch);
47 } else if bytes[i] == b'{' {
48 break;
49 }
50 }
51 serde_json::from_slice::<Request<'_>>(bytes).map(Self::Single)
52 }
53}
54
55pub struct RpcRequestsProcessor<S> {
56 body_limit: usize,
57 state: S,
58 methods: HashMap<&'static str, RpcRequestHandler<S>>,
59}
60
61impl<S> fmt::Debug for RpcRequestsProcessor<S> {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 f.debug_struct("RpcRequestsProcessor").finish()
64 }
65}
66
67impl<S: Clone> RpcRequestsProcessor<S> {
68 pub fn new(body_limit: usize, state: S) -> Self {
69 Self {
70 body_limit,
71 state,
72 methods: HashMap::new(),
73 }
74 }
75
76 pub fn add_handler(
77 &mut self,
78 method: &'static str,
79 handler: RpcRequestHandler<S>,
80 ) -> &mut Self {
81 self.methods.insert(method, handler);
82 self
83 }
84
85 pub async fn on_request(&self, req: hyper::Request<BodyIncoming>) -> HttpResult<RpcResponse> {
86 let (parts, body) = req.into_parts();
87
88 let x_subscription_id = get_x_subscription_id(&parts.headers);
89 let upstream_disabled = get_x_bigtable_disabled(&parts.headers);
90
91 let bytes = match Limited::new(body, self.body_limit).collect().await {
92 Ok(body) => body.to_bytes(),
93 Err(error) => return response_400(error),
94 };
95 let requests = match RpcRequests::parse(&bytes) {
96 Ok(requests) => requests,
97 Err(error) => return response_400(error),
98 };
99
100 let mut buffer = match requests {
101 RpcRequests::Single(request) => {
102 match self
103 .process(Arc::clone(&x_subscription_id), upstream_disabled, request)
104 .await
105 {
106 Ok(response) => {
107 serde_json::to_vec(&response).expect("json serialization never fail")
108 }
109 Err(error) => return response_500(error),
110 }
111 }
112 RpcRequests::Batch(requests) => {
113 let mut futures = FuturesOrdered::new();
114 for request in requests {
115 let x_subscription_id = Arc::clone(&x_subscription_id);
116 futures.push_back(self.process(
117 Arc::clone(&x_subscription_id),
118 upstream_disabled,
119 request,
120 ));
121 }
122
123 let mut buffer = Vec::new();
124 buffer.push(b'[');
125 while let Some(result) = futures.next().await {
126 match result {
127 Ok(response) => serde_json::to_writer(&mut buffer, &response)
128 .expect("json serialization never fail"),
129 Err(error) => return response_500(error),
130 }
131 if !futures.is_empty() {
132 buffer.push(b',');
133 }
134 }
135 buffer.push(b']');
136 buffer
137 }
138 };
139 buffer.push(b'\n');
140 counter!(
141 RPC_REQUESTS_GENERATED_BYTES_TOTAL,
142 "x_subscription_id" => x_subscription_id,
143 )
144 .increment(buffer.len() as u64);
145 response_200(buffer)
146 }
147
148 async fn process<'a>(
149 &'a self,
150 x_subscription_id: Arc<str>,
151 upstream_disabled: bool,
152 request: Request<'a>,
153 ) -> anyhow::Result<Response<'a, serde_json::Value>> {
154 let Some((method, handle)) = self.methods.get_key_value(request.method.as_ref()) else {
155 return Ok(Response {
156 jsonrpc: Some(TwoPointZero),
157 payload: ResponsePayload::error(ErrorCode::MethodNotFound),
158 id: request.id.into_owned(),
159 });
160 };
161
162 let ts = Instant::now();
163 let result = handle(
164 self.state.clone(),
165 Arc::clone(&x_subscription_id),
166 upstream_disabled,
167 request,
168 )
169 .await;
170 counter!(
171 RPC_REQUESTS_TOTAL,
172 "x_subscription_id" => Arc::clone(&x_subscription_id),
173 "method" => *method,
174 )
175 .increment(1);
176 histogram!(
177 RPC_REQUESTS_DURATION_SECONDS,
178 "x_subscription_id" => x_subscription_id,
179 "method" => *method,
180 )
181 .record(duration_to_seconds(ts.elapsed()));
182 result
183 }
184}