Skip to main content

reinhardt_http/
middleware.rs

1//! Middleware and handler traits for HTTP request processing.
2//!
3//! This module provides the core abstractions for handling HTTP requests
4//! and composing middleware chains.
5//!
6//! ## Handler
7//!
8//! The `Handler` trait is the core abstraction for processing requests:
9//!
10//! ```rust
11//! use reinhardt_http::{Handler, Request, Response};
12//! use async_trait::async_trait;
13//!
14//! struct MyHandler;
15//!
16//! #[async_trait]
17//! impl Handler for MyHandler {
18//!     async fn handle(&self, request: Request) -> reinhardt_core::exception::Result<Response> {
19//!         Ok(Response::ok().with_body("Hello!"))
20//!     }
21//! }
22//! ```
23//!
24//! ## Middleware
25//!
26//! Middleware wraps handlers to add cross-cutting concerns:
27//!
28//! ```rust
29//! use reinhardt_http::{Handler, Middleware, Request, Response};
30//! use async_trait::async_trait;
31//! use std::sync::Arc;
32//!
33//! struct LoggingMiddleware;
34//!
35//! #[async_trait]
36//! impl Middleware for LoggingMiddleware {
37//!     async fn process(&self, request: Request, next: Arc<dyn Handler>) -> reinhardt_core::exception::Result<Response> {
38//!         println!("Request: {} {}", request.method, request.uri);
39//!         next.handle(request).await
40//!     }
41//! }
42//! ```
43
44use async_trait::async_trait;
45use reinhardt_core::exception::Result;
46use std::any::{Any, TypeId};
47use std::sync::Arc;
48
49use crate::{Request, Response};
50
51/// Type-erased DI singleton registration entry contributed by a middleware.
52///
53/// Pairs the concrete `TypeId` of `T` with an `Arc<dyn Any + Send + Sync>` that
54/// can be inserted directly into a DI singleton scope keyed by that `TypeId`.
55/// This indirection lets `reinhardt-http` expose a DI hook on the `Middleware`
56/// trait without taking a dependency on `reinhardt-di` (which would create a
57/// circular crate dependency).
58pub type MiddlewareDiRegistration = (TypeId, Arc<dyn Any + Send + Sync>);
59
60/// Handler trait for processing requests.
61///
62/// This is the core abstraction - all request handlers implement this trait.
63/// Handlers receive a request and produce a response or an error.
64#[async_trait]
65pub trait Handler: Send + Sync {
66	/// Handles an HTTP request and produces a response.
67	///
68	/// # Errors
69	///
70	/// Returns an error if the request cannot be processed.
71	async fn handle(&self, request: Request) -> Result<Response>;
72}
73
74/// Blanket implementation for `Arc<T>` where T: Handler.
75///
76/// This allows `Arc<dyn Handler>` to be used as a Handler,
77/// enabling shared ownership of handlers across threads.
78#[async_trait]
79impl<T: Handler + ?Sized> Handler for Arc<T> {
80	async fn handle(&self, request: Request) -> Result<Response> {
81		(**self).handle(request).await
82	}
83}
84
85/// Middleware trait for request/response processing.
86///
87/// Uses composition pattern instead of inheritance.
88/// Middleware can modify requests before passing to the next handler,
89/// or modify responses after the handler processes the request.
90#[async_trait]
91pub trait Middleware: Send + Sync {
92	/// Processes a request through this middleware.
93	///
94	/// # Arguments
95	///
96	/// * `request` - The incoming HTTP request
97	/// * `next` - The next handler in the chain to call
98	///
99	/// # Errors
100	///
101	/// Returns an error if the middleware or next handler fails.
102	async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response>;
103
104	/// Determines whether this middleware should be executed for the given request.
105	///
106	/// This method enables conditional execution of middleware, allowing the middleware
107	/// chain to skip unnecessary middleware based on request properties.
108	///
109	/// # Performance Benefits
110	///
111	/// By implementing this method, middleware chains can achieve O(k) complexity
112	/// instead of O(n), where k is the number of middleware that should run,
113	/// and k <= n (total middleware count).
114	///
115	/// # Common Use Cases
116	///
117	/// - Skip authentication middleware for public endpoints
118	/// - Skip compression middleware for already compressed responses
119	/// - Skip CORS middleware for same-origin requests
120	/// - Skip rate limiting for internal/admin requests
121	///
122	/// # Default Implementation
123	///
124	/// By default, returns `true` (always execute), maintaining backward compatibility.
125	fn should_continue(&self, _request: &Request) -> bool {
126		true
127	}
128
129	/// Returns DI singleton registrations contributed by this middleware.
130	///
131	/// Each entry is a `(TypeId, Arc<dyn Any + Send + Sync>)` pair representing
132	/// a singleton that the middleware owns and wants to expose to handlers
133	/// resolved via `#[inject]`. The default implementation returns an empty
134	/// vector, preserving backward compatibility for middleware that does not
135	/// own any DI-visible state.
136	///
137	/// Routers such as `ServerRouter` / `UnifiedRouter` call this method when
138	/// the middleware is registered via `with_middleware()` and merge the
139	/// resulting list into the server's DI singleton scope. This lets a
140	/// middleware (for example `SessionMiddleware`) automatically register the
141	/// `Arc<T>` it constructs in its constructor, so callers no longer have to
142	/// thread a parallel `with_di_registrations(...)` call alongside every
143	/// `with_middleware(...)`.
144	///
145	/// # Example
146	///
147	/// ```rust,ignore
148	/// use std::any::TypeId;
149	/// use std::sync::Arc;
150	/// use reinhardt_http::{Middleware, MiddlewareDiRegistration};
151	///
152	/// struct MyStore;
153	/// struct MyMiddleware { store: Arc<MyStore> }
154	///
155	/// impl Middleware for MyMiddleware {
156	///     // ... process / should_continue ...
157	///     fn di_registrations(&self) -> Vec<MiddlewareDiRegistration> {
158	///         vec![(TypeId::of::<MyStore>(), Arc::clone(&self.store) as _)]
159	///     }
160	/// }
161	/// ```
162	fn di_registrations(&self) -> Vec<MiddlewareDiRegistration> {
163		Vec::new()
164	}
165}
166
167/// Middleware chain - composes multiple middleware into a single handler.
168///
169/// The chain processes requests through middleware in the order they were added,
170/// with optimizations for conditional execution and early termination.
171pub struct MiddlewareChain {
172	middlewares: Vec<Arc<dyn Middleware>>,
173	handler: Arc<dyn Handler>,
174}
175
176impl MiddlewareChain {
177	/// Creates a new middleware chain with the given handler.
178	///
179	/// # Examples
180	///
181	/// ```rust
182	/// use reinhardt_http::{MiddlewareChain, Handler, Request, Response};
183	/// use std::sync::Arc;
184	///
185	/// struct MyHandler;
186	///
187	/// #[async_trait::async_trait]
188	/// impl Handler for MyHandler {
189	///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
190	///         Ok(Response::ok())
191	///     }
192	/// }
193	///
194	/// let handler = Arc::new(MyHandler);
195	/// let chain = MiddlewareChain::new(handler);
196	/// ```
197	pub fn new(handler: Arc<dyn Handler>) -> Self {
198		Self {
199			middlewares: Vec::new(),
200			handler,
201		}
202	}
203
204	/// Adds a middleware to the chain using builder pattern.
205	///
206	/// # Examples
207	///
208	/// ```rust
209	/// use reinhardt_http::{MiddlewareChain, Handler, Middleware, Request, Response};
210	/// use std::sync::Arc;
211	///
212	/// # struct MyHandler;
213	/// # struct MyMiddleware;
214	/// # #[async_trait::async_trait]
215	/// # impl Handler for MyHandler {
216	/// #     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
217	/// #         Ok(Response::ok())
218	/// #     }
219	/// # }
220	/// # #[async_trait::async_trait]
221	/// # impl Middleware for MyMiddleware {
222	/// #     async fn process(&self, request: Request, next: Arc<dyn Handler>) -> reinhardt_core::exception::Result<Response> {
223	/// #         next.handle(request).await
224	/// #     }
225	/// # }
226	/// let handler = Arc::new(MyHandler);
227	/// let middleware = Arc::new(MyMiddleware);
228	/// let chain = MiddlewareChain::new(handler)
229	///     .with_middleware(middleware);
230	/// ```
231	pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
232		self.middlewares.push(middleware);
233		self
234	}
235
236	/// Adds a middleware to the chain.
237	///
238	/// # Examples
239	///
240	/// ```rust
241	/// use reinhardt_http::{MiddlewareChain, Handler, Middleware, Request, Response};
242	/// use std::sync::Arc;
243	///
244	/// # struct MyHandler;
245	/// # struct MyMiddleware;
246	/// # #[async_trait::async_trait]
247	/// # impl Handler for MyHandler {
248	/// #     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
249	/// #         Ok(Response::ok())
250	/// #     }
251	/// # }
252	/// # #[async_trait::async_trait]
253	/// # impl Middleware for MyMiddleware {
254	/// #     async fn process(&self, request: Request, next: Arc<dyn Handler>) -> reinhardt_core::exception::Result<Response> {
255	/// #         next.handle(request).await
256	/// #     }
257	/// # }
258	/// let handler = Arc::new(MyHandler);
259	/// let middleware = Arc::new(MyMiddleware);
260	/// let mut chain = MiddlewareChain::new(handler);
261	/// chain.add_middleware(middleware);
262	/// ```
263	pub fn add_middleware(&mut self, middleware: Arc<dyn Middleware>) {
264		self.middlewares.push(middleware);
265	}
266}
267
268#[async_trait]
269impl Handler for MiddlewareChain {
270	async fn handle(&self, request: Request) -> Result<Response> {
271		if self.middlewares.is_empty() {
272			return self.handler.handle(request).await;
273		}
274
275		// Build nested handler chain using composition with optimizations:
276		// 1. Conditional execution (skip middleware based on should_continue)
277		// 2. Short-circuiting (early return if response.should_stop_chain() is true)
278		//
279		// Performance improvements:
280		// - Condition check: O(1) per middleware
281		// - Skip unnecessary middleware: achieves O(k) where k <= n
282		// - Early return: stops processing on first stop_chain=true response
283		// Wrap the base handler to convert errors to responses, ensuring
284		// all middleware post-processing runs even for error responses.
285		let mut current_handler: Arc<dyn Handler> = Arc::new(ErrorToResponseHandler {
286			inner: self.handler.clone(),
287		});
288
289		// Filter middleware based on should_continue condition
290		// This achieves the O(k) optimization where k is the number of middleware that should run
291		let active_middlewares: Vec<_> = self
292			.middlewares
293			.iter()
294			.rev()
295			.filter(|mw| mw.should_continue(&request))
296			.collect();
297
298		for middleware in active_middlewares {
299			let mw = middleware.clone();
300			let handler = current_handler.clone();
301
302			current_handler = Arc::new(ConditionalComposedHandler {
303				middleware: mw,
304				next: handler,
305			});
306		}
307
308		current_handler.handle(request).await
309	}
310}
311
312/// Middleware wrapper that excludes specific URL paths from execution.
313///
314/// When a request matches an excluded path, the middleware is skipped
315/// and the request passes directly to the next handler in the chain.
316///
317/// Path matching follows Django URL conventions:
318/// - Paths ending with `/` are treated as **prefix matches**
319///   (e.g., `"/api/auth/"` excludes `"/api/auth/login"`, `"/api/auth/register"`)
320/// - Paths without trailing `/` require an **exact match**
321///   (e.g., `"/health"` excludes only `"/health"`, not `"/health/check"`)
322///
323/// This struct is typically not used directly. Instead, use the
324/// `exclude` methods on the `ServerRouter` or `UnifiedRouter` types
325/// from the `reinhardt_urls::routers` module for declarative
326/// route exclusion at the router level.
327///
328/// # Examples
329///
330/// ```rust
331/// use reinhardt_http::middleware::ExcludeMiddleware;
332/// use reinhardt_http::{Middleware, Request};
333/// use std::sync::Arc;
334///
335/// # struct MyMiddleware;
336/// # #[async_trait::async_trait]
337/// # impl Middleware for MyMiddleware {
338/// #     async fn process(
339/// #         &self,
340/// #         request: Request,
341/// #         next: Arc<dyn reinhardt_http::Handler>,
342/// #     ) -> reinhardt_core::exception::Result<reinhardt_http::Response> {
343/// #         next.handle(request).await
344/// #     }
345/// # }
346/// let inner: Arc<dyn Middleware> = Arc::new(MyMiddleware);
347/// let excluded = ExcludeMiddleware::new(inner)
348///     .add_exclusion("/api/auth/")   // prefix match
349///     .add_exclusion("/health");     // exact match
350/// ```
351pub struct ExcludeMiddleware {
352	inner: Arc<dyn Middleware>,
353	exclusions: Vec<String>,
354}
355
356impl ExcludeMiddleware {
357	/// Creates a new `ExcludeMiddleware` wrapping the given middleware.
358	pub fn new(inner: Arc<dyn Middleware>) -> Self {
359		Self {
360			inner,
361			exclusions: Vec::new(),
362		}
363	}
364
365	/// Adds an exclusion pattern (builder pattern, consumes self).
366	///
367	/// Paths ending with `/` are prefix matches; others are exact matches.
368	pub fn add_exclusion(mut self, pattern: &str) -> Self {
369		self.exclusions.push(pattern.to_string());
370		self
371	}
372
373	/// Adds an exclusion pattern (mutable reference).
374	///
375	/// Paths ending with `/` are prefix matches; others are exact matches.
376	pub fn add_exclusion_mut(&mut self, pattern: &str) {
377		self.exclusions.push(pattern.to_string());
378	}
379
380	/// Checks whether the given path matches any exclusion pattern.
381	fn is_excluded(&self, path: &str) -> bool {
382		self.exclusions.iter().any(|pattern| {
383			if pattern.ends_with('/') {
384				// Prefix match: excluded if path starts with the pattern
385				path.starts_with(pattern.as_str())
386			} else {
387				// Exact match: excluded only if path equals the pattern
388				path == pattern
389			}
390		})
391	}
392}
393
394#[async_trait]
395impl Middleware for ExcludeMiddleware {
396	async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
397		self.inner.process(request, next).await
398	}
399
400	fn should_continue(&self, request: &Request) -> bool {
401		if self.is_excluded(request.uri.path()) {
402			return false;
403		}
404		self.inner.should_continue(request)
405	}
406}
407
408/// Internal handler wrapper that converts errors to HTTP responses.
409///
410/// Wraps the base handler so that middleware always receives `Ok(Response)`
411/// from `next.handle()`, even when the handler returns an error. This ensures
412/// middleware post-processing (e.g., adding security headers) runs for all
413/// responses, matching Django's `process_response` semantics.
414struct ErrorToResponseHandler {
415	inner: Arc<dyn Handler>,
416}
417
418#[async_trait]
419impl Handler for ErrorToResponseHandler {
420	async fn handle(&self, request: Request) -> Result<Response> {
421		match self.inner.handle(request).await {
422			Ok(response) => Ok(response),
423			Err(e) => Ok(Response::from(e)),
424		}
425	}
426}
427
428/// Internal handler that composes a single middleware with the next handler.
429///
430/// Converts middleware errors to HTTP responses so that outer middleware
431/// post-processing (e.g., adding security headers) always runs.
432struct ConditionalComposedHandler {
433	middleware: Arc<dyn Middleware>,
434	next: Arc<dyn Handler>,
435}
436
437#[async_trait]
438impl Handler for ConditionalComposedHandler {
439	async fn handle(&self, request: Request) -> Result<Response> {
440		// Process the request through this middleware.
441		// Convert errors to responses so that outer middleware post-processing
442		// (e.g., security headers) always runs — matching Django's process_response
443		// semantics where the response hook executes for both success and error cases.
444		let response = match self.middleware.process(request, self.next.clone()).await {
445			Ok(response) => response,
446			Err(e) => Response::from(e),
447		};
448
449		Ok(response)
450	}
451}
452
453#[cfg(test)]
454mod tests {
455	use super::*;
456	use bytes::Bytes;
457	use hyper::{HeaderMap, Method, Version};
458
459	// Mock handler for testing
460	struct MockHandler {
461		response_body: String,
462	}
463
464	#[async_trait]
465	impl Handler for MockHandler {
466		async fn handle(&self, _request: Request) -> Result<Response> {
467			Ok(Response::ok().with_body(self.response_body.clone()))
468		}
469	}
470
471	// Mock middleware for testing
472	struct MockMiddleware {
473		prefix: String,
474	}
475
476	#[async_trait]
477	impl Middleware for MockMiddleware {
478		async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
479			// Call the next handler
480			let response = next.handle(request).await?;
481
482			// Modify the response
483			let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
484			let new_body = format!("{}{}", self.prefix, current_body);
485
486			Ok(Response::ok().with_body(new_body))
487		}
488	}
489
490	fn create_test_request() -> Request {
491		Request::builder()
492			.method(Method::GET)
493			.uri("/")
494			.version(Version::HTTP_11)
495			.headers(HeaderMap::new())
496			.body(Bytes::new())
497			.build()
498			.unwrap()
499	}
500
501	#[tokio::test]
502	async fn test_handler_basic() {
503		let handler = MockHandler {
504			response_body: "Hello".to_string(),
505		};
506
507		let request = create_test_request();
508		let response = handler.handle(request).await.unwrap();
509
510		let body = String::from_utf8(response.body.to_vec()).unwrap();
511		assert_eq!(body, "Hello");
512	}
513
514	#[tokio::test]
515	async fn test_middleware_basic() {
516		let handler = Arc::new(MockHandler {
517			response_body: "World".to_string(),
518		});
519
520		let middleware = MockMiddleware {
521			prefix: "Hello, ".to_string(),
522		};
523
524		let request = create_test_request();
525		let response = middleware.process(request, handler).await.unwrap();
526
527		let body = String::from_utf8(response.body.to_vec()).unwrap();
528		assert_eq!(body, "Hello, World");
529	}
530
531	#[tokio::test]
532	async fn test_middleware_chain_empty() {
533		let handler = Arc::new(MockHandler {
534			response_body: "Test".to_string(),
535		});
536
537		let chain = MiddlewareChain::new(handler);
538
539		let request = create_test_request();
540		let response = chain.handle(request).await.unwrap();
541
542		let body = String::from_utf8(response.body.to_vec()).unwrap();
543		assert_eq!(body, "Test");
544	}
545
546	#[tokio::test]
547	async fn test_middleware_chain_single() {
548		let handler = Arc::new(MockHandler {
549			response_body: "Handler".to_string(),
550		});
551
552		let middleware1 = Arc::new(MockMiddleware {
553			prefix: "MW1:".to_string(),
554		});
555
556		let chain = MiddlewareChain::new(handler).with_middleware(middleware1);
557
558		let request = create_test_request();
559		let response = chain.handle(request).await.unwrap();
560
561		let body = String::from_utf8(response.body.to_vec()).unwrap();
562		assert_eq!(body, "MW1:Handler");
563	}
564
565	#[tokio::test]
566	async fn test_middleware_chain_multiple() {
567		let handler = Arc::new(MockHandler {
568			response_body: "Data".to_string(),
569		});
570
571		let middleware1 = Arc::new(MockMiddleware {
572			prefix: "M1:".to_string(),
573		});
574
575		let middleware2 = Arc::new(MockMiddleware {
576			prefix: "M2:".to_string(),
577		});
578
579		let chain = MiddlewareChain::new(handler)
580			.with_middleware(middleware1)
581			.with_middleware(middleware2);
582
583		let request = create_test_request();
584		let response = chain.handle(request).await.unwrap();
585
586		let body = String::from_utf8(response.body.to_vec()).unwrap();
587		// Middleware are applied in the order they were added
588		assert_eq!(body, "M1:M2:Data");
589	}
590
591	#[tokio::test]
592	async fn test_middleware_chain_add_middleware() {
593		let handler = Arc::new(MockHandler {
594			response_body: "Result".to_string(),
595		});
596
597		let middleware = Arc::new(MockMiddleware {
598			prefix: "Prefix:".to_string(),
599		});
600
601		let mut chain = MiddlewareChain::new(handler);
602		chain.add_middleware(middleware);
603
604		let request = create_test_request();
605		let response = chain.handle(request).await.unwrap();
606
607		let body = String::from_utf8(response.body.to_vec()).unwrap();
608		assert_eq!(body, "Prefix:Result");
609	}
610
611	// Conditional middleware that only runs for /api/* paths
612	struct ConditionalMiddleware {
613		prefix: String,
614	}
615
616	#[async_trait]
617	impl Middleware for ConditionalMiddleware {
618		async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
619			let response = next.handle(request).await?;
620			let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
621			let new_body = format!("{}{}", self.prefix, current_body);
622			Ok(Response::ok().with_body(new_body))
623		}
624
625		fn should_continue(&self, request: &Request) -> bool {
626			request.uri.path().starts_with("/api/")
627		}
628	}
629
630	#[tokio::test]
631	async fn test_middleware_conditional_skip() {
632		let handler = Arc::new(MockHandler {
633			response_body: "Response".to_string(),
634		});
635
636		let conditional_mw = Arc::new(ConditionalMiddleware {
637			prefix: "API:".to_string(),
638		});
639
640		let chain = MiddlewareChain::new(handler).with_middleware(conditional_mw);
641
642		// Test with /api/ path - middleware should run
643		let api_request = Request::builder()
644			.method(Method::GET)
645			.uri("/api/users")
646			.version(Version::HTTP_11)
647			.headers(HeaderMap::new())
648			.body(Bytes::new())
649			.build()
650			.unwrap();
651		let response = chain.handle(api_request).await.unwrap();
652		let body = String::from_utf8(response.body.to_vec()).unwrap();
653		assert_eq!(body, "API:Response");
654
655		// Test with non-/api/ path - middleware should be skipped
656		let non_api_request = Request::builder()
657			.method(Method::GET)
658			.uri("/public")
659			.version(Version::HTTP_11)
660			.headers(HeaderMap::new())
661			.body(Bytes::new())
662			.build()
663			.unwrap();
664		let response = chain.handle(non_api_request).await.unwrap();
665		let body = String::from_utf8(response.body.to_vec()).unwrap();
666		assert_eq!(body, "Response"); // No prefix because middleware was skipped
667	}
668
669	// Middleware that returns early with stop_chain=true
670	struct ShortCircuitMiddleware {
671		should_stop: bool,
672	}
673
674	#[async_trait]
675	impl Middleware for ShortCircuitMiddleware {
676		async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
677			if self.should_stop {
678				// Return early without calling next
679				return Ok(Response::unauthorized()
680					.with_body("Auth required")
681					.with_stop_chain(true));
682			}
683			next.handle(request).await
684		}
685	}
686
687	#[tokio::test]
688	async fn test_middleware_short_circuit() {
689		let handler = Arc::new(MockHandler {
690			response_body: "Handler Response".to_string(),
691		});
692
693		let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: true });
694		let normal_mw = Arc::new(MockMiddleware {
695			prefix: "Normal:".to_string(),
696		});
697
698		let chain = MiddlewareChain::new(handler)
699			.with_middleware(short_circuit_mw)
700			.with_middleware(normal_mw);
701
702		let request = create_test_request();
703		let response = chain.handle(request).await.unwrap();
704
705		// Should get unauthorized response, not the handler response
706		assert_eq!(response.status, hyper::StatusCode::UNAUTHORIZED);
707		let body = String::from_utf8(response.body.to_vec()).unwrap();
708		assert_eq!(body, "Auth required");
709	}
710
711	#[tokio::test]
712	async fn test_middleware_no_short_circuit() {
713		let handler = Arc::new(MockHandler {
714			response_body: "Handler Response".to_string(),
715		});
716
717		let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: false });
718		let normal_mw = Arc::new(MockMiddleware {
719			prefix: "Normal:".to_string(),
720		});
721
722		let chain = MiddlewareChain::new(handler)
723			.with_middleware(short_circuit_mw)
724			.with_middleware(normal_mw);
725
726		let request = create_test_request();
727		let response = chain.handle(request).await.unwrap();
728
729		// Should pass through to handler and apply normal middleware
730		assert_eq!(response.status, hyper::StatusCode::OK);
731		let body = String::from_utf8(response.body.to_vec()).unwrap();
732		assert_eq!(body, "Normal:Handler Response");
733	}
734
735	#[tokio::test]
736	async fn test_middleware_multiple_conditions() {
737		let handler = Arc::new(MockHandler {
738			response_body: "Base".to_string(),
739		});
740
741		// Only runs for /api/* paths
742		let api_mw = Arc::new(ConditionalMiddleware {
743			prefix: "API:".to_string(),
744		});
745
746		// Always runs
747		let always_mw = Arc::new(MockMiddleware {
748			prefix: "Always:".to_string(),
749		});
750
751		let chain = MiddlewareChain::new(handler)
752			.with_middleware(api_mw)
753			.with_middleware(always_mw);
754
755		// Test with /api/ path - both middleware should run
756		let api_request = Request::builder()
757			.method(Method::GET)
758			.uri("/api/test")
759			.version(Version::HTTP_11)
760			.headers(HeaderMap::new())
761			.body(Bytes::new())
762			.build()
763			.unwrap();
764		let response = chain.handle(api_request).await.unwrap();
765		let body = String::from_utf8(response.body.to_vec()).unwrap();
766		assert_eq!(body, "API:Always:Base");
767
768		// Test with non-/api/ path - only always_mw should run
769		let non_api_request = Request::builder()
770			.method(Method::GET)
771			.uri("/public")
772			.version(Version::HTTP_11)
773			.headers(HeaderMap::new())
774			.body(Bytes::new())
775			.build()
776			.unwrap();
777		let response = chain.handle(non_api_request).await.unwrap();
778		let body = String::from_utf8(response.body.to_vec()).unwrap();
779		assert_eq!(body, "Always:Base"); // Only always_mw prefix
780	}
781
782	#[tokio::test]
783	async fn test_response_should_stop_chain() {
784		let response = Response::ok();
785		assert!(!response.should_stop_chain());
786
787		let stopping_response = Response::unauthorized().with_stop_chain(true);
788		assert!(stopping_response.should_stop_chain());
789	}
790
791	// --- ExcludeMiddleware tests ---
792
793	fn create_request_with_path(path: &str) -> Request {
794		Request::builder()
795			.method(Method::GET)
796			.uri(path)
797			.version(Version::HTTP_11)
798			.headers(HeaderMap::new())
799			.body(Bytes::new())
800			.build()
801			.unwrap()
802	}
803
804	#[rstest::rstest]
805	#[case("/api/auth/login", true)]
806	#[case("/api/auth/register", true)]
807	#[case("/api/auth/", true)]
808	#[case("/api/users", false)]
809	#[case("/public", false)]
810	fn test_exclude_middleware_prefix_match(#[case] path: &str, #[case] should_exclude: bool) {
811		// Arrange
812		let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
813			prefix: "MW:".to_string(),
814		});
815		let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/api/auth/");
816
817		// Act
818		let request = create_request_with_path(path);
819		let result = exclude_mw.should_continue(&request);
820
821		// Assert
822		assert_eq!(result, !should_exclude);
823	}
824
825	#[rstest::rstest]
826	#[case("/health", true)]
827	#[case("/health/check", false)]
828	#[case("/healthz", false)]
829	#[case("/api/health", false)]
830	fn test_exclude_middleware_exact_match(#[case] path: &str, #[case] should_exclude: bool) {
831		// Arrange
832		let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
833			prefix: "MW:".to_string(),
834		});
835		let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/health");
836
837		// Act
838		let request = create_request_with_path(path);
839		let result = exclude_mw.should_continue(&request);
840
841		// Assert
842		assert_eq!(result, !should_exclude);
843	}
844
845	#[rstest::rstest]
846	fn test_exclude_middleware_no_match_passes_through() {
847		// Arrange
848		let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
849			prefix: "MW:".to_string(),
850		});
851		let exclude_mw = ExcludeMiddleware::new(inner)
852			.add_exclusion("/api/auth/")
853			.add_exclusion("/health");
854
855		// Act
856		let request = create_request_with_path("/api/users");
857		let result = exclude_mw.should_continue(&request);
858
859		// Assert
860		assert!(result);
861	}
862
863	#[rstest::rstest]
864	#[tokio::test]
865	async fn test_exclude_middleware_delegates_process() {
866		// Arrange
867		let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
868			prefix: "INNER:".to_string(),
869		});
870		let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/excluded/");
871
872		let handler = Arc::new(MockHandler {
873			response_body: "Response".to_string(),
874		});
875
876		// Act
877		let request = create_request_with_path("/api/test");
878		let response = exclude_mw.process(request, handler).await.unwrap();
879
880		// Assert
881		let body = String::from_utf8(response.body.to_vec()).unwrap();
882		assert_eq!(body, "INNER:Response");
883	}
884
885	#[rstest::rstest]
886	fn test_exclude_middleware_multiple_exclusions() {
887		// Arrange
888		let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
889			prefix: "MW:".to_string(),
890		});
891		let mut exclude_mw = ExcludeMiddleware::new(inner);
892		exclude_mw.add_exclusion_mut("/api/auth/");
893		exclude_mw.add_exclusion_mut("/admin/");
894		exclude_mw.add_exclusion_mut("/health");
895
896		// Act & Assert
897		assert!(!exclude_mw.should_continue(&create_request_with_path("/api/auth/login")));
898		assert!(!exclude_mw.should_continue(&create_request_with_path("/admin/dashboard")));
899		assert!(!exclude_mw.should_continue(&create_request_with_path("/health")));
900		assert!(exclude_mw.should_continue(&create_request_with_path("/api/users")));
901	}
902
903	#[rstest::rstest]
904	fn test_exclude_middleware_respects_inner_should_continue() {
905		// Arrange - inner middleware that rejects non-/api/ paths
906		let inner: Arc<dyn Middleware> = Arc::new(ConditionalMiddleware {
907			prefix: "API:".to_string(),
908		});
909		let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/api/auth/");
910
911		// Act & Assert
912		// Excluded path -> false (excluded by wrapper)
913		assert!(!exclude_mw.should_continue(&create_request_with_path("/api/auth/login")));
914		// Non-excluded, but inner rejects non-/api/ -> false (inner's should_continue)
915		assert!(!exclude_mw.should_continue(&create_request_with_path("/public")));
916		// Non-excluded, inner accepts /api/ -> true
917		assert!(exclude_mw.should_continue(&create_request_with_path("/api/users")));
918	}
919
920	// ========================================================================
921	// Error-to-response conversion tests (issue #3230)
922	// ========================================================================
923
924	/// Handler that always returns an error.
925	struct NotFoundHandler;
926
927	#[async_trait]
928	impl Handler for NotFoundHandler {
929		async fn handle(&self, _request: Request) -> Result<Response> {
930			Err(reinhardt_core::exception::Error::NotFound(
931				"not found".into(),
932			))
933		}
934	}
935
936	struct UnauthorizedHandler;
937
938	#[async_trait]
939	impl Handler for UnauthorizedHandler {
940		async fn handle(&self, _request: Request) -> Result<Response> {
941			Err(reinhardt_core::exception::Error::Authentication(
942				"unauthorized".into(),
943			))
944		}
945	}
946
947	/// Middleware that adds a custom header to the response after calling next.
948	struct HeaderAddingMiddleware {
949		header_name: &'static str,
950		header_value: &'static str,
951	}
952
953	#[async_trait]
954	impl Middleware for HeaderAddingMiddleware {
955		async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
956			let response = next.handle(request).await?;
957			Ok(response.with_header(self.header_name, self.header_value))
958		}
959	}
960
961	/// Middleware that always returns an error (simulates CSRF rejection).
962	struct RejectingMiddleware;
963
964	#[async_trait]
965	impl Middleware for RejectingMiddleware {
966		async fn process(&self, _request: Request, _next: Arc<dyn Handler>) -> Result<Response> {
967			Err(reinhardt_core::exception::Error::Authorization(
968				"CSRF check failed".into(),
969			))
970		}
971	}
972
973	#[rstest::rstest]
974	#[tokio::test]
975	async fn test_chain_post_processing_runs_on_handler_error() {
976		// Arrange: handler returns 404 error, outer middleware adds header
977		let handler: Arc<dyn Handler> = Arc::new(NotFoundHandler);
978		let mut chain = MiddlewareChain::new(handler);
979		chain.add_middleware(Arc::new(HeaderAddingMiddleware {
980			header_name: "X-Custom-Security",
981			header_value: "applied",
982		}));
983
984		// Act
985		let request = create_test_request();
986		let response = chain.handle(request).await.unwrap();
987
988		// Assert: error converted to 404 response AND header is present
989		assert_eq!(response.status, hyper::StatusCode::NOT_FOUND);
990		assert_eq!(
991			response
992				.headers
993				.get("X-Custom-Security")
994				.map(|v| v.to_str().unwrap()),
995			Some("applied")
996		);
997	}
998
999	#[rstest::rstest]
1000	#[tokio::test]
1001	async fn test_chain_post_processing_runs_on_middleware_error() {
1002		// Arrange: outer middleware adds header, inner middleware rejects.
1003		// First add = outermost in this framework's chain ordering.
1004		let handler = Arc::new(MockHandler {
1005			response_body: "OK".into(),
1006		});
1007		let mut chain = MiddlewareChain::new(handler);
1008		// Outer middleware adds a security header (post-processing)
1009		chain.add_middleware(Arc::new(HeaderAddingMiddleware {
1010			header_name: "X-Frame-Options",
1011			header_value: "DENY",
1012		}));
1013		// Inner middleware rejects the request
1014		chain.add_middleware(Arc::new(RejectingMiddleware));
1015
1016		// Act
1017		let request = create_test_request();
1018		let response = chain.handle(request).await.unwrap();
1019
1020		// Assert: inner middleware error converted to 403, outer middleware header present
1021		assert_eq!(response.status, hyper::StatusCode::FORBIDDEN);
1022		assert_eq!(
1023			response
1024				.headers
1025				.get("X-Frame-Options")
1026				.map(|v| v.to_str().unwrap()),
1027			Some("DENY")
1028		);
1029	}
1030
1031	/// Passthrough middleware that does not modify the response.
1032	struct PassthroughMiddleware;
1033
1034	#[async_trait]
1035	impl Middleware for PassthroughMiddleware {
1036		async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
1037			next.handle(request).await
1038		}
1039	}
1040
1041	#[rstest::rstest]
1042	#[tokio::test]
1043	async fn test_chain_error_preserves_correct_status_code() {
1044		// Arrange: handler returns 401 Unauthorized, with at least one middleware
1045		// so that ConditionalComposedHandler is used (empty chain bypasses it)
1046		let handler: Arc<dyn Handler> = Arc::new(UnauthorizedHandler);
1047		let mut chain = MiddlewareChain::new(handler);
1048		chain.add_middleware(Arc::new(PassthroughMiddleware));
1049
1050		// Act
1051		let request = create_test_request();
1052		let response = chain.handle(request).await.unwrap();
1053
1054		// Assert: status code correctly reflects the error
1055		assert_eq!(response.status, hyper::StatusCode::UNAUTHORIZED);
1056	}
1057}