Skip to main content

kora_lib/rpc_server/
recaptcha.rs

1use crate::{
2    constant::X_RECAPTCHA_TOKEN,
3    rpc_server::{
4        middleware_utils::{extract_parts_and_body_bytes, get_jsonrpc_method},
5        recaptcha_util::RecaptchaConfig,
6    },
7};
8use http::{Request, Response};
9use jsonrpsee::server::logger::Body;
10
11#[derive(Clone)]
12pub struct RecaptchaLayer {
13    config: RecaptchaConfig,
14}
15
16impl RecaptchaLayer {
17    pub fn new(config: RecaptchaConfig) -> Self {
18        Self { config }
19    }
20}
21
22impl<S> tower::Layer<S> for RecaptchaLayer {
23    type Service = RecaptchaService<S>;
24
25    fn layer(&self, inner: S) -> Self::Service {
26        RecaptchaService { inner, config: self.config.clone() }
27    }
28}
29
30#[derive(Clone)]
31pub struct RecaptchaService<S> {
32    inner: S,
33    config: RecaptchaConfig,
34}
35
36impl<S> tower::Service<Request<Body>> for RecaptchaService<S>
37where
38    S: tower::Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
39    S::Future: Send + 'static,
40{
41    type Response = S::Response;
42    type Error = S::Error;
43    type Future = std::pin::Pin<
44        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
45    >;
46
47    fn poll_ready(
48        &mut self,
49        cx: &mut std::task::Context<'_>,
50    ) -> std::task::Poll<Result<(), Self::Error>> {
51        self.inner.poll_ready(cx)
52    }
53
54    fn call(&mut self, request: Request<Body>) -> Self::Future {
55        let config = self.config.clone();
56        let mut inner = self.inner.clone();
57
58        Box::pin(async move {
59            let (parts, body_bytes) = extract_parts_and_body_bytes(request).await;
60
61            if let Some(method) = get_jsonrpc_method(&body_bytes) {
62                if !config.is_protected_method(&method) {
63                    let new_request = Request::from_parts(parts, Body::from(body_bytes));
64                    return inner.call(new_request).await;
65                }
66            }
67
68            let new_request = Request::from_parts(parts, Body::from(body_bytes));
69            let recaptcha_token =
70                new_request.headers().get(X_RECAPTCHA_TOKEN).and_then(|v| v.to_str().ok());
71
72            if let Err(resp) = config.validate(recaptcha_token).await {
73                return Ok(resp);
74            }
75
76            inner.call(new_request).await
77        })
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use http::{Method, StatusCode};
85    use std::{
86        future::Ready,
87        task::{Context, Poll},
88    };
89    use tower::{Layer, Service, ServiceExt};
90
91    #[derive(Clone)]
92    struct MockService;
93
94    impl tower::Service<Request<Body>> for MockService {
95        type Response = Response<Body>;
96        type Error = std::convert::Infallible;
97        type Future = Ready<Result<Self::Response, Self::Error>>;
98
99        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100            Poll::Ready(Ok(()))
101        }
102
103        fn call(&mut self, _: Request<Body>) -> Self::Future {
104            std::future::ready(Ok(Response::builder().status(200).body(Body::empty()).unwrap()))
105        }
106    }
107
108    fn test_recaptcha_config() -> RecaptchaConfig {
109        RecaptchaConfig::new(
110            "test-recaptcha-secret".to_string(),
111            0.5,
112            vec!["signTransaction".to_string(), "signAndSendTransaction".to_string()],
113        )
114    }
115
116    #[tokio::test]
117    async fn test_recaptcha_layer_bypasses_unprotected_method() {
118        let config = test_recaptcha_config();
119        let layer = RecaptchaLayer::new(config);
120        let mut service = layer.layer(MockService);
121
122        let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
123        let request =
124            Request::builder().method(Method::POST).uri("/").body(Body::from(body)).unwrap();
125
126        let response = service.ready().await.unwrap().call(request).await.unwrap();
127        assert_eq!(response.status(), StatusCode::OK);
128    }
129
130    #[tokio::test]
131    async fn test_recaptcha_layer_rejects_protected_method_missing_token() {
132        let config = test_recaptcha_config();
133        let layer = RecaptchaLayer::new(config);
134        let mut service = layer.layer(MockService);
135
136        let body = r#"{"jsonrpc":"2.0","method":"signTransaction","id":1}"#;
137        let request =
138            Request::builder().method(Method::POST).uri("/").body(Body::from(body)).unwrap();
139
140        let response = service.ready().await.unwrap().call(request).await.unwrap();
141        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
142    }
143
144    #[tokio::test]
145    async fn test_recaptcha_layer_rejects_sign_and_send_missing_token() {
146        let config = test_recaptcha_config();
147        let layer = RecaptchaLayer::new(config);
148        let mut service = layer.layer(MockService);
149
150        let body = r#"{"jsonrpc":"2.0","method":"signAndSendTransaction","id":1}"#;
151        let request =
152            Request::builder().method(Method::POST).uri("/").body(Body::from(body)).unwrap();
153
154        let response = service.ready().await.unwrap().call(request).await.unwrap();
155        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
156    }
157
158    #[tokio::test]
159    async fn test_recaptcha_layer_rejects_empty_token() {
160        let config = test_recaptcha_config();
161        let layer = RecaptchaLayer::new(config);
162        let mut service = layer.layer(MockService);
163
164        let body = r#"{"jsonrpc":"2.0","method":"signTransaction","id":1}"#;
165        let request = Request::builder()
166            .method(Method::POST)
167            .uri("/")
168            .header(X_RECAPTCHA_TOKEN, "")
169            .body(Body::from(body))
170            .unwrap();
171
172        let response = service.ready().await.unwrap().call(request).await.unwrap();
173        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
174    }
175}