armature_core/
extractors.rs

1//! Request parameter extractors
2//!
3//! This module provides types for extracting data from HTTP requests
4//! in a type-safe manner, similar to NestJS decorators.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use armature::prelude::*;
10//! use armature_core::extractors::{Body, Query, Path, Header};
11//!
12//! #[derive(Deserialize)]
13//! struct CreateUser {
14//!     name: String,
15//!     email: String,
16//! }
17//!
18//! #[derive(Deserialize)]
19//! struct UserFilters {
20//!     page: Option<u32>,
21//!     limit: Option<u32>,
22//! }
23//!
24//! // Extract body as JSON
25//! let body: Body<CreateUser> = Body::from_request(&request)?;
26//!
27//! // Extract query parameters
28//! let query: Query<UserFilters> = Query::from_request(&request)?;
29//!
30//! // Extract path parameter
31//! let id: Path<u32> = Path::from_request(&request, "id")?;
32//!
33//! // Extract header
34//! let auth: Header = Header::from_request(&request, "Authorization")?;
35//! ```
36
37use crate::{Error, HttpRequest};
38use serde::de::DeserializeOwned;
39use std::ops::Deref;
40use std::sync::Arc;
41
42/// Trait for extracting data from an HTTP request
43pub trait FromRequest: Sized {
44    /// Extract data from the request
45    fn from_request(request: &HttpRequest) -> Result<Self, Error>;
46}
47
48// ========== State Extractor ==========
49
50/// Zero-cost application state extractor.
51///
52/// `State<T>` provides type-safe access to application state without
53/// runtime type checking overhead. State is stored in request extensions
54/// and retrieved via `TypeId` lookup followed by a direct pointer cast.
55///
56/// # Performance
57///
58/// Unlike DI container lookups which use `Any::downcast`, `State<T>` uses
59/// a pre-verified `TypeId` for O(1) retrieval with no runtime type checking.
60///
61/// # Example
62///
63/// ```rust,ignore
64/// use armature_core::extractors::State;
65/// use std::sync::Arc;
66///
67/// // Define your application state
68/// #[derive(Clone)]
69/// struct AppState {
70///     db_pool: Pool,
71///     config: AppConfig,
72/// }
73///
74/// // Insert state into the application (done once at startup)
75/// let state = Arc::new(AppState { db_pool, config });
76/// app.with_state(state);
77///
78/// // Extract in handler - zero-cost after setup
79/// #[get("/users")]
80/// async fn list_users(state: State<AppState>) -> Result<HttpResponse, Error> {
81///     let users = state.db_pool.query("SELECT * FROM users").await?;
82///     HttpResponse::json(&users)
83/// }
84/// ```
85///
86/// # Notes
87///
88/// - State must be `Send + Sync + 'static`
89/// - State should be wrapped in `Arc` for efficient cloning
90/// - Multiple state types can be registered
91#[derive(Debug)]
92pub struct State<T: Send + Sync + 'static>(pub Arc<T>);
93
94impl<T: Send + Sync + 'static> State<T> {
95    /// Create a new State wrapper.
96    #[inline]
97    pub fn new(value: Arc<T>) -> Self {
98        Self(value)
99    }
100
101    /// Get the inner Arc.
102    #[inline]
103    pub fn into_inner(self) -> Arc<T> {
104        self.0
105    }
106}
107
108impl<T: Send + Sync + 'static> Clone for State<T> {
109    #[inline]
110    fn clone(&self) -> Self {
111        Self(Arc::clone(&self.0))
112    }
113}
114
115impl<T: Send + Sync + 'static> Deref for State<T> {
116    type Target = T;
117
118    #[inline]
119    fn deref(&self) -> &Self::Target {
120        &self.0
121    }
122}
123
124impl<T: Send + Sync + 'static> AsRef<T> for State<T> {
125    #[inline]
126    fn as_ref(&self) -> &T {
127        &self.0
128    }
129}
130
131impl<T: Send + Sync + 'static> FromRequest for State<T> {
132    /// Extract state from request extensions.
133    ///
134    /// # Errors
135    ///
136    /// Returns `Error::ProviderNotFound` if state of type `T` was not
137    /// registered in the application.
138    #[inline]
139    fn from_request(request: &HttpRequest) -> Result<Self, Error> {
140        request.extensions.get_arc::<T>().map(State).ok_or_else(|| {
141            Error::ProviderNotFound(format!(
142                "State<{}> not found in request extensions. \
143                     Did you forget to register it with `app.with_state()`?",
144                std::any::type_name::<T>()
145            ))
146        })
147    }
148}
149
150/// Trait for extracting named parameters from a request
151pub trait FromRequestNamed: Sized {
152    /// Extract a named parameter from the request
153    fn from_request(request: &HttpRequest, name: &str) -> Result<Self, Error>;
154}
155
156// ========== Body Extractor ==========
157
158/// Extracts and deserializes the request body as JSON
159///
160/// # Example
161///
162/// ```rust,ignore
163/// #[derive(Deserialize)]
164/// struct CreateUser {
165///     name: String,
166///     email: String,
167/// }
168///
169/// let body: Body<CreateUser> = Body::from_request(&request)?;
170/// println!("Creating user: {}", body.name);
171/// ```
172#[derive(Debug, Clone)]
173pub struct Body<T>(pub T);
174
175impl<T> Body<T> {
176    /// Create a new Body wrapper
177    pub fn new(value: T) -> Self {
178        Self(value)
179    }
180
181    /// Get the inner value
182    pub fn into_inner(self) -> T {
183        self.0
184    }
185}
186
187impl<T> Deref for Body<T> {
188    type Target = T;
189
190    fn deref(&self) -> &Self::Target {
191        &self.0
192    }
193}
194
195impl<T: DeserializeOwned> FromRequest for Body<T> {
196    fn from_request(request: &HttpRequest) -> Result<Self, Error> {
197        let value: T = request.json()?;
198        Ok(Body(value))
199    }
200}
201
202// ========== Query Extractor ==========
203
204/// Extracts and deserializes query parameters
205///
206/// # Example
207///
208/// ```rust,ignore
209/// #[derive(Deserialize)]
210/// struct Pagination {
211///     page: Option<u32>,
212///     limit: Option<u32>,
213///     sort: Option<String>,
214/// }
215///
216/// let query: Query<Pagination> = Query::from_request(&request)?;
217/// let page = query.page.unwrap_or(1);
218/// ```
219#[derive(Debug, Clone)]
220pub struct Query<T>(pub T);
221
222impl<T> Query<T> {
223    /// Create a new Query wrapper
224    pub fn new(value: T) -> Self {
225        Self(value)
226    }
227
228    /// Get the inner value
229    pub fn into_inner(self) -> T {
230        self.0
231    }
232}
233
234impl<T> Deref for Query<T> {
235    type Target = T;
236
237    fn deref(&self) -> &Self::Target {
238        &self.0
239    }
240}
241
242impl<T: DeserializeOwned> FromRequest for Query<T> {
243    fn from_request(request: &HttpRequest) -> Result<Self, Error> {
244        // Build a query string from params and deserialize
245        let query_string: String = request
246            .query_params
247            .iter()
248            .map(|(k, v)| format!("{}={}", k, v))
249            .collect::<Vec<_>>()
250            .join("&");
251
252        let value: T = serde_urlencoded::from_str(&query_string)
253            .map_err(|e| Error::Validation(format!("Invalid query parameters: {}", e)))?;
254
255        Ok(Query(value))
256    }
257}
258
259// ========== Path Extractor ==========
260
261/// Extracts a path parameter by name
262///
263/// # Example
264///
265/// ```rust,ignore
266/// // For route /users/:id
267/// let id: Path<u32> = Path::from_request(&request, "id")?;
268/// println!("User ID: {}", *id);
269/// ```
270#[derive(Debug, Clone)]
271pub struct Path<T>(pub T);
272
273impl<T> Path<T> {
274    /// Create a new Path wrapper
275    pub fn new(value: T) -> Self {
276        Self(value)
277    }
278
279    /// Get the inner value
280    pub fn into_inner(self) -> T {
281        self.0
282    }
283}
284
285impl<T> Deref for Path<T> {
286    type Target = T;
287
288    fn deref(&self) -> &Self::Target {
289        &self.0
290    }
291}
292
293impl<T: std::str::FromStr> FromRequestNamed for Path<T>
294where
295    T::Err: std::fmt::Display,
296{
297    fn from_request(request: &HttpRequest, name: &str) -> Result<Self, Error> {
298        let value_str = request
299            .param(name)
300            .ok_or_else(|| Error::Validation(format!("Missing path parameter: {}", name)))?;
301
302        let value: T = value_str.parse().map_err(|e: T::Err| {
303            Error::Validation(format!("Invalid path parameter '{}': {}", name, e))
304        })?;
305
306        Ok(Path(value))
307    }
308}
309
310// ========== PathParams Extractor ==========
311
312/// Extracts all path parameters into a struct
313///
314/// # Example
315///
316/// ```rust,ignore
317/// #[derive(Deserialize)]
318/// struct UserParams {
319///     user_id: u32,
320///     post_id: u32,
321/// }
322///
323/// // For route /users/:user_id/posts/:post_id
324/// let params: PathParams<UserParams> = PathParams::from_request(&request)?;
325/// ```
326#[derive(Debug, Clone)]
327pub struct PathParams<T>(pub T);
328
329impl<T> PathParams<T> {
330    /// Create a new PathParams wrapper
331    pub fn new(value: T) -> Self {
332        Self(value)
333    }
334
335    /// Get the inner value
336    pub fn into_inner(self) -> T {
337        self.0
338    }
339}
340
341impl<T> Deref for PathParams<T> {
342    type Target = T;
343
344    fn deref(&self) -> &Self::Target {
345        &self.0
346    }
347}
348
349impl<T: DeserializeOwned> FromRequest for PathParams<T> {
350    fn from_request(request: &HttpRequest) -> Result<Self, Error> {
351        // Build a query string from path params and deserialize
352        let params_string: String = request
353            .path_params
354            .iter()
355            .map(|(k, v)| format!("{}={}", k, v))
356            .collect::<Vec<_>>()
357            .join("&");
358
359        let value: T = serde_urlencoded::from_str(&params_string)
360            .map_err(|e| Error::Validation(format!("Invalid path parameters: {}", e)))?;
361
362        Ok(PathParams(value))
363    }
364}
365
366// ========== Header Extractor ==========
367
368/// Extracts a header value by name
369///
370/// # Example
371///
372/// ```rust,ignore
373/// let auth: Header = Header::from_request(&request, "Authorization")?;
374/// println!("Auth: {}", auth.value());
375///
376/// // Or as optional
377/// let custom: Option<Header> = Header::optional(&request, "X-Custom-Header");
378/// ```
379#[derive(Debug, Clone)]
380pub struct Header {
381    name: String,
382    value: String,
383}
384
385impl Header {
386    /// Create a new Header
387    pub fn new(name: impl Into<String>, value: impl Into<String>) -> Self {
388        Self {
389            name: name.into(),
390            value: value.into(),
391        }
392    }
393
394    /// Get the header name
395    pub fn name(&self) -> &str {
396        &self.name
397    }
398
399    /// Get the header value
400    pub fn value(&self) -> &str {
401        &self.value
402    }
403
404    /// Get the header value, consuming self
405    pub fn into_value(self) -> String {
406        self.value
407    }
408
409    /// Extract a header, returning None if not present
410    pub fn optional(request: &HttpRequest, name: &str) -> Option<Self> {
411        request
412            .headers
413            .get(name)
414            .or_else(|| request.headers.get(&name.to_lowercase()))
415            .map(|v| Header::new(name, v.clone()))
416    }
417}
418
419impl FromRequestNamed for Header {
420    fn from_request(request: &HttpRequest, name: &str) -> Result<Self, Error> {
421        let value = request
422            .headers
423            .get(name)
424            .or_else(|| request.headers.get(&name.to_lowercase()))
425            .ok_or_else(|| Error::Validation(format!("Missing header: {}", name)))?;
426
427        Ok(Header::new(name, value.clone()))
428    }
429}
430
431impl Deref for Header {
432    type Target = str;
433
434    fn deref(&self) -> &Self::Target {
435        &self.value
436    }
437}
438
439// ========== Headers Extractor ==========
440
441/// Extracts all headers as a map
442#[derive(Debug, Clone)]
443pub struct Headers(pub std::collections::HashMap<String, String>);
444
445impl Headers {
446    /// Get a header value by name
447    pub fn get(&self, name: &str) -> Option<&String> {
448        self.0
449            .get(name)
450            .or_else(|| self.0.get(&name.to_lowercase()))
451    }
452
453    /// Check if a header exists
454    pub fn contains(&self, name: &str) -> bool {
455        self.get(name).is_some()
456    }
457
458    /// Iterate over all headers
459    pub fn iter(&self) -> impl Iterator<Item = (&String, &String)> {
460        self.0.iter()
461    }
462}
463
464impl FromRequest for Headers {
465    fn from_request(request: &HttpRequest) -> Result<Self, Error> {
466        Ok(Headers(request.headers.clone()))
467    }
468}
469
470impl Deref for Headers {
471    type Target = std::collections::HashMap<String, String>;
472
473    fn deref(&self) -> &Self::Target {
474        &self.0
475    }
476}
477
478// ========== RawBody Extractor ==========
479
480/// Extracts the raw request body as bytes
481///
482/// # Example
483///
484/// ```rust,ignore
485/// let raw: RawBody = RawBody::from_request(&request)?;
486/// println!("Body length: {} bytes", raw.len());
487/// ```
488#[derive(Debug, Clone)]
489pub struct RawBody(pub Vec<u8>);
490
491impl RawBody {
492    /// Create a new RawBody
493    pub fn new(data: Vec<u8>) -> Self {
494        Self(data)
495    }
496
497    /// Get the body length
498    pub fn len(&self) -> usize {
499        self.0.len()
500    }
501
502    /// Check if the body is empty
503    pub fn is_empty(&self) -> bool {
504        self.0.is_empty()
505    }
506
507    /// Convert to a UTF-8 string
508    pub fn to_string_lossy(&self) -> String {
509        String::from_utf8_lossy(&self.0).to_string()
510    }
511
512    /// Try to convert to a UTF-8 string
513    pub fn to_string(&self) -> Result<String, std::string::FromUtf8Error> {
514        String::from_utf8(self.0.clone())
515    }
516
517    /// Get the inner bytes
518    pub fn into_inner(self) -> Vec<u8> {
519        self.0
520    }
521}
522
523impl FromRequest for RawBody {
524    fn from_request(request: &HttpRequest) -> Result<Self, Error> {
525        Ok(RawBody(request.body.clone()))
526    }
527}
528
529impl Deref for RawBody {
530    type Target = [u8];
531
532    fn deref(&self) -> &Self::Target {
533        &self.0
534    }
535}
536
537// ========== Form Extractor ==========
538
539/// Extracts and deserializes form data (application/x-www-form-urlencoded)
540///
541/// # Example
542///
543/// ```rust,ignore
544/// #[derive(Deserialize)]
545/// struct LoginForm {
546///     username: String,
547///     password: String,
548/// }
549///
550/// let form: Form<LoginForm> = Form::from_request(&request)?;
551/// ```
552#[derive(Debug, Clone)]
553pub struct Form<T>(pub T);
554
555impl<T> Form<T> {
556    /// Create a new Form wrapper
557    pub fn new(value: T) -> Self {
558        Self(value)
559    }
560
561    /// Get the inner value
562    pub fn into_inner(self) -> T {
563        self.0
564    }
565}
566
567impl<T> Deref for Form<T> {
568    type Target = T;
569
570    fn deref(&self) -> &Self::Target {
571        &self.0
572    }
573}
574
575impl<T: DeserializeOwned> FromRequest for Form<T> {
576    fn from_request(request: &HttpRequest) -> Result<Self, Error> {
577        let value: T = request.form()?;
578        Ok(Form(value))
579    }
580}
581
582// ========== ContentType Extractor ==========
583
584/// Extracts the Content-Type header
585#[derive(Debug, Clone)]
586pub struct ContentType(pub String);
587
588impl ContentType {
589    /// Check if the content type is JSON
590    pub fn is_json(&self) -> bool {
591        self.0.contains("application/json")
592    }
593
594    /// Check if the content type is form data
595    pub fn is_form(&self) -> bool {
596        self.0.contains("application/x-www-form-urlencoded")
597    }
598
599    /// Check if the content type is multipart
600    pub fn is_multipart(&self) -> bool {
601        self.0.contains("multipart/form-data")
602    }
603
604    /// Get the inner value
605    pub fn into_inner(self) -> String {
606        self.0
607    }
608}
609
610impl FromRequest for ContentType {
611    fn from_request(request: &HttpRequest) -> Result<Self, Error> {
612        let value = request
613            .headers
614            .get("Content-Type")
615            .or_else(|| request.headers.get("content-type"))
616            .cloned()
617            .unwrap_or_default();
618
619        Ok(ContentType(value))
620    }
621}
622
623impl Deref for ContentType {
624    type Target = str;
625
626    fn deref(&self) -> &Self::Target {
627        &self.0
628    }
629}
630
631// ========== Method Extractor ==========
632
633/// Extracts the HTTP method
634#[derive(Debug, Clone)]
635pub struct Method(pub String);
636
637impl Method {
638    /// Check if the method is GET
639    pub fn is_get(&self) -> bool {
640        self.0 == "GET"
641    }
642
643    /// Check if the method is POST
644    pub fn is_post(&self) -> bool {
645        self.0 == "POST"
646    }
647
648    /// Check if the method is PUT
649    pub fn is_put(&self) -> bool {
650        self.0 == "PUT"
651    }
652
653    /// Check if the method is DELETE
654    pub fn is_delete(&self) -> bool {
655        self.0 == "DELETE"
656    }
657
658    /// Check if the method is PATCH
659    pub fn is_patch(&self) -> bool {
660        self.0 == "PATCH"
661    }
662}
663
664impl FromRequest for Method {
665    fn from_request(request: &HttpRequest) -> Result<Self, Error> {
666        Ok(Method(request.method.clone()))
667    }
668}
669
670impl Deref for Method {
671    type Target = str;
672
673    fn deref(&self) -> &Self::Target {
674        &self.0
675    }
676}
677
678// ========== Extension: FromRequest for primitives ==========
679
680impl FromRequest for HttpRequest {
681    fn from_request(request: &HttpRequest) -> Result<Self, Error> {
682        Ok(request.clone())
683    }
684}
685
686// ========== Helper Macros ==========
687
688/// Extract body from request as the specified type
689///
690/// # Example
691///
692/// ```rust,ignore
693/// let user: CreateUser = body!(request, CreateUser)?;
694/// // or with type inference if annotated
695/// let user = body!(request, CreateUser)?;
696/// ```
697#[macro_export]
698macro_rules! body {
699    ($request:expr, $type:ty) => {
700        <$crate::extractors::Body<$type> as $crate::extractors::FromRequest>::from_request(
701            &$request,
702        )
703        .map(|b| b.into_inner())
704    };
705}
706
707/// Extract query parameters from request as the specified type
708///
709/// # Example
710///
711/// ```rust,ignore
712/// let filters = query!(request, UserFilters)?;
713/// ```
714#[macro_export]
715macro_rules! query {
716    ($request:expr, $type:ty) => {
717        <$crate::extractors::Query<$type> as $crate::extractors::FromRequest>::from_request(
718            &$request,
719        )
720        .map(|q| q.into_inner())
721    };
722}
723
724/// Extract path parameter from request
725///
726/// # Example
727///
728/// ```rust,ignore
729/// let id: u32 = path!(request, "id", u32)?;
730/// ```
731#[macro_export]
732macro_rules! path {
733    ($request:expr, $name:expr, $type:ty) => {
734        <$crate::extractors::Path<$type> as $crate::extractors::FromRequestNamed>::from_request(
735            &$request, $name,
736        )
737        .map(|p| p.into_inner())
738    };
739}
740
741/// Extract header from request
742///
743/// # Example
744///
745/// ```rust,ignore
746/// let auth: String = header!(request, "Authorization")?;
747/// ```
748#[macro_export]
749macro_rules! header {
750    ($request:expr, $name:expr) => {
751        <$crate::extractors::Header as $crate::extractors::FromRequestNamed>::from_request(
752            &$request, $name,
753        )
754        .map(|h| h.into_value())
755    };
756}
757
758#[cfg(test)]
759mod tests {
760    use super::*;
761    use serde::Deserialize;
762
763    fn create_request() -> HttpRequest {
764        let mut req = HttpRequest::new("GET".to_string(), "/users/123".to_string());
765        req.path_params.insert("id".to_string(), "123".to_string());
766        req.query_params.insert("page".to_string(), "1".to_string());
767        req.query_params
768            .insert("limit".to_string(), "10".to_string());
769        req.headers
770            .insert("Authorization".to_string(), "Bearer token123".to_string());
771        req.headers
772            .insert("Content-Type".to_string(), "application/json".to_string());
773        req
774    }
775
776    #[test]
777    fn test_path_extraction() {
778        let request = create_request();
779        let id: Path<u32> = Path::from_request(&request, "id").unwrap();
780        assert_eq!(*id, 123);
781    }
782
783    #[test]
784    fn test_path_missing() {
785        let request = create_request();
786        let result: Result<Path<u32>, _> = Path::from_request(&request, "missing");
787        assert!(result.is_err());
788    }
789
790    #[test]
791    fn test_header_extraction() {
792        let request = create_request();
793        let auth: Header = Header::from_request(&request, "Authorization").unwrap();
794        assert_eq!(auth.value(), "Bearer token123");
795    }
796
797    #[test]
798    fn test_header_optional() {
799        let request = create_request();
800
801        let auth = Header::optional(&request, "Authorization");
802        assert!(auth.is_some());
803
804        let missing = Header::optional(&request, "X-Missing");
805        assert!(missing.is_none());
806    }
807
808    #[test]
809    fn test_headers_extraction() {
810        let request = create_request();
811        let headers: Headers = Headers::from_request(&request).unwrap();
812
813        assert!(headers.contains("Authorization"));
814        assert!(headers.contains("Content-Type"));
815        assert!(!headers.contains("X-Missing"));
816    }
817
818    #[test]
819    fn test_query_extraction() {
820        let request = create_request();
821
822        #[derive(Debug, Deserialize, PartialEq)]
823        struct Pagination {
824            page: u32,
825            limit: u32,
826        }
827
828        let query: Query<Pagination> = Query::from_request(&request).unwrap();
829        assert_eq!(query.page, 1);
830        assert_eq!(query.limit, 10);
831    }
832
833    #[test]
834    fn test_body_extraction() {
835        let mut request = create_request();
836        request.body = serde_json::to_vec(&serde_json::json!({
837            "name": "Test",
838            "email": "test@example.com"
839        }))
840        .unwrap();
841
842        #[derive(Debug, Deserialize)]
843        struct CreateUser {
844            name: String,
845            email: String,
846        }
847
848        let body: Body<CreateUser> = Body::from_request(&request).unwrap();
849        assert_eq!(body.name, "Test");
850        assert_eq!(body.email, "test@example.com");
851    }
852
853    #[test]
854    fn test_raw_body() {
855        let mut request = create_request();
856        request.body = b"raw content".to_vec();
857
858        let raw: RawBody = RawBody::from_request(&request).unwrap();
859        assert_eq!(raw.len(), 11);
860        assert_eq!(raw.to_string_lossy(), "raw content");
861    }
862
863    #[test]
864    fn test_content_type() {
865        let request = create_request();
866        let ct: ContentType = ContentType::from_request(&request).unwrap();
867
868        assert!(ct.is_json());
869        assert!(!ct.is_form());
870        assert!(!ct.is_multipart());
871    }
872
873    #[test]
874    fn test_method() {
875        let request = create_request();
876        let method: Method = Method::from_request(&request).unwrap();
877
878        assert!(method.is_get());
879        assert!(!method.is_post());
880    }
881
882    #[test]
883    fn test_request_extraction() {
884        let request = create_request();
885        let extracted: HttpRequest = HttpRequest::from_request(&request).unwrap();
886
887        assert_eq!(extracted.method, request.method);
888        assert_eq!(extracted.path, request.path);
889    }
890}