Skip to main content

fastapi_core/
extract.rs

1//! Request extraction traits and extractors.
2//!
3//! This module provides the [`FromRequest`] trait and common extractors
4//! like [`Json`] and [`Path`] for parsing request data.
5
6use crate::context::RequestContext;
7use crate::error::{HttpError, ValidationError, ValidationErrors};
8use crate::request::{Body, Request};
9use crate::response::{IntoResponse, Response, ResponseBody};
10use serde::de::{
11    self, DeserializeOwned, Deserializer, IntoDeserializer, MapAccess, SeqAccess, Visitor,
12};
13use std::fmt;
14use std::future::Future;
15use std::ops::{Deref, DerefMut};
16
17/// Trait for types that can be extracted from a request.
18///
19/// This is the core abstraction for request handlers. Each parameter
20/// in a handler function implements this trait.
21///
22/// The `ctx` parameter provides access to the request context, including
23/// asupersync's capability context for cancellation checkpoints and
24/// budget-aware operations.
25///
26/// # Example
27///
28/// ```ignore
29/// use fastapi_core::{FromRequest, Request, RequestContext};
30///
31/// struct MyExtractor(String);
32///
33/// impl FromRequest for MyExtractor {
34///     type Error = std::convert::Infallible;
35///
36///     async fn from_request(
37///         ctx: &RequestContext,
38///         req: &mut Request,
39///     ) -> Result<Self, Self::Error> {
40///         // Check for cancellation before expensive work
41///         let _ = ctx.checkpoint();
42///         Ok(MyExtractor("extracted".to_string()))
43///     }
44/// }
45/// ```
46pub trait FromRequest: Sized {
47    /// Error type when extraction fails.
48    type Error: IntoResponse;
49
50    /// Extract a value from the request.
51    ///
52    /// # Parameters
53    ///
54    /// - `ctx`: The request context providing access to asupersync capabilities
55    /// - `req`: The HTTP request to extract from
56    fn from_request(
57        ctx: &RequestContext,
58        req: &mut Request,
59    ) -> impl Future<Output = Result<Self, Self::Error>> + Send;
60}
61
62// Implement for Option to make extractors optional
63impl<T: FromRequest> FromRequest for Option<T> {
64    type Error = std::convert::Infallible;
65
66    async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
67        Ok(T::from_request(ctx, req).await.ok())
68    }
69}
70
71// Implement for RequestContext itself - allows handlers to receive the context
72impl FromRequest for RequestContext {
73    type Error = std::convert::Infallible;
74
75    async fn from_request(ctx: &RequestContext, _req: &mut Request) -> Result<Self, Self::Error> {
76        Ok(ctx.clone())
77    }
78}
79
80// ============================================================================
81// JSON Body Extractor
82// ============================================================================
83
84/// Default maximum JSON body size (1MB).
85pub const DEFAULT_JSON_LIMIT: usize = 1024 * 1024;
86
87/// Configuration for JSON request body extraction.
88///
89/// Controls the maximum body size and accepted content types for
90/// the [`Json`] extractor.
91///
92/// # Defaults
93///
94/// | Setting | Default |
95/// |---------|---------|
96/// | `limit` | 1 MB (`DEFAULT_JSON_LIMIT`) |
97/// | `content_type` | `None` (accepts any `application/json` variant) |
98#[derive(Debug, Clone)]
99pub struct JsonConfig {
100    /// Maximum body size in bytes.
101    limit: usize,
102    /// Content-Type header value to accept (case-insensitive).
103    /// If None, accepts any application/json variant.
104    content_type: Option<String>,
105}
106
107impl Default for JsonConfig {
108    fn default() -> Self {
109        Self {
110            limit: DEFAULT_JSON_LIMIT,
111            content_type: None,
112        }
113    }
114}
115
116impl JsonConfig {
117    /// Create a new JSON configuration.
118    #[must_use]
119    pub fn new() -> Self {
120        Self::default()
121    }
122
123    /// Set the maximum body size limit.
124    #[must_use]
125    pub fn limit(mut self, limit: usize) -> Self {
126        self.limit = limit;
127        self
128    }
129
130    /// Set a specific Content-Type to accept.
131    #[must_use]
132    pub fn content_type(mut self, content_type: impl Into<String>) -> Self {
133        self.content_type = Some(content_type.into());
134        self
135    }
136
137    /// Returns the configured size limit.
138    #[must_use]
139    pub fn get_limit(&self) -> usize {
140        self.limit
141    }
142}
143
144/// JSON body extractor.
145///
146/// Extracts a JSON body from the request and deserializes it to type `T`.
147///
148/// # Error Responses
149///
150/// - **415 Unsupported Media Type**: Content-Type is not `application/json`
151/// - **413 Payload Too Large**: Body exceeds configured size limit
152/// - **422 Unprocessable Entity**: JSON parsing failed
153///
154/// # Example
155///
156/// ```ignore
157/// use fastapi_core::extract::Json;
158/// use serde::Deserialize;
159///
160/// #[derive(Deserialize)]
161/// struct CreateUser {
162///     name: String,
163///     email: String,
164/// }
165///
166/// async fn create_user(Json(user): Json<CreateUser>) -> impl IntoResponse {
167///     format!("Created user: {}", user.name)
168/// }
169/// ```
170#[derive(Debug, Clone, Copy, Default)]
171pub struct Json<T>(pub T);
172
173impl<T> Json<T> {
174    /// Unwrap the inner value.
175    pub fn into_inner(self) -> T {
176        self.0
177    }
178}
179
180impl<T> Deref for Json<T> {
181    type Target = T;
182
183    fn deref(&self) -> &Self::Target {
184        &self.0
185    }
186}
187
188impl<T> DerefMut for Json<T> {
189    fn deref_mut(&mut self) -> &mut Self::Target {
190        &mut self.0
191    }
192}
193
194impl<T: serde::Serialize> IntoResponse for Json<T> {
195    fn into_response(self) -> Response {
196        match serde_json::to_vec(&self.0) {
197            Ok(bytes) => Response::ok()
198                .header("content-type", b"application/json".to_vec())
199                .body(ResponseBody::Bytes(bytes)),
200            Err(e) => {
201                // Serialization error - use ResponseValidationError for proper handling
202                // This ensures error details are logged but not exposed to clients
203                crate::error::ResponseValidationError::serialization_failed(e.to_string())
204                    .into_response()
205            }
206        }
207    }
208}
209
210/// Error returned when JSON extraction fails.
211#[derive(Debug)]
212pub enum JsonExtractError {
213    /// Content-Type header is missing or not application/json.
214    UnsupportedMediaType {
215        /// The actual Content-Type received (if any).
216        actual: Option<String>,
217    },
218    /// Request body exceeds the size limit.
219    PayloadTooLarge {
220        /// The actual body size.
221        size: usize,
222        /// The configured limit.
223        limit: usize,
224    },
225    /// JSON deserialization failed.
226    DeserializeError {
227        /// The serde_json error message.
228        message: String,
229        /// Line number where error occurred (if available).
230        line: Option<usize>,
231        /// Column number where error occurred (if available).
232        column: Option<usize>,
233    },
234    /// Streaming request bodies are not supported.
235    StreamingNotSupported,
236}
237
238impl std::fmt::Display for JsonExtractError {
239    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240        match self {
241            Self::UnsupportedMediaType { actual } => {
242                if let Some(ct) = actual {
243                    write!(f, "Expected Content-Type: application/json, got: {ct}")
244                } else {
245                    write!(f, "Missing Content-Type header, expected application/json")
246                }
247            }
248            Self::PayloadTooLarge { size, limit } => {
249                write!(
250                    f,
251                    "Request body too large: {size} bytes exceeds {limit} byte limit"
252                )
253            }
254            Self::DeserializeError {
255                message,
256                line,
257                column,
258            } => {
259                if let (Some(l), Some(c)) = (line, column) {
260                    write!(f, "JSON parse error at line {l}, column {c}: {message}")
261                } else {
262                    write!(f, "JSON parse error: {message}")
263                }
264            }
265            Self::StreamingNotSupported => {
266                write!(
267                    f,
268                    "Streaming request bodies are not supported for JSON extraction"
269                )
270            }
271        }
272    }
273}
274
275impl std::error::Error for JsonExtractError {}
276
277impl IntoResponse for JsonExtractError {
278    fn into_response(self) -> crate::response::Response {
279        match self {
280            Self::UnsupportedMediaType { actual } => {
281                let detail = if let Some(ct) = actual {
282                    format!("Expected Content-Type: application/json, got: {ct}")
283                } else {
284                    "Missing Content-Type header, expected application/json".to_string()
285                };
286                HttpError::unsupported_media_type()
287                    .with_detail(detail)
288                    .into_response()
289            }
290            Self::PayloadTooLarge { size, limit } => HttpError::payload_too_large()
291                .with_detail(format!(
292                    "Request body too large: {size} bytes exceeds {limit} byte limit"
293                ))
294                .into_response(),
295            Self::DeserializeError {
296                message,
297                line,
298                column,
299            } => {
300                // Return a 422 with validation error format
301                let msg = if let (Some(l), Some(c)) = (line, column) {
302                    format!("JSON parse error at line {l}, column {c}: {message}")
303                } else {
304                    format!("JSON parse error: {message}")
305                };
306                ValidationErrors::single(ValidationError::json_invalid(
307                    crate::error::loc::body(),
308                    msg,
309                ))
310                .into_response()
311            }
312            Self::StreamingNotSupported => HttpError::bad_request()
313                .with_detail("Streaming request bodies are not supported for JSON extraction")
314                .into_response(),
315        }
316    }
317}
318
319impl<T: DeserializeOwned> FromRequest for Json<T> {
320    type Error = JsonExtractError;
321
322    async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
323        // Validate Content-Type
324        let content_type = req
325            .headers()
326            .get("content-type")
327            .and_then(|v| std::str::from_utf8(v).ok());
328
329        let is_json = content_type.is_some_and(|ct| {
330            // Case-insensitive check without allocation
331            ct.get(..16)
332                .is_some_and(|prefix| prefix.eq_ignore_ascii_case("application/json"))
333                || (ct
334                    .get(..12)
335                    .is_some_and(|p| p.eq_ignore_ascii_case("application/"))
336                    && ct
337                        .as_bytes()
338                        .windows(5)
339                        .any(|w| w.eq_ignore_ascii_case(b"+json")))
340        });
341
342        if !is_json {
343            return Err(JsonExtractError::UnsupportedMediaType {
344                actual: content_type.map(String::from),
345            });
346        }
347
348        // Check cancellation before reading the body
349        let _ = ctx.checkpoint();
350
351        // Get body bytes
352        let body = req.take_body();
353        let bytes = match body {
354            Body::Empty => Vec::new(),
355            Body::Bytes(b) => b,
356            Body::Stream(_) => {
357                // Streaming bodies not yet supported in Json extractor
358                return Err(JsonExtractError::StreamingNotSupported);
359            }
360        };
361
362        // Check size limit using the configured limit from RequestContext.
363        // This respects both app-level config (AppConfig.max_body_size) and
364        // per-route overrides when available.
365        let limit = ctx.max_body_size();
366        if bytes.len() > limit {
367            return Err(JsonExtractError::PayloadTooLarge {
368                size: bytes.len(),
369                limit,
370            });
371        }
372
373        // Check cancellation before deserialization
374        let _ = ctx.checkpoint();
375
376        // Deserialize JSON
377        // NOTE: serde_json 1.0.114+ has a default recursion limit of 128,
378        // protecting against stack overflow from deeply nested JSON.
379        let value =
380            serde_json::from_slice(&bytes).map_err(|e| JsonExtractError::DeserializeError {
381                message: e.to_string(),
382                line: Some(e.line()),
383                column: Some(e.column()),
384            })?;
385
386        // Check cancellation after parsing
387        let _ = ctx.checkpoint();
388
389        Ok(Json(value))
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use crate::request::Method;
397
398    // Helper to create a test context
399    fn test_context() -> RequestContext {
400        let cx = asupersync::Cx::for_testing();
401        RequestContext::new(cx, 12345)
402    }
403
404    // Helper to create a request with JSON body
405    fn json_request(body: &str) -> Request {
406        let mut req = Request::new(Method::Post, "/test");
407        req.headers_mut()
408            .insert("content-type", b"application/json".to_vec());
409        req.set_body(Body::Bytes(body.as_bytes().to_vec()));
410        req
411    }
412
413    #[test]
414    fn json_config_defaults() {
415        let config = JsonConfig::default();
416        assert_eq!(config.get_limit(), DEFAULT_JSON_LIMIT);
417    }
418
419    #[test]
420    fn json_config_custom() {
421        let config = JsonConfig::new().limit(1024);
422        assert_eq!(config.get_limit(), 1024);
423    }
424
425    #[test]
426    fn json_deref() {
427        let json = Json(42i32);
428        assert_eq!(*json, 42);
429    }
430
431    #[test]
432    fn json_into_inner() {
433        let json = Json("hello".to_string());
434        assert_eq!(json.into_inner(), "hello");
435    }
436
437    #[test]
438    fn json_extract_success() {
439        use serde::Deserialize;
440
441        #[derive(Deserialize, Debug, PartialEq)]
442        struct TestPayload {
443            name: String,
444            value: i32,
445        }
446
447        let ctx = test_context();
448        let mut req = json_request(r#"{"name": "test", "value": 42}"#);
449
450        let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
451        let Json(payload) = result.unwrap();
452        assert_eq!(payload.name, "test");
453        assert_eq!(payload.value, 42);
454    }
455
456    #[test]
457    fn json_extract_wrong_content_type() {
458        use serde::Deserialize;
459
460        #[derive(Deserialize)]
461        struct TestPayload {
462            #[allow(dead_code)]
463            name: String,
464        }
465
466        let ctx = test_context();
467        let mut req = Request::new(Method::Post, "/test");
468        req.headers_mut()
469            .insert("content-type", b"text/plain".to_vec());
470        req.set_body(Body::Bytes(b"{}".to_vec()));
471
472        let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
473        assert!(matches!(
474            result,
475            Err(JsonExtractError::UnsupportedMediaType { actual: Some(ct) })
476            if ct == "text/plain"
477        ));
478    }
479
480    #[test]
481    fn json_extract_missing_content_type() {
482        use serde::Deserialize;
483
484        #[derive(Deserialize)]
485        struct TestPayload {
486            #[allow(dead_code)]
487            name: String,
488        }
489
490        let ctx = test_context();
491        let mut req = Request::new(Method::Post, "/test");
492        req.set_body(Body::Bytes(b"{}".to_vec()));
493
494        let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
495        assert!(matches!(
496            result,
497            Err(JsonExtractError::UnsupportedMediaType { actual: None })
498        ));
499    }
500
501    #[test]
502    fn json_extract_invalid_json() {
503        use serde::Deserialize;
504
505        #[derive(Deserialize)]
506        struct TestPayload {
507            #[allow(dead_code)]
508            name: String,
509        }
510
511        let ctx = test_context();
512        let mut req = json_request(r#"{"name": invalid}"#);
513
514        let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
515        assert!(matches!(
516            result,
517            Err(JsonExtractError::DeserializeError { .. })
518        ));
519    }
520
521    #[test]
522    fn json_extract_application_json_charset() {
523        use serde::Deserialize;
524
525        #[derive(Deserialize, PartialEq, Debug)]
526        struct TestPayload {
527            value: i32,
528        }
529
530        let ctx = test_context();
531        let mut req = Request::new(Method::Post, "/test");
532        req.headers_mut()
533            .insert("content-type", b"application/json; charset=utf-8".to_vec());
534        req.set_body(Body::Bytes(b"{\"value\": 123}".to_vec()));
535
536        let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
537        let Json(payload) = result.unwrap();
538        assert_eq!(payload.value, 123);
539    }
540
541    #[test]
542    fn json_extract_vendor_json() {
543        use serde::Deserialize;
544
545        #[derive(Deserialize, PartialEq, Debug)]
546        struct TestPayload {
547            value: i32,
548        }
549
550        let ctx = test_context();
551        let mut req = Request::new(Method::Post, "/test");
552        // Vendor media types like application/vnd.api+json should work
553        req.headers_mut()
554            .insert("content-type", b"application/vnd.api+json".to_vec());
555        req.set_body(Body::Bytes(b"{\"value\": 456}".to_vec()));
556
557        let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
558        let Json(payload) = result.unwrap();
559        assert_eq!(payload.value, 456);
560    }
561
562    #[test]
563    fn json_error_display() {
564        let err = JsonExtractError::UnsupportedMediaType {
565            actual: Some("text/html".to_string()),
566        };
567        assert!(err.to_string().contains("text/html"));
568
569        let err = JsonExtractError::PayloadTooLarge {
570            size: 2000,
571            limit: 1000,
572        };
573        assert!(err.to_string().contains("2000"));
574        assert!(err.to_string().contains("1000"));
575
576        let err = JsonExtractError::DeserializeError {
577            message: "unexpected token".to_string(),
578            line: Some(1),
579            column: Some(10),
580        };
581        assert!(err.to_string().contains("line 1"));
582        assert!(err.to_string().contains("column 10"));
583    }
584}
585
586// ============================================================================
587// Form Body Extractor
588// ============================================================================
589
590/// Default maximum form body size (1MB).
591pub const DEFAULT_FORM_LIMIT: usize = 1024 * 1024;
592
593/// Configuration for form extraction.
594#[derive(Debug, Clone)]
595pub struct FormConfig {
596    limit: usize,
597}
598
599impl Default for FormConfig {
600    fn default() -> Self {
601        Self {
602            limit: DEFAULT_FORM_LIMIT,
603        }
604    }
605}
606
607impl FormConfig {
608    #[must_use]
609    pub fn new() -> Self {
610        Self::default()
611    }
612
613    #[must_use]
614    pub fn limit(mut self, limit: usize) -> Self {
615        self.limit = limit;
616        self
617    }
618
619    #[must_use]
620    pub fn get_limit(&self) -> usize {
621        self.limit
622    }
623}
624
625/// Form body extractor for `application/x-www-form-urlencoded`.
626#[derive(Debug, Clone, Copy, Default)]
627pub struct Form<T>(pub T);
628
629impl<T> Form<T> {
630    pub fn new(value: T) -> Self {
631        Self(value)
632    }
633
634    pub fn into_inner(self) -> T {
635        self.0
636    }
637}
638
639impl<T> Deref for Form<T> {
640    type Target = T;
641    fn deref(&self) -> &Self::Target {
642        &self.0
643    }
644}
645
646impl<T> DerefMut for Form<T> {
647    fn deref_mut(&mut self) -> &mut Self::Target {
648        &mut self.0
649    }
650}
651
652/// Error for form extraction failures.
653#[derive(Debug)]
654pub enum FormExtractError {
655    UnsupportedMediaType { actual: Option<String> },
656    PayloadTooLarge { size: usize, limit: usize },
657    DeserializeError { message: String },
658    StreamingNotSupported,
659    InvalidUtf8,
660}
661
662impl std::fmt::Display for FormExtractError {
663    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
664        match self {
665            Self::UnsupportedMediaType { actual } => {
666                if let Some(ct) = actual {
667                    write!(f, "Expected application/x-www-form-urlencoded, got: {ct}")
668                } else {
669                    write!(f, "Missing Content-Type header")
670                }
671            }
672            Self::PayloadTooLarge { size, limit } => {
673                write!(f, "Body too large: {size} > {limit}")
674            }
675            Self::DeserializeError { message } => write!(f, "Form error: {message}"),
676            Self::StreamingNotSupported => write!(f, "Streaming not supported"),
677            Self::InvalidUtf8 => write!(f, "Invalid UTF-8"),
678        }
679    }
680}
681
682impl std::error::Error for FormExtractError {}
683
684impl IntoResponse for FormExtractError {
685    fn into_response(self) -> Response {
686        match &self {
687            FormExtractError::UnsupportedMediaType { .. } => {
688                HttpError::unsupported_media_type().into_response()
689            }
690            FormExtractError::PayloadTooLarge { size, limit } => HttpError::payload_too_large()
691                .with_detail(format!("Body {size} > {limit}"))
692                .into_response(),
693            FormExtractError::DeserializeError { message } => {
694                use crate::error::error_types;
695                ValidationErrors::single(
696                    ValidationError::new(
697                        error_types::VALUE_ERROR,
698                        vec![crate::error::LocItem::field("body")],
699                    )
700                    .with_msg(message.clone()),
701                )
702                .into_response()
703            }
704            FormExtractError::StreamingNotSupported => HttpError::bad_request().into_response(),
705            FormExtractError::InvalidUtf8 => HttpError::bad_request()
706                .with_detail("Invalid UTF-8")
707                .into_response(),
708        }
709    }
710}
711
712impl<T: DeserializeOwned> FromRequest for Form<T> {
713    type Error = FormExtractError;
714
715    async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
716        let ct = req
717            .headers()
718            .get("content-type")
719            .and_then(|v| std::str::from_utf8(v).ok());
720        let is_form = ct.is_some_and(|c| {
721            c.to_ascii_lowercase()
722                .starts_with("application/x-www-form-urlencoded")
723        });
724        if !is_form {
725            return Err(FormExtractError::UnsupportedMediaType {
726                actual: ct.map(String::from),
727            });
728        }
729        let _ = ctx.checkpoint();
730        let body = req.take_body();
731        let bytes = match body {
732            Body::Empty => Vec::new(),
733            Body::Bytes(b) => b,
734            Body::Stream(_) => return Err(FormExtractError::StreamingNotSupported),
735        };
736        let limit = ctx.max_body_size();
737        if bytes.len() > limit {
738            return Err(FormExtractError::PayloadTooLarge {
739                size: bytes.len(),
740                limit,
741            });
742        }
743        let _ = ctx.checkpoint();
744        let body_str = std::str::from_utf8(&bytes).map_err(|_| FormExtractError::InvalidUtf8)?;
745        let params = QueryParams::parse(body_str);
746        let value = T::deserialize(QueryDeserializer::new(&params)).map_err(|e| {
747            FormExtractError::DeserializeError {
748                message: e.to_string(),
749            }
750        })?;
751        let _ = ctx.checkpoint();
752        Ok(Form(value))
753    }
754}
755
756#[cfg(test)]
757mod form_tests {
758    use super::*;
759    use crate::request::Method;
760
761    fn test_context() -> RequestContext {
762        let cx = asupersync::Cx::for_testing();
763        RequestContext::new(cx, 12345)
764    }
765
766    fn form_request(body: &str) -> Request {
767        let mut req = Request::new(Method::Post, "/test");
768        req.headers_mut().insert(
769            "content-type",
770            b"application/x-www-form-urlencoded".to_vec(),
771        );
772        req.set_body(Body::Bytes(body.as_bytes().to_vec()));
773        req
774    }
775
776    #[test]
777    fn form_extract_success() {
778        use serde::Deserialize;
779        #[derive(Deserialize, Debug, PartialEq)]
780        struct Login {
781            username: String,
782            password: String,
783        }
784        let ctx = test_context();
785        let mut req = form_request("username=alice&password=secret");
786        let result = futures_executor::block_on(Form::<Login>::from_request(&ctx, &mut req));
787        let Form(form) = result.unwrap();
788        assert_eq!(form.username, "alice");
789        assert_eq!(form.password, "secret");
790    }
791
792    #[test]
793    fn form_wrong_content_type() {
794        use serde::Deserialize;
795        #[derive(Deserialize)]
796        struct T {
797            #[allow(dead_code)]
798            x: String,
799        }
800        let ctx = test_context();
801        let mut req = Request::new(Method::Post, "/test");
802        req.headers_mut()
803            .insert("content-type", b"application/json".to_vec());
804        req.set_body(Body::Bytes(b"x=1".to_vec()));
805        let result = futures_executor::block_on(Form::<T>::from_request(&ctx, &mut req));
806        assert!(matches!(
807            result,
808            Err(FormExtractError::UnsupportedMediaType { .. })
809        ));
810    }
811}
812
813// ============================================================================
814// Raw Body Extractors (Bytes/String)
815// ============================================================================
816
817/// Default maximum raw body size (2MB).
818pub const DEFAULT_RAW_BODY_LIMIT: usize = 2 * 1024 * 1024;
819
820/// Configuration for raw body extraction.
821#[derive(Debug, Clone)]
822pub struct RawBodyConfig {
823    /// Maximum body size in bytes.
824    limit: usize,
825}
826
827impl Default for RawBodyConfig {
828    fn default() -> Self {
829        Self {
830            limit: DEFAULT_RAW_BODY_LIMIT,
831        }
832    }
833}
834
835impl RawBodyConfig {
836    /// Create a new configuration with default settings.
837    #[must_use]
838    pub fn new() -> Self {
839        Self::default()
840    }
841
842    /// Set the maximum body size.
843    #[must_use]
844    pub fn limit(mut self, size: usize) -> Self {
845        self.limit = size;
846        self
847    }
848
849    /// Get the maximum body size.
850    #[must_use]
851    pub fn get_limit(&self) -> usize {
852        self.limit
853    }
854}
855
856/// Error for raw body extraction failures.
857#[derive(Debug)]
858pub enum RawBodyError {
859    /// Body exceeds maximum allowed size.
860    PayloadTooLarge { size: usize, limit: usize },
861    /// Streaming body not supported.
862    StreamingNotSupported,
863    /// Body is not valid UTF-8 (for String extractor).
864    InvalidUtf8,
865}
866
867impl std::fmt::Display for RawBodyError {
868    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
869        match self {
870            Self::PayloadTooLarge { size, limit } => {
871                write!(
872                    f,
873                    "Payload too large: {size} bytes exceeds limit of {limit}"
874                )
875            }
876            Self::StreamingNotSupported => {
877                write!(f, "Streaming body not supported for raw extraction")
878            }
879            Self::InvalidUtf8 => write!(f, "Body is not valid UTF-8"),
880        }
881    }
882}
883
884impl std::error::Error for RawBodyError {}
885
886impl IntoResponse for RawBodyError {
887    fn into_response(self) -> Response {
888        match &self {
889            RawBodyError::PayloadTooLarge { size, limit } => HttpError::payload_too_large()
890                .with_detail(format!("Body {size} bytes > {limit} limit"))
891                .into_response(),
892            RawBodyError::StreamingNotSupported => HttpError::bad_request()
893                .with_detail("Streaming body not supported")
894                .into_response(),
895            RawBodyError::InvalidUtf8 => HttpError::bad_request()
896                .with_detail("Body is not valid UTF-8")
897                .into_response(),
898        }
899    }
900}
901
902/// Raw bytes body extractor.
903///
904/// Extracts the request body as raw bytes without any content-type validation.
905/// This is useful when you need the raw payload regardless of format.
906///
907/// # Example
908///
909/// ```ignore
910/// use fastapi_core::{Bytes, FromRequest};
911///
912/// async fn upload(body: Bytes) -> String {
913///     format!("Received {} bytes", body.len())
914/// }
915/// ```
916#[derive(Debug, Clone)]
917pub struct Bytes(pub Vec<u8>);
918
919impl Bytes {
920    /// Create a new Bytes from a vector.
921    #[must_use]
922    pub fn new(data: Vec<u8>) -> Self {
923        Self(data)
924    }
925
926    /// Get the length of the body.
927    #[must_use]
928    pub fn len(&self) -> usize {
929        self.0.len()
930    }
931
932    /// Check if the body is empty.
933    #[must_use]
934    pub fn is_empty(&self) -> bool {
935        self.0.is_empty()
936    }
937
938    /// Get the bytes as a slice.
939    #[must_use]
940    pub fn as_slice(&self) -> &[u8] {
941        &self.0
942    }
943
944    /// Take ownership of the inner Vec.
945    #[must_use]
946    pub fn into_inner(self) -> Vec<u8> {
947        self.0
948    }
949}
950
951impl AsRef<[u8]> for Bytes {
952    fn as_ref(&self) -> &[u8] {
953        &self.0
954    }
955}
956
957impl std::ops::Deref for Bytes {
958    type Target = [u8];
959
960    fn deref(&self) -> &Self::Target {
961        &self.0
962    }
963}
964
965impl From<Vec<u8>> for Bytes {
966    fn from(data: Vec<u8>) -> Self {
967        Self(data)
968    }
969}
970
971impl From<Bytes> for Vec<u8> {
972    fn from(bytes: Bytes) -> Self {
973        bytes.0
974    }
975}
976
977impl FromRequest for Bytes {
978    type Error = RawBodyError;
979
980    async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
981        let _ = ctx.checkpoint();
982
983        let body = req.take_body();
984        let bytes = match body {
985            Body::Empty => Vec::new(),
986            Body::Bytes(b) => b,
987            Body::Stream(_) => return Err(RawBodyError::StreamingNotSupported),
988        };
989
990        // Get limit from config or use default
991        let limit = req
992            .get_extension::<RawBodyConfig>()
993            .map(|c| c.limit)
994            .unwrap_or(DEFAULT_RAW_BODY_LIMIT);
995
996        if bytes.len() > limit {
997            return Err(RawBodyError::PayloadTooLarge {
998                size: bytes.len(),
999                limit,
1000            });
1001        }
1002
1003        let _ = ctx.checkpoint();
1004        Ok(Bytes(bytes))
1005    }
1006}
1007
1008/// String body extractor.
1009///
1010/// Extracts the request body as a UTF-8 string. Returns an error if the
1011/// body is not valid UTF-8.
1012///
1013/// # Example
1014///
1015/// ```ignore
1016/// use fastapi_core::{StringBody, FromRequest};
1017///
1018/// async fn process(body: StringBody) -> String {
1019///     format!("Received: {}", body.as_str())
1020/// }
1021/// ```
1022#[derive(Debug, Clone)]
1023pub struct StringBody(pub String);
1024
1025impl StringBody {
1026    /// Create a new Text from a string.
1027    #[must_use]
1028    pub fn new(data: String) -> Self {
1029        Self(data)
1030    }
1031
1032    /// Get the length of the string.
1033    #[must_use]
1034    pub fn len(&self) -> usize {
1035        self.0.len()
1036    }
1037
1038    /// Check if the string is empty.
1039    #[must_use]
1040    pub fn is_empty(&self) -> bool {
1041        self.0.is_empty()
1042    }
1043
1044    /// Get the string as a str slice.
1045    #[must_use]
1046    pub fn as_str(&self) -> &str {
1047        &self.0
1048    }
1049
1050    /// Take ownership of the inner String.
1051    #[must_use]
1052    pub fn into_inner(self) -> String {
1053        self.0
1054    }
1055}
1056
1057impl AsRef<str> for StringBody {
1058    fn as_ref(&self) -> &str {
1059        &self.0
1060    }
1061}
1062
1063impl std::ops::Deref for StringBody {
1064    type Target = str;
1065
1066    fn deref(&self) -> &Self::Target {
1067        &self.0
1068    }
1069}
1070
1071impl From<String> for StringBody {
1072    fn from(data: String) -> Self {
1073        Self(data)
1074    }
1075}
1076
1077impl From<StringBody> for String {
1078    fn from(text: StringBody) -> Self {
1079        text.0
1080    }
1081}
1082
1083impl FromRequest for StringBody {
1084    type Error = RawBodyError;
1085
1086    async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
1087        let bytes = Bytes::from_request(ctx, req).await?;
1088
1089        let text = String::from_utf8(bytes.into_inner()).map_err(|_| RawBodyError::InvalidUtf8)?;
1090
1091        Ok(StringBody(text))
1092    }
1093}
1094
1095#[cfg(test)]
1096mod raw_body_tests {
1097    use super::*;
1098    use crate::request::Method;
1099
1100    fn test_context() -> RequestContext {
1101        RequestContext::new(asupersync::Cx::for_testing(), 1)
1102    }
1103
1104    #[test]
1105    fn test_bytes_extract_success() {
1106        let ctx = test_context();
1107        let mut req = Request::new(Method::Post, "/upload");
1108        req.set_body(Body::Bytes(b"hello world".to_vec()));
1109
1110        let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
1111        let bytes = result.unwrap();
1112        assert_eq!(bytes.as_slice(), b"hello world");
1113        assert_eq!(bytes.len(), 11);
1114    }
1115
1116    #[test]
1117    fn test_bytes_extract_empty() {
1118        let ctx = test_context();
1119        let mut req = Request::new(Method::Post, "/upload");
1120        req.set_body(Body::Empty);
1121
1122        let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
1123        let bytes = result.unwrap();
1124        assert!(bytes.is_empty());
1125    }
1126
1127    #[test]
1128    fn test_bytes_size_limit() {
1129        let ctx = test_context();
1130        let mut req = Request::new(Method::Post, "/upload");
1131        let large_body = vec![0u8; DEFAULT_RAW_BODY_LIMIT + 1];
1132        req.set_body(Body::Bytes(large_body));
1133
1134        let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
1135        assert!(matches!(result, Err(RawBodyError::PayloadTooLarge { .. })));
1136    }
1137
1138    #[test]
1139    fn test_bytes_custom_limit() {
1140        let ctx = test_context();
1141        let mut req = Request::new(Method::Post, "/upload");
1142        req.insert_extension(RawBodyConfig::new().limit(100));
1143        req.set_body(Body::Bytes(vec![0u8; 150]));
1144
1145        let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
1146        assert!(matches!(
1147            result,
1148            Err(RawBodyError::PayloadTooLarge {
1149                size: 150,
1150                limit: 100
1151            })
1152        ));
1153    }
1154
1155    #[test]
1156    fn test_bytes_deref() {
1157        let bytes = Bytes::new(b"test".to_vec());
1158        assert_eq!(&*bytes, b"test");
1159    }
1160
1161    #[test]
1162    fn test_bytes_from_vec() {
1163        let bytes: Bytes = vec![1, 2, 3].into();
1164        assert_eq!(bytes.as_slice(), &[1, 2, 3]);
1165    }
1166
1167    #[test]
1168    fn test_string_body_extract_success() {
1169        let ctx = test_context();
1170        let mut req = Request::new(Method::Post, "/text");
1171        req.set_body(Body::Bytes(b"hello world".to_vec()));
1172
1173        let result = futures_executor::block_on(StringBody::from_request(&ctx, &mut req));
1174        let text = result.unwrap();
1175        assert_eq!(text.as_str(), "hello world");
1176        assert_eq!(text.len(), 11);
1177    }
1178
1179    #[test]
1180    fn test_string_body_extract_empty() {
1181        let ctx = test_context();
1182        let mut req = Request::new(Method::Post, "/text");
1183        req.set_body(Body::Empty);
1184
1185        let result = futures_executor::block_on(StringBody::from_request(&ctx, &mut req));
1186        let text = result.unwrap();
1187        assert!(text.is_empty());
1188    }
1189
1190    #[test]
1191    fn test_string_body_invalid_utf8() {
1192        let ctx = test_context();
1193        let mut req = Request::new(Method::Post, "/text");
1194        // Invalid UTF-8 sequence
1195        req.set_body(Body::Bytes(vec![0xff, 0xfe, 0x00, 0x01]));
1196
1197        let result = futures_executor::block_on(StringBody::from_request(&ctx, &mut req));
1198        assert!(matches!(result, Err(RawBodyError::InvalidUtf8)));
1199    }
1200
1201    #[test]
1202    fn test_string_body_deref() {
1203        let text = StringBody::new("hello".to_string());
1204        assert_eq!(&*text, "hello");
1205    }
1206
1207    #[test]
1208    fn test_string_body_from_string() {
1209        let text: StringBody = "test".to_string().into();
1210        assert_eq!(text.as_str(), "test");
1211    }
1212
1213    #[test]
1214    fn test_string_body_unicode() {
1215        let ctx = test_context();
1216        let mut req = Request::new(Method::Post, "/text");
1217        req.set_body(Body::Bytes("こんにちは世界 🌍".as_bytes().to_vec()));
1218
1219        let result = futures_executor::block_on(StringBody::from_request(&ctx, &mut req));
1220        let text = result.unwrap();
1221        assert_eq!(text.as_str(), "こんにちは世界 🌍");
1222    }
1223}
1224
1225// ============================================================================
1226// Multipart Form Extractor
1227// ============================================================================
1228
1229/// Default maximum file size for multipart uploads (10MB).
1230pub const DEFAULT_MULTIPART_FILE_SIZE: usize = 10 * 1024 * 1024;
1231
1232/// Default maximum total size for multipart uploads (50MB).
1233pub const DEFAULT_MULTIPART_TOTAL_SIZE: usize = 50 * 1024 * 1024;
1234
1235/// Default maximum number of fields in multipart form.
1236pub const DEFAULT_MULTIPART_MAX_FIELDS: usize = 100;
1237
1238/// Configuration for multipart form extraction.
1239///
1240/// Controls size limits for file uploads and the maximum number of
1241/// form fields. These limits prevent denial-of-service via oversized uploads.
1242///
1243/// # Defaults
1244///
1245/// | Setting | Default |
1246/// |---------|---------|
1247/// | `max_file_size` | 10 MB |
1248/// | `max_total_size` | 50 MB |
1249/// | `max_fields` | 100 |
1250#[derive(Debug, Clone)]
1251pub struct MultipartConfig {
1252    max_file_size: usize,
1253    max_total_size: usize,
1254    max_fields: usize,
1255}
1256
1257impl Default for MultipartConfig {
1258    fn default() -> Self {
1259        Self {
1260            max_file_size: DEFAULT_MULTIPART_FILE_SIZE,
1261            max_total_size: DEFAULT_MULTIPART_TOTAL_SIZE,
1262            max_fields: DEFAULT_MULTIPART_MAX_FIELDS,
1263        }
1264    }
1265}
1266
1267impl MultipartConfig {
1268    /// Create a new configuration with default settings.
1269    #[must_use]
1270    pub fn new() -> Self {
1271        Self::default()
1272    }
1273
1274    /// Set the maximum file size.
1275    #[must_use]
1276    pub fn max_file_size(mut self, size: usize) -> Self {
1277        self.max_file_size = size;
1278        self
1279    }
1280
1281    /// Set the maximum total upload size.
1282    #[must_use]
1283    pub fn max_total_size(mut self, size: usize) -> Self {
1284        self.max_total_size = size;
1285        self
1286    }
1287
1288    /// Set the maximum number of fields.
1289    #[must_use]
1290    pub fn max_fields(mut self, count: usize) -> Self {
1291        self.max_fields = count;
1292        self
1293    }
1294
1295    /// Get the maximum file size.
1296    #[must_use]
1297    pub fn get_max_file_size(&self) -> usize {
1298        self.max_file_size
1299    }
1300
1301    /// Get the maximum total upload size.
1302    #[must_use]
1303    pub fn get_max_total_size(&self) -> usize {
1304        self.max_total_size
1305    }
1306
1307    /// Get the maximum number of fields.
1308    #[must_use]
1309    pub fn get_max_fields(&self) -> usize {
1310        self.max_fields
1311    }
1312}
1313
1314/// An uploaded file extracted from a multipart form.
1315///
1316/// # Example
1317///
1318/// ```ignore
1319/// use fastapi_core::{UploadedFile, FromRequest};
1320///
1321/// async fn upload(file: UploadedFile) -> String {
1322///     format!("Received file '{}' ({} bytes)", file.filename(), file.size())
1323/// }
1324/// ```
1325#[derive(Debug, Clone)]
1326pub struct UploadedFile {
1327    /// The form field name.
1328    field_name: String,
1329    /// The original filename.
1330    filename: String,
1331    /// The Content-Type of the file.
1332    content_type: String,
1333    /// The file contents.
1334    data: Vec<u8>,
1335}
1336
1337impl UploadedFile {
1338    /// Create a new uploaded file.
1339    #[must_use]
1340    pub fn new(field_name: String, filename: String, content_type: String, data: Vec<u8>) -> Self {
1341        Self {
1342            field_name,
1343            filename,
1344            content_type,
1345            data,
1346        }
1347    }
1348
1349    /// Get the form field name.
1350    #[must_use]
1351    pub fn field_name(&self) -> &str {
1352        &self.field_name
1353    }
1354
1355    /// Get the original filename.
1356    #[must_use]
1357    pub fn filename(&self) -> &str {
1358        &self.filename
1359    }
1360
1361    /// Get the Content-Type.
1362    #[must_use]
1363    pub fn content_type(&self) -> &str {
1364        &self.content_type
1365    }
1366
1367    /// Get the file data.
1368    #[must_use]
1369    pub fn data(&self) -> &[u8] {
1370        &self.data
1371    }
1372
1373    /// Take ownership of the file data.
1374    #[must_use]
1375    pub fn into_data(self) -> Vec<u8> {
1376        self.data
1377    }
1378
1379    /// Get the file size in bytes.
1380    #[must_use]
1381    pub fn size(&self) -> usize {
1382        self.data.len()
1383    }
1384
1385    /// Get the file extension from the filename.
1386    #[must_use]
1387    pub fn extension(&self) -> Option<&str> {
1388        self.filename
1389            .rsplit('.')
1390            .next()
1391            .filter(|ext| !ext.is_empty() && *ext != self.filename)
1392    }
1393
1394    /// Read the file data as UTF-8 text.
1395    ///
1396    /// Returns `None` if the data is not valid UTF-8.
1397    #[must_use]
1398    pub fn text(&self) -> Option<&str> {
1399        std::str::from_utf8(&self.data).ok()
1400    }
1401}
1402
1403/// Error for multipart form extraction failures.
1404#[derive(Debug)]
1405pub enum MultipartExtractError {
1406    /// Content-Type is not multipart/form-data.
1407    UnsupportedMediaType { actual: Option<String> },
1408    /// Missing boundary in Content-Type.
1409    MissingBoundary,
1410    /// File size exceeds limit.
1411    FileTooLarge { size: usize, limit: usize },
1412    /// Total upload size exceeds limit.
1413    TotalTooLarge { size: usize, limit: usize },
1414    /// Too many fields.
1415    TooManyFields { count: usize, limit: usize },
1416    /// Invalid multipart format.
1417    InvalidFormat { detail: String },
1418    /// Streaming body not supported.
1419    StreamingNotSupported,
1420    /// No file found with the given field name.
1421    FileNotFound { field_name: String },
1422}
1423
1424impl std::fmt::Display for MultipartExtractError {
1425    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1426        match self {
1427            Self::UnsupportedMediaType { actual } => {
1428                if let Some(ct) = actual {
1429                    write!(f, "Expected multipart/form-data, got: {ct}")
1430                } else {
1431                    write!(f, "Expected multipart/form-data, got empty Content-Type")
1432                }
1433            }
1434            Self::MissingBoundary => write!(f, "Missing boundary in multipart Content-Type"),
1435            Self::FileTooLarge { size, limit } => {
1436                write!(f, "File too large: {size} bytes exceeds limit of {limit}")
1437            }
1438            Self::TotalTooLarge { size, limit } => {
1439                write!(
1440                    f,
1441                    "Total upload too large: {size} bytes exceeds limit of {limit}"
1442                )
1443            }
1444            Self::TooManyFields { count, limit } => {
1445                write!(f, "Too many fields: {count} exceeds limit of {limit}")
1446            }
1447            Self::InvalidFormat { detail } => {
1448                write!(f, "Invalid multipart format: {detail}")
1449            }
1450            Self::StreamingNotSupported => {
1451                write!(f, "Streaming body not supported for multipart extraction")
1452            }
1453            Self::FileNotFound { field_name } => {
1454                write!(f, "No file found with field name '{field_name}'")
1455            }
1456        }
1457    }
1458}
1459
1460impl std::error::Error for MultipartExtractError {}
1461
1462impl IntoResponse for MultipartExtractError {
1463    fn into_response(self) -> Response {
1464        match &self {
1465            MultipartExtractError::UnsupportedMediaType { .. } => {
1466                HttpError::unsupported_media_type().into_response()
1467            }
1468            MultipartExtractError::MissingBoundary => HttpError::bad_request()
1469                .with_detail("Missing boundary in multipart Content-Type")
1470                .into_response(),
1471            MultipartExtractError::FileTooLarge { size, limit } => HttpError::payload_too_large()
1472                .with_detail(format!("File {size} bytes > {limit} limit"))
1473                .into_response(),
1474            MultipartExtractError::TotalTooLarge { size, limit } => HttpError::payload_too_large()
1475                .with_detail(format!("Total {size} bytes > {limit} limit"))
1476                .into_response(),
1477            MultipartExtractError::TooManyFields { count, limit } => HttpError::bad_request()
1478                .with_detail(format!("Too many fields: {count} > {limit}"))
1479                .into_response(),
1480            MultipartExtractError::InvalidFormat { detail } => HttpError::bad_request()
1481                .with_detail(format!("Invalid multipart: {detail}"))
1482                .into_response(),
1483            MultipartExtractError::StreamingNotSupported => HttpError::bad_request()
1484                .with_detail("Streaming body not supported")
1485                .into_response(),
1486            MultipartExtractError::FileNotFound { field_name } => {
1487                use crate::error::error_types;
1488                ValidationErrors::single(
1489                    ValidationError::new(
1490                        error_types::VALUE_ERROR,
1491                        vec![crate::error::LocItem::field(field_name)],
1492                    )
1493                    .with_msg(format!("Required file '{field_name}' not found")),
1494                )
1495                .into_response()
1496            }
1497        }
1498    }
1499}
1500
1501/// Multipart form data extractor.
1502///
1503/// Extracts a complete multipart form including all fields and files.
1504///
1505/// # Example
1506///
1507/// ```ignore
1508/// use fastapi_core::{Multipart, FromRequest};
1509///
1510/// async fn upload(form: Multipart) -> String {
1511///     let description = form.get_field("description").unwrap_or("No description");
1512///     let file = form.get_file("document");
1513///     format!("Description: {}, File: {:?}", description, file.map(|f| f.filename()))
1514/// }
1515/// ```
1516#[derive(Debug, Clone)]
1517pub struct Multipart {
1518    parts: Vec<MultipartPart>,
1519}
1520
1521/// A part of a multipart form (either a field or a file).
1522#[derive(Debug, Clone)]
1523pub struct MultipartPart {
1524    /// Field name.
1525    pub name: String,
1526    /// Filename if this is a file upload.
1527    pub filename: Option<String>,
1528    /// Content-Type if specified.
1529    pub content_type: Option<String>,
1530    /// The part data.
1531    pub data: Vec<u8>,
1532}
1533
1534impl Multipart {
1535    /// Create from parsed parts.
1536    #[must_use]
1537    pub fn from_parts(parts: Vec<MultipartPart>) -> Self {
1538        Self { parts }
1539    }
1540
1541    /// Get all parts.
1542    #[must_use]
1543    pub fn parts(&self) -> &[MultipartPart] {
1544        &self.parts
1545    }
1546
1547    /// Get a form field value by name.
1548    #[must_use]
1549    pub fn get_field(&self, name: &str) -> Option<&str> {
1550        self.parts
1551            .iter()
1552            .find(|p| p.name == name && p.filename.is_none())
1553            .and_then(|p| std::str::from_utf8(&p.data).ok())
1554    }
1555
1556    /// Get an uploaded file by field name.
1557    #[must_use]
1558    pub fn get_file(&self, name: &str) -> Option<UploadedFile> {
1559        self.parts
1560            .iter()
1561            .find(|p| p.name == name && p.filename.is_some())
1562            .map(|p| {
1563                UploadedFile::new(
1564                    p.name.clone(),
1565                    p.filename.clone().unwrap_or_default(),
1566                    p.content_type
1567                        .clone()
1568                        .unwrap_or_else(|| "application/octet-stream".to_string()),
1569                    p.data.clone(),
1570                )
1571            })
1572    }
1573
1574    /// Get all files.
1575    #[must_use]
1576    pub fn files(&self) -> Vec<UploadedFile> {
1577        self.parts
1578            .iter()
1579            .filter(|p| p.filename.is_some())
1580            .map(|p| {
1581                UploadedFile::new(
1582                    p.name.clone(),
1583                    p.filename.clone().unwrap_or_default(),
1584                    p.content_type
1585                        .clone()
1586                        .unwrap_or_else(|| "application/octet-stream".to_string()),
1587                    p.data.clone(),
1588                )
1589            })
1590            .collect()
1591    }
1592
1593    /// Get all files with a specific field name.
1594    #[must_use]
1595    pub fn get_files(&self, name: &str) -> Vec<UploadedFile> {
1596        self.parts
1597            .iter()
1598            .filter(|p| p.name == name && p.filename.is_some())
1599            .map(|p| {
1600                UploadedFile::new(
1601                    p.name.clone(),
1602                    p.filename.clone().unwrap_or_default(),
1603                    p.content_type
1604                        .clone()
1605                        .unwrap_or_else(|| "application/octet-stream".to_string()),
1606                    p.data.clone(),
1607                )
1608            })
1609            .collect()
1610    }
1611
1612    /// Get all field names and values.
1613    #[must_use]
1614    pub fn fields(&self) -> Vec<(&str, &str)> {
1615        self.parts
1616            .iter()
1617            .filter(|p| p.filename.is_none())
1618            .filter_map(|p| Some((p.name.as_str(), std::str::from_utf8(&p.data).ok()?)))
1619            .collect()
1620    }
1621
1622    /// Check if a field exists.
1623    #[must_use]
1624    pub fn has_field(&self, name: &str) -> bool {
1625        self.parts.iter().any(|p| p.name == name)
1626    }
1627
1628    /// Get the number of parts.
1629    #[must_use]
1630    pub fn len(&self) -> usize {
1631        self.parts.len()
1632    }
1633
1634    /// Check if the form is empty.
1635    #[must_use]
1636    pub fn is_empty(&self) -> bool {
1637        self.parts.is_empty()
1638    }
1639}
1640
1641impl FromRequest for Multipart {
1642    type Error = MultipartExtractError;
1643
1644    async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
1645        // Check Content-Type
1646        let content_type = req
1647            .headers()
1648            .get("content-type")
1649            .and_then(|v| std::str::from_utf8(v).ok())
1650            .map(String::from);
1651
1652        let ct = content_type
1653            .as_deref()
1654            .ok_or(MultipartExtractError::UnsupportedMediaType { actual: None })?;
1655
1656        if !ct.to_ascii_lowercase().starts_with("multipart/form-data") {
1657            return Err(MultipartExtractError::UnsupportedMediaType {
1658                actual: Some(ct.to_string()),
1659            });
1660        }
1661
1662        // Parse boundary
1663        let boundary = parse_multipart_boundary(ct)?;
1664
1665        let _ = ctx.checkpoint();
1666
1667        // Get body
1668        let body = req.take_body();
1669        let bytes = match body {
1670            Body::Empty => Vec::new(),
1671            Body::Bytes(b) => b,
1672            Body::Stream(_) => return Err(MultipartExtractError::StreamingNotSupported),
1673        };
1674
1675        // Get config from request extensions or use default
1676        let config = req
1677            .get_extension::<MultipartConfig>()
1678            .cloned()
1679            .unwrap_or_default();
1680
1681        let _ = ctx.checkpoint();
1682
1683        // Parse multipart
1684        let parts = parse_multipart_body(&bytes, &boundary, &config)?;
1685
1686        let _ = ctx.checkpoint();
1687
1688        Ok(Multipart::from_parts(parts))
1689    }
1690}
1691
1692/// File extractor for a single file upload.
1693///
1694/// This extractor gets a single file from a multipart form by field name.
1695/// The field name is specified via `FileConfig` extension on the request,
1696/// or defaults to "file".
1697///
1698/// # Example
1699///
1700/// ```ignore
1701/// use fastapi_core::{File, FromRequest};
1702///
1703/// async fn upload(file: File) -> String {
1704///     format!("Received: {} ({} bytes)", file.filename(), file.size())
1705/// }
1706/// ```
1707#[derive(Debug, Clone)]
1708pub struct File(pub UploadedFile);
1709
1710impl File {
1711    /// Get the underlying uploaded file.
1712    #[must_use]
1713    pub fn into_inner(self) -> UploadedFile {
1714        self.0
1715    }
1716
1717    /// Get a reference to the uploaded file.
1718    #[must_use]
1719    pub fn inner(&self) -> &UploadedFile {
1720        &self.0
1721    }
1722}
1723
1724impl std::ops::Deref for File {
1725    type Target = UploadedFile;
1726
1727    fn deref(&self) -> &Self::Target {
1728        &self.0
1729    }
1730}
1731
1732/// Configuration for the File extractor.
1733#[derive(Debug, Clone)]
1734pub struct FileConfig {
1735    field_name: String,
1736}
1737
1738impl Default for FileConfig {
1739    fn default() -> Self {
1740        Self {
1741            field_name: "file".to_string(),
1742        }
1743    }
1744}
1745
1746impl FileConfig {
1747    /// Create a new file config with the given field name.
1748    #[must_use]
1749    pub fn new(field_name: impl Into<String>) -> Self {
1750        Self {
1751            field_name: field_name.into(),
1752        }
1753    }
1754
1755    /// Get the field name.
1756    #[must_use]
1757    pub fn field_name(&self) -> &str {
1758        &self.field_name
1759    }
1760}
1761
1762impl FromRequest for File {
1763    type Error = MultipartExtractError;
1764
1765    async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
1766        let field_name = req
1767            .get_extension::<FileConfig>()
1768            .map(|c| c.field_name.clone())
1769            .unwrap_or_else(|| "file".to_string());
1770
1771        let multipart = Multipart::from_request(ctx, req).await?;
1772
1773        let file = multipart
1774            .get_file(&field_name)
1775            .ok_or(MultipartExtractError::FileNotFound { field_name })?;
1776
1777        Ok(File(file))
1778    }
1779}
1780
1781/// Maximum boundary length per RFC 2046.
1782///
1783/// Boundaries can be 1-70 characters. We add a safety margin and reject
1784/// anything longer to prevent memory amplification attacks.
1785const MAX_MULTIPART_BOUNDARY_LEN: usize = 70;
1786
1787/// Parse boundary from Content-Type header.
1788fn parse_multipart_boundary(content_type: &str) -> Result<String, MultipartExtractError> {
1789    for part in content_type.split(';') {
1790        let part = part.trim();
1791        if let Some(boundary) = part
1792            .strip_prefix("boundary=")
1793            .or_else(|| part.strip_prefix("BOUNDARY="))
1794        {
1795            let boundary = boundary.trim_matches('"').trim_matches('\'');
1796            if boundary.is_empty() {
1797                return Err(MultipartExtractError::MissingBoundary);
1798            }
1799            // RFC 2046: boundary must be 1-70 characters
1800            if boundary.len() > MAX_MULTIPART_BOUNDARY_LEN {
1801                return Err(MultipartExtractError::InvalidFormat {
1802                    detail: format!(
1803                        "boundary too long: {} chars (max {})",
1804                        boundary.len(),
1805                        MAX_MULTIPART_BOUNDARY_LEN
1806                    ),
1807                });
1808            }
1809            return Ok(boundary.to_string());
1810        }
1811    }
1812    Err(MultipartExtractError::MissingBoundary)
1813}
1814
1815/// Parse multipart body into parts.
1816fn parse_multipart_body(
1817    body: &[u8],
1818    boundary: &str,
1819    config: &MultipartConfig,
1820) -> Result<Vec<MultipartPart>, MultipartExtractError> {
1821    let boundary_bytes = format!("--{boundary}").into_bytes();
1822    let mut parts = Vec::new();
1823    let mut total_size = 0usize;
1824    let mut pos = 0;
1825
1826    // Find first boundary
1827    pos = find_bytes(body, &boundary_bytes, pos).ok_or_else(|| {
1828        MultipartExtractError::InvalidFormat {
1829            detail: "no boundary found".to_string(),
1830        }
1831    })?;
1832
1833    loop {
1834        // Check field limit
1835        if parts.len() >= config.max_fields {
1836            return Err(MultipartExtractError::TooManyFields {
1837                count: parts.len() + 1,
1838                limit: config.max_fields,
1839            });
1840        }
1841
1842        // Check if this is the final boundary (--boundary--)
1843        let boundary_end = pos + boundary_bytes.len();
1844        if boundary_end + 2 <= body.len() && body[boundary_end..boundary_end + 2] == *b"--" {
1845            break;
1846        }
1847
1848        // Skip boundary and CRLF
1849        pos = boundary_end;
1850        if pos + 2 > body.len() {
1851            return Err(MultipartExtractError::InvalidFormat {
1852                detail: "unexpected end after boundary".to_string(),
1853            });
1854        }
1855        if body[pos..pos + 2] != *b"\r\n" {
1856            return Err(MultipartExtractError::InvalidFormat {
1857                detail: "expected CRLF after boundary".to_string(),
1858            });
1859        }
1860        pos += 2;
1861
1862        // Parse headers
1863        let mut name = None;
1864        let mut filename = None;
1865        let mut content_type = None;
1866
1867        loop {
1868            let line_end =
1869                find_crlf(body, pos).ok_or_else(|| MultipartExtractError::InvalidFormat {
1870                    detail: "unterminated headers".to_string(),
1871                })?;
1872
1873            let line = &body[pos..line_end];
1874            if line.is_empty() {
1875                pos = line_end + 2;
1876                break;
1877            }
1878
1879            if let Ok(line_str) = std::str::from_utf8(line) {
1880                if let Some((header_name, header_value)) = line_str.split_once(':') {
1881                    let header_name = header_name.trim().to_ascii_lowercase();
1882                    let header_value = header_value.trim();
1883
1884                    if header_name == "content-disposition" {
1885                        (name, filename) = parse_content_disposition_header(header_value);
1886                    } else if header_name == "content-type" {
1887                        content_type = Some(header_value.to_string());
1888                    }
1889                }
1890            }
1891
1892            pos = line_end + 2;
1893        }
1894
1895        let name = name.ok_or_else(|| MultipartExtractError::InvalidFormat {
1896            detail: "missing Content-Disposition name".to_string(),
1897        })?;
1898
1899        // Find next boundary
1900        let data_end = find_bytes(body, &boundary_bytes, pos).ok_or_else(|| {
1901            MultipartExtractError::InvalidFormat {
1902                detail: "missing closing boundary".to_string(),
1903            }
1904        })?;
1905
1906        // Data ends before \r\n--boundary
1907        let data = if data_end >= 2 && body[data_end - 2..data_end] == *b"\r\n" {
1908            &body[pos..data_end - 2]
1909        } else {
1910            &body[pos..data_end]
1911        };
1912
1913        // Check size limits for files
1914        if filename.is_some() && data.len() > config.max_file_size {
1915            return Err(MultipartExtractError::FileTooLarge {
1916                size: data.len(),
1917                limit: config.max_file_size,
1918            });
1919        }
1920
1921        total_size += data.len();
1922        if total_size > config.max_total_size {
1923            return Err(MultipartExtractError::TotalTooLarge {
1924                size: total_size,
1925                limit: config.max_total_size,
1926            });
1927        }
1928
1929        parts.push(MultipartPart {
1930            name,
1931            filename,
1932            content_type,
1933            data: data.to_vec(),
1934        });
1935
1936        pos = data_end;
1937    }
1938
1939    Ok(parts)
1940}
1941
1942/// Find a byte sequence in data starting from position.
1943///
1944/// Uses memchr for SIMD-accelerated searching when the needle starts with
1945/// a specific byte, providing 2-5x speedup over naive linear search.
1946fn find_bytes(data: &[u8], needle: &[u8], start: usize) -> Option<usize> {
1947    if needle.is_empty() {
1948        return Some(start);
1949    }
1950    if start >= data.len() {
1951        return None;
1952    }
1953    // Use memchr to find candidate positions for the first byte,
1954    // then verify the full needle match
1955    let first_byte = needle[0];
1956    let search_slice = &data[start..];
1957
1958    if needle.len() == 1 {
1959        // Single byte search - pure memchr
1960        return memchr::memchr(first_byte, search_slice).map(|pos| pos + start);
1961    }
1962
1963    // Multi-byte needle: find first byte candidates, then verify
1964    let mut search_offset = 0;
1965    while let Some(pos) = memchr::memchr(first_byte, &search_slice[search_offset..]) {
1966        let abs_pos = start + search_offset + pos;
1967        if abs_pos + needle.len() > data.len() {
1968            return None;
1969        }
1970        if data[abs_pos..].starts_with(needle) {
1971            return Some(abs_pos);
1972        }
1973        search_offset += pos + 1;
1974    }
1975    None
1976}
1977
1978/// Find CRLF in data starting from position.
1979///
1980/// Uses memchr for SIMD-accelerated CR byte search.
1981fn find_crlf(data: &[u8], start: usize) -> Option<usize> {
1982    if start >= data.len().saturating_sub(1) {
1983        return None;
1984    }
1985    let search_slice = &data[start..];
1986
1987    // Find '\r' candidates using SIMD, then verify '\n' follows
1988    let mut search_offset = 0;
1989    while let Some(pos) = memchr::memchr(b'\r', &search_slice[search_offset..]) {
1990        let abs_pos = start + search_offset + pos;
1991        if abs_pos + 1 < data.len() && data[abs_pos + 1] == b'\n' {
1992            return Some(abs_pos);
1993        }
1994        search_offset += pos + 1;
1995        if search_offset >= search_slice.len().saturating_sub(1) {
1996            break;
1997        }
1998    }
1999    None
2000}
2001
2002/// Parse Content-Disposition header value.
2003fn parse_content_disposition_header(value: &str) -> (Option<String>, Option<String>) {
2004    let mut name = None;
2005    let mut filename = None;
2006
2007    for part in value.split(';') {
2008        let part = part.trim();
2009        if let Some(n) = part
2010            .strip_prefix("name=")
2011            .or_else(|| part.strip_prefix("NAME="))
2012        {
2013            name = Some(unquote_param(n));
2014        } else if let Some(f) = part
2015            .strip_prefix("filename=")
2016            .or_else(|| part.strip_prefix("FILENAME="))
2017        {
2018            filename = Some(unquote_param(f));
2019        }
2020    }
2021
2022    (name, filename)
2023}
2024
2025/// Remove quotes from a parameter value.
2026///
2027/// Returns the string with surrounding quotes removed, or the original
2028/// string if it's not quoted. Handles both single and double quotes.
2029fn unquote_param(s: &str) -> String {
2030    let s = s.trim();
2031    // Need at least 2 chars for valid quotes (e.g., "")
2032    if s.len() >= 2
2033        && ((s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')))
2034    {
2035        s[1..s.len() - 1].to_string()
2036    } else {
2037        s.to_string()
2038    }
2039}
2040
2041#[cfg(test)]
2042mod multipart_tests {
2043    use super::*;
2044    use crate::RequestContext;
2045    use crate::request::Method;
2046    use asupersync::Cx;
2047
2048    fn test_context() -> RequestContext {
2049        RequestContext::new(Cx::for_testing(), 1)
2050    }
2051
2052    #[test]
2053    fn test_parse_boundary() {
2054        let ct = "multipart/form-data; boundary=----WebKit";
2055        let boundary = parse_multipart_boundary(ct).unwrap();
2056        assert_eq!(boundary, "----WebKit");
2057    }
2058
2059    #[test]
2060    fn test_parse_boundary_quoted() {
2061        let ct = r#"multipart/form-data; boundary="simple""#;
2062        let boundary = parse_multipart_boundary(ct).unwrap();
2063        assert_eq!(boundary, "simple");
2064    }
2065
2066    #[test]
2067    fn test_parse_boundary_missing() {
2068        let ct = "multipart/form-data";
2069        let result = parse_multipart_boundary(ct);
2070        assert!(matches!(
2071            result,
2072            Err(MultipartExtractError::MissingBoundary)
2073        ));
2074    }
2075
2076    #[test]
2077    fn test_parse_boundary_too_long() {
2078        // RFC 2046 limits boundary to 70 chars
2079        let long_boundary = "x".repeat(100);
2080        let ct = format!("multipart/form-data; boundary={long_boundary}");
2081        let result = parse_multipart_boundary(&ct);
2082        assert!(
2083            matches!(result, Err(MultipartExtractError::InvalidFormat { .. })),
2084            "Expected InvalidFormat for boundary > 70 chars"
2085        );
2086    }
2087
2088    #[test]
2089    fn test_parse_boundary_max_length() {
2090        // 70 chars should be OK
2091        let boundary = "x".repeat(70);
2092        let ct = format!("multipart/form-data; boundary={boundary}");
2093        let result = parse_multipart_boundary(&ct);
2094        assert!(result.is_ok(), "70-char boundary should be accepted");
2095        assert_eq!(result.unwrap(), boundary);
2096    }
2097
2098    #[test]
2099    fn test_parse_simple_form() {
2100        let boundary = "----boundary";
2101        let body = concat!(
2102            "------boundary\r\n",
2103            "Content-Disposition: form-data; name=\"field1\"\r\n",
2104            "\r\n",
2105            "value1\r\n",
2106            "------boundary\r\n",
2107            "Content-Disposition: form-data; name=\"field2\"\r\n",
2108            "\r\n",
2109            "value2\r\n",
2110            "------boundary--\r\n"
2111        );
2112
2113        let config = MultipartConfig::default();
2114        let parts = parse_multipart_body(body.as_bytes(), boundary, &config).unwrap();
2115
2116        assert_eq!(parts.len(), 2);
2117        assert_eq!(parts[0].name, "field1");
2118        assert_eq!(std::str::from_utf8(&parts[0].data).unwrap(), "value1");
2119        assert_eq!(parts[1].name, "field2");
2120        assert_eq!(std::str::from_utf8(&parts[1].data).unwrap(), "value2");
2121    }
2122
2123    #[test]
2124    fn test_parse_file_upload() {
2125        let boundary = "----boundary";
2126        let body = concat!(
2127            "------boundary\r\n",
2128            "Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n",
2129            "Content-Type: text/plain\r\n",
2130            "\r\n",
2131            "Hello!\r\n",
2132            "------boundary--\r\n"
2133        );
2134
2135        let config = MultipartConfig::default();
2136        let parts = parse_multipart_body(body.as_bytes(), boundary, &config).unwrap();
2137
2138        assert_eq!(parts.len(), 1);
2139        assert_eq!(parts[0].name, "file");
2140        assert_eq!(parts[0].filename, Some("test.txt".to_string()));
2141        assert_eq!(parts[0].content_type, Some("text/plain".to_string()));
2142        assert_eq!(std::str::from_utf8(&parts[0].data).unwrap(), "Hello!");
2143    }
2144
2145    #[test]
2146    fn test_multipart_extractor() {
2147        let boundary = "----boundary";
2148        let body = concat!(
2149            "------boundary\r\n",
2150            "Content-Disposition: form-data; name=\"name\"\r\n",
2151            "\r\n",
2152            "John\r\n",
2153            "------boundary\r\n",
2154            "Content-Disposition: form-data; name=\"avatar\"; filename=\"pic.jpg\"\r\n",
2155            "Content-Type: image/jpeg\r\n",
2156            "\r\n",
2157            "JPEG\r\n",
2158            "------boundary--\r\n"
2159        );
2160
2161        let config = MultipartConfig::default();
2162        let parts = parse_multipart_body(body.as_bytes(), boundary, &config).unwrap();
2163        let form = Multipart::from_parts(parts);
2164
2165        assert_eq!(form.get_field("name"), Some("John"));
2166        let file = form.get_file("avatar").unwrap();
2167        assert_eq!(file.filename(), "pic.jpg");
2168        assert_eq!(file.content_type(), "image/jpeg");
2169    }
2170
2171    #[test]
2172    fn test_file_size_limit() {
2173        let boundary = "----boundary";
2174        let large = "x".repeat(1000);
2175        let body = format!(
2176            "------boundary\r\n\
2177             Content-Disposition: form-data; name=\"file\"; filename=\"big.txt\"\r\n\
2178             \r\n\
2179             {}\r\n\
2180             ------boundary--\r\n",
2181            large
2182        );
2183
2184        let config = MultipartConfig::default().max_file_size(100);
2185        let result = parse_multipart_body(body.as_bytes(), boundary, &config);
2186
2187        assert!(matches!(
2188            result,
2189            Err(MultipartExtractError::FileTooLarge { .. })
2190        ));
2191    }
2192
2193    #[test]
2194    fn test_total_size_limit() {
2195        let boundary = "----boundary";
2196        let data = "x".repeat(500);
2197        let body = format!(
2198            "------boundary\r\n\
2199             Content-Disposition: form-data; name=\"f1\"; filename=\"a.txt\"\r\n\
2200             \r\n\
2201             {}\r\n\
2202             ------boundary\r\n\
2203             Content-Disposition: form-data; name=\"f2\"; filename=\"b.txt\"\r\n\
2204             \r\n\
2205             {}\r\n\
2206             ------boundary--\r\n",
2207            data, data
2208        );
2209
2210        let config = MultipartConfig::default()
2211            .max_file_size(1000)
2212            .max_total_size(800);
2213        let result = parse_multipart_body(body.as_bytes(), boundary, &config);
2214
2215        assert!(matches!(
2216            result,
2217            Err(MultipartExtractError::TotalTooLarge { .. })
2218        ));
2219    }
2220
2221    #[test]
2222    fn test_field_count_limit() {
2223        let boundary = "----boundary";
2224        let mut body = String::new();
2225        for i in 0..5 {
2226            body.push_str(&format!(
2227                "------boundary\r\n\
2228                 Content-Disposition: form-data; name=\"f{}\"\r\n\
2229                 \r\n\
2230                 v{}\r\n",
2231                i, i
2232            ));
2233        }
2234        body.push_str("------boundary--\r\n");
2235
2236        let config = MultipartConfig::default().max_fields(3);
2237        let result = parse_multipart_body(body.as_bytes(), boundary, &config);
2238
2239        assert!(matches!(
2240            result,
2241            Err(MultipartExtractError::TooManyFields { .. })
2242        ));
2243    }
2244
2245    #[test]
2246    fn test_uploaded_file_extension() {
2247        let file = UploadedFile::new(
2248            "doc".to_string(),
2249            "report.pdf".to_string(),
2250            "application/pdf".to_string(),
2251            vec![],
2252        );
2253        assert_eq!(file.extension(), Some("pdf"));
2254
2255        let no_ext = UploadedFile::new(
2256            "doc".to_string(),
2257            "README".to_string(),
2258            "text/plain".to_string(),
2259            vec![],
2260        );
2261        assert_eq!(no_ext.extension(), None);
2262    }
2263
2264    #[test]
2265    fn test_multipart_from_request_wrong_content_type() {
2266        let ctx = test_context();
2267        let mut req = Request::new(Method::Post, "/upload");
2268        req.headers_mut()
2269            .insert("content-type", b"application/json".to_vec());
2270        req.set_body(Body::Bytes(b"{}".to_vec()));
2271
2272        let result = futures_executor::block_on(Multipart::from_request(&ctx, &mut req));
2273        assert!(matches!(
2274            result,
2275            Err(MultipartExtractError::UnsupportedMediaType { .. })
2276        ));
2277    }
2278
2279    #[test]
2280    fn test_file_extractor() {
2281        let boundary = "----boundary";
2282        let body = concat!(
2283            "------boundary\r\n",
2284            "Content-Disposition: form-data; name=\"file\"; filename=\"doc.pdf\"\r\n",
2285            "Content-Type: application/pdf\r\n",
2286            "\r\n",
2287            "PDF content\r\n",
2288            "------boundary--\r\n"
2289        );
2290
2291        let config = MultipartConfig::default();
2292        let parts = parse_multipart_body(body.as_bytes(), boundary, &config).unwrap();
2293        let form = Multipart::from_parts(parts);
2294
2295        let file = form.get_file("file").unwrap();
2296        assert_eq!(file.filename(), "doc.pdf");
2297        assert_eq!(file.content_type(), "application/pdf");
2298        assert_eq!(file.text(), Some("PDF content"));
2299    }
2300
2301    #[test]
2302    fn test_multiple_files() {
2303        let boundary = "----boundary";
2304        let body = concat!(
2305            "------boundary\r\n",
2306            "Content-Disposition: form-data; name=\"files\"; filename=\"a.txt\"\r\n",
2307            "\r\n",
2308            "file a\r\n",
2309            "------boundary\r\n",
2310            "Content-Disposition: form-data; name=\"files\"; filename=\"b.txt\"\r\n",
2311            "\r\n",
2312            "file b\r\n",
2313            "------boundary--\r\n"
2314        );
2315
2316        let config = MultipartConfig::default();
2317        let parts = parse_multipart_body(body.as_bytes(), boundary, &config).unwrap();
2318        let form = Multipart::from_parts(parts);
2319
2320        let files = form.get_files("files");
2321        assert_eq!(files.len(), 2);
2322        assert_eq!(files[0].filename(), "a.txt");
2323        assert_eq!(files[1].filename(), "b.txt");
2324    }
2325
2326    // =========================================================================
2327    // unquote_param edge case tests
2328    // =========================================================================
2329
2330    #[test]
2331    fn test_unquote_param_normal_quoted() {
2332        assert_eq!(unquote_param("\"hello\""), "hello");
2333        assert_eq!(unquote_param("'hello'"), "hello");
2334    }
2335
2336    #[test]
2337    fn test_unquote_param_empty_quotes() {
2338        // Empty quoted string ""
2339        assert_eq!(unquote_param("\"\""), "");
2340        assert_eq!(unquote_param("''"), "");
2341    }
2342
2343    #[test]
2344    fn test_unquote_param_single_char_no_panic() {
2345        // Single quote char should not panic, just return as-is
2346        assert_eq!(unquote_param("\""), "\"");
2347        assert_eq!(unquote_param("'"), "'");
2348    }
2349
2350    #[test]
2351    fn test_unquote_param_unquoted() {
2352        assert_eq!(unquote_param("hello"), "hello");
2353        assert_eq!(unquote_param(""), "");
2354    }
2355
2356    #[test]
2357    fn test_unquote_param_mismatched_quotes() {
2358        // Mismatched quotes should not be unquoted
2359        assert_eq!(unquote_param("\"hello'"), "\"hello'");
2360        assert_eq!(unquote_param("'hello\""), "'hello\"");
2361    }
2362
2363    #[test]
2364    fn test_unquote_param_whitespace() {
2365        // Whitespace is trimmed before quote check
2366        assert_eq!(unquote_param("  \"hello\"  "), "hello");
2367        assert_eq!(unquote_param("  'hello'  "), "hello");
2368    }
2369}
2370
2371// ============================================================================
2372// Path Parameter Extractor
2373// ============================================================================
2374
2375/// Extracted path parameters stored in request extensions.
2376///
2377/// This type is set by the router after matching a route and extracting
2378/// path parameters. The [`Path`] extractor retrieves this from the request.
2379///
2380/// # Example
2381///
2382/// For a route `/users/{user_id}/posts/{post_id}` matched against
2383/// `/users/42/posts/99`, this would contain:
2384/// `[("user_id", "42"), ("post_id", "99")]`
2385#[derive(Debug, Clone, Default)]
2386pub struct PathParams(pub Vec<(String, String)>);
2387
2388impl PathParams {
2389    /// Create empty path parameters.
2390    #[must_use]
2391    pub fn new() -> Self {
2392        Self(Vec::new())
2393    }
2394
2395    /// Create from a vector of name-value pairs.
2396    #[must_use]
2397    pub fn from_pairs(pairs: Vec<(String, String)>) -> Self {
2398        Self(pairs)
2399    }
2400
2401    /// Get a parameter value by name.
2402    #[must_use]
2403    pub fn get(&self, name: &str) -> Option<&str> {
2404        self.0
2405            .iter()
2406            .find(|(n, _)| n == name)
2407            .map(|(_, v)| v.as_str())
2408    }
2409
2410    /// Get all parameters as a slice.
2411    #[must_use]
2412    pub fn as_slice(&self) -> &[(String, String)] {
2413        &self.0
2414    }
2415
2416    /// Returns true if there are no parameters.
2417    #[must_use]
2418    pub fn is_empty(&self) -> bool {
2419        self.0.is_empty()
2420    }
2421
2422    /// Returns the number of parameters.
2423    #[must_use]
2424    pub fn len(&self) -> usize {
2425        self.0.len()
2426    }
2427}
2428
2429/// Path parameter extractor.
2430///
2431/// Extracts path parameters from the URL and deserializes them to type `T`.
2432///
2433/// # Supported Types
2434///
2435/// - **Single value**: `Path<i64>` extracts the first (or only) path parameter
2436/// - **Tuple**: `Path<(String, i64)>` extracts parameters in order
2437/// - **Struct**: `Path<MyParams>` extracts parameters by field name
2438///
2439/// # Error Responses
2440///
2441/// - **500 Internal Server Error**: Path parameters not set by router (server bug)
2442/// - **422 Unprocessable Entity**: Parameter missing or type conversion failed
2443///
2444/// # Examples
2445///
2446/// ## Single Parameter
2447///
2448/// ```ignore
2449/// #[get("/users/{id}")]
2450/// async fn get_user(Path(id): Path<i64>) -> impl IntoResponse {
2451///     format!("User ID: {id}")
2452/// }
2453/// ```
2454///
2455/// ## Multiple Parameters (Tuple)
2456///
2457/// ```ignore
2458/// #[get("/users/{user_id}/posts/{post_id}")]
2459/// async fn get_post(Path((user_id, post_id)): Path<(i64, i64)>) -> impl IntoResponse {
2460///     format!("User {user_id}, Post {post_id}")
2461/// }
2462/// ```
2463///
2464/// ## Struct Extraction
2465///
2466/// ```ignore
2467/// #[derive(Deserialize)]
2468/// struct PostPath {
2469///     user_id: i64,
2470///     post_id: i64,
2471/// }
2472///
2473/// #[get("/users/{user_id}/posts/{post_id}")]
2474/// async fn get_post(Path(path): Path<PostPath>) -> impl IntoResponse {
2475///     format!("User {}, Post {}", path.user_id, path.post_id)
2476/// }
2477/// ```
2478#[derive(Debug, Clone, Copy, Default)]
2479pub struct Path<T>(pub T);
2480
2481impl<T> Path<T> {
2482    /// Unwrap the inner value.
2483    pub fn into_inner(self) -> T {
2484        self.0
2485    }
2486}
2487
2488impl<T> Deref for Path<T> {
2489    type Target = T;
2490
2491    fn deref(&self) -> &Self::Target {
2492        &self.0
2493    }
2494}
2495
2496impl<T> DerefMut for Path<T> {
2497    fn deref_mut(&mut self) -> &mut Self::Target {
2498        &mut self.0
2499    }
2500}
2501
2502/// Error returned when path extraction fails.
2503#[derive(Debug)]
2504pub enum PathExtractError {
2505    /// Path parameters not available in request extensions.
2506    /// This indicates a server configuration error (router not setting params).
2507    MissingPathParams,
2508    /// A required parameter was not found.
2509    MissingParam {
2510        /// The parameter name that was missing.
2511        name: String,
2512    },
2513    /// Parameter value could not be converted to the expected type.
2514    InvalidValue {
2515        /// The parameter name.
2516        name: String,
2517        /// The actual value that couldn't be converted.
2518        value: String,
2519        /// Description of the expected type.
2520        expected: &'static str,
2521        /// Additional error details.
2522        message: String,
2523    },
2524    /// Deserialization error (e.g., wrong number of parameters for tuple).
2525    DeserializeError {
2526        /// The error message.
2527        message: String,
2528    },
2529}
2530
2531impl fmt::Display for PathExtractError {
2532    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2533        match self {
2534            Self::MissingPathParams => {
2535                write!(f, "Path parameters not available in request")
2536            }
2537            Self::MissingParam { name } => {
2538                write!(f, "Missing path parameter: {name}")
2539            }
2540            Self::InvalidValue {
2541                name,
2542                value,
2543                expected,
2544                message,
2545            } => {
2546                write!(
2547                    f,
2548                    "Invalid value for path parameter '{name}': expected {expected}, got '{value}': {message}"
2549                )
2550            }
2551            Self::DeserializeError { message } => {
2552                write!(f, "Path deserialization error: {message}")
2553            }
2554        }
2555    }
2556}
2557
2558impl std::error::Error for PathExtractError {}
2559
2560impl IntoResponse for PathExtractError {
2561    fn into_response(self) -> crate::response::Response {
2562        match self {
2563            Self::MissingPathParams => {
2564                // Server bug - path params should always be set by router
2565                HttpError::internal()
2566                    .with_detail("Path parameters not available")
2567                    .into_response()
2568            }
2569            Self::MissingParam { name } => ValidationErrors::single(
2570                ValidationError::missing(crate::error::loc::path(&name))
2571                    .with_msg("Path parameter is required"),
2572            )
2573            .into_response(),
2574            Self::InvalidValue {
2575                name,
2576                value,
2577                expected,
2578                message,
2579            } => ValidationErrors::single(
2580                ValidationError::type_error(crate::error::loc::path(&name), &expected)
2581                    .with_msg(format!("Expected {expected}: {message}"))
2582                    .with_input(serde_json::Value::String(value)),
2583            )
2584            .into_response(),
2585            Self::DeserializeError { message } => ValidationErrors::single(
2586                ValidationError::new(
2587                    crate::error::error_types::VALUE_ERROR,
2588                    vec![crate::error::LocItem::field("path")],
2589                )
2590                .with_msg(message),
2591            )
2592            .into_response(),
2593        }
2594    }
2595}
2596
2597impl<T: DeserializeOwned> FromRequest for Path<T> {
2598    type Error = PathExtractError;
2599
2600    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
2601        // Get path params from request extensions
2602        let params = req
2603            .get_extension::<PathParams>()
2604            .ok_or(PathExtractError::MissingPathParams)?
2605            .clone();
2606
2607        // Deserialize using our custom deserializer
2608        let value = T::deserialize(PathDeserializer::new(&params))?;
2609
2610        Ok(Path(value))
2611    }
2612}
2613
2614// ============================================================================
2615// Path Parameter Deserializer
2616// ============================================================================
2617
2618/// Custom serde deserializer for path parameters.
2619///
2620/// Handles three modes:
2621/// - Single value: Deserializes the first parameter value
2622/// - Sequence (tuple): Deserializes parameters in order
2623/// - Map (struct): Deserializes parameters by name
2624struct PathDeserializer<'de> {
2625    params: &'de PathParams,
2626}
2627
2628impl<'de> PathDeserializer<'de> {
2629    fn new(params: &'de PathParams) -> Self {
2630        Self { params }
2631    }
2632}
2633
2634impl<'de> Deserializer<'de> for PathDeserializer<'de> {
2635    type Error = PathExtractError;
2636
2637    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2638    where
2639        V: Visitor<'de>,
2640    {
2641        // Default: try as map (struct)
2642        self.deserialize_map(visitor)
2643    }
2644
2645    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2646    where
2647        V: Visitor<'de>,
2648    {
2649        let value = self.get_single_value()?;
2650        let b = value
2651            .parse::<bool>()
2652            .map_err(|_| PathExtractError::InvalidValue {
2653                name: self.get_first_name(),
2654                value: value.to_string(),
2655                expected: "boolean",
2656                message: "expected 'true' or 'false'".to_string(),
2657            })?;
2658        visitor.visit_bool(b)
2659    }
2660
2661    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2662    where
2663        V: Visitor<'de>,
2664    {
2665        let value = self.get_single_value()?;
2666        let n = value
2667            .parse::<i8>()
2668            .map_err(|e| PathExtractError::InvalidValue {
2669                name: self.get_first_name(),
2670                value: value.to_string(),
2671                expected: "i8",
2672                message: e.to_string(),
2673            })?;
2674        visitor.visit_i8(n)
2675    }
2676
2677    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2678    where
2679        V: Visitor<'de>,
2680    {
2681        let value = self.get_single_value()?;
2682        let n = value
2683            .parse::<i16>()
2684            .map_err(|e| PathExtractError::InvalidValue {
2685                name: self.get_first_name(),
2686                value: value.to_string(),
2687                expected: "i16",
2688                message: e.to_string(),
2689            })?;
2690        visitor.visit_i16(n)
2691    }
2692
2693    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2694    where
2695        V: Visitor<'de>,
2696    {
2697        let value = self.get_single_value()?;
2698        let n = value
2699            .parse::<i32>()
2700            .map_err(|e| PathExtractError::InvalidValue {
2701                name: self.get_first_name(),
2702                value: value.to_string(),
2703                expected: "i32",
2704                message: e.to_string(),
2705            })?;
2706        visitor.visit_i32(n)
2707    }
2708
2709    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2710    where
2711        V: Visitor<'de>,
2712    {
2713        let value = self.get_single_value()?;
2714        let n = value
2715            .parse::<i64>()
2716            .map_err(|e| PathExtractError::InvalidValue {
2717                name: self.get_first_name(),
2718                value: value.to_string(),
2719                expected: "i64",
2720                message: e.to_string(),
2721            })?;
2722        visitor.visit_i64(n)
2723    }
2724
2725    fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2726    where
2727        V: Visitor<'de>,
2728    {
2729        let value = self.get_single_value()?;
2730        let n = value
2731            .parse::<i128>()
2732            .map_err(|e| PathExtractError::InvalidValue {
2733                name: self.get_first_name(),
2734                value: value.to_string(),
2735                expected: "i128",
2736                message: e.to_string(),
2737            })?;
2738        visitor.visit_i128(n)
2739    }
2740
2741    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2742    where
2743        V: Visitor<'de>,
2744    {
2745        let value = self.get_single_value()?;
2746        let n = value
2747            .parse::<u8>()
2748            .map_err(|e| PathExtractError::InvalidValue {
2749                name: self.get_first_name(),
2750                value: value.to_string(),
2751                expected: "u8",
2752                message: e.to_string(),
2753            })?;
2754        visitor.visit_u8(n)
2755    }
2756
2757    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2758    where
2759        V: Visitor<'de>,
2760    {
2761        let value = self.get_single_value()?;
2762        let n = value
2763            .parse::<u16>()
2764            .map_err(|e| PathExtractError::InvalidValue {
2765                name: self.get_first_name(),
2766                value: value.to_string(),
2767                expected: "u16",
2768                message: e.to_string(),
2769            })?;
2770        visitor.visit_u16(n)
2771    }
2772
2773    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2774    where
2775        V: Visitor<'de>,
2776    {
2777        let value = self.get_single_value()?;
2778        let n = value
2779            .parse::<u32>()
2780            .map_err(|e| PathExtractError::InvalidValue {
2781                name: self.get_first_name(),
2782                value: value.to_string(),
2783                expected: "u32",
2784                message: e.to_string(),
2785            })?;
2786        visitor.visit_u32(n)
2787    }
2788
2789    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2790    where
2791        V: Visitor<'de>,
2792    {
2793        let value = self.get_single_value()?;
2794        let n = value
2795            .parse::<u64>()
2796            .map_err(|e| PathExtractError::InvalidValue {
2797                name: self.get_first_name(),
2798                value: value.to_string(),
2799                expected: "u64",
2800                message: e.to_string(),
2801            })?;
2802        visitor.visit_u64(n)
2803    }
2804
2805    fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2806    where
2807        V: Visitor<'de>,
2808    {
2809        let value = self.get_single_value()?;
2810        let n = value
2811            .parse::<u128>()
2812            .map_err(|e| PathExtractError::InvalidValue {
2813                name: self.get_first_name(),
2814                value: value.to_string(),
2815                expected: "u128",
2816                message: e.to_string(),
2817            })?;
2818        visitor.visit_u128(n)
2819    }
2820
2821    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2822    where
2823        V: Visitor<'de>,
2824    {
2825        let value = self.get_single_value()?;
2826        let n = value
2827            .parse::<f32>()
2828            .map_err(|e| PathExtractError::InvalidValue {
2829                name: self.get_first_name(),
2830                value: value.to_string(),
2831                expected: "f32",
2832                message: e.to_string(),
2833            })?;
2834        visitor.visit_f32(n)
2835    }
2836
2837    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2838    where
2839        V: Visitor<'de>,
2840    {
2841        let value = self.get_single_value()?;
2842        let n = value
2843            .parse::<f64>()
2844            .map_err(|e| PathExtractError::InvalidValue {
2845                name: self.get_first_name(),
2846                value: value.to_string(),
2847                expected: "f64",
2848                message: e.to_string(),
2849            })?;
2850        visitor.visit_f64(n)
2851    }
2852
2853    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2854    where
2855        V: Visitor<'de>,
2856    {
2857        let value = self.get_single_value()?;
2858        let mut chars = value.chars();
2859        let c = chars.next().ok_or_else(|| PathExtractError::InvalidValue {
2860            name: self.get_first_name(),
2861            value: value.to_string(),
2862            expected: "char",
2863            message: "empty string".to_string(),
2864        })?;
2865        if chars.next().is_some() {
2866            return Err(PathExtractError::InvalidValue {
2867                name: self.get_first_name(),
2868                value: value.to_string(),
2869                expected: "char",
2870                message: "expected single character".to_string(),
2871            });
2872        }
2873        visitor.visit_char(c)
2874    }
2875
2876    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2877    where
2878        V: Visitor<'de>,
2879    {
2880        let value = self.get_single_value()?;
2881        visitor.visit_str(value)
2882    }
2883
2884    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2885    where
2886        V: Visitor<'de>,
2887    {
2888        let value = self.get_single_value()?;
2889        visitor.visit_string(value.to_string())
2890    }
2891
2892    fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
2893    where
2894        V: Visitor<'de>,
2895    {
2896        Err(PathExtractError::DeserializeError {
2897            message: "bytes deserialization not supported for path parameters".to_string(),
2898        })
2899    }
2900
2901    fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
2902    where
2903        V: Visitor<'de>,
2904    {
2905        Err(PathExtractError::DeserializeError {
2906            message: "byte_buf deserialization not supported for path parameters".to_string(),
2907        })
2908    }
2909
2910    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2911    where
2912        V: Visitor<'de>,
2913    {
2914        // Path params are always present, so always Some
2915        visitor.visit_some(self)
2916    }
2917
2918    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2919    where
2920        V: Visitor<'de>,
2921    {
2922        visitor.visit_unit()
2923    }
2924
2925    fn deserialize_unit_struct<V>(
2926        self,
2927        _name: &'static str,
2928        visitor: V,
2929    ) -> Result<V::Value, Self::Error>
2930    where
2931        V: Visitor<'de>,
2932    {
2933        visitor.visit_unit()
2934    }
2935
2936    fn deserialize_newtype_struct<V>(
2937        self,
2938        _name: &'static str,
2939        visitor: V,
2940    ) -> Result<V::Value, Self::Error>
2941    where
2942        V: Visitor<'de>,
2943    {
2944        visitor.visit_newtype_struct(self)
2945    }
2946
2947    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2948    where
2949        V: Visitor<'de>,
2950    {
2951        visitor.visit_seq(PathSeqAccess::new(self.params))
2952    }
2953
2954    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
2955    where
2956        V: Visitor<'de>,
2957    {
2958        visitor.visit_seq(PathSeqAccess::new(self.params))
2959    }
2960
2961    fn deserialize_tuple_struct<V>(
2962        self,
2963        _name: &'static str,
2964        _len: usize,
2965        visitor: V,
2966    ) -> Result<V::Value, Self::Error>
2967    where
2968        V: Visitor<'de>,
2969    {
2970        visitor.visit_seq(PathSeqAccess::new(self.params))
2971    }
2972
2973    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2974    where
2975        V: Visitor<'de>,
2976    {
2977        visitor.visit_map(PathMapAccess::new(self.params))
2978    }
2979
2980    fn deserialize_struct<V>(
2981        self,
2982        _name: &'static str,
2983        _fields: &'static [&'static str],
2984        visitor: V,
2985    ) -> Result<V::Value, Self::Error>
2986    where
2987        V: Visitor<'de>,
2988    {
2989        visitor.visit_map(PathMapAccess::new(self.params))
2990    }
2991
2992    fn deserialize_enum<V>(
2993        self,
2994        _name: &'static str,
2995        _variants: &'static [&'static str],
2996        visitor: V,
2997    ) -> Result<V::Value, Self::Error>
2998    where
2999        V: Visitor<'de>,
3000    {
3001        let value = self.get_single_value()?;
3002        visitor.visit_enum(value.into_deserializer())
3003    }
3004
3005    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3006    where
3007        V: Visitor<'de>,
3008    {
3009        self.deserialize_str(visitor)
3010    }
3011
3012    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3013    where
3014        V: Visitor<'de>,
3015    {
3016        visitor.visit_unit()
3017    }
3018}
3019
3020impl PathDeserializer<'_> {
3021    fn get_single_value(&self) -> Result<&str, PathExtractError> {
3022        self.params
3023            .0
3024            .first()
3025            .map(|(_, v)| v.as_str())
3026            .ok_or_else(|| PathExtractError::DeserializeError {
3027                message: "no path parameters available".to_string(),
3028            })
3029    }
3030
3031    fn get_first_name(&self) -> String {
3032        self.params
3033            .0
3034            .first()
3035            .map_or_else(|| "unknown".to_string(), |(n, _)| n.clone())
3036    }
3037}
3038
3039impl de::Error for PathExtractError {
3040    fn custom<T: fmt::Display>(msg: T) -> Self {
3041        PathExtractError::DeserializeError {
3042            message: msg.to_string(),
3043        }
3044    }
3045}
3046
3047/// Sequence access for deserializing tuples from path params.
3048struct PathSeqAccess<'de> {
3049    params: &'de PathParams,
3050    index: usize,
3051}
3052
3053impl<'de> PathSeqAccess<'de> {
3054    fn new(params: &'de PathParams) -> Self {
3055        Self { params, index: 0 }
3056    }
3057}
3058
3059impl<'de> SeqAccess<'de> for PathSeqAccess<'de> {
3060    type Error = PathExtractError;
3061
3062    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
3063    where
3064        T: de::DeserializeSeed<'de>,
3065    {
3066        if self.index >= self.params.0.len() {
3067            return Ok(None);
3068        }
3069
3070        let (name, value) = &self.params.0[self.index];
3071        self.index += 1;
3072
3073        seed.deserialize(PathValueDeserializer::new(name, value))
3074            .map(Some)
3075    }
3076}
3077
3078/// Map access for deserializing structs from path params.
3079struct PathMapAccess<'de> {
3080    params: &'de PathParams,
3081    index: usize,
3082}
3083
3084impl<'de> PathMapAccess<'de> {
3085    fn new(params: &'de PathParams) -> Self {
3086        Self { params, index: 0 }
3087    }
3088}
3089
3090impl<'de> MapAccess<'de> for PathMapAccess<'de> {
3091    type Error = PathExtractError;
3092
3093    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
3094    where
3095        K: de::DeserializeSeed<'de>,
3096    {
3097        if self.index >= self.params.0.len() {
3098            return Ok(None);
3099        }
3100
3101        let (name, _) = &self.params.0[self.index];
3102        seed.deserialize(name.as_str().into_deserializer())
3103            .map(Some)
3104    }
3105
3106    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
3107    where
3108        V: de::DeserializeSeed<'de>,
3109    {
3110        let (name, value) = &self.params.0[self.index];
3111        self.index += 1;
3112
3113        seed.deserialize(PathValueDeserializer::new(name, value))
3114    }
3115}
3116
3117/// Deserializer for a single path parameter value.
3118struct PathValueDeserializer<'de> {
3119    name: &'de str,
3120    value: &'de str,
3121}
3122
3123impl<'de> PathValueDeserializer<'de> {
3124    fn new(name: &'de str, value: &'de str) -> Self {
3125        Self { name, value }
3126    }
3127}
3128
3129impl<'de> Deserializer<'de> for PathValueDeserializer<'de> {
3130    type Error = PathExtractError;
3131
3132    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3133    where
3134        V: Visitor<'de>,
3135    {
3136        // Default to string
3137        visitor.visit_str(self.value)
3138    }
3139
3140    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3141    where
3142        V: Visitor<'de>,
3143    {
3144        let b = self
3145            .value
3146            .parse::<bool>()
3147            .map_err(|_| PathExtractError::InvalidValue {
3148                name: self.name.to_string(),
3149                value: self.value.to_string(),
3150                expected: "boolean",
3151                message: "expected 'true' or 'false'".to_string(),
3152            })?;
3153        visitor.visit_bool(b)
3154    }
3155
3156    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3157    where
3158        V: Visitor<'de>,
3159    {
3160        let n = self
3161            .value
3162            .parse::<i8>()
3163            .map_err(|e| PathExtractError::InvalidValue {
3164                name: self.name.to_string(),
3165                value: self.value.to_string(),
3166                expected: "i8",
3167                message: e.to_string(),
3168            })?;
3169        visitor.visit_i8(n)
3170    }
3171
3172    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3173    where
3174        V: Visitor<'de>,
3175    {
3176        let n = self
3177            .value
3178            .parse::<i16>()
3179            .map_err(|e| PathExtractError::InvalidValue {
3180                name: self.name.to_string(),
3181                value: self.value.to_string(),
3182                expected: "i16",
3183                message: e.to_string(),
3184            })?;
3185        visitor.visit_i16(n)
3186    }
3187
3188    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3189    where
3190        V: Visitor<'de>,
3191    {
3192        let n = self
3193            .value
3194            .parse::<i32>()
3195            .map_err(|e| PathExtractError::InvalidValue {
3196                name: self.name.to_string(),
3197                value: self.value.to_string(),
3198                expected: "i32",
3199                message: e.to_string(),
3200            })?;
3201        visitor.visit_i32(n)
3202    }
3203
3204    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3205    where
3206        V: Visitor<'de>,
3207    {
3208        let n = self
3209            .value
3210            .parse::<i64>()
3211            .map_err(|e| PathExtractError::InvalidValue {
3212                name: self.name.to_string(),
3213                value: self.value.to_string(),
3214                expected: "i64",
3215                message: e.to_string(),
3216            })?;
3217        visitor.visit_i64(n)
3218    }
3219
3220    fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3221    where
3222        V: Visitor<'de>,
3223    {
3224        let n = self
3225            .value
3226            .parse::<i128>()
3227            .map_err(|e| PathExtractError::InvalidValue {
3228                name: self.name.to_string(),
3229                value: self.value.to_string(),
3230                expected: "i128",
3231                message: e.to_string(),
3232            })?;
3233        visitor.visit_i128(n)
3234    }
3235
3236    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3237    where
3238        V: Visitor<'de>,
3239    {
3240        let n = self
3241            .value
3242            .parse::<u8>()
3243            .map_err(|e| PathExtractError::InvalidValue {
3244                name: self.name.to_string(),
3245                value: self.value.to_string(),
3246                expected: "u8",
3247                message: e.to_string(),
3248            })?;
3249        visitor.visit_u8(n)
3250    }
3251
3252    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3253    where
3254        V: Visitor<'de>,
3255    {
3256        let n = self
3257            .value
3258            .parse::<u16>()
3259            .map_err(|e| PathExtractError::InvalidValue {
3260                name: self.name.to_string(),
3261                value: self.value.to_string(),
3262                expected: "u16",
3263                message: e.to_string(),
3264            })?;
3265        visitor.visit_u16(n)
3266    }
3267
3268    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3269    where
3270        V: Visitor<'de>,
3271    {
3272        let n = self
3273            .value
3274            .parse::<u32>()
3275            .map_err(|e| PathExtractError::InvalidValue {
3276                name: self.name.to_string(),
3277                value: self.value.to_string(),
3278                expected: "u32",
3279                message: e.to_string(),
3280            })?;
3281        visitor.visit_u32(n)
3282    }
3283
3284    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3285    where
3286        V: Visitor<'de>,
3287    {
3288        let n = self
3289            .value
3290            .parse::<u64>()
3291            .map_err(|e| PathExtractError::InvalidValue {
3292                name: self.name.to_string(),
3293                value: self.value.to_string(),
3294                expected: "u64",
3295                message: e.to_string(),
3296            })?;
3297        visitor.visit_u64(n)
3298    }
3299
3300    fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3301    where
3302        V: Visitor<'de>,
3303    {
3304        let n = self
3305            .value
3306            .parse::<u128>()
3307            .map_err(|e| PathExtractError::InvalidValue {
3308                name: self.name.to_string(),
3309                value: self.value.to_string(),
3310                expected: "u128",
3311                message: e.to_string(),
3312            })?;
3313        visitor.visit_u128(n)
3314    }
3315
3316    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3317    where
3318        V: Visitor<'de>,
3319    {
3320        let n = self
3321            .value
3322            .parse::<f32>()
3323            .map_err(|e| PathExtractError::InvalidValue {
3324                name: self.name.to_string(),
3325                value: self.value.to_string(),
3326                expected: "f32",
3327                message: e.to_string(),
3328            })?;
3329        visitor.visit_f32(n)
3330    }
3331
3332    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3333    where
3334        V: Visitor<'de>,
3335    {
3336        let n = self
3337            .value
3338            .parse::<f64>()
3339            .map_err(|e| PathExtractError::InvalidValue {
3340                name: self.name.to_string(),
3341                value: self.value.to_string(),
3342                expected: "f64",
3343                message: e.to_string(),
3344            })?;
3345        visitor.visit_f64(n)
3346    }
3347
3348    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3349    where
3350        V: Visitor<'de>,
3351    {
3352        let mut chars = self.value.chars();
3353        let c = chars.next().ok_or_else(|| PathExtractError::InvalidValue {
3354            name: self.name.to_string(),
3355            value: self.value.to_string(),
3356            expected: "char",
3357            message: "empty string".to_string(),
3358        })?;
3359        if chars.next().is_some() {
3360            return Err(PathExtractError::InvalidValue {
3361                name: self.name.to_string(),
3362                value: self.value.to_string(),
3363                expected: "char",
3364                message: "expected single character".to_string(),
3365            });
3366        }
3367        visitor.visit_char(c)
3368    }
3369
3370    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3371    where
3372        V: Visitor<'de>,
3373    {
3374        visitor.visit_str(self.value)
3375    }
3376
3377    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3378    where
3379        V: Visitor<'de>,
3380    {
3381        visitor.visit_string(self.value.to_string())
3382    }
3383
3384    fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
3385    where
3386        V: Visitor<'de>,
3387    {
3388        Err(PathExtractError::DeserializeError {
3389            message: "bytes deserialization not supported for path parameters".to_string(),
3390        })
3391    }
3392
3393    fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
3394    where
3395        V: Visitor<'de>,
3396    {
3397        Err(PathExtractError::DeserializeError {
3398            message: "byte_buf deserialization not supported for path parameters".to_string(),
3399        })
3400    }
3401
3402    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3403    where
3404        V: Visitor<'de>,
3405    {
3406        visitor.visit_some(self)
3407    }
3408
3409    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3410    where
3411        V: Visitor<'de>,
3412    {
3413        visitor.visit_unit()
3414    }
3415
3416    fn deserialize_unit_struct<V>(
3417        self,
3418        _name: &'static str,
3419        visitor: V,
3420    ) -> Result<V::Value, Self::Error>
3421    where
3422        V: Visitor<'de>,
3423    {
3424        visitor.visit_unit()
3425    }
3426
3427    fn deserialize_newtype_struct<V>(
3428        self,
3429        _name: &'static str,
3430        visitor: V,
3431    ) -> Result<V::Value, Self::Error>
3432    where
3433        V: Visitor<'de>,
3434    {
3435        visitor.visit_newtype_struct(self)
3436    }
3437
3438    fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
3439    where
3440        V: Visitor<'de>,
3441    {
3442        Err(PathExtractError::DeserializeError {
3443            message: "sequence deserialization not supported for single path parameter".to_string(),
3444        })
3445    }
3446
3447    fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
3448    where
3449        V: Visitor<'de>,
3450    {
3451        Err(PathExtractError::DeserializeError {
3452            message: "tuple deserialization not supported for single path parameter".to_string(),
3453        })
3454    }
3455
3456    fn deserialize_tuple_struct<V>(
3457        self,
3458        _name: &'static str,
3459        _len: usize,
3460        _visitor: V,
3461    ) -> Result<V::Value, Self::Error>
3462    where
3463        V: Visitor<'de>,
3464    {
3465        Err(PathExtractError::DeserializeError {
3466            message: "tuple struct deserialization not supported for single path parameter"
3467                .to_string(),
3468        })
3469    }
3470
3471    fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
3472    where
3473        V: Visitor<'de>,
3474    {
3475        Err(PathExtractError::DeserializeError {
3476            message: "map deserialization not supported for single path parameter".to_string(),
3477        })
3478    }
3479
3480    fn deserialize_struct<V>(
3481        self,
3482        _name: &'static str,
3483        _fields: &'static [&'static str],
3484        _visitor: V,
3485    ) -> Result<V::Value, Self::Error>
3486    where
3487        V: Visitor<'de>,
3488    {
3489        Err(PathExtractError::DeserializeError {
3490            message: "struct deserialization not supported for single path parameter".to_string(),
3491        })
3492    }
3493
3494    fn deserialize_enum<V>(
3495        self,
3496        _name: &'static str,
3497        _variants: &'static [&'static str],
3498        visitor: V,
3499    ) -> Result<V::Value, Self::Error>
3500    where
3501        V: Visitor<'de>,
3502    {
3503        visitor.visit_enum(self.value.into_deserializer())
3504    }
3505
3506    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3507    where
3508        V: Visitor<'de>,
3509    {
3510        visitor.visit_str(self.value)
3511    }
3512
3513    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3514    where
3515        V: Visitor<'de>,
3516    {
3517        visitor.visit_unit()
3518    }
3519}
3520
3521// ============================================================================
3522// Query String Extractor
3523// ============================================================================
3524
3525/// Query string extractor.
3526///
3527/// Extracts and deserializes query string parameters into a typed struct.
3528/// This extractor uses serde for deserialization, so the target type must
3529/// implement `DeserializeOwned`.
3530///
3531/// # Features
3532///
3533/// - **Optional fields**: Use `Option<T>` for optional parameters
3534/// - **Multi-value**: Use `Vec<T>` for parameters that appear multiple times
3535/// - **Default values**: Use `#[serde(default)]` for default values
3536/// - **Percent-decoding**: Values are automatically percent-decoded
3537///
3538/// # Example
3539///
3540/// ```ignore
3541/// use fastapi_core::Query;
3542/// use serde::Deserialize;
3543///
3544/// #[derive(Deserialize)]
3545/// struct SearchParams {
3546///     q: String,                      // Required
3547///     page: Option<i32>,              // Optional
3548///     #[serde(default)]
3549///     limit: i32,                     // Default (0)
3550///     tags: Vec<String>,              // Multi-value: ?tags=a&tags=b
3551/// }
3552///
3553/// #[get("/search")]
3554/// async fn search(cx: &Cx, params: Query<SearchParams>) -> impl IntoResponse {
3555///     // Access the inner value via params.0 or *params
3556///     let query = &params.q;
3557///     // ...
3558/// }
3559/// ```
3560///
3561/// # Error Handling
3562///
3563/// Returns HTTP 422 (Unprocessable Entity) when:
3564/// - Required fields are missing
3565/// - Type conversion fails (e.g., "abc" to i32)
3566/// - Serde deserialization fails
3567#[derive(Debug, Clone, Copy, Default)]
3568pub struct Query<T>(pub T);
3569
3570impl<T> Query<T> {
3571    /// Create a new Query extractor with the given value.
3572    pub fn new(value: T) -> Self {
3573        Self(value)
3574    }
3575
3576    /// Consume the extractor and return the inner value.
3577    pub fn into_inner(self) -> T {
3578        self.0
3579    }
3580}
3581
3582impl<T> Deref for Query<T> {
3583    type Target = T;
3584
3585    fn deref(&self) -> &Self::Target {
3586        &self.0
3587    }
3588}
3589
3590impl<T> DerefMut for Query<T> {
3591    fn deref_mut(&mut self) -> &mut Self::Target {
3592        &mut self.0
3593    }
3594}
3595
3596/// Error type for query string extraction failures.
3597#[derive(Debug)]
3598pub enum QueryExtractError {
3599    /// A required parameter is missing.
3600    MissingParam { name: String },
3601    /// A parameter value could not be converted to the expected type.
3602    InvalidValue {
3603        name: String,
3604        value: String,
3605        expected: &'static str,
3606        message: String,
3607    },
3608    /// Serde deserialization failed.
3609    DeserializeError { message: String },
3610}
3611
3612impl fmt::Display for QueryExtractError {
3613    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3614        match self {
3615            Self::MissingParam { name } => {
3616                write!(f, "Missing required query parameter: {}", name)
3617            }
3618            Self::InvalidValue {
3619                name,
3620                value,
3621                expected,
3622                message,
3623            } => {
3624                write!(
3625                    f,
3626                    "Invalid value '{}' for query parameter '{}' (expected {}): {}",
3627                    value, name, expected, message
3628                )
3629            }
3630            Self::DeserializeError { message } => {
3631                write!(f, "Query deserialization error: {}", message)
3632            }
3633        }
3634    }
3635}
3636
3637impl std::error::Error for QueryExtractError {}
3638
3639impl de::Error for QueryExtractError {
3640    fn custom<T: fmt::Display>(msg: T) -> Self {
3641        Self::DeserializeError {
3642            message: msg.to_string(),
3643        }
3644    }
3645}
3646
3647impl IntoResponse for QueryExtractError {
3648    fn into_response(self) -> crate::response::Response {
3649        match self {
3650            Self::MissingParam { name } => ValidationErrors::single(
3651                ValidationError::missing(crate::error::loc::query(&name))
3652                    .with_msg("Query parameter is required"),
3653            )
3654            .into_response(),
3655            Self::InvalidValue {
3656                name,
3657                value,
3658                expected,
3659                message,
3660            } => ValidationErrors::single(
3661                ValidationError::type_error(crate::error::loc::query(&name), &expected)
3662                    .with_msg(format!("Expected {expected}: {message}"))
3663                    .with_input(serde_json::Value::String(value)),
3664            )
3665            .into_response(),
3666            Self::DeserializeError { message } => ValidationErrors::single(
3667                ValidationError::new(
3668                    crate::error::error_types::VALUE_ERROR,
3669                    vec![crate::error::LocItem::field("query")],
3670                )
3671                .with_msg(message),
3672            )
3673            .into_response(),
3674        }
3675    }
3676}
3677
3678/// Stored query parameters for extraction.
3679///
3680/// Similar to `PathParams` but handles multi-value parameters.
3681/// Stored in request extensions by the framework.
3682#[derive(Debug, Clone, Default)]
3683pub struct QueryParams {
3684    /// Params stored as Vec to preserve order and handle duplicates.
3685    params: Vec<(String, String)>,
3686}
3687
3688impl QueryParams {
3689    /// Create empty query params.
3690    pub fn new() -> Self {
3691        Self { params: Vec::new() }
3692    }
3693
3694    /// Create from a vector of key-value pairs.
3695    pub fn from_pairs(pairs: Vec<(String, String)>) -> Self {
3696        Self { params: pairs }
3697    }
3698
3699    /// Parse from a query string (without leading '?').
3700    pub fn parse(query: &str) -> Self {
3701        let pairs: Vec<(String, String)> = query
3702            .split('&')
3703            .filter(|s| !s.is_empty())
3704            .map(|pair| {
3705                if let Some(eq_pos) = pair.find('=') {
3706                    let key = &pair[..eq_pos];
3707                    let value = &pair[eq_pos + 1..];
3708                    (
3709                        percent_decode(key).into_owned(),
3710                        percent_decode(value).into_owned(),
3711                    )
3712                } else {
3713                    // Key without value: "flag" -> ("flag", "")
3714                    (percent_decode(pair).into_owned(), String::new())
3715                }
3716            })
3717            .collect();
3718        Self { params: pairs }
3719    }
3720
3721    /// Get the first value for a key.
3722    pub fn get(&self, key: &str) -> Option<&str> {
3723        self.params
3724            .iter()
3725            .find(|(k, _)| k == key)
3726            .map(|(_, v)| v.as_str())
3727    }
3728
3729    /// Get all values for a key.
3730    pub fn get_all(&self, key: &str) -> Vec<&str> {
3731        self.params
3732            .iter()
3733            .filter(|(k, _)| k == key)
3734            .map(|(_, v)| v.as_str())
3735            .collect()
3736    }
3737
3738    /// Check if a key exists.
3739    pub fn contains(&self, key: &str) -> bool {
3740        self.params.iter().any(|(k, _)| k == key)
3741    }
3742
3743    /// Get all key-value pairs.
3744    pub fn pairs(&self) -> &[(String, String)] {
3745        &self.params
3746    }
3747
3748    /// Get iterator over unique keys.
3749    pub fn keys(&self) -> impl Iterator<Item = &str> {
3750        let mut seen = std::collections::HashSet::new();
3751        self.params.iter().filter_map(move |(k, _)| {
3752            if seen.insert(k.as_str()) {
3753                Some(k.as_str())
3754            } else {
3755                None
3756            }
3757        })
3758    }
3759
3760    /// Return the number of parameters (including duplicates).
3761    pub fn len(&self) -> usize {
3762        self.params.len()
3763    }
3764
3765    /// Check if empty.
3766    pub fn is_empty(&self) -> bool {
3767        self.params.is_empty()
3768    }
3769}
3770
3771/// Percent-decode a string.
3772///
3773/// Returns a `Cow::Borrowed` if no decoding was needed,
3774/// or `Cow::Owned` if percent sequences were decoded.
3775fn percent_decode(s: &str) -> std::borrow::Cow<'_, str> {
3776    use std::borrow::Cow;
3777
3778    // Fast path: no encoding
3779    if !s.contains('%') && !s.contains('+') {
3780        return Cow::Borrowed(s);
3781    }
3782
3783    let mut result = Vec::with_capacity(s.len());
3784    let bytes = s.as_bytes();
3785    let mut i = 0;
3786
3787    while i < bytes.len() {
3788        match bytes[i] {
3789            b'%' if i + 2 < bytes.len() => {
3790                // Try to decode hex pair
3791                if let (Some(hi), Some(lo)) = (hex_digit(bytes[i + 1]), hex_digit(bytes[i + 2])) {
3792                    result.push(hi << 4 | lo);
3793                    i += 3;
3794                } else {
3795                    // Invalid hex, keep as-is
3796                    result.push(b'%');
3797                    i += 1;
3798                }
3799            }
3800            b'+' => {
3801                // Plus as space (application/x-www-form-urlencoded)
3802                result.push(b' ');
3803                i += 1;
3804            }
3805            b => {
3806                result.push(b);
3807                i += 1;
3808            }
3809        }
3810    }
3811
3812    Cow::Owned(String::from_utf8_lossy(&result).into_owned())
3813}
3814
3815/// Convert a hex digit to its numeric value.
3816fn hex_digit(b: u8) -> Option<u8> {
3817    match b {
3818        b'0'..=b'9' => Some(b - b'0'),
3819        b'a'..=b'f' => Some(b - b'a' + 10),
3820        b'A'..=b'F' => Some(b - b'A' + 10),
3821        _ => None,
3822    }
3823}
3824
3825impl<T: DeserializeOwned> FromRequest for Query<T> {
3826    type Error = QueryExtractError;
3827
3828    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
3829        // Get or parse query params
3830        let params = match req.get_extension::<QueryParams>() {
3831            Some(p) => p.clone(),
3832            None => {
3833                // Parse from request query string
3834                let query_str = req.query().unwrap_or("");
3835                QueryParams::parse(query_str)
3836            }
3837        };
3838
3839        // Deserialize using our custom deserializer
3840        let value = T::deserialize(QueryDeserializer::new(&params))?;
3841
3842        Ok(Query(value))
3843    }
3844}
3845
3846// ============================================================================
3847// Query String Deserializer
3848// ============================================================================
3849
3850/// Custom serde deserializer for query string parameters.
3851///
3852/// Handles:
3853/// - Single values: Deserializes from first matching parameter
3854/// - Sequences (Vec): Collects all values for a parameter
3855/// - Maps/Structs: Deserializes parameters by name
3856/// - Options: Missing parameters become None
3857struct QueryDeserializer<'de> {
3858    params: &'de QueryParams,
3859}
3860
3861impl<'de> QueryDeserializer<'de> {
3862    fn new(params: &'de QueryParams) -> Self {
3863        Self { params }
3864    }
3865}
3866
3867impl<'de> Deserializer<'de> for QueryDeserializer<'de> {
3868    type Error = QueryExtractError;
3869
3870    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3871    where
3872        V: Visitor<'de>,
3873    {
3874        // Default: try as map (struct)
3875        self.deserialize_map(visitor)
3876    }
3877
3878    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3879    where
3880        V: Visitor<'de>,
3881    {
3882        let value = self
3883            .params
3884            .pairs()
3885            .first()
3886            .map(|(_, v)| v.as_str())
3887            .ok_or_else(|| QueryExtractError::MissingParam {
3888                name: "value".to_string(),
3889            })?;
3890
3891        let b = parse_bool(value).map_err(|msg| QueryExtractError::InvalidValue {
3892            name: "value".to_string(),
3893            value: value.to_string(),
3894            expected: "bool",
3895            message: msg,
3896        })?;
3897        visitor.visit_bool(b)
3898    }
3899
3900    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3901    where
3902        V: Visitor<'de>,
3903    {
3904        let value = self.get_single_value()?;
3905        let n = value
3906            .parse::<i8>()
3907            .map_err(|e| QueryExtractError::InvalidValue {
3908                name: "value".to_string(),
3909                value: value.to_string(),
3910                expected: "i8",
3911                message: e.to_string(),
3912            })?;
3913        visitor.visit_i8(n)
3914    }
3915
3916    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3917    where
3918        V: Visitor<'de>,
3919    {
3920        let value = self.get_single_value()?;
3921        let n = value
3922            .parse::<i16>()
3923            .map_err(|e| QueryExtractError::InvalidValue {
3924                name: "value".to_string(),
3925                value: value.to_string(),
3926                expected: "i16",
3927                message: e.to_string(),
3928            })?;
3929        visitor.visit_i16(n)
3930    }
3931
3932    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3933    where
3934        V: Visitor<'de>,
3935    {
3936        let value = self.get_single_value()?;
3937        let n = value
3938            .parse::<i32>()
3939            .map_err(|e| QueryExtractError::InvalidValue {
3940                name: "value".to_string(),
3941                value: value.to_string(),
3942                expected: "i32",
3943                message: e.to_string(),
3944            })?;
3945        visitor.visit_i32(n)
3946    }
3947
3948    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3949    where
3950        V: Visitor<'de>,
3951    {
3952        let value = self.get_single_value()?;
3953        let n = value
3954            .parse::<i64>()
3955            .map_err(|e| QueryExtractError::InvalidValue {
3956                name: "value".to_string(),
3957                value: value.to_string(),
3958                expected: "i64",
3959                message: e.to_string(),
3960            })?;
3961        visitor.visit_i64(n)
3962    }
3963
3964    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3965    where
3966        V: Visitor<'de>,
3967    {
3968        let value = self.get_single_value()?;
3969        let n = value
3970            .parse::<u8>()
3971            .map_err(|e| QueryExtractError::InvalidValue {
3972                name: "value".to_string(),
3973                value: value.to_string(),
3974                expected: "u8",
3975                message: e.to_string(),
3976            })?;
3977        visitor.visit_u8(n)
3978    }
3979
3980    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3981    where
3982        V: Visitor<'de>,
3983    {
3984        let value = self.get_single_value()?;
3985        let n = value
3986            .parse::<u16>()
3987            .map_err(|e| QueryExtractError::InvalidValue {
3988                name: "value".to_string(),
3989                value: value.to_string(),
3990                expected: "u16",
3991                message: e.to_string(),
3992            })?;
3993        visitor.visit_u16(n)
3994    }
3995
3996    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3997    where
3998        V: Visitor<'de>,
3999    {
4000        let value = self.get_single_value()?;
4001        let n = value
4002            .parse::<u32>()
4003            .map_err(|e| QueryExtractError::InvalidValue {
4004                name: "value".to_string(),
4005                value: value.to_string(),
4006                expected: "u32",
4007                message: e.to_string(),
4008            })?;
4009        visitor.visit_u32(n)
4010    }
4011
4012    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4013    where
4014        V: Visitor<'de>,
4015    {
4016        let value = self.get_single_value()?;
4017        let n = value
4018            .parse::<u64>()
4019            .map_err(|e| QueryExtractError::InvalidValue {
4020                name: "value".to_string(),
4021                value: value.to_string(),
4022                expected: "u64",
4023                message: e.to_string(),
4024            })?;
4025        visitor.visit_u64(n)
4026    }
4027
4028    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4029    where
4030        V: Visitor<'de>,
4031    {
4032        let value = self.get_single_value()?;
4033        let n = value
4034            .parse::<f32>()
4035            .map_err(|e| QueryExtractError::InvalidValue {
4036                name: "value".to_string(),
4037                value: value.to_string(),
4038                expected: "f32",
4039                message: e.to_string(),
4040            })?;
4041        visitor.visit_f32(n)
4042    }
4043
4044    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4045    where
4046        V: Visitor<'de>,
4047    {
4048        let value = self.get_single_value()?;
4049        let n = value
4050            .parse::<f64>()
4051            .map_err(|e| QueryExtractError::InvalidValue {
4052                name: "value".to_string(),
4053                value: value.to_string(),
4054                expected: "f64",
4055                message: e.to_string(),
4056            })?;
4057        visitor.visit_f64(n)
4058    }
4059
4060    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4061    where
4062        V: Visitor<'de>,
4063    {
4064        let value = self.get_single_value()?;
4065        let mut chars = value.chars();
4066        match (chars.next(), chars.next()) {
4067            (Some(c), None) => visitor.visit_char(c),
4068            _ => Err(QueryExtractError::InvalidValue {
4069                name: "value".to_string(),
4070                value: value.to_string(),
4071                expected: "char",
4072                message: "expected single character".to_string(),
4073            }),
4074        }
4075    }
4076
4077    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4078    where
4079        V: Visitor<'de>,
4080    {
4081        let value = self.get_single_value()?;
4082        visitor.visit_str(value)
4083    }
4084
4085    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4086    where
4087        V: Visitor<'de>,
4088    {
4089        let value = self.get_single_value()?;
4090        visitor.visit_string(value.to_owned())
4091    }
4092
4093    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4094    where
4095        V: Visitor<'de>,
4096    {
4097        let value = self.get_single_value()?;
4098        visitor.visit_bytes(value.as_bytes())
4099    }
4100
4101    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4102    where
4103        V: Visitor<'de>,
4104    {
4105        let value = self.get_single_value()?;
4106        visitor.visit_byte_buf(value.as_bytes().to_vec())
4107    }
4108
4109    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4110    where
4111        V: Visitor<'de>,
4112    {
4113        // For top-level option, check if we have any params
4114        if self.params.is_empty() {
4115            visitor.visit_none()
4116        } else {
4117            visitor.visit_some(self)
4118        }
4119    }
4120
4121    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4122    where
4123        V: Visitor<'de>,
4124    {
4125        visitor.visit_unit()
4126    }
4127
4128    fn deserialize_unit_struct<V>(
4129        self,
4130        _name: &'static str,
4131        visitor: V,
4132    ) -> Result<V::Value, Self::Error>
4133    where
4134        V: Visitor<'de>,
4135    {
4136        visitor.visit_unit()
4137    }
4138
4139    fn deserialize_newtype_struct<V>(
4140        self,
4141        _name: &'static str,
4142        visitor: V,
4143    ) -> Result<V::Value, Self::Error>
4144    where
4145        V: Visitor<'de>,
4146    {
4147        visitor.visit_newtype_struct(self)
4148    }
4149
4150    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4151    where
4152        V: Visitor<'de>,
4153    {
4154        // For a sequence at the top level, use all values
4155        let values: Vec<&str> = self
4156            .params
4157            .pairs()
4158            .iter()
4159            .map(|(_, v)| v.as_str())
4160            .collect();
4161        visitor.visit_seq(QuerySeqAccess::new(values))
4162    }
4163
4164    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
4165    where
4166        V: Visitor<'de>,
4167    {
4168        // For a tuple, use values in order
4169        let values: Vec<&str> = self
4170            .params
4171            .pairs()
4172            .iter()
4173            .map(|(_, v)| v.as_str())
4174            .collect();
4175        visitor.visit_seq(QuerySeqAccess::new(values))
4176    }
4177
4178    fn deserialize_tuple_struct<V>(
4179        self,
4180        _name: &'static str,
4181        _len: usize,
4182        visitor: V,
4183    ) -> Result<V::Value, Self::Error>
4184    where
4185        V: Visitor<'de>,
4186    {
4187        let values: Vec<&str> = self
4188            .params
4189            .pairs()
4190            .iter()
4191            .map(|(_, v)| v.as_str())
4192            .collect();
4193        visitor.visit_seq(QuerySeqAccess::new(values))
4194    }
4195
4196    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4197    where
4198        V: Visitor<'de>,
4199    {
4200        visitor.visit_map(QueryMapAccess::new(self.params))
4201    }
4202
4203    fn deserialize_struct<V>(
4204        self,
4205        _name: &'static str,
4206        _fields: &'static [&'static str],
4207        visitor: V,
4208    ) -> Result<V::Value, Self::Error>
4209    where
4210        V: Visitor<'de>,
4211    {
4212        self.deserialize_map(visitor)
4213    }
4214
4215    fn deserialize_enum<V>(
4216        self,
4217        _name: &'static str,
4218        _variants: &'static [&'static str],
4219        visitor: V,
4220    ) -> Result<V::Value, Self::Error>
4221    where
4222        V: Visitor<'de>,
4223    {
4224        // For enum, use the first value as a unit variant name
4225        let value = self.get_single_value()?;
4226        visitor.visit_enum(value.into_deserializer())
4227    }
4228
4229    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4230    where
4231        V: Visitor<'de>,
4232    {
4233        let value = self.get_single_value()?;
4234        visitor.visit_str(value)
4235    }
4236
4237    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4238    where
4239        V: Visitor<'de>,
4240    {
4241        visitor.visit_unit()
4242    }
4243}
4244
4245impl<'de> QueryDeserializer<'de> {
4246    fn get_single_value(&self) -> Result<&'de str, QueryExtractError> {
4247        self.params
4248            .pairs()
4249            .first()
4250            .map(|(_, v)| v.as_str())
4251            .ok_or_else(|| QueryExtractError::MissingParam {
4252                name: "value".to_string(),
4253            })
4254    }
4255}
4256
4257/// Helper to parse boolean from string.
4258fn parse_bool(s: &str) -> Result<bool, String> {
4259    match s.to_lowercase().as_str() {
4260        "true" | "1" | "yes" | "on" => Ok(true),
4261        "false" | "0" | "no" | "off" | "" => Ok(false),
4262        _ => Err(format!("cannot parse '{}' as boolean", s)),
4263    }
4264}
4265
4266/// Sequence access for deserializing arrays/vectors from query params.
4267struct QuerySeqAccess<'de> {
4268    values: Vec<&'de str>,
4269    index: usize,
4270}
4271
4272impl<'de> QuerySeqAccess<'de> {
4273    fn new(values: Vec<&'de str>) -> Self {
4274        Self { values, index: 0 }
4275    }
4276}
4277
4278impl<'de> SeqAccess<'de> for QuerySeqAccess<'de> {
4279    type Error = QueryExtractError;
4280
4281    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
4282    where
4283        T: de::DeserializeSeed<'de>,
4284    {
4285        if self.index >= self.values.len() {
4286            return Ok(None);
4287        }
4288
4289        let value = self.values[self.index];
4290        self.index += 1;
4291
4292        seed.deserialize(QueryValueDeserializer::new(value, None))
4293            .map(Some)
4294    }
4295
4296    fn size_hint(&self) -> Option<usize> {
4297        Some(self.values.len() - self.index)
4298    }
4299}
4300
4301/// Map access for deserializing structs from query params.
4302struct QueryMapAccess<'de> {
4303    params: &'de QueryParams,
4304    keys: Vec<&'de str>,
4305    index: usize,
4306}
4307
4308impl<'de> QueryMapAccess<'de> {
4309    fn new(params: &'de QueryParams) -> Self {
4310        let keys: Vec<&str> = params.keys().collect();
4311        Self {
4312            params,
4313            keys,
4314            index: 0,
4315        }
4316    }
4317}
4318
4319impl<'de> MapAccess<'de> for QueryMapAccess<'de> {
4320    type Error = QueryExtractError;
4321
4322    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
4323    where
4324        K: de::DeserializeSeed<'de>,
4325    {
4326        if self.index >= self.keys.len() {
4327            return Ok(None);
4328        }
4329
4330        let key = self.keys[self.index];
4331        seed.deserialize(key.into_deserializer()).map(Some)
4332    }
4333
4334    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
4335    where
4336        V: de::DeserializeSeed<'de>,
4337    {
4338        let key = self.keys[self.index];
4339        self.index += 1;
4340
4341        // Get all values for this key to support Vec<T>
4342        let values = self.params.get_all(key);
4343
4344        seed.deserialize(QueryFieldDeserializer::new(key, values))
4345    }
4346}
4347
4348/// Deserializer for a single query parameter value.
4349struct QueryValueDeserializer<'de> {
4350    value: &'de str,
4351    name: Option<&'de str>,
4352}
4353
4354impl<'de> QueryValueDeserializer<'de> {
4355    fn new(value: &'de str, name: Option<&'de str>) -> Self {
4356        Self { value, name }
4357    }
4358
4359    fn field_name(&self) -> String {
4360        self.name.unwrap_or("value").to_string()
4361    }
4362}
4363
4364impl<'de> Deserializer<'de> for QueryValueDeserializer<'de> {
4365    type Error = QueryExtractError;
4366
4367    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4368    where
4369        V: Visitor<'de>,
4370    {
4371        visitor.visit_str(self.value)
4372    }
4373
4374    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4375    where
4376        V: Visitor<'de>,
4377    {
4378        let b = parse_bool(self.value).map_err(|msg| QueryExtractError::InvalidValue {
4379            name: self.field_name(),
4380            value: self.value.to_string(),
4381            expected: "bool",
4382            message: msg,
4383        })?;
4384        visitor.visit_bool(b)
4385    }
4386
4387    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4388    where
4389        V: Visitor<'de>,
4390    {
4391        let n = self
4392            .value
4393            .parse::<i8>()
4394            .map_err(|e| QueryExtractError::InvalidValue {
4395                name: self.field_name(),
4396                value: self.value.to_string(),
4397                expected: "i8",
4398                message: e.to_string(),
4399            })?;
4400        visitor.visit_i8(n)
4401    }
4402
4403    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4404    where
4405        V: Visitor<'de>,
4406    {
4407        let n = self
4408            .value
4409            .parse::<i16>()
4410            .map_err(|e| QueryExtractError::InvalidValue {
4411                name: self.field_name(),
4412                value: self.value.to_string(),
4413                expected: "i16",
4414                message: e.to_string(),
4415            })?;
4416        visitor.visit_i16(n)
4417    }
4418
4419    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4420    where
4421        V: Visitor<'de>,
4422    {
4423        let n = self
4424            .value
4425            .parse::<i32>()
4426            .map_err(|e| QueryExtractError::InvalidValue {
4427                name: self.field_name(),
4428                value: self.value.to_string(),
4429                expected: "i32",
4430                message: e.to_string(),
4431            })?;
4432        visitor.visit_i32(n)
4433    }
4434
4435    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4436    where
4437        V: Visitor<'de>,
4438    {
4439        let n = self
4440            .value
4441            .parse::<i64>()
4442            .map_err(|e| QueryExtractError::InvalidValue {
4443                name: self.field_name(),
4444                value: self.value.to_string(),
4445                expected: "i64",
4446                message: e.to_string(),
4447            })?;
4448        visitor.visit_i64(n)
4449    }
4450
4451    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4452    where
4453        V: Visitor<'de>,
4454    {
4455        let n = self
4456            .value
4457            .parse::<u8>()
4458            .map_err(|e| QueryExtractError::InvalidValue {
4459                name: self.field_name(),
4460                value: self.value.to_string(),
4461                expected: "u8",
4462                message: e.to_string(),
4463            })?;
4464        visitor.visit_u8(n)
4465    }
4466
4467    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4468    where
4469        V: Visitor<'de>,
4470    {
4471        let n = self
4472            .value
4473            .parse::<u16>()
4474            .map_err(|e| QueryExtractError::InvalidValue {
4475                name: self.field_name(),
4476                value: self.value.to_string(),
4477                expected: "u16",
4478                message: e.to_string(),
4479            })?;
4480        visitor.visit_u16(n)
4481    }
4482
4483    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4484    where
4485        V: Visitor<'de>,
4486    {
4487        let n = self
4488            .value
4489            .parse::<u32>()
4490            .map_err(|e| QueryExtractError::InvalidValue {
4491                name: self.field_name(),
4492                value: self.value.to_string(),
4493                expected: "u32",
4494                message: e.to_string(),
4495            })?;
4496        visitor.visit_u32(n)
4497    }
4498
4499    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4500    where
4501        V: Visitor<'de>,
4502    {
4503        let n = self
4504            .value
4505            .parse::<u64>()
4506            .map_err(|e| QueryExtractError::InvalidValue {
4507                name: self.field_name(),
4508                value: self.value.to_string(),
4509                expected: "u64",
4510                message: e.to_string(),
4511            })?;
4512        visitor.visit_u64(n)
4513    }
4514
4515    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4516    where
4517        V: Visitor<'de>,
4518    {
4519        let n = self
4520            .value
4521            .parse::<f32>()
4522            .map_err(|e| QueryExtractError::InvalidValue {
4523                name: self.field_name(),
4524                value: self.value.to_string(),
4525                expected: "f32",
4526                message: e.to_string(),
4527            })?;
4528        visitor.visit_f32(n)
4529    }
4530
4531    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4532    where
4533        V: Visitor<'de>,
4534    {
4535        let n = self
4536            .value
4537            .parse::<f64>()
4538            .map_err(|e| QueryExtractError::InvalidValue {
4539                name: self.field_name(),
4540                value: self.value.to_string(),
4541                expected: "f64",
4542                message: e.to_string(),
4543            })?;
4544        visitor.visit_f64(n)
4545    }
4546
4547    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4548    where
4549        V: Visitor<'de>,
4550    {
4551        let mut chars = self.value.chars();
4552        match (chars.next(), chars.next()) {
4553            (Some(c), None) => visitor.visit_char(c),
4554            _ => Err(QueryExtractError::InvalidValue {
4555                name: self.field_name(),
4556                value: self.value.to_string(),
4557                expected: "char",
4558                message: "expected single character".to_string(),
4559            }),
4560        }
4561    }
4562
4563    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4564    where
4565        V: Visitor<'de>,
4566    {
4567        visitor.visit_str(self.value)
4568    }
4569
4570    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4571    where
4572        V: Visitor<'de>,
4573    {
4574        visitor.visit_string(self.value.to_owned())
4575    }
4576
4577    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4578    where
4579        V: Visitor<'de>,
4580    {
4581        visitor.visit_bytes(self.value.as_bytes())
4582    }
4583
4584    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4585    where
4586        V: Visitor<'de>,
4587    {
4588        visitor.visit_byte_buf(self.value.as_bytes().to_vec())
4589    }
4590
4591    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4592    where
4593        V: Visitor<'de>,
4594    {
4595        if self.value.is_empty() {
4596            visitor.visit_none()
4597        } else {
4598            visitor.visit_some(self)
4599        }
4600    }
4601
4602    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4603    where
4604        V: Visitor<'de>,
4605    {
4606        visitor.visit_unit()
4607    }
4608
4609    fn deserialize_unit_struct<V>(
4610        self,
4611        _name: &'static str,
4612        visitor: V,
4613    ) -> Result<V::Value, Self::Error>
4614    where
4615        V: Visitor<'de>,
4616    {
4617        visitor.visit_unit()
4618    }
4619
4620    fn deserialize_newtype_struct<V>(
4621        self,
4622        _name: &'static str,
4623        visitor: V,
4624    ) -> Result<V::Value, Self::Error>
4625    where
4626        V: Visitor<'de>,
4627    {
4628        visitor.visit_newtype_struct(self)
4629    }
4630
4631    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4632    where
4633        V: Visitor<'de>,
4634    {
4635        // Single value as sequence of one
4636        visitor.visit_seq(QuerySeqAccess::new(vec![self.value]))
4637    }
4638
4639    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
4640    where
4641        V: Visitor<'de>,
4642    {
4643        visitor.visit_seq(QuerySeqAccess::new(vec![self.value]))
4644    }
4645
4646    fn deserialize_tuple_struct<V>(
4647        self,
4648        _name: &'static str,
4649        _len: usize,
4650        visitor: V,
4651    ) -> Result<V::Value, Self::Error>
4652    where
4653        V: Visitor<'de>,
4654    {
4655        visitor.visit_seq(QuerySeqAccess::new(vec![self.value]))
4656    }
4657
4658    fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
4659    where
4660        V: Visitor<'de>,
4661    {
4662        // Can't deserialize a single value as a map
4663        Err(QueryExtractError::DeserializeError {
4664            message: "cannot deserialize single value as map".to_string(),
4665        })
4666    }
4667
4668    fn deserialize_struct<V>(
4669        self,
4670        _name: &'static str,
4671        _fields: &'static [&'static str],
4672        visitor: V,
4673    ) -> Result<V::Value, Self::Error>
4674    where
4675        V: Visitor<'de>,
4676    {
4677        self.deserialize_map(visitor)
4678    }
4679
4680    fn deserialize_enum<V>(
4681        self,
4682        _name: &'static str,
4683        _variants: &'static [&'static str],
4684        visitor: V,
4685    ) -> Result<V::Value, Self::Error>
4686    where
4687        V: Visitor<'de>,
4688    {
4689        visitor.visit_enum(self.value.into_deserializer())
4690    }
4691
4692    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4693    where
4694        V: Visitor<'de>,
4695    {
4696        visitor.visit_str(self.value)
4697    }
4698
4699    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4700    where
4701        V: Visitor<'de>,
4702    {
4703        visitor.visit_unit()
4704    }
4705}
4706
4707/// Deserializer for a query field that may have multiple values.
4708///
4709/// This handles the Vec<T> case: ?tags=a&tags=b -> tags: ["a", "b"]
4710struct QueryFieldDeserializer<'de> {
4711    name: &'de str,
4712    values: Vec<&'de str>,
4713}
4714
4715impl<'de> QueryFieldDeserializer<'de> {
4716    fn new(name: &'de str, values: Vec<&'de str>) -> Self {
4717        Self { name, values }
4718    }
4719}
4720
4721impl<'de> Deserializer<'de> for QueryFieldDeserializer<'de> {
4722    type Error = QueryExtractError;
4723
4724    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4725    where
4726        V: Visitor<'de>,
4727    {
4728        // Default to first value as string
4729        if let Some(value) = self.values.first() {
4730            visitor.visit_str(value)
4731        } else {
4732            visitor.visit_none()
4733        }
4734    }
4735
4736    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4737    where
4738        V: Visitor<'de>,
4739    {
4740        let value = self
4741            .values
4742            .first()
4743            .ok_or_else(|| QueryExtractError::MissingParam {
4744                name: self.name.to_string(),
4745            })?;
4746        let b = parse_bool(value).map_err(|msg| QueryExtractError::InvalidValue {
4747            name: self.name.to_string(),
4748            value: (*value).to_string(),
4749            expected: "bool",
4750            message: msg,
4751        })?;
4752        visitor.visit_bool(b)
4753    }
4754
4755    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4756    where
4757        V: Visitor<'de>,
4758    {
4759        let value = self
4760            .values
4761            .first()
4762            .ok_or_else(|| QueryExtractError::MissingParam {
4763                name: self.name.to_string(),
4764            })?;
4765        let n = value
4766            .parse::<i8>()
4767            .map_err(|e| QueryExtractError::InvalidValue {
4768                name: self.name.to_string(),
4769                value: (*value).to_string(),
4770                expected: "i8",
4771                message: e.to_string(),
4772            })?;
4773        visitor.visit_i8(n)
4774    }
4775
4776    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4777    where
4778        V: Visitor<'de>,
4779    {
4780        let value = self
4781            .values
4782            .first()
4783            .ok_or_else(|| QueryExtractError::MissingParam {
4784                name: self.name.to_string(),
4785            })?;
4786        let n = value
4787            .parse::<i16>()
4788            .map_err(|e| QueryExtractError::InvalidValue {
4789                name: self.name.to_string(),
4790                value: (*value).to_string(),
4791                expected: "i16",
4792                message: e.to_string(),
4793            })?;
4794        visitor.visit_i16(n)
4795    }
4796
4797    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4798    where
4799        V: Visitor<'de>,
4800    {
4801        let value = self
4802            .values
4803            .first()
4804            .ok_or_else(|| QueryExtractError::MissingParam {
4805                name: self.name.to_string(),
4806            })?;
4807        let n = value
4808            .parse::<i32>()
4809            .map_err(|e| QueryExtractError::InvalidValue {
4810                name: self.name.to_string(),
4811                value: (*value).to_string(),
4812                expected: "i32",
4813                message: e.to_string(),
4814            })?;
4815        visitor.visit_i32(n)
4816    }
4817
4818    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4819    where
4820        V: Visitor<'de>,
4821    {
4822        let value = self
4823            .values
4824            .first()
4825            .ok_or_else(|| QueryExtractError::MissingParam {
4826                name: self.name.to_string(),
4827            })?;
4828        let n = value
4829            .parse::<i64>()
4830            .map_err(|e| QueryExtractError::InvalidValue {
4831                name: self.name.to_string(),
4832                value: (*value).to_string(),
4833                expected: "i64",
4834                message: e.to_string(),
4835            })?;
4836        visitor.visit_i64(n)
4837    }
4838
4839    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4840    where
4841        V: Visitor<'de>,
4842    {
4843        let value = self
4844            .values
4845            .first()
4846            .ok_or_else(|| QueryExtractError::MissingParam {
4847                name: self.name.to_string(),
4848            })?;
4849        let n = value
4850            .parse::<u8>()
4851            .map_err(|e| QueryExtractError::InvalidValue {
4852                name: self.name.to_string(),
4853                value: (*value).to_string(),
4854                expected: "u8",
4855                message: e.to_string(),
4856            })?;
4857        visitor.visit_u8(n)
4858    }
4859
4860    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4861    where
4862        V: Visitor<'de>,
4863    {
4864        let value = self
4865            .values
4866            .first()
4867            .ok_or_else(|| QueryExtractError::MissingParam {
4868                name: self.name.to_string(),
4869            })?;
4870        let n = value
4871            .parse::<u16>()
4872            .map_err(|e| QueryExtractError::InvalidValue {
4873                name: self.name.to_string(),
4874                value: (*value).to_string(),
4875                expected: "u16",
4876                message: e.to_string(),
4877            })?;
4878        visitor.visit_u16(n)
4879    }
4880
4881    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4882    where
4883        V: Visitor<'de>,
4884    {
4885        let value = self
4886            .values
4887            .first()
4888            .ok_or_else(|| QueryExtractError::MissingParam {
4889                name: self.name.to_string(),
4890            })?;
4891        let n = value
4892            .parse::<u32>()
4893            .map_err(|e| QueryExtractError::InvalidValue {
4894                name: self.name.to_string(),
4895                value: (*value).to_string(),
4896                expected: "u32",
4897                message: e.to_string(),
4898            })?;
4899        visitor.visit_u32(n)
4900    }
4901
4902    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4903    where
4904        V: Visitor<'de>,
4905    {
4906        let value = self
4907            .values
4908            .first()
4909            .ok_or_else(|| QueryExtractError::MissingParam {
4910                name: self.name.to_string(),
4911            })?;
4912        let n = value
4913            .parse::<u64>()
4914            .map_err(|e| QueryExtractError::InvalidValue {
4915                name: self.name.to_string(),
4916                value: (*value).to_string(),
4917                expected: "u64",
4918                message: e.to_string(),
4919            })?;
4920        visitor.visit_u64(n)
4921    }
4922
4923    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4924    where
4925        V: Visitor<'de>,
4926    {
4927        let value = self
4928            .values
4929            .first()
4930            .ok_or_else(|| QueryExtractError::MissingParam {
4931                name: self.name.to_string(),
4932            })?;
4933        let n = value
4934            .parse::<f32>()
4935            .map_err(|e| QueryExtractError::InvalidValue {
4936                name: self.name.to_string(),
4937                value: (*value).to_string(),
4938                expected: "f32",
4939                message: e.to_string(),
4940            })?;
4941        visitor.visit_f32(n)
4942    }
4943
4944    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4945    where
4946        V: Visitor<'de>,
4947    {
4948        let value = self
4949            .values
4950            .first()
4951            .ok_or_else(|| QueryExtractError::MissingParam {
4952                name: self.name.to_string(),
4953            })?;
4954        let n = value
4955            .parse::<f64>()
4956            .map_err(|e| QueryExtractError::InvalidValue {
4957                name: self.name.to_string(),
4958                value: (*value).to_string(),
4959                expected: "f64",
4960                message: e.to_string(),
4961            })?;
4962        visitor.visit_f64(n)
4963    }
4964
4965    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4966    where
4967        V: Visitor<'de>,
4968    {
4969        let value = self
4970            .values
4971            .first()
4972            .ok_or_else(|| QueryExtractError::MissingParam {
4973                name: self.name.to_string(),
4974            })?;
4975        let mut chars = value.chars();
4976        match (chars.next(), chars.next()) {
4977            (Some(c), None) => visitor.visit_char(c),
4978            _ => Err(QueryExtractError::InvalidValue {
4979                name: self.name.to_string(),
4980                value: (*value).to_string(),
4981                expected: "char",
4982                message: "expected single character".to_string(),
4983            }),
4984        }
4985    }
4986
4987    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4988    where
4989        V: Visitor<'de>,
4990    {
4991        let value = self
4992            .values
4993            .first()
4994            .ok_or_else(|| QueryExtractError::MissingParam {
4995                name: self.name.to_string(),
4996            })?;
4997        visitor.visit_str(value)
4998    }
4999
5000    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5001    where
5002        V: Visitor<'de>,
5003    {
5004        let value = self
5005            .values
5006            .first()
5007            .ok_or_else(|| QueryExtractError::MissingParam {
5008                name: self.name.to_string(),
5009            })?;
5010        visitor.visit_string((*value).to_owned())
5011    }
5012
5013    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5014    where
5015        V: Visitor<'de>,
5016    {
5017        let value = self
5018            .values
5019            .first()
5020            .ok_or_else(|| QueryExtractError::MissingParam {
5021                name: self.name.to_string(),
5022            })?;
5023        visitor.visit_bytes(value.as_bytes())
5024    }
5025
5026    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5027    where
5028        V: Visitor<'de>,
5029    {
5030        let value = self
5031            .values
5032            .first()
5033            .ok_or_else(|| QueryExtractError::MissingParam {
5034                name: self.name.to_string(),
5035            })?;
5036        visitor.visit_byte_buf(value.as_bytes().to_vec())
5037    }
5038
5039    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5040    where
5041        V: Visitor<'de>,
5042    {
5043        if self.values.is_empty() {
5044            visitor.visit_none()
5045        } else {
5046            visitor.visit_some(self)
5047        }
5048    }
5049
5050    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5051    where
5052        V: Visitor<'de>,
5053    {
5054        visitor.visit_unit()
5055    }
5056
5057    fn deserialize_unit_struct<V>(
5058        self,
5059        _name: &'static str,
5060        visitor: V,
5061    ) -> Result<V::Value, Self::Error>
5062    where
5063        V: Visitor<'de>,
5064    {
5065        visitor.visit_unit()
5066    }
5067
5068    fn deserialize_newtype_struct<V>(
5069        self,
5070        _name: &'static str,
5071        visitor: V,
5072    ) -> Result<V::Value, Self::Error>
5073    where
5074        V: Visitor<'de>,
5075    {
5076        visitor.visit_newtype_struct(self)
5077    }
5078
5079    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5080    where
5081        V: Visitor<'de>,
5082    {
5083        // This is the Vec<T> case: return all values as a sequence
5084        visitor.visit_seq(QuerySeqAccess::new(self.values))
5085    }
5086
5087    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
5088    where
5089        V: Visitor<'de>,
5090    {
5091        visitor.visit_seq(QuerySeqAccess::new(self.values))
5092    }
5093
5094    fn deserialize_tuple_struct<V>(
5095        self,
5096        _name: &'static str,
5097        _len: usize,
5098        visitor: V,
5099    ) -> Result<V::Value, Self::Error>
5100    where
5101        V: Visitor<'de>,
5102    {
5103        visitor.visit_seq(QuerySeqAccess::new(self.values))
5104    }
5105
5106    fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
5107    where
5108        V: Visitor<'de>,
5109    {
5110        Err(QueryExtractError::DeserializeError {
5111            message: "cannot deserialize query field as map".to_string(),
5112        })
5113    }
5114
5115    fn deserialize_struct<V>(
5116        self,
5117        _name: &'static str,
5118        _fields: &'static [&'static str],
5119        visitor: V,
5120    ) -> Result<V::Value, Self::Error>
5121    where
5122        V: Visitor<'de>,
5123    {
5124        self.deserialize_map(visitor)
5125    }
5126
5127    fn deserialize_enum<V>(
5128        self,
5129        _name: &'static str,
5130        _variants: &'static [&'static str],
5131        visitor: V,
5132    ) -> Result<V::Value, Self::Error>
5133    where
5134        V: Visitor<'de>,
5135    {
5136        let value = self
5137            .values
5138            .first()
5139            .ok_or_else(|| QueryExtractError::MissingParam {
5140                name: self.name.to_string(),
5141            })?;
5142        visitor.visit_enum((*value).into_deserializer())
5143    }
5144
5145    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5146    where
5147        V: Visitor<'de>,
5148    {
5149        let value = self
5150            .values
5151            .first()
5152            .ok_or_else(|| QueryExtractError::MissingParam {
5153                name: self.name.to_string(),
5154            })?;
5155        visitor.visit_str(value)
5156    }
5157
5158    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5159    where
5160        V: Visitor<'de>,
5161    {
5162        visitor.visit_unit()
5163    }
5164}
5165
5166// ============================================================================
5167// Application State Extractor
5168// ============================================================================
5169
5170/// Application state container.
5171///
5172/// `AppState` holds typed state values that can be shared across request handlers.
5173/// State is typically set up when creating the application and injected into
5174/// requests by the router/server.
5175///
5176/// # Example
5177///
5178/// ```ignore
5179/// use fastapi_core::extract::{AppState, State};
5180/// use std::sync::Arc;
5181///
5182/// // Define your state types
5183/// struct DatabasePool { /* ... */ }
5184/// struct Config { api_key: String }
5185///
5186/// // Build the app state
5187/// let state = AppState::new()
5188///     .with(Arc::new(DatabasePool::new()))
5189///     .with(Arc::new(Config { api_key: "secret".into() }));
5190///
5191/// // In handlers, extract the state
5192/// async fn handler(db: State<Arc<DatabasePool>>, config: State<Arc<Config>>) {
5193///     // Use db and config...
5194/// }
5195/// ```
5196#[derive(Debug, Default, Clone)]
5197pub struct AppState {
5198    inner: std::sync::Arc<
5199        std::collections::HashMap<
5200            std::any::TypeId,
5201            std::sync::Arc<dyn std::any::Any + Send + Sync>,
5202        >,
5203    >,
5204}
5205
5206impl AppState {
5207    /// Create an empty application state container.
5208    #[must_use]
5209    pub fn new() -> Self {
5210        Self {
5211            inner: std::sync::Arc::new(std::collections::HashMap::new()),
5212        }
5213    }
5214
5215    /// Add a typed state value.
5216    ///
5217    /// The value must be `Send + Sync + 'static` to be safely shared across
5218    /// requests and threads.
5219    #[must_use]
5220    pub fn with<T: Send + Sync + 'static>(self, value: T) -> Self {
5221        let mut map = match std::sync::Arc::try_unwrap(self.inner) {
5222            Ok(map) => map,
5223            Err(arc) => (*arc).clone(),
5224        };
5225        map.insert(std::any::TypeId::of::<T>(), std::sync::Arc::new(value));
5226        Self {
5227            inner: std::sync::Arc::new(map),
5228        }
5229    }
5230
5231    /// Get a reference to a typed state value.
5232    #[must_use]
5233    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
5234        self.inner
5235            .get(&std::any::TypeId::of::<T>())
5236            .and_then(|arc| arc.downcast_ref::<T>())
5237    }
5238
5239    /// Check if state contains a value of type T.
5240    #[must_use]
5241    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
5242        self.inner.contains_key(&std::any::TypeId::of::<T>())
5243    }
5244
5245    /// Return the number of state values.
5246    #[must_use]
5247    pub fn len(&self) -> usize {
5248        self.inner.len()
5249    }
5250
5251    /// Returns true if no state values are stored.
5252    #[must_use]
5253    pub fn is_empty(&self) -> bool {
5254        self.inner.is_empty()
5255    }
5256}
5257
5258/// State extractor for application-wide shared state.
5259///
5260/// Extracts a typed state value from the application state stored in request
5261/// extensions. The state must have been previously registered with the application.
5262///
5263/// # Type Requirements
5264///
5265/// The type `T` must be `Clone + Send + Sync + 'static`.
5266///
5267/// # Error Responses
5268///
5269/// - **500 Internal Server Error**: State type not found (server configuration error)
5270///
5271/// # Example
5272///
5273/// ```ignore
5274/// use fastapi_core::extract::State;
5275/// use std::sync::Arc;
5276///
5277/// struct DatabasePool { /* ... */ }
5278///
5279/// #[get("/users")]
5280/// async fn list_users(db: State<Arc<DatabasePool>>) -> impl IntoResponse {
5281///     // db.0 contains the Arc<DatabasePool>
5282///     let users = db.query_users().await;
5283///     Json(users)
5284/// }
5285/// ```
5286#[derive(Debug, Clone)]
5287pub struct State<T>(pub T);
5288
5289impl<T> State<T> {
5290    /// Unwrap the inner value.
5291    pub fn into_inner(self) -> T {
5292        self.0
5293    }
5294}
5295
5296impl<T> Deref for State<T> {
5297    type Target = T;
5298
5299    fn deref(&self) -> &Self::Target {
5300        &self.0
5301    }
5302}
5303
5304impl<T> DerefMut for State<T> {
5305    fn deref_mut(&mut self) -> &mut Self::Target {
5306        &mut self.0
5307    }
5308}
5309
5310/// Error returned when state extraction fails.
5311#[derive(Debug)]
5312pub enum StateExtractError {
5313    /// Application state not found in request extensions.
5314    ///
5315    /// This indicates the server was not configured to inject state into requests.
5316    MissingAppState,
5317    /// Requested state type not found.
5318    ///
5319    /// The type was not registered with the application state.
5320    MissingStateType {
5321        /// The name of the type that was not found.
5322        type_name: &'static str,
5323    },
5324}
5325
5326impl std::fmt::Display for StateExtractError {
5327    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5328        match self {
5329            Self::MissingAppState => {
5330                write!(f, "Application state not configured in request")
5331            }
5332            Self::MissingStateType { type_name } => {
5333                write!(f, "State type not found: {type_name}")
5334            }
5335        }
5336    }
5337}
5338
5339impl std::error::Error for StateExtractError {}
5340
5341impl IntoResponse for StateExtractError {
5342    fn into_response(self) -> crate::response::Response {
5343        // State extraction failures are server configuration errors (500)
5344        HttpError::internal()
5345            .with_detail(self.to_string())
5346            .into_response()
5347    }
5348}
5349
5350impl<T> FromRequest for State<T>
5351where
5352    T: Clone + Send + Sync + 'static,
5353{
5354    type Error = StateExtractError;
5355
5356    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
5357        // Get the AppState from request extensions
5358        let app_state = req
5359            .get_extension::<AppState>()
5360            .ok_or(StateExtractError::MissingAppState)?;
5361
5362        // Get the specific state type
5363        let value = app_state
5364            .get::<T>()
5365            .ok_or(StateExtractError::MissingStateType {
5366                type_name: std::any::type_name::<T>(),
5367            })?;
5368
5369        Ok(State(value.clone()))
5370    }
5371}
5372
5373#[cfg(test)]
5374mod state_tests {
5375    use super::*;
5376    use crate::request::Method;
5377
5378    fn test_context() -> RequestContext {
5379        let cx = asupersync::Cx::for_testing();
5380        RequestContext::new(cx, 12345)
5381    }
5382
5383    #[derive(Clone, Debug, PartialEq)]
5384    struct DatabasePool {
5385        connection_string: String,
5386    }
5387
5388    #[derive(Clone, Debug, PartialEq)]
5389    struct AppConfig {
5390        debug: bool,
5391        port: u16,
5392    }
5393
5394    #[test]
5395    fn app_state_new_is_empty() {
5396        let state = AppState::new();
5397        assert!(state.is_empty());
5398        assert_eq!(state.len(), 0);
5399    }
5400
5401    #[test]
5402    fn app_state_with_single_type() {
5403        let db = DatabasePool {
5404            connection_string: "postgres://localhost".into(),
5405        };
5406        let state = AppState::new().with(db.clone());
5407
5408        assert!(!state.is_empty());
5409        assert_eq!(state.len(), 1);
5410        assert!(state.contains::<DatabasePool>());
5411        assert_eq!(state.get::<DatabasePool>(), Some(&db));
5412    }
5413
5414    #[test]
5415    fn app_state_with_multiple_types() {
5416        let db = DatabasePool {
5417            connection_string: "postgres://localhost".into(),
5418        };
5419        let config = AppConfig {
5420            debug: true,
5421            port: 8080,
5422        };
5423
5424        let state = AppState::new().with(db.clone()).with(config.clone());
5425
5426        assert_eq!(state.len(), 2);
5427        assert_eq!(state.get::<DatabasePool>(), Some(&db));
5428        assert_eq!(state.get::<AppConfig>(), Some(&config));
5429    }
5430
5431    #[test]
5432    fn app_state_get_missing_type() {
5433        let state = AppState::new().with(42i32);
5434        assert!(state.get::<String>().is_none());
5435        assert!(!state.contains::<String>());
5436    }
5437
5438    #[test]
5439    fn state_deref() {
5440        let state = State(42i32);
5441        assert_eq!(*state, 42);
5442    }
5443
5444    #[test]
5445    fn state_into_inner() {
5446        let state = State("hello".to_string());
5447        assert_eq!(state.into_inner(), "hello");
5448    }
5449
5450    #[test]
5451    fn state_extract_success() {
5452        let ctx = test_context();
5453        let db = DatabasePool {
5454            connection_string: "postgres://localhost".into(),
5455        };
5456        let app_state = AppState::new().with(db.clone());
5457
5458        let mut req = Request::new(Method::Get, "/test");
5459        req.insert_extension(app_state);
5460
5461        let result =
5462            futures_executor::block_on(State::<DatabasePool>::from_request(&ctx, &mut req));
5463        let State(extracted) = result.unwrap();
5464        assert_eq!(extracted, db);
5465    }
5466
5467    #[test]
5468    fn state_extract_multiple_types() {
5469        let ctx = test_context();
5470        let db = DatabasePool {
5471            connection_string: "postgres://localhost".into(),
5472        };
5473        let config = AppConfig {
5474            debug: true,
5475            port: 8080,
5476        };
5477        let app_state = AppState::new().with(db.clone()).with(config.clone());
5478
5479        let mut req = Request::new(Method::Get, "/test");
5480        req.insert_extension(app_state);
5481
5482        // Extract DatabasePool
5483        let result =
5484            futures_executor::block_on(State::<DatabasePool>::from_request(&ctx, &mut req));
5485        let State(extracted_db) = result.unwrap();
5486        assert_eq!(extracted_db, db);
5487
5488        // Extract AppConfig
5489        let result = futures_executor::block_on(State::<AppConfig>::from_request(&ctx, &mut req));
5490        let State(extracted_config) = result.unwrap();
5491        assert_eq!(extracted_config, config);
5492    }
5493
5494    #[test]
5495    fn state_extract_missing_app_state() {
5496        let ctx = test_context();
5497        let mut req = Request::new(Method::Get, "/test");
5498        // Don't insert AppState
5499
5500        let result =
5501            futures_executor::block_on(State::<DatabasePool>::from_request(&ctx, &mut req));
5502        assert!(matches!(result, Err(StateExtractError::MissingAppState)));
5503    }
5504
5505    #[test]
5506    fn state_extract_missing_type() {
5507        let ctx = test_context();
5508        let app_state = AppState::new().with(42i32);
5509
5510        let mut req = Request::new(Method::Get, "/test");
5511        req.insert_extension(app_state);
5512
5513        let result =
5514            futures_executor::block_on(State::<DatabasePool>::from_request(&ctx, &mut req));
5515        assert!(matches!(
5516            result,
5517            Err(StateExtractError::MissingStateType { .. })
5518        ));
5519    }
5520
5521    #[test]
5522    fn state_error_display() {
5523        let err = StateExtractError::MissingAppState;
5524        assert!(err.to_string().contains("not configured"));
5525
5526        let err = StateExtractError::MissingStateType {
5527            type_name: "DatabasePool",
5528        };
5529        assert!(err.to_string().contains("DatabasePool"));
5530    }
5531
5532    #[test]
5533    fn app_state_clone() {
5534        let db = DatabasePool {
5535            connection_string: "postgres://localhost".into(),
5536        };
5537        let state1 = AppState::new().with(db.clone());
5538        let state2 = state1.clone();
5539
5540        assert_eq!(state2.get::<DatabasePool>(), Some(&db));
5541    }
5542
5543    #[test]
5544    fn state_with_arc() {
5545        use std::sync::Arc;
5546
5547        let ctx = test_context();
5548        let db = Arc::new(DatabasePool {
5549            connection_string: "postgres://localhost".into(),
5550        });
5551        let app_state = AppState::new().with(db.clone());
5552
5553        let mut req = Request::new(Method::Get, "/test");
5554        req.insert_extension(app_state);
5555
5556        let result =
5557            futures_executor::block_on(State::<Arc<DatabasePool>>::from_request(&ctx, &mut req));
5558        let State(extracted) = result.unwrap();
5559        assert_eq!(extracted.connection_string, "postgres://localhost");
5560    }
5561
5562    // ========================================================================
5563    // bd-2u8l: Atomic State Mutation Tests
5564    // ========================================================================
5565
5566    #[test]
5567    fn atomic_counter_fetch_add_concurrent() {
5568        // Test concurrent fetch_add operations don't lose increments (bd-2u8l)
5569        use std::sync::Arc;
5570        use std::sync::atomic::{AtomicUsize, Ordering};
5571        use std::thread;
5572
5573        const NUM_THREADS: usize = 100;
5574        const INCREMENTS_PER_THREAD: usize = 1000;
5575
5576        let counter = Arc::new(AtomicUsize::new(0));
5577        let app_state = AppState::new().with(counter.clone());
5578
5579        let handles: Vec<_> = (0..NUM_THREADS)
5580            .map(|_| {
5581                let state = app_state.clone();
5582                thread::spawn(move || {
5583                    // Each thread gets the counter from state and increments it
5584                    let counter = state.get::<Arc<AtomicUsize>>().expect("Counter not found");
5585                    for _ in 0..INCREMENTS_PER_THREAD {
5586                        counter.fetch_add(1, Ordering::SeqCst);
5587                    }
5588                })
5589            })
5590            .collect();
5591
5592        // Wait for all threads
5593        for handle in handles {
5594            handle.join().expect("Thread panicked");
5595        }
5596
5597        // Verify no lost increments
5598        let final_value = counter.load(Ordering::SeqCst);
5599        let expected = NUM_THREADS * INCREMENTS_PER_THREAD;
5600        assert_eq!(
5601            final_value, expected,
5602            "Lost increments: expected {expected}, got {final_value}"
5603        );
5604    }
5605
5606    #[test]
5607    fn atomic_compare_and_swap_concurrent() {
5608        // Test compare-and-swap (CAS) patterns under concurrency (bd-2u8l)
5609        use std::sync::Arc;
5610        use std::sync::atomic::{AtomicUsize, Ordering};
5611        use std::thread;
5612
5613        const NUM_THREADS: usize = 50;
5614        const CAS_ATTEMPTS_PER_THREAD: usize = 100;
5615
5616        // Counter that tracks successful CAS operations
5617        let counter = Arc::new(AtomicUsize::new(0));
5618        let success_count = Arc::new(AtomicUsize::new(0));
5619
5620        let handles: Vec<_> = (0..NUM_THREADS)
5621            .map(|_| {
5622                let counter = counter.clone();
5623                let success_count = success_count.clone();
5624                thread::spawn(move || {
5625                    for _ in 0..CAS_ATTEMPTS_PER_THREAD {
5626                        // Try to increment using CAS
5627                        let mut current = counter.load(Ordering::SeqCst);
5628                        loop {
5629                            match counter.compare_exchange(
5630                                current,
5631                                current + 1,
5632                                Ordering::SeqCst,
5633                                Ordering::SeqCst,
5634                            ) {
5635                                Ok(_) => {
5636                                    success_count.fetch_add(1, Ordering::SeqCst);
5637                                    break;
5638                                }
5639                                Err(actual) => {
5640                                    // Retry with the actual value
5641                                    current = actual;
5642                                }
5643                            }
5644                        }
5645                    }
5646                })
5647            })
5648            .collect();
5649
5650        for handle in handles {
5651            handle.join().expect("Thread panicked");
5652        }
5653
5654        let final_counter = counter.load(Ordering::SeqCst);
5655        let total_successes = success_count.load(Ordering::SeqCst);
5656        let expected = NUM_THREADS * CAS_ATTEMPTS_PER_THREAD;
5657
5658        // All CAS operations should succeed (with retries)
5659        assert_eq!(total_successes, expected);
5660        // Counter should equal total successful CAS operations
5661        assert_eq!(final_counter, expected);
5662    }
5663
5664    #[test]
5665    fn atomic_state_concurrent_reads() {
5666        // Test concurrent reads don't corrupt state (bd-2u8l)
5667        use std::sync::Arc;
5668        use std::sync::atomic::{AtomicUsize, Ordering};
5669        use std::thread;
5670
5671        const NUM_READERS: usize = 100;
5672        const READS_PER_THREAD: usize = 1000;
5673        const INITIAL_VALUE: usize = 42;
5674
5675        let counter = Arc::new(AtomicUsize::new(INITIAL_VALUE));
5676        let app_state = AppState::new().with(counter.clone());
5677
5678        let handles: Vec<_> = (0..NUM_READERS)
5679            .map(|_| {
5680                let state = app_state.clone();
5681                thread::spawn(move || {
5682                    let counter = state.get::<Arc<AtomicUsize>>().expect("Counter not found");
5683                    for _ in 0..READS_PER_THREAD {
5684                        let value = counter.load(Ordering::SeqCst);
5685                        assert_eq!(value, INITIAL_VALUE, "Value corrupted during read");
5686                    }
5687                })
5688            })
5689            .collect();
5690
5691        for handle in handles {
5692            handle.join().expect("Thread panicked");
5693        }
5694
5695        // Value should be unchanged
5696        assert_eq!(counter.load(Ordering::SeqCst), INITIAL_VALUE);
5697    }
5698
5699    #[test]
5700    fn atomic_rate_limiter_pattern() {
5701        // Test a rate-limiter pattern using atomics (bd-2u8l)
5702        use std::sync::Arc;
5703        use std::sync::atomic::{AtomicUsize, Ordering};
5704        use std::thread;
5705
5706        const MAX_REQUESTS: usize = 100;
5707        const NUM_CLIENTS: usize = 50;
5708        const REQUESTS_PER_CLIENT: usize = 5;
5709
5710        // Simple rate limiter state
5711        #[derive(Clone)]
5712        struct RateLimiter {
5713            current_count: Arc<AtomicUsize>,
5714            max_count: usize,
5715        }
5716
5717        impl RateLimiter {
5718            fn new(max: usize) -> Self {
5719                Self {
5720                    current_count: Arc::new(AtomicUsize::new(0)),
5721                    max_count: max,
5722                }
5723            }
5724
5725            fn try_acquire(&self) -> bool {
5726                let mut current = self.current_count.load(Ordering::SeqCst);
5727                loop {
5728                    if current >= self.max_count {
5729                        return false;
5730                    }
5731                    match self.current_count.compare_exchange(
5732                        current,
5733                        current + 1,
5734                        Ordering::SeqCst,
5735                        Ordering::SeqCst,
5736                    ) {
5737                        Ok(_) => return true,
5738                        Err(actual) => current = actual,
5739                    }
5740                }
5741            }
5742
5743            fn count(&self) -> usize {
5744                self.current_count.load(Ordering::SeqCst)
5745            }
5746        }
5747
5748        let limiter = RateLimiter::new(MAX_REQUESTS);
5749        let app_state = AppState::new().with(limiter.clone());
5750        let allowed_count = Arc::new(AtomicUsize::new(0));
5751        let denied_count = Arc::new(AtomicUsize::new(0));
5752
5753        let handles: Vec<_> = (0..NUM_CLIENTS)
5754            .map(|_| {
5755                let state = app_state.clone();
5756                let allowed = allowed_count.clone();
5757                let denied = denied_count.clone();
5758                thread::spawn(move || {
5759                    let limiter = state.get::<RateLimiter>().expect("Limiter not found");
5760                    for _ in 0..REQUESTS_PER_CLIENT {
5761                        if limiter.try_acquire() {
5762                            allowed.fetch_add(1, Ordering::SeqCst);
5763                        } else {
5764                            denied.fetch_add(1, Ordering::SeqCst);
5765                        }
5766                    }
5767                })
5768            })
5769            .collect();
5770
5771        for handle in handles {
5772            handle.join().expect("Thread panicked");
5773        }
5774
5775        let total_allowed = allowed_count.load(Ordering::SeqCst);
5776        let total_denied = denied_count.load(Ordering::SeqCst);
5777        let total_requests = NUM_CLIENTS * REQUESTS_PER_CLIENT;
5778
5779        // Exactly MAX_REQUESTS should be allowed
5780        assert_eq!(total_allowed, MAX_REQUESTS, "Allowed count mismatch");
5781        // Rest should be denied
5782        assert_eq!(
5783            total_denied,
5784            total_requests - MAX_REQUESTS,
5785            "Denied count mismatch"
5786        );
5787        // Counter should never exceed max
5788        assert!(limiter.count() <= MAX_REQUESTS, "Rate limiter exceeded max");
5789    }
5790
5791    #[test]
5792    fn atomic_concurrent_queue_pattern() {
5793        // Test a lock-free queue pattern using atomics (bd-2u8l)
5794        use std::sync::Arc;
5795        use std::sync::atomic::{AtomicUsize, Ordering};
5796        use std::thread;
5797
5798        const NUM_PRODUCERS: usize = 10;
5799        const ITEMS_PER_PRODUCER: usize = 100;
5800
5801        // Simple monotonic counter to simulate queue indices
5802        let _head = Arc::new(AtomicUsize::new(0));
5803        let tail = Arc::new(AtomicUsize::new(0));
5804        let produced_ids = Arc::new(parking_lot::Mutex::new(Vec::new()));
5805
5806        let handles: Vec<_> = (0..NUM_PRODUCERS)
5807            .map(|_| {
5808                let tail = tail.clone();
5809                let ids = produced_ids.clone();
5810                thread::spawn(move || {
5811                    for _ in 0..ITEMS_PER_PRODUCER {
5812                        // Atomically claim a slot in the "queue"
5813                        let slot = tail.fetch_add(1, Ordering::SeqCst);
5814                        ids.lock().push(slot);
5815                    }
5816                })
5817            })
5818            .collect();
5819
5820        for handle in handles {
5821            handle.join().expect("Thread panicked");
5822        }
5823
5824        let produced = produced_ids.lock();
5825        let expected_count = NUM_PRODUCERS * ITEMS_PER_PRODUCER;
5826
5827        // Should have produced the expected number of items
5828        assert_eq!(produced.len(), expected_count);
5829
5830        // Each ID should be unique (no lost slots)
5831        let mut sorted: Vec<_> = produced.iter().copied().collect();
5832        sorted.sort_unstable();
5833        for (i, &id) in sorted.iter().enumerate() {
5834            assert_eq!(id, i, "Slot {i} missing or duplicated");
5835        }
5836
5837        // Tail should match expected
5838        assert_eq!(tail.load(Ordering::SeqCst), expected_count);
5839    }
5840
5841    // ========================================================================
5842    // bd-tnw0: Concurrent State Reads Under Load
5843    // ========================================================================
5844
5845    #[test]
5846    fn concurrent_reads_basic_consistency() {
5847        // Test basic concurrent reads return consistent values (bd-tnw0)
5848        use std::sync::Arc;
5849        use std::thread;
5850
5851        const NUM_READERS: usize = 100;
5852        const READS_PER_THREAD: usize = 100;
5853
5854        // Create immutable state values
5855        let config_value = "production";
5856        let port_value = 8080u16;
5857        let enabled_value = true;
5858
5859        let app_state = AppState::new()
5860            .with(config_value.to_string())
5861            .with(port_value)
5862            .with(enabled_value);
5863
5864        let error_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
5865
5866        let handles: Vec<_> = (0..NUM_READERS)
5867            .map(|_| {
5868                let state = app_state.clone();
5869                let errors = error_count.clone();
5870                thread::spawn(move || {
5871                    for _ in 0..READS_PER_THREAD {
5872                        // Read all state values (get returns Option<&T>)
5873                        let config = state.get::<String>();
5874                        let port = state.get::<u16>();
5875                        let enabled = state.get::<bool>();
5876
5877                        // Verify consistency
5878                        if config.map(String::as_str) != Some(config_value) {
5879                            errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
5880                        }
5881                        if port != Some(&port_value) {
5882                            errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
5883                        }
5884                        if enabled != Some(&enabled_value) {
5885                            errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
5886                        }
5887                    }
5888                })
5889            })
5890            .collect();
5891
5892        for handle in handles {
5893            handle.join().expect("Thread panicked");
5894        }
5895
5896        assert_eq!(
5897            error_count.load(std::sync::atomic::Ordering::SeqCst),
5898            0,
5899            "Some reads returned inconsistent values"
5900        );
5901    }
5902
5903    #[test]
5904    fn concurrent_reads_varying_payload_sizes() {
5905        // Test reads with different payload sizes (bd-tnw0)
5906        use std::sync::Arc;
5907        use std::thread;
5908
5909        const NUM_READERS: usize = 50;
5910        const READS_PER_THREAD: usize = 50;
5911
5912        // Small payload (just a number)
5913        let small: i32 = 42;
5914        // Medium payload (string)
5915        let medium: String = "a".repeat(1000);
5916        // Large payload (vector)
5917        let large: Vec<u8> = vec![0u8; 10_000];
5918
5919        let app_state = AppState::new()
5920            .with(small)
5921            .with(medium.clone())
5922            .with(large.clone());
5923
5924        let error_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
5925
5926        let handles: Vec<_> = (0..NUM_READERS)
5927            .map(|_| {
5928                let state = app_state.clone();
5929                let _expected_medium = medium.clone();
5930                let expected_large = large.clone();
5931                let errors = error_count.clone();
5932                thread::spawn(move || {
5933                    for _ in 0..READS_PER_THREAD {
5934                        // Read all payload sizes
5935                        if state.get::<i32>() != Some(&42) {
5936                            errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
5937                        }
5938                        if state.get::<String>().is_some_and(|s| s.len() != 1000) {
5939                            errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
5940                        }
5941                        if state.get::<Vec<u8>>() != Some(&expected_large) {
5942                            errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
5943                        }
5944                    }
5945                })
5946            })
5947            .collect();
5948
5949        for handle in handles {
5950            handle.join().expect("Thread panicked");
5951        }
5952
5953        assert_eq!(
5954            error_count.load(std::sync::atomic::Ordering::SeqCst),
5955            0,
5956            "Payload size affected read consistency"
5957        );
5958    }
5959
5960    #[test]
5961    fn concurrent_reads_nested_structures() {
5962        // Test reads with nested state structures (bd-tnw0)
5963        use std::sync::Arc;
5964        use std::thread;
5965
5966        const NUM_READERS: usize = 50;
5967
5968        #[derive(Clone, Debug, PartialEq)]
5969        struct OuterConfig {
5970            inner: InnerConfig,
5971            name: String,
5972        }
5973
5974        #[derive(Clone, Debug, PartialEq)]
5975        struct InnerConfig {
5976            values: Vec<i32>,
5977            enabled: bool,
5978        }
5979
5980        let nested = OuterConfig {
5981            inner: InnerConfig {
5982                values: vec![1, 2, 3, 4, 5],
5983                enabled: true,
5984            },
5985            name: "nested_test".to_string(),
5986        };
5987
5988        let app_state = AppState::new().with(nested.clone());
5989        let error_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
5990
5991        let handles: Vec<_> = (0..NUM_READERS)
5992            .map(|_| {
5993                let state = app_state.clone();
5994                let expected = nested.clone();
5995                let errors = error_count.clone();
5996                thread::spawn(move || {
5997                    for _ in 0..100 {
5998                        let read = state.get::<OuterConfig>();
5999                        if read != Some(&expected) {
6000                            errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
6001                        }
6002                        // Also verify nested field access
6003                        if let Some(outer) = read {
6004                            if outer.inner.values.len() != 5 {
6005                                errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
6006                            }
6007                        }
6008                    }
6009                })
6010            })
6011            .collect();
6012
6013        for handle in handles {
6014            handle.join().expect("Thread panicked");
6015        }
6016
6017        assert_eq!(
6018            error_count.load(std::sync::atomic::Ordering::SeqCst),
6019            0,
6020            "Nested structure reads inconsistent"
6021        );
6022    }
6023
6024    #[test]
6025    #[allow(clippy::cast_possible_wrap)]
6026    fn concurrent_reads_with_arc_rwlock_pattern() {
6027        // Test Arc<RwLock<T>> patterns for mutable shared state (bd-tnw0)
6028        use parking_lot::RwLock;
6029        use std::sync::Arc;
6030        use std::thread;
6031
6032        const NUM_READERS: usize = 80;
6033        const NUM_WRITERS: usize = 20;
6034        const OPS_PER_THREAD: usize = 100;
6035
6036        #[derive(Clone)]
6037        struct MutableState {
6038            data: Arc<RwLock<Vec<i32>>>,
6039        }
6040
6041        impl MutableState {
6042            fn new() -> Self {
6043                Self {
6044                    data: Arc::new(RwLock::new(Vec::new())),
6045                }
6046            }
6047
6048            fn push(&self, value: i32) {
6049                self.data.write().push(value);
6050            }
6051
6052            fn len(&self) -> usize {
6053                self.data.read().len()
6054            }
6055        }
6056
6057        let mutable_state = MutableState::new();
6058        let app_state = AppState::new().with(mutable_state.clone());
6059        let read_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
6060        let write_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
6061
6062        // Spawn reader threads
6063        let reader_handles: Vec<_> = (0..NUM_READERS)
6064            .map(|_| {
6065                let state = app_state.clone();
6066                let reads = read_count.clone();
6067                thread::spawn(move || {
6068                    for _ in 0..OPS_PER_THREAD {
6069                        let ms = state.get::<MutableState>().expect("State not found");
6070                        let _ = ms.len(); // Just read the length
6071                        reads.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
6072                    }
6073                })
6074            })
6075            .collect();
6076
6077        // Spawn writer threads
6078        let writer_handles: Vec<_> = (0..NUM_WRITERS)
6079            .map(|i| {
6080                let state = app_state.clone();
6081                let writes = write_count.clone();
6082                thread::spawn(move || {
6083                    for j in 0..OPS_PER_THREAD {
6084                        let ms = state.get::<MutableState>().expect("State not found");
6085                        ms.push((i * OPS_PER_THREAD + j) as i32);
6086                        writes.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
6087                    }
6088                })
6089            })
6090            .collect();
6091
6092        // Wait for all threads
6093        for handle in reader_handles {
6094            handle.join().expect("Reader thread panicked");
6095        }
6096        for handle in writer_handles {
6097            handle.join().expect("Writer thread panicked");
6098        }
6099
6100        // Verify all operations completed
6101        let total_reads = read_count.load(std::sync::atomic::Ordering::SeqCst);
6102        let total_writes = write_count.load(std::sync::atomic::Ordering::SeqCst);
6103
6104        assert_eq!(total_reads, NUM_READERS * OPS_PER_THREAD);
6105        assert_eq!(total_writes, NUM_WRITERS * OPS_PER_THREAD);
6106
6107        // Verify final state has all writes
6108        assert_eq!(mutable_state.len(), NUM_WRITERS * OPS_PER_THREAD);
6109    }
6110}
6111
6112// ============================================================================
6113// Special Parameter Extractors (Request/Response Injection)
6114// ============================================================================
6115
6116/// Read-only request data access.
6117///
6118/// Provides access to request metadata without consuming the body.
6119/// For body access, use the `Json`, `Form`, or other body extractors.
6120///
6121/// # Example
6122///
6123/// ```ignore
6124/// use fastapi_core::extract::RequestRef;
6125///
6126/// async fn handler(req: RequestRef) -> impl IntoResponse {
6127///     format!("Method: {}, Path: {}", req.method(), req.path())
6128/// }
6129/// ```
6130#[derive(Debug, Clone)]
6131pub struct RequestRef {
6132    method: crate::request::Method,
6133    path: String,
6134    query: Option<String>,
6135    headers: Vec<(String, Vec<u8>)>,
6136}
6137
6138impl RequestRef {
6139    /// Get the HTTP method.
6140    #[must_use]
6141    pub fn method(&self) -> crate::request::Method {
6142        self.method
6143    }
6144
6145    /// Get the request path.
6146    #[must_use]
6147    pub fn path(&self) -> &str {
6148        &self.path
6149    }
6150
6151    /// Get the query string.
6152    #[must_use]
6153    pub fn query(&self) -> Option<&str> {
6154        self.query.as_deref()
6155    }
6156
6157    /// Get a header value by name (case-insensitive).
6158    #[must_use]
6159    pub fn header(&self, name: &str) -> Option<&[u8]> {
6160        let name_lower = name.to_ascii_lowercase();
6161        self.headers
6162            .iter()
6163            .find(|(n, _)| n.to_ascii_lowercase() == name_lower)
6164            .map(|(_, v)| v.as_slice())
6165    }
6166
6167    /// Iterate over all headers.
6168    pub fn headers(&self) -> impl Iterator<Item = (&str, &[u8])> {
6169        self.headers.iter().map(|(n, v)| (n.as_str(), v.as_slice()))
6170    }
6171}
6172
6173impl FromRequest for RequestRef {
6174    type Error = std::convert::Infallible;
6175
6176    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
6177        Ok(RequestRef {
6178            method: req.method(),
6179            path: req.path().to_string(),
6180            query: req.query().map(String::from),
6181            headers: req
6182                .headers()
6183                .iter()
6184                .map(|(name, value)| (name.to_string(), value.to_vec()))
6185                .collect(),
6186        })
6187    }
6188}
6189
6190/// Mutable response container for setting response headers and cookies.
6191///
6192/// This extractor allows handlers to set additional response headers and cookies
6193/// that will be merged into the final response. The handler's return value
6194/// determines the status code and body; `ResponseMut` adds headers on top.
6195///
6196/// # Example
6197///
6198/// ```ignore
6199/// use fastapi_core::extract::ResponseMut;
6200/// use fastapi_core::response::Json;
6201///
6202/// async fn handler(mut resp: ResponseMut) -> Json<Data> {
6203///     resp.header("X-Custom-Header", "custom-value");
6204///     resp.cookie("session", "abc123");
6205///     Json(data)
6206/// }
6207/// ```
6208#[derive(Debug, Clone, Default)]
6209pub struct ResponseMutations {
6210    /// Headers to add to the response.
6211    pub headers: Vec<(String, Vec<u8>)>,
6212    /// Cookies to set (name, value, attributes).
6213    pub cookies: Vec<Cookie>,
6214    /// Cookies to delete.
6215    pub delete_cookies: Vec<String>,
6216}
6217
6218impl ResponseMutations {
6219    /// Create empty response mutations.
6220    #[must_use]
6221    pub fn new() -> Self {
6222        Self::default()
6223    }
6224
6225    /// Add a header.
6226    pub fn add_header(&mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) {
6227        self.headers.push((name.into(), value.into()));
6228    }
6229
6230    /// Set a cookie.
6231    pub fn add_cookie(&mut self, cookie: Cookie) {
6232        self.cookies.push(cookie);
6233    }
6234
6235    /// Delete a cookie by name.
6236    pub fn remove_cookie(&mut self, name: impl Into<String>) {
6237        self.delete_cookies.push(name.into());
6238    }
6239
6240    /// Apply mutations to a response.
6241    #[must_use]
6242    pub fn apply(self, mut response: crate::response::Response) -> crate::response::Response {
6243        // Add headers
6244        for (name, value) in self.headers {
6245            response = response.header(name, value);
6246        }
6247
6248        // Add Set-Cookie headers for cookies
6249        for cookie in self.cookies {
6250            response = response.header("Set-Cookie", cookie.to_header_value().into_bytes());
6251        }
6252
6253        // Add Set-Cookie headers to delete cookies
6254        for name in self.delete_cookies {
6255            // Sanitize the cookie name to prevent injection
6256            let sanitized_name = sanitize_cookie_token(&name);
6257            let delete_cookie = format!("{}=; Max-Age=0; Path=/", sanitized_name);
6258            response = response.header("Set-Cookie", delete_cookie.into_bytes());
6259        }
6260
6261        response
6262    }
6263}
6264
6265// ============================================================================
6266// Cookie Sanitization Helpers
6267// ============================================================================
6268
6269/// Sanitize a cookie name to prevent injection attacks.
6270///
6271/// RFC 6265 specifies that cookie-name must be a valid HTTP token.
6272/// This removes control characters and separators that could be misinterpreted.
6273fn sanitize_cookie_token(name: &str) -> String {
6274    name.chars()
6275        .filter(|&c| {
6276            // Token characters per RFC 7230: any VCHAR except delimiters
6277            // Delimiters: "(),/:;<=>?@[\]{} and control chars
6278            c.is_ascii()
6279                && !c.is_ascii_control()
6280                && c != ' '
6281                && c != '"'
6282                && c != '('
6283                && c != ')'
6284                && c != ','
6285                && c != '/'
6286                && c != ':'
6287                && c != ';'
6288                && c != '<'
6289                && c != '='
6290                && c != '>'
6291                && c != '?'
6292                && c != '@'
6293                && c != '['
6294                && c != '\\'
6295                && c != ']'
6296                && c != '{'
6297                && c != '}'
6298        })
6299        .collect()
6300}
6301
6302/// Sanitize a cookie value to prevent injection attacks.
6303///
6304/// RFC 6265 specifies that cookie-value excludes CTLs, whitespace,
6305/// DQUOTE, comma, semicolon, and backslash (unless using quoted form).
6306fn sanitize_cookie_value(value: &str) -> String {
6307    value
6308        .chars()
6309        .filter(|&c| {
6310            c.is_ascii()
6311                && !c.is_ascii_control()
6312                && c != ' '
6313                && c != '"'
6314                && c != ','
6315                && c != ';'
6316                && c != '\\'
6317        })
6318        .collect()
6319}
6320
6321/// Sanitize a cookie attribute value (path, domain) to prevent injection.
6322///
6323/// Removes characters that could be interpreted as attribute delimiters.
6324fn sanitize_cookie_attr(attr: &str) -> String {
6325    attr.chars()
6326        .filter(|&c| c != ';' && c != '\r' && c != '\n' && c != '\0')
6327        .collect()
6328}
6329
6330/// A cookie to set in the response.
6331#[derive(Debug, Clone)]
6332pub struct Cookie {
6333    /// Cookie name.
6334    pub name: String,
6335    /// Cookie value.
6336    pub value: String,
6337    /// Max-Age in seconds (None = session cookie).
6338    pub max_age: Option<i64>,
6339    /// Path (defaults to /).
6340    pub path: Option<String>,
6341    /// Domain.
6342    pub domain: Option<String>,
6343    /// Secure flag.
6344    pub secure: bool,
6345    /// HttpOnly flag.
6346    pub http_only: bool,
6347    /// SameSite attribute.
6348    pub same_site: Option<SameSite>,
6349}
6350
6351impl Cookie {
6352    /// Create a new cookie with name and value.
6353    #[must_use]
6354    pub fn new(name: impl Into<String>, value: impl Into<String>) -> Self {
6355        Self {
6356            name: name.into(),
6357            value: value.into(),
6358            max_age: None,
6359            path: None,
6360            domain: None,
6361            secure: false,
6362            http_only: false,
6363            same_site: None,
6364        }
6365    }
6366
6367    /// Set the Max-Age attribute.
6368    #[must_use]
6369    pub fn max_age(mut self, seconds: i64) -> Self {
6370        self.max_age = Some(seconds);
6371        self
6372    }
6373
6374    /// Set the Path attribute.
6375    #[must_use]
6376    pub fn path(mut self, path: impl Into<String>) -> Self {
6377        self.path = Some(path.into());
6378        self
6379    }
6380
6381    /// Set the Domain attribute.
6382    #[must_use]
6383    pub fn domain(mut self, domain: impl Into<String>) -> Self {
6384        self.domain = Some(domain.into());
6385        self
6386    }
6387
6388    /// Set the Secure flag.
6389    #[must_use]
6390    pub fn secure(mut self, secure: bool) -> Self {
6391        self.secure = secure;
6392        self
6393    }
6394
6395    /// Set the HttpOnly flag.
6396    #[must_use]
6397    pub fn http_only(mut self, http_only: bool) -> Self {
6398        self.http_only = http_only;
6399        self
6400    }
6401
6402    /// Set the SameSite attribute.
6403    #[must_use]
6404    pub fn same_site(mut self, same_site: SameSite) -> Self {
6405        self.same_site = Some(same_site);
6406        self
6407    }
6408
6409    /// Convert to Set-Cookie header value.
6410    ///
6411    /// # Security
6412    ///
6413    /// Cookie names, values, and attribute values are sanitized to prevent
6414    /// attribute injection attacks. Characters that could be interpreted as
6415    /// attribute delimiters (`;`, `\r`, `\n`, `\0`) are removed.
6416    #[must_use]
6417    pub fn to_header_value(&self) -> String {
6418        // Sanitize cookie name: remove any characters that aren't valid tokens
6419        // RFC 6265: cookie-name = token (excludes CTLs, separators)
6420        let sanitized_name = sanitize_cookie_token(&self.name);
6421
6422        // Sanitize cookie value: remove characters that could inject attributes
6423        // RFC 6265: cookie-value excludes CTLs, whitespace, DQUOTE, comma, semicolon, backslash
6424        let sanitized_value = sanitize_cookie_value(&self.value);
6425
6426        let mut parts = vec![format!("{}={}", sanitized_name, sanitized_value)];
6427
6428        if let Some(max_age) = self.max_age {
6429            parts.push(format!("Max-Age={}", max_age));
6430        }
6431        if let Some(ref path) = self.path {
6432            // Sanitize path to prevent attribute injection
6433            let sanitized_path = sanitize_cookie_attr(path);
6434            parts.push(format!("Path={}", sanitized_path));
6435        }
6436        if let Some(ref domain) = self.domain {
6437            // Sanitize domain to prevent attribute injection
6438            let sanitized_domain = sanitize_cookie_attr(domain);
6439            parts.push(format!("Domain={}", sanitized_domain));
6440        }
6441        if self.secure {
6442            parts.push("Secure".to_string());
6443        }
6444        if self.http_only {
6445            parts.push("HttpOnly".to_string());
6446        }
6447        if let Some(ref same_site) = self.same_site {
6448            parts.push(format!("SameSite={}", same_site.as_str()));
6449        }
6450
6451        parts.join("; ")
6452    }
6453
6454    // =========================================================================
6455    // Secure Cookie Configuration Helpers
6456    // =========================================================================
6457
6458    /// Create a session cookie with secure defaults.
6459    ///
6460    /// Session cookies are:
6461    /// - HttpOnly (not accessible to JavaScript)
6462    /// - Secure (HTTPS only, unless `production` is false)
6463    /// - SameSite=Lax (sent with top-level navigations)
6464    /// - Path=/ (accessible site-wide)
6465    ///
6466    /// # Arguments
6467    ///
6468    /// * `name` - Cookie name
6469    /// * `value` - Cookie value
6470    /// * `production` - If true, sets Secure flag; if false, omits it for local development
6471    ///
6472    /// # Example
6473    ///
6474    /// ```ignore
6475    /// use fastapi_core::extract::Cookie;
6476    ///
6477    /// // Production session cookie
6478    /// let cookie = Cookie::session("session_id", "abc123", true);
6479    ///
6480    /// // Development session cookie (no Secure flag)
6481    /// let cookie = Cookie::session("session_id", "abc123", false);
6482    /// ```
6483    #[must_use]
6484    pub fn session(name: impl Into<String>, value: impl Into<String>, production: bool) -> Self {
6485        Self::new(name, value)
6486            .http_only(true)
6487            .secure(production)
6488            .same_site(SameSite::Lax)
6489            .path("/")
6490    }
6491
6492    /// Create an authentication cookie with strict security.
6493    ///
6494    /// Auth cookies are:
6495    /// - HttpOnly (not accessible to JavaScript)
6496    /// - Secure (HTTPS only, unless `production` is false)
6497    /// - SameSite=Strict (only sent in first-party context)
6498    /// - Path=/ (accessible site-wide)
6499    ///
6500    /// Use this for authentication tokens that should never be sent in cross-site requests.
6501    ///
6502    /// # Arguments
6503    ///
6504    /// * `name` - Cookie name
6505    /// * `value` - Cookie value
6506    /// * `production` - If true, sets Secure flag; if false, omits it for local development
6507    ///
6508    /// # Example
6509    ///
6510    /// ```ignore
6511    /// use fastapi_core::extract::Cookie;
6512    ///
6513    /// let cookie = Cookie::auth("auth_token", "jwt_here", true)
6514    ///     .max_age(86400); // 1 day
6515    /// ```
6516    #[must_use]
6517    pub fn auth(name: impl Into<String>, value: impl Into<String>, production: bool) -> Self {
6518        Self::new(name, value)
6519            .http_only(true)
6520            .secure(production)
6521            .same_site(SameSite::Strict)
6522            .path("/")
6523    }
6524
6525    /// Create a CSRF token cookie.
6526    ///
6527    /// CSRF cookies are:
6528    /// - NOT HttpOnly (must be readable by JavaScript to include in requests)
6529    /// - Secure (HTTPS only, unless `production` is false)
6530    /// - SameSite=Strict (only sent in first-party context)
6531    /// - Path=/ (accessible site-wide)
6532    ///
6533    /// The CSRF token must be accessible to JavaScript so it can be included in
6534    /// request headers or form data for validation.
6535    ///
6536    /// # Arguments
6537    ///
6538    /// * `name` - Cookie name (commonly "csrf_token" or "_csrf")
6539    /// * `value` - The CSRF token value
6540    /// * `production` - If true, sets Secure flag; if false, omits it for local development
6541    ///
6542    /// # Example
6543    ///
6544    /// ```ignore
6545    /// use fastapi_core::extract::Cookie;
6546    ///
6547    /// let cookie = Cookie::csrf("csrf_token", "random_token_here", true);
6548    /// ```
6549    #[must_use]
6550    pub fn csrf(name: impl Into<String>, value: impl Into<String>, production: bool) -> Self {
6551        Self::new(name, value)
6552            .http_only(false)
6553            .secure(production)
6554            .same_site(SameSite::Strict)
6555            .path("/")
6556    }
6557
6558    /// Create a cookie with the `__Host-` prefix.
6559    ///
6560    /// The `__Host-` prefix enforces that the cookie:
6561    /// - MUST have the Secure flag
6562    /// - MUST NOT have a Domain attribute
6563    /// - MUST have Path=/
6564    ///
6565    /// This provides the strongest cookie security by preventing the cookie from
6566    /// being set by subdomains or accessed across different paths.
6567    ///
6568    /// # Arguments
6569    ///
6570    /// * `name` - Cookie name (without the `__Host-` prefix - it will be added)
6571    /// * `value` - Cookie value
6572    ///
6573    /// # Example
6574    ///
6575    /// ```ignore
6576    /// use fastapi_core::extract::Cookie;
6577    ///
6578    /// // Creates cookie named "__Host-session"
6579    /// let cookie = Cookie::host_prefixed("session", "abc123")
6580    ///     .http_only(true)
6581    ///     .same_site(SameSite::Strict);
6582    /// ```
6583    #[must_use]
6584    pub fn host_prefixed(name: impl Into<String>, value: impl Into<String>) -> Self {
6585        let prefixed_name = format!("__Host-{}", name.into());
6586        Self::new(prefixed_name, value).secure(true).path("/")
6587    }
6588
6589    /// Create a cookie with the `__Secure-` prefix.
6590    ///
6591    /// The `__Secure-` prefix enforces that the cookie:
6592    /// - MUST have the Secure flag
6593    ///
6594    /// Unlike `__Host-`, this allows Domain and Path attributes.
6595    ///
6596    /// # Arguments
6597    ///
6598    /// * `name` - Cookie name (without the `__Secure-` prefix - it will be added)
6599    /// * `value` - Cookie value
6600    ///
6601    /// # Example
6602    ///
6603    /// ```ignore
6604    /// use fastapi_core::extract::Cookie;
6605    ///
6606    /// // Creates cookie named "__Secure-token"
6607    /// let cookie = Cookie::secure_prefixed("token", "abc123")
6608    ///     .domain(".example.com")
6609    ///     .http_only(true);
6610    /// ```
6611    #[must_use]
6612    pub fn secure_prefixed(name: impl Into<String>, value: impl Into<String>) -> Self {
6613        let prefixed_name = format!("__Secure-{}", name.into());
6614        Self::new(prefixed_name, value).secure(true)
6615    }
6616
6617    /// Validate that the cookie meets its prefix requirements.
6618    ///
6619    /// Returns `Ok(())` if valid, or `Err` with a description of the violation.
6620    ///
6621    /// # Cookie Prefix Rules
6622    ///
6623    /// - `__Host-`: Must have Secure=true, Path="/", and no Domain
6624    /// - `__Secure-`: Must have Secure=true
6625    ///
6626    /// # Example
6627    ///
6628    /// ```ignore
6629    /// use fastapi_core::extract::Cookie;
6630    ///
6631    /// let cookie = Cookie::host_prefixed("session", "abc123");
6632    /// assert!(cookie.validate_prefix().is_ok());
6633    ///
6634    /// // This would fail validation
6635    /// let invalid = Cookie::new("__Host-session", "abc123")
6636    ///     .domain("example.com"); // __Host- cannot have Domain
6637    /// assert!(invalid.validate_prefix().is_err());
6638    /// ```
6639    pub fn validate_prefix(&self) -> Result<(), CookiePrefixError> {
6640        if self.name.starts_with("__Host-") {
6641            if !self.secure {
6642                return Err(CookiePrefixError::HostRequiresSecure);
6643            }
6644            if self.domain.is_some() {
6645                return Err(CookiePrefixError::HostCannotHaveDomain);
6646            }
6647            if self.path.as_deref() != Some("/") {
6648                return Err(CookiePrefixError::HostRequiresRootPath);
6649            }
6650        } else if self.name.starts_with("__Secure-") && !self.secure {
6651            return Err(CookiePrefixError::SecureRequiresSecure);
6652        }
6653        Ok(())
6654    }
6655
6656    /// Check if this cookie has a security prefix.
6657    #[must_use]
6658    pub fn has_security_prefix(&self) -> bool {
6659        self.name.starts_with("__Host-") || self.name.starts_with("__Secure-")
6660    }
6661
6662    /// Get the security prefix type, if any.
6663    #[must_use]
6664    pub fn prefix(&self) -> Option<CookiePrefix> {
6665        if self.name.starts_with("__Host-") {
6666            Some(CookiePrefix::Host)
6667        } else if self.name.starts_with("__Secure-") {
6668            Some(CookiePrefix::Secure)
6669        } else {
6670            None
6671        }
6672    }
6673}
6674
6675/// Cookie security prefix types.
6676///
6677/// Modern browsers support cookie prefixes that enforce security requirements:
6678/// - `__Host-`: Strongest protection, locks cookie to a single origin
6679/// - `__Secure-`: Requires HTTPS, but allows subdomain/path configuration
6680#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6681pub enum CookiePrefix {
6682    /// The `__Host-` prefix.
6683    ///
6684    /// Requires: Secure=true, Path="/", no Domain attribute.
6685    Host,
6686    /// The `__Secure-` prefix.
6687    ///
6688    /// Requires: Secure=true only.
6689    Secure,
6690}
6691
6692impl CookiePrefix {
6693    /// Get the string representation of the prefix.
6694    #[must_use]
6695    pub const fn as_str(&self) -> &'static str {
6696        match self {
6697            Self::Host => "__Host-",
6698            Self::Secure => "__Secure-",
6699        }
6700    }
6701}
6702
6703/// Errors that can occur when validating cookie prefixes.
6704#[derive(Debug, Clone, PartialEq, Eq)]
6705pub enum CookiePrefixError {
6706    /// `__Host-` prefix requires Secure flag.
6707    HostRequiresSecure,
6708    /// `__Host-` prefix cannot have a Domain attribute.
6709    HostCannotHaveDomain,
6710    /// `__Host-` prefix requires Path="/".
6711    HostRequiresRootPath,
6712    /// `__Secure-` prefix requires Secure flag.
6713    SecureRequiresSecure,
6714}
6715
6716impl std::fmt::Display for CookiePrefixError {
6717    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
6718        match self {
6719            Self::HostRequiresSecure => {
6720                write!(f, "__Host- prefix requires Secure flag to be true")
6721            }
6722            Self::HostCannotHaveDomain => {
6723                write!(f, "__Host- prefix cannot have a Domain attribute")
6724            }
6725            Self::HostRequiresRootPath => {
6726                write!(f, "__Host- prefix requires Path=\"/\"")
6727            }
6728            Self::SecureRequiresSecure => {
6729                write!(f, "__Secure- prefix requires Secure flag to be true")
6730            }
6731        }
6732    }
6733}
6734
6735impl std::error::Error for CookiePrefixError {}
6736
6737/// SameSite cookie attribute.
6738#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6739pub enum SameSite {
6740    /// Strict: Cookie only sent in first-party context.
6741    Strict,
6742    /// Lax: Cookie sent with top-level navigations.
6743    Lax,
6744    /// None: Cookie sent in all contexts (requires Secure).
6745    None,
6746}
6747
6748impl SameSite {
6749    /// Get the string representation.
6750    #[must_use]
6751    pub const fn as_str(&self) -> &'static str {
6752        match self {
6753            Self::Strict => "Strict",
6754            Self::Lax => "Lax",
6755            Self::None => "None",
6756        }
6757    }
6758}
6759
6760// ============================================================================
6761// Cookie Request Extractors
6762// ============================================================================
6763
6764/// Extract all cookies from the incoming request as a map.
6765///
6766/// Parses the `Cookie` header and provides access to all cookies by name.
6767///
6768/// # Example
6769///
6770/// ```ignore
6771/// use fastapi_core::extract::RequestCookies;
6772///
6773/// async fn handler(cookies: RequestCookies) -> impl IntoResponse {
6774///     if let Some(session_id) = cookies.get("session_id") {
6775///         format!("Session: {}", session_id)
6776///     } else {
6777///         "No session".to_string()
6778///     }
6779/// }
6780/// ```
6781#[derive(Debug, Clone, Default)]
6782pub struct RequestCookies {
6783    cookies: std::collections::HashMap<String, String>,
6784}
6785
6786impl RequestCookies {
6787    /// Create an empty cookie collection.
6788    #[must_use]
6789    pub fn new() -> Self {
6790        Self::default()
6791    }
6792
6793    /// Parse cookies from a Cookie header value.
6794    #[must_use]
6795    pub fn from_header(header_value: &str) -> Self {
6796        let mut cookies = std::collections::HashMap::new();
6797
6798        // Cookie header format: "name1=value1; name2=value2"
6799        for pair in header_value.split(';') {
6800            let pair = pair.trim();
6801            if let Some((name, value)) = pair.split_once('=') {
6802                let name = name.trim().to_string();
6803                let value = value.trim().to_string();
6804                if !name.is_empty() {
6805                    cookies.insert(name, value);
6806                }
6807            }
6808        }
6809
6810        Self { cookies }
6811    }
6812
6813    /// Get a cookie value by name.
6814    #[must_use]
6815    pub fn get(&self, name: &str) -> Option<&str> {
6816        self.cookies.get(name).map(String::as_str)
6817    }
6818
6819    /// Check if a cookie exists.
6820    #[must_use]
6821    pub fn contains(&self, name: &str) -> bool {
6822        self.cookies.contains_key(name)
6823    }
6824
6825    /// Get the number of cookies.
6826    #[must_use]
6827    pub fn len(&self) -> usize {
6828        self.cookies.len()
6829    }
6830
6831    /// Check if there are no cookies.
6832    #[must_use]
6833    pub fn is_empty(&self) -> bool {
6834        self.cookies.is_empty()
6835    }
6836
6837    /// Iterate over all cookie name-value pairs.
6838    pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
6839        self.cookies.iter().map(|(k, v)| (k.as_str(), v.as_str()))
6840    }
6841
6842    /// Get all cookie names.
6843    pub fn names(&self) -> impl Iterator<Item = &str> {
6844        self.cookies.keys().map(String::as_str)
6845    }
6846}
6847
6848impl FromRequest for RequestCookies {
6849    type Error = std::convert::Infallible;
6850
6851    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
6852        let cookies = req
6853            .headers()
6854            .get("cookie")
6855            .and_then(|v| std::str::from_utf8(v).ok())
6856            .map(Self::from_header)
6857            .unwrap_or_default();
6858
6859        Ok(cookies)
6860    }
6861}
6862
6863/// Extract a single cookie value by name from the incoming request.
6864///
6865/// The cookie name is specified via the `CookieName` trait, similar to how
6866/// `Header<T>` works with `HeaderName`.
6867///
6868/// # Example
6869///
6870/// ```ignore
6871/// use fastapi_core::extract::{RequestCookie, CookieName};
6872///
6873/// // Define a cookie name
6874/// struct SessionId;
6875/// impl CookieName for SessionId {
6876///     const NAME: &'static str = "session_id";
6877/// }
6878///
6879/// async fn handler(session: RequestCookie<SessionId>) -> impl IntoResponse {
6880///     format!("Session: {}", session.value())
6881/// }
6882///
6883/// // For optional cookies:
6884/// async fn optional_handler(session: Option<RequestCookie<SessionId>>) -> impl IntoResponse {
6885///     match session {
6886///         Some(s) => format!("Session: {}", s.value()),
6887///         None => "No session".to_string(),
6888///     }
6889/// }
6890/// ```
6891#[derive(Debug, Clone)]
6892pub struct RequestCookie<T> {
6893    value: String,
6894    _marker: std::marker::PhantomData<T>,
6895}
6896
6897impl<T> RequestCookie<T> {
6898    /// Create a new cookie extractor with the given value.
6899    #[must_use]
6900    pub fn new(value: impl Into<String>) -> Self {
6901        Self {
6902            value: value.into(),
6903            _marker: std::marker::PhantomData,
6904        }
6905    }
6906
6907    /// Get the cookie value.
6908    #[must_use]
6909    pub fn value(&self) -> &str {
6910        &self.value
6911    }
6912
6913    /// Consume and return the cookie value.
6914    #[must_use]
6915    pub fn into_value(self) -> String {
6916        self.value
6917    }
6918}
6919
6920impl<T> Deref for RequestCookie<T> {
6921    type Target = str;
6922
6923    fn deref(&self) -> &Self::Target {
6924        &self.value
6925    }
6926}
6927
6928impl<T> AsRef<str> for RequestCookie<T> {
6929    fn as_ref(&self) -> &str {
6930        &self.value
6931    }
6932}
6933
6934/// Trait for defining cookie names used with `RequestCookie<T>`.
6935///
6936/// # Example
6937///
6938/// ```ignore
6939/// use fastapi_core::extract::CookieName;
6940///
6941/// struct SessionId;
6942/// impl CookieName for SessionId {
6943///     const NAME: &'static str = "session_id";
6944/// }
6945/// ```
6946pub trait CookieName {
6947    /// The cookie name to extract.
6948    const NAME: &'static str;
6949}
6950
6951/// Error type for cookie extraction failures.
6952#[derive(Debug)]
6953pub enum CookieExtractError {
6954    /// The requested cookie was not found.
6955    NotFound {
6956        /// The name of the missing cookie.
6957        name: &'static str,
6958    },
6959    /// The cookie value could not be parsed.
6960    InvalidValue {
6961        /// The cookie name.
6962        name: &'static str,
6963        /// The raw value that couldn't be parsed.
6964        value: String,
6965        /// Description of the expected format.
6966        expected: &'static str,
6967    },
6968}
6969
6970impl fmt::Display for CookieExtractError {
6971    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
6972        match self {
6973            Self::NotFound { name } => {
6974                write!(f, "Cookie '{}' not found", name)
6975            }
6976            Self::InvalidValue {
6977                name,
6978                value,
6979                expected,
6980            } => {
6981                write!(
6982                    f,
6983                    "Invalid cookie '{}' value '{}': expected {}",
6984                    name, value, expected
6985                )
6986            }
6987        }
6988    }
6989}
6990
6991impl std::error::Error for CookieExtractError {}
6992
6993impl IntoResponse for CookieExtractError {
6994    fn into_response(self) -> crate::response::Response {
6995        match self {
6996            Self::NotFound { name } => ValidationErrors::single(
6997                ValidationError::missing(crate::error::loc::cookie(name))
6998                    .with_msg("Cookie is required"),
6999            )
7000            .into_response(),
7001            Self::InvalidValue {
7002                name,
7003                value,
7004                expected,
7005            } => ValidationErrors::single(
7006                ValidationError::type_error(crate::error::loc::cookie(name), expected)
7007                    .with_msg(format!("Expected {expected}"))
7008                    .with_input(serde_json::Value::String(value)),
7009            )
7010            .into_response(),
7011        }
7012    }
7013}
7014
7015impl<T: CookieName> FromRequest for RequestCookie<T> {
7016    type Error = CookieExtractError;
7017
7018    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
7019        let cookies = req
7020            .headers()
7021            .get("cookie")
7022            .and_then(|v| std::str::from_utf8(v).ok())
7023            .map(RequestCookies::from_header)
7024            .unwrap_or_default();
7025
7026        cookies
7027            .get(T::NAME)
7028            .map(|v| RequestCookie::new(v))
7029            .ok_or(CookieExtractError::NotFound { name: T::NAME })
7030    }
7031}
7032
7033// Common cookie name types for convenience
7034
7035/// Session ID cookie name marker.
7036pub struct SessionIdCookie;
7037impl CookieName for SessionIdCookie {
7038    const NAME: &'static str = "session_id";
7039}
7040
7041/// CSRF token cookie name marker.
7042pub struct CsrfTokenCookie;
7043impl CookieName for CsrfTokenCookie {
7044    const NAME: &'static str = "csrf_token";
7045}
7046
7047// ============================================================================
7048// Response Mutation Extractor
7049// ============================================================================
7050
7051/// Mutable response wrapper for setting headers and cookies.
7052///
7053/// This is the extractor type that handlers receive. Mutations made through
7054/// this wrapper are stored in request extensions and applied after the handler
7055/// returns.
7056///
7057/// # Example
7058///
7059/// ```ignore
7060/// use fastapi_core::extract::ResponseMut;
7061///
7062/// async fn handler(mut resp: ResponseMut) -> &'static str {
7063///     resp.header("X-Powered-By", "fastapi-rust");
7064///     resp.cookie("visited", "true");
7065///     "Hello"
7066/// }
7067/// ```
7068pub struct ResponseMut<'a> {
7069    mutations: &'a mut ResponseMutations,
7070}
7071
7072impl<'a> ResponseMut<'a> {
7073    /// Set a response header.
7074    pub fn header(&mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) {
7075        self.mutations.add_header(name, value);
7076    }
7077
7078    /// Set a cookie.
7079    pub fn cookie(&mut self, name: impl Into<String>, value: impl Into<String>) {
7080        self.mutations.add_cookie(Cookie::new(name, value));
7081    }
7082
7083    /// Set a cookie with full configuration.
7084    pub fn set_cookie(&mut self, cookie: Cookie) {
7085        self.mutations.add_cookie(cookie);
7086    }
7087
7088    /// Delete a cookie by name.
7089    pub fn delete_cookie(&mut self, name: impl Into<String>) {
7090        self.mutations.remove_cookie(name);
7091    }
7092}
7093
7094impl<'a> std::fmt::Debug for ResponseMut<'a> {
7095    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
7096        f.debug_struct("ResponseMut")
7097            .field("mutations", &self.mutations)
7098            .finish()
7099    }
7100}
7101
7102// Note: ResponseMut cannot implement FromRequest because it returns a borrowed
7103// reference. Instead, handlers should extract ResponseMutations and get a &mut
7104// reference to it. The App will apply mutations after handler execution.
7105
7106impl FromRequest for ResponseMutations {
7107    type Error = std::convert::Infallible;
7108
7109    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
7110        // Get existing mutations or create new ones
7111        if let Some(mutations) = req.get_extension::<ResponseMutations>() {
7112            Ok(mutations.clone())
7113        } else {
7114            let mutations = ResponseMutations::new();
7115            req.insert_extension(mutations.clone());
7116            Ok(mutations)
7117        }
7118    }
7119}
7120
7121// ============================================================================
7122// Background Tasks Extractor
7123// ============================================================================
7124
7125use std::sync::Arc;
7126
7127/// Formats a panic payload into a human-readable message.
7128///
7129/// This helper extracts the panic message from the `Box<dyn Any>` payload
7130/// returned by `catch_unwind`.
7131fn format_panic_message(panic_info: &Box<dyn std::any::Any + Send>) -> String {
7132    if let Some(s) = panic_info.downcast_ref::<&str>() {
7133        (*s).to_string()
7134    } else if let Some(s) = panic_info.downcast_ref::<String>() {
7135        s.clone()
7136    } else {
7137        "unknown panic".to_string()
7138    }
7139}
7140
7141/// Type alias for a boxed async task function.
7142pub type BackgroundTask =
7143    Box<dyn FnOnce() -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send>;
7144
7145/// Internal storage for background tasks (thread-safe).
7146///
7147/// This uses `parking_lot::Mutex` for interior mutability while being Send + Sync.
7148#[derive(Default, Clone)]
7149pub struct BackgroundTasksInner {
7150    inner: Arc<parking_lot::Mutex<Vec<BackgroundTask>>>,
7151}
7152
7153impl BackgroundTasksInner {
7154    /// Create a new empty task storage.
7155    #[must_use]
7156    pub fn new() -> Self {
7157        Self {
7158            inner: Arc::new(parking_lot::Mutex::new(Vec::new())),
7159        }
7160    }
7161
7162    /// Add a task to the queue.
7163    pub fn push(&self, task: BackgroundTask) {
7164        self.inner.lock().push(task);
7165    }
7166
7167    /// Take all tasks from the queue.
7168    pub fn take(&self) -> Vec<BackgroundTask> {
7169        std::mem::take(&mut *self.inner.lock())
7170    }
7171
7172    /// Returns the number of tasks.
7173    #[must_use]
7174    pub fn len(&self) -> usize {
7175        self.inner.lock().len()
7176    }
7177
7178    /// Returns true if there are no tasks.
7179    #[must_use]
7180    pub fn is_empty(&self) -> bool {
7181        self.inner.lock().is_empty()
7182    }
7183}
7184
7185impl std::fmt::Debug for BackgroundTasksInner {
7186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
7187        f.debug_struct("BackgroundTasksInner")
7188            .field("task_count", &self.len())
7189            .finish()
7190    }
7191}
7192
7193/// Background task queue for running tasks after response is sent.
7194///
7195/// Tasks are executed in the order they are added, after the response
7196/// has been sent to the client. This is useful for:
7197/// - Sending emails
7198/// - Writing to external logs
7199/// - Triggering webhooks
7200/// - Updating caches
7201///
7202/// # Example
7203///
7204/// ```ignore
7205/// use fastapi_core::extract::BackgroundTasks;
7206///
7207/// async fn handler(mut tasks: BackgroundTasks) -> &'static str {
7208///     tasks.add_task(|| async {
7209///         // Send notification email
7210///         send_email("user@example.com", "Welcome!").await;
7211///     });
7212///     "Response sent, email will be sent in background"
7213/// }
7214/// ```
7215///
7216/// # Note
7217///
7218/// Background tasks run after the response is sent but before the request
7219/// context is fully cleaned up. They share the same cancellation context
7220/// as the request, so long-running tasks should check for cancellation.
7221#[derive(Clone)]
7222pub struct BackgroundTasks {
7223    inner: BackgroundTasksInner,
7224}
7225
7226impl Default for BackgroundTasks {
7227    fn default() -> Self {
7228        Self::new()
7229    }
7230}
7231
7232impl BackgroundTasks {
7233    /// Create a new empty task queue.
7234    #[must_use]
7235    pub fn new() -> Self {
7236        Self {
7237            inner: BackgroundTasksInner::new(),
7238        }
7239    }
7240
7241    /// Create from inner storage.
7242    #[must_use]
7243    pub(crate) fn from_inner(inner: BackgroundTasksInner) -> Self {
7244        Self { inner }
7245    }
7246
7247    /// Add a background task.
7248    ///
7249    /// The task will be executed after the response is sent.
7250    pub fn add_task<F, Fut>(&mut self, task: F)
7251    where
7252        F: FnOnce() -> Fut + Send + 'static,
7253        Fut: std::future::Future<Output = ()> + Send + 'static,
7254    {
7255        self.inner.push(Box::new(move || Box::pin(task())));
7256    }
7257
7258    /// Add a synchronous background task.
7259    ///
7260    /// The task will be executed after the response is sent.
7261    pub fn add_sync_task<F>(&mut self, task: F)
7262    where
7263        F: FnOnce() + Send + 'static,
7264    {
7265        self.inner.push(Box::new(move || {
7266            Box::pin(async move {
7267                task();
7268            })
7269        }));
7270    }
7271
7272    /// Take all tasks from the queue.
7273    pub fn take_tasks(&mut self) -> Vec<BackgroundTask> {
7274        self.inner.take()
7275    }
7276
7277    /// Returns true if there are no tasks.
7278    #[must_use]
7279    pub fn is_empty(&self) -> bool {
7280        self.inner.is_empty()
7281    }
7282
7283    /// Returns the number of tasks.
7284    #[must_use]
7285    pub fn len(&self) -> usize {
7286        self.inner.len()
7287    }
7288
7289    /// Execute all tasks sequentially.
7290    ///
7291    /// This is called by the framework after the response is sent.
7292    /// Tasks run in the order they were added (FIFO).
7293    ///
7294    /// # Error Handling
7295    ///
7296    /// This method does NOT provide error isolation. If a task panics,
7297    /// subsequent tasks will not run. For panic isolation, use
7298    /// [`Self::execute_with_panic_isolation()`] instead.
7299    ///
7300    /// Since tasks run after the response is sent, panics do not affect
7301    /// the HTTP response that was already delivered to the client.
7302    ///
7303    /// # Example
7304    ///
7305    /// ```ignore
7306    /// let response = app.handle(&ctx, &mut request).await;
7307    /// // Send response to client...
7308    /// if let Some(tasks) = App::take_background_tasks(&mut request) {
7309    ///     tasks.execute_all().await;
7310    /// }
7311    /// ```
7312    pub async fn execute_all(mut self) {
7313        for task in self.take_tasks() {
7314            let future = task();
7315            future.await;
7316        }
7317    }
7318
7319    /// Execute all tasks with cancellation support via RequestContext.
7320    ///
7321    /// This version checks for cancellation between tasks and respects
7322    /// the request's cancellation state. Use this when you want background
7323    /// tasks to be cancelled along with the request.
7324    ///
7325    /// # Arguments
7326    ///
7327    /// * `ctx` - The request context for cancellation checking
7328    ///
7329    /// # Cancellation Behavior
7330    ///
7331    /// - Checks `ctx.is_cancelled()` before starting each task
7332    /// - If cancelled, remaining tasks are skipped
7333    /// - Already-running tasks complete before the check
7334    /// - Logs the number of skipped tasks
7335    ///
7336    /// # Integration with asupersync
7337    ///
7338    /// Background tasks run in the same region as the request. When the
7339    /// region is cancelled (client disconnect, timeout, etc.), subsequent
7340    /// tasks will be skipped. This ensures proper structured concurrency.
7341    ///
7342    /// # Example
7343    ///
7344    /// ```ignore
7345    /// let response = app.handle(&ctx, &mut request).await;
7346    /// // Send response to client...
7347    /// if let Some(tasks) = App::take_background_tasks(&mut request) {
7348    ///     tasks.execute_with_context(&ctx).await;
7349    /// }
7350    /// ```
7351    pub async fn execute_with_context(mut self, ctx: &RequestContext) {
7352        let tasks = self.take_tasks();
7353        let task_count = tasks.len();
7354        let mut executed_count = 0;
7355
7356        for (index, task) in tasks.into_iter().enumerate() {
7357            // Check for cancellation before starting each task
7358            if ctx.is_cancelled() {
7359                let remaining = task_count - index;
7360                if remaining > 0 {
7361                    ctx.trace(&format!(
7362                        "BackgroundTasks: Cancellation requested, skipping {} remaining task(s)",
7363                        remaining
7364                    ));
7365                }
7366                break;
7367            }
7368
7369            // Execute the task
7370            let future = task();
7371            future.await;
7372            executed_count += 1;
7373        }
7374
7375        // Trace completion for observability
7376        if task_count > 0 {
7377            ctx.trace(&format!(
7378                "BackgroundTasks: Executed {}/{} tasks",
7379                executed_count, task_count
7380            ));
7381        }
7382    }
7383
7384    /// Execute tasks with per-task panic isolation using catch_unwind.
7385    ///
7386    /// This version catches panics that occur when **calling the task closure**
7387    /// (i.e., when creating the future), allowing subsequent tasks to continue.
7388    ///
7389    /// # What Is Caught vs. Not Caught
7390    ///
7391    /// ```ignore
7392    /// // CAUGHT: Panic in closure before returning future
7393    /// tasks.add_task(|| {
7394    ///     panic!("this is caught");
7395    ///     async {}  // never reached
7396    /// });
7397    ///
7398    /// // NOT CAUGHT: Panic inside async block (during .await)
7399    /// tasks.add_task(|| async {
7400    ///     panic!("this is NOT caught - propagates and stops remaining tasks");
7401    /// });
7402    /// ```
7403    ///
7404    /// Due to Rust's async limitation, panics inside the async task body
7405    /// cannot be caught without additional crate dependencies (like `futures`).
7406    ///
7407    /// For most use cases where task code is well-behaved, `execute_all()`
7408    /// or `execute_with_context()` is sufficient.
7409    pub async fn execute_with_panic_isolation(mut self) {
7410        let tasks = self.take_tasks();
7411        let task_count = tasks.len();
7412        let mut success_count = 0;
7413        let mut panic_count = 0;
7414
7415        for (index, task) in tasks.into_iter().enumerate() {
7416            // Attempt to get the future - catch panics in the closure
7417            let future_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| task()));
7418
7419            match future_result {
7420                Ok(future) => {
7421                    // Got the future, now await it
7422                    // Note: Panics during await will still propagate
7423                    future.await;
7424                    success_count += 1;
7425                }
7426                Err(panic_info) => {
7427                    // Task closure panicked
7428                    panic_count += 1;
7429                    let panic_msg = format_panic_message(&panic_info);
7430                    eprintln!(
7431                        "[BackgroundTasks] Task {}/{} panicked: {}",
7432                        index + 1,
7433                        task_count,
7434                        panic_msg
7435                    );
7436                }
7437            }
7438        }
7439
7440        if panic_count > 0 {
7441            eprintln!(
7442                "[BackgroundTasks] Completed with {}/{} successful, {} panicked",
7443                success_count, task_count, panic_count
7444            );
7445        }
7446    }
7447
7448    /// Get the inner storage for request extensions.
7449    pub fn into_inner(self) -> BackgroundTasksInner {
7450        self.inner
7451    }
7452}
7453
7454impl std::fmt::Debug for BackgroundTasks {
7455    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
7456        f.debug_struct("BackgroundTasks")
7457            .field("task_count", &self.len())
7458            .finish()
7459    }
7460}
7461
7462impl FromRequest for BackgroundTasks {
7463    type Error = std::convert::Infallible;
7464
7465    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
7466        // Get existing task storage or create new one
7467        if let Some(inner) = req.get_extension::<BackgroundTasksInner>() {
7468            Ok(BackgroundTasks::from_inner(inner.clone()))
7469        } else {
7470            let inner = BackgroundTasksInner::new();
7471            req.insert_extension(inner.clone());
7472            Ok(BackgroundTasks::from_inner(inner))
7473        }
7474    }
7475}
7476
7477#[cfg(test)]
7478mod special_extractor_tests {
7479    use super::*;
7480    use crate::request::Method;
7481
7482    fn test_context() -> RequestContext {
7483        let cx = asupersync::Cx::for_testing();
7484        RequestContext::new(cx, 12345)
7485    }
7486
7487    #[test]
7488    fn request_ref_extracts_metadata() {
7489        let ctx = test_context();
7490        let mut req = Request::new(Method::Get, "/users/42");
7491        req.set_query(Some("page=1".to_string()));
7492        req.headers_mut()
7493            .insert("content-type", b"application/json".to_vec());
7494
7495        let result = futures_executor::block_on(RequestRef::from_request(&ctx, &mut req));
7496        let req_ref = result.unwrap();
7497
7498        assert_eq!(req_ref.method(), Method::Get);
7499        assert_eq!(req_ref.path(), "/users/42");
7500        assert_eq!(req_ref.query(), Some("page=1"));
7501        assert_eq!(
7502            req_ref.header("content-type"),
7503            Some(b"application/json".as_slice())
7504        );
7505    }
7506
7507    #[test]
7508    fn request_ref_header_case_insensitive() {
7509        let ctx = test_context();
7510        let mut req = Request::new(Method::Get, "/");
7511        req.headers_mut()
7512            .insert("X-Custom-Header", b"value".to_vec());
7513
7514        let result = futures_executor::block_on(RequestRef::from_request(&ctx, &mut req));
7515        let req_ref = result.unwrap();
7516
7517        assert_eq!(req_ref.header("x-custom-header"), Some(b"value".as_slice()));
7518        assert_eq!(req_ref.header("X-CUSTOM-HEADER"), Some(b"value".as_slice()));
7519    }
7520
7521    #[test]
7522    fn cookie_to_header_value_simple() {
7523        let cookie = Cookie::new("session", "abc123");
7524        assert_eq!(cookie.to_header_value(), "session=abc123");
7525    }
7526
7527    #[test]
7528    fn cookie_to_header_value_with_attributes() {
7529        let cookie = Cookie::new("session", "abc123")
7530            .max_age(3600)
7531            .path("/")
7532            .secure(true)
7533            .http_only(true)
7534            .same_site(SameSite::Strict);
7535
7536        let header = cookie.to_header_value();
7537        assert!(header.contains("session=abc123"));
7538        assert!(header.contains("Max-Age=3600"));
7539        assert!(header.contains("Path=/"));
7540        assert!(header.contains("Secure"));
7541        assert!(header.contains("HttpOnly"));
7542        assert!(header.contains("SameSite=Strict"));
7543    }
7544
7545    #[test]
7546    fn response_mutations_apply_headers() {
7547        let mut mutations = ResponseMutations::new();
7548        mutations.add_header("X-Custom", "value");
7549        mutations.add_header("X-Another", "other");
7550
7551        let response = crate::response::Response::ok();
7552        let response = mutations.apply(response);
7553
7554        let headers = response.headers();
7555        assert!(
7556            headers
7557                .iter()
7558                .any(|(n, v)| n == "X-Custom" && v == b"value")
7559        );
7560        assert!(
7561            headers
7562                .iter()
7563                .any(|(n, v)| n == "X-Another" && v == b"other")
7564        );
7565    }
7566
7567    #[test]
7568    fn response_mutations_apply_cookies() {
7569        let mut mutations = ResponseMutations::new();
7570        mutations.add_cookie(Cookie::new("session", "abc").http_only(true));
7571
7572        let response = crate::response::Response::ok();
7573        let response = mutations.apply(response);
7574
7575        let headers = response.headers();
7576        let set_cookie = headers
7577            .iter()
7578            .find(|(n, _)| n == "Set-Cookie")
7579            .map(|(_, v)| String::from_utf8_lossy(v).to_string());
7580        assert!(set_cookie.is_some());
7581        assert!(set_cookie.unwrap().contains("session=abc"));
7582    }
7583
7584    #[test]
7585    fn response_mutations_delete_cookie() {
7586        let mut mutations = ResponseMutations::new();
7587        mutations.remove_cookie("session");
7588
7589        let response = crate::response::Response::ok();
7590        let response = mutations.apply(response);
7591
7592        let headers = response.headers();
7593        let set_cookie = headers
7594            .iter()
7595            .find(|(n, _)| n == "Set-Cookie")
7596            .map(|(_, v)| String::from_utf8_lossy(v).to_string());
7597        assert!(set_cookie.is_some());
7598        let cookie_header = set_cookie.unwrap();
7599        assert!(cookie_header.contains("session="));
7600        assert!(cookie_header.contains("Max-Age=0"));
7601    }
7602
7603    #[test]
7604    fn response_mutations_extract() {
7605        let ctx = test_context();
7606        let mut req = Request::new(Method::Get, "/");
7607
7608        let result = futures_executor::block_on(ResponseMutations::from_request(&ctx, &mut req));
7609        let mutations = result.unwrap();
7610        assert!(mutations.headers.is_empty());
7611        assert!(mutations.cookies.is_empty());
7612    }
7613
7614    // =========================================================================
7615    // Cookie Sanitization Security Tests
7616    // =========================================================================
7617
7618    #[test]
7619    fn cookie_sanitizes_semicolon_injection_in_value() {
7620        // Attacker tries to inject Domain attribute via value
7621        let cookie = Cookie::new("session", "abc; Domain=.evil.com");
7622        let header = cookie.to_header_value();
7623        // Semicolon should be removed, preventing attribute injection
7624        assert_eq!(header, "session=abcDomain=.evil.com");
7625        assert!(!header.contains("; Domain"));
7626    }
7627
7628    #[test]
7629    fn cookie_sanitizes_semicolon_injection_in_name() {
7630        // Attacker tries to inject via cookie name
7631        let cookie = Cookie::new("session; HttpOnly", "value");
7632        let header = cookie.to_header_value();
7633        // Semicolon should be removed from name
7634        assert!(!header.starts_with("session; "));
7635        assert!(header.starts_with("sessionHttpOnly="));
7636    }
7637
7638    #[test]
7639    fn cookie_sanitizes_path_injection() {
7640        // Attacker tries to inject attributes via path
7641        let cookie = Cookie::new("session", "abc").path("/; HttpOnly; Domain=.evil.com");
7642        let header = cookie.to_header_value();
7643        // Semicolons should be removed from path, preventing attribute injection
7644        // The path becomes "/HttpOnlyDomain=.evil.com" (gibberish but safe)
7645        assert!(!header.contains("; Domain"));
7646        assert!(!header.contains("; HttpOnly"));
7647        // Verify path is present but sanitized (no semicolons)
7648        assert!(header.contains("Path=/"));
7649    }
7650
7651    #[test]
7652    fn cookie_sanitizes_domain_injection() {
7653        // Attacker tries to inject attributes via domain
7654        let cookie = Cookie::new("session", "abc").domain(".example.com; HttpOnly=false");
7655        let header = cookie.to_header_value();
7656        // Semicolons should be removed from domain, preventing attribute injection
7657        assert!(!header.contains("; HttpOnly=false"));
7658        // Domain is sanitized (semicolon removed), but space is preserved in attr values
7659        assert!(header.contains("Domain=.example.com HttpOnly=false"));
7660    }
7661
7662    #[test]
7663    fn cookie_sanitizes_control_characters() {
7664        // Attacker tries CRLF injection
7665        let cookie = Cookie::new("session", "abc\r\nSet-Cookie: evil=value");
7666        let header = cookie.to_header_value();
7667        // Control characters and spaces should be removed
7668        assert!(!header.contains("\r"));
7669        assert!(!header.contains("\n"));
7670        assert!(!header.contains(" ")); // Space is also removed from cookie values
7671        // The sanitized value is "abcSet-Cookie:evil=value" (no CRLF injection possible)
7672        assert!(header.contains("session=abcSet-Cookie:evil=value"));
7673    }
7674
7675    #[test]
7676    fn delete_cookie_sanitizes_name() {
7677        // Attacker tries to inject via delete cookie name
7678        let mut mutations = ResponseMutations::new();
7679        mutations.remove_cookie("session; Domain=.evil.com");
7680
7681        let response = crate::response::Response::ok();
7682        let response = mutations.apply(response);
7683
7684        let headers = response.headers();
7685        let set_cookie = headers
7686            .iter()
7687            .find(|(n, _)| n == "Set-Cookie")
7688            .map(|(_, v)| String::from_utf8_lossy(v).to_string());
7689        assert!(set_cookie.is_some());
7690        let cookie_header = set_cookie.unwrap();
7691        // Semicolon should be removed, no Domain attribute injected
7692        assert!(!cookie_header.contains("; Domain"));
7693    }
7694
7695    // =========================================================================
7696    // Secure Cookie Configuration Helper Tests
7697    // =========================================================================
7698
7699    #[test]
7700    fn session_cookie_production() {
7701        let cookie = Cookie::session("session_id", "abc123", true);
7702        assert_eq!(cookie.name, "session_id");
7703        assert_eq!(cookie.value, "abc123");
7704        assert!(cookie.http_only);
7705        assert!(cookie.secure);
7706        assert_eq!(cookie.same_site, Some(SameSite::Lax));
7707        assert_eq!(cookie.path, Some("/".to_string()));
7708    }
7709
7710    #[test]
7711    fn session_cookie_development() {
7712        let cookie = Cookie::session("session_id", "abc123", false);
7713        assert!(cookie.http_only);
7714        assert!(!cookie.secure); // No Secure flag in development
7715        assert_eq!(cookie.same_site, Some(SameSite::Lax));
7716    }
7717
7718    #[test]
7719    fn auth_cookie_production() {
7720        let cookie = Cookie::auth("auth_token", "jwt_token", true);
7721        assert_eq!(cookie.name, "auth_token");
7722        assert!(cookie.http_only);
7723        assert!(cookie.secure);
7724        assert_eq!(cookie.same_site, Some(SameSite::Strict)); // Stricter than session
7725        assert_eq!(cookie.path, Some("/".to_string()));
7726    }
7727
7728    #[test]
7729    fn csrf_cookie_is_readable_by_js() {
7730        let cookie = Cookie::csrf("csrf_token", "random_value", true);
7731        assert_eq!(cookie.name, "csrf_token");
7732        assert!(!cookie.http_only); // Must be readable by JS
7733        assert!(cookie.secure);
7734        assert_eq!(cookie.same_site, Some(SameSite::Strict));
7735    }
7736
7737    #[test]
7738    fn host_prefixed_cookie() {
7739        let cookie = Cookie::host_prefixed("session", "abc123");
7740        assert_eq!(cookie.name, "__Host-session");
7741        assert!(cookie.secure);
7742        assert_eq!(cookie.path, Some("/".to_string()));
7743        assert!(cookie.domain.is_none());
7744        assert!(cookie.validate_prefix().is_ok());
7745    }
7746
7747    #[test]
7748    fn host_prefixed_cookie_validation_fails_without_secure() {
7749        let cookie = Cookie::new("__Host-session", "abc123")
7750            .path("/")
7751            .secure(false);
7752        assert_eq!(
7753            cookie.validate_prefix(),
7754            Err(CookiePrefixError::HostRequiresSecure)
7755        );
7756    }
7757
7758    #[test]
7759    fn host_prefixed_cookie_validation_fails_with_domain() {
7760        let cookie = Cookie::new("__Host-session", "abc123")
7761            .path("/")
7762            .secure(true)
7763            .domain("example.com");
7764        assert_eq!(
7765            cookie.validate_prefix(),
7766            Err(CookiePrefixError::HostCannotHaveDomain)
7767        );
7768    }
7769
7770    #[test]
7771    fn host_prefixed_cookie_validation_fails_without_root_path() {
7772        let cookie = Cookie::new("__Host-session", "abc123")
7773            .path("/api")
7774            .secure(true);
7775        assert_eq!(
7776            cookie.validate_prefix(),
7777            Err(CookiePrefixError::HostRequiresRootPath)
7778        );
7779    }
7780
7781    #[test]
7782    fn secure_prefixed_cookie() {
7783        let cookie = Cookie::secure_prefixed("token", "abc123");
7784        assert_eq!(cookie.name, "__Secure-token");
7785        assert!(cookie.secure);
7786        // __Secure- allows Domain and Path
7787        let cookie = cookie.domain("example.com").path("/api");
7788        assert!(cookie.validate_prefix().is_ok());
7789    }
7790
7791    #[test]
7792    fn secure_prefixed_cookie_validation_fails_without_secure() {
7793        let cookie = Cookie::new("__Secure-token", "abc123").secure(false);
7794        assert_eq!(
7795            cookie.validate_prefix(),
7796            Err(CookiePrefixError::SecureRequiresSecure)
7797        );
7798    }
7799
7800    #[test]
7801    fn cookie_prefix_detection() {
7802        let host_cookie = Cookie::host_prefixed("session", "abc");
7803        assert!(host_cookie.has_security_prefix());
7804        assert_eq!(host_cookie.prefix(), Some(CookiePrefix::Host));
7805
7806        let secure_cookie = Cookie::secure_prefixed("token", "abc");
7807        assert!(secure_cookie.has_security_prefix());
7808        assert_eq!(secure_cookie.prefix(), Some(CookiePrefix::Secure));
7809
7810        let normal_cookie = Cookie::new("regular", "abc");
7811        assert!(!normal_cookie.has_security_prefix());
7812        assert_eq!(normal_cookie.prefix(), None);
7813    }
7814
7815    #[test]
7816    fn cookie_prefix_as_str() {
7817        assert_eq!(CookiePrefix::Host.as_str(), "__Host-");
7818        assert_eq!(CookiePrefix::Secure.as_str(), "__Secure-");
7819    }
7820
7821    #[test]
7822    fn cookie_prefix_error_display() {
7823        assert_eq!(
7824            CookiePrefixError::HostRequiresSecure.to_string(),
7825            "__Host- prefix requires Secure flag to be true"
7826        );
7827        assert_eq!(
7828            CookiePrefixError::HostCannotHaveDomain.to_string(),
7829            "__Host- prefix cannot have a Domain attribute"
7830        );
7831        assert_eq!(
7832            CookiePrefixError::HostRequiresRootPath.to_string(),
7833            "__Host- prefix requires Path=\"/\""
7834        );
7835        assert_eq!(
7836            CookiePrefixError::SecureRequiresSecure.to_string(),
7837            "__Secure- prefix requires Secure flag to be true"
7838        );
7839    }
7840
7841    #[test]
7842    fn session_cookie_header_format() {
7843        let cookie = Cookie::session("sid", "abc", true);
7844        let header = cookie.to_header_value();
7845        assert!(header.contains("sid=abc"));
7846        assert!(header.contains("HttpOnly"));
7847        assert!(header.contains("Secure"));
7848        assert!(header.contains("SameSite=Lax"));
7849        assert!(header.contains("Path=/"));
7850    }
7851
7852    #[test]
7853    fn host_prefixed_cookie_header_format() {
7854        let cookie = Cookie::host_prefixed("session", "abc")
7855            .http_only(true)
7856            .same_site(SameSite::Strict);
7857        let header = cookie.to_header_value();
7858        assert!(header.contains("__Host-session=abc"));
7859        assert!(header.contains("Secure"));
7860        assert!(header.contains("Path=/"));
7861        assert!(header.contains("HttpOnly"));
7862        assert!(header.contains("SameSite=Strict"));
7863    }
7864
7865    // =========================================================================
7866    // Request Cookie Extractor Tests
7867    // =========================================================================
7868
7869    #[test]
7870    fn request_cookies_parses_single_cookie() {
7871        let cookies = RequestCookies::from_header("session_id=abc123");
7872        assert_eq!(cookies.len(), 1);
7873        assert_eq!(cookies.get("session_id"), Some("abc123"));
7874    }
7875
7876    #[test]
7877    fn request_cookies_parses_multiple_cookies() {
7878        let cookies = RequestCookies::from_header("session_id=abc123; user=bob; theme=dark");
7879        assert_eq!(cookies.len(), 3);
7880        assert_eq!(cookies.get("session_id"), Some("abc123"));
7881        assert_eq!(cookies.get("user"), Some("bob"));
7882        assert_eq!(cookies.get("theme"), Some("dark"));
7883    }
7884
7885    #[test]
7886    fn request_cookies_handles_whitespace() {
7887        let cookies = RequestCookies::from_header("  session_id = abc123 ;  user=bob  ");
7888        assert_eq!(cookies.get("session_id"), Some("abc123"));
7889        assert_eq!(cookies.get("user"), Some("bob"));
7890    }
7891
7892    #[test]
7893    fn request_cookies_handles_empty_header() {
7894        let cookies = RequestCookies::from_header("");
7895        assert!(cookies.is_empty());
7896    }
7897
7898    #[test]
7899    fn request_cookies_handles_malformed_pairs() {
7900        // Malformed pairs without = should be skipped
7901        let cookies = RequestCookies::from_header("valid=value; malformed; another=good");
7902        assert_eq!(cookies.len(), 2);
7903        assert_eq!(cookies.get("valid"), Some("value"));
7904        assert_eq!(cookies.get("another"), Some("good"));
7905        assert!(!cookies.contains("malformed"));
7906    }
7907
7908    #[test]
7909    fn request_cookies_contains_check() {
7910        let cookies = RequestCookies::from_header("session=abc");
7911        assert!(cookies.contains("session"));
7912        assert!(!cookies.contains("missing"));
7913    }
7914
7915    #[test]
7916    fn request_cookies_iter() {
7917        let cookies = RequestCookies::from_header("a=1; b=2");
7918        let pairs: Vec<_> = cookies.iter().collect();
7919        assert_eq!(pairs.len(), 2);
7920        assert!(pairs.contains(&("a", "1")));
7921        assert!(pairs.contains(&("b", "2")));
7922    }
7923
7924    #[test]
7925    fn request_cookies_from_request() {
7926        let ctx = test_context();
7927        let mut req = Request::new(Method::Get, "/");
7928        req.headers_mut()
7929            .insert("cookie", b"session=xyz; user=alice".to_vec());
7930
7931        let result = futures_executor::block_on(RequestCookies::from_request(&ctx, &mut req));
7932        let cookies = result.unwrap();
7933        assert_eq!(cookies.get("session"), Some("xyz"));
7934        assert_eq!(cookies.get("user"), Some("alice"));
7935    }
7936
7937    #[test]
7938    fn request_cookies_from_request_no_cookie_header() {
7939        let ctx = test_context();
7940        let mut req = Request::new(Method::Get, "/");
7941
7942        let result = futures_executor::block_on(RequestCookies::from_request(&ctx, &mut req));
7943        let cookies = result.unwrap();
7944        assert!(cookies.is_empty());
7945    }
7946
7947    #[test]
7948    fn request_cookie_extractor_found() {
7949        #[derive(Debug)]
7950        struct TestCookie;
7951        impl CookieName for TestCookie {
7952            const NAME: &'static str = "test_cookie";
7953        }
7954
7955        let ctx = test_context();
7956        let mut req = Request::new(Method::Get, "/");
7957        req.headers_mut()
7958            .insert("cookie", b"test_cookie=hello_world".to_vec());
7959
7960        let result =
7961            futures_executor::block_on(RequestCookie::<TestCookie>::from_request(&ctx, &mut req));
7962        let cookie = result.unwrap();
7963        assert_eq!(cookie.value(), "hello_world");
7964    }
7965
7966    #[test]
7967    fn request_cookie_extractor_not_found() {
7968        #[derive(Debug)]
7969        struct MissingCookie;
7970        impl CookieName for MissingCookie {
7971            const NAME: &'static str = "missing";
7972        }
7973
7974        let ctx = test_context();
7975        let mut req = Request::new(Method::Get, "/");
7976        req.headers_mut().insert("cookie", b"other=value".to_vec());
7977
7978        let result = futures_executor::block_on(RequestCookie::<MissingCookie>::from_request(
7979            &ctx, &mut req,
7980        ));
7981        assert!(result.is_err());
7982        let err = result.unwrap_err();
7983        assert!(matches!(
7984            err,
7985            CookieExtractError::NotFound { name: "missing" }
7986        ));
7987    }
7988
7989    #[test]
7990    fn request_cookie_deref() {
7991        #[derive(Debug)]
7992        struct TestCookie;
7993        impl CookieName for TestCookie {
7994            const NAME: &'static str = "test";
7995        }
7996
7997        let cookie = RequestCookie::<TestCookie>::new("test_value");
7998        // Test Deref to str
7999        assert_eq!(&*cookie, "test_value");
8000        // Test AsRef
8001        assert_eq!(cookie.as_ref(), "test_value");
8002    }
8003
8004    #[test]
8005    fn session_id_cookie_marker() {
8006        let ctx = test_context();
8007        let mut req = Request::new(Method::Get, "/");
8008        req.headers_mut()
8009            .insert("cookie", b"session_id=sess123".to_vec());
8010
8011        let result = futures_executor::block_on(RequestCookie::<SessionIdCookie>::from_request(
8012            &ctx, &mut req,
8013        ));
8014        let cookie = result.unwrap();
8015        assert_eq!(cookie.value(), "sess123");
8016    }
8017
8018    #[test]
8019    fn csrf_token_cookie_marker() {
8020        let ctx = test_context();
8021        let mut req = Request::new(Method::Get, "/");
8022        req.headers_mut()
8023            .insert("cookie", b"csrf_token=csrf_abc".to_vec());
8024
8025        let result = futures_executor::block_on(RequestCookie::<CsrfTokenCookie>::from_request(
8026            &ctx, &mut req,
8027        ));
8028        let cookie = result.unwrap();
8029        assert_eq!(cookie.value(), "csrf_abc");
8030    }
8031}
8032
8033#[cfg(test)]
8034mod background_tasks_tests {
8035    use super::*;
8036    use crate::request::Method;
8037    use std::sync::Arc;
8038    use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
8039
8040    fn test_context() -> RequestContext {
8041        let cx = asupersync::Cx::for_testing();
8042        RequestContext::new(cx, 12345)
8043    }
8044
8045    #[test]
8046    fn background_tasks_inner_new_is_empty() {
8047        let inner = BackgroundTasksInner::new();
8048        assert!(inner.take().is_empty());
8049    }
8050
8051    #[test]
8052    fn background_tasks_inner_push_and_take() {
8053        let inner = BackgroundTasksInner::new();
8054
8055        let executed = Arc::new(AtomicBool::new(false));
8056        let executed_clone = executed.clone();
8057
8058        // BackgroundTask is FnOnce() -> Pin<Box<Future>>
8059        inner.push(Box::new(move || {
8060            Box::pin(async move {
8061                executed_clone.store(true, Ordering::SeqCst);
8062            })
8063        }));
8064
8065        let tasks = inner.take();
8066        assert_eq!(tasks.len(), 1);
8067
8068        // Execute the task: call to get future, then block_on
8069        let task_fn = tasks.into_iter().next().unwrap();
8070        let future = task_fn();
8071        futures_executor::block_on(future);
8072        assert!(executed.load(Ordering::SeqCst));
8073    }
8074
8075    #[test]
8076    fn background_tasks_inner_take_empties_queue() {
8077        let inner = BackgroundTasksInner::new();
8078
8079        inner.push(Box::new(|| Box::pin(async {})));
8080        inner.push(Box::new(|| Box::pin(async {})));
8081
8082        let tasks = inner.take();
8083        assert_eq!(tasks.len(), 2);
8084
8085        // Taking again should be empty
8086        let tasks = inner.take();
8087        assert!(tasks.is_empty());
8088    }
8089
8090    #[test]
8091    fn background_tasks_add_async_task() {
8092        let mut tasks = BackgroundTasks::new();
8093
8094        let counter = Arc::new(AtomicU32::new(0));
8095        let counter_clone = counter.clone();
8096
8097        tasks.add_task(move || async move {
8098            counter_clone.fetch_add(1, Ordering::SeqCst);
8099        });
8100
8101        let queued = tasks.take_tasks();
8102        assert_eq!(queued.len(), 1);
8103
8104        // Execute the task
8105        let task_fn = queued.into_iter().next().unwrap();
8106        futures_executor::block_on(task_fn());
8107        assert_eq!(counter.load(Ordering::SeqCst), 1);
8108    }
8109
8110    #[test]
8111    fn background_tasks_add_sync_task() {
8112        let mut tasks = BackgroundTasks::new();
8113
8114        let counter = Arc::new(AtomicU32::new(0));
8115        let counter_clone = counter.clone();
8116
8117        tasks.add_sync_task(move || {
8118            counter_clone.fetch_add(10, Ordering::SeqCst);
8119        });
8120
8121        let queued = tasks.take_tasks();
8122        assert_eq!(queued.len(), 1);
8123
8124        let task_fn = queued.into_iter().next().unwrap();
8125        futures_executor::block_on(task_fn());
8126        assert_eq!(counter.load(Ordering::SeqCst), 10);
8127    }
8128
8129    #[test]
8130    fn background_tasks_multiple_tasks_execute_in_order() {
8131        let mut tasks = BackgroundTasks::new();
8132
8133        let order = Arc::new(parking_lot::Mutex::new(Vec::new()));
8134        let order1 = order.clone();
8135        let order2 = order.clone();
8136        let order3 = order.clone();
8137
8138        tasks.add_task(move || async move {
8139            order1.lock().push(1);
8140        });
8141        tasks.add_task(move || async move {
8142            order2.lock().push(2);
8143        });
8144        tasks.add_task(move || async move {
8145            order3.lock().push(3);
8146        });
8147
8148        let queued = tasks.take_tasks();
8149        assert_eq!(queued.len(), 3);
8150
8151        // Execute tasks in order
8152        for task_fn in queued {
8153            futures_executor::block_on(task_fn());
8154        }
8155
8156        assert_eq!(*order.lock(), vec![1, 2, 3]);
8157    }
8158
8159    #[test]
8160    fn background_tasks_from_request_creates_new() {
8161        let ctx = test_context();
8162        let mut req = Request::new(Method::Get, "/");
8163
8164        let tasks = futures_executor::block_on(BackgroundTasks::from_request(&ctx, &mut req))
8165            .expect("extraction should succeed");
8166
8167        // Should be empty initially
8168        let inner_tasks = tasks.into_inner().take();
8169        assert!(inner_tasks.is_empty());
8170    }
8171
8172    #[test]
8173    fn background_tasks_from_request_shares_inner() {
8174        let ctx = test_context();
8175        let mut req = Request::new(Method::Get, "/");
8176
8177        // First extraction creates inner
8178        let mut tasks1 = futures_executor::block_on(BackgroundTasks::from_request(&ctx, &mut req))
8179            .expect("extraction should succeed");
8180
8181        let counter = Arc::new(AtomicU32::new(0));
8182        let counter_clone = counter.clone();
8183        tasks1.add_task(move || async move {
8184            counter_clone.store(42, Ordering::SeqCst);
8185        });
8186
8187        // Second extraction should get the same inner
8188        let tasks2 = futures_executor::block_on(BackgroundTasks::from_request(&ctx, &mut req))
8189            .expect("extraction should succeed");
8190
8191        // Inner should have the task added by tasks1
8192        let queued = tasks2.into_inner().take();
8193        assert_eq!(queued.len(), 1);
8194
8195        // Execute and verify
8196        let task_fn = queued.into_iter().next().unwrap();
8197        futures_executor::block_on(task_fn());
8198        assert_eq!(counter.load(Ordering::SeqCst), 42);
8199    }
8200
8201    #[test]
8202    fn background_tasks_debug_shows_task_count() {
8203        let mut tasks = BackgroundTasks::new();
8204
8205        let debug_empty = format!("{:?}", tasks);
8206        assert!(debug_empty.contains("task_count"));
8207        assert!(debug_empty.contains("BackgroundTasks"));
8208
8209        tasks.add_task(|| async {});
8210        tasks.add_task(|| async {});
8211
8212        let debug_with_tasks = format!("{:?}", tasks);
8213        assert!(debug_with_tasks.contains("task_count"));
8214    }
8215
8216    #[test]
8217    fn background_tasks_inner_thread_safe() {
8218        let inner = BackgroundTasksInner::new();
8219        let inner_clone = inner.clone();
8220
8221        let counter = Arc::new(AtomicU32::new(0));
8222        let counter1 = counter.clone();
8223        let counter2 = counter.clone();
8224
8225        // Push from two "threads" (simulated via sequential calls)
8226        inner.push(Box::new(move || {
8227            Box::pin(async move {
8228                counter1.fetch_add(1, Ordering::SeqCst);
8229            })
8230        }));
8231        inner_clone.push(Box::new(move || {
8232            Box::pin(async move {
8233                counter2.fetch_add(10, Ordering::SeqCst);
8234            })
8235        }));
8236
8237        // Both tasks should be in the queue
8238        let tasks = inner.take();
8239        assert_eq!(tasks.len(), 2);
8240
8241        for task_fn in tasks {
8242            futures_executor::block_on(task_fn());
8243        }
8244
8245        assert_eq!(counter.load(Ordering::SeqCst), 11);
8246    }
8247
8248    #[test]
8249    fn background_tasks_into_inner_conversion() {
8250        let mut tasks = BackgroundTasks::new();
8251
8252        tasks.add_task(|| async {});
8253        tasks.add_task(|| async {});
8254
8255        let inner = tasks.into_inner();
8256        let queued = inner.take();
8257        assert_eq!(queued.len(), 2);
8258    }
8259
8260    #[test]
8261    fn background_tasks_is_empty_and_len() {
8262        let mut tasks = BackgroundTasks::new();
8263
8264        assert!(tasks.is_empty());
8265        assert_eq!(tasks.len(), 0);
8266
8267        tasks.add_task(|| async {});
8268        assert!(!tasks.is_empty());
8269        assert_eq!(tasks.len(), 1);
8270
8271        tasks.add_task(|| async {});
8272        assert_eq!(tasks.len(), 2);
8273    }
8274
8275    #[test]
8276    fn background_tasks_inner_len_and_is_empty() {
8277        let inner = BackgroundTasksInner::new();
8278
8279        assert!(inner.is_empty());
8280        assert_eq!(inner.len(), 0);
8281
8282        inner.push(Box::new(|| Box::pin(async {})));
8283        assert!(!inner.is_empty());
8284        assert_eq!(inner.len(), 1);
8285
8286        inner.push(Box::new(|| Box::pin(async {})));
8287        assert_eq!(inner.len(), 2);
8288
8289        // Take empties it
8290        let _ = inner.take();
8291        assert!(inner.is_empty());
8292        assert_eq!(inner.len(), 0);
8293    }
8294
8295    #[test]
8296    fn background_tasks_execute_all_runs_all_tasks() {
8297        let mut tasks = BackgroundTasks::new();
8298
8299        let counter = Arc::new(AtomicU32::new(0));
8300        let c1 = counter.clone();
8301        let c2 = counter.clone();
8302        let c3 = counter.clone();
8303
8304        tasks.add_task(move || async move {
8305            c1.fetch_add(1, Ordering::SeqCst);
8306        });
8307        tasks.add_task(move || async move {
8308            c2.fetch_add(10, Ordering::SeqCst);
8309        });
8310        tasks.add_task(move || async move {
8311            c3.fetch_add(100, Ordering::SeqCst);
8312        });
8313
8314        futures_executor::block_on(tasks.execute_all());
8315        assert_eq!(counter.load(Ordering::SeqCst), 111);
8316    }
8317
8318    #[test]
8319    fn background_tasks_execute_with_context_respects_cancellation() {
8320        let ctx = test_context();
8321        let mut tasks = BackgroundTasks::new();
8322
8323        let counter = Arc::new(AtomicU32::new(0));
8324        let c1 = counter.clone();
8325        let c2 = counter.clone();
8326
8327        tasks.add_task(move || async move {
8328            c1.fetch_add(1, Ordering::SeqCst);
8329        });
8330        tasks.add_task(move || async move {
8331            c2.fetch_add(10, Ordering::SeqCst);
8332        });
8333
8334        // Execute with non-cancelled context - all tasks should run
8335        futures_executor::block_on(tasks.execute_with_context(&ctx));
8336        assert_eq!(counter.load(Ordering::SeqCst), 11);
8337    }
8338
8339    #[test]
8340    fn background_tasks_execute_with_panic_isolation_handles_closure_panic() {
8341        let mut tasks = BackgroundTasks::new();
8342
8343        let counter = Arc::new(AtomicU32::new(0));
8344        let c1 = counter.clone();
8345        let c2 = counter.clone();
8346
8347        // First task succeeds
8348        tasks.add_task(move || async move {
8349            c1.fetch_add(1, Ordering::SeqCst);
8350        });
8351
8352        // Second task panics in closure (not in async block)
8353        tasks.inner.push(Box::new(|| {
8354            panic!("intentional test panic");
8355        }));
8356
8357        // Third task should still run after second panics
8358        tasks.add_task(move || async move {
8359            c2.fetch_add(100, Ordering::SeqCst);
8360        });
8361
8362        // Execute with panic isolation
8363        futures_executor::block_on(tasks.execute_with_panic_isolation());
8364
8365        // First task ran, second panicked, third ran
8366        // Note: Due to the way the test is structured, the third task should run
8367        assert_eq!(counter.load(Ordering::SeqCst), 101);
8368    }
8369
8370    #[test]
8371    fn format_panic_message_extracts_str() {
8372        let panic_info: Box<dyn std::any::Any + Send> = Box::new("test panic message");
8373        let msg = super::format_panic_message(&panic_info);
8374        assert_eq!(msg, "test panic message");
8375    }
8376
8377    #[test]
8378    fn format_panic_message_extracts_string() {
8379        let panic_info: Box<dyn std::any::Any + Send> = Box::new(String::from("string panic"));
8380        let msg = super::format_panic_message(&panic_info);
8381        assert_eq!(msg, "string panic");
8382    }
8383
8384    #[test]
8385    fn format_panic_message_handles_unknown() {
8386        let panic_info: Box<dyn std::any::Any + Send> = Box::new(42i32);
8387        let msg = super::format_panic_message(&panic_info);
8388        assert_eq!(msg, "unknown panic");
8389    }
8390
8391    // =========================================================================
8392    // Execution Timing Tests (bd-1ktc)
8393    // These tests verify background task execution timing behavior as it would
8394    // be used by the HTTP server: response is built first, then tasks execute.
8395    // =========================================================================
8396
8397    #[test]
8398    fn background_tasks_timing_single_task_after_response() {
8399        // Simulates: handler queues task, response returned, task executes after
8400        let ctx = test_context();
8401        let mut req = Request::new(Method::Get, "/");
8402
8403        // Phase 1: Handler adds task to the request (simulating FromRequest extraction)
8404        let counter = Arc::new(AtomicU32::new(0));
8405        let counter_clone = counter.clone();
8406
8407        let mut tasks = futures_executor::block_on(BackgroundTasks::from_request(&ctx, &mut req))
8408            .expect("extraction should succeed");
8409
8410        tasks.add_task(move || async move {
8411            counter_clone.store(42, Ordering::SeqCst);
8412        });
8413
8414        // Phase 2: Response is "returned" - counter should still be 0
8415        // (In real usage, response would be sent to client here)
8416        assert_eq!(
8417            counter.load(Ordering::SeqCst),
8418            0,
8419            "task should not run before take_tasks"
8420        );
8421
8422        // Phase 3: Server takes tasks after response is "sent"
8423        // Get the inner from the request extension (simulating App::take_background_tasks)
8424        let taken_tasks = req
8425            .get_extension::<BackgroundTasksInner>()
8426            .map(|inner| BackgroundTasks::from_inner(inner.clone()))
8427            .expect("tasks should be in extension");
8428
8429        // Counter still 0 - tasks haven't executed yet
8430        assert_eq!(
8431            counter.load(Ordering::SeqCst),
8432            0,
8433            "task should not run before execute_all"
8434        );
8435
8436        // Phase 4: Execute tasks (server does this after sending response)
8437        futures_executor::block_on(taken_tasks.execute_all());
8438
8439        // Now the counter should be set
8440        assert_eq!(
8441            counter.load(Ordering::SeqCst),
8442            42,
8443            "task should have executed"
8444        );
8445    }
8446
8447    #[test]
8448    fn background_tasks_timing_multiple_tasks_in_order() {
8449        // Verify that multiple tasks execute in the order they were queued
8450        let ctx = test_context();
8451        let mut req = Request::new(Method::Get, "/");
8452
8453        let execution_order = Arc::new(parking_lot::Mutex::new(Vec::new()));
8454        let order1 = execution_order.clone();
8455        let order2 = execution_order.clone();
8456        let order3 = execution_order.clone();
8457
8458        let mut tasks = futures_executor::block_on(BackgroundTasks::from_request(&ctx, &mut req))
8459            .expect("extraction should succeed");
8460
8461        // Queue tasks in order: 1, 2, 3
8462        tasks.add_task(move || async move {
8463            order1.lock().push(1);
8464        });
8465        tasks.add_task(move || async move {
8466            order2.lock().push(2);
8467        });
8468        tasks.add_task(move || async move {
8469            order3.lock().push(3);
8470        });
8471
8472        // Simulate response being built/sent
8473        assert!(
8474            execution_order.lock().is_empty(),
8475            "no tasks should run during response building"
8476        );
8477
8478        // Take and execute tasks
8479        let taken_tasks = req
8480            .get_extension::<BackgroundTasksInner>()
8481            .map(|inner| BackgroundTasks::from_inner(inner.clone()))
8482            .expect("tasks should be in extension");
8483
8484        futures_executor::block_on(taken_tasks.execute_all());
8485
8486        // Verify order: 1, 2, 3
8487        assert_eq!(
8488            *execution_order.lock(),
8489            vec![1, 2, 3],
8490            "tasks should execute in queue order"
8491        );
8492    }
8493
8494    #[test]
8495    fn background_tasks_timing_tasks_can_spawn_more_tasks() {
8496        // Test that a task can add more tasks to its own BackgroundTasks
8497        // Note: In real usage, the inner is shared, so adding during execution
8498        // would affect the same queue (but those tasks wouldn't run until next iteration)
8499
8500        let mut tasks = BackgroundTasks::new();
8501
8502        let counter = Arc::new(AtomicU32::new(0));
8503        let c1 = counter.clone();
8504        let c2 = counter.clone();
8505        let c3 = counter.clone();
8506
8507        // First batch of tasks
8508        tasks.add_task(move || async move {
8509            c1.fetch_add(1, Ordering::SeqCst);
8510        });
8511        tasks.add_task(move || async move {
8512            c2.fetch_add(10, Ordering::SeqCst);
8513        });
8514
8515        // Execute first batch
8516        futures_executor::block_on(tasks.execute_all());
8517        assert_eq!(
8518            counter.load(Ordering::SeqCst),
8519            11,
8520            "first batch should complete"
8521        );
8522
8523        // Create new tasks (simulating a task that spawns more work)
8524        let mut more_tasks = BackgroundTasks::new();
8525        more_tasks.add_task(move || async move {
8526            c3.fetch_add(100, Ordering::SeqCst);
8527        });
8528
8529        // Execute second batch
8530        futures_executor::block_on(more_tasks.execute_all());
8531        assert_eq!(
8532            counter.load(Ordering::SeqCst),
8533            111,
8534            "spawned tasks should also run"
8535        );
8536    }
8537
8538    #[test]
8539    fn background_tasks_timing_independent_requests() {
8540        // Verify that tasks from different requests are independent
8541        // (simulating concurrent request handling)
8542        let ctx1 = test_context();
8543        let ctx2 = test_context();
8544
8545        let mut req1 = Request::new(Method::Get, "/request1");
8546        let mut req2 = Request::new(Method::Get, "/request2");
8547
8548        let counter1 = Arc::new(AtomicU32::new(0));
8549        let counter2 = Arc::new(AtomicU32::new(0));
8550        let c1 = counter1.clone();
8551        let c2 = counter2.clone();
8552
8553        // Request 1 adds task
8554        let mut tasks1 =
8555            futures_executor::block_on(BackgroundTasks::from_request(&ctx1, &mut req1))
8556                .expect("extraction should succeed");
8557        tasks1.add_task(move || async move {
8558            c1.store(100, Ordering::SeqCst);
8559        });
8560
8561        // Request 2 adds task
8562        let mut tasks2 =
8563            futures_executor::block_on(BackgroundTasks::from_request(&ctx2, &mut req2))
8564                .expect("extraction should succeed");
8565        tasks2.add_task(move || async move {
8566            c2.store(200, Ordering::SeqCst);
8567        });
8568
8569        // Execute request 1's tasks
8570        let taken1 = req1
8571            .get_extension::<BackgroundTasksInner>()
8572            .map(|inner| BackgroundTasks::from_inner(inner.clone()))
8573            .expect("tasks should be in extension");
8574        futures_executor::block_on(taken1.execute_all());
8575
8576        // Only counter1 should be affected
8577        assert_eq!(
8578            counter1.load(Ordering::SeqCst),
8579            100,
8580            "request 1 task should run"
8581        );
8582        assert_eq!(
8583            counter2.load(Ordering::SeqCst),
8584            0,
8585            "request 2 task should not run yet"
8586        );
8587
8588        // Execute request 2's tasks
8589        let taken2 = req2
8590            .get_extension::<BackgroundTasksInner>()
8591            .map(|inner| BackgroundTasks::from_inner(inner.clone()))
8592            .expect("tasks should be in extension");
8593        futures_executor::block_on(taken2.execute_all());
8594
8595        // Now both should be set
8596        assert_eq!(
8597            counter1.load(Ordering::SeqCst),
8598            100,
8599            "request 1 task unchanged"
8600        );
8601        assert_eq!(
8602            counter2.load(Ordering::SeqCst),
8603            200,
8604            "request 2 task should run"
8605        );
8606    }
8607
8608    #[test]
8609    fn background_tasks_timing_nonblocking_next_request() {
8610        // Verify that one request's tasks don't block handling of next request
8611        // This is implicit in the design: tasks are separate from response handling
8612
8613        let ctx = test_context();
8614        let mut req1 = Request::new(Method::Get, "/first");
8615        let mut req2 = Request::new(Method::Get, "/second");
8616
8617        let req1_done = Arc::new(AtomicBool::new(false));
8618        let req2_done = Arc::new(AtomicBool::new(false));
8619        let r1 = req1_done.clone();
8620        let r2 = req2_done.clone();
8621
8622        // First request adds a task
8623        let mut tasks1 = futures_executor::block_on(BackgroundTasks::from_request(&ctx, &mut req1))
8624            .expect("extraction should succeed");
8625        tasks1.add_task(move || async move {
8626            // Simulate some work
8627            r1.store(true, Ordering::SeqCst);
8628        });
8629
8630        // Second request can be handled immediately (before first request's tasks run)
8631        let mut tasks2 = futures_executor::block_on(BackgroundTasks::from_request(&ctx, &mut req2))
8632            .expect("extraction should succeed");
8633        tasks2.add_task(move || async move {
8634            r2.store(true, Ordering::SeqCst);
8635        });
8636
8637        // At this point, neither request's tasks have run
8638        assert!(
8639            !req1_done.load(Ordering::SeqCst),
8640            "req1 tasks not yet executed"
8641        );
8642        assert!(
8643            !req2_done.load(Ordering::SeqCst),
8644            "req2 tasks not yet executed"
8645        );
8646
8647        // The second request's response could be "sent" before first request's tasks run
8648        // This demonstrates non-blocking behavior
8649
8650        // Now execute both sets of tasks
8651        let taken1 = req1
8652            .get_extension::<BackgroundTasksInner>()
8653            .map(|inner| BackgroundTasks::from_inner(inner.clone()))
8654            .expect("tasks should be in extension");
8655        let taken2 = req2
8656            .get_extension::<BackgroundTasksInner>()
8657            .map(|inner| BackgroundTasks::from_inner(inner.clone()))
8658            .expect("tasks should be in extension");
8659
8660        futures_executor::block_on(taken1.execute_all());
8661        futures_executor::block_on(taken2.execute_all());
8662
8663        assert!(
8664            req1_done.load(Ordering::SeqCst),
8665            "req1 tasks should be done"
8666        );
8667        assert!(
8668            req2_done.load(Ordering::SeqCst),
8669            "req2 tasks should be done"
8670        );
8671    }
8672}
8673
8674// ============================================================================
8675// Header Extractor
8676// ============================================================================
8677
8678/// Header extractor for individual HTTP headers.
8679///
8680/// Extracts a single header value by name from the request. The header name
8681/// is derived from the generic type's name, converting from snake_case to
8682/// Header-Case (e.g., `x_request_id` -> `X-Request-Id`).
8683///
8684/// For required headers, extraction failure returns 400 Bad Request.
8685/// Use `Option<Header<T>>` for optional headers.
8686///
8687/// # Example
8688///
8689/// ```ignore
8690/// use fastapi_core::extract::Header;
8691///
8692/// // Extract Authorization header (required)
8693/// async fn protected(auth: Header<String>) -> impl IntoResponse {
8694///     format!("Authorized with: {}", auth.0)
8695/// }
8696///
8697/// // Extract optional header
8698/// async fn optional_header(trace_id: Option<Header<String>>) -> impl IntoResponse {
8699///     match trace_id {
8700///         Some(Header(id)) => format!("Trace: {id}"),
8701///         None => "No trace".into(),
8702///     }
8703/// }
8704/// ```
8705#[derive(Debug, Clone)]
8706pub struct Header<T> {
8707    /// The extracted header value.
8708    pub value: T,
8709    /// The original header name used for extraction.
8710    pub name: String,
8711}
8712
8713impl<T> Header<T> {
8714    /// Create a new Header wrapper.
8715    #[must_use]
8716    pub fn new(name: impl Into<String>, value: T) -> Self {
8717        Self {
8718            value,
8719            name: name.into(),
8720        }
8721    }
8722
8723    /// Unwrap the inner value.
8724    #[must_use]
8725    pub fn into_inner(self) -> T {
8726        self.value
8727    }
8728}
8729
8730impl<T> Deref for Header<T> {
8731    type Target = T;
8732
8733    fn deref(&self) -> &Self::Target {
8734        &self.value
8735    }
8736}
8737
8738impl<T> DerefMut for Header<T> {
8739    fn deref_mut(&mut self) -> &mut Self::Target {
8740        &mut self.value
8741    }
8742}
8743
8744/// Convert a snake_case name to Header-Case.
8745///
8746/// Examples:
8747/// - `x_request_id` -> `X-Request-Id`
8748/// - `content_type` -> `Content-Type`
8749/// - `authorization` -> `Authorization`
8750#[must_use]
8751pub fn snake_to_header_case(name: &str) -> String {
8752    name.split('_')
8753        .map(|word| {
8754            let mut chars = word.chars();
8755            match chars.next() {
8756                None => String::new(),
8757                Some(first) => {
8758                    let mut result = first.to_uppercase().to_string();
8759                    result.extend(chars);
8760                    result
8761                }
8762            }
8763        })
8764        .collect::<Vec<_>>()
8765        .join("-")
8766}
8767
8768/// Error returned when header extraction fails.
8769#[derive(Debug)]
8770pub enum HeaderExtractError {
8771    /// Required header is missing from the request.
8772    MissingHeader {
8773        /// The header name that was expected.
8774        name: String,
8775    },
8776    /// Header value could not be parsed as UTF-8.
8777    InvalidUtf8 {
8778        /// The header name.
8779        name: String,
8780    },
8781    /// Header value could not be parsed to the target type.
8782    ParseError {
8783        /// The header name.
8784        name: String,
8785        /// The raw value that couldn't be parsed.
8786        value: String,
8787        /// Description of the expected type.
8788        expected: &'static str,
8789        /// The parse error message.
8790        message: String,
8791    },
8792}
8793
8794impl std::fmt::Display for HeaderExtractError {
8795    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
8796        match self {
8797            Self::MissingHeader { name } => {
8798                write!(f, "Missing required header: {name}")
8799            }
8800            Self::InvalidUtf8 { name } => {
8801                write!(f, "Header '{name}' contains invalid UTF-8")
8802            }
8803            Self::ParseError {
8804                name,
8805                value,
8806                expected,
8807                message,
8808            } => {
8809                write!(
8810                    f,
8811                    "Failed to parse header '{name}' value '{value}' as {expected}: {message}"
8812                )
8813            }
8814        }
8815    }
8816}
8817
8818impl std::error::Error for HeaderExtractError {}
8819
8820impl IntoResponse for HeaderExtractError {
8821    fn into_response(self) -> crate::response::Response {
8822        // Missing or invalid headers are client errors (400)
8823        let error = match &self {
8824            HeaderExtractError::MissingHeader { name } => {
8825                ValidationError::missing(crate::error::loc::header(name))
8826                    .with_msg(format!("Missing required header: {name}"))
8827            }
8828            HeaderExtractError::InvalidUtf8 { name } => {
8829                ValidationError::type_error(crate::error::loc::header(name), "string")
8830                    .with_msg(format!("Header '{name}' contains invalid UTF-8"))
8831            }
8832            HeaderExtractError::ParseError {
8833                name,
8834                value,
8835                expected,
8836                message,
8837            } => ValidationError::type_error(crate::error::loc::header(name), expected)
8838                .with_msg(format!("Failed to parse as {expected}: {message}"))
8839                .with_input(serde_json::Value::String(value.clone())),
8840        };
8841        ValidationErrors::single(error).into_response()
8842    }
8843}
8844
8845/// Trait for types that can be extracted from header values.
8846pub trait FromHeaderValue: Sized {
8847    /// Parse the header value.
8848    fn from_header_value(value: &str) -> Result<Self, String>;
8849
8850    /// Return the expected type name for error messages.
8851    fn type_name() -> &'static str;
8852}
8853
8854impl FromHeaderValue for String {
8855    fn from_header_value(value: &str) -> Result<Self, String> {
8856        Ok(value.to_string())
8857    }
8858
8859    fn type_name() -> &'static str {
8860        "String"
8861    }
8862}
8863
8864impl FromHeaderValue for i32 {
8865    fn from_header_value(value: &str) -> Result<Self, String> {
8866        value.parse().map_err(|e| format!("{e}"))
8867    }
8868
8869    fn type_name() -> &'static str {
8870        "i32"
8871    }
8872}
8873
8874impl FromHeaderValue for i64 {
8875    fn from_header_value(value: &str) -> Result<Self, String> {
8876        value.parse().map_err(|e| format!("{e}"))
8877    }
8878
8879    fn type_name() -> &'static str {
8880        "i64"
8881    }
8882}
8883
8884impl FromHeaderValue for u32 {
8885    fn from_header_value(value: &str) -> Result<Self, String> {
8886        value.parse().map_err(|e| format!("{e}"))
8887    }
8888
8889    fn type_name() -> &'static str {
8890        "u32"
8891    }
8892}
8893
8894impl FromHeaderValue for u64 {
8895    fn from_header_value(value: &str) -> Result<Self, String> {
8896        value.parse().map_err(|e| format!("{e}"))
8897    }
8898
8899    fn type_name() -> &'static str {
8900        "u64"
8901    }
8902}
8903
8904impl FromHeaderValue for bool {
8905    fn from_header_value(value: &str) -> Result<Self, String> {
8906        match value.to_ascii_lowercase().as_str() {
8907            "true" | "1" | "yes" | "on" => Ok(true),
8908            "false" | "0" | "no" | "off" => Ok(false),
8909            _ => Err(format!("invalid boolean: {value}")),
8910        }
8911    }
8912
8913    fn type_name() -> &'static str {
8914        "bool"
8915    }
8916}
8917
8918/// Named header extractor with explicit header name.
8919///
8920/// Use this when the header name doesn't match a type name.
8921///
8922/// # Example
8923///
8924/// ```ignore
8925/// use fastapi_core::extract::NamedHeader;
8926///
8927/// async fn handler(
8928///     auth: NamedHeader<String, AuthorizationHeader>,
8929///     trace: NamedHeader<String, XRequestIdHeader>,
8930/// ) -> impl IntoResponse {
8931///     // ...
8932/// }
8933///
8934/// struct AuthorizationHeader;
8935/// impl HeaderName for AuthorizationHeader {
8936///     const NAME: &'static str = "Authorization";
8937/// }
8938/// ```
8939#[derive(Debug, Clone)]
8940pub struct NamedHeader<T, N> {
8941    /// The extracted header value.
8942    pub value: T,
8943    _marker: std::marker::PhantomData<N>,
8944}
8945
8946/// Trait for header name markers.
8947pub trait HeaderName {
8948    /// The HTTP header name.
8949    const NAME: &'static str;
8950}
8951
8952impl<T, N> NamedHeader<T, N> {
8953    /// Create a new named header wrapper.
8954    #[must_use]
8955    pub fn new(value: T) -> Self {
8956        Self {
8957            value,
8958            _marker: std::marker::PhantomData,
8959        }
8960    }
8961
8962    /// Unwrap the inner value.
8963    #[must_use]
8964    pub fn into_inner(self) -> T {
8965        self.value
8966    }
8967}
8968
8969impl<T, N> Deref for NamedHeader<T, N> {
8970    type Target = T;
8971
8972    fn deref(&self) -> &Self::Target {
8973        &self.value
8974    }
8975}
8976
8977impl<T, N> DerefMut for NamedHeader<T, N> {
8978    fn deref_mut(&mut self) -> &mut Self::Target {
8979        &mut self.value
8980    }
8981}
8982
8983impl<T, N> FromRequest for NamedHeader<T, N>
8984where
8985    T: FromHeaderValue + Send + Sync + 'static,
8986    N: HeaderName + Send + Sync + 'static,
8987{
8988    type Error = HeaderExtractError;
8989
8990    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
8991        let header_name = N::NAME;
8992
8993        let value_bytes =
8994            req.headers()
8995                .get(header_name)
8996                .ok_or_else(|| HeaderExtractError::MissingHeader {
8997                    name: header_name.to_string(),
8998                })?;
8999
9000        let value_str =
9001            std::str::from_utf8(value_bytes).map_err(|_| HeaderExtractError::InvalidUtf8 {
9002                name: header_name.to_string(),
9003            })?;
9004
9005        let value =
9006            T::from_header_value(value_str).map_err(|message| HeaderExtractError::ParseError {
9007                name: header_name.to_string(),
9008                value: value_str.to_string(),
9009                expected: T::type_name(),
9010                message,
9011            })?;
9012
9013        Ok(NamedHeader::new(value))
9014    }
9015}
9016
9017// Common header name markers
9018/// Authorization header marker.
9019pub struct Authorization;
9020impl HeaderName for Authorization {
9021    const NAME: &'static str = "authorization";
9022}
9023
9024/// Content-Type header marker.
9025pub struct ContentType;
9026impl HeaderName for ContentType {
9027    const NAME: &'static str = "content-type";
9028}
9029
9030/// Accept header marker.
9031pub struct Accept;
9032impl HeaderName for Accept {
9033    const NAME: &'static str = "accept";
9034}
9035
9036/// X-Request-Id header marker.
9037pub struct XRequestId;
9038impl HeaderName for XRequestId {
9039    const NAME: &'static str = "x-request-id";
9040}
9041
9042/// User-Agent header marker.
9043pub struct UserAgent;
9044impl HeaderName for UserAgent {
9045    const NAME: &'static str = "user-agent";
9046}
9047
9048/// Host header marker.
9049pub struct Host;
9050impl HeaderName for Host {
9051    const NAME: &'static str = "host";
9052}
9053
9054// ============================================================================
9055// OAuth2 Security Extractors
9056// ============================================================================
9057
9058/// OAuth2 password bearer security scheme extractor.
9059///
9060/// Extracts a bearer token from the `Authorization` header. This implements
9061/// the OAuth2 password bearer flow where the client sends a token in the
9062/// format `Bearer <token>`.
9063///
9064/// # Example
9065///
9066/// ```ignore
9067/// use fastapi_core::OAuth2PasswordBearer;
9068///
9069/// async fn protected_route(token: OAuth2PasswordBearer) -> impl IntoResponse {
9070///     // Validate the token and get user
9071///     let user = validate_token(&token.token).await?;
9072///     format!("Hello, {}!", user.name)
9073/// }
9074/// ```
9075///
9076/// # Auto-Error Behavior
9077///
9078/// When `auto_error` is `true` (default), missing or invalid tokens result
9079/// in a 401 Unauthorized response with a `WWW-Authenticate: Bearer` header.
9080///
9081/// When `auto_error` is `false`, use `Option<OAuth2PasswordBearer>` to handle
9082/// missing tokens in your handler logic.
9083///
9084/// # OpenAPI
9085///
9086/// This extractor generates the following OpenAPI security scheme:
9087/// ```yaml
9088/// securitySchemes:
9089///   OAuth2PasswordBearer:
9090///     type: oauth2
9091///     flows:
9092///       password:
9093///         tokenUrl: "/token"
9094///         scopes: {}
9095/// ```
9096#[derive(Debug, Clone)]
9097pub struct OAuth2PasswordBearer {
9098    /// The extracted bearer token (without the "Bearer " prefix).
9099    pub token: String,
9100}
9101
9102impl OAuth2PasswordBearer {
9103    /// Create a new OAuth2PasswordBearer with the given token.
9104    #[must_use]
9105    pub fn new(token: impl Into<String>) -> Self {
9106        Self {
9107            token: token.into(),
9108        }
9109    }
9110
9111    /// Get the token value.
9112    #[must_use]
9113    pub fn token(&self) -> &str {
9114        &self.token
9115    }
9116
9117    /// Consume self and return the token.
9118    #[must_use]
9119    pub fn into_token(self) -> String {
9120        self.token
9121    }
9122}
9123
9124impl Deref for OAuth2PasswordBearer {
9125    type Target = str;
9126
9127    fn deref(&self) -> &Self::Target {
9128        &self.token
9129    }
9130}
9131
9132/// Configuration for OAuth2PasswordBearer extraction.
9133///
9134/// Use this to customize the token extraction behavior.
9135#[derive(Debug, Clone)]
9136pub struct OAuth2PasswordBearerConfig {
9137    /// URL to obtain the token. Required for OpenAPI documentation.
9138    pub token_url: String,
9139    /// URL to refresh the token. Optional.
9140    pub refresh_url: Option<String>,
9141    /// OAuth2 scopes with their descriptions.
9142    pub scopes: std::collections::HashMap<String, String>,
9143    /// Custom scheme name for OpenAPI documentation.
9144    pub scheme_name: Option<String>,
9145    /// Description for OpenAPI documentation.
9146    pub description: Option<String>,
9147    /// Whether to automatically return 401 on missing/invalid token.
9148    /// Default: true.
9149    pub auto_error: bool,
9150}
9151
9152impl Default for OAuth2PasswordBearerConfig {
9153    fn default() -> Self {
9154        Self {
9155            token_url: "/token".to_string(),
9156            refresh_url: None,
9157            scopes: std::collections::HashMap::new(),
9158            scheme_name: None,
9159            description: None,
9160            auto_error: true,
9161        }
9162    }
9163}
9164
9165impl OAuth2PasswordBearerConfig {
9166    /// Create a new configuration with the given token URL.
9167    #[must_use]
9168    pub fn new(token_url: impl Into<String>) -> Self {
9169        Self {
9170            token_url: token_url.into(),
9171            ..Default::default()
9172        }
9173    }
9174
9175    /// Set the refresh URL.
9176    #[must_use]
9177    pub fn with_refresh_url(mut self, url: impl Into<String>) -> Self {
9178        self.refresh_url = Some(url.into());
9179        self
9180    }
9181
9182    /// Add an OAuth2 scope.
9183    #[must_use]
9184    pub fn with_scope(mut self, scope: impl Into<String>, description: impl Into<String>) -> Self {
9185        self.scopes.insert(scope.into(), description.into());
9186        self
9187    }
9188
9189    /// Set the scheme name for OpenAPI.
9190    #[must_use]
9191    pub fn with_scheme_name(mut self, name: impl Into<String>) -> Self {
9192        self.scheme_name = Some(name.into());
9193        self
9194    }
9195
9196    /// Set the description for OpenAPI.
9197    #[must_use]
9198    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
9199        self.description = Some(desc.into());
9200        self
9201    }
9202
9203    /// Set whether to auto-error on missing/invalid tokens.
9204    #[must_use]
9205    pub fn with_auto_error(mut self, auto_error: bool) -> Self {
9206        self.auto_error = auto_error;
9207        self
9208    }
9209}
9210
9211/// Error when OAuth2 bearer token extraction fails.
9212#[derive(Debug, Clone)]
9213pub struct OAuth2BearerError {
9214    /// The kind of error that occurred.
9215    pub kind: OAuth2BearerErrorKind,
9216}
9217
9218/// The specific kind of OAuth2 bearer error.
9219#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9220pub enum OAuth2BearerErrorKind {
9221    /// Authorization header is missing.
9222    MissingHeader,
9223    /// Authorization header doesn't start with "Bearer ".
9224    InvalidScheme,
9225    /// Token is empty after "Bearer " prefix.
9226    EmptyToken,
9227}
9228
9229impl OAuth2BearerError {
9230    /// Create a new missing header error.
9231    #[must_use]
9232    pub fn missing_header() -> Self {
9233        Self {
9234            kind: OAuth2BearerErrorKind::MissingHeader,
9235        }
9236    }
9237
9238    /// Create a new invalid scheme error.
9239    #[must_use]
9240    pub fn invalid_scheme() -> Self {
9241        Self {
9242            kind: OAuth2BearerErrorKind::InvalidScheme,
9243        }
9244    }
9245
9246    /// Create a new empty token error.
9247    #[must_use]
9248    pub fn empty_token() -> Self {
9249        Self {
9250            kind: OAuth2BearerErrorKind::EmptyToken,
9251        }
9252    }
9253}
9254
9255impl fmt::Display for OAuth2BearerError {
9256    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
9257        match self.kind {
9258            OAuth2BearerErrorKind::MissingHeader => {
9259                write!(f, "Missing Authorization header")
9260            }
9261            OAuth2BearerErrorKind::InvalidScheme => {
9262                write!(f, "Authorization header must use Bearer scheme")
9263            }
9264            OAuth2BearerErrorKind::EmptyToken => {
9265                write!(f, "Bearer token is empty")
9266            }
9267        }
9268    }
9269}
9270
9271impl IntoResponse for OAuth2BearerError {
9272    fn into_response(self) -> crate::response::Response {
9273        use crate::response::{Response, ResponseBody, StatusCode};
9274
9275        let message = match self.kind {
9276            OAuth2BearerErrorKind::MissingHeader => "Not authenticated",
9277            OAuth2BearerErrorKind::InvalidScheme => "Invalid authentication credentials",
9278            OAuth2BearerErrorKind::EmptyToken => "Invalid authentication credentials",
9279        };
9280
9281        let body = serde_json::json!({
9282            "detail": message
9283        });
9284
9285        Response::with_status(StatusCode::UNAUTHORIZED)
9286            .header("www-authenticate", b"Bearer".to_vec())
9287            .header("content-type", b"application/json".to_vec())
9288            .body(ResponseBody::Bytes(body.to_string().into_bytes()))
9289    }
9290}
9291
9292impl FromRequest for OAuth2PasswordBearer {
9293    type Error = OAuth2BearerError;
9294
9295    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
9296        // Get the Authorization header
9297        let auth_header = req
9298            .headers()
9299            .get("authorization")
9300            .ok_or_else(OAuth2BearerError::missing_header)?;
9301
9302        // Convert to string
9303        let auth_str =
9304            std::str::from_utf8(auth_header).map_err(|_| OAuth2BearerError::invalid_scheme())?;
9305
9306        // Check for "Bearer " prefix (case-insensitive)
9307        const BEARER_PREFIX: &str = "Bearer ";
9308        const BEARER_PREFIX_LOWER: &str = "bearer ";
9309
9310        let token = if auth_str.starts_with(BEARER_PREFIX) {
9311            &auth_str[BEARER_PREFIX.len()..]
9312        } else if auth_str.starts_with(BEARER_PREFIX_LOWER) {
9313            &auth_str[BEARER_PREFIX_LOWER.len()..]
9314        } else {
9315            return Err(OAuth2BearerError::invalid_scheme());
9316        };
9317
9318        // Check token isn't empty
9319        let token = token.trim();
9320        if token.is_empty() {
9321            return Err(OAuth2BearerError::empty_token());
9322        }
9323
9324        Ok(OAuth2PasswordBearer::new(token))
9325    }
9326}
9327
9328// ============================================================================
9329// OAuth2 Password Request Form
9330// ============================================================================
9331
9332/// Error when OAuth2 password request form extraction fails.
9333#[derive(Debug, Clone, PartialEq, Eq)]
9334pub enum OAuth2PasswordFormError {
9335    /// Content-Type is not `application/x-www-form-urlencoded`.
9336    UnsupportedMediaType {
9337        /// The actual Content-Type, if any.
9338        actual: Option<String>,
9339    },
9340    /// Request body exceeds the configured limit.
9341    PayloadTooLarge {
9342        /// The actual body size.
9343        size: usize,
9344        /// The configured limit.
9345        limit: usize,
9346    },
9347    /// The `username` field is missing.
9348    MissingUsername,
9349    /// The `password` field is missing.
9350    MissingPassword,
9351    /// The `grant_type` must be `"password"` (strict mode).
9352    InvalidGrantType {
9353        /// The actual grant_type value.
9354        actual: String,
9355    },
9356    /// The `grant_type` field is missing (strict mode).
9357    MissingGrantType,
9358    /// Request body is not valid UTF-8.
9359    InvalidUtf8,
9360    /// Streaming bodies are not supported.
9361    StreamingNotSupported,
9362}
9363
9364impl std::fmt::Display for OAuth2PasswordFormError {
9365    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
9366        match self {
9367            Self::UnsupportedMediaType { actual } => {
9368                if let Some(ct) = actual {
9369                    write!(f, "Expected application/x-www-form-urlencoded, got: {ct}")
9370                } else {
9371                    write!(f, "Missing Content-Type header")
9372                }
9373            }
9374            Self::PayloadTooLarge { size, limit } => {
9375                write!(f, "Body too large: {size} > {limit}")
9376            }
9377            Self::MissingUsername => write!(f, "Missing required field: username"),
9378            Self::MissingPassword => write!(f, "Missing required field: password"),
9379            Self::InvalidGrantType { actual } => {
9380                write!(f, "grant_type must be \"password\", got: \"{actual}\"")
9381            }
9382            Self::MissingGrantType => write!(f, "Missing required field: grant_type"),
9383            Self::InvalidUtf8 => write!(f, "Invalid UTF-8 in form body"),
9384            Self::StreamingNotSupported => write!(f, "Streaming bodies not supported"),
9385        }
9386    }
9387}
9388
9389impl std::error::Error for OAuth2PasswordFormError {}
9390
9391impl IntoResponse for OAuth2PasswordFormError {
9392    fn into_response(self) -> Response {
9393        match &self {
9394            OAuth2PasswordFormError::UnsupportedMediaType { .. } => {
9395                HttpError::unsupported_media_type().into_response()
9396            }
9397            OAuth2PasswordFormError::PayloadTooLarge { size, limit } => {
9398                HttpError::payload_too_large()
9399                    .with_detail(format!("Body {size} > {limit}"))
9400                    .into_response()
9401            }
9402            OAuth2PasswordFormError::MissingUsername => ValidationErrors::single(
9403                ValidationError::new(
9404                    crate::error::error_types::MISSING,
9405                    vec![
9406                        crate::error::LocItem::field("body"),
9407                        crate::error::LocItem::field("username"),
9408                    ],
9409                )
9410                .with_msg("Field required".to_string()),
9411            )
9412            .into_response(),
9413            OAuth2PasswordFormError::MissingPassword => ValidationErrors::single(
9414                ValidationError::new(
9415                    crate::error::error_types::MISSING,
9416                    vec![
9417                        crate::error::LocItem::field("body"),
9418                        crate::error::LocItem::field("password"),
9419                    ],
9420                )
9421                .with_msg("Field required".to_string()),
9422            )
9423            .into_response(),
9424            OAuth2PasswordFormError::InvalidGrantType { actual } => ValidationErrors::single(
9425                ValidationError::new(
9426                    crate::error::error_types::VALUE_ERROR,
9427                    vec![
9428                        crate::error::LocItem::field("body"),
9429                        crate::error::LocItem::field("grant_type"),
9430                    ],
9431                )
9432                .with_msg(format!(
9433                    "grant_type must be \"password\", got: \"{actual}\""
9434                )),
9435            )
9436            .into_response(),
9437            OAuth2PasswordFormError::MissingGrantType => ValidationErrors::single(
9438                ValidationError::new(
9439                    crate::error::error_types::MISSING,
9440                    vec![
9441                        crate::error::LocItem::field("body"),
9442                        crate::error::LocItem::field("grant_type"),
9443                    ],
9444                )
9445                .with_msg("Field required".to_string()),
9446            )
9447            .into_response(),
9448            OAuth2PasswordFormError::InvalidUtf8 => HttpError::bad_request()
9449                .with_detail("Invalid UTF-8")
9450                .into_response(),
9451            OAuth2PasswordFormError::StreamingNotSupported => {
9452                HttpError::bad_request().into_response()
9453            }
9454        }
9455    }
9456}
9457
9458/// OAuth2 password request form data.
9459///
9460/// Extracts the standard OAuth2 password grant fields from a
9461/// `application/x-www-form-urlencoded` request body, matching FastAPI's
9462/// `OAuth2PasswordRequestForm`.
9463///
9464/// # Fields
9465///
9466/// - `grant_type`: Optional, should be `"password"` per spec (not enforced here; use
9467///   [`OAuth2PasswordRequestFormStrict`] for strict validation).
9468/// - `username`: Required.
9469/// - `password`: Required.
9470/// - `scope`: Space-separated scopes (empty string if not provided).
9471/// - `client_id`: Optional.
9472/// - `client_secret`: Optional.
9473///
9474/// # Computed
9475///
9476/// - [`scopes()`](OAuth2PasswordRequestForm::scopes): Returns the `scope` string split into a `Vec<String>`.
9477///
9478/// # Example
9479///
9480/// ```ignore
9481/// use fastapi_core::OAuth2PasswordRequestForm;
9482///
9483/// async fn login(form: OAuth2PasswordRequestForm) -> Response {
9484///     let username = &form.username;
9485///     let password = &form.password;
9486///     let scopes = form.scopes();
9487///     // ... authenticate user ...
9488/// }
9489/// ```
9490#[derive(Debug, Clone)]
9491pub struct OAuth2PasswordRequestForm {
9492    /// The grant type. Should be `"password"` per the OAuth2 spec.
9493    pub grant_type: Option<String>,
9494    /// The username (required).
9495    pub username: String,
9496    /// The password (required).
9497    pub password: String,
9498    /// Space-separated scopes. Defaults to empty string.
9499    pub scope: String,
9500    /// Optional client ID.
9501    pub client_id: Option<String>,
9502    /// Optional client secret.
9503    pub client_secret: Option<String>,
9504}
9505
9506impl OAuth2PasswordRequestForm {
9507    /// Parse the `scope` field into a vector of individual scope strings.
9508    #[must_use]
9509    pub fn scopes(&self) -> Vec<String> {
9510        if self.scope.is_empty() {
9511            Vec::new()
9512        } else {
9513            self.scope.split(' ').map(String::from).collect()
9514        }
9515    }
9516}
9517
9518/// Parse form body into key-value pairs and extract a field by name.
9519fn extract_form_body(
9520    ctx: &RequestContext,
9521    req: &mut Request,
9522) -> Result<QueryParams, OAuth2PasswordFormError> {
9523    // Validate Content-Type
9524    let ct = req
9525        .headers()
9526        .get("content-type")
9527        .and_then(|v| std::str::from_utf8(v).ok());
9528    let is_form = ct.is_some_and(|c| {
9529        c.to_ascii_lowercase()
9530            .starts_with("application/x-www-form-urlencoded")
9531    });
9532    if !is_form {
9533        return Err(OAuth2PasswordFormError::UnsupportedMediaType {
9534            actual: ct.map(String::from),
9535        });
9536    }
9537
9538    // Extract body
9539    let body = req.take_body();
9540    let bytes = match body {
9541        Body::Empty => Vec::new(),
9542        Body::Bytes(b) => b,
9543        Body::Stream(_) => return Err(OAuth2PasswordFormError::StreamingNotSupported),
9544    };
9545
9546    // Check size limit
9547    let limit = ctx.max_body_size();
9548    if bytes.len() > limit {
9549        return Err(OAuth2PasswordFormError::PayloadTooLarge {
9550            size: bytes.len(),
9551            limit,
9552        });
9553    }
9554
9555    // Parse form body
9556    let body_str = std::str::from_utf8(&bytes).map_err(|_| OAuth2PasswordFormError::InvalidUtf8)?;
9557    Ok(QueryParams::parse(body_str))
9558}
9559
9560impl FromRequest for OAuth2PasswordRequestForm {
9561    type Error = OAuth2PasswordFormError;
9562
9563    async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
9564        let _ = ctx.checkpoint();
9565        let params = extract_form_body(ctx, req)?;
9566        let _ = ctx.checkpoint();
9567
9568        let username = params
9569            .get("username")
9570            .ok_or(OAuth2PasswordFormError::MissingUsername)?
9571            .to_string();
9572
9573        let password = params
9574            .get("password")
9575            .ok_or(OAuth2PasswordFormError::MissingPassword)?
9576            .to_string();
9577
9578        let grant_type = params.get("grant_type").map(String::from);
9579        let scope = params.get("scope").map(String::from).unwrap_or_default();
9580        let client_id = params.get("client_id").map(String::from);
9581        let client_secret = params.get("client_secret").map(String::from);
9582
9583        let _ = ctx.checkpoint();
9584        Ok(OAuth2PasswordRequestForm {
9585            grant_type,
9586            username,
9587            password,
9588            scope,
9589            client_id,
9590            client_secret,
9591        })
9592    }
9593}
9594
9595/// Strict variant of [`OAuth2PasswordRequestForm`].
9596///
9597/// Same as [`OAuth2PasswordRequestForm`], but requires `grant_type` to be
9598/// present and equal to `"password"`. Returns an error if `grant_type` is
9599/// missing or has any other value.
9600///
9601/// # Example
9602///
9603/// ```ignore
9604/// use fastapi_core::OAuth2PasswordRequestFormStrict;
9605///
9606/// async fn login(form: OAuth2PasswordRequestFormStrict) -> Response {
9607///     // grant_type is guaranteed to be "password"
9608///     let username = &form.form.username;
9609///     let scopes = form.form.scopes();
9610///     // ...
9611/// }
9612/// ```
9613#[derive(Debug, Clone)]
9614pub struct OAuth2PasswordRequestFormStrict {
9615    /// The validated form data. The `grant_type` is guaranteed to be `Some("password")`.
9616    pub form: OAuth2PasswordRequestForm,
9617}
9618
9619impl OAuth2PasswordRequestFormStrict {
9620    /// Get a reference to the inner form.
9621    #[must_use]
9622    pub fn inner(&self) -> &OAuth2PasswordRequestForm {
9623        &self.form
9624    }
9625
9626    /// Consume self and return the inner form.
9627    #[must_use]
9628    pub fn into_inner(self) -> OAuth2PasswordRequestForm {
9629        self.form
9630    }
9631}
9632
9633impl std::ops::Deref for OAuth2PasswordRequestFormStrict {
9634    type Target = OAuth2PasswordRequestForm;
9635
9636    fn deref(&self) -> &Self::Target {
9637        &self.form
9638    }
9639}
9640
9641impl FromRequest for OAuth2PasswordRequestFormStrict {
9642    type Error = OAuth2PasswordFormError;
9643
9644    async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
9645        let _ = ctx.checkpoint();
9646        let params = extract_form_body(ctx, req)?;
9647        let _ = ctx.checkpoint();
9648
9649        let grant_type_value = params
9650            .get("grant_type")
9651            .ok_or(OAuth2PasswordFormError::MissingGrantType)?;
9652
9653        if grant_type_value != "password" {
9654            return Err(OAuth2PasswordFormError::InvalidGrantType {
9655                actual: grant_type_value.to_string(),
9656            });
9657        }
9658
9659        let username = params
9660            .get("username")
9661            .ok_or(OAuth2PasswordFormError::MissingUsername)?
9662            .to_string();
9663
9664        let password = params
9665            .get("password")
9666            .ok_or(OAuth2PasswordFormError::MissingPassword)?
9667            .to_string();
9668
9669        let scope = params.get("scope").map(String::from).unwrap_or_default();
9670        let client_id = params.get("client_id").map(String::from);
9671        let client_secret = params.get("client_secret").map(String::from);
9672
9673        let _ = ctx.checkpoint();
9674        Ok(OAuth2PasswordRequestFormStrict {
9675            form: OAuth2PasswordRequestForm {
9676                grant_type: Some("password".to_string()),
9677                username,
9678                password,
9679                scope,
9680                client_id,
9681                client_secret,
9682            },
9683        })
9684    }
9685}
9686
9687#[cfg(test)]
9688mod oauth2_password_form_tests {
9689    use super::*;
9690    use crate::request::Method;
9691
9692    fn test_context() -> RequestContext {
9693        let cx = asupersync::Cx::for_testing();
9694        RequestContext::new(cx, 12345)
9695    }
9696
9697    fn form_request(body: &str) -> Request {
9698        let mut req = Request::new(Method::Post, "/token");
9699        req.headers_mut().insert(
9700            "content-type",
9701            b"application/x-www-form-urlencoded".to_vec(),
9702        );
9703        req.set_body(Body::Bytes(body.as_bytes().to_vec()));
9704        req
9705    }
9706
9707    // ---- OAuth2PasswordRequestForm tests ----
9708
9709    #[test]
9710    fn basic_form_extraction() {
9711        let ctx = test_context();
9712        let mut req = form_request("username=alice&password=secret123");
9713        let form =
9714            futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9715                .unwrap();
9716
9717        assert_eq!(form.username, "alice");
9718        assert_eq!(form.password, "secret123");
9719        assert!(form.grant_type.is_none());
9720        assert_eq!(form.scope, "");
9721        assert!(form.client_id.is_none());
9722        assert!(form.client_secret.is_none());
9723    }
9724
9725    #[test]
9726    fn full_form_extraction() {
9727        let ctx = test_context();
9728        let body = "grant_type=password&username=bob&password=s3cret&scope=read+write&client_id=myapp&client_secret=appsecret";
9729        let mut req = form_request(body);
9730        let form =
9731            futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9732                .unwrap();
9733
9734        assert_eq!(form.grant_type.as_deref(), Some("password"));
9735        assert_eq!(form.username, "bob");
9736        assert_eq!(form.password, "s3cret");
9737        assert_eq!(form.scope, "read write");
9738        assert_eq!(form.client_id.as_deref(), Some("myapp"));
9739        assert_eq!(form.client_secret.as_deref(), Some("appsecret"));
9740    }
9741
9742    #[test]
9743    fn scopes_parsing() {
9744        let ctx = test_context();
9745        let mut req = form_request("username=u&password=p&scope=read+write+admin");
9746        let form =
9747            futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9748                .unwrap();
9749
9750        let scopes = form.scopes();
9751        assert_eq!(scopes, vec!["read", "write", "admin"]);
9752    }
9753
9754    #[test]
9755    fn empty_scope_returns_empty_vec() {
9756        let ctx = test_context();
9757        let mut req = form_request("username=u&password=p");
9758        let form =
9759            futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9760                .unwrap();
9761
9762        assert!(form.scopes().is_empty());
9763    }
9764
9765    #[test]
9766    fn missing_username_error() {
9767        let ctx = test_context();
9768        let mut req = form_request("password=secret");
9769        let err =
9770            futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9771                .unwrap_err();
9772
9773        assert_eq!(err, OAuth2PasswordFormError::MissingUsername);
9774        assert!(err.to_string().contains("username"));
9775    }
9776
9777    #[test]
9778    fn missing_password_error() {
9779        let ctx = test_context();
9780        let mut req = form_request("username=alice");
9781        let err =
9782            futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9783                .unwrap_err();
9784
9785        assert_eq!(err, OAuth2PasswordFormError::MissingPassword);
9786        assert!(err.to_string().contains("password"));
9787    }
9788
9789    #[test]
9790    fn wrong_content_type_error() {
9791        let ctx = test_context();
9792        let mut req = Request::new(Method::Post, "/token");
9793        req.headers_mut()
9794            .insert("content-type", b"application/json".to_vec());
9795        req.set_body(Body::Bytes(b"username=a&password=b".to_vec()));
9796
9797        let err =
9798            futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9799                .unwrap_err();
9800
9801        match err {
9802            OAuth2PasswordFormError::UnsupportedMediaType { actual } => {
9803                assert_eq!(actual.as_deref(), Some("application/json"));
9804            }
9805            other => panic!("Expected UnsupportedMediaType, got: {other:?}"),
9806        }
9807    }
9808
9809    #[test]
9810    fn missing_content_type_error() {
9811        let ctx = test_context();
9812        let mut req = Request::new(Method::Post, "/token");
9813        req.set_body(Body::Bytes(b"username=a&password=b".to_vec()));
9814
9815        let err =
9816            futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9817                .unwrap_err();
9818
9819        match err {
9820            OAuth2PasswordFormError::UnsupportedMediaType { actual } => {
9821                assert!(actual.is_none());
9822            }
9823            other => panic!("Expected UnsupportedMediaType, got: {other:?}"),
9824        }
9825    }
9826
9827    #[test]
9828    fn url_encoded_values() {
9829        let ctx = test_context();
9830        let mut req = form_request("username=user%40example.com&password=p%26ss%3Dword");
9831        let form =
9832            futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9833                .unwrap();
9834
9835        assert_eq!(form.username, "user@example.com");
9836        assert_eq!(form.password, "p&ss=word");
9837    }
9838
9839    #[test]
9840    fn plus_decoded_as_space_in_scope() {
9841        let ctx = test_context();
9842        let mut req = form_request("username=u&password=p&scope=read+write+admin");
9843        let form =
9844            futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9845                .unwrap();
9846
9847        // '+' is decoded as space in application/x-www-form-urlencoded
9848        assert_eq!(form.scope, "read write admin");
9849    }
9850
9851    #[test]
9852    fn empty_body_returns_missing_username() {
9853        let ctx = test_context();
9854        let mut req = form_request("");
9855        let err =
9856            futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9857                .unwrap_err();
9858        assert_eq!(err, OAuth2PasswordFormError::MissingUsername);
9859    }
9860
9861    #[test]
9862    fn streaming_not_supported_error_type() {
9863        // Verify the streaming error variant exists and has correct display
9864        let err = OAuth2PasswordFormError::StreamingNotSupported;
9865        assert!(err.to_string().contains("Streaming"));
9866
9867        // Verify it converts to a 400 Bad Request response
9868        let resp = err.into_response();
9869        assert_eq!(resp.status().as_u16(), 400);
9870    }
9871
9872    #[test]
9873    fn error_display_messages() {
9874        assert!(
9875            OAuth2PasswordFormError::MissingUsername
9876                .to_string()
9877                .contains("username")
9878        );
9879        assert!(
9880            OAuth2PasswordFormError::MissingPassword
9881                .to_string()
9882                .contains("password")
9883        );
9884        assert!(
9885            OAuth2PasswordFormError::InvalidUtf8
9886                .to_string()
9887                .contains("UTF-8")
9888        );
9889        assert!(
9890            OAuth2PasswordFormError::StreamingNotSupported
9891                .to_string()
9892                .contains("Streaming")
9893        );
9894
9895        let err = OAuth2PasswordFormError::InvalidGrantType {
9896            actual: "code".to_string(),
9897        };
9898        assert!(err.to_string().contains("code"));
9899
9900        let err = OAuth2PasswordFormError::PayloadTooLarge {
9901            size: 100,
9902            limit: 50,
9903        };
9904        assert!(err.to_string().contains("100"));
9905    }
9906
9907    #[test]
9908    fn error_into_response_status_codes() {
9909        let resp = OAuth2PasswordFormError::MissingUsername.into_response();
9910        assert_eq!(resp.status().as_u16(), 422);
9911
9912        let resp = OAuth2PasswordFormError::MissingPassword.into_response();
9913        assert_eq!(resp.status().as_u16(), 422);
9914
9915        let resp = OAuth2PasswordFormError::UnsupportedMediaType { actual: None }.into_response();
9916        assert_eq!(resp.status().as_u16(), 415);
9917
9918        let resp = OAuth2PasswordFormError::PayloadTooLarge {
9919            size: 100,
9920            limit: 50,
9921        }
9922        .into_response();
9923        assert_eq!(resp.status().as_u16(), 413);
9924
9925        let resp = OAuth2PasswordFormError::InvalidGrantType {
9926            actual: "code".to_string(),
9927        }
9928        .into_response();
9929        assert_eq!(resp.status().as_u16(), 422);
9930    }
9931
9932    // ---- OAuth2PasswordRequestFormStrict tests ----
9933
9934    #[test]
9935    fn strict_accepts_password_grant_type() {
9936        let ctx = test_context();
9937        let mut req = form_request("grant_type=password&username=alice&password=secret");
9938        let form = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
9939            &ctx, &mut req,
9940        ))
9941        .unwrap();
9942
9943        assert_eq!(form.form.grant_type.as_deref(), Some("password"));
9944        assert_eq!(form.username, "alice");
9945        assert_eq!(form.password, "secret");
9946    }
9947
9948    #[test]
9949    fn strict_rejects_missing_grant_type() {
9950        let ctx = test_context();
9951        let mut req = form_request("username=alice&password=secret");
9952        let err = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
9953            &ctx, &mut req,
9954        ))
9955        .unwrap_err();
9956
9957        assert_eq!(err, OAuth2PasswordFormError::MissingGrantType);
9958    }
9959
9960    #[test]
9961    fn strict_rejects_wrong_grant_type() {
9962        let ctx = test_context();
9963        let mut req = form_request("grant_type=authorization_code&username=alice&password=secret");
9964        let err = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
9965            &ctx, &mut req,
9966        ))
9967        .unwrap_err();
9968
9969        match err {
9970            OAuth2PasswordFormError::InvalidGrantType { actual } => {
9971                assert_eq!(actual, "authorization_code");
9972            }
9973            other => panic!("Expected InvalidGrantType, got: {other:?}"),
9974        }
9975    }
9976
9977    #[test]
9978    fn strict_with_all_fields() {
9979        let ctx = test_context();
9980        let body = "grant_type=password&username=bob&password=pw&scope=read+write&client_id=app&client_secret=sec";
9981        let mut req = form_request(body);
9982        let form = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
9983            &ctx, &mut req,
9984        ))
9985        .unwrap();
9986
9987        assert_eq!(form.username, "bob");
9988        assert_eq!(form.password, "pw");
9989        assert_eq!(form.scope, "read write");
9990        assert_eq!(form.client_id.as_deref(), Some("app"));
9991        assert_eq!(form.client_secret.as_deref(), Some("sec"));
9992        assert_eq!(form.scopes(), vec!["read", "write"]);
9993    }
9994
9995    #[test]
9996    fn strict_deref_to_inner() {
9997        let ctx = test_context();
9998        let mut req = form_request("grant_type=password&username=alice&password=pw");
9999        let strict = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
10000            &ctx, &mut req,
10001        ))
10002        .unwrap();
10003
10004        // Deref allows accessing inner form fields directly
10005        let _: &str = &strict.username;
10006        let _: &str = &strict.password;
10007        assert_eq!(strict.inner().username, "alice");
10008    }
10009
10010    #[test]
10011    fn strict_into_inner() {
10012        let ctx = test_context();
10013        let mut req = form_request("grant_type=password&username=alice&password=pw");
10014        let strict = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
10015            &ctx, &mut req,
10016        ))
10017        .unwrap();
10018
10019        let form = strict.into_inner();
10020        assert_eq!(form.username, "alice");
10021        assert_eq!(form.grant_type.as_deref(), Some("password"));
10022    }
10023
10024    #[test]
10025    fn strict_missing_username_after_grant_type() {
10026        let ctx = test_context();
10027        let mut req = form_request("grant_type=password&password=secret");
10028        let err = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
10029            &ctx, &mut req,
10030        ))
10031        .unwrap_err();
10032        assert_eq!(err, OAuth2PasswordFormError::MissingUsername);
10033    }
10034
10035    #[test]
10036    fn strict_missing_password_after_grant_type() {
10037        let ctx = test_context();
10038        let mut req = form_request("grant_type=password&username=alice");
10039        let err = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
10040            &ctx, &mut req,
10041        ))
10042        .unwrap_err();
10043        assert_eq!(err, OAuth2PasswordFormError::MissingPassword);
10044    }
10045}
10046
10047// ============================================================================
10048// OAuth2 Authorization Code Bearer Extractor
10049// ============================================================================
10050
10051/// Configuration for OAuth2 authorization code bearer extraction.
10052///
10053/// This configures the authorization code flow for OpenAPI documentation.
10054/// The actual token extraction is identical to [`OAuth2PasswordBearer`] —
10055/// the difference is in the OpenAPI security scheme generated.
10056///
10057/// # Example
10058///
10059/// ```ignore
10060/// use fastapi_core::OAuth2AuthorizationCodeBearerConfig;
10061///
10062/// let config = OAuth2AuthorizationCodeBearerConfig::new(
10063///     "https://auth.example.com/authorize",
10064///     "https://auth.example.com/token",
10065/// )
10066/// .with_refresh_url("https://auth.example.com/refresh")
10067/// .with_scope("read", "Read access")
10068/// .with_scope("write", "Write access");
10069/// ```
10070#[derive(Debug, Clone)]
10071pub struct OAuth2AuthorizationCodeBearerConfig {
10072    /// URL for the authorization endpoint. Required for OpenAPI documentation.
10073    pub authorization_url: String,
10074    /// URL for the token endpoint. Required for OpenAPI documentation.
10075    pub token_url: String,
10076    /// URL to refresh the token. Optional.
10077    pub refresh_url: Option<String>,
10078    /// OAuth2 scopes with their descriptions.
10079    pub scopes: std::collections::HashMap<String, String>,
10080    /// Custom scheme name for OpenAPI documentation.
10081    pub scheme_name: Option<String>,
10082    /// Description for OpenAPI documentation.
10083    pub description: Option<String>,
10084    /// Whether to automatically return 401 on missing/invalid token.
10085    /// Default: true.
10086    pub auto_error: bool,
10087}
10088
10089impl OAuth2AuthorizationCodeBearerConfig {
10090    /// Create a new configuration with the given authorization and token URLs.
10091    #[must_use]
10092    pub fn new(authorization_url: impl Into<String>, token_url: impl Into<String>) -> Self {
10093        Self {
10094            authorization_url: authorization_url.into(),
10095            token_url: token_url.into(),
10096            refresh_url: None,
10097            scopes: std::collections::HashMap::new(),
10098            scheme_name: None,
10099            description: None,
10100            auto_error: true,
10101        }
10102    }
10103
10104    /// Set the refresh URL.
10105    #[must_use]
10106    pub fn with_refresh_url(mut self, url: impl Into<String>) -> Self {
10107        self.refresh_url = Some(url.into());
10108        self
10109    }
10110
10111    /// Add an OAuth2 scope.
10112    #[must_use]
10113    pub fn with_scope(mut self, scope: impl Into<String>, description: impl Into<String>) -> Self {
10114        self.scopes.insert(scope.into(), description.into());
10115        self
10116    }
10117
10118    /// Set the scheme name for OpenAPI.
10119    #[must_use]
10120    pub fn with_scheme_name(mut self, name: impl Into<String>) -> Self {
10121        self.scheme_name = Some(name.into());
10122        self
10123    }
10124
10125    /// Set the description for OpenAPI.
10126    #[must_use]
10127    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
10128        self.description = Some(desc.into());
10129        self
10130    }
10131
10132    /// Set whether to auto-error on missing/invalid tokens.
10133    #[must_use]
10134    pub fn with_auto_error(mut self, auto_error: bool) -> Self {
10135        self.auto_error = auto_error;
10136        self
10137    }
10138}
10139
10140/// OAuth2 authorization code bearer extractor.
10141///
10142/// Extracts a bearer token from the `Authorization` header, identical to
10143/// [`OAuth2PasswordBearer`]. The difference is purely in OpenAPI documentation:
10144/// this generates an `oauth2` security scheme with `authorizationCode` flow
10145/// instead of `password` flow.
10146///
10147/// # Example
10148///
10149/// ```ignore
10150/// use fastapi_core::OAuth2AuthorizationCodeBearer;
10151///
10152/// async fn protected(token: OAuth2AuthorizationCodeBearer) -> impl IntoResponse {
10153///     let access_token = token.token();
10154///     // Verify token with your auth provider...
10155/// }
10156/// ```
10157#[derive(Debug, Clone)]
10158pub struct OAuth2AuthorizationCodeBearer {
10159    /// The extracted bearer token (without the "Bearer " prefix).
10160    pub token: String,
10161}
10162
10163impl OAuth2AuthorizationCodeBearer {
10164    /// Create a new instance with the given token.
10165    #[must_use]
10166    pub fn new(token: impl Into<String>) -> Self {
10167        Self {
10168            token: token.into(),
10169        }
10170    }
10171
10172    /// Get the token value.
10173    #[must_use]
10174    pub fn token(&self) -> &str {
10175        &self.token
10176    }
10177
10178    /// Consume self and return the token.
10179    #[must_use]
10180    pub fn into_token(self) -> String {
10181        self.token
10182    }
10183}
10184
10185impl Deref for OAuth2AuthorizationCodeBearer {
10186    type Target = str;
10187
10188    fn deref(&self) -> &Self::Target {
10189        &self.token
10190    }
10191}
10192
10193impl FromRequest for OAuth2AuthorizationCodeBearer {
10194    type Error = OAuth2BearerError;
10195
10196    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
10197        // Token extraction is identical to OAuth2PasswordBearer.
10198        // The difference is in the OpenAPI security scheme (authorizationCode flow).
10199        let auth_header = req
10200            .headers()
10201            .get("authorization")
10202            .ok_or_else(OAuth2BearerError::missing_header)?;
10203
10204        let auth_str =
10205            std::str::from_utf8(auth_header).map_err(|_| OAuth2BearerError::invalid_scheme())?;
10206
10207        const BEARER_PREFIX: &str = "Bearer ";
10208        const BEARER_PREFIX_LOWER: &str = "bearer ";
10209
10210        let token = if auth_str.starts_with(BEARER_PREFIX) {
10211            &auth_str[BEARER_PREFIX.len()..]
10212        } else if auth_str.starts_with(BEARER_PREFIX_LOWER) {
10213            &auth_str[BEARER_PREFIX_LOWER.len()..]
10214        } else {
10215            return Err(OAuth2BearerError::invalid_scheme());
10216        };
10217
10218        let token = token.trim();
10219        if token.is_empty() {
10220            return Err(OAuth2BearerError::empty_token());
10221        }
10222
10223        Ok(OAuth2AuthorizationCodeBearer::new(token))
10224    }
10225}
10226
10227#[cfg(test)]
10228mod oauth2_authcode_bearer_tests {
10229    use super::*;
10230    use crate::request::Method;
10231
10232    fn test_context() -> RequestContext {
10233        let cx = asupersync::Cx::for_testing();
10234        RequestContext::new(cx, 12345)
10235    }
10236
10237    // ---- Config tests ----
10238
10239    #[test]
10240    fn config_new() {
10241        let config = OAuth2AuthorizationCodeBearerConfig::new(
10242            "https://auth.example.com/authorize",
10243            "https://auth.example.com/token",
10244        );
10245        assert_eq!(
10246            config.authorization_url,
10247            "https://auth.example.com/authorize"
10248        );
10249        assert_eq!(config.token_url, "https://auth.example.com/token");
10250        assert!(config.refresh_url.is_none());
10251        assert!(config.scopes.is_empty());
10252        assert!(config.scheme_name.is_none());
10253        assert!(config.description.is_none());
10254        assert!(config.auto_error);
10255    }
10256
10257    #[test]
10258    fn config_builder() {
10259        let config = OAuth2AuthorizationCodeBearerConfig::new(
10260            "https://auth.example.com/authorize",
10261            "https://auth.example.com/token",
10262        )
10263        .with_refresh_url("https://auth.example.com/refresh")
10264        .with_scope("read", "Read access")
10265        .with_scope("write", "Write access")
10266        .with_scheme_name("MyAuth")
10267        .with_description("Authorization code flow")
10268        .with_auto_error(false);
10269
10270        assert_eq!(
10271            config.refresh_url.as_deref(),
10272            Some("https://auth.example.com/refresh")
10273        );
10274        assert_eq!(config.scopes.len(), 2);
10275        assert_eq!(config.scopes.get("read").unwrap(), "Read access");
10276        assert_eq!(config.scopes.get("write").unwrap(), "Write access");
10277        assert_eq!(config.scheme_name.as_deref(), Some("MyAuth"));
10278        assert_eq!(
10279            config.description.as_deref(),
10280            Some("Authorization code flow")
10281        );
10282        assert!(!config.auto_error);
10283    }
10284
10285    // ---- Extractor tests ----
10286
10287    #[test]
10288    fn extracts_bearer_token() {
10289        let ctx = test_context();
10290        let mut req = Request::new(Method::Get, "/protected");
10291        req.headers_mut()
10292            .insert("authorization", b"Bearer my-access-token".to_vec());
10293
10294        let result =
10295            futures_executor::block_on(OAuth2AuthorizationCodeBearer::from_request(&ctx, &mut req))
10296                .unwrap();
10297
10298        assert_eq!(result.token(), "my-access-token");
10299        assert_eq!(result.into_token(), "my-access-token");
10300    }
10301
10302    #[test]
10303    fn extracts_lowercase_bearer() {
10304        let ctx = test_context();
10305        let mut req = Request::new(Method::Get, "/protected");
10306        req.headers_mut()
10307            .insert("authorization", b"bearer my-token".to_vec());
10308
10309        let result =
10310            futures_executor::block_on(OAuth2AuthorizationCodeBearer::from_request(&ctx, &mut req))
10311                .unwrap();
10312
10313        assert_eq!(result.token(), "my-token");
10314    }
10315
10316    #[test]
10317    fn missing_header_returns_401() {
10318        let ctx = test_context();
10319        let mut req = Request::new(Method::Get, "/protected");
10320
10321        let err =
10322            futures_executor::block_on(OAuth2AuthorizationCodeBearer::from_request(&ctx, &mut req))
10323                .unwrap_err();
10324
10325        assert_eq!(err.kind, OAuth2BearerErrorKind::MissingHeader);
10326
10327        let resp = err.into_response();
10328        assert_eq!(resp.status().as_u16(), 401);
10329    }
10330
10331    #[test]
10332    fn invalid_scheme_returns_401() {
10333        let ctx = test_context();
10334        let mut req = Request::new(Method::Get, "/protected");
10335        req.headers_mut()
10336            .insert("authorization", b"Basic dXNlcjpwYXNz".to_vec());
10337
10338        let err =
10339            futures_executor::block_on(OAuth2AuthorizationCodeBearer::from_request(&ctx, &mut req))
10340                .unwrap_err();
10341
10342        assert_eq!(err.kind, OAuth2BearerErrorKind::InvalidScheme);
10343    }
10344
10345    #[test]
10346    fn empty_token_returns_error() {
10347        let ctx = test_context();
10348        let mut req = Request::new(Method::Get, "/protected");
10349        req.headers_mut()
10350            .insert("authorization", b"Bearer ".to_vec());
10351
10352        let err =
10353            futures_executor::block_on(OAuth2AuthorizationCodeBearer::from_request(&ctx, &mut req))
10354                .unwrap_err();
10355
10356        assert_eq!(err.kind, OAuth2BearerErrorKind::EmptyToken);
10357    }
10358
10359    #[test]
10360    fn whitespace_only_token_returns_error() {
10361        let ctx = test_context();
10362        let mut req = Request::new(Method::Get, "/protected");
10363        req.headers_mut()
10364            .insert("authorization", b"Bearer   ".to_vec());
10365
10366        let err =
10367            futures_executor::block_on(OAuth2AuthorizationCodeBearer::from_request(&ctx, &mut req))
10368                .unwrap_err();
10369
10370        assert_eq!(err.kind, OAuth2BearerErrorKind::EmptyToken);
10371    }
10372
10373    #[test]
10374    fn token_trimmed() {
10375        let ctx = test_context();
10376        let mut req = Request::new(Method::Get, "/protected");
10377        req.headers_mut()
10378            .insert("authorization", b"Bearer  my-token  ".to_vec());
10379
10380        let result =
10381            futures_executor::block_on(OAuth2AuthorizationCodeBearer::from_request(&ctx, &mut req))
10382                .unwrap();
10383
10384        assert_eq!(result.token(), "my-token");
10385    }
10386
10387    #[test]
10388    fn deref_to_str() {
10389        let bearer = OAuth2AuthorizationCodeBearer::new("abc123");
10390        let s: &str = &bearer;
10391        assert_eq!(s, "abc123");
10392    }
10393
10394    #[test]
10395    fn new_constructor() {
10396        let bearer = OAuth2AuthorizationCodeBearer::new("token-value");
10397        assert_eq!(bearer.token, "token-value");
10398        assert_eq!(bearer.token(), "token-value");
10399    }
10400
10401    #[test]
10402    fn www_authenticate_header_on_error() {
10403        let err = OAuth2BearerError::missing_header();
10404        let resp = err.into_response();
10405
10406        let has_www_auth = resp
10407            .headers()
10408            .iter()
10409            .any(|(n, v)| n.eq_ignore_ascii_case("www-authenticate") && v == b"Bearer");
10410        assert!(
10411            has_www_auth,
10412            "Response should have WWW-Authenticate: Bearer header"
10413        );
10414    }
10415}
10416
10417// ============================================================================
10418// Security Scopes
10419// ============================================================================
10420
10421/// Required OAuth2 security scopes for a handler.
10422///
10423/// `SecurityScopes` aggregates the scopes required by the current handler's
10424/// security dependency chain. It is typically injected alongside a bearer token
10425/// extractor so the handler can verify that the token grants the required
10426/// scopes.
10427///
10428/// In Python FastAPI, `SecurityScopes` is automatically populated from the
10429/// `scopes` parameter on `Security(...)` dependencies. In this Rust
10430/// implementation, scopes are set via request extensions (populated by
10431/// middleware or route configuration) and read by the `FromRequest` impl.
10432///
10433/// # Example
10434///
10435/// ```ignore
10436/// use fastapi_core::{SecurityScopes, BearerToken};
10437///
10438/// async fn get_admin(
10439///     token: BearerToken,
10440///     scopes: SecurityScopes,
10441/// ) -> impl IntoResponse {
10442///     // scopes.scopes() returns ["admin", "users:read"]
10443///     // Verify the token grants all required scopes
10444///     for scope in scopes.scopes() {
10445///         // check token has scope...
10446///     }
10447/// }
10448/// ```
10449#[derive(Debug, Clone)]
10450pub struct SecurityScopes {
10451    /// Required scopes (order preserved, no duplicates).
10452    scopes: Vec<String>,
10453    /// Space-separated scope string.
10454    scope_str: String,
10455}
10456
10457impl SecurityScopes {
10458    /// Create empty security scopes (no scopes required).
10459    #[must_use]
10460    pub fn new() -> Self {
10461        Self {
10462            scopes: Vec::new(),
10463            scope_str: String::new(),
10464        }
10465    }
10466
10467    /// Create from a list of scope strings.
10468    ///
10469    /// Duplicates are removed while preserving order.
10470    #[must_use]
10471    pub fn from_scopes(scopes: impl IntoIterator<Item = impl Into<String>>) -> Self {
10472        let mut seen = std::collections::HashSet::new();
10473        let deduped: Vec<String> = scopes
10474            .into_iter()
10475            .map(Into::into)
10476            .filter(|s| seen.insert(s.clone()))
10477            .collect();
10478        let scope_str = deduped.join(" ");
10479        Self {
10480            scopes: deduped,
10481            scope_str,
10482        }
10483    }
10484
10485    /// Create from a space-separated scope string.
10486    ///
10487    /// The string is split on spaces, empty segments are ignored, and
10488    /// duplicates are removed while preserving order.
10489    #[must_use]
10490    pub fn from_scope_str(scope_str: &str) -> Self {
10491        let parts = scope_str
10492            .split(' ')
10493            .filter(|s| !s.is_empty())
10494            .map(String::from);
10495        Self::from_scopes(parts)
10496    }
10497
10498    /// Get the required scopes as a slice.
10499    #[must_use]
10500    pub fn scopes(&self) -> &[String] {
10501        &self.scopes
10502    }
10503
10504    /// Get the space-separated scope string.
10505    #[must_use]
10506    pub fn scope_str(&self) -> &str {
10507        &self.scope_str
10508    }
10509
10510    /// Check if a specific scope is required.
10511    #[must_use]
10512    pub fn contains(&self, scope: &str) -> bool {
10513        self.scopes.iter().any(|s| s == scope)
10514    }
10515
10516    /// Returns true if no scopes are required.
10517    #[must_use]
10518    pub fn is_empty(&self) -> bool {
10519        self.scopes.is_empty()
10520    }
10521
10522    /// Returns the number of required scopes.
10523    #[must_use]
10524    pub fn len(&self) -> usize {
10525        self.scopes.len()
10526    }
10527
10528    /// Merge another set of scopes into this one.
10529    ///
10530    /// New scopes are appended, preserving order and deduplicating.
10531    pub fn merge(&mut self, other: &SecurityScopes) {
10532        let existing: std::collections::HashSet<String> = self.scopes.iter().cloned().collect();
10533        for scope in &other.scopes {
10534            if !existing.contains(scope) {
10535                self.scopes.push(scope.clone());
10536            }
10537        }
10538        self.scope_str = self.scopes.join(" ");
10539    }
10540
10541    /// Create a new `SecurityScopes` by merging two sets.
10542    #[must_use]
10543    pub fn merged(&self, other: &SecurityScopes) -> Self {
10544        let mut result = self.clone();
10545        result.merge(other);
10546        result
10547    }
10548}
10549
10550impl Default for SecurityScopes {
10551    fn default() -> Self {
10552        Self::new()
10553    }
10554}
10555
10556impl std::fmt::Display for SecurityScopes {
10557    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
10558        f.write_str(&self.scope_str)
10559    }
10560}
10561
10562/// Error when SecurityScopes extraction fails.
10563///
10564/// This occurs when no security scopes have been configured for the route
10565/// (i.e., no `SecurityScopesData` extension exists on the request).
10566#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10567pub struct SecurityScopesError;
10568
10569impl std::fmt::Display for SecurityScopesError {
10570    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
10571        write!(f, "No security scopes configured for this route")
10572    }
10573}
10574
10575impl std::error::Error for SecurityScopesError {}
10576
10577impl IntoResponse for SecurityScopesError {
10578    fn into_response(self) -> Response {
10579        HttpError::new(crate::response::StatusCode::INTERNAL_SERVER_ERROR)
10580            .with_detail("Security scopes not configured")
10581            .into_response()
10582    }
10583}
10584
10585impl FromRequest for SecurityScopes {
10586    type Error = SecurityScopesError;
10587
10588    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
10589        // Look for SecurityScopes stored as a request extension.
10590        // Middleware or route configuration should populate this before the handler runs.
10591        if let Some(scopes) = req.get_extension::<SecurityScopes>() {
10592            Ok(scopes.clone())
10593        } else {
10594            // If no scopes extension exists, return empty scopes (no scopes required).
10595            // This is lenient — handlers that don't configure scopes get an empty set.
10596            Ok(SecurityScopes::new())
10597        }
10598    }
10599}
10600
10601#[cfg(test)]
10602mod security_scopes_tests {
10603    use super::*;
10604    use crate::request::Method;
10605
10606    fn test_context() -> RequestContext {
10607        let cx = asupersync::Cx::for_testing();
10608        RequestContext::new(cx, 12345)
10609    }
10610
10611    // ---- Construction tests ----
10612
10613    #[test]
10614    fn new_is_empty() {
10615        let scopes = SecurityScopes::new();
10616        assert!(scopes.is_empty());
10617        assert_eq!(scopes.len(), 0);
10618        assert_eq!(scopes.scope_str(), "");
10619        assert!(scopes.scopes().is_empty());
10620    }
10621
10622    #[test]
10623    fn default_is_empty() {
10624        let scopes = SecurityScopes::default();
10625        assert!(scopes.is_empty());
10626    }
10627
10628    #[test]
10629    fn from_scopes_preserves_order() {
10630        let scopes = SecurityScopes::from_scopes(["read", "write", "admin"]);
10631        assert_eq!(scopes.scopes(), &["read", "write", "admin"]);
10632        assert_eq!(scopes.scope_str(), "read write admin");
10633        assert_eq!(scopes.len(), 3);
10634    }
10635
10636    #[test]
10637    fn from_scopes_deduplicates() {
10638        let scopes = SecurityScopes::from_scopes(["read", "write", "read", "admin", "write"]);
10639        assert_eq!(scopes.scopes(), &["read", "write", "admin"]);
10640        assert_eq!(scopes.scope_str(), "read write admin");
10641    }
10642
10643    #[test]
10644    fn from_scope_str() {
10645        let scopes = SecurityScopes::from_scope_str("read write admin");
10646        assert_eq!(scopes.scopes(), &["read", "write", "admin"]);
10647        assert_eq!(scopes.scope_str(), "read write admin");
10648    }
10649
10650    #[test]
10651    fn from_scope_str_deduplicates() {
10652        let scopes = SecurityScopes::from_scope_str("read write read admin");
10653        assert_eq!(scopes.scopes(), &["read", "write", "admin"]);
10654    }
10655
10656    #[test]
10657    fn from_scope_str_ignores_empty_segments() {
10658        let scopes = SecurityScopes::from_scope_str("read  write   admin");
10659        assert_eq!(scopes.scopes(), &["read", "write", "admin"]);
10660    }
10661
10662    #[test]
10663    fn from_empty_scope_str() {
10664        let scopes = SecurityScopes::from_scope_str("");
10665        assert!(scopes.is_empty());
10666        assert_eq!(scopes.scope_str(), "");
10667    }
10668
10669    // ---- Query methods ----
10670
10671    #[test]
10672    fn contains_scope() {
10673        let scopes = SecurityScopes::from_scopes(["read", "write"]);
10674        assert!(scopes.contains("read"));
10675        assert!(scopes.contains("write"));
10676        assert!(!scopes.contains("admin"));
10677    }
10678
10679    #[test]
10680    fn display_format() {
10681        let scopes = SecurityScopes::from_scopes(["read", "write"]);
10682        assert_eq!(format!("{scopes}"), "read write");
10683
10684        let empty = SecurityScopes::new();
10685        assert_eq!(format!("{empty}"), "");
10686    }
10687
10688    // ---- Merge tests ----
10689
10690    #[test]
10691    fn merge_appends_new_scopes() {
10692        let mut base = SecurityScopes::from_scopes(["read"]);
10693        let other = SecurityScopes::from_scopes(["write", "admin"]);
10694        base.merge(&other);
10695
10696        assert_eq!(base.scopes(), &["read", "write", "admin"]);
10697        assert_eq!(base.scope_str(), "read write admin");
10698    }
10699
10700    #[test]
10701    fn merge_deduplicates() {
10702        let mut base = SecurityScopes::from_scopes(["read", "write"]);
10703        let other = SecurityScopes::from_scopes(["write", "admin"]);
10704        base.merge(&other);
10705
10706        assert_eq!(base.scopes(), &["read", "write", "admin"]);
10707    }
10708
10709    #[test]
10710    fn merge_empty_into_nonempty() {
10711        let mut base = SecurityScopes::from_scopes(["read"]);
10712        let other = SecurityScopes::new();
10713        base.merge(&other);
10714
10715        assert_eq!(base.scopes(), &["read"]);
10716    }
10717
10718    #[test]
10719    fn merge_nonempty_into_empty() {
10720        let mut base = SecurityScopes::new();
10721        let other = SecurityScopes::from_scopes(["read", "write"]);
10722        base.merge(&other);
10723
10724        assert_eq!(base.scopes(), &["read", "write"]);
10725    }
10726
10727    #[test]
10728    fn merged_returns_new_instance() {
10729        let base = SecurityScopes::from_scopes(["read"]);
10730        let other = SecurityScopes::from_scopes(["write"]);
10731        let combined = base.merged(&other);
10732
10733        assert_eq!(combined.scopes(), &["read", "write"]);
10734        // Original unchanged
10735        assert_eq!(base.scopes(), &["read"]);
10736    }
10737
10738    // ---- FromRequest tests ----
10739
10740    #[test]
10741    fn extract_with_extension() {
10742        let ctx = test_context();
10743        let mut req = Request::new(Method::Get, "/protected");
10744        req.insert_extension(SecurityScopes::from_scopes(["admin", "users:read"]));
10745
10746        let scopes =
10747            futures_executor::block_on(SecurityScopes::from_request(&ctx, &mut req)).unwrap();
10748        assert_eq!(scopes.scopes(), &["admin", "users:read"]);
10749        assert_eq!(scopes.scope_str(), "admin users:read");
10750    }
10751
10752    #[test]
10753    fn extract_without_extension_returns_empty() {
10754        let ctx = test_context();
10755        let mut req = Request::new(Method::Get, "/public");
10756
10757        let scopes =
10758            futures_executor::block_on(SecurityScopes::from_request(&ctx, &mut req)).unwrap();
10759        assert!(scopes.is_empty());
10760    }
10761
10762    // ---- Error tests ----
10763
10764    #[test]
10765    fn error_display() {
10766        let err = SecurityScopesError;
10767        assert!(err.to_string().contains("security scopes"));
10768    }
10769
10770    #[test]
10771    fn error_into_response_is_500() {
10772        let resp = SecurityScopesError.into_response();
10773        assert_eq!(resp.status().as_u16(), 500);
10774    }
10775}
10776
10777// ============================================================================
10778// HTTP Bearer Token Extractor
10779// ============================================================================
10780
10781/// Simple HTTP bearer token extractor.
10782///
10783/// Extracts a bearer token from the `Authorization` header. This is a simpler
10784/// alternative to [`OAuth2PasswordBearer`] when you don't need OAuth2-specific
10785/// functionality like token URLs and scopes.
10786///
10787/// This corresponds to FastAPI's `HTTPBearer` security scheme, which generates
10788/// an OpenAPI security scheme with `type: "http"` and `scheme: "bearer"`.
10789///
10790/// # Example
10791///
10792/// ```ignore
10793/// use fastapi_core::BearerToken;
10794///
10795/// async fn protected_route(token: BearerToken) -> impl IntoResponse {
10796///     // Validate the token
10797///     if verify_token(&token) {
10798///         format!("Token valid: {}", token.token())
10799///     } else {
10800///         // Return error response
10801///     }
10802/// }
10803/// ```
10804///
10805/// # Error Handling
10806///
10807/// When the token is missing or invalid, a 401 Unauthorized response is returned
10808/// with a `WWW-Authenticate: Bearer` header, following RFC 6750.
10809///
10810/// # Optional Extraction
10811///
10812/// Wrap in `Option` to make the token optional:
10813///
10814/// ```ignore
10815/// async fn maybe_auth(token: Option<BearerToken>) -> impl IntoResponse {
10816///     match token {
10817///         Some(t) => format!("Authenticated with: {}", t.token()),
10818///         None => "Anonymous access".to_string(),
10819///     }
10820/// }
10821/// ```
10822///
10823/// # OpenAPI
10824///
10825/// This extractor generates the following OpenAPI security scheme:
10826/// ```yaml
10827/// securitySchemes:
10828///   BearerToken:
10829///     type: http
10830///     scheme: bearer
10831/// ```
10832#[derive(Debug, Clone, PartialEq, Eq)]
10833pub struct BearerToken {
10834    /// The extracted bearer token (without the "Bearer " prefix).
10835    token: String,
10836}
10837
10838impl BearerToken {
10839    /// Create a new BearerToken with the given token value.
10840    #[must_use]
10841    pub fn new(token: impl Into<String>) -> Self {
10842        Self {
10843            token: token.into(),
10844        }
10845    }
10846
10847    /// Get the token value.
10848    #[must_use]
10849    pub fn token(&self) -> &str {
10850        &self.token
10851    }
10852
10853    /// Consume self and return the token string.
10854    #[must_use]
10855    pub fn into_token(self) -> String {
10856        self.token
10857    }
10858}
10859
10860impl Deref for BearerToken {
10861    type Target = str;
10862
10863    fn deref(&self) -> &Self::Target {
10864        &self.token
10865    }
10866}
10867
10868impl AsRef<str> for BearerToken {
10869    fn as_ref(&self) -> &str {
10870        &self.token
10871    }
10872}
10873
10874/// Error when bearer token extraction fails.
10875#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10876pub enum BearerTokenError {
10877    /// The Authorization header is missing.
10878    MissingHeader,
10879    /// The Authorization header doesn't use the Bearer scheme.
10880    InvalidScheme,
10881    /// The token is empty after the "Bearer " prefix.
10882    EmptyToken,
10883}
10884
10885impl BearerTokenError {
10886    /// Create a missing header error.
10887    #[must_use]
10888    pub fn missing_header() -> Self {
10889        Self::MissingHeader
10890    }
10891
10892    /// Create an invalid scheme error.
10893    #[must_use]
10894    pub fn invalid_scheme() -> Self {
10895        Self::InvalidScheme
10896    }
10897
10898    /// Create an empty token error.
10899    #[must_use]
10900    pub fn empty_token() -> Self {
10901        Self::EmptyToken
10902    }
10903
10904    /// Get a human-readable description of this error.
10905    #[must_use]
10906    pub fn detail(&self) -> &'static str {
10907        match self {
10908            Self::MissingHeader => "Not authenticated",
10909            Self::InvalidScheme => "Invalid authentication credentials",
10910            Self::EmptyToken => "Invalid authentication credentials",
10911        }
10912    }
10913}
10914
10915impl fmt::Display for BearerTokenError {
10916    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
10917        match self {
10918            Self::MissingHeader => write!(f, "Missing Authorization header"),
10919            Self::InvalidScheme => write!(f, "Authorization header must use Bearer scheme"),
10920            Self::EmptyToken => write!(f, "Bearer token is empty"),
10921        }
10922    }
10923}
10924
10925impl std::error::Error for BearerTokenError {}
10926
10927impl IntoResponse for BearerTokenError {
10928    fn into_response(self) -> crate::response::Response {
10929        use crate::response::{Response, ResponseBody, StatusCode};
10930
10931        let body = serde_json::json!({
10932            "detail": self.detail()
10933        });
10934
10935        Response::with_status(StatusCode::UNAUTHORIZED)
10936            .header("www-authenticate", b"Bearer".to_vec())
10937            .header("content-type", b"application/json".to_vec())
10938            .body(ResponseBody::Bytes(body.to_string().into_bytes()))
10939    }
10940}
10941
10942impl FromRequest for BearerToken {
10943    type Error = BearerTokenError;
10944
10945    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
10946        // Get the Authorization header
10947        let auth_header = req
10948            .headers()
10949            .get("authorization")
10950            .ok_or(BearerTokenError::MissingHeader)?;
10951
10952        // Convert to string (invalid UTF-8 is treated as invalid scheme)
10953        let auth_str =
10954            std::str::from_utf8(auth_header).map_err(|_| BearerTokenError::InvalidScheme)?;
10955
10956        // Check for "Bearer " prefix (case-sensitive per RFC 6750, but we allow lowercase)
10957        const BEARER_PREFIX: &str = "Bearer ";
10958        const BEARER_PREFIX_LOWER: &str = "bearer ";
10959
10960        let token = if auth_str.starts_with(BEARER_PREFIX) {
10961            &auth_str[BEARER_PREFIX.len()..]
10962        } else if auth_str.starts_with(BEARER_PREFIX_LOWER) {
10963            &auth_str[BEARER_PREFIX_LOWER.len()..]
10964        } else {
10965            return Err(BearerTokenError::InvalidScheme);
10966        };
10967
10968        // Trim whitespace and check for empty token
10969        let token = token.trim();
10970        if token.is_empty() {
10971            return Err(BearerTokenError::EmptyToken);
10972        }
10973
10974        Ok(BearerToken::new(token))
10975    }
10976}
10977
10978// ============================================================================
10979// API Key Header Extractor
10980// ============================================================================
10981
10982/// Default header name for API key extraction.
10983pub const DEFAULT_API_KEY_HEADER: &str = "x-api-key";
10984
10985/// Configuration for API key header extraction.
10986#[derive(Debug, Clone)]
10987pub struct ApiKeyHeaderConfig {
10988    /// Header name to extract API key from (case-insensitive).
10989    header_name: String,
10990}
10991
10992impl Default for ApiKeyHeaderConfig {
10993    fn default() -> Self {
10994        Self {
10995            header_name: DEFAULT_API_KEY_HEADER.to_string(),
10996        }
10997    }
10998}
10999
11000impl ApiKeyHeaderConfig {
11001    /// Create a new configuration with default settings.
11002    #[must_use]
11003    pub fn new() -> Self {
11004        Self::default()
11005    }
11006
11007    /// Set the header name to extract API key from.
11008    #[must_use]
11009    pub fn header_name(mut self, name: impl Into<String>) -> Self {
11010        self.header_name = name.into();
11011        self
11012    }
11013
11014    /// Get the configured header name.
11015    #[must_use]
11016    pub fn get_header_name(&self) -> &str {
11017        &self.header_name
11018    }
11019}
11020
11021/// API key extracted from a request header.
11022///
11023/// Extracts an API key from a configurable header (default: `X-API-Key`).
11024/// Returns 401 Unauthorized if the header is missing or empty.
11025///
11026/// # Example
11027///
11028/// ```ignore
11029/// use fastapi_core::extract::ApiKeyHeader;
11030///
11031/// async fn protected_route(api_key: ApiKeyHeader) -> impl IntoResponse {
11032///     // Validate the API key against your database/config
11033///     if is_valid_key(api_key.key()) {
11034///         "Access granted".to_string()
11035///     } else {
11036///         // Return error response
11037///     }
11038/// }
11039/// ```
11040///
11041/// # Custom Header Name
11042///
11043/// Configure a custom header name by adding `ApiKeyHeaderConfig` to request extensions:
11044///
11045/// ```ignore
11046/// // In middleware or app setup:
11047/// req.insert_extension(ApiKeyHeaderConfig::new().header_name("Authorization"));
11048/// ```
11049///
11050/// # OpenAPI Security Scheme
11051///
11052/// This generates the following OpenAPI security scheme:
11053/// ```yaml
11054/// securitySchemes:
11055///   ApiKeyHeader:
11056///     type: apiKey
11057///     in: header
11058///     name: X-API-Key
11059/// ```
11060#[derive(Debug, Clone, PartialEq, Eq)]
11061pub struct ApiKeyHeader {
11062    /// The extracted API key value.
11063    key: String,
11064    /// The header name it was extracted from.
11065    header_name: String,
11066}
11067
11068impl ApiKeyHeader {
11069    /// Create a new ApiKeyHeader with the given key.
11070    #[must_use]
11071    pub fn new(key: impl Into<String>) -> Self {
11072        Self {
11073            key: key.into(),
11074            header_name: DEFAULT_API_KEY_HEADER.to_string(),
11075        }
11076    }
11077
11078    /// Create a new ApiKeyHeader with a custom header name.
11079    #[must_use]
11080    pub fn with_header_name(key: impl Into<String>, header_name: impl Into<String>) -> Self {
11081        Self {
11082            key: key.into(),
11083            header_name: header_name.into(),
11084        }
11085    }
11086
11087    /// Get the API key value.
11088    #[must_use]
11089    pub fn key(&self) -> &str {
11090        &self.key
11091    }
11092
11093    /// Get the header name the key was extracted from.
11094    #[must_use]
11095    pub fn header_name(&self) -> &str {
11096        &self.header_name
11097    }
11098
11099    /// Consume self and return the key string.
11100    #[must_use]
11101    pub fn into_key(self) -> String {
11102        self.key
11103    }
11104}
11105
11106impl Deref for ApiKeyHeader {
11107    type Target = str;
11108
11109    fn deref(&self) -> &Self::Target {
11110        &self.key
11111    }
11112}
11113
11114impl AsRef<str> for ApiKeyHeader {
11115    fn as_ref(&self) -> &str {
11116        &self.key
11117    }
11118}
11119
11120/// Implement SecureCompare for timing-safe API key validation.
11121impl SecureCompare for ApiKeyHeader {
11122    fn secure_eq(&self, other: &str) -> bool {
11123        constant_time_str_eq(&self.key, other)
11124    }
11125
11126    fn secure_eq_bytes(&self, other: &[u8]) -> bool {
11127        constant_time_eq(self.key.as_bytes(), other)
11128    }
11129}
11130
11131/// Error returned when API key header extraction fails.
11132#[derive(Debug, Clone)]
11133pub enum ApiKeyHeaderError {
11134    /// The API key header is missing.
11135    MissingHeader {
11136        /// Name of the expected header.
11137        header_name: String,
11138    },
11139    /// The API key header is empty.
11140    EmptyKey {
11141        /// Name of the header.
11142        header_name: String,
11143    },
11144    /// The header value is not valid UTF-8.
11145    InvalidUtf8 {
11146        /// Name of the header.
11147        header_name: String,
11148    },
11149}
11150
11151impl ApiKeyHeaderError {
11152    /// Create a missing header error.
11153    #[must_use]
11154    pub fn missing_header(header_name: impl Into<String>) -> Self {
11155        Self::MissingHeader {
11156            header_name: header_name.into(),
11157        }
11158    }
11159
11160    /// Create an empty key error.
11161    #[must_use]
11162    pub fn empty_key(header_name: impl Into<String>) -> Self {
11163        Self::EmptyKey {
11164            header_name: header_name.into(),
11165        }
11166    }
11167
11168    /// Create an invalid UTF-8 error.
11169    #[must_use]
11170    pub fn invalid_utf8(header_name: impl Into<String>) -> Self {
11171        Self::InvalidUtf8 {
11172            header_name: header_name.into(),
11173        }
11174    }
11175
11176    /// Get a human-readable description of this error.
11177    #[must_use]
11178    pub fn detail(&self) -> String {
11179        match self {
11180            Self::MissingHeader { header_name } => {
11181                format!("Missing required header: {header_name}")
11182            }
11183            Self::EmptyKey { header_name } => {
11184                format!("Empty API key in header: {header_name}")
11185            }
11186            Self::InvalidUtf8 { header_name } => {
11187                format!("Invalid API key encoding in header: {header_name}")
11188            }
11189        }
11190    }
11191}
11192
11193impl fmt::Display for ApiKeyHeaderError {
11194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11195        match self {
11196            Self::MissingHeader { header_name } => {
11197                write!(f, "Missing API key header: {header_name}")
11198            }
11199            Self::EmptyKey { header_name } => {
11200                write!(f, "Empty API key in header: {header_name}")
11201            }
11202            Self::InvalidUtf8 { header_name } => {
11203                write!(f, "Invalid UTF-8 in header: {header_name}")
11204            }
11205        }
11206    }
11207}
11208
11209impl std::error::Error for ApiKeyHeaderError {}
11210
11211impl IntoResponse for ApiKeyHeaderError {
11212    fn into_response(self) -> crate::response::Response {
11213        use crate::response::{Response, ResponseBody, StatusCode};
11214
11215        let body = serde_json::json!({
11216            "detail": self.detail()
11217        });
11218
11219        Response::with_status(StatusCode::UNAUTHORIZED)
11220            .header("content-type", b"application/json".to_vec())
11221            .body(ResponseBody::Bytes(body.to_string().into_bytes()))
11222    }
11223}
11224
11225impl FromRequest for ApiKeyHeader {
11226    type Error = ApiKeyHeaderError;
11227
11228    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
11229        // Get config from request extensions or use default
11230        let header_name = req.get_extension::<ApiKeyHeaderConfig>().map_or_else(
11231            || DEFAULT_API_KEY_HEADER.to_string(),
11232            |c| c.get_header_name().to_string(),
11233        );
11234
11235        // Get the API key header (case-insensitive lookup)
11236        let key_header = req
11237            .headers()
11238            .get(&header_name)
11239            .ok_or_else(|| ApiKeyHeaderError::missing_header(&header_name))?;
11240
11241        // Convert to string
11242        let key_str = std::str::from_utf8(key_header)
11243            .map_err(|_| ApiKeyHeaderError::invalid_utf8(&header_name))?;
11244
11245        // Trim whitespace and check for empty key
11246        let key = key_str.trim();
11247        if key.is_empty() {
11248            return Err(ApiKeyHeaderError::empty_key(&header_name));
11249        }
11250
11251        Ok(ApiKeyHeader::with_header_name(key, header_name))
11252    }
11253}
11254
11255// ============================================================================
11256// API Key Query Parameter Extractor
11257// ============================================================================
11258
11259/// Default query parameter name for API key extraction.
11260pub const DEFAULT_API_KEY_QUERY_PARAM: &str = "api_key";
11261
11262/// Configuration for API key query parameter extraction.
11263#[derive(Debug, Clone)]
11264pub struct ApiKeyQueryConfig {
11265    /// Query parameter name to extract API key from.
11266    param_name: String,
11267}
11268
11269impl Default for ApiKeyQueryConfig {
11270    fn default() -> Self {
11271        Self {
11272            param_name: DEFAULT_API_KEY_QUERY_PARAM.to_string(),
11273        }
11274    }
11275}
11276
11277impl ApiKeyQueryConfig {
11278    /// Create a new configuration with default settings.
11279    #[must_use]
11280    pub fn new() -> Self {
11281        Self::default()
11282    }
11283
11284    /// Set the query parameter name to extract the API key from.
11285    #[must_use]
11286    pub fn param_name(mut self, name: impl Into<String>) -> Self {
11287        self.param_name = name.into();
11288        self
11289    }
11290
11291    /// Get the configured parameter name.
11292    #[must_use]
11293    pub fn get_param_name(&self) -> &str {
11294        &self.param_name
11295    }
11296}
11297
11298/// Extracts an API key from a query parameter.
11299///
11300/// This extractor pulls the API key from a configurable query parameter
11301/// (default: `api_key`). Returns 401 Unauthorized if missing or empty.
11302///
11303/// # Security Warning
11304///
11305/// Query parameter API keys are **less secure** than header-based keys:
11306/// - They appear in URL logs (browser history, server logs, proxies)
11307/// - They can leak via the Referer header
11308/// - They may be cached by browsers and intermediate caches
11309///
11310/// Use [`ApiKeyHeader`] for production-grade API key authentication.
11311/// Query parameter keys are primarily useful for:
11312/// - Quick testing/debugging
11313/// - Webhook callbacks where headers aren't controllable
11314/// - Legacy API compatibility
11315///
11316/// # Example
11317///
11318/// ```ignore
11319/// use fastapi_core::extract::ApiKeyQuery;
11320///
11321/// async fn webhook_handler(api_key: ApiKeyQuery) -> impl IntoResponse {
11322///     // Validate the API key
11323///     if api_key.key() == expected_key {
11324///         "Webhook received"
11325///     } else {
11326///         "Invalid API key"
11327///     }
11328/// }
11329/// ```
11330///
11331/// # Custom Parameter Name
11332///
11333/// Configure a custom parameter name by adding `ApiKeyQueryConfig` to request extensions:
11334///
11335/// ```ignore
11336/// // In middleware or app setup:
11337/// req.insert_extension(ApiKeyQueryConfig::new().param_name("token"));
11338/// // Then ?token=xxx will be used instead of ?api_key=xxx
11339/// ```
11340///
11341/// # OpenAPI Security Scheme
11342///
11343/// This generates the following OpenAPI security scheme:
11344/// ```yaml
11345/// securitySchemes:
11346///   ApiKeyQuery:
11347///     type: apiKey
11348///     in: query
11349///     name: api_key
11350/// ```
11351#[derive(Debug, Clone, PartialEq, Eq)]
11352pub struct ApiKeyQuery {
11353    /// The extracted API key value.
11354    key: String,
11355    /// The parameter name it was extracted from.
11356    param_name: String,
11357}
11358
11359impl ApiKeyQuery {
11360    /// Create a new ApiKeyQuery with the given key.
11361    #[must_use]
11362    pub fn new(key: impl Into<String>) -> Self {
11363        Self {
11364            key: key.into(),
11365            param_name: DEFAULT_API_KEY_QUERY_PARAM.to_string(),
11366        }
11367    }
11368
11369    /// Create a new ApiKeyQuery with a custom parameter name.
11370    #[must_use]
11371    pub fn with_param_name(key: impl Into<String>, param_name: impl Into<String>) -> Self {
11372        Self {
11373            key: key.into(),
11374            param_name: param_name.into(),
11375        }
11376    }
11377
11378    /// Get the API key value.
11379    #[must_use]
11380    pub fn key(&self) -> &str {
11381        &self.key
11382    }
11383
11384    /// Get the parameter name it was extracted from.
11385    #[must_use]
11386    pub fn param_name(&self) -> &str {
11387        &self.param_name
11388    }
11389
11390    /// Consume and return the key value.
11391    #[must_use]
11392    pub fn into_key(self) -> String {
11393        self.key
11394    }
11395}
11396
11397impl Deref for ApiKeyQuery {
11398    type Target = str;
11399
11400    fn deref(&self) -> &Self::Target {
11401        &self.key
11402    }
11403}
11404
11405impl AsRef<str> for ApiKeyQuery {
11406    fn as_ref(&self) -> &str {
11407        &self.key
11408    }
11409}
11410
11411/// Implement SecureCompare for timing-safe API key validation.
11412impl SecureCompare for ApiKeyQuery {
11413    fn secure_eq(&self, other: &str) -> bool {
11414        constant_time_str_eq(&self.key, other)
11415    }
11416
11417    fn secure_eq_bytes(&self, other: &[u8]) -> bool {
11418        constant_time_eq(self.key.as_bytes(), other)
11419    }
11420}
11421
11422/// Error returned when API key query parameter extraction fails.
11423#[derive(Debug, Clone)]
11424pub enum ApiKeyQueryError {
11425    /// The API key query parameter is missing.
11426    MissingParam {
11427        /// Name of the expected parameter.
11428        param_name: String,
11429    },
11430    /// The API key parameter is present but empty.
11431    EmptyKey {
11432        /// Name of the parameter.
11433        param_name: String,
11434    },
11435}
11436
11437impl ApiKeyQueryError {
11438    /// Create a missing parameter error.
11439    #[must_use]
11440    pub fn missing_param(param_name: impl Into<String>) -> Self {
11441        Self::MissingParam {
11442            param_name: param_name.into(),
11443        }
11444    }
11445
11446    /// Create an empty key error.
11447    #[must_use]
11448    pub fn empty_key(param_name: impl Into<String>) -> Self {
11449        Self::EmptyKey {
11450            param_name: param_name.into(),
11451        }
11452    }
11453
11454    /// Get the detail message for error responses.
11455    #[must_use]
11456    pub fn detail(&self) -> String {
11457        match self {
11458            Self::MissingParam { param_name } => {
11459                format!("API key required. Include '{param_name}' query parameter.")
11460            }
11461            Self::EmptyKey { param_name } => {
11462                format!("API key cannot be empty. Provide a value for '{param_name}'.")
11463            }
11464        }
11465    }
11466}
11467
11468impl fmt::Display for ApiKeyQueryError {
11469    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11470        match self {
11471            Self::MissingParam { param_name } => {
11472                write!(f, "Missing API key query parameter: {param_name}")
11473            }
11474            Self::EmptyKey { param_name } => {
11475                write!(f, "Empty API key in query parameter: {param_name}")
11476            }
11477        }
11478    }
11479}
11480
11481impl std::error::Error for ApiKeyQueryError {}
11482
11483impl IntoResponse for ApiKeyQueryError {
11484    fn into_response(self) -> crate::response::Response {
11485        use crate::response::{Response, ResponseBody, StatusCode};
11486
11487        let body = serde_json::json!({
11488            "detail": self.detail()
11489        });
11490
11491        Response::with_status(StatusCode::UNAUTHORIZED)
11492            .header("content-type", b"application/json".to_vec())
11493            .body(ResponseBody::Bytes(body.to_string().into_bytes()))
11494    }
11495}
11496
11497impl FromRequest for ApiKeyQuery {
11498    type Error = ApiKeyQueryError;
11499
11500    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
11501        // Get config from request extensions or use default
11502        let param_name = req.get_extension::<ApiKeyQueryConfig>().map_or_else(
11503            || DEFAULT_API_KEY_QUERY_PARAM.to_string(),
11504            |c| c.get_param_name().to_string(),
11505        );
11506
11507        // Parse the query string if present
11508        let query_params = req.query().map(QueryParams::parse).unwrap_or_default();
11509
11510        // Get the API key parameter
11511        let key_value = query_params
11512            .get(&param_name)
11513            .ok_or_else(|| ApiKeyQueryError::missing_param(&param_name))?;
11514
11515        // Trim whitespace and check for empty key
11516        let key = key_value.trim();
11517        if key.is_empty() {
11518            return Err(ApiKeyQueryError::empty_key(&param_name));
11519        }
11520
11521        Ok(ApiKeyQuery::with_param_name(key, param_name))
11522    }
11523}
11524
11525// ============================================================================
11526// API Key Cookie Extractor
11527// ============================================================================
11528
11529/// Default cookie name for API key extraction.
11530pub const DEFAULT_API_KEY_COOKIE: &str = "api_key";
11531
11532/// Configuration for API key cookie extraction.
11533#[derive(Debug, Clone)]
11534pub struct ApiKeyCookieConfig {
11535    /// Cookie name to extract API key from.
11536    cookie_name: String,
11537}
11538
11539impl Default for ApiKeyCookieConfig {
11540    fn default() -> Self {
11541        Self {
11542            cookie_name: DEFAULT_API_KEY_COOKIE.to_string(),
11543        }
11544    }
11545}
11546
11547impl ApiKeyCookieConfig {
11548    /// Create a new configuration with default settings.
11549    #[must_use]
11550    pub fn new() -> Self {
11551        Self::default()
11552    }
11553
11554    /// Set the cookie name to extract the API key from.
11555    #[must_use]
11556    pub fn cookie_name(mut self, name: impl Into<String>) -> Self {
11557        self.cookie_name = name.into();
11558        self
11559    }
11560
11561    /// Get the configured cookie name.
11562    #[must_use]
11563    pub fn get_cookie_name(&self) -> &str {
11564        &self.cookie_name
11565    }
11566}
11567
11568/// Extracts an API key from a cookie.
11569///
11570/// This extractor pulls the API key from a configurable cookie
11571/// (default: `api_key`). Returns 401 Unauthorized if missing or empty.
11572///
11573/// # Security Considerations
11574///
11575/// Cookie-based API keys have different security characteristics than headers:
11576/// - Automatically sent by browsers (enables browser-based API access)
11577/// - Subject to CSRF attacks (use with CSRF protection middleware)
11578/// - Can be marked `HttpOnly` to prevent JavaScript access
11579/// - Can be marked `Secure` to require HTTPS
11580///
11581/// For browser-based APIs, consider pairing with CSRF protection.
11582/// For programmatic API access, prefer [`ApiKeyHeader`].
11583///
11584/// # Example
11585///
11586/// ```ignore
11587/// use fastapi_core::extract::ApiKeyCookie;
11588///
11589/// async fn protected_endpoint(api_key: ApiKeyCookie) -> impl IntoResponse {
11590///     // Validate the API key
11591///     if api_key.secure_eq(expected_key) {
11592///         "Access granted"
11593///     } else {
11594///         "Invalid API key"
11595///     }
11596/// }
11597/// ```
11598///
11599/// # Custom Cookie Name
11600///
11601/// Configure a custom cookie name by adding `ApiKeyCookieConfig` to request extensions:
11602///
11603/// ```ignore
11604/// // In middleware or app setup:
11605/// req.insert_extension(ApiKeyCookieConfig::new().cookie_name("auth_token"));
11606/// // Then the auth_token cookie will be used
11607/// ```
11608///
11609/// # OpenAPI Security Scheme
11610///
11611/// This generates the following OpenAPI security scheme:
11612/// ```yaml
11613/// securitySchemes:
11614///   ApiKeyCookie:
11615///     type: apiKey
11616///     in: cookie
11617///     name: api_key
11618/// ```
11619#[derive(Debug, Clone, PartialEq, Eq)]
11620pub struct ApiKeyCookie {
11621    /// The extracted API key value.
11622    key: String,
11623    /// The cookie name it was extracted from.
11624    cookie_name: String,
11625}
11626
11627impl ApiKeyCookie {
11628    /// Create a new ApiKeyCookie with the given key.
11629    #[must_use]
11630    pub fn new(key: impl Into<String>) -> Self {
11631        Self {
11632            key: key.into(),
11633            cookie_name: DEFAULT_API_KEY_COOKIE.to_string(),
11634        }
11635    }
11636
11637    /// Create a new ApiKeyCookie with a custom cookie name.
11638    #[must_use]
11639    pub fn with_cookie_name(key: impl Into<String>, cookie_name: impl Into<String>) -> Self {
11640        Self {
11641            key: key.into(),
11642            cookie_name: cookie_name.into(),
11643        }
11644    }
11645
11646    /// Get the API key value.
11647    #[must_use]
11648    pub fn key(&self) -> &str {
11649        &self.key
11650    }
11651
11652    /// Get the cookie name it was extracted from.
11653    #[must_use]
11654    pub fn cookie_name(&self) -> &str {
11655        &self.cookie_name
11656    }
11657
11658    /// Consume and return the key value.
11659    #[must_use]
11660    pub fn into_key(self) -> String {
11661        self.key
11662    }
11663}
11664
11665impl Deref for ApiKeyCookie {
11666    type Target = str;
11667
11668    fn deref(&self) -> &Self::Target {
11669        &self.key
11670    }
11671}
11672
11673impl AsRef<str> for ApiKeyCookie {
11674    fn as_ref(&self) -> &str {
11675        &self.key
11676    }
11677}
11678
11679/// Implement SecureCompare for timing-safe API key validation.
11680impl SecureCompare for ApiKeyCookie {
11681    fn secure_eq(&self, other: &str) -> bool {
11682        constant_time_str_eq(&self.key, other)
11683    }
11684
11685    fn secure_eq_bytes(&self, other: &[u8]) -> bool {
11686        constant_time_eq(self.key.as_bytes(), other)
11687    }
11688}
11689
11690/// Error returned when API key cookie extraction fails.
11691#[derive(Debug, Clone)]
11692pub enum ApiKeyCookieError {
11693    /// The API key cookie is missing.
11694    MissingCookie {
11695        /// Name of the expected cookie.
11696        cookie_name: String,
11697    },
11698    /// The API key cookie is present but empty.
11699    EmptyKey {
11700        /// Name of the cookie.
11701        cookie_name: String,
11702    },
11703}
11704
11705impl ApiKeyCookieError {
11706    /// Create a missing cookie error.
11707    #[must_use]
11708    pub fn missing_cookie(cookie_name: impl Into<String>) -> Self {
11709        Self::MissingCookie {
11710            cookie_name: cookie_name.into(),
11711        }
11712    }
11713
11714    /// Create an empty key error.
11715    #[must_use]
11716    pub fn empty_key(cookie_name: impl Into<String>) -> Self {
11717        Self::EmptyKey {
11718            cookie_name: cookie_name.into(),
11719        }
11720    }
11721
11722    /// Get the detail message for error responses.
11723    #[must_use]
11724    pub fn detail(&self) -> String {
11725        match self {
11726            Self::MissingCookie { cookie_name } => {
11727                format!("API key required. Include '{cookie_name}' cookie.")
11728            }
11729            Self::EmptyKey { cookie_name } => {
11730                format!("API key cannot be empty. Provide a value for '{cookie_name}' cookie.")
11731            }
11732        }
11733    }
11734}
11735
11736impl fmt::Display for ApiKeyCookieError {
11737    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11738        match self {
11739            Self::MissingCookie { cookie_name } => {
11740                write!(f, "Missing API key cookie: {cookie_name}")
11741            }
11742            Self::EmptyKey { cookie_name } => {
11743                write!(f, "Empty API key in cookie: {cookie_name}")
11744            }
11745        }
11746    }
11747}
11748
11749impl std::error::Error for ApiKeyCookieError {}
11750
11751impl IntoResponse for ApiKeyCookieError {
11752    fn into_response(self) -> crate::response::Response {
11753        use crate::response::{Response, ResponseBody, StatusCode};
11754
11755        let body = serde_json::json!({
11756            "detail": self.detail()
11757        });
11758
11759        Response::with_status(StatusCode::UNAUTHORIZED)
11760            .header("content-type", b"application/json".to_vec())
11761            .body(ResponseBody::Bytes(body.to_string().into_bytes()))
11762    }
11763}
11764
11765impl FromRequest for ApiKeyCookie {
11766    type Error = ApiKeyCookieError;
11767
11768    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
11769        // Get config from request extensions or use default
11770        let cookie_name = req.get_extension::<ApiKeyCookieConfig>().map_or_else(
11771            || DEFAULT_API_KEY_COOKIE.to_string(),
11772            |c| c.get_cookie_name().to_string(),
11773        );
11774
11775        // Parse cookies from the Cookie header
11776        let cookies = req
11777            .headers()
11778            .get("cookie")
11779            .and_then(|v| std::str::from_utf8(v).ok())
11780            .map(RequestCookies::from_header)
11781            .unwrap_or_default();
11782
11783        // Get the API key cookie
11784        let key_value = cookies
11785            .get(&cookie_name)
11786            .ok_or_else(|| ApiKeyCookieError::missing_cookie(&cookie_name))?;
11787
11788        // Trim whitespace and check for empty key
11789        let key = key_value.trim();
11790        if key.is_empty() {
11791            return Err(ApiKeyCookieError::empty_key(&cookie_name));
11792        }
11793
11794        Ok(ApiKeyCookie::with_cookie_name(key, cookie_name))
11795    }
11796}
11797
11798// ============================================================================
11799// Basic Authentication Extractor
11800// ============================================================================
11801
11802/// Extracts HTTP Basic authentication credentials from the `Authorization` header.
11803///
11804/// This implements the HTTP Basic authentication scheme as defined in RFC 7617.
11805/// The Authorization header should contain `Basic <base64(username:password)>`.
11806///
11807/// # Example
11808///
11809/// ```ignore
11810/// use fastapi_core::BasicAuth;
11811///
11812/// async fn protected_route(auth: BasicAuth) -> impl IntoResponse {
11813///     format!("Hello, {}!", auth.username())
11814/// }
11815/// ```
11816///
11817/// # Error Handling
11818///
11819/// When credentials are missing or invalid, a 401 Unauthorized response is returned
11820/// with a `WWW-Authenticate: Basic` header, following RFC 7617.
11821///
11822/// # Optional Extraction
11823///
11824/// Wrap in `Option` to make authentication optional:
11825///
11826/// ```ignore
11827/// async fn maybe_auth(auth: Option<BasicAuth>) -> impl IntoResponse {
11828///     match auth {
11829///         Some(a) => format!("Hello, {}!", a.username()),
11830///         None => "Anonymous access".to_string(),
11831///     }
11832/// }
11833/// ```
11834///
11835/// # OpenAPI
11836///
11837/// This extractor generates the following OpenAPI security scheme:
11838/// ```yaml
11839/// securitySchemes:
11840///   BasicAuth:
11841///     type: http
11842///     scheme: basic
11843/// ```
11844#[derive(Debug, Clone, PartialEq, Eq)]
11845pub struct BasicAuth {
11846    /// The username extracted from the credentials.
11847    username: String,
11848    /// The password extracted from the credentials.
11849    password: String,
11850}
11851
11852impl BasicAuth {
11853    /// Create a new BasicAuth with the given username and password.
11854    #[must_use]
11855    pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
11856        Self {
11857            username: username.into(),
11858            password: password.into(),
11859        }
11860    }
11861
11862    /// Get the username.
11863    #[must_use]
11864    pub fn username(&self) -> &str {
11865        &self.username
11866    }
11867
11868    /// Get the password.
11869    #[must_use]
11870    pub fn password(&self) -> &str {
11871        &self.password
11872    }
11873
11874    /// Consume self and return the username and password as a tuple.
11875    #[must_use]
11876    pub fn into_credentials(self) -> (String, String) {
11877        (self.username, self.password)
11878    }
11879}
11880
11881/// Error when basic auth extraction fails.
11882#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11883pub enum BasicAuthError {
11884    /// The Authorization header is missing.
11885    MissingHeader,
11886    /// The Authorization header doesn't use the Basic scheme.
11887    InvalidScheme,
11888    /// The credentials are not valid base64.
11889    InvalidBase64,
11890    /// The decoded credentials don't contain a colon separator.
11891    MissingColon,
11892    /// The header value is not valid UTF-8.
11893    InvalidUtf8,
11894}
11895
11896impl BasicAuthError {
11897    /// Create a missing header error.
11898    #[must_use]
11899    pub fn missing_header() -> Self {
11900        Self::MissingHeader
11901    }
11902
11903    /// Create an invalid scheme error.
11904    #[must_use]
11905    pub fn invalid_scheme() -> Self {
11906        Self::InvalidScheme
11907    }
11908
11909    /// Create an invalid base64 error.
11910    #[must_use]
11911    pub fn invalid_base64() -> Self {
11912        Self::InvalidBase64
11913    }
11914
11915    /// Create a missing colon error.
11916    #[must_use]
11917    pub fn missing_colon() -> Self {
11918        Self::MissingColon
11919    }
11920
11921    /// Create an invalid UTF-8 error.
11922    #[must_use]
11923    pub fn invalid_utf8() -> Self {
11924        Self::InvalidUtf8
11925    }
11926
11927    /// Get a human-readable description of this error.
11928    #[must_use]
11929    pub fn detail(&self) -> &'static str {
11930        match self {
11931            Self::MissingHeader => "Not authenticated",
11932            Self::InvalidScheme => "Invalid authentication credentials",
11933            Self::InvalidBase64 => "Invalid authentication credentials",
11934            Self::MissingColon => "Invalid authentication credentials",
11935            Self::InvalidUtf8 => "Invalid authentication credentials",
11936        }
11937    }
11938}
11939
11940impl fmt::Display for BasicAuthError {
11941    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11942        match self {
11943            Self::MissingHeader => write!(f, "Missing Authorization header"),
11944            Self::InvalidScheme => write!(f, "Authorization header must use Basic scheme"),
11945            Self::InvalidBase64 => write!(f, "Invalid base64 encoding in credentials"),
11946            Self::MissingColon => write!(f, "Credentials must contain username:password"),
11947            Self::InvalidUtf8 => write!(f, "Credentials contain invalid UTF-8"),
11948        }
11949    }
11950}
11951
11952impl std::error::Error for BasicAuthError {}
11953
11954impl IntoResponse for BasicAuthError {
11955    fn into_response(self) -> crate::response::Response {
11956        use crate::response::{Response, ResponseBody, StatusCode};
11957
11958        let body = serde_json::json!({
11959            "detail": self.detail()
11960        });
11961
11962        Response::with_status(StatusCode::UNAUTHORIZED)
11963            .header("www-authenticate", b"Basic".to_vec())
11964            .header("content-type", b"application/json".to_vec())
11965            .body(ResponseBody::Bytes(body.to_string().into_bytes()))
11966    }
11967}
11968
11969/// Decode a base64 string to bytes.
11970///
11971/// This is a minimal implementation for Basic auth credential decoding.
11972/// Supports standard base64 alphabet (A-Za-z0-9+/) with optional padding.
11973fn decode_base64(input: &str) -> Result<Vec<u8>, BasicAuthError> {
11974    const INVALID: u8 = 0xFF;
11975    const DECODE_TABLE: [u8; 256] = {
11976        let mut table = [INVALID; 256];
11977        let mut i = 0u8;
11978        // A-Z = 0-25
11979        while i < 26 {
11980            table[(b'A' + i) as usize] = i;
11981            i += 1;
11982        }
11983        // a-z = 26-51
11984        i = 0;
11985        while i < 26 {
11986            table[(b'a' + i) as usize] = 26 + i;
11987            i += 1;
11988        }
11989        // 0-9 = 52-61
11990        i = 0;
11991        while i < 10 {
11992            table[(b'0' + i) as usize] = 52 + i;
11993            i += 1;
11994        }
11995        // + = 62, / = 63
11996        table[b'+' as usize] = 62;
11997        table[b'/' as usize] = 63;
11998        table
11999    };
12000
12001    // Remove padding and whitespace
12002    let input = input.trim_end_matches('=').trim();
12003    if input.is_empty() {
12004        return Ok(Vec::new());
12005    }
12006
12007    let mut output = Vec::with_capacity((input.len() * 3) / 4);
12008    let mut buffer: u32 = 0;
12009    let mut bits_collected: u8 = 0;
12010
12011    for byte in input.bytes() {
12012        let value = DECODE_TABLE[byte as usize];
12013        if value == INVALID {
12014            return Err(BasicAuthError::InvalidBase64);
12015        }
12016
12017        buffer = (buffer << 6) | u32::from(value);
12018        bits_collected += 6;
12019
12020        if bits_collected >= 8 {
12021            bits_collected -= 8;
12022            output.push((buffer >> bits_collected) as u8);
12023            buffer &= (1 << bits_collected) - 1;
12024        }
12025    }
12026
12027    Ok(output)
12028}
12029
12030impl FromRequest for BasicAuth {
12031    type Error = BasicAuthError;
12032
12033    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
12034        // Get the Authorization header
12035        let auth_header = req
12036            .headers()
12037            .get("authorization")
12038            .ok_or(BasicAuthError::MissingHeader)?;
12039
12040        // Convert to string
12041        let auth_str = std::str::from_utf8(auth_header).map_err(|_| BasicAuthError::InvalidUtf8)?;
12042
12043        // Check for "Basic " prefix (case-insensitive per RFC 7617)
12044        const BASIC_PREFIX: &str = "Basic ";
12045        const BASIC_PREFIX_LOWER: &str = "basic ";
12046
12047        let encoded = if auth_str.starts_with(BASIC_PREFIX) {
12048            &auth_str[BASIC_PREFIX.len()..]
12049        } else if auth_str.starts_with(BASIC_PREFIX_LOWER) {
12050            &auth_str[BASIC_PREFIX_LOWER.len()..]
12051        } else {
12052            return Err(BasicAuthError::InvalidScheme);
12053        };
12054
12055        // Decode base64
12056        let decoded_bytes = decode_base64(encoded.trim())?;
12057
12058        // Convert to UTF-8 string
12059        let decoded = String::from_utf8(decoded_bytes).map_err(|_| BasicAuthError::InvalidUtf8)?;
12060
12061        // Split on first colon (password may contain colons)
12062        let colon_pos = decoded.find(':').ok_or(BasicAuthError::MissingColon)?;
12063        let (username, password_with_colon) = decoded.split_at(colon_pos);
12064        let password = &password_with_colon[1..]; // Skip the colon
12065
12066        Ok(BasicAuth::new(username, password))
12067    }
12068}
12069
12070// ============================================================================
12071// HTTP Digest Auth Extractor (Stub)
12072// ============================================================================
12073
12074/// HTTP Digest authentication credentials extractor (stub).
12075///
12076/// Extracts the raw `Authorization: Digest ...` header value without
12077/// implementing the full Digest challenge-response protocol (RFC 7616).
12078/// This mirrors Python FastAPI's behavior of providing a stub for
12079/// Digest auth.
12080///
12081/// The credentials string contains the raw Digest parameters
12082/// (username, realm, nonce, uri, response, etc.) which the
12083/// application must validate.
12084#[derive(Debug, Clone)]
12085pub struct DigestAuth {
12086    /// The raw credentials string after "Digest ".
12087    credentials: String,
12088}
12089
12090impl DigestAuth {
12091    /// Create a new DigestAuth with raw credentials.
12092    #[must_use]
12093    pub fn new(credentials: impl Into<String>) -> Self {
12094        Self {
12095            credentials: credentials.into(),
12096        }
12097    }
12098
12099    /// Get the raw credentials string.
12100    #[must_use]
12101    pub fn credentials(&self) -> &str {
12102        &self.credentials
12103    }
12104
12105    /// Extract a parameter value from the Digest credentials.
12106    ///
12107    /// Looks for `key="value"` or `key=value` in the credentials string.
12108    /// The key must appear at a word boundary (start of string or after a comma).
12109    pub fn param(&self, key: &str) -> Option<&str> {
12110        let search = format!("{key}=");
12111        let mut search_start = 0;
12112
12113        // Find the key at a proper word boundary
12114        while let Some(pos) = self.credentials[search_start..].find(&search) {
12115            let abs_pos = search_start + pos;
12116
12117            // Check that we're at a word boundary: start of string, or after comma/whitespace
12118            let at_boundary = if abs_pos == 0 {
12119                true
12120            } else {
12121                // Look at the character before the match
12122                let prev_char = self.credentials[..abs_pos].chars().next_back();
12123                matches!(prev_char, Some(',' | ' ' | '\t'))
12124            };
12125
12126            if at_boundary {
12127                // Found a valid match at word boundary
12128                let after_eq = &self.credentials[abs_pos + search.len()..];
12129                if after_eq.starts_with('"') {
12130                    // Quoted value
12131                    let inner = &after_eq[1..];
12132                    let end = inner.find('"')?;
12133                    return Some(&inner[..end]);
12134                }
12135                // Unquoted value
12136                let end = after_eq.find(',').unwrap_or(after_eq.len());
12137                return Some(after_eq[..end].trim());
12138            }
12139
12140            // Not at boundary, continue searching after this position
12141            search_start = abs_pos + 1;
12142        }
12143
12144        None
12145    }
12146}
12147
12148/// Error when Digest auth extraction fails.
12149#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12150pub enum DigestAuthError {
12151    /// The Authorization header is missing.
12152    MissingHeader,
12153    /// The Authorization header doesn't use the Digest scheme.
12154    InvalidScheme,
12155    /// The header value is not valid UTF-8.
12156    InvalidUtf8,
12157}
12158
12159impl DigestAuthError {
12160    /// Get a human-readable description of this error.
12161    #[must_use]
12162    pub fn detail(&self) -> &'static str {
12163        match self {
12164            Self::MissingHeader => "Not authenticated",
12165            Self::InvalidScheme => "Invalid authentication credentials",
12166            Self::InvalidUtf8 => "Invalid authentication credentials",
12167        }
12168    }
12169}
12170
12171impl fmt::Display for DigestAuthError {
12172    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12173        match self {
12174            Self::MissingHeader => write!(f, "Missing Authorization header"),
12175            Self::InvalidScheme => write!(f, "Authorization header must use Digest scheme"),
12176            Self::InvalidUtf8 => write!(f, "Authorization header contains invalid UTF-8"),
12177        }
12178    }
12179}
12180
12181impl std::error::Error for DigestAuthError {}
12182
12183impl IntoResponse for DigestAuthError {
12184    fn into_response(self) -> crate::response::Response {
12185        use crate::response::{Response, ResponseBody, StatusCode};
12186
12187        let body = serde_json::json!({
12188            "detail": self.detail()
12189        });
12190
12191        Response::with_status(StatusCode::UNAUTHORIZED)
12192            .header("www-authenticate", b"Digest".to_vec())
12193            .header("content-type", b"application/json".to_vec())
12194            .body(ResponseBody::Bytes(body.to_string().into_bytes()))
12195    }
12196}
12197
12198impl FromRequest for DigestAuth {
12199    type Error = DigestAuthError;
12200
12201    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
12202        let auth_header = req
12203            .headers()
12204            .get("authorization")
12205            .ok_or(DigestAuthError::MissingHeader)?;
12206
12207        let auth_str =
12208            std::str::from_utf8(auth_header).map_err(|_| DigestAuthError::InvalidUtf8)?;
12209
12210        const DIGEST_PREFIX: &str = "Digest ";
12211        const DIGEST_PREFIX_LOWER: &str = "digest ";
12212
12213        let credentials = if auth_str.starts_with(DIGEST_PREFIX) {
12214            &auth_str[DIGEST_PREFIX.len()..]
12215        } else if auth_str.starts_with(DIGEST_PREFIX_LOWER) {
12216            &auth_str[DIGEST_PREFIX_LOWER.len()..]
12217        } else {
12218            return Err(DigestAuthError::InvalidScheme);
12219        };
12220
12221        Ok(DigestAuth::new(credentials.trim()))
12222    }
12223}
12224
12225// ============================================================================
12226// Timing-Safe Comparison Utilities
12227// ============================================================================
12228
12229/// Performs constant-time comparison of two byte slices.
12230///
12231/// This function compares two byte slices in a way that takes the same amount
12232/// of time regardless of where (or if) the slices differ. This prevents
12233/// [timing attacks](https://en.wikipedia.org/wiki/Timing_attack) where an
12234/// attacker can deduce secret values by measuring comparison time.
12235///
12236/// # Security Properties
12237///
12238/// - **Constant time**: Always iterates through all bytes regardless of mismatches
12239/// - **No early return**: Uses bitwise OR to accumulate differences
12240/// - **Length-safe**: Returns false for different lengths (but length itself may leak)
12241///
12242/// # Timing Attack Prevention
12243///
12244/// A naive comparison like `a == b` returns as soon as it finds a difference:
12245/// - `"secret" == "aaaaaa"` returns immediately (first byte differs)
12246/// - `"secret" == "saaaaa"` takes slightly longer (second byte differs)
12247/// - `"secret" == "seaaaa"` takes even longer, etc.
12248///
12249/// An attacker can exploit this to guess a secret character-by-character.
12250/// This function prevents that by always examining all bytes.
12251///
12252/// # Warning: Length Leakage
12253///
12254/// While the comparison itself is constant-time, this function returns `false`
12255/// immediately if the lengths differ. This is intentional for most use cases,
12256/// but be aware that an attacker may be able to determine the length of secret
12257/// values. For HMAC comparison, this is typically acceptable since HMACs have
12258/// fixed, known lengths.
12259///
12260/// # Example
12261///
12262/// ```ignore
12263/// use fastapi_core::constant_time_eq;
12264///
12265/// let secret_token = b"supersecrettoken12345";
12266/// let user_input = b"supersecrettoken12345";
12267///
12268/// if constant_time_eq(secret_token, user_input) {
12269///     // Tokens match - grant access
12270/// } else {
12271///     // Tokens don't match - deny access
12272/// }
12273/// ```
12274///
12275/// # When to Use
12276///
12277/// Use this function when comparing:
12278/// - Authentication tokens
12279/// - API keys
12280/// - Session IDs
12281/// - HMAC signatures
12282/// - Password hashes (after hashing)
12283/// - Any secret value where timing attacks are a concern
12284#[must_use]
12285pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
12286    // Length check - this does leak length information, but for most auth
12287    // scenarios (HMAC, tokens) the length is known/fixed
12288    if a.len() != b.len() {
12289        return false;
12290    }
12291
12292    // Accumulate XOR of all byte pairs. Any difference sets bits in `diff`.
12293    // This always processes all bytes regardless of where differences occur.
12294    let diff = a
12295        .iter()
12296        .zip(b.iter())
12297        .fold(0u8, |acc, (x, y)| acc | (x ^ y));
12298
12299    // diff == 0 only if all bytes matched
12300    diff == 0
12301}
12302
12303/// Performs constant-time comparison of two strings.
12304///
12305/// This is a convenience wrapper around [`constant_time_eq`] that works with
12306/// string slices. Internally, it compares the UTF-8 byte representations.
12307///
12308/// See [`constant_time_eq`] for full documentation on timing attack prevention.
12309///
12310/// # Example
12311///
12312/// ```ignore
12313/// use fastapi_core::constant_time_str_eq;
12314///
12315/// let stored_token = "user_api_key_xyz123";
12316/// let provided_token = get_token_from_header();
12317///
12318/// if constant_time_str_eq(stored_token, &provided_token) {
12319///     // Valid token
12320/// }
12321/// ```
12322#[must_use]
12323#[inline]
12324pub fn constant_time_str_eq(a: &str, b: &str) -> bool {
12325    constant_time_eq(a.as_bytes(), b.as_bytes())
12326}
12327
12328/// Extension trait for constant-time equality comparison on `BearerToken`.
12329///
12330/// Provides a method to securely compare bearer tokens without leaking
12331/// timing information about where tokens differ.
12332///
12333/// # Example
12334///
12335/// ```ignore
12336/// use fastapi_core::{BearerToken, SecureCompare};
12337///
12338/// async fn validate_token(token: BearerToken) -> bool {
12339///     let expected = "valid_api_key_12345";
12340///     token.secure_eq(expected)
12341/// }
12342/// ```
12343pub trait SecureCompare {
12344    /// Compares this value with another using constant-time comparison.
12345    ///
12346    /// Returns `true` if the values are equal, `false` otherwise.
12347    /// The comparison time is independent of where (or if) the values differ.
12348    fn secure_eq(&self, other: &str) -> bool;
12349
12350    /// Compares this value with a byte slice using constant-time comparison.
12351    fn secure_eq_bytes(&self, other: &[u8]) -> bool;
12352}
12353
12354impl SecureCompare for BearerToken {
12355    #[inline]
12356    fn secure_eq(&self, other: &str) -> bool {
12357        constant_time_str_eq(self.token(), other)
12358    }
12359
12360    #[inline]
12361    fn secure_eq_bytes(&self, other: &[u8]) -> bool {
12362        constant_time_eq(self.token().as_bytes(), other)
12363    }
12364}
12365
12366impl SecureCompare for str {
12367    #[inline]
12368    fn secure_eq(&self, other: &str) -> bool {
12369        constant_time_str_eq(self, other)
12370    }
12371
12372    #[inline]
12373    fn secure_eq_bytes(&self, other: &[u8]) -> bool {
12374        constant_time_eq(self.as_bytes(), other)
12375    }
12376}
12377
12378impl SecureCompare for String {
12379    #[inline]
12380    fn secure_eq(&self, other: &str) -> bool {
12381        constant_time_str_eq(self, other)
12382    }
12383
12384    #[inline]
12385    fn secure_eq_bytes(&self, other: &[u8]) -> bool {
12386        constant_time_eq(self.as_bytes(), other)
12387    }
12388}
12389
12390impl SecureCompare for [u8] {
12391    #[inline]
12392    fn secure_eq(&self, other: &str) -> bool {
12393        constant_time_eq(self, other.as_bytes())
12394    }
12395
12396    #[inline]
12397    fn secure_eq_bytes(&self, other: &[u8]) -> bool {
12398        constant_time_eq(self, other)
12399    }
12400}
12401
12402impl<const N: usize> SecureCompare for [u8; N] {
12403    #[inline]
12404    fn secure_eq(&self, other: &str) -> bool {
12405        constant_time_eq(self, other.as_bytes())
12406    }
12407
12408    #[inline]
12409    fn secure_eq_bytes(&self, other: &[u8]) -> bool {
12410        constant_time_eq(self, other)
12411    }
12412}
12413
12414impl SecureCompare for Vec<u8> {
12415    #[inline]
12416    fn secure_eq(&self, other: &str) -> bool {
12417        constant_time_eq(self, other.as_bytes())
12418    }
12419
12420    #[inline]
12421    fn secure_eq_bytes(&self, other: &[u8]) -> bool {
12422        constant_time_eq(self, other)
12423    }
12424}
12425
12426// ============================================================================
12427// Pagination Extractor and Response
12428// ============================================================================
12429
12430/// Default page number (1-indexed).
12431pub const DEFAULT_PAGE: u64 = 1;
12432/// Default items per page.
12433pub const DEFAULT_PER_PAGE: u64 = 20;
12434/// Maximum items per page (to prevent abuse).
12435pub const MAX_PER_PAGE: u64 = 100;
12436
12437/// Pagination query parameters extractor.
12438///
12439/// Extracts common pagination parameters from the query string:
12440/// - `page`: Current page number (1-indexed, default: 1)
12441/// - `per_page` or `limit`: Items per page (default: 20, max: 100)
12442/// - `offset`: Alternative to page-based pagination (overrides page if set)
12443///
12444/// # Example
12445///
12446/// ```ignore
12447/// use fastapi_core::Pagination;
12448///
12449/// #[get("/items")]
12450/// async fn list_items(cx: &Cx, pagination: Pagination) -> impl IntoResponse {
12451///     let offset = pagination.offset();
12452///     let limit = pagination.limit();
12453///
12454///     // Fetch items from database with offset and limit
12455///     let items = db.fetch_items(offset, limit).await;
12456///
12457///     // Return paginated response
12458///     pagination.paginate(items, total_count, "/items")
12459/// }
12460/// ```
12461///
12462/// # Query String Formats
12463///
12464/// ```text
12465/// # Page-based (preferred)
12466/// ?page=2&per_page=10
12467///
12468/// # Using limit alias
12469/// ?page=2&limit=10
12470///
12471/// # Offset-based (for cursor-style pagination)
12472/// ?offset=20&limit=10
12473/// ```
12474///
12475/// # Configuration
12476///
12477/// Use [`PaginationConfig`] to customize defaults and limits:
12478///
12479/// ```ignore
12480/// use fastapi_core::{Pagination, PaginationConfig};
12481///
12482/// let config = PaginationConfig::new()
12483///     .default_per_page(50)
12484///     .max_per_page(200);
12485///
12486/// // The config can be stored in app state and used by handlers
12487/// ```
12488#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12489pub struct Pagination {
12490    /// Current page (1-indexed).
12491    page: u64,
12492    /// Items per page.
12493    per_page: u64,
12494    /// Explicit offset (overrides page calculation if set).
12495    offset: Option<u64>,
12496}
12497
12498impl Default for Pagination {
12499    fn default() -> Self {
12500        Self {
12501            page: DEFAULT_PAGE,
12502            per_page: DEFAULT_PER_PAGE,
12503            offset: None,
12504        }
12505    }
12506}
12507
12508impl Pagination {
12509    /// Create pagination with specific values.
12510    #[must_use]
12511    pub fn new(page: u64, per_page: u64) -> Self {
12512        Self {
12513            page: page.max(1),
12514            per_page: per_page.clamp(1, MAX_PER_PAGE),
12515            offset: None,
12516        }
12517    }
12518
12519    /// Create pagination from offset and limit.
12520    #[must_use]
12521    pub fn from_offset(offset: u64, limit: u64) -> Self {
12522        Self {
12523            page: (offset / limit.max(1)) + 1,
12524            per_page: limit.clamp(1, MAX_PER_PAGE),
12525            offset: Some(offset),
12526        }
12527    }
12528
12529    /// Get the current page number (1-indexed).
12530    #[must_use]
12531    pub fn page(&self) -> u64 {
12532        self.page
12533    }
12534
12535    /// Get the number of items per page.
12536    #[must_use]
12537    pub fn per_page(&self) -> u64 {
12538        self.per_page
12539    }
12540
12541    /// Alias for `per_page()` - returns the page size limit.
12542    #[must_use]
12543    pub fn limit(&self) -> u64 {
12544        self.per_page
12545    }
12546
12547    /// Calculate the offset for database queries.
12548    ///
12549    /// If an explicit offset was provided, returns that.
12550    /// Otherwise, calculates from page number: `(page - 1) * per_page`.
12551    #[must_use]
12552    pub fn offset(&self) -> u64 {
12553        self.offset
12554            .unwrap_or_else(|| (self.page.saturating_sub(1)) * self.per_page)
12555    }
12556
12557    /// Calculate total number of pages given a total item count.
12558    #[must_use]
12559    pub fn total_pages(&self, total_items: u64) -> u64 {
12560        if self.per_page == 0 {
12561            return 0;
12562        }
12563        total_items.div_ceil(self.per_page)
12564    }
12565
12566    /// Check if there is a next page.
12567    #[must_use]
12568    pub fn has_next(&self, total_items: u64) -> bool {
12569        self.page < self.total_pages(total_items)
12570    }
12571
12572    /// Check if there is a previous page.
12573    #[must_use]
12574    pub fn has_prev(&self) -> bool {
12575        self.page > 1
12576    }
12577
12578    /// Create a paginated response from items.
12579    ///
12580    /// # Arguments
12581    ///
12582    /// * `items` - The items for the current page
12583    /// * `total` - Total number of items across all pages
12584    /// * `base_url` - Base URL for generating Link headers (e.g., "/api/items")
12585    #[must_use]
12586    pub fn paginate<T>(self, items: Vec<T>, total: u64, base_url: &str) -> Page<T> {
12587        Page::new(items, total, self, base_url.to_string())
12588    }
12589}
12590
12591impl FromRequest for Pagination {
12592    type Error = std::convert::Infallible;
12593
12594    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
12595        let query = req
12596            .get_extension::<QueryParams>()
12597            .cloned()
12598            .unwrap_or_default();
12599
12600        // Parse page (1-indexed, default 1)
12601        let page = query
12602            .get("page")
12603            .and_then(|v: &str| v.parse::<u64>().ok())
12604            .unwrap_or(DEFAULT_PAGE)
12605            .max(1);
12606
12607        // Parse per_page or limit (default 20, max 100)
12608        let per_page = query
12609            .get("per_page")
12610            .or_else(|| query.get("limit"))
12611            .and_then(|v: &str| v.parse::<u64>().ok())
12612            .unwrap_or(DEFAULT_PER_PAGE)
12613            .clamp(1, MAX_PER_PAGE);
12614
12615        // Parse offset (optional, overrides page if present)
12616        let offset = query
12617            .get("offset")
12618            .and_then(|v: &str| v.parse::<u64>().ok());
12619
12620        Ok(Pagination {
12621            page,
12622            per_page,
12623            offset,
12624        })
12625    }
12626}
12627
12628/// Configuration for pagination behavior.
12629///
12630/// Use this to customize default values and limits for pagination.
12631#[derive(Debug, Clone, Copy)]
12632pub struct PaginationConfig {
12633    /// Default items per page when not specified.
12634    pub default_per_page: u64,
12635    /// Maximum allowed items per page.
12636    pub max_per_page: u64,
12637    /// Default page number (usually 1).
12638    pub default_page: u64,
12639}
12640
12641impl Default for PaginationConfig {
12642    fn default() -> Self {
12643        Self {
12644            default_per_page: DEFAULT_PER_PAGE,
12645            max_per_page: MAX_PER_PAGE,
12646            default_page: DEFAULT_PAGE,
12647        }
12648    }
12649}
12650
12651impl PaginationConfig {
12652    /// Create a new pagination configuration with defaults.
12653    #[must_use]
12654    pub fn new() -> Self {
12655        Self::default()
12656    }
12657
12658    /// Set the default items per page.
12659    #[must_use]
12660    pub fn default_per_page(mut self, value: u64) -> Self {
12661        self.default_per_page = value;
12662        self
12663    }
12664
12665    /// Set the maximum items per page.
12666    #[must_use]
12667    pub fn max_per_page(mut self, value: u64) -> Self {
12668        self.max_per_page = value;
12669        self
12670    }
12671
12672    /// Set the default page number.
12673    #[must_use]
12674    pub fn default_page(mut self, value: u64) -> Self {
12675        self.default_page = value;
12676        self
12677    }
12678}
12679
12680/// Paginated response wrapper.
12681///
12682/// Wraps a collection of items with pagination metadata and generates
12683/// RFC 5988 Link headers for navigation.
12684///
12685/// # JSON Response Format
12686///
12687/// ```json
12688/// {
12689///     "items": [...],
12690///     "total": 100,
12691///     "page": 2,
12692///     "per_page": 20,
12693///     "pages": 5
12694/// }
12695/// ```
12696///
12697/// # Link Headers
12698///
12699/// When converted to a response, includes RFC 5988 Link headers:
12700///
12701/// ```text
12702/// Link: </items?page=1&per_page=20>; rel="first",
12703///       </items?page=1&per_page=20>; rel="prev",
12704///       </items?page=3&per_page=20>; rel="next",
12705///       </items?page=5&per_page=20>; rel="last"
12706/// ```
12707///
12708/// # Example
12709///
12710/// ```ignore
12711/// use fastapi_core::{Pagination, Page};
12712///
12713/// #[get("/users")]
12714/// async fn list_users(cx: &Cx, pagination: Pagination) -> impl IntoResponse {
12715///     let users = fetch_users(pagination.offset(), pagination.limit()).await;
12716///     let total = count_users().await;
12717///
12718///     pagination.paginate(users, total, "/users")
12719/// }
12720/// ```
12721#[derive(Debug, Clone)]
12722pub struct Page<T> {
12723    /// Items for the current page.
12724    pub items: Vec<T>,
12725    /// Total number of items across all pages.
12726    pub total: u64,
12727    /// Current page number (1-indexed).
12728    pub page: u64,
12729    /// Items per page.
12730    pub per_page: u64,
12731    /// Total number of pages.
12732    pub pages: u64,
12733    /// Base URL for Link header generation.
12734    base_url: String,
12735}
12736
12737impl<T> Page<T> {
12738    /// Create a new paginated response.
12739    #[must_use]
12740    pub fn new(items: Vec<T>, total: u64, pagination: Pagination, base_url: String) -> Self {
12741        let pages = pagination.total_pages(total);
12742        Self {
12743            items,
12744            total,
12745            page: pagination.page(),
12746            per_page: pagination.per_page(),
12747            pages,
12748            base_url,
12749        }
12750    }
12751
12752    /// Create a page with explicit values (for testing or manual construction).
12753    #[must_use]
12754    pub fn with_values(
12755        items: Vec<T>,
12756        total: u64,
12757        page: u64,
12758        per_page: u64,
12759        base_url: impl Into<String>,
12760    ) -> Self {
12761        let pages = if per_page > 0 {
12762            total.div_ceil(per_page)
12763        } else {
12764            0
12765        };
12766        Self {
12767            items,
12768            total,
12769            page,
12770            per_page,
12771            pages,
12772            base_url: base_url.into(),
12773        }
12774    }
12775
12776    /// Get the number of items on the current page.
12777    #[must_use]
12778    pub fn len(&self) -> usize {
12779        self.items.len()
12780    }
12781
12782    /// Check if the page is empty.
12783    #[must_use]
12784    pub fn is_empty(&self) -> bool {
12785        self.items.is_empty()
12786    }
12787
12788    /// Check if there is a next page.
12789    #[must_use]
12790    pub fn has_next(&self) -> bool {
12791        self.page < self.pages
12792    }
12793
12794    /// Check if there is a previous page.
12795    #[must_use]
12796    pub fn has_prev(&self) -> bool {
12797        self.page > 1
12798    }
12799
12800    /// Generate RFC 5988 Link header value.
12801    ///
12802    /// Returns a string with Link headers for navigation:
12803    /// - `first`: Link to the first page
12804    /// - `prev`: Link to the previous page (if applicable)
12805    /// - `next`: Link to the next page (if applicable)
12806    /// - `last`: Link to the last page
12807    #[must_use]
12808    pub fn link_header(&self) -> String {
12809        let mut links = Vec::with_capacity(4);
12810
12811        // Always include first and last
12812        links.push(format!(
12813            "<{}?page=1&per_page={}>; rel=\"first\"",
12814            self.base_url, self.per_page
12815        ));
12816
12817        // Previous page (if not on first page)
12818        if self.has_prev() {
12819            links.push(format!(
12820                "<{}?page={}&per_page={}>; rel=\"prev\"",
12821                self.base_url,
12822                self.page - 1,
12823                self.per_page
12824            ));
12825        }
12826
12827        // Next page (if not on last page)
12828        if self.has_next() {
12829            links.push(format!(
12830                "<{}?page={}&per_page={}>; rel=\"next\"",
12831                self.base_url,
12832                self.page + 1,
12833                self.per_page
12834            ));
12835        }
12836
12837        // Last page
12838        links.push(format!(
12839            "<{}?page={}&per_page={}>; rel=\"last\"",
12840            self.base_url, self.pages, self.per_page
12841        ));
12842
12843        links.join(", ")
12844    }
12845
12846    /// Map the items using a transformation function.
12847    pub fn map<U, F>(self, f: F) -> Page<U>
12848    where
12849        F: FnMut(T) -> U,
12850    {
12851        Page {
12852            items: self.items.into_iter().map(f).collect(),
12853            total: self.total,
12854            page: self.page,
12855            per_page: self.per_page,
12856            pages: self.pages,
12857            base_url: self.base_url,
12858        }
12859    }
12860}
12861
12862/// JSON representation of a paginated response.
12863#[derive(serde::Serialize)]
12864struct PageJson<'a, T: serde::Serialize> {
12865    items: &'a Vec<T>,
12866    total: u64,
12867    page: u64,
12868    per_page: u64,
12869    pages: u64,
12870}
12871
12872impl<T: serde::Serialize> IntoResponse for Page<T> {
12873    fn into_response(self) -> crate::response::Response {
12874        let json_body = PageJson {
12875            items: &self.items,
12876            total: self.total,
12877            page: self.page,
12878            per_page: self.per_page,
12879            pages: self.pages,
12880        };
12881
12882        // Serialize to JSON
12883        let Ok(body_bytes) = serde_json::to_vec(&json_body) else {
12884            // Fallback to empty error response on serialization failure
12885            return crate::response::Response::with_status(
12886                crate::response::StatusCode::INTERNAL_SERVER_ERROR,
12887            )
12888            .header("content-type", b"application/json".to_vec())
12889            .body(crate::response::ResponseBody::Bytes(
12890                b"{\"error\":\"Serialization failed\"}".to_vec(),
12891            ));
12892        };
12893
12894        // Build response with Link header
12895        let link_header = self.link_header();
12896
12897        crate::response::Response::ok()
12898            .header("content-type", b"application/json".to_vec())
12899            .header("link", link_header.into_bytes())
12900            .header("x-total-count", self.total.to_string().into_bytes())
12901            .header("x-page", self.page.to_string().into_bytes())
12902            .header("x-per-page", self.per_page.to_string().into_bytes())
12903            .header("x-total-pages", self.pages.to_string().into_bytes())
12904            .body(crate::response::ResponseBody::Bytes(body_bytes))
12905    }
12906}
12907
12908/// Multiple header values extractor.
12909///
12910/// Extracts all values for a header that may appear multiple times.
12911#[derive(Debug, Clone)]
12912pub struct HeaderValues<T, N> {
12913    /// All extracted header values.
12914    pub values: Vec<T>,
12915    _marker: std::marker::PhantomData<N>,
12916}
12917
12918impl<T, N> HeaderValues<T, N> {
12919    /// Create a new header values wrapper.
12920    #[must_use]
12921    pub fn new(values: Vec<T>) -> Self {
12922        Self {
12923            values,
12924            _marker: std::marker::PhantomData,
12925        }
12926    }
12927
12928    /// Returns true if no values were extracted.
12929    #[must_use]
12930    pub fn is_empty(&self) -> bool {
12931        self.values.is_empty()
12932    }
12933
12934    /// Returns the number of values.
12935    #[must_use]
12936    pub fn len(&self) -> usize {
12937        self.values.len()
12938    }
12939}
12940
12941impl<T, N> Deref for HeaderValues<T, N> {
12942    type Target = Vec<T>;
12943
12944    fn deref(&self) -> &Self::Target {
12945        &self.values
12946    }
12947}
12948
12949// ============================================================================
12950// Content Negotiation
12951// ============================================================================
12952
12953/// A parsed media type (MIME type) with optional parameters.
12954///
12955/// Represents types like `text/html`, `application/json`, or
12956/// `text/html; charset=utf-8`.
12957#[derive(Debug, Clone, PartialEq)]
12958pub struct MediaType {
12959    /// The primary type (e.g., "text", "application", "image").
12960    pub typ: String,
12961    /// The subtype (e.g., "html", "json", "png").
12962    pub subtype: String,
12963    /// Optional parameters (e.g., charset, boundary).
12964    pub params: Vec<(String, String)>,
12965}
12966
12967impl MediaType {
12968    /// Parse a media type string.
12969    ///
12970    /// # Example
12971    ///
12972    /// ```ignore
12973    /// let mt = MediaType::parse("text/html; charset=utf-8").unwrap();
12974    /// assert_eq!(mt.typ, "text");
12975    /// assert_eq!(mt.subtype, "html");
12976    /// ```
12977    pub fn parse(s: &str) -> Option<Self> {
12978        let s = s.trim();
12979        let (type_part, params_part) = match s.find(';') {
12980            Some(pos) => (&s[..pos], Some(&s[pos + 1..])),
12981            None => (s, None),
12982        };
12983
12984        let (typ, subtype) = type_part.split_once('/')?;
12985        let typ = typ.trim().to_ascii_lowercase();
12986        let subtype = subtype.trim().to_ascii_lowercase();
12987
12988        if typ.is_empty() || subtype.is_empty() {
12989            return None;
12990        }
12991
12992        let mut params = Vec::new();
12993        if let Some(params_str) = params_part {
12994            for param in params_str.split(';') {
12995                let param = param.trim();
12996                if param.is_empty() {
12997                    continue;
12998                }
12999                if let Some((key, value)) = param.split_once('=') {
13000                    let key = key.trim().to_ascii_lowercase();
13001                    let value = value.trim().trim_matches('"').to_string();
13002                    // Skip the quality parameter, handled separately
13003                    if key != "q" {
13004                        params.push((key, value));
13005                    }
13006                }
13007            }
13008        }
13009
13010        Some(Self {
13011            typ,
13012            subtype,
13013            params,
13014        })
13015    }
13016
13017    /// Create a new media type without parameters.
13018    #[must_use]
13019    pub fn new(typ: impl Into<String>, subtype: impl Into<String>) -> Self {
13020        Self {
13021            typ: typ.into().to_ascii_lowercase(),
13022            subtype: subtype.into().to_ascii_lowercase(),
13023            params: Vec::new(),
13024        }
13025    }
13026
13027    /// Check if this media type matches another, supporting wildcards.
13028    ///
13029    /// A wildcard `*` matches any value.
13030    ///
13031    /// # Example
13032    ///
13033    /// ```ignore
13034    /// let html = MediaType::new("text", "html");
13035    /// let any_text = MediaType::new("text", "*");
13036    /// let any = MediaType::new("*", "*");
13037    ///
13038    /// assert!(html.matches(&any_text));
13039    /// assert!(html.matches(&any));
13040    /// assert!(!any_text.matches(&html)); // wildcard doesn't match specific
13041    /// ```
13042    #[must_use]
13043    pub fn matches(&self, other: &MediaType) -> bool {
13044        let type_matches = other.typ == "*" || self.typ == other.typ;
13045        let subtype_matches = other.subtype == "*" || self.subtype == other.subtype;
13046        type_matches && subtype_matches
13047    }
13048
13049    /// Returns the media type as a string without parameters.
13050    #[must_use]
13051    pub fn essence(&self) -> String {
13052        format!("{}/{}", self.typ, self.subtype)
13053    }
13054
13055    /// Get a parameter value by name.
13056    #[must_use]
13057    pub fn param(&self, name: &str) -> Option<&str> {
13058        let name_lower = name.to_ascii_lowercase();
13059        self.params
13060            .iter()
13061            .find(|(k, _)| k == &name_lower)
13062            .map(|(_, v)| v.as_str())
13063    }
13064}
13065
13066impl fmt::Display for MediaType {
13067    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13068        write!(f, "{}/{}", self.typ, self.subtype)?;
13069        for (key, value) in &self.params {
13070            write!(f, "; {}={}", key, value)?;
13071        }
13072        Ok(())
13073    }
13074}
13075
13076/// A single entry from an Accept header with quality value.
13077#[derive(Debug, Clone)]
13078pub struct AcceptItem {
13079    /// The media type.
13080    pub media_type: MediaType,
13081    /// Quality value (0.0 to 1.0, default 1.0).
13082    pub quality: f32,
13083}
13084
13085impl AcceptItem {
13086    /// Parse a single Accept header item.
13087    ///
13088    /// # Example
13089    ///
13090    /// ```ignore
13091    /// let item = AcceptItem::parse("text/html;q=0.9").unwrap();
13092    /// assert_eq!(item.media_type.typ, "text");
13093    /// assert_eq!(item.quality, 0.9);
13094    /// ```
13095    pub fn parse(s: &str) -> Option<Self> {
13096        let s = s.trim();
13097        let mut quality = 1.0f32;
13098
13099        // Extract quality parameter
13100        let media_str = if let Some(q_pos) = s.to_ascii_lowercase().find(";q=") {
13101            let after_q = &s[q_pos + 3..];
13102            let q_end = after_q.find(';').unwrap_or(after_q.len());
13103            let q_str = &after_q[..q_end];
13104            if let Ok(q) = q_str.trim().parse::<f32>() {
13105                quality = q.clamp(0.0, 1.0);
13106            }
13107            // Remove the q parameter from the string for media type parsing
13108            let before = &s[..q_pos];
13109            let after = if q_end < after_q.len() {
13110                &after_q[q_end..]
13111            } else {
13112                ""
13113            };
13114            format!("{}{}", before, after)
13115        } else {
13116            s.to_string()
13117        };
13118
13119        let media_type = MediaType::parse(&media_str)?;
13120        Some(Self {
13121            media_type,
13122            quality,
13123        })
13124    }
13125}
13126
13127impl PartialEq for AcceptItem {
13128    fn eq(&self, other: &Self) -> bool {
13129        self.media_type == other.media_type && (self.quality - other.quality).abs() < f32::EPSILON
13130    }
13131}
13132
13133/// Parsed Accept header with quality-ordered media types.
13134///
13135/// This extractor parses the Accept header and provides the list of
13136/// acceptable media types sorted by quality (highest first).
13137///
13138/// # Example
13139///
13140/// ```ignore
13141/// use fastapi_core::AcceptHeader;
13142///
13143/// async fn handler(accept: AcceptHeader) -> impl IntoResponse {
13144///     if accept.prefers("application/json") {
13145///         Json(data).into_response()
13146///     } else if accept.prefers("text/html") {
13147///         Html(template).into_response()
13148///     } else {
13149///         // Default to JSON
13150///         Json(data).into_response()
13151///     }
13152/// }
13153/// ```
13154#[derive(Debug, Clone)]
13155pub struct AcceptHeader {
13156    /// Media types sorted by quality (highest first).
13157    pub items: Vec<AcceptItem>,
13158}
13159
13160impl AcceptHeader {
13161    /// Parse an Accept header value.
13162    ///
13163    /// # Example
13164    ///
13165    /// ```ignore
13166    /// let accept = AcceptHeader::parse("text/html, application/json;q=0.9, */*;q=0.1");
13167    /// assert_eq!(accept.items.len(), 3);
13168    /// assert_eq!(accept.items[0].media_type.subtype, "html"); // q=1.0
13169    /// assert_eq!(accept.items[1].media_type.subtype, "json"); // q=0.9
13170    /// ```
13171    #[must_use]
13172    pub fn parse(s: &str) -> Self {
13173        let mut items: Vec<AcceptItem> = s.split(',').filter_map(AcceptItem::parse).collect();
13174
13175        // Sort by quality descending, then by specificity
13176        items.sort_by(|a, b| {
13177            // Higher quality first
13178            let q_cmp = b
13179                .quality
13180                .partial_cmp(&a.quality)
13181                .unwrap_or(std::cmp::Ordering::Equal);
13182            if q_cmp != std::cmp::Ordering::Equal {
13183                return q_cmp;
13184            }
13185            // More specific types first (fewer wildcards)
13186            let a_wildcards =
13187                u8::from(a.media_type.typ == "*") + u8::from(a.media_type.subtype == "*");
13188            let b_wildcards =
13189                u8::from(b.media_type.typ == "*") + u8::from(b.media_type.subtype == "*");
13190            a_wildcards.cmp(&b_wildcards)
13191        });
13192
13193        Self { items }
13194    }
13195
13196    /// Create an AcceptHeader that accepts anything.
13197    #[must_use]
13198    pub fn any() -> Self {
13199        Self {
13200            items: vec![AcceptItem {
13201                media_type: MediaType::new("*", "*"),
13202                quality: 1.0,
13203            }],
13204        }
13205    }
13206
13207    /// Check if a media type is acceptable.
13208    #[must_use]
13209    pub fn accepts(&self, media_type: &str) -> bool {
13210        if self.items.is_empty() {
13211            return true; // No Accept header means accept anything
13212        }
13213
13214        let Some(mt) = MediaType::parse(media_type) else {
13215            return false;
13216        };
13217
13218        self.items
13219            .iter()
13220            .any(|item| item.quality > 0.0 && mt.matches(&item.media_type))
13221    }
13222
13223    /// Check if a media type is the preferred type.
13224    ///
13225    /// Returns true if the given media type matches the highest-quality
13226    /// acceptable type.
13227    #[must_use]
13228    pub fn prefers(&self, media_type: &str) -> bool {
13229        let Some(mt) = MediaType::parse(media_type) else {
13230            return false;
13231        };
13232
13233        self.items
13234            .first()
13235            .map(|item| mt.matches(&item.media_type))
13236            .unwrap_or(true)
13237    }
13238
13239    /// Get the quality value for a specific media type.
13240    ///
13241    /// Returns 0.0 if not acceptable, or the quality value if acceptable.
13242    #[must_use]
13243    pub fn quality_of(&self, media_type: &str) -> f32 {
13244        if self.items.is_empty() {
13245            return 1.0; // No Accept header means q=1.0 for everything
13246        }
13247
13248        let Some(mt) = MediaType::parse(media_type) else {
13249            return 0.0;
13250        };
13251
13252        self.items
13253            .iter()
13254            .find(|item| mt.matches(&item.media_type))
13255            .map(|item| item.quality)
13256            .unwrap_or(0.0)
13257    }
13258
13259    /// Negotiate the best media type from a list of available types.
13260    ///
13261    /// Returns the first available type that is acceptable, preferring
13262    /// higher quality matches.
13263    ///
13264    /// # Example
13265    ///
13266    /// ```ignore
13267    /// let accept = AcceptHeader::parse("text/html, application/json;q=0.9");
13268    /// let available = ["application/json", "text/html", "text/plain"];
13269    /// let best = accept.negotiate(&available);
13270    /// assert_eq!(best, Some("text/html"));
13271    /// ```
13272    #[must_use]
13273    pub fn negotiate<'a>(&self, available: &[&'a str]) -> Option<&'a str> {
13274        if self.items.is_empty() {
13275            return available.first().copied();
13276        }
13277
13278        // Score each available type
13279        let mut scored: Vec<(&str, f32, usize)> = available
13280            .iter()
13281            .enumerate()
13282            .filter_map(|(idx, &media_type)| {
13283                let q = self.quality_of(media_type);
13284                if q > 0.0 {
13285                    Some((media_type, q, idx))
13286                } else {
13287                    None
13288                }
13289            })
13290            .collect();
13291
13292        // Sort by quality descending, then by position in available list
13293        scored.sort_by(|a, b| {
13294            b.1.partial_cmp(&a.1)
13295                .unwrap_or(std::cmp::Ordering::Equal)
13296                .then_with(|| a.2.cmp(&b.2))
13297        });
13298
13299        scored.first().map(|(mt, _, _)| *mt)
13300    }
13301
13302    /// Check if the Accept header is empty (accepts anything).
13303    #[must_use]
13304    pub fn is_empty(&self) -> bool {
13305        self.items.is_empty()
13306    }
13307}
13308
13309impl Default for AcceptHeader {
13310    fn default() -> Self {
13311        Self::any()
13312    }
13313}
13314
13315impl FromRequest for AcceptHeader {
13316    type Error = std::convert::Infallible;
13317
13318    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
13319        let header = req
13320            .headers()
13321            .get("accept")
13322            .and_then(|v| std::str::from_utf8(v).ok())
13323            .map(Self::parse)
13324            .unwrap_or_else(Self::any);
13325        Ok(header)
13326    }
13327}
13328
13329/// A single entry from an Accept-Encoding header.
13330#[derive(Debug, Clone)]
13331pub struct AcceptEncodingItem {
13332    /// The encoding name (e.g., "gzip", "br", "deflate", "identity").
13333    pub encoding: String,
13334    /// Quality value (0.0 to 1.0, default 1.0).
13335    pub quality: f32,
13336}
13337
13338impl AcceptEncodingItem {
13339    /// Parse a single Accept-Encoding item.
13340    pub fn parse(s: &str) -> Option<Self> {
13341        let s = s.trim();
13342        let mut quality = 1.0f32;
13343
13344        let (encoding, _) = if let Some(q_pos) = s.to_ascii_lowercase().find(";q=") {
13345            let q_str = &s[q_pos + 3..];
13346            if let Ok(q) = q_str.trim().parse::<f32>() {
13347                quality = q.clamp(0.0, 1.0);
13348            }
13349            (s[..q_pos].trim().to_ascii_lowercase(), quality)
13350        } else {
13351            (s.to_ascii_lowercase(), quality)
13352        };
13353
13354        if encoding.is_empty() {
13355            return None;
13356        }
13357
13358        Some(Self { encoding, quality })
13359    }
13360}
13361
13362/// Parsed Accept-Encoding header.
13363///
13364/// # Example
13365///
13366/// ```ignore
13367/// use fastapi_core::AcceptEncodingHeader;
13368///
13369/// async fn handler(encoding: AcceptEncodingHeader) -> impl IntoResponse {
13370///     if encoding.accepts("br") {
13371///         // Use Brotli compression
13372///     } else if encoding.accepts("gzip") {
13373///         // Use gzip compression
13374///     }
13375/// }
13376/// ```
13377#[derive(Debug, Clone, Default)]
13378pub struct AcceptEncodingHeader {
13379    /// Encodings sorted by quality (highest first).
13380    pub items: Vec<AcceptEncodingItem>,
13381}
13382
13383impl AcceptEncodingHeader {
13384    /// Parse an Accept-Encoding header value.
13385    #[must_use]
13386    pub fn parse(s: &str) -> Self {
13387        let mut items: Vec<AcceptEncodingItem> =
13388            s.split(',').filter_map(AcceptEncodingItem::parse).collect();
13389
13390        items.sort_by(|a, b| {
13391            b.quality
13392                .partial_cmp(&a.quality)
13393                .unwrap_or(std::cmp::Ordering::Equal)
13394        });
13395
13396        Self { items }
13397    }
13398
13399    /// Check if an encoding is acceptable.
13400    #[must_use]
13401    pub fn accepts(&self, encoding: &str) -> bool {
13402        let encoding = encoding.to_ascii_lowercase();
13403        self.items
13404            .iter()
13405            .any(|item| item.quality > 0.0 && (item.encoding == encoding || item.encoding == "*"))
13406    }
13407
13408    /// Get the preferred encoding from a list of available encodings.
13409    #[must_use]
13410    pub fn negotiate<'a>(&self, available: &[&'a str]) -> Option<&'a str> {
13411        if self.items.is_empty() {
13412            return available.first().copied();
13413        }
13414
13415        let mut best: Option<(&str, f32)> = None;
13416
13417        for &encoding in available {
13418            let enc_lower = encoding.to_ascii_lowercase();
13419            for item in &self.items {
13420                if item.quality > 0.0 && (item.encoding == enc_lower || item.encoding == "*") {
13421                    match best {
13422                        None => best = Some((encoding, item.quality)),
13423                        Some((_, q)) if item.quality > q => best = Some((encoding, item.quality)),
13424                        _ => {}
13425                    }
13426                    break;
13427                }
13428            }
13429        }
13430
13431        best.map(|(e, _)| e)
13432    }
13433}
13434
13435impl FromRequest for AcceptEncodingHeader {
13436    type Error = std::convert::Infallible;
13437
13438    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
13439        let header = req
13440            .headers()
13441            .get("accept-encoding")
13442            .and_then(|v| std::str::from_utf8(v).ok())
13443            .map(Self::parse)
13444            .unwrap_or_default();
13445        Ok(header)
13446    }
13447}
13448
13449/// A single entry from an Accept-Language header.
13450#[derive(Debug, Clone)]
13451pub struct AcceptLanguageItem {
13452    /// The language tag (e.g., "en", "en-US", "fr-FR").
13453    pub language: String,
13454    /// Quality value (0.0 to 1.0, default 1.0).
13455    pub quality: f32,
13456}
13457
13458impl AcceptLanguageItem {
13459    /// Parse a single Accept-Language item.
13460    pub fn parse(s: &str) -> Option<Self> {
13461        let s = s.trim();
13462        let mut quality = 1.0f32;
13463
13464        let (language, _) = if let Some(q_pos) = s.to_ascii_lowercase().find(";q=") {
13465            let q_str = &s[q_pos + 3..];
13466            if let Ok(q) = q_str.trim().parse::<f32>() {
13467                quality = q.clamp(0.0, 1.0);
13468            }
13469            (s[..q_pos].trim().to_string(), quality)
13470        } else {
13471            (s.to_string(), quality)
13472        };
13473
13474        if language.is_empty() {
13475            return None;
13476        }
13477
13478        Some(Self { language, quality })
13479    }
13480}
13481
13482/// Parsed Accept-Language header.
13483///
13484/// # Example
13485///
13486/// ```ignore
13487/// use fastapi_core::AcceptLanguageHeader;
13488///
13489/// async fn handler(lang: AcceptLanguageHeader) -> impl IntoResponse {
13490///     let locale = lang.negotiate(&["en", "fr", "de"]).unwrap_or("en");
13491///     // Use locale for response
13492/// }
13493/// ```
13494#[derive(Debug, Clone, Default)]
13495pub struct AcceptLanguageHeader {
13496    /// Languages sorted by quality (highest first).
13497    pub items: Vec<AcceptLanguageItem>,
13498}
13499
13500impl AcceptLanguageHeader {
13501    /// Parse an Accept-Language header value.
13502    #[must_use]
13503    pub fn parse(s: &str) -> Self {
13504        let mut items: Vec<AcceptLanguageItem> =
13505            s.split(',').filter_map(AcceptLanguageItem::parse).collect();
13506
13507        items.sort_by(|a, b| {
13508            b.quality
13509                .partial_cmp(&a.quality)
13510                .unwrap_or(std::cmp::Ordering::Equal)
13511        });
13512
13513        Self { items }
13514    }
13515
13516    /// Check if a language is acceptable.
13517    #[must_use]
13518    pub fn accepts(&self, language: &str) -> bool {
13519        let lang_lower = language.to_ascii_lowercase();
13520        self.items.iter().any(|item| {
13521            if item.quality <= 0.0 {
13522                return false;
13523            }
13524            let item_lower = item.language.to_ascii_lowercase();
13525            // Exact match or prefix match (e.g., "en" matches "en-US")
13526            item_lower == lang_lower
13527                || item_lower == "*"
13528                || lang_lower.starts_with(&format!("{}-", item_lower))
13529                || item_lower.starts_with(&format!("{}-", lang_lower))
13530        })
13531    }
13532
13533    /// Get the preferred language from a list of available languages.
13534    #[must_use]
13535    pub fn negotiate<'a>(&self, available: &[&'a str]) -> Option<&'a str> {
13536        if self.items.is_empty() {
13537            return available.first().copied();
13538        }
13539
13540        let mut best: Option<(&str, f32, bool)> = None; // (lang, quality, exact_match)
13541
13542        for &lang in available {
13543            let lang_lower = lang.to_ascii_lowercase();
13544            for item in &self.items {
13545                if item.quality <= 0.0 {
13546                    continue;
13547                }
13548                let item_lower = item.language.to_ascii_lowercase();
13549
13550                let (matches, exact) = if item_lower == lang_lower {
13551                    (true, true)
13552                } else if item_lower == "*"
13553                    || lang_lower.starts_with(&format!("{}-", item_lower))
13554                    || item_lower.starts_with(&format!("{}-", lang_lower))
13555                {
13556                    (true, false)
13557                } else {
13558                    (false, false)
13559                };
13560
13561                if matches {
13562                    match best {
13563                        None => best = Some((lang, item.quality, exact)),
13564                        Some((_, q, e))
13565                            if item.quality > q
13566                                || ((item.quality - q).abs() < f32::EPSILON && exact && !e) =>
13567                        {
13568                            best = Some((lang, item.quality, exact));
13569                        }
13570                        _ => {}
13571                    }
13572                    break;
13573                }
13574            }
13575        }
13576
13577        best.map(|(l, _, _)| l)
13578    }
13579}
13580
13581impl FromRequest for AcceptLanguageHeader {
13582    type Error = std::convert::Infallible;
13583
13584    async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
13585        let header = req
13586            .headers()
13587            .get("accept-language")
13588            .and_then(|v| std::str::from_utf8(v).ok())
13589            .map(Self::parse)
13590            .unwrap_or_default();
13591        Ok(header)
13592    }
13593}
13594
13595/// Error returned when content negotiation fails.
13596#[derive(Debug, Clone)]
13597pub struct NotAcceptableError {
13598    /// The requested media types.
13599    pub requested: Vec<String>,
13600    /// The available media types.
13601    pub available: Vec<String>,
13602}
13603
13604impl NotAcceptableError {
13605    /// Create a new NotAcceptableError.
13606    #[must_use]
13607    pub fn new(requested: Vec<String>, available: Vec<String>) -> Self {
13608        Self {
13609            requested,
13610            available,
13611        }
13612    }
13613}
13614
13615impl fmt::Display for NotAcceptableError {
13616    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13617        write!(
13618            f,
13619            "Not Acceptable: requested [{}], available [{}]",
13620            self.requested.join(", "),
13621            self.available.join(", ")
13622        )
13623    }
13624}
13625
13626impl std::error::Error for NotAcceptableError {}
13627
13628impl IntoResponse for NotAcceptableError {
13629    fn into_response(self) -> Response {
13630        Response::with_status(crate::response::StatusCode::NOT_ACCEPTABLE)
13631            .header("content-type", b"application/json".to_vec())
13632            .body(ResponseBody::Bytes(
13633                serde_json::json!({
13634                    "error": "Not Acceptable",
13635                    "message": self.to_string(),
13636                    "requested": self.requested,
13637                    "available": self.available,
13638                })
13639                .to_string()
13640                .into_bytes(),
13641            ))
13642    }
13643}
13644
13645/// Helper to build Vary header values for content negotiation.
13646#[derive(Debug, Clone, Default)]
13647pub struct VaryBuilder {
13648    headers: Vec<String>,
13649}
13650
13651impl VaryBuilder {
13652    /// Create a new Vary builder.
13653    #[must_use]
13654    pub fn new() -> Self {
13655        Self::default()
13656    }
13657
13658    /// Add Accept to the Vary header.
13659    #[must_use]
13660    pub fn accept(mut self) -> Self {
13661        if !self.headers.contains(&"Accept".to_string()) {
13662            self.headers.push("Accept".to_string());
13663        }
13664        self
13665    }
13666
13667    /// Add Accept-Encoding to the Vary header.
13668    #[must_use]
13669    pub fn accept_encoding(mut self) -> Self {
13670        if !self.headers.contains(&"Accept-Encoding".to_string()) {
13671            self.headers.push("Accept-Encoding".to_string());
13672        }
13673        self
13674    }
13675
13676    /// Add Accept-Language to the Vary header.
13677    #[must_use]
13678    pub fn accept_language(mut self) -> Self {
13679        if !self.headers.contains(&"Accept-Language".to_string()) {
13680            self.headers.push("Accept-Language".to_string());
13681        }
13682        self
13683    }
13684
13685    /// Add a custom header to Vary.
13686    #[must_use]
13687    pub fn header(mut self, name: impl Into<String>) -> Self {
13688        let name = name.into();
13689        if !self.headers.contains(&name) {
13690            self.headers.push(name);
13691        }
13692        self
13693    }
13694
13695    /// Build the Vary header value.
13696    #[must_use]
13697    pub fn build(&self) -> String {
13698        self.headers.join(", ")
13699    }
13700
13701    /// Check if any headers have been added.
13702    #[must_use]
13703    pub fn is_empty(&self) -> bool {
13704        self.headers.is_empty()
13705    }
13706}
13707
13708#[cfg(test)]
13709mod content_negotiation_tests {
13710    use super::*;
13711
13712    #[test]
13713    fn media_type_parse_simple() {
13714        let mt = MediaType::parse("text/html").unwrap();
13715        assert_eq!(mt.typ, "text");
13716        assert_eq!(mt.subtype, "html");
13717        assert!(mt.params.is_empty());
13718    }
13719
13720    #[test]
13721    fn media_type_parse_with_params() {
13722        let mt = MediaType::parse("text/html; charset=utf-8").unwrap();
13723        assert_eq!(mt.typ, "text");
13724        assert_eq!(mt.subtype, "html");
13725        assert_eq!(mt.param("charset"), Some("utf-8"));
13726    }
13727
13728    #[test]
13729    fn media_type_parse_case_insensitive() {
13730        let mt = MediaType::parse("TEXT/HTML").unwrap();
13731        assert_eq!(mt.typ, "text");
13732        assert_eq!(mt.subtype, "html");
13733    }
13734
13735    #[test]
13736    fn media_type_matches_wildcard() {
13737        let html = MediaType::new("text", "html");
13738        let any_text = MediaType::new("text", "*");
13739        let any = MediaType::new("*", "*");
13740
13741        assert!(html.matches(&any_text));
13742        assert!(html.matches(&any));
13743        assert!(html.matches(&html));
13744    }
13745
13746    #[test]
13747    fn accept_item_parse_with_quality() {
13748        let item = AcceptItem::parse("text/html;q=0.9").unwrap();
13749        assert_eq!(item.media_type.typ, "text");
13750        assert_eq!(item.media_type.subtype, "html");
13751        assert!((item.quality - 0.9).abs() < f32::EPSILON);
13752    }
13753
13754    #[test]
13755    fn accept_item_parse_default_quality() {
13756        let item = AcceptItem::parse("application/json").unwrap();
13757        assert!((item.quality - 1.0).abs() < f32::EPSILON);
13758    }
13759
13760    #[test]
13761    fn accept_header_parse_multiple() {
13762        let accept = AcceptHeader::parse("text/html, application/json;q=0.9, */*;q=0.1");
13763        assert_eq!(accept.items.len(), 3);
13764        assert_eq!(accept.items[0].media_type.subtype, "html");
13765        assert_eq!(accept.items[1].media_type.subtype, "json");
13766        assert_eq!(accept.items[2].media_type.subtype, "*");
13767    }
13768
13769    #[test]
13770    fn accept_header_prefers() {
13771        let accept = AcceptHeader::parse("text/html, application/json;q=0.9");
13772        assert!(accept.prefers("text/html"));
13773        assert!(!accept.prefers("application/json"));
13774    }
13775
13776    #[test]
13777    fn accept_header_accepts() {
13778        let accept = AcceptHeader::parse("text/html, application/json;q=0.9");
13779        assert!(accept.accepts("text/html"));
13780        assert!(accept.accepts("application/json"));
13781        assert!(!accept.accepts("image/png"));
13782    }
13783
13784    #[test]
13785    fn accept_header_negotiate() {
13786        let accept = AcceptHeader::parse("text/html, application/json;q=0.9");
13787        let available = ["application/json", "text/html", "text/plain"];
13788        assert_eq!(accept.negotiate(&available), Some("text/html"));
13789    }
13790
13791    #[test]
13792    fn accept_header_negotiate_returns_best_available() {
13793        let accept = AcceptHeader::parse("application/xml, application/json;q=0.9");
13794        let available = ["application/json", "text/plain"];
13795        assert_eq!(accept.negotiate(&available), Some("application/json"));
13796    }
13797
13798    #[test]
13799    fn accept_header_quality_of() {
13800        let accept = AcceptHeader::parse("text/html, application/json;q=0.9, */*;q=0.1");
13801        assert!((accept.quality_of("text/html") - 1.0).abs() < f32::EPSILON);
13802        assert!((accept.quality_of("application/json") - 0.9).abs() < f32::EPSILON);
13803        assert!((accept.quality_of("image/png") - 0.1).abs() < f32::EPSILON);
13804    }
13805
13806    #[test]
13807    #[allow(clippy::float_cmp)]
13808    fn accept_header_empty_accepts_all() {
13809        let accept = AcceptHeader::parse("");
13810        assert!(accept.accepts("anything/here"));
13811        assert_eq!(accept.quality_of("text/html"), 1.0);
13812    }
13813
13814    #[test]
13815    fn accept_encoding_parse() {
13816        let enc = AcceptEncodingHeader::parse("gzip, deflate, br;q=0.8");
13817        assert_eq!(enc.items.len(), 3);
13818        assert!(enc.accepts("gzip"));
13819        assert!(enc.accepts("br"));
13820    }
13821
13822    #[test]
13823    fn accept_encoding_negotiate() {
13824        let enc = AcceptEncodingHeader::parse("gzip;q=0.9, br");
13825        let available = ["gzip", "br", "identity"];
13826        assert_eq!(enc.negotiate(&available), Some("br"));
13827    }
13828
13829    #[test]
13830    fn accept_language_parse() {
13831        let lang = AcceptLanguageHeader::parse("en-US, en;q=0.9, fr;q=0.8");
13832        assert_eq!(lang.items.len(), 3);
13833        assert!(lang.accepts("en-US"));
13834        assert!(lang.accepts("en"));
13835        assert!(lang.accepts("fr"));
13836    }
13837
13838    #[test]
13839    fn accept_language_negotiate() {
13840        let lang = AcceptLanguageHeader::parse("fr, en;q=0.9");
13841        let available = ["en", "de", "fr"];
13842        assert_eq!(lang.negotiate(&available), Some("fr"));
13843    }
13844
13845    #[test]
13846    fn accept_language_prefix_match() {
13847        let lang = AcceptLanguageHeader::parse("en");
13848        assert!(lang.accepts("en-US"));
13849        assert!(lang.accepts("en-GB"));
13850    }
13851
13852    #[test]
13853    fn vary_builder() {
13854        let vary = VaryBuilder::new().accept().accept_encoding().build();
13855        assert_eq!(vary, "Accept, Accept-Encoding");
13856    }
13857
13858    #[test]
13859    fn vary_builder_no_duplicates() {
13860        let vary = VaryBuilder::new().accept().accept().build();
13861        assert_eq!(vary, "Accept");
13862    }
13863
13864    #[test]
13865    fn not_acceptable_error_response() {
13866        let err = NotAcceptableError::new(
13867            vec!["image/png".to_string()],
13868            vec!["application/json".to_string(), "text/html".to_string()],
13869        );
13870        let response = err.into_response();
13871        assert_eq!(
13872            response.status(),
13873            crate::response::StatusCode::NOT_ACCEPTABLE
13874        );
13875    }
13876}
13877
13878#[cfg(test)]
13879mod header_tests {
13880    use super::*;
13881    use crate::request::Method;
13882
13883    fn test_context() -> RequestContext {
13884        let cx = asupersync::Cx::for_testing();
13885        RequestContext::new(cx, 12345)
13886    }
13887
13888    #[test]
13889    fn snake_to_header_case_simple() {
13890        assert_eq!(snake_to_header_case("authorization"), "Authorization");
13891        assert_eq!(snake_to_header_case("content_type"), "Content-Type");
13892        assert_eq!(snake_to_header_case("x_request_id"), "X-Request-Id");
13893        assert_eq!(snake_to_header_case("accept"), "Accept");
13894    }
13895
13896    #[test]
13897    fn snake_to_header_case_edge_cases() {
13898        assert_eq!(snake_to_header_case(""), "");
13899        assert_eq!(snake_to_header_case("a"), "A");
13900        assert_eq!(snake_to_header_case("a_b_c"), "A-B-C");
13901    }
13902
13903    #[test]
13904    fn header_deref() {
13905        let header = Header::new("test", "value".to_string());
13906        assert_eq!(*header, "value");
13907    }
13908
13909    #[test]
13910    fn header_into_inner() {
13911        let header = Header::new("test", 42i32);
13912        assert_eq!(header.into_inner(), 42);
13913    }
13914
13915    #[test]
13916    fn from_header_value_string() {
13917        let result = String::from_header_value("test value");
13918        assert_eq!(result.unwrap(), "test value");
13919    }
13920
13921    #[test]
13922    fn from_header_value_i32() {
13923        assert_eq!(i32::from_header_value("42").unwrap(), 42);
13924        assert_eq!(i32::from_header_value("-1").unwrap(), -1);
13925        assert!(i32::from_header_value("abc").is_err());
13926    }
13927
13928    #[test]
13929    fn from_header_value_bool() {
13930        assert!(bool::from_header_value("true").unwrap());
13931        assert!(bool::from_header_value("1").unwrap());
13932        assert!(bool::from_header_value("yes").unwrap());
13933        assert!(!bool::from_header_value("false").unwrap());
13934        assert!(!bool::from_header_value("0").unwrap());
13935        assert!(!bool::from_header_value("no").unwrap());
13936        assert!(bool::from_header_value("maybe").is_err());
13937    }
13938
13939    #[test]
13940    fn named_header_extract_success() {
13941        let ctx = test_context();
13942        let mut req = Request::new(Method::Get, "/test");
13943        req.headers_mut()
13944            .insert("authorization", b"Bearer token123".to_vec());
13945
13946        let result = futures_executor::block_on(
13947            NamedHeader::<String, Authorization>::from_request(&ctx, &mut req),
13948        );
13949        let header = result.unwrap();
13950        assert_eq!(header.value, "Bearer token123");
13951    }
13952
13953    #[test]
13954    fn named_header_extract_i32() {
13955        let ctx = test_context();
13956        let mut req = Request::new(Method::Get, "/test");
13957        req.headers_mut().insert("x-request-id", b"12345".to_vec());
13958
13959        let result = futures_executor::block_on(NamedHeader::<i32, XRequestId>::from_request(
13960            &ctx, &mut req,
13961        ));
13962        let header = result.unwrap();
13963        assert_eq!(header.value, 12345);
13964    }
13965
13966    #[test]
13967    fn named_header_missing() {
13968        let ctx = test_context();
13969        let mut req = Request::new(Method::Get, "/test");
13970        // Don't insert the header
13971
13972        let result = futures_executor::block_on(
13973            NamedHeader::<String, Authorization>::from_request(&ctx, &mut req),
13974        );
13975        assert!(matches!(
13976            result,
13977            Err(HeaderExtractError::MissingHeader { .. })
13978        ));
13979    }
13980
13981    #[test]
13982    fn named_header_parse_error() {
13983        let ctx = test_context();
13984        let mut req = Request::new(Method::Get, "/test");
13985        req.headers_mut()
13986            .insert("x-request-id", b"not-a-number".to_vec());
13987
13988        let result = futures_executor::block_on(NamedHeader::<i32, XRequestId>::from_request(
13989            &ctx, &mut req,
13990        ));
13991        assert!(matches!(result, Err(HeaderExtractError::ParseError { .. })));
13992    }
13993
13994    #[test]
13995    fn header_error_display() {
13996        let err = HeaderExtractError::MissingHeader {
13997            name: "Authorization".to_string(),
13998        };
13999        assert!(err.to_string().contains("Authorization"));
14000
14001        let err = HeaderExtractError::ParseError {
14002            name: "X-Count".to_string(),
14003            value: "abc".to_string(),
14004            expected: "i32",
14005            message: "invalid digit".to_string(),
14006        };
14007        assert!(err.to_string().contains("X-Count"));
14008        assert!(err.to_string().contains("abc"));
14009    }
14010
14011    #[test]
14012    fn optional_header_some() {
14013        let ctx = test_context();
14014        let mut req = Request::new(Method::Get, "/test");
14015        req.headers_mut()
14016            .insert("authorization", b"Bearer token".to_vec());
14017
14018        let result = futures_executor::block_on(
14019            Option::<NamedHeader<String, Authorization>>::from_request(&ctx, &mut req),
14020        );
14021        let opt = result.unwrap();
14022        assert!(opt.is_some());
14023        assert_eq!(opt.unwrap().value, "Bearer token");
14024    }
14025
14026    #[test]
14027    fn optional_header_none() {
14028        let ctx = test_context();
14029        let mut req = Request::new(Method::Get, "/test");
14030        // Don't insert the header
14031
14032        let result = futures_executor::block_on(
14033            Option::<NamedHeader<String, Authorization>>::from_request(&ctx, &mut req),
14034        );
14035        let opt = result.unwrap();
14036        assert!(opt.is_none());
14037    }
14038}
14039
14040#[cfg(test)]
14041mod oauth2_tests {
14042    use super::*;
14043    use crate::request::Method;
14044    use crate::response::IntoResponse;
14045
14046    fn test_context() -> RequestContext {
14047        let cx = asupersync::Cx::for_testing();
14048        RequestContext::new(cx, 12345)
14049    }
14050
14051    #[test]
14052    fn oauth2_extract_valid_bearer_token() {
14053        let ctx = test_context();
14054        let mut req = Request::new(Method::Get, "/api/protected");
14055        req.headers_mut()
14056            .insert("authorization", b"Bearer mytoken123".to_vec());
14057
14058        let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14059        let bearer = result.unwrap();
14060        assert_eq!(bearer.token(), "mytoken123");
14061        assert_eq!(&*bearer, "mytoken123"); // Test Deref
14062    }
14063
14064    #[test]
14065    fn oauth2_extract_bearer_lowercase() {
14066        let ctx = test_context();
14067        let mut req = Request::new(Method::Get, "/api/protected");
14068        req.headers_mut()
14069            .insert("authorization", b"bearer lowercase_token".to_vec());
14070
14071        let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14072        let bearer = result.unwrap();
14073        assert_eq!(bearer.token(), "lowercase_token");
14074    }
14075
14076    #[test]
14077    fn oauth2_missing_header() {
14078        let ctx = test_context();
14079        let mut req = Request::new(Method::Get, "/api/protected");
14080        // No authorization header
14081
14082        let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14083        let err = result.unwrap_err();
14084        assert_eq!(err.kind, OAuth2BearerErrorKind::MissingHeader);
14085    }
14086
14087    #[test]
14088    fn oauth2_wrong_scheme() {
14089        let ctx = test_context();
14090        let mut req = Request::new(Method::Get, "/api/protected");
14091        req.headers_mut()
14092            .insert("authorization", b"Basic dXNlcjpwYXNz".to_vec());
14093
14094        let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14095        let err = result.unwrap_err();
14096        assert_eq!(err.kind, OAuth2BearerErrorKind::InvalidScheme);
14097    }
14098
14099    #[test]
14100    fn oauth2_empty_token() {
14101        let ctx = test_context();
14102        let mut req = Request::new(Method::Get, "/api/protected");
14103        req.headers_mut()
14104            .insert("authorization", b"Bearer ".to_vec());
14105
14106        let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14107        let err = result.unwrap_err();
14108        assert_eq!(err.kind, OAuth2BearerErrorKind::EmptyToken);
14109    }
14110
14111    #[test]
14112    fn oauth2_whitespace_only_token() {
14113        let ctx = test_context();
14114        let mut req = Request::new(Method::Get, "/api/protected");
14115        req.headers_mut()
14116            .insert("authorization", b"Bearer    ".to_vec());
14117
14118        let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14119        let err = result.unwrap_err();
14120        assert_eq!(err.kind, OAuth2BearerErrorKind::EmptyToken);
14121    }
14122
14123    #[test]
14124    fn oauth2_token_with_spaces_trimmed() {
14125        let ctx = test_context();
14126        let mut req = Request::new(Method::Get, "/api/protected");
14127        req.headers_mut()
14128            .insert("authorization", b"Bearer  spaced_token  ".to_vec());
14129
14130        let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14131        let bearer = result.unwrap();
14132        assert_eq!(bearer.token(), "spaced_token");
14133    }
14134
14135    #[test]
14136    fn oauth2_optional_extraction_some() {
14137        let ctx = test_context();
14138        let mut req = Request::new(Method::Get, "/api/maybe-protected");
14139        req.headers_mut()
14140            .insert("authorization", b"Bearer optional_token".to_vec());
14141
14142        let result = futures_executor::block_on(Option::<OAuth2PasswordBearer>::from_request(
14143            &ctx, &mut req,
14144        ));
14145        let opt = result.unwrap();
14146        assert!(opt.is_some());
14147        assert_eq!(opt.unwrap().token(), "optional_token");
14148    }
14149
14150    #[test]
14151    fn oauth2_optional_extraction_none() {
14152        let ctx = test_context();
14153        let mut req = Request::new(Method::Get, "/api/maybe-protected");
14154        // No authorization header
14155
14156        let result = futures_executor::block_on(Option::<OAuth2PasswordBearer>::from_request(
14157            &ctx, &mut req,
14158        ));
14159        let opt = result.unwrap();
14160        assert!(opt.is_none());
14161    }
14162
14163    #[test]
14164    fn oauth2_error_response_401() {
14165        let err = OAuth2BearerError::missing_header();
14166        let response = err.into_response();
14167        assert_eq!(response.status().as_u16(), 401);
14168    }
14169
14170    #[test]
14171    fn oauth2_error_response_has_www_authenticate() {
14172        let err = OAuth2BearerError::missing_header();
14173        let response = err.into_response();
14174
14175        let www_auth = response
14176            .headers()
14177            .iter()
14178            .find(|(name, _)| name == "www-authenticate")
14179            .map(|(_, value)| String::from_utf8_lossy(value).to_string());
14180
14181        assert_eq!(www_auth, Some("Bearer".to_string()));
14182    }
14183
14184    #[test]
14185    fn oauth2_error_display() {
14186        assert!(
14187            OAuth2BearerError::missing_header()
14188                .to_string()
14189                .contains("Missing")
14190        );
14191        assert!(
14192            OAuth2BearerError::invalid_scheme()
14193                .to_string()
14194                .contains("Bearer")
14195        );
14196        assert!(
14197            OAuth2BearerError::empty_token()
14198                .to_string()
14199                .contains("empty")
14200        );
14201    }
14202
14203    #[test]
14204    fn oauth2_config_builder() {
14205        let config = OAuth2PasswordBearerConfig::new("/auth/token")
14206            .with_refresh_url("/auth/refresh")
14207            .with_scope("read", "Read access")
14208            .with_scope("write", "Write access")
14209            .with_scheme_name("MyOAuth2")
14210            .with_description("Custom OAuth2 scheme")
14211            .with_auto_error(false);
14212
14213        assert_eq!(config.token_url, "/auth/token");
14214        assert_eq!(config.refresh_url, Some("/auth/refresh".to_string()));
14215        assert_eq!(config.scopes.len(), 2);
14216        assert_eq!(config.scopes.get("read"), Some(&"Read access".to_string()));
14217        assert_eq!(config.scheme_name, Some("MyOAuth2".to_string()));
14218        assert!(!config.auto_error);
14219    }
14220
14221    #[test]
14222    fn oauth2_password_bearer_accessors() {
14223        let bearer = OAuth2PasswordBearer::new("test_token");
14224        assert_eq!(bearer.token(), "test_token");
14225        assert_eq!(bearer.into_token(), "test_token");
14226    }
14227
14228    // ============================================================================
14229    // Additional Security Tests for fastapi_rust-3mg
14230    // ============================================================================
14231
14232    #[test]
14233    fn oauth2_error_response_json_body_format() {
14234        let err = OAuth2BearerError::missing_header();
14235        let response = err.into_response();
14236
14237        // Verify the body is valid JSON with "detail" field
14238        let body = match response.body_ref() {
14239            crate::response::ResponseBody::Bytes(b) => String::from_utf8_lossy(b).to_string(),
14240            _ => panic!("Expected Bytes body"),
14241        };
14242
14243        let json: serde_json::Value =
14244            serde_json::from_str(&body).expect("Body should be valid JSON");
14245        assert!(
14246            json.get("detail").is_some(),
14247            "Response should have 'detail' field"
14248        );
14249        assert_eq!(json["detail"], "Not authenticated");
14250    }
14251
14252    #[test]
14253    fn oauth2_error_invalid_scheme_json_body() {
14254        let err = OAuth2BearerError::invalid_scheme();
14255        let response = err.into_response();
14256
14257        let body = match response.body_ref() {
14258            crate::response::ResponseBody::Bytes(b) => String::from_utf8_lossy(b).to_string(),
14259            _ => panic!("Expected Bytes body"),
14260        };
14261
14262        let json: serde_json::Value =
14263            serde_json::from_str(&body).expect("Body should be valid JSON");
14264        assert_eq!(json["detail"], "Invalid authentication credentials");
14265    }
14266
14267    #[test]
14268    fn oauth2_error_empty_token_json_body() {
14269        let err = OAuth2BearerError::empty_token();
14270        let response = err.into_response();
14271
14272        let body = match response.body_ref() {
14273            crate::response::ResponseBody::Bytes(b) => String::from_utf8_lossy(b).to_string(),
14274            _ => panic!("Expected Bytes body"),
14275        };
14276
14277        let json: serde_json::Value =
14278            serde_json::from_str(&body).expect("Body should be valid JSON");
14279        assert_eq!(json["detail"], "Invalid authentication credentials");
14280    }
14281
14282    #[test]
14283    fn oauth2_error_response_content_type_json() {
14284        let err = OAuth2BearerError::missing_header();
14285        let response = err.into_response();
14286
14287        let content_type = response
14288            .headers()
14289            .iter()
14290            .find(|(name, _)| name == "content-type")
14291            .map(|(_, value)| String::from_utf8_lossy(value).to_string());
14292
14293        assert_eq!(content_type, Some("application/json".to_string()));
14294    }
14295
14296    #[test]
14297    fn oauth2_extract_token_with_special_characters() {
14298        let ctx = test_context();
14299        let mut req = Request::new(Method::Get, "/api/protected");
14300        // JWT-like token with special characters
14301        req.headers_mut()
14302            .insert("authorization", b"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U".to_vec());
14303
14304        let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14305        let bearer = result.unwrap();
14306        assert!(bearer.token().contains("eyJ"));
14307        assert!(bearer.token().contains("."));
14308    }
14309
14310    #[test]
14311    fn oauth2_extract_token_with_unicode() {
14312        let ctx = test_context();
14313        let mut req = Request::new(Method::Get, "/api/protected");
14314        // Token with unicode characters (unusual but should work)
14315        req.headers_mut().insert(
14316            "authorization",
14317            "Bearer tökën_with_ünïcödë".as_bytes().to_vec(),
14318        );
14319
14320        let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14321        let bearer = result.unwrap();
14322        assert_eq!(bearer.token(), "tökën_with_ünïcödë");
14323    }
14324
14325    #[test]
14326    fn oauth2_invalid_utf8_in_token() {
14327        let ctx = test_context();
14328        let mut req = Request::new(Method::Get, "/api/protected");
14329        // Invalid UTF-8 sequence
14330        req.headers_mut().insert(
14331            "authorization",
14332            vec![66, 101, 97, 114, 101, 114, 32, 0xFF, 0xFE],
14333        );
14334
14335        let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14336        // Should return InvalidScheme because it can't be parsed as valid UTF-8
14337        assert!(result.is_err());
14338        assert_eq!(
14339            result.unwrap_err().kind,
14340            OAuth2BearerErrorKind::InvalidScheme
14341        );
14342    }
14343
14344    #[test]
14345    fn oauth2_only_bearer_prefix_no_space() {
14346        let ctx = test_context();
14347        let mut req = Request::new(Method::Get, "/api/protected");
14348        // "Bearer" without space or token - should be invalid scheme
14349        req.headers_mut()
14350            .insert("authorization", b"Bearertoken".to_vec());
14351
14352        let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14353        let err = result.unwrap_err();
14354        assert_eq!(err.kind, OAuth2BearerErrorKind::InvalidScheme);
14355    }
14356
14357    #[test]
14358    fn oauth2_mixed_case_bearer() {
14359        let ctx = test_context();
14360        let mut req = Request::new(Method::Get, "/api/protected");
14361        // "BEARER" all caps - should fail (only "Bearer" and "bearer" supported)
14362        req.headers_mut()
14363            .insert("authorization", b"BEARER uppercase_token".to_vec());
14364
14365        let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14366        // Currently the implementation only supports "Bearer " and "bearer " prefixes
14367        let err = result.unwrap_err();
14368        assert_eq!(err.kind, OAuth2BearerErrorKind::InvalidScheme);
14369    }
14370
14371    #[test]
14372    fn oauth2_extract_very_long_token() {
14373        let ctx = test_context();
14374        let mut req = Request::new(Method::Get, "/api/protected");
14375        // Very long token (4KB)
14376        let long_token = "x".repeat(4096);
14377        req.headers_mut()
14378            .insert("authorization", format!("Bearer {long_token}").into_bytes());
14379
14380        let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14381        let bearer = result.unwrap();
14382        assert_eq!(bearer.token().len(), 4096);
14383    }
14384
14385    #[test]
14386    fn oauth2_config_default_values() {
14387        let config = OAuth2PasswordBearerConfig::default();
14388
14389        assert_eq!(config.token_url, "/token");
14390        assert!(config.refresh_url.is_none());
14391        assert!(config.scopes.is_empty());
14392        assert!(config.scheme_name.is_none());
14393        assert!(config.description.is_none());
14394        assert!(config.auto_error); // Default should be true
14395    }
14396
14397    #[test]
14398    fn oauth2_error_kind_equality() {
14399        // Verify error kinds implement PartialEq correctly
14400        assert_eq!(
14401            OAuth2BearerErrorKind::MissingHeader,
14402            OAuth2BearerErrorKind::MissingHeader
14403        );
14404        assert_eq!(
14405            OAuth2BearerErrorKind::InvalidScheme,
14406            OAuth2BearerErrorKind::InvalidScheme
14407        );
14408        assert_eq!(
14409            OAuth2BearerErrorKind::EmptyToken,
14410            OAuth2BearerErrorKind::EmptyToken
14411        );
14412        assert_ne!(
14413            OAuth2BearerErrorKind::MissingHeader,
14414            OAuth2BearerErrorKind::InvalidScheme
14415        );
14416    }
14417
14418    #[test]
14419    fn oauth2_error_debug_format() {
14420        // Verify error types implement Debug
14421        let err = OAuth2BearerError::missing_header();
14422        let debug_str = format!("{:?}", err);
14423        assert!(debug_str.contains("MissingHeader"));
14424    }
14425
14426    #[test]
14427    fn oauth2_bearer_clone() {
14428        let bearer = OAuth2PasswordBearer::new("cloneable_token");
14429        let cloned = bearer.clone();
14430        assert_eq!(bearer.token(), cloned.token());
14431    }
14432
14433    #[test]
14434    fn oauth2_config_clone() {
14435        let config =
14436            OAuth2PasswordBearerConfig::new("/auth/token").with_scope("admin", "Admin access");
14437        let cloned = config.clone();
14438        assert_eq!(config.token_url, cloned.token_url);
14439        assert_eq!(config.scopes.len(), cloned.scopes.len());
14440    }
14441
14442    #[test]
14443    fn oauth2_all_error_responses_are_401() {
14444        // All OAuth2 bearer errors should result in 401 Unauthorized
14445        let errors = [
14446            OAuth2BearerError::missing_header(),
14447            OAuth2BearerError::invalid_scheme(),
14448            OAuth2BearerError::empty_token(),
14449        ];
14450
14451        for err in errors {
14452            let response = err.into_response();
14453            assert_eq!(
14454                response.status().as_u16(),
14455                401,
14456                "All OAuth2 errors should be 401"
14457            );
14458        }
14459    }
14460
14461    #[test]
14462    fn oauth2_all_error_responses_have_www_authenticate() {
14463        // All OAuth2 bearer errors should include WWW-Authenticate header
14464        let errors = [
14465            OAuth2BearerError::missing_header(),
14466            OAuth2BearerError::invalid_scheme(),
14467            OAuth2BearerError::empty_token(),
14468        ];
14469
14470        for err in errors {
14471            let response = err.into_response();
14472            let has_www_auth = response
14473                .headers()
14474                .iter()
14475                .any(|(name, value)| name == "www-authenticate" && value == b"Bearer");
14476            assert!(
14477                has_www_auth,
14478                "All OAuth2 errors should have WWW-Authenticate: Bearer"
14479            );
14480        }
14481    }
14482}
14483
14484#[cfg(test)]
14485mod bearer_token_tests {
14486    use super::*;
14487    use crate::request::Method;
14488    use crate::response::IntoResponse;
14489
14490    fn test_context() -> RequestContext {
14491        let cx = asupersync::Cx::for_testing();
14492        RequestContext::new(cx, 12345)
14493    }
14494
14495    #[test]
14496    fn bearer_token_extract_valid_token() {
14497        let ctx = test_context();
14498        let mut req = Request::new(Method::Get, "/api/protected");
14499        req.headers_mut()
14500            .insert("authorization", b"Bearer mytoken123".to_vec());
14501
14502        let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14503        let token = result.unwrap();
14504        assert_eq!(token.token(), "mytoken123");
14505        assert_eq!(&*token, "mytoken123"); // Test Deref
14506        assert_eq!(token.as_ref(), "mytoken123"); // Test AsRef
14507    }
14508
14509    #[test]
14510    fn bearer_token_extract_lowercase_bearer() {
14511        let ctx = test_context();
14512        let mut req = Request::new(Method::Get, "/api/protected");
14513        req.headers_mut()
14514            .insert("authorization", b"bearer lowercase_token".to_vec());
14515
14516        let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14517        let token = result.unwrap();
14518        assert_eq!(token.token(), "lowercase_token");
14519    }
14520
14521    #[test]
14522    fn bearer_token_missing_header() {
14523        let ctx = test_context();
14524        let mut req = Request::new(Method::Get, "/api/protected");
14525        // No authorization header
14526
14527        let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14528        let err = result.unwrap_err();
14529        assert_eq!(err, BearerTokenError::MissingHeader);
14530    }
14531
14532    #[test]
14533    fn bearer_token_wrong_scheme() {
14534        let ctx = test_context();
14535        let mut req = Request::new(Method::Get, "/api/protected");
14536        req.headers_mut()
14537            .insert("authorization", b"Basic dXNlcjpwYXNz".to_vec());
14538
14539        let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14540        let err = result.unwrap_err();
14541        assert_eq!(err, BearerTokenError::InvalidScheme);
14542    }
14543
14544    #[test]
14545    fn bearer_token_empty_token() {
14546        let ctx = test_context();
14547        let mut req = Request::new(Method::Get, "/api/protected");
14548        req.headers_mut()
14549            .insert("authorization", b"Bearer ".to_vec());
14550
14551        let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14552        let err = result.unwrap_err();
14553        assert_eq!(err, BearerTokenError::EmptyToken);
14554    }
14555
14556    #[test]
14557    fn bearer_token_whitespace_only_token() {
14558        let ctx = test_context();
14559        let mut req = Request::new(Method::Get, "/api/protected");
14560        req.headers_mut()
14561            .insert("authorization", b"Bearer    ".to_vec());
14562
14563        let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14564        let err = result.unwrap_err();
14565        assert_eq!(err, BearerTokenError::EmptyToken);
14566    }
14567
14568    #[test]
14569    fn bearer_token_with_spaces_trimmed() {
14570        let ctx = test_context();
14571        let mut req = Request::new(Method::Get, "/api/protected");
14572        req.headers_mut()
14573            .insert("authorization", b"Bearer   spaced_token   ".to_vec());
14574
14575        let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14576        let token = result.unwrap();
14577        assert_eq!(token.token(), "spaced_token");
14578    }
14579
14580    #[test]
14581    fn bearer_token_optional_some() {
14582        let ctx = test_context();
14583        let mut req = Request::new(Method::Get, "/api/protected");
14584        req.headers_mut()
14585            .insert("authorization", b"Bearer optional_token".to_vec());
14586
14587        let result =
14588            futures_executor::block_on(Option::<BearerToken>::from_request(&ctx, &mut req));
14589        let maybe_token = result.unwrap();
14590        assert!(maybe_token.is_some());
14591        assert_eq!(maybe_token.unwrap().token(), "optional_token");
14592    }
14593
14594    #[test]
14595    fn bearer_token_optional_none() {
14596        let ctx = test_context();
14597        let mut req = Request::new(Method::Get, "/api/protected");
14598        // No authorization header
14599
14600        let result =
14601            futures_executor::block_on(Option::<BearerToken>::from_request(&ctx, &mut req));
14602        let maybe_token = result.unwrap();
14603        assert!(maybe_token.is_none());
14604    }
14605
14606    #[test]
14607    fn bearer_token_error_response_401() {
14608        let err = BearerTokenError::missing_header();
14609        let response = err.into_response();
14610        assert_eq!(response.status().as_u16(), 401);
14611    }
14612
14613    #[test]
14614    fn bearer_token_error_has_www_authenticate() {
14615        let err = BearerTokenError::missing_header();
14616        let response = err.into_response();
14617
14618        let has_www_auth = response
14619            .headers()
14620            .iter()
14621            .any(|(name, value)| name == "www-authenticate" && value == b"Bearer");
14622        assert!(has_www_auth);
14623    }
14624
14625    #[test]
14626    fn bearer_token_error_display() {
14627        assert_eq!(
14628            BearerTokenError::missing_header().to_string(),
14629            "Missing Authorization header"
14630        );
14631        assert_eq!(
14632            BearerTokenError::invalid_scheme().to_string(),
14633            "Authorization header must use Bearer scheme"
14634        );
14635        assert_eq!(
14636            BearerTokenError::empty_token().to_string(),
14637            "Bearer token is empty"
14638        );
14639    }
14640
14641    #[test]
14642    fn bearer_token_error_detail() {
14643        assert_eq!(
14644            BearerTokenError::MissingHeader.detail(),
14645            "Not authenticated"
14646        );
14647        assert_eq!(
14648            BearerTokenError::InvalidScheme.detail(),
14649            "Invalid authentication credentials"
14650        );
14651        assert_eq!(
14652            BearerTokenError::EmptyToken.detail(),
14653            "Invalid authentication credentials"
14654        );
14655    }
14656
14657    #[test]
14658    fn bearer_token_new_and_accessors() {
14659        let token = BearerToken::new("test_token");
14660        assert_eq!(token.token(), "test_token");
14661        assert_eq!(token.clone().into_token(), "test_token");
14662    }
14663
14664    #[test]
14665    fn bearer_token_error_response_json_body() {
14666        let err = BearerTokenError::missing_header();
14667        let response = err.into_response();
14668
14669        let body_str = match response.body_ref() {
14670            crate::response::ResponseBody::Bytes(b) => String::from_utf8_lossy(b).to_string(),
14671            _ => panic!("Expected Bytes body"),
14672        };
14673        let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
14674
14675        assert_eq!(body["detail"], "Not authenticated");
14676    }
14677
14678    #[test]
14679    fn bearer_token_error_content_type_json() {
14680        let err = BearerTokenError::missing_header();
14681        let response = err.into_response();
14682
14683        let has_json_content_type = response
14684            .headers()
14685            .iter()
14686            .any(|(name, value)| name == "content-type" && value == b"application/json");
14687        assert!(has_json_content_type);
14688    }
14689
14690    #[test]
14691    fn bearer_token_special_characters() {
14692        let ctx = test_context();
14693        let mut req = Request::new(Method::Get, "/api/protected");
14694        let special_token = "abc123!@#$%^&*()_+-=[]{}|;':\",./<>?";
14695        req.headers_mut().insert(
14696            "authorization",
14697            format!("Bearer {}", special_token).into_bytes(),
14698        );
14699
14700        let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14701        let token = result.unwrap();
14702        assert_eq!(token.token(), special_token);
14703    }
14704
14705    #[test]
14706    fn bearer_token_very_long_token() {
14707        let ctx = test_context();
14708        let mut req = Request::new(Method::Get, "/api/protected");
14709        let long_token = "a".repeat(10000);
14710        req.headers_mut().insert(
14711            "authorization",
14712            format!("Bearer {}", long_token).into_bytes(),
14713        );
14714
14715        let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14716        let token = result.unwrap();
14717        assert_eq!(token.token(), long_token);
14718    }
14719
14720    #[test]
14721    fn bearer_token_invalid_utf8() {
14722        let ctx = test_context();
14723        let mut req = Request::new(Method::Get, "/api/protected");
14724        // Invalid UTF-8 sequence
14725        req.headers_mut().insert(
14726            "authorization",
14727            vec![0x42, 0x65, 0x61, 0x72, 0x65, 0x72, 0x20, 0xFF, 0xFE],
14728        );
14729
14730        let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14731        let err = result.unwrap_err();
14732        assert_eq!(err, BearerTokenError::InvalidScheme);
14733    }
14734
14735    #[test]
14736    fn bearer_token_only_bearer_no_space() {
14737        let ctx = test_context();
14738        let mut req = Request::new(Method::Get, "/api/protected");
14739        // "Bearer" without trailing space and token
14740        req.headers_mut()
14741            .insert("authorization", b"Bearer".to_vec());
14742
14743        let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14744        let err = result.unwrap_err();
14745        assert_eq!(err, BearerTokenError::InvalidScheme);
14746    }
14747
14748    #[test]
14749    fn bearer_token_mixed_case_bearer() {
14750        let ctx = test_context();
14751        let mut req = Request::new(Method::Get, "/api/protected");
14752        // Mixed case should fail (we only support "Bearer" and "bearer")
14753        req.headers_mut()
14754            .insert("authorization", b"BEARER token".to_vec());
14755
14756        let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14757        let err = result.unwrap_err();
14758        assert_eq!(err, BearerTokenError::InvalidScheme);
14759    }
14760
14761    #[test]
14762    fn bearer_token_all_errors_are_401() {
14763        let errors = vec![
14764            BearerTokenError::missing_header(),
14765            BearerTokenError::invalid_scheme(),
14766            BearerTokenError::empty_token(),
14767        ];
14768
14769        for err in errors {
14770            let response = err.into_response();
14771            assert_eq!(
14772                response.status().as_u16(),
14773                401,
14774                "All BearerToken errors should be 401"
14775            );
14776        }
14777    }
14778
14779    #[test]
14780    fn bearer_token_all_errors_have_www_authenticate() {
14781        let errors = vec![
14782            BearerTokenError::missing_header(),
14783            BearerTokenError::invalid_scheme(),
14784            BearerTokenError::empty_token(),
14785        ];
14786
14787        for err in errors {
14788            let response = err.into_response();
14789            let has_www_auth = response
14790                .headers()
14791                .iter()
14792                .any(|(name, value)| name == "www-authenticate" && value == b"Bearer");
14793            assert!(
14794                has_www_auth,
14795                "All BearerToken errors should have WWW-Authenticate: Bearer"
14796            );
14797        }
14798    }
14799
14800    #[test]
14801    fn bearer_token_equality() {
14802        let token1 = BearerToken::new("same_token");
14803        let token2 = BearerToken::new("same_token");
14804        let token3 = BearerToken::new("different_token");
14805
14806        assert_eq!(token1, token2);
14807        assert_ne!(token1, token3);
14808    }
14809
14810    #[test]
14811    fn bearer_token_error_equality() {
14812        assert_eq!(
14813            BearerTokenError::MissingHeader,
14814            BearerTokenError::MissingHeader
14815        );
14816        assert_eq!(
14817            BearerTokenError::InvalidScheme,
14818            BearerTokenError::InvalidScheme
14819        );
14820        assert_eq!(BearerTokenError::EmptyToken, BearerTokenError::EmptyToken);
14821        assert_ne!(
14822            BearerTokenError::MissingHeader,
14823            BearerTokenError::InvalidScheme
14824        );
14825    }
14826
14827    #[test]
14828    fn bearer_token_debug() {
14829        let token = BearerToken::new("debug_token");
14830        let debug_str = format!("{:?}", token);
14831        assert!(debug_str.contains("debug_token"));
14832    }
14833
14834    #[test]
14835    fn bearer_token_clone() {
14836        let token = BearerToken::new("cloneable");
14837        let cloned = token.clone();
14838        assert_eq!(token, cloned);
14839    }
14840}
14841
14842#[cfg(test)]
14843mod api_key_header_tests {
14844    use super::*;
14845    use crate::request::Method;
14846    use crate::response::IntoResponse;
14847
14848    fn test_context() -> RequestContext {
14849        let cx = asupersync::Cx::for_testing();
14850        RequestContext::new(cx, 54321)
14851    }
14852
14853    #[test]
14854    fn api_key_header_extraction_default() {
14855        let ctx = test_context();
14856        let mut req = Request::new(Method::Get, "/api/protected");
14857        req.headers_mut()
14858            .insert("x-api-key", b"test_api_key_123".to_vec());
14859
14860        let result = futures_executor::block_on(ApiKeyHeader::from_request(&ctx, &mut req));
14861        let api_key = result.unwrap();
14862        assert_eq!(api_key.key(), "test_api_key_123");
14863        assert_eq!(api_key.header_name(), "x-api-key");
14864    }
14865
14866    #[test]
14867    fn api_key_header_missing() {
14868        let ctx = test_context();
14869        let mut req = Request::new(Method::Get, "/api/protected");
14870        // No API key header
14871
14872        let result = futures_executor::block_on(ApiKeyHeader::from_request(&ctx, &mut req));
14873        assert!(result.is_err());
14874        let err = result.unwrap_err();
14875        assert!(matches!(err, ApiKeyHeaderError::MissingHeader { .. }));
14876    }
14877
14878    #[test]
14879    fn api_key_header_empty() {
14880        let ctx = test_context();
14881        let mut req = Request::new(Method::Get, "/api/protected");
14882        req.headers_mut().insert("x-api-key", b"".to_vec());
14883
14884        let result = futures_executor::block_on(ApiKeyHeader::from_request(&ctx, &mut req));
14885        assert!(result.is_err());
14886        let err = result.unwrap_err();
14887        assert!(matches!(err, ApiKeyHeaderError::EmptyKey { .. }));
14888    }
14889
14890    #[test]
14891    fn api_key_header_whitespace_only() {
14892        let ctx = test_context();
14893        let mut req = Request::new(Method::Get, "/api/protected");
14894        req.headers_mut().insert("x-api-key", b"   ".to_vec());
14895
14896        let result = futures_executor::block_on(ApiKeyHeader::from_request(&ctx, &mut req));
14897        assert!(result.is_err());
14898        let err = result.unwrap_err();
14899        assert!(matches!(err, ApiKeyHeaderError::EmptyKey { .. }));
14900    }
14901
14902    #[test]
14903    fn api_key_header_trims_whitespace() {
14904        let ctx = test_context();
14905        let mut req = Request::new(Method::Get, "/api/protected");
14906        req.headers_mut()
14907            .insert("x-api-key", b"  my_key_123  ".to_vec());
14908
14909        let result = futures_executor::block_on(ApiKeyHeader::from_request(&ctx, &mut req));
14910        let api_key = result.unwrap();
14911        assert_eq!(api_key.key(), "my_key_123");
14912    }
14913
14914    #[test]
14915    fn api_key_header_custom_header_name() {
14916        let ctx = test_context();
14917        let mut req = Request::new(Method::Get, "/api/protected");
14918        req.headers_mut()
14919            .insert("authorization", b"custom_key".to_vec());
14920        req.insert_extension(ApiKeyHeaderConfig::new().header_name("authorization"));
14921
14922        let result = futures_executor::block_on(ApiKeyHeader::from_request(&ctx, &mut req));
14923        let api_key = result.unwrap();
14924        assert_eq!(api_key.key(), "custom_key");
14925        assert_eq!(api_key.header_name(), "authorization");
14926    }
14927
14928    #[test]
14929    fn api_key_header_invalid_utf8() {
14930        let ctx = test_context();
14931        let mut req = Request::new(Method::Get, "/api/protected");
14932        // Invalid UTF-8 sequence
14933        req.headers_mut()
14934            .insert("x-api-key", vec![0xFF, 0xFE, 0x00, 0x01]);
14935
14936        let result = futures_executor::block_on(ApiKeyHeader::from_request(&ctx, &mut req));
14937        assert!(result.is_err());
14938        let err = result.unwrap_err();
14939        assert!(matches!(err, ApiKeyHeaderError::InvalidUtf8 { .. }));
14940    }
14941
14942    #[test]
14943    fn api_key_header_error_response_401() {
14944        let err = ApiKeyHeaderError::missing_header("x-api-key");
14945        let response = err.into_response();
14946        assert_eq!(response.status().as_u16(), 401);
14947    }
14948
14949    #[test]
14950    fn api_key_header_error_response_json() {
14951        let err = ApiKeyHeaderError::missing_header("x-api-key");
14952        let response = err.into_response();
14953
14954        let has_json_content_type = response
14955            .headers()
14956            .iter()
14957            .any(|(name, value)| name == "content-type" && value == b"application/json");
14958        assert!(has_json_content_type);
14959    }
14960
14961    #[test]
14962    fn api_key_header_secure_compare() {
14963        let api_key = ApiKeyHeader::new("secret_key_123");
14964
14965        // Timing-safe comparison
14966        assert!(api_key.secure_eq("secret_key_123"));
14967        assert!(!api_key.secure_eq("secret_key_124"));
14968        assert!(!api_key.secure_eq("wrong"));
14969
14970        // Bytes comparison
14971        assert!(api_key.secure_eq_bytes(b"secret_key_123"));
14972        assert!(!api_key.secure_eq_bytes(b"secret_key_124"));
14973    }
14974
14975    #[test]
14976    fn api_key_header_deref_and_as_ref() {
14977        let api_key = ApiKeyHeader::new("deref_test");
14978
14979        // Deref to &str
14980        let s: &str = &api_key;
14981        assert_eq!(s, "deref_test");
14982
14983        // AsRef<str>
14984        let s: &str = api_key.as_ref();
14985        assert_eq!(s, "deref_test");
14986    }
14987
14988    #[test]
14989    fn api_key_header_config_defaults() {
14990        let config = ApiKeyHeaderConfig::default();
14991        assert_eq!(config.get_header_name(), DEFAULT_API_KEY_HEADER);
14992    }
14993
14994    #[test]
14995    fn api_key_header_error_display() {
14996        let err = ApiKeyHeaderError::missing_header("x-api-key");
14997        assert!(err.to_string().contains("x-api-key"));
14998
14999        let err = ApiKeyHeaderError::empty_key("x-api-key");
15000        assert!(err.to_string().contains("Empty"));
15001
15002        let err = ApiKeyHeaderError::invalid_utf8("x-api-key");
15003        assert!(err.to_string().contains("Invalid UTF-8"));
15004    }
15005
15006    #[test]
15007    fn api_key_header_equality() {
15008        let key1 = ApiKeyHeader::new("same_key");
15009        let key2 = ApiKeyHeader::new("same_key");
15010        let key3 = ApiKeyHeader::new("different_key");
15011
15012        assert_eq!(key1, key2);
15013        assert_ne!(key1, key3);
15014    }
15015}
15016
15017#[cfg(test)]
15018mod api_key_query_tests {
15019    use super::*;
15020    use crate::request::Method;
15021    use crate::response::IntoResponse;
15022
15023    fn test_context() -> RequestContext {
15024        let cx = asupersync::Cx::for_testing();
15025        RequestContext::new(cx, 99999)
15026    }
15027
15028    #[test]
15029    fn api_key_query_basic_extraction() {
15030        let ctx = test_context();
15031        let mut req = Request::new(Method::Get, "/api/webhook");
15032        req.set_query(Some("api_key=test_key_123".to_string()));
15033
15034        let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15035        let api_key = result.unwrap();
15036        assert_eq!(api_key.key(), "test_key_123");
15037        assert_eq!(api_key.param_name(), "api_key");
15038    }
15039
15040    #[test]
15041    fn api_key_query_missing() {
15042        let ctx = test_context();
15043        let mut req = Request::new(Method::Get, "/api/webhook");
15044        // No query string
15045
15046        let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15047        assert!(result.is_err());
15048        let err = result.unwrap_err();
15049        assert!(matches!(err, ApiKeyQueryError::MissingParam { .. }));
15050    }
15051
15052    #[test]
15053    fn api_key_query_empty_query_string() {
15054        let ctx = test_context();
15055        let mut req = Request::new(Method::Get, "/api/webhook");
15056        req.set_query(Some(String::new()));
15057
15058        let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15059        assert!(result.is_err());
15060        let err = result.unwrap_err();
15061        assert!(matches!(err, ApiKeyQueryError::MissingParam { .. }));
15062    }
15063
15064    #[test]
15065    fn api_key_query_param_missing_but_others_present() {
15066        let ctx = test_context();
15067        let mut req = Request::new(Method::Get, "/api/webhook");
15068        req.set_query(Some("other_param=value".to_string()));
15069
15070        let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15071        assert!(result.is_err());
15072        let err = result.unwrap_err();
15073        assert!(matches!(err, ApiKeyQueryError::MissingParam { .. }));
15074    }
15075
15076    #[test]
15077    fn api_key_query_empty_value() {
15078        let ctx = test_context();
15079        let mut req = Request::new(Method::Get, "/api/webhook");
15080        req.set_query(Some("api_key=".to_string()));
15081
15082        let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15083        assert!(result.is_err());
15084        let err = result.unwrap_err();
15085        assert!(matches!(err, ApiKeyQueryError::EmptyKey { .. }));
15086    }
15087
15088    #[test]
15089    fn api_key_query_whitespace_only() {
15090        let ctx = test_context();
15091        let mut req = Request::new(Method::Get, "/api/webhook");
15092        req.set_query(Some("api_key=   ".to_string()));
15093
15094        let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15095        assert!(result.is_err());
15096        let err = result.unwrap_err();
15097        assert!(matches!(err, ApiKeyQueryError::EmptyKey { .. }));
15098    }
15099
15100    #[test]
15101    fn api_key_query_trims_whitespace() {
15102        let ctx = test_context();
15103        let mut req = Request::new(Method::Get, "/api/webhook");
15104        req.set_query(Some("api_key=  my_key_123  ".to_string()));
15105
15106        let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15107        let api_key = result.unwrap();
15108        assert_eq!(api_key.key(), "my_key_123");
15109    }
15110
15111    #[test]
15112    fn api_key_query_custom_param_name() {
15113        let ctx = test_context();
15114        let mut req = Request::new(Method::Get, "/api/webhook");
15115        req.set_query(Some("token=custom_key".to_string()));
15116        req.insert_extension(ApiKeyQueryConfig::new().param_name("token"));
15117
15118        let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15119        let api_key = result.unwrap();
15120        assert_eq!(api_key.key(), "custom_key");
15121        assert_eq!(api_key.param_name(), "token");
15122    }
15123
15124    #[test]
15125    fn api_key_query_with_other_params() {
15126        let ctx = test_context();
15127        let mut req = Request::new(Method::Get, "/api/webhook");
15128        req.set_query(Some(
15129            "callback=https://example.com&api_key=webhook_key&format=json".to_string(),
15130        ));
15131
15132        let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15133        let api_key = result.unwrap();
15134        assert_eq!(api_key.key(), "webhook_key");
15135    }
15136
15137    #[test]
15138    fn api_key_query_url_encoded_value() {
15139        let ctx = test_context();
15140        let mut req = Request::new(Method::Get, "/api/webhook");
15141        // URL encoded key with special chars: "key+with spaces" -> "key%2Bwith%20spaces"
15142        req.set_query(Some("api_key=key%2Bwith%20spaces".to_string()));
15143
15144        let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15145        let api_key = result.unwrap();
15146        assert_eq!(api_key.key(), "key+with spaces");
15147    }
15148
15149    #[test]
15150    fn api_key_query_error_response_401() {
15151        let err = ApiKeyQueryError::missing_param("api_key");
15152        let response = err.into_response();
15153        assert_eq!(response.status().as_u16(), 401);
15154    }
15155
15156    #[test]
15157    fn api_key_query_error_response_json() {
15158        let err = ApiKeyQueryError::missing_param("api_key");
15159        let response = err.into_response();
15160
15161        let has_json_content_type = response
15162            .headers()
15163            .iter()
15164            .any(|(n, v)| n == "content-type" && v.starts_with(b"application/json"));
15165        assert!(has_json_content_type);
15166    }
15167
15168    #[test]
15169    fn api_key_query_secure_compare() {
15170        let api_key = ApiKeyQuery::new("secret_key_123");
15171
15172        // Timing-safe comparison
15173        assert!(api_key.secure_eq("secret_key_123"));
15174        assert!(!api_key.secure_eq("secret_key_124"));
15175        assert!(!api_key.secure_eq("wrong"));
15176
15177        // Byte comparison
15178        assert!(api_key.secure_eq_bytes(b"secret_key_123"));
15179        assert!(!api_key.secure_eq_bytes(b"secret_key_124"));
15180    }
15181
15182    #[test]
15183    fn api_key_query_deref_and_as_ref() {
15184        let api_key = ApiKeyQuery::new("deref_test");
15185
15186        // Deref to &str
15187        let s: &str = &api_key;
15188        assert_eq!(s, "deref_test");
15189
15190        // AsRef<str>
15191        let s: &str = api_key.as_ref();
15192        assert_eq!(s, "deref_test");
15193    }
15194
15195    #[test]
15196    fn api_key_query_config_defaults() {
15197        let config = ApiKeyQueryConfig::default();
15198        assert_eq!(config.get_param_name(), DEFAULT_API_KEY_QUERY_PARAM);
15199    }
15200
15201    #[test]
15202    fn api_key_query_error_display() {
15203        let err = ApiKeyQueryError::missing_param("api_key");
15204        assert!(err.to_string().contains("api_key"));
15205
15206        let err = ApiKeyQueryError::empty_key("api_key");
15207        assert!(err.to_string().contains("Empty"));
15208    }
15209
15210    #[test]
15211    fn api_key_query_equality() {
15212        let key1 = ApiKeyQuery::new("same_key");
15213        let key2 = ApiKeyQuery::new("same_key");
15214        let key3 = ApiKeyQuery::new("different_key");
15215
15216        assert_eq!(key1, key2);
15217        assert_ne!(key1, key3);
15218    }
15219}
15220
15221#[cfg(test)]
15222mod api_key_cookie_tests {
15223    use super::*;
15224    use crate::request::Method;
15225    use crate::response::IntoResponse;
15226
15227    fn test_context() -> RequestContext {
15228        let cx = asupersync::Cx::for_testing();
15229        RequestContext::new(cx, 77777)
15230    }
15231
15232    #[test]
15233    fn api_key_cookie_basic_extraction() {
15234        let ctx = test_context();
15235        let mut req = Request::new(Method::Get, "/api/protected");
15236        req.headers_mut()
15237            .insert("cookie", b"api_key=test_key_123".to_vec());
15238
15239        let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15240        let api_key = result.unwrap();
15241        assert_eq!(api_key.key(), "test_key_123");
15242        assert_eq!(api_key.cookie_name(), "api_key");
15243    }
15244
15245    #[test]
15246    fn api_key_cookie_missing_header() {
15247        let ctx = test_context();
15248        let mut req = Request::new(Method::Get, "/api/protected");
15249        // No cookie header
15250
15251        let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15252        assert!(result.is_err());
15253        let err = result.unwrap_err();
15254        assert!(matches!(err, ApiKeyCookieError::MissingCookie { .. }));
15255    }
15256
15257    #[test]
15258    fn api_key_cookie_other_cookies_present() {
15259        let ctx = test_context();
15260        let mut req = Request::new(Method::Get, "/api/protected");
15261        req.headers_mut()
15262            .insert("cookie", b"session_id=abc123; theme=dark".to_vec());
15263        // api_key cookie is missing
15264
15265        let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15266        assert!(result.is_err());
15267        let err = result.unwrap_err();
15268        assert!(matches!(err, ApiKeyCookieError::MissingCookie { .. }));
15269    }
15270
15271    #[test]
15272    fn api_key_cookie_empty_value() {
15273        let ctx = test_context();
15274        let mut req = Request::new(Method::Get, "/api/protected");
15275        req.headers_mut().insert("cookie", b"api_key=".to_vec());
15276
15277        let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15278        assert!(result.is_err());
15279        let err = result.unwrap_err();
15280        assert!(matches!(err, ApiKeyCookieError::EmptyKey { .. }));
15281    }
15282
15283    #[test]
15284    fn api_key_cookie_whitespace_only() {
15285        let ctx = test_context();
15286        let mut req = Request::new(Method::Get, "/api/protected");
15287        req.headers_mut().insert("cookie", b"api_key=   ".to_vec());
15288
15289        let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15290        assert!(result.is_err());
15291        let err = result.unwrap_err();
15292        assert!(matches!(err, ApiKeyCookieError::EmptyKey { .. }));
15293    }
15294
15295    #[test]
15296    fn api_key_cookie_trims_whitespace() {
15297        let ctx = test_context();
15298        let mut req = Request::new(Method::Get, "/api/protected");
15299        req.headers_mut()
15300            .insert("cookie", b"api_key=  my_key_123  ".to_vec());
15301
15302        let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15303        let api_key = result.unwrap();
15304        assert_eq!(api_key.key(), "my_key_123");
15305    }
15306
15307    #[test]
15308    fn api_key_cookie_custom_name() {
15309        let ctx = test_context();
15310        let mut req = Request::new(Method::Get, "/api/protected");
15311        req.headers_mut()
15312            .insert("cookie", b"auth_token=custom_key".to_vec());
15313        req.insert_extension(ApiKeyCookieConfig::new().cookie_name("auth_token"));
15314
15315        let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15316        let api_key = result.unwrap();
15317        assert_eq!(api_key.key(), "custom_key");
15318        assert_eq!(api_key.cookie_name(), "auth_token");
15319    }
15320
15321    #[test]
15322    fn api_key_cookie_with_multiple_cookies() {
15323        let ctx = test_context();
15324        let mut req = Request::new(Method::Get, "/api/protected");
15325        req.headers_mut().insert(
15326            "cookie",
15327            b"session_id=sess123; api_key=my_api_key; theme=dark".to_vec(),
15328        );
15329
15330        let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15331        let api_key = result.unwrap();
15332        assert_eq!(api_key.key(), "my_api_key");
15333    }
15334
15335    #[test]
15336    fn api_key_cookie_error_response_401() {
15337        let err = ApiKeyCookieError::missing_cookie("api_key");
15338        let response = err.into_response();
15339        assert_eq!(response.status().as_u16(), 401);
15340    }
15341
15342    #[test]
15343    fn api_key_cookie_error_response_json() {
15344        let err = ApiKeyCookieError::missing_cookie("api_key");
15345        let response = err.into_response();
15346
15347        let has_json_content_type = response
15348            .headers()
15349            .iter()
15350            .any(|(n, v)| n == "content-type" && v.starts_with(b"application/json"));
15351        assert!(has_json_content_type);
15352    }
15353
15354    #[test]
15355    fn api_key_cookie_secure_compare() {
15356        let api_key = ApiKeyCookie::new("secret_key_123");
15357
15358        // Timing-safe comparison
15359        assert!(api_key.secure_eq("secret_key_123"));
15360        assert!(!api_key.secure_eq("secret_key_124"));
15361        assert!(!api_key.secure_eq("wrong"));
15362
15363        // Byte comparison
15364        assert!(api_key.secure_eq_bytes(b"secret_key_123"));
15365        assert!(!api_key.secure_eq_bytes(b"secret_key_124"));
15366    }
15367
15368    #[test]
15369    fn api_key_cookie_deref_and_as_ref() {
15370        let api_key = ApiKeyCookie::new("deref_test");
15371
15372        // Deref to &str
15373        let s: &str = &api_key;
15374        assert_eq!(s, "deref_test");
15375
15376        // AsRef<str>
15377        let s: &str = api_key.as_ref();
15378        assert_eq!(s, "deref_test");
15379    }
15380
15381    #[test]
15382    fn api_key_cookie_config_defaults() {
15383        let config = ApiKeyCookieConfig::default();
15384        assert_eq!(config.get_cookie_name(), DEFAULT_API_KEY_COOKIE);
15385    }
15386
15387    #[test]
15388    fn api_key_cookie_error_display() {
15389        let err = ApiKeyCookieError::missing_cookie("api_key");
15390        assert!(err.to_string().contains("api_key"));
15391
15392        let err = ApiKeyCookieError::empty_key("api_key");
15393        assert!(err.to_string().contains("Empty"));
15394    }
15395
15396    #[test]
15397    fn api_key_cookie_equality() {
15398        let key1 = ApiKeyCookie::new("same_key");
15399        let key2 = ApiKeyCookie::new("same_key");
15400        let key3 = ApiKeyCookie::new("different_key");
15401
15402        assert_eq!(key1, key2);
15403        assert_ne!(key1, key3);
15404    }
15405}
15406
15407#[cfg(test)]
15408mod basic_auth_tests {
15409    use super::*;
15410    use crate::request::Method;
15411    use crate::response::IntoResponse;
15412
15413    fn test_context() -> RequestContext {
15414        let cx = asupersync::Cx::for_testing();
15415        RequestContext::new(cx, 12345)
15416    }
15417
15418    // Helper to base64 encode credentials
15419    fn encode_basic_auth(username: &str, password: &str) -> String {
15420        // Manual base64 encoding for test purposes
15421        const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
15422        let input = format!("{username}:{password}");
15423        let bytes = input.as_bytes();
15424        let mut output = String::new();
15425
15426        for chunk in bytes.chunks(3) {
15427            let mut n: u32 = 0;
15428            for (i, &byte) in chunk.iter().enumerate() {
15429                n |= u32::from(byte) << (16 - 8 * i);
15430            }
15431
15432            let chars = match chunk.len() {
15433                3 => 4,
15434                2 => 3,
15435                1 => 2,
15436                _ => unreachable!(),
15437            };
15438
15439            for i in 0..chars {
15440                let idx = ((n >> (18 - 6 * i)) & 0x3F) as usize;
15441                output.push(ALPHABET[idx] as char);
15442            }
15443
15444            // Add padding
15445            for _ in chars..4 {
15446                output.push('=');
15447            }
15448        }
15449
15450        output
15451    }
15452
15453    #[test]
15454    fn basic_auth_extract_valid_credentials() {
15455        let ctx = test_context();
15456        let mut req = Request::new(Method::Get, "/api/protected");
15457        let encoded = encode_basic_auth("alice", "secret123");
15458        req.headers_mut()
15459            .insert("authorization", format!("Basic {encoded}").into_bytes());
15460
15461        let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15462        let auth = result.unwrap();
15463        assert_eq!(auth.username(), "alice");
15464        assert_eq!(auth.password(), "secret123");
15465    }
15466
15467    #[test]
15468    fn basic_auth_extract_lowercase_basic() {
15469        let ctx = test_context();
15470        let mut req = Request::new(Method::Get, "/api/protected");
15471        let encoded = encode_basic_auth("bob", "pass");
15472        req.headers_mut()
15473            .insert("authorization", format!("basic {encoded}").into_bytes());
15474
15475        let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15476        let auth = result.unwrap();
15477        assert_eq!(auth.username(), "bob");
15478        assert_eq!(auth.password(), "pass");
15479    }
15480
15481    #[test]
15482    fn basic_auth_missing_header() {
15483        let ctx = test_context();
15484        let mut req = Request::new(Method::Get, "/api/protected");
15485        // No authorization header
15486
15487        let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15488        let err = result.unwrap_err();
15489        assert_eq!(err, BasicAuthError::MissingHeader);
15490    }
15491
15492    #[test]
15493    fn basic_auth_wrong_scheme() {
15494        let ctx = test_context();
15495        let mut req = Request::new(Method::Get, "/api/protected");
15496        req.headers_mut()
15497            .insert("authorization", b"Bearer sometoken".to_vec());
15498
15499        let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15500        let err = result.unwrap_err();
15501        assert_eq!(err, BasicAuthError::InvalidScheme);
15502    }
15503
15504    #[test]
15505    fn basic_auth_invalid_base64() {
15506        let ctx = test_context();
15507        let mut req = Request::new(Method::Get, "/api/protected");
15508        req.headers_mut()
15509            .insert("authorization", b"Basic !!!invalid!!!".to_vec());
15510
15511        let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15512        let err = result.unwrap_err();
15513        assert_eq!(err, BasicAuthError::InvalidBase64);
15514    }
15515
15516    #[test]
15517    fn basic_auth_missing_colon() {
15518        let ctx = test_context();
15519        let mut req = Request::new(Method::Get, "/api/protected");
15520        // Base64 of "nocolon" (no colon separator)
15521        req.headers_mut()
15522            .insert("authorization", b"Basic bm9jb2xvbg==".to_vec());
15523
15524        let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15525        let err = result.unwrap_err();
15526        assert_eq!(err, BasicAuthError::MissingColon);
15527    }
15528
15529    #[test]
15530    fn basic_auth_empty_username() {
15531        let ctx = test_context();
15532        let mut req = Request::new(Method::Get, "/api/protected");
15533        let encoded = encode_basic_auth("", "password");
15534        req.headers_mut()
15535            .insert("authorization", format!("Basic {encoded}").into_bytes());
15536
15537        let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15538        let auth = result.unwrap();
15539        assert_eq!(auth.username(), "");
15540        assert_eq!(auth.password(), "password");
15541    }
15542
15543    #[test]
15544    fn basic_auth_empty_password() {
15545        let ctx = test_context();
15546        let mut req = Request::new(Method::Get, "/api/protected");
15547        let encoded = encode_basic_auth("user", "");
15548        req.headers_mut()
15549            .insert("authorization", format!("Basic {encoded}").into_bytes());
15550
15551        let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15552        let auth = result.unwrap();
15553        assert_eq!(auth.username(), "user");
15554        assert_eq!(auth.password(), "");
15555    }
15556
15557    #[test]
15558    fn basic_auth_password_with_colons() {
15559        let ctx = test_context();
15560        let mut req = Request::new(Method::Get, "/api/protected");
15561        // Password contains colons: "pass:word:with:colons"
15562        let encoded = encode_basic_auth("user", "pass:word:with:colons");
15563        req.headers_mut()
15564            .insert("authorization", format!("Basic {encoded}").into_bytes());
15565
15566        let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15567        let auth = result.unwrap();
15568        assert_eq!(auth.username(), "user");
15569        assert_eq!(auth.password(), "pass:word:with:colons");
15570    }
15571
15572    #[test]
15573    fn basic_auth_optional_some() {
15574        let ctx = test_context();
15575        let mut req = Request::new(Method::Get, "/api/protected");
15576        let encoded = encode_basic_auth("optional", "user");
15577        req.headers_mut()
15578            .insert("authorization", format!("Basic {encoded}").into_bytes());
15579
15580        let result = futures_executor::block_on(Option::<BasicAuth>::from_request(&ctx, &mut req));
15581        let maybe_auth = result.unwrap();
15582        assert!(maybe_auth.is_some());
15583        assert_eq!(maybe_auth.unwrap().username(), "optional");
15584    }
15585
15586    #[test]
15587    fn basic_auth_optional_none() {
15588        let ctx = test_context();
15589        let mut req = Request::new(Method::Get, "/api/protected");
15590        // No authorization header
15591
15592        let result = futures_executor::block_on(Option::<BasicAuth>::from_request(&ctx, &mut req));
15593        let maybe_auth = result.unwrap();
15594        assert!(maybe_auth.is_none());
15595    }
15596
15597    #[test]
15598    fn basic_auth_error_response_401() {
15599        let err = BasicAuthError::missing_header();
15600        let response = err.into_response();
15601        assert_eq!(response.status().as_u16(), 401);
15602    }
15603
15604    #[test]
15605    fn basic_auth_error_has_www_authenticate() {
15606        let err = BasicAuthError::missing_header();
15607        let response = err.into_response();
15608
15609        let has_www_auth = response
15610            .headers()
15611            .iter()
15612            .any(|(name, value)| name == "www-authenticate" && value == b"Basic");
15613        assert!(has_www_auth);
15614    }
15615
15616    #[test]
15617    fn basic_auth_error_display() {
15618        assert_eq!(
15619            BasicAuthError::missing_header().to_string(),
15620            "Missing Authorization header"
15621        );
15622        assert_eq!(
15623            BasicAuthError::invalid_scheme().to_string(),
15624            "Authorization header must use Basic scheme"
15625        );
15626        assert_eq!(
15627            BasicAuthError::invalid_base64().to_string(),
15628            "Invalid base64 encoding in credentials"
15629        );
15630        assert_eq!(
15631            BasicAuthError::missing_colon().to_string(),
15632            "Credentials must contain username:password"
15633        );
15634        assert_eq!(
15635            BasicAuthError::invalid_utf8().to_string(),
15636            "Credentials contain invalid UTF-8"
15637        );
15638    }
15639
15640    #[test]
15641    fn basic_auth_error_detail() {
15642        assert_eq!(BasicAuthError::MissingHeader.detail(), "Not authenticated");
15643        assert_eq!(
15644            BasicAuthError::InvalidScheme.detail(),
15645            "Invalid authentication credentials"
15646        );
15647        assert_eq!(
15648            BasicAuthError::InvalidBase64.detail(),
15649            "Invalid authentication credentials"
15650        );
15651        assert_eq!(
15652            BasicAuthError::MissingColon.detail(),
15653            "Invalid authentication credentials"
15654        );
15655        assert_eq!(
15656            BasicAuthError::InvalidUtf8.detail(),
15657            "Invalid authentication credentials"
15658        );
15659    }
15660
15661    #[test]
15662    fn basic_auth_new_and_accessors() {
15663        let auth = BasicAuth::new("testuser", "testpass");
15664        assert_eq!(auth.username(), "testuser");
15665        assert_eq!(auth.password(), "testpass");
15666        let (user, pass) = auth.into_credentials();
15667        assert_eq!(user, "testuser");
15668        assert_eq!(pass, "testpass");
15669    }
15670
15671    #[test]
15672    fn basic_auth_error_response_json_body() {
15673        let err = BasicAuthError::missing_header();
15674        let response = err.into_response();
15675
15676        let body_str = match response.body_ref() {
15677            crate::response::ResponseBody::Bytes(b) => String::from_utf8_lossy(b).to_string(),
15678            _ => panic!("Expected Bytes body"),
15679        };
15680        let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
15681
15682        assert_eq!(body["detail"], "Not authenticated");
15683    }
15684
15685    #[test]
15686    fn basic_auth_error_content_type_json() {
15687        let err = BasicAuthError::missing_header();
15688        let response = err.into_response();
15689
15690        let has_json_content_type = response
15691            .headers()
15692            .iter()
15693            .any(|(name, value)| name == "content-type" && value == b"application/json");
15694        assert!(has_json_content_type);
15695    }
15696
15697    #[test]
15698    fn basic_auth_all_errors_return_401() {
15699        let errors = [
15700            BasicAuthError::missing_header(),
15701            BasicAuthError::invalid_scheme(),
15702            BasicAuthError::invalid_base64(),
15703            BasicAuthError::missing_colon(),
15704            BasicAuthError::invalid_utf8(),
15705        ];
15706
15707        for err in errors {
15708            let response = err.into_response();
15709            assert_eq!(
15710                response.status().as_u16(),
15711                401,
15712                "All BasicAuth errors should be 401"
15713            );
15714        }
15715    }
15716
15717    #[test]
15718    fn basic_auth_all_errors_have_www_authenticate() {
15719        let errors = [
15720            BasicAuthError::missing_header(),
15721            BasicAuthError::invalid_scheme(),
15722            BasicAuthError::invalid_base64(),
15723            BasicAuthError::missing_colon(),
15724            BasicAuthError::invalid_utf8(),
15725        ];
15726
15727        for err in errors {
15728            let response = err.into_response();
15729            let has_www_auth = response
15730                .headers()
15731                .iter()
15732                .any(|(name, value)| name == "www-authenticate" && value == b"Basic");
15733            assert!(
15734                has_www_auth,
15735                "All BasicAuth errors should have WWW-Authenticate: Basic"
15736            );
15737        }
15738    }
15739
15740    #[test]
15741    fn basic_auth_eq_and_clone() {
15742        let auth1 = BasicAuth::new("user", "pass");
15743        let auth2 = BasicAuth::new("user", "pass");
15744        let auth3 = BasicAuth::new("other", "pass");
15745
15746        assert_eq!(auth1, auth2);
15747        assert_ne!(auth1, auth3);
15748
15749        let cloned = auth1.clone();
15750        assert_eq!(auth1, cloned);
15751    }
15752
15753    #[test]
15754    fn basic_auth_error_eq() {
15755        assert_eq!(BasicAuthError::MissingHeader, BasicAuthError::MissingHeader);
15756        assert_eq!(BasicAuthError::InvalidScheme, BasicAuthError::InvalidScheme);
15757        assert_eq!(BasicAuthError::InvalidBase64, BasicAuthError::InvalidBase64);
15758        assert_eq!(BasicAuthError::MissingColon, BasicAuthError::MissingColon);
15759        assert_eq!(BasicAuthError::InvalidUtf8, BasicAuthError::InvalidUtf8);
15760        assert_ne!(BasicAuthError::MissingHeader, BasicAuthError::InvalidScheme);
15761    }
15762
15763    #[test]
15764    fn basic_auth_debug() {
15765        let auth = BasicAuth::new("debug_user", "debug_pass");
15766        let debug_str = format!("{auth:?}");
15767        assert!(debug_str.contains("debug_user"));
15768        assert!(debug_str.contains("debug_pass"));
15769    }
15770
15771    // Base64 decoder tests
15772    #[test]
15773    fn decode_base64_valid() {
15774        // "user:pass" encodes to "dXNlcjpwYXNz"
15775        let result = decode_base64("dXNlcjpwYXNz").unwrap();
15776        assert_eq!(String::from_utf8(result).unwrap(), "user:pass");
15777    }
15778
15779    #[test]
15780    fn decode_base64_with_padding() {
15781        // "a" encodes to "YQ=="
15782        let result = decode_base64("YQ==").unwrap();
15783        assert_eq!(String::from_utf8(result).unwrap(), "a");
15784
15785        // "ab" encodes to "YWI="
15786        let result = decode_base64("YWI=").unwrap();
15787        assert_eq!(String::from_utf8(result).unwrap(), "ab");
15788    }
15789
15790    #[test]
15791    fn decode_base64_without_padding() {
15792        // Padding is optional
15793        let result = decode_base64("YQ").unwrap();
15794        assert_eq!(String::from_utf8(result).unwrap(), "a");
15795
15796        let result = decode_base64("YWI").unwrap();
15797        assert_eq!(String::from_utf8(result).unwrap(), "ab");
15798    }
15799
15800    #[test]
15801    fn decode_base64_empty() {
15802        let result = decode_base64("").unwrap();
15803        assert!(result.is_empty());
15804    }
15805
15806    #[test]
15807    fn decode_base64_invalid_char() {
15808        let result = decode_base64("abc!def");
15809        assert!(result.is_err());
15810    }
15811
15812    #[test]
15813    fn decode_base64_complex_password() {
15814        // Test with special characters in password
15815        // "admin:p@$$w0rd!123" base64 encoded
15816        let encoded = encode_basic_auth("admin", "p@$$w0rd!123");
15817        // Strip "Basic " prefix that encode_basic_auth adds
15818        let result = decode_base64(&encoded).unwrap();
15819        assert_eq!(String::from_utf8(result).unwrap(), "admin:p@$$w0rd!123");
15820    }
15821}
15822
15823#[cfg(test)]
15824mod secure_compare_tests {
15825    use super::*;
15826
15827    // ========================================================================
15828    // Basic constant_time_eq tests
15829    // ========================================================================
15830
15831    #[test]
15832    fn constant_time_eq_equal_slices() {
15833        assert!(constant_time_eq(b"secret", b"secret"));
15834        assert!(constant_time_eq(b"", b""));
15835        assert!(constant_time_eq(b"a", b"a"));
15836        assert!(constant_time_eq(
15837            b"a_very_long_secret_token_12345",
15838            b"a_very_long_secret_token_12345"
15839        ));
15840    }
15841
15842    #[test]
15843    fn constant_time_eq_different_slices() {
15844        assert!(!constant_time_eq(b"secret", b"secreT"));
15845        assert!(!constant_time_eq(b"aaaaaa", b"aaaaab"));
15846        assert!(!constant_time_eq(b"a", b"b"));
15847    }
15848
15849    #[test]
15850    fn constant_time_eq_different_lengths() {
15851        assert!(!constant_time_eq(b"short", b"longer"));
15852        assert!(!constant_time_eq(b"", b"a"));
15853        assert!(!constant_time_eq(b"abc", b"ab"));
15854    }
15855
15856    #[test]
15857    fn constant_time_eq_binary_data() {
15858        let a = [0u8, 1, 2, 3, 255, 254, 253];
15859        let b = [0u8, 1, 2, 3, 255, 254, 253];
15860        let c = [0u8, 1, 2, 3, 255, 254, 252];
15861
15862        assert!(constant_time_eq(&a, &b));
15863        assert!(!constant_time_eq(&a, &c));
15864    }
15865
15866    #[test]
15867    fn constant_time_eq_all_zeros() {
15868        let a = [0u8; 32];
15869        let b = [0u8; 32];
15870        let c = {
15871            let mut arr = [0u8; 32];
15872            arr[31] = 1;
15873            arr
15874        };
15875
15876        assert!(constant_time_eq(&a, &b));
15877        assert!(!constant_time_eq(&a, &c));
15878    }
15879
15880    #[test]
15881    fn constant_time_eq_all_ones() {
15882        let a = [0xFFu8; 16];
15883        let b = [0xFFu8; 16];
15884        let c = {
15885            let mut arr = [0xFFu8; 16];
15886            arr[0] = 0xFE;
15887            arr
15888        };
15889
15890        assert!(constant_time_eq(&a, &b));
15891        assert!(!constant_time_eq(&a, &c));
15892    }
15893
15894    // ========================================================================
15895    // constant_time_str_eq tests
15896    // ========================================================================
15897
15898    #[test]
15899    fn constant_time_str_eq_equal() {
15900        assert!(constant_time_str_eq("password123", "password123"));
15901        assert!(constant_time_str_eq("", ""));
15902        assert!(constant_time_str_eq("🔐", "🔐")); // Unicode
15903    }
15904
15905    #[test]
15906    fn constant_time_str_eq_different() {
15907        assert!(!constant_time_str_eq("password123", "password124"));
15908        assert!(!constant_time_str_eq("case", "CASE"));
15909        assert!(!constant_time_str_eq("🔐", "🔑"));
15910    }
15911
15912    #[test]
15913    fn constant_time_str_eq_unicode() {
15914        // Multi-byte UTF-8 characters
15915        assert!(constant_time_str_eq("日本語", "日本語"));
15916        assert!(!constant_time_str_eq("日本語", "日本话"));
15917        assert!(!constant_time_str_eq("café", "cafe"));
15918    }
15919
15920    // ========================================================================
15921    // SecureCompare trait tests for BearerToken
15922    // ========================================================================
15923
15924    #[test]
15925    fn bearer_token_secure_eq() {
15926        let token = BearerToken::new("my_secret_token");
15927
15928        assert!(token.secure_eq("my_secret_token"));
15929        assert!(!token.secure_eq("my_secret_Token")); // Case sensitive
15930        assert!(!token.secure_eq("wrong_token"));
15931    }
15932
15933    #[test]
15934    fn bearer_token_secure_eq_bytes() {
15935        let token = BearerToken::new("api_key_123");
15936
15937        assert!(token.secure_eq_bytes(b"api_key_123"));
15938        assert!(!token.secure_eq_bytes(b"api_key_124"));
15939    }
15940
15941    // ========================================================================
15942    // SecureCompare trait tests for String/str
15943    // ========================================================================
15944
15945    #[test]
15946    fn str_secure_eq() {
15947        let secret: &str = "hunter2";
15948
15949        assert!(secret.secure_eq("hunter2"));
15950        assert!(!secret.secure_eq("hunter3"));
15951    }
15952
15953    #[test]
15954    fn string_secure_eq() {
15955        let secret = String::from("password");
15956
15957        assert!(secret.secure_eq("password"));
15958        assert!(!secret.secure_eq("passwor"));
15959    }
15960
15961    #[test]
15962    fn string_secure_eq_bytes() {
15963        let secret = String::from("binary_safe");
15964
15965        assert!(secret.secure_eq_bytes(b"binary_safe"));
15966        assert!(!secret.secure_eq_bytes(b"binary_Safe"));
15967    }
15968
15969    // ========================================================================
15970    // SecureCompare trait tests for byte slices
15971    // ========================================================================
15972
15973    #[test]
15974    fn byte_slice_secure_eq() {
15975        let hmac: &[u8] = &[0xDE, 0xAD, 0xBE, 0xEF];
15976
15977        assert!(hmac.secure_eq_bytes(&[0xDE, 0xAD, 0xBE, 0xEF]));
15978        assert!(!hmac.secure_eq_bytes(&[0xDE, 0xAD, 0xBE, 0xEE]));
15979    }
15980
15981    #[test]
15982    fn byte_array_secure_eq() {
15983        let key: [u8; 4] = [1, 2, 3, 4];
15984
15985        assert!(key.secure_eq_bytes(&[1, 2, 3, 4]));
15986        assert!(!key.secure_eq_bytes(&[1, 2, 3, 5]));
15987    }
15988
15989    #[test]
15990    fn vec_secure_eq() {
15991        let token: Vec<u8> = vec![0x41, 0x42, 0x43];
15992
15993        assert!(token.secure_eq("ABC"));
15994        assert!(!token.secure_eq("ABD"));
15995    }
15996
15997    // ========================================================================
15998    // Edge cases and security properties
15999    // ========================================================================
16000
16001    #[test]
16002    fn secure_compare_empty_values() {
16003        assert!(constant_time_eq(b"", b""));
16004        assert!(constant_time_str_eq("", ""));
16005        assert!(!constant_time_eq(b"", b"x"));
16006        assert!(!constant_time_str_eq("", "x"));
16007    }
16008
16009    #[test]
16010    fn secure_compare_single_bit_difference() {
16011        // These differ by exactly one bit
16012        let a = [0b1111_1111u8];
16013        let b = [0b1111_1110u8];
16014
16015        assert!(!constant_time_eq(&a, &b));
16016    }
16017
16018    #[test]
16019    fn secure_compare_first_byte_differs() {
16020        // Difference at the very start
16021        assert!(!constant_time_eq(b"Xsecret", b"Ysecret"));
16022    }
16023
16024    #[test]
16025    fn secure_compare_last_byte_differs() {
16026        // Difference at the very end
16027        assert!(!constant_time_eq(b"secretX", b"secretY"));
16028    }
16029
16030    #[test]
16031    fn secure_compare_middle_byte_differs() {
16032        // Difference in the middle
16033        assert!(!constant_time_eq(b"secXet", b"secYet"));
16034    }
16035
16036    // Test that ensures the trait can be used with the BearerToken extractor
16037    #[test]
16038    fn bearer_token_integration_with_secure_compare() {
16039        let token = BearerToken::new("real_api_token_xyz789");
16040
16041        // Simulating token validation in a handler
16042        let stored_token = "real_api_token_xyz789";
16043        let is_valid = token.secure_eq(stored_token);
16044        assert!(is_valid);
16045
16046        // Wrong token should fail
16047        let wrong_token = "fake_api_token_abc123";
16048        let is_invalid = !token.secure_eq(wrong_token);
16049        assert!(is_invalid);
16050    }
16051
16052    #[test]
16053    fn deref_with_secure_compare() {
16054        // BearerToken derefs to &str, so we can use secure_eq on the deref result
16055        let token = BearerToken::new("my_token");
16056        let token_str: &str = &token; // Deref
16057
16058        // Using SecureCompare on the &str
16059        assert!(token_str.secure_eq("my_token"));
16060    }
16061
16062    // ========================================================================
16063    // Timing verification (best-effort, not a cryptographic proof)
16064    // ========================================================================
16065    // Note: Actual timing tests are notoriously unreliable due to CPU caching,
16066    // branch prediction, and OS scheduling. These tests verify the algorithm
16067    // correctness rather than timing properties. For true timing verification,
16068    // use specialized tools like dudect or benchmarking with statistical analysis.
16069
16070    #[test]
16071    fn algorithm_processes_all_bytes() {
16072        // This test verifies the algorithm structure by checking that
16073        // all bytes are processed regardless of early differences.
16074        // The fold operation ensures all bytes are XORed.
16075
16076        // Same length, differ at position 0
16077        let a = b"Xsecret_token";
16078        let b = b"Ysecret_token";
16079        assert!(!constant_time_eq(a, b));
16080
16081        // Same length, differ at last position
16082        let c = b"secret_tokenX";
16083        let d = b"secret_tokenY";
16084        assert!(!constant_time_eq(c, d));
16085
16086        // Both should take the same code path (all 13 bytes XORed)
16087        // We can't easily verify timing in unit tests, but we verify correctness
16088    }
16089}
16090
16091#[cfg(test)]
16092mod pagination_tests {
16093    use super::*;
16094    use crate::request::Method;
16095    use crate::response::IntoResponse;
16096
16097    // Helper to create a test context
16098    fn test_context() -> RequestContext {
16099        let cx = asupersync::Cx::for_testing();
16100        RequestContext::new(cx, 12345)
16101    }
16102
16103    // ========================================================================
16104    // Pagination struct tests
16105    // ========================================================================
16106
16107    #[test]
16108    fn pagination_default_values() {
16109        let p = Pagination::default();
16110        assert_eq!(p.page(), DEFAULT_PAGE);
16111        assert_eq!(p.per_page(), DEFAULT_PER_PAGE);
16112        assert_eq!(p.limit(), DEFAULT_PER_PAGE);
16113        assert_eq!(p.offset(), 0);
16114    }
16115
16116    #[test]
16117    fn pagination_new() {
16118        let p = Pagination::new(3, 50);
16119        assert_eq!(p.page(), 3);
16120        assert_eq!(p.per_page(), 50);
16121        assert_eq!(p.offset(), 100); // (3-1) * 50
16122    }
16123
16124    #[test]
16125    fn pagination_new_clamps_per_page() {
16126        // Below minimum
16127        let p = Pagination::new(1, 0);
16128        assert_eq!(p.per_page(), 1);
16129
16130        // Above maximum
16131        let p = Pagination::new(1, 1000);
16132        assert_eq!(p.per_page(), MAX_PER_PAGE);
16133    }
16134
16135    #[test]
16136    fn pagination_new_clamps_page() {
16137        // Page 0 should become 1
16138        let p = Pagination::new(0, 20);
16139        assert_eq!(p.page(), 1);
16140    }
16141
16142    #[test]
16143    fn pagination_from_offset() {
16144        let p = Pagination::from_offset(40, 20);
16145        assert_eq!(p.offset(), 40); // Explicit offset preserved
16146        assert_eq!(p.per_page(), 20);
16147        assert_eq!(p.page(), 3); // 40/20 + 1
16148    }
16149
16150    #[test]
16151    fn pagination_total_pages() {
16152        let p = Pagination::new(1, 10);
16153        assert_eq!(p.total_pages(0), 0);
16154        assert_eq!(p.total_pages(10), 1);
16155        assert_eq!(p.total_pages(11), 2);
16156        assert_eq!(p.total_pages(100), 10);
16157    }
16158
16159    #[test]
16160    fn pagination_has_next_prev() {
16161        let p = Pagination::new(1, 10);
16162        assert!(!p.has_prev());
16163        assert!(p.has_next(100));
16164
16165        let p = Pagination::new(5, 10);
16166        assert!(p.has_prev());
16167        assert!(p.has_next(100));
16168
16169        let p = Pagination::new(10, 10);
16170        assert!(p.has_prev());
16171        assert!(!p.has_next(100));
16172    }
16173
16174    // ========================================================================
16175    // Pagination extractor tests
16176    // ========================================================================
16177
16178    #[test]
16179    fn pagination_extractor_default_params() {
16180        let ctx = test_context();
16181        let mut req = Request::new(Method::Get, "/items");
16182
16183        let p = futures_executor::block_on(Pagination::from_request(&ctx, &mut req)).unwrap();
16184        assert_eq!(p.page(), DEFAULT_PAGE);
16185        assert_eq!(p.per_page(), DEFAULT_PER_PAGE);
16186    }
16187
16188    #[test]
16189    fn pagination_extractor_page_param() {
16190        let ctx = test_context();
16191        let mut req = Request::new(Method::Get, "/items?page=5");
16192        req.insert_extension(QueryParams::parse("page=5"));
16193
16194        let p = futures_executor::block_on(Pagination::from_request(&ctx, &mut req)).unwrap();
16195        assert_eq!(p.page(), 5);
16196        assert_eq!(p.per_page(), DEFAULT_PER_PAGE);
16197    }
16198
16199    #[test]
16200    fn pagination_extractor_per_page_param() {
16201        let ctx = test_context();
16202        let mut req = Request::new(Method::Get, "/items?per_page=50");
16203        req.insert_extension(QueryParams::parse("per_page=50"));
16204
16205        let p = futures_executor::block_on(Pagination::from_request(&ctx, &mut req)).unwrap();
16206        assert_eq!(p.page(), DEFAULT_PAGE);
16207        assert_eq!(p.per_page(), 50);
16208    }
16209
16210    #[test]
16211    fn pagination_extractor_limit_alias() {
16212        let ctx = test_context();
16213        let mut req = Request::new(Method::Get, "/items?limit=25");
16214        req.insert_extension(QueryParams::parse("limit=25"));
16215
16216        let p = futures_executor::block_on(Pagination::from_request(&ctx, &mut req)).unwrap();
16217        assert_eq!(p.per_page(), 25);
16218    }
16219
16220    #[test]
16221    fn pagination_extractor_offset_param() {
16222        let ctx = test_context();
16223        let mut req = Request::new(Method::Get, "/items?offset=40&limit=10");
16224        req.insert_extension(QueryParams::parse("offset=40&limit=10"));
16225
16226        let p = futures_executor::block_on(Pagination::from_request(&ctx, &mut req)).unwrap();
16227        assert_eq!(p.offset(), 40);
16228        assert_eq!(p.per_page(), 10);
16229    }
16230
16231    #[test]
16232    fn pagination_extractor_clamps_max_per_page() {
16233        let ctx = test_context();
16234        let mut req = Request::new(Method::Get, "/items?per_page=1000");
16235        req.insert_extension(QueryParams::parse("per_page=1000"));
16236
16237        let p = futures_executor::block_on(Pagination::from_request(&ctx, &mut req)).unwrap();
16238        assert_eq!(p.per_page(), MAX_PER_PAGE);
16239    }
16240
16241    #[test]
16242    fn pagination_extractor_invalid_page_uses_default() {
16243        let ctx = test_context();
16244        let mut req = Request::new(Method::Get, "/items?page=abc");
16245        req.insert_extension(QueryParams::parse("page=abc"));
16246
16247        let p = futures_executor::block_on(Pagination::from_request(&ctx, &mut req)).unwrap();
16248        assert_eq!(p.page(), DEFAULT_PAGE);
16249    }
16250
16251    // ========================================================================
16252    // Page struct tests
16253    // ========================================================================
16254
16255    #[test]
16256    fn page_new() {
16257        let items = vec!["a", "b", "c"];
16258        let pagination = Pagination::new(2, 10);
16259        let page = Page::new(items.clone(), 100, pagination, "/items".to_string());
16260
16261        assert_eq!(page.items, items);
16262        assert_eq!(page.total, 100);
16263        assert_eq!(page.page, 2);
16264        assert_eq!(page.per_page, 10);
16265        assert_eq!(page.pages, 10);
16266    }
16267
16268    #[test]
16269    fn page_with_values() {
16270        let items = vec![1, 2, 3];
16271        let page = Page::with_values(items.clone(), 50, 3, 10, "/users");
16272
16273        assert_eq!(page.items, items);
16274        assert_eq!(page.total, 50);
16275        assert_eq!(page.page, 3);
16276        assert_eq!(page.per_page, 10);
16277        assert_eq!(page.pages, 5);
16278    }
16279
16280    #[test]
16281    fn page_len_is_empty() {
16282        let page: Page<i32> = Page::with_values(vec![], 0, 1, 10, "/items");
16283        assert!(page.is_empty());
16284        assert_eq!(page.len(), 0);
16285
16286        let page = Page::with_values(vec![1, 2, 3], 100, 1, 10, "/items");
16287        assert!(!page.is_empty());
16288        assert_eq!(page.len(), 3);
16289    }
16290
16291    #[test]
16292    fn page_has_next_prev() {
16293        // First page
16294        let page = Page::with_values(vec![1, 2, 3], 100, 1, 10, "/items");
16295        assert!(!page.has_prev());
16296        assert!(page.has_next());
16297
16298        // Middle page
16299        let page = Page::with_values(vec![1, 2, 3], 100, 5, 10, "/items");
16300        assert!(page.has_prev());
16301        assert!(page.has_next());
16302
16303        // Last page
16304        let page = Page::with_values(vec![1, 2, 3], 100, 10, 10, "/items");
16305        assert!(page.has_prev());
16306        assert!(!page.has_next());
16307    }
16308
16309    #[test]
16310    fn page_map() {
16311        let page = Page::with_values(vec![1, 2, 3], 100, 1, 10, "/items");
16312        let mapped = page.map(|n| n * 2);
16313
16314        assert_eq!(mapped.items, vec![2, 4, 6]);
16315        assert_eq!(mapped.total, 100);
16316        assert_eq!(mapped.page, 1);
16317    }
16318
16319    // ========================================================================
16320    // Link header tests
16321    // ========================================================================
16322
16323    #[test]
16324    fn page_link_header_first_page() {
16325        let page = Page::with_values(vec![1, 2, 3], 100, 1, 10, "/items");
16326        let link = page.link_header();
16327
16328        assert!(link.contains("rel=\"first\""));
16329        assert!(link.contains("rel=\"last\""));
16330        assert!(link.contains("rel=\"next\""));
16331        assert!(!link.contains("rel=\"prev\"")); // No prev on first page
16332        assert!(link.contains("page=1"));
16333        assert!(link.contains("page=2")); // Next page
16334        assert!(link.contains("page=10")); // Last page
16335    }
16336
16337    #[test]
16338    fn page_link_header_middle_page() {
16339        let page = Page::with_values(vec![1, 2, 3], 100, 5, 10, "/items");
16340        let link = page.link_header();
16341
16342        assert!(link.contains("rel=\"first\""));
16343        assert!(link.contains("rel=\"last\""));
16344        assert!(link.contains("rel=\"next\""));
16345        assert!(link.contains("rel=\"prev\""));
16346        assert!(link.contains("page=4")); // Prev page
16347        assert!(link.contains("page=6")); // Next page
16348    }
16349
16350    #[test]
16351    fn page_link_header_last_page() {
16352        let page = Page::with_values(vec![1, 2, 3], 100, 10, 10, "/items");
16353        let link = page.link_header();
16354
16355        assert!(link.contains("rel=\"first\""));
16356        assert!(link.contains("rel=\"last\""));
16357        assert!(!link.contains("rel=\"next\"")); // No next on last page
16358        assert!(link.contains("rel=\"prev\""));
16359        assert!(link.contains("page=9")); // Prev page
16360    }
16361
16362    #[test]
16363    fn page_link_header_single_page() {
16364        let page = Page::with_values(vec![1, 2, 3], 3, 1, 10, "/items");
16365        let link = page.link_header();
16366
16367        assert!(link.contains("rel=\"first\""));
16368        assert!(link.contains("rel=\"last\""));
16369        assert!(!link.contains("rel=\"next\"")); // Only one page
16370        assert!(!link.contains("rel=\"prev\""));
16371    }
16372
16373    // ========================================================================
16374    // IntoResponse tests
16375    // ========================================================================
16376
16377    #[test]
16378    fn page_into_response_status_ok() {
16379        let page = Page::with_values(vec![1, 2, 3], 100, 1, 10, "/items");
16380        let response = page.into_response();
16381
16382        assert_eq!(response.status().as_u16(), 200);
16383    }
16384
16385    #[test]
16386    fn page_into_response_content_type() {
16387        let page = Page::with_values(vec![1, 2, 3], 100, 1, 10, "/items");
16388        let response = page.into_response();
16389
16390        let content_type = response
16391            .headers()
16392            .iter()
16393            .find(|(name, _)| name == "content-type");
16394        assert!(content_type.is_some());
16395        assert_eq!(content_type.unwrap().1, b"application/json");
16396    }
16397
16398    #[test]
16399    fn page_into_response_has_link_header() {
16400        let page = Page::with_values(vec![1, 2, 3], 100, 1, 10, "/items");
16401        let response = page.into_response();
16402
16403        let link_header = response.headers().iter().find(|(name, _)| name == "link");
16404        assert!(link_header.is_some());
16405
16406        let link_value = String::from_utf8_lossy(&link_header.unwrap().1);
16407        assert!(link_value.contains("rel=\"first\""));
16408    }
16409
16410    #[test]
16411    fn page_into_response_has_pagination_headers() {
16412        let page = Page::with_values(vec![1, 2, 3], 100, 2, 10, "/items");
16413        let response = page.into_response();
16414
16415        let get_header = |name: &str| {
16416            response
16417                .headers()
16418                .iter()
16419                .find(|(n, _)| n == name)
16420                .map(|(_, v)| String::from_utf8_lossy(v).to_string())
16421        };
16422
16423        assert_eq!(get_header("x-total-count"), Some("100".to_string()));
16424        assert_eq!(get_header("x-page"), Some("2".to_string()));
16425        assert_eq!(get_header("x-per-page"), Some("10".to_string()));
16426        assert_eq!(get_header("x-total-pages"), Some("10".to_string()));
16427    }
16428
16429    #[test]
16430    fn page_into_response_json_body() {
16431        let page = Page::with_values(vec!["a", "b", "c"], 100, 2, 10, "/items");
16432        let response = page.into_response();
16433
16434        let body_str = match response.body_ref() {
16435            crate::response::ResponseBody::Bytes(b) => String::from_utf8_lossy(b).to_string(),
16436            _ => panic!("Expected bytes body"),
16437        };
16438
16439        // Parse and verify JSON structure
16440        let json: serde_json::Value = serde_json::from_str(&body_str).unwrap();
16441        assert_eq!(json["items"], serde_json::json!(["a", "b", "c"]));
16442        assert_eq!(json["total"], 100);
16443        assert_eq!(json["page"], 2);
16444        assert_eq!(json["per_page"], 10);
16445        assert_eq!(json["pages"], 10);
16446    }
16447
16448    // ========================================================================
16449    // PaginationConfig tests
16450    // ========================================================================
16451
16452    #[test]
16453    fn pagination_config_default() {
16454        let config = PaginationConfig::default();
16455        assert_eq!(config.default_per_page, DEFAULT_PER_PAGE);
16456        assert_eq!(config.max_per_page, MAX_PER_PAGE);
16457        assert_eq!(config.default_page, DEFAULT_PAGE);
16458    }
16459
16460    #[test]
16461    fn pagination_config_builder() {
16462        let config = PaginationConfig::new()
16463            .default_per_page(50)
16464            .max_per_page(200)
16465            .default_page(1);
16466
16467        assert_eq!(config.default_per_page, 50);
16468        assert_eq!(config.max_per_page, 200);
16469        assert_eq!(config.default_page, 1);
16470    }
16471
16472    // ========================================================================
16473    // Integration tests
16474    // ========================================================================
16475
16476    #[test]
16477    fn pagination_paginate_helper() {
16478        let pagination = Pagination::new(2, 10);
16479        let items = vec!["item1", "item2", "item3"];
16480
16481        let page = pagination.paginate(items.clone(), 100, "/api/items");
16482
16483        assert_eq!(page.items, items);
16484        assert_eq!(page.total, 100);
16485        assert_eq!(page.page, 2);
16486        assert_eq!(page.per_page, 10);
16487        assert_eq!(page.pages, 10);
16488    }
16489
16490    #[test]
16491    fn pagination_equality() {
16492        let p1 = Pagination::new(2, 10);
16493        let p2 = Pagination::new(2, 10);
16494        let p3 = Pagination::new(3, 10);
16495
16496        assert_eq!(p1, p2);
16497        assert_ne!(p1, p3);
16498    }
16499
16500    #[test]
16501    fn pagination_copy_clone() {
16502        let p1 = Pagination::new(2, 10);
16503        let p2 = p1; // Copy
16504        let p3 = p1; // Copy
16505
16506        assert_eq!(p1, p2);
16507        assert_eq!(p1, p3);
16508    }
16509}
16510
16511#[cfg(test)]
16512mod path_tests {
16513    use super::*;
16514    use crate::request::Method;
16515    use serde::Deserialize;
16516
16517    // Helper to create a test context
16518    fn test_context() -> RequestContext {
16519        let cx = asupersync::Cx::for_testing();
16520        RequestContext::new(cx, 12345)
16521    }
16522
16523    // Helper to create a request with path params
16524    fn request_with_params(params: Vec<(&str, &str)>) -> Request {
16525        let mut req = Request::new(Method::Get, "/test");
16526        let path_params = PathParams::from_pairs(
16527            params
16528                .into_iter()
16529                .map(|(k, v)| (k.to_string(), v.to_string()))
16530                .collect(),
16531        );
16532        req.insert_extension(path_params);
16533        req
16534    }
16535
16536    #[test]
16537    fn path_params_get() {
16538        let params = PathParams::from_pairs(vec![("id".to_string(), "42".to_string())]);
16539        assert_eq!(params.get("id"), Some("42"));
16540        assert_eq!(params.get("unknown"), None);
16541    }
16542
16543    #[test]
16544    fn path_params_len() {
16545        let params = PathParams::new();
16546        assert!(params.is_empty());
16547        assert_eq!(params.len(), 0);
16548
16549        let params = PathParams::from_pairs(vec![
16550            ("a".to_string(), "1".to_string()),
16551            ("b".to_string(), "2".to_string()),
16552        ]);
16553        assert!(!params.is_empty());
16554        assert_eq!(params.len(), 2);
16555    }
16556
16557    #[test]
16558    fn path_extract_single_i64() {
16559        let ctx = test_context();
16560        let mut req = request_with_params(vec![("id", "42")]);
16561
16562        let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
16563        let Path(id) = result.unwrap();
16564        assert_eq!(id, 42);
16565    }
16566
16567    #[test]
16568    fn path_extract_single_string() {
16569        let ctx = test_context();
16570        let mut req = request_with_params(vec![("name", "alice")]);
16571
16572        let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
16573        let Path(name) = result.unwrap();
16574        assert_eq!(name, "alice");
16575    }
16576
16577    #[test]
16578    fn path_extract_single_u32() {
16579        let ctx = test_context();
16580        let mut req = request_with_params(vec![("count", "100")]);
16581
16582        let result = futures_executor::block_on(Path::<u32>::from_request(&ctx, &mut req));
16583        let Path(count) = result.unwrap();
16584        assert_eq!(count, 100);
16585    }
16586
16587    #[test]
16588    fn path_extract_tuple() {
16589        let ctx = test_context();
16590        let mut req = request_with_params(vec![("user_id", "42"), ("post_id", "99")]);
16591
16592        let result = futures_executor::block_on(Path::<(i64, i64)>::from_request(&ctx, &mut req));
16593        let Path((user_id, post_id)) = result.unwrap();
16594        assert_eq!(user_id, 42);
16595        assert_eq!(post_id, 99);
16596    }
16597
16598    #[test]
16599    fn path_extract_tuple_mixed_types() {
16600        let ctx = test_context();
16601        let mut req = request_with_params(vec![("name", "alice"), ("id", "123")]);
16602
16603        let result =
16604            futures_executor::block_on(Path::<(String, i64)>::from_request(&ctx, &mut req));
16605        let Path((name, id)) = result.unwrap();
16606        assert_eq!(name, "alice");
16607        assert_eq!(id, 123);
16608    }
16609
16610    #[test]
16611    fn path_extract_struct() {
16612        #[derive(Deserialize, Debug, PartialEq)]
16613        struct UserPath {
16614            user_id: i64,
16615            post_id: i64,
16616        }
16617
16618        let ctx = test_context();
16619        let mut req = request_with_params(vec![("user_id", "42"), ("post_id", "99")]);
16620
16621        let result = futures_executor::block_on(Path::<UserPath>::from_request(&ctx, &mut req));
16622        let Path(path) = result.unwrap();
16623        assert_eq!(path.user_id, 42);
16624        assert_eq!(path.post_id, 99);
16625    }
16626
16627    #[test]
16628    fn path_extract_missing_params() {
16629        let ctx = test_context();
16630        let mut req = Request::new(Method::Get, "/test");
16631        // No PathParams extension set
16632
16633        let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
16634        assert!(matches!(result, Err(PathExtractError::MissingPathParams)));
16635    }
16636
16637    #[test]
16638    fn path_extract_invalid_type() {
16639        let ctx = test_context();
16640        let mut req = request_with_params(vec![("id", "not_a_number")]);
16641
16642        let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
16643        assert!(matches!(
16644            result,
16645            Err(PathExtractError::InvalidValue { name, .. }) if name == "id"
16646        ));
16647    }
16648
16649    #[test]
16650    fn path_extract_negative_for_unsigned() {
16651        let ctx = test_context();
16652        let mut req = request_with_params(vec![("count", "-5")]);
16653
16654        let result = futures_executor::block_on(Path::<u32>::from_request(&ctx, &mut req));
16655        assert!(matches!(result, Err(PathExtractError::InvalidValue { .. })));
16656    }
16657
16658    #[test]
16659    fn path_extract_f64() {
16660        let ctx = test_context();
16661        let mut req = request_with_params(vec![("price", "19.99")]);
16662
16663        let result = futures_executor::block_on(Path::<f64>::from_request(&ctx, &mut req));
16664        let Path(price) = result.unwrap();
16665        assert!((price - 19.99).abs() < 0.001);
16666    }
16667
16668    #[test]
16669    fn path_deref() {
16670        let path = Path(42i64);
16671        assert_eq!(*path, 42);
16672    }
16673
16674    #[test]
16675    fn path_into_inner() {
16676        let path = Path("hello".to_string());
16677        assert_eq!(path.into_inner(), "hello");
16678    }
16679
16680    #[test]
16681    fn path_error_display() {
16682        let err = PathExtractError::MissingPathParams;
16683        assert!(err.to_string().contains("not available"));
16684
16685        let err = PathExtractError::MissingParam {
16686            name: "user_id".to_string(),
16687        };
16688        assert!(err.to_string().contains("user_id"));
16689
16690        let err = PathExtractError::InvalidValue {
16691            name: "id".to_string(),
16692            value: "abc".to_string(),
16693            expected: "i64",
16694            message: "invalid digit".to_string(),
16695        };
16696        assert!(err.to_string().contains("id"));
16697        assert!(err.to_string().contains("abc"));
16698        assert!(err.to_string().contains("i64"));
16699    }
16700
16701    #[test]
16702    fn path_extract_bool() {
16703        let ctx = test_context();
16704        let mut req = request_with_params(vec![("active", "true")]);
16705
16706        let result = futures_executor::block_on(Path::<bool>::from_request(&ctx, &mut req));
16707        let Path(active) = result.unwrap();
16708        assert!(active);
16709    }
16710
16711    #[test]
16712    fn path_extract_char() {
16713        let ctx = test_context();
16714        let mut req = request_with_params(vec![("letter", "A")]);
16715
16716        let result = futures_executor::block_on(Path::<char>::from_request(&ctx, &mut req));
16717        let Path(letter) = result.unwrap();
16718        assert_eq!(letter, 'A');
16719    }
16720}
16721
16722#[cfg(test)]
16723mod query_tests {
16724    use super::*;
16725    use crate::request::Method;
16726    use serde::Deserialize;
16727
16728    // Helper to create a test context
16729    fn test_context() -> RequestContext {
16730        let cx = asupersync::Cx::for_testing();
16731        RequestContext::new(cx, 12345)
16732    }
16733
16734    // Helper to create a request with query string
16735    fn request_with_query(query: &str) -> Request {
16736        let mut req = Request::new(Method::Get, "/test");
16737        req.set_query(Some(query.to_string()));
16738        req
16739    }
16740
16741    #[test]
16742    fn query_params_parse() {
16743        let params = QueryParams::parse("a=1&b=2&c=3");
16744        assert_eq!(params.get("a"), Some("1"));
16745        assert_eq!(params.get("b"), Some("2"));
16746        assert_eq!(params.get("c"), Some("3"));
16747        assert_eq!(params.get("d"), None);
16748    }
16749
16750    #[test]
16751    fn query_params_multi_value() {
16752        let params = QueryParams::parse("tag=rust&tag=web&tag=api");
16753        assert_eq!(params.get("tag"), Some("rust")); // First value
16754        assert_eq!(params.get_all("tag"), vec!["rust", "web", "api"]);
16755    }
16756
16757    #[test]
16758    fn query_params_percent_decode() {
16759        let params = QueryParams::parse("msg=hello%20world&name=caf%C3%A9");
16760        assert_eq!(params.get("msg"), Some("hello world"));
16761        assert_eq!(params.get("name"), Some("café"));
16762    }
16763
16764    #[test]
16765    fn query_params_plus_as_space() {
16766        let params = QueryParams::parse("msg=hello+world");
16767        assert_eq!(params.get("msg"), Some("hello world"));
16768    }
16769
16770    #[test]
16771    fn query_params_empty_value() {
16772        let params = QueryParams::parse("flag&name=alice");
16773        assert!(params.contains("flag"));
16774        assert_eq!(params.get("flag"), Some(""));
16775        assert_eq!(params.get("name"), Some("alice"));
16776    }
16777
16778    #[test]
16779    fn query_extract_struct() {
16780        #[derive(Deserialize, Debug, PartialEq)]
16781        struct SearchParams {
16782            q: String,
16783            page: i32,
16784        }
16785
16786        let ctx = test_context();
16787        let mut req = request_with_query("q=rust&page=5");
16788
16789        let result =
16790            futures_executor::block_on(Query::<SearchParams>::from_request(&ctx, &mut req));
16791        let Query(params) = result.unwrap();
16792        assert_eq!(params.q, "rust");
16793        assert_eq!(params.page, 5);
16794    }
16795
16796    #[test]
16797    fn query_extract_optional_field() {
16798        #[derive(Deserialize, Debug)]
16799        struct Params {
16800            required: String,
16801            optional: Option<i32>,
16802        }
16803
16804        let ctx = test_context();
16805
16806        // With optional present
16807        let mut req = request_with_query("required=hello&optional=42");
16808        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16809        let Query(params) = result.unwrap();
16810        assert_eq!(params.required, "hello");
16811        assert_eq!(params.optional, Some(42));
16812
16813        // Without optional
16814        let mut req = request_with_query("required=hello");
16815        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16816        let Query(params) = result.unwrap();
16817        assert_eq!(params.required, "hello");
16818        assert_eq!(params.optional, None);
16819    }
16820
16821    #[test]
16822    fn query_extract_multi_value() {
16823        #[derive(Deserialize, Debug)]
16824        struct Params {
16825            tags: Vec<String>,
16826        }
16827
16828        let ctx = test_context();
16829        let mut req = request_with_query("tags=rust&tags=web&tags=api");
16830
16831        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16832        let Query(params) = result.unwrap();
16833        assert_eq!(params.tags, vec!["rust", "web", "api"]);
16834    }
16835
16836    #[test]
16837    fn query_extract_default_value() {
16838        #[derive(Deserialize, Debug)]
16839        struct Params {
16840            name: String,
16841            #[serde(default)]
16842            limit: i32,
16843        }
16844
16845        let ctx = test_context();
16846        let mut req = request_with_query("name=test");
16847
16848        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16849        let Query(params) = result.unwrap();
16850        assert_eq!(params.name, "test");
16851        assert_eq!(params.limit, 0); // Default for i32
16852    }
16853
16854    #[test]
16855    fn query_extract_bool() {
16856        #[derive(Deserialize, Debug)]
16857        struct Params {
16858            active: bool,
16859            archived: bool,
16860        }
16861
16862        let ctx = test_context();
16863        let mut req = request_with_query("active=true&archived=false");
16864
16865        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16866        let Query(params) = result.unwrap();
16867        assert!(params.active);
16868        assert!(!params.archived);
16869    }
16870
16871    #[test]
16872    fn query_extract_bool_variants() {
16873        #[derive(Deserialize, Debug)]
16874        struct Params {
16875            a: bool,
16876            b: bool,
16877            c: bool,
16878        }
16879
16880        let ctx = test_context();
16881        let mut req = request_with_query("a=1&b=yes&c=on");
16882
16883        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16884        let Query(params) = result.unwrap();
16885        assert!(params.a);
16886        assert!(params.b);
16887        assert!(params.c);
16888    }
16889
16890    #[test]
16891    fn query_extract_missing_required_fails() {
16892        #[derive(Deserialize, Debug)]
16893        #[allow(dead_code)]
16894        struct Params {
16895            required: String,
16896        }
16897
16898        let ctx = test_context();
16899        let mut req = request_with_query("other=value");
16900
16901        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16902        assert!(result.is_err());
16903    }
16904
16905    #[test]
16906    fn query_extract_invalid_type_fails() {
16907        #[derive(Deserialize, Debug)]
16908        #[allow(dead_code)]
16909        struct Params {
16910            count: i32,
16911        }
16912
16913        let ctx = test_context();
16914        let mut req = request_with_query("count=not_a_number");
16915
16916        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16917        assert!(result.is_err());
16918    }
16919
16920    #[test]
16921    fn query_extract_empty_query() {
16922        #[derive(Deserialize, Debug, Default)]
16923        struct Params {
16924            #[serde(default)]
16925            name: String,
16926        }
16927
16928        let ctx = test_context();
16929        let mut req = request_with_query("");
16930
16931        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16932        let Query(params) = result.unwrap();
16933        assert_eq!(params.name, "");
16934    }
16935
16936    #[test]
16937    fn query_extract_float() {
16938        #[derive(Deserialize, Debug)]
16939        struct Params {
16940            price: f64,
16941        }
16942
16943        let ctx = test_context();
16944        let mut req = request_with_query("price=29.99");
16945
16946        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16947        let Query(params) = result.unwrap();
16948        assert!((params.price - 29.99).abs() < 0.001);
16949    }
16950
16951    #[test]
16952    fn query_deref() {
16953        #[derive(Deserialize, Debug)]
16954        struct Params {
16955            name: String,
16956        }
16957
16958        let query = Query(Params {
16959            name: "test".to_string(),
16960        });
16961        assert_eq!(query.name, "test");
16962    }
16963
16964    #[test]
16965    fn query_into_inner() {
16966        #[derive(Deserialize, Debug, PartialEq)]
16967        struct Params {
16968            value: i32,
16969        }
16970
16971        let query = Query(Params { value: 42 });
16972        assert_eq!(query.into_inner(), Params { value: 42 });
16973    }
16974
16975    #[test]
16976    fn query_error_display() {
16977        let err = QueryExtractError::MissingParam {
16978            name: "user_id".to_string(),
16979        };
16980        assert!(err.to_string().contains("user_id"));
16981
16982        let err = QueryExtractError::InvalidValue {
16983            name: "count".to_string(),
16984            value: "abc".to_string(),
16985            expected: "i32",
16986            message: "invalid digit".to_string(),
16987        };
16988        assert!(err.to_string().contains("count"));
16989        assert!(err.to_string().contains("abc"));
16990        assert!(err.to_string().contains("i32"));
16991    }
16992
16993    #[test]
16994    fn query_params_keys() {
16995        let params = QueryParams::parse("a=1&b=2&a=3&c=4");
16996        let keys: Vec<&str> = params.keys().collect();
16997        assert_eq!(keys, vec!["a", "b", "c"]); // Unique keys in order
16998    }
16999
17000    #[test]
17001    fn query_params_len() {
17002        let params = QueryParams::parse("a=1&b=2&c=3");
17003        assert_eq!(params.len(), 3);
17004        assert!(!params.is_empty());
17005
17006        let empty = QueryParams::new();
17007        assert_eq!(empty.len(), 0);
17008        assert!(empty.is_empty());
17009    }
17010}
17011
17012// ============================================================================
17013// Optional Extraction Tests
17014// ============================================================================
17015
17016#[cfg(test)]
17017mod optional_tests {
17018    use super::*;
17019    use crate::request::Method;
17020
17021    fn test_context() -> RequestContext {
17022        let cx = asupersync::Cx::for_testing();
17023        RequestContext::new(cx, 99999)
17024    }
17025
17026    // --- Option<Json<T>> Tests ---
17027
17028    #[test]
17029    fn optional_json_present_valid() {
17030        use serde::Deserialize;
17031
17032        #[derive(Deserialize, PartialEq, Debug)]
17033        struct Data {
17034            value: i32,
17035        }
17036
17037        let ctx = test_context();
17038        let mut req = Request::new(Method::Post, "/test");
17039        req.headers_mut()
17040            .insert("content-type", b"application/json".to_vec());
17041        req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
17042
17043        let result = futures_executor::block_on(Option::<Json<Data>>::from_request(&ctx, &mut req));
17044        let Some(Json(data)) = result.unwrap() else {
17045            panic!("Expected Some");
17046        };
17047        assert_eq!(data.value, 42);
17048    }
17049
17050    #[test]
17051    fn optional_json_invalid_content_type_returns_none() {
17052        use serde::Deserialize;
17053
17054        #[derive(Deserialize)]
17055        #[allow(dead_code)]
17056        struct Data {
17057            value: i32,
17058        }
17059
17060        let ctx = test_context();
17061        let mut req = Request::new(Method::Post, "/test");
17062        req.headers_mut()
17063            .insert("content-type", b"text/plain".to_vec());
17064        req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
17065
17066        let result = futures_executor::block_on(Option::<Json<Data>>::from_request(&ctx, &mut req));
17067        assert!(result.unwrap().is_none());
17068    }
17069
17070    #[test]
17071    fn optional_json_missing_body_returns_none() {
17072        use serde::Deserialize;
17073
17074        #[derive(Deserialize)]
17075        #[allow(dead_code)]
17076        struct Data {
17077            value: i32,
17078        }
17079
17080        let ctx = test_context();
17081        let mut req = Request::new(Method::Post, "/test");
17082        req.headers_mut()
17083            .insert("content-type", b"application/json".to_vec());
17084        // No body set, but content-type is present - will fail parsing
17085
17086        let result = futures_executor::block_on(Option::<Json<Data>>::from_request(&ctx, &mut req));
17087        // Either None (if content-type check fails) or None from parse error
17088        assert!(result.unwrap().is_none());
17089    }
17090
17091    #[test]
17092    fn optional_json_malformed_returns_none() {
17093        use serde::Deserialize;
17094
17095        #[derive(Deserialize)]
17096        #[allow(dead_code)]
17097        struct Data {
17098            value: i32,
17099        }
17100
17101        let ctx = test_context();
17102        let mut req = Request::new(Method::Post, "/test");
17103        req.headers_mut()
17104            .insert("content-type", b"application/json".to_vec());
17105        req.set_body(Body::Bytes(b"{ not valid json }".to_vec()));
17106
17107        let result = futures_executor::block_on(Option::<Json<Data>>::from_request(&ctx, &mut req));
17108        assert!(result.unwrap().is_none());
17109    }
17110
17111    // --- Option<Path<T>> Tests ---
17112
17113    #[test]
17114    fn optional_path_present_valid() {
17115        let ctx = test_context();
17116        let mut req = Request::new(Method::Get, "/users/42");
17117        req.insert_extension(PathParams::from_pairs(vec![(
17118            "id".to_string(),
17119            "42".to_string(),
17120        )]));
17121
17122        let result = futures_executor::block_on(Option::<Path<i64>>::from_request(&ctx, &mut req));
17123        let Some(Path(id)) = result.unwrap() else {
17124            panic!("Expected Some");
17125        };
17126        assert_eq!(id, 42);
17127    }
17128
17129    #[test]
17130    fn optional_path_missing_params_returns_none() {
17131        let ctx = test_context();
17132        let mut req = Request::new(Method::Get, "/users/42");
17133        // No PathParams set
17134
17135        let result = futures_executor::block_on(Option::<Path<i64>>::from_request(&ctx, &mut req));
17136        assert!(result.unwrap().is_none());
17137    }
17138
17139    #[test]
17140    fn optional_path_invalid_type_returns_none() {
17141        let ctx = test_context();
17142        let mut req = Request::new(Method::Get, "/users/abc");
17143        req.insert_extension(PathParams::from_pairs(vec![(
17144            "id".to_string(),
17145            "abc".to_string(),
17146        )]));
17147
17148        let result = futures_executor::block_on(Option::<Path<i64>>::from_request(&ctx, &mut req));
17149        assert!(result.unwrap().is_none());
17150    }
17151
17152    // --- Option<Query<T>> Tests ---
17153
17154    #[test]
17155    fn optional_query_present_valid() {
17156        use serde::Deserialize;
17157
17158        #[derive(Deserialize, PartialEq, Debug)]
17159        struct Params {
17160            page: i32,
17161        }
17162
17163        let ctx = test_context();
17164        let mut req = Request::new(Method::Get, "/items");
17165        req.set_query(Some("page=5".to_string()));
17166
17167        let result =
17168            futures_executor::block_on(Option::<Query<Params>>::from_request(&ctx, &mut req));
17169        let Some(Query(params)) = result.unwrap() else {
17170            panic!("Expected Some");
17171        };
17172        assert_eq!(params.page, 5);
17173    }
17174
17175    #[test]
17176    fn optional_query_missing_returns_none() {
17177        use serde::Deserialize;
17178
17179        #[derive(Deserialize)]
17180        #[allow(dead_code)]
17181        struct Params {
17182            required: String,
17183        }
17184
17185        let ctx = test_context();
17186        let mut req = Request::new(Method::Get, "/items");
17187        // No query set
17188
17189        let result =
17190            futures_executor::block_on(Option::<Query<Params>>::from_request(&ctx, &mut req));
17191        assert!(result.unwrap().is_none());
17192    }
17193
17194    #[test]
17195    fn optional_query_invalid_type_returns_none() {
17196        use serde::Deserialize;
17197
17198        #[derive(Deserialize)]
17199        #[allow(dead_code)]
17200        struct Params {
17201            page: i32,
17202        }
17203
17204        let ctx = test_context();
17205        let mut req = Request::new(Method::Get, "/items");
17206        req.set_query(Some("page=abc".to_string()));
17207
17208        let result =
17209            futures_executor::block_on(Option::<Query<Params>>::from_request(&ctx, &mut req));
17210        assert!(result.unwrap().is_none());
17211    }
17212
17213    // --- Option<State<T>> Tests ---
17214
17215    #[test]
17216    fn optional_state_present() {
17217        let ctx = test_context();
17218        let mut req = Request::new(Method::Get, "/");
17219        let app_state = AppState::new().with(42i32);
17220        req.insert_extension(app_state);
17221
17222        let result = futures_executor::block_on(Option::<State<i32>>::from_request(&ctx, &mut req));
17223        let Some(State(val)) = result.unwrap() else {
17224            panic!("Expected Some");
17225        };
17226        assert_eq!(val, 42);
17227    }
17228
17229    #[test]
17230    fn optional_state_missing_returns_none() {
17231        let ctx = test_context();
17232        let mut req = Request::new(Method::Get, "/");
17233        // No AppState set
17234
17235        let result = futures_executor::block_on(Option::<State<i32>>::from_request(&ctx, &mut req));
17236        assert!(result.unwrap().is_none());
17237    }
17238
17239    #[test]
17240    fn optional_state_wrong_type_returns_none() {
17241        let ctx = test_context();
17242        let mut req = Request::new(Method::Get, "/");
17243        let app_state = AppState::new().with("string".to_string()); // String, not i32
17244        req.insert_extension(app_state);
17245
17246        let result = futures_executor::block_on(Option::<State<i32>>::from_request(&ctx, &mut req));
17247        assert!(result.unwrap().is_none());
17248    }
17249}
17250
17251// ============================================================================
17252// Multiple Extractors Combination Tests
17253// ============================================================================
17254
17255#[cfg(test)]
17256mod combination_tests {
17257    use super::*;
17258    use crate::request::Method;
17259
17260    fn test_context() -> RequestContext {
17261        let cx = asupersync::Cx::for_testing();
17262        RequestContext::new(cx, 88888)
17263    }
17264
17265    #[test]
17266    fn path_and_query_together() {
17267        use serde::Deserialize;
17268
17269        #[derive(Deserialize, PartialEq, Debug)]
17270        struct QueryParams {
17271            limit: i32,
17272        }
17273
17274        let ctx = test_context();
17275        let mut req = Request::new(Method::Get, "/users/42");
17276        req.insert_extension(PathParams::from_pairs(vec![(
17277            "id".to_string(),
17278            "42".to_string(),
17279        )]));
17280        req.set_query(Some("limit=10".to_string()));
17281
17282        // Extract path
17283        let path_result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
17284        let Path(user_id) = path_result.unwrap();
17285        assert_eq!(user_id, 42);
17286
17287        // Extract query
17288        let query_result =
17289            futures_executor::block_on(Query::<QueryParams>::from_request(&ctx, &mut req));
17290        let Query(params) = query_result.unwrap();
17291        assert_eq!(params.limit, 10);
17292    }
17293
17294    #[test]
17295    fn json_body_and_path() {
17296        use serde::Deserialize;
17297
17298        #[derive(Deserialize, PartialEq, Debug)]
17299        struct CreateItem {
17300            name: String,
17301        }
17302
17303        let ctx = test_context();
17304        let mut req = Request::new(Method::Post, "/categories/5/items");
17305        req.headers_mut()
17306            .insert("content-type", b"application/json".to_vec());
17307        req.set_body(Body::Bytes(b"{\"name\": \"Widget\"}".to_vec()));
17308        req.insert_extension(PathParams::from_pairs(vec![(
17309            "cat_id".to_string(),
17310            "5".to_string(),
17311        )]));
17312
17313        // Extract path first (doesn't consume body)
17314        let path_result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
17315        let Path(cat_id) = path_result.unwrap();
17316        assert_eq!(cat_id, 5);
17317
17318        // Extract JSON body
17319        let json_result =
17320            futures_executor::block_on(Json::<CreateItem>::from_request(&ctx, &mut req));
17321        let Json(item) = json_result.unwrap();
17322        assert_eq!(item.name, "Widget");
17323    }
17324
17325    #[test]
17326    fn state_and_query() {
17327        use serde::Deserialize;
17328
17329        #[derive(Deserialize, PartialEq, Debug)]
17330        struct SearchParams {
17331            q: String,
17332        }
17333
17334        #[derive(Clone, PartialEq, Debug)]
17335        struct Config {
17336            max_results: i32,
17337        }
17338
17339        let ctx = test_context();
17340        let mut req = Request::new(Method::Get, "/search");
17341        req.set_query(Some("q=hello".to_string()));
17342        let app_state = AppState::new().with(Config { max_results: 100 });
17343        req.insert_extension(app_state);
17344
17345        // Extract state
17346        let state_result =
17347            futures_executor::block_on(State::<Config>::from_request(&ctx, &mut req));
17348        let State(config) = state_result.unwrap();
17349        assert_eq!(config.max_results, 100);
17350
17351        // Extract query
17352        let query_result =
17353            futures_executor::block_on(Query::<SearchParams>::from_request(&ctx, &mut req));
17354        let Query(params) = query_result.unwrap();
17355        assert_eq!(params.q, "hello");
17356    }
17357
17358    #[test]
17359    fn multiple_path_params_with_struct() {
17360        use serde::Deserialize;
17361
17362        #[derive(Deserialize, PartialEq, Debug)]
17363        struct CommentPath {
17364            post_id: i64,
17365            comment_id: i64,
17366        }
17367
17368        let ctx = test_context();
17369        let mut req = Request::new(Method::Get, "/posts/123/comments/456");
17370        req.insert_extension(PathParams::from_pairs(vec![
17371            ("post_id".to_string(), "123".to_string()),
17372            ("comment_id".to_string(), "456".to_string()),
17373        ]));
17374
17375        let result = futures_executor::block_on(Path::<CommentPath>::from_request(&ctx, &mut req));
17376        let Path(path) = result.unwrap();
17377        assert_eq!(path.post_id, 123);
17378        assert_eq!(path.comment_id, 456);
17379    }
17380
17381    #[test]
17382    fn optional_mixed_with_required() {
17383        use serde::Deserialize;
17384
17385        #[derive(Deserialize, PartialEq, Debug)]
17386        struct OptionalParams {
17387            page: Option<i32>,
17388        }
17389
17390        let ctx = test_context();
17391        let mut req = Request::new(Method::Get, "/users/42");
17392        req.insert_extension(PathParams::from_pairs(vec![(
17393            "id".to_string(),
17394            "42".to_string(),
17395        )]));
17396
17397        // Required path - should succeed
17398        let path_result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
17399        let Path(id) = path_result.unwrap();
17400        assert_eq!(id, 42);
17401
17402        // Optional query - should return default None
17403        let query_result =
17404            futures_executor::block_on(Query::<OptionalParams>::from_request(&ctx, &mut req));
17405        let Query(params) = query_result.unwrap();
17406        assert_eq!(params.page, None);
17407    }
17408
17409    #[test]
17410    fn request_context_extraction() {
17411        let ctx = test_context();
17412        let mut req = Request::new(Method::Get, "/");
17413
17414        let result = futures_executor::block_on(RequestContext::from_request(&ctx, &mut req));
17415        let extracted_ctx = result.unwrap();
17416        assert_eq!(extracted_ctx.request_id(), ctx.request_id());
17417    }
17418
17419    #[test]
17420    fn triple_extraction_path_query_state() {
17421        use serde::Deserialize;
17422
17423        #[derive(Deserialize, PartialEq, Debug)]
17424        struct QueryFilter {
17425            status: String,
17426        }
17427
17428        #[derive(Clone)]
17429        struct DbPool {
17430            connection_count: i32,
17431        }
17432
17433        let ctx = test_context();
17434        let mut req = Request::new(Method::Get, "/projects/99/tasks");
17435        req.insert_extension(PathParams::from_pairs(vec![(
17436            "project_id".to_string(),
17437            "99".to_string(),
17438        )]));
17439        req.set_query(Some("status=active".to_string()));
17440        let app_state = AppState::new().with(DbPool {
17441            connection_count: 10,
17442        });
17443        req.insert_extension(app_state);
17444
17445        // Path
17446        let Path(project_id): Path<i32> =
17447            futures_executor::block_on(Path::<i32>::from_request(&ctx, &mut req)).unwrap();
17448        assert_eq!(project_id, 99);
17449
17450        // Query
17451        let Query(filter): Query<QueryFilter> =
17452            futures_executor::block_on(Query::<QueryFilter>::from_request(&ctx, &mut req)).unwrap();
17453        assert_eq!(filter.status, "active");
17454
17455        // State
17456        let State(pool): State<DbPool> =
17457            futures_executor::block_on(State::<DbPool>::from_request(&ctx, &mut req)).unwrap();
17458        assert_eq!(pool.connection_count, 10);
17459    }
17460}
17461
17462// ============================================================================
17463// Edge Case Tests
17464// ============================================================================
17465
17466#[cfg(test)]
17467mod edge_case_tests {
17468    use super::*;
17469    use crate::request::Method;
17470
17471    fn test_context() -> RequestContext {
17472        let cx = asupersync::Cx::for_testing();
17473        RequestContext::new(cx, 77777)
17474    }
17475
17476    // --- Unicode and Special Characters ---
17477
17478    #[test]
17479    fn json_with_unicode() {
17480        use serde::Deserialize;
17481
17482        #[derive(Deserialize, PartialEq, Debug)]
17483        struct Data {
17484            name: String,
17485            emoji: String,
17486        }
17487
17488        let ctx = test_context();
17489        let mut req = Request::new(Method::Post, "/test");
17490        req.headers_mut()
17491            .insert("content-type", b"application/json".to_vec());
17492        req.set_body(Body::Bytes(
17493            r#"{"name": "日本語", "emoji": "🎉🚀"}"#.as_bytes().to_vec(),
17494        ));
17495
17496        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
17497        let Json(data) = result.unwrap();
17498        assert_eq!(data.name, "日本語");
17499        assert_eq!(data.emoji, "🎉🚀");
17500    }
17501
17502    #[test]
17503    fn query_with_unicode_percent_encoded() {
17504        use serde::Deserialize;
17505
17506        #[derive(Deserialize, PartialEq, Debug)]
17507        struct Search {
17508            q: String,
17509        }
17510
17511        let ctx = test_context();
17512        let mut req = Request::new(Method::Get, "/search");
17513        // "こんにちは" (hello in Japanese), percent-encoded
17514        req.set_query(Some(
17515            "q=%E3%81%93%E3%82%93%E3%81%AB%E3%81%A1%E3%81%AF".to_string(),
17516        ));
17517
17518        let result = futures_executor::block_on(Query::<Search>::from_request(&ctx, &mut req));
17519        let Query(search) = result.unwrap();
17520        assert_eq!(search.q, "こんにちは");
17521    }
17522
17523    #[test]
17524    fn path_with_unicode() {
17525        let ctx = test_context();
17526        let mut req = Request::new(Method::Get, "/users/用户123");
17527        req.insert_extension(PathParams::from_pairs(vec![(
17528            "name".to_string(),
17529            "用户123".to_string(),
17530        )]));
17531
17532        let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
17533        let Path(name) = result.unwrap();
17534        assert_eq!(name, "用户123");
17535    }
17536
17537    // --- Boundary Values ---
17538
17539    #[test]
17540    fn path_max_i64() {
17541        let ctx = test_context();
17542        let mut req = Request::new(Method::Get, "/items/9223372036854775807");
17543        req.insert_extension(PathParams::from_pairs(vec![(
17544            "id".to_string(),
17545            "9223372036854775807".to_string(),
17546        )]));
17547
17548        let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
17549        let Path(id) = result.unwrap();
17550        assert_eq!(id, i64::MAX);
17551    }
17552
17553    #[test]
17554    fn path_min_i64() {
17555        let ctx = test_context();
17556        let mut req = Request::new(Method::Get, "/items/-9223372036854775808");
17557        req.insert_extension(PathParams::from_pairs(vec![(
17558            "id".to_string(),
17559            "-9223372036854775808".to_string(),
17560        )]));
17561
17562        let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
17563        let Path(id) = result.unwrap();
17564        assert_eq!(id, i64::MIN);
17565    }
17566
17567    #[test]
17568    fn path_overflow_i64_fails() {
17569        let ctx = test_context();
17570        let mut req = Request::new(Method::Get, "/items/9223372036854775808");
17571        req.insert_extension(PathParams::from_pairs(vec![(
17572            "id".to_string(),
17573            "9223372036854775808".to_string(), // i64::MAX + 1
17574        )]));
17575
17576        let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
17577        assert!(result.is_err());
17578    }
17579
17580    #[test]
17581    fn query_with_empty_value() {
17582        use serde::Deserialize;
17583
17584        #[derive(Deserialize, PartialEq, Debug)]
17585        struct Params {
17586            key: String,
17587        }
17588
17589        let ctx = test_context();
17590        let mut req = Request::new(Method::Get, "/test");
17591        req.set_query(Some("key=".to_string()));
17592
17593        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
17594        let Query(params) = result.unwrap();
17595        assert_eq!(params.key, "");
17596    }
17597
17598    #[test]
17599    fn query_with_only_key_no_equals() {
17600        use serde::Deserialize;
17601
17602        #[derive(Deserialize, PartialEq, Debug)]
17603        struct Params {
17604            flag: Option<String>,
17605        }
17606
17607        let ctx = test_context();
17608        let mut req = Request::new(Method::Get, "/test");
17609        req.set_query(Some("flag".to_string()));
17610
17611        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
17612        let Query(params) = result.unwrap();
17613        // Key without = should have empty string value
17614        assert_eq!(params.flag, Some(String::new()));
17615    }
17616
17617    #[test]
17618    fn json_empty_object() {
17619        use serde::Deserialize;
17620
17621        #[derive(Deserialize, PartialEq, Debug)]
17622        struct Empty {}
17623
17624        let ctx = test_context();
17625        let mut req = Request::new(Method::Post, "/test");
17626        req.headers_mut()
17627            .insert("content-type", b"application/json".to_vec());
17628        req.set_body(Body::Bytes(b"{}".to_vec()));
17629
17630        let result = futures_executor::block_on(Json::<Empty>::from_request(&ctx, &mut req));
17631        assert!(result.is_ok());
17632    }
17633
17634    #[test]
17635    fn json_with_null_field() {
17636        use serde::Deserialize;
17637
17638        #[derive(Deserialize, PartialEq, Debug)]
17639        struct Data {
17640            value: Option<i32>,
17641        }
17642
17643        let ctx = test_context();
17644        let mut req = Request::new(Method::Post, "/test");
17645        req.headers_mut()
17646            .insert("content-type", b"application/json".to_vec());
17647        req.set_body(Body::Bytes(b"{\"value\": null}".to_vec()));
17648
17649        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
17650        let Json(data) = result.unwrap();
17651        assert_eq!(data.value, None);
17652    }
17653
17654    #[test]
17655    fn json_with_nested_objects() {
17656        use serde::Deserialize;
17657
17658        #[derive(Deserialize, PartialEq, Debug)]
17659        struct Address {
17660            city: String,
17661            zip: String,
17662        }
17663
17664        #[derive(Deserialize, PartialEq, Debug)]
17665        struct User {
17666            name: String,
17667            address: Address,
17668        }
17669
17670        let ctx = test_context();
17671        let mut req = Request::new(Method::Post, "/test");
17672        req.headers_mut()
17673            .insert("content-type", b"application/json".to_vec());
17674        req.set_body(Body::Bytes(
17675            b"{\"name\": \"Alice\", \"address\": {\"city\": \"NYC\", \"zip\": \"10001\"}}".to_vec(),
17676        ));
17677
17678        let result = futures_executor::block_on(Json::<User>::from_request(&ctx, &mut req));
17679        let Json(user) = result.unwrap();
17680        assert_eq!(user.name, "Alice");
17681        assert_eq!(user.address.city, "NYC");
17682        assert_eq!(user.address.zip, "10001");
17683    }
17684
17685    #[test]
17686    fn json_with_array() {
17687        use serde::Deserialize;
17688
17689        #[derive(Deserialize, PartialEq, Debug)]
17690        struct Data {
17691            items: Vec<i32>,
17692        }
17693
17694        let ctx = test_context();
17695        let mut req = Request::new(Method::Post, "/test");
17696        req.headers_mut()
17697            .insert("content-type", b"application/json".to_vec());
17698        req.set_body(Body::Bytes(b"{\"items\": [1, 2, 3, 4, 5]}".to_vec()));
17699
17700        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
17701        let Json(data) = result.unwrap();
17702        assert_eq!(data.items, vec![1, 2, 3, 4, 5]);
17703    }
17704
17705    #[test]
17706    fn path_with_special_chars() {
17707        let ctx = test_context();
17708        let mut req = Request::new(Method::Get, "/files/my-file_v2.txt");
17709        req.insert_extension(PathParams::from_pairs(vec![(
17710            "filename".to_string(),
17711            "my-file_v2.txt".to_string(),
17712        )]));
17713
17714        let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
17715        let Path(filename) = result.unwrap();
17716        assert_eq!(filename, "my-file_v2.txt");
17717    }
17718
17719    #[test]
17720    fn query_with_special_chars_encoded() {
17721        use serde::Deserialize;
17722
17723        #[derive(Deserialize, PartialEq, Debug)]
17724        struct Params {
17725            value: String,
17726        }
17727
17728        let ctx = test_context();
17729        let mut req = Request::new(Method::Get, "/test");
17730        // Encoded: "hello world & more"
17731        req.set_query(Some("value=hello%20world%20%26%20more".to_string()));
17732
17733        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
17734        let Query(params) = result.unwrap();
17735        assert_eq!(params.value, "hello world & more");
17736    }
17737
17738    #[test]
17739    fn query_multiple_values_same_key() {
17740        use serde::Deserialize;
17741
17742        #[derive(Deserialize, PartialEq, Debug)]
17743        struct Params {
17744            tags: Vec<String>,
17745        }
17746
17747        let ctx = test_context();
17748        let mut req = Request::new(Method::Get, "/test");
17749        req.set_query(Some("tags=a&tags=b&tags=c".to_string()));
17750
17751        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
17752        let Query(params) = result.unwrap();
17753        assert_eq!(params.tags, vec!["a", "b", "c"]);
17754    }
17755
17756    #[test]
17757    fn path_empty_string() {
17758        let ctx = test_context();
17759        let mut req = Request::new(Method::Get, "/items//details");
17760        req.insert_extension(PathParams::from_pairs(vec![(
17761            "id".to_string(),
17762            String::new(),
17763        )]));
17764
17765        let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
17766        let Path(id) = result.unwrap();
17767        assert_eq!(id, "");
17768    }
17769
17770    #[test]
17771    fn json_with_escaped_quotes() {
17772        use serde::Deserialize;
17773
17774        #[derive(Deserialize, PartialEq, Debug)]
17775        struct Data {
17776            message: String,
17777        }
17778
17779        let ctx = test_context();
17780        let mut req = Request::new(Method::Post, "/test");
17781        req.headers_mut()
17782            .insert("content-type", b"application/json".to_vec());
17783        req.set_body(Body::Bytes(
17784            b"{\"message\": \"He said \\\"hello\\\"\"}".to_vec(),
17785        ));
17786
17787        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
17788        let Json(data) = result.unwrap();
17789        assert_eq!(data.message, "He said \"hello\"");
17790    }
17791
17792    #[test]
17793    fn query_with_plus_as_space() {
17794        use serde::Deserialize;
17795
17796        #[derive(Deserialize, PartialEq, Debug)]
17797        struct Params {
17798            q: String,
17799        }
17800
17801        let ctx = test_context();
17802        let mut req = Request::new(Method::Get, "/search");
17803        req.set_query(Some("q=hello+world".to_string()));
17804
17805        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
17806        let Query(params) = result.unwrap();
17807        assert_eq!(params.q, "hello world");
17808    }
17809}
17810
17811// ============================================================================
17812// Security Tests
17813// ============================================================================
17814
17815#[cfg(test)]
17816mod security_tests {
17817    use super::*;
17818    use crate::request::Method;
17819
17820    fn test_context() -> RequestContext {
17821        let cx = asupersync::Cx::for_testing();
17822        RequestContext::new(cx, 66666)
17823    }
17824
17825    #[test]
17826    fn json_payload_size_limit() {
17827        use serde::Deserialize;
17828
17829        #[derive(Deserialize)]
17830        #[allow(dead_code)]
17831        struct Data {
17832            content: String,
17833        }
17834
17835        let ctx = test_context();
17836        let mut req = Request::new(Method::Post, "/test");
17837        req.headers_mut()
17838            .insert("content-type", b"application/json".to_vec());
17839
17840        // Create payload larger than DEFAULT_JSON_LIMIT (1MB)
17841        let large_content = "x".repeat(DEFAULT_JSON_LIMIT + 100);
17842        let body = format!("{{\"content\": \"{large_content}\"}}");
17843        req.set_body(Body::Bytes(body.into_bytes()));
17844
17845        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
17846        assert!(matches!(
17847            result,
17848            Err(JsonExtractError::PayloadTooLarge { .. })
17849        ));
17850    }
17851
17852    #[test]
17853    fn json_deeply_nested_object() {
17854        use serde::Deserialize;
17855
17856        // Deeply nested structure
17857        #[derive(Deserialize)]
17858        struct Level1 {
17859            #[allow(dead_code)]
17860            l2: Level2,
17861        }
17862        #[derive(Deserialize)]
17863        struct Level2 {
17864            #[allow(dead_code)]
17865            l3: Level3,
17866        }
17867        #[derive(Deserialize)]
17868        struct Level3 {
17869            #[allow(dead_code)]
17870            l4: Level4,
17871        }
17872        #[derive(Deserialize)]
17873        struct Level4 {
17874            #[allow(dead_code)]
17875            value: i32,
17876        }
17877
17878        let ctx = test_context();
17879        let mut req = Request::new(Method::Post, "/test");
17880        req.headers_mut()
17881            .insert("content-type", b"application/json".to_vec());
17882        req.set_body(Body::Bytes(
17883            b"{\"l2\":{\"l3\":{\"l4\":{\"value\":42}}}}".to_vec(),
17884        ));
17885
17886        let result = futures_executor::block_on(Json::<Level1>::from_request(&ctx, &mut req));
17887        assert!(result.is_ok());
17888    }
17889
17890    #[test]
17891    fn query_injection_attempt_escaped() {
17892        use serde::Deserialize;
17893
17894        #[derive(Deserialize, PartialEq, Debug)]
17895        struct Params {
17896            name: String,
17897        }
17898
17899        let ctx = test_context();
17900        let mut req = Request::new(Method::Get, "/test");
17901        // SQL injection attempt - should be treated as literal string
17902        req.set_query(Some(
17903            "name=Robert%27%3B%20DROP%20TABLE%20users%3B--".to_string(),
17904        ));
17905
17906        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
17907        let Query(params) = result.unwrap();
17908        // The value should be preserved as-is (decoded)
17909        assert_eq!(params.name, "Robert'; DROP TABLE users;--");
17910    }
17911
17912    #[test]
17913    fn path_traversal_attempt() {
17914        let ctx = test_context();
17915        let mut req = Request::new(Method::Get, "/files/../../../etc/passwd");
17916        req.insert_extension(PathParams::from_pairs(vec![(
17917            "path".to_string(),
17918            "../../../etc/passwd".to_string(),
17919        )]));
17920
17921        let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
17922        let Path(path) = result.unwrap();
17923        // Path is extracted as-is - application must validate
17924        assert_eq!(path, "../../../etc/passwd");
17925    }
17926
17927    #[test]
17928    fn json_with_script_tag_xss() {
17929        use serde::Deserialize;
17930
17931        #[derive(Deserialize, PartialEq, Debug)]
17932        struct Data {
17933            comment: String,
17934        }
17935
17936        let ctx = test_context();
17937        let mut req = Request::new(Method::Post, "/test");
17938        req.headers_mut()
17939            .insert("content-type", b"application/json".to_vec());
17940        req.set_body(Body::Bytes(
17941            b"{\"comment\": \"<script>alert('xss')</script>\"}".to_vec(),
17942        ));
17943
17944        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
17945        let Json(data) = result.unwrap();
17946        // XSS content is preserved as-is - application must sanitize on output
17947        assert_eq!(data.comment, "<script>alert('xss')</script>");
17948    }
17949
17950    #[test]
17951    fn json_content_type_case_insensitive() {
17952        use serde::Deserialize;
17953
17954        #[derive(Deserialize, PartialEq, Debug)]
17955        struct Data {
17956            value: i32,
17957        }
17958
17959        // Test various case combinations
17960        for content_type in &[
17961            "APPLICATION/JSON",
17962            "Application/Json",
17963            "application/JSON",
17964            "APPLICATION/json",
17965        ] {
17966            let ctx = test_context();
17967            let mut req = Request::new(Method::Post, "/test");
17968            req.headers_mut()
17969                .insert("content-type", content_type.as_bytes().to_vec());
17970            req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
17971
17972            let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
17973            assert!(result.is_ok(), "Failed for content-type: {}", content_type);
17974        }
17975    }
17976
17977    #[test]
17978    fn json_wrong_content_type_variants() {
17979        use serde::Deserialize;
17980
17981        #[derive(Deserialize)]
17982        #[allow(dead_code)]
17983        struct Data {
17984            value: i32,
17985        }
17986
17987        // These should all be rejected
17988        for content_type in &[
17989            "text/json",
17990            "text/plain",
17991            "application/xml",
17992            "application/x-json",
17993        ] {
17994            let ctx = test_context();
17995            let mut req = Request::new(Method::Post, "/test");
17996            req.headers_mut()
17997                .insert("content-type", content_type.as_bytes().to_vec());
17998            req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
17999
18000            let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18001            assert!(
18002                matches!(result, Err(JsonExtractError::UnsupportedMediaType { .. })),
18003                "Should reject content-type: {}",
18004                content_type
18005            );
18006        }
18007    }
18008
18009    #[test]
18010    fn query_null_byte_handling() {
18011        use serde::Deserialize;
18012
18013        #[derive(Deserialize, PartialEq, Debug)]
18014        struct Params {
18015            name: String,
18016        }
18017
18018        let ctx = test_context();
18019        let mut req = Request::new(Method::Get, "/test");
18020        // Percent-encoded null byte
18021        req.set_query(Some("name=test%00value".to_string()));
18022
18023        let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
18024        let Query(params) = result.unwrap();
18025        // Null byte should be decoded
18026        assert_eq!(params.name, "test\0value");
18027    }
18028
18029    #[test]
18030    fn path_with_null_byte() {
18031        let ctx = test_context();
18032        let mut req = Request::new(Method::Get, "/files/test");
18033        req.insert_extension(PathParams::from_pairs(vec![(
18034            "filename".to_string(),
18035            "test\0.txt".to_string(),
18036        )]));
18037
18038        let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
18039        let Path(filename) = result.unwrap();
18040        assert_eq!(filename, "test\0.txt");
18041    }
18042
18043    #[test]
18044    fn json_number_precision() {
18045        use serde::Deserialize;
18046
18047        #[derive(Deserialize, PartialEq, Debug)]
18048        struct Data {
18049            big_int: i64,
18050            float_val: f64,
18051        }
18052
18053        let ctx = test_context();
18054        let mut req = Request::new(Method::Post, "/test");
18055        req.headers_mut()
18056            .insert("content-type", b"application/json".to_vec());
18057        // Large number that fits in i64 but not in f64 without precision loss
18058        req.set_body(Body::Bytes(
18059            b"{\"big_int\": 9007199254740993, \"float_val\": 3.141592653589793}".to_vec(),
18060        ));
18061
18062        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18063        let Json(data) = result.unwrap();
18064        assert_eq!(data.big_int, 9007199254740993_i64);
18065        assert!((data.float_val - std::f64::consts::PI).abs() < 0.0000001);
18066    }
18067
18068    // =========================================================================
18069    // Json<T> IntoResponse tests
18070    // =========================================================================
18071
18072    #[test]
18073    fn json_into_response_serializes_struct() {
18074        use serde::Serialize;
18075
18076        #[derive(Serialize)]
18077        struct User {
18078            name: String,
18079            age: u32,
18080        }
18081
18082        let user = User {
18083            name: "Alice".to_string(),
18084            age: 30,
18085        };
18086        let json = Json(user);
18087        let response = json.into_response();
18088
18089        assert_eq!(response.status().as_u16(), 200);
18090
18091        // Check content-type header
18092        let content_type = response
18093            .headers()
18094            .iter()
18095            .find(|(name, _)| name == "content-type")
18096            .map(|(_, value)| String::from_utf8_lossy(value).to_string());
18097        assert_eq!(content_type, Some("application/json".to_string()));
18098
18099        // Check body content
18100        if let ResponseBody::Bytes(bytes) = response.body_ref() {
18101            let parsed: serde_json::Value = serde_json::from_slice(bytes).unwrap();
18102            assert_eq!(parsed["name"], "Alice");
18103            assert_eq!(parsed["age"], 30);
18104        } else {
18105            panic!("Expected Bytes body");
18106        }
18107    }
18108
18109    #[test]
18110    fn json_into_response_serializes_primitive() {
18111        let json = Json(42i32);
18112        let response = json.into_response();
18113
18114        assert_eq!(response.status().as_u16(), 200);
18115
18116        if let ResponseBody::Bytes(bytes) = response.body_ref() {
18117            let parsed: i32 = serde_json::from_slice(bytes).unwrap();
18118            assert_eq!(parsed, 42);
18119        } else {
18120            panic!("Expected Bytes body");
18121        }
18122    }
18123
18124    #[test]
18125    fn json_into_response_serializes_array() {
18126        let json = Json(vec!["a", "b", "c"]);
18127        let response = json.into_response();
18128
18129        assert_eq!(response.status().as_u16(), 200);
18130
18131        if let ResponseBody::Bytes(bytes) = response.body_ref() {
18132            let parsed: Vec<String> = serde_json::from_slice(bytes).unwrap();
18133            assert_eq!(parsed, vec!["a", "b", "c"]);
18134        } else {
18135            panic!("Expected Bytes body");
18136        }
18137    }
18138
18139    #[test]
18140    fn json_into_response_serializes_hashmap() {
18141        use std::collections::HashMap;
18142
18143        let mut map = HashMap::new();
18144        map.insert("key1", "value1");
18145        map.insert("key2", "value2");
18146
18147        let json = Json(map);
18148        let response = json.into_response();
18149
18150        assert_eq!(response.status().as_u16(), 200);
18151
18152        if let ResponseBody::Bytes(bytes) = response.body_ref() {
18153            let parsed: HashMap<String, String> = serde_json::from_slice(bytes).unwrap();
18154            assert_eq!(parsed.get("key1"), Some(&"value1".to_string()));
18155            assert_eq!(parsed.get("key2"), Some(&"value2".to_string()));
18156        } else {
18157            panic!("Expected Bytes body");
18158        }
18159    }
18160
18161    #[test]
18162    fn json_into_response_handles_null() {
18163        let json = Json(Option::<String>::None);
18164        let response = json.into_response();
18165
18166        assert_eq!(response.status().as_u16(), 200);
18167
18168        if let ResponseBody::Bytes(bytes) = response.body_ref() {
18169            let content = String::from_utf8_lossy(bytes);
18170            assert_eq!(content, "null");
18171        } else {
18172            panic!("Expected Bytes body");
18173        }
18174    }
18175}
18176
18177// ============================================================================
18178// Body Size Limit Enforcement Tests (bd-dl14)
18179// ============================================================================
18180
18181#[cfg(test)]
18182mod body_size_limit_tests {
18183    use super::*;
18184    use crate::request::{Body, Method};
18185    use crate::response::{ResponseBody, StatusCode};
18186
18187    fn test_context() -> RequestContext {
18188        let cx = asupersync::Cx::for_testing();
18189        RequestContext::new(cx, 1)
18190    }
18191
18192    fn test_context_with_limit(limit: usize) -> RequestContext {
18193        let cx = asupersync::Cx::for_testing();
18194        RequestContext::with_body_limit(cx, 1, limit)
18195    }
18196
18197    // ---- Default limit constants ----
18198
18199    #[test]
18200    fn default_constants_match_expected_values() {
18201        assert_eq!(DEFAULT_JSON_LIMIT, 1024 * 1024); // 1MB
18202        assert_eq!(DEFAULT_FORM_LIMIT, 1024 * 1024); // 1MB
18203        assert_eq!(DEFAULT_RAW_BODY_LIMIT, 2 * 1024 * 1024); // 2MB
18204        assert_eq!(crate::DEFAULT_MAX_BODY_SIZE, 1024 * 1024); // 1MB
18205    }
18206
18207    // ---- JSON body limit tests ----
18208
18209    #[test]
18210    fn json_body_under_limit_accepted() {
18211        use serde::Deserialize;
18212
18213        #[derive(Deserialize, Debug)]
18214        struct Msg {
18215            text: String,
18216        }
18217
18218        let ctx = test_context();
18219        let mut req = Request::new(Method::Post, "/api");
18220        req.headers_mut()
18221            .insert("content-type", b"application/json".to_vec());
18222        req.set_body(Body::Bytes(b"{\"text\":\"hello\"}".to_vec()));
18223
18224        let result = futures_executor::block_on(Json::<Msg>::from_request(&ctx, &mut req));
18225        assert!(result.is_ok());
18226        assert_eq!(result.unwrap().0.text, "hello");
18227    }
18228
18229    #[test]
18230    fn json_body_over_default_limit_rejected() {
18231        use serde::Deserialize;
18232
18233        #[derive(Deserialize)]
18234        #[allow(dead_code)]
18235        struct Data {
18236            content: String,
18237        }
18238
18239        let ctx = test_context();
18240        let mut req = Request::new(Method::Post, "/api");
18241        req.headers_mut()
18242            .insert("content-type", b"application/json".to_vec());
18243
18244        // Body exceeds the 1MB default limit
18245        let large = "x".repeat(crate::DEFAULT_MAX_BODY_SIZE + 1);
18246        let body = format!("{{\"content\":\"{}\"}}", large);
18247        req.set_body(Body::Bytes(body.into_bytes()));
18248
18249        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18250        assert!(matches!(
18251            result,
18252            Err(JsonExtractError::PayloadTooLarge { .. })
18253        ));
18254    }
18255
18256    #[test]
18257    fn json_body_exactly_at_limit_accepted() {
18258        use serde::Deserialize;
18259
18260        #[derive(Deserialize)]
18261        #[allow(dead_code)]
18262        struct Data {
18263            content: String,
18264        }
18265
18266        // Use a custom small limit for easier testing
18267        let ctx = test_context_with_limit(100);
18268        let mut req = Request::new(Method::Post, "/api");
18269        req.headers_mut()
18270            .insert("content-type", b"application/json".to_vec());
18271
18272        // Build a body that is exactly 100 bytes
18273        // {"content":"xxx..."} where content is padded to make total = 100
18274        let prefix = b"{\"content\":\"";
18275        let suffix = b"\"}";
18276        let content_len = 100 - prefix.len() - suffix.len();
18277        let content: String = "a".repeat(content_len);
18278        let body = format!("{{\"content\":\"{}\"}}", content);
18279        assert_eq!(body.len(), 100);
18280
18281        req.set_body(Body::Bytes(body.into_bytes()));
18282        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18283        assert!(result.is_ok(), "Body exactly at limit should be accepted");
18284    }
18285
18286    #[test]
18287    fn json_body_one_byte_over_limit_rejected() {
18288        use serde::Deserialize;
18289
18290        #[derive(Deserialize, Debug)]
18291        #[allow(dead_code)]
18292        struct Data {
18293            content: String,
18294        }
18295
18296        let ctx = test_context_with_limit(100);
18297        let mut req = Request::new(Method::Post, "/api");
18298        req.headers_mut()
18299            .insert("content-type", b"application/json".to_vec());
18300
18301        // 101 bytes body
18302        let prefix = b"{\"content\":\"";
18303        let suffix = b"\"}";
18304        let content_len = 101 - prefix.len() - suffix.len();
18305        let content: String = "a".repeat(content_len);
18306        let body = format!("{{\"content\":\"{}\"}}", content);
18307        assert_eq!(body.len(), 101);
18308
18309        req.set_body(Body::Bytes(body.into_bytes()));
18310        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18311        match result {
18312            Err(JsonExtractError::PayloadTooLarge { size, limit }) => {
18313                assert_eq!(size, 101);
18314                assert_eq!(limit, 100);
18315            }
18316            other => panic!("Expected PayloadTooLarge, got {:?}", other),
18317        }
18318    }
18319
18320    #[test]
18321    fn json_custom_body_limit_via_context() {
18322        use serde::Deserialize;
18323
18324        #[derive(Deserialize)]
18325        #[allow(dead_code)]
18326        struct Data {
18327            val: String,
18328        }
18329
18330        // Small limit: 50 bytes
18331        let ctx = test_context_with_limit(50);
18332        let mut req = Request::new(Method::Post, "/api");
18333        req.headers_mut()
18334            .insert("content-type", b"application/json".to_vec());
18335
18336        // Body exceeding 50 bytes
18337        let padding = "x".repeat(60);
18338        let body = format!("{{\"val\":\"{}\"}}", padding);
18339        assert!(
18340            body.len() > 50,
18341            "Body is {} bytes, expected > 50",
18342            body.len()
18343        );
18344        req.set_body(Body::Bytes(body.into_bytes()));
18345
18346        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18347        assert!(matches!(
18348            result,
18349            Err(JsonExtractError::PayloadTooLarge { .. })
18350        ));
18351    }
18352
18353    #[test]
18354    fn json_large_custom_limit_accepts_big_body() {
18355        use serde::Deserialize;
18356
18357        #[derive(Deserialize, Debug)]
18358        #[allow(dead_code)]
18359        struct Data {
18360            content: String,
18361        }
18362
18363        // 10MB limit
18364        let ctx = test_context_with_limit(10 * 1024 * 1024);
18365        let mut req = Request::new(Method::Post, "/api");
18366        req.headers_mut()
18367            .insert("content-type", b"application/json".to_vec());
18368
18369        // 2MB body (under the 10MB custom limit)
18370        let large = "x".repeat(2 * 1024 * 1024);
18371        let body = format!("{{\"content\":\"{}\"}}", large);
18372        req.set_body(Body::Bytes(body.into_bytes()));
18373
18374        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18375        assert!(result.is_ok(), "Body under custom limit should be accepted");
18376    }
18377
18378    #[test]
18379    fn json_empty_body_accepted() {
18380        use serde::Deserialize;
18381
18382        #[derive(Deserialize, Debug)]
18383        #[allow(dead_code)]
18384        struct Data {
18385            #[serde(default)]
18386            val: i32,
18387        }
18388
18389        let ctx = test_context_with_limit(10);
18390        let mut req = Request::new(Method::Post, "/api");
18391        req.headers_mut()
18392            .insert("content-type", b"application/json".to_vec());
18393        req.set_body(Body::Empty);
18394
18395        // Empty body parses to empty bytes, under any limit
18396        // But it will fail deserialization (not a size issue)
18397        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18398        // Empty body means 0 bytes which is under limit, but deserialization fails
18399        match result {
18400            Err(JsonExtractError::DeserializeError { .. }) => {} // expected
18401            Err(JsonExtractError::PayloadTooLarge { .. }) => {
18402                panic!("Empty body should not trigger size limit")
18403            }
18404            Ok(_) => {} // Also fine if serde accepts empty
18405            other => panic!("Unexpected result: {:?}", other),
18406        }
18407    }
18408
18409    #[test]
18410    fn json_payload_too_large_error_response_is_413() {
18411        let err = JsonExtractError::PayloadTooLarge {
18412            size: 2_000_000,
18413            limit: 1_000_000,
18414        };
18415        let response = err.into_response();
18416        assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
18417        assert_eq!(response.status().as_u16(), 413);
18418    }
18419
18420    // ---- Form body limit tests ----
18421
18422    #[test]
18423    fn form_body_under_limit_accepted() {
18424        use serde::Deserialize;
18425
18426        #[derive(Deserialize, Debug)]
18427        struct Login {
18428            user: String,
18429        }
18430
18431        let ctx = test_context();
18432        let mut req = Request::new(Method::Post, "/login");
18433        req.headers_mut().insert(
18434            "content-type",
18435            b"application/x-www-form-urlencoded".to_vec(),
18436        );
18437        req.set_body(Body::Bytes(b"user=alice".to_vec()));
18438
18439        let result = futures_executor::block_on(Form::<Login>::from_request(&ctx, &mut req));
18440        assert!(result.is_ok());
18441        assert_eq!(result.unwrap().0.user, "alice");
18442    }
18443
18444    #[test]
18445    fn form_body_over_limit_rejected() {
18446        use serde::Deserialize;
18447
18448        #[derive(Deserialize)]
18449        #[allow(dead_code)]
18450        struct Data {
18451            field: String,
18452        }
18453
18454        let ctx = test_context(); // default 1MB limit
18455        let mut req = Request::new(Method::Post, "/submit");
18456        req.headers_mut().insert(
18457            "content-type",
18458            b"application/x-www-form-urlencoded".to_vec(),
18459        );
18460
18461        // Build form body larger than 1MB
18462        let large_value = "x".repeat(crate::DEFAULT_MAX_BODY_SIZE + 1);
18463        let body = format!("field={}", large_value);
18464        req.set_body(Body::Bytes(body.into_bytes()));
18465
18466        let result = futures_executor::block_on(Form::<Data>::from_request(&ctx, &mut req));
18467        assert!(matches!(
18468            result,
18469            Err(FormExtractError::PayloadTooLarge { .. })
18470        ));
18471    }
18472
18473    #[test]
18474    fn form_custom_limit_via_context() {
18475        use serde::Deserialize;
18476
18477        #[derive(Deserialize, Debug)]
18478        #[allow(dead_code)]
18479        struct Data {
18480            field: String,
18481        }
18482
18483        let ctx = test_context_with_limit(20);
18484        let mut req = Request::new(Method::Post, "/submit");
18485        req.headers_mut().insert(
18486            "content-type",
18487            b"application/x-www-form-urlencoded".to_vec(),
18488        );
18489
18490        // 30-byte body exceeds 20-byte limit
18491        let body = "field=abcdefghijklmnopqrstuv";
18492        assert!(body.len() > 20);
18493        req.set_body(Body::Bytes(body.as_bytes().to_vec()));
18494
18495        let result = futures_executor::block_on(Form::<Data>::from_request(&ctx, &mut req));
18496        match result {
18497            Err(FormExtractError::PayloadTooLarge { size, limit }) => {
18498                assert_eq!(limit, 20);
18499                assert!(size > 20);
18500            }
18501            other => panic!("Expected PayloadTooLarge, got {:?}", other),
18502        }
18503    }
18504
18505    #[test]
18506    fn form_payload_too_large_error_response_is_413() {
18507        let err = FormExtractError::PayloadTooLarge {
18508            size: 2_000_000,
18509            limit: 1_000_000,
18510        };
18511        let response = err.into_response();
18512        assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
18513    }
18514
18515    // ---- Raw body (Bytes) limit tests ----
18516
18517    #[test]
18518    fn bytes_body_under_limit_accepted() {
18519        let ctx = test_context();
18520        let mut req = Request::new(Method::Post, "/upload");
18521        req.set_body(Body::Bytes(b"small payload".to_vec()));
18522
18523        let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18524        assert!(result.is_ok());
18525        assert_eq!(result.unwrap().as_slice(), b"small payload");
18526    }
18527
18528    #[test]
18529    fn bytes_body_over_default_limit_rejected() {
18530        let ctx = test_context();
18531        let mut req = Request::new(Method::Post, "/upload");
18532        let large_body = vec![0u8; DEFAULT_RAW_BODY_LIMIT + 1];
18533        req.set_body(Body::Bytes(large_body));
18534
18535        let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18536        match result {
18537            Err(RawBodyError::PayloadTooLarge { size, limit }) => {
18538                assert_eq!(size, DEFAULT_RAW_BODY_LIMIT + 1);
18539                assert_eq!(limit, DEFAULT_RAW_BODY_LIMIT);
18540            }
18541            other => panic!("Expected PayloadTooLarge, got {:?}", other),
18542        }
18543    }
18544
18545    #[test]
18546    fn bytes_custom_limit_via_extension() {
18547        let ctx = test_context();
18548        let mut req = Request::new(Method::Post, "/upload");
18549        req.insert_extension(RawBodyConfig::new().limit(50));
18550        req.set_body(Body::Bytes(vec![0u8; 80]));
18551
18552        let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18553        match result {
18554            Err(RawBodyError::PayloadTooLarge { size, limit }) => {
18555                assert_eq!(size, 80);
18556                assert_eq!(limit, 50);
18557            }
18558            other => panic!("Expected PayloadTooLarge, got {:?}", other),
18559        }
18560    }
18561
18562    #[test]
18563    fn bytes_custom_limit_accepts_body_under() {
18564        let ctx = test_context();
18565        let mut req = Request::new(Method::Post, "/upload");
18566        req.insert_extension(RawBodyConfig::new().limit(200));
18567        req.set_body(Body::Bytes(vec![0u8; 100]));
18568
18569        let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18570        assert!(result.is_ok());
18571        assert_eq!(result.unwrap().len(), 100);
18572    }
18573
18574    #[test]
18575    fn bytes_empty_body_always_accepted() {
18576        let ctx = test_context();
18577        let mut req = Request::new(Method::Post, "/upload");
18578        req.insert_extension(RawBodyConfig::new().limit(0));
18579        req.set_body(Body::Empty);
18580
18581        let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18582        assert!(result.is_ok());
18583        assert!(result.unwrap().is_empty());
18584    }
18585
18586    #[test]
18587    fn bytes_payload_too_large_error_response_is_413() {
18588        let err = RawBodyError::PayloadTooLarge {
18589            size: 5_000_000,
18590            limit: 2_000_000,
18591        };
18592        let response = err.into_response();
18593        assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
18594    }
18595
18596    // ---- StringBody limit tests ----
18597
18598    #[test]
18599    fn string_body_over_limit_rejected() {
18600        let ctx = test_context();
18601        let mut req = Request::new(Method::Post, "/text");
18602        req.insert_extension(RawBodyConfig::new().limit(10));
18603        req.set_body(Body::Bytes(b"this is longer than ten bytes".to_vec()));
18604
18605        let result = futures_executor::block_on(StringBody::from_request(&ctx, &mut req));
18606        assert!(matches!(result, Err(RawBodyError::PayloadTooLarge { .. })));
18607    }
18608
18609    #[test]
18610    fn string_body_under_limit_accepted() {
18611        let ctx = test_context();
18612        let mut req = Request::new(Method::Post, "/text");
18613        req.insert_extension(RawBodyConfig::new().limit(100));
18614        req.set_body(Body::Bytes(b"short".to_vec()));
18615
18616        let result = futures_executor::block_on(StringBody::from_request(&ctx, &mut req));
18617        assert!(result.is_ok());
18618        assert_eq!(result.unwrap().as_str(), "short");
18619    }
18620
18621    // ---- RawBodyConfig tests ----
18622
18623    #[test]
18624    fn raw_body_config_default() {
18625        let config = RawBodyConfig::default();
18626        assert_eq!(config.get_limit(), DEFAULT_RAW_BODY_LIMIT);
18627        assert_eq!(config.get_limit(), 2 * 1024 * 1024);
18628    }
18629
18630    #[test]
18631    fn raw_body_config_builder() {
18632        let config = RawBodyConfig::new().limit(500);
18633        assert_eq!(config.get_limit(), 500);
18634    }
18635
18636    // ---- JsonConfig tests ----
18637
18638    #[test]
18639    fn json_config_default() {
18640        let config = JsonConfig::default();
18641        assert_eq!(config.get_limit(), DEFAULT_JSON_LIMIT);
18642    }
18643
18644    #[test]
18645    fn json_config_builder() {
18646        let config = JsonConfig::new().limit(2048);
18647        assert_eq!(config.get_limit(), 2048);
18648    }
18649
18650    // ---- FormConfig tests ----
18651
18652    #[test]
18653    fn form_config_default() {
18654        let config = FormConfig::default();
18655        assert_eq!(config.get_limit(), DEFAULT_FORM_LIMIT);
18656    }
18657
18658    #[test]
18659    fn form_config_builder() {
18660        let config = FormConfig::new().limit(4096);
18661        assert_eq!(config.get_limit(), 4096);
18662    }
18663
18664    // ---- Cross-extractor limit behavior ----
18665
18666    #[test]
18667    fn json_uses_context_body_limit_not_json_config() {
18668        // Verify Json extractor uses ctx.max_body_size(), not JsonConfig.limit
18669        use serde::Deserialize;
18670
18671        #[derive(Deserialize)]
18672        #[allow(dead_code)]
18673        struct Data {
18674            val: String,
18675        }
18676
18677        // Context with 50 byte limit
18678        let ctx = test_context_with_limit(50);
18679        let mut req = Request::new(Method::Post, "/api");
18680        req.headers_mut()
18681            .insert("content-type", b"application/json".to_vec());
18682
18683        // Body of 60 bytes
18684        let body = "{\"val\":\"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\"}";
18685        assert!(body.len() > 50);
18686        req.set_body(Body::Bytes(body.as_bytes().to_vec()));
18687
18688        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18689        assert!(
18690            matches!(result, Err(JsonExtractError::PayloadTooLarge { .. })),
18691            "Json should use context body limit (50), not default"
18692        );
18693    }
18694
18695    #[test]
18696    fn form_uses_context_body_limit() {
18697        use serde::Deserialize;
18698
18699        #[derive(Deserialize)]
18700        #[allow(dead_code)]
18701        struct Data {
18702            val: String,
18703        }
18704
18705        // Context with 30 byte limit
18706        let ctx = test_context_with_limit(30);
18707        let mut req = Request::new(Method::Post, "/form");
18708        req.headers_mut().insert(
18709            "content-type",
18710            b"application/x-www-form-urlencoded".to_vec(),
18711        );
18712
18713        // Body of 35 bytes
18714        let body = "val=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
18715        assert!(body.len() > 30);
18716        req.set_body(Body::Bytes(body.as_bytes().to_vec()));
18717
18718        let result = futures_executor::block_on(Form::<Data>::from_request(&ctx, &mut req));
18719        assert!(
18720            matches!(result, Err(FormExtractError::PayloadTooLarge { .. })),
18721            "Form should use context body limit"
18722        );
18723    }
18724
18725    #[test]
18726    fn bytes_uses_extension_config_not_context_limit() {
18727        // Bytes uses RawBodyConfig extension, not ctx.max_body_size()
18728        let ctx = test_context_with_limit(10); // Small context limit
18729        let mut req = Request::new(Method::Post, "/upload");
18730        // RawBodyConfig with a large limit
18731        req.insert_extension(RawBodyConfig::new().limit(1000));
18732        req.set_body(Body::Bytes(vec![0u8; 500]));
18733
18734        let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18735        // Should use extension config (1000), not context limit (10)
18736        assert!(
18737            result.is_ok(),
18738            "Bytes should use RawBodyConfig from extension, not context limit"
18739        );
18740    }
18741
18742    #[test]
18743    fn bytes_without_config_uses_default_raw_limit() {
18744        // When no RawBodyConfig extension, Bytes uses DEFAULT_RAW_BODY_LIMIT
18745        let ctx = test_context_with_limit(10); // Small context limit
18746        let mut req = Request::new(Method::Post, "/upload");
18747        // No extension set, body under DEFAULT_RAW_BODY_LIMIT (2MB)
18748        req.set_body(Body::Bytes(vec![0u8; 500]));
18749
18750        let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18751        // Uses DEFAULT_RAW_BODY_LIMIT (2MB), not context limit (10)
18752        assert!(result.is_ok());
18753    }
18754
18755    // ---- Error detail tests ----
18756
18757    #[test]
18758    fn json_error_response_body_contains_size_info() {
18759        let err = JsonExtractError::PayloadTooLarge {
18760            size: 2_500_000,
18761            limit: 1_048_576,
18762        };
18763        let response = err.into_response();
18764        if let ResponseBody::Bytes(bytes) = response.body_ref() {
18765            let body = std::str::from_utf8(bytes).unwrap();
18766            assert!(
18767                body.contains("2500000"),
18768                "Error should contain actual size, got: {}",
18769                body
18770            );
18771            assert!(
18772                body.contains("1048576"),
18773                "Error should contain limit, got: {}",
18774                body
18775            );
18776        } else {
18777            panic!("Expected Bytes body");
18778        }
18779    }
18780
18781    #[test]
18782    fn form_error_response_body_contains_size_info() {
18783        let err = FormExtractError::PayloadTooLarge {
18784            size: 3_000_000,
18785            limit: 1_048_576,
18786        };
18787        let response = err.into_response();
18788        if let ResponseBody::Bytes(bytes) = response.body_ref() {
18789            let body = std::str::from_utf8(bytes).unwrap();
18790            assert!(
18791                body.contains("3000000"),
18792                "Error should contain actual size, got: {}",
18793                body
18794            );
18795            assert!(
18796                body.contains("1048576"),
18797                "Error should contain limit, got: {}",
18798                body
18799            );
18800        } else {
18801            panic!("Expected Bytes body");
18802        }
18803    }
18804
18805    #[test]
18806    fn raw_body_error_response_contains_size_info() {
18807        let err = RawBodyError::PayloadTooLarge {
18808            size: 5_000_000,
18809            limit: 2_097_152,
18810        };
18811        let response = err.into_response();
18812        if let ResponseBody::Bytes(bytes) = response.body_ref() {
18813            let body = std::str::from_utf8(bytes).unwrap();
18814            assert!(
18815                body.contains("5000000"),
18816                "Error should contain actual size, got: {}",
18817                body
18818            );
18819            assert!(
18820                body.contains("2097152"),
18821                "Error should contain limit, got: {}",
18822                body
18823            );
18824        } else {
18825            panic!("Expected Bytes body");
18826        }
18827    }
18828
18829    // ---- Streaming body rejection tests ----
18830
18831    #[test]
18832    fn json_streaming_body_rejected() {
18833        use serde::Deserialize;
18834
18835        #[derive(Deserialize)]
18836        #[allow(dead_code)]
18837        struct Data {
18838            val: i32,
18839        }
18840
18841        let ctx = test_context();
18842        let mut req = Request::new(Method::Post, "/api");
18843        req.headers_mut()
18844            .insert("content-type", b"application/json".to_vec());
18845
18846        let stream = asupersync::stream::iter(
18847            vec![Ok(b"chunk".to_vec())]
18848                .into_iter()
18849                .map(|r: Result<Vec<u8>, crate::request::RequestBodyStreamError>| r),
18850        );
18851        req.set_body(Body::streaming(stream));
18852
18853        let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18854        assert!(matches!(
18855            result,
18856            Err(JsonExtractError::StreamingNotSupported)
18857        ));
18858    }
18859
18860    #[test]
18861    fn form_streaming_body_rejected() {
18862        use serde::Deserialize;
18863
18864        #[derive(Deserialize)]
18865        #[allow(dead_code)]
18866        struct Data {
18867            field: String,
18868        }
18869
18870        let ctx = test_context();
18871        let mut req = Request::new(Method::Post, "/form");
18872        req.headers_mut().insert(
18873            "content-type",
18874            b"application/x-www-form-urlencoded".to_vec(),
18875        );
18876
18877        let stream = asupersync::stream::iter(
18878            vec![Ok(b"chunk".to_vec())]
18879                .into_iter()
18880                .map(|r: Result<Vec<u8>, crate::request::RequestBodyStreamError>| r),
18881        );
18882        req.set_body(Body::streaming(stream));
18883
18884        let result = futures_executor::block_on(Form::<Data>::from_request(&ctx, &mut req));
18885        assert!(matches!(
18886            result,
18887            Err(FormExtractError::StreamingNotSupported)
18888        ));
18889    }
18890
18891    #[test]
18892    fn bytes_streaming_body_rejected() {
18893        let ctx = test_context();
18894        let mut req = Request::new(Method::Post, "/upload");
18895
18896        let stream = asupersync::stream::iter(
18897            vec![Ok(b"chunk".to_vec())]
18898                .into_iter()
18899                .map(|r: Result<Vec<u8>, crate::request::RequestBodyStreamError>| r),
18900        );
18901        req.set_body(Body::streaming(stream));
18902
18903        let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18904        assert!(matches!(result, Err(RawBodyError::StreamingNotSupported)));
18905    }
18906
18907    // ====================================================================
18908    // Digest Auth Tests
18909    // ====================================================================
18910
18911    #[test]
18912    fn digest_auth_extraction() {
18913        let ctx = test_context();
18914        let mut req = Request::new(Method::Get, "/protected");
18915        req.headers_mut().insert(
18916            "authorization",
18917            b"Digest username=\"alice\", realm=\"test\", nonce=\"abc123\"".to_vec(),
18918        );
18919        let result = futures_executor::block_on(DigestAuth::from_request(&ctx, &mut req));
18920        let auth = result.unwrap();
18921        assert!(auth.credentials().contains("username=\"alice\""));
18922    }
18923
18924    #[test]
18925    fn digest_auth_param_extraction() {
18926        let auth = DigestAuth::new("username=\"alice\", realm=\"test\", nonce=\"abc123\"");
18927        assert_eq!(auth.param("username"), Some("alice"));
18928        assert_eq!(auth.param("realm"), Some("test"));
18929        assert_eq!(auth.param("nonce"), Some("abc123"));
18930        assert_eq!(auth.param("nonexistent"), None);
18931    }
18932
18933    #[test]
18934    fn digest_auth_param_no_substring_match() {
18935        // Regression test: param("name") should NOT match inside "username"
18936        let auth = DigestAuth::new("username=\"alice\", realm=\"test\", nonce=\"abc123\"");
18937
18938        // "name" is a substring of "username" but should not be extracted
18939        assert_eq!(auth.param("name"), None);
18940
18941        // "realm" should still work correctly
18942        assert_eq!(auth.param("realm"), Some("test"));
18943
18944        // Single-char substrings should also not match incorrectly
18945        assert_eq!(auth.param("e"), None); // "e" is in "username", "realm"
18946        assert_eq!(auth.param("c"), None); // "c" is in "nonce"
18947    }
18948
18949    #[test]
18950    fn digest_auth_param_unquoted_values() {
18951        // Some Digest parameters like qop, nc can be unquoted
18952        let auth = DigestAuth::new("qop=auth, nc=00000001, cnonce=\"xyz\"");
18953        assert_eq!(auth.param("qop"), Some("auth"));
18954        assert_eq!(auth.param("nc"), Some("00000001"));
18955        assert_eq!(auth.param("cnonce"), Some("xyz"));
18956
18957        // "c" should NOT match "nc=00000001"
18958        assert_eq!(auth.param("c"), None);
18959    }
18960
18961    #[test]
18962    fn digest_auth_param_at_start() {
18963        // Parameter at the start of the credentials string
18964        let auth = DigestAuth::new("realm=\"test\", username=\"bob\"");
18965        assert_eq!(auth.param("realm"), Some("test"));
18966        assert_eq!(auth.param("username"), Some("bob"));
18967    }
18968
18969    #[test]
18970    fn digest_auth_missing_header() {
18971        let ctx = test_context();
18972        let mut req = Request::new(Method::Get, "/protected");
18973        let result = futures_executor::block_on(DigestAuth::from_request(&ctx, &mut req));
18974        assert!(matches!(result, Err(DigestAuthError::MissingHeader)));
18975    }
18976
18977    #[test]
18978    fn digest_auth_wrong_scheme() {
18979        let ctx = test_context();
18980        let mut req = Request::new(Method::Get, "/protected");
18981        req.headers_mut()
18982            .insert("authorization", b"Bearer token123".to_vec());
18983        let result = futures_executor::block_on(DigestAuth::from_request(&ctx, &mut req));
18984        assert!(matches!(result, Err(DigestAuthError::InvalidScheme)));
18985    }
18986
18987    #[test]
18988    fn digest_auth_error_response_401() {
18989        let resp = DigestAuthError::MissingHeader.into_response();
18990        assert_eq!(resp.status().as_u16(), 401);
18991        let has_www_auth = resp
18992            .headers()
18993            .iter()
18994            .any(|(n, v)| n == "www-authenticate" && v == b"Digest");
18995        assert!(has_www_auth);
18996    }
18997
18998    #[test]
18999    fn digest_auth_case_insensitive() {
19000        let ctx = test_context();
19001        let mut req = Request::new(Method::Get, "/protected");
19002        req.headers_mut()
19003            .insert("authorization", b"digest username=\"bob\"".to_vec());
19004        let result = futures_executor::block_on(DigestAuth::from_request(&ctx, &mut req));
19005        assert!(result.is_ok());
19006    }
19007}