elif_http/
middleware.rs

1//! # Middleware
2//!
3//! Basic middleware system for processing requests and responses.
4//! Provides async middleware trait and pipeline composition.
5
6pub mod logging;
7pub mod enhanced_logging;
8pub mod timing;
9pub mod tracing;
10pub mod timeout;
11pub mod body_limit;
12
13use std::future::Future;
14use std::pin::Pin;
15use axum::{
16    response::{Response, IntoResponse},
17    extract::Request,
18};
19
20use crate::{HttpResult, HttpError};
21
22/// Type alias for async middleware function result
23pub type MiddlewareResult = HttpResult<Response>;
24
25/// Type alias for boxed future returned by middleware
26pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
27
28/// Core middleware trait that can process requests before handlers
29/// and responses after handlers.
30pub trait Middleware: Send + Sync {
31    /// Process the request before it reaches the handler.
32    /// Can modify the request or return early response.
33    fn process_request<'a>(
34        &'a self, 
35        request: Request
36    ) -> BoxFuture<'a, Result<Request, Response>> {
37        Box::pin(async move { Ok(request) })
38    }
39    
40    /// Process the response after the handler processes it.
41    /// Can modify the response before returning to client.
42    fn process_response<'a>(
43        &'a self, 
44        response: Response
45    ) -> BoxFuture<'a, Response> {
46        Box::pin(async move { response })
47    }
48    
49    /// Optional middleware name for debugging
50    fn name(&self) -> &'static str {
51        "Middleware"
52    }
53}
54
55/// Middleware pipeline that composes multiple middleware in sequence
56#[derive(Default)]
57pub struct MiddlewarePipeline {
58    middleware: Vec<Box<dyn Middleware>>,
59}
60
61impl MiddlewarePipeline {
62    /// Create a new empty middleware pipeline
63    pub fn new() -> Self {
64        Self {
65            middleware: Vec::new(),
66        }
67    }
68    
69    /// Add middleware to the pipeline
70    pub fn add<M: Middleware + 'static>(mut self, middleware: M) -> Self {
71        self.middleware.push(Box::new(middleware));
72        self
73    }
74    
75    /// Process request through all middleware in order
76    pub async fn process_request(&self, mut request: Request) -> Result<Request, Response> {
77        for middleware in &self.middleware {
78            match middleware.process_request(request).await {
79                Ok(req) => request = req,
80                Err(response) => return Err(response),
81            }
82        }
83        Ok(request)
84    }
85    
86    /// Process response through all middleware in reverse order
87    pub async fn process_response(&self, mut response: Response) -> Response {
88        // Process in reverse order - last middleware added processes response first
89        for middleware in self.middleware.iter().rev() {
90            response = middleware.process_response(response).await;
91        }
92        response
93    }
94    
95    /// Get number of middleware in pipeline
96    pub fn len(&self) -> usize {
97        self.middleware.len()
98    }
99    
100    /// Check if pipeline is empty
101    pub fn is_empty(&self) -> bool {
102        self.middleware.is_empty()
103    }
104    
105    /// Get middleware names for debugging
106    pub fn names(&self) -> Vec<&'static str> {
107        self.middleware.iter().map(|m| m.name()).collect()
108    }
109}
110
111/// Middleware wrapper that can handle errors and convert them to responses
112pub struct ErrorHandlingMiddleware<M> {
113    inner: M,
114}
115
116impl<M> ErrorHandlingMiddleware<M> {
117    pub fn new(middleware: M) -> Self {
118        Self { inner: middleware }
119    }
120}
121
122impl<M: Middleware> Middleware for ErrorHandlingMiddleware<M> {
123    fn process_request<'a>(
124        &'a self, 
125        request: Request
126    ) -> BoxFuture<'a, Result<Request, Response>> {
127        Box::pin(async move {
128            // Delegate to inner middleware with error handling
129            match self.inner.process_request(request).await {
130                Ok(req) => Ok(req),
131                Err(response) => Err(response),
132            }
133        })
134    }
135    
136    fn process_response<'a>(
137        &'a self, 
138        response: Response
139    ) -> BoxFuture<'a, Response> {
140        Box::pin(async move {
141            self.inner.process_response(response).await
142        })
143    }
144    
145    fn name(&self) -> &'static str {
146        self.inner.name()
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use crate::response::ElifStatusCode as StatusCode;
154    use axum::http::Method; // TODO: Replace with framework type when available
155    
156    struct TestMiddleware {
157        name: &'static str,
158    }
159    
160    impl TestMiddleware {
161        fn new(name: &'static str) -> Self {
162            Self { name }
163        }
164    }
165    
166    impl Middleware for TestMiddleware {
167        fn process_request<'a>(
168            &'a self, 
169            mut request: Request
170        ) -> BoxFuture<'a, Result<Request, Response>> {
171            Box::pin(async move {
172                // Add a header to track middleware execution
173                let headers = request.headers_mut();
174                headers.insert("X-Middleware", self.name.parse().unwrap());
175                Ok(request)
176            })
177        }
178        
179        fn process_response<'a>(
180            &'a self, 
181            mut response: Response
182        ) -> BoxFuture<'a, Response> {
183            Box::pin(async move {
184                // Add response header
185                let headers = response.headers_mut();
186                headers.insert("X-Response-Middleware", self.name.parse().unwrap());
187                response
188            })
189        }
190        
191        fn name(&self) -> &'static str {
192            self.name
193        }
194    }
195    
196    #[tokio::test]
197    async fn test_middleware_pipeline() {
198        let pipeline = MiddlewarePipeline::new()
199            .add(TestMiddleware::new("First"))
200            .add(TestMiddleware::new("Second"));
201        
202        // Create test request
203        let request = Request::builder()
204            .method(Method::GET)
205            .uri("/test")
206            .body(axum::body::Body::empty())
207            .unwrap();
208        
209        // Process request
210        let processed_request = pipeline.process_request(request).await.unwrap();
211        
212        // Should have header from last middleware (Second overwrites First)
213        assert_eq!(
214            processed_request.headers().get("X-Middleware").unwrap(),
215            "Second"
216        );
217        
218        // Create test response
219        let response = Response::builder()
220            .status(StatusCode::OK)
221            .body(axum::body::Body::empty())
222            .unwrap();
223        
224        // Process response
225        let processed_response = pipeline.process_response(response).await;
226        
227        // Should have header from first middleware (reverse order)
228        assert_eq!(
229            processed_response.headers().get("X-Response-Middleware").unwrap(),
230            "First"
231        );
232    }
233    
234    #[tokio::test]
235    async fn test_pipeline_info() {
236        let pipeline = MiddlewarePipeline::new()
237            .add(TestMiddleware::new("Test1"))
238            .add(TestMiddleware::new("Test2"));
239        
240        assert_eq!(pipeline.len(), 2);
241        assert!(!pipeline.is_empty());
242        assert_eq!(pipeline.names(), vec!["Test1", "Test2"]);
243    }
244    
245    #[tokio::test]
246    async fn test_empty_pipeline() {
247        let pipeline = MiddlewarePipeline::new();
248        
249        let request = Request::builder()
250            .method(Method::GET)
251            .uri("/test")
252            .body(axum::body::Body::empty())
253            .unwrap();
254        
255        let processed_request = pipeline.process_request(request).await.unwrap();
256        
257        // Request should pass through unchanged
258        assert_eq!(processed_request.method(), Method::GET);
259        assert_eq!(processed_request.uri().path(), "/test");
260    }
261}