Skip to main content

kora_lib/rpc_server/
auth.rs

1use crate::{
2    constant::{X_API_KEY, X_HMAC_SIGNATURE, X_TIMESTAMP},
3    rpc_server::middleware_utils::{
4        build_response_with_graceful_error, extract_parts_and_body_bytes, get_jsonrpc_method,
5    },
6};
7use hmac::{Hmac, Mac};
8use http::{Request, Response, StatusCode};
9use jsonrpsee::server::logger::Body;
10use sha2::Sha256;
11use subtle::ConstantTimeEq;
12
13#[derive(Clone)]
14pub struct ApiKeyAuthLayer {
15    api_key: String,
16}
17
18impl ApiKeyAuthLayer {
19    pub fn new(api_key: String) -> Self {
20        Self { api_key }
21    }
22}
23
24#[derive(Clone)]
25pub struct ApiKeyAuthService<S> {
26    inner: S,
27    api_key: String,
28}
29
30impl<S> tower::Layer<S> for ApiKeyAuthLayer {
31    type Service = ApiKeyAuthService<S>;
32    fn layer(&self, inner: S) -> Self::Service {
33        ApiKeyAuthService { inner, api_key: self.api_key.clone() }
34    }
35}
36
37impl<S> tower::Service<Request<Body>> for ApiKeyAuthService<S>
38where
39    S: tower::Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
40    S::Future: Send + 'static,
41{
42    type Response = S::Response;
43    type Error = S::Error;
44    type Future = std::pin::Pin<
45        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
46    >;
47
48    fn poll_ready(
49        &mut self,
50        cx: &mut std::task::Context<'_>,
51    ) -> std::task::Poll<Result<(), Self::Error>> {
52        self.inner.poll_ready(cx)
53    }
54
55    fn call(&mut self, request: Request<Body>) -> Self::Future {
56        let api_key = self.api_key.clone();
57        let mut inner = self.inner.clone();
58
59        Box::pin(async move {
60            let unauthorized_response =
61                build_response_with_graceful_error(None, StatusCode::UNAUTHORIZED, "");
62
63            let (parts, body_bytes) = extract_parts_and_body_bytes(request).await;
64
65            // Bypass auth for liveness endpoint
66            if let Some(method) = get_jsonrpc_method(&body_bytes) {
67                if method == "liveness" {
68                    let new_body = Body::from(body_bytes);
69                    let new_request = Request::from_parts(parts, new_body);
70                    return inner.call(new_request).await;
71                }
72            }
73
74            // Check for API key header
75            let req = Request::from_parts(parts, Body::from(body_bytes));
76            if let Some(provided_key) = req.headers().get(X_API_KEY) {
77                // Constant-time comparison prevents timing attacks
78                if provided_key.as_bytes().ct_eq(api_key.as_bytes()).into() {
79                    return inner.call(req).await;
80                }
81            }
82
83            Ok(unauthorized_response)
84        })
85    }
86}
87
88#[derive(Clone)]
89pub struct HmacAuthLayer {
90    secret: String,
91    max_timestamp_age: i64,
92}
93
94impl HmacAuthLayer {
95    pub fn new(secret: String, max_timestamp_age: i64) -> Self {
96        Self { secret, max_timestamp_age }
97    }
98}
99
100impl<S> tower::Layer<S> for HmacAuthLayer {
101    type Service = HmacAuthService<S>;
102
103    fn layer(&self, inner: S) -> Self::Service {
104        HmacAuthService {
105            inner,
106            secret: self.secret.clone(),
107            max_timestamp_age: self.max_timestamp_age,
108        }
109    }
110}
111
112#[derive(Clone)]
113pub struct HmacAuthService<S> {
114    inner: S,
115    secret: String,
116    max_timestamp_age: i64,
117}
118
119impl<S> tower::Service<Request<Body>> for HmacAuthService<S>
120where
121    S: tower::Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
122    S::Future: Send + 'static,
123{
124    type Response = S::Response;
125    type Error = S::Error;
126    type Future = std::pin::Pin<
127        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
128    >;
129
130    fn poll_ready(
131        &mut self,
132        cx: &mut std::task::Context<'_>,
133    ) -> std::task::Poll<Result<(), Self::Error>> {
134        self.inner.poll_ready(cx)
135    }
136
137    fn call(&mut self, request: Request<Body>) -> Self::Future {
138        let secret = self.secret.clone();
139        let max_timestamp_age = self.max_timestamp_age;
140        let mut inner = self.inner.clone();
141
142        Box::pin(async move {
143            let unauthorized_response =
144                build_response_with_graceful_error(None, StatusCode::UNAUTHORIZED, "");
145
146            let signature_header = request.headers().get(X_HMAC_SIGNATURE).cloned();
147            let timestamp_header = request.headers().get(X_TIMESTAMP).cloned();
148
149            let (parts, body_bytes) = extract_parts_and_body_bytes(request).await;
150
151            // Bypass auth for liveness endpoint
152            if let Some(method) = get_jsonrpc_method(&body_bytes) {
153                if method == "liveness" {
154                    let new_body = Body::from(body_bytes);
155                    let new_request = Request::from_parts(parts, new_body);
156                    return inner.call(new_request).await;
157                }
158            }
159
160            let (signature, timestamp) =
161                match (signature_header.as_ref(), timestamp_header.as_ref()) {
162                    (Some(sig), Some(ts)) => (sig, ts),
163                    _ => return Ok(unauthorized_response),
164                };
165
166            let signature = signature.to_str().unwrap_or("");
167            let timestamp = timestamp.to_str().unwrap_or("");
168
169            // Verify timestamp is within allowed age
170            let ts = match timestamp.parse::<i64>() {
171                Ok(ts) => ts,
172                Err(_) => return Ok(unauthorized_response),
173            };
174            let now = std::time::SystemTime::now()
175                .duration_since(std::time::UNIX_EPOCH)
176                .map_err(|e| {
177                    log::error!("System time error: {e:?}");
178                    e
179                })
180                .unwrap_or_else(|_| std::time::Duration::from_secs(0))
181                .as_secs() as i64;
182
183            if (now - ts).abs() > max_timestamp_age {
184                return Ok(unauthorized_response);
185            }
186
187            // Verify HMAC signature using timestamp + body
188            let body_str = match std::str::from_utf8(&body_bytes) {
189                Ok(s) => s,
190                Err(_) => {
191                    log::error!("HMAC authentication failed: invalid UTF-8 in request body");
192                    return Ok(unauthorized_response);
193                }
194            };
195            let message = format!("{}{}", timestamp, body_str);
196
197            let mut mac = match Hmac::<Sha256>::new_from_slice(secret.as_bytes()) {
198                Ok(mac) => mac,
199                Err(_) => {
200                    log::error!("HMAC authentication failed");
201                    return Ok(unauthorized_response);
202                }
203            };
204
205            mac.update(message.as_bytes());
206
207            let signature_bytes = match hex::decode(signature) {
208                Ok(bytes) => bytes,
209                Err(_) => {
210                    log::error!("HMAC signature hex decode failed");
211                    return Ok(unauthorized_response);
212                }
213            };
214
215            // Constant time comparison prevents timing attacks
216            if mac.verify_slice(&signature_bytes).is_err() {
217                return Ok(unauthorized_response);
218            }
219
220            // Reconstruct the request with the consumed body
221            let new_body = Body::from(body_bytes);
222            let new_request = Request::from_parts(parts, new_body);
223
224            inner.call(new_request).await
225        })
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::constant::{DEFAULT_MAX_TIMESTAMP_AGE, X_API_KEY, X_HMAC_SIGNATURE, X_TIMESTAMP};
233    use hmac::{Hmac, Mac};
234    use http::Method;
235    use jsonrpsee::server::logger::Body;
236    use sha2::Sha256;
237    use std::{
238        future::Ready,
239        task::{Context, Poll},
240    };
241    use tower::{Layer, Service, ServiceExt};
242
243    // Mock service that always returns OK
244    #[derive(Clone)]
245    struct MockService;
246
247    impl tower::Service<Request<Body>> for MockService {
248        type Response = Response<Body>;
249        type Error = std::convert::Infallible;
250        type Future = Ready<Result<Self::Response, Self::Error>>;
251
252        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
253            Poll::Ready(Ok(()))
254        }
255
256        fn call(&mut self, _: Request<Body>) -> Self::Future {
257            std::future::ready(Ok(Response::builder().status(200).body(Body::empty()).unwrap()))
258        }
259    }
260
261    #[tokio::test]
262    async fn test_api_key_auth_valid_key() {
263        let layer = ApiKeyAuthLayer::new("test-key".to_string());
264        let mut service = layer.layer(MockService);
265        let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
266        let request = Request::builder()
267            .uri("/test")
268            .header(X_API_KEY, "test-key")
269            .body(Body::from(body))
270            .unwrap();
271
272        let response = service.ready().await.unwrap().call(request).await.unwrap();
273        assert_eq!(response.status(), StatusCode::OK);
274    }
275
276    #[tokio::test]
277    async fn test_api_key_auth_invalid_key() {
278        let layer = ApiKeyAuthLayer::new("test-key".to_string());
279        let mut service = layer.layer(MockService);
280        let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
281        let request = Request::builder()
282            .uri("/test")
283            .header(X_API_KEY, "wrong-key")
284            .body(Body::from(body))
285            .unwrap();
286
287        let response = service.ready().await.unwrap().call(request).await.unwrap();
288        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
289    }
290
291    #[tokio::test]
292    async fn test_api_key_auth_missing_header() {
293        let layer = ApiKeyAuthLayer::new("test-key".to_string());
294        let mut service = layer.layer(MockService);
295        let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
296        let request = Request::builder().uri("/test").body(Body::from(body)).unwrap();
297
298        let response = service.ready().await.unwrap().call(request).await.unwrap();
299        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
300    }
301
302    #[tokio::test]
303    async fn test_api_key_auth_liveness_bypass() {
304        let layer = ApiKeyAuthLayer::new("test-key".to_string());
305        let mut service = layer.layer(MockService);
306        let liveness_body = r#"{"jsonrpc":"2.0","method":"liveness","params":[],"id":1}"#;
307        let request = Request::builder()
308            .method(Method::POST)
309            .uri("/")
310            .body(Body::from(liveness_body))
311            .unwrap();
312
313        let response = service.ready().await.unwrap().call(request).await.unwrap();
314        assert_eq!(response.status(), StatusCode::OK);
315    }
316
317    #[tokio::test]
318    async fn test_hmac_auth_valid_signature() {
319        let secret = "test-secret";
320        let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE);
321        let mut service = layer.layer(MockService);
322
323        let timestamp = std::time::SystemTime::now()
324            .duration_since(std::time::UNIX_EPOCH)
325            .unwrap()
326            .as_secs()
327            .to_string();
328
329        let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
330        let message = format!("{timestamp}{body}");
331
332        let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
333        mac.update(message.as_bytes());
334        let signature = hex::encode(mac.finalize().into_bytes());
335
336        let request = Request::builder()
337            .method(Method::POST)
338            .uri("/test")
339            .header(X_TIMESTAMP, &timestamp)
340            .header(X_HMAC_SIGNATURE, &signature)
341            .body(Body::from(body))
342            .unwrap();
343
344        let response = service.ready().await.unwrap().call(request).await.unwrap();
345        assert_eq!(response.status(), StatusCode::OK);
346    }
347
348    #[tokio::test]
349    async fn test_hmac_auth_invalid_signature() {
350        let secret = "test-secret";
351        let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE);
352        let mut service = layer.layer(MockService);
353
354        let timestamp = std::time::SystemTime::now()
355            .duration_since(std::time::UNIX_EPOCH)
356            .unwrap()
357            .as_secs()
358            .to_string();
359
360        let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
361
362        let request = Request::builder()
363            .method(Method::POST)
364            .uri("/test")
365            .header(X_TIMESTAMP, &timestamp)
366            .header(X_HMAC_SIGNATURE, "invalid-signature")
367            .body(Body::from(body))
368            .unwrap();
369
370        let response = service.ready().await.unwrap().call(request).await.unwrap();
371        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
372    }
373
374    #[tokio::test]
375    async fn test_hmac_auth_missing_headers() {
376        let secret = "test-secret";
377        let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE);
378        let mut service = layer.layer(MockService);
379
380        let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
381        let request =
382            Request::builder().method(Method::POST).uri("/test").body(Body::from(body)).unwrap();
383
384        let response = service.ready().await.unwrap().call(request).await.unwrap();
385        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
386    }
387
388    #[tokio::test]
389    async fn test_hmac_auth_expired_timestamp() {
390        let secret = "test-secret";
391        let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE);
392        let mut service = layer.layer(MockService);
393
394        // Timestamp from 10 minutes ago (expired)
395        let expired_timestamp =
396            (std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs()
397                - 600)
398                .to_string();
399
400        let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
401        let message = format!("{expired_timestamp}{body}");
402
403        let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
404        mac.update(message.as_bytes());
405        let signature = hex::encode(mac.finalize().into_bytes());
406
407        let request = Request::builder()
408            .method(Method::POST)
409            .uri("/test")
410            .header(X_TIMESTAMP, &expired_timestamp)
411            .header(X_HMAC_SIGNATURE, &signature)
412            .body(Body::from(body))
413            .unwrap();
414
415        let response = service.ready().await.unwrap().call(request).await.unwrap();
416        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
417    }
418
419    #[tokio::test]
420    async fn test_hmac_auth_malformed_timestamp() {
421        let secret = "test-secret";
422        let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE);
423        let mut service = layer.layer(MockService);
424
425        let body = r#"{"jsonrpc":"2.0","method":"getConfig","id":1}"#;
426
427        let request = Request::builder()
428            .method(Method::POST)
429            .uri("/test")
430            .header(X_TIMESTAMP, "not-a-number")
431            .header(X_HMAC_SIGNATURE, "some-signature")
432            .body(Body::from(body))
433            .unwrap();
434
435        let response = service.ready().await.unwrap().call(request).await.unwrap();
436        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
437    }
438
439    #[tokio::test]
440    async fn test_hmac_auth_liveness_bypass() {
441        let secret = "test-secret";
442        let layer = HmacAuthLayer::new(secret.to_string(), DEFAULT_MAX_TIMESTAMP_AGE);
443        let mut service = layer.layer(MockService);
444
445        let liveness_body = r#"{"jsonrpc":"2.0","method":"liveness","params":[],"id":1}"#;
446        let request = Request::builder()
447            .method(Method::POST)
448            .uri("/")
449            .body(Body::from(liveness_body))
450            .unwrap();
451
452        let response = service.ready().await.unwrap().call(request).await.unwrap();
453        assert_eq!(response.status(), StatusCode::OK);
454    }
455}