Skip to main content

tiny_proxy/api/
middleware.rs

1//! Middleware for API requests
2
3use bytes::Bytes;
4use http_body::Body;
5use http_body_util::Full;
6use hyper::{Request, Response, StatusCode};
7
8/// API authentication middleware
9pub async fn auth_middleware<B>(
10    req: Request<B>,
11    api_key: &str,
12) -> Result<Request<B>, Response<Full<Bytes>>>
13where
14    B: Body,
15{
16    let provided_key = req
17        .headers()
18        .get("X-API-Key")
19        .and_then(|h: &hyper::header::HeaderValue| h.to_str().ok());
20
21    match provided_key {
22        Some(key) if key == api_key => Ok(req),
23        Some(_) => Err(unauthorized_response("Invalid API key")),
24        None => Err(unauthorized_response("Missing API key")),
25    }
26}
27
28fn unauthorized_response(message: &str) -> Response<Full<Bytes>> {
29    let body = format!(r#"{{"error": "Unauthorized", "message": "{}"}}"#, message);
30
31    Response::builder()
32        .status(StatusCode::UNAUTHORIZED)
33        .header("Content-Type", "application/json")
34        .body(Full::new(Bytes::from(body)))
35        .unwrap()
36}
37
38/// Logging middleware
39pub fn logging_middleware<B: Body>(req: Request<B>) -> Request<B> {
40    let method = req.method();
41    let path = req.uri().path();
42    let client_ip = req
43        .headers()
44        .get("X-Real-IP")
45        .or_else(|| req.headers().get("X-Forwarded-For"))
46        .and_then(|h: &hyper::header::HeaderValue| h.to_str().ok())
47        .unwrap_or("unknown");
48
49    tracing::info!("API request from {}: {} {}", client_ip, method, path);
50
51    req
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57    use http_body_util::Empty;
58    use hyper::Request;
59
60    #[tokio::test]
61    async fn test_auth_middleware_valid_key() {
62        let req: Request<Empty<Bytes>> = Request::builder()
63            .header("X-API-Key", "secret-key-123")
64            .body(Empty::new())
65            .unwrap();
66
67        let api_key = "secret-key-123";
68        let result = auth_middleware(req, api_key).await;
69        assert!(result.is_ok());
70    }
71
72    #[tokio::test]
73    async fn test_auth_middleware_invalid_key() {
74        let req: Request<Empty<Bytes>> = Request::builder()
75            .header("X-API-Key", "wrong-key")
76            .body(Empty::new())
77            .unwrap();
78
79        let api_key = "secret-key-123";
80        let result = auth_middleware(req, api_key).await;
81        assert!(result.is_err());
82    }
83
84    #[tokio::test]
85    async fn test_auth_middleware_missing_key() {
86        let req: Request<Empty<Bytes>> = Request::builder().body(Empty::new()).unwrap();
87
88        let api_key = "secret-key-123";
89        let result = auth_middleware(req, api_key).await;
90        assert!(result.is_err());
91    }
92
93    #[tokio::test]
94    async fn test_logging_middleware() {
95        let req: Request<Empty<Bytes>> = Request::builder()
96            .header("X-Real-IP", "192.168.1.1")
97            .body(Empty::new())
98            .unwrap();
99
100        let _logged_req = logging_middleware(req);
101    }
102
103    #[tokio::test]
104    async fn test_unauthorized_response() {
105        let response = unauthorized_response("Test error");
106        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
107
108        let content_type = response
109            .headers()
110            .get("Content-Type")
111            .and_then(|h| h.to_str().ok());
112        assert_eq!(content_type, Some("application/json"));
113    }
114}