avx_http/
middleware.rs

1//! Middleware system for HTTP server
2//!
3//! Middlewares allow you to process requests before they reach handlers
4//! and modify responses before they are sent to clients.
5
6use crate::server::{Request, Response};
7use crate::error::Result;
8use async_trait::async_trait;
9use http::{HeaderName, HeaderValue, StatusCode};
10use std::sync::Arc;
11
12/// Middleware trait for processing HTTP requests and responses
13#[async_trait]
14pub trait Middleware: Send + Sync {
15    /// Process the request and response
16    ///
17    /// The middleware can:
18    /// - Inspect or modify the request
19    /// - Short-circuit by returning early
20    /// - Call `next.run(request).await` to continue the chain
21    async fn handle(&self, request: Request, next: Next) -> Result<Response>;
22}
23
24/// Represents the next middleware/handler in the chain
25pub struct Next {
26    middleware: Vec<Arc<dyn Middleware>>,
27    index: usize,
28    handler: Option<Arc<dyn Handler>>,
29}
30
31impl Next {
32    /// Create a new middleware chain
33    pub fn new(middleware: Vec<Arc<dyn Middleware>>, handler: Arc<dyn Handler>) -> Self {
34        Self {
35            middleware,
36            index: 0,
37            handler: Some(handler),
38        }
39    }
40
41    /// Run the next middleware or handler
42    pub async fn run(mut self, request: Request) -> Result<Response> {
43        if self.index < self.middleware.len() {
44            let middleware = self.middleware[self.index].clone();
45            self.index += 1;
46            middleware.handle(request, self).await
47        } else if let Some(handler) = &self.handler {
48            handler.handle(request).await
49        } else {
50            Ok(Response::text("Not Found").with_status(StatusCode::NOT_FOUND))
51        }
52    }
53}
54
55/// Handler trait for final request processing
56#[async_trait]
57pub trait Handler: Send + Sync {
58    /// Handle the request
59    async fn handle(&self, request: Request) -> Result<Response>;
60}
61
62/// Logger middleware - logs all requests
63pub struct Logger;
64
65impl Logger {
66    /// Create a new logger middleware
67    pub fn new() -> Self {
68        Self
69    }
70}
71
72#[async_trait]
73impl Middleware for Logger {
74    async fn handle(&self, request: Request, next: Next) -> Result<Response> {
75        let method = request.method.clone();
76        let path = request.path.clone();
77        let start = std::time::Instant::now();
78
79        println!("[REQUEST] {} {}", method, path);
80
81        let response = next.run(request).await?;
82
83        let duration = start.elapsed();
84        println!(
85            "[RESPONSE] {} {} - {} ({:?})",
86            method, path, response.status(), duration
87        );
88
89        Ok(response)
90    }
91}
92
93impl Default for Logger {
94    fn default() -> Self {
95        Self::new()
96    }
97}
98
99/// CORS middleware - adds CORS headers
100pub struct Cors {
101    allow_origin: String,
102    allow_methods: Vec<String>,
103    allow_headers: Vec<String>,
104    max_age: u32,
105}
106
107impl Cors {
108    /// Create a new CORS middleware with permissive settings
109    pub fn permissive() -> Self {
110        Self {
111            allow_origin: "*".to_string(),
112            allow_methods: vec![
113                "GET".to_string(),
114                "POST".to_string(),
115                "PUT".to_string(),
116                "DELETE".to_string(),
117                "PATCH".to_string(),
118                "OPTIONS".to_string(),
119            ],
120            allow_headers: vec![
121                "Content-Type".to_string(),
122                "Authorization".to_string(),
123                "X-Requested-With".to_string(),
124            ],
125            max_age: 86400,
126        }
127    }
128
129    /// Create a new CORS middleware with custom origin
130    pub fn new(allow_origin: impl Into<String>) -> Self {
131        Self {
132            allow_origin: allow_origin.into(),
133            allow_methods: vec!["GET".to_string(), "POST".to_string()],
134            allow_headers: vec!["Content-Type".to_string()],
135            max_age: 3600,
136        }
137    }
138
139    /// Set allowed methods
140    pub fn allow_methods(mut self, methods: Vec<String>) -> Self {
141        self.allow_methods = methods;
142        self
143    }
144
145    /// Set allowed headers
146    pub fn allow_headers(mut self, headers: Vec<String>) -> Self {
147        self.allow_headers = headers;
148        self
149    }
150
151    /// Set max age
152    pub fn max_age(mut self, seconds: u32) -> Self {
153        self.max_age = seconds;
154        self
155    }
156}
157
158#[async_trait]
159impl Middleware for Cors {
160    async fn handle(&self, request: Request, next: Next) -> Result<Response> {
161        // Handle preflight OPTIONS request
162        if request.method.as_str() == "OPTIONS" {
163            let response = Response::text("")
164                .with_header(
165                    HeaderName::from_static("access-control-allow-origin"),
166                    HeaderValue::from_str(&self.allow_origin).unwrap(),
167                )
168                .with_header(
169                    HeaderName::from_static("access-control-allow-methods"),
170                    HeaderValue::from_str(&self.allow_methods.join(", ")).unwrap(),
171                )
172                .with_header(
173                    HeaderName::from_static("access-control-allow-headers"),
174                    HeaderValue::from_str(&self.allow_headers.join(", ")).unwrap(),
175                )
176                .with_header(
177                    HeaderName::from_static("access-control-max-age"),
178                    HeaderValue::from_str(&self.max_age.to_string()).unwrap(),
179                )
180                .with_status(StatusCode::NO_CONTENT);
181            return Ok(response);
182        }
183
184        // Add CORS headers to response
185        let response = next.run(request).await?;
186        let response = response.with_header(
187            HeaderName::from_static("access-control-allow-origin"),
188            HeaderValue::from_str(&self.allow_origin).unwrap(),
189        );
190
191        Ok(response)
192    }
193}
194
195/// Rate limiting middleware
196#[allow(dead_code)]
197pub struct RateLimit {
198    max_requests: usize,
199    window_secs: u64,
200    // In a real implementation, this would use a proper storage backend
201}
202
203impl RateLimit {
204    /// Create a new rate limiter
205    pub fn new(max_requests: usize, window_secs: u64) -> Self {
206        Self {
207            max_requests,
208            window_secs,
209        }
210    }
211}
212
213#[async_trait]
214impl Middleware for RateLimit {
215    async fn handle(&self, request: Request, next: Next) -> Result<Response> {
216        // TODO: Implement actual rate limiting with storage
217        // For now, just pass through
218        next.run(request).await
219    }
220}
221
222/// Authentication middleware
223pub struct Auth {
224    token: String,
225}
226
227impl Auth {
228    /// Create a new auth middleware with bearer token
229    pub fn bearer(token: impl Into<String>) -> Self {
230        Self {
231            token: token.into(),
232        }
233    }
234}
235
236#[async_trait]
237impl Middleware for Auth {
238    async fn handle(&self, request: Request, next: Next) -> Result<Response> {
239        // Check for Authorization header
240        if let Some(auth_header) = request.headers.get("authorization") {
241            if let Ok(auth_str) = auth_header.to_str() {
242                if auth_str.starts_with("Bearer ") {
243                    let token = &auth_str[7..];
244                    if token == self.token {
245                        return next.run(request).await;
246                    }
247                }
248            }
249        }
250
251        Ok(Response::text("Unauthorized")
252            .with_status(StatusCode::UNAUTHORIZED)
253            .with_header(
254                HeaderName::from_static("www-authenticate"),
255                HeaderValue::from_static("Bearer"),
256            ))
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use http::Method;
264
265    struct TestHandler;
266
267    #[async_trait]
268    impl Handler for TestHandler {
269        async fn handle(&self, _request: Request) -> Result<Response> {
270            Ok(Response::text("Hello"))
271        }
272    }
273
274    #[tokio::test]
275    async fn test_logger_middleware() {
276        let logger = Logger::new();
277        let handler = Arc::new(TestHandler);
278        let request = Request {
279            method: Method::GET,
280            path: "/test".to_string(),
281            headers: http::HeaderMap::new(),
282            body: bytes::Bytes::new(),
283        };
284
285        let next = Next::new(vec![], handler);
286        let response = logger.handle(request, next).await.unwrap();
287        assert_eq!(response.status(), StatusCode::OK);
288    }
289
290    #[tokio::test]
291    async fn test_cors_middleware() {
292        let cors = Cors::permissive();
293        let handler = Arc::new(TestHandler);
294        let request = Request {
295            method: Method::GET,
296            path: "/test".to_string(),
297            headers: http::HeaderMap::new(),
298            body: bytes::Bytes::new(),
299        };
300
301        let next = Next::new(vec![], handler);
302        let response = cors.handle(request, next).await.unwrap();
303
304        assert!(response.headers().contains_key("access-control-allow-origin"));
305    }
306
307    #[tokio::test]
308    async fn test_cors_preflight() {
309        let cors = Cors::permissive();
310        let handler = Arc::new(TestHandler);
311        let request = Request {
312            method: Method::OPTIONS,
313            path: "/test".to_string(),
314            headers: http::HeaderMap::new(),
315            body: bytes::Bytes::new(),
316        };
317
318        let next = Next::new(vec![], handler);
319        let response = cors.handle(request, next).await.unwrap();
320
321        assert_eq!(response.status(), StatusCode::NO_CONTENT);
322        assert!(response.headers().contains_key("access-control-allow-methods"));
323    }
324
325    #[tokio::test]
326    async fn test_auth_middleware_success() {
327        let auth = Auth::bearer("secret-token");
328        let handler = Arc::new(TestHandler);
329
330        let mut headers = http::HeaderMap::new();
331        headers.insert(
332            "authorization",
333            http::HeaderValue::from_static("Bearer secret-token"),
334        );
335
336        let request = Request {
337            method: Method::GET,
338            path: "/protected".to_string(),
339            headers,
340            body: bytes::Bytes::new(),
341        };
342
343        let next = Next::new(vec![], handler);
344        let response = auth.handle(request, next).await.unwrap();
345        assert_eq!(response.status(), StatusCode::OK);
346    }
347
348    #[tokio::test]
349    async fn test_auth_middleware_failure() {
350        let auth = Auth::bearer("secret-token");
351        let handler = Arc::new(TestHandler);
352        let request = Request {
353            method: Method::GET,
354            path: "/protected".to_string(),
355            headers: http::HeaderMap::new(),
356            body: bytes::Bytes::new(),
357        };
358
359        let next = Next::new(vec![], handler);
360        let response = auth.handle(request, next).await.unwrap();
361        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
362    }
363}