Skip to main content

reinhardt_http/
middleware.rs

1//! Middleware and handler traits for HTTP request processing.
2//!
3//! This module provides the core abstractions for handling HTTP requests
4//! and composing middleware chains.
5//!
6//! ## Handler
7//!
8//! The `Handler` trait is the core abstraction for processing requests:
9//!
10//! ```rust
11//! use reinhardt_http::{Handler, Request, Response};
12//! use async_trait::async_trait;
13//!
14//! struct MyHandler;
15//!
16//! #[async_trait]
17//! impl Handler for MyHandler {
18//!     async fn handle(&self, request: Request) -> reinhardt_core::exception::Result<Response> {
19//!         Ok(Response::ok().with_body("Hello!"))
20//!     }
21//! }
22//! ```
23//!
24//! ## Middleware
25//!
26//! Middleware wraps handlers to add cross-cutting concerns:
27//!
28//! ```rust
29//! use reinhardt_http::{Handler, Middleware, Request, Response};
30//! use async_trait::async_trait;
31//! use std::sync::Arc;
32//!
33//! struct LoggingMiddleware;
34//!
35//! #[async_trait]
36//! impl Middleware for LoggingMiddleware {
37//!     async fn process(&self, request: Request, next: Arc<dyn Handler>) -> reinhardt_core::exception::Result<Response> {
38//!         println!("Request: {} {}", request.method, request.uri);
39//!         next.handle(request).await
40//!     }
41//! }
42//! ```
43
44use async_trait::async_trait;
45use reinhardt_core::exception::Result;
46use std::sync::Arc;
47
48use crate::{Request, Response};
49
50/// Handler trait for processing requests.
51///
52/// This is the core abstraction - all request handlers implement this trait.
53/// Handlers receive a request and produce a response or an error.
54#[async_trait]
55pub trait Handler: Send + Sync {
56	/// Handles an HTTP request and produces a response.
57	///
58	/// # Errors
59	///
60	/// Returns an error if the request cannot be processed.
61	async fn handle(&self, request: Request) -> Result<Response>;
62}
63
64/// Blanket implementation for `Arc<T>` where T: Handler.
65///
66/// This allows `Arc<dyn Handler>` to be used as a Handler,
67/// enabling shared ownership of handlers across threads.
68#[async_trait]
69impl<T: Handler + ?Sized> Handler for Arc<T> {
70	async fn handle(&self, request: Request) -> Result<Response> {
71		(**self).handle(request).await
72	}
73}
74
75/// Middleware trait for request/response processing.
76///
77/// Uses composition pattern instead of inheritance.
78/// Middleware can modify requests before passing to the next handler,
79/// or modify responses after the handler processes the request.
80#[async_trait]
81pub trait Middleware: Send + Sync {
82	/// Processes a request through this middleware.
83	///
84	/// # Arguments
85	///
86	/// * `request` - The incoming HTTP request
87	/// * `next` - The next handler in the chain to call
88	///
89	/// # Errors
90	///
91	/// Returns an error if the middleware or next handler fails.
92	async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response>;
93
94	/// Determines whether this middleware should be executed for the given request.
95	///
96	/// This method enables conditional execution of middleware, allowing the middleware
97	/// chain to skip unnecessary middleware based on request properties.
98	///
99	/// # Performance Benefits
100	///
101	/// By implementing this method, middleware chains can achieve O(k) complexity
102	/// instead of O(n), where k is the number of middleware that should run,
103	/// and k <= n (total middleware count).
104	///
105	/// # Common Use Cases
106	///
107	/// - Skip authentication middleware for public endpoints
108	/// - Skip compression middleware for already compressed responses
109	/// - Skip CORS middleware for same-origin requests
110	/// - Skip rate limiting for internal/admin requests
111	///
112	/// # Default Implementation
113	///
114	/// By default, returns `true` (always execute), maintaining backward compatibility.
115	fn should_continue(&self, _request: &Request) -> bool {
116		true
117	}
118}
119
120/// Middleware chain - composes multiple middleware into a single handler.
121///
122/// The chain processes requests through middleware in the order they were added,
123/// with optimizations for conditional execution and early termination.
124pub struct MiddlewareChain {
125	middlewares: Vec<Arc<dyn Middleware>>,
126	handler: Arc<dyn Handler>,
127}
128
129impl MiddlewareChain {
130	/// Creates a new middleware chain with the given handler.
131	///
132	/// # Examples
133	///
134	/// ```rust
135	/// use reinhardt_http::{MiddlewareChain, Handler, Request, Response};
136	/// use std::sync::Arc;
137	///
138	/// struct MyHandler;
139	///
140	/// #[async_trait::async_trait]
141	/// impl Handler for MyHandler {
142	///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
143	///         Ok(Response::ok())
144	///     }
145	/// }
146	///
147	/// let handler = Arc::new(MyHandler);
148	/// let chain = MiddlewareChain::new(handler);
149	/// ```
150	pub fn new(handler: Arc<dyn Handler>) -> Self {
151		Self {
152			middlewares: Vec::new(),
153			handler,
154		}
155	}
156
157	/// Adds a middleware to the chain using builder pattern.
158	///
159	/// # Examples
160	///
161	/// ```rust
162	/// use reinhardt_http::{MiddlewareChain, Handler, Middleware, Request, Response};
163	/// use std::sync::Arc;
164	///
165	/// # struct MyHandler;
166	/// # struct MyMiddleware;
167	/// # #[async_trait::async_trait]
168	/// # impl Handler for MyHandler {
169	/// #     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
170	/// #         Ok(Response::ok())
171	/// #     }
172	/// # }
173	/// # #[async_trait::async_trait]
174	/// # impl Middleware for MyMiddleware {
175	/// #     async fn process(&self, request: Request, next: Arc<dyn Handler>) -> reinhardt_core::exception::Result<Response> {
176	/// #         next.handle(request).await
177	/// #     }
178	/// # }
179	/// let handler = Arc::new(MyHandler);
180	/// let middleware = Arc::new(MyMiddleware);
181	/// let chain = MiddlewareChain::new(handler)
182	///     .with_middleware(middleware);
183	/// ```
184	pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
185		self.middlewares.push(middleware);
186		self
187	}
188
189	/// Adds a middleware to the chain.
190	///
191	/// # Examples
192	///
193	/// ```rust
194	/// use reinhardt_http::{MiddlewareChain, Handler, Middleware, Request, Response};
195	/// use std::sync::Arc;
196	///
197	/// # struct MyHandler;
198	/// # struct MyMiddleware;
199	/// # #[async_trait::async_trait]
200	/// # impl Handler for MyHandler {
201	/// #     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
202	/// #         Ok(Response::ok())
203	/// #     }
204	/// # }
205	/// # #[async_trait::async_trait]
206	/// # impl Middleware for MyMiddleware {
207	/// #     async fn process(&self, request: Request, next: Arc<dyn Handler>) -> reinhardt_core::exception::Result<Response> {
208	/// #         next.handle(request).await
209	/// #     }
210	/// # }
211	/// let handler = Arc::new(MyHandler);
212	/// let middleware = Arc::new(MyMiddleware);
213	/// let mut chain = MiddlewareChain::new(handler);
214	/// chain.add_middleware(middleware);
215	/// ```
216	pub fn add_middleware(&mut self, middleware: Arc<dyn Middleware>) {
217		self.middlewares.push(middleware);
218	}
219}
220
221#[async_trait]
222impl Handler for MiddlewareChain {
223	async fn handle(&self, request: Request) -> Result<Response> {
224		if self.middlewares.is_empty() {
225			return self.handler.handle(request).await;
226		}
227
228		// Build nested handler chain using composition with optimizations:
229		// 1. Conditional execution (skip middleware based on should_continue)
230		// 2. Short-circuiting (early return if response.should_stop_chain() is true)
231		//
232		// Performance improvements:
233		// - Condition check: O(1) per middleware
234		// - Skip unnecessary middleware: achieves O(k) where k <= n
235		// - Early return: stops processing on first stop_chain=true response
236		// Wrap the base handler to convert errors to responses, ensuring
237		// all middleware post-processing runs even for error responses.
238		let mut current_handler: Arc<dyn Handler> = Arc::new(ErrorToResponseHandler {
239			inner: self.handler.clone(),
240		});
241
242		// Filter middleware based on should_continue condition
243		// This achieves the O(k) optimization where k is the number of middleware that should run
244		let active_middlewares: Vec<_> = self
245			.middlewares
246			.iter()
247			.rev()
248			.filter(|mw| mw.should_continue(&request))
249			.collect();
250
251		for middleware in active_middlewares {
252			let mw = middleware.clone();
253			let handler = current_handler.clone();
254
255			current_handler = Arc::new(ConditionalComposedHandler {
256				middleware: mw,
257				next: handler,
258			});
259		}
260
261		current_handler.handle(request).await
262	}
263}
264
265/// Middleware wrapper that excludes specific URL paths from execution.
266///
267/// When a request matches an excluded path, the middleware is skipped
268/// and the request passes directly to the next handler in the chain.
269///
270/// Path matching follows Django URL conventions:
271/// - Paths ending with `/` are treated as **prefix matches**
272///   (e.g., `"/api/auth/"` excludes `"/api/auth/login"`, `"/api/auth/register"`)
273/// - Paths without trailing `/` require an **exact match**
274///   (e.g., `"/health"` excludes only `"/health"`, not `"/health/check"`)
275///
276/// This struct is typically not used directly. Instead, use the
277/// `exclude` methods on the `ServerRouter` or `UnifiedRouter` types
278/// from the `reinhardt_urls::routers` module for declarative
279/// route exclusion at the router level.
280///
281/// # Examples
282///
283/// ```rust
284/// use reinhardt_http::middleware::ExcludeMiddleware;
285/// use reinhardt_http::{Middleware, Request};
286/// use std::sync::Arc;
287///
288/// # struct MyMiddleware;
289/// # #[async_trait::async_trait]
290/// # impl Middleware for MyMiddleware {
291/// #     async fn process(
292/// #         &self,
293/// #         request: Request,
294/// #         next: Arc<dyn reinhardt_http::Handler>,
295/// #     ) -> reinhardt_core::exception::Result<reinhardt_http::Response> {
296/// #         next.handle(request).await
297/// #     }
298/// # }
299/// let inner: Arc<dyn Middleware> = Arc::new(MyMiddleware);
300/// let excluded = ExcludeMiddleware::new(inner)
301///     .add_exclusion("/api/auth/")   // prefix match
302///     .add_exclusion("/health");     // exact match
303/// ```
304pub struct ExcludeMiddleware {
305	inner: Arc<dyn Middleware>,
306	exclusions: Vec<String>,
307}
308
309impl ExcludeMiddleware {
310	/// Creates a new `ExcludeMiddleware` wrapping the given middleware.
311	pub fn new(inner: Arc<dyn Middleware>) -> Self {
312		Self {
313			inner,
314			exclusions: Vec::new(),
315		}
316	}
317
318	/// Adds an exclusion pattern (builder pattern, consumes self).
319	///
320	/// Paths ending with `/` are prefix matches; others are exact matches.
321	pub fn add_exclusion(mut self, pattern: &str) -> Self {
322		self.exclusions.push(pattern.to_string());
323		self
324	}
325
326	/// Adds an exclusion pattern (mutable reference).
327	///
328	/// Paths ending with `/` are prefix matches; others are exact matches.
329	pub fn add_exclusion_mut(&mut self, pattern: &str) {
330		self.exclusions.push(pattern.to_string());
331	}
332
333	/// Checks whether the given path matches any exclusion pattern.
334	fn is_excluded(&self, path: &str) -> bool {
335		self.exclusions.iter().any(|pattern| {
336			if pattern.ends_with('/') {
337				// Prefix match: excluded if path starts with the pattern
338				path.starts_with(pattern.as_str())
339			} else {
340				// Exact match: excluded only if path equals the pattern
341				path == pattern
342			}
343		})
344	}
345}
346
347#[async_trait]
348impl Middleware for ExcludeMiddleware {
349	async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
350		self.inner.process(request, next).await
351	}
352
353	fn should_continue(&self, request: &Request) -> bool {
354		if self.is_excluded(request.uri.path()) {
355			return false;
356		}
357		self.inner.should_continue(request)
358	}
359}
360
361/// Internal handler wrapper that converts errors to HTTP responses.
362///
363/// Wraps the base handler so that middleware always receives `Ok(Response)`
364/// from `next.handle()`, even when the handler returns an error. This ensures
365/// middleware post-processing (e.g., adding security headers) runs for all
366/// responses, matching Django's `process_response` semantics.
367struct ErrorToResponseHandler {
368	inner: Arc<dyn Handler>,
369}
370
371#[async_trait]
372impl Handler for ErrorToResponseHandler {
373	async fn handle(&self, request: Request) -> Result<Response> {
374		match self.inner.handle(request).await {
375			Ok(response) => Ok(response),
376			Err(e) => Ok(Response::from(e)),
377		}
378	}
379}
380
381/// Internal handler that composes a single middleware with the next handler.
382///
383/// Converts middleware errors to HTTP responses so that outer middleware
384/// post-processing (e.g., adding security headers) always runs.
385struct ConditionalComposedHandler {
386	middleware: Arc<dyn Middleware>,
387	next: Arc<dyn Handler>,
388}
389
390#[async_trait]
391impl Handler for ConditionalComposedHandler {
392	async fn handle(&self, request: Request) -> Result<Response> {
393		// Process the request through this middleware.
394		// Convert errors to responses so that outer middleware post-processing
395		// (e.g., security headers) always runs — matching Django's process_response
396		// semantics where the response hook executes for both success and error cases.
397		let response = match self.middleware.process(request, self.next.clone()).await {
398			Ok(response) => response,
399			Err(e) => Response::from(e),
400		};
401
402		Ok(response)
403	}
404}
405
406#[cfg(test)]
407mod tests {
408	use super::*;
409	use bytes::Bytes;
410	use hyper::{HeaderMap, Method, Version};
411
412	// Mock handler for testing
413	struct MockHandler {
414		response_body: String,
415	}
416
417	#[async_trait]
418	impl Handler for MockHandler {
419		async fn handle(&self, _request: Request) -> Result<Response> {
420			Ok(Response::ok().with_body(self.response_body.clone()))
421		}
422	}
423
424	// Mock middleware for testing
425	struct MockMiddleware {
426		prefix: String,
427	}
428
429	#[async_trait]
430	impl Middleware for MockMiddleware {
431		async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
432			// Call the next handler
433			let response = next.handle(request).await?;
434
435			// Modify the response
436			let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
437			let new_body = format!("{}{}", self.prefix, current_body);
438
439			Ok(Response::ok().with_body(new_body))
440		}
441	}
442
443	fn create_test_request() -> Request {
444		Request::builder()
445			.method(Method::GET)
446			.uri("/")
447			.version(Version::HTTP_11)
448			.headers(HeaderMap::new())
449			.body(Bytes::new())
450			.build()
451			.unwrap()
452	}
453
454	#[tokio::test]
455	async fn test_handler_basic() {
456		let handler = MockHandler {
457			response_body: "Hello".to_string(),
458		};
459
460		let request = create_test_request();
461		let response = handler.handle(request).await.unwrap();
462
463		let body = String::from_utf8(response.body.to_vec()).unwrap();
464		assert_eq!(body, "Hello");
465	}
466
467	#[tokio::test]
468	async fn test_middleware_basic() {
469		let handler = Arc::new(MockHandler {
470			response_body: "World".to_string(),
471		});
472
473		let middleware = MockMiddleware {
474			prefix: "Hello, ".to_string(),
475		};
476
477		let request = create_test_request();
478		let response = middleware.process(request, handler).await.unwrap();
479
480		let body = String::from_utf8(response.body.to_vec()).unwrap();
481		assert_eq!(body, "Hello, World");
482	}
483
484	#[tokio::test]
485	async fn test_middleware_chain_empty() {
486		let handler = Arc::new(MockHandler {
487			response_body: "Test".to_string(),
488		});
489
490		let chain = MiddlewareChain::new(handler);
491
492		let request = create_test_request();
493		let response = chain.handle(request).await.unwrap();
494
495		let body = String::from_utf8(response.body.to_vec()).unwrap();
496		assert_eq!(body, "Test");
497	}
498
499	#[tokio::test]
500	async fn test_middleware_chain_single() {
501		let handler = Arc::new(MockHandler {
502			response_body: "Handler".to_string(),
503		});
504
505		let middleware1 = Arc::new(MockMiddleware {
506			prefix: "MW1:".to_string(),
507		});
508
509		let chain = MiddlewareChain::new(handler).with_middleware(middleware1);
510
511		let request = create_test_request();
512		let response = chain.handle(request).await.unwrap();
513
514		let body = String::from_utf8(response.body.to_vec()).unwrap();
515		assert_eq!(body, "MW1:Handler");
516	}
517
518	#[tokio::test]
519	async fn test_middleware_chain_multiple() {
520		let handler = Arc::new(MockHandler {
521			response_body: "Data".to_string(),
522		});
523
524		let middleware1 = Arc::new(MockMiddleware {
525			prefix: "M1:".to_string(),
526		});
527
528		let middleware2 = Arc::new(MockMiddleware {
529			prefix: "M2:".to_string(),
530		});
531
532		let chain = MiddlewareChain::new(handler)
533			.with_middleware(middleware1)
534			.with_middleware(middleware2);
535
536		let request = create_test_request();
537		let response = chain.handle(request).await.unwrap();
538
539		let body = String::from_utf8(response.body.to_vec()).unwrap();
540		// Middleware are applied in the order they were added
541		assert_eq!(body, "M1:M2:Data");
542	}
543
544	#[tokio::test]
545	async fn test_middleware_chain_add_middleware() {
546		let handler = Arc::new(MockHandler {
547			response_body: "Result".to_string(),
548		});
549
550		let middleware = Arc::new(MockMiddleware {
551			prefix: "Prefix:".to_string(),
552		});
553
554		let mut chain = MiddlewareChain::new(handler);
555		chain.add_middleware(middleware);
556
557		let request = create_test_request();
558		let response = chain.handle(request).await.unwrap();
559
560		let body = String::from_utf8(response.body.to_vec()).unwrap();
561		assert_eq!(body, "Prefix:Result");
562	}
563
564	// Conditional middleware that only runs for /api/* paths
565	struct ConditionalMiddleware {
566		prefix: String,
567	}
568
569	#[async_trait]
570	impl Middleware for ConditionalMiddleware {
571		async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
572			let response = next.handle(request).await?;
573			let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
574			let new_body = format!("{}{}", self.prefix, current_body);
575			Ok(Response::ok().with_body(new_body))
576		}
577
578		fn should_continue(&self, request: &Request) -> bool {
579			request.uri.path().starts_with("/api/")
580		}
581	}
582
583	#[tokio::test]
584	async fn test_middleware_conditional_skip() {
585		let handler = Arc::new(MockHandler {
586			response_body: "Response".to_string(),
587		});
588
589		let conditional_mw = Arc::new(ConditionalMiddleware {
590			prefix: "API:".to_string(),
591		});
592
593		let chain = MiddlewareChain::new(handler).with_middleware(conditional_mw);
594
595		// Test with /api/ path - middleware should run
596		let api_request = Request::builder()
597			.method(Method::GET)
598			.uri("/api/users")
599			.version(Version::HTTP_11)
600			.headers(HeaderMap::new())
601			.body(Bytes::new())
602			.build()
603			.unwrap();
604		let response = chain.handle(api_request).await.unwrap();
605		let body = String::from_utf8(response.body.to_vec()).unwrap();
606		assert_eq!(body, "API:Response");
607
608		// Test with non-/api/ path - middleware should be skipped
609		let non_api_request = Request::builder()
610			.method(Method::GET)
611			.uri("/public")
612			.version(Version::HTTP_11)
613			.headers(HeaderMap::new())
614			.body(Bytes::new())
615			.build()
616			.unwrap();
617		let response = chain.handle(non_api_request).await.unwrap();
618		let body = String::from_utf8(response.body.to_vec()).unwrap();
619		assert_eq!(body, "Response"); // No prefix because middleware was skipped
620	}
621
622	// Middleware that returns early with stop_chain=true
623	struct ShortCircuitMiddleware {
624		should_stop: bool,
625	}
626
627	#[async_trait]
628	impl Middleware for ShortCircuitMiddleware {
629		async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
630			if self.should_stop {
631				// Return early without calling next
632				return Ok(Response::unauthorized()
633					.with_body("Auth required")
634					.with_stop_chain(true));
635			}
636			next.handle(request).await
637		}
638	}
639
640	#[tokio::test]
641	async fn test_middleware_short_circuit() {
642		let handler = Arc::new(MockHandler {
643			response_body: "Handler Response".to_string(),
644		});
645
646		let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: true });
647		let normal_mw = Arc::new(MockMiddleware {
648			prefix: "Normal:".to_string(),
649		});
650
651		let chain = MiddlewareChain::new(handler)
652			.with_middleware(short_circuit_mw)
653			.with_middleware(normal_mw);
654
655		let request = create_test_request();
656		let response = chain.handle(request).await.unwrap();
657
658		// Should get unauthorized response, not the handler response
659		assert_eq!(response.status, hyper::StatusCode::UNAUTHORIZED);
660		let body = String::from_utf8(response.body.to_vec()).unwrap();
661		assert_eq!(body, "Auth required");
662	}
663
664	#[tokio::test]
665	async fn test_middleware_no_short_circuit() {
666		let handler = Arc::new(MockHandler {
667			response_body: "Handler Response".to_string(),
668		});
669
670		let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: false });
671		let normal_mw = Arc::new(MockMiddleware {
672			prefix: "Normal:".to_string(),
673		});
674
675		let chain = MiddlewareChain::new(handler)
676			.with_middleware(short_circuit_mw)
677			.with_middleware(normal_mw);
678
679		let request = create_test_request();
680		let response = chain.handle(request).await.unwrap();
681
682		// Should pass through to handler and apply normal middleware
683		assert_eq!(response.status, hyper::StatusCode::OK);
684		let body = String::from_utf8(response.body.to_vec()).unwrap();
685		assert_eq!(body, "Normal:Handler Response");
686	}
687
688	#[tokio::test]
689	async fn test_middleware_multiple_conditions() {
690		let handler = Arc::new(MockHandler {
691			response_body: "Base".to_string(),
692		});
693
694		// Only runs for /api/* paths
695		let api_mw = Arc::new(ConditionalMiddleware {
696			prefix: "API:".to_string(),
697		});
698
699		// Always runs
700		let always_mw = Arc::new(MockMiddleware {
701			prefix: "Always:".to_string(),
702		});
703
704		let chain = MiddlewareChain::new(handler)
705			.with_middleware(api_mw)
706			.with_middleware(always_mw);
707
708		// Test with /api/ path - both middleware should run
709		let api_request = Request::builder()
710			.method(Method::GET)
711			.uri("/api/test")
712			.version(Version::HTTP_11)
713			.headers(HeaderMap::new())
714			.body(Bytes::new())
715			.build()
716			.unwrap();
717		let response = chain.handle(api_request).await.unwrap();
718		let body = String::from_utf8(response.body.to_vec()).unwrap();
719		assert_eq!(body, "API:Always:Base");
720
721		// Test with non-/api/ path - only always_mw should run
722		let non_api_request = Request::builder()
723			.method(Method::GET)
724			.uri("/public")
725			.version(Version::HTTP_11)
726			.headers(HeaderMap::new())
727			.body(Bytes::new())
728			.build()
729			.unwrap();
730		let response = chain.handle(non_api_request).await.unwrap();
731		let body = String::from_utf8(response.body.to_vec()).unwrap();
732		assert_eq!(body, "Always:Base"); // Only always_mw prefix
733	}
734
735	#[tokio::test]
736	async fn test_response_should_stop_chain() {
737		let response = Response::ok();
738		assert!(!response.should_stop_chain());
739
740		let stopping_response = Response::unauthorized().with_stop_chain(true);
741		assert!(stopping_response.should_stop_chain());
742	}
743
744	// --- ExcludeMiddleware tests ---
745
746	fn create_request_with_path(path: &str) -> Request {
747		Request::builder()
748			.method(Method::GET)
749			.uri(path)
750			.version(Version::HTTP_11)
751			.headers(HeaderMap::new())
752			.body(Bytes::new())
753			.build()
754			.unwrap()
755	}
756
757	#[rstest::rstest]
758	#[case("/api/auth/login", true)]
759	#[case("/api/auth/register", true)]
760	#[case("/api/auth/", true)]
761	#[case("/api/users", false)]
762	#[case("/public", false)]
763	fn test_exclude_middleware_prefix_match(#[case] path: &str, #[case] should_exclude: bool) {
764		// Arrange
765		let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
766			prefix: "MW:".to_string(),
767		});
768		let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/api/auth/");
769
770		// Act
771		let request = create_request_with_path(path);
772		let result = exclude_mw.should_continue(&request);
773
774		// Assert
775		assert_eq!(result, !should_exclude);
776	}
777
778	#[rstest::rstest]
779	#[case("/health", true)]
780	#[case("/health/check", false)]
781	#[case("/healthz", false)]
782	#[case("/api/health", false)]
783	fn test_exclude_middleware_exact_match(#[case] path: &str, #[case] should_exclude: bool) {
784		// Arrange
785		let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
786			prefix: "MW:".to_string(),
787		});
788		let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/health");
789
790		// Act
791		let request = create_request_with_path(path);
792		let result = exclude_mw.should_continue(&request);
793
794		// Assert
795		assert_eq!(result, !should_exclude);
796	}
797
798	#[rstest::rstest]
799	fn test_exclude_middleware_no_match_passes_through() {
800		// Arrange
801		let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
802			prefix: "MW:".to_string(),
803		});
804		let exclude_mw = ExcludeMiddleware::new(inner)
805			.add_exclusion("/api/auth/")
806			.add_exclusion("/health");
807
808		// Act
809		let request = create_request_with_path("/api/users");
810		let result = exclude_mw.should_continue(&request);
811
812		// Assert
813		assert!(result);
814	}
815
816	#[rstest::rstest]
817	#[tokio::test]
818	async fn test_exclude_middleware_delegates_process() {
819		// Arrange
820		let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
821			prefix: "INNER:".to_string(),
822		});
823		let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/excluded/");
824
825		let handler = Arc::new(MockHandler {
826			response_body: "Response".to_string(),
827		});
828
829		// Act
830		let request = create_request_with_path("/api/test");
831		let response = exclude_mw.process(request, handler).await.unwrap();
832
833		// Assert
834		let body = String::from_utf8(response.body.to_vec()).unwrap();
835		assert_eq!(body, "INNER:Response");
836	}
837
838	#[rstest::rstest]
839	fn test_exclude_middleware_multiple_exclusions() {
840		// Arrange
841		let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
842			prefix: "MW:".to_string(),
843		});
844		let mut exclude_mw = ExcludeMiddleware::new(inner);
845		exclude_mw.add_exclusion_mut("/api/auth/");
846		exclude_mw.add_exclusion_mut("/admin/");
847		exclude_mw.add_exclusion_mut("/health");
848
849		// Act & Assert
850		assert!(!exclude_mw.should_continue(&create_request_with_path("/api/auth/login")));
851		assert!(!exclude_mw.should_continue(&create_request_with_path("/admin/dashboard")));
852		assert!(!exclude_mw.should_continue(&create_request_with_path("/health")));
853		assert!(exclude_mw.should_continue(&create_request_with_path("/api/users")));
854	}
855
856	#[rstest::rstest]
857	fn test_exclude_middleware_respects_inner_should_continue() {
858		// Arrange - inner middleware that rejects non-/api/ paths
859		let inner: Arc<dyn Middleware> = Arc::new(ConditionalMiddleware {
860			prefix: "API:".to_string(),
861		});
862		let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/api/auth/");
863
864		// Act & Assert
865		// Excluded path -> false (excluded by wrapper)
866		assert!(!exclude_mw.should_continue(&create_request_with_path("/api/auth/login")));
867		// Non-excluded, but inner rejects non-/api/ -> false (inner's should_continue)
868		assert!(!exclude_mw.should_continue(&create_request_with_path("/public")));
869		// Non-excluded, inner accepts /api/ -> true
870		assert!(exclude_mw.should_continue(&create_request_with_path("/api/users")));
871	}
872
873	// ========================================================================
874	// Error-to-response conversion tests (issue #3230)
875	// ========================================================================
876
877	/// Handler that always returns an error.
878	struct NotFoundHandler;
879
880	#[async_trait]
881	impl Handler for NotFoundHandler {
882		async fn handle(&self, _request: Request) -> Result<Response> {
883			Err(reinhardt_core::exception::Error::NotFound(
884				"not found".into(),
885			))
886		}
887	}
888
889	struct UnauthorizedHandler;
890
891	#[async_trait]
892	impl Handler for UnauthorizedHandler {
893		async fn handle(&self, _request: Request) -> Result<Response> {
894			Err(reinhardt_core::exception::Error::Authentication(
895				"unauthorized".into(),
896			))
897		}
898	}
899
900	/// Middleware that adds a custom header to the response after calling next.
901	struct HeaderAddingMiddleware {
902		header_name: &'static str,
903		header_value: &'static str,
904	}
905
906	#[async_trait]
907	impl Middleware for HeaderAddingMiddleware {
908		async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
909			let response = next.handle(request).await?;
910			Ok(response.with_header(self.header_name, self.header_value))
911		}
912	}
913
914	/// Middleware that always returns an error (simulates CSRF rejection).
915	struct RejectingMiddleware;
916
917	#[async_trait]
918	impl Middleware for RejectingMiddleware {
919		async fn process(&self, _request: Request, _next: Arc<dyn Handler>) -> Result<Response> {
920			Err(reinhardt_core::exception::Error::Authorization(
921				"CSRF check failed".into(),
922			))
923		}
924	}
925
926	#[rstest::rstest]
927	#[tokio::test]
928	async fn test_chain_post_processing_runs_on_handler_error() {
929		// Arrange: handler returns 404 error, outer middleware adds header
930		let handler: Arc<dyn Handler> = Arc::new(NotFoundHandler);
931		let mut chain = MiddlewareChain::new(handler);
932		chain.add_middleware(Arc::new(HeaderAddingMiddleware {
933			header_name: "X-Custom-Security",
934			header_value: "applied",
935		}));
936
937		// Act
938		let request = create_test_request();
939		let response = chain.handle(request).await.unwrap();
940
941		// Assert: error converted to 404 response AND header is present
942		assert_eq!(response.status, hyper::StatusCode::NOT_FOUND);
943		assert_eq!(
944			response
945				.headers
946				.get("X-Custom-Security")
947				.map(|v| v.to_str().unwrap()),
948			Some("applied")
949		);
950	}
951
952	#[rstest::rstest]
953	#[tokio::test]
954	async fn test_chain_post_processing_runs_on_middleware_error() {
955		// Arrange: outer middleware adds header, inner middleware rejects.
956		// First add = outermost in this framework's chain ordering.
957		let handler = Arc::new(MockHandler {
958			response_body: "OK".into(),
959		});
960		let mut chain = MiddlewareChain::new(handler);
961		// Outer middleware adds a security header (post-processing)
962		chain.add_middleware(Arc::new(HeaderAddingMiddleware {
963			header_name: "X-Frame-Options",
964			header_value: "DENY",
965		}));
966		// Inner middleware rejects the request
967		chain.add_middleware(Arc::new(RejectingMiddleware));
968
969		// Act
970		let request = create_test_request();
971		let response = chain.handle(request).await.unwrap();
972
973		// Assert: inner middleware error converted to 403, outer middleware header present
974		assert_eq!(response.status, hyper::StatusCode::FORBIDDEN);
975		assert_eq!(
976			response
977				.headers
978				.get("X-Frame-Options")
979				.map(|v| v.to_str().unwrap()),
980			Some("DENY")
981		);
982	}
983
984	/// Passthrough middleware that does not modify the response.
985	struct PassthroughMiddleware;
986
987	#[async_trait]
988	impl Middleware for PassthroughMiddleware {
989		async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
990			next.handle(request).await
991		}
992	}
993
994	#[rstest::rstest]
995	#[tokio::test]
996	async fn test_chain_error_preserves_correct_status_code() {
997		// Arrange: handler returns 401 Unauthorized, with at least one middleware
998		// so that ConditionalComposedHandler is used (empty chain bypasses it)
999		let handler: Arc<dyn Handler> = Arc::new(UnauthorizedHandler);
1000		let mut chain = MiddlewareChain::new(handler);
1001		chain.add_middleware(Arc::new(PassthroughMiddleware));
1002
1003		// Act
1004		let request = create_test_request();
1005		let response = chain.handle(request).await.unwrap();
1006
1007		// Assert: status code correctly reflects the error
1008		assert_eq!(response.status, hyper::StatusCode::UNAUTHORIZED);
1009	}
1010}