kora_lib/rpc_server/
middleware_utils.rs

1use std::collections::{HashMap, HashSet};
2
3use futures_util::TryStreamExt;
4use http::{Request, Response, StatusCode};
5use jsonrpsee::server::logger::Body;
6
7use crate::KoraError;
8
9pub fn default_sig_verify() -> bool {
10    false
11}
12
13pub async fn extract_parts_and_body_bytes(
14    request: Request<Body>,
15) -> (http::request::Parts, Vec<u8>) {
16    let (parts, body) = request.into_parts();
17    let body_bytes = body
18        .try_fold(Vec::new(), |mut acc, chunk| async move {
19            acc.extend_from_slice(&chunk);
20            Ok(acc)
21        })
22        .await
23        .unwrap_or_default();
24    (parts, body_bytes)
25}
26
27pub fn get_jsonrpc_method(body_bytes: &[u8]) -> Option<String> {
28    match serde_json::from_slice::<serde_json::Value>(body_bytes) {
29        Ok(val) => val.get("method").and_then(|m| m.as_str()).map(|s| s.to_string()),
30        Err(_) => None,
31    }
32}
33
34pub fn verify_jsonrpc_method(
35    body_bytes: &[u8],
36    allowed_methods: &HashSet<String>,
37) -> Result<String, KoraError> {
38    let method = get_jsonrpc_method(body_bytes);
39    if let Some(method) = method {
40        if allowed_methods.contains(&method) {
41            return Ok(method);
42        }
43    }
44    Err(KoraError::InvalidRequest("Method not allowed".to_string()))
45}
46
47pub fn build_response_with_graceful_error(
48    headers: Option<HashMap<String, String>>,
49    status_code: StatusCode,
50    error_message: &str,
51) -> Response<Body> {
52    let mut builder = Response::builder();
53
54    if let Some(headers) = headers {
55        for (key, value) in headers.iter() {
56            builder = builder.header(key, value);
57        }
58    }
59
60    builder.status(status_code).body(Body::from(error_message.to_string())).unwrap_or_else(|e| {
61        log::error!("Failed to build response, error: {e:?}");
62        let mut response = Response::new(Body::empty());
63        *response.status_mut() = status_code;
64        response
65    })
66}
67
68/// Method validation layer - applies first in middleware stack to fail fast
69#[derive(Clone)]
70pub struct MethodValidationLayer {
71    allowed_methods: HashSet<String>,
72}
73
74impl MethodValidationLayer {
75    pub fn new(allowed_methods: Vec<String>) -> Self {
76        Self { allowed_methods: allowed_methods.into_iter().collect() }
77    }
78}
79
80#[derive(Clone)]
81pub struct MethodValidationService<S> {
82    inner: S,
83    allowed_methods: HashSet<String>,
84}
85
86impl<S> tower::Layer<S> for MethodValidationLayer {
87    type Service = MethodValidationService<S>;
88
89    fn layer(&self, inner: S) -> Self::Service {
90        MethodValidationService { inner, allowed_methods: self.allowed_methods.clone() }
91    }
92}
93
94impl<S> tower::Service<Request<Body>> for MethodValidationService<S>
95where
96    S: tower::Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
97    S::Future: Send + 'static,
98{
99    type Response = S::Response;
100    type Error = S::Error;
101    type Future = std::pin::Pin<
102        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
103    >;
104
105    fn poll_ready(
106        &mut self,
107        cx: &mut std::task::Context<'_>,
108    ) -> std::task::Poll<Result<(), Self::Error>> {
109        self.inner.poll_ready(cx)
110    }
111
112    fn call(&mut self, request: Request<Body>) -> Self::Future {
113        let allowed_methods = self.allowed_methods.clone();
114        let mut inner = self.inner.clone();
115
116        Box::pin(async move {
117            let (parts, body_bytes) = extract_parts_and_body_bytes(request).await;
118
119            match verify_jsonrpc_method(&body_bytes, &allowed_methods) {
120                Ok(_) => {}
121                Err(_) => {
122                    return Ok(build_response_with_graceful_error(
123                        None,
124                        StatusCode::METHOD_NOT_ALLOWED,
125                        "",
126                    ));
127                }
128            }
129
130            let new_body = Body::from(body_bytes);
131            let new_request = Request::from_parts(parts, new_body);
132            inner.call(new_request).await
133        })
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use http::Method;
141    use std::{
142        future::Ready,
143        task::{Context, Poll},
144    };
145    use tower::{Layer, Service, ServiceExt};
146
147    // Mock service that always returns OK
148    #[derive(Clone)]
149    struct MockService;
150
151    impl tower::Service<Request<Body>> for MockService {
152        type Response = Response<Body>;
153        type Error = std::convert::Infallible;
154        type Future = Ready<Result<Self::Response, Self::Error>>;
155
156        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
157            Poll::Ready(Ok(()))
158        }
159
160        fn call(&mut self, _: Request<Body>) -> Self::Future {
161            std::future::ready(Ok(Response::builder().status(200).body(Body::empty()).unwrap()))
162        }
163    }
164
165    #[tokio::test]
166    async fn test_method_validation_disallowed_method() {
167        let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
168        let layer = MethodValidationLayer::new(allowed_methods);
169        let mut service = layer.layer(MockService);
170
171        let body = r#"{"jsonrpc":"2.0","method":"unknownMethod","id":1}"#;
172        let request =
173            Request::builder().method(Method::POST).uri("/test").body(Body::from(body)).unwrap();
174
175        let response = service.ready().await.unwrap().call(request).await.unwrap();
176        assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
177    }
178
179    #[tokio::test]
180    async fn test_method_validation_malformed_json() {
181        let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
182        let layer = MethodValidationLayer::new(allowed_methods);
183        let mut service = layer.layer(MockService);
184
185        let body = r#"{"invalid json"#;
186        let request =
187            Request::builder().method(Method::POST).uri("/test").body(Body::from(body)).unwrap();
188
189        let response = service.ready().await.unwrap().call(request).await.unwrap();
190        assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
191    }
192
193    #[tokio::test]
194    async fn test_method_validation_missing_method_field() {
195        let allowed_methods = vec!["liveness".to_string(), "getConfig".to_string()];
196        let layer = MethodValidationLayer::new(allowed_methods);
197        let mut service = layer.layer(MockService);
198
199        let body = r#"{"jsonrpc":"2.0","id":1}"#;
200        let request =
201            Request::builder().method(Method::POST).uri("/test").body(Body::from(body)).unwrap();
202
203        let response = service.ready().await.unwrap().call(request).await.unwrap();
204        assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
205    }
206
207    #[tokio::test]
208    async fn test_method_validation_multiple_allowed_methods() {
209        let allowed_methods = vec![
210            "liveness".to_string(),
211            "getConfig".to_string(),
212            "signTransaction".to_string(),
213            "estimateTransactionFee".to_string(),
214        ];
215        let layer = MethodValidationLayer::new(allowed_methods);
216        let mut service = layer.layer(MockService);
217
218        // Test each allowed method
219        for method in &["liveness", "getConfig", "signTransaction", "estimateTransactionFee"] {
220            let body = format!(r#"{{"jsonrpc":"2.0","method":"{}","id":1}}"#, method);
221            let request = Request::builder()
222                .method(Method::POST)
223                .uri("/test")
224                .body(Body::from(body))
225                .unwrap();
226
227            let response = service.ready().await.unwrap().call(request).await.unwrap();
228            assert_eq!(response.status(), StatusCode::OK, "Method {} should be allowed", method);
229        }
230    }
231}