axum_acl/
extractor.rs

1//! Role and ID extraction from HTTP requests.
2//!
3//! This module provides traits for extracting user identity from requests:
4//! - [`RoleExtractor`]: Extract roles as a `u32` bitmask (up to 32 roles)
5//! - [`IdExtractor`]: Extract user/resource ID as a `String`
6//!
7//! ## Custom Role Translation
8//!
9//! If your system uses a different role scheme (e.g., string roles, enums),
10//! implement `RoleExtractor` to translate to u32 bitmask:
11//!
12//! ```
13//! use axum_acl::{RoleExtractor, RoleExtractionResult};
14//! use http::Request;
15//!
16//! // Your role enum
17//! enum Role { Admin, User, Guest }
18//!
19//! // Define bit positions
20//! const ROLE_ADMIN: u32 = 1 << 0;
21//! const ROLE_USER: u32 = 1 << 1;
22//! const ROLE_GUEST: u32 = 1 << 2;
23//!
24//! fn roles_to_mask(roles: &[Role]) -> u32 {
25//!     roles.iter().fold(0u32, |mask, role| {
26//!         mask | match role {
27//!             Role::Admin => ROLE_ADMIN,
28//!             Role::User => ROLE_USER,
29//!             Role::Guest => ROLE_GUEST,
30//!         }
31//!     })
32//! }
33//! ```
34//!
35//! ## Path Parameter ID Matching
36//!
37//! For paths like `/api/boat/{id}/size`, the `{id}` is matched against
38//! the user's ID from `IdExtractor`. This enables ownership-based access:
39//!
40//! ```text
41//! Rule: /api/boat/{id}/**  role=USER, id={id}  -> allow
42//! User: id="boat-123", roles=USER
43//! Path: /api/boat/boat-123/size  -> ALLOWED (id matches)
44//! Path: /api/boat/boat-456/size  -> DENIED (id doesn't match)
45//! ```
46
47use http::Request;
48use std::sync::Arc;
49
50/// Result of role extraction.
51#[derive(Debug, Clone)]
52pub enum RoleExtractionResult {
53    /// Roles were successfully extracted (u32 bitmask).
54    Roles(u32),
55    /// No role could be extracted (user is anonymous/guest).
56    Anonymous,
57    /// An error occurred during extraction.
58    Error(String),
59}
60
61impl RoleExtractionResult {
62    /// Get the roles bitmask, returning a default for anonymous users.
63    pub fn roles_or(&self, default: u32) -> u32 {
64        match self {
65            Self::Roles(roles) => *roles,
66            Self::Anonymous => default,
67            Self::Error(_) => default,
68        }
69    }
70
71    /// Get the roles, returning 0 (no roles) for anonymous users.
72    pub fn roles_or_none(&self) -> u32 {
73        self.roles_or(0)
74    }
75}
76
77/// Trait for extracting roles from HTTP requests.
78///
79/// Implement this trait to customize how roles are determined from incoming
80/// requests. This allows integration with various authentication systems.
81///
82/// Roles are represented as `u32` bitmasks, allowing multiple roles per user.
83///
84/// The trait is synchronous because role extraction typically involves
85/// reading headers or request extensions, which doesn't require async.
86///
87/// # Example
88/// ```
89/// use axum_acl::{RoleExtractor, RoleExtractionResult};
90/// use http::Request;
91///
92/// const ROLE_ADMIN: u32 = 0b001;
93/// const ROLE_USER: u32 = 0b010;
94///
95/// /// Extract roles from a custom header as a bitmask.
96/// struct CustomRolesExtractor;
97///
98/// impl<B> RoleExtractor<B> for CustomRolesExtractor {
99///     fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
100///         match request.headers().get("X-Roles") {
101///             Some(value) => {
102///                 match value.to_str() {
103///                     Ok(s) => {
104///                         // Parse comma-separated role names to bitmask
105///                         let mut mask = 0u32;
106///                         for role in s.split(',') {
107///                             match role.trim() {
108///                                 "admin" => mask |= ROLE_ADMIN,
109///                                 "user" => mask |= ROLE_USER,
110///                                 _ => {}
111///                             }
112///                         }
113///                         RoleExtractionResult::Roles(mask)
114///                     }
115///                     Err(_) => RoleExtractionResult::Anonymous,
116///                 }
117///             }
118///             None => RoleExtractionResult::Anonymous,
119///         }
120///     }
121/// }
122/// ```
123pub trait RoleExtractor<B>: Send + Sync {
124    /// Extract the roles bitmask from an HTTP request.
125    fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult;
126}
127
128// Implement for Arc<T> where T: RoleExtractor
129impl<B, T: RoleExtractor<B>> RoleExtractor<B> for Arc<T> {
130    fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
131        (**self).extract_roles(request)
132    }
133}
134
135// Implement for Box<T> where T: RoleExtractor
136impl<B, T: RoleExtractor<B> + ?Sized> RoleExtractor<B> for Box<T> {
137    fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
138        (**self).extract_roles(request)
139    }
140}
141
142/// Extract roles bitmask from an HTTP header.
143///
144/// The header value is parsed as a u32 bitmask directly, or you can use
145/// a custom parser function to convert header values to bitmasks.
146///
147/// # Example
148/// ```
149/// use axum_acl::HeaderRoleExtractor;
150///
151/// // Extract roles bitmask directly from X-Roles header (as decimal or hex)
152/// let extractor = HeaderRoleExtractor::new("X-Roles");
153///
154/// // With a custom default roles bitmask for missing headers
155/// let extractor = HeaderRoleExtractor::new("X-Roles")
156///     .with_default_roles(0b100);  // guest role
157/// ```
158#[derive(Debug, Clone)]
159pub struct HeaderRoleExtractor {
160    header_name: String,
161    default_roles: u32,
162}
163
164impl HeaderRoleExtractor {
165    /// Create a new header role extractor.
166    pub fn new(header_name: impl Into<String>) -> Self {
167        Self {
168            header_name: header_name.into(),
169            default_roles: 0,
170        }
171    }
172
173    /// Set default roles bitmask to use when the header is missing.
174    pub fn with_default_roles(mut self, roles: u32) -> Self {
175        self.default_roles = roles;
176        self
177    }
178}
179
180impl<B> RoleExtractor<B> for HeaderRoleExtractor {
181    fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
182        match request.headers().get(&self.header_name) {
183            Some(value) => match value.to_str() {
184                Ok(s) if !s.is_empty() => {
185                    // Try parsing as decimal first, then hex (with 0x prefix)
186                    let trimmed = s.trim();
187                    if let Ok(roles) = trimmed.parse::<u32>() {
188                        RoleExtractionResult::Roles(roles)
189                    } else if let Some(hex) = trimmed.strip_prefix("0x") {
190                        u32::from_str_radix(hex, 16)
191                            .map(RoleExtractionResult::Roles)
192                            .unwrap_or_else(|_| {
193                                if self.default_roles != 0 {
194                                    RoleExtractionResult::Roles(self.default_roles)
195                                } else {
196                                    RoleExtractionResult::Anonymous
197                                }
198                            })
199                    } else if self.default_roles != 0 {
200                        RoleExtractionResult::Roles(self.default_roles)
201                    } else {
202                        RoleExtractionResult::Anonymous
203                    }
204                }
205                _ => {
206                    if self.default_roles != 0 {
207                        RoleExtractionResult::Roles(self.default_roles)
208                    } else {
209                        RoleExtractionResult::Anonymous
210                    }
211                }
212            },
213            None => {
214                if self.default_roles != 0 {
215                    RoleExtractionResult::Roles(self.default_roles)
216                } else {
217                    RoleExtractionResult::Anonymous
218                }
219            }
220        }
221    }
222}
223
224/// Extract roles from a request extension.
225///
226/// This extractor looks for roles that were set by a previous middleware
227/// (e.g., an authentication middleware) as a request extension.
228///
229/// # Example
230/// ```
231/// use axum_acl::ExtensionRoleExtractor;
232///
233/// // The authentication middleware should insert a Roles struct into extensions
234/// #[derive(Clone)]
235/// struct UserRoles(u32);
236///
237/// let extractor = ExtensionRoleExtractor::<UserRoles>::new(|roles| roles.0);
238/// ```
239pub struct ExtensionRoleExtractor<T> {
240    extract_fn: Box<dyn Fn(&T) -> u32 + Send + Sync>,
241}
242
243impl<T> ExtensionRoleExtractor<T> {
244    /// Create a new extension role extractor.
245    ///
246    /// The `extract_fn` converts the extension type to a roles bitmask.
247    pub fn new<F>(extract_fn: F) -> Self
248    where
249        F: Fn(&T) -> u32 + Send + Sync + 'static,
250    {
251        Self {
252            extract_fn: Box::new(extract_fn),
253        }
254    }
255}
256
257impl<T> std::fmt::Debug for ExtensionRoleExtractor<T> {
258    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259        f.debug_struct("ExtensionRoleExtractor")
260            .field("type", &std::any::type_name::<T>())
261            .finish()
262    }
263}
264
265impl<B, T: Clone + Send + Sync + 'static> RoleExtractor<B> for ExtensionRoleExtractor<T> {
266    fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
267        match request.extensions().get::<T>() {
268            Some(ext) => RoleExtractionResult::Roles((self.extract_fn)(ext)),
269            None => RoleExtractionResult::Anonymous,
270        }
271    }
272}
273
274/// A roles extractor that always returns fixed roles.
275///
276/// Useful for testing or for routes that should always use specific roles.
277#[derive(Debug, Clone)]
278pub struct FixedRoleExtractor {
279    roles: u32,
280}
281
282impl FixedRoleExtractor {
283    /// Create a new fixed roles extractor.
284    pub fn new(roles: u32) -> Self {
285        Self { roles }
286    }
287}
288
289impl<B> RoleExtractor<B> for FixedRoleExtractor {
290    fn extract_roles(&self, _request: &Request<B>) -> RoleExtractionResult {
291        RoleExtractionResult::Roles(self.roles)
292    }
293}
294
295/// A role extractor that always returns anonymous (no roles).
296#[derive(Debug, Clone, Default)]
297pub struct AnonymousRoleExtractor;
298
299impl AnonymousRoleExtractor {
300    /// Create a new anonymous role extractor.
301    pub fn new() -> Self {
302        Self
303    }
304}
305
306impl<B> RoleExtractor<B> for AnonymousRoleExtractor {
307    fn extract_roles(&self, _request: &Request<B>) -> RoleExtractionResult {
308        RoleExtractionResult::Anonymous
309    }
310}
311
312/// A composite extractor that tries multiple extractors in order.
313///
314/// Returns the first successful roles extraction, or anonymous if all fail.
315/// Roles from multiple extractors are NOT combined - only the first match is used.
316pub struct ChainedRoleExtractor<B> {
317    extractors: Vec<Box<dyn RoleExtractor<B>>>,
318}
319
320// ============================================================================
321// ID Extraction
322// ============================================================================
323
324/// Result of ID extraction.
325#[derive(Debug, Clone)]
326pub enum IdExtractionResult {
327    /// ID was successfully extracted.
328    Id(String),
329    /// No ID could be extracted (anonymous user).
330    Anonymous,
331    /// An error occurred during extraction.
332    Error(String),
333}
334
335impl IdExtractionResult {
336    /// Get the ID, returning a default for anonymous users.
337    pub fn id_or(&self, default: impl Into<String>) -> String {
338        match self {
339            Self::Id(id) => id.clone(),
340            Self::Anonymous => default.into(),
341            Self::Error(_) => default.into(),
342        }
343    }
344
345    /// Get the ID, returning "*" (wildcard) for anonymous users.
346    pub fn id_or_wildcard(&self) -> String {
347        self.id_or("*")
348    }
349}
350
351/// Trait for extracting user/resource ID from HTTP requests.
352///
353/// Implement this trait to customize how user IDs are determined from
354/// incoming requests. The ID is used for:
355/// - Matching against `{id}` path parameters
356/// - Direct ID matching in ACL rules
357///
358/// # Example: JWT User ID
359/// ```
360/// use axum_acl::{IdExtractor, IdExtractionResult};
361/// use http::Request;
362///
363/// struct JwtIdExtractor;
364///
365/// impl<B> IdExtractor<B> for JwtIdExtractor {
366///     fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
367///         // In practice, you'd decode the JWT and extract the user ID
368///         if let Some(auth) = request.headers().get("Authorization") {
369///             if let Ok(s) = auth.to_str() {
370///                 // Simplified: extract user ID from token
371///                 if s.starts_with("Bearer ") {
372///                     return IdExtractionResult::Id("user-123".to_string());
373///                 }
374///             }
375///         }
376///         IdExtractionResult::Anonymous
377///     }
378/// }
379/// ```
380///
381/// # Example: Path-based Resource ID
382/// ```
383/// use axum_acl::{IdExtractor, IdExtractionResult};
384/// use http::Request;
385///
386/// /// Extract resource ID from path like /api/boat/{id}/...
387/// struct PathIdExtractor {
388///     prefix: String,  // e.g., "/api/boat/"
389/// }
390///
391/// impl<B> IdExtractor<B> for PathIdExtractor {
392///     fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
393///         let path = request.uri().path();
394///         if let Some(rest) = path.strip_prefix(&self.prefix) {
395///             // Get the next path segment as the ID
396///             if let Some(id) = rest.split('/').next() {
397///                 if !id.is_empty() {
398///                     return IdExtractionResult::Id(id.to_string());
399///                 }
400///             }
401///         }
402///         IdExtractionResult::Anonymous
403///     }
404/// }
405/// ```
406pub trait IdExtractor<B>: Send + Sync {
407    /// Extract the user/resource ID from an HTTP request.
408    fn extract_id(&self, request: &Request<B>) -> IdExtractionResult;
409}
410
411// Implement for Arc<T> where T: IdExtractor
412impl<B, T: IdExtractor<B>> IdExtractor<B> for Arc<T> {
413    fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
414        (**self).extract_id(request)
415    }
416}
417
418// Implement for Box<T> where T: IdExtractor
419impl<B, T: IdExtractor<B> + ?Sized> IdExtractor<B> for Box<T> {
420    fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
421        (**self).extract_id(request)
422    }
423}
424
425/// Extract ID from an HTTP header.
426///
427/// # Example
428/// ```
429/// use axum_acl::HeaderIdExtractor;
430///
431/// // Extract user ID from X-User-Id header
432/// let extractor = HeaderIdExtractor::new("X-User-Id");
433/// ```
434#[derive(Debug, Clone)]
435pub struct HeaderIdExtractor {
436    header_name: String,
437}
438
439impl HeaderIdExtractor {
440    /// Create a new header ID extractor.
441    pub fn new(header_name: impl Into<String>) -> Self {
442        Self {
443            header_name: header_name.into(),
444        }
445    }
446}
447
448impl<B> IdExtractor<B> for HeaderIdExtractor {
449    fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
450        match request.headers().get(&self.header_name) {
451            Some(value) => match value.to_str() {
452                Ok(s) if !s.is_empty() => IdExtractionResult::Id(s.trim().to_string()),
453                _ => IdExtractionResult::Anonymous,
454            },
455            None => IdExtractionResult::Anonymous,
456        }
457    }
458}
459
460/// Extract ID from a request extension.
461///
462/// This extractor looks for an ID that was set by a previous middleware
463/// (e.g., an authentication middleware) as a request extension.
464///
465/// # Example
466/// ```
467/// use axum_acl::ExtensionIdExtractor;
468///
469/// // The authentication middleware should insert a User struct into extensions
470/// #[derive(Clone)]
471/// struct AuthenticatedUser {
472///     id: String,
473///     name: String,
474/// }
475///
476/// let extractor = ExtensionIdExtractor::<AuthenticatedUser>::new(|user| user.id.clone());
477/// ```
478pub struct ExtensionIdExtractor<T> {
479    extract_fn: Box<dyn Fn(&T) -> String + Send + Sync>,
480}
481
482impl<T> ExtensionIdExtractor<T> {
483    /// Create a new extension ID extractor.
484    ///
485    /// The `extract_fn` converts the extension type to an ID string.
486    pub fn new<F>(extract_fn: F) -> Self
487    where
488        F: Fn(&T) -> String + Send + Sync + 'static,
489    {
490        Self {
491            extract_fn: Box::new(extract_fn),
492        }
493    }
494}
495
496impl<T> std::fmt::Debug for ExtensionIdExtractor<T> {
497    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
498        f.debug_struct("ExtensionIdExtractor")
499            .field("type", &std::any::type_name::<T>())
500            .finish()
501    }
502}
503
504impl<B, T: Clone + Send + Sync + 'static> IdExtractor<B> for ExtensionIdExtractor<T> {
505    fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
506        match request.extensions().get::<T>() {
507            Some(ext) => IdExtractionResult::Id((self.extract_fn)(ext)),
508            None => IdExtractionResult::Anonymous,
509        }
510    }
511}
512
513/// An ID extractor that always returns a fixed ID.
514///
515/// Useful for testing.
516#[derive(Debug, Clone)]
517pub struct FixedIdExtractor {
518    id: String,
519}
520
521impl FixedIdExtractor {
522    /// Create a new fixed ID extractor.
523    pub fn new(id: impl Into<String>) -> Self {
524        Self { id: id.into() }
525    }
526}
527
528impl<B> IdExtractor<B> for FixedIdExtractor {
529    fn extract_id(&self, _request: &Request<B>) -> IdExtractionResult {
530        IdExtractionResult::Id(self.id.clone())
531    }
532}
533
534/// An ID extractor that always returns anonymous (no ID).
535#[derive(Debug, Clone, Default)]
536pub struct AnonymousIdExtractor;
537
538impl AnonymousIdExtractor {
539    /// Create a new anonymous ID extractor.
540    pub fn new() -> Self {
541        Self
542    }
543}
544
545impl<B> IdExtractor<B> for AnonymousIdExtractor {
546    fn extract_id(&self, _request: &Request<B>) -> IdExtractionResult {
547        IdExtractionResult::Anonymous
548    }
549}
550
551impl<B> ChainedRoleExtractor<B> {
552    /// Create a new chained role extractor.
553    pub fn new() -> Self {
554        Self {
555            extractors: Vec::new(),
556        }
557    }
558
559    /// Add an extractor to the chain.
560    pub fn push<E: RoleExtractor<B> + 'static>(mut self, extractor: E) -> Self {
561        self.extractors.push(Box::new(extractor));
562        self
563    }
564}
565
566impl<B> Default for ChainedRoleExtractor<B> {
567    fn default() -> Self {
568        Self::new()
569    }
570}
571
572impl<B> std::fmt::Debug for ChainedRoleExtractor<B> {
573    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
574        f.debug_struct("ChainedRoleExtractor")
575            .field("extractors_count", &self.extractors.len())
576            .finish()
577    }
578}
579
580impl<B> RoleExtractor<B> for ChainedRoleExtractor<B>
581where
582    B: Send + Sync,
583{
584    fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
585        for extractor in &self.extractors {
586            match extractor.extract_roles(request) {
587                RoleExtractionResult::Roles(roles) => return RoleExtractionResult::Roles(roles),
588                RoleExtractionResult::Error(e) => {
589                    tracing::warn!(error = %e, "Role extractor failed, trying next");
590                }
591                RoleExtractionResult::Anonymous => continue,
592            }
593        }
594        RoleExtractionResult::Anonymous
595    }
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601    use http::Request;
602
603    #[test]
604    fn test_header_extractor_decimal() {
605        let extractor = HeaderRoleExtractor::new("X-Roles");
606
607        let req = Request::builder()
608            .header("X-Roles", "5")  // 0b101 = roles 0 and 2
609            .body(())
610            .unwrap();
611
612        match extractor.extract_roles(&req) {
613            RoleExtractionResult::Roles(roles) => assert_eq!(roles, 5),
614            _ => panic!("Expected Roles"),
615        }
616    }
617
618    #[test]
619    fn test_header_extractor_hex() {
620        let extractor = HeaderRoleExtractor::new("X-Roles");
621
622        let req = Request::builder()
623            .header("X-Roles", "0x1F")  // 0b11111 = roles 0-4
624            .body(())
625            .unwrap();
626
627        match extractor.extract_roles(&req) {
628            RoleExtractionResult::Roles(roles) => assert_eq!(roles, 0x1F),
629            _ => panic!("Expected Roles"),
630        }
631    }
632
633    #[test]
634    fn test_header_extractor_missing() {
635        let extractor = HeaderRoleExtractor::new("X-Roles");
636
637        let req = Request::builder().body(()).unwrap();
638
639        match extractor.extract_roles(&req) {
640            RoleExtractionResult::Anonymous => {}
641            _ => panic!("Expected Anonymous"),
642        }
643    }
644
645    #[test]
646    fn test_header_extractor_default() {
647        let extractor = HeaderRoleExtractor::new("X-Roles")
648            .with_default_roles(0b100);  // guest role
649
650        let req = Request::builder().body(()).unwrap();
651
652        match extractor.extract_roles(&req) {
653            RoleExtractionResult::Roles(roles) => assert_eq!(roles, 0b100),
654            _ => panic!("Expected Roles"),
655        }
656    }
657
658    #[test]
659    fn test_fixed_extractor() {
660        let extractor = FixedRoleExtractor::new(0b11);  // admin + user
661
662        let req = Request::builder().body(()).unwrap();
663
664        match extractor.extract_roles(&req) {
665            RoleExtractionResult::Roles(roles) => assert_eq!(roles, 0b11),
666            _ => panic!("Expected Roles"),
667        }
668    }
669}