aws_lambda_router/
middleware.rs

1use async_trait::async_trait;
2use lambda_runtime::Error;
3use crate::{Request, Response};
4
5/// Next function type for middleware chain
6pub type Next = Box<dyn Fn(Request) -> futures::future::BoxFuture<'static, Result<Response, Error>> + Send + Sync>;
7
8/// Middleware trait
9#[async_trait]
10pub trait Middleware: Send + Sync {
11    /// Execute middleware
12    async fn handle(&self, req: Request, next: Next) -> Result<Response, Error>;
13}
14
15/// Function-based middleware wrapper
16pub struct MiddlewareFn<F>
17where
18    F: Fn(Request, Next) -> futures::future::BoxFuture<'static, Result<Response, Error>> + Send + Sync,
19{
20    func: F,
21}
22
23impl<F> MiddlewareFn<F>
24where
25    F: Fn(Request, Next) -> futures::future::BoxFuture<'static, Result<Response, Error>> + Send + Sync,
26{
27    pub fn new(func: F) -> Self {
28        Self { func }
29    }
30}
31
32#[async_trait]
33impl<F> Middleware for MiddlewareFn<F>
34where
35    F: Fn(Request, Next) -> futures::future::BoxFuture<'static, Result<Response, Error>> + Send + Sync,
36{
37    async fn handle(&self, req: Request, next: Next) -> Result<Response, Error> {
38        (self.func)(req, next).await
39    }
40}
41
42/// Logging middleware
43pub struct LoggingMiddleware;
44
45#[async_trait]
46impl Middleware for LoggingMiddleware {
47    async fn handle(&self, req: Request, next: Next) -> Result<Response, Error> {
48        println!("→ {} {}", req.method, req.path);
49        let response = next(req).await?;
50        println!("← {}", response.status_code);
51        Ok(response)
52    }
53}
54
55/// CORS middleware
56pub struct CorsMiddleware {
57    allow_origin: String,
58    allow_methods: String,
59    allow_headers: String,
60    max_age: String,
61}
62
63impl CorsMiddleware {
64    pub fn new() -> Self {
65        Self {
66            allow_origin: "*".to_string(),
67            allow_methods: "GET, POST, PUT, DELETE, OPTIONS".to_string(),
68            allow_headers: "Content-Type, Authorization".to_string(),
69            max_age: "3600".to_string(),
70        }
71    }
72    
73    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
74        self.allow_origin = origin.into();
75        self
76    }
77    
78    pub fn allow_methods(mut self, methods: impl Into<String>) -> Self {
79        self.allow_methods = methods.into();
80        self
81    }
82    
83    pub fn allow_headers(mut self, headers: impl Into<String>) -> Self {
84        self.allow_headers = headers.into();
85        self
86    }
87}
88
89impl Default for CorsMiddleware {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95#[async_trait]
96impl Middleware for CorsMiddleware {
97    async fn handle(&self, req: Request, next: Next) -> Result<Response, Error> {
98        // Handle preflight
99        if req.is_preflight() {
100            return Ok(Response::cors_preflight());
101        }
102        
103        // Add CORS headers to response
104        let mut response = next(req).await?;
105        response = response
106            .header("Access-Control-Allow-Origin", &self.allow_origin)
107            .header("Access-Control-Allow-Methods", &self.allow_methods)
108            .header("Access-Control-Allow-Headers", &self.allow_headers)
109            .header("Access-Control-Max-Age", &self.max_age);
110        
111        Ok(response)
112    }
113}