Skip to main content

fastapi_core/
middleware.rs

1//! Middleware abstraction for request/response processing.
2//!
3//! This module provides a flexible middleware system that allows:
4//! - Pre-processing requests before handlers run
5//! - Post-processing responses after handlers complete
6//! - Short-circuiting to return early without calling handlers
7//! - Composable middleware stacks with defined ordering
8//!
9//! # Design Philosophy
10//!
11//! The middleware system follows these principles:
12//! - **Zero-cost when empty**: No overhead if no middleware is configured
13//! - **Async-native**: All hooks are async for I/O operations
14//! - **Cancel-aware**: Integrates with asupersync's cancellation
15//! - **Composable**: Middleware can be stacked and layered
16//!
17//! # Ordering Semantics
18//!
19//! Middleware executes in a specific order:
20//! 1. `before` hooks run in **registration order** (first registered, first run)
21//! 2. Handler executes
22//! 3. `after` hooks run in **reverse order** (last registered, first run)
23//!
24//! This creates an "onion" model where the first middleware wraps everything:
25//!
26//! ```text
27//! Request → MW1.before → MW2.before → MW3.before → Handler
28//!                                                     ↓
29//! Response ← MW1.after ← MW2.after ← MW3.after ← Response
30//! ```
31//!
32//! # Example
33//!
34//! ```ignore
35//! use fastapi_core::middleware::{Middleware, ControlFlow};
36//! use fastapi_core::{Request, Response, RequestContext};
37//!
38//! struct LoggingMiddleware;
39//!
40//! impl Middleware for LoggingMiddleware {
41//!     async fn before(&self, ctx: &RequestContext, req: &Request) -> ControlFlow {
42//!         println!("Request: {} {}", req.method(), req.path());
43//!         ControlFlow::Continue
44//!     }
45//!
46//!     async fn after(&self, _ctx: &RequestContext, _req: &Request, resp: Response) -> Response {
47//!         println!("Response: {}", resp.status().as_u16());
48//!         resp
49//!     }
50//! }
51//! ```
52
53use std::collections::HashSet;
54use std::future::Future;
55use std::ops::ControlFlow as StdControlFlow;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::time::Instant;
59
60use crate::context::RequestContext;
61use crate::dependency::DependencyOverrides;
62use crate::logging::{LogConfig, RequestLogger};
63use crate::request::{Body, Request};
64use crate::response::Response;
65
66/// A boxed future for async middleware operations.
67pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
68
69/// Control flow for middleware `before` hooks.
70///
71/// Determines whether request processing should continue to the handler
72/// or short-circuit with an early response.
73#[derive(Debug)]
74pub enum ControlFlow {
75    /// Continue processing - call the next middleware or handler.
76    Continue,
77    /// Short-circuit - return this response immediately without calling the handler.
78    ///
79    /// Subsequent `before` hooks and the handler will NOT run.
80    /// However, `after` hooks for middleware that already ran their `before` WILL run.
81    Break(Response),
82}
83
84impl ControlFlow {
85    /// Returns `true` if this is `Continue`.
86    #[must_use]
87    pub fn is_continue(&self) -> bool {
88        matches!(self, Self::Continue)
89    }
90
91    /// Returns `true` if this is `Break`.
92    #[must_use]
93    pub fn is_break(&self) -> bool {
94        matches!(self, Self::Break(_))
95    }
96}
97
98impl From<ControlFlow> for StdControlFlow<Response, ()> {
99    fn from(cf: ControlFlow) -> Self {
100        match cf {
101            ControlFlow::Continue => StdControlFlow::Continue(()),
102            ControlFlow::Break(r) => StdControlFlow::Break(r),
103        }
104    }
105}
106
107/// The core middleware trait.
108///
109/// Middleware wraps request handling with pre-processing and post-processing hooks.
110/// Implementations must be thread-safe (`Send + Sync`) as middleware may be shared
111/// across concurrent requests.
112///
113/// # Implementation Guide
114///
115/// - **`before`**: Inspect/modify the request, optionally short-circuit
116/// - **`after`**: Inspect/modify the response
117///
118/// Both methods have default implementations that do nothing, so you can
119/// implement only what you need.
120///
121/// # Cancel-Safety
122///
123/// Middleware should check `ctx.checkpoint()` for long operations to support
124/// graceful cancellation when clients disconnect or timeouts occur.
125///
126/// # Example: Request Timing
127///
128/// ```ignore
129/// use std::time::Instant;
130/// use fastapi_core::middleware::{Middleware, ControlFlow};
131///
132/// struct TimingMiddleware;
133///
134/// impl Middleware for TimingMiddleware {
135///     async fn before(&self, ctx: &RequestContext, req: &mut Request) -> ControlFlow {
136///         // Store start time in request extensions (future feature)
137///         ControlFlow::Continue
138///     }
139///
140///     async fn after(&self, _ctx: &RequestContext, _req: &Request, mut resp: Response) -> Response {
141///         // Add timing header
142///         resp = resp.header("X-Response-Time", b"42ms".to_vec());
143///         resp
144///     }
145/// }
146/// ```
147pub trait Middleware: Send + Sync {
148    /// Called before the handler executes.
149    ///
150    /// # Parameters
151    ///
152    /// - `ctx`: Request context with cancellation support
153    /// - `req`: Mutable request that can be inspected or modified
154    ///
155    /// # Returns
156    ///
157    /// - `ControlFlow::Continue` to proceed to the next middleware/handler
158    /// - `ControlFlow::Break(response)` to short-circuit and return immediately
159    ///
160    /// # Default Implementation
161    ///
162    /// Returns `ControlFlow::Continue` (no-op).
163    fn before<'a>(
164        &'a self,
165        _ctx: &'a RequestContext,
166        _req: &'a mut Request,
167    ) -> BoxFuture<'a, ControlFlow> {
168        Box::pin(async { ControlFlow::Continue })
169    }
170
171    /// Called after the handler executes.
172    ///
173    /// # Parameters
174    ///
175    /// - `ctx`: Request context with cancellation support
176    /// - `req`: The request (read-only at this point)
177    /// - `response`: The response from the handler or previous `after` hooks
178    ///
179    /// # Returns
180    ///
181    /// The response to pass to the next `after` hook or to return to the client.
182    ///
183    /// # Default Implementation
184    ///
185    /// Returns the response unchanged (no-op).
186    fn after<'a>(
187        &'a self,
188        _ctx: &'a RequestContext,
189        _req: &'a Request,
190        response: Response,
191    ) -> BoxFuture<'a, Response> {
192        Box::pin(async move { response })
193    }
194
195    /// Returns the middleware name for debugging and logging.
196    ///
197    /// Override this to provide a meaningful name for your middleware.
198    fn name(&self) -> &'static str {
199        std::any::type_name::<Self>()
200    }
201}
202
203/// A handler that processes requests into responses.
204///
205/// This trait abstracts over handler functions, allowing middleware to wrap
206/// any type that can handle requests.
207pub trait Handler: Send + Sync {
208    /// Process a request and return a response.
209    fn call<'a>(&'a self, ctx: &'a RequestContext, req: &'a mut Request)
210    -> BoxFuture<'a, Response>;
211
212    /// Optional dependency overrides to apply when building request contexts.
213    ///
214    /// Default implementation returns `None`, which means no overrides.
215    fn dependency_overrides(&self) -> Option<Arc<DependencyOverrides>> {
216        None
217    }
218}
219
220/// Implement Handler for async functions.
221///
222/// This allows any async function with the signature
223/// `async fn(&RequestContext, &mut Request) -> Response` to be used as a handler.
224impl<F, Fut> Handler for F
225where
226    F: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync,
227    Fut: Future<Output = Response> + Send + 'static,
228{
229    fn call<'a>(
230        &'a self,
231        ctx: &'a RequestContext,
232        req: &'a mut Request,
233    ) -> BoxFuture<'a, Response> {
234        let fut = self(ctx, req);
235        Box::pin(fut)
236    }
237}
238
239/// A stack of middleware that wraps a handler.
240///
241/// The stack executes middleware in order:
242/// 1. `before` hooks run first-to-last (registration order)
243/// 2. Handler executes (if no middleware short-circuited)
244/// 3. `after` hooks run last-to-first (reverse order)
245///
246/// # Example
247///
248/// ```ignore
249/// let mut stack = MiddlewareStack::new();
250/// stack.push(LoggingMiddleware);
251/// stack.push(AuthMiddleware);
252/// stack.push(CorsMiddleware);
253///
254/// let response = stack.execute(&handler, &ctx, &mut request).await;
255/// ```
256#[derive(Default)]
257pub struct MiddlewareStack {
258    middleware: Vec<Arc<dyn Middleware>>,
259}
260
261impl MiddlewareStack {
262    /// Creates an empty middleware stack.
263    #[must_use]
264    pub fn new() -> Self {
265        Self {
266            middleware: Vec::new(),
267        }
268    }
269
270    /// Creates a middleware stack with pre-allocated capacity.
271    #[must_use]
272    pub fn with_capacity(capacity: usize) -> Self {
273        Self {
274            middleware: Vec::with_capacity(capacity),
275        }
276    }
277
278    /// Adds middleware to the end of the stack.
279    ///
280    /// Middleware added first will have its `before` run first and `after` run last.
281    pub fn push<M: Middleware + 'static>(&mut self, middleware: M) {
282        self.middleware.push(Arc::new(middleware));
283    }
284
285    /// Adds middleware wrapped in an Arc.
286    ///
287    /// Useful for sharing middleware across multiple stacks.
288    pub fn push_arc(&mut self, middleware: Arc<dyn Middleware>) {
289        self.middleware.push(middleware);
290    }
291
292    /// Returns the number of middleware in the stack.
293    #[must_use]
294    pub fn len(&self) -> usize {
295        self.middleware.len()
296    }
297
298    /// Returns `true` if the stack is empty.
299    #[must_use]
300    pub fn is_empty(&self) -> bool {
301        self.middleware.is_empty()
302    }
303
304    /// Executes the middleware stack with the given handler.
305    ///
306    /// # Execution Order
307    ///
308    /// 1. Each middleware's `before` hook runs in order
309    /// 2. If any `before` returns `Break`, skip remaining middleware and handler
310    /// 3. Handler executes
311    /// 4. Each middleware's `after` hook runs in reverse order
312    ///
313    /// # Short-Circuit Behavior
314    ///
315    /// If middleware N calls `Break(response)`:
316    /// - Middleware N+1..end `before` hooks do NOT run
317    /// - Handler does NOT run
318    /// - Middleware 0..N `after` hooks STILL run (in reverse: N, N-1, ..., 0)
319    ///
320    /// This ensures cleanup middleware (like timing or logging) always runs.
321    pub async fn execute<H: Handler>(
322        &self,
323        handler: &H,
324        ctx: &RequestContext,
325        req: &mut Request,
326    ) -> Response {
327        // Track which middleware ran their `before` hook
328        let mut ran_before_count = 0;
329
330        // Run before hooks in order
331        for mw in &self.middleware {
332            let _ = ctx.checkpoint();
333            match mw.before(ctx, req).await {
334                ControlFlow::Continue => {
335                    ran_before_count += 1;
336                }
337                ControlFlow::Break(response) => {
338                    // Short-circuit: run after hooks for middleware that already ran
339                    return self
340                        .run_after_hooks(ctx, req, response, ran_before_count)
341                        .await;
342                }
343            }
344        }
345
346        // All before hooks passed, call the handler
347        let _ = ctx.checkpoint();
348        let response = handler.call(ctx, req).await;
349
350        // Run after hooks in reverse order
351        self.run_after_hooks(ctx, req, response, ran_before_count)
352            .await
353    }
354
355    /// Runs after hooks for middleware that ran their before hook.
356    async fn run_after_hooks(
357        &self,
358        ctx: &RequestContext,
359        req: &Request,
360        mut response: Response,
361        count: usize,
362    ) -> Response {
363        // Run in reverse order (last middleware's after runs first)
364        for mw in self.middleware[..count].iter().rev() {
365            let _ = ctx.checkpoint();
366            response = mw.after(ctx, req, response).await;
367        }
368        response
369    }
370}
371
372/// A layer that can wrap handlers with middleware.
373///
374/// This provides a more functional composition style similar to Tower's Layer trait.
375///
376/// # Example
377///
378/// ```ignore
379/// let layer = Layer::new(LoggingMiddleware);
380/// let wrapped = layer.wrap(my_handler);
381/// ```
382pub struct Layer<M> {
383    middleware: M,
384}
385
386impl<M: Middleware + Clone> Layer<M> {
387    /// Creates a new layer with the given middleware.
388    pub fn new(middleware: M) -> Self {
389        Self { middleware }
390    }
391
392    /// Wraps a handler with this layer's middleware.
393    pub fn wrap<H: Handler>(&self, handler: H) -> Layered<M, H> {
394        Layered {
395            middleware: self.middleware.clone(),
396            inner: handler,
397        }
398    }
399}
400
401/// A handler wrapped with middleware via a Layer.
402pub struct Layered<M, H> {
403    middleware: M,
404    inner: H,
405}
406
407impl<M: Middleware, H: Handler> Handler for Layered<M, H> {
408    fn call<'a>(
409        &'a self,
410        ctx: &'a RequestContext,
411        req: &'a mut Request,
412    ) -> BoxFuture<'a, Response> {
413        Box::pin(async move {
414            // Run before hook
415            let _ = ctx.checkpoint();
416            match self.middleware.before(ctx, req).await {
417                ControlFlow::Continue => {
418                    // Call inner handler
419                    let _ = ctx.checkpoint();
420                    let response = self.inner.call(ctx, req).await;
421                    // Run after hook
422                    let _ = ctx.checkpoint();
423                    self.middleware.after(ctx, req, response).await
424                }
425                ControlFlow::Break(response) => {
426                    // Short-circuit: still run after for this middleware
427                    let _ = ctx.checkpoint();
428                    self.middleware.after(ctx, req, response).await
429                }
430            }
431        })
432    }
433}
434
435// ============================================================================
436// Common Middleware Implementations
437// ============================================================================
438
439/// No-op middleware that does nothing.
440///
441/// Useful as a placeholder or for testing.
442#[derive(Debug, Clone, Copy, Default)]
443pub struct NoopMiddleware;
444
445impl Middleware for NoopMiddleware {
446    fn name(&self) -> &'static str {
447        "Noop"
448    }
449}
450
451/// Middleware that adds a custom header to all responses.
452///
453/// # Example
454///
455/// ```ignore
456/// // Add X-Powered-By header to all responses
457/// let mw = AddResponseHeader::new("X-Powered-By", "fastapi_rust");
458/// stack.push(mw);
459/// ```
460#[derive(Debug, Clone)]
461pub struct AddResponseHeader {
462    name: String,
463    value: Vec<u8>,
464}
465
466impl AddResponseHeader {
467    /// Creates a new middleware that adds the specified header to responses.
468    pub fn new(name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
469        Self {
470            name: name.into(),
471            value: value.into(),
472        }
473    }
474}
475
476impl Middleware for AddResponseHeader {
477    fn after<'a>(
478        &'a self,
479        _ctx: &'a RequestContext,
480        _req: &'a Request,
481        response: Response,
482    ) -> BoxFuture<'a, Response> {
483        let name = self.name.clone();
484        let value = self.value.clone();
485        Box::pin(async move { response.header(name, value) })
486    }
487
488    fn name(&self) -> &'static str {
489        "AddResponseHeader"
490    }
491}
492
493/// Middleware that requires a specific header to be present.
494///
495/// Returns 400 Bad Request if the header is missing.
496///
497/// # Example
498///
499/// ```ignore
500/// // Require X-Api-Key header
501/// let mw = RequireHeader::new("X-Api-Key");
502/// stack.push(mw);
503/// ```
504#[derive(Debug, Clone)]
505pub struct RequireHeader {
506    name: String,
507}
508
509impl RequireHeader {
510    /// Creates a new middleware that requires the specified header.
511    pub fn new(name: impl Into<String>) -> Self {
512        Self { name: name.into() }
513    }
514}
515
516impl Middleware for RequireHeader {
517    fn before<'a>(
518        &'a self,
519        _ctx: &'a RequestContext,
520        req: &'a mut Request,
521    ) -> BoxFuture<'a, ControlFlow> {
522        let has_header = req.headers().get(&self.name).is_some();
523        let name = self.name.clone();
524        Box::pin(async move {
525            if has_header {
526                ControlFlow::Continue
527            } else {
528                let body = format!("Missing required header: {name}");
529                ControlFlow::Break(
530                    Response::with_status(crate::response::StatusCode::BAD_REQUEST)
531                        .header("content-type", b"text/plain".to_vec())
532                        .body(crate::response::ResponseBody::Bytes(body.into_bytes())),
533                )
534            }
535        })
536    }
537
538    fn name(&self) -> &'static str {
539        "RequireHeader"
540    }
541}
542
543/// Middleware that limits request processing based on path prefix.
544///
545/// Only allows requests to paths starting with the specified prefix.
546/// Other requests receive a 404 Not Found response.
547///
548/// # Example
549///
550/// ```ignore
551/// // Only allow requests to /api/*
552/// let mw = PathPrefixFilter::new("/api");
553/// stack.push(mw);
554/// ```
555#[derive(Debug, Clone)]
556pub struct PathPrefixFilter {
557    prefix: String,
558}
559
560impl PathPrefixFilter {
561    /// Creates a new middleware that only allows requests with the specified path prefix.
562    pub fn new(prefix: impl Into<String>) -> Self {
563        Self {
564            prefix: prefix.into(),
565        }
566    }
567}
568
569impl Middleware for PathPrefixFilter {
570    fn before<'a>(
571        &'a self,
572        _ctx: &'a RequestContext,
573        req: &'a mut Request,
574    ) -> BoxFuture<'a, ControlFlow> {
575        let path_matches = req.path().starts_with(&self.prefix);
576        Box::pin(async move {
577            if path_matches {
578                ControlFlow::Continue
579            } else {
580                ControlFlow::Break(Response::with_status(
581                    crate::response::StatusCode::NOT_FOUND,
582                ))
583            }
584        })
585    }
586
587    fn name(&self) -> &'static str {
588        "PathPrefixFilter"
589    }
590}
591
592/// Middleware that sets response status code based on a condition.
593///
594/// This is useful for implementing health checks or conditional responses.
595#[derive(Debug, Clone)]
596pub struct ConditionalStatus<F>
597where
598    F: Fn(&Request) -> bool + Send + Sync,
599{
600    condition: F,
601    status_if_true: crate::response::StatusCode,
602    status_if_false: crate::response::StatusCode,
603}
604
605impl<F> ConditionalStatus<F>
606where
607    F: Fn(&Request) -> bool + Send + Sync,
608{
609    /// Creates a new conditional status middleware.
610    ///
611    /// If the condition returns true, the response gets `status_if_true`.
612    /// Otherwise, it gets `status_if_false`.
613    pub fn new(
614        condition: F,
615        status_if_true: crate::response::StatusCode,
616        status_if_false: crate::response::StatusCode,
617    ) -> Self {
618        Self {
619            condition,
620            status_if_true,
621            status_if_false,
622        }
623    }
624}
625
626impl<F> Middleware for ConditionalStatus<F>
627where
628    F: Fn(&Request) -> bool + Send + Sync,
629{
630    fn after<'a>(
631        &'a self,
632        _ctx: &'a RequestContext,
633        req: &'a Request,
634        response: Response,
635    ) -> BoxFuture<'a, Response> {
636        let matches = (self.condition)(req);
637        let status = if matches {
638            self.status_if_true
639        } else {
640            self.status_if_false
641        };
642        Box::pin(async move { Response::with_status(status).body(response.body_ref().into()) })
643    }
644
645    fn name(&self) -> &'static str {
646        "ConditionalStatus"
647    }
648}
649
650// ============================================================================
651// CORS Middleware
652// ============================================================================
653
654/// Origin matching pattern for CORS.
655#[derive(Debug, Clone)]
656pub enum OriginPattern {
657    /// Allow any origin.
658    Any,
659    /// Exact match.
660    Exact(String),
661    /// Wildcard match (supports `*`).
662    Wildcard(String),
663    /// Simple regex match (supports `^`, `$`, `.`, `*`).
664    Regex(String),
665}
666
667impl OriginPattern {
668    fn matches(&self, origin: &str) -> bool {
669        match self {
670            Self::Any => true,
671            Self::Exact(value) => value == origin,
672            Self::Wildcard(pattern) => wildcard_match(pattern, origin),
673            Self::Regex(pattern) => regex_match(pattern, origin),
674        }
675    }
676}
677
678/// Cross-Origin Resource Sharing (CORS) configuration.
679///
680/// Controls which origins, methods, and headers are allowed for
681/// cross-origin requests. By default, no origins are allowed.
682///
683/// # Defaults
684///
685/// | Setting | Default |
686/// |---------|---------|
687/// | `allow_any_origin` | `false` |
688/// | `allow_credentials` | `false` |
689/// | `allowed_methods` | GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD |
690/// | `allowed_headers` | none |
691/// | `expose_headers` | none |
692/// | `max_age` | none |
693///
694/// # Security: Credentials and Wildcards
695///
696/// According to the CORS specification (Fetch Standard), when credentials
697/// mode is enabled (`allow_credentials: true`), the following headers
698/// **cannot** use the `*` wildcard value:
699///
700/// - `Access-Control-Allow-Origin` (must echo the specific origin)
701/// - `Access-Control-Allow-Headers` (must list specific headers)
702/// - `Access-Control-Allow-Methods` (must list specific methods)
703/// - `Access-Control-Expose-Headers` (must list specific headers)
704///
705/// This implementation enforces this: when `allow_credentials(true)` is
706/// combined with `allow_any_origin()`, the response echoes back the
707/// specific request origin instead of returning `*`.
708///
709/// # Example
710///
711/// ```ignore
712/// use fastapi_core::Cors;
713///
714/// // Secure: specific origin with credentials
715/// let cors = Cors::new()
716///     .allow_origin("https://myapp.example.com")
717///     .allow_credentials(true)
718///     .expose_headers(["X-Request-Id"]);
719///
720/// // Also secure: any origin echoes back specific origin when credentials enabled
721/// // (not recommended - prefer explicit origins for security)
722/// let cors = Cors::new()
723///     .allow_any_origin()
724///     .allow_credentials(true);
725/// ```
726#[derive(Debug, Clone)]
727pub struct CorsConfig {
728    allow_any_origin: bool,
729    allow_credentials: bool,
730    allowed_methods: Vec<crate::request::Method>,
731    allowed_headers: Vec<String>,
732    expose_headers: Vec<String>,
733    max_age: Option<u32>,
734    origins: Vec<OriginPattern>,
735}
736
737impl Default for CorsConfig {
738    fn default() -> Self {
739        Self {
740            allow_any_origin: false,
741            allow_credentials: false,
742            allowed_methods: vec![
743                crate::request::Method::Get,
744                crate::request::Method::Post,
745                crate::request::Method::Put,
746                crate::request::Method::Patch,
747                crate::request::Method::Delete,
748                crate::request::Method::Options,
749                crate::request::Method::Head,
750            ],
751            allowed_headers: Vec::new(),
752            expose_headers: Vec::new(),
753            max_age: None,
754            origins: Vec::new(),
755        }
756    }
757}
758
759/// CORS middleware.
760#[derive(Debug, Clone)]
761pub struct Cors {
762    config: CorsConfig,
763}
764
765impl Cors {
766    /// Create a new CORS middleware with default configuration.
767    #[must_use]
768    pub fn new() -> Self {
769        Self {
770            config: CorsConfig::default(),
771        }
772    }
773
774    /// Replace the configuration entirely.
775    #[must_use]
776    pub fn config(mut self, config: CorsConfig) -> Self {
777        self.config = config;
778        self
779    }
780
781    /// Allow any origin.
782    #[must_use]
783    pub fn allow_any_origin(mut self) -> Self {
784        self.config.allow_any_origin = true;
785        self
786    }
787
788    /// Allow a single exact origin.
789    #[must_use]
790    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
791        self.config
792            .origins
793            .push(OriginPattern::Exact(origin.into()));
794        self
795    }
796
797    /// Allow a wildcard origin pattern (supports `*`).
798    #[must_use]
799    pub fn allow_origin_wildcard(mut self, pattern: impl Into<String>) -> Self {
800        self.config
801            .origins
802            .push(OriginPattern::Wildcard(pattern.into()));
803        self
804    }
805
806    /// Allow a simple regex origin pattern (supports `^`, `$`, `.`, `*`).
807    #[must_use]
808    pub fn allow_origin_regex(mut self, pattern: impl Into<String>) -> Self {
809        self.config
810            .origins
811            .push(OriginPattern::Regex(pattern.into()));
812        self
813    }
814
815    /// Allow credentials for CORS responses.
816    #[must_use]
817    pub fn allow_credentials(mut self, allow: bool) -> Self {
818        self.config.allow_credentials = allow;
819        self
820    }
821
822    /// Override allowed HTTP methods for preflight.
823    #[must_use]
824    pub fn allow_methods<I>(mut self, methods: I) -> Self
825    where
826        I: IntoIterator<Item = crate::request::Method>,
827    {
828        self.config.allowed_methods = methods.into_iter().collect();
829        self
830    }
831
832    /// Override allowed headers for preflight.
833    #[must_use]
834    pub fn allow_headers<I, S>(mut self, headers: I) -> Self
835    where
836        I: IntoIterator<Item = S>,
837        S: Into<String>,
838    {
839        self.config.allowed_headers = headers.into_iter().map(Into::into).collect();
840        self
841    }
842
843    /// Add exposed headers for responses.
844    #[must_use]
845    pub fn expose_headers<I, S>(mut self, headers: I) -> Self
846    where
847        I: IntoIterator<Item = S>,
848        S: Into<String>,
849    {
850        self.config.expose_headers = headers.into_iter().map(Into::into).collect();
851        self
852    }
853
854    /// Set the preflight max-age in seconds.
855    #[must_use]
856    pub fn max_age(mut self, seconds: u32) -> Self {
857        self.config.max_age = Some(seconds);
858        self
859    }
860
861    fn is_origin_allowed(&self, origin: &str) -> bool {
862        if self.config.allow_any_origin {
863            return true;
864        }
865        self.config
866            .origins
867            .iter()
868            .any(|pattern| pattern.matches(origin))
869    }
870
871    fn allow_origin_value(&self, origin: &str) -> Option<String> {
872        if !self.is_origin_allowed(origin) {
873            return None;
874        }
875        if self.config.allow_any_origin && !self.config.allow_credentials {
876            Some("*".to_string())
877        } else {
878            Some(origin.to_string())
879        }
880    }
881
882    fn allow_methods_value(&self) -> String {
883        self.config
884            .allowed_methods
885            .iter()
886            .map(|method| method.as_str())
887            .collect::<Vec<_>>()
888            .join(", ")
889    }
890
891    fn allow_headers_value(&self, request: &Request) -> Option<String> {
892        if !self.config.allowed_headers.is_empty() {
893            return Some(self.config.allowed_headers.join(", "));
894        }
895
896        request
897            .headers()
898            .get("access-control-request-headers")
899            .and_then(|value| std::str::from_utf8(value).ok())
900            .map(ToString::to_string)
901    }
902
903    fn apply_common_headers(&self, mut response: Response, origin: &str) -> Response {
904        if let Some(allow_origin) = self.allow_origin_value(origin) {
905            let is_wildcard = allow_origin == "*";
906            response = response.header("access-control-allow-origin", allow_origin.into_bytes());
907            if !is_wildcard {
908                response = response.header("vary", b"Origin".to_vec());
909            }
910            if self.config.allow_credentials {
911                response = response.header("access-control-allow-credentials", b"true".to_vec());
912            }
913            if !self.config.expose_headers.is_empty() {
914                response = response.header(
915                    "access-control-expose-headers",
916                    self.config.expose_headers.join(", ").into_bytes(),
917                );
918            }
919        }
920        response
921    }
922}
923
924impl Default for Cors {
925    fn default() -> Self {
926        Self::new()
927    }
928}
929
930#[derive(Debug, Clone)]
931struct CorsOrigin(String);
932
933impl Middleware for Cors {
934    fn before<'a>(
935        &'a self,
936        _ctx: &'a RequestContext,
937        req: &'a mut Request,
938    ) -> BoxFuture<'a, ControlFlow> {
939        let origin = req
940            .headers()
941            .get("origin")
942            .and_then(|value| std::str::from_utf8(value).ok())
943            .map(ToString::to_string);
944
945        let Some(origin) = origin else {
946            return Box::pin(async { ControlFlow::Continue });
947        };
948
949        if !self.is_origin_allowed(&origin) {
950            let is_preflight = req.method() == crate::request::Method::Options
951                && req.headers().get("access-control-request-method").is_some();
952            if is_preflight {
953                return Box::pin(async {
954                    ControlFlow::Break(Response::with_status(
955                        crate::response::StatusCode::FORBIDDEN,
956                    ))
957                });
958            }
959            return Box::pin(async { ControlFlow::Continue });
960        }
961
962        let is_preflight = req.method() == crate::request::Method::Options
963            && req.headers().get("access-control-request-method").is_some();
964
965        if is_preflight {
966            let mut response = Response::no_content();
967            response = self.apply_common_headers(response, &origin);
968            response = response.header(
969                "access-control-allow-methods",
970                self.allow_methods_value().into_bytes(),
971            );
972
973            if let Some(value) = self.allow_headers_value(req) {
974                response = response.header("access-control-allow-headers", value.into_bytes());
975            }
976
977            if let Some(max_age) = self.config.max_age {
978                response =
979                    response.header("access-control-max-age", max_age.to_string().into_bytes());
980            }
981
982            return Box::pin(async move { ControlFlow::Break(response) });
983        }
984
985        req.insert_extension(CorsOrigin(origin));
986        Box::pin(async { ControlFlow::Continue })
987    }
988
989    fn after<'a>(
990        &'a self,
991        _ctx: &'a RequestContext,
992        req: &'a Request,
993        response: Response,
994    ) -> BoxFuture<'a, Response> {
995        let origin = req.get_extension::<CorsOrigin>().map(|v| v.0.clone());
996        Box::pin(async move {
997            if let Some(origin) = origin {
998                return self.apply_common_headers(response, &origin);
999            }
1000            response
1001        })
1002    }
1003
1004    fn name(&self) -> &'static str {
1005        "Cors"
1006    }
1007}
1008
1009fn wildcard_match(pattern: &str, value: &str) -> bool {
1010    // Simple glob matcher for '*'
1011    let mut pat_chars = pattern.chars().peekable();
1012    let mut val_chars = value.chars().peekable();
1013    let mut star = None;
1014    let mut match_after_star = None;
1015
1016    while let Some(p) = pat_chars.next() {
1017        match p {
1018            '*' => {
1019                star = Some(pat_chars.clone());
1020                match_after_star = Some(val_chars.clone());
1021            }
1022            _ => {
1023                if let Some(v) = val_chars.next() {
1024                    if p != v {
1025                        if let (Some(pat_backup), Some(val_backup)) =
1026                            (star.clone(), match_after_star.clone())
1027                        {
1028                            pat_chars = pat_backup;
1029                            val_chars = val_backup;
1030                            val_chars.next();
1031                            match_after_star = Some(val_chars.clone());
1032                            continue;
1033                        }
1034                        return false;
1035                    }
1036                } else {
1037                    return false;
1038                }
1039            }
1040        }
1041    }
1042
1043    // Consume trailing '*' in pattern
1044    if pat_chars.peek().is_none() && val_chars.peek().is_none() {
1045        return true;
1046    }
1047
1048    if let Some(pat_backup) = star {
1049        if val_chars.peek().is_none() {
1050            let trailing = pat_backup;
1051            for ch in trailing {
1052                if ch != '*' {
1053                    return false;
1054                }
1055            }
1056            return true;
1057        }
1058    }
1059
1060    val_chars.peek().is_none()
1061}
1062
1063fn regex_match(pattern: &str, value: &str) -> bool {
1064    // Minimal regex engine: supports ^, $, ., *
1065    let pat = pattern.as_bytes();
1066    let text = value.as_bytes();
1067
1068    if pat.first() == Some(&b'^') {
1069        return regex_match_here(&pat[1..], text);
1070    }
1071
1072    let mut i = 0;
1073    loop {
1074        if regex_match_here(pat, &text[i..]) {
1075            return true;
1076        }
1077        if i == text.len() {
1078            break;
1079        }
1080        i += 1;
1081    }
1082    false
1083}
1084
1085fn regex_match_here(pattern: &[u8], text: &[u8]) -> bool {
1086    if pattern.is_empty() {
1087        return true;
1088    }
1089    if pattern == b"$" {
1090        return text.is_empty();
1091    }
1092    if pattern.len() >= 2 && pattern[1] == b'*' {
1093        return regex_match_star(pattern[0], &pattern[2..], text);
1094    }
1095    if !text.is_empty() && (pattern[0] == b'.' || pattern[0] == text[0]) {
1096        return regex_match_here(&pattern[1..], &text[1..]);
1097    }
1098    false
1099}
1100
1101fn regex_match_star(ch: u8, pattern: &[u8], text: &[u8]) -> bool {
1102    let mut i = 0;
1103    loop {
1104        if regex_match_here(pattern, &text[i..]) {
1105            return true;
1106        }
1107        if i == text.len() {
1108            return false;
1109        }
1110        if ch != b'.' && text[i] != ch {
1111            return false;
1112        }
1113        i += 1;
1114    }
1115}
1116
1117// ============================================================================
1118// Request/Response Logging Middleware
1119// ============================================================================
1120
1121/// Middleware that logs requests and responses with configurable redaction.
1122#[derive(Debug, Clone)]
1123pub struct RequestResponseLogger {
1124    log_config: LogConfig,
1125    redact_headers: HashSet<String>,
1126    log_request_headers: bool,
1127    log_response_headers: bool,
1128    log_body: bool,
1129    max_body_bytes: usize,
1130}
1131
1132impl Default for RequestResponseLogger {
1133    fn default() -> Self {
1134        Self {
1135            log_config: LogConfig::production(),
1136            redact_headers: default_redacted_headers(),
1137            log_request_headers: true,
1138            log_response_headers: true,
1139            log_body: false,
1140            max_body_bytes: 1024,
1141        }
1142    }
1143}
1144
1145impl RequestResponseLogger {
1146    /// Create a new logger middleware with defaults.
1147    #[must_use]
1148    pub fn new() -> Self {
1149        Self::default()
1150    }
1151
1152    /// Override the logging configuration.
1153    #[must_use]
1154    pub fn log_config(mut self, config: LogConfig) -> Self {
1155        self.log_config = config;
1156        self
1157    }
1158
1159    /// Enable or disable request header logging.
1160    #[must_use]
1161    pub fn log_request_headers(mut self, enabled: bool) -> Self {
1162        self.log_request_headers = enabled;
1163        self
1164    }
1165
1166    /// Enable or disable response header logging.
1167    #[must_use]
1168    pub fn log_response_headers(mut self, enabled: bool) -> Self {
1169        self.log_response_headers = enabled;
1170        self
1171    }
1172
1173    /// Enable or disable request/response body logging.
1174    #[must_use]
1175    pub fn log_body(mut self, enabled: bool) -> Self {
1176        self.log_body = enabled;
1177        self
1178    }
1179
1180    /// Set the maximum number of body bytes to include in logs.
1181    #[must_use]
1182    pub fn max_body_bytes(mut self, max: usize) -> Self {
1183        self.max_body_bytes = max;
1184        self
1185    }
1186
1187    /// Add a header name to redact (case-insensitive).
1188    #[must_use]
1189    pub fn redact_header(mut self, name: impl Into<String>) -> Self {
1190        self.redact_headers.insert(name.into().to_ascii_lowercase());
1191        self
1192    }
1193}
1194
1195#[derive(Debug, Clone)]
1196struct RequestStart(Instant);
1197
1198impl Middleware for RequestResponseLogger {
1199    fn before<'a>(
1200        &'a self,
1201        ctx: &'a RequestContext,
1202        req: &'a mut Request,
1203    ) -> BoxFuture<'a, ControlFlow> {
1204        let logger = RequestLogger::new(ctx, self.log_config.clone());
1205        req.insert_extension(RequestStart(Instant::now()));
1206
1207        let method = req.method();
1208        let path = req.path();
1209        let query = req.query();
1210        let body_bytes = body_len(req.body());
1211
1212        logger.info_with_fields("request", |entry| {
1213            let mut entry = entry
1214                .field("method", method)
1215                .field("path", path)
1216                .field("body_bytes", body_bytes);
1217
1218            if let Some(q) = query {
1219                entry = entry.field("query", q);
1220            }
1221
1222            if self.log_request_headers {
1223                let headers = format_headers(req.headers().iter(), &self.redact_headers);
1224                entry = entry.field("headers", headers);
1225            }
1226
1227            if self.log_body {
1228                if let Some(body) = preview_body(req.body(), self.max_body_bytes) {
1229                    entry = entry.field("body", body);
1230                }
1231            }
1232
1233            entry
1234        });
1235
1236        Box::pin(async { ControlFlow::Continue })
1237    }
1238
1239    fn after<'a>(
1240        &'a self,
1241        ctx: &'a RequestContext,
1242        req: &'a Request,
1243        response: Response,
1244    ) -> BoxFuture<'a, Response> {
1245        let logger = RequestLogger::new(ctx, self.log_config.clone());
1246        let duration = req
1247            .get_extension::<RequestStart>()
1248            .map(|start| start.0.elapsed())
1249            .unwrap_or_default();
1250
1251        let status = response.status();
1252        let body_bytes = response.body_ref().len();
1253
1254        logger.info_with_fields("response", |entry| {
1255            let mut entry = entry
1256                .field("status", status.as_u16())
1257                .field("duration_us", duration.as_micros())
1258                .field("body_bytes", body_bytes);
1259
1260            if self.log_response_headers {
1261                let headers = format_response_headers(response.headers(), &self.redact_headers);
1262                entry = entry.field("headers", headers);
1263            }
1264
1265            if self.log_body {
1266                if let Some(body) = preview_response_body(response.body_ref(), self.max_body_bytes)
1267                {
1268                    entry = entry.field("body", body);
1269                }
1270            }
1271
1272            entry
1273        });
1274
1275        Box::pin(async move { response })
1276    }
1277
1278    fn name(&self) -> &'static str {
1279        "RequestResponseLogger"
1280    }
1281}
1282
1283fn default_redacted_headers() -> HashSet<String> {
1284    [
1285        "authorization",
1286        "proxy-authorization",
1287        "cookie",
1288        "set-cookie",
1289    ]
1290    .iter()
1291    .map(ToString::to_string)
1292    .collect()
1293}
1294
1295fn body_len(body: &Body) -> usize {
1296    match body {
1297        Body::Empty => 0,
1298        Body::Bytes(bytes) => bytes.len(),
1299        Body::Stream(_) => 0, // Length unknown for streaming bodies
1300    }
1301}
1302
1303fn preview_body(body: &Body, max_bytes: usize) -> Option<String> {
1304    if max_bytes == 0 {
1305        return None;
1306    }
1307    match body {
1308        Body::Empty => None,
1309        Body::Bytes(bytes) => {
1310            if bytes.is_empty() {
1311                None
1312            } else {
1313                Some(format_bytes(bytes, max_bytes))
1314            }
1315        }
1316        Body::Stream(_) => None, // Cannot preview streaming body
1317    }
1318}
1319
1320fn preview_response_body(body: &crate::response::ResponseBody, max_bytes: usize) -> Option<String> {
1321    if max_bytes == 0 {
1322        return None;
1323    }
1324    match body {
1325        crate::response::ResponseBody::Empty => None,
1326        crate::response::ResponseBody::Bytes(bytes) => {
1327            if bytes.is_empty() {
1328                None
1329            } else {
1330                Some(format_bytes(bytes, max_bytes))
1331            }
1332        }
1333        crate::response::ResponseBody::Stream(_) => None,
1334    }
1335}
1336
1337fn format_headers<'a>(
1338    headers: impl Iterator<Item = (&'a str, &'a [u8])>,
1339    redacted: &HashSet<String>,
1340) -> String {
1341    let mut out = String::new();
1342    for (idx, (name, value)) in headers.enumerate() {
1343        if idx > 0 {
1344            out.push_str(", ");
1345        }
1346        out.push_str(name);
1347        out.push('=');
1348
1349        let lowered = name.to_ascii_lowercase();
1350        if redacted.contains(&lowered) {
1351            out.push_str("<redacted>");
1352            continue;
1353        }
1354
1355        match std::str::from_utf8(value) {
1356            Ok(text) => out.push_str(text),
1357            Err(_) => out.push_str("<binary>"),
1358        }
1359    }
1360    out
1361}
1362
1363fn format_response_headers(headers: &[(String, Vec<u8>)], redacted: &HashSet<String>) -> String {
1364    format_headers(
1365        headers
1366            .iter()
1367            .map(|(name, value)| (name.as_str(), value.as_slice())),
1368        redacted,
1369    )
1370}
1371
1372fn format_bytes(bytes: &[u8], max_bytes: usize) -> String {
1373    let limit = max_bytes.min(bytes.len());
1374    match std::str::from_utf8(&bytes[..limit]) {
1375        Ok(text) => {
1376            let mut output = text.to_string();
1377            if bytes.len() > max_bytes {
1378                output.push_str("...");
1379            }
1380            output
1381        }
1382        Err(_) => format!("<{} bytes binary>", bytes.len()),
1383    }
1384}
1385
1386// Helper for ResponseBody conversion
1387impl From<&crate::response::ResponseBody> for crate::response::ResponseBody {
1388    fn from(body: &crate::response::ResponseBody) -> Self {
1389        match body {
1390            crate::response::ResponseBody::Empty => crate::response::ResponseBody::Empty,
1391            crate::response::ResponseBody::Bytes(b) => {
1392                crate::response::ResponseBody::Bytes(b.clone())
1393            }
1394            crate::response::ResponseBody::Stream(_) => crate::response::ResponseBody::Empty,
1395        }
1396    }
1397}
1398
1399// ============================================================================
1400// Request ID Middleware
1401// ============================================================================
1402
1403/// A request ID that was extracted or generated for the current request.
1404///
1405/// This is stored in request extensions and can be retrieved by handlers
1406/// or other middleware for logging and tracing.
1407#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1408pub struct RequestId(pub String);
1409
1410impl RequestId {
1411    /// Creates a new request ID with the given value.
1412    #[must_use]
1413    pub fn new(id: impl Into<String>) -> Self {
1414        Self(id.into())
1415    }
1416
1417    /// Returns the request ID as a string slice.
1418    #[must_use]
1419    pub fn as_str(&self) -> &str {
1420        &self.0
1421    }
1422
1423    /// Generates a new unique request ID.
1424    ///
1425    /// Uses a simple format: timestamp-counter for uniqueness without
1426    /// requiring external UUID dependencies.
1427    #[must_use]
1428    pub fn generate() -> Self {
1429        use std::sync::atomic::{AtomicU64, Ordering};
1430        use std::time::{SystemTime, UNIX_EPOCH};
1431
1432        static COUNTER: AtomicU64 = AtomicU64::new(0);
1433
1434        let timestamp = SystemTime::now()
1435            .duration_since(UNIX_EPOCH)
1436            .map(|d| d.as_micros() as u64)
1437            .unwrap_or(0);
1438        let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
1439
1440        // Format: base36 timestamp + counter for compact, unique IDs
1441        Self(format!("{:x}-{:04x}", timestamp, counter & 0xFFFF))
1442    }
1443}
1444
1445impl std::fmt::Display for RequestId {
1446    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1447        write!(f, "{}", self.0)
1448    }
1449}
1450
1451impl From<String> for RequestId {
1452    fn from(s: String) -> Self {
1453        Self(s)
1454    }
1455}
1456
1457impl From<&str> for RequestId {
1458    fn from(s: &str) -> Self {
1459        Self(s.to_string())
1460    }
1461}
1462
1463/// Configuration for request ID middleware.
1464#[derive(Debug, Clone)]
1465pub struct RequestIdConfig {
1466    /// Header name to read/write request ID (default: "x-request-id").
1467    pub header_name: String,
1468    /// Whether to accept request ID from client (default: true).
1469    pub accept_from_client: bool,
1470    /// Whether to add request ID to response headers (default: true).
1471    pub add_to_response: bool,
1472    /// Maximum length of client-provided request ID (default: 128).
1473    pub max_client_id_length: usize,
1474}
1475
1476impl Default for RequestIdConfig {
1477    fn default() -> Self {
1478        Self {
1479            header_name: "x-request-id".to_string(),
1480            accept_from_client: true,
1481            add_to_response: true,
1482            max_client_id_length: 128,
1483        }
1484    }
1485}
1486
1487impl RequestIdConfig {
1488    /// Creates a new configuration with defaults.
1489    #[must_use]
1490    pub fn new() -> Self {
1491        Self::default()
1492    }
1493
1494    /// Sets the header name for request ID.
1495    #[must_use]
1496    pub fn header_name(mut self, name: impl Into<String>) -> Self {
1497        self.header_name = name.into();
1498        self
1499    }
1500
1501    /// Sets whether to accept request ID from client.
1502    #[must_use]
1503    pub fn accept_from_client(mut self, accept: bool) -> Self {
1504        self.accept_from_client = accept;
1505        self
1506    }
1507
1508    /// Sets whether to add request ID to response.
1509    #[must_use]
1510    pub fn add_to_response(mut self, add: bool) -> Self {
1511        self.add_to_response = add;
1512        self
1513    }
1514
1515    /// Sets the maximum length for client-provided request IDs.
1516    #[must_use]
1517    pub fn max_client_id_length(mut self, max: usize) -> Self {
1518        self.max_client_id_length = max;
1519        self
1520    }
1521}
1522
1523/// Middleware that adds unique request IDs to requests and responses.
1524///
1525/// This middleware:
1526/// 1. Checks for an existing X-Request-ID header from the client
1527/// 2. If present and valid, uses it; otherwise generates a new ID
1528/// 3. Stores the ID in request extensions for handlers to access
1529/// 4. Adds the ID to response headers
1530///
1531/// # Example
1532///
1533/// ```ignore
1534/// use fastapi_core::middleware::RequestIdMiddleware;
1535///
1536/// let mut stack = MiddlewareStack::new();
1537/// stack.push(RequestIdMiddleware::new());
1538///
1539/// // In your handler:
1540/// async fn handler(ctx: &RequestContext, req: &Request) -> Response {
1541///     if let Some(request_id) = req.get_extension::<RequestId>() {
1542///         println!("Request ID: {}", request_id);
1543///     }
1544///     Response::ok()
1545/// }
1546/// ```
1547#[derive(Debug, Clone)]
1548pub struct RequestIdMiddleware {
1549    config: RequestIdConfig,
1550}
1551
1552impl Default for RequestIdMiddleware {
1553    fn default() -> Self {
1554        Self::new()
1555    }
1556}
1557
1558impl RequestIdMiddleware {
1559    /// Creates a new request ID middleware with default configuration.
1560    #[must_use]
1561    pub fn new() -> Self {
1562        Self {
1563            config: RequestIdConfig::default(),
1564        }
1565    }
1566
1567    /// Creates a new request ID middleware with the given configuration.
1568    #[must_use]
1569    pub fn with_config(config: RequestIdConfig) -> Self {
1570        Self { config }
1571    }
1572
1573    /// Extracts or generates a request ID for the given request.
1574    fn get_or_generate_id(&self, req: &Request) -> RequestId {
1575        if self.config.accept_from_client {
1576            if let Some(header_value) = req.headers().get(&self.config.header_name) {
1577                if let Ok(client_id) = std::str::from_utf8(header_value) {
1578                    // Validate length and basic content
1579                    if !client_id.is_empty()
1580                        && client_id.len() <= self.config.max_client_id_length
1581                        && is_valid_request_id(client_id)
1582                    {
1583                        return RequestId::new(client_id);
1584                    }
1585                }
1586            }
1587        }
1588        RequestId::generate()
1589    }
1590}
1591
1592/// Validates that a request ID contains only safe characters.
1593fn is_valid_request_id(id: &str) -> bool {
1594    !id.is_empty()
1595        && id
1596            .chars()
1597            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.')
1598}
1599
1600impl Middleware for RequestIdMiddleware {
1601    fn before<'a>(
1602        &'a self,
1603        _ctx: &'a RequestContext,
1604        req: &'a mut Request,
1605    ) -> BoxFuture<'a, ControlFlow> {
1606        let request_id = self.get_or_generate_id(req);
1607        req.insert_extension(request_id);
1608        Box::pin(async { ControlFlow::Continue })
1609    }
1610
1611    fn after<'a>(
1612        &'a self,
1613        _ctx: &'a RequestContext,
1614        req: &'a Request,
1615        response: Response,
1616    ) -> BoxFuture<'a, Response> {
1617        if !self.config.add_to_response {
1618            return Box::pin(async move { response });
1619        }
1620
1621        let request_id = req.get_extension::<RequestId>().cloned();
1622        let header_name = self.config.header_name.clone();
1623
1624        Box::pin(async move {
1625            if let Some(id) = request_id {
1626                response.header(header_name, id.0.into_bytes())
1627            } else {
1628                response
1629            }
1630        })
1631    }
1632
1633    fn name(&self) -> &'static str {
1634        "RequestId"
1635    }
1636}
1637
1638// ============================================================================
1639// Security Headers Middleware
1640// ============================================================================
1641
1642/// X-Frame-Options header value.
1643///
1644/// Controls whether the page can be displayed in a frame.
1645#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1646pub enum XFrameOptions {
1647    /// Prevents any domain from framing the content.
1648    Deny,
1649    /// Allows the current site to frame the content.
1650    SameOrigin,
1651}
1652
1653impl XFrameOptions {
1654    fn as_bytes(self) -> &'static [u8] {
1655        match self {
1656            Self::Deny => b"DENY",
1657            Self::SameOrigin => b"SAMEORIGIN",
1658        }
1659    }
1660}
1661
1662/// Referrer-Policy header value.
1663///
1664/// Controls how much referrer information should be included with requests.
1665#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1666pub enum ReferrerPolicy {
1667    /// No referrer information is sent.
1668    NoReferrer,
1669    /// Only send origin when protocol security level stays the same.
1670    NoReferrerWhenDowngrade,
1671    /// Only send the origin (not the path).
1672    Origin,
1673    /// Only send origin for cross-origin requests.
1674    OriginWhenCrossOrigin,
1675    /// Send the origin, path, and query string for same-origin requests only.
1676    SameOrigin,
1677    /// Only send origin if protocol security level stays the same.
1678    StrictOrigin,
1679    /// Send full referrer for same-origin, origin only for cross-origin if secure.
1680    StrictOriginWhenCrossOrigin,
1681    /// Send the full referrer (not recommended).
1682    UnsafeUrl,
1683}
1684
1685impl ReferrerPolicy {
1686    fn as_bytes(self) -> &'static [u8] {
1687        match self {
1688            Self::NoReferrer => b"no-referrer",
1689            Self::NoReferrerWhenDowngrade => b"no-referrer-when-downgrade",
1690            Self::Origin => b"origin",
1691            Self::OriginWhenCrossOrigin => b"origin-when-cross-origin",
1692            Self::SameOrigin => b"same-origin",
1693            Self::StrictOrigin => b"strict-origin",
1694            Self::StrictOriginWhenCrossOrigin => b"strict-origin-when-cross-origin",
1695            Self::UnsafeUrl => b"unsafe-url",
1696        }
1697    }
1698}
1699
1700/// Configuration for the Security Headers middleware.
1701///
1702/// All headers are optional. Set a value to `Some(...)` to include the header,
1703/// or `None` to skip it.
1704///
1705/// # Defaults
1706///
1707/// The default configuration provides secure defaults:
1708/// - `X-Content-Type-Options: nosniff`
1709/// - `X-Frame-Options: DENY`
1710/// - `X-XSS-Protection: 0` (disabled as modern browsers have built-in protection)
1711/// - `Referrer-Policy: strict-origin-when-cross-origin`
1712///
1713/// # Example
1714///
1715/// ```ignore
1716/// use fastapi_core::middleware::{SecurityHeadersConfig, XFrameOptions, ReferrerPolicy};
1717///
1718/// let config = SecurityHeadersConfig::default()
1719///     .x_frame_options(XFrameOptions::SameOrigin)
1720///     .content_security_policy("default-src 'self'")
1721///     .hsts(31536000, true);  // 1 year, includeSubDomains
1722/// ```
1723#[derive(Debug, Clone)]
1724pub struct SecurityHeadersConfig {
1725    /// X-Content-Type-Options header.
1726    /// Default: `Some("nosniff")`
1727    pub x_content_type_options: Option<&'static str>,
1728    /// X-Frame-Options header.
1729    /// Default: `Some(XFrameOptions::Deny)`
1730    pub x_frame_options: Option<XFrameOptions>,
1731    /// X-XSS-Protection header.
1732    /// Default: `Some("0")` (disabled - modern browsers have built-in protection)
1733    ///
1734    /// Note: This header is largely obsolete. Setting it to "0" is recommended
1735    /// to prevent potential security issues in older browsers.
1736    pub x_xss_protection: Option<&'static str>,
1737    /// Content-Security-Policy header.
1738    /// Default: `None` (should be configured based on your application)
1739    pub content_security_policy: Option<String>,
1740    /// Strict-Transport-Security (HSTS) header.
1741    /// Tuple of (max_age_seconds, include_sub_domains, preload)
1742    /// Default: `None` (only set this for HTTPS-only sites)
1743    pub hsts: Option<(u64, bool, bool)>,
1744    /// Referrer-Policy header.
1745    /// Default: `Some(ReferrerPolicy::StrictOriginWhenCrossOrigin)`
1746    pub referrer_policy: Option<ReferrerPolicy>,
1747    /// Permissions-Policy header (formerly Feature-Policy).
1748    /// Default: `None` (should be configured based on your application)
1749    pub permissions_policy: Option<String>,
1750}
1751
1752impl Default for SecurityHeadersConfig {
1753    fn default() -> Self {
1754        Self {
1755            x_content_type_options: Some("nosniff"),
1756            x_frame_options: Some(XFrameOptions::Deny),
1757            x_xss_protection: Some("0"),
1758            content_security_policy: None,
1759            hsts: None,
1760            referrer_policy: Some(ReferrerPolicy::StrictOriginWhenCrossOrigin),
1761            permissions_policy: None,
1762        }
1763    }
1764}
1765
1766impl SecurityHeadersConfig {
1767    /// Creates a new configuration with secure defaults.
1768    #[must_use]
1769    pub fn new() -> Self {
1770        Self::default()
1771    }
1772
1773    /// Creates an empty configuration (no headers).
1774    #[must_use]
1775    pub fn none() -> Self {
1776        Self {
1777            x_content_type_options: None,
1778            x_frame_options: None,
1779            x_xss_protection: None,
1780            content_security_policy: None,
1781            hsts: None,
1782            referrer_policy: None,
1783            permissions_policy: None,
1784        }
1785    }
1786
1787    /// Creates a strict configuration for high-security applications.
1788    ///
1789    /// Includes:
1790    /// - All default headers
1791    /// - HSTS with 1 year max-age and includeSubDomains
1792    /// - A basic CSP that only allows same-origin resources
1793    #[must_use]
1794    pub fn strict() -> Self {
1795        Self {
1796            x_content_type_options: Some("nosniff"),
1797            x_frame_options: Some(XFrameOptions::Deny),
1798            x_xss_protection: Some("0"),
1799            content_security_policy: Some("default-src 'self'".to_string()),
1800            hsts: Some((31536000, true, false)), // 1 year, includeSubDomains
1801            referrer_policy: Some(ReferrerPolicy::NoReferrer),
1802            permissions_policy: Some("geolocation=(), camera=(), microphone=()".to_string()),
1803        }
1804    }
1805
1806    /// Sets the X-Content-Type-Options header.
1807    #[must_use]
1808    pub fn x_content_type_options(mut self, value: Option<&'static str>) -> Self {
1809        self.x_content_type_options = value;
1810        self
1811    }
1812
1813    /// Sets the X-Frame-Options header.
1814    #[must_use]
1815    pub fn x_frame_options(mut self, value: Option<XFrameOptions>) -> Self {
1816        self.x_frame_options = value;
1817        self
1818    }
1819
1820    /// Sets the X-XSS-Protection header.
1821    #[must_use]
1822    pub fn x_xss_protection(mut self, value: Option<&'static str>) -> Self {
1823        self.x_xss_protection = value;
1824        self
1825    }
1826
1827    /// Sets the Content-Security-Policy header.
1828    #[must_use]
1829    pub fn content_security_policy(mut self, value: impl Into<String>) -> Self {
1830        self.content_security_policy = Some(value.into());
1831        self
1832    }
1833
1834    /// Clears the Content-Security-Policy header.
1835    #[must_use]
1836    pub fn no_content_security_policy(mut self) -> Self {
1837        self.content_security_policy = None;
1838        self
1839    }
1840
1841    /// Sets the Strict-Transport-Security (HSTS) header.
1842    ///
1843    /// # Arguments
1844    ///
1845    /// - `max_age`: Maximum time (in seconds) the browser should remember HTTPS
1846    /// - `include_sub_domains`: Whether to apply to all subdomains
1847    /// - `preload`: Whether to include in browser preload lists (use with caution)
1848    ///
1849    /// # Warning
1850    ///
1851    /// Only enable HSTS for sites that are HTTPS-only. Enabling HSTS incorrectly
1852    /// can make your site inaccessible.
1853    #[must_use]
1854    pub fn hsts(mut self, max_age: u64, include_sub_domains: bool, preload: bool) -> Self {
1855        self.hsts = Some((max_age, include_sub_domains, preload));
1856        self
1857    }
1858
1859    /// Clears the HSTS header.
1860    #[must_use]
1861    pub fn no_hsts(mut self) -> Self {
1862        self.hsts = None;
1863        self
1864    }
1865
1866    /// Sets the Referrer-Policy header.
1867    #[must_use]
1868    pub fn referrer_policy(mut self, value: Option<ReferrerPolicy>) -> Self {
1869        self.referrer_policy = value;
1870        self
1871    }
1872
1873    /// Sets the Permissions-Policy header.
1874    #[must_use]
1875    pub fn permissions_policy(mut self, value: impl Into<String>) -> Self {
1876        self.permissions_policy = Some(value.into());
1877        self
1878    }
1879
1880    /// Clears the Permissions-Policy header.
1881    #[must_use]
1882    pub fn no_permissions_policy(mut self) -> Self {
1883        self.permissions_policy = None;
1884        self
1885    }
1886
1887    /// Builds the HSTS header value.
1888    fn build_hsts_value(&self) -> Option<String> {
1889        self.hsts.map(|(max_age, include_sub, preload)| {
1890            let mut value = format!("max-age={}", max_age);
1891            if include_sub {
1892                value.push_str("; includeSubDomains");
1893            }
1894            if preload {
1895                value.push_str("; preload");
1896            }
1897            value
1898        })
1899    }
1900}
1901
1902/// Middleware that adds security-related HTTP headers to responses.
1903///
1904/// This middleware helps protect against common web vulnerabilities by setting
1905/// appropriate security headers. It's recommended for all web applications.
1906///
1907/// # Headers
1908///
1909/// - **X-Content-Type-Options**: Prevents MIME type sniffing
1910/// - **X-Frame-Options**: Controls iframe embedding (clickjacking protection)
1911/// - **X-XSS-Protection**: Legacy XSS filter control (disabled by default)
1912/// - **Content-Security-Policy**: Controls resource loading
1913/// - **Strict-Transport-Security**: Enforces HTTPS
1914/// - **Referrer-Policy**: Controls referrer information
1915/// - **Permissions-Policy**: Controls browser features
1916///
1917/// # Example
1918///
1919/// ```ignore
1920/// use fastapi_core::middleware::{SecurityHeaders, SecurityHeadersConfig};
1921///
1922/// // Use defaults
1923/// let mw = SecurityHeaders::new();
1924///
1925/// // Custom configuration
1926/// let config = SecurityHeadersConfig::default()
1927///     .content_security_policy("default-src 'self'; img-src *")
1928///     .hsts(86400, false, false);  // 1 day
1929///
1930/// let mw = SecurityHeaders::with_config(config);
1931/// ```
1932#[derive(Debug, Clone)]
1933pub struct SecurityHeaders {
1934    config: SecurityHeadersConfig,
1935}
1936
1937impl Default for SecurityHeaders {
1938    fn default() -> Self {
1939        Self::new()
1940    }
1941}
1942
1943impl SecurityHeaders {
1944    /// Creates a new middleware with default configuration.
1945    #[must_use]
1946    pub fn new() -> Self {
1947        Self {
1948            config: SecurityHeadersConfig::default(),
1949        }
1950    }
1951
1952    /// Creates a new middleware with custom configuration.
1953    #[must_use]
1954    pub fn with_config(config: SecurityHeadersConfig) -> Self {
1955        Self { config }
1956    }
1957
1958    /// Creates a middleware with strict security settings.
1959    #[must_use]
1960    pub fn strict() -> Self {
1961        Self {
1962            config: SecurityHeadersConfig::strict(),
1963        }
1964    }
1965}
1966
1967impl Middleware for SecurityHeaders {
1968    fn after<'a>(
1969        &'a self,
1970        _ctx: &'a RequestContext,
1971        _req: &'a Request,
1972        response: Response,
1973    ) -> BoxFuture<'a, Response> {
1974        let config = self.config.clone();
1975        Box::pin(async move {
1976            let mut resp = response;
1977
1978            // X-Content-Type-Options
1979            if let Some(value) = config.x_content_type_options {
1980                resp = resp.header("X-Content-Type-Options", value.as_bytes().to_vec());
1981            }
1982
1983            // X-Frame-Options
1984            if let Some(value) = config.x_frame_options {
1985                resp = resp.header("X-Frame-Options", value.as_bytes().to_vec());
1986            }
1987
1988            // X-XSS-Protection
1989            if let Some(value) = config.x_xss_protection {
1990                resp = resp.header("X-XSS-Protection", value.as_bytes().to_vec());
1991            }
1992
1993            // Content-Security-Policy
1994            if let Some(ref value) = config.content_security_policy {
1995                resp = resp.header("Content-Security-Policy", value.as_bytes().to_vec());
1996            }
1997
1998            // Strict-Transport-Security
1999            if let Some(ref hsts_value) = config.build_hsts_value() {
2000                resp = resp.header("Strict-Transport-Security", hsts_value.as_bytes().to_vec());
2001            }
2002
2003            // Referrer-Policy
2004            if let Some(value) = config.referrer_policy {
2005                resp = resp.header("Referrer-Policy", value.as_bytes().to_vec());
2006            }
2007
2008            // Permissions-Policy
2009            if let Some(ref value) = config.permissions_policy {
2010                resp = resp.header("Permissions-Policy", value.as_bytes().to_vec());
2011            }
2012
2013            resp
2014        })
2015    }
2016
2017    fn name(&self) -> &'static str {
2018        "SecurityHeaders"
2019    }
2020}
2021
2022// ============================================================================
2023// CSRF Protection Middleware
2024// ============================================================================
2025
2026/// CSRF token stored in request extensions.
2027///
2028/// Middleware stores this after generating or validating a token,
2029/// allowing handlers to access the current CSRF token.
2030#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2031pub struct CsrfToken(pub String);
2032
2033impl CsrfToken {
2034    /// Creates a new CSRF token with the given value.
2035    #[must_use]
2036    pub fn new(token: impl Into<String>) -> Self {
2037        Self(token.into())
2038    }
2039
2040    /// Returns the token as a string slice.
2041    #[must_use]
2042    pub fn as_str(&self) -> &str {
2043        &self.0
2044    }
2045
2046    /// Generates a new unique CSRF token using cryptographic randomness.
2047    ///
2048    /// Uses `/dev/urandom` for secure random bytes.
2049    ///
2050    /// # Panics
2051    ///
2052    /// Panics if `/dev/urandom` is unavailable. CSRF tokens MUST be
2053    /// cryptographically unpredictable - there is no safe fallback.
2054    #[must_use]
2055    pub fn generate() -> Self {
2056        // CSRF tokens must be cryptographically secure - no weak fallback
2057        let bytes = Self::read_urandom(32).unwrap_or_else(|_| {
2058            panic!(
2059                "FATAL: Cryptographically secure random source (/dev/urandom) is unavailable. \
2060                 CSRF token generation requires a CSPRNG. Cannot safely generate CSRF tokens \
2061                 without cryptographic entropy."
2062            );
2063        });
2064        Self(Self::bytes_to_hex(&bytes))
2065    }
2066
2067    fn read_urandom(len: usize) -> std::io::Result<Vec<u8>> {
2068        use std::io::Read;
2069        let mut f = std::fs::File::open("/dev/urandom")?;
2070        let mut buf = vec![0u8; len];
2071        f.read_exact(&mut buf)?;
2072        Ok(buf)
2073    }
2074
2075    fn bytes_to_hex(bytes: &[u8]) -> String {
2076        use std::fmt::Write;
2077        let mut s = String::with_capacity(bytes.len() * 2);
2078        for b in bytes {
2079            let _ = write!(s, "{b:02x}");
2080        }
2081        s
2082    }
2083}
2084
2085impl std::fmt::Display for CsrfToken {
2086    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2087        f.write_str(&self.0)
2088    }
2089}
2090
2091impl From<&str> for CsrfToken {
2092    fn from(s: &str) -> Self {
2093        Self(s.to_string())
2094    }
2095}
2096
2097/// CSRF protection mode.
2098#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
2099pub enum CsrfMode {
2100    /// Double-submit cookie pattern: token in cookie must match token in header.
2101    /// This is the default and most common pattern.
2102    #[default]
2103    DoubleSubmit,
2104    /// Require token in header only (for APIs where cookies are not used).
2105    HeaderOnly,
2106}
2107
2108/// Configuration for CSRF protection middleware.
2109#[derive(Debug, Clone)]
2110pub struct CsrfConfig {
2111    /// Cookie name for CSRF token (default: "csrf_token").
2112    pub cookie_name: String,
2113    /// Header name for CSRF token (default: "x-csrf-token").
2114    pub header_name: String,
2115    /// CSRF protection mode (default: DoubleSubmit).
2116    pub mode: CsrfMode,
2117    /// Whether to rotate token on each request (default: false).
2118    pub rotate_token: bool,
2119    /// Whether in production mode (affects Secure cookie flag).
2120    pub production: bool,
2121    /// Custom error message for CSRF failures.
2122    pub error_message: Option<String>,
2123}
2124
2125impl Default for CsrfConfig {
2126    fn default() -> Self {
2127        Self {
2128            cookie_name: "csrf_token".to_string(),
2129            header_name: "x-csrf-token".to_string(),
2130            mode: CsrfMode::DoubleSubmit,
2131            rotate_token: false,
2132            production: true,
2133            error_message: None,
2134        }
2135    }
2136}
2137
2138impl CsrfConfig {
2139    /// Creates a new configuration with defaults.
2140    #[must_use]
2141    pub fn new() -> Self {
2142        Self::default()
2143    }
2144
2145    /// Sets the cookie name for CSRF token.
2146    #[must_use]
2147    pub fn cookie_name(mut self, name: impl Into<String>) -> Self {
2148        self.cookie_name = name.into();
2149        self
2150    }
2151
2152    /// Sets the header name for CSRF token.
2153    #[must_use]
2154    pub fn header_name(mut self, name: impl Into<String>) -> Self {
2155        self.header_name = name.into();
2156        self
2157    }
2158
2159    /// Sets the CSRF protection mode.
2160    #[must_use]
2161    pub fn mode(mut self, mode: CsrfMode) -> Self {
2162        self.mode = mode;
2163        self
2164    }
2165
2166    /// Enables token rotation on each request.
2167    #[must_use]
2168    pub fn rotate_token(mut self, rotate: bool) -> Self {
2169        self.rotate_token = rotate;
2170        self
2171    }
2172
2173    /// Sets production mode (affects Secure cookie flag).
2174    #[must_use]
2175    pub fn production(mut self, production: bool) -> Self {
2176        self.production = production;
2177        self
2178    }
2179
2180    /// Sets a custom error message for CSRF failures.
2181    #[must_use]
2182    pub fn error_message(mut self, message: impl Into<String>) -> Self {
2183        self.error_message = Some(message.into());
2184        self
2185    }
2186}
2187
2188/// CSRF protection middleware.
2189///
2190/// Implements protection against Cross-Site Request Forgery attacks using
2191/// the double-submit cookie pattern by default.
2192///
2193/// # How It Works
2194///
2195/// 1. For safe methods (GET, HEAD, OPTIONS, TRACE): generates a CSRF token
2196///    and sets it in a cookie if not present.
2197/// 2. For state-changing methods (POST, PUT, DELETE, PATCH): validates that
2198///    the token in the header matches the token in the cookie.
2199///
2200/// # Example
2201///
2202/// ```ignore
2203/// use fastapi_core::middleware::{CsrfMiddleware, CsrfConfig};
2204///
2205/// let mut stack = MiddlewareStack::new();
2206/// stack.push(CsrfMiddleware::new());
2207///
2208/// // Or with custom configuration:
2209/// let csrf = CsrfMiddleware::with_config(
2210///     CsrfConfig::new()
2211///         .header_name("X-XSRF-Token")
2212///         .cookie_name("XSRF-TOKEN")
2213///         .production(false)
2214/// );
2215/// stack.push(csrf);
2216/// ```
2217#[derive(Debug, Clone)]
2218pub struct CsrfMiddleware {
2219    config: CsrfConfig,
2220}
2221
2222impl Default for CsrfMiddleware {
2223    fn default() -> Self {
2224        Self::new()
2225    }
2226}
2227
2228impl CsrfMiddleware {
2229    /// Creates a new CSRF middleware with default configuration.
2230    #[must_use]
2231    pub fn new() -> Self {
2232        Self {
2233            config: CsrfConfig::default(),
2234        }
2235    }
2236
2237    /// Creates a new CSRF middleware with the given configuration.
2238    #[must_use]
2239    pub fn with_config(config: CsrfConfig) -> Self {
2240        Self { config }
2241    }
2242
2243    /// Checks if the HTTP method is safe (does not modify state).
2244    fn is_safe_method(method: crate::request::Method) -> bool {
2245        matches!(
2246            method,
2247            crate::request::Method::Get
2248                | crate::request::Method::Head
2249                | crate::request::Method::Options
2250                | crate::request::Method::Trace
2251        )
2252    }
2253
2254    /// Extracts the CSRF token from the cookie header.
2255    fn get_cookie_token(&self, req: &Request) -> Option<String> {
2256        let cookie_header = req.headers().get("cookie")?;
2257        let cookie_str = std::str::from_utf8(cookie_header).ok()?;
2258
2259        // Parse cookie header: "name1=value1; name2=value2"
2260        for part in cookie_str.split(';') {
2261            let part = part.trim();
2262            if let Some((name, value)) = part.split_once('=') {
2263                if name.trim() == self.config.cookie_name {
2264                    return Some(value.trim().to_string());
2265                }
2266            }
2267        }
2268        None
2269    }
2270
2271    /// Extracts the CSRF token from the request header.
2272    fn get_header_token(&self, req: &Request) -> Option<String> {
2273        let header_value = req.headers().get(&self.config.header_name)?;
2274        std::str::from_utf8(header_value)
2275            .ok()
2276            .map(|s| s.trim().to_string())
2277    }
2278
2279    /// Validates the CSRF token for state-changing requests.
2280    fn validate_token(&self, req: &Request) -> Result<Option<CsrfToken>, Response> {
2281        let header_token = self.get_header_token(req);
2282
2283        match self.config.mode {
2284            CsrfMode::DoubleSubmit => {
2285                let cookie_token = self.get_cookie_token(req);
2286
2287                match (header_token, cookie_token) {
2288                    (Some(header), Some(cookie))
2289                        if !header.is_empty()
2290                            && crate::extract::constant_time_eq(
2291                                header.as_bytes(),
2292                                cookie.as_bytes(),
2293                            ) =>
2294                    {
2295                        Ok(Some(CsrfToken::new(header)))
2296                    }
2297                    (None, _) | (_, None) => Err(self.csrf_error_response("CSRF token missing")),
2298                    _ => Err(self.csrf_error_response("CSRF token mismatch")),
2299                }
2300            }
2301            CsrfMode::HeaderOnly => match header_token {
2302                Some(token) if !token.is_empty() => Ok(Some(CsrfToken::new(token))),
2303                _ => Err(self.csrf_error_response("CSRF token missing in header")),
2304            },
2305        }
2306    }
2307
2308    /// Creates a 403 Forbidden response for CSRF failures.
2309    fn csrf_error_response(&self, default_message: &str) -> Response {
2310        let message = self
2311            .config
2312            .error_message
2313            .as_deref()
2314            .unwrap_or(default_message);
2315
2316        // Create a FastAPI-compatible error response
2317        let body = format!(
2318            r#"{{"detail":[{{"type":"csrf_error","loc":["header","{}"],"msg":"{}"}}]}}"#,
2319            self.config.header_name, message
2320        );
2321
2322        Response::with_status(crate::response::StatusCode::FORBIDDEN)
2323            .header("content-type", b"application/json".to_vec())
2324            .body(crate::response::ResponseBody::Bytes(body.into_bytes()))
2325    }
2326
2327    /// Creates the Set-Cookie header value for a CSRF token.
2328    fn make_set_cookie_header_value(cookie_name: &str, token: &str, production: bool) -> Vec<u8> {
2329        let mut cookie = format!("{}={}; Path=/; SameSite=Strict", cookie_name, token);
2330
2331        if production {
2332            cookie.push_str("; Secure");
2333        }
2334
2335        // Note: HttpOnly is NOT set - CSRF cookies must be readable by JavaScript
2336
2337        cookie.into_bytes()
2338    }
2339}
2340
2341impl Middleware for CsrfMiddleware {
2342    fn before<'a>(
2343        &'a self,
2344        _ctx: &'a RequestContext,
2345        req: &'a mut Request,
2346    ) -> BoxFuture<'a, ControlFlow> {
2347        Box::pin(async move {
2348            if Self::is_safe_method(req.method()) {
2349                // Safe methods: generate token if not present
2350                let existing_token = self.get_cookie_token(req);
2351                let token = existing_token
2352                    .map(CsrfToken::new)
2353                    .unwrap_or_else(CsrfToken::generate);
2354                req.insert_extension(token);
2355                ControlFlow::Continue
2356            } else {
2357                // State-changing methods: validate token
2358                match self.validate_token(req) {
2359                    Ok(Some(token)) => {
2360                        req.insert_extension(token);
2361                        ControlFlow::Continue
2362                    }
2363                    Ok(None) => ControlFlow::Continue,
2364                    Err(response) => ControlFlow::Break(response),
2365                }
2366            }
2367        })
2368    }
2369
2370    fn after<'a>(
2371        &'a self,
2372        _ctx: &'a RequestContext,
2373        req: &'a Request,
2374        response: Response,
2375    ) -> BoxFuture<'a, Response> {
2376        let config = self.config.clone();
2377        let is_safe = Self::is_safe_method(req.method());
2378        let existing_cookie_token = self.get_cookie_token(req);
2379        let token = req.get_extension::<CsrfToken>().cloned();
2380
2381        Box::pin(async move {
2382            // Set cookie for safe methods if:
2383            // 1. No cookie exists yet, or
2384            // 2. Token rotation is enabled
2385            if is_safe {
2386                let should_set_cookie = existing_cookie_token.is_none() || config.rotate_token;
2387
2388                if should_set_cookie {
2389                    if let Some(token) = token {
2390                        let cookie_value = Self::make_set_cookie_header_value(
2391                            &config.cookie_name,
2392                            token.as_str(),
2393                            config.production,
2394                        );
2395                        return response.header("set-cookie", cookie_value);
2396                    }
2397                }
2398            }
2399            response
2400        })
2401    }
2402
2403    fn name(&self) -> &'static str {
2404        "CSRF"
2405    }
2406}
2407
2408// ============================================================================
2409// Compression Middleware (requires "compression" feature)
2410// ============================================================================
2411
2412/// Configuration for response compression.
2413///
2414/// Controls when and how responses are compressed using gzip.
2415///
2416/// # Example
2417///
2418/// ```ignore
2419/// use fastapi_core::middleware::{CompressionMiddleware, CompressionConfig};
2420///
2421/// // Use defaults (min size 1024, level 6)
2422/// let mw = CompressionMiddleware::new();
2423///
2424/// // Custom configuration
2425/// let config = CompressionConfig::new()
2426///     .min_size(512)
2427///     .level(9);  // Maximum compression
2428/// let mw = CompressionMiddleware::with_config(config);
2429/// ```
2430#[cfg(feature = "compression")]
2431#[derive(Debug, Clone)]
2432pub struct CompressionConfig {
2433    /// Minimum response size in bytes to compress.
2434    /// Responses smaller than this are not compressed.
2435    /// Default: 1024 bytes (1 KB)
2436    pub min_size: usize,
2437    /// Compression level (1-9).
2438    /// 1 = fastest, 9 = best compression, 6 = balanced (default)
2439    pub level: u32,
2440    /// Content types that are already compressed and should be skipped.
2441    /// Default includes common compressed formats.
2442    pub skip_content_types: Vec<&'static str>,
2443}
2444
2445#[cfg(feature = "compression")]
2446impl Default for CompressionConfig {
2447    fn default() -> Self {
2448        Self {
2449            min_size: 1024,
2450            level: 6,
2451            skip_content_types: vec![
2452                // Images (already compressed)
2453                "image/jpeg",
2454                "image/png",
2455                "image/gif",
2456                "image/webp",
2457                "image/avif",
2458                // Video/Audio (already compressed)
2459                "video/",
2460                "audio/",
2461                // Archives (already compressed)
2462                "application/zip",
2463                "application/gzip",
2464                "application/x-gzip",
2465                "application/x-bzip2",
2466                "application/x-xz",
2467                "application/x-7z-compressed",
2468                "application/x-rar-compressed",
2469                // Other compressed formats
2470                "application/pdf",
2471                "application/woff",
2472                "application/woff2",
2473                "font/woff",
2474                "font/woff2",
2475            ],
2476        }
2477    }
2478}
2479
2480#[cfg(feature = "compression")]
2481impl CompressionConfig {
2482    /// Creates a new configuration with default values.
2483    #[must_use]
2484    pub fn new() -> Self {
2485        Self::default()
2486    }
2487
2488    /// Sets the minimum response size to compress.
2489    ///
2490    /// Responses smaller than this threshold will not be compressed,
2491    /// as compression overhead may exceed the savings.
2492    #[must_use]
2493    pub fn min_size(mut self, size: usize) -> Self {
2494        self.min_size = size;
2495        self
2496    }
2497
2498    /// Sets the compression level (1-9).
2499    ///
2500    /// - 1: Fastest compression, lowest ratio
2501    /// - 6: Balanced (default)
2502    /// - 9: Best compression ratio, slowest
2503    ///
2504    /// Values outside 1-9 are clamped.
2505    #[must_use]
2506    pub fn level(mut self, level: u32) -> Self {
2507        self.level = level.clamp(1, 9);
2508        self
2509    }
2510
2511    /// Adds a content type to skip during compression.
2512    ///
2513    /// Content types can be exact matches or prefixes (e.g., "video/" matches all video types).
2514    #[must_use]
2515    pub fn skip_content_type(mut self, content_type: &'static str) -> Self {
2516        self.skip_content_types.push(content_type);
2517        self
2518    }
2519
2520    /// Checks if the given content type should be skipped.
2521    fn should_skip_content_type(&self, content_type: &str) -> bool {
2522        let ct_lower = content_type.to_ascii_lowercase();
2523        for skip in &self.skip_content_types {
2524            if skip.ends_with('/') {
2525                // Prefix match (e.g., "video/" matches "video/mp4")
2526                if ct_lower.starts_with(*skip) {
2527                    return true;
2528                }
2529            } else {
2530                // Exact match (with optional charset)
2531                if ct_lower == *skip || ct_lower.starts_with(&format!("{skip};")) {
2532                    return true;
2533                }
2534            }
2535        }
2536        false
2537    }
2538}
2539
2540/// Middleware that compresses responses using gzip.
2541///
2542/// This middleware inspects the `Accept-Encoding` header and compresses
2543/// eligible responses with gzip. Compression is skipped for:
2544/// - Responses smaller than `min_size`
2545/// - Responses with already-compressed content types
2546/// - Responses that already have a `Content-Encoding` header
2547/// - Clients that don't accept gzip
2548///
2549/// # Example
2550///
2551/// ```ignore
2552/// use fastapi_core::middleware::{CompressionMiddleware, CompressionConfig, MiddlewareStack};
2553///
2554/// let mut stack = MiddlewareStack::new();
2555///
2556/// // Default configuration
2557/// stack.push(CompressionMiddleware::new());
2558///
2559/// // Or with custom settings
2560/// let config = CompressionConfig::new()
2561///     .min_size(256)   // Compress smaller responses
2562///     .level(9);       // Maximum compression
2563/// stack.push(CompressionMiddleware::with_config(config));
2564/// ```
2565///
2566/// # Headers
2567///
2568/// When compression is applied:
2569/// - `Content-Encoding: gzip` is added
2570/// - `Vary: Accept-Encoding` is added (for caching)
2571/// - `Content-Length` is updated to reflect compressed size
2572#[cfg(feature = "compression")]
2573#[derive(Debug, Clone)]
2574pub struct CompressionMiddleware {
2575    config: CompressionConfig,
2576}
2577
2578#[cfg(feature = "compression")]
2579impl Default for CompressionMiddleware {
2580    fn default() -> Self {
2581        Self::new()
2582    }
2583}
2584
2585#[cfg(feature = "compression")]
2586impl CompressionMiddleware {
2587    /// Creates compression middleware with default configuration.
2588    #[must_use]
2589    pub fn new() -> Self {
2590        Self {
2591            config: CompressionConfig::default(),
2592        }
2593    }
2594
2595    /// Creates compression middleware with custom configuration.
2596    #[must_use]
2597    pub fn with_config(config: CompressionConfig) -> Self {
2598        Self { config }
2599    }
2600
2601    /// Checks if the client accepts gzip encoding.
2602    fn accepts_gzip(req: &Request) -> bool {
2603        if let Some(accept_encoding) = req.headers().get("accept-encoding") {
2604            if let Ok(value) = std::str::from_utf8(accept_encoding) {
2605                // Parse Accept-Encoding header
2606                // Examples: "gzip", "gzip, deflate", "gzip;q=1.0, identity;q=0.5"
2607                for part in value.split(',') {
2608                    let encoding = part.trim().split(';').next().unwrap_or("").trim();
2609                    if encoding.eq_ignore_ascii_case("gzip") {
2610                        return true;
2611                    }
2612                    // Also accept "*" which means any encoding
2613                    if encoding == "*" {
2614                        return true;
2615                    }
2616                }
2617            }
2618        }
2619        false
2620    }
2621
2622    /// Gets the Content-Type from response headers.
2623    fn get_content_type(headers: &[(String, Vec<u8>)]) -> Option<String> {
2624        for (name, value) in headers {
2625            if name.eq_ignore_ascii_case("content-type") {
2626                return std::str::from_utf8(value).ok().map(String::from);
2627            }
2628        }
2629        None
2630    }
2631
2632    /// Checks if response already has Content-Encoding header.
2633    fn has_content_encoding(headers: &[(String, Vec<u8>)]) -> bool {
2634        headers
2635            .iter()
2636            .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"))
2637    }
2638
2639    /// Compresses data using gzip.
2640    fn compress_gzip(data: &[u8], level: u32) -> Result<Vec<u8>, std::io::Error> {
2641        use flate2::Compression;
2642        use flate2::write::GzEncoder;
2643        use std::io::Write;
2644
2645        let mut encoder = GzEncoder::new(Vec::new(), Compression::new(level));
2646        encoder.write_all(data)?;
2647        encoder.finish()
2648    }
2649}
2650
2651#[cfg(feature = "compression")]
2652impl Middleware for CompressionMiddleware {
2653    fn after<'a>(
2654        &'a self,
2655        _ctx: &'a RequestContext,
2656        req: &'a Request,
2657        response: Response,
2658    ) -> BoxFuture<'a, Response> {
2659        let config = self.config.clone();
2660
2661        Box::pin(async move {
2662            // Check if client accepts gzip
2663            if !Self::accepts_gzip(req) {
2664                return response;
2665            }
2666
2667            // Decompose response to inspect body
2668            let (status, headers, body) = response.into_parts();
2669
2670            // Check if already compressed
2671            if Self::has_content_encoding(&headers) {
2672                return Response::with_status(status)
2673                    .body(body)
2674                    .rebuild_with_headers(headers);
2675            }
2676
2677            // Get body bytes (only compress Bytes variant, not streaming)
2678            let body_bytes = match body {
2679                crate::response::ResponseBody::Bytes(bytes) => bytes,
2680                other => {
2681                    // Can't compress Empty or Stream bodies
2682                    return Response::with_status(status)
2683                        .body(other)
2684                        .rebuild_with_headers(headers);
2685                }
2686            };
2687
2688            // Check minimum size
2689            if body_bytes.len() < config.min_size {
2690                return Response::with_status(status)
2691                    .body(crate::response::ResponseBody::Bytes(body_bytes))
2692                    .rebuild_with_headers(headers);
2693            }
2694
2695            // Check content type
2696            if let Some(content_type) = Self::get_content_type(&headers) {
2697                if config.should_skip_content_type(&content_type) {
2698                    return Response::with_status(status)
2699                        .body(crate::response::ResponseBody::Bytes(body_bytes))
2700                        .rebuild_with_headers(headers);
2701                }
2702            }
2703
2704            // Compress the body
2705            match Self::compress_gzip(&body_bytes, config.level) {
2706                Ok(compressed) => {
2707                    // Only use compressed if it's actually smaller
2708                    if compressed.len() >= body_bytes.len() {
2709                        return Response::with_status(status)
2710                            .body(crate::response::ResponseBody::Bytes(body_bytes))
2711                            .rebuild_with_headers(headers);
2712                    }
2713
2714                    // Build response with compression headers
2715                    let mut resp = Response::with_status(status)
2716                        .body(crate::response::ResponseBody::Bytes(compressed));
2717
2718                    // Copy original headers (except content-length)
2719                    for (name, value) in headers {
2720                        if !name.eq_ignore_ascii_case("content-length") {
2721                            resp = resp.header(name, value);
2722                        }
2723                    }
2724
2725                    // Add compression headers
2726                    resp = resp.header("Content-Encoding", b"gzip".to_vec());
2727                    resp = resp.header("Vary", b"Accept-Encoding".to_vec());
2728
2729                    resp
2730                }
2731                Err(_) => {
2732                    // Compression failed, return original
2733                    Response::with_status(status)
2734                        .body(crate::response::ResponseBody::Bytes(body_bytes))
2735                        .rebuild_with_headers(headers)
2736                }
2737            }
2738        })
2739    }
2740
2741    fn name(&self) -> &'static str {
2742        "Compression"
2743    }
2744}
2745
2746// ---------------------------------------------------------------------------
2747// Rate Limiting Middleware
2748// ---------------------------------------------------------------------------
2749
2750use parking_lot::Mutex;
2751use std::collections::HashMap as StdHashMap;
2752use std::time::Duration;
2753
2754/// Rate limiting algorithm.
2755#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2756pub enum RateLimitAlgorithm {
2757    /// Token bucket: steady refill rate, allows short bursts.
2758    TokenBucket,
2759    /// Fixed window: resets at the start of each interval.
2760    FixedWindow,
2761    /// Sliding window: weighted combination of current and previous window.
2762    SlidingWindow,
2763}
2764
2765/// Result of a rate limit check.
2766#[derive(Debug, Clone)]
2767pub struct RateLimitResult {
2768    /// Whether the request is allowed.
2769    pub allowed: bool,
2770    /// Maximum requests per window.
2771    pub limit: u64,
2772    /// Remaining requests in the current window.
2773    pub remaining: u64,
2774    /// Seconds until the window resets.
2775    pub reset_after_secs: u64,
2776}
2777
2778/// Extracts a rate limit key from a request.
2779///
2780/// Different extractors allow rate limiting by different criteria:
2781/// IP address, API key header, path, or custom logic.
2782pub trait KeyExtractor: Send + Sync {
2783    /// Extract the key string from the request.
2784    ///
2785    /// Returns `None` if no key can be extracted (request is not rate-limited).
2786    fn extract_key(&self, req: &Request) -> Option<String>;
2787}
2788
2789/// The remote address (peer IP) of the TCP connection.
2790///
2791/// This should be set by the HTTP server layer as a request extension to enable
2792/// secure IP-based rate limiting. Unlike `X-Forwarded-For` headers, this value
2793/// cannot be spoofed by clients.
2794///
2795/// # Example
2796///
2797/// ```ignore
2798/// // In your HTTP server code:
2799/// use fastapi_core::middleware::RemoteAddr;
2800/// use std::net::IpAddr;
2801///
2802/// // When accepting a connection:
2803/// let peer_addr: IpAddr = socket.peer_addr()?.ip();
2804/// request.insert_extension(RemoteAddr(peer_addr));
2805/// ```
2806#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
2807pub struct RemoteAddr(pub std::net::IpAddr);
2808
2809impl std::fmt::Display for RemoteAddr {
2810    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2811        write!(f, "{}", self.0)
2812    }
2813}
2814
2815/// Rate limit by the actual TCP connection IP address.
2816///
2817/// This is the **secure** way to do IP-based rate limiting. It uses the
2818/// `RemoteAddr` extension set by the HTTP server, which represents the actual
2819/// TCP peer address and cannot be spoofed by clients.
2820///
2821/// # Prerequisites
2822///
2823/// Your HTTP server must set the `RemoteAddr` extension on each request:
2824///
2825/// ```ignore
2826/// request.insert_extension(RemoteAddr(peer_addr.ip()));
2827/// ```
2828///
2829/// If `RemoteAddr` is not set, this extractor returns `None` (request is not rate-limited).
2830///
2831/// # Security
2832///
2833/// This extractor is safe to use without a reverse proxy, as it relies on the
2834/// TCP connection's peer address rather than client-supplied headers.
2835#[derive(Debug, Clone)]
2836pub struct ConnectedIpKeyExtractor;
2837
2838impl KeyExtractor for ConnectedIpKeyExtractor {
2839    fn extract_key(&self, req: &Request) -> Option<String> {
2840        req.get_extension::<RemoteAddr>().map(ToString::to_string)
2841    }
2842}
2843
2844/// Rate limit by client IP address from `X-Forwarded-For` or `X-Real-IP` headers.
2845///
2846/// # Security Warning
2847///
2848/// **This extractor trusts client-supplied headers, which can be spoofed!**
2849///
2850/// Only use this extractor when:
2851/// 1. Your application runs behind a trusted reverse proxy (nginx, Cloudflare, etc.)
2852/// 2. The proxy is configured to set/override these headers
2853/// 3. Clients cannot connect directly to your application
2854///
2855/// For direct client connections, use [`ConnectedIpKeyExtractor`] instead.
2856///
2857/// # How Proxies Work
2858///
2859/// When a request passes through proxies:
2860/// - `X-Forwarded-For: client_ip, proxy1_ip, proxy2_ip`
2861/// - The first IP is typically the original client
2862/// - Each proxy appends its own IP
2863///
2864/// This extractor takes the **first** IP from `X-Forwarded-For`, which is correct
2865/// only if your trusted proxy always sets/overwrites this header.
2866///
2867/// # Fallback Behavior
2868///
2869/// Falls back to `"unknown"` when no IP header is present, which means all such
2870/// requests share the same rate limit bucket. This may not be desirable in
2871/// production - consider using [`TrustedProxyIpKeyExtractor`] for better control.
2872#[derive(Debug, Clone)]
2873pub struct IpKeyExtractor;
2874
2875impl KeyExtractor for IpKeyExtractor {
2876    fn extract_key(&self, req: &Request) -> Option<String> {
2877        // Try X-Forwarded-For first, then X-Real-IP, then fall back
2878        if let Some(forwarded) = req.headers().get("x-forwarded-for") {
2879            if let Ok(s) = std::str::from_utf8(forwarded) {
2880                // Take the first IP (client IP) from the chain
2881                if let Some(ip) = s.split(',').next() {
2882                    return Some(ip.trim().to_string());
2883                }
2884            }
2885        }
2886        if let Some(real_ip) = req.headers().get("x-real-ip") {
2887            if let Ok(s) = std::str::from_utf8(real_ip) {
2888                return Some(s.trim().to_string());
2889            }
2890        }
2891        Some("unknown".to_string())
2892    }
2893}
2894
2895/// Rate limit by client IP with trusted proxy validation.
2896///
2897/// This is a **secure** IP extractor that only trusts `X-Forwarded-For` headers
2898/// when the immediate upstream (TCP peer) is a known trusted proxy.
2899///
2900/// # How It Works
2901///
2902/// 1. If `RemoteAddr` extension is set and matches a trusted proxy CIDR:
2903///    - Extract client IP from `X-Forwarded-For` (first IP in chain)
2904/// 2. If `RemoteAddr` is set but NOT a trusted proxy:
2905///    - Use the `RemoteAddr` directly (the client connected directly)
2906/// 3. If `RemoteAddr` is not set:
2907///    - Returns `None` (request is not rate-limited) - safer than guessing
2908///
2909/// # Example
2910///
2911/// ```ignore
2912/// use fastapi_core::middleware::{TrustedProxyIpKeyExtractor, RateLimitMiddleware};
2913///
2914/// let extractor = TrustedProxyIpKeyExtractor::new()
2915///     .trust_cidr("10.0.0.0/8")      // Internal network
2916///     .trust_cidr("172.16.0.0/12")   // Docker default
2917///     .trust_loopback();              // localhost
2918///
2919/// let rate_limiter = RateLimitMiddleware::builder()
2920///     .requests(100)
2921///     .per(Duration::from_secs(60))
2922///     .key_extractor(extractor)
2923///     .build();
2924/// ```
2925#[derive(Debug, Clone)]
2926pub struct TrustedProxyIpKeyExtractor {
2927    /// List of trusted proxy CIDRs (stored as (ip, prefix_len))
2928    trusted_cidrs: Vec<(std::net::IpAddr, u8)>,
2929}
2930
2931impl TrustedProxyIpKeyExtractor {
2932    /// Create a new trusted proxy IP extractor with no trusted proxies.
2933    #[must_use]
2934    pub fn new() -> Self {
2935        Self {
2936            trusted_cidrs: Vec::new(),
2937        }
2938    }
2939
2940    /// Add a trusted CIDR range (e.g., "10.0.0.0/8", "192.168.1.0/24").
2941    ///
2942    /// # Panics
2943    ///
2944    /// Panics if the CIDR string is invalid.
2945    #[must_use]
2946    pub fn trust_cidr(mut self, cidr: &str) -> Self {
2947        let (ip, prefix) = parse_cidr(cidr).expect("invalid CIDR notation");
2948        self.trusted_cidrs.push((ip, prefix));
2949        self
2950    }
2951
2952    /// Trust loopback addresses (127.0.0.0/8 for IPv4, ::1/128 for IPv6).
2953    #[must_use]
2954    pub fn trust_loopback(mut self) -> Self {
2955        self.trusted_cidrs.push((
2956            std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 0)),
2957            8,
2958        ));
2959        self.trusted_cidrs
2960            .push((std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST), 128));
2961        self
2962    }
2963
2964    /// Check if an IP is within any trusted CIDR range.
2965    fn is_trusted(&self, ip: std::net::IpAddr) -> bool {
2966        self.trusted_cidrs
2967            .iter()
2968            .any(|(cidr_ip, prefix)| ip_in_cidr(ip, *cidr_ip, *prefix))
2969    }
2970
2971    /// Extract client IP from X-Forwarded-For header.
2972    fn extract_from_header(&self, req: &Request) -> Option<String> {
2973        if let Some(forwarded) = req.headers().get("x-forwarded-for") {
2974            if let Ok(s) = std::str::from_utf8(forwarded) {
2975                if let Some(ip) = s.split(',').next() {
2976                    return Some(ip.trim().to_string());
2977                }
2978            }
2979        }
2980        if let Some(real_ip) = req.headers().get("x-real-ip") {
2981            if let Ok(s) = std::str::from_utf8(real_ip) {
2982                return Some(s.trim().to_string());
2983            }
2984        }
2985        None
2986    }
2987}
2988
2989impl Default for TrustedProxyIpKeyExtractor {
2990    fn default() -> Self {
2991        Self::new()
2992    }
2993}
2994
2995impl KeyExtractor for TrustedProxyIpKeyExtractor {
2996    fn extract_key(&self, req: &Request) -> Option<String> {
2997        let remote = req.get_extension::<RemoteAddr>()?;
2998
2999        if self.is_trusted(remote.0) {
3000            // Request came from trusted proxy - use header value
3001            self.extract_from_header(req)
3002                .or_else(|| Some(remote.to_string()))
3003        } else {
3004            // Request came directly from client - use connection IP
3005            Some(remote.to_string())
3006        }
3007    }
3008}
3009
3010/// Parse a CIDR string like "192.168.1.0/24" into (ip, prefix_length).
3011fn parse_cidr(cidr: &str) -> Option<(std::net::IpAddr, u8)> {
3012    let (ip_str, prefix_str) = cidr.split_once('/')?;
3013    let ip: std::net::IpAddr = ip_str.parse().ok()?;
3014    let prefix: u8 = prefix_str.parse().ok()?;
3015
3016    // Validate prefix length
3017    let max_prefix = match ip {
3018        std::net::IpAddr::V4(_) => 32,
3019        std::net::IpAddr::V6(_) => 128,
3020    };
3021    if prefix > max_prefix {
3022        return None;
3023    }
3024
3025    Some((ip, prefix))
3026}
3027
3028/// Check if an IP address is within a CIDR range.
3029fn ip_in_cidr(ip: std::net::IpAddr, cidr_ip: std::net::IpAddr, prefix: u8) -> bool {
3030    match (ip, cidr_ip) {
3031        (std::net::IpAddr::V4(ip), std::net::IpAddr::V4(cidr)) => {
3032            if prefix == 0 {
3033                return true;
3034            }
3035            let ip_bits = u32::from(ip);
3036            let cidr_bits = u32::from(cidr);
3037            let mask = !0u32 << (32 - prefix);
3038            (ip_bits & mask) == (cidr_bits & mask)
3039        }
3040        (std::net::IpAddr::V6(ip), std::net::IpAddr::V6(cidr)) => {
3041            if prefix == 0 {
3042                return true;
3043            }
3044            let ip_bits = u128::from(ip);
3045            let cidr_bits = u128::from(cidr);
3046            let mask = !0u128 << (128 - prefix);
3047            (ip_bits & mask) == (cidr_bits & mask)
3048        }
3049        _ => false, // IPv4 vs IPv6 mismatch
3050    }
3051}
3052
3053/// Rate limit by a specific header value (e.g., `X-API-Key`).
3054#[derive(Debug, Clone)]
3055pub struct HeaderKeyExtractor {
3056    header_name: String,
3057}
3058
3059impl HeaderKeyExtractor {
3060    /// Create a new header key extractor.
3061    #[must_use]
3062    pub fn new(header_name: impl Into<String>) -> Self {
3063        Self {
3064            header_name: header_name.into(),
3065        }
3066    }
3067}
3068
3069impl KeyExtractor for HeaderKeyExtractor {
3070    fn extract_key(&self, req: &Request) -> Option<String> {
3071        req.headers()
3072            .get(&self.header_name)
3073            .and_then(|v| std::str::from_utf8(v).ok())
3074            .map(str::to_string)
3075    }
3076}
3077
3078/// Rate limit by request path.
3079#[derive(Debug, Clone)]
3080pub struct PathKeyExtractor;
3081
3082impl KeyExtractor for PathKeyExtractor {
3083    fn extract_key(&self, req: &Request) -> Option<String> {
3084        Some(req.path().to_string())
3085    }
3086}
3087
3088/// A composite key extractor that combines multiple extractors.
3089///
3090/// Keys from all extractors are joined with `:` to form a composite key.
3091/// If any extractor returns `None`, that part is omitted.
3092pub struct CompositeKeyExtractor {
3093    extractors: Vec<Box<dyn KeyExtractor>>,
3094}
3095
3096impl CompositeKeyExtractor {
3097    /// Create a composite key extractor from multiple extractors.
3098    #[must_use]
3099    pub fn new(extractors: Vec<Box<dyn KeyExtractor>>) -> Self {
3100        Self { extractors }
3101    }
3102}
3103
3104impl KeyExtractor for CompositeKeyExtractor {
3105    fn extract_key(&self, req: &Request) -> Option<String> {
3106        let parts: Vec<String> = self
3107            .extractors
3108            .iter()
3109            .filter_map(|e| e.extract_key(req))
3110            .collect();
3111        if parts.is_empty() {
3112            None
3113        } else {
3114            Some(parts.join(":"))
3115        }
3116    }
3117}
3118
3119/// Token bucket state for a single key.
3120#[derive(Debug, Clone)]
3121struct TokenBucketState {
3122    tokens: f64,
3123    last_refill: Instant,
3124}
3125
3126/// Fixed window state for a single key.
3127#[derive(Debug, Clone)]
3128struct FixedWindowState {
3129    count: u64,
3130    window_start: Instant,
3131}
3132
3133/// Sliding window state for a single key.
3134#[derive(Debug, Clone)]
3135struct SlidingWindowState {
3136    current_count: u64,
3137    previous_count: u64,
3138    current_window_start: Instant,
3139}
3140
3141/// In-memory rate limit store.
3142///
3143/// Uses a `HashMap` protected by a `Mutex` for thread-safe access.
3144/// Suitable for single-process deployments. For distributed systems,
3145/// implement a custom store using Redis or similar.
3146pub struct InMemoryRateLimitStore {
3147    token_buckets: Mutex<StdHashMap<String, TokenBucketState>>,
3148    fixed_windows: Mutex<StdHashMap<String, FixedWindowState>>,
3149    sliding_windows: Mutex<StdHashMap<String, SlidingWindowState>>,
3150}
3151
3152impl InMemoryRateLimitStore {
3153    /// Create a new in-memory store.
3154    #[must_use]
3155    pub fn new() -> Self {
3156        Self {
3157            token_buckets: Mutex::new(StdHashMap::new()),
3158            fixed_windows: Mutex::new(StdHashMap::new()),
3159            sliding_windows: Mutex::new(StdHashMap::new()),
3160        }
3161    }
3162
3163    #[allow(clippy::cast_precision_loss, clippy::cast_sign_loss)]
3164    fn check_token_bucket(
3165        &self,
3166        key: &str,
3167        max_tokens: u64,
3168        refill_rate: f64,
3169        window: Duration,
3170    ) -> RateLimitResult {
3171        let mut buckets = self.token_buckets.lock();
3172        let now = Instant::now();
3173
3174        let state = buckets
3175            .entry(key.to_string())
3176            .or_insert_with(|| TokenBucketState {
3177                tokens: max_tokens as f64,
3178                last_refill: now,
3179            });
3180
3181        // Refill tokens based on elapsed time
3182        let elapsed = now.duration_since(state.last_refill);
3183        let refill = elapsed.as_secs_f64() * refill_rate;
3184        state.tokens = (state.tokens + refill).min(max_tokens as f64);
3185        state.last_refill = now;
3186
3187        if state.tokens >= 1.0 {
3188            state.tokens -= 1.0;
3189            RateLimitResult {
3190                allowed: true,
3191                limit: max_tokens,
3192                remaining: state.tokens as u64,
3193                reset_after_secs: if state.tokens < max_tokens as f64 {
3194                    ((max_tokens as f64 - state.tokens) / refill_rate).ceil() as u64
3195                } else {
3196                    window.as_secs()
3197                },
3198            }
3199        } else {
3200            let wait_secs = ((1.0 - state.tokens) / refill_rate).ceil() as u64;
3201            RateLimitResult {
3202                allowed: false,
3203                limit: max_tokens,
3204                remaining: 0,
3205                reset_after_secs: wait_secs,
3206            }
3207        }
3208    }
3209
3210    fn check_fixed_window(
3211        &self,
3212        key: &str,
3213        max_requests: u64,
3214        window: Duration,
3215    ) -> RateLimitResult {
3216        let mut windows = self.fixed_windows.lock();
3217        let now = Instant::now();
3218
3219        let state = windows
3220            .entry(key.to_string())
3221            .or_insert_with(|| FixedWindowState {
3222                count: 0,
3223                window_start: now,
3224            });
3225
3226        // Check if window has expired
3227        let elapsed = now.duration_since(state.window_start);
3228        if elapsed >= window {
3229            state.count = 0;
3230            state.window_start = now;
3231        }
3232
3233        let remaining_time = window
3234            .checked_sub(now.duration_since(state.window_start))
3235            .unwrap_or(Duration::ZERO);
3236
3237        if state.count < max_requests {
3238            state.count += 1;
3239            RateLimitResult {
3240                allowed: true,
3241                limit: max_requests,
3242                remaining: max_requests - state.count,
3243                reset_after_secs: remaining_time.as_secs(),
3244            }
3245        } else {
3246            RateLimitResult {
3247                allowed: false,
3248                limit: max_requests,
3249                remaining: 0,
3250                reset_after_secs: remaining_time.as_secs(),
3251            }
3252        }
3253    }
3254
3255    #[allow(clippy::cast_precision_loss, clippy::cast_sign_loss)]
3256    fn check_sliding_window(
3257        &self,
3258        key: &str,
3259        max_requests: u64,
3260        window: Duration,
3261    ) -> RateLimitResult {
3262        let mut windows = self.sliding_windows.lock();
3263        let now = Instant::now();
3264
3265        let state = windows
3266            .entry(key.to_string())
3267            .or_insert_with(|| SlidingWindowState {
3268                current_count: 0,
3269                previous_count: 0,
3270                current_window_start: now,
3271            });
3272
3273        // Check if we need to rotate windows
3274        let elapsed = now.duration_since(state.current_window_start);
3275        if elapsed >= window {
3276            // Rotate: current becomes previous
3277            state.previous_count = state.current_count;
3278            state.current_count = 0;
3279            state.current_window_start = now;
3280        }
3281
3282        // Calculate weighted count using the proportion of the previous window
3283        // that overlaps with the current sliding window
3284        let window_elapsed = now.duration_since(state.current_window_start);
3285        let window_fraction = window_elapsed.as_secs_f64() / window.as_secs_f64();
3286        let previous_weight = 1.0 - window_fraction;
3287        let weighted_count =
3288            (state.previous_count as f64 * previous_weight) + state.current_count as f64;
3289
3290        let remaining_time = window.checked_sub(window_elapsed).unwrap_or(Duration::ZERO);
3291
3292        if weighted_count < max_requests as f64 {
3293            state.current_count += 1;
3294            let new_weighted =
3295                (state.previous_count as f64 * previous_weight) + state.current_count as f64;
3296            let remaining = (max_requests as f64 - new_weighted).max(0.0) as u64;
3297            RateLimitResult {
3298                allowed: true,
3299                limit: max_requests,
3300                remaining,
3301                reset_after_secs: remaining_time.as_secs(),
3302            }
3303        } else {
3304            RateLimitResult {
3305                allowed: false,
3306                limit: max_requests,
3307                remaining: 0,
3308                reset_after_secs: remaining_time.as_secs(),
3309            }
3310        }
3311    }
3312
3313    /// Check and consume a request against the rate limit.
3314    #[allow(clippy::cast_precision_loss)]
3315    pub fn check(
3316        &self,
3317        key: &str,
3318        algorithm: RateLimitAlgorithm,
3319        max_requests: u64,
3320        window: Duration,
3321    ) -> RateLimitResult {
3322        match algorithm {
3323            RateLimitAlgorithm::TokenBucket => {
3324                let refill_rate = max_requests as f64 / window.as_secs_f64();
3325                self.check_token_bucket(key, max_requests, refill_rate, window)
3326            }
3327            RateLimitAlgorithm::FixedWindow => self.check_fixed_window(key, max_requests, window),
3328            RateLimitAlgorithm::SlidingWindow => {
3329                self.check_sliding_window(key, max_requests, window)
3330            }
3331        }
3332    }
3333}
3334
3335impl Default for InMemoryRateLimitStore {
3336    fn default() -> Self {
3337        Self::new()
3338    }
3339}
3340
3341/// Configuration for the rate limiting middleware.
3342///
3343/// Controls request rate limits using token bucket or sliding window algorithms.
3344/// When the limit is exceeded, a 429 Too Many Requests response is returned.
3345///
3346/// # Defaults
3347///
3348/// | Setting | Default |
3349/// |---------|---------|
3350/// | `max_requests` | 100 |
3351/// | `window` | 60s |
3352/// | `algorithm` | `TokenBucket` |
3353/// | `include_headers` | `true` |
3354/// | `retry_message` | "Rate limit exceeded. Please retry later." |
3355///
3356/// # Response Headers (when `include_headers` is `true`)
3357///
3358/// - `X-RateLimit-Limit`: Maximum requests per window
3359/// - `X-RateLimit-Remaining`: Remaining requests in current window
3360/// - `X-RateLimit-Reset`: Seconds until window resets
3361/// - `Retry-After`: Seconds to wait (only on 429 responses)
3362///
3363/// # Example
3364///
3365/// ```ignore
3366/// use fastapi_core::middleware::{RateLimitBuilder, RateLimitAlgorithm};
3367///
3368/// let rate_limit = RateLimitBuilder::new()
3369///     .max_requests(1000)
3370///     .window_secs(3600) // 1000 req/hour
3371///     .algorithm(RateLimitAlgorithm::SlidingWindow)
3372///     .build();
3373/// ```
3374#[derive(Clone)]
3375pub struct RateLimitConfig {
3376    /// Maximum number of requests allowed per window.
3377    pub max_requests: u64,
3378    /// Time window for the rate limit.
3379    pub window: Duration,
3380    /// The algorithm to use.
3381    pub algorithm: RateLimitAlgorithm,
3382    /// Whether to include rate limit headers in responses.
3383    pub include_headers: bool,
3384    /// Custom message for 429 responses.
3385    pub retry_message: String,
3386}
3387
3388impl Default for RateLimitConfig {
3389    fn default() -> Self {
3390        Self {
3391            max_requests: 100,
3392            window: Duration::from_secs(60),
3393            algorithm: RateLimitAlgorithm::TokenBucket,
3394            include_headers: true,
3395            retry_message: "Rate limit exceeded. Please retry later.".to_string(),
3396        }
3397    }
3398}
3399
3400/// Builder for `RateLimitConfig`.
3401pub struct RateLimitBuilder {
3402    config: RateLimitConfig,
3403    key_extractor: Option<Box<dyn KeyExtractor>>,
3404}
3405
3406impl RateLimitBuilder {
3407    /// Create a new rate limit builder with default configuration.
3408    #[must_use]
3409    pub fn new() -> Self {
3410        Self {
3411            config: RateLimitConfig::default(),
3412            key_extractor: None,
3413        }
3414    }
3415
3416    /// Set the maximum number of requests per window.
3417    #[must_use]
3418    pub fn requests(mut self, max: u64) -> Self {
3419        self.config.max_requests = max;
3420        self
3421    }
3422
3423    /// Set the time window.
3424    #[must_use]
3425    pub fn per(mut self, window: Duration) -> Self {
3426        self.config.window = window;
3427        self
3428    }
3429
3430    /// Shorthand: set the window to the given number of seconds.
3431    #[must_use]
3432    pub fn per_second(self, secs: u64) -> Self {
3433        self.per(Duration::from_secs(secs))
3434    }
3435
3436    /// Shorthand: set the window to the given number of minutes.
3437    #[must_use]
3438    pub fn per_minute(self, minutes: u64) -> Self {
3439        self.per(Duration::from_secs(minutes * 60))
3440    }
3441
3442    /// Shorthand: set the window to the given number of hours.
3443    #[must_use]
3444    pub fn per_hour(self, hours: u64) -> Self {
3445        self.per(Duration::from_secs(hours * 3600))
3446    }
3447
3448    /// Set the rate limiting algorithm.
3449    #[must_use]
3450    pub fn algorithm(mut self, algo: RateLimitAlgorithm) -> Self {
3451        self.config.algorithm = algo;
3452        self
3453    }
3454
3455    /// Set the key extractor.
3456    #[must_use]
3457    pub fn key_extractor(mut self, extractor: impl KeyExtractor + 'static) -> Self {
3458        self.key_extractor = Some(Box::new(extractor));
3459        self
3460    }
3461
3462    /// Whether to include rate limit headers in responses.
3463    #[must_use]
3464    pub fn include_headers(mut self, include: bool) -> Self {
3465        self.config.include_headers = include;
3466        self
3467    }
3468
3469    /// Set the custom message for 429 responses.
3470    #[must_use]
3471    pub fn retry_message(mut self, msg: impl Into<String>) -> Self {
3472        self.config.retry_message = msg.into();
3473        self
3474    }
3475
3476    /// Build the rate limiting middleware.
3477    #[must_use]
3478    pub fn build(self) -> RateLimitMiddleware {
3479        let key_extractor = self
3480            .key_extractor
3481            .unwrap_or_else(|| Box::new(IpKeyExtractor));
3482        RateLimitMiddleware {
3483            config: self.config,
3484            store: Arc::new(InMemoryRateLimitStore::new()),
3485            key_extractor: Arc::from(key_extractor),
3486        }
3487    }
3488}
3489
3490impl Default for RateLimitBuilder {
3491    fn default() -> Self {
3492        Self::new()
3493    }
3494}
3495
3496/// Extension type stored on requests to carry rate limit info to `after` hook.
3497#[derive(Debug, Clone)]
3498struct RateLimitInfo {
3499    result: RateLimitResult,
3500}
3501
3502/// Rate limiting middleware.
3503///
3504/// Tracks request rates per key and returns 429 Too Many Requests
3505/// when a client exceeds the configured limit.
3506///
3507/// # Example
3508///
3509/// ```ignore
3510/// use fastapi_core::middleware::{RateLimitMiddleware, RateLimitAlgorithm, IpKeyExtractor};
3511/// use std::time::Duration;
3512///
3513/// let rate_limiter = RateLimitMiddleware::builder()
3514///     .requests(100)
3515///     .per(Duration::from_secs(60))
3516///     .algorithm(RateLimitAlgorithm::TokenBucket)
3517///     .key_extractor(IpKeyExtractor)
3518///     .build();
3519///
3520/// let app = App::builder()
3521///     .middleware(rate_limiter)
3522///     .build();
3523/// ```
3524pub struct RateLimitMiddleware {
3525    config: RateLimitConfig,
3526    store: Arc<InMemoryRateLimitStore>,
3527    key_extractor: Arc<dyn KeyExtractor>,
3528}
3529
3530impl RateLimitMiddleware {
3531    /// Create a new rate limiter with default settings (100 requests/minute, token bucket, IP-based).
3532    #[must_use]
3533    pub fn new() -> Self {
3534        Self::builder().build()
3535    }
3536
3537    /// Create a builder for configuring the rate limiter.
3538    #[must_use]
3539    pub fn builder() -> RateLimitBuilder {
3540        RateLimitBuilder::new()
3541    }
3542
3543    /// Format a 429 response body as JSON.
3544    fn too_many_requests_body(&self, result: &RateLimitResult) -> Vec<u8> {
3545        format!(
3546            r#"{{"detail":"{}","retry_after_secs":{}}}"#,
3547            self.config.retry_message, result.reset_after_secs
3548        )
3549        .into_bytes()
3550    }
3551
3552    /// Add rate limit headers to a response.
3553    fn add_headers(&self, response: Response, result: &RateLimitResult) -> Response {
3554        response
3555            .header("X-RateLimit-Limit", result.limit.to_string().into_bytes())
3556            .header(
3557                "X-RateLimit-Remaining",
3558                result.remaining.to_string().into_bytes(),
3559            )
3560            .header(
3561                "X-RateLimit-Reset",
3562                result.reset_after_secs.to_string().into_bytes(),
3563            )
3564    }
3565}
3566
3567impl Default for RateLimitMiddleware {
3568    fn default() -> Self {
3569        Self::new()
3570    }
3571}
3572
3573impl Middleware for RateLimitMiddleware {
3574    fn before<'a>(
3575        &'a self,
3576        _ctx: &'a RequestContext,
3577        req: &'a mut Request,
3578    ) -> BoxFuture<'a, ControlFlow> {
3579        Box::pin(async move {
3580            // Extract the key for this request
3581            let Some(key) = self.key_extractor.extract_key(req) else {
3582                // No key extracted — skip rate limiting for this request
3583                return ControlFlow::Continue;
3584            };
3585
3586            // Check the rate limit
3587            let result = self.store.check(
3588                &key,
3589                self.config.algorithm,
3590                self.config.max_requests,
3591                self.config.window,
3592            );
3593
3594            if result.allowed {
3595                // Store the result for the `after` hook to add headers
3596                req.insert_extension(RateLimitInfo { result });
3597                ControlFlow::Continue
3598            } else {
3599                // Return 429 Too Many Requests
3600                let body = self.too_many_requests_body(&result);
3601                let mut response =
3602                    Response::with_status(crate::response::StatusCode::TOO_MANY_REQUESTS)
3603                        .header("Content-Type", b"application/json".to_vec())
3604                        .header(
3605                            "Retry-After",
3606                            result.reset_after_secs.to_string().into_bytes(),
3607                        )
3608                        .body(crate::response::ResponseBody::Bytes(body));
3609
3610                if self.config.include_headers {
3611                    response = self.add_headers(response, &result);
3612                }
3613
3614                ControlFlow::Break(response)
3615            }
3616        })
3617    }
3618
3619    fn after<'a>(
3620        &'a self,
3621        _ctx: &'a RequestContext,
3622        req: &'a Request,
3623        response: Response,
3624    ) -> BoxFuture<'a, Response> {
3625        Box::pin(async move {
3626            if !self.config.include_headers {
3627                return response;
3628            }
3629
3630            // Retrieve the rate limit info stored in `before`
3631            if let Some(info) = req.get_extension::<RateLimitInfo>() {
3632                self.add_headers(response, &info.result)
3633            } else {
3634                response
3635            }
3636        })
3637    }
3638
3639    fn name(&self) -> &'static str {
3640        "RateLimit"
3641    }
3642}
3643
3644// ---------------------------------------------------------------------------
3645// End Rate Limiting Middleware
3646// ---------------------------------------------------------------------------
3647
3648// ============================================================================
3649// Request Inspection Middleware (Development)
3650// ============================================================================
3651
3652/// Verbosity level for the request inspection middleware.
3653///
3654/// Controls how much detail is shown in the request/response output.
3655#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3656pub enum InspectionVerbosity {
3657    /// Minimal: one-line summary per request/response.
3658    ///
3659    /// Shows: `-->  GET /path` and `<--  200 OK (12ms)`
3660    Minimal,
3661
3662    /// Normal: summary plus headers.
3663    ///
3664    /// Shows method/path, all headers (filtered), and status/timing.
3665    Normal,
3666
3667    /// Verbose: summary, headers, and body preview.
3668    ///
3669    /// Shows everything in Normal plus request/response body previews
3670    /// with JSON pretty-printing when applicable.
3671    Verbose,
3672}
3673
3674/// Development middleware that logs detailed, human-readable request/response
3675/// information using arrow-style formatting.
3676///
3677/// This middleware is designed for development and debugging. It outputs
3678/// concise inspection lines showing request flow:
3679///
3680/// ```text
3681/// -->  POST /api/users
3682///      Content-Type: application/json
3683///      Content-Length: 42
3684///      {"name": "Alice"}
3685/// <--  201 Created (12ms)
3686///      Content-Type: application/json
3687///      {"id": 1, "name": "Alice"}
3688/// ```
3689///
3690/// # Features
3691///
3692/// - **Configurable verbosity**: Minimal (one-liner), Normal (+ headers),
3693///   Verbose (+ body preview with JSON pretty-printing)
3694/// - **Slow request highlighting**: Marks requests exceeding a threshold
3695/// - **Sensitive header filtering**: Redacts authorization, cookie, etc.
3696/// - **JSON pretty-printing**: Detects JSON bodies and formats them
3697/// - **Body size limits**: Truncates large bodies to a configurable max
3698///
3699/// # Example
3700///
3701/// ```ignore
3702/// use fastapi_core::middleware::RequestInspectionMiddleware;
3703///
3704/// let inspector = RequestInspectionMiddleware::new()
3705///     .verbosity(InspectionVerbosity::Verbose)
3706///     .slow_threshold_ms(500)
3707///     .max_body_preview(4096);
3708///
3709/// let mut stack = MiddlewareStack::new();
3710/// stack.push(inspector);
3711/// ```
3712pub struct RequestInspectionMiddleware {
3713    log_config: LogConfig,
3714    verbosity: InspectionVerbosity,
3715    redact_headers: HashSet<String>,
3716    slow_threshold_ms: u64,
3717    max_body_preview: usize,
3718}
3719
3720impl Default for RequestInspectionMiddleware {
3721    fn default() -> Self {
3722        Self {
3723            log_config: LogConfig::development(),
3724            verbosity: InspectionVerbosity::Normal,
3725            redact_headers: default_redacted_headers(),
3726            slow_threshold_ms: 1000,
3727            max_body_preview: 2048,
3728        }
3729    }
3730}
3731
3732impl RequestInspectionMiddleware {
3733    /// Create a new inspection middleware with development defaults.
3734    #[must_use]
3735    pub fn new() -> Self {
3736        Self::default()
3737    }
3738
3739    /// Set the logging configuration.
3740    #[must_use]
3741    pub fn log_config(mut self, config: LogConfig) -> Self {
3742        self.log_config = config;
3743        self
3744    }
3745
3746    /// Set the verbosity level.
3747    #[must_use]
3748    pub fn verbosity(mut self, level: InspectionVerbosity) -> Self {
3749        self.verbosity = level;
3750        self
3751    }
3752
3753    /// Set the threshold (in milliseconds) above which requests are flagged as slow.
3754    #[must_use]
3755    pub fn slow_threshold_ms(mut self, ms: u64) -> Self {
3756        self.slow_threshold_ms = ms;
3757        self
3758    }
3759
3760    /// Set the maximum number of bytes to show in body previews.
3761    #[must_use]
3762    pub fn max_body_preview(mut self, max: usize) -> Self {
3763        self.max_body_preview = max;
3764        self
3765    }
3766
3767    /// Add a header name to the redaction set (case-insensitive).
3768    #[must_use]
3769    pub fn redact_header(mut self, name: impl Into<String>) -> Self {
3770        self.redact_headers.insert(name.into().to_ascii_lowercase());
3771        self
3772    }
3773
3774    /// Format a request body for display, with optional JSON pretty-printing.
3775    fn format_body_preview(&self, bytes: &[u8], content_type: Option<&[u8]>) -> Option<String> {
3776        if bytes.is_empty() || self.max_body_preview == 0 {
3777            return None;
3778        }
3779
3780        let is_json = content_type
3781            .and_then(|ct| std::str::from_utf8(ct).ok())
3782            .is_some_and(|ct| ct.contains("application/json"));
3783
3784        let limit = self.max_body_preview.min(bytes.len());
3785        let truncated = bytes.len() > self.max_body_preview;
3786
3787        match std::str::from_utf8(&bytes[..limit]) {
3788            Ok(text) => {
3789                if is_json {
3790                    // Attempt JSON pretty-printing on the full available text
3791                    if let Some(pretty) = try_pretty_json(text) {
3792                        let mut output = pretty;
3793                        if truncated {
3794                            output.push_str("\n     ... (truncated)");
3795                        }
3796                        return Some(output);
3797                    }
3798                }
3799                let mut output = text.to_string();
3800                if truncated {
3801                    output.push_str("...");
3802                }
3803                Some(output)
3804            }
3805            Err(_) => Some(format!("<{} bytes binary>", bytes.len())),
3806        }
3807    }
3808
3809    /// Format a response body for display.
3810    fn format_response_preview(
3811        &self,
3812        body: &crate::response::ResponseBody,
3813        content_type: Option<&[u8]>,
3814    ) -> Option<String> {
3815        match body {
3816            crate::response::ResponseBody::Empty => None,
3817            crate::response::ResponseBody::Bytes(bytes) => {
3818                self.format_body_preview(bytes, content_type)
3819            }
3820            crate::response::ResponseBody::Stream(_) => Some("<streaming body>".to_string()),
3821        }
3822    }
3823
3824    /// Build the formatted header block for display.
3825    fn format_inspection_headers<'a>(
3826        &self,
3827        headers: impl Iterator<Item = (&'a str, &'a [u8])>,
3828    ) -> String {
3829        let mut out = String::new();
3830        for (name, value) in headers {
3831            out.push_str("\n     ");
3832            out.push_str(name);
3833            out.push_str(": ");
3834
3835            let lowered = name.to_ascii_lowercase();
3836            if self.redact_headers.contains(&lowered) {
3837                out.push_str("[REDACTED]");
3838            } else {
3839                match std::str::from_utf8(value) {
3840                    Ok(text) => out.push_str(text),
3841                    Err(_) => out.push_str("<binary>"),
3842                }
3843            }
3844        }
3845        out
3846    }
3847
3848    /// Build the response header block from (String, Vec<u8>) pairs.
3849    fn format_response_inspection_headers(&self, headers: &[(String, Vec<u8>)]) -> String {
3850        self.format_inspection_headers(
3851            headers
3852                .iter()
3853                .map(|(name, value)| (name.as_str(), value.as_slice())),
3854        )
3855    }
3856}
3857
3858/// Extension type to store request start time for the inspection middleware.
3859#[derive(Debug, Clone)]
3860struct InspectionStart(Instant);
3861
3862impl Middleware for RequestInspectionMiddleware {
3863    fn before<'a>(
3864        &'a self,
3865        ctx: &'a RequestContext,
3866        req: &'a mut Request,
3867    ) -> BoxFuture<'a, ControlFlow> {
3868        let logger = RequestLogger::new(ctx, self.log_config.clone());
3869        req.insert_extension(InspectionStart(Instant::now()));
3870
3871        let method = req.method();
3872        let path = req.path();
3873        let query = req.query();
3874
3875        // Build the request line: "-->  GET /path?query"
3876        let mut request_line = format!("-->  {method} {path}");
3877        if let Some(q) = query {
3878            request_line.push('?');
3879            request_line.push_str(q);
3880        }
3881
3882        let body_size = body_len(req.body());
3883        if body_size > 0 {
3884            request_line.push_str(&format!(" ({body_size} bytes)"));
3885        }
3886
3887        match self.verbosity {
3888            InspectionVerbosity::Minimal => {
3889                logger.info(request_line);
3890            }
3891            InspectionVerbosity::Normal => {
3892                let headers = self.format_inspection_headers(req.headers().iter());
3893                logger.info(format!("{request_line}{headers}"));
3894            }
3895            InspectionVerbosity::Verbose => {
3896                let headers = self.format_inspection_headers(req.headers().iter());
3897                let content_type = req.headers().get("content-type");
3898                let body_preview = match req.body() {
3899                    Body::Empty => None,
3900                    Body::Bytes(bytes) => self.format_body_preview(bytes, content_type),
3901                    Body::Stream(_) => Some("<streaming body>".to_string()),
3902                };
3903
3904                let mut output = format!("{request_line}{headers}");
3905                if let Some(body) = body_preview {
3906                    output.push_str("\n     ");
3907                    // Indent multi-line body previews
3908                    output.push_str(&body.replace('\n', "\n     "));
3909                }
3910                logger.info(output);
3911            }
3912        }
3913
3914        Box::pin(async { ControlFlow::Continue })
3915    }
3916
3917    fn after<'a>(
3918        &'a self,
3919        ctx: &'a RequestContext,
3920        req: &'a Request,
3921        response: Response,
3922    ) -> BoxFuture<'a, Response> {
3923        let logger = RequestLogger::new(ctx, self.log_config.clone());
3924        let duration = req
3925            .get_extension::<InspectionStart>()
3926            .map(|start| start.0.elapsed())
3927            .unwrap_or_default();
3928
3929        let status = response.status();
3930        let duration_ms = duration.as_millis();
3931
3932        // Build the response line: "<--  200 OK (12ms)"
3933        let mut response_line = format!(
3934            "<--  {} {} ({duration_ms}ms)",
3935            status.as_u16(),
3936            status.canonical_reason(),
3937        );
3938
3939        // Flag slow requests
3940        if duration_ms >= u128::from(self.slow_threshold_ms) {
3941            response_line.push_str(" [SLOW]");
3942        }
3943
3944        match self.verbosity {
3945            InspectionVerbosity::Minimal => {
3946                if duration_ms >= u128::from(self.slow_threshold_ms) {
3947                    logger.warn(response_line);
3948                } else {
3949                    logger.info(response_line);
3950                }
3951            }
3952            InspectionVerbosity::Normal => {
3953                let headers = self.format_response_inspection_headers(response.headers());
3954                let output = format!("{response_line}{headers}");
3955                if duration_ms >= u128::from(self.slow_threshold_ms) {
3956                    logger.warn(output);
3957                } else {
3958                    logger.info(output);
3959                }
3960            }
3961            InspectionVerbosity::Verbose => {
3962                let headers = self.format_response_inspection_headers(response.headers());
3963
3964                // Find content-type from response headers for JSON detection
3965                let resp_content_type: Option<&[u8]> = response
3966                    .headers()
3967                    .iter()
3968                    .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
3969                    .map(|(_, value)| value.as_slice());
3970
3971                let body_preview =
3972                    self.format_response_preview(response.body_ref(), resp_content_type);
3973
3974                let mut output = format!("{response_line}{headers}");
3975                if let Some(body) = body_preview {
3976                    output.push_str("\n     ");
3977                    output.push_str(&body.replace('\n', "\n     "));
3978                }
3979
3980                if duration_ms >= u128::from(self.slow_threshold_ms) {
3981                    logger.warn(output);
3982                } else {
3983                    logger.info(output);
3984                }
3985            }
3986        }
3987
3988        Box::pin(async move { response })
3989    }
3990
3991    fn name(&self) -> &'static str {
3992        "RequestInspection"
3993    }
3994}
3995
3996/// Attempt to parse and pretty-print a JSON string.
3997///
3998/// Returns `None` if the input is not valid JSON. Uses a minimal
3999/// recursive formatter to avoid external dependencies.
4000fn try_pretty_json(input: &str) -> Option<String> {
4001    let trimmed = input.trim();
4002    if !trimmed.starts_with('{') && !trimmed.starts_with('[') {
4003        return None;
4004    }
4005
4006    // Validate it's actual JSON by attempting a parse, then pretty-format.
4007    let mut output = String::with_capacity(trimmed.len() * 2);
4008    if json_pretty_format(trimmed, &mut output).is_ok() {
4009        Some(output)
4010    } else {
4011        None
4012    }
4013}
4014
4015/// Minimal JSON pretty-formatter without external dependencies.
4016///
4017/// Handles objects, arrays, strings, numbers, booleans, and null.
4018/// Produces 2-space indented output.
4019fn json_pretty_format(input: &str, output: &mut String) -> Result<(), ()> {
4020    let bytes = input.as_bytes();
4021    let mut pos = 0;
4022    let mut indent: usize = 0;
4023    let mut in_string = false;
4024    let mut escape_next = false;
4025
4026    while pos < bytes.len() {
4027        let ch = bytes[pos] as char;
4028
4029        if escape_next {
4030            output.push(ch);
4031            escape_next = false;
4032            pos += 1;
4033            continue;
4034        }
4035
4036        if in_string {
4037            output.push(ch);
4038            if ch == '\\' {
4039                escape_next = true;
4040            } else if ch == '"' {
4041                in_string = false;
4042            }
4043            pos += 1;
4044            continue;
4045        }
4046
4047        match ch {
4048            '"' => {
4049                in_string = true;
4050                output.push('"');
4051            }
4052            '{' | '[' => {
4053                output.push(ch);
4054                // Peek ahead: if the next non-whitespace is the closing bracket, keep compact
4055                let peek = skip_whitespace(bytes, pos + 1);
4056                let closing = if ch == '{' { '}' } else { ']' };
4057                if peek < bytes.len() && bytes[peek] as char == closing {
4058                    output.push(closing);
4059                    pos = peek + 1;
4060                    continue;
4061                }
4062                indent += 1;
4063                output.push('\n');
4064                push_indent(output, indent);
4065            }
4066            '}' | ']' => {
4067                indent = indent.saturating_sub(1);
4068                output.push('\n');
4069                push_indent(output, indent);
4070                output.push(ch);
4071            }
4072            ':' => {
4073                output.push_str(": ");
4074            }
4075            ',' => {
4076                output.push(',');
4077                output.push('\n');
4078                push_indent(output, indent);
4079            }
4080            c if c.is_ascii_whitespace() => {
4081                // Skip whitespace outside strings
4082            }
4083            _ => {
4084                output.push(ch);
4085            }
4086        }
4087
4088        pos += 1;
4089    }
4090
4091    if in_string || indent != 0 {
4092        return Err(());
4093    }
4094
4095    Ok(())
4096}
4097
4098fn skip_whitespace(bytes: &[u8], start: usize) -> usize {
4099    let mut i = start;
4100    while i < bytes.len() && (bytes[i] as char).is_ascii_whitespace() {
4101        i += 1;
4102    }
4103    i
4104}
4105
4106fn push_indent(output: &mut String, level: usize) {
4107    for _ in 0..level {
4108        output.push_str("  ");
4109    }
4110}
4111
4112// ---------------------------------------------------------------------------
4113// End Request Inspection Middleware
4114// ---------------------------------------------------------------------------
4115
4116// ===========================================================================
4117// ETag Middleware
4118// ===========================================================================
4119
4120/// Configuration for ETag generation strategy.
4121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4122pub enum ETagMode {
4123    /// Automatically generate ETag from response body hash.
4124    /// Uses FNV-1a hash for fast, consistent ETag generation.
4125    Auto,
4126    /// Expect handler to set ETag manually. Middleware only handles
4127    /// conditional request logic (If-None-Match checking).
4128    Manual,
4129    /// Disable ETag handling entirely.
4130    Disabled,
4131}
4132
4133impl Default for ETagMode {
4134    fn default() -> Self {
4135        Self::Auto
4136    }
4137}
4138
4139/// Configuration for ETag middleware.
4140#[derive(Debug, Clone)]
4141pub struct ETagConfig {
4142    /// ETag generation mode.
4143    pub mode: ETagMode,
4144    /// Generate weak ETags (W/"...") instead of strong ETags.
4145    /// Weak ETags indicate semantic equivalence, allowing minor changes.
4146    pub weak: bool,
4147    /// Minimum response body size to generate ETag.
4148    /// Responses smaller than this won't get an ETag.
4149    pub min_size: usize,
4150}
4151
4152impl Default for ETagConfig {
4153    fn default() -> Self {
4154        Self {
4155            mode: ETagMode::Auto,
4156            weak: false,
4157            min_size: 0,
4158        }
4159    }
4160}
4161
4162impl ETagConfig {
4163    /// Create a new ETag configuration with default settings.
4164    #[must_use]
4165    pub fn new() -> Self {
4166        Self::default()
4167    }
4168
4169    /// Set the ETag generation mode.
4170    #[must_use]
4171    pub fn mode(mut self, mode: ETagMode) -> Self {
4172        self.mode = mode;
4173        self
4174    }
4175
4176    /// Enable weak ETags.
4177    #[must_use]
4178    pub fn weak(mut self, weak: bool) -> Self {
4179        self.weak = weak;
4180        self
4181    }
4182
4183    /// Set minimum body size for ETag generation.
4184    #[must_use]
4185    pub fn min_size(mut self, size: usize) -> Self {
4186        self.min_size = size;
4187        self
4188    }
4189}
4190
4191/// Middleware for ETag generation and conditional request handling.
4192///
4193/// Implements HTTP caching through ETags as defined in RFC 7232.
4194///
4195/// # Features
4196///
4197/// - **Automatic ETag generation**: Computes ETag from response body hash
4198/// - **If-None-Match handling**: Returns 304 Not Modified for GET/HEAD when ETag matches
4199/// - **Weak and strong ETags**: Configurable ETag strength
4200///
4201/// # Example
4202///
4203/// ```ignore
4204/// use fastapi_core::middleware::{ETagMiddleware, ETagConfig, ETagMode};
4205///
4206/// // Default: auto-generate strong ETags
4207/// let middleware = ETagMiddleware::new();
4208///
4209/// // With custom configuration
4210/// let middleware = ETagMiddleware::with_config(
4211///     ETagConfig::new()
4212///         .mode(ETagMode::Auto)
4213///         .weak(true)
4214///         .min_size(1024)
4215/// );
4216/// ```
4217///
4218/// # Conditional Request Flow
4219///
4220/// For GET/HEAD requests with `If-None-Match` header:
4221/// 1. Generate ETag for response body
4222/// 2. Compare with client's cached ETag
4223/// 3. If match: return 304 Not Modified (empty body)
4224/// 4. If no match: return full response with ETag header
4225///
4226/// # Note on If-Match
4227///
4228/// `If-Match` handling for PUT/PATCH/DELETE is typically done at the
4229/// application level since it requires knowledge of the current resource
4230/// state before the modification occurs.
4231pub struct ETagMiddleware {
4232    config: ETagConfig,
4233}
4234
4235impl Default for ETagMiddleware {
4236    fn default() -> Self {
4237        Self::new()
4238    }
4239}
4240
4241impl ETagMiddleware {
4242    /// Create ETag middleware with default configuration.
4243    #[must_use]
4244    pub fn new() -> Self {
4245        Self {
4246            config: ETagConfig::default(),
4247        }
4248    }
4249
4250    /// Create ETag middleware with custom configuration.
4251    #[must_use]
4252    pub fn with_config(config: ETagConfig) -> Self {
4253        Self { config }
4254    }
4255
4256    /// Generate an ETag from response body bytes using FNV-1a hash.
4257    ///
4258    /// FNV-1a is chosen for:
4259    /// - Speed: Very fast for small to medium data
4260    /// - Consistency: Deterministic output
4261    /// - Simplicity: No external dependencies
4262    fn generate_etag(data: &[u8], weak: bool) -> String {
4263        // FNV-1a 64-bit hash
4264        const FNV_OFFSET_BASIS: u64 = 0xcbf29ce484222325;
4265        const FNV_PRIME: u64 = 0x100000001b3;
4266
4267        let mut hash = FNV_OFFSET_BASIS;
4268        for &byte in data {
4269            hash ^= u64::from(byte);
4270            hash = hash.wrapping_mul(FNV_PRIME);
4271        }
4272
4273        // Format as quoted hex string
4274        if weak {
4275            format!("W/\"{:016x}\"", hash)
4276        } else {
4277            format!("\"{:016x}\"", hash)
4278        }
4279    }
4280
4281    /// Parse ETags from If-None-Match header value.
4282    ///
4283    /// Handles:
4284    /// - Single ETag: "abc123"
4285    /// - Multiple ETags: "abc123", "def456"
4286    /// - Wildcard: *
4287    /// - Weak ETags: W/"abc123"
4288    fn parse_if_none_match(value: &str) -> Vec<String> {
4289        let trimmed = value.trim();
4290
4291        // Handle wildcard
4292        if trimmed == "*" {
4293            return vec!["*".to_string()];
4294        }
4295
4296        let mut etags = Vec::new();
4297        let mut current = String::new();
4298        let mut in_quote = false;
4299        let mut prev_char = '\0';
4300
4301        for ch in trimmed.chars() {
4302            match ch {
4303                '"' if prev_char != '\\' => {
4304                    current.push(ch);
4305                    if in_quote {
4306                        // End of ETag value
4307                        let etag = current.trim().to_string();
4308                        if !etag.is_empty() {
4309                            etags.push(etag);
4310                        }
4311                        current.clear();
4312                    }
4313                    in_quote = !in_quote;
4314                }
4315                ',' if !in_quote => {
4316                    // ETag separator, already handled by quote closing
4317                    current.clear();
4318                }
4319                _ => {
4320                    current.push(ch);
4321                }
4322            }
4323            prev_char = ch;
4324        }
4325
4326        etags
4327    }
4328
4329    /// Check if two ETags match according to weak comparison rules.
4330    ///
4331    /// Weak comparison (for If-None-Match with GET/HEAD):
4332    /// - W/"a" matches W/"a"
4333    /// - W/"a" matches "a"
4334    /// - "a" matches W/"a"
4335    /// - "a" matches "a"
4336    fn etags_match_weak(etag1: &str, etag2: &str) -> bool {
4337        // Strip W/ prefix for weak comparison
4338        let e1 = Self::strip_weak_prefix(etag1);
4339        let e2 = Self::strip_weak_prefix(etag2);
4340        e1 == e2
4341    }
4342
4343    /// Strip the weak ETag prefix (W/) if present.
4344    fn strip_weak_prefix(s: &str) -> &str {
4345        if s.starts_with("W/") || s.starts_with("w/") {
4346            &s[2..]
4347        } else {
4348            s
4349        }
4350    }
4351
4352    /// Check if request method is cacheable (GET or HEAD).
4353    fn is_cacheable_method(method: crate::request::Method) -> bool {
4354        matches!(
4355            method,
4356            crate::request::Method::Get | crate::request::Method::Head
4357        )
4358    }
4359
4360    /// Get existing ETag from response headers.
4361    fn get_existing_etag(headers: &[(String, Vec<u8>)]) -> Option<String> {
4362        for (name, value) in headers {
4363            if name.eq_ignore_ascii_case("etag") {
4364                return std::str::from_utf8(value).ok().map(String::from);
4365            }
4366        }
4367        None
4368    }
4369}
4370
4371impl Middleware for ETagMiddleware {
4372    fn after<'a>(
4373        &'a self,
4374        _ctx: &'a RequestContext,
4375        req: &'a Request,
4376        response: Response,
4377    ) -> BoxFuture<'a, Response> {
4378        let config = self.config.clone();
4379
4380        Box::pin(async move {
4381            // Skip if disabled
4382            if config.mode == ETagMode::Disabled {
4383                return response;
4384            }
4385
4386            // Only handle cacheable methods
4387            if !Self::is_cacheable_method(req.method()) {
4388                return response;
4389            }
4390
4391            // Decompose response to work with parts
4392            let (status, headers, body) = response.into_parts();
4393
4394            // Check for existing ETag (for Manual mode or pre-set ETags)
4395            let existing_etag = Self::get_existing_etag(&headers);
4396
4397            // Get body bytes if available
4398            let body_bytes = match &body {
4399                crate::response::ResponseBody::Bytes(bytes) => Some(bytes.clone()),
4400                crate::response::ResponseBody::Empty => Some(Vec::new()),
4401                crate::response::ResponseBody::Stream(_) => None,
4402            };
4403
4404            // Determine the ETag to use
4405            let etag = if let Some(existing) = existing_etag {
4406                Some(existing)
4407            } else if config.mode == ETagMode::Auto {
4408                if let Some(ref bytes) = body_bytes {
4409                    if bytes.len() >= config.min_size {
4410                        Some(Self::generate_etag(bytes, config.weak))
4411                    } else {
4412                        None
4413                    }
4414                } else {
4415                    None
4416                }
4417            } else {
4418                None
4419            };
4420
4421            // Check If-None-Match header
4422            if let Some(ref etag_value) = etag {
4423                if let Some(if_none_match) = req.headers().get("if-none-match") {
4424                    if let Ok(value) = std::str::from_utf8(if_none_match) {
4425                        let client_etags = Self::parse_if_none_match(value);
4426
4427                        // Check for wildcard or matching ETag
4428                        let matches = client_etags.iter().any(|client_etag| {
4429                            client_etag == "*" || Self::etags_match_weak(client_etag, etag_value)
4430                        });
4431
4432                        if matches {
4433                            // Return 304 Not Modified with ETag header
4434                            return Response::with_status(
4435                                crate::response::StatusCode::NOT_MODIFIED,
4436                            )
4437                            .header("etag", etag_value.as_bytes().to_vec());
4438                        }
4439                    }
4440                }
4441            }
4442
4443            // Rebuild response with ETag header if we have one
4444            let mut new_response = Response::with_status(status)
4445                .body(body)
4446                .rebuild_with_headers(headers);
4447
4448            if let Some(etag_value) = etag {
4449                new_response = new_response.header("etag", etag_value.into_bytes());
4450            }
4451
4452            new_response
4453        })
4454    }
4455
4456    fn name(&self) -> &'static str {
4457        "ETagMiddleware"
4458    }
4459}
4460
4461// ===========================================================================
4462// HTTP Cache Control Middleware
4463// ===========================================================================
4464
4465/// Individual Cache-Control directives.
4466///
4467/// These directives control how responses are cached by browsers, proxies,
4468/// and CDNs. See RFC 7234 for full specification.
4469#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4470pub enum CacheDirective {
4471    /// Response may be stored by any cache.
4472    Public,
4473    /// Response may only be stored by browser cache (not shared caches like CDNs).
4474    Private,
4475    /// Response must not be stored by any cache.
4476    NoStore,
4477    /// Cache must validate with server before using cached response.
4478    NoCache,
4479    /// Cache must not transform the response (e.g., compress images).
4480    NoTransform,
4481    /// Cached response must be revalidated once it becomes stale.
4482    MustRevalidate,
4483    /// Like must-revalidate but only for shared caches.
4484    ProxyRevalidate,
4485    /// Response may be served stale if origin is unreachable.
4486    StaleIfError,
4487    /// Response may be served stale while revalidating in background.
4488    StaleWhileRevalidate,
4489    /// Only cache if explicitly told to (for shared caches).
4490    SMaxAge,
4491    /// Do not store response in persistent storage.
4492    OnlyIfCached,
4493    /// Indicates an immutable response that won't change during its freshness lifetime.
4494    Immutable,
4495}
4496
4497impl CacheDirective {
4498    /// Returns the directive as a Cache-Control header string fragment.
4499    fn as_str(self) -> &'static str {
4500        match self {
4501            Self::Public => "public",
4502            Self::Private => "private",
4503            Self::NoStore => "no-store",
4504            Self::NoCache => "no-cache",
4505            Self::NoTransform => "no-transform",
4506            Self::MustRevalidate => "must-revalidate",
4507            Self::ProxyRevalidate => "proxy-revalidate",
4508            Self::StaleIfError => "stale-if-error",
4509            Self::StaleWhileRevalidate => "stale-while-revalidate",
4510            Self::SMaxAge => "s-maxage",
4511            Self::OnlyIfCached => "only-if-cached",
4512            Self::Immutable => "immutable",
4513        }
4514    }
4515}
4516
4517/// Builder for constructing Cache-Control header values.
4518///
4519/// Provides a fluent API for building complex cache control policies.
4520///
4521/// # Example
4522///
4523/// ```ignore
4524/// use fastapi_core::middleware::CacheControlBuilder;
4525///
4526/// // Public, cacheable for 1 hour, must revalidate after
4527/// let cache = CacheControlBuilder::new()
4528///     .public()
4529///     .max_age_secs(3600)
4530///     .must_revalidate()
4531///     .build();
4532///
4533/// // Private, no caching
4534/// let no_cache = CacheControlBuilder::new()
4535///     .private()
4536///     .no_store()
4537///     .build();
4538///
4539/// // CDN-friendly: public with different browser/CDN TTLs
4540/// let cdn = CacheControlBuilder::new()
4541///     .public()
4542///     .max_age_secs(60)        // Browser caches for 1 minute
4543///     .s_maxage_secs(3600)     // CDN caches for 1 hour
4544///     .build();
4545/// ```
4546#[derive(Debug, Clone, Default)]
4547pub struct CacheControlBuilder {
4548    directives: Vec<CacheDirective>,
4549    max_age: Option<u32>,
4550    s_maxage: Option<u32>,
4551    stale_while_revalidate: Option<u32>,
4552    stale_if_error: Option<u32>,
4553}
4554
4555impl CacheControlBuilder {
4556    /// Create a new empty Cache-Control builder.
4557    #[must_use]
4558    pub fn new() -> Self {
4559        Self::default()
4560    }
4561
4562    /// Add the `public` directive - response may be cached by any cache.
4563    #[must_use]
4564    pub fn public(mut self) -> Self {
4565        self.directives.push(CacheDirective::Public);
4566        self
4567    }
4568
4569    /// Add the `private` directive - response may only be cached by browser.
4570    #[must_use]
4571    pub fn private(mut self) -> Self {
4572        self.directives.push(CacheDirective::Private);
4573        self
4574    }
4575
4576    /// Add the `no-store` directive - response must not be cached.
4577    #[must_use]
4578    pub fn no_store(mut self) -> Self {
4579        self.directives.push(CacheDirective::NoStore);
4580        self
4581    }
4582
4583    /// Add the `no-cache` directive - must revalidate before using cache.
4584    #[must_use]
4585    pub fn no_cache(mut self) -> Self {
4586        self.directives.push(CacheDirective::NoCache);
4587        self
4588    }
4589
4590    /// Add the `no-transform` directive - caches must not modify response.
4591    #[must_use]
4592    pub fn no_transform(mut self) -> Self {
4593        self.directives.push(CacheDirective::NoTransform);
4594        self
4595    }
4596
4597    /// Add the `must-revalidate` directive - cache must check origin when stale.
4598    #[must_use]
4599    pub fn must_revalidate(mut self) -> Self {
4600        self.directives.push(CacheDirective::MustRevalidate);
4601        self
4602    }
4603
4604    /// Add the `proxy-revalidate` directive - shared caches must check origin when stale.
4605    #[must_use]
4606    pub fn proxy_revalidate(mut self) -> Self {
4607        self.directives.push(CacheDirective::ProxyRevalidate);
4608        self
4609    }
4610
4611    /// Add the `immutable` directive - response won't change during freshness lifetime.
4612    #[must_use]
4613    pub fn immutable(mut self) -> Self {
4614        self.directives.push(CacheDirective::Immutable);
4615        self
4616    }
4617
4618    /// Set `max-age` directive - maximum time response is fresh (in seconds).
4619    #[must_use]
4620    pub fn max_age_secs(mut self, seconds: u32) -> Self {
4621        self.max_age = Some(seconds);
4622        self
4623    }
4624
4625    /// Set `max-age` directive from a Duration.
4626    #[must_use]
4627    pub fn max_age(self, duration: std::time::Duration) -> Self {
4628        self.max_age_secs(duration.as_secs() as u32)
4629    }
4630
4631    /// Set `s-maxage` directive - maximum time for shared caches (in seconds).
4632    #[must_use]
4633    pub fn s_maxage_secs(mut self, seconds: u32) -> Self {
4634        self.s_maxage = Some(seconds);
4635        self
4636    }
4637
4638    /// Set `s-maxage` directive from a Duration.
4639    #[must_use]
4640    pub fn s_maxage(self, duration: std::time::Duration) -> Self {
4641        self.s_maxage_secs(duration.as_secs() as u32)
4642    }
4643
4644    /// Set `stale-while-revalidate` directive - serve stale while revalidating (in seconds).
4645    #[must_use]
4646    pub fn stale_while_revalidate_secs(mut self, seconds: u32) -> Self {
4647        self.stale_while_revalidate = Some(seconds);
4648        self
4649    }
4650
4651    /// Set `stale-if-error` directive - serve stale if origin errors (in seconds).
4652    #[must_use]
4653    pub fn stale_if_error_secs(mut self, seconds: u32) -> Self {
4654        self.stale_if_error = Some(seconds);
4655        self
4656    }
4657
4658    /// Build the Cache-Control header value string.
4659    #[must_use]
4660    pub fn build(&self) -> String {
4661        let mut parts = Vec::new();
4662
4663        // Add directives
4664        for directive in &self.directives {
4665            parts.push(directive.as_str().to_string());
4666        }
4667
4668        // Add max-age
4669        if let Some(age) = self.max_age {
4670            parts.push(format!("max-age={age}"));
4671        }
4672
4673        // Add s-maxage
4674        if let Some(age) = self.s_maxage {
4675            parts.push(format!("s-maxage={age}"));
4676        }
4677
4678        // Add stale-while-revalidate
4679        if let Some(seconds) = self.stale_while_revalidate {
4680            parts.push(format!("stale-while-revalidate={seconds}"));
4681        }
4682
4683        // Add stale-if-error
4684        if let Some(seconds) = self.stale_if_error {
4685            parts.push(format!("stale-if-error={seconds}"));
4686        }
4687
4688        parts.join(", ")
4689    }
4690
4691    /// Check if this represents a no-cache policy.
4692    #[must_use]
4693    pub fn is_no_cache(&self) -> bool {
4694        self.directives.contains(&CacheDirective::NoStore)
4695            || self.directives.contains(&CacheDirective::NoCache)
4696    }
4697}
4698
4699/// Common cache control presets for typical use cases.
4700#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4701pub enum CachePreset {
4702    /// No caching: `no-store, no-cache, must-revalidate`
4703    NoCache,
4704    /// Private caching only: `private, max-age=0, must-revalidate`
4705    PrivateNoCache,
4706    /// Standard public caching: `public, max-age=3600`
4707    PublicOneHour,
4708    /// Long-term immutable: `public, max-age=31536000, immutable`
4709    Immutable,
4710    /// CDN-friendly with short browser TTL: `public, max-age=60, s-maxage=3600`
4711    CdnFriendly,
4712    /// Static assets: `public, max-age=86400`
4713    StaticAssets,
4714}
4715
4716impl CachePreset {
4717    /// Convert preset to Cache-Control header value.
4718    #[must_use]
4719    pub fn to_header_value(&self) -> String {
4720        match self {
4721            Self::NoCache => "no-store, no-cache, must-revalidate".to_string(),
4722            Self::PrivateNoCache => "private, max-age=0, must-revalidate".to_string(),
4723            Self::PublicOneHour => "public, max-age=3600".to_string(),
4724            Self::Immutable => "public, max-age=31536000, immutable".to_string(),
4725            Self::CdnFriendly => "public, max-age=60, s-maxage=3600".to_string(),
4726            Self::StaticAssets => "public, max-age=86400".to_string(),
4727        }
4728    }
4729
4730    /// Convert preset to a CacheControlBuilder for further customization.
4731    #[must_use]
4732    pub fn to_builder(&self) -> CacheControlBuilder {
4733        match self {
4734            Self::NoCache => CacheControlBuilder::new()
4735                .no_store()
4736                .no_cache()
4737                .must_revalidate(),
4738            Self::PrivateNoCache => CacheControlBuilder::new()
4739                .private()
4740                .max_age_secs(0)
4741                .must_revalidate(),
4742            Self::PublicOneHour => CacheControlBuilder::new().public().max_age_secs(3600),
4743            Self::Immutable => CacheControlBuilder::new()
4744                .public()
4745                .max_age_secs(31536000)
4746                .immutable(),
4747            Self::CdnFriendly => CacheControlBuilder::new()
4748                .public()
4749                .max_age_secs(60)
4750                .s_maxage_secs(3600),
4751            Self::StaticAssets => CacheControlBuilder::new().public().max_age_secs(86400),
4752        }
4753    }
4754}
4755
4756/// Configuration for the Cache Control middleware.
4757#[derive(Debug, Clone)]
4758pub struct CacheControlConfig {
4759    /// The Cache-Control header value to set.
4760    pub cache_control: String,
4761    /// Optional Vary header values for content negotiation.
4762    pub vary: Vec<String>,
4763    /// Whether to set Expires header (deprecated but still used).
4764    pub set_expires: bool,
4765    /// Whether to preserve existing Cache-Control headers.
4766    pub preserve_existing: bool,
4767    /// HTTP methods to apply caching to (default: GET, HEAD).
4768    pub methods: Vec<crate::request::Method>,
4769    /// Path patterns to match (empty = match all).
4770    pub path_patterns: Vec<String>,
4771    /// Status codes to cache (default: 200-299).
4772    pub cacheable_statuses: Vec<u16>,
4773}
4774
4775impl Default for CacheControlConfig {
4776    fn default() -> Self {
4777        Self {
4778            cache_control: CachePreset::NoCache.to_header_value(),
4779            vary: Vec::new(),
4780            set_expires: false,
4781            preserve_existing: true,
4782            methods: vec![crate::request::Method::Get, crate::request::Method::Head],
4783            path_patterns: Vec::new(),
4784            cacheable_statuses: (200..300).collect(),
4785        }
4786    }
4787}
4788
4789impl CacheControlConfig {
4790    /// Create a new configuration with the default no-cache policy.
4791    #[must_use]
4792    pub fn new() -> Self {
4793        Self::default()
4794    }
4795
4796    /// Create configuration from a preset.
4797    #[must_use]
4798    pub fn from_preset(preset: CachePreset) -> Self {
4799        Self {
4800            cache_control: preset.to_header_value(),
4801            ..Self::default()
4802        }
4803    }
4804
4805    /// Create configuration from a custom builder.
4806    #[must_use]
4807    pub fn from_builder(builder: CacheControlBuilder) -> Self {
4808        Self {
4809            cache_control: builder.build(),
4810            ..Self::default()
4811        }
4812    }
4813
4814    /// Set the Cache-Control header value.
4815    #[must_use]
4816    pub fn cache_control(mut self, value: impl Into<String>) -> Self {
4817        self.cache_control = value.into();
4818        self
4819    }
4820
4821    /// Add a Vary header value (for content negotiation).
4822    #[must_use]
4823    pub fn vary(mut self, header: impl Into<String>) -> Self {
4824        self.vary.push(header.into());
4825        self
4826    }
4827
4828    /// Add multiple Vary header values.
4829    #[must_use]
4830    pub fn vary_headers(mut self, headers: Vec<String>) -> Self {
4831        self.vary.extend(headers);
4832        self
4833    }
4834
4835    /// Enable setting the Expires header.
4836    #[must_use]
4837    pub fn with_expires(mut self, enable: bool) -> Self {
4838        self.set_expires = enable;
4839        self
4840    }
4841
4842    /// Whether to preserve existing Cache-Control headers.
4843    #[must_use]
4844    pub fn preserve_existing(mut self, preserve: bool) -> Self {
4845        self.preserve_existing = preserve;
4846        self
4847    }
4848
4849    /// Set the HTTP methods to apply caching to.
4850    #[must_use]
4851    pub fn methods(mut self, methods: Vec<crate::request::Method>) -> Self {
4852        self.methods = methods;
4853        self
4854    }
4855
4856    /// Set path patterns to match (glob-style).
4857    #[must_use]
4858    pub fn path_patterns(mut self, patterns: Vec<String>) -> Self {
4859        self.path_patterns = patterns;
4860        self
4861    }
4862
4863    /// Set cacheable status codes.
4864    #[must_use]
4865    pub fn cacheable_statuses(mut self, statuses: Vec<u16>) -> Self {
4866        self.cacheable_statuses = statuses;
4867        self
4868    }
4869}
4870
4871/// Middleware for setting HTTP cache control headers.
4872///
4873/// This middleware adds Cache-Control, Vary, and optionally Expires headers
4874/// to responses. It supports various caching strategies from no-cache to
4875/// aggressive caching for static assets.
4876///
4877/// # Features
4878///
4879/// - **Cache-Control directives**: Full support for RFC 7234 directives
4880/// - **Vary header**: Content negotiation support for Accept-Encoding, Accept-Language, etc.
4881/// - **Expires header**: Optional legacy header support
4882/// - **Per-route configuration**: Apply different policies via middleware stacks
4883/// - **Method filtering**: Only cache GET/HEAD by default
4884/// - **Status filtering**: Only cache successful responses
4885///
4886/// # Example
4887///
4888/// ```ignore
4889/// use fastapi_core::middleware::{CacheControlMiddleware, CacheControlConfig, CachePreset};
4890///
4891/// // No caching for API responses (default)
4892/// let api_cache = CacheControlMiddleware::new();
4893///
4894/// // Public caching for static assets
4895/// let static_cache = CacheControlMiddleware::with_preset(CachePreset::StaticAssets);
4896///
4897/// // Custom caching with Vary header
4898/// let custom_cache = CacheControlMiddleware::with_config(
4899///     CacheControlConfig::from_preset(CachePreset::PublicOneHour)
4900///         .vary("Accept-Encoding")
4901///         .vary("Accept-Language")
4902///         .with_expires(true)
4903/// );
4904///
4905/// // CDN-friendly caching
4906/// let cdn_cache = CacheControlMiddleware::with_preset(CachePreset::CdnFriendly);
4907/// ```
4908///
4909/// # Response Headers Set
4910///
4911/// | Header | Description |
4912/// |--------|-------------|
4913/// | `Cache-Control` | Main caching directive |
4914/// | `Vary` | Headers that affect caching |
4915/// | `Expires` | Legacy expiration (if enabled) |
4916///
4917pub struct CacheControlMiddleware {
4918    config: CacheControlConfig,
4919}
4920
4921impl Default for CacheControlMiddleware {
4922    fn default() -> Self {
4923        Self::new()
4924    }
4925}
4926
4927impl CacheControlMiddleware {
4928    /// Create middleware with default no-cache policy.
4929    ///
4930    /// This is the safest default - no caching unless explicitly configured.
4931    #[must_use]
4932    pub fn new() -> Self {
4933        Self {
4934            config: CacheControlConfig::default(),
4935        }
4936    }
4937
4938    /// Create middleware with a preset caching policy.
4939    #[must_use]
4940    pub fn with_preset(preset: CachePreset) -> Self {
4941        Self {
4942            config: CacheControlConfig::from_preset(preset),
4943        }
4944    }
4945
4946    /// Create middleware with custom configuration.
4947    #[must_use]
4948    pub fn with_config(config: CacheControlConfig) -> Self {
4949        Self { config }
4950    }
4951
4952    /// Check if the request method is cacheable.
4953    fn is_cacheable_method(&self, method: crate::request::Method) -> bool {
4954        self.config.methods.contains(&method)
4955    }
4956
4957    /// Check if the response status is cacheable.
4958    fn is_cacheable_status(&self, status: u16) -> bool {
4959        self.config.cacheable_statuses.contains(&status)
4960    }
4961
4962    /// Check if the path matches any configured patterns.
4963    fn matches_path(&self, path: &str) -> bool {
4964        if self.config.path_patterns.is_empty() {
4965            return true; // Match all if no patterns configured
4966        }
4967
4968        for pattern in &self.config.path_patterns {
4969            if path_matches_pattern(path, pattern) {
4970                return true;
4971            }
4972        }
4973        false
4974    }
4975
4976    /// Check if response already has a Cache-Control header.
4977    fn has_cache_control(headers: &[(String, Vec<u8>)]) -> bool {
4978        headers
4979            .iter()
4980            .any(|(name, _)| name.eq_ignore_ascii_case("cache-control"))
4981    }
4982
4983    /// Calculate Expires date from max-age value.
4984    fn calculate_expires(cache_control: &str) -> Option<String> {
4985        // Extract max-age value if present
4986        for directive in cache_control.split(',') {
4987            let directive = directive.trim();
4988            if directive.starts_with("max-age=") {
4989                if let Ok(seconds) = directive[8..].parse::<u64>() {
4990                    // Calculate expiration time
4991                    let now = std::time::SystemTime::now();
4992                    if let Some(expires) = now.checked_add(std::time::Duration::from_secs(seconds))
4993                    {
4994                        return Some(format_http_date(expires));
4995                    }
4996                }
4997            }
4998        }
4999        None
5000    }
5001}
5002
5003/// Simple path pattern matching (supports * wildcard).
5004fn path_matches_pattern(path: &str, pattern: &str) -> bool {
5005    if pattern == "*" {
5006        return true;
5007    }
5008
5009    if pattern.contains('*') {
5010        // Simple wildcard matching
5011        let parts: Vec<&str> = pattern.split('*').collect();
5012        if parts.len() == 2 {
5013            let (prefix, suffix) = (parts[0], parts[1]);
5014            return path.starts_with(prefix) && path.ends_with(suffix);
5015        }
5016        // For more complex patterns, do a simple contains check
5017        let fixed_parts: Vec<&str> = pattern.split('*').filter(|s| !s.is_empty()).collect();
5018        let mut remaining = path;
5019        for part in fixed_parts {
5020            if let Some(pos) = remaining.find(part) {
5021                remaining = &remaining[pos + part.len()..];
5022            } else {
5023                return false;
5024            }
5025        }
5026        true
5027    } else {
5028        path == pattern
5029    }
5030}
5031
5032/// Format a SystemTime as an HTTP date (RFC 7231).
5033fn format_http_date(time: std::time::SystemTime) -> String {
5034    // Use UNIX_EPOCH to calculate duration
5035    match time.duration_since(std::time::UNIX_EPOCH) {
5036        Ok(duration) => {
5037            // Calculate date components
5038            let secs = duration.as_secs();
5039            // Days since epoch
5040            let days = secs / 86400;
5041            let remaining_secs = secs % 86400;
5042            let hours = remaining_secs / 3600;
5043            let minutes = (remaining_secs % 3600) / 60;
5044            let seconds = remaining_secs % 60;
5045
5046            // Calculate day of week (Jan 1, 1970 was Thursday = 4)
5047            let day_of_week = ((days + 4) % 7) as usize;
5048            let day_names = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"];
5049
5050            // Calculate date (simplified - doesn't account for leap years perfectly but good enough)
5051            let (year, month, day) = days_to_date(days);
5052            let month_names = [
5053                "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec",
5054            ];
5055
5056            format!(
5057                "{}, {:02} {} {} {:02}:{:02}:{:02} GMT",
5058                day_names[day_of_week],
5059                day,
5060                month_names[(month - 1) as usize],
5061                year,
5062                hours,
5063                minutes,
5064                seconds
5065            )
5066        }
5067        Err(_) => "Thu, 01 Jan 1970 00:00:00 GMT".to_string(),
5068    }
5069}
5070
5071/// Convert days since UNIX epoch to (year, month, day).
5072fn days_to_date(days: u64) -> (u64, u64, u64) {
5073    // Simplified algorithm - works for dates 1970-2099
5074    let mut remaining_days = days;
5075    let mut year = 1970u64;
5076
5077    loop {
5078        let days_in_year = if is_leap_year(year) { 366 } else { 365 };
5079        if remaining_days < days_in_year {
5080            break;
5081        }
5082        remaining_days -= days_in_year;
5083        year += 1;
5084    }
5085
5086    let leap = is_leap_year(year);
5087    let month_days: [u64; 12] = if leap {
5088        [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
5089    } else {
5090        [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
5091    };
5092
5093    let mut month = 1u64;
5094    for &days_in_month in &month_days {
5095        if remaining_days < days_in_month {
5096            break;
5097        }
5098        remaining_days -= days_in_month;
5099        month += 1;
5100    }
5101
5102    (year, month, remaining_days + 1)
5103}
5104
5105/// Check if a year is a leap year.
5106fn is_leap_year(year: u64) -> bool {
5107    (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
5108}
5109
5110impl Middleware for CacheControlMiddleware {
5111    fn after<'a>(
5112        &'a self,
5113        _ctx: &'a RequestContext,
5114        req: &'a Request,
5115        response: Response,
5116    ) -> BoxFuture<'a, Response> {
5117        let config = self.config.clone();
5118
5119        Box::pin(async move {
5120            // Check if this request/response is cacheable
5121            if !self.is_cacheable_method(req.method()) {
5122                return response;
5123            }
5124
5125            if !self.is_cacheable_status(response.status().as_u16()) {
5126                return response;
5127            }
5128
5129            if !self.matches_path(req.path()) {
5130                return response;
5131            }
5132
5133            // Decompose response to modify headers
5134            let (status, mut headers, body) = response.into_parts();
5135
5136            // Check for existing Cache-Control header
5137            if config.preserve_existing && Self::has_cache_control(&headers) {
5138                // Reconstruct and return unchanged
5139                let mut resp = Response::with_status(status);
5140                for (name, value) in headers {
5141                    resp = resp.header(name, value);
5142                }
5143                return resp.body(body);
5144            }
5145
5146            // Add Cache-Control header
5147            headers.push((
5148                "Cache-Control".to_string(),
5149                config.cache_control.as_bytes().to_vec(),
5150            ));
5151
5152            // Add Vary header if configured
5153            if !config.vary.is_empty() {
5154                let vary_value = config.vary.join(", ");
5155                headers.push(("Vary".to_string(), vary_value.into_bytes()));
5156            }
5157
5158            // Add Expires header if configured
5159            if config.set_expires {
5160                if let Some(expires) = Self::calculate_expires(&config.cache_control) {
5161                    headers.push(("Expires".to_string(), expires.into_bytes()));
5162                }
5163            }
5164
5165            // Reconstruct response
5166            let mut resp = Response::with_status(status);
5167            for (name, value) in headers {
5168                resp = resp.header(name, value);
5169            }
5170            resp.body(body)
5171        })
5172    }
5173
5174    fn name(&self) -> &'static str {
5175        "CacheControlMiddleware"
5176    }
5177}
5178
5179// ===========================================================================
5180// End Cache Control Middleware
5181// ===========================================================================
5182
5183// ===========================================================================
5184// TRACE Method Rejection Middleware (Security)
5185// ===========================================================================
5186
5187/// Middleware that rejects HTTP TRACE requests to prevent Cross-Site Tracing (XST) attacks.
5188///
5189/// The HTTP TRACE method echoes the request back to the client, which can be exploited
5190/// in XSS attacks to steal sensitive headers like Authorization or cookies.
5191///
5192/// # Security Rationale
5193///
5194/// - TRACE can expose Authorization headers via XSS attacks
5195/// - No legitimate use case in modern APIs
5196/// - OWASP recommends disabling TRACE
5197///
5198/// # Example
5199///
5200/// ```ignore
5201/// use fastapi_core::middleware::TraceRejectionMiddleware;
5202///
5203/// let app = App::builder()
5204///     .middleware(TraceRejectionMiddleware::new())
5205///     .build();
5206/// ```
5207///
5208/// # Behavior
5209///
5210/// - Returns 405 Method Not Allowed for all TRACE requests
5211/// - Logs TRACE attempts as security events (when log_attempts is true)
5212/// - Cannot be disabled per-route (intentionally)
5213#[derive(Debug, Clone)]
5214pub struct TraceRejectionMiddleware {
5215    /// Whether to log TRACE attempts as security events.
5216    log_attempts: bool,
5217}
5218
5219impl Default for TraceRejectionMiddleware {
5220    fn default() -> Self {
5221        Self::new()
5222    }
5223}
5224
5225impl TraceRejectionMiddleware {
5226    /// Create a new TRACE rejection middleware with default settings.
5227    ///
5228    /// By default, logging of TRACE attempts is enabled.
5229    #[must_use]
5230    pub fn new() -> Self {
5231        Self { log_attempts: true }
5232    }
5233
5234    /// Configure whether to log TRACE attempts.
5235    ///
5236    /// When enabled, each TRACE request is logged as a security event
5237    /// including the remote IP (if available) and request path.
5238    #[must_use]
5239    pub fn log_attempts(mut self, log: bool) -> Self {
5240        self.log_attempts = log;
5241        self
5242    }
5243
5244    /// Create a response for rejected TRACE requests.
5245    fn rejection_response(path: &str) -> Response {
5246        let body = format!(
5247            r#"{{"detail":"HTTP TRACE method is not allowed","path":"{}"}}"#,
5248            path.replace('"', "\\\"")
5249        );
5250        Response::with_status(crate::response::StatusCode::METHOD_NOT_ALLOWED)
5251            .header("Content-Type", b"application/json".to_vec())
5252            .header(
5253                "Allow",
5254                b"GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD".to_vec(),
5255            )
5256            .body(crate::response::ResponseBody::Bytes(body.into_bytes()))
5257    }
5258}
5259
5260impl Middleware for TraceRejectionMiddleware {
5261    fn before<'a>(
5262        &'a self,
5263        _ctx: &'a RequestContext,
5264        req: &'a mut Request,
5265    ) -> BoxFuture<'a, ControlFlow> {
5266        Box::pin(async move {
5267            if req.method() == crate::request::Method::Trace {
5268                if self.log_attempts {
5269                    // Log as security event
5270                    let path = req.path();
5271                    let remote_ip = req
5272                        .headers()
5273                        .get("X-Forwarded-For")
5274                        .or_else(|| req.headers().get("X-Real-IP"))
5275                        .map(|v| String::from_utf8_lossy(v).to_string())
5276                        .unwrap_or_else(|| "unknown".to_string());
5277
5278                    eprintln!(
5279                        "[SECURITY] TRACE request blocked: path={}, remote_ip={}",
5280                        path, remote_ip
5281                    );
5282                }
5283
5284                return ControlFlow::Break(Self::rejection_response(req.path()));
5285            }
5286
5287            ControlFlow::Continue
5288        })
5289    }
5290
5291    fn name(&self) -> &'static str {
5292        "TraceRejection"
5293    }
5294}
5295
5296// ===========================================================================
5297// End TRACE Rejection Middleware
5298// ===========================================================================
5299
5300// ===========================================================================
5301// HTTPS Redirect and HSTS Middleware (Security)
5302// ===========================================================================
5303
5304/// Configuration for HTTPS redirect behavior.
5305#[derive(Debug, Clone)]
5306#[allow(clippy::struct_excessive_bools)]
5307pub struct HttpsRedirectConfig {
5308    /// Enable HTTP to HTTPS redirects.
5309    pub redirect_enabled: bool,
5310    /// Use permanent (301) or temporary (307) redirects.
5311    pub permanent_redirect: bool,
5312    /// HSTS max-age in seconds (0 = disabled).
5313    pub hsts_max_age_secs: u64,
5314    /// Include subdomains in HSTS.
5315    pub hsts_include_subdomains: bool,
5316    /// Enable HSTS preload.
5317    pub hsts_preload: bool,
5318    /// Paths to exclude from redirect (e.g., health checks).
5319    pub exclude_paths: Vec<String>,
5320    /// Port for HTTPS (default 443).
5321    pub https_port: u16,
5322}
5323
5324impl Default for HttpsRedirectConfig {
5325    fn default() -> Self {
5326        Self {
5327            redirect_enabled: true,
5328            permanent_redirect: true,      // 301
5329            hsts_max_age_secs: 31_536_000, // 1 year
5330            hsts_include_subdomains: false,
5331            hsts_preload: false,
5332            exclude_paths: Vec::new(),
5333            https_port: 443,
5334        }
5335    }
5336}
5337
5338/// Middleware that redirects HTTP requests to HTTPS and sets HSTS headers.
5339///
5340/// This middleware provides two critical security features:
5341///
5342/// 1. **HTTP to HTTPS Redirect**: Automatically redirects insecure HTTP requests
5343///    to their HTTPS equivalents, ensuring all traffic is encrypted.
5344///
5345/// 2. **HSTS (Strict Transport Security)**: Adds the `Strict-Transport-Security`
5346///    header to HTTPS responses, instructing browsers to always use HTTPS.
5347///
5348/// # Proxy Awareness
5349///
5350/// The middleware respects the `X-Forwarded-Proto` header, so it works correctly
5351/// behind reverse proxies like nginx or HAProxy. If the proxy sets this header
5352/// to "https", the request is treated as secure.
5353///
5354/// # Example
5355///
5356/// ```ignore
5357/// use fastapi_core::middleware::HttpsRedirectMiddleware;
5358///
5359/// let app = App::builder()
5360///     .middleware(HttpsRedirectMiddleware::new()
5361///         .hsts_max_age_secs(31536000)  // 1 year
5362///         .include_subdomains(true)
5363///         .preload(true)
5364///         .exclude_path("/health")
5365///         .exclude_path("/readiness"))
5366///     .build();
5367/// ```
5368///
5369/// # Configuration Options
5370///
5371/// - `redirect_enabled`: Enable/disable redirects (default: true)
5372/// - `permanent_redirect`: Use 301 (true) or 307 (false) redirects
5373/// - `hsts_max_age_secs`: HSTS max-age value in seconds
5374/// - `include_subdomains`: Apply HSTS to all subdomains
5375/// - `preload`: Mark site for HSTS preload list
5376/// - `exclude_path`: Paths that should remain accessible over HTTP
5377#[derive(Debug, Clone)]
5378pub struct HttpsRedirectMiddleware {
5379    config: HttpsRedirectConfig,
5380}
5381
5382impl Default for HttpsRedirectMiddleware {
5383    fn default() -> Self {
5384        Self::new()
5385    }
5386}
5387
5388impl HttpsRedirectMiddleware {
5389    /// Create a new HTTPS redirect middleware with default settings.
5390    #[must_use]
5391    pub fn new() -> Self {
5392        Self {
5393            config: HttpsRedirectConfig::default(),
5394        }
5395    }
5396
5397    /// Enable or disable HTTP to HTTPS redirects.
5398    #[must_use]
5399    pub fn redirect_enabled(mut self, enabled: bool) -> Self {
5400        self.config.redirect_enabled = enabled;
5401        self
5402    }
5403
5404    /// Use permanent (301) redirects instead of temporary (307).
5405    ///
5406    /// Default is true (permanent redirects).
5407    #[must_use]
5408    pub fn permanent_redirect(mut self, permanent: bool) -> Self {
5409        self.config.permanent_redirect = permanent;
5410        self
5411    }
5412
5413    /// Set the HSTS max-age in seconds.
5414    ///
5415    /// Set to 0 to disable HSTS header.
5416    /// Default is 31536000 (1 year).
5417    #[must_use]
5418    pub fn hsts_max_age_secs(mut self, secs: u64) -> Self {
5419        self.config.hsts_max_age_secs = secs;
5420        self
5421    }
5422
5423    /// Include subdomains in HSTS policy.
5424    #[must_use]
5425    pub fn include_subdomains(mut self, include: bool) -> Self {
5426        self.config.hsts_include_subdomains = include;
5427        self
5428    }
5429
5430    /// Enable HSTS preload.
5431    ///
5432    /// Only enable this if you're ready to submit your site to the
5433    /// HSTS preload list at hstspreload.org.
5434    #[must_use]
5435    pub fn preload(mut self, preload: bool) -> Self {
5436        self.config.hsts_preload = preload;
5437        self
5438    }
5439
5440    /// Add a path to exclude from redirects.
5441    ///
5442    /// Use this for health check endpoints that need to remain
5443    /// accessible over HTTP for load balancer probes.
5444    #[must_use]
5445    pub fn exclude_path(mut self, path: impl Into<String>) -> Self {
5446        self.config.exclude_paths.push(path.into());
5447        self
5448    }
5449
5450    /// Set multiple excluded paths at once.
5451    #[must_use]
5452    pub fn exclude_paths(mut self, paths: Vec<String>) -> Self {
5453        self.config.exclude_paths = paths;
5454        self
5455    }
5456
5457    /// Set the HTTPS port (default 443).
5458    #[must_use]
5459    pub fn https_port(mut self, port: u16) -> Self {
5460        self.config.https_port = port;
5461        self
5462    }
5463
5464    /// Check if the request is using HTTPS.
5465    ///
5466    /// This checks both the scheme and the X-Forwarded-Proto header
5467    /// for proxy-aware detection.
5468    fn is_secure(&self, req: &Request) -> bool {
5469        // Check X-Forwarded-Proto header first (for reverse proxy)
5470        if let Some(proto) = req.headers().get("X-Forwarded-Proto") {
5471            return proto.eq_ignore_ascii_case(b"https");
5472        }
5473
5474        // Check X-Forwarded-Ssl header (alternative)
5475        if let Some(ssl) = req.headers().get("X-Forwarded-Ssl") {
5476            return ssl.eq_ignore_ascii_case(b"on");
5477        }
5478
5479        // Check Front-End-Https header (Microsoft IIS)
5480        if let Some(https) = req.headers().get("Front-End-Https") {
5481            return https.eq_ignore_ascii_case(b"on");
5482        }
5483
5484        // No forwarding headers - assume HTTP for now
5485        // In a real server, we'd check the connection's TLS status
5486        false
5487    }
5488
5489    /// Check if a path should be excluded from redirects.
5490    fn is_excluded(&self, path: &str) -> bool {
5491        self.config
5492            .exclude_paths
5493            .iter()
5494            .any(|p| path.starts_with(p))
5495    }
5496
5497    /// Build the HSTS header value.
5498    fn build_hsts_header(&self) -> Option<Vec<u8>> {
5499        if self.config.hsts_max_age_secs == 0 {
5500            return None;
5501        }
5502
5503        let mut value = format!("max-age={}", self.config.hsts_max_age_secs);
5504
5505        if self.config.hsts_include_subdomains {
5506            value.push_str("; includeSubDomains");
5507        }
5508
5509        if self.config.hsts_preload {
5510            value.push_str("; preload");
5511        }
5512
5513        Some(value.into_bytes())
5514    }
5515
5516    /// Build the redirect URL.
5517    fn build_redirect_url(&self, req: &Request) -> String {
5518        let host = req
5519            .headers()
5520            .get("Host")
5521            .map(|h| String::from_utf8_lossy(h).to_string())
5522            .unwrap_or_else(|| "localhost".to_string());
5523
5524        // Remove port from host if present
5525        let host_without_port = host.split(':').next().unwrap_or(&host);
5526
5527        let path = req.path();
5528        let query = req.query();
5529
5530        if self.config.https_port == 443 {
5531            match query {
5532                Some(q) => format!("https://{}{}?{}", host_without_port, path, q),
5533                None => format!("https://{}{}", host_without_port, path),
5534            }
5535        } else {
5536            match query {
5537                Some(q) => format!(
5538                    "https://{}:{}{}?{}",
5539                    host_without_port, self.config.https_port, path, q
5540                ),
5541                None => format!(
5542                    "https://{}:{}{}",
5543                    host_without_port, self.config.https_port, path
5544                ),
5545            }
5546        }
5547    }
5548}
5549
5550impl Middleware for HttpsRedirectMiddleware {
5551    fn before<'a>(
5552        &'a self,
5553        _ctx: &'a RequestContext,
5554        req: &'a mut Request,
5555    ) -> BoxFuture<'a, ControlFlow> {
5556        Box::pin(async move {
5557            // Skip if redirects are disabled
5558            if !self.config.redirect_enabled {
5559                return ControlFlow::Continue;
5560            }
5561
5562            // Skip if already HTTPS
5563            if self.is_secure(req) {
5564                return ControlFlow::Continue;
5565            }
5566
5567            // Skip excluded paths (e.g., health checks)
5568            if self.is_excluded(req.path()) {
5569                return ControlFlow::Continue;
5570            }
5571
5572            // Build redirect URL
5573            let redirect_url = self.build_redirect_url(req);
5574
5575            // Choose status code
5576            let status = if self.config.permanent_redirect {
5577                crate::response::StatusCode::MOVED_PERMANENTLY
5578            } else {
5579                crate::response::StatusCode::TEMPORARY_REDIRECT
5580            };
5581
5582            // Create redirect response
5583            let response = Response::with_status(status)
5584                .header("Location", redirect_url.into_bytes())
5585                .header("Content-Type", b"text/plain".to_vec())
5586                .body(crate::response::ResponseBody::Bytes(
5587                    b"Redirecting to HTTPS...".to_vec(),
5588                ));
5589
5590            ControlFlow::Break(response)
5591        })
5592    }
5593
5594    fn after<'a>(
5595        &'a self,
5596        _ctx: &'a RequestContext,
5597        req: &'a Request,
5598        response: Response,
5599    ) -> BoxFuture<'a, Response> {
5600        Box::pin(async move {
5601            // Only add HSTS to secure responses
5602            if !self.is_secure(req) {
5603                return response;
5604            }
5605
5606            // Add HSTS header if configured
5607            if let Some(hsts_value) = self.build_hsts_header() {
5608                response.header("Strict-Transport-Security", hsts_value)
5609            } else {
5610                response
5611            }
5612        })
5613    }
5614
5615    fn name(&self) -> &'static str {
5616        "HttpsRedirect"
5617    }
5618}
5619
5620// ===========================================================================
5621// End HTTPS Redirect Middleware
5622// ===========================================================================
5623
5624// ===========================================================================
5625// Response Interceptors and Transformers
5626// ===========================================================================
5627//
5628// This section provides a simplified abstraction for response-only processing.
5629// Unlike full Middleware, ResponseInterceptor only handles post-handler processing,
5630// making it lighter weight and easier to compose for response transformations.
5631
5632/// A response interceptor that processes responses after handler execution.
5633///
5634/// Unlike the full [`Middleware`] trait, `ResponseInterceptor` only handles
5635/// the post-handler phase, making it simpler to implement for response-only
5636/// processing like:
5637/// - Adding timing headers
5638/// - Transforming response bodies
5639/// - Adding debug information
5640/// - Logging response details
5641///
5642/// # Example
5643///
5644/// ```ignore
5645/// use fastapi_core::middleware::{ResponseInterceptor, ResponseInterceptorContext};
5646///
5647/// struct TimingInterceptor {
5648///     start_time: Instant,
5649/// }
5650///
5651/// impl ResponseInterceptor for TimingInterceptor {
5652///     fn intercept(&self, ctx: &ResponseInterceptorContext, response: Response) -> Response {
5653///         let elapsed = self.start_time.elapsed();
5654///         response.header("X-Response-Time", format!("{}ms", elapsed.as_millis()).into_bytes())
5655///     }
5656/// }
5657/// ```
5658pub trait ResponseInterceptor: Send + Sync {
5659    /// Process a response after the handler has executed.
5660    ///
5661    /// # Parameters
5662    ///
5663    /// - `ctx`: Context containing request information and timing data
5664    /// - `response`: The response from the handler or previous interceptors
5665    ///
5666    /// # Returns
5667    ///
5668    /// The modified response to pass to the next interceptor or return to client.
5669    fn intercept<'a>(
5670        &'a self,
5671        ctx: &'a ResponseInterceptorContext<'a>,
5672        response: Response,
5673    ) -> BoxFuture<'a, Response>;
5674
5675    /// Returns the interceptor name for debugging and logging.
5676    fn name(&self) -> &'static str {
5677        std::any::type_name::<Self>()
5678    }
5679}
5680
5681/// Context provided to response interceptors.
5682///
5683/// Contains information about the original request and timing data
5684/// that interceptors might need to process responses.
5685#[derive(Debug)]
5686pub struct ResponseInterceptorContext<'a> {
5687    /// The original request (read-only).
5688    pub request: &'a Request,
5689    /// When the request processing started.
5690    pub start_time: Instant,
5691    /// The request context for cancellation support.
5692    pub request_ctx: &'a RequestContext,
5693}
5694
5695impl<'a> ResponseInterceptorContext<'a> {
5696    /// Create a new interceptor context.
5697    pub fn new(request: &'a Request, request_ctx: &'a RequestContext, start_time: Instant) -> Self {
5698        Self {
5699            request,
5700            start_time,
5701            request_ctx,
5702        }
5703    }
5704
5705    /// Get the elapsed time since request processing started.
5706    pub fn elapsed(&self) -> std::time::Duration {
5707        self.start_time.elapsed()
5708    }
5709
5710    /// Get the elapsed time in milliseconds.
5711    pub fn elapsed_ms(&self) -> u128 {
5712        self.start_time.elapsed().as_millis()
5713    }
5714}
5715
5716/// A stack of response interceptors that run in order.
5717///
5718/// Interceptors are executed in registration order (first registered, first run).
5719/// Each interceptor receives the response from the previous one and can modify it.
5720///
5721/// # Example
5722///
5723/// ```ignore
5724/// let mut stack = ResponseInterceptorStack::new();
5725/// stack.push(TimingInterceptor);
5726/// stack.push(DebugHeadersInterceptor::new());
5727///
5728/// let response = stack.process(&ctx, response).await;
5729/// ```
5730#[derive(Default)]
5731pub struct ResponseInterceptorStack {
5732    interceptors: Vec<Arc<dyn ResponseInterceptor>>,
5733}
5734
5735impl ResponseInterceptorStack {
5736    /// Create an empty interceptor stack.
5737    #[must_use]
5738    pub fn new() -> Self {
5739        Self {
5740            interceptors: Vec::new(),
5741        }
5742    }
5743
5744    /// Create a stack with pre-allocated capacity.
5745    #[must_use]
5746    pub fn with_capacity(capacity: usize) -> Self {
5747        Self {
5748            interceptors: Vec::with_capacity(capacity),
5749        }
5750    }
5751
5752    /// Add an interceptor to the end of the stack.
5753    pub fn push<I: ResponseInterceptor + 'static>(&mut self, interceptor: I) {
5754        self.interceptors.push(Arc::new(interceptor));
5755    }
5756
5757    /// Add an Arc-wrapped interceptor.
5758    pub fn push_arc(&mut self, interceptor: Arc<dyn ResponseInterceptor>) {
5759        self.interceptors.push(interceptor);
5760    }
5761
5762    /// Return the number of interceptors in the stack.
5763    #[must_use]
5764    pub fn len(&self) -> usize {
5765        self.interceptors.len()
5766    }
5767
5768    /// Return true if the stack is empty.
5769    #[must_use]
5770    pub fn is_empty(&self) -> bool {
5771        self.interceptors.is_empty()
5772    }
5773
5774    /// Process a response through all interceptors.
5775    pub async fn process(
5776        &self,
5777        ctx: &ResponseInterceptorContext<'_>,
5778        mut response: Response,
5779    ) -> Response {
5780        for interceptor in &self.interceptors {
5781            let _ = ctx.request_ctx.checkpoint();
5782            response = interceptor.intercept(ctx, response).await;
5783        }
5784        response
5785    }
5786}
5787
5788// ---------------------------------------------------------------------------
5789// Timing Interceptor
5790// ---------------------------------------------------------------------------
5791
5792/// Interceptor that adds response timing headers.
5793///
5794/// Adds the `X-Response-Time` header with the time taken to process the request.
5795/// Optionally adds Server-Timing header for browser DevTools integration.
5796///
5797/// # Example
5798///
5799/// ```ignore
5800/// let interceptor = TimingInterceptor::new();
5801/// // Or with Server-Timing header
5802/// let interceptor = TimingInterceptor::with_server_timing("app");
5803/// ```
5804#[derive(Debug, Clone)]
5805pub struct TimingInterceptor {
5806    /// Header name for the response time (default: X-Response-Time).
5807    header_name: String,
5808    /// Whether to include Server-Timing header.
5809    include_server_timing: bool,
5810    /// The timing metric name for Server-Timing (default: "total").
5811    server_timing_name: String,
5812}
5813
5814impl Default for TimingInterceptor {
5815    fn default() -> Self {
5816        Self::new()
5817    }
5818}
5819
5820impl TimingInterceptor {
5821    /// Create a new timing interceptor with default settings.
5822    #[must_use]
5823    pub fn new() -> Self {
5824        Self {
5825            header_name: "X-Response-Time".to_string(),
5826            include_server_timing: false,
5827            server_timing_name: "total".to_string(),
5828        }
5829    }
5830
5831    /// Enable Server-Timing header with the given metric name.
5832    #[must_use]
5833    pub fn with_server_timing(mut self, metric_name: impl Into<String>) -> Self {
5834        self.include_server_timing = true;
5835        self.server_timing_name = metric_name.into();
5836        self
5837    }
5838
5839    /// Set a custom header name instead of X-Response-Time.
5840    #[must_use]
5841    pub fn header_name(mut self, name: impl Into<String>) -> Self {
5842        self.header_name = name.into();
5843        self
5844    }
5845}
5846
5847impl ResponseInterceptor for TimingInterceptor {
5848    fn intercept<'a>(
5849        &'a self,
5850        ctx: &'a ResponseInterceptorContext<'a>,
5851        response: Response,
5852    ) -> BoxFuture<'a, Response> {
5853        Box::pin(async move {
5854            let elapsed_ms = ctx.elapsed_ms();
5855            let timing_value = format!("{}ms", elapsed_ms);
5856
5857            let response = response.header(&self.header_name, timing_value.clone().into_bytes());
5858
5859            if self.include_server_timing {
5860                // Server-Timing format: name;dur=value;desc="description"
5861                let server_timing = format!("{};dur={}", self.server_timing_name, elapsed_ms);
5862                response.header("Server-Timing", server_timing.into_bytes())
5863            } else {
5864                response
5865            }
5866        })
5867    }
5868
5869    fn name(&self) -> &'static str {
5870        "TimingInterceptor"
5871    }
5872}
5873
5874// ---------------------------------------------------------------------------
5875// Debug Headers Interceptor
5876// ---------------------------------------------------------------------------
5877
5878/// Interceptor that adds debug information headers.
5879///
5880/// Useful for development/staging environments to expose internal
5881/// processing information in response headers.
5882///
5883/// # Headers Added
5884///
5885/// - `X-Debug-Request-Id`: The request ID (if available)
5886/// - `X-Debug-Handler-Time`: Handler execution time
5887/// - `X-Debug-Path`: The request path
5888/// - `X-Debug-Method`: The HTTP method
5889///
5890/// # Example
5891///
5892/// ```ignore
5893/// let interceptor = DebugInfoInterceptor::new()
5894///     .include_path(true)
5895///     .include_method(true);
5896/// ```
5897#[derive(Debug, Clone)]
5898#[allow(clippy::struct_excessive_bools)]
5899pub struct DebugInfoInterceptor {
5900    /// Include path in debug headers.
5901    include_path: bool,
5902    /// Include HTTP method in debug headers.
5903    include_method: bool,
5904    /// Include request ID in debug headers.
5905    include_request_id: bool,
5906    /// Include timing information.
5907    include_timing: bool,
5908    /// Header prefix (default: "X-Debug-").
5909    header_prefix: String,
5910}
5911
5912impl Default for DebugInfoInterceptor {
5913    fn default() -> Self {
5914        Self::new()
5915    }
5916}
5917
5918impl DebugInfoInterceptor {
5919    /// Create a new debug info interceptor with all options enabled.
5920    #[must_use]
5921    pub fn new() -> Self {
5922        Self {
5923            include_path: true,
5924            include_method: true,
5925            include_request_id: true,
5926            include_timing: true,
5927            header_prefix: "X-Debug-".to_string(),
5928        }
5929    }
5930
5931    /// Set whether to include the path.
5932    #[must_use]
5933    pub fn include_path(mut self, include: bool) -> Self {
5934        self.include_path = include;
5935        self
5936    }
5937
5938    /// Set whether to include the HTTP method.
5939    #[must_use]
5940    pub fn include_method(mut self, include: bool) -> Self {
5941        self.include_method = include;
5942        self
5943    }
5944
5945    /// Set whether to include the request ID.
5946    #[must_use]
5947    pub fn include_request_id(mut self, include: bool) -> Self {
5948        self.include_request_id = include;
5949        self
5950    }
5951
5952    /// Set whether to include timing information.
5953    #[must_use]
5954    pub fn include_timing(mut self, include: bool) -> Self {
5955        self.include_timing = include;
5956        self
5957    }
5958
5959    /// Set a custom header prefix.
5960    #[must_use]
5961    pub fn header_prefix(mut self, prefix: impl Into<String>) -> Self {
5962        self.header_prefix = prefix.into();
5963        self
5964    }
5965}
5966
5967impl ResponseInterceptor for DebugInfoInterceptor {
5968    fn intercept<'a>(
5969        &'a self,
5970        ctx: &'a ResponseInterceptorContext<'a>,
5971        response: Response,
5972    ) -> BoxFuture<'a, Response> {
5973        Box::pin(async move {
5974            let mut resp = response;
5975
5976            if self.include_path {
5977                let header_name = format!("{}Path", self.header_prefix);
5978                resp = resp.header(header_name, ctx.request.path().as_bytes().to_vec());
5979            }
5980
5981            if self.include_method {
5982                let header_name = format!("{}Method", self.header_prefix);
5983                resp = resp.header(
5984                    header_name,
5985                    ctx.request.method().as_str().as_bytes().to_vec(),
5986                );
5987            }
5988
5989            if self.include_request_id {
5990                if let Some(request_id) = ctx.request.get_extension::<RequestId>() {
5991                    let header_name = format!("{}Request-Id", self.header_prefix);
5992                    resp = resp.header(header_name, request_id.0.as_bytes().to_vec());
5993                }
5994            }
5995
5996            if self.include_timing {
5997                let header_name = format!("{}Handler-Time", self.header_prefix);
5998                let timing = format!("{}ms", ctx.elapsed_ms());
5999                resp = resp.header(header_name, timing.into_bytes());
6000            }
6001
6002            resp
6003        })
6004    }
6005
6006    fn name(&self) -> &'static str {
6007        "DebugInfoInterceptor"
6008    }
6009}
6010
6011// ---------------------------------------------------------------------------
6012// Response Body Transform
6013// ---------------------------------------------------------------------------
6014
6015/// A response transformer that applies a function to the response body.
6016///
6017/// This is useful for content transformations like:
6018/// - Minification
6019/// - Pretty-printing
6020/// - Wrapping responses
6021/// - Filtering content
6022///
6023/// # Example
6024///
6025/// ```ignore
6026/// // Wrap JSON responses in an envelope
6027/// let transformer = ResponseBodyTransform::new(|body| {
6028///     format!(r#"{{"data": {}}}"#, String::from_utf8_lossy(&body)).into_bytes()
6029/// });
6030/// ```
6031pub struct ResponseBodyTransform<F>
6032where
6033    F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
6034{
6035    transform_fn: F,
6036    /// Optional content type filter - only transform if content type matches.
6037    content_type_filter: Option<String>,
6038}
6039
6040impl<F> ResponseBodyTransform<F>
6041where
6042    F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
6043{
6044    /// Create a new body transformer with the given function.
6045    pub fn new(transform_fn: F) -> Self {
6046        Self {
6047            transform_fn,
6048            content_type_filter: None,
6049        }
6050    }
6051
6052    /// Only apply transformation if the response content type starts with this value.
6053    #[must_use]
6054    pub fn for_content_type(mut self, content_type: impl Into<String>) -> Self {
6055        self.content_type_filter = Some(content_type.into());
6056        self
6057    }
6058
6059    fn should_transform(&self, response: &Response) -> bool {
6060        match &self.content_type_filter {
6061            Some(filter) => response
6062                .headers()
6063                .iter()
6064                .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
6065                .and_then(|(_, ct)| std::str::from_utf8(ct).ok())
6066                .map(|ct| ct.starts_with(filter))
6067                .unwrap_or(false),
6068            None => true,
6069        }
6070    }
6071}
6072
6073impl<F> ResponseInterceptor for ResponseBodyTransform<F>
6074where
6075    F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
6076{
6077    fn intercept<'a>(
6078        &'a self,
6079        _ctx: &'a ResponseInterceptorContext<'a>,
6080        response: Response,
6081    ) -> BoxFuture<'a, Response> {
6082        Box::pin(async move {
6083            if !self.should_transform(&response) {
6084                return response;
6085            }
6086
6087            // Extract the body bytes
6088            let body_bytes = match response.body_ref() {
6089                crate::response::ResponseBody::Empty => Vec::new(),
6090                crate::response::ResponseBody::Bytes(b) => b.clone(),
6091                crate::response::ResponseBody::Stream(_) => {
6092                    // Cannot transform streaming responses
6093                    return response;
6094                }
6095            };
6096
6097            // Apply transformation
6098            let transformed = (self.transform_fn)(body_bytes);
6099
6100            // Rebuild response with new body
6101            response.body(crate::response::ResponseBody::Bytes(transformed))
6102        })
6103    }
6104
6105    fn name(&self) -> &'static str {
6106        "ResponseBodyTransform"
6107    }
6108}
6109
6110// ---------------------------------------------------------------------------
6111// Header Transform Interceptor
6112// ---------------------------------------------------------------------------
6113
6114/// An interceptor that transforms response headers.
6115///
6116/// Allows adding, removing, or modifying headers based on the response.
6117///
6118/// # Example
6119///
6120/// ```ignore
6121/// let interceptor = HeaderTransformInterceptor::new()
6122///     .add("X-Powered-By", "fastapi_rust")
6123///     .remove("Server")
6124///     .rename("X-Request-Id", "X-Trace-Id");
6125/// ```
6126#[derive(Debug, Clone, Default)]
6127pub struct HeaderTransformInterceptor {
6128    /// Headers to add.
6129    add_headers: Vec<(String, Vec<u8>)>,
6130    /// Headers to remove.
6131    remove_headers: Vec<String>,
6132    /// Headers to rename (old_name -> new_name).
6133    rename_headers: Vec<(String, String)>,
6134}
6135
6136impl HeaderTransformInterceptor {
6137    /// Create a new header transform interceptor.
6138    #[must_use]
6139    pub fn new() -> Self {
6140        Self::default()
6141    }
6142
6143    /// Add a header to the response.
6144    #[must_use]
6145    pub fn add(mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
6146        self.add_headers.push((name.into(), value.into()));
6147        self
6148    }
6149
6150    /// Remove a header from the response.
6151    #[must_use]
6152    pub fn remove(mut self, name: impl Into<String>) -> Self {
6153        self.remove_headers.push(name.into());
6154        self
6155    }
6156
6157    /// Rename a header (if it exists).
6158    #[must_use]
6159    pub fn rename(mut self, old_name: impl Into<String>, new_name: impl Into<String>) -> Self {
6160        self.rename_headers.push((old_name.into(), new_name.into()));
6161        self
6162    }
6163}
6164
6165impl ResponseInterceptor for HeaderTransformInterceptor {
6166    fn intercept<'a>(
6167        &'a self,
6168        _ctx: &'a ResponseInterceptorContext<'a>,
6169        response: Response,
6170    ) -> BoxFuture<'a, Response> {
6171        let add_headers = self.add_headers.clone();
6172        let remove_headers = self.remove_headers.clone();
6173        let rename_headers = self.rename_headers.clone();
6174
6175        Box::pin(async move {
6176            let mut resp = response;
6177
6178            // Handle renames first - get values of headers to rename
6179            for (old_name, new_name) in &rename_headers {
6180                let header_value = resp
6181                    .headers()
6182                    .iter()
6183                    .find(|(name, _)| name.eq_ignore_ascii_case(old_name))
6184                    .map(|(_, v)| v.clone());
6185
6186                if let Some(value) = header_value {
6187                    resp = resp.header(new_name, value);
6188                    // Note: We can't remove the old header without rebuild
6189                    // so we just add the new one
6190                }
6191            }
6192
6193            // Add new headers
6194            for (name, value) in add_headers {
6195                resp = resp.header(name, value);
6196            }
6197
6198            // Note: Header removal would require Response to support remove_header
6199            // For now, this is a no-op but documented as a limitation
6200            let _ = remove_headers;
6201
6202            resp
6203        })
6204    }
6205
6206    fn name(&self) -> &'static str {
6207        "HeaderTransformInterceptor"
6208    }
6209}
6210
6211// ---------------------------------------------------------------------------
6212// Conditional Interceptor Wrapper
6213// ---------------------------------------------------------------------------
6214
6215/// Wrapper that applies an interceptor only when a condition is met.
6216///
6217/// # Example
6218///
6219/// ```ignore
6220/// // Only add debug headers for non-production requests
6221/// let interceptor = ConditionalInterceptor::new(
6222///     DebugInfoInterceptor::new(),
6223///     |ctx, resp| ctx.request.headers().get("X-Debug").is_some()
6224/// );
6225/// ```
6226pub struct ConditionalInterceptor<I, F>
6227where
6228    I: ResponseInterceptor,
6229    F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
6230{
6231    inner: I,
6232    condition: F,
6233}
6234
6235impl<I, F> ConditionalInterceptor<I, F>
6236where
6237    I: ResponseInterceptor,
6238    F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
6239{
6240    /// Create a new conditional interceptor.
6241    pub fn new(inner: I, condition: F) -> Self {
6242        Self { inner, condition }
6243    }
6244}
6245
6246impl<I, F> ResponseInterceptor for ConditionalInterceptor<I, F>
6247where
6248    I: ResponseInterceptor,
6249    F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
6250{
6251    fn intercept<'a>(
6252        &'a self,
6253        ctx: &'a ResponseInterceptorContext<'a>,
6254        response: Response,
6255    ) -> BoxFuture<'a, Response> {
6256        Box::pin(async move {
6257            if (self.condition)(ctx, &response) {
6258                self.inner.intercept(ctx, response).await
6259            } else {
6260                response
6261            }
6262        })
6263    }
6264
6265    fn name(&self) -> &'static str {
6266        "ConditionalInterceptor"
6267    }
6268}
6269
6270// ---------------------------------------------------------------------------
6271// Error Response Transformer
6272// ---------------------------------------------------------------------------
6273
6274/// Interceptor that transforms error responses.
6275///
6276/// Useful for:
6277/// - Hiding internal error details in production
6278/// - Adding consistent error formatting
6279/// - Logging error responses
6280///
6281/// # Example
6282///
6283/// ```ignore
6284/// let interceptor = ErrorResponseTransformer::new()
6285///     .hide_details_for_status(StatusCode::INTERNAL_SERVER_ERROR)
6286///     .with_replacement_body(b"An internal error occurred".to_vec());
6287/// ```
6288#[derive(Debug, Clone)]
6289pub struct ErrorResponseTransformer {
6290    /// Status codes to transform.
6291    status_codes: HashSet<u16>,
6292    /// Replacement body for error responses.
6293    replacement_body: Option<Vec<u8>>,
6294    /// Whether to add an error ID header.
6295    add_error_id: bool,
6296}
6297
6298impl Default for ErrorResponseTransformer {
6299    fn default() -> Self {
6300        Self::new()
6301    }
6302}
6303
6304impl ErrorResponseTransformer {
6305    /// Create a new error response transformer.
6306    #[must_use]
6307    pub fn new() -> Self {
6308        Self {
6309            status_codes: HashSet::new(),
6310            replacement_body: None,
6311            add_error_id: false,
6312        }
6313    }
6314
6315    /// Hide details for the given status code.
6316    #[must_use]
6317    pub fn hide_details_for_status(mut self, status: crate::response::StatusCode) -> Self {
6318        self.status_codes.insert(status.as_u16());
6319        self
6320    }
6321
6322    /// Set the replacement body for error responses.
6323    #[must_use]
6324    pub fn with_replacement_body(mut self, body: impl Into<Vec<u8>>) -> Self {
6325        self.replacement_body = Some(body.into());
6326        self
6327    }
6328
6329    /// Enable adding an error ID header for tracking.
6330    #[must_use]
6331    pub fn add_error_id(mut self, enable: bool) -> Self {
6332        self.add_error_id = enable;
6333        self
6334    }
6335}
6336
6337impl ResponseInterceptor for ErrorResponseTransformer {
6338    fn intercept<'a>(
6339        &'a self,
6340        ctx: &'a ResponseInterceptorContext<'a>,
6341        response: Response,
6342    ) -> BoxFuture<'a, Response> {
6343        Box::pin(async move {
6344            let status_code = response.status().as_u16();
6345
6346            if !self.status_codes.contains(&status_code) {
6347                return response;
6348            }
6349
6350            let mut resp = response;
6351
6352            // Replace body if configured
6353            if let Some(ref replacement) = self.replacement_body {
6354                resp = resp.body(crate::response::ResponseBody::Bytes(replacement.clone()));
6355            }
6356
6357            // Add error ID header if enabled
6358            if self.add_error_id {
6359                // Use request ID if available, otherwise generate a simple one
6360                let error_id = ctx
6361                    .request
6362                    .get_extension::<RequestId>()
6363                    .map(|r| r.0.clone())
6364                    .unwrap_or_else(|| format!("err-{}", ctx.elapsed_ms()));
6365                resp = resp.header("X-Error-Id", error_id.into_bytes());
6366            }
6367
6368            resp
6369        })
6370    }
6371
6372    fn name(&self) -> &'static str {
6373        "ErrorResponseTransformer"
6374    }
6375}
6376
6377// ---------------------------------------------------------------------------
6378// Middleware adapter for ResponseInterceptor
6379// ---------------------------------------------------------------------------
6380
6381/// Adapter that wraps a `ResponseInterceptor` as a `Middleware`.
6382///
6383/// This allows using response interceptors in the existing middleware stack.
6384///
6385/// # Example
6386///
6387/// ```ignore
6388/// let timing = TimingInterceptor::new();
6389/// let middleware = ResponseInterceptorMiddleware::new(timing);
6390/// stack.push(middleware);
6391/// ```
6392pub struct ResponseInterceptorMiddleware<I>
6393where
6394    I: ResponseInterceptor,
6395{
6396    interceptor: I,
6397}
6398
6399impl<I> ResponseInterceptorMiddleware<I>
6400where
6401    I: ResponseInterceptor,
6402{
6403    /// Wrap a response interceptor as middleware.
6404    pub fn new(interceptor: I) -> Self {
6405        Self { interceptor }
6406    }
6407}
6408
6409impl<I> Middleware for ResponseInterceptorMiddleware<I>
6410where
6411    I: ResponseInterceptor,
6412{
6413    fn before<'a>(
6414        &'a self,
6415        _ctx: &'a RequestContext,
6416        req: &'a mut Request,
6417    ) -> BoxFuture<'a, ControlFlow> {
6418        // Store the start time in request extensions
6419        req.insert_extension(InterceptorStartTime(Instant::now()));
6420        Box::pin(async { ControlFlow::Continue })
6421    }
6422
6423    fn after<'a>(
6424        &'a self,
6425        ctx: &'a RequestContext,
6426        req: &'a Request,
6427        response: Response,
6428    ) -> BoxFuture<'a, Response> {
6429        Box::pin(async move {
6430            // Retrieve start time from extensions
6431            let start_time = req
6432                .get_extension::<InterceptorStartTime>()
6433                .map(|t| t.0)
6434                .unwrap_or_else(Instant::now);
6435
6436            let interceptor_ctx = ResponseInterceptorContext::new(req, ctx, start_time);
6437            self.interceptor.intercept(&interceptor_ctx, response).await
6438        })
6439    }
6440
6441    fn name(&self) -> &'static str {
6442        self.interceptor.name()
6443    }
6444}
6445
6446/// Internal type for storing interceptor start time in request extensions.
6447#[derive(Debug, Clone, Copy)]
6448struct InterceptorStartTime(Instant);
6449
6450// ===========================================================================
6451// End Response Interceptors and Transformers
6452// ===========================================================================
6453
6454// ===========================================================================
6455// Response Timing Metrics Collection
6456// ===========================================================================
6457//
6458// This section provides comprehensive timing metrics for monitoring:
6459// - Request duration
6460// - Time-to-first-byte (TTFB)
6461// - Server-Timing header with multiple metrics
6462// - Histogram collection for aggregation
6463// - Integration with logging
6464
6465/// A single entry in the Server-Timing header.
6466///
6467/// Each entry has a name, duration in milliseconds, and optional description.
6468///
6469/// # Server-Timing Format
6470///
6471/// ```text
6472/// Server-Timing: name;dur=value;desc="description"
6473/// ```
6474///
6475/// # Example
6476///
6477/// ```ignore
6478/// let entry = ServerTimingEntry::new("db", 42.5)
6479///     .with_description("Database query");
6480/// ```
6481#[derive(Debug, Clone)]
6482pub struct ServerTimingEntry {
6483    /// The metric name (e.g., "db", "cache", "render").
6484    name: String,
6485    /// Duration in milliseconds (supports sub-millisecond precision).
6486    duration_ms: f64,
6487    /// Optional description for the metric.
6488    description: Option<String>,
6489}
6490
6491impl ServerTimingEntry {
6492    /// Create a new Server-Timing entry.
6493    #[must_use]
6494    pub fn new(name: impl Into<String>, duration_ms: f64) -> Self {
6495        Self {
6496            name: name.into(),
6497            duration_ms,
6498            description: None,
6499        }
6500    }
6501
6502    /// Add a description to the entry.
6503    #[must_use]
6504    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
6505        self.description = Some(desc.into());
6506        self
6507    }
6508
6509    /// Format this entry for the Server-Timing header.
6510    #[must_use]
6511    pub fn to_header_value(&self) -> String {
6512        match &self.description {
6513            Some(desc) => format!(
6514                "{};dur={:.3};desc=\"{}\"",
6515                self.name, self.duration_ms, desc
6516            ),
6517            None => format!("{};dur={:.3}", self.name, self.duration_ms),
6518        }
6519    }
6520}
6521
6522/// Builder for constructing Server-Timing headers with multiple metrics.
6523///
6524/// Collects multiple timing entries and formats them as a single header value.
6525///
6526/// # Example
6527///
6528/// ```ignore
6529/// let timing = ServerTimingBuilder::new()
6530///     .add("total", 150.5)
6531///     .add_with_desc("db", 42.0, "Database queries")
6532///     .add_with_desc("cache", 5.0, "Cache lookup")
6533///     .build();
6534///
6535/// // Result: "total;dur=150.500, db;dur=42.000;desc=\"Database queries\", cache;dur=5.000;desc=\"Cache lookup\""
6536/// ```
6537#[derive(Debug, Clone, Default)]
6538pub struct ServerTimingBuilder {
6539    entries: Vec<ServerTimingEntry>,
6540}
6541
6542impl ServerTimingBuilder {
6543    /// Create a new empty builder.
6544    #[must_use]
6545    pub fn new() -> Self {
6546        Self::default()
6547    }
6548
6549    /// Add a timing entry with just a name and duration.
6550    #[must_use]
6551    pub fn add(mut self, name: impl Into<String>, duration_ms: f64) -> Self {
6552        self.entries.push(ServerTimingEntry::new(name, duration_ms));
6553        self
6554    }
6555
6556    /// Add a timing entry with a description.
6557    #[must_use]
6558    pub fn add_with_desc(
6559        mut self,
6560        name: impl Into<String>,
6561        duration_ms: f64,
6562        description: impl Into<String>,
6563    ) -> Self {
6564        self.entries
6565            .push(ServerTimingEntry::new(name, duration_ms).with_description(description));
6566        self
6567    }
6568
6569    /// Add a pre-built entry.
6570    #[must_use]
6571    pub fn add_entry(mut self, entry: ServerTimingEntry) -> Self {
6572        self.entries.push(entry);
6573        self
6574    }
6575
6576    /// Build the Server-Timing header value.
6577    #[must_use]
6578    pub fn build(&self) -> String {
6579        self.entries
6580            .iter()
6581            .map(ServerTimingEntry::to_header_value)
6582            .collect::<Vec<_>>()
6583            .join(", ")
6584    }
6585
6586    /// Return true if no entries have been added.
6587    #[must_use]
6588    pub fn is_empty(&self) -> bool {
6589        self.entries.is_empty()
6590    }
6591
6592    /// Return the number of entries.
6593    #[must_use]
6594    pub fn len(&self) -> usize {
6595        self.entries.len()
6596    }
6597}
6598
6599/// Collected timing metrics for a single request.
6600///
6601/// This struct is stored in request extensions and can be read by
6602/// interceptors or logging middleware to expose timing data.
6603///
6604/// # Usage
6605///
6606/// Handlers can access and modify timing metrics via request extensions:
6607///
6608/// ```ignore
6609/// // Add a custom timing metric
6610/// if let Some(metrics) = req.get_extension_mut::<TimingMetrics>() {
6611///     metrics.add_metric("db", db_time.as_secs_f64() * 1000.0);
6612/// }
6613/// ```
6614#[derive(Debug, Clone)]
6615pub struct TimingMetrics {
6616    /// When the request processing started.
6617    pub start_time: Instant,
6618    /// When the first byte of the response was sent (if known).
6619    pub first_byte_time: Option<Instant>,
6620    /// Custom metrics added by handlers (name -> duration_ms).
6621    pub custom_metrics: Vec<(String, f64, Option<String>)>,
6622}
6623
6624impl TimingMetrics {
6625    /// Create new timing metrics starting now.
6626    #[must_use]
6627    pub fn new() -> Self {
6628        Self {
6629            start_time: Instant::now(),
6630            first_byte_time: None,
6631            custom_metrics: Vec::new(),
6632        }
6633    }
6634
6635    /// Create timing metrics with a specific start time.
6636    #[must_use]
6637    pub fn with_start_time(start_time: Instant) -> Self {
6638        Self {
6639            start_time,
6640            first_byte_time: None,
6641            custom_metrics: Vec::new(),
6642        }
6643    }
6644
6645    /// Mark the time when the first byte of the response was sent.
6646    pub fn mark_first_byte(&mut self) {
6647        self.first_byte_time = Some(Instant::now());
6648    }
6649
6650    /// Add a custom metric (e.g., database query time).
6651    pub fn add_metric(&mut self, name: impl Into<String>, duration_ms: f64) {
6652        self.custom_metrics.push((name.into(), duration_ms, None));
6653    }
6654
6655    /// Add a custom metric with a description.
6656    pub fn add_metric_with_desc(
6657        &mut self,
6658        name: impl Into<String>,
6659        duration_ms: f64,
6660        desc: impl Into<String>,
6661    ) {
6662        self.custom_metrics
6663            .push((name.into(), duration_ms, Some(desc.into())));
6664    }
6665
6666    /// Get the total elapsed time in milliseconds.
6667    #[must_use]
6668    pub fn total_ms(&self) -> f64 {
6669        self.start_time.elapsed().as_secs_f64() * 1000.0
6670    }
6671
6672    /// Get the time-to-first-byte in milliseconds (if available).
6673    #[must_use]
6674    pub fn ttfb_ms(&self) -> Option<f64> {
6675        self.first_byte_time
6676            .map(|t| t.duration_since(self.start_time).as_secs_f64() * 1000.0)
6677    }
6678
6679    /// Build a Server-Timing header from the collected metrics.
6680    #[must_use]
6681    pub fn to_server_timing(&self) -> ServerTimingBuilder {
6682        let mut builder = ServerTimingBuilder::new().add_with_desc(
6683            "total",
6684            self.total_ms(),
6685            "Total request time",
6686        );
6687
6688        if let Some(ttfb) = self.ttfb_ms() {
6689            builder = builder.add_with_desc("ttfb", ttfb, "Time to first byte");
6690        }
6691
6692        for (name, duration, desc) in &self.custom_metrics {
6693            match desc {
6694                Some(d) => builder = builder.add_with_desc(name, *duration, d),
6695                None => builder = builder.add(name, *duration),
6696            }
6697        }
6698
6699        builder
6700    }
6701}
6702
6703impl Default for TimingMetrics {
6704    fn default() -> Self {
6705        Self::new()
6706    }
6707}
6708
6709/// Configuration for the timing metrics middleware.
6710#[derive(Debug, Clone)]
6711#[allow(clippy::struct_excessive_bools)]
6712pub struct TimingMetricsConfig {
6713    /// Whether to add the Server-Timing header.
6714    pub add_server_timing_header: bool,
6715    /// Whether to add the X-Response-Time header.
6716    pub add_response_time_header: bool,
6717    /// Custom header name for response time (default: "X-Response-Time").
6718    pub response_time_header_name: String,
6719    /// Whether to include custom metrics from handlers.
6720    pub include_custom_metrics: bool,
6721    /// Whether to include TTFB in the Server-Timing header.
6722    pub include_ttfb: bool,
6723}
6724
6725impl Default for TimingMetricsConfig {
6726    fn default() -> Self {
6727        Self {
6728            add_server_timing_header: true,
6729            add_response_time_header: true,
6730            response_time_header_name: "X-Response-Time".to_string(),
6731            include_custom_metrics: true,
6732            include_ttfb: true,
6733        }
6734    }
6735}
6736
6737impl TimingMetricsConfig {
6738    /// Create a new config with default settings.
6739    #[must_use]
6740    pub fn new() -> Self {
6741        Self::default()
6742    }
6743
6744    /// Enable or disable Server-Timing header.
6745    #[must_use]
6746    pub fn server_timing(mut self, enabled: bool) -> Self {
6747        self.add_server_timing_header = enabled;
6748        self
6749    }
6750
6751    /// Enable or disable X-Response-Time header.
6752    #[must_use]
6753    pub fn response_time(mut self, enabled: bool) -> Self {
6754        self.add_response_time_header = enabled;
6755        self
6756    }
6757
6758    /// Set a custom response time header name.
6759    #[must_use]
6760    pub fn response_time_header(mut self, name: impl Into<String>) -> Self {
6761        self.response_time_header_name = name.into();
6762        self
6763    }
6764
6765    /// Enable or disable custom metrics.
6766    #[must_use]
6767    pub fn custom_metrics(mut self, enabled: bool) -> Self {
6768        self.include_custom_metrics = enabled;
6769        self
6770    }
6771
6772    /// Enable or disable TTFB tracking.
6773    #[must_use]
6774    pub fn ttfb(mut self, enabled: bool) -> Self {
6775        self.include_ttfb = enabled;
6776        self
6777    }
6778
6779    /// Create a production-safe config (minimal headers).
6780    #[must_use]
6781    pub fn production() -> Self {
6782        Self {
6783            add_server_timing_header: false,
6784            add_response_time_header: true,
6785            response_time_header_name: "X-Response-Time".to_string(),
6786            include_custom_metrics: false,
6787            include_ttfb: false,
6788        }
6789    }
6790
6791    /// Create a development config (all timing info exposed).
6792    #[must_use]
6793    pub fn development() -> Self {
6794        Self::default()
6795    }
6796}
6797
6798/// Middleware that collects and exposes timing metrics.
6799///
6800/// This middleware:
6801/// 1. Records the request start time
6802/// 2. Injects `TimingMetrics` into request extensions for handlers to use
6803/// 3. Adds timing headers to the response
6804///
6805/// # Example
6806///
6807/// ```ignore
6808/// let timing = TimingMetricsMiddleware::new();
6809/// // Or with custom config:
6810/// let timing = TimingMetricsMiddleware::with_config(
6811///     TimingMetricsConfig::production()
6812/// );
6813///
6814/// middleware_stack.push(timing);
6815/// ```
6816#[derive(Debug, Clone)]
6817pub struct TimingMetricsMiddleware {
6818    config: TimingMetricsConfig,
6819}
6820
6821impl TimingMetricsMiddleware {
6822    /// Create a new timing metrics middleware with default config.
6823    #[must_use]
6824    pub fn new() -> Self {
6825        Self {
6826            config: TimingMetricsConfig::default(),
6827        }
6828    }
6829
6830    /// Create with a custom configuration.
6831    #[must_use]
6832    pub fn with_config(config: TimingMetricsConfig) -> Self {
6833        Self { config }
6834    }
6835
6836    /// Create a production-safe instance (minimal headers).
6837    #[must_use]
6838    pub fn production() -> Self {
6839        Self {
6840            config: TimingMetricsConfig::production(),
6841        }
6842    }
6843
6844    /// Create a development instance (all timing info exposed).
6845    #[must_use]
6846    pub fn development() -> Self {
6847        Self {
6848            config: TimingMetricsConfig::development(),
6849        }
6850    }
6851}
6852
6853impl Default for TimingMetricsMiddleware {
6854    fn default() -> Self {
6855        Self::new()
6856    }
6857}
6858
6859impl Middleware for TimingMetricsMiddleware {
6860    fn before<'a>(
6861        &'a self,
6862        _ctx: &'a RequestContext,
6863        req: &'a mut Request,
6864    ) -> BoxFuture<'a, ControlFlow> {
6865        // Store timing metrics in request extensions
6866        req.insert_extension(TimingMetrics::new());
6867        Box::pin(async { ControlFlow::Continue })
6868    }
6869
6870    fn after<'a>(
6871        &'a self,
6872        _ctx: &'a RequestContext,
6873        req: &'a Request,
6874        response: Response,
6875    ) -> BoxFuture<'a, Response> {
6876        let config = self.config.clone();
6877
6878        Box::pin(async move {
6879            let mut resp = response;
6880
6881            // Get timing metrics from extensions
6882            let metrics = req.get_extension::<TimingMetrics>();
6883
6884            match metrics {
6885                Some(metrics) => {
6886                    // Add X-Response-Time header
6887                    if config.add_response_time_header {
6888                        let timing = format!("{:.3}ms", metrics.total_ms());
6889                        resp = resp.header(&config.response_time_header_name, timing.into_bytes());
6890                    }
6891
6892                    // Add Server-Timing header
6893                    if config.add_server_timing_header {
6894                        let mut builder = ServerTimingBuilder::new().add_with_desc(
6895                            "total",
6896                            metrics.total_ms(),
6897                            "Total request time",
6898                        );
6899
6900                        // Add TTFB if available and enabled
6901                        if config.include_ttfb {
6902                            if let Some(ttfb) = metrics.ttfb_ms() {
6903                                builder = builder.add_with_desc("ttfb", ttfb, "Time to first byte");
6904                            }
6905                        }
6906
6907                        // Add custom metrics if enabled
6908                        if config.include_custom_metrics {
6909                            for (name, duration, desc) in &metrics.custom_metrics {
6910                                match desc {
6911                                    Some(d) => builder = builder.add_with_desc(name, *duration, d),
6912                                    None => builder = builder.add(name, *duration),
6913                                }
6914                            }
6915                        }
6916
6917                        let header_value = builder.build();
6918                        resp = resp.header("Server-Timing", header_value.into_bytes());
6919                    }
6920                }
6921                None => {
6922                    // No timing metrics in extensions - add basic timing
6923                    // This shouldn't happen if middleware is properly registered
6924                    if config.add_response_time_header {
6925                        resp = resp.header(&config.response_time_header_name, b"0.000ms".to_vec());
6926                    }
6927                }
6928            }
6929
6930            resp
6931        })
6932    }
6933
6934    fn name(&self) -> &'static str {
6935        "TimingMetrics"
6936    }
6937}
6938
6939/// Simple histogram bucket for collecting timing distributions.
6940///
6941/// Useful for aggregating timing data across many requests.
6942#[derive(Debug, Clone)]
6943pub struct TimingHistogramBucket {
6944    /// Upper bound for this bucket (milliseconds).
6945    pub le: f64,
6946    /// Count of observations in this bucket.
6947    pub count: u64,
6948}
6949
6950/// A histogram for collecting timing distributions.
6951///
6952/// This provides Prometheus-style histogram buckets for aggregating
6953/// timing data across many requests.
6954///
6955/// # Example
6956///
6957/// ```ignore
6958/// let mut histogram = TimingHistogram::with_buckets(vec![
6959///     1.0, 5.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 1000.0
6960/// ]);
6961///
6962/// histogram.observe(42.5);  // 42.5ms response time
6963/// histogram.observe(150.0);
6964///
6965/// let buckets = histogram.buckets();
6966/// let avg = histogram.mean();
6967/// ```
6968#[derive(Debug, Clone)]
6969pub struct TimingHistogram {
6970    /// Bucket upper bounds in milliseconds.
6971    bucket_bounds: Vec<f64>,
6972    /// Count per bucket.
6973    bucket_counts: Vec<u64>,
6974    /// Sum of all observed values.
6975    sum: f64,
6976    /// Total count of observations.
6977    count: u64,
6978}
6979
6980impl TimingHistogram {
6981    /// Create a histogram with the given bucket upper bounds.
6982    ///
6983    /// Bounds should be sorted in ascending order.
6984    #[must_use]
6985    pub fn with_buckets(bucket_bounds: Vec<f64>) -> Self {
6986        let bucket_counts = vec![0; bucket_bounds.len()];
6987        Self {
6988            bucket_bounds,
6989            bucket_counts,
6990            sum: 0.0,
6991            count: 0,
6992        }
6993    }
6994
6995    /// Create a histogram with default HTTP latency buckets.
6996    ///
6997    /// Buckets: 1ms, 5ms, 10ms, 25ms, 50ms, 100ms, 250ms, 500ms, 1s, 2.5s, 5s, 10s
6998    #[must_use]
6999    pub fn http_latency() -> Self {
7000        Self::with_buckets(vec![
7001            1.0, 5.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 1000.0, 2500.0, 5000.0, 10000.0,
7002        ])
7003    }
7004
7005    /// Record an observation.
7006    pub fn observe(&mut self, value_ms: f64) {
7007        self.sum += value_ms;
7008        self.count += 1;
7009
7010        // Increment bucket counts (cumulative)
7011        for (i, bound) in self.bucket_bounds.iter().enumerate() {
7012            if value_ms <= *bound {
7013                self.bucket_counts[i] += 1;
7014            }
7015        }
7016    }
7017
7018    /// Get the total count of observations.
7019    #[must_use]
7020    pub fn count(&self) -> u64 {
7021        self.count
7022    }
7023
7024    /// Get the sum of all observed values.
7025    #[must_use]
7026    pub fn sum(&self) -> f64 {
7027        self.sum
7028    }
7029
7030    /// Get the mean value.
7031    #[must_use]
7032    pub fn mean(&self) -> f64 {
7033        if self.count == 0 {
7034            0.0
7035        } else {
7036            #[allow(clippy::cast_precision_loss)]
7037            {
7038                self.sum / self.count as f64
7039            }
7040        }
7041    }
7042
7043    /// Get the bucket data.
7044    #[must_use]
7045    pub fn buckets(&self) -> Vec<TimingHistogramBucket> {
7046        self.bucket_bounds
7047            .iter()
7048            .zip(&self.bucket_counts)
7049            .map(|(&le, &count)| TimingHistogramBucket { le, count })
7050            .collect()
7051    }
7052
7053    /// Reset the histogram.
7054    pub fn reset(&mut self) {
7055        self.sum = 0.0;
7056        self.count = 0;
7057        for count in &mut self.bucket_counts {
7058            *count = 0;
7059        }
7060    }
7061}
7062
7063impl Default for TimingHistogram {
7064    fn default() -> Self {
7065        Self::http_latency()
7066    }
7067}
7068
7069// ===========================================================================
7070// End Response Timing Metrics Collection
7071// ===========================================================================
7072
7073#[cfg(test)]
7074mod timing_metrics_tests {
7075    use super::*;
7076    use crate::request::Method;
7077    use crate::response::StatusCode;
7078
7079    fn test_context() -> RequestContext {
7080        RequestContext::new(asupersync::Cx::for_testing(), 1)
7081    }
7082
7083    fn test_request() -> Request {
7084        Request::new(Method::Get, "/test")
7085    }
7086
7087    fn run_middleware_before(mw: &impl Middleware, req: &mut Request) -> ControlFlow {
7088        let ctx = test_context();
7089        futures_executor::block_on(mw.before(&ctx, req))
7090    }
7091
7092    fn run_middleware_after(mw: &impl Middleware, req: &Request, resp: Response) -> Response {
7093        let ctx = test_context();
7094        futures_executor::block_on(mw.after(&ctx, req, resp))
7095    }
7096
7097    #[test]
7098    fn server_timing_entry_basic() {
7099        let entry = ServerTimingEntry::new("db", 42.5);
7100        assert_eq!(entry.to_header_value(), "db;dur=42.500");
7101    }
7102
7103    #[test]
7104    fn server_timing_entry_with_description() {
7105        let entry = ServerTimingEntry::new("db", 42.5).with_description("Database query");
7106        assert_eq!(
7107            entry.to_header_value(),
7108            "db;dur=42.500;desc=\"Database query\""
7109        );
7110    }
7111
7112    #[test]
7113    fn server_timing_builder_single_entry() {
7114        let timing = ServerTimingBuilder::new().add("total", 150.0).build();
7115        assert_eq!(timing, "total;dur=150.000");
7116    }
7117
7118    #[test]
7119    fn server_timing_builder_multiple_entries() {
7120        let timing = ServerTimingBuilder::new()
7121            .add("total", 150.0)
7122            .add_with_desc("db", 42.0, "Database")
7123            .add("cache", 5.0)
7124            .build();
7125
7126        assert!(timing.contains("total;dur=150.000"));
7127        assert!(timing.contains("db;dur=42.000;desc=\"Database\""));
7128        assert!(timing.contains("cache;dur=5.000"));
7129        assert!(timing.contains(", ")); // Multiple entries separated by comma
7130    }
7131
7132    #[test]
7133    fn server_timing_builder_empty() {
7134        let builder = ServerTimingBuilder::new();
7135        assert!(builder.is_empty());
7136        assert_eq!(builder.len(), 0);
7137        assert_eq!(builder.build(), "");
7138    }
7139
7140    #[test]
7141    fn timing_metrics_basic() {
7142        let metrics = TimingMetrics::new();
7143        std::thread::sleep(std::time::Duration::from_millis(5));
7144
7145        let total = metrics.total_ms();
7146        assert!(total >= 5.0, "Total should be at least 5ms");
7147        assert!(metrics.ttfb_ms().is_none(), "TTFB should not be set");
7148    }
7149
7150    #[test]
7151    fn timing_metrics_custom_metrics() {
7152        let mut metrics = TimingMetrics::new();
7153        metrics.add_metric("db", 42.5);
7154        metrics.add_metric_with_desc("cache", 5.0, "Cache lookup");
7155
7156        let timing = metrics.to_server_timing();
7157        assert_eq!(timing.len(), 3); // total + 2 custom
7158
7159        let header = timing.build();
7160        assert!(header.contains("total"));
7161        assert!(header.contains("db;dur=42.500"));
7162        assert!(header.contains("cache;dur=5.000;desc=\"Cache lookup\""));
7163    }
7164
7165    #[test]
7166    fn timing_metrics_ttfb() {
7167        let mut metrics = TimingMetrics::new();
7168        std::thread::sleep(std::time::Duration::from_millis(5));
7169        metrics.mark_first_byte();
7170
7171        let ttfb = metrics.ttfb_ms().unwrap();
7172        assert!(ttfb >= 5.0, "TTFB should be at least 5ms");
7173    }
7174
7175    #[test]
7176    fn timing_metrics_config_default() {
7177        let config = TimingMetricsConfig::default();
7178        assert!(config.add_server_timing_header);
7179        assert!(config.add_response_time_header);
7180        assert!(config.include_custom_metrics);
7181        assert!(config.include_ttfb);
7182    }
7183
7184    #[test]
7185    fn timing_metrics_config_production() {
7186        let config = TimingMetricsConfig::production();
7187        assert!(!config.add_server_timing_header);
7188        assert!(config.add_response_time_header);
7189        assert!(!config.include_custom_metrics);
7190    }
7191
7192    #[test]
7193    fn timing_middleware_adds_metrics_to_request() {
7194        let mw = TimingMetricsMiddleware::new();
7195        let mut req = test_request();
7196
7197        // Before should insert TimingMetrics
7198        let result = run_middleware_before(&mw, &mut req);
7199        assert!(result.is_continue());
7200
7201        let metrics = req.get_extension::<TimingMetrics>();
7202        assert!(metrics.is_some(), "TimingMetrics should be in extensions");
7203    }
7204
7205    #[test]
7206    fn timing_middleware_adds_response_time_header() {
7207        let mw = TimingMetricsMiddleware::new();
7208        let mut req = test_request();
7209
7210        // Run before to insert TimingMetrics
7211        run_middleware_before(&mw, &mut req);
7212
7213        let resp = Response::with_status(StatusCode::OK);
7214        let result = run_middleware_after(&mw, &req, resp);
7215
7216        let has_timing = result
7217            .headers()
7218            .iter()
7219            .any(|(name, _)| name == "X-Response-Time");
7220        assert!(has_timing, "Should have X-Response-Time header");
7221    }
7222
7223    #[test]
7224    fn timing_middleware_adds_server_timing_header() {
7225        let mw = TimingMetricsMiddleware::new();
7226        let mut req = test_request();
7227
7228        run_middleware_before(&mw, &mut req);
7229
7230        let resp = Response::with_status(StatusCode::OK);
7231        let result = run_middleware_after(&mw, &req, resp);
7232
7233        let server_timing = result
7234            .headers()
7235            .iter()
7236            .find(|(name, _)| name == "Server-Timing")
7237            .map(|(_, v)| String::from_utf8_lossy(v).to_string());
7238
7239        assert!(server_timing.is_some(), "Should have Server-Timing header");
7240        let header = server_timing.unwrap();
7241        assert!(header.contains("total"), "Should have total timing");
7242    }
7243
7244    #[test]
7245    fn timing_middleware_production_mode() {
7246        let mw = TimingMetricsMiddleware::production();
7247        let mut req = test_request();
7248
7249        run_middleware_before(&mw, &mut req);
7250
7251        let resp = Response::with_status(StatusCode::OK);
7252        let result = run_middleware_after(&mw, &req, resp);
7253
7254        // Should have X-Response-Time
7255        let has_response_time = result
7256            .headers()
7257            .iter()
7258            .any(|(name, _)| name == "X-Response-Time");
7259        assert!(has_response_time);
7260
7261        // Should NOT have Server-Timing
7262        let has_server_timing = result
7263            .headers()
7264            .iter()
7265            .any(|(name, _)| name == "Server-Timing");
7266        assert!(!has_server_timing);
7267    }
7268
7269    #[test]
7270    #[allow(clippy::float_cmp)]
7271    fn timing_histogram_basic() {
7272        let mut histogram = TimingHistogram::http_latency();
7273        assert_eq!(histogram.count(), 0);
7274        assert_eq!(histogram.sum(), 0.0);
7275
7276        histogram.observe(42.0);
7277        histogram.observe(150.0);
7278        histogram.observe(5.0);
7279
7280        assert_eq!(histogram.count(), 3);
7281        assert_eq!(histogram.sum(), 197.0);
7282        assert!((histogram.mean() - 65.666).abs() < 0.01);
7283    }
7284
7285    #[test]
7286    fn timing_histogram_buckets() {
7287        let mut histogram = TimingHistogram::with_buckets(vec![10.0, 50.0, 100.0]);
7288
7289        histogram.observe(5.0); // Falls in 10 bucket
7290        histogram.observe(25.0); // Falls in 50 bucket
7291        histogram.observe(75.0); // Falls in 100 bucket
7292        histogram.observe(150.0); // Above all buckets
7293
7294        let buckets = histogram.buckets();
7295        assert_eq!(buckets.len(), 3);
7296
7297        // Buckets are cumulative
7298        assert_eq!(buckets[0].count, 1); // <= 10: 1
7299        assert_eq!(buckets[1].count, 2); // <= 50: 2
7300        assert_eq!(buckets[2].count, 3); // <= 100: 3
7301    }
7302
7303    #[test]
7304    #[allow(clippy::float_cmp)]
7305    fn timing_histogram_reset() {
7306        let mut histogram = TimingHistogram::http_latency();
7307        histogram.observe(100.0);
7308        histogram.observe(200.0);
7309
7310        assert_eq!(histogram.count(), 2);
7311
7312        histogram.reset();
7313
7314        assert_eq!(histogram.count(), 0);
7315        assert_eq!(histogram.sum(), 0.0);
7316    }
7317}
7318
7319#[cfg(test)]
7320mod response_interceptor_tests {
7321    use super::*;
7322    use crate::request::Method;
7323    use crate::response::StatusCode;
7324
7325    fn test_context() -> RequestContext {
7326        RequestContext::new(asupersync::Cx::for_testing(), 1)
7327    }
7328
7329    fn test_request() -> Request {
7330        Request::new(Method::Get, "/test")
7331    }
7332
7333    fn run_interceptor<I: ResponseInterceptor>(
7334        interceptor: &I,
7335        req: &Request,
7336        resp: Response,
7337    ) -> Response {
7338        let ctx = test_context();
7339        let start_time = Instant::now();
7340        let interceptor_ctx = ResponseInterceptorContext::new(req, &ctx, start_time);
7341        futures_executor::block_on(interceptor.intercept(&interceptor_ctx, resp))
7342    }
7343
7344    #[test]
7345    fn timing_interceptor_adds_header() {
7346        let interceptor = TimingInterceptor::new();
7347        let req = test_request();
7348        let resp = Response::with_status(StatusCode::OK);
7349
7350        let result = run_interceptor(&interceptor, &req, resp);
7351
7352        let has_timing = result
7353            .headers()
7354            .iter()
7355            .any(|(name, _)| name == "X-Response-Time");
7356        assert!(has_timing, "Should have X-Response-Time header");
7357    }
7358
7359    #[test]
7360    fn timing_interceptor_with_server_timing() {
7361        let interceptor = TimingInterceptor::new().with_server_timing("app");
7362        let req = test_request();
7363        let resp = Response::with_status(StatusCode::OK);
7364
7365        let result = run_interceptor(&interceptor, &req, resp);
7366
7367        let has_server_timing = result
7368            .headers()
7369            .iter()
7370            .any(|(name, _)| name == "Server-Timing");
7371        assert!(has_server_timing, "Should have Server-Timing header");
7372    }
7373
7374    #[test]
7375    fn timing_interceptor_custom_header_name() {
7376        let interceptor = TimingInterceptor::new().header_name("X-Custom-Time");
7377        let req = test_request();
7378        let resp = Response::with_status(StatusCode::OK);
7379
7380        let result = run_interceptor(&interceptor, &req, resp);
7381
7382        let has_custom = result
7383            .headers()
7384            .iter()
7385            .any(|(name, _)| name == "X-Custom-Time");
7386        assert!(has_custom, "Should have X-Custom-Time header");
7387    }
7388
7389    #[test]
7390    fn debug_info_interceptor_adds_headers() {
7391        let interceptor = DebugInfoInterceptor::new();
7392        let req = test_request();
7393        let resp = Response::with_status(StatusCode::OK);
7394
7395        let result = run_interceptor(&interceptor, &req, resp);
7396
7397        let has_path = result
7398            .headers()
7399            .iter()
7400            .any(|(name, _)| name == "X-Debug-Path");
7401        let has_method = result
7402            .headers()
7403            .iter()
7404            .any(|(name, _)| name == "X-Debug-Method");
7405        let has_timing = result
7406            .headers()
7407            .iter()
7408            .any(|(name, _)| name == "X-Debug-Handler-Time");
7409
7410        assert!(has_path, "Should have X-Debug-Path header");
7411        assert!(has_method, "Should have X-Debug-Method header");
7412        assert!(has_timing, "Should have X-Debug-Handler-Time header");
7413    }
7414
7415    #[test]
7416    fn debug_info_interceptor_custom_prefix() {
7417        let interceptor = DebugInfoInterceptor::new().header_prefix("X-Trace-");
7418        let req = test_request();
7419        let resp = Response::with_status(StatusCode::OK);
7420
7421        let result = run_interceptor(&interceptor, &req, resp);
7422
7423        let has_trace_path = result
7424            .headers()
7425            .iter()
7426            .any(|(name, _)| name == "X-Trace-Path");
7427        assert!(has_trace_path, "Should have X-Trace-Path header");
7428    }
7429
7430    #[test]
7431    fn debug_info_interceptor_selective_options() {
7432        let interceptor = DebugInfoInterceptor::new()
7433            .include_path(true)
7434            .include_method(false)
7435            .include_timing(false)
7436            .include_request_id(false);
7437        let req = test_request();
7438        let resp = Response::with_status(StatusCode::OK);
7439
7440        let result = run_interceptor(&interceptor, &req, resp);
7441
7442        let has_path = result
7443            .headers()
7444            .iter()
7445            .any(|(name, _)| name == "X-Debug-Path");
7446        let has_method = result
7447            .headers()
7448            .iter()
7449            .any(|(name, _)| name == "X-Debug-Method");
7450
7451        assert!(has_path, "Should have X-Debug-Path header");
7452        assert!(!has_method, "Should NOT have X-Debug-Method header");
7453    }
7454
7455    #[test]
7456    fn header_transform_adds_headers() {
7457        let interceptor = HeaderTransformInterceptor::new()
7458            .add("X-Powered-By", b"fastapi_rust".to_vec())
7459            .add("X-Version", b"1.0".to_vec());
7460        let req = test_request();
7461        let resp = Response::with_status(StatusCode::OK);
7462
7463        let result = run_interceptor(&interceptor, &req, resp);
7464
7465        let has_powered_by = result
7466            .headers()
7467            .iter()
7468            .any(|(name, _)| name == "X-Powered-By");
7469        let has_version = result.headers().iter().any(|(name, _)| name == "X-Version");
7470
7471        assert!(has_powered_by, "Should have X-Powered-By header");
7472        assert!(has_version, "Should have X-Version header");
7473    }
7474
7475    #[test]
7476    fn response_body_transform_modifies_body() {
7477        let transformer = ResponseBodyTransform::new(|body| {
7478            let mut result = b"[".to_vec();
7479            result.extend_from_slice(&body);
7480            result.extend_from_slice(b"]");
7481            result
7482        });
7483        let req = test_request();
7484        let resp = Response::with_status(StatusCode::OK)
7485            .body(crate::response::ResponseBody::Bytes(b"hello".to_vec()));
7486
7487        let result = run_interceptor(&transformer, &req, resp);
7488
7489        match result.body_ref() {
7490            crate::response::ResponseBody::Bytes(b) => {
7491                assert_eq!(b, b"[hello]");
7492            }
7493            _ => panic!("Expected bytes body"),
7494        }
7495    }
7496
7497    #[test]
7498    fn response_body_transform_with_content_type_filter() {
7499        let transformer =
7500            ResponseBodyTransform::new(|_| b"transformed".to_vec()).for_content_type("text/plain");
7501        let req = test_request();
7502
7503        // JSON response should NOT be transformed
7504        let json_resp = Response::with_status(StatusCode::OK)
7505            .header("content-type", b"application/json".to_vec())
7506            .body(crate::response::ResponseBody::Bytes(b"original".to_vec()));
7507
7508        let result = run_interceptor(&transformer, &req, json_resp);
7509
7510        match result.body_ref() {
7511            crate::response::ResponseBody::Bytes(b) => {
7512                assert_eq!(b, b"original", "JSON should not be transformed");
7513            }
7514            _ => panic!("Expected bytes body"),
7515        }
7516
7517        // Plain text response SHOULD be transformed
7518        let text_resp = Response::with_status(StatusCode::OK)
7519            .header("content-type", b"text/plain".to_vec())
7520            .body(crate::response::ResponseBody::Bytes(b"original".to_vec()));
7521
7522        let result = run_interceptor(&transformer, &req, text_resp);
7523
7524        match result.body_ref() {
7525            crate::response::ResponseBody::Bytes(b) => {
7526                assert_eq!(b, b"transformed", "Text should be transformed");
7527            }
7528            _ => panic!("Expected bytes body"),
7529        }
7530    }
7531
7532    #[test]
7533    fn error_response_transformer_hides_details() {
7534        let transformer = ErrorResponseTransformer::new()
7535            .hide_details_for_status(StatusCode::INTERNAL_SERVER_ERROR)
7536            .with_replacement_body(b"An error occurred");
7537
7538        let req = test_request();
7539
7540        // 500 response should be transformed
7541        let error_resp = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR).body(
7542            crate::response::ResponseBody::Bytes(b"Sensitive error details".to_vec()),
7543        );
7544
7545        let result = run_interceptor(&transformer, &req, error_resp);
7546
7547        match result.body_ref() {
7548            crate::response::ResponseBody::Bytes(b) => {
7549                assert_eq!(b, b"An error occurred");
7550            }
7551            _ => panic!("Expected bytes body"),
7552        }
7553
7554        // 200 response should NOT be transformed
7555        let ok_resp = Response::with_status(StatusCode::OK)
7556            .body(crate::response::ResponseBody::Bytes(b"Success".to_vec()));
7557
7558        let result = run_interceptor(&transformer, &req, ok_resp);
7559
7560        match result.body_ref() {
7561            crate::response::ResponseBody::Bytes(b) => {
7562                assert_eq!(b, b"Success");
7563            }
7564            _ => panic!("Expected bytes body"),
7565        }
7566    }
7567
7568    #[test]
7569    fn response_interceptor_stack_chains_interceptors() {
7570        let mut stack = ResponseInterceptorStack::new();
7571        stack.push(TimingInterceptor::new());
7572        stack.push(HeaderTransformInterceptor::new().add("X-Extra", b"value".to_vec()));
7573
7574        let req = test_request();
7575        let resp = Response::with_status(StatusCode::OK);
7576
7577        let ctx = test_context();
7578        let start_time = Instant::now();
7579        let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
7580        let result = futures_executor::block_on(stack.process(&interceptor_ctx, resp));
7581
7582        let has_timing = result
7583            .headers()
7584            .iter()
7585            .any(|(name, _)| name == "X-Response-Time");
7586        let has_extra = result.headers().iter().any(|(name, _)| name == "X-Extra");
7587
7588        assert!(
7589            has_timing,
7590            "Should have timing header from first interceptor"
7591        );
7592        assert!(
7593            has_extra,
7594            "Should have extra header from second interceptor"
7595        );
7596    }
7597
7598    #[test]
7599    fn response_interceptor_stack_empty_is_noop() {
7600        let stack = ResponseInterceptorStack::new();
7601        assert!(stack.is_empty());
7602        assert_eq!(stack.len(), 0);
7603
7604        let req = test_request();
7605        let resp = Response::with_status(StatusCode::OK)
7606            .body(crate::response::ResponseBody::Bytes(b"unchanged".to_vec()));
7607
7608        let ctx = test_context();
7609        let start_time = Instant::now();
7610        let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
7611        let result = futures_executor::block_on(stack.process(&interceptor_ctx, resp));
7612
7613        match result.body_ref() {
7614            crate::response::ResponseBody::Bytes(b) => {
7615                assert_eq!(b, b"unchanged");
7616            }
7617            _ => panic!("Expected bytes body"),
7618        }
7619    }
7620
7621    #[test]
7622    fn interceptor_context_provides_timing() {
7623        let ctx = test_context();
7624        let req = test_request();
7625        let start_time = Instant::now();
7626        std::thread::sleep(std::time::Duration::from_millis(5));
7627
7628        let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
7629
7630        assert!(
7631            interceptor_ctx.elapsed_ms() >= 5,
7632            "Elapsed time should be at least 5ms"
7633        );
7634        assert!(interceptor_ctx.elapsed().as_millis() >= 5);
7635    }
7636
7637    #[test]
7638    fn conditional_interceptor_applies_conditionally() {
7639        // Only add header if response is 200 OK
7640        let inner = HeaderTransformInterceptor::new().add("X-Success", b"true".to_vec());
7641        let conditional =
7642            ConditionalInterceptor::new(inner, |_ctx, resp| resp.status().as_u16() == 200);
7643
7644        let req = test_request();
7645
7646        // 200 response should get the header
7647        let ok_resp = Response::with_status(StatusCode::OK);
7648        let result = run_interceptor(&conditional, &req, ok_resp);
7649        let has_success = result.headers().iter().any(|(name, _)| name == "X-Success");
7650        assert!(has_success, "200 response should get X-Success header");
7651
7652        // 404 response should NOT get the header
7653        let not_found = Response::with_status(StatusCode::NOT_FOUND);
7654        let result = run_interceptor(&conditional, &req, not_found);
7655        let has_success = result.headers().iter().any(|(name, _)| name == "X-Success");
7656        assert!(!has_success, "404 response should NOT get X-Success header");
7657    }
7658}
7659
7660#[cfg(test)]
7661mod cache_control_tests {
7662    use super::*;
7663    use crate::request::Method;
7664    use crate::response::StatusCode;
7665
7666    fn test_context() -> RequestContext {
7667        RequestContext::new(asupersync::Cx::for_testing(), 1)
7668    }
7669
7670    fn run_after(mw: &CacheControlMiddleware, req: &Request, resp: Response) -> Response {
7671        let ctx = test_context();
7672        let fut = mw.after(&ctx, req, resp);
7673        futures_executor::block_on(fut)
7674    }
7675
7676    #[test]
7677    fn cache_directive_as_str_works() {
7678        assert_eq!(CacheDirective::Public.as_str(), "public");
7679        assert_eq!(CacheDirective::Private.as_str(), "private");
7680        assert_eq!(CacheDirective::NoStore.as_str(), "no-store");
7681        assert_eq!(CacheDirective::NoCache.as_str(), "no-cache");
7682        assert_eq!(CacheDirective::MustRevalidate.as_str(), "must-revalidate");
7683        assert_eq!(CacheDirective::Immutable.as_str(), "immutable");
7684    }
7685
7686    #[test]
7687    fn cache_control_builder_basic() {
7688        let cc = CacheControlBuilder::new()
7689            .public()
7690            .max_age_secs(3600)
7691            .build();
7692        assert!(cc.contains("public"));
7693        assert!(cc.contains("max-age=3600"));
7694    }
7695
7696    #[test]
7697    fn cache_control_builder_complex() {
7698        let cc = CacheControlBuilder::new()
7699            .public()
7700            .max_age_secs(60)
7701            .s_maxage_secs(3600)
7702            .stale_while_revalidate_secs(86400)
7703            .build();
7704        assert!(cc.contains("public"));
7705        assert!(cc.contains("max-age=60"));
7706        assert!(cc.contains("s-maxage=3600"));
7707        assert!(cc.contains("stale-while-revalidate=86400"));
7708    }
7709
7710    #[test]
7711    fn cache_control_builder_no_cache() {
7712        let cc = CacheControlBuilder::new()
7713            .no_store()
7714            .no_cache()
7715            .must_revalidate()
7716            .build();
7717        assert!(cc.contains("no-store"));
7718        assert!(cc.contains("no-cache"));
7719        assert!(cc.contains("must-revalidate"));
7720    }
7721
7722    #[test]
7723    fn cache_preset_no_cache() {
7724        let value = CachePreset::NoCache.to_header_value();
7725        assert!(value.contains("no-store"));
7726        assert!(value.contains("no-cache"));
7727        assert!(value.contains("must-revalidate"));
7728    }
7729
7730    #[test]
7731    fn cache_preset_immutable() {
7732        let value = CachePreset::Immutable.to_header_value();
7733        assert!(value.contains("public"));
7734        assert!(value.contains("max-age=31536000"));
7735        assert!(value.contains("immutable"));
7736    }
7737
7738    #[test]
7739    fn cache_preset_static_assets() {
7740        let value = CachePreset::StaticAssets.to_header_value();
7741        assert!(value.contains("public"));
7742        assert!(value.contains("max-age=86400"));
7743    }
7744
7745    #[test]
7746    fn middleware_adds_cache_control_header() {
7747        let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
7748        let req = Request::new(Method::Get, "/api/test");
7749        let resp = Response::with_status(StatusCode::OK);
7750
7751        let result = run_after(&mw, &req, resp);
7752        let headers = result.headers();
7753        let cc_header = headers
7754            .iter()
7755            .find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
7756        assert!(
7757            cc_header.is_some(),
7758            "Cache-Control header should be present"
7759        );
7760        let (_, value) = cc_header.unwrap();
7761        let value_str = String::from_utf8_lossy(value);
7762        assert!(value_str.contains("public"));
7763        assert!(value_str.contains("max-age=3600"));
7764    }
7765
7766    #[test]
7767    fn middleware_skips_post_requests() {
7768        let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
7769        let req = Request::new(Method::Post, "/api/test");
7770        let resp = Response::with_status(StatusCode::OK);
7771
7772        let result = run_after(&mw, &req, resp);
7773        let headers = result.headers();
7774        let cc_header = headers
7775            .iter()
7776            .find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
7777        assert!(
7778            cc_header.is_none(),
7779            "Cache-Control should not be added for POST"
7780        );
7781    }
7782
7783    #[test]
7784    fn middleware_skips_error_responses() {
7785        let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
7786        let req = Request::new(Method::Get, "/api/test");
7787        let resp = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
7788
7789        let result = run_after(&mw, &req, resp);
7790        let headers = result.headers();
7791        let cc_header = headers
7792            .iter()
7793            .find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
7794        assert!(
7795            cc_header.is_none(),
7796            "Cache-Control should not be added for error responses"
7797        );
7798    }
7799
7800    #[test]
7801    fn middleware_with_vary_header() {
7802        let mw = CacheControlMiddleware::with_config(
7803            CacheControlConfig::from_preset(CachePreset::PublicOneHour)
7804                .vary("Accept-Encoding")
7805                .vary("Accept-Language"),
7806        );
7807        let req = Request::new(Method::Get, "/api/test");
7808        let resp = Response::with_status(StatusCode::OK);
7809
7810        let result = run_after(&mw, &req, resp);
7811        let headers = result.headers();
7812        let vary_header = headers
7813            .iter()
7814            .find(|(name, _)| name.eq_ignore_ascii_case("vary"));
7815        assert!(vary_header.is_some(), "Vary header should be present");
7816        let (_, value) = vary_header.unwrap();
7817        let value_str = String::from_utf8_lossy(value);
7818        assert!(value_str.contains("Accept-Encoding"));
7819        assert!(value_str.contains("Accept-Language"));
7820    }
7821
7822    #[test]
7823    fn middleware_preserves_existing_cache_control() {
7824        let mw = CacheControlMiddleware::with_config(
7825            CacheControlConfig::from_preset(CachePreset::PublicOneHour).preserve_existing(true),
7826        );
7827        let req = Request::new(Method::Get, "/api/test");
7828        let resp =
7829            Response::with_status(StatusCode::OK).header("Cache-Control", b"max-age=60".to_vec());
7830
7831        let result = run_after(&mw, &req, resp);
7832        let headers = result.headers();
7833        let cc_headers: Vec<_> = headers
7834            .iter()
7835            .filter(|(name, _)| name.eq_ignore_ascii_case("cache-control"))
7836            .collect();
7837        // Should only have the original header, not add a new one
7838        assert_eq!(cc_headers.len(), 1);
7839        let (_, value) = cc_headers[0];
7840        let value_str = String::from_utf8_lossy(value);
7841        assert_eq!(value_str, "max-age=60");
7842    }
7843
7844    #[test]
7845    fn path_pattern_matching_exact() {
7846        assert!(path_matches_pattern("/api/users", "/api/users"));
7847        assert!(!path_matches_pattern("/api/users", "/api/items"));
7848    }
7849
7850    #[test]
7851    fn path_pattern_matching_wildcard() {
7852        assert!(path_matches_pattern("/api/users/123", "/api/users/*"));
7853        assert!(path_matches_pattern("/static/css/style.css", "/static/*"));
7854        assert!(path_matches_pattern("/anything", "*"));
7855    }
7856
7857    #[test]
7858    fn date_formatting_works() {
7859        // Test that format_http_date doesn't panic and produces valid format
7860        let now = std::time::SystemTime::now();
7861        let formatted = format_http_date(now);
7862        // Should contain GMT
7863        assert!(formatted.ends_with(" GMT"));
7864        // Should have day name
7865        let days = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"];
7866        assert!(days.iter().any(|d| formatted.starts_with(d)));
7867    }
7868
7869    #[test]
7870    fn leap_year_detection() {
7871        assert!(!is_leap_year(1900)); // Divisible by 100 but not 400
7872        assert!(is_leap_year(2000)); // Divisible by 400
7873        assert!(is_leap_year(2024)); // Divisible by 4 but not 100
7874        assert!(!is_leap_year(2023)); // Not divisible by 4
7875    }
7876}
7877
7878// ===========================================================================
7879// TRACE Rejection Middleware Tests
7880// ===========================================================================
7881
7882#[cfg(test)]
7883mod trace_rejection_tests {
7884    use super::*;
7885    use crate::request::Method;
7886    use crate::response::StatusCode;
7887
7888    fn test_context() -> RequestContext {
7889        RequestContext::new(asupersync::Cx::for_testing(), 1)
7890    }
7891
7892    fn run_before(mw: &TraceRejectionMiddleware, req: &mut Request) -> ControlFlow {
7893        let ctx = test_context();
7894        let fut = mw.before(&ctx, req);
7895        futures_executor::block_on(fut)
7896    }
7897
7898    fn find_header<'a>(headers: &'a [(String, Vec<u8>)], name: &str) -> Option<&'a [u8]> {
7899        headers
7900            .iter()
7901            .find(|(n, _)| n.eq_ignore_ascii_case(name))
7902            .map(|(_, v)| v.as_slice())
7903    }
7904
7905    #[test]
7906    fn trace_request_rejected() {
7907        let mw = TraceRejectionMiddleware::new();
7908        let mut req = Request::new(Method::Trace, "/");
7909
7910        let result = run_before(&mw, &mut req);
7911
7912        match result {
7913            ControlFlow::Break(response) => {
7914                assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
7915            }
7916            ControlFlow::Continue => panic!("TRACE request should have been rejected"),
7917        }
7918    }
7919
7920    #[test]
7921    fn trace_request_with_path() {
7922        let mw = TraceRejectionMiddleware::new();
7923        let mut req = Request::new(Method::Trace, "/api/users/123");
7924
7925        let result = run_before(&mw, &mut req);
7926
7927        match result {
7928            ControlFlow::Break(response) => {
7929                assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
7930            }
7931            ControlFlow::Continue => panic!("TRACE request should have been rejected"),
7932        }
7933    }
7934
7935    #[test]
7936    fn get_request_allowed() {
7937        let mw = TraceRejectionMiddleware::new();
7938        let mut req = Request::new(Method::Get, "/");
7939
7940        let result = run_before(&mw, &mut req);
7941
7942        match result {
7943            ControlFlow::Continue => {} // Expected
7944            ControlFlow::Break(_) => panic!("GET request should be allowed"),
7945        }
7946    }
7947
7948    #[test]
7949    fn post_request_allowed() {
7950        let mw = TraceRejectionMiddleware::new();
7951        let mut req = Request::new(Method::Post, "/api/users");
7952
7953        let result = run_before(&mw, &mut req);
7954
7955        match result {
7956            ControlFlow::Continue => {} // Expected
7957            ControlFlow::Break(_) => panic!("POST request should be allowed"),
7958        }
7959    }
7960
7961    #[test]
7962    fn put_request_allowed() {
7963        let mw = TraceRejectionMiddleware::new();
7964        let mut req = Request::new(Method::Put, "/api/users/1");
7965
7966        let result = run_before(&mw, &mut req);
7967
7968        match result {
7969            ControlFlow::Continue => {} // Expected
7970            ControlFlow::Break(_) => panic!("PUT request should be allowed"),
7971        }
7972    }
7973
7974    #[test]
7975    fn delete_request_allowed() {
7976        let mw = TraceRejectionMiddleware::new();
7977        let mut req = Request::new(Method::Delete, "/api/users/1");
7978
7979        let result = run_before(&mw, &mut req);
7980
7981        match result {
7982            ControlFlow::Continue => {} // Expected
7983            ControlFlow::Break(_) => panic!("DELETE request should be allowed"),
7984        }
7985    }
7986
7987    #[test]
7988    fn patch_request_allowed() {
7989        let mw = TraceRejectionMiddleware::new();
7990        let mut req = Request::new(Method::Patch, "/api/users/1");
7991
7992        let result = run_before(&mw, &mut req);
7993
7994        match result {
7995            ControlFlow::Continue => {} // Expected
7996            ControlFlow::Break(_) => panic!("PATCH request should be allowed"),
7997        }
7998    }
7999
8000    #[test]
8001    fn options_request_allowed() {
8002        let mw = TraceRejectionMiddleware::new();
8003        let mut req = Request::new(Method::Options, "/api/users");
8004
8005        let result = run_before(&mw, &mut req);
8006
8007        match result {
8008            ControlFlow::Continue => {} // Expected
8009            ControlFlow::Break(_) => panic!("OPTIONS request should be allowed"),
8010        }
8011    }
8012
8013    #[test]
8014    fn head_request_allowed() {
8015        let mw = TraceRejectionMiddleware::new();
8016        let mut req = Request::new(Method::Head, "/");
8017
8018        let result = run_before(&mw, &mut req);
8019
8020        match result {
8021            ControlFlow::Continue => {} // Expected
8022            ControlFlow::Break(_) => panic!("HEAD request should be allowed"),
8023        }
8024    }
8025
8026    #[test]
8027    fn response_includes_allow_header() {
8028        let mw = TraceRejectionMiddleware::new();
8029        let mut req = Request::new(Method::Trace, "/");
8030
8031        let result = run_before(&mw, &mut req);
8032
8033        match result {
8034            ControlFlow::Break(response) => {
8035                let allow_header = find_header(response.headers(), "Allow");
8036                assert!(
8037                    allow_header.is_some(),
8038                    "Response should include Allow header"
8039                );
8040            }
8041            ControlFlow::Continue => panic!("TRACE request should have been rejected"),
8042        }
8043    }
8044
8045    #[test]
8046    fn response_has_json_content_type() {
8047        let mw = TraceRejectionMiddleware::new();
8048        let mut req = Request::new(Method::Trace, "/");
8049
8050        let result = run_before(&mw, &mut req);
8051
8052        match result {
8053            ControlFlow::Break(response) => {
8054                let ct_header = find_header(response.headers(), "Content-Type");
8055                assert_eq!(ct_header, Some(b"application/json".as_slice()));
8056            }
8057            ControlFlow::Continue => panic!("TRACE request should have been rejected"),
8058        }
8059    }
8060
8061    #[test]
8062    fn default_enables_logging() {
8063        let mw = TraceRejectionMiddleware::new();
8064        assert!(mw.log_attempts);
8065    }
8066
8067    #[test]
8068    fn log_attempts_can_be_disabled() {
8069        let mw = TraceRejectionMiddleware::new().log_attempts(false);
8070        assert!(!mw.log_attempts);
8071    }
8072
8073    #[test]
8074    fn middleware_name() {
8075        let mw = TraceRejectionMiddleware::new();
8076        assert_eq!(mw.name(), "TraceRejection");
8077    }
8078
8079    #[test]
8080    fn default_impl() {
8081        let mw = TraceRejectionMiddleware::default();
8082        assert!(mw.log_attempts);
8083    }
8084}
8085
8086// ===========================================================================
8087// End TRACE Rejection Middleware Tests
8088// ===========================================================================
8089
8090// ===========================================================================
8091// HTTPS Redirect Middleware Tests
8092// ===========================================================================
8093
8094#[cfg(test)]
8095mod https_redirect_tests {
8096    use super::*;
8097    use crate::request::Method;
8098    use crate::response::StatusCode;
8099
8100    fn test_context() -> RequestContext {
8101        RequestContext::new(asupersync::Cx::for_testing(), 1)
8102    }
8103
8104    fn run_before(mw: &HttpsRedirectMiddleware, req: &mut Request) -> ControlFlow {
8105        let ctx = test_context();
8106        let fut = mw.before(&ctx, req);
8107        futures_executor::block_on(fut)
8108    }
8109
8110    fn run_after(mw: &HttpsRedirectMiddleware, req: &Request, resp: Response) -> Response {
8111        let ctx = test_context();
8112        let fut = mw.after(&ctx, req, resp);
8113        futures_executor::block_on(fut)
8114    }
8115
8116    fn find_header<'a>(headers: &'a [(String, Vec<u8>)], name: &str) -> Option<&'a [u8]> {
8117        headers
8118            .iter()
8119            .find(|(n, _)| n.eq_ignore_ascii_case(name))
8120            .map(|(_, v)| v.as_slice())
8121    }
8122
8123    #[test]
8124    fn http_request_redirected() {
8125        let mw = HttpsRedirectMiddleware::new();
8126        let mut req = Request::new(Method::Get, "/");
8127        req.headers_mut().insert("Host", b"example.com".to_vec());
8128
8129        let result = run_before(&mw, &mut req);
8130
8131        match result {
8132            ControlFlow::Break(response) => {
8133                assert_eq!(response.status(), StatusCode::MOVED_PERMANENTLY);
8134                let location = find_header(response.headers(), "Location");
8135                assert_eq!(location, Some(b"https://example.com/".as_slice()));
8136            }
8137            ControlFlow::Continue => panic!("HTTP request should be redirected"),
8138        }
8139    }
8140
8141    #[test]
8142    fn http_request_with_path_and_query() {
8143        let mw = HttpsRedirectMiddleware::new();
8144        let mut req = Request::new(Method::Get, "/api/users?page=1");
8145        req.headers_mut().insert("Host", b"example.com".to_vec());
8146
8147        let result = run_before(&mw, &mut req);
8148
8149        match result {
8150            ControlFlow::Break(response) => {
8151                let location = find_header(response.headers(), "Location");
8152                assert_eq!(
8153                    location,
8154                    Some(b"https://example.com/api/users?page=1".as_slice())
8155                );
8156            }
8157            ControlFlow::Continue => panic!("HTTP request should be redirected"),
8158        }
8159    }
8160
8161    #[test]
8162    fn https_request_not_redirected() {
8163        let mw = HttpsRedirectMiddleware::new();
8164        let mut req = Request::new(Method::Get, "/");
8165        req.headers_mut().insert("Host", b"example.com".to_vec());
8166        req.headers_mut()
8167            .insert("X-Forwarded-Proto", b"https".to_vec());
8168
8169        let result = run_before(&mw, &mut req);
8170
8171        match result {
8172            ControlFlow::Continue => {} // Expected
8173            ControlFlow::Break(_) => panic!("HTTPS request should not be redirected"),
8174        }
8175    }
8176
8177    #[test]
8178    fn x_forwarded_ssl_recognized() {
8179        let mw = HttpsRedirectMiddleware::new();
8180        let mut req = Request::new(Method::Get, "/");
8181        req.headers_mut().insert("Host", b"example.com".to_vec());
8182        req.headers_mut().insert("X-Forwarded-Ssl", b"on".to_vec());
8183
8184        let result = run_before(&mw, &mut req);
8185
8186        match result {
8187            ControlFlow::Continue => {} // Expected
8188            ControlFlow::Break(_) => panic!("Request with X-Forwarded-Ssl=on should not redirect"),
8189        }
8190    }
8191
8192    #[test]
8193    fn excluded_path_not_redirected() {
8194        let mw = HttpsRedirectMiddleware::new().exclude_path("/health");
8195        let mut req = Request::new(Method::Get, "/health");
8196        req.headers_mut().insert("Host", b"example.com".to_vec());
8197
8198        let result = run_before(&mw, &mut req);
8199
8200        match result {
8201            ControlFlow::Continue => {} // Expected
8202            ControlFlow::Break(_) => panic!("Excluded path should not be redirected"),
8203        }
8204    }
8205
8206    #[test]
8207    fn excluded_path_prefix_matches() {
8208        let mw = HttpsRedirectMiddleware::new().exclude_path("/health");
8209        let mut req = Request::new(Method::Get, "/health/live");
8210        req.headers_mut().insert("Host", b"example.com".to_vec());
8211
8212        let result = run_before(&mw, &mut req);
8213
8214        match result {
8215            ControlFlow::Continue => {} // Expected
8216            ControlFlow::Break(_) => panic!("Path with excluded prefix should not be redirected"),
8217        }
8218    }
8219
8220    #[test]
8221    fn temporary_redirect_option() {
8222        let mw = HttpsRedirectMiddleware::new().permanent_redirect(false);
8223        let mut req = Request::new(Method::Get, "/");
8224        req.headers_mut().insert("Host", b"example.com".to_vec());
8225
8226        let result = run_before(&mw, &mut req);
8227
8228        match result {
8229            ControlFlow::Break(response) => {
8230                assert_eq!(response.status(), StatusCode::TEMPORARY_REDIRECT);
8231            }
8232            ControlFlow::Continue => panic!("HTTP request should be redirected"),
8233        }
8234    }
8235
8236    #[test]
8237    fn redirect_disabled() {
8238        let mw = HttpsRedirectMiddleware::new().redirect_enabled(false);
8239        let mut req = Request::new(Method::Get, "/");
8240        req.headers_mut().insert("Host", b"example.com".to_vec());
8241
8242        let result = run_before(&mw, &mut req);
8243
8244        match result {
8245            ControlFlow::Continue => {} // Expected
8246            ControlFlow::Break(_) => panic!("Redirects are disabled, should continue"),
8247        }
8248    }
8249
8250    #[test]
8251    fn hsts_header_on_https_response() {
8252        let mw = HttpsRedirectMiddleware::new();
8253        let mut req = Request::new(Method::Get, "/");
8254        req.headers_mut()
8255            .insert("X-Forwarded-Proto", b"https".to_vec());
8256
8257        let response = Response::with_status(StatusCode::OK);
8258        let result = run_after(&mw, &req, response);
8259
8260        let hsts = find_header(result.headers(), "Strict-Transport-Security");
8261        assert!(
8262            hsts.is_some(),
8263            "HSTS header should be present on HTTPS response"
8264        );
8265        let hsts_str = String::from_utf8_lossy(hsts.unwrap());
8266        assert!(hsts_str.contains("max-age=31536000"));
8267    }
8268
8269    #[test]
8270    fn hsts_header_not_on_http_response() {
8271        let mw = HttpsRedirectMiddleware::new().redirect_enabled(false);
8272        let req = Request::new(Method::Get, "/");
8273        // No X-Forwarded-Proto, so this is HTTP
8274
8275        let response = Response::with_status(StatusCode::OK);
8276        let result = run_after(&mw, &req, response);
8277
8278        let hsts = find_header(result.headers(), "Strict-Transport-Security");
8279        assert!(hsts.is_none(), "HSTS header should not be on HTTP response");
8280    }
8281
8282    #[test]
8283    fn hsts_with_include_subdomains() {
8284        let mw = HttpsRedirectMiddleware::new().include_subdomains(true);
8285        let mut req = Request::new(Method::Get, "/");
8286        req.headers_mut()
8287            .insert("X-Forwarded-Proto", b"https".to_vec());
8288
8289        let response = Response::with_status(StatusCode::OK);
8290        let result = run_after(&mw, &req, response);
8291
8292        let hsts = find_header(result.headers(), "Strict-Transport-Security");
8293        let hsts_str = String::from_utf8_lossy(hsts.unwrap());
8294        assert!(hsts_str.contains("includeSubDomains"));
8295    }
8296
8297    #[test]
8298    fn hsts_with_preload() {
8299        let mw = HttpsRedirectMiddleware::new().preload(true);
8300        let mut req = Request::new(Method::Get, "/");
8301        req.headers_mut()
8302            .insert("X-Forwarded-Proto", b"https".to_vec());
8303
8304        let response = Response::with_status(StatusCode::OK);
8305        let result = run_after(&mw, &req, response);
8306
8307        let hsts = find_header(result.headers(), "Strict-Transport-Security");
8308        let hsts_str = String::from_utf8_lossy(hsts.unwrap());
8309        assert!(hsts_str.contains("preload"));
8310    }
8311
8312    #[test]
8313    fn hsts_disabled_with_zero_max_age() {
8314        let mw = HttpsRedirectMiddleware::new().hsts_max_age_secs(0);
8315        let mut req = Request::new(Method::Get, "/");
8316        req.headers_mut()
8317            .insert("X-Forwarded-Proto", b"https".to_vec());
8318
8319        let response = Response::with_status(StatusCode::OK);
8320        let result = run_after(&mw, &req, response);
8321
8322        let hsts = find_header(result.headers(), "Strict-Transport-Security");
8323        assert!(hsts.is_none(), "HSTS should be disabled with max-age=0");
8324    }
8325
8326    #[test]
8327    fn custom_https_port() {
8328        let mw = HttpsRedirectMiddleware::new().https_port(8443);
8329        let mut req = Request::new(Method::Get, "/");
8330        req.headers_mut().insert("Host", b"example.com".to_vec());
8331
8332        let result = run_before(&mw, &mut req);
8333
8334        match result {
8335            ControlFlow::Break(response) => {
8336                let location = find_header(response.headers(), "Location");
8337                assert_eq!(location, Some(b"https://example.com:8443/".as_slice()));
8338            }
8339            ControlFlow::Continue => panic!("HTTP request should be redirected"),
8340        }
8341    }
8342
8343    #[test]
8344    fn host_with_port_stripped() {
8345        let mw = HttpsRedirectMiddleware::new();
8346        let mut req = Request::new(Method::Get, "/");
8347        req.headers_mut()
8348            .insert("Host", b"example.com:8080".to_vec());
8349
8350        let result = run_before(&mw, &mut req);
8351
8352        match result {
8353            ControlFlow::Break(response) => {
8354                let location = find_header(response.headers(), "Location");
8355                // Port should be stripped from host, using default 443
8356                assert_eq!(location, Some(b"https://example.com/".as_slice()));
8357            }
8358            ControlFlow::Continue => panic!("HTTP request should be redirected"),
8359        }
8360    }
8361
8362    #[test]
8363    fn middleware_name() {
8364        let mw = HttpsRedirectMiddleware::new();
8365        assert_eq!(mw.name(), "HttpsRedirect");
8366    }
8367
8368    #[test]
8369    fn default_impl() {
8370        let mw = HttpsRedirectMiddleware::default();
8371        assert!(mw.config.redirect_enabled);
8372        assert!(mw.config.permanent_redirect);
8373        assert_eq!(mw.config.hsts_max_age_secs, 31_536_000);
8374    }
8375
8376    #[test]
8377    fn config_builder() {
8378        let mw = HttpsRedirectMiddleware::new()
8379            .redirect_enabled(false)
8380            .permanent_redirect(false)
8381            .hsts_max_age_secs(86400)
8382            .include_subdomains(true)
8383            .preload(true)
8384            .https_port(8443);
8385
8386        assert!(!mw.config.redirect_enabled);
8387        assert!(!mw.config.permanent_redirect);
8388        assert_eq!(mw.config.hsts_max_age_secs, 86400);
8389        assert!(mw.config.hsts_include_subdomains);
8390        assert!(mw.config.hsts_preload);
8391        assert_eq!(mw.config.https_port, 8443);
8392    }
8393
8394    #[test]
8395    fn exclude_paths_method() {
8396        let mw = HttpsRedirectMiddleware::new()
8397            .exclude_paths(vec!["/health".to_string(), "/ready".to_string()]);
8398
8399        assert_eq!(mw.config.exclude_paths.len(), 2);
8400        assert!(mw.config.exclude_paths.contains(&"/health".to_string()));
8401        assert!(mw.config.exclude_paths.contains(&"/ready".to_string()));
8402    }
8403}
8404
8405// ===========================================================================
8406// End HTTPS Redirect Middleware Tests
8407// ===========================================================================
8408
8409// ===========================================================================
8410// End ETag Middleware
8411// ===========================================================================
8412
8413#[cfg(test)]
8414mod tests {
8415    use super::*;
8416    use crate::response::{ResponseBody, StatusCode};
8417
8418    // Test middleware that adds a header
8419    #[allow(dead_code)]
8420    struct AddHeaderMiddleware {
8421        name: &'static str,
8422        value: &'static [u8],
8423    }
8424
8425    impl Middleware for AddHeaderMiddleware {
8426        fn after<'a>(
8427            &'a self,
8428            _ctx: &'a RequestContext,
8429            _req: &'a Request,
8430            response: Response,
8431        ) -> BoxFuture<'a, Response> {
8432            Box::pin(async move { response.header(self.name, self.value.to_vec()) })
8433        }
8434    }
8435
8436    // Test middleware that short-circuits
8437    #[allow(dead_code)]
8438    struct BlockingMiddleware;
8439
8440    impl Middleware for BlockingMiddleware {
8441        fn before<'a>(
8442            &'a self,
8443            _ctx: &'a RequestContext,
8444            _req: &'a mut Request,
8445        ) -> BoxFuture<'a, ControlFlow> {
8446            Box::pin(async {
8447                ControlFlow::Break(
8448                    Response::with_status(StatusCode::FORBIDDEN)
8449                        .body(ResponseBody::Bytes(b"blocked".to_vec())),
8450                )
8451            })
8452        }
8453    }
8454
8455    // Test middleware that tracks calls
8456    #[allow(dead_code)]
8457    struct TrackingMiddleware {
8458        before_count: std::sync::atomic::AtomicUsize,
8459        after_count: std::sync::atomic::AtomicUsize,
8460    }
8461
8462    #[allow(dead_code)]
8463    impl TrackingMiddleware {
8464        fn new() -> Self {
8465            Self {
8466                before_count: std::sync::atomic::AtomicUsize::new(0),
8467                after_count: std::sync::atomic::AtomicUsize::new(0),
8468            }
8469        }
8470
8471        fn before_count(&self) -> usize {
8472            self.before_count.load(std::sync::atomic::Ordering::SeqCst)
8473        }
8474
8475        fn after_count(&self) -> usize {
8476            self.after_count.load(std::sync::atomic::Ordering::SeqCst)
8477        }
8478    }
8479
8480    impl Middleware for TrackingMiddleware {
8481        fn before<'a>(
8482            &'a self,
8483            _ctx: &'a RequestContext,
8484            _req: &'a mut Request,
8485        ) -> BoxFuture<'a, ControlFlow> {
8486            self.before_count
8487                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
8488            Box::pin(async { ControlFlow::Continue })
8489        }
8490
8491        fn after<'a>(
8492            &'a self,
8493            _ctx: &'a RequestContext,
8494            _req: &'a Request,
8495            response: Response,
8496        ) -> BoxFuture<'a, Response> {
8497            self.after_count
8498                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
8499            Box::pin(async move { response })
8500        }
8501    }
8502
8503    #[test]
8504    fn control_flow_variants() {
8505        let cont = ControlFlow::Continue;
8506        assert!(cont.is_continue());
8507        assert!(!cont.is_break());
8508
8509        let brk = ControlFlow::Break(Response::ok());
8510        assert!(!brk.is_continue());
8511        assert!(brk.is_break());
8512    }
8513
8514    #[test]
8515    fn middleware_stack_empty() {
8516        let stack = MiddlewareStack::new();
8517        assert!(stack.is_empty());
8518        assert_eq!(stack.len(), 0);
8519    }
8520
8521    #[test]
8522    fn middleware_stack_push() {
8523        let mut stack = MiddlewareStack::new();
8524        stack.push(NoopMiddleware);
8525        stack.push(NoopMiddleware);
8526        assert_eq!(stack.len(), 2);
8527        assert!(!stack.is_empty());
8528    }
8529
8530    #[test]
8531    fn noop_middleware_name() {
8532        let mw = NoopMiddleware;
8533        assert_eq!(mw.name(), "Noop");
8534    }
8535
8536    #[test]
8537    fn logging_redacts_sensitive_headers() {
8538        let mut headers = crate::request::Headers::new();
8539        headers.insert("Authorization", b"secret".to_vec());
8540        headers.insert("X-Request-Id", b"abc123".to_vec());
8541
8542        let redacted = super::default_redacted_headers();
8543        let formatted = super::format_headers(headers.iter(), &redacted);
8544
8545        assert!(formatted.contains("authorization=<redacted>"));
8546        assert!(formatted.contains("x-request-id=abc123"));
8547    }
8548
8549    #[test]
8550    fn logging_body_truncation() {
8551        let body = b"abcdef";
8552        let preview = super::format_bytes(body, 4);
8553        assert_eq!(preview, "abcd...");
8554
8555        let preview_full = super::format_bytes(body, 10);
8556        assert_eq!(preview_full, "abcdef");
8557    }
8558
8559    fn test_context() -> RequestContext {
8560        let cx = asupersync::Cx::for_testing();
8561        RequestContext::new(cx, 1)
8562    }
8563
8564    fn header_value(response: &Response, name: &str) -> Option<String> {
8565        response
8566            .headers()
8567            .iter()
8568            .find(|(n, _)| n.eq_ignore_ascii_case(name))
8569            .and_then(|(_, v)| std::str::from_utf8(v).ok())
8570            .map(ToString::to_string)
8571    }
8572
8573    #[test]
8574    fn cors_exact_origin_allows() {
8575        let cors = Cors::new().allow_origin("https://example.com");
8576        let ctx = test_context();
8577        let mut req = Request::new(crate::request::Method::Get, "/");
8578        req.headers_mut()
8579            .insert("origin", b"https://example.com".to_vec());
8580
8581        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8582        assert!(matches!(result, ControlFlow::Continue));
8583
8584        let response = Response::ok().body(ResponseBody::Bytes(b"ok".to_vec()));
8585        let response = futures_executor::block_on(cors.after(&ctx, &req, response));
8586
8587        assert_eq!(
8588            header_value(&response, "access-control-allow-origin"),
8589            Some("https://example.com".to_string())
8590        );
8591        assert_eq!(header_value(&response, "vary"), Some("Origin".to_string()));
8592    }
8593
8594    #[test]
8595    fn cors_wildcard_origin_allows() {
8596        let cors = Cors::new().allow_origin_wildcard("https://*.example.com");
8597        let ctx = test_context();
8598        let mut req = Request::new(crate::request::Method::Get, "/");
8599        req.headers_mut()
8600            .insert("origin", b"https://api.example.com".to_vec());
8601
8602        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8603        assert!(matches!(result, ControlFlow::Continue));
8604    }
8605
8606    #[test]
8607    fn cors_regex_origin_allows() {
8608        let cors = Cors::new().allow_origin_regex(r"^https://.*\.example\.com$");
8609        let ctx = test_context();
8610        let mut req = Request::new(crate::request::Method::Get, "/");
8611        req.headers_mut()
8612            .insert("origin", b"https://svc.example.com".to_vec());
8613
8614        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8615        assert!(matches!(result, ControlFlow::Continue));
8616    }
8617
8618    #[test]
8619    fn cors_preflight_handled() {
8620        let cors = Cors::new()
8621            .allow_any_origin()
8622            .allow_headers(["x-test", "content-type"])
8623            .max_age(600);
8624        let ctx = test_context();
8625        let mut req = Request::new(crate::request::Method::Options, "/");
8626        req.headers_mut()
8627            .insert("origin", b"https://example.com".to_vec());
8628        req.headers_mut()
8629            .insert("access-control-request-method", b"POST".to_vec());
8630        req.headers_mut().insert(
8631            "access-control-request-headers",
8632            b"x-test, content-type".to_vec(),
8633        );
8634
8635        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8636        let ControlFlow::Break(response) = result else {
8637            panic!("expected preflight break");
8638        };
8639
8640        assert_eq!(response.status().as_u16(), 204);
8641        assert_eq!(
8642            header_value(&response, "access-control-allow-origin"),
8643            Some("*".to_string())
8644        );
8645        assert_eq!(
8646            header_value(&response, "access-control-allow-methods"),
8647            Some("GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD".to_string())
8648        );
8649        assert_eq!(
8650            header_value(&response, "access-control-allow-headers"),
8651            Some("x-test, content-type".to_string())
8652        );
8653        assert_eq!(
8654            header_value(&response, "access-control-max-age"),
8655            Some("600".to_string())
8656        );
8657    }
8658
8659    #[test]
8660    fn cors_credentials_echo_origin() {
8661        let cors = Cors::new().allow_any_origin().allow_credentials(true);
8662        let ctx = test_context();
8663        let mut req = Request::new(crate::request::Method::Get, "/");
8664        req.headers_mut()
8665            .insert("origin", b"https://example.com".to_vec());
8666
8667        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8668        assert!(matches!(result, ControlFlow::Continue));
8669
8670        let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8671        assert_eq!(
8672            header_value(&response, "access-control-allow-origin"),
8673            Some("https://example.com".to_string())
8674        );
8675        assert_eq!(
8676            header_value(&response, "access-control-allow-credentials"),
8677            Some("true".to_string())
8678        );
8679    }
8680
8681    // CORS Spec Compliance Tests (bd-l1qe)
8682    // According to the Fetch Standard, when credentials mode is true,
8683    // the Access-Control-Allow-Origin header MUST NOT be "*".
8684
8685    #[test]
8686    fn cors_spec_compliance_credentials_never_wildcard_origin() {
8687        // When credentials are enabled, Access-Control-Allow-Origin
8688        // must echo the specific origin, never "*"
8689        let cors = Cors::new().allow_any_origin().allow_credentials(true);
8690        let ctx = test_context();
8691
8692        // Test with various origins
8693        for origin in &[
8694            "https://example.com",
8695            "https://api.example.com",
8696            "http://localhost:3000",
8697        ] {
8698            let mut req = Request::new(crate::request::Method::Get, "/");
8699            req.headers_mut()
8700                .insert("origin", origin.as_bytes().to_vec());
8701
8702            futures_executor::block_on(cors.before(&ctx, &mut req));
8703            let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8704
8705            let allow_origin = header_value(&response, "access-control-allow-origin");
8706            assert_eq!(
8707                allow_origin,
8708                Some((*origin).to_string()),
8709                "With credentials enabled, Access-Control-Allow-Origin must echo '{}', not '*'",
8710                origin
8711            );
8712            assert_ne!(
8713                allow_origin,
8714                Some("*".to_string()),
8715                "CORS spec violation: credentials + wildcard origin is forbidden"
8716            );
8717        }
8718    }
8719
8720    #[test]
8721    fn cors_spec_compliance_preflight_with_credentials() {
8722        // Preflight response with credentials should also echo origin, not "*"
8723        let cors = Cors::new()
8724            .allow_any_origin()
8725            .allow_credentials(true)
8726            .allow_headers(["content-type", "x-custom-header"]);
8727        let ctx = test_context();
8728
8729        let mut req = Request::new(crate::request::Method::Options, "/");
8730        req.headers_mut()
8731            .insert("origin", b"https://example.com".to_vec());
8732        req.headers_mut()
8733            .insert("access-control-request-method", b"POST".to_vec());
8734        req.headers_mut()
8735            .insert("access-control-request-headers", b"content-type".to_vec());
8736
8737        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8738        let ControlFlow::Break(response) = result else {
8739            panic!("expected preflight break");
8740        };
8741
8742        // Verify Access-Control-Allow-Origin is NOT "*" with credentials
8743        let allow_origin = header_value(&response, "access-control-allow-origin");
8744        assert_eq!(allow_origin, Some("https://example.com".to_string()));
8745        assert_ne!(
8746            allow_origin,
8747            Some("*".to_string()),
8748            "CORS spec violation: preflight with credentials must not use wildcard origin"
8749        );
8750
8751        // Verify credentials header is set
8752        assert_eq!(
8753            header_value(&response, "access-control-allow-credentials"),
8754            Some("true".to_string())
8755        );
8756    }
8757
8758    #[test]
8759    fn cors_spec_without_credentials_allows_wildcard() {
8760        // When credentials are NOT enabled, "*" is allowed for Access-Control-Allow-Origin
8761        let cors = Cors::new().allow_any_origin();
8762        let ctx = test_context();
8763        let mut req = Request::new(crate::request::Method::Get, "/");
8764        req.headers_mut()
8765            .insert("origin", b"https://example.com".to_vec());
8766
8767        futures_executor::block_on(cors.before(&ctx, &mut req));
8768        let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8769
8770        // Without credentials, wildcard IS allowed
8771        assert_eq!(
8772            header_value(&response, "access-control-allow-origin"),
8773            Some("*".to_string())
8774        );
8775        // Should NOT have credentials header
8776        assert!(header_value(&response, "access-control-allow-credentials").is_none());
8777    }
8778
8779    #[test]
8780    fn cors_disallowed_preflight_forbidden() {
8781        let cors = Cors::new().allow_origin("https://good.example");
8782        let ctx = test_context();
8783        let mut req = Request::new(crate::request::Method::Options, "/");
8784        req.headers_mut()
8785            .insert("origin", b"https://evil.example".to_vec());
8786        req.headers_mut()
8787            .insert("access-control-request-method", b"GET".to_vec());
8788
8789        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8790        let ControlFlow::Break(response) = result else {
8791            panic!("expected forbidden preflight");
8792        };
8793        assert_eq!(response.status().as_u16(), 403);
8794    }
8795
8796    #[test]
8797    fn cors_simple_request_disallowed_origin_no_headers() {
8798        // Non-preflight request from disallowed origin should proceed but not get CORS headers
8799        let cors = Cors::new().allow_origin("https://good.example");
8800        let ctx = test_context();
8801        let mut req = Request::new(crate::request::Method::Get, "/");
8802        req.headers_mut()
8803            .insert("origin", b"https://evil.example".to_vec());
8804
8805        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8806        // Simple requests proceed (browser will block based on missing headers)
8807        assert!(matches!(result, ControlFlow::Continue));
8808
8809        let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8810        // No CORS headers should be added for disallowed origin
8811        assert!(header_value(&response, "access-control-allow-origin").is_none());
8812    }
8813
8814    #[test]
8815    fn cors_expose_headers_configuration() {
8816        let cors = Cors::new()
8817            .allow_any_origin()
8818            .expose_headers(["x-custom-header", "x-another-header"]);
8819        let ctx = test_context();
8820        let mut req = Request::new(crate::request::Method::Get, "/");
8821        req.headers_mut()
8822            .insert("origin", b"https://example.com".to_vec());
8823
8824        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8825        assert!(matches!(result, ControlFlow::Continue));
8826
8827        let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8828        assert_eq!(
8829            header_value(&response, "access-control-expose-headers"),
8830            Some("x-custom-header, x-another-header".to_string())
8831        );
8832    }
8833
8834    #[test]
8835    fn cors_any_origin_sets_wildcard() {
8836        let cors = Cors::new().allow_any_origin();
8837        let ctx = test_context();
8838        let mut req = Request::new(crate::request::Method::Get, "/");
8839        req.headers_mut()
8840            .insert("origin", b"https://any-site.com".to_vec());
8841
8842        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8843        assert!(matches!(result, ControlFlow::Continue));
8844
8845        let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8846        assert_eq!(
8847            header_value(&response, "access-control-allow-origin"),
8848            Some("*".to_string())
8849        );
8850    }
8851
8852    #[test]
8853    fn cors_config_allows_method_override() {
8854        // Test that allow_methods overrides defaults
8855        let cors = Cors::new()
8856            .allow_any_origin()
8857            .allow_methods([crate::request::Method::Get, crate::request::Method::Post]);
8858        let ctx = test_context();
8859        let mut req = Request::new(crate::request::Method::Options, "/");
8860        req.headers_mut()
8861            .insert("origin", b"https://example.com".to_vec());
8862        req.headers_mut()
8863            .insert("access-control-request-method", b"POST".to_vec());
8864
8865        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8866        let ControlFlow::Break(response) = result else {
8867            panic!("expected preflight break");
8868        };
8869        assert_eq!(
8870            header_value(&response, "access-control-allow-methods"),
8871            Some("GET, POST".to_string())
8872        );
8873    }
8874
8875    #[test]
8876    fn cors_no_origin_header_skips_cors() {
8877        // Request without Origin header should not get CORS headers
8878        let cors = Cors::new().allow_any_origin();
8879        let ctx = test_context();
8880        let mut req = Request::new(crate::request::Method::Get, "/");
8881
8882        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8883        assert!(matches!(result, ControlFlow::Continue));
8884
8885        let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8886        assert!(header_value(&response, "access-control-allow-origin").is_none());
8887    }
8888
8889    #[test]
8890    fn cors_middleware_name() {
8891        let cors = Cors::new();
8892        assert_eq!(cors.name(), "Cors");
8893    }
8894
8895    // =========================================================================
8896    // Request ID Middleware tests
8897    // =========================================================================
8898
8899    #[test]
8900    fn request_id_generates_unique_ids() {
8901        let id1 = RequestId::generate();
8902        let id2 = RequestId::generate();
8903        let id3 = RequestId::generate();
8904
8905        assert_ne!(id1, id2);
8906        assert_ne!(id2, id3);
8907        assert_ne!(id1, id3);
8908
8909        // IDs should be non-empty
8910        assert!(!id1.as_str().is_empty());
8911        assert!(!id2.as_str().is_empty());
8912        assert!(!id3.as_str().is_empty());
8913    }
8914
8915    #[test]
8916    fn request_id_display() {
8917        let id = RequestId::new("test-request-123");
8918        assert_eq!(format!("{}", id), "test-request-123");
8919    }
8920
8921    #[test]
8922    fn request_id_from_string() {
8923        let id: RequestId = "my-id".into();
8924        assert_eq!(id.as_str(), "my-id");
8925
8926        let id2: RequestId = String::from("my-id-2").into();
8927        assert_eq!(id2.as_str(), "my-id-2");
8928    }
8929
8930    #[test]
8931    fn request_id_config_defaults() {
8932        let config = RequestIdConfig::default();
8933        assert_eq!(config.header_name, "x-request-id");
8934        assert!(config.accept_from_client);
8935        assert!(config.add_to_response);
8936        assert_eq!(config.max_client_id_length, 128);
8937    }
8938
8939    #[test]
8940    fn request_id_config_builder() {
8941        let config = RequestIdConfig::new()
8942            .header_name("X-Trace-ID")
8943            .accept_from_client(false)
8944            .add_to_response(false)
8945            .max_client_id_length(64);
8946
8947        assert_eq!(config.header_name, "X-Trace-ID");
8948        assert!(!config.accept_from_client);
8949        assert!(!config.add_to_response);
8950        assert_eq!(config.max_client_id_length, 64);
8951    }
8952
8953    #[test]
8954    fn request_id_middleware_generates_id() {
8955        let middleware = RequestIdMiddleware::new();
8956        let ctx = test_context();
8957        let mut req = Request::new(crate::request::Method::Get, "/");
8958
8959        let result = futures_executor::block_on(middleware.before(&ctx, &mut req));
8960        assert!(matches!(result, ControlFlow::Continue));
8961
8962        let stored_id = req.get_extension::<RequestId>();
8963        assert!(stored_id.is_some());
8964        assert!(!stored_id.unwrap().as_str().is_empty());
8965    }
8966
8967    #[test]
8968    fn request_id_middleware_accepts_client_id() {
8969        let middleware = RequestIdMiddleware::new();
8970        let ctx = test_context();
8971        let mut req = Request::new(crate::request::Method::Get, "/");
8972        req.headers_mut()
8973            .insert("x-request-id", b"client-provided-id-123".to_vec());
8974
8975        futures_executor::block_on(middleware.before(&ctx, &mut req));
8976
8977        let stored_id = req.get_extension::<RequestId>().unwrap();
8978        assert_eq!(stored_id.as_str(), "client-provided-id-123");
8979    }
8980
8981    #[test]
8982    fn request_id_middleware_rejects_invalid_client_id() {
8983        let middleware = RequestIdMiddleware::new();
8984        let ctx = test_context();
8985
8986        // Test with invalid characters
8987        let mut req = Request::new(crate::request::Method::Get, "/");
8988        req.headers_mut()
8989            .insert("x-request-id", b"invalid<script>id".to_vec());
8990
8991        futures_executor::block_on(middleware.before(&ctx, &mut req));
8992
8993        let stored_id = req.get_extension::<RequestId>().unwrap();
8994        // Should have generated a new ID instead of using the invalid one
8995        assert_ne!(stored_id.as_str(), "invalid<script>id");
8996    }
8997
8998    #[test]
8999    fn request_id_middleware_rejects_too_long_client_id() {
9000        let config = RequestIdConfig::new().max_client_id_length(10);
9001        let middleware = RequestIdMiddleware::with_config(config);
9002        let ctx = test_context();
9003
9004        let mut req = Request::new(crate::request::Method::Get, "/");
9005        req.headers_mut()
9006            .insert("x-request-id", b"this-id-is-way-too-long".to_vec());
9007
9008        futures_executor::block_on(middleware.before(&ctx, &mut req));
9009
9010        let stored_id = req.get_extension::<RequestId>().unwrap();
9011        // Should have generated a new ID instead of using the too-long one
9012        assert_ne!(stored_id.as_str(), "this-id-is-way-too-long");
9013    }
9014
9015    #[test]
9016    fn request_id_middleware_adds_to_response() {
9017        let middleware = RequestIdMiddleware::new();
9018        let ctx = test_context();
9019        let mut req = Request::new(crate::request::Method::Get, "/");
9020
9021        futures_executor::block_on(middleware.before(&ctx, &mut req));
9022        let stored_id = req.get_extension::<RequestId>().unwrap().clone();
9023
9024        let response = Response::ok();
9025        let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
9026
9027        let header = header_value(&response, "x-request-id");
9028        assert_eq!(header, Some(stored_id.0));
9029    }
9030
9031    #[test]
9032    fn request_id_middleware_respects_add_to_response_false() {
9033        let config = RequestIdConfig::new().add_to_response(false);
9034        let middleware = RequestIdMiddleware::with_config(config);
9035        let ctx = test_context();
9036        let mut req = Request::new(crate::request::Method::Get, "/");
9037
9038        futures_executor::block_on(middleware.before(&ctx, &mut req));
9039
9040        let response = Response::ok();
9041        let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
9042
9043        let header = header_value(&response, "x-request-id");
9044        assert!(header.is_none());
9045    }
9046
9047    #[test]
9048    fn request_id_middleware_respects_accept_from_client_false() {
9049        let config = RequestIdConfig::new().accept_from_client(false);
9050        let middleware = RequestIdMiddleware::with_config(config);
9051        let ctx = test_context();
9052        let mut req = Request::new(crate::request::Method::Get, "/");
9053        req.headers_mut()
9054            .insert("x-request-id", b"client-id".to_vec());
9055
9056        futures_executor::block_on(middleware.before(&ctx, &mut req));
9057
9058        let stored_id = req.get_extension::<RequestId>().unwrap();
9059        // Should ignore client ID and generate new one
9060        assert_ne!(stored_id.as_str(), "client-id");
9061    }
9062
9063    #[test]
9064    fn request_id_middleware_custom_header_name() {
9065        let config = RequestIdConfig::new().header_name("X-Trace-ID");
9066        let middleware = RequestIdMiddleware::with_config(config);
9067        let ctx = test_context();
9068        let mut req = Request::new(crate::request::Method::Get, "/");
9069        req.headers_mut()
9070            .insert("X-Trace-ID", b"trace-123".to_vec());
9071
9072        futures_executor::block_on(middleware.before(&ctx, &mut req));
9073
9074        let stored_id = req.get_extension::<RequestId>().unwrap();
9075        assert_eq!(stored_id.as_str(), "trace-123");
9076
9077        let response = Response::ok();
9078        let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
9079
9080        let header = header_value(&response, "X-Trace-ID");
9081        assert_eq!(header, Some("trace-123".to_string()));
9082    }
9083
9084    #[test]
9085    fn is_valid_request_id_accepts_valid() {
9086        assert!(super::is_valid_request_id("abc123"));
9087        assert!(super::is_valid_request_id("request-id-123"));
9088        assert!(super::is_valid_request_id("request_id_123"));
9089        assert!(super::is_valid_request_id("request.id.123"));
9090        assert!(super::is_valid_request_id("ABC123"));
9091        assert!(super::is_valid_request_id("a-b_c.D"));
9092    }
9093
9094    #[test]
9095    fn is_valid_request_id_rejects_invalid() {
9096        assert!(!super::is_valid_request_id(""));
9097        assert!(!super::is_valid_request_id("id with spaces"));
9098        assert!(!super::is_valid_request_id("id<script>"));
9099        assert!(!super::is_valid_request_id("id\nwith\nnewlines"));
9100        assert!(!super::is_valid_request_id("id;with;semicolons"));
9101        assert!(!super::is_valid_request_id("id/with/slashes"));
9102    }
9103
9104    #[test]
9105    fn request_id_middleware_name() {
9106        let middleware = RequestIdMiddleware::new();
9107        assert_eq!(middleware.name(), "RequestId");
9108    }
9109
9110    // =========================================================================
9111    // Middleware Stack Execution Order Tests
9112    // =========================================================================
9113
9114    /// Test middleware that records when its before/after hooks run
9115    struct OrderTrackingMiddleware {
9116        id: &'static str,
9117        log: Arc<std::sync::Mutex<Vec<String>>>,
9118    }
9119
9120    impl OrderTrackingMiddleware {
9121        fn new(id: &'static str, log: Arc<std::sync::Mutex<Vec<String>>>) -> Self {
9122            Self { id, log }
9123        }
9124    }
9125
9126    impl Middleware for OrderTrackingMiddleware {
9127        fn before<'a>(
9128            &'a self,
9129            _ctx: &'a RequestContext,
9130            _req: &'a mut Request,
9131        ) -> BoxFuture<'a, ControlFlow> {
9132            self.log.lock().unwrap().push(format!("{}.before", self.id));
9133            Box::pin(async { ControlFlow::Continue })
9134        }
9135
9136        fn after<'a>(
9137            &'a self,
9138            _ctx: &'a RequestContext,
9139            _req: &'a Request,
9140            response: Response,
9141        ) -> BoxFuture<'a, Response> {
9142            self.log.lock().unwrap().push(format!("{}.after", self.id));
9143            Box::pin(async move { response })
9144        }
9145    }
9146
9147    /// Test middleware that short-circuits with a configurable condition
9148    struct ConditionalBreakMiddleware {
9149        id: &'static str,
9150        should_break: bool,
9151        log: Arc<std::sync::Mutex<Vec<String>>>,
9152    }
9153
9154    impl ConditionalBreakMiddleware {
9155        fn new(
9156            id: &'static str,
9157            should_break: bool,
9158            log: Arc<std::sync::Mutex<Vec<String>>>,
9159        ) -> Self {
9160            Self {
9161                id,
9162                should_break,
9163                log,
9164            }
9165        }
9166    }
9167
9168    impl Middleware for ConditionalBreakMiddleware {
9169        fn before<'a>(
9170            &'a self,
9171            _ctx: &'a RequestContext,
9172            _req: &'a mut Request,
9173        ) -> BoxFuture<'a, ControlFlow> {
9174            self.log.lock().unwrap().push(format!("{}.before", self.id));
9175            let should_break = self.should_break;
9176            Box::pin(async move {
9177                if should_break {
9178                    ControlFlow::Break(
9179                        Response::with_status(StatusCode::FORBIDDEN)
9180                            .body(ResponseBody::Bytes(b"blocked".to_vec())),
9181                    )
9182                } else {
9183                    ControlFlow::Continue
9184                }
9185            })
9186        }
9187
9188        fn after<'a>(
9189            &'a self,
9190            _ctx: &'a RequestContext,
9191            _req: &'a Request,
9192            response: Response,
9193        ) -> BoxFuture<'a, Response> {
9194            self.log.lock().unwrap().push(format!("{}.after", self.id));
9195            Box::pin(async move { response })
9196        }
9197    }
9198
9199    /// Simple test handler that returns 200 OK
9200    struct OkHandler;
9201
9202    impl Handler for OkHandler {
9203        fn call<'a>(
9204            &'a self,
9205            _ctx: &'a RequestContext,
9206            _req: &'a mut Request,
9207        ) -> BoxFuture<'a, Response> {
9208            Box::pin(async move { Response::ok().body(ResponseBody::Bytes(b"handler".to_vec())) })
9209        }
9210    }
9211
9212    /// Handler that checks for a header injected by middleware.
9213    struct CheckHeaderHandler;
9214
9215    impl Handler for CheckHeaderHandler {
9216        fn call<'a>(
9217            &'a self,
9218            _ctx: &'a RequestContext,
9219            req: &'a mut Request,
9220        ) -> BoxFuture<'a, Response> {
9221            let has_header = req.headers().get("X-Modified-By").is_some();
9222            Box::pin(async move {
9223                if has_header {
9224                    Response::ok().body(ResponseBody::Bytes(b"header-present".to_vec()))
9225                } else {
9226                    Response::with_status(StatusCode::BAD_REQUEST)
9227                }
9228            })
9229        }
9230    }
9231
9232    /// Handler that returns an error status.
9233    struct ErrorHandler;
9234
9235    impl Handler for ErrorHandler {
9236        fn call<'a>(
9237            &'a self,
9238            _ctx: &'a RequestContext,
9239            _req: &'a mut Request,
9240        ) -> BoxFuture<'a, Response> {
9241            Box::pin(async move { Response::with_status(StatusCode::INTERNAL_SERVER_ERROR) })
9242        }
9243    }
9244
9245    #[test]
9246    fn middleware_stack_executes_in_correct_order() {
9247        // Verify the "onion" model: before hooks run first-to-last,
9248        // after hooks run last-to-first
9249        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9250
9251        let mut stack = MiddlewareStack::new();
9252        stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9253        stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
9254        stack.push(OrderTrackingMiddleware::new("mw3", log.clone()));
9255
9256        let ctx = test_context();
9257        let mut req = Request::new(crate::request::Method::Get, "/");
9258
9259        futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9260
9261        let calls = log.lock().unwrap().clone();
9262        assert_eq!(
9263            calls,
9264            vec![
9265                "mw1.before",
9266                "mw2.before",
9267                "mw3.before",
9268                "mw3.after",
9269                "mw2.after",
9270                "mw1.after",
9271            ]
9272        );
9273    }
9274
9275    #[test]
9276    fn middleware_stack_short_circuit_skips_later_middleware() {
9277        // When middleware 2 breaks, middleware 3's before should NOT run
9278        // But middleware 1 and 2's after hooks should still run
9279        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9280
9281        let mut stack = MiddlewareStack::new();
9282        stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9283        stack.push(ConditionalBreakMiddleware::new("mw2", true, log.clone()));
9284        stack.push(OrderTrackingMiddleware::new("mw3", log.clone()));
9285
9286        let ctx = test_context();
9287        let mut req = Request::new(crate::request::Method::Get, "/");
9288
9289        let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9290
9291        // Should get 403 from the break
9292        assert_eq!(response.status().as_u16(), 403);
9293
9294        let calls = log.lock().unwrap().clone();
9295        assert_eq!(
9296            calls,
9297            vec![
9298                "mw1.before",
9299                "mw2.before",
9300                // mw3.before NOT called because mw2 broke
9301                // mw2.after NOT called because it was the one that broke (ran_before_count = 1)
9302                "mw1.after",
9303            ]
9304        );
9305    }
9306
9307    #[test]
9308    fn middleware_stack_first_middleware_breaks() {
9309        // When the first middleware breaks, no other middleware should run
9310        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9311
9312        let mut stack = MiddlewareStack::new();
9313        stack.push(ConditionalBreakMiddleware::new("mw1", true, log.clone()));
9314        stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
9315
9316        let ctx = test_context();
9317        let mut req = Request::new(crate::request::Method::Get, "/");
9318
9319        let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9320
9321        assert_eq!(response.status().as_u16(), 403);
9322
9323        let calls = log.lock().unwrap().clone();
9324        assert_eq!(calls, vec!["mw1.before"]);
9325        // No after hooks because ran_before_count = 0
9326    }
9327
9328    #[test]
9329    fn middleware_stack_last_middleware_breaks() {
9330        // When the last middleware breaks, all previous after hooks should run
9331        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9332
9333        let mut stack = MiddlewareStack::new();
9334        stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9335        stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
9336        stack.push(ConditionalBreakMiddleware::new("mw3", true, log.clone()));
9337
9338        let ctx = test_context();
9339        let mut req = Request::new(crate::request::Method::Get, "/");
9340
9341        let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9342
9343        assert_eq!(response.status().as_u16(), 403);
9344
9345        let calls = log.lock().unwrap().clone();
9346        assert_eq!(
9347            calls,
9348            vec![
9349                "mw1.before",
9350                "mw2.before",
9351                "mw3.before",
9352                // mw3 broke, so only mw1 and mw2 after hooks run
9353                "mw2.after",
9354                "mw1.after",
9355            ]
9356        );
9357    }
9358
9359    #[test]
9360    fn middleware_stack_empty_executes_handler_directly() {
9361        let stack = MiddlewareStack::new();
9362        let ctx = test_context();
9363        let mut req = Request::new(crate::request::Method::Get, "/");
9364
9365        let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9366
9367        assert_eq!(response.status().as_u16(), 200);
9368    }
9369
9370    #[test]
9371    fn middleware_stack_with_capacity() {
9372        let stack = MiddlewareStack::with_capacity(10);
9373        assert!(stack.is_empty());
9374        assert_eq!(stack.len(), 0);
9375    }
9376
9377    #[test]
9378    fn middleware_stack_push_arc() {
9379        let mut stack = MiddlewareStack::new();
9380        let mw: Arc<dyn Middleware> = Arc::new(NoopMiddleware);
9381        stack.push_arc(mw);
9382        assert_eq!(stack.len(), 1);
9383    }
9384
9385    // =========================================================================
9386    // AddResponseHeader Middleware Tests
9387    // =========================================================================
9388
9389    #[test]
9390    fn add_response_header_adds_header() {
9391        let mw = AddResponseHeader::new("X-Custom", b"custom-value".to_vec());
9392        let ctx = test_context();
9393        let req = Request::new(crate::request::Method::Get, "/");
9394
9395        let response = Response::ok();
9396        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9397
9398        assert_eq!(
9399            header_value(&response, "X-Custom"),
9400            Some("custom-value".to_string())
9401        );
9402    }
9403
9404    #[test]
9405    fn add_response_header_preserves_existing_headers() {
9406        let mw = AddResponseHeader::new("X-New", b"new".to_vec());
9407        let ctx = test_context();
9408        let req = Request::new(crate::request::Method::Get, "/");
9409
9410        let response = Response::ok().header("X-Existing", b"existing".to_vec());
9411        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9412
9413        assert_eq!(
9414            header_value(&response, "X-Existing"),
9415            Some("existing".to_string())
9416        );
9417        assert_eq!(header_value(&response, "X-New"), Some("new".to_string()));
9418    }
9419
9420    #[test]
9421    fn add_response_header_name() {
9422        let mw = AddResponseHeader::new("X-Test", b"test".to_vec());
9423        assert_eq!(mw.name(), "AddResponseHeader");
9424    }
9425
9426    // =========================================================================
9427    // RequireHeader Middleware Tests
9428    // =========================================================================
9429
9430    #[test]
9431    fn require_header_allows_with_header() {
9432        let mw = RequireHeader::new("X-Api-Key");
9433        let ctx = test_context();
9434        let mut req = Request::new(crate::request::Method::Get, "/");
9435        req.headers_mut()
9436            .insert("X-Api-Key", b"secret-key".to_vec());
9437
9438        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9439        assert!(matches!(result, ControlFlow::Continue));
9440    }
9441
9442    #[test]
9443    fn require_header_blocks_without_header() {
9444        let mw = RequireHeader::new("X-Api-Key");
9445        let ctx = test_context();
9446        let mut req = Request::new(crate::request::Method::Get, "/");
9447
9448        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9449
9450        match result {
9451            ControlFlow::Break(response) => {
9452                assert_eq!(response.status().as_u16(), 400);
9453            }
9454            ControlFlow::Continue => panic!("Expected Break, got Continue"),
9455        }
9456    }
9457
9458    #[test]
9459    fn require_header_name() {
9460        let mw = RequireHeader::new("X-Test");
9461        assert_eq!(mw.name(), "RequireHeader");
9462    }
9463
9464    // =========================================================================
9465    // PathPrefixFilter Middleware Tests
9466    // =========================================================================
9467
9468    #[test]
9469    fn path_prefix_filter_allows_matching_path() {
9470        let mw = PathPrefixFilter::new("/api");
9471        let ctx = test_context();
9472        let mut req = Request::new(crate::request::Method::Get, "/api/users");
9473
9474        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9475        assert!(matches!(result, ControlFlow::Continue));
9476    }
9477
9478    #[test]
9479    fn path_prefix_filter_allows_exact_prefix() {
9480        let mw = PathPrefixFilter::new("/api");
9481        let ctx = test_context();
9482        let mut req = Request::new(crate::request::Method::Get, "/api");
9483
9484        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9485        assert!(matches!(result, ControlFlow::Continue));
9486    }
9487
9488    #[test]
9489    fn path_prefix_filter_blocks_non_matching_path() {
9490        let mw = PathPrefixFilter::new("/api");
9491        let ctx = test_context();
9492        let mut req = Request::new(crate::request::Method::Get, "/admin/users");
9493
9494        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9495
9496        match result {
9497            ControlFlow::Break(response) => {
9498                assert_eq!(response.status().as_u16(), 404);
9499            }
9500            ControlFlow::Continue => panic!("Expected Break, got Continue"),
9501        }
9502    }
9503
9504    #[test]
9505    fn path_prefix_filter_name() {
9506        let mw = PathPrefixFilter::new("/api");
9507        assert_eq!(mw.name(), "PathPrefixFilter");
9508    }
9509
9510    // =========================================================================
9511    // ConditionalStatus Middleware Tests
9512    // =========================================================================
9513
9514    #[test]
9515    fn conditional_status_applies_true_status() {
9516        let mw = ConditionalStatus::new(
9517            |req| req.path() == "/health",
9518            StatusCode::OK,
9519            StatusCode::NOT_FOUND,
9520        );
9521        let ctx = test_context();
9522        let req = Request::new(crate::request::Method::Get, "/health");
9523        let response = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
9524
9525        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9526        assert_eq!(response.status().as_u16(), 200);
9527    }
9528
9529    #[test]
9530    fn conditional_status_applies_false_status() {
9531        let mw = ConditionalStatus::new(
9532            |req| req.path() == "/health",
9533            StatusCode::OK,
9534            StatusCode::NOT_FOUND,
9535        );
9536        let ctx = test_context();
9537        let req = Request::new(crate::request::Method::Get, "/other");
9538        let response = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
9539
9540        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9541        assert_eq!(response.status().as_u16(), 404);
9542    }
9543
9544    #[test]
9545    fn conditional_status_name() {
9546        let mw = ConditionalStatus::new(|_| true, StatusCode::OK, StatusCode::NOT_FOUND);
9547        assert_eq!(mw.name(), "ConditionalStatus");
9548    }
9549
9550    // =========================================================================
9551    // Layer and Layered Tests
9552    // =========================================================================
9553
9554    #[derive(Clone)]
9555    struct LayerTestMiddleware {
9556        prefix: String,
9557    }
9558
9559    impl LayerTestMiddleware {
9560        fn new(prefix: impl Into<String>) -> Self {
9561            Self {
9562                prefix: prefix.into(),
9563            }
9564        }
9565    }
9566
9567    impl Middleware for LayerTestMiddleware {
9568        fn after<'a>(
9569            &'a self,
9570            _ctx: &'a RequestContext,
9571            _req: &'a Request,
9572            response: Response,
9573        ) -> BoxFuture<'a, Response> {
9574            let prefix = self.prefix.clone();
9575            Box::pin(async move { response.header("X-Layer", prefix.into_bytes()) })
9576        }
9577    }
9578
9579    #[test]
9580    fn layer_wraps_handler() {
9581        let layer = Layer::new(LayerTestMiddleware::new("wrapped"));
9582        let wrapped = layer.wrap(OkHandler);
9583
9584        let ctx = test_context();
9585        let mut req = Request::new(crate::request::Method::Get, "/");
9586
9587        let response = futures_executor::block_on(wrapped.call(&ctx, &mut req));
9588
9589        assert_eq!(response.status().as_u16(), 200);
9590        assert_eq!(
9591            header_value(&response, "X-Layer"),
9592            Some("wrapped".to_string())
9593        );
9594    }
9595
9596    #[test]
9597    fn layered_handles_break() {
9598        #[derive(Clone)]
9599        struct BreakingMiddleware;
9600
9601        impl Middleware for BreakingMiddleware {
9602            fn before<'a>(
9603                &'a self,
9604                _ctx: &'a RequestContext,
9605                _req: &'a mut Request,
9606            ) -> BoxFuture<'a, ControlFlow> {
9607                Box::pin(async {
9608                    ControlFlow::Break(Response::with_status(StatusCode::UNAUTHORIZED))
9609                })
9610            }
9611
9612            fn after<'a>(
9613                &'a self,
9614                _ctx: &'a RequestContext,
9615                _req: &'a Request,
9616                response: Response,
9617            ) -> BoxFuture<'a, Response> {
9618                Box::pin(async move { response.header("X-After", b"ran".to_vec()) })
9619            }
9620        }
9621
9622        let layer = Layer::new(BreakingMiddleware);
9623        let wrapped = layer.wrap(OkHandler);
9624
9625        let ctx = test_context();
9626        let mut req = Request::new(crate::request::Method::Get, "/");
9627
9628        let response = futures_executor::block_on(wrapped.call(&ctx, &mut req));
9629
9630        // Should get 401 from break
9631        assert_eq!(response.status().as_u16(), 401);
9632        // After hook should still run
9633        assert_eq!(header_value(&response, "X-After"), Some("ran".to_string()));
9634    }
9635
9636    // =========================================================================
9637    // RequestResponseLogger Tests
9638    // =========================================================================
9639
9640    #[test]
9641    fn request_response_logger_default() {
9642        let logger = RequestResponseLogger::default();
9643        assert!(logger.log_request_headers);
9644        assert!(logger.log_response_headers);
9645        assert!(!logger.log_body);
9646        assert_eq!(logger.max_body_bytes, 1024);
9647    }
9648
9649    #[test]
9650    fn request_response_logger_builder() {
9651        let logger = RequestResponseLogger::new()
9652            .log_request_headers(false)
9653            .log_response_headers(false)
9654            .log_body(true)
9655            .max_body_bytes(2048)
9656            .redact_header("x-secret");
9657
9658        assert!(!logger.log_request_headers);
9659        assert!(!logger.log_response_headers);
9660        assert!(logger.log_body);
9661        assert_eq!(logger.max_body_bytes, 2048);
9662        assert!(logger.redact_headers.contains("x-secret"));
9663    }
9664
9665    #[test]
9666    fn request_response_logger_name() {
9667        let logger = RequestResponseLogger::new();
9668        assert_eq!(logger.name(), "RequestResponseLogger");
9669    }
9670
9671    // =========================================================================
9672    // Integration Tests with Handlers
9673    // =========================================================================
9674
9675    #[test]
9676    fn middleware_stack_modifies_request_for_handler() {
9677        /// Middleware that adds a header that the handler can see
9678        struct RequestModifier;
9679
9680        impl Middleware for RequestModifier {
9681            fn before<'a>(
9682                &'a self,
9683                _ctx: &'a RequestContext,
9684                req: &'a mut Request,
9685            ) -> BoxFuture<'a, ControlFlow> {
9686                req.headers_mut()
9687                    .insert("X-Modified-By", b"middleware".to_vec());
9688                Box::pin(async { ControlFlow::Continue })
9689            }
9690        }
9691
9692        let mut stack = MiddlewareStack::new();
9693        stack.push(RequestModifier);
9694
9695        let ctx = test_context();
9696        let mut req = Request::new(crate::request::Method::Get, "/");
9697
9698        let response =
9699            futures_executor::block_on(stack.execute(&CheckHeaderHandler, &ctx, &mut req));
9700
9701        assert_eq!(response.status().as_u16(), 200);
9702    }
9703
9704    #[test]
9705    fn middleware_stack_multiple_response_modifications() {
9706        let mut stack = MiddlewareStack::new();
9707        stack.push(AddResponseHeader::new("X-First", b"1".to_vec()));
9708        stack.push(AddResponseHeader::new("X-Second", b"2".to_vec()));
9709        stack.push(AddResponseHeader::new("X-Third", b"3".to_vec()));
9710
9711        let ctx = test_context();
9712        let mut req = Request::new(crate::request::Method::Get, "/");
9713
9714        let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9715
9716        // All headers should be present (after hooks run in reverse)
9717        assert_eq!(header_value(&response, "X-First"), Some("1".to_string()));
9718        assert_eq!(header_value(&response, "X-Second"), Some("2".to_string()));
9719        assert_eq!(header_value(&response, "X-Third"), Some("3".to_string()));
9720    }
9721
9722    #[test]
9723    fn middleware_stack_handler_receives_response_after_break() {
9724        // Verify that when middleware breaks, the response body is from the break
9725        let mut stack = MiddlewareStack::new();
9726        stack.push(ConditionalBreakMiddleware::new(
9727            "breaker",
9728            true,
9729            Arc::new(std::sync::Mutex::new(Vec::new())),
9730        ));
9731
9732        let ctx = test_context();
9733        let mut req = Request::new(crate::request::Method::Get, "/");
9734
9735        let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9736
9737        assert_eq!(response.status().as_u16(), 403);
9738        // Body should be from the breaking middleware, not the handler
9739        match response.body_ref() {
9740            ResponseBody::Bytes(b) => assert_eq!(b, b"blocked"),
9741            _ => panic!("Expected Bytes body"),
9742        }
9743    }
9744
9745    // =========================================================================
9746    // Error Propagation Tests
9747    // =========================================================================
9748
9749    #[test]
9750    fn middleware_after_can_change_status() {
9751        struct StatusChanger;
9752
9753        impl Middleware for StatusChanger {
9754            fn after<'a>(
9755                &'a self,
9756                _ctx: &'a RequestContext,
9757                _req: &'a Request,
9758                _response: Response,
9759            ) -> BoxFuture<'a, Response> {
9760                Box::pin(async { Response::with_status(StatusCode::SERVICE_UNAVAILABLE) })
9761            }
9762        }
9763
9764        let mut stack = MiddlewareStack::new();
9765        stack.push(StatusChanger);
9766
9767        let ctx = test_context();
9768        let mut req = Request::new(crate::request::Method::Get, "/");
9769
9770        let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9771
9772        // Should be changed by after hook
9773        assert_eq!(response.status().as_u16(), 503);
9774    }
9775
9776    #[test]
9777    fn middleware_after_runs_even_on_error_status() {
9778        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9779        let mut stack = MiddlewareStack::new();
9780        stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9781
9782        let ctx = test_context();
9783        let mut req = Request::new(crate::request::Method::Get, "/");
9784
9785        let response = futures_executor::block_on(stack.execute(&ErrorHandler, &ctx, &mut req));
9786
9787        assert_eq!(response.status().as_u16(), 500);
9788
9789        let calls = log.lock().unwrap().clone();
9790        // After should run even when handler returns error status
9791        assert_eq!(calls, vec!["mw1.before", "mw1.after"]);
9792    }
9793
9794    // =========================================================================
9795    // Wildcard and Regex Matching Tests
9796    // =========================================================================
9797
9798    #[test]
9799    fn wildcard_match_simple() {
9800        assert!(super::wildcard_match("*.example.com", "api.example.com"));
9801        assert!(super::wildcard_match("*.example.com", "www.example.com"));
9802        assert!(!super::wildcard_match("*.example.com", "example.com"));
9803    }
9804
9805    #[test]
9806    fn wildcard_match_suffix_pattern() {
9807        // Wildcard at start with fixed suffix - primary use case for CORS
9808        assert!(super::wildcard_match("*.txt", "file.txt"));
9809        assert!(super::wildcard_match("*.txt", "document.txt"));
9810        assert!(!super::wildcard_match("*.txt", "file.doc"));
9811        assert!(super::wildcard_match("*-suffix", "any-suffix"));
9812    }
9813
9814    #[test]
9815    fn wildcard_match_no_wildcard() {
9816        assert!(super::wildcard_match("exact", "exact"));
9817        assert!(!super::wildcard_match("exact", "different"));
9818    }
9819
9820    #[test]
9821    fn regex_match_anchored() {
9822        assert!(super::regex_match("^hello$", "hello"));
9823        assert!(!super::regex_match("^hello$", "hello world"));
9824        assert!(!super::regex_match("^hello$", "say hello"));
9825    }
9826
9827    #[test]
9828    fn regex_match_dot_wildcard() {
9829        assert!(super::regex_match("h.llo", "hello"));
9830        assert!(super::regex_match("h.llo", "hallo"));
9831    }
9832
9833    #[test]
9834    fn regex_match_star() {
9835        assert!(super::regex_match("hel*o", "hello"));
9836        assert!(super::regex_match("hel*o", "helo"));
9837        assert!(super::regex_match("hel*o", "hellllllo"));
9838    }
9839
9840    // =========================================================================
9841    // Middleware Trait Default Implementation Tests
9842    // =========================================================================
9843
9844    #[test]
9845    fn middleware_default_before_continues() {
9846        struct DefaultBefore;
9847        impl Middleware for DefaultBefore {}
9848
9849        let mw = DefaultBefore;
9850        let ctx = test_context();
9851        let mut req = Request::new(crate::request::Method::Get, "/");
9852
9853        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9854        assert!(matches!(result, ControlFlow::Continue));
9855    }
9856
9857    #[test]
9858    fn middleware_default_after_passes_through() {
9859        struct DefaultAfter;
9860        impl Middleware for DefaultAfter {}
9861
9862        let mw = DefaultAfter;
9863        let ctx = test_context();
9864        let req = Request::new(crate::request::Method::Get, "/");
9865        let response = Response::with_status(StatusCode::CREATED);
9866
9867        let result = futures_executor::block_on(mw.after(&ctx, &req, response));
9868        assert_eq!(result.status().as_u16(), 201);
9869    }
9870
9871    #[test]
9872    fn middleware_default_name_is_type_name() {
9873        struct MyCustomMiddleware;
9874        impl Middleware for MyCustomMiddleware {}
9875
9876        let mw = MyCustomMiddleware;
9877        assert!(mw.name().contains("MyCustomMiddleware"));
9878    }
9879
9880    // =========================================================================
9881    // Security Headers Middleware Tests
9882    // =========================================================================
9883
9884    #[test]
9885    fn security_headers_default_config() {
9886        let config = SecurityHeadersConfig::default();
9887        assert_eq!(config.x_content_type_options, Some("nosniff"));
9888        assert_eq!(config.x_frame_options, Some(XFrameOptions::Deny));
9889        assert_eq!(config.x_xss_protection, Some("0"));
9890        assert!(config.content_security_policy.is_none());
9891        assert!(config.hsts.is_none());
9892        assert_eq!(
9893            config.referrer_policy,
9894            Some(ReferrerPolicy::StrictOriginWhenCrossOrigin)
9895        );
9896        assert!(config.permissions_policy.is_none());
9897    }
9898
9899    #[test]
9900    fn security_headers_none_config() {
9901        let config = SecurityHeadersConfig::none();
9902        assert!(config.x_content_type_options.is_none());
9903        assert!(config.x_frame_options.is_none());
9904        assert!(config.x_xss_protection.is_none());
9905        assert!(config.content_security_policy.is_none());
9906        assert!(config.hsts.is_none());
9907        assert!(config.referrer_policy.is_none());
9908        assert!(config.permissions_policy.is_none());
9909    }
9910
9911    #[test]
9912    fn security_headers_strict_config() {
9913        let config = SecurityHeadersConfig::strict();
9914        assert_eq!(config.x_content_type_options, Some("nosniff"));
9915        assert_eq!(config.x_frame_options, Some(XFrameOptions::Deny));
9916        assert_eq!(
9917            config.content_security_policy,
9918            Some("default-src 'self'".to_string())
9919        );
9920        assert_eq!(config.hsts, Some((31536000, true, false)));
9921        assert_eq!(config.referrer_policy, Some(ReferrerPolicy::NoReferrer));
9922        assert!(config.permissions_policy.is_some());
9923    }
9924
9925    #[test]
9926    fn security_headers_config_builder() {
9927        let config = SecurityHeadersConfig::new()
9928            .x_frame_options(Some(XFrameOptions::SameOrigin))
9929            .content_security_policy("default-src 'self'")
9930            .hsts(86400, false, false)
9931            .referrer_policy(Some(ReferrerPolicy::Origin));
9932
9933        assert_eq!(config.x_frame_options, Some(XFrameOptions::SameOrigin));
9934        assert_eq!(
9935            config.content_security_policy,
9936            Some("default-src 'self'".to_string())
9937        );
9938        assert_eq!(config.hsts, Some((86400, false, false)));
9939        assert_eq!(config.referrer_policy, Some(ReferrerPolicy::Origin));
9940    }
9941
9942    #[test]
9943    fn security_headers_hsts_value_format() {
9944        // Basic HSTS
9945        let config = SecurityHeadersConfig::none().hsts(3600, false, false);
9946        assert_eq!(config.build_hsts_value(), Some("max-age=3600".to_string()));
9947
9948        // With includeSubDomains
9949        let config = SecurityHeadersConfig::none().hsts(3600, true, false);
9950        assert_eq!(
9951            config.build_hsts_value(),
9952            Some("max-age=3600; includeSubDomains".to_string())
9953        );
9954
9955        // With preload
9956        let config = SecurityHeadersConfig::none().hsts(3600, false, true);
9957        assert_eq!(
9958            config.build_hsts_value(),
9959            Some("max-age=3600; preload".to_string())
9960        );
9961
9962        // With both
9963        let config = SecurityHeadersConfig::none().hsts(3600, true, true);
9964        assert_eq!(
9965            config.build_hsts_value(),
9966            Some("max-age=3600; includeSubDomains; preload".to_string())
9967        );
9968    }
9969
9970    #[test]
9971    fn security_headers_middleware_adds_default_headers() {
9972        let mw = SecurityHeaders::new();
9973        let ctx = test_context();
9974        let req = Request::new(crate::request::Method::Get, "/");
9975        let response = Response::ok();
9976
9977        let result = futures_executor::block_on(mw.after(&ctx, &req, response));
9978
9979        // Check that default headers are present
9980        assert!(header_value(&result, "X-Content-Type-Options").is_some());
9981        assert!(header_value(&result, "X-Frame-Options").is_some());
9982        assert!(header_value(&result, "X-XSS-Protection").is_some());
9983        assert!(header_value(&result, "Referrer-Policy").is_some());
9984
9985        // Check that optional headers are NOT present by default
9986        assert!(header_value(&result, "Content-Security-Policy").is_none());
9987        assert!(header_value(&result, "Strict-Transport-Security").is_none());
9988        assert!(header_value(&result, "Permissions-Policy").is_none());
9989    }
9990
9991    #[test]
9992    fn security_headers_middleware_with_csp() {
9993        let config = SecurityHeadersConfig::new()
9994            .content_security_policy("default-src 'self'; script-src 'self' 'unsafe-inline'");
9995        let mw = SecurityHeaders::with_config(config);
9996        let ctx = test_context();
9997        let req = Request::new(crate::request::Method::Get, "/");
9998        let response = Response::ok();
9999
10000        let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10001
10002        let csp = header_value(&result, "Content-Security-Policy");
10003        assert!(csp.is_some());
10004        assert_eq!(
10005            csp.unwrap(),
10006            "default-src 'self'; script-src 'self' 'unsafe-inline'"
10007        );
10008    }
10009
10010    #[test]
10011    fn security_headers_middleware_with_hsts() {
10012        let config = SecurityHeadersConfig::new().hsts(31536000, true, false);
10013        let mw = SecurityHeaders::with_config(config);
10014        let ctx = test_context();
10015        let req = Request::new(crate::request::Method::Get, "/");
10016        let response = Response::ok();
10017
10018        let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10019
10020        let hsts = header_value(&result, "Strict-Transport-Security");
10021        assert!(hsts.is_some());
10022        assert_eq!(hsts.unwrap(), "max-age=31536000; includeSubDomains");
10023    }
10024
10025    #[test]
10026    fn security_headers_middleware_name() {
10027        let mw = SecurityHeaders::new();
10028        assert_eq!(mw.name(), "SecurityHeaders");
10029    }
10030
10031    #[test]
10032    fn x_frame_options_values() {
10033        assert_eq!(XFrameOptions::Deny.as_bytes(), b"DENY");
10034        assert_eq!(XFrameOptions::SameOrigin.as_bytes(), b"SAMEORIGIN");
10035    }
10036
10037    #[test]
10038    fn referrer_policy_values() {
10039        assert_eq!(ReferrerPolicy::NoReferrer.as_bytes(), b"no-referrer");
10040        assert_eq!(
10041            ReferrerPolicy::NoReferrerWhenDowngrade.as_bytes(),
10042            b"no-referrer-when-downgrade"
10043        );
10044        assert_eq!(ReferrerPolicy::Origin.as_bytes(), b"origin");
10045        assert_eq!(
10046            ReferrerPolicy::OriginWhenCrossOrigin.as_bytes(),
10047            b"origin-when-cross-origin"
10048        );
10049        assert_eq!(ReferrerPolicy::SameOrigin.as_bytes(), b"same-origin");
10050        assert_eq!(ReferrerPolicy::StrictOrigin.as_bytes(), b"strict-origin");
10051        assert_eq!(
10052            ReferrerPolicy::StrictOriginWhenCrossOrigin.as_bytes(),
10053            b"strict-origin-when-cross-origin"
10054        );
10055        assert_eq!(ReferrerPolicy::UnsafeUrl.as_bytes(), b"unsafe-url");
10056    }
10057
10058    #[test]
10059    fn security_headers_strict_preset() {
10060        let mw = SecurityHeaders::strict();
10061        let ctx = test_context();
10062        let req = Request::new(crate::request::Method::Get, "/");
10063        let response = Response::ok();
10064
10065        let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10066
10067        // All headers should be present with strict config
10068        assert!(header_value(&result, "X-Content-Type-Options").is_some());
10069        assert!(header_value(&result, "X-Frame-Options").is_some());
10070        assert!(header_value(&result, "Content-Security-Policy").is_some());
10071        assert!(header_value(&result, "Strict-Transport-Security").is_some());
10072        assert!(header_value(&result, "Referrer-Policy").is_some());
10073        assert!(header_value(&result, "Permissions-Policy").is_some());
10074    }
10075
10076    #[test]
10077    fn security_headers_config_clearing_methods() {
10078        let config = SecurityHeadersConfig::strict()
10079            .no_content_security_policy()
10080            .no_hsts()
10081            .no_permissions_policy();
10082
10083        assert!(config.content_security_policy.is_none());
10084        assert!(config.hsts.is_none());
10085        assert!(config.permissions_policy.is_none());
10086    }
10087
10088    // =========================================================================
10089    // CSRF Middleware Tests
10090    // =========================================================================
10091
10092    #[test]
10093    fn csrf_token_generate_produces_unique_tokens() {
10094        let token1 = CsrfToken::generate();
10095        let token2 = CsrfToken::generate();
10096        assert_ne!(token1, token2);
10097        assert!(!token1.as_str().is_empty());
10098        assert!(!token2.as_str().is_empty());
10099    }
10100
10101    #[test]
10102    fn csrf_token_display() {
10103        let token = CsrfToken::new("test-token-123");
10104        assert_eq!(format!("{}", token), "test-token-123");
10105    }
10106
10107    #[test]
10108    fn csrf_config_defaults() {
10109        let config = CsrfConfig::default();
10110        assert_eq!(config.cookie_name, "csrf_token");
10111        assert_eq!(config.header_name, "x-csrf-token");
10112        assert_eq!(config.mode, CsrfMode::DoubleSubmit);
10113        assert!(!config.rotate_token);
10114        assert!(config.production);
10115        assert!(config.error_message.is_none());
10116    }
10117
10118    #[test]
10119    fn csrf_config_builder() {
10120        let config = CsrfConfig::new()
10121            .cookie_name("XSRF-TOKEN")
10122            .header_name("X-XSRF-Token")
10123            .mode(CsrfMode::HeaderOnly)
10124            .rotate_token(true)
10125            .production(false)
10126            .error_message("Custom CSRF error");
10127
10128        assert_eq!(config.cookie_name, "XSRF-TOKEN");
10129        assert_eq!(config.header_name, "X-XSRF-Token");
10130        assert_eq!(config.mode, CsrfMode::HeaderOnly);
10131        assert!(config.rotate_token);
10132        assert!(!config.production);
10133        assert_eq!(config.error_message, Some("Custom CSRF error".to_string()));
10134    }
10135
10136    #[test]
10137    fn csrf_middleware_allows_get_without_token() {
10138        let csrf = CsrfMiddleware::new();
10139        let ctx = test_context();
10140        let mut req = Request::new(crate::request::Method::Get, "/");
10141
10142        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10143        assert!(result.is_continue());
10144        // Token should be generated and stored
10145        assert!(req.get_extension::<CsrfToken>().is_some());
10146    }
10147
10148    #[test]
10149    fn csrf_middleware_allows_head_without_token() {
10150        let csrf = CsrfMiddleware::new();
10151        let ctx = test_context();
10152        let mut req = Request::new(crate::request::Method::Head, "/");
10153
10154        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10155        assert!(result.is_continue());
10156    }
10157
10158    #[test]
10159    fn csrf_middleware_allows_options_without_token() {
10160        let csrf = CsrfMiddleware::new();
10161        let ctx = test_context();
10162        let mut req = Request::new(crate::request::Method::Options, "/");
10163
10164        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10165        assert!(result.is_continue());
10166    }
10167
10168    #[test]
10169    fn csrf_middleware_blocks_post_without_token() {
10170        let csrf = CsrfMiddleware::new();
10171        let ctx = test_context();
10172        let mut req = Request::new(crate::request::Method::Post, "/");
10173
10174        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10175        assert!(result.is_break());
10176
10177        if let ControlFlow::Break(response) = result {
10178            assert_eq!(response.status(), StatusCode::FORBIDDEN);
10179        }
10180    }
10181
10182    #[test]
10183    fn csrf_middleware_blocks_put_without_token() {
10184        let csrf = CsrfMiddleware::new();
10185        let ctx = test_context();
10186        let mut req = Request::new(crate::request::Method::Put, "/");
10187
10188        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10189        assert!(result.is_break());
10190    }
10191
10192    #[test]
10193    fn csrf_middleware_blocks_delete_without_token() {
10194        let csrf = CsrfMiddleware::new();
10195        let ctx = test_context();
10196        let mut req = Request::new(crate::request::Method::Delete, "/");
10197
10198        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10199        assert!(result.is_break());
10200    }
10201
10202    #[test]
10203    fn csrf_middleware_blocks_patch_without_token() {
10204        let csrf = CsrfMiddleware::new();
10205        let ctx = test_context();
10206        let mut req = Request::new(crate::request::Method::Patch, "/");
10207
10208        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10209        assert!(result.is_break());
10210    }
10211
10212    #[test]
10213    fn csrf_middleware_allows_post_with_matching_tokens() {
10214        let csrf = CsrfMiddleware::new();
10215        let ctx = test_context();
10216        let mut req = Request::new(crate::request::Method::Post, "/");
10217
10218        // Set matching cookie and header
10219        let token = "valid-csrf-token-12345";
10220        req.headers_mut()
10221            .insert("cookie", format!("csrf_token={}", token).into_bytes());
10222        req.headers_mut()
10223            .insert("x-csrf-token", token.as_bytes().to_vec());
10224
10225        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10226        assert!(result.is_continue());
10227
10228        // Token should be stored in extensions
10229        let stored_token = req.get_extension::<CsrfToken>().unwrap();
10230        assert_eq!(stored_token.as_str(), token);
10231    }
10232
10233    #[test]
10234    fn csrf_middleware_blocks_post_with_mismatched_tokens() {
10235        let csrf = CsrfMiddleware::new();
10236        let ctx = test_context();
10237        let mut req = Request::new(crate::request::Method::Post, "/");
10238
10239        // Set mismatched cookie and header
10240        req.headers_mut()
10241            .insert("cookie", b"csrf_token=token-in-cookie".to_vec());
10242        req.headers_mut()
10243            .insert("x-csrf-token", b"different-token".to_vec());
10244
10245        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10246        assert!(result.is_break());
10247
10248        if let ControlFlow::Break(response) = result {
10249            assert_eq!(response.status(), StatusCode::FORBIDDEN);
10250        }
10251    }
10252
10253    #[test]
10254    fn csrf_middleware_blocks_post_with_header_only_in_double_submit_mode() {
10255        let csrf = CsrfMiddleware::new();
10256        let ctx = test_context();
10257        let mut req = Request::new(crate::request::Method::Post, "/");
10258
10259        // Only header, no cookie
10260        req.headers_mut()
10261            .insert("x-csrf-token", b"some-token".to_vec());
10262
10263        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10264        assert!(result.is_break());
10265    }
10266
10267    #[test]
10268    fn csrf_middleware_blocks_post_with_cookie_only_in_double_submit_mode() {
10269        let csrf = CsrfMiddleware::new();
10270        let ctx = test_context();
10271        let mut req = Request::new(crate::request::Method::Post, "/");
10272
10273        // Only cookie, no header
10274        req.headers_mut()
10275            .insert("cookie", b"csrf_token=some-token".to_vec());
10276
10277        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10278        assert!(result.is_break());
10279    }
10280
10281    #[test]
10282    fn csrf_middleware_header_only_mode_accepts_header_token() {
10283        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10284        let ctx = test_context();
10285        let mut req = Request::new(crate::request::Method::Post, "/");
10286
10287        req.headers_mut()
10288            .insert("x-csrf-token", b"valid-token".to_vec());
10289
10290        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10291        assert!(result.is_continue());
10292    }
10293
10294    #[test]
10295    fn csrf_middleware_header_only_mode_rejects_empty_header() {
10296        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10297        let ctx = test_context();
10298        let mut req = Request::new(crate::request::Method::Post, "/");
10299
10300        req.headers_mut().insert("x-csrf-token", b"".to_vec());
10301
10302        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10303        assert!(result.is_break());
10304    }
10305
10306    #[test]
10307    fn csrf_middleware_sets_cookie_on_get() {
10308        let csrf = CsrfMiddleware::new();
10309        let ctx = test_context();
10310        let mut req = Request::new(crate::request::Method::Get, "/");
10311
10312        // Run before to generate token
10313        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10314
10315        // Run after to set cookie
10316        let response = Response::ok();
10317        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10318
10319        // Check Set-Cookie header
10320        let cookie_value = header_value(&result, "set-cookie");
10321        assert!(cookie_value.is_some());
10322
10323        let cookie_value = cookie_value.unwrap();
10324        assert!(cookie_value.starts_with("csrf_token="));
10325        assert!(cookie_value.contains("SameSite=Strict"));
10326        assert!(cookie_value.contains("Secure")); // Production mode
10327    }
10328
10329    #[test]
10330    fn csrf_middleware_no_secure_in_dev_mode() {
10331        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().production(false));
10332        let ctx = test_context();
10333        let mut req = Request::new(crate::request::Method::Get, "/");
10334
10335        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10336
10337        let response = Response::ok();
10338        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10339
10340        let cookie_value = header_value(&result, "set-cookie").unwrap();
10341        assert!(!cookie_value.contains("Secure")); // No Secure in dev mode
10342    }
10343
10344    #[test]
10345    fn csrf_middleware_does_not_set_cookie_if_already_present() {
10346        let csrf = CsrfMiddleware::new();
10347        let ctx = test_context();
10348        let mut req = Request::new(crate::request::Method::Get, "/");
10349
10350        // Cookie already present
10351        req.headers_mut()
10352            .insert("cookie", b"csrf_token=existing-token".to_vec());
10353
10354        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10355
10356        let response = Response::ok();
10357        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10358
10359        // Should not set a new cookie
10360        assert!(header_value(&result, "set-cookie").is_none());
10361    }
10362
10363    #[test]
10364    fn csrf_middleware_rotates_token_when_configured() {
10365        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(true));
10366        let ctx = test_context();
10367        let mut req = Request::new(crate::request::Method::Get, "/");
10368
10369        // Cookie already present
10370        req.headers_mut()
10371            .insert("cookie", b"csrf_token=old-token".to_vec());
10372
10373        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10374
10375        let response = Response::ok();
10376        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10377
10378        // Should set a new cookie even though one exists
10379        assert!(header_value(&result, "set-cookie").is_some());
10380    }
10381
10382    #[test]
10383    fn csrf_middleware_custom_header_name() {
10384        let csrf = CsrfMiddleware::with_config(
10385            CsrfConfig::new()
10386                .header_name("X-XSRF-Token")
10387                .cookie_name("XSRF-TOKEN"),
10388        );
10389        let ctx = test_context();
10390        let mut req = Request::new(crate::request::Method::Post, "/");
10391
10392        let token = "custom-token-value";
10393        req.headers_mut()
10394            .insert("cookie", format!("XSRF-TOKEN={}", token).into_bytes());
10395        req.headers_mut()
10396            .insert("x-xsrf-token", token.as_bytes().to_vec());
10397
10398        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10399        assert!(result.is_continue());
10400    }
10401
10402    #[test]
10403    fn csrf_middleware_error_response_is_json() {
10404        let csrf = CsrfMiddleware::new();
10405        let ctx = test_context();
10406        let mut req = Request::new(crate::request::Method::Post, "/");
10407
10408        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10409
10410        if let ControlFlow::Break(response) = result {
10411            let content_type = header_value(&response, "content-type");
10412            assert_eq!(content_type, Some("application/json".to_string()));
10413
10414            // Check body contains proper error structure
10415            if let ResponseBody::Bytes(body) = response.body_ref() {
10416                let body_str = std::str::from_utf8(body).unwrap();
10417                assert!(body_str.contains("csrf_error"));
10418                assert!(body_str.contains("x-csrf-token"));
10419            } else {
10420                panic!("Expected Bytes body");
10421            }
10422        } else {
10423            panic!("Expected Break");
10424        }
10425    }
10426
10427    #[test]
10428    fn csrf_middleware_custom_error_message() {
10429        let csrf = CsrfMiddleware::with_config(
10430            CsrfConfig::new().error_message("Access denied: invalid security token"),
10431        );
10432        let ctx = test_context();
10433        let mut req = Request::new(crate::request::Method::Post, "/");
10434
10435        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10436
10437        if let ControlFlow::Break(response) = result {
10438            if let ResponseBody::Bytes(body) = response.body_ref() {
10439                let body_str = std::str::from_utf8(body).unwrap();
10440                assert!(body_str.contains("Access denied: invalid security token"));
10441            }
10442        }
10443    }
10444
10445    #[test]
10446    fn csrf_middleware_name() {
10447        let csrf = CsrfMiddleware::new();
10448        assert_eq!(csrf.name(), "CSRF");
10449    }
10450
10451    #[test]
10452    fn csrf_middleware_parses_cookie_with_multiple_cookies() {
10453        let csrf = CsrfMiddleware::new();
10454        let ctx = test_context();
10455        let mut req = Request::new(crate::request::Method::Post, "/");
10456
10457        // Multiple cookies in the header
10458        let token = "the-csrf-token";
10459        req.headers_mut().insert(
10460            "cookie",
10461            format!("session=abc123; csrf_token={}; user=test", token).into_bytes(),
10462        );
10463        req.headers_mut()
10464            .insert("x-csrf-token", token.as_bytes().to_vec());
10465
10466        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10467        assert!(result.is_continue());
10468    }
10469
10470    #[test]
10471    fn csrf_middleware_handles_empty_token_value() {
10472        let csrf = CsrfMiddleware::new();
10473        let ctx = test_context();
10474        let mut req = Request::new(crate::request::Method::Post, "/");
10475
10476        // Empty token values
10477        req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
10478        req.headers_mut().insert("x-csrf-token", b"".to_vec());
10479
10480        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10481        assert!(result.is_break()); // Should reject empty tokens
10482    }
10483
10484    // ---- Comprehensive CSRF tests (bd-3v0c) ----
10485
10486    #[test]
10487    fn csrf_token_generate_many_unique() {
10488        // Generate many tokens and verify all are unique
10489        let mut tokens = std::collections::HashSet::new();
10490        for _ in 0..100 {
10491            let token = CsrfToken::generate();
10492            assert!(
10493                tokens.insert(token.0.clone()),
10494                "Duplicate token generated: {}",
10495                token.0
10496            );
10497        }
10498        assert_eq!(tokens.len(), 100);
10499    }
10500
10501    #[test]
10502    fn csrf_token_generate_format_is_hex() {
10503        let token = CsrfToken::generate();
10504        let s = token.as_str();
10505        // Token should be all hex characters, at least 64 chars (32 bytes from urandom)
10506        assert!(
10507            s.len() >= 64,
10508            "Expected at least 64 hex characters, got {} in '{s}'",
10509            s.len()
10510        );
10511        assert!(
10512            s.chars().all(|c| c.is_ascii_hexdigit()),
10513            "Non-hex character in token: {s}"
10514        );
10515    }
10516
10517    #[test]
10518    fn csrf_token_generate_minimum_length() {
10519        let token = CsrfToken::generate();
10520        // 32 bytes from urandom = 64 hex chars
10521        assert!(
10522            token.as_str().len() >= 64,
10523            "Token too short: {} (len={})",
10524            token.as_str(),
10525            token.as_str().len()
10526        );
10527    }
10528
10529    #[test]
10530    fn csrf_token_from_str() {
10531        let token: CsrfToken = "my-token".into();
10532        assert_eq!(token.as_str(), "my-token");
10533        assert_eq!(token.0, "my-token");
10534    }
10535
10536    #[test]
10537    fn csrf_token_clone_eq() {
10538        let t1 = CsrfToken::new("abc");
10539        let t2 = t1.clone();
10540        assert_eq!(t1, t2);
10541        assert_eq!(t1.as_str(), t2.as_str());
10542    }
10543
10544    #[test]
10545    fn csrf_middleware_allows_trace_without_token() {
10546        let csrf = CsrfMiddleware::new();
10547        let ctx = test_context();
10548        let mut req = Request::new(crate::request::Method::Trace, "/");
10549
10550        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10551        assert!(result.is_continue());
10552        // Token should be generated
10553        assert!(req.get_extension::<CsrfToken>().is_some());
10554    }
10555
10556    #[test]
10557    fn csrf_safe_method_generates_token_into_extension() {
10558        let csrf = CsrfMiddleware::new();
10559        let ctx = test_context();
10560
10561        for method in [
10562            crate::request::Method::Get,
10563            crate::request::Method::Head,
10564            crate::request::Method::Options,
10565            crate::request::Method::Trace,
10566        ] {
10567            let mut req = Request::new(method, "/test");
10568            let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10569            assert!(result.is_continue());
10570            let token = req.get_extension::<CsrfToken>().expect("token missing");
10571            assert!(!token.as_str().is_empty());
10572        }
10573    }
10574
10575    #[test]
10576    fn csrf_safe_method_preserves_existing_cookie_token() {
10577        let csrf = CsrfMiddleware::new();
10578        let ctx = test_context();
10579        let mut req = Request::new(crate::request::Method::Get, "/");
10580        req.headers_mut()
10581            .insert("cookie", b"csrf_token=my-existing-token".to_vec());
10582
10583        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10584
10585        // Extension should contain the existing cookie token, not a new one
10586        let token = req.get_extension::<CsrfToken>().unwrap();
10587        assert_eq!(token.as_str(), "my-existing-token");
10588    }
10589
10590    #[test]
10591    fn csrf_valid_post_stores_token_in_extension() {
10592        let csrf = CsrfMiddleware::new();
10593        let ctx = test_context();
10594        let mut req = Request::new(crate::request::Method::Post, "/submit");
10595
10596        let tk = "valid-token-xyz";
10597        req.headers_mut()
10598            .insert("cookie", format!("csrf_token={}", tk).into_bytes());
10599        req.headers_mut()
10600            .insert("x-csrf-token", tk.as_bytes().to_vec());
10601
10602        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10603        assert!(result.is_continue());
10604        let stored = req.get_extension::<CsrfToken>().unwrap();
10605        assert_eq!(stored.as_str(), tk);
10606    }
10607
10608    #[test]
10609    fn csrf_double_submit_both_empty_strings_rejected() {
10610        let csrf = CsrfMiddleware::new();
10611        let ctx = test_context();
10612        let mut req = Request::new(crate::request::Method::Post, "/");
10613
10614        // Both cookie and header have empty string values
10615        req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
10616        req.headers_mut().insert("x-csrf-token", b"".to_vec());
10617
10618        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10619        assert!(result.is_break());
10620    }
10621
10622    #[test]
10623    fn csrf_double_submit_matching_empty_rejected() {
10624        // Even if both are technically "equal" (empty), should reject
10625        let csrf = CsrfMiddleware::new();
10626        let ctx = test_context();
10627        let mut req = Request::new(crate::request::Method::Post, "/");
10628
10629        req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
10630        req.headers_mut().insert("x-csrf-token", b"".to_vec());
10631
10632        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10633        assert!(
10634            result.is_break(),
10635            "Empty matching tokens should be rejected"
10636        );
10637    }
10638
10639    #[test]
10640    fn csrf_header_only_mode_does_not_need_cookie() {
10641        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10642        let ctx = test_context();
10643        let mut req = Request::new(crate::request::Method::Post, "/");
10644
10645        // Header only, no cookie
10646        req.headers_mut()
10647            .insert("x-csrf-token", b"header-only-token".to_vec());
10648
10649        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10650        assert!(result.is_continue());
10651        let token = req.get_extension::<CsrfToken>().unwrap();
10652        assert_eq!(token.as_str(), "header-only-token");
10653    }
10654
10655    #[test]
10656    fn csrf_header_only_mode_ignores_mismatched_cookie() {
10657        // In HeaderOnly mode, the cookie value is irrelevant
10658        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10659        let ctx = test_context();
10660        let mut req = Request::new(crate::request::Method::Post, "/");
10661
10662        req.headers_mut()
10663            .insert("cookie", b"csrf_token=different-value".to_vec());
10664        req.headers_mut()
10665            .insert("x-csrf-token", b"header-value".to_vec());
10666
10667        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10668        assert!(result.is_continue(), "HeaderOnly should ignore cookie");
10669    }
10670
10671    #[test]
10672    fn csrf_header_only_mode_rejects_no_header() {
10673        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10674        let ctx = test_context();
10675        let mut req = Request::new(crate::request::Method::Post, "/");
10676        // No header at all
10677        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10678        assert!(result.is_break());
10679    }
10680
10681    #[test]
10682    fn csrf_header_only_error_message_mentions_header() {
10683        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10684        let ctx = test_context();
10685        let mut req = Request::new(crate::request::Method::Post, "/");
10686
10687        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10688        if let ControlFlow::Break(response) = result {
10689            if let ResponseBody::Bytes(body) = response.body_ref() {
10690                let body_str = std::str::from_utf8(body).unwrap();
10691                assert!(
10692                    body_str.contains("missing in header"),
10693                    "Expected 'missing in header' in: {}",
10694                    body_str
10695                );
10696            }
10697        } else {
10698            panic!("Expected Break");
10699        }
10700    }
10701
10702    #[test]
10703    fn csrf_mismatch_error_differs_from_missing_error() {
10704        let csrf = CsrfMiddleware::new();
10705        let ctx = test_context();
10706
10707        // Missing: no header or cookie
10708        let mut req_missing = Request::new(crate::request::Method::Post, "/");
10709        let missing_result = futures_executor::block_on(csrf.before(&ctx, &mut req_missing));
10710        let missing_body = match missing_result {
10711            ControlFlow::Break(r) => match r.body_ref() {
10712                ResponseBody::Bytes(b) => std::str::from_utf8(b).unwrap().to_string(),
10713                ResponseBody::Empty | ResponseBody::Stream(_) => panic!("Expected Bytes"),
10714            },
10715            ControlFlow::Continue => panic!("Expected Break"),
10716        };
10717
10718        // Mismatch: both present but different
10719        let mut req_mismatch = Request::new(crate::request::Method::Post, "/");
10720        req_mismatch
10721            .headers_mut()
10722            .insert("cookie", b"csrf_token=aaa".to_vec());
10723        req_mismatch
10724            .headers_mut()
10725            .insert("x-csrf-token", b"bbb".to_vec());
10726        let mismatch_result = futures_executor::block_on(csrf.before(&ctx, &mut req_mismatch));
10727        let mismatch_body = match mismatch_result {
10728            ControlFlow::Break(r) => match r.body_ref() {
10729                ResponseBody::Bytes(b) => std::str::from_utf8(b).unwrap().to_string(),
10730                ResponseBody::Empty | ResponseBody::Stream(_) => panic!("Expected Bytes"),
10731            },
10732            ControlFlow::Continue => panic!("Expected Break"),
10733        };
10734
10735        // Error messages should differ
10736        assert_ne!(
10737            missing_body, mismatch_body,
10738            "Missing vs mismatch should have different error messages"
10739        );
10740        assert!(missing_body.contains("missing"));
10741        assert!(mismatch_body.contains("mismatch"));
10742    }
10743
10744    #[test]
10745    fn csrf_cookie_not_httponly() {
10746        // CSRF cookies MUST be readable by JavaScript (no HttpOnly)
10747        let csrf = CsrfMiddleware::new();
10748        let ctx = test_context();
10749        let mut req = Request::new(crate::request::Method::Get, "/");
10750
10751        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10752        let response = Response::ok();
10753        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10754
10755        let cookie_value = header_value(&result, "set-cookie").unwrap();
10756        assert!(
10757            !cookie_value.to_lowercase().contains("httponly"),
10758            "CSRF cookie must NOT be HttpOnly (needs JS access), got: {}",
10759            cookie_value
10760        );
10761    }
10762
10763    #[test]
10764    fn csrf_cookie_has_path_slash() {
10765        let csrf = CsrfMiddleware::new();
10766        let ctx = test_context();
10767        let mut req = Request::new(crate::request::Method::Get, "/");
10768
10769        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10770        let response = Response::ok();
10771        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10772
10773        let cookie_value = header_value(&result, "set-cookie").unwrap();
10774        assert!(
10775            cookie_value.contains("Path=/"),
10776            "Cookie should have Path=/, got: {}",
10777            cookie_value
10778        );
10779    }
10780
10781    #[test]
10782    fn csrf_cookie_has_samesite_strict() {
10783        let csrf = CsrfMiddleware::new();
10784        let ctx = test_context();
10785        let mut req = Request::new(crate::request::Method::Get, "/");
10786
10787        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10788        let response = Response::ok();
10789        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10790
10791        let cookie_value = header_value(&result, "set-cookie").unwrap();
10792        assert!(
10793            cookie_value.contains("SameSite=Strict"),
10794            "Cookie should have SameSite=Strict, got: {}",
10795            cookie_value
10796        );
10797    }
10798
10799    #[test]
10800    fn csrf_production_mode_sets_secure_flag() {
10801        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().production(true));
10802        let ctx = test_context();
10803        let mut req = Request::new(crate::request::Method::Get, "/");
10804
10805        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10806        let response = Response::ok();
10807        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10808
10809        let cookie_value = header_value(&result, "set-cookie").unwrap();
10810        assert!(
10811            cookie_value.contains("Secure"),
10812            "Production cookie must have Secure flag, got: {}",
10813            cookie_value
10814        );
10815    }
10816
10817    #[test]
10818    fn csrf_no_set_cookie_on_post_response() {
10819        // Set-Cookie should only be added for safe methods, not POST
10820        let csrf = CsrfMiddleware::new();
10821        let ctx = test_context();
10822        let mut req = Request::new(crate::request::Method::Post, "/");
10823
10824        let token = "valid-token";
10825        req.headers_mut()
10826            .insert("cookie", format!("csrf_token={}", token).into_bytes());
10827        req.headers_mut()
10828            .insert("x-csrf-token", token.as_bytes().to_vec());
10829
10830        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10831        let response = Response::ok();
10832        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10833
10834        assert!(
10835            header_value(&result, "set-cookie").is_none(),
10836            "POST response should not set CSRF cookie"
10837        );
10838    }
10839
10840    #[test]
10841    fn csrf_head_method_sets_cookie() {
10842        let csrf = CsrfMiddleware::new();
10843        let ctx = test_context();
10844        let mut req = Request::new(crate::request::Method::Head, "/");
10845
10846        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10847        let response = Response::ok();
10848        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10849
10850        assert!(
10851            header_value(&result, "set-cookie").is_some(),
10852            "HEAD response should set CSRF cookie"
10853        );
10854    }
10855
10856    #[test]
10857    fn csrf_options_method_sets_cookie() {
10858        let csrf = CsrfMiddleware::new();
10859        let ctx = test_context();
10860        let mut req = Request::new(crate::request::Method::Options, "/");
10861
10862        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10863        let response = Response::ok();
10864        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10865
10866        assert!(
10867            header_value(&result, "set-cookie").is_some(),
10868            "OPTIONS response should set CSRF cookie"
10869        );
10870    }
10871
10872    #[test]
10873    fn csrf_rotation_produces_different_token_in_cookie() {
10874        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(true));
10875        let ctx = test_context();
10876        let mut req = Request::new(crate::request::Method::Get, "/");
10877
10878        let old_token = "old-token-value";
10879        req.headers_mut()
10880            .insert("cookie", format!("csrf_token={}", old_token).into_bytes());
10881
10882        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10883        let response = Response::ok();
10884        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10885
10886        let cookie_value = header_value(&result, "set-cookie").unwrap();
10887        // When rotation is enabled, old token is reused from cookie parse, but
10888        // the cookie IS set (which the before phase stored in extension).
10889        // The existing token from cookie is used, so cookie_value will contain old_token.
10890        // This verifies the Set-Cookie is emitted even with an existing cookie.
10891        assert!(cookie_value.starts_with("csrf_token="));
10892    }
10893
10894    #[test]
10895    fn csrf_no_rotation_skips_set_cookie_when_present() {
10896        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(false));
10897        let ctx = test_context();
10898        let mut req = Request::new(crate::request::Method::Get, "/");
10899
10900        req.headers_mut()
10901            .insert("cookie", b"csrf_token=existing".to_vec());
10902
10903        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10904        let response = Response::ok();
10905        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10906
10907        assert!(
10908            header_value(&result, "set-cookie").is_none(),
10909            "Without rotation, should not re-set existing cookie"
10910        );
10911    }
10912
10913    #[test]
10914    fn csrf_custom_cookie_name_in_set_cookie_response() {
10915        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().cookie_name("XSRF-TOKEN"));
10916        let ctx = test_context();
10917        let mut req = Request::new(crate::request::Method::Get, "/");
10918
10919        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10920        let response = Response::ok();
10921        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10922
10923        let cookie_value = header_value(&result, "set-cookie").unwrap();
10924        assert!(
10925            cookie_value.starts_with("XSRF-TOKEN="),
10926            "Custom cookie name should appear in Set-Cookie, got: {}",
10927            cookie_value
10928        );
10929    }
10930
10931    #[test]
10932    fn csrf_custom_header_name_validated() {
10933        let csrf = CsrfMiddleware::with_config(
10934            CsrfConfig::new()
10935                .header_name("X-Custom-CSRF")
10936                .cookie_name("my_csrf"),
10937        );
10938        let ctx = test_context();
10939        let mut req = Request::new(crate::request::Method::Post, "/");
10940
10941        let token = "custom-tok";
10942        req.headers_mut()
10943            .insert("cookie", format!("my_csrf={}", token).into_bytes());
10944        req.headers_mut()
10945            .insert("x-custom-csrf", token.as_bytes().to_vec());
10946
10947        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10948        assert!(result.is_continue());
10949    }
10950
10951    #[test]
10952    fn csrf_custom_header_name_wrong_header_rejected() {
10953        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().header_name("X-Custom-CSRF"));
10954        let ctx = test_context();
10955        let mut req = Request::new(crate::request::Method::Post, "/");
10956
10957        let token = "some-token";
10958        req.headers_mut()
10959            .insert("cookie", format!("csrf_token={}", token).into_bytes());
10960        // Using default header name instead of custom one
10961        req.headers_mut()
10962            .insert("x-csrf-token", token.as_bytes().to_vec());
10963
10964        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10965        assert!(result.is_break(), "Wrong header name should be rejected");
10966    }
10967
10968    #[test]
10969    fn csrf_cookie_parsing_multiple_cookies_picks_correct() {
10970        let csrf = CsrfMiddleware::new();
10971        let ctx = test_context();
10972        let mut req = Request::new(crate::request::Method::Post, "/");
10973
10974        let token = "correct-csrf";
10975        req.headers_mut().insert(
10976            "cookie",
10977            format!("session=abc; other=xyz; csrf_token={}; tracking=123", token).into_bytes(),
10978        );
10979        req.headers_mut()
10980            .insert("x-csrf-token", token.as_bytes().to_vec());
10981
10982        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10983        assert!(result.is_continue());
10984    }
10985
10986    #[test]
10987    fn csrf_cookie_parsing_spaces_around_semicolons() {
10988        let csrf = CsrfMiddleware::new();
10989        let ctx = test_context();
10990        let mut req = Request::new(crate::request::Method::Post, "/");
10991
10992        let token = "spaced-token";
10993        req.headers_mut().insert(
10994            "cookie",
10995            format!("session=abc ;  csrf_token={}  ; other=xyz", token).into_bytes(),
10996        );
10997        req.headers_mut()
10998            .insert("x-csrf-token", token.as_bytes().to_vec());
10999
11000        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11001        assert!(result.is_continue());
11002    }
11003
11004    #[test]
11005    fn csrf_error_response_status_is_403() {
11006        let csrf = CsrfMiddleware::new();
11007        let ctx = test_context();
11008
11009        // Test all state-changing methods return 403
11010        for method in [
11011            crate::request::Method::Post,
11012            crate::request::Method::Put,
11013            crate::request::Method::Delete,
11014            crate::request::Method::Patch,
11015        ] {
11016            let mut req = Request::new(method, "/");
11017            let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11018            match result {
11019                ControlFlow::Break(response) => {
11020                    assert_eq!(
11021                        response.status(),
11022                        StatusCode::FORBIDDEN,
11023                        "Expected 403 for {:?}",
11024                        method
11025                    );
11026                }
11027                ControlFlow::Continue => panic!("Expected Break for {:?}", method),
11028            }
11029        }
11030    }
11031
11032    #[test]
11033    fn csrf_error_body_json_structure() {
11034        let csrf = CsrfMiddleware::new();
11035        let ctx = test_context();
11036        let mut req = Request::new(crate::request::Method::Post, "/");
11037
11038        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11039        if let ControlFlow::Break(response) = result {
11040            if let ResponseBody::Bytes(body) = response.body_ref() {
11041                let body_str = std::str::from_utf8(body).unwrap();
11042                // Verify JSON structure
11043                let parsed: serde_json::Value = serde_json::from_str(body_str)
11044                    .unwrap_or_else(|e| panic!("Invalid JSON: {}: {}", body_str, e));
11045                assert!(parsed["detail"].is_array());
11046                let detail = &parsed["detail"][0];
11047                assert_eq!(detail["type"], "csrf_error");
11048                assert!(detail["loc"].is_array());
11049                assert_eq!(detail["loc"][0], "header");
11050                assert_eq!(detail["loc"][1], "x-csrf-token");
11051                assert!(detail["msg"].is_string());
11052            } else {
11053                panic!("Expected Bytes body");
11054            }
11055        } else {
11056            panic!("Expected Break");
11057        }
11058    }
11059
11060    #[test]
11061    fn csrf_default_trait() {
11062        let csrf = CsrfMiddleware::default();
11063        assert_eq!(csrf.name(), "CSRF");
11064        // Should behave identically to new()
11065        let ctx = test_context();
11066        let mut req = Request::new(crate::request::Method::Get, "/");
11067        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11068        assert!(result.is_continue());
11069    }
11070
11071    #[test]
11072    fn csrf_mode_default_is_double_submit() {
11073        assert_eq!(CsrfMode::default(), CsrfMode::DoubleSubmit);
11074    }
11075
11076    #[test]
11077    fn csrf_double_submit_both_present_same_non_empty_passes() {
11078        // Explicit test of the core double-submit pattern
11079        let csrf = CsrfMiddleware::new();
11080        let ctx = test_context();
11081
11082        let token = "a1b2c3d4e5f6";
11083        let mut req = Request::new(crate::request::Method::Delete, "/resource/1");
11084        req.headers_mut()
11085            .insert("cookie", format!("csrf_token={}", token).into_bytes());
11086        req.headers_mut()
11087            .insert("x-csrf-token", token.as_bytes().to_vec());
11088
11089        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11090        assert!(result.is_continue());
11091    }
11092
11093    #[test]
11094    fn csrf_double_submit_case_sensitive() {
11095        // Token comparison should be case-sensitive
11096        let csrf = CsrfMiddleware::new();
11097        let ctx = test_context();
11098        let mut req = Request::new(crate::request::Method::Post, "/");
11099
11100        req.headers_mut()
11101            .insert("cookie", b"csrf_token=AbCdEf".to_vec());
11102        req.headers_mut().insert("x-csrf-token", b"abcdef".to_vec());
11103
11104        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11105        assert!(
11106            result.is_break(),
11107            "Token comparison should be case-sensitive"
11108        );
11109    }
11110
11111    #[test]
11112    fn csrf_token_cookie_extractor_reads_csrf_cookie() {
11113        // Test that CsrfTokenCookie works as a cookie name marker
11114        use crate::extract::{CookieName, CsrfTokenCookie};
11115        assert_eq!(CsrfTokenCookie::NAME, "csrf_token");
11116    }
11117
11118    #[test]
11119    fn csrf_make_set_cookie_header_value_production() {
11120        let value = CsrfMiddleware::make_set_cookie_header_value("csrf_token", "tok123", true);
11121        let s = std::str::from_utf8(&value).unwrap();
11122        assert!(s.contains("csrf_token=tok123"));
11123        assert!(s.contains("Path=/"));
11124        assert!(s.contains("SameSite=Strict"));
11125        assert!(s.contains("Secure"));
11126        assert!(!s.to_lowercase().contains("httponly"));
11127    }
11128
11129    #[test]
11130    fn csrf_make_set_cookie_header_value_development() {
11131        let value = CsrfMiddleware::make_set_cookie_header_value("csrf_token", "tok123", false);
11132        let s = std::str::from_utf8(&value).unwrap();
11133        assert!(s.contains("csrf_token=tok123"));
11134        assert!(s.contains("Path=/"));
11135        assert!(s.contains("SameSite=Strict"));
11136        assert!(!s.contains("Secure"));
11137    }
11138
11139    #[test]
11140    fn csrf_before_after_full_cycle_get_then_post() {
11141        // Simulate a full CSRF flow: GET sets cookie, POST uses it
11142        let csrf = CsrfMiddleware::new();
11143        let ctx = test_context();
11144
11145        // Step 1: GET request - generates token and sets cookie
11146        let mut get_req = Request::new(crate::request::Method::Get, "/form");
11147        let _ = futures_executor::block_on(csrf.before(&ctx, &mut get_req));
11148        let get_response = Response::ok();
11149        let get_result = futures_executor::block_on(csrf.after(&ctx, &get_req, get_response));
11150
11151        let set_cookie = header_value(&get_result, "set-cookie").expect("GET should set cookie");
11152        // Extract token value from "csrf_token=<value>; Path=/; ..."
11153        let token_value = set_cookie
11154            .strip_prefix("csrf_token=")
11155            .unwrap()
11156            .split(';')
11157            .next()
11158            .unwrap();
11159        assert!(!token_value.is_empty());
11160
11161        // Step 2: POST request - uses the token from cookie + header
11162        let mut post_req = Request::new(crate::request::Method::Post, "/form");
11163        post_req
11164            .headers_mut()
11165            .insert("cookie", format!("csrf_token={}", token_value).into_bytes());
11166        post_req
11167            .headers_mut()
11168            .insert("x-csrf-token", token_value.as_bytes().to_vec());
11169
11170        let result = futures_executor::block_on(csrf.before(&ctx, &mut post_req));
11171        assert!(result.is_continue(), "POST with valid token should pass");
11172    }
11173
11174    #[test]
11175    fn csrf_all_state_changing_methods_require_token() {
11176        let csrf = CsrfMiddleware::new();
11177        let ctx = test_context();
11178
11179        for method in [
11180            crate::request::Method::Post,
11181            crate::request::Method::Put,
11182            crate::request::Method::Delete,
11183            crate::request::Method::Patch,
11184        ] {
11185            let mut req = Request::new(method, "/resource");
11186            let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11187            assert!(
11188                result.is_break(),
11189                "{:?} without token should be rejected",
11190                method
11191            );
11192        }
11193    }
11194
11195    #[test]
11196    fn csrf_all_safe_methods_pass_without_token() {
11197        let csrf = CsrfMiddleware::new();
11198        let ctx = test_context();
11199
11200        for method in [
11201            crate::request::Method::Get,
11202            crate::request::Method::Head,
11203            crate::request::Method::Options,
11204            crate::request::Method::Trace,
11205        ] {
11206            let mut req = Request::new(method, "/resource");
11207            let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11208            assert!(
11209                result.is_continue(),
11210                "{:?} should be allowed without token",
11211                method
11212            );
11213        }
11214    }
11215
11216    // =========================================================================
11217    // Middleware Stack Ordering Tests (Onion Model)
11218    // =========================================================================
11219
11220    /// Middleware that records execution order to a shared Vec.
11221    /// Used to verify the onion model (before in order, after in reverse).
11222    #[derive(Clone)]
11223    struct OrderRecordingMiddleware {
11224        id: &'static str,
11225        log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11226    }
11227
11228    impl OrderRecordingMiddleware {
11229        fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11230            Self { id, log }
11231        }
11232    }
11233
11234    impl Middleware for OrderRecordingMiddleware {
11235        fn before<'a>(
11236            &'a self,
11237            _ctx: &'a RequestContext,
11238            _req: &'a mut Request,
11239        ) -> BoxFuture<'a, ControlFlow> {
11240            let id = self.id;
11241            let log = self.log.clone();
11242            Box::pin(async move {
11243                log.lock().unwrap().push(format!("{id}:before"));
11244                ControlFlow::Continue
11245            })
11246        }
11247
11248        fn after<'a>(
11249            &'a self,
11250            _ctx: &'a RequestContext,
11251            _req: &'a Request,
11252            response: Response,
11253        ) -> BoxFuture<'a, Response> {
11254            let id = self.id;
11255            let log = self.log.clone();
11256            Box::pin(async move {
11257                log.lock().unwrap().push(format!("{id}:after"));
11258                response
11259            })
11260        }
11261
11262        fn name(&self) -> &'static str {
11263            "OrderRecording"
11264        }
11265    }
11266
11267    /// Middleware that short-circuits in its before hook.
11268    struct ShortCircuitMiddleware {
11269        id: &'static str,
11270        log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11271    }
11272
11273    impl ShortCircuitMiddleware {
11274        fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11275            Self { id, log }
11276        }
11277    }
11278
11279    impl Middleware for ShortCircuitMiddleware {
11280        fn before<'a>(
11281            &'a self,
11282            _ctx: &'a RequestContext,
11283            _req: &'a mut Request,
11284        ) -> BoxFuture<'a, ControlFlow> {
11285            let id = self.id;
11286            let log = self.log.clone();
11287            Box::pin(async move {
11288                log.lock().unwrap().push(format!("{id}:before:break"));
11289                ControlFlow::Break(
11290                    Response::with_status(StatusCode::FORBIDDEN)
11291                        .body(ResponseBody::Bytes(b"short-circuited".to_vec())),
11292                )
11293            })
11294        }
11295
11296        fn after<'a>(
11297            &'a self,
11298            _ctx: &'a RequestContext,
11299            _req: &'a Request,
11300            response: Response,
11301        ) -> BoxFuture<'a, Response> {
11302            let id = self.id;
11303            let log = self.log.clone();
11304            Box::pin(async move {
11305                log.lock().unwrap().push(format!("{id}:after"));
11306                response
11307            })
11308        }
11309
11310        fn name(&self) -> &'static str {
11311            "ShortCircuit"
11312        }
11313    }
11314
11315    /// Simple handler that records when it runs.
11316    struct RecordingHandler {
11317        log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11318    }
11319
11320    impl RecordingHandler {
11321        fn new(log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11322            Self { log }
11323        }
11324    }
11325
11326    impl Handler for RecordingHandler {
11327        fn call<'a>(
11328            &'a self,
11329            _ctx: &'a RequestContext,
11330            _req: &'a mut Request,
11331        ) -> BoxFuture<'a, Response> {
11332            let log = self.log.clone();
11333            Box::pin(async move {
11334                log.lock().unwrap().push("handler".to_string());
11335                Response::ok().body(ResponseBody::Bytes(b"ok".to_vec()))
11336            })
11337        }
11338    }
11339
11340    #[test]
11341    fn middleware_stack_three_middleware_onion_order() {
11342        // Test that three middleware follow the onion model:
11343        // Before hooks run in order: 1 -> 2 -> 3
11344        // After hooks run in reverse: 3 -> 2 -> 1
11345        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11346
11347        let mut stack = MiddlewareStack::new();
11348        stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11349        stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
11350        stack.push(OrderRecordingMiddleware::new("mw3", log.clone()));
11351
11352        let handler = RecordingHandler::new(log.clone());
11353        let ctx = test_context();
11354        let mut req = Request::new(crate::request::Method::Get, "/");
11355
11356        let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11357
11358        let execution_log = log.lock().unwrap().clone();
11359        assert_eq!(
11360            execution_log,
11361            vec![
11362                "mw1:before",
11363                "mw2:before",
11364                "mw3:before",
11365                "handler",
11366                "mw3:after",
11367                "mw2:after",
11368                "mw1:after",
11369            ]
11370        );
11371    }
11372
11373    #[test]
11374    fn middleware_stack_short_circuit_runs_prior_after_hooks() {
11375        // When middleware 2 short-circuits:
11376        // - mw1:before runs (returns Continue, count=1)
11377        // - mw2:before short-circuits (returns Break, count stays at 1)
11378        // - mw3:before does NOT run
11379        // - handler does NOT run
11380        // - Only middleware that successfully completed before (mw1) have after run
11381        // - mw1:after runs
11382        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11383
11384        let mut stack = MiddlewareStack::new();
11385        stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11386        stack.push(ShortCircuitMiddleware::new("mw2", log.clone()));
11387        stack.push(OrderRecordingMiddleware::new("mw3", log.clone()));
11388
11389        let handler = RecordingHandler::new(log.clone());
11390        let ctx = test_context();
11391        let mut req = Request::new(crate::request::Method::Get, "/");
11392
11393        let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11394
11395        // Should return the short-circuit response
11396        assert_eq!(response.status().as_u16(), 403);
11397
11398        let execution_log = log.lock().unwrap().clone();
11399        // Note: mw2's after hook does NOT run because it didn't return Continue
11400        // Only middleware that successfully completed before (returned Continue) have after run
11401        assert_eq!(
11402            execution_log,
11403            vec!["mw1:before", "mw2:before:break", "mw1:after",]
11404        );
11405    }
11406
11407    #[test]
11408    fn middleware_stack_first_middleware_short_circuits() {
11409        // When the first middleware short-circuits:
11410        // - mw1:before short-circuits (returns Break, count=0)
11411        // - No after hooks run (count=0)
11412        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11413
11414        let mut stack = MiddlewareStack::new();
11415        stack.push(ShortCircuitMiddleware::new("mw1", log.clone()));
11416        stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
11417
11418        let handler = RecordingHandler::new(log.clone());
11419        let ctx = test_context();
11420        let mut req = Request::new(crate::request::Method::Get, "/");
11421
11422        let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11423        assert_eq!(response.status().as_u16(), 403);
11424
11425        let execution_log = log.lock().unwrap().clone();
11426        // No after hooks run because no middleware returned Continue
11427        assert_eq!(execution_log, vec!["mw1:before:break",]);
11428    }
11429
11430    #[test]
11431    fn middleware_stack_empty_runs_handler_only() {
11432        // Empty stack should just run the handler (onion ordering variant)
11433        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11434
11435        let stack = MiddlewareStack::new();
11436        let handler = RecordingHandler::new(log.clone());
11437        let ctx = test_context();
11438        let mut req = Request::new(crate::request::Method::Get, "/");
11439
11440        let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11441        assert_eq!(response.status().as_u16(), 200);
11442
11443        let execution_log = log.lock().unwrap().clone();
11444        assert_eq!(execution_log, vec!["handler"]);
11445    }
11446
11447    #[test]
11448    fn middleware_stack_single_middleware_ordering() {
11449        // Single middleware should have before -> handler -> after
11450        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11451
11452        let mut stack = MiddlewareStack::new();
11453        stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11454
11455        let handler = RecordingHandler::new(log.clone());
11456        let ctx = test_context();
11457        let mut req = Request::new(crate::request::Method::Get, "/");
11458
11459        let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11460
11461        let execution_log = log.lock().unwrap().clone();
11462        assert_eq!(execution_log, vec!["mw1:before", "handler", "mw1:after",]);
11463    }
11464
11465    #[test]
11466    fn middleware_stack_five_middleware_onion_order() {
11467        // Test with five middleware for a longer chain
11468        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11469
11470        let mut stack = MiddlewareStack::new();
11471        stack.push(OrderRecordingMiddleware::new("a", log.clone()));
11472        stack.push(OrderRecordingMiddleware::new("b", log.clone()));
11473        stack.push(OrderRecordingMiddleware::new("c", log.clone()));
11474        stack.push(OrderRecordingMiddleware::new("d", log.clone()));
11475        stack.push(OrderRecordingMiddleware::new("e", log.clone()));
11476
11477        let handler = RecordingHandler::new(log.clone());
11478        let ctx = test_context();
11479        let mut req = Request::new(crate::request::Method::Get, "/");
11480
11481        let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11482
11483        let execution_log = log.lock().unwrap().clone();
11484        assert_eq!(
11485            execution_log,
11486            vec![
11487                "a:before", "b:before", "c:before", "d:before", "e:before", "handler", "e:after",
11488                "d:after", "c:after", "b:after", "a:after",
11489            ]
11490        );
11491    }
11492
11493    #[test]
11494    fn middleware_stack_short_circuit_at_end_runs_prior_afters() {
11495        // When the last middleware short-circuits:
11496        // - mw1:before runs (Continue, count=1)
11497        // - mw2:before runs (Continue, count=2)
11498        // - mw3:before short-circuits (Break, count stays at 2)
11499        // - handler does NOT run
11500        // - After hooks run for mw1 and mw2 only (they returned Continue)
11501        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11502
11503        let mut stack = MiddlewareStack::new();
11504        stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11505        stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
11506        stack.push(ShortCircuitMiddleware::new("mw3", log.clone()));
11507
11508        let handler = RecordingHandler::new(log.clone());
11509        let ctx = test_context();
11510        let mut req = Request::new(crate::request::Method::Get, "/");
11511
11512        let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11513        assert_eq!(response.status().as_u16(), 403);
11514
11515        let execution_log = log.lock().unwrap().clone();
11516        // mw3's after hook does NOT run because it didn't return Continue
11517        assert_eq!(
11518            execution_log,
11519            vec![
11520                "mw1:before",
11521                "mw2:before",
11522                "mw3:before:break",
11523                "mw2:after",
11524                "mw1:after",
11525            ]
11526        );
11527    }
11528
11529    /// Middleware that modifies the request in before and response in after.
11530    struct ModifyingMiddleware {
11531        id: &'static str,
11532        log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11533    }
11534
11535    impl ModifyingMiddleware {
11536        fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11537            Self { id, log }
11538        }
11539    }
11540
11541    impl Middleware for ModifyingMiddleware {
11542        fn before<'a>(
11543            &'a self,
11544            _ctx: &'a RequestContext,
11545            req: &'a mut Request,
11546        ) -> BoxFuture<'a, ControlFlow> {
11547            let id = self.id;
11548            let log = self.log.clone();
11549            Box::pin(async move {
11550                // Add a header to track middleware order
11551                req.headers_mut()
11552                    .insert(format!("x-{id}-before"), b"true".to_vec());
11553                log.lock().unwrap().push(format!("{id}:before"));
11554                ControlFlow::Continue
11555            })
11556        }
11557
11558        fn after<'a>(
11559            &'a self,
11560            _ctx: &'a RequestContext,
11561            _req: &'a Request,
11562            response: Response,
11563        ) -> BoxFuture<'a, Response> {
11564            let id = self.id;
11565            let log = self.log.clone();
11566            Box::pin(async move {
11567                log.lock().unwrap().push(format!("{id}:after"));
11568                // Add a header to the response
11569                response.header(format!("x-{id}-after"), b"true".to_vec())
11570            })
11571        }
11572
11573        fn name(&self) -> &'static str {
11574            "Modifying"
11575        }
11576    }
11577
11578    #[test]
11579    fn middleware_stack_modifications_accumulate_correctly() {
11580        // Test that request modifications in before hooks accumulate,
11581        // and response modifications in after hooks accumulate
11582        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11583
11584        let mut stack = MiddlewareStack::new();
11585        stack.push(ModifyingMiddleware::new("mw1", log.clone()));
11586        stack.push(ModifyingMiddleware::new("mw2", log.clone()));
11587        stack.push(ModifyingMiddleware::new("mw3", log.clone()));
11588
11589        let handler = RecordingHandler::new(log.clone());
11590        let ctx = test_context();
11591        let mut req = Request::new(crate::request::Method::Get, "/");
11592
11593        let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11594
11595        // Check that all after hooks added their headers
11596        assert!(header_value(&response, "x-mw1-after").is_some());
11597        assert!(header_value(&response, "x-mw2-after").is_some());
11598        assert!(header_value(&response, "x-mw3-after").is_some());
11599
11600        // Check that the request was modified by all before hooks
11601        assert!(req.headers().contains("x-mw1-before"));
11602        assert!(req.headers().contains("x-mw2-before"));
11603        assert!(req.headers().contains("x-mw3-before"));
11604    }
11605
11606    #[test]
11607    fn layer_wrap_maintains_middleware_order() {
11608        // Test that Layer::wrap creates a Layered handler that maintains before->after ordering
11609        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11610
11611        // Create a layer with our recording middleware
11612        let layer = Layer::new(OrderRecordingMiddleware::new("layer", log.clone()));
11613
11614        // Wrap the recording handler
11615        let handler = RecordingHandler::new(log.clone());
11616        let layered_handler = layer.wrap(handler);
11617
11618        let ctx = test_context();
11619        let mut req = Request::new(crate::request::Method::Get, "/");
11620
11621        // Execute the layered handler directly (not via middleware stack)
11622        let _response = futures_executor::block_on(layered_handler.call(&ctx, &mut req));
11623
11624        let execution_log = log.lock().unwrap().clone();
11625        assert_eq!(
11626            execution_log,
11627            vec!["layer:before", "handler", "layer:after",]
11628        );
11629    }
11630}
11631
11632// ============================================================================
11633// Compression Middleware Tests (requires "compression" feature)
11634// ============================================================================
11635
11636#[cfg(all(test, feature = "compression"))]
11637mod compression_tests {
11638    use super::*;
11639    use crate::request::Method;
11640    use crate::response::ResponseBody;
11641
11642    fn test_context() -> RequestContext {
11643        RequestContext::new(asupersync::Cx::for_testing(), 1)
11644    }
11645
11646    #[test]
11647    fn compression_config_defaults() {
11648        let config = CompressionConfig::default();
11649        assert_eq!(config.min_size, 1024);
11650        assert_eq!(config.level, 6);
11651        assert!(!config.skip_content_types.is_empty());
11652    }
11653
11654    #[test]
11655    fn compression_config_builder() {
11656        let config = CompressionConfig::new().min_size(512).level(9);
11657        assert_eq!(config.min_size, 512);
11658        assert_eq!(config.level, 9);
11659    }
11660
11661    #[test]
11662    fn compression_level_clamped() {
11663        let config = CompressionConfig::new().level(100);
11664        assert_eq!(config.level, 9);
11665
11666        let config = CompressionConfig::new().level(0);
11667        assert_eq!(config.level, 1);
11668    }
11669
11670    #[test]
11671    fn skip_content_type_exact_match() {
11672        let config = CompressionConfig::default();
11673        assert!(config.should_skip_content_type("image/jpeg"));
11674        assert!(config.should_skip_content_type("image/jpeg; charset=utf-8"));
11675        assert!(!config.should_skip_content_type("text/html"));
11676    }
11677
11678    #[test]
11679    fn skip_content_type_prefix_match() {
11680        let config = CompressionConfig::default();
11681        // "video/" prefix should match any video type
11682        assert!(config.should_skip_content_type("video/mp4"));
11683        assert!(config.should_skip_content_type("video/webm"));
11684        assert!(config.should_skip_content_type("audio/mpeg"));
11685    }
11686
11687    #[test]
11688    fn compression_skips_small_responses() {
11689        let middleware = CompressionMiddleware::new();
11690        let ctx = test_context();
11691
11692        // Create request with Accept-Encoding: gzip
11693        let mut req = Request::new(Method::Get, "/");
11694        req.headers_mut()
11695            .insert("accept-encoding", b"gzip".to_vec());
11696
11697        // Create a small response (less than 1024 bytes)
11698        let response = Response::ok()
11699            .header("content-type", b"text/plain".to_vec())
11700            .body(ResponseBody::Bytes(b"Hello, World!".to_vec()));
11701
11702        // Run the after hook
11703        let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11704
11705        // Should NOT be compressed (too small)
11706        let has_encoding = result
11707            .headers()
11708            .iter()
11709            .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11710        assert!(!has_encoding, "Small response should not be compressed");
11711    }
11712
11713    #[test]
11714    fn compression_works_for_large_responses() {
11715        let config = CompressionConfig::new().min_size(10); // Lower threshold
11716        let middleware = CompressionMiddleware::with_config(config);
11717        let ctx = test_context();
11718
11719        // Create request with Accept-Encoding: gzip
11720        let mut req = Request::new(Method::Get, "/");
11721        req.headers_mut()
11722            .insert("accept-encoding", b"gzip".to_vec());
11723
11724        // Create a response with repetitive content (compresses well)
11725        let body = "Hello, World! ".repeat(100);
11726        let original_size = body.len();
11727
11728        let response = Response::ok()
11729            .header("content-type", b"text/plain".to_vec())
11730            .body(ResponseBody::Bytes(body.into_bytes()));
11731
11732        // Run the after hook
11733        let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11734
11735        // Should be compressed
11736        let encoding = result
11737            .headers()
11738            .iter()
11739            .find(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11740        assert!(encoding.is_some(), "Large response should be compressed");
11741
11742        let (_, value) = encoding.unwrap();
11743        assert_eq!(value, b"gzip");
11744
11745        // Check Vary header
11746        let vary = result
11747            .headers()
11748            .iter()
11749            .find(|(name, _)| name.eq_ignore_ascii_case("vary"));
11750        assert!(vary.is_some(), "Should have Vary header");
11751
11752        // Verify compressed size is smaller
11753        if let ResponseBody::Bytes(compressed) = result.body_ref() {
11754            assert!(
11755                compressed.len() < original_size,
11756                "Compressed size should be smaller"
11757            );
11758        } else {
11759            panic!("Expected Bytes body");
11760        }
11761    }
11762
11763    #[test]
11764    fn compression_skips_without_accept_encoding() {
11765        let config = CompressionConfig::new().min_size(10);
11766        let middleware = CompressionMiddleware::with_config(config);
11767        let ctx = test_context();
11768
11769        // Create request WITHOUT Accept-Encoding
11770        let req = Request::new(Method::Get, "/");
11771
11772        let body = "Hello, World! ".repeat(100);
11773        let response = Response::ok()
11774            .header("content-type", b"text/plain".to_vec())
11775            .body(ResponseBody::Bytes(body.into_bytes()));
11776
11777        let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11778
11779        // Should NOT be compressed (no Accept-Encoding)
11780        let has_encoding = result
11781            .headers()
11782            .iter()
11783            .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11784        assert!(!has_encoding, "Should not compress without Accept-Encoding");
11785    }
11786
11787    #[test]
11788    fn compression_skips_already_compressed_content() {
11789        let config = CompressionConfig::new().min_size(10);
11790        let middleware = CompressionMiddleware::with_config(config);
11791        let ctx = test_context();
11792
11793        // Create request with Accept-Encoding: gzip
11794        let mut req = Request::new(Method::Get, "/");
11795        req.headers_mut()
11796            .insert("accept-encoding", b"gzip".to_vec());
11797
11798        // Create response with already-compressed content type
11799        let body = "Some image data".repeat(100);
11800        let response = Response::ok()
11801            .header("content-type", b"image/jpeg".to_vec())
11802            .body(ResponseBody::Bytes(body.into_bytes()));
11803
11804        let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11805
11806        // Should NOT be compressed (image/jpeg is already compressed)
11807        let has_encoding = result
11808            .headers()
11809            .iter()
11810            .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11811        assert!(
11812            !has_encoding,
11813            "Should not compress already-compressed content types"
11814        );
11815    }
11816
11817    #[test]
11818    fn compression_skips_if_already_has_content_encoding() {
11819        let config = CompressionConfig::new().min_size(10);
11820        let middleware = CompressionMiddleware::with_config(config);
11821        let ctx = test_context();
11822
11823        // Create request with Accept-Encoding: gzip
11824        let mut req = Request::new(Method::Get, "/");
11825        req.headers_mut()
11826            .insert("accept-encoding", b"gzip".to_vec());
11827
11828        // Create response that already has Content-Encoding
11829        let body = "Hello, World! ".repeat(100);
11830        let response = Response::ok()
11831            .header("content-type", b"text/plain".to_vec())
11832            .header("content-encoding", b"br".to_vec())
11833            .body(ResponseBody::Bytes(body.into_bytes()));
11834
11835        let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11836
11837        // Should NOT double-compress
11838        let encodings: Vec<_> = result
11839            .headers()
11840            .iter()
11841            .filter(|(name, _)| name.eq_ignore_ascii_case("content-encoding"))
11842            .collect();
11843
11844        // Should still have exactly one Content-Encoding header (the original br)
11845        assert_eq!(encodings.len(), 1);
11846        assert_eq!(encodings[0].1, b"br");
11847    }
11848
11849    #[test]
11850    fn accepts_gzip_parses_header_correctly() {
11851        // Test various Accept-Encoding header formats
11852
11853        // Simple gzip
11854        let mut req = Request::new(Method::Get, "/");
11855        req.headers_mut()
11856            .insert("accept-encoding", b"gzip".to_vec());
11857        assert!(CompressionMiddleware::accepts_gzip(&req));
11858
11859        // Multiple encodings
11860        let mut req = Request::new(Method::Get, "/");
11861        req.headers_mut()
11862            .insert("accept-encoding", b"deflate, gzip, br".to_vec());
11863        assert!(CompressionMiddleware::accepts_gzip(&req));
11864
11865        // With quality values
11866        let mut req = Request::new(Method::Get, "/");
11867        req.headers_mut()
11868            .insert("accept-encoding", b"gzip;q=1.0, identity;q=0.5".to_vec());
11869        assert!(CompressionMiddleware::accepts_gzip(&req));
11870
11871        // Wildcard
11872        let mut req = Request::new(Method::Get, "/");
11873        req.headers_mut().insert("accept-encoding", b"*".to_vec());
11874        assert!(CompressionMiddleware::accepts_gzip(&req));
11875
11876        // No gzip
11877        let mut req = Request::new(Method::Get, "/");
11878        req.headers_mut()
11879            .insert("accept-encoding", b"deflate, br".to_vec());
11880        assert!(!CompressionMiddleware::accepts_gzip(&req));
11881
11882        // No header
11883        let req_no_header = Request::new(Method::Get, "/");
11884        assert!(!CompressionMiddleware::accepts_gzip(&req_no_header));
11885    }
11886
11887    #[test]
11888    fn compression_middleware_name() {
11889        let middleware = CompressionMiddleware::new();
11890        assert_eq!(middleware.name(), "Compression");
11891    }
11892}
11893
11894// ============================================================================
11895// Request Inspection Middleware Tests
11896// ============================================================================
11897
11898#[cfg(test)]
11899mod request_inspection_tests {
11900    use super::*;
11901    use crate::request::Method;
11902    use crate::response::ResponseBody;
11903
11904    fn test_context() -> RequestContext {
11905        RequestContext::new(asupersync::Cx::for_testing(), 1)
11906    }
11907
11908    #[test]
11909    fn inspection_middleware_default_creates_normal_verbosity() {
11910        let mw = RequestInspectionMiddleware::new();
11911        assert_eq!(mw.verbosity, InspectionVerbosity::Normal);
11912        assert_eq!(mw.slow_threshold_ms, 1000);
11913        assert_eq!(mw.max_body_preview, 2048);
11914        assert_eq!(mw.name(), "RequestInspection");
11915    }
11916
11917    #[test]
11918    fn inspection_middleware_builder_methods() {
11919        let mw = RequestInspectionMiddleware::new()
11920            .verbosity(InspectionVerbosity::Verbose)
11921            .slow_threshold_ms(500)
11922            .max_body_preview(4096)
11923            .log_config(LogConfig::development())
11924            .redact_header("x-api-key");
11925
11926        assert_eq!(mw.verbosity, InspectionVerbosity::Verbose);
11927        assert_eq!(mw.slow_threshold_ms, 500);
11928        assert_eq!(mw.max_body_preview, 4096);
11929        assert!(mw.redact_headers.contains("x-api-key"));
11930        // Default redacted headers should still be present
11931        assert!(mw.redact_headers.contains("authorization"));
11932        assert!(mw.redact_headers.contains("cookie"));
11933    }
11934
11935    #[test]
11936    fn inspection_before_continues_processing() {
11937        let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
11938        let ctx = test_context();
11939        let mut req = Request::new(Method::Post, "/api/users");
11940
11941        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
11942        assert!(result.is_continue());
11943    }
11944
11945    #[test]
11946    fn inspection_after_returns_response_unchanged() {
11947        let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
11948        let ctx = test_context();
11949        let mut req = Request::new(Method::Get, "/health");
11950
11951        // Run before to set the InspectionStart extension
11952        let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
11953
11954        let response = Response::ok().body(ResponseBody::Bytes(b"OK".to_vec()));
11955
11956        let result = futures_executor::block_on(mw.after(&ctx, &req, response));
11957        assert_eq!(result.status().as_u16(), 200);
11958        assert_eq!(result.body_ref().len(), 2);
11959    }
11960
11961    #[test]
11962    fn inspection_stores_start_extension() {
11963        let mw = RequestInspectionMiddleware::new();
11964        let ctx = test_context();
11965        let mut req = Request::new(Method::Get, "/");
11966
11967        let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
11968
11969        // Verify the InspectionStart extension was set
11970        assert!(req.get_extension::<InspectionStart>().is_some());
11971    }
11972
11973    #[test]
11974    fn inspection_all_verbosity_levels_continue() {
11975        for verbosity in [
11976            InspectionVerbosity::Minimal,
11977            InspectionVerbosity::Normal,
11978            InspectionVerbosity::Verbose,
11979        ] {
11980            let mw = RequestInspectionMiddleware::new().verbosity(verbosity);
11981            let ctx = test_context();
11982            let mut req = Request::new(Method::Get, "/test");
11983            req.headers_mut()
11984                .insert("content-type", b"text/plain".to_vec());
11985
11986            let result = futures_executor::block_on(mw.before(&ctx, &mut req));
11987            assert!(
11988                result.is_continue(),
11989                "Verbosity {verbosity:?} should continue"
11990            );
11991        }
11992    }
11993
11994    #[test]
11995    fn inspection_verbose_with_json_body() {
11996        let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Verbose);
11997        let ctx = test_context();
11998        let body = br#"{"name":"Alice","age":30}"#;
11999        let mut req = Request::new(Method::Post, "/api/users");
12000        req.headers_mut()
12001            .insert("content-type", b"application/json".to_vec());
12002        req.set_body(Body::Bytes(body.to_vec()));
12003
12004        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
12005        assert!(result.is_continue());
12006    }
12007
12008    #[test]
12009    fn inspection_verbose_after_with_json_response() {
12010        let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Verbose);
12011        let ctx = test_context();
12012        let mut req = Request::new(Method::Get, "/api/users/1");
12013
12014        let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
12015
12016        let response = Response::ok()
12017            .header("content-type", b"application/json".to_vec())
12018            .body(ResponseBody::Bytes(br#"{"id":1,"name":"Alice"}"#.to_vec()));
12019
12020        let result = futures_executor::block_on(mw.after(&ctx, &req, response));
12021        assert_eq!(result.status().as_u16(), 200);
12022    }
12023
12024    #[test]
12025    fn inspection_redacts_sensitive_headers() {
12026        let mw = RequestInspectionMiddleware::new();
12027
12028        // Verify default redacted headers are present
12029        assert!(mw.redact_headers.contains("authorization"));
12030        assert!(mw.redact_headers.contains("proxy-authorization"));
12031        assert!(mw.redact_headers.contains("cookie"));
12032        assert!(mw.redact_headers.contains("set-cookie"));
12033    }
12034
12035    #[test]
12036    fn inspection_format_headers_redacts() {
12037        let mw = RequestInspectionMiddleware::new().redact_header("x-secret");
12038
12039        let headers = vec![
12040            ("content-type", b"text/plain".as_slice()),
12041            ("x-secret", b"my-secret-value".as_slice()),
12042            ("x-normal", b"visible".as_slice()),
12043        ];
12044
12045        let output = mw.format_inspection_headers(headers.into_iter());
12046        assert!(output.contains("content-type: text/plain"));
12047        assert!(output.contains("x-secret: [REDACTED]"));
12048        assert!(output.contains("x-normal: visible"));
12049        assert!(!output.contains("my-secret-value"));
12050    }
12051
12052    #[test]
12053    fn inspection_format_body_preview_truncates() {
12054        let mw = RequestInspectionMiddleware::new().max_body_preview(10);
12055
12056        let body = b"Hello, World! This is a long body.";
12057        let result = mw.format_body_preview(body, None);
12058        assert!(result.is_some());
12059        let text = result.unwrap();
12060        assert!(text.ends_with("..."));
12061        assert!(text.len() <= 15); // 10 chars + "..."
12062    }
12063
12064    #[test]
12065    fn inspection_format_body_preview_empty() {
12066        let mw = RequestInspectionMiddleware::new();
12067        assert!(mw.format_body_preview(b"", None).is_none());
12068    }
12069
12070    #[test]
12071    fn inspection_format_body_preview_zero_max() {
12072        let mw = RequestInspectionMiddleware::new().max_body_preview(0);
12073        assert!(mw.format_body_preview(b"hello", None).is_none());
12074    }
12075
12076    #[test]
12077    fn inspection_format_body_preview_json_pretty() {
12078        let mw = RequestInspectionMiddleware::new();
12079        let body = br#"{"key":"value","num":42}"#;
12080        let ct = b"application/json".as_slice();
12081        let result = mw.format_body_preview(body, Some(ct));
12082        assert!(result.is_some());
12083        let text = result.unwrap();
12084        // Pretty-printed JSON should contain newlines
12085        assert!(text.contains('\n'));
12086        assert!(text.contains("\"key\": \"value\""));
12087    }
12088
12089    #[test]
12090    fn inspection_format_body_preview_non_json() {
12091        let mw = RequestInspectionMiddleware::new();
12092        let body = b"Hello, World!";
12093        let ct = b"text/plain".as_slice();
12094        let result = mw.format_body_preview(body, Some(ct));
12095        assert_eq!(result.unwrap(), "Hello, World!");
12096    }
12097
12098    #[test]
12099    fn inspection_format_body_preview_binary() {
12100        let mw = RequestInspectionMiddleware::new();
12101        let body: &[u8] = &[0xFF, 0xFE, 0xFD, 0x00];
12102        let result = mw.format_body_preview(body, None);
12103        assert!(result.is_some());
12104        assert!(result.unwrap().contains("binary"));
12105    }
12106
12107    #[test]
12108    fn try_pretty_json_valid_object() {
12109        let result = try_pretty_json(r#"{"a":"b","c":1}"#);
12110        assert!(result.is_some());
12111        let pretty = result.unwrap();
12112        assert!(pretty.contains('\n'));
12113        assert!(pretty.contains("  \"a\": \"b\""));
12114    }
12115
12116    #[test]
12117    fn try_pretty_json_valid_array() {
12118        let result = try_pretty_json(r"[1,2,3]");
12119        assert!(result.is_some());
12120        let pretty = result.unwrap();
12121        assert!(pretty.contains('\n'));
12122    }
12123
12124    #[test]
12125    fn try_pretty_json_empty_object() {
12126        let result = try_pretty_json("{}");
12127        assert!(result.is_some());
12128        assert_eq!(result.unwrap(), "{}");
12129    }
12130
12131    #[test]
12132    fn try_pretty_json_empty_array() {
12133        let result = try_pretty_json("[]");
12134        assert!(result.is_some());
12135        assert_eq!(result.unwrap(), "[]");
12136    }
12137
12138    #[test]
12139    fn try_pretty_json_not_json() {
12140        assert!(try_pretty_json("hello world").is_none());
12141        assert!(try_pretty_json("12345").is_none());
12142    }
12143
12144    #[test]
12145    fn try_pretty_json_nested() {
12146        let input = r#"{"user":{"name":"Alice","roles":["admin","user"]}}"#;
12147        let result = try_pretty_json(input);
12148        assert!(result.is_some());
12149        let pretty = result.unwrap();
12150        assert!(pretty.contains("\"user\":"));
12151        assert!(pretty.contains("\"name\": \"Alice\""));
12152        assert!(pretty.contains("\"roles\":"));
12153    }
12154
12155    #[test]
12156    fn try_pretty_json_with_escapes() {
12157        let input = r#"{"msg":"hello \"world\""}"#;
12158        let result = try_pretty_json(input);
12159        assert!(result.is_some());
12160        let pretty = result.unwrap();
12161        assert!(pretty.contains(r#"\"world\""#));
12162    }
12163
12164    #[test]
12165    fn inspection_name() {
12166        let mw = RequestInspectionMiddleware::new();
12167        assert_eq!(mw.name(), "RequestInspection");
12168    }
12169
12170    #[test]
12171    fn inspection_default_via_default_trait() {
12172        let mw = RequestInspectionMiddleware::default();
12173        assert_eq!(mw.verbosity, InspectionVerbosity::Normal);
12174        assert_eq!(mw.slow_threshold_ms, 1000);
12175    }
12176
12177    #[test]
12178    fn inspection_with_query_string() {
12179        let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
12180        let ctx = test_context();
12181        let mut req = Request::new(Method::Get, "/search");
12182        req.set_query(Some("q=rust&page=1".to_string()));
12183
12184        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
12185        assert!(result.is_continue());
12186    }
12187
12188    #[test]
12189    fn inspection_response_body_stream() {
12190        let mw = RequestInspectionMiddleware::new();
12191        let result = mw.format_response_preview(&ResponseBody::Empty, None);
12192        assert!(result.is_none());
12193    }
12194}
12195
12196// ============================================================================
12197// Rate Limiting Middleware Tests
12198// ============================================================================
12199
12200#[cfg(test)]
12201mod rate_limit_tests {
12202    use super::*;
12203    use crate::request::Method;
12204    use crate::response::{ResponseBody, StatusCode};
12205    use std::time::Duration;
12206
12207    fn test_context() -> RequestContext {
12208        RequestContext::new(asupersync::Cx::for_testing(), 1)
12209    }
12210
12211    fn run_rate_limit_before(mw: &RateLimitMiddleware, req: &mut Request) -> ControlFlow {
12212        let ctx = test_context();
12213        let fut = mw.before(&ctx, req);
12214        futures_executor::block_on(fut)
12215    }
12216
12217    fn run_rate_limit_after(mw: &RateLimitMiddleware, req: &Request, resp: Response) -> Response {
12218        let ctx = test_context();
12219        let fut = mw.after(&ctx, req, resp);
12220        futures_executor::block_on(fut)
12221    }
12222
12223    #[test]
12224    fn rate_limit_default_allows_requests() {
12225        let mw = RateLimitMiddleware::new();
12226        let mut req = Request::new(Method::Get, "/api/test");
12227        req.headers_mut()
12228            .insert("x-forwarded-for", b"192.168.1.1".to_vec());
12229
12230        let result = run_rate_limit_before(&mw, &mut req);
12231        assert!(result.is_continue(), "first request should be allowed");
12232    }
12233
12234    #[test]
12235    fn rate_limit_fixed_window_blocks_after_limit() {
12236        let mw = RateLimitMiddleware::builder()
12237            .requests(3)
12238            .per(Duration::from_secs(60))
12239            .algorithm(RateLimitAlgorithm::FixedWindow)
12240            .key_extractor(IpKeyExtractor)
12241            .build();
12242
12243        for i in 0..3 {
12244            let mut req = Request::new(Method::Get, "/api/test");
12245            req.headers_mut()
12246                .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12247            let result = run_rate_limit_before(&mw, &mut req);
12248            assert!(
12249                result.is_continue(),
12250                "request {i} should be allowed within limit"
12251            );
12252        }
12253
12254        // Fourth request should be blocked
12255        let mut req = Request::new(Method::Get, "/api/test");
12256        req.headers_mut()
12257            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12258        let result = run_rate_limit_before(&mw, &mut req);
12259        assert!(result.is_break(), "fourth request should be blocked");
12260
12261        // Verify 429 status
12262        if let ControlFlow::Break(resp) = result {
12263            assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
12264        }
12265    }
12266
12267    #[test]
12268    fn rate_limit_different_keys_independent() {
12269        let mw = RateLimitMiddleware::builder()
12270            .requests(2)
12271            .per(Duration::from_secs(60))
12272            .algorithm(RateLimitAlgorithm::FixedWindow)
12273            .key_extractor(IpKeyExtractor)
12274            .build();
12275
12276        // Two requests from IP A
12277        for _ in 0..2 {
12278            let mut req = Request::new(Method::Get, "/");
12279            req.headers_mut()
12280                .insert("x-forwarded-for", b"1.1.1.1".to_vec());
12281            assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12282        }
12283
12284        // IP A is now exhausted
12285        let mut req = Request::new(Method::Get, "/");
12286        req.headers_mut()
12287            .insert("x-forwarded-for", b"1.1.1.1".to_vec());
12288        assert!(run_rate_limit_before(&mw, &mut req).is_break());
12289
12290        // IP B should still be fine
12291        let mut req = Request::new(Method::Get, "/");
12292        req.headers_mut()
12293            .insert("x-forwarded-for", b"2.2.2.2".to_vec());
12294        assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12295    }
12296
12297    #[test]
12298    fn rate_limit_token_bucket_allows_burst() {
12299        let mw = RateLimitMiddleware::builder()
12300            .requests(5)
12301            .per(Duration::from_secs(60))
12302            .algorithm(RateLimitAlgorithm::TokenBucket)
12303            .key_extractor(IpKeyExtractor)
12304            .build();
12305
12306        // Should allow 5 rapid requests (full bucket)
12307        for i in 0..5 {
12308            let mut req = Request::new(Method::Get, "/");
12309            req.headers_mut()
12310                .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12311            let result = run_rate_limit_before(&mw, &mut req);
12312            assert!(result.is_continue(), "burst request {i} should be allowed");
12313        }
12314
12315        // 6th request should be blocked (bucket empty)
12316        let mut req = Request::new(Method::Get, "/");
12317        req.headers_mut()
12318            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12319        assert!(run_rate_limit_before(&mw, &mut req).is_break());
12320    }
12321
12322    #[test]
12323    fn rate_limit_sliding_window_basic() {
12324        let mw = RateLimitMiddleware::builder()
12325            .requests(3)
12326            .per(Duration::from_secs(60))
12327            .algorithm(RateLimitAlgorithm::SlidingWindow)
12328            .key_extractor(IpKeyExtractor)
12329            .build();
12330
12331        for i in 0..3 {
12332            let mut req = Request::new(Method::Get, "/");
12333            req.headers_mut()
12334                .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12335            assert!(
12336                run_rate_limit_before(&mw, &mut req).is_continue(),
12337                "sliding window request {i} should be allowed"
12338            );
12339        }
12340
12341        // Should block once limit reached
12342        let mut req = Request::new(Method::Get, "/");
12343        req.headers_mut()
12344            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12345        assert!(run_rate_limit_before(&mw, &mut req).is_break());
12346    }
12347
12348    #[test]
12349    fn rate_limit_header_key_extractor() {
12350        let mw = RateLimitMiddleware::builder()
12351            .requests(2)
12352            .per(Duration::from_secs(60))
12353            .algorithm(RateLimitAlgorithm::FixedWindow)
12354            .key_extractor(HeaderKeyExtractor::new("x-api-key"))
12355            .build();
12356
12357        // Two requests with same API key
12358        for _ in 0..2 {
12359            let mut req = Request::new(Method::Get, "/");
12360            req.headers_mut().insert("x-api-key", b"key-abc".to_vec());
12361            assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12362        }
12363
12364        // Same key blocked
12365        let mut req = Request::new(Method::Get, "/");
12366        req.headers_mut().insert("x-api-key", b"key-abc".to_vec());
12367        assert!(run_rate_limit_before(&mw, &mut req).is_break());
12368
12369        // Different key still allowed
12370        let mut req = Request::new(Method::Get, "/");
12371        req.headers_mut().insert("x-api-key", b"key-xyz".to_vec());
12372        assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12373    }
12374
12375    #[test]
12376    fn rate_limit_path_key_extractor() {
12377        let mw = RateLimitMiddleware::builder()
12378            .requests(1)
12379            .per(Duration::from_secs(60))
12380            .algorithm(RateLimitAlgorithm::FixedWindow)
12381            .key_extractor(PathKeyExtractor)
12382            .build();
12383
12384        let mut req = Request::new(Method::Get, "/api/a");
12385        assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12386
12387        // Same path is blocked
12388        let mut req = Request::new(Method::Get, "/api/a");
12389        assert!(run_rate_limit_before(&mw, &mut req).is_break());
12390
12391        // Different path is allowed
12392        let mut req = Request::new(Method::Get, "/api/b");
12393        assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12394    }
12395
12396    #[test]
12397    fn rate_limit_no_key_skips_limiting() {
12398        let mw = RateLimitMiddleware::builder()
12399            .requests(1)
12400            .per(Duration::from_secs(60))
12401            .algorithm(RateLimitAlgorithm::FixedWindow)
12402            .key_extractor(HeaderKeyExtractor::new("x-api-key"))
12403            .build();
12404
12405        // Request without the header — no key extracted, should pass
12406        let mut req = Request::new(Method::Get, "/");
12407        assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12408
12409        // Still passes even with many requests (no key = no limiting)
12410        for _ in 0..10 {
12411            let mut req = Request::new(Method::Get, "/");
12412            assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12413        }
12414    }
12415
12416    #[test]
12417    fn rate_limit_response_headers_on_success() {
12418        let mw = RateLimitMiddleware::builder()
12419            .requests(10)
12420            .per(Duration::from_secs(60))
12421            .algorithm(RateLimitAlgorithm::FixedWindow)
12422            .key_extractor(IpKeyExtractor)
12423            .build();
12424
12425        let mut req = Request::new(Method::Get, "/");
12426        req.headers_mut()
12427            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12428        let cf = run_rate_limit_before(&mw, &mut req);
12429        assert!(cf.is_continue());
12430
12431        let resp = Response::with_status(StatusCode::OK);
12432        let resp = run_rate_limit_after(&mw, &req, resp);
12433
12434        // Verify rate limit headers are present
12435        let headers = resp.headers();
12436        let has_limit = headers
12437            .iter()
12438            .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"));
12439        let has_remaining = headers
12440            .iter()
12441            .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-remaining"));
12442        let has_reset = headers
12443            .iter()
12444            .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-reset"));
12445
12446        assert!(has_limit, "should have X-RateLimit-Limit header");
12447        assert!(has_remaining, "should have X-RateLimit-Remaining header");
12448        assert!(has_reset, "should have X-RateLimit-Reset header");
12449
12450        // Check limit value
12451        let limit_val = headers
12452            .iter()
12453            .find(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"))
12454            .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
12455            .unwrap();
12456        assert_eq!(limit_val, "10");
12457    }
12458
12459    #[test]
12460    fn rate_limit_429_response_has_retry_after() {
12461        let mw = RateLimitMiddleware::builder()
12462            .requests(1)
12463            .per(Duration::from_secs(60))
12464            .algorithm(RateLimitAlgorithm::FixedWindow)
12465            .key_extractor(IpKeyExtractor)
12466            .build();
12467
12468        // Consume the single allowed request
12469        let mut req = Request::new(Method::Get, "/");
12470        req.headers_mut()
12471            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12472        assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12473
12474        // Second request should be blocked with 429
12475        let mut req = Request::new(Method::Get, "/");
12476        req.headers_mut()
12477            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12478        let result = run_rate_limit_before(&mw, &mut req);
12479
12480        if let ControlFlow::Break(resp) = result {
12481            assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
12482
12483            // Should have Retry-After header
12484            let has_retry = resp
12485                .headers()
12486                .iter()
12487                .any(|(n, _)| n.eq_ignore_ascii_case("retry-after"));
12488            assert!(has_retry, "429 response should have Retry-After header");
12489
12490            // Should have JSON body
12491            let has_ct = resp
12492                .headers()
12493                .iter()
12494                .any(|(n, v)| n.eq_ignore_ascii_case("content-type") && v == b"application/json");
12495            assert!(has_ct, "429 response should have JSON content type");
12496        } else {
12497            panic!("expected Break(429)");
12498        }
12499    }
12500
12501    #[test]
12502    fn rate_limit_no_headers_when_disabled() {
12503        let mw = RateLimitMiddleware::builder()
12504            .requests(10)
12505            .per(Duration::from_secs(60))
12506            .algorithm(RateLimitAlgorithm::FixedWindow)
12507            .key_extractor(IpKeyExtractor)
12508            .include_headers(false)
12509            .build();
12510
12511        let mut req = Request::new(Method::Get, "/");
12512        req.headers_mut()
12513            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12514        assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12515
12516        let resp = Response::with_status(StatusCode::OK);
12517        let resp = run_rate_limit_after(&mw, &req, resp);
12518
12519        let has_limit = resp
12520            .headers()
12521            .iter()
12522            .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"));
12523        assert!(
12524            !has_limit,
12525            "should NOT have rate limit headers when disabled"
12526        );
12527    }
12528
12529    #[test]
12530    fn rate_limit_custom_retry_message() {
12531        let mw = RateLimitMiddleware::builder()
12532            .requests(1)
12533            .per(Duration::from_secs(60))
12534            .algorithm(RateLimitAlgorithm::FixedWindow)
12535            .key_extractor(IpKeyExtractor)
12536            .retry_message("Slow down, partner!")
12537            .build();
12538
12539        // Exhaust limit
12540        let mut req = Request::new(Method::Get, "/");
12541        req.headers_mut()
12542            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12543        run_rate_limit_before(&mw, &mut req);
12544
12545        // Check custom message in 429 body
12546        let mut req = Request::new(Method::Get, "/");
12547        req.headers_mut()
12548            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12549        if let ControlFlow::Break(resp) = run_rate_limit_before(&mw, &mut req) {
12550            if let ResponseBody::Bytes(body) = resp.body_ref() {
12551                let body_str = std::str::from_utf8(body).unwrap();
12552                assert!(
12553                    body_str.contains("Slow down, partner!"),
12554                    "expected custom message in body, got: {body_str}"
12555                );
12556            } else {
12557                panic!("expected Bytes body");
12558            }
12559        } else {
12560            panic!("expected Break(429)");
12561        }
12562    }
12563
12564    #[test]
12565    fn rate_limit_ip_extractor_x_forwarded_for() {
12566        let extractor = IpKeyExtractor;
12567        let mut req = Request::new(Method::Get, "/");
12568        req.headers_mut()
12569            .insert("x-forwarded-for", b"1.2.3.4, 5.6.7.8".to_vec());
12570        assert_eq!(extractor.extract_key(&req), Some("1.2.3.4".to_string()));
12571    }
12572
12573    #[test]
12574    fn rate_limit_ip_extractor_x_real_ip() {
12575        let extractor = IpKeyExtractor;
12576        let mut req = Request::new(Method::Get, "/");
12577        req.headers_mut().insert("x-real-ip", b"9.8.7.6".to_vec());
12578        assert_eq!(extractor.extract_key(&req), Some("9.8.7.6".to_string()));
12579    }
12580
12581    #[test]
12582    fn rate_limit_ip_extractor_fallback() {
12583        let extractor = IpKeyExtractor;
12584        let req = Request::new(Method::Get, "/");
12585        assert_eq!(extractor.extract_key(&req), Some("unknown".to_string()));
12586    }
12587
12588    // Tests for secure ConnectedIpKeyExtractor (bd-u9gw)
12589    #[test]
12590    fn connected_ip_extractor_with_remote_addr() {
12591        use std::net::{IpAddr, Ipv4Addr};
12592
12593        let extractor = ConnectedIpKeyExtractor;
12594        let mut req = Request::new(Method::Get, "/");
12595        req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100))));
12596
12597        assert_eq!(
12598            extractor.extract_key(&req),
12599            Some("192.168.1.100".to_string())
12600        );
12601    }
12602
12603    #[test]
12604    fn connected_ip_extractor_without_remote_addr() {
12605        let extractor = ConnectedIpKeyExtractor;
12606        let req = Request::new(Method::Get, "/");
12607
12608        // Should return None when no RemoteAddr is set
12609        assert_eq!(extractor.extract_key(&req), None);
12610    }
12611
12612    #[test]
12613    fn connected_ip_extractor_ignores_headers() {
12614        use std::net::{IpAddr, Ipv4Addr};
12615
12616        let extractor = ConnectedIpKeyExtractor;
12617        let mut req = Request::new(Method::Get, "/");
12618        req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
12619        // Add spoofed header - should be ignored
12620        req.headers_mut()
12621            .insert("x-forwarded-for", b"1.2.3.4".to_vec());
12622
12623        // Should use RemoteAddr, not the header
12624        assert_eq!(extractor.extract_key(&req), Some("10.0.0.1".to_string()));
12625    }
12626
12627    // Tests for TrustedProxyIpKeyExtractor (bd-u9gw)
12628    #[test]
12629    fn trusted_proxy_extractor_from_trusted_proxy() {
12630        use std::net::{IpAddr, Ipv4Addr};
12631
12632        let extractor = TrustedProxyIpKeyExtractor::new().trust_cidr("10.0.0.0/8");
12633
12634        let mut req = Request::new(Method::Get, "/");
12635        // Request came from trusted proxy 10.0.0.1
12636        req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
12637        // Proxy set X-Forwarded-For with real client IP
12638        req.headers_mut()
12639            .insert("x-forwarded-for", b"203.0.113.50".to_vec());
12640
12641        // Should trust the header and extract client IP
12642        assert_eq!(
12643            extractor.extract_key(&req),
12644            Some("203.0.113.50".to_string())
12645        );
12646    }
12647
12648    #[test]
12649    fn trusted_proxy_extractor_from_untrusted_direct() {
12650        use std::net::{IpAddr, Ipv4Addr};
12651
12652        let extractor = TrustedProxyIpKeyExtractor::new().trust_cidr("10.0.0.0/8");
12653
12654        let mut req = Request::new(Method::Get, "/");
12655        // Request came directly from client (not a trusted proxy)
12656        req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50))));
12657        // Client tries to spoof X-Forwarded-For
12658        req.headers_mut()
12659            .insert("x-forwarded-for", b"1.2.3.4".to_vec());
12660
12661        // Should ignore header and use RemoteAddr
12662        assert_eq!(
12663            extractor.extract_key(&req),
12664            Some("203.0.113.50".to_string())
12665        );
12666    }
12667
12668    #[test]
12669    fn trusted_proxy_extractor_no_remote_addr() {
12670        let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
12671
12672        let mut req = Request::new(Method::Get, "/");
12673        // No RemoteAddr set - should return None (safer than guessing)
12674        req.headers_mut()
12675            .insert("x-forwarded-for", b"1.2.3.4".to_vec());
12676
12677        assert_eq!(extractor.extract_key(&req), None);
12678    }
12679
12680    #[test]
12681    fn trusted_proxy_extractor_loopback_ipv4() {
12682        use std::net::{IpAddr, Ipv4Addr};
12683
12684        let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
12685
12686        let mut req = Request::new(Method::Get, "/");
12687        req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::LOCALHOST)));
12688        req.headers_mut()
12689            .insert("x-forwarded-for", b"8.8.8.8".to_vec());
12690
12691        assert_eq!(extractor.extract_key(&req), Some("8.8.8.8".to_string()));
12692    }
12693
12694    #[test]
12695    fn trusted_proxy_extractor_loopback_ipv6() {
12696        use std::net::{IpAddr, Ipv6Addr};
12697
12698        let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
12699
12700        let mut req = Request::new(Method::Get, "/");
12701        req.insert_extension(RemoteAddr(IpAddr::V6(Ipv6Addr::LOCALHOST)));
12702        req.headers_mut()
12703            .insert("x-forwarded-for", b"8.8.8.8".to_vec());
12704
12705        assert_eq!(extractor.extract_key(&req), Some("8.8.8.8".to_string()));
12706    }
12707
12708    #[test]
12709    fn cidr_parsing() {
12710        // Valid CIDRs
12711        assert!(parse_cidr("10.0.0.0/8").is_some());
12712        assert!(parse_cidr("192.168.1.0/24").is_some());
12713        assert!(parse_cidr("0.0.0.0/0").is_some());
12714        assert!(parse_cidr("::1/128").is_some());
12715        assert!(parse_cidr("::/0").is_some());
12716
12717        // Invalid CIDRs
12718        assert!(parse_cidr("10.0.0.0/33").is_none()); // Prefix too large for IPv4
12719        assert!(parse_cidr("invalid").is_none());
12720        assert!(parse_cidr("10.0.0.0").is_none()); // Missing prefix
12721    }
12722
12723    #[test]
12724    fn ip_in_cidr_matching() {
12725        use std::net::{IpAddr, Ipv4Addr};
12726
12727        let cidr_10 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0));
12728
12729        // In range
12730        assert!(ip_in_cidr(
12731            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
12732            cidr_10,
12733            8
12734        ));
12735        assert!(ip_in_cidr(
12736            IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255)),
12737            cidr_10,
12738            8
12739        ));
12740
12741        // Out of range
12742        assert!(!ip_in_cidr(
12743            IpAddr::V4(Ipv4Addr::new(11, 0, 0, 1)),
12744            cidr_10,
12745            8
12746        ));
12747        assert!(!ip_in_cidr(
12748            IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
12749            cidr_10,
12750            8
12751        ));
12752    }
12753
12754    #[test]
12755    fn rate_limit_composite_key_extractor() {
12756        let extractor =
12757            CompositeKeyExtractor::new(vec![Box::new(IpKeyExtractor), Box::new(PathKeyExtractor)]);
12758
12759        let mut req = Request::new(Method::Get, "/api/users");
12760        req.headers_mut()
12761            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12762
12763        let key = extractor.extract_key(&req);
12764        assert_eq!(key, Some("10.0.0.1:/api/users".to_string()));
12765    }
12766
12767    #[test]
12768    fn rate_limit_builder_defaults() {
12769        let mw = RateLimitMiddleware::builder().build();
12770        assert_eq!(mw.config.max_requests, 100);
12771        assert_eq!(mw.config.window, Duration::from_secs(60));
12772        assert_eq!(mw.config.algorithm, RateLimitAlgorithm::TokenBucket);
12773        assert!(mw.config.include_headers);
12774    }
12775
12776    #[test]
12777    fn rate_limit_builder_per_minute() {
12778        let mw = RateLimitMiddleware::builder()
12779            .requests(50)
12780            .per_minute(2)
12781            .algorithm(RateLimitAlgorithm::SlidingWindow)
12782            .build();
12783        assert_eq!(mw.config.max_requests, 50);
12784        assert_eq!(mw.config.window, Duration::from_secs(120));
12785        assert_eq!(mw.config.algorithm, RateLimitAlgorithm::SlidingWindow);
12786    }
12787
12788    #[test]
12789    fn rate_limit_builder_per_hour() {
12790        let mw = RateLimitMiddleware::builder()
12791            .requests(1000)
12792            .per_hour(1)
12793            .build();
12794        assert_eq!(mw.config.window, Duration::from_secs(3600));
12795    }
12796
12797    #[test]
12798    fn rate_limit_middleware_name() {
12799        let mw = RateLimitMiddleware::new();
12800        assert_eq!(mw.name(), "RateLimit");
12801    }
12802
12803    #[test]
12804    fn rate_limit_default_via_default_trait() {
12805        let mw = RateLimitMiddleware::default();
12806        assert_eq!(mw.config.max_requests, 100);
12807    }
12808
12809    // ========================================================================
12810    // ETag Middleware Tests
12811    // ========================================================================
12812
12813    #[test]
12814    fn etag_middleware_generates_etag_for_get() {
12815        let mw = ETagMiddleware::new();
12816        let ctx = test_context();
12817        let req = Request::new(crate::request::Method::Get, "/resource");
12818
12819        // Create response with body
12820        let response = Response::ok()
12821            .header("content-type", b"application/json".to_vec())
12822            .body(ResponseBody::Bytes(br#"{"status":"ok"}"#.to_vec()));
12823
12824        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12825
12826        // Should have ETag header
12827        let etag = response
12828            .headers()
12829            .iter()
12830            .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
12831        assert!(etag.is_some(), "Response should have ETag header");
12832
12833        // ETag should be a quoted hex string
12834        let etag_value = std::str::from_utf8(&etag.unwrap().1).unwrap();
12835        assert!(etag_value.starts_with('"'), "ETag should start with quote");
12836        assert!(etag_value.ends_with('"'), "ETag should end with quote");
12837    }
12838
12839    #[test]
12840    fn etag_middleware_returns_304_on_match() {
12841        let mw = ETagMiddleware::new();
12842        let ctx = test_context();
12843
12844        // First request to get the ETag
12845        let req1 = Request::new(crate::request::Method::Get, "/resource");
12846        let body = br#"{"status":"ok"}"#.to_vec();
12847        let response1 = Response::ok().body(ResponseBody::Bytes(body.clone()));
12848        let response1 = futures_executor::block_on(mw.after(&ctx, &req1, response1));
12849
12850        let etag = response1
12851            .headers()
12852            .iter()
12853            .find(|(name, _)| name.eq_ignore_ascii_case("etag"))
12854            .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
12855            .unwrap();
12856
12857        // Second request with If-None-Match header
12858        let mut req2 = Request::new(crate::request::Method::Get, "/resource");
12859        req2.headers_mut()
12860            .insert("if-none-match", etag.as_bytes().to_vec());
12861
12862        let response2 = Response::ok().body(ResponseBody::Bytes(body));
12863        let response2 = futures_executor::block_on(mw.after(&ctx, &req2, response2));
12864
12865        // Should return 304 Not Modified
12866        assert_eq!(response2.status().as_u16(), 304);
12867        assert!(response2.body_ref().is_empty());
12868    }
12869
12870    #[test]
12871    fn etag_middleware_returns_full_response_on_mismatch() {
12872        let mw = ETagMiddleware::new();
12873        let ctx = test_context();
12874
12875        let mut req = Request::new(crate::request::Method::Get, "/resource");
12876        req.headers_mut()
12877            .insert("if-none-match", b"\"old-etag\"".to_vec());
12878
12879        let body = br#"{"status":"updated"}"#.to_vec();
12880        let response = Response::ok().body(ResponseBody::Bytes(body.clone()));
12881        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12882
12883        // Should return 200 OK with body
12884        assert_eq!(response.status().as_u16(), 200);
12885        assert!(!response.body_ref().is_empty());
12886    }
12887
12888    #[test]
12889    fn etag_middleware_weak_etag_generation() {
12890        let config = ETagConfig::new().weak(true);
12891        let mw = ETagMiddleware::with_config(config);
12892        let ctx = test_context();
12893        let req = Request::new(crate::request::Method::Get, "/resource");
12894
12895        let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
12896        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12897
12898        let etag = response
12899            .headers()
12900            .iter()
12901            .find(|(name, _)| name.eq_ignore_ascii_case("etag"))
12902            .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
12903            .unwrap();
12904
12905        assert!(etag.starts_with("W/"), "Weak ETag should start with W/");
12906    }
12907
12908    #[test]
12909    fn etag_middleware_skips_post_requests() {
12910        let mw = ETagMiddleware::new();
12911        let ctx = test_context();
12912        let req = Request::new(crate::request::Method::Post, "/resource");
12913
12914        let response = Response::ok().body(ResponseBody::Bytes(b"created".to_vec()));
12915        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12916
12917        // POST should not get ETag
12918        let etag = response
12919            .headers()
12920            .iter()
12921            .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
12922        assert!(etag.is_none(), "POST should not have ETag");
12923    }
12924
12925    #[test]
12926    fn etag_middleware_handles_head_requests() {
12927        let mw = ETagMiddleware::new();
12928        let ctx = test_context();
12929        let req = Request::new(crate::request::Method::Head, "/resource");
12930
12931        let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
12932        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12933
12934        // HEAD should get ETag
12935        let etag = response
12936            .headers()
12937            .iter()
12938            .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
12939        assert!(etag.is_some(), "HEAD should have ETag");
12940    }
12941
12942    #[test]
12943    fn etag_middleware_disabled_mode() {
12944        let config = ETagConfig::new().mode(ETagMode::Disabled);
12945        let mw = ETagMiddleware::with_config(config);
12946        let ctx = test_context();
12947        let req = Request::new(crate::request::Method::Get, "/resource");
12948
12949        let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
12950        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12951
12952        // Should not have ETag when disabled
12953        let etag = response
12954            .headers()
12955            .iter()
12956            .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
12957        assert!(etag.is_none(), "Disabled mode should not add ETag");
12958    }
12959
12960    #[test]
12961    fn etag_middleware_min_size_filter() {
12962        let config = ETagConfig::new().min_size(1000);
12963        let mw = ETagMiddleware::with_config(config);
12964        let ctx = test_context();
12965        let req = Request::new(crate::request::Method::Get, "/resource");
12966
12967        // Small body below min_size
12968        let response = Response::ok().body(ResponseBody::Bytes(b"small".to_vec()));
12969        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12970
12971        // Should not have ETag for small body
12972        let etag = response
12973            .headers()
12974            .iter()
12975            .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
12976        assert!(etag.is_none(), "Small body should not get ETag");
12977    }
12978
12979    #[test]
12980    fn etag_middleware_preserves_existing_etag() {
12981        let config = ETagConfig::new().mode(ETagMode::Manual);
12982        let mw = ETagMiddleware::with_config(config);
12983        let ctx = test_context();
12984
12985        // First request to set up cached ETag
12986        let mut req = Request::new(crate::request::Method::Get, "/resource");
12987        req.headers_mut()
12988            .insert("if-none-match", b"\"custom-etag\"".to_vec());
12989
12990        // Response with pre-set ETag matching the request
12991        let response = Response::ok()
12992            .header("etag", b"\"custom-etag\"".to_vec())
12993            .body(ResponseBody::Bytes(b"data".to_vec()));
12994        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12995
12996        // Should return 304 since custom ETag matches
12997        assert_eq!(response.status().as_u16(), 304);
12998    }
12999
13000    #[test]
13001    fn etag_middleware_wildcard_if_none_match() {
13002        let mw = ETagMiddleware::new();
13003        let ctx = test_context();
13004        let mut req = Request::new(crate::request::Method::Get, "/resource");
13005        req.headers_mut().insert("if-none-match", b"*".to_vec());
13006
13007        let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
13008        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13009
13010        // Wildcard should match any ETag
13011        assert_eq!(response.status().as_u16(), 304);
13012    }
13013
13014    #[test]
13015    fn etag_middleware_weak_comparison_matches() {
13016        let mw = ETagMiddleware::new();
13017        let ctx = test_context();
13018
13019        // Get the strong ETag
13020        let req1 = Request::new(crate::request::Method::Get, "/resource");
13021        let body = b"test data".to_vec();
13022        let response1 = Response::ok().body(ResponseBody::Bytes(body.clone()));
13023        let response1 = futures_executor::block_on(mw.after(&ctx, &req1, response1));
13024
13025        let etag = response1
13026            .headers()
13027            .iter()
13028            .find(|(name, _)| name.eq_ignore_ascii_case("etag"))
13029            .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
13030            .unwrap();
13031
13032        // Send request with weak version of the same ETag
13033        let mut req2 = Request::new(crate::request::Method::Get, "/resource");
13034        let weak_etag = format!("W/{}", etag);
13035        req2.headers_mut()
13036            .insert("if-none-match", weak_etag.as_bytes().to_vec());
13037
13038        let response2 = Response::ok().body(ResponseBody::Bytes(body));
13039        let response2 = futures_executor::block_on(mw.after(&ctx, &req2, response2));
13040
13041        // Weak comparison should match
13042        assert_eq!(response2.status().as_u16(), 304);
13043    }
13044
13045    #[test]
13046    fn etag_middleware_name() {
13047        let mw = ETagMiddleware::new();
13048        assert_eq!(mw.name(), "ETagMiddleware");
13049    }
13050
13051    #[test]
13052    fn etag_config_builder() {
13053        let config = ETagConfig::new()
13054            .mode(ETagMode::Auto)
13055            .weak(true)
13056            .min_size(512);
13057
13058        assert_eq!(config.mode, ETagMode::Auto);
13059        assert!(config.weak);
13060        assert_eq!(config.min_size, 512);
13061    }
13062
13063    #[test]
13064    fn etag_generates_consistent_hash() {
13065        // Same data should produce same ETag
13066        let etag1 = ETagMiddleware::generate_etag(b"hello world", false);
13067        let etag2 = ETagMiddleware::generate_etag(b"hello world", false);
13068        assert_eq!(etag1, etag2);
13069
13070        // Different data should produce different ETag
13071        let etag3 = ETagMiddleware::generate_etag(b"hello world!", false);
13072        assert_ne!(etag1, etag3);
13073    }
13074}