Skip to main content

reinhardt_http/
middleware.rs

1//! Middleware and handler traits for HTTP request processing.
2//!
3//! This module provides the core abstractions for handling HTTP requests
4//! and composing middleware chains.
5//!
6//! ## Handler
7//!
8//! The `Handler` trait is the core abstraction for processing requests:
9//!
10//! ```rust
11//! use reinhardt_http::{Handler, Request, Response};
12//! use async_trait::async_trait;
13//!
14//! struct MyHandler;
15//!
16//! #[async_trait]
17//! impl Handler for MyHandler {
18//!     async fn handle(&self, request: Request) -> reinhardt_core::exception::Result<Response> {
19//!         Ok(Response::ok().with_body("Hello!"))
20//!     }
21//! }
22//! ```
23//!
24//! ## Middleware
25//!
26//! Middleware wraps handlers to add cross-cutting concerns:
27//!
28//! ```rust
29//! use reinhardt_http::{Handler, Middleware, Request, Response};
30//! use async_trait::async_trait;
31//! use std::sync::Arc;
32//!
33//! struct LoggingMiddleware;
34//!
35//! #[async_trait]
36//! impl Middleware for LoggingMiddleware {
37//!     async fn process(&self, request: Request, next: Arc<dyn Handler>) -> reinhardt_core::exception::Result<Response> {
38//!         println!("Request: {} {}", request.method, request.uri);
39//!         next.handle(request).await
40//!     }
41//! }
42//! ```
43
44use async_trait::async_trait;
45use reinhardt_core::exception::Result;
46use std::sync::Arc;
47
48use crate::{Request, Response};
49
50/// Handler trait for processing requests.
51///
52/// This is the core abstraction - all request handlers implement this trait.
53/// Handlers receive a request and produce a response or an error.
54#[async_trait]
55pub trait Handler: Send + Sync {
56	/// Handles an HTTP request and produces a response.
57	///
58	/// # Errors
59	///
60	/// Returns an error if the request cannot be processed.
61	async fn handle(&self, request: Request) -> Result<Response>;
62}
63
64/// Blanket implementation for `Arc<T>` where T: Handler.
65///
66/// This allows `Arc<dyn Handler>` to be used as a Handler,
67/// enabling shared ownership of handlers across threads.
68#[async_trait]
69impl<T: Handler + ?Sized> Handler for Arc<T> {
70	async fn handle(&self, request: Request) -> Result<Response> {
71		(**self).handle(request).await
72	}
73}
74
75/// Middleware trait for request/response processing.
76///
77/// Uses composition pattern instead of inheritance.
78/// Middleware can modify requests before passing to the next handler,
79/// or modify responses after the handler processes the request.
80#[async_trait]
81pub trait Middleware: Send + Sync {
82	/// Processes a request through this middleware.
83	///
84	/// # Arguments
85	///
86	/// * `request` - The incoming HTTP request
87	/// * `next` - The next handler in the chain to call
88	///
89	/// # Errors
90	///
91	/// Returns an error if the middleware or next handler fails.
92	async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response>;
93
94	/// Determines whether this middleware should be executed for the given request.
95	///
96	/// This method enables conditional execution of middleware, allowing the middleware
97	/// chain to skip unnecessary middleware based on request properties.
98	///
99	/// # Performance Benefits
100	///
101	/// By implementing this method, middleware chains can achieve O(k) complexity
102	/// instead of O(n), where k is the number of middleware that should run,
103	/// and k <= n (total middleware count).
104	///
105	/// # Common Use Cases
106	///
107	/// - Skip authentication middleware for public endpoints
108	/// - Skip compression middleware for already compressed responses
109	/// - Skip CORS middleware for same-origin requests
110	/// - Skip rate limiting for internal/admin requests
111	///
112	/// # Default Implementation
113	///
114	/// By default, returns `true` (always execute), maintaining backward compatibility.
115	fn should_continue(&self, _request: &Request) -> bool {
116		true
117	}
118}
119
120/// Middleware chain - composes multiple middleware into a single handler.
121///
122/// The chain processes requests through middleware in the order they were added,
123/// with optimizations for conditional execution and early termination.
124pub struct MiddlewareChain {
125	middlewares: Vec<Arc<dyn Middleware>>,
126	handler: Arc<dyn Handler>,
127}
128
129impl MiddlewareChain {
130	/// Creates a new middleware chain with the given handler.
131	///
132	/// # Examples
133	///
134	/// ```rust
135	/// use reinhardt_http::{MiddlewareChain, Handler, Request, Response};
136	/// use std::sync::Arc;
137	///
138	/// struct MyHandler;
139	///
140	/// #[async_trait::async_trait]
141	/// impl Handler for MyHandler {
142	///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
143	///         Ok(Response::ok())
144	///     }
145	/// }
146	///
147	/// let handler = Arc::new(MyHandler);
148	/// let chain = MiddlewareChain::new(handler);
149	/// ```
150	pub fn new(handler: Arc<dyn Handler>) -> Self {
151		Self {
152			middlewares: Vec::new(),
153			handler,
154		}
155	}
156
157	/// Adds a middleware to the chain using builder pattern.
158	///
159	/// # Examples
160	///
161	/// ```rust
162	/// use reinhardt_http::{MiddlewareChain, Handler, Middleware, Request, Response};
163	/// use std::sync::Arc;
164	///
165	/// # struct MyHandler;
166	/// # struct MyMiddleware;
167	/// # #[async_trait::async_trait]
168	/// # impl Handler for MyHandler {
169	/// #     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
170	/// #         Ok(Response::ok())
171	/// #     }
172	/// # }
173	/// # #[async_trait::async_trait]
174	/// # impl Middleware for MyMiddleware {
175	/// #     async fn process(&self, request: Request, next: Arc<dyn Handler>) -> reinhardt_core::exception::Result<Response> {
176	/// #         next.handle(request).await
177	/// #     }
178	/// # }
179	/// let handler = Arc::new(MyHandler);
180	/// let middleware = Arc::new(MyMiddleware);
181	/// let chain = MiddlewareChain::new(handler)
182	///     .with_middleware(middleware);
183	/// ```
184	pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
185		self.middlewares.push(middleware);
186		self
187	}
188
189	/// Adds a middleware to the chain.
190	///
191	/// # Examples
192	///
193	/// ```rust
194	/// use reinhardt_http::{MiddlewareChain, Handler, Middleware, Request, Response};
195	/// use std::sync::Arc;
196	///
197	/// # struct MyHandler;
198	/// # struct MyMiddleware;
199	/// # #[async_trait::async_trait]
200	/// # impl Handler for MyHandler {
201	/// #     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
202	/// #         Ok(Response::ok())
203	/// #     }
204	/// # }
205	/// # #[async_trait::async_trait]
206	/// # impl Middleware for MyMiddleware {
207	/// #     async fn process(&self, request: Request, next: Arc<dyn Handler>) -> reinhardt_core::exception::Result<Response> {
208	/// #         next.handle(request).await
209	/// #     }
210	/// # }
211	/// let handler = Arc::new(MyHandler);
212	/// let middleware = Arc::new(MyMiddleware);
213	/// let mut chain = MiddlewareChain::new(handler);
214	/// chain.add_middleware(middleware);
215	/// ```
216	pub fn add_middleware(&mut self, middleware: Arc<dyn Middleware>) {
217		self.middlewares.push(middleware);
218	}
219}
220
221#[async_trait]
222impl Handler for MiddlewareChain {
223	async fn handle(&self, request: Request) -> Result<Response> {
224		if self.middlewares.is_empty() {
225			return self.handler.handle(request).await;
226		}
227
228		// Build nested handler chain using composition with optimizations:
229		// 1. Conditional execution (skip middleware based on should_continue)
230		// 2. Short-circuiting (early return if response.should_stop_chain() is true)
231		//
232		// Performance improvements:
233		// - Condition check: O(1) per middleware
234		// - Skip unnecessary middleware: achieves O(k) where k <= n
235		// - Early return: stops processing on first stop_chain=true response
236		let mut current_handler = self.handler.clone();
237
238		// Filter middleware based on should_continue condition
239		// This achieves the O(k) optimization where k is the number of middleware that should run
240		let active_middlewares: Vec<_> = self
241			.middlewares
242			.iter()
243			.rev()
244			.filter(|mw| mw.should_continue(&request))
245			.collect();
246
247		for middleware in active_middlewares {
248			let mw = middleware.clone();
249			let handler = current_handler.clone();
250
251			current_handler = Arc::new(ConditionalComposedHandler {
252				middleware: mw,
253				next: handler,
254			});
255		}
256
257		current_handler.handle(request).await
258	}
259}
260
261/// Optimized internal handler that composes middleware with next handler.
262///
263/// Supports short-circuiting via `response.should_stop_chain()`.
264struct ConditionalComposedHandler {
265	middleware: Arc<dyn Middleware>,
266	next: Arc<dyn Handler>,
267}
268
269#[async_trait]
270impl Handler for ConditionalComposedHandler {
271	async fn handle(&self, request: Request) -> Result<Response> {
272		// Process the request through this middleware
273		let response = self.middleware.process(request, self.next.clone()).await?;
274
275		// Short-circuit: if response indicates chain should stop, return immediately
276		// This prevents further middleware/handlers from executing
277		if response.should_stop_chain() {
278			return Ok(response);
279		}
280
281		Ok(response)
282	}
283}
284
285#[cfg(test)]
286mod tests {
287	use super::*;
288	use bytes::Bytes;
289	use hyper::{HeaderMap, Method, Version};
290
291	// Mock handler for testing
292	struct MockHandler {
293		response_body: String,
294	}
295
296	#[async_trait]
297	impl Handler for MockHandler {
298		async fn handle(&self, _request: Request) -> Result<Response> {
299			Ok(Response::ok().with_body(self.response_body.clone()))
300		}
301	}
302
303	// Mock middleware for testing
304	struct MockMiddleware {
305		prefix: String,
306	}
307
308	#[async_trait]
309	impl Middleware for MockMiddleware {
310		async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
311			// Call the next handler
312			let response = next.handle(request).await?;
313
314			// Modify the response
315			let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
316			let new_body = format!("{}{}", self.prefix, current_body);
317
318			Ok(Response::ok().with_body(new_body))
319		}
320	}
321
322	fn create_test_request() -> Request {
323		Request::builder()
324			.method(Method::GET)
325			.uri("/")
326			.version(Version::HTTP_11)
327			.headers(HeaderMap::new())
328			.body(Bytes::new())
329			.build()
330			.unwrap()
331	}
332
333	#[tokio::test]
334	async fn test_handler_basic() {
335		let handler = MockHandler {
336			response_body: "Hello".to_string(),
337		};
338
339		let request = create_test_request();
340		let response = handler.handle(request).await.unwrap();
341
342		let body = String::from_utf8(response.body.to_vec()).unwrap();
343		assert_eq!(body, "Hello");
344	}
345
346	#[tokio::test]
347	async fn test_middleware_basic() {
348		let handler = Arc::new(MockHandler {
349			response_body: "World".to_string(),
350		});
351
352		let middleware = MockMiddleware {
353			prefix: "Hello, ".to_string(),
354		};
355
356		let request = create_test_request();
357		let response = middleware.process(request, handler).await.unwrap();
358
359		let body = String::from_utf8(response.body.to_vec()).unwrap();
360		assert_eq!(body, "Hello, World");
361	}
362
363	#[tokio::test]
364	async fn test_middleware_chain_empty() {
365		let handler = Arc::new(MockHandler {
366			response_body: "Test".to_string(),
367		});
368
369		let chain = MiddlewareChain::new(handler);
370
371		let request = create_test_request();
372		let response = chain.handle(request).await.unwrap();
373
374		let body = String::from_utf8(response.body.to_vec()).unwrap();
375		assert_eq!(body, "Test");
376	}
377
378	#[tokio::test]
379	async fn test_middleware_chain_single() {
380		let handler = Arc::new(MockHandler {
381			response_body: "Handler".to_string(),
382		});
383
384		let middleware1 = Arc::new(MockMiddleware {
385			prefix: "MW1:".to_string(),
386		});
387
388		let chain = MiddlewareChain::new(handler).with_middleware(middleware1);
389
390		let request = create_test_request();
391		let response = chain.handle(request).await.unwrap();
392
393		let body = String::from_utf8(response.body.to_vec()).unwrap();
394		assert_eq!(body, "MW1:Handler");
395	}
396
397	#[tokio::test]
398	async fn test_middleware_chain_multiple() {
399		let handler = Arc::new(MockHandler {
400			response_body: "Data".to_string(),
401		});
402
403		let middleware1 = Arc::new(MockMiddleware {
404			prefix: "M1:".to_string(),
405		});
406
407		let middleware2 = Arc::new(MockMiddleware {
408			prefix: "M2:".to_string(),
409		});
410
411		let chain = MiddlewareChain::new(handler)
412			.with_middleware(middleware1)
413			.with_middleware(middleware2);
414
415		let request = create_test_request();
416		let response = chain.handle(request).await.unwrap();
417
418		let body = String::from_utf8(response.body.to_vec()).unwrap();
419		// Middleware are applied in the order they were added
420		assert_eq!(body, "M1:M2:Data");
421	}
422
423	#[tokio::test]
424	async fn test_middleware_chain_add_middleware() {
425		let handler = Arc::new(MockHandler {
426			response_body: "Result".to_string(),
427		});
428
429		let middleware = Arc::new(MockMiddleware {
430			prefix: "Prefix:".to_string(),
431		});
432
433		let mut chain = MiddlewareChain::new(handler);
434		chain.add_middleware(middleware);
435
436		let request = create_test_request();
437		let response = chain.handle(request).await.unwrap();
438
439		let body = String::from_utf8(response.body.to_vec()).unwrap();
440		assert_eq!(body, "Prefix:Result");
441	}
442
443	// Conditional middleware that only runs for /api/* paths
444	struct ConditionalMiddleware {
445		prefix: String,
446	}
447
448	#[async_trait]
449	impl Middleware for ConditionalMiddleware {
450		async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
451			let response = next.handle(request).await?;
452			let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
453			let new_body = format!("{}{}", self.prefix, current_body);
454			Ok(Response::ok().with_body(new_body))
455		}
456
457		fn should_continue(&self, request: &Request) -> bool {
458			request.uri.path().starts_with("/api/")
459		}
460	}
461
462	#[tokio::test]
463	async fn test_middleware_conditional_skip() {
464		let handler = Arc::new(MockHandler {
465			response_body: "Response".to_string(),
466		});
467
468		let conditional_mw = Arc::new(ConditionalMiddleware {
469			prefix: "API:".to_string(),
470		});
471
472		let chain = MiddlewareChain::new(handler).with_middleware(conditional_mw);
473
474		// Test with /api/ path - middleware should run
475		let api_request = Request::builder()
476			.method(Method::GET)
477			.uri("/api/users")
478			.version(Version::HTTP_11)
479			.headers(HeaderMap::new())
480			.body(Bytes::new())
481			.build()
482			.unwrap();
483		let response = chain.handle(api_request).await.unwrap();
484		let body = String::from_utf8(response.body.to_vec()).unwrap();
485		assert_eq!(body, "API:Response");
486
487		// Test with non-/api/ path - middleware should be skipped
488		let non_api_request = Request::builder()
489			.method(Method::GET)
490			.uri("/public")
491			.version(Version::HTTP_11)
492			.headers(HeaderMap::new())
493			.body(Bytes::new())
494			.build()
495			.unwrap();
496		let response = chain.handle(non_api_request).await.unwrap();
497		let body = String::from_utf8(response.body.to_vec()).unwrap();
498		assert_eq!(body, "Response"); // No prefix because middleware was skipped
499	}
500
501	// Middleware that returns early with stop_chain=true
502	struct ShortCircuitMiddleware {
503		should_stop: bool,
504	}
505
506	#[async_trait]
507	impl Middleware for ShortCircuitMiddleware {
508		async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
509			if self.should_stop {
510				// Return early without calling next
511				return Ok(Response::unauthorized()
512					.with_body("Auth required")
513					.with_stop_chain(true));
514			}
515			next.handle(request).await
516		}
517	}
518
519	#[tokio::test]
520	async fn test_middleware_short_circuit() {
521		let handler = Arc::new(MockHandler {
522			response_body: "Handler Response".to_string(),
523		});
524
525		let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: true });
526		let normal_mw = Arc::new(MockMiddleware {
527			prefix: "Normal:".to_string(),
528		});
529
530		let chain = MiddlewareChain::new(handler)
531			.with_middleware(short_circuit_mw)
532			.with_middleware(normal_mw);
533
534		let request = create_test_request();
535		let response = chain.handle(request).await.unwrap();
536
537		// Should get unauthorized response, not the handler response
538		assert_eq!(response.status, hyper::StatusCode::UNAUTHORIZED);
539		let body = String::from_utf8(response.body.to_vec()).unwrap();
540		assert_eq!(body, "Auth required");
541	}
542
543	#[tokio::test]
544	async fn test_middleware_no_short_circuit() {
545		let handler = Arc::new(MockHandler {
546			response_body: "Handler Response".to_string(),
547		});
548
549		let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: false });
550		let normal_mw = Arc::new(MockMiddleware {
551			prefix: "Normal:".to_string(),
552		});
553
554		let chain = MiddlewareChain::new(handler)
555			.with_middleware(short_circuit_mw)
556			.with_middleware(normal_mw);
557
558		let request = create_test_request();
559		let response = chain.handle(request).await.unwrap();
560
561		// Should pass through to handler and apply normal middleware
562		assert_eq!(response.status, hyper::StatusCode::OK);
563		let body = String::from_utf8(response.body.to_vec()).unwrap();
564		assert_eq!(body, "Normal:Handler Response");
565	}
566
567	#[tokio::test]
568	async fn test_middleware_multiple_conditions() {
569		let handler = Arc::new(MockHandler {
570			response_body: "Base".to_string(),
571		});
572
573		// Only runs for /api/* paths
574		let api_mw = Arc::new(ConditionalMiddleware {
575			prefix: "API:".to_string(),
576		});
577
578		// Always runs
579		let always_mw = Arc::new(MockMiddleware {
580			prefix: "Always:".to_string(),
581		});
582
583		let chain = MiddlewareChain::new(handler)
584			.with_middleware(api_mw)
585			.with_middleware(always_mw);
586
587		// Test with /api/ path - both middleware should run
588		let api_request = Request::builder()
589			.method(Method::GET)
590			.uri("/api/test")
591			.version(Version::HTTP_11)
592			.headers(HeaderMap::new())
593			.body(Bytes::new())
594			.build()
595			.unwrap();
596		let response = chain.handle(api_request).await.unwrap();
597		let body = String::from_utf8(response.body.to_vec()).unwrap();
598		assert_eq!(body, "API:Always:Base");
599
600		// Test with non-/api/ path - only always_mw should run
601		let non_api_request = Request::builder()
602			.method(Method::GET)
603			.uri("/public")
604			.version(Version::HTTP_11)
605			.headers(HeaderMap::new())
606			.body(Bytes::new())
607			.build()
608			.unwrap();
609		let response = chain.handle(non_api_request).await.unwrap();
610		let body = String::from_utf8(response.body.to_vec()).unwrap();
611		assert_eq!(body, "Always:Base"); // Only always_mw prefix
612	}
613
614	#[tokio::test]
615	async fn test_response_should_stop_chain() {
616		let response = Response::ok();
617		assert!(!response.should_stop_chain());
618
619		let stopping_response = Response::unauthorized().with_stop_chain(true);
620		assert!(stopping_response.should_stop_chain());
621	}
622}