elif_http/
middleware.rs

1//! # Middleware
2//!
3//! Basic middleware system for processing requests and responses.
4//! Provides async middleware trait and pipeline composition.
5
6pub mod logging;
7pub mod timing;
8
9use std::future::Future;
10use std::pin::Pin;
11use axum::{
12    response::{Response, IntoResponse},
13    extract::Request,
14};
15
16use crate::{HttpResult, HttpError};
17
18/// Type alias for async middleware function result
19pub type MiddlewareResult = HttpResult<Response>;
20
21/// Type alias for boxed future returned by middleware
22pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
23
24/// Core middleware trait that can process requests before handlers
25/// and responses after handlers.
26pub trait Middleware: Send + Sync {
27    /// Process the request before it reaches the handler.
28    /// Can modify the request or return early response.
29    fn process_request<'a>(
30        &'a self, 
31        request: Request
32    ) -> BoxFuture<'a, Result<Request, Response>> {
33        Box::pin(async move { Ok(request) })
34    }
35    
36    /// Process the response after the handler processes it.
37    /// Can modify the response before returning to client.
38    fn process_response<'a>(
39        &'a self, 
40        response: Response
41    ) -> BoxFuture<'a, Response> {
42        Box::pin(async move { response })
43    }
44    
45    /// Optional middleware name for debugging
46    fn name(&self) -> &'static str {
47        "Middleware"
48    }
49}
50
51/// Middleware pipeline that composes multiple middleware in sequence
52#[derive(Default)]
53pub struct MiddlewarePipeline {
54    middleware: Vec<Box<dyn Middleware>>,
55}
56
57impl MiddlewarePipeline {
58    /// Create a new empty middleware pipeline
59    pub fn new() -> Self {
60        Self {
61            middleware: Vec::new(),
62        }
63    }
64    
65    /// Add middleware to the pipeline
66    pub fn add<M: Middleware + 'static>(mut self, middleware: M) -> Self {
67        self.middleware.push(Box::new(middleware));
68        self
69    }
70    
71    /// Process request through all middleware in order
72    pub async fn process_request(&self, mut request: Request) -> Result<Request, Response> {
73        for middleware in &self.middleware {
74            match middleware.process_request(request).await {
75                Ok(req) => request = req,
76                Err(response) => return Err(response),
77            }
78        }
79        Ok(request)
80    }
81    
82    /// Process response through all middleware in reverse order
83    pub async fn process_response(&self, mut response: Response) -> Response {
84        // Process in reverse order - last middleware added processes response first
85        for middleware in self.middleware.iter().rev() {
86            response = middleware.process_response(response).await;
87        }
88        response
89    }
90    
91    /// Get number of middleware in pipeline
92    pub fn len(&self) -> usize {
93        self.middleware.len()
94    }
95    
96    /// Check if pipeline is empty
97    pub fn is_empty(&self) -> bool {
98        self.middleware.is_empty()
99    }
100    
101    /// Get middleware names for debugging
102    pub fn names(&self) -> Vec<&'static str> {
103        self.middleware.iter().map(|m| m.name()).collect()
104    }
105}
106
107/// Middleware wrapper that can handle errors and convert them to responses
108pub struct ErrorHandlingMiddleware<M> {
109    inner: M,
110}
111
112impl<M> ErrorHandlingMiddleware<M> {
113    pub fn new(middleware: M) -> Self {
114        Self { inner: middleware }
115    }
116}
117
118impl<M: Middleware> Middleware for ErrorHandlingMiddleware<M> {
119    fn process_request<'a>(
120        &'a self, 
121        request: Request
122    ) -> BoxFuture<'a, Result<Request, Response>> {
123        Box::pin(async move {
124            // Delegate to inner middleware with error handling
125            match self.inner.process_request(request).await {
126                Ok(req) => Ok(req),
127                Err(response) => Err(response),
128            }
129        })
130    }
131    
132    fn process_response<'a>(
133        &'a self, 
134        response: Response
135    ) -> BoxFuture<'a, Response> {
136        Box::pin(async move {
137            self.inner.process_response(response).await
138        })
139    }
140    
141    fn name(&self) -> &'static str {
142        self.inner.name()
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use axum::http::{StatusCode, Method};
150    
151    struct TestMiddleware {
152        name: &'static str,
153    }
154    
155    impl TestMiddleware {
156        fn new(name: &'static str) -> Self {
157            Self { name }
158        }
159    }
160    
161    impl Middleware for TestMiddleware {
162        fn process_request<'a>(
163            &'a self, 
164            mut request: Request
165        ) -> BoxFuture<'a, Result<Request, Response>> {
166            Box::pin(async move {
167                // Add a header to track middleware execution
168                let headers = request.headers_mut();
169                headers.insert("X-Middleware", self.name.parse().unwrap());
170                Ok(request)
171            })
172        }
173        
174        fn process_response<'a>(
175            &'a self, 
176            mut response: Response
177        ) -> BoxFuture<'a, Response> {
178            Box::pin(async move {
179                // Add response header
180                let headers = response.headers_mut();
181                headers.insert("X-Response-Middleware", self.name.parse().unwrap());
182                response
183            })
184        }
185        
186        fn name(&self) -> &'static str {
187            self.name
188        }
189    }
190    
191    #[tokio::test]
192    async fn test_middleware_pipeline() {
193        let pipeline = MiddlewarePipeline::new()
194            .add(TestMiddleware::new("First"))
195            .add(TestMiddleware::new("Second"));
196        
197        // Create test request
198        let request = Request::builder()
199            .method(Method::GET)
200            .uri("/test")
201            .body(axum::body::Body::empty())
202            .unwrap();
203        
204        // Process request
205        let processed_request = pipeline.process_request(request).await.unwrap();
206        
207        // Should have header from last middleware (Second overwrites First)
208        assert_eq!(
209            processed_request.headers().get("X-Middleware").unwrap(),
210            "Second"
211        );
212        
213        // Create test response
214        let response = Response::builder()
215            .status(StatusCode::OK)
216            .body(axum::body::Body::empty())
217            .unwrap();
218        
219        // Process response
220        let processed_response = pipeline.process_response(response).await;
221        
222        // Should have header from first middleware (reverse order)
223        assert_eq!(
224            processed_response.headers().get("X-Response-Middleware").unwrap(),
225            "First"
226        );
227    }
228    
229    #[tokio::test]
230    async fn test_pipeline_info() {
231        let pipeline = MiddlewarePipeline::new()
232            .add(TestMiddleware::new("Test1"))
233            .add(TestMiddleware::new("Test2"));
234        
235        assert_eq!(pipeline.len(), 2);
236        assert!(!pipeline.is_empty());
237        assert_eq!(pipeline.names(), vec!["Test1", "Test2"]);
238    }
239    
240    #[tokio::test]
241    async fn test_empty_pipeline() {
242        let pipeline = MiddlewarePipeline::new();
243        
244        let request = Request::builder()
245            .method(Method::GET)
246            .uri("/test")
247            .body(axum::body::Body::empty())
248            .unwrap();
249        
250        let processed_request = pipeline.process_request(request).await.unwrap();
251        
252        // Request should pass through unchanged
253        assert_eq!(processed_request.method(), Method::GET);
254        assert_eq!(processed_request.uri().path(), "/test");
255    }
256}