elif_http/
middleware.rs

1//! # Middleware
2//!
3//! Basic middleware system for processing requests and responses.
4//! Provides async middleware trait and pipeline composition.
5
6pub mod logging;
7pub mod timing;
8pub mod tracing;
9pub mod timeout;
10pub mod body_limit;
11
12use std::future::Future;
13use std::pin::Pin;
14use axum::{
15    response::{Response, IntoResponse},
16    extract::Request,
17};
18
19use crate::{HttpResult, HttpError};
20
21/// Type alias for async middleware function result
22pub type MiddlewareResult = HttpResult<Response>;
23
24/// Type alias for boxed future returned by middleware
25pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
26
27/// Core middleware trait that can process requests before handlers
28/// and responses after handlers.
29pub trait Middleware: Send + Sync {
30    /// Process the request before it reaches the handler.
31    /// Can modify the request or return early response.
32    fn process_request<'a>(
33        &'a self, 
34        request: Request
35    ) -> BoxFuture<'a, Result<Request, Response>> {
36        Box::pin(async move { Ok(request) })
37    }
38    
39    /// Process the response after the handler processes it.
40    /// Can modify the response before returning to client.
41    fn process_response<'a>(
42        &'a self, 
43        response: Response
44    ) -> BoxFuture<'a, Response> {
45        Box::pin(async move { response })
46    }
47    
48    /// Optional middleware name for debugging
49    fn name(&self) -> &'static str {
50        "Middleware"
51    }
52}
53
54/// Middleware pipeline that composes multiple middleware in sequence
55#[derive(Default)]
56pub struct MiddlewarePipeline {
57    middleware: Vec<Box<dyn Middleware>>,
58}
59
60impl MiddlewarePipeline {
61    /// Create a new empty middleware pipeline
62    pub fn new() -> Self {
63        Self {
64            middleware: Vec::new(),
65        }
66    }
67    
68    /// Add middleware to the pipeline
69    pub fn add<M: Middleware + 'static>(mut self, middleware: M) -> Self {
70        self.middleware.push(Box::new(middleware));
71        self
72    }
73    
74    /// Process request through all middleware in order
75    pub async fn process_request(&self, mut request: Request) -> Result<Request, Response> {
76        for middleware in &self.middleware {
77            match middleware.process_request(request).await {
78                Ok(req) => request = req,
79                Err(response) => return Err(response),
80            }
81        }
82        Ok(request)
83    }
84    
85    /// Process response through all middleware in reverse order
86    pub async fn process_response(&self, mut response: Response) -> Response {
87        // Process in reverse order - last middleware added processes response first
88        for middleware in self.middleware.iter().rev() {
89            response = middleware.process_response(response).await;
90        }
91        response
92    }
93    
94    /// Get number of middleware in pipeline
95    pub fn len(&self) -> usize {
96        self.middleware.len()
97    }
98    
99    /// Check if pipeline is empty
100    pub fn is_empty(&self) -> bool {
101        self.middleware.is_empty()
102    }
103    
104    /// Get middleware names for debugging
105    pub fn names(&self) -> Vec<&'static str> {
106        self.middleware.iter().map(|m| m.name()).collect()
107    }
108}
109
110/// Middleware wrapper that can handle errors and convert them to responses
111pub struct ErrorHandlingMiddleware<M> {
112    inner: M,
113}
114
115impl<M> ErrorHandlingMiddleware<M> {
116    pub fn new(middleware: M) -> Self {
117        Self { inner: middleware }
118    }
119}
120
121impl<M: Middleware> Middleware for ErrorHandlingMiddleware<M> {
122    fn process_request<'a>(
123        &'a self, 
124        request: Request
125    ) -> BoxFuture<'a, Result<Request, Response>> {
126        Box::pin(async move {
127            // Delegate to inner middleware with error handling
128            match self.inner.process_request(request).await {
129                Ok(req) => Ok(req),
130                Err(response) => Err(response),
131            }
132        })
133    }
134    
135    fn process_response<'a>(
136        &'a self, 
137        response: Response
138    ) -> BoxFuture<'a, Response> {
139        Box::pin(async move {
140            self.inner.process_response(response).await
141        })
142    }
143    
144    fn name(&self) -> &'static str {
145        self.inner.name()
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use axum::http::{StatusCode, Method};
153    
154    struct TestMiddleware {
155        name: &'static str,
156    }
157    
158    impl TestMiddleware {
159        fn new(name: &'static str) -> Self {
160            Self { name }
161        }
162    }
163    
164    impl Middleware for TestMiddleware {
165        fn process_request<'a>(
166            &'a self, 
167            mut request: Request
168        ) -> BoxFuture<'a, Result<Request, Response>> {
169            Box::pin(async move {
170                // Add a header to track middleware execution
171                let headers = request.headers_mut();
172                headers.insert("X-Middleware", self.name.parse().unwrap());
173                Ok(request)
174            })
175        }
176        
177        fn process_response<'a>(
178            &'a self, 
179            mut response: Response
180        ) -> BoxFuture<'a, Response> {
181            Box::pin(async move {
182                // Add response header
183                let headers = response.headers_mut();
184                headers.insert("X-Response-Middleware", self.name.parse().unwrap());
185                response
186            })
187        }
188        
189        fn name(&self) -> &'static str {
190            self.name
191        }
192    }
193    
194    #[tokio::test]
195    async fn test_middleware_pipeline() {
196        let pipeline = MiddlewarePipeline::new()
197            .add(TestMiddleware::new("First"))
198            .add(TestMiddleware::new("Second"));
199        
200        // Create test request
201        let request = Request::builder()
202            .method(Method::GET)
203            .uri("/test")
204            .body(axum::body::Body::empty())
205            .unwrap();
206        
207        // Process request
208        let processed_request = pipeline.process_request(request).await.unwrap();
209        
210        // Should have header from last middleware (Second overwrites First)
211        assert_eq!(
212            processed_request.headers().get("X-Middleware").unwrap(),
213            "Second"
214        );
215        
216        // Create test response
217        let response = Response::builder()
218            .status(StatusCode::OK)
219            .body(axum::body::Body::empty())
220            .unwrap();
221        
222        // Process response
223        let processed_response = pipeline.process_response(response).await;
224        
225        // Should have header from first middleware (reverse order)
226        assert_eq!(
227            processed_response.headers().get("X-Response-Middleware").unwrap(),
228            "First"
229        );
230    }
231    
232    #[tokio::test]
233    async fn test_pipeline_info() {
234        let pipeline = MiddlewarePipeline::new()
235            .add(TestMiddleware::new("Test1"))
236            .add(TestMiddleware::new("Test2"));
237        
238        assert_eq!(pipeline.len(), 2);
239        assert!(!pipeline.is_empty());
240        assert_eq!(pipeline.names(), vec!["Test1", "Test2"]);
241    }
242    
243    #[tokio::test]
244    async fn test_empty_pipeline() {
245        let pipeline = MiddlewarePipeline::new();
246        
247        let request = Request::builder()
248            .method(Method::GET)
249            .uri("/test")
250            .body(axum::body::Body::empty())
251            .unwrap();
252        
253        let processed_request = pipeline.process_request(request).await.unwrap();
254        
255        // Request should pass through unchanged
256        assert_eq!(processed_request.method(), Method::GET);
257        assert_eq!(processed_request.uri().path(), "/test");
258    }
259}