kora_lib/rpc_server/
middleware_utils.rs1use std::collections::{HashMap, HashSet};
2
3use futures_util::TryStreamExt;
4use http::{Request, Response, StatusCode};
5use jsonrpsee::server::logger::Body;
6
7use crate::KoraError;
8
9pub fn default_sig_verify() -> bool {
10 false
11}
12
13pub async fn extract_parts_and_body_bytes(
14 request: Request<Body>,
15) -> (http::request::Parts, Vec<u8>) {
16 let (parts, body) = request.into_parts();
17 let body_bytes = body
18 .try_fold(Vec::new(), |mut acc, chunk| async move {
19 acc.extend_from_slice(&chunk);
20 Ok(acc)
21 })
22 .await
23 .unwrap_or_default();
24 (parts, body_bytes)
25}
26
27pub fn get_jsonrpc_method(body_bytes: &[u8]) -> Option<String> {
28 match serde_json::from_slice::<serde_json::Value>(body_bytes) {
29 Ok(val) => val.get("method").and_then(|m| m.as_str()).map(|s| s.to_string()),
30 Err(_) => None,
31 }
32}
33
34pub fn verify_jsonrpc_method(
35 body_bytes: &[u8],
36 allowed_methods: &HashSet<String>,
37) -> Result<String, KoraError> {
38 let method = get_jsonrpc_method(body_bytes);
39 if let Some(method) = method {
40 if allowed_methods.contains(&method) {
41 return Ok(method);
42 }
43 }
44 Err(KoraError::InvalidRequest("Method not allowed".to_string()))
45}
46
47pub fn build_response_with_graceful_error(
48 headers: Option<HashMap<String, String>>,
49 status_code: StatusCode,
50 error_message: &str,
51) -> Response<Body> {
52 let mut builder = Response::builder();
53
54 if let Some(headers) = headers {
55 for (key, value) in headers.iter() {
56 builder = builder.header(key, value);
57 }
58 }
59
60 builder.status(status_code).body(Body::from(error_message.to_string())).unwrap_or_else(|e| {
61 log::error!("Failed to build response, error: {e:?}");
62 let mut response = Response::new(Body::empty());
63 *response.status_mut() = status_code;
64 response
65 })
66}
67
68#[derive(Clone)]
70pub struct MethodValidationLayer {
71 allowed_methods: HashSet<String>,
72}
73
74impl MethodValidationLayer {
75 pub fn new(allowed_methods: Vec<String>) -> Self {
76 Self { allowed_methods: allowed_methods.into_iter().collect() }
77 }
78}
79
80#[derive(Clone)]
81pub struct MethodValidationService<S> {
82 inner: S,
83 allowed_methods: HashSet<String>,
84}
85
86impl<S> tower::Layer<S> for MethodValidationLayer {
87 type Service = MethodValidationService<S>;
88
89 fn layer(&self, inner: S) -> Self::Service {
90 MethodValidationService { inner, allowed_methods: self.allowed_methods.clone() }
91 }
92}
93
94impl<S> tower::Service<Request<Body>> for MethodValidationService<S>
95where
96 S: tower::Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
97 S::Future: Send + 'static,
98{
99 type Response = S::Response;
100 type Error = S::Error;
101 type Future = std::pin::Pin<
102 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
103 >;
104
105 fn poll_ready(
106 &mut self,
107 cx: &mut std::task::Context<'_>,
108 ) -> std::task::Poll<Result<(), Self::Error>> {
109 self.inner.poll_ready(cx)
110 }
111
112 fn call(&mut self, request: Request<Body>) -> Self::Future {
113 let allowed_methods = self.allowed_methods.clone();
114 let mut inner = self.inner.clone();
115
116 Box::pin(async move {
117 let (parts, body_bytes) = extract_parts_and_body_bytes(request).await;
118
119 match verify_jsonrpc_method(&body_bytes, &allowed_methods) {
120 Ok(_) => {}
121 Err(_) => {
122 return Ok(build_response_with_graceful_error(
123 None,
124 StatusCode::METHOD_NOT_ALLOWED,
125 "",
126 ));
127 }
128 }
129
130 let new_body = Body::from(body_bytes);
131 let new_request = Request::from_parts(parts, new_body);
132 inner.call(new_request).await
133 })
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use http::Method;
141 use std::{
142 future::Ready,
143 task::{Context, Poll},
144 };
145 use tower::{Layer, Service, ServiceExt};
146
147 #[derive(Clone)]
149 struct MockService;
150
151 impl tower::Service<Request<Body>> for MockService {
152 type Response = Response<Body>;
153 type Error = std::convert::Infallible;
154 type Future = Ready<Result<Self::Response, Self::Error>>;
155
156 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
157 Poll::Ready(Ok(()))
158 }
159
160 fn call(&mut self, _: Request<Body>) -> Self::Future {
161 std::future::ready(Ok(Response::builder().status(200).body(Body::empty()).unwrap()))
162 }
163 }
164
165 #[tokio::test]
166 async fn test_method_validation_disallowed_method() {
167 let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
168 let layer = MethodValidationLayer::new(allowed_methods);
169 let mut service = layer.layer(MockService);
170
171 let body = r#"{"jsonrpc":"2.0","method":"unknownMethod","id":1}"#;
172 let request =
173 Request::builder().method(Method::POST).uri("/test").body(Body::from(body)).unwrap();
174
175 let response = service.ready().await.unwrap().call(request).await.unwrap();
176 assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
177 }
178
179 #[tokio::test]
180 async fn test_method_validation_malformed_json() {
181 let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
182 let layer = MethodValidationLayer::new(allowed_methods);
183 let mut service = layer.layer(MockService);
184
185 let body = r#"{"invalid json"#;
186 let request =
187 Request::builder().method(Method::POST).uri("/test").body(Body::from(body)).unwrap();
188
189 let response = service.ready().await.unwrap().call(request).await.unwrap();
190 assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
191 }
192
193 #[tokio::test]
194 async fn test_method_validation_missing_method_field() {
195 let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
196 let layer = MethodValidationLayer::new(allowed_methods);
197 let mut service = layer.layer(MockService);
198
199 let body = r#"{"jsonrpc":"2.0","id":1}"#;
200 let request =
201 Request::builder().method(Method::POST).uri("/test").body(Body::from(body)).unwrap();
202
203 let response = service.ready().await.unwrap().call(request).await.unwrap();
204 assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
205 }
206
207 #[tokio::test]
208 async fn test_method_validation_multiple_allowed_methods() {
209 let allowed_methods = vec![
210 "liveness".to_string(),
211 "getConfig".to_string(),
212 "signTransaction".to_string(),
213 "estimateTransactionFee".to_string(),
214 ];
215 let layer = MethodValidationLayer::new(allowed_methods);
216 let mut service = layer.layer(MockService);
217
218 for method in &["liveness", "getConfig", "signTransaction", "estimateTransactionFee"] {
220 let body = format!(r#"{{"jsonrpc":"2.0","method":"{}","id":1}}"#, method);
221 let request = Request::builder()
222 .method(Method::POST)
223 .uri("/test")
224 .body(Body::from(body))
225 .unwrap();
226
227 let response = service.ready().await.unwrap().call(request).await.unwrap();
228 assert_eq!(response.status(), StatusCode::OK, "Method {} should be allowed", method);
229 }
230 }
231}