tiny_proxy/api/
middleware.rs1use bytes::Bytes;
4use http_body::Body;
5use http_body_util::Full;
6use hyper::{Request, Response, StatusCode};
7
8pub async fn auth_middleware<B>(
10 req: Request<B>,
11 api_key: &str,
12) -> Result<Request<B>, Response<Full<Bytes>>>
13where
14 B: Body,
15{
16 let provided_key = req
17 .headers()
18 .get("X-API-Key")
19 .and_then(|h: &hyper::header::HeaderValue| h.to_str().ok());
20
21 match provided_key {
22 Some(key) if key == api_key => Ok(req),
23 Some(_) => Err(unauthorized_response("Invalid API key")),
24 None => Err(unauthorized_response("Missing API key")),
25 }
26}
27
28fn unauthorized_response(message: &str) -> Response<Full<Bytes>> {
29 let body = format!(r#"{{"error": "Unauthorized", "message": "{}"}}"#, message);
30
31 Response::builder()
32 .status(StatusCode::UNAUTHORIZED)
33 .header("Content-Type", "application/json")
34 .body(Full::new(Bytes::from(body)))
35 .unwrap()
36}
37
38pub fn logging_middleware<B: Body>(req: Request<B>) -> Request<B> {
40 let method = req.method();
41 let path = req.uri().path();
42 let client_ip = req
43 .headers()
44 .get("X-Real-IP")
45 .or_else(|| req.headers().get("X-Forwarded-For"))
46 .and_then(|h: &hyper::header::HeaderValue| h.to_str().ok())
47 .unwrap_or("unknown");
48
49 tracing::info!("API request from {}: {} {}", client_ip, method, path);
50
51 req
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57 use http_body_util::Empty;
58 use hyper::Request;
59
60 #[tokio::test]
61 async fn test_auth_middleware_valid_key() {
62 let req: Request<Empty<Bytes>> = Request::builder()
63 .header("X-API-Key", "secret-key-123")
64 .body(Empty::new())
65 .unwrap();
66
67 let api_key = "secret-key-123";
68 let result = auth_middleware(req, api_key).await;
69 assert!(result.is_ok());
70 }
71
72 #[tokio::test]
73 async fn test_auth_middleware_invalid_key() {
74 let req: Request<Empty<Bytes>> = Request::builder()
75 .header("X-API-Key", "wrong-key")
76 .body(Empty::new())
77 .unwrap();
78
79 let api_key = "secret-key-123";
80 let result = auth_middleware(req, api_key).await;
81 assert!(result.is_err());
82 }
83
84 #[tokio::test]
85 async fn test_auth_middleware_missing_key() {
86 let req: Request<Empty<Bytes>> = Request::builder().body(Empty::new()).unwrap();
87
88 let api_key = "secret-key-123";
89 let result = auth_middleware(req, api_key).await;
90 assert!(result.is_err());
91 }
92
93 #[tokio::test]
94 async fn test_logging_middleware() {
95 let req: Request<Empty<Bytes>> = Request::builder()
96 .header("X-Real-IP", "192.168.1.1")
97 .body(Empty::new())
98 .unwrap();
99
100 let _logged_req = logging_middleware(req);
101 }
102
103 #[tokio::test]
104 async fn test_unauthorized_response() {
105 let response = unauthorized_response("Test error");
106 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
107
108 let content_type = response
109 .headers()
110 .get("Content-Type")
111 .and_then(|h| h.to_str().ok());
112 assert_eq!(content_type, Some("application/json"));
113 }
114}