axum_acl/extractor.rs
1//! Role and ID extraction from HTTP requests.
2//!
3//! This module provides traits for extracting user identity from requests:
4//! - [`RoleExtractor`]: Extract roles as a `u32` bitmask (up to 32 roles)
5//! - [`IdExtractor`]: Extract user/resource ID as a `String`
6//!
7//! ## Custom Role Translation
8//!
9//! If your system uses a different role scheme (e.g., string roles, enums),
10//! implement `RoleExtractor` to translate to u32 bitmask:
11//!
12//! ```
13//! use axum_acl::{RoleExtractor, RoleExtractionResult};
14//! use http::Request;
15//!
16//! // Your role enum
17//! enum Role { Admin, User, Guest }
18//!
19//! // Define bit positions
20//! const ROLE_ADMIN: u32 = 1 << 0;
21//! const ROLE_USER: u32 = 1 << 1;
22//! const ROLE_GUEST: u32 = 1 << 2;
23//!
24//! fn roles_to_mask(roles: &[Role]) -> u32 {
25//! roles.iter().fold(0u32, |mask, role| {
26//! mask | match role {
27//! Role::Admin => ROLE_ADMIN,
28//! Role::User => ROLE_USER,
29//! Role::Guest => ROLE_GUEST,
30//! }
31//! })
32//! }
33//! ```
34//!
35//! ## Path Parameter ID Matching
36//!
37//! For paths like `/api/boat/{id}/size`, the `{id}` is matched against
38//! the user's ID from `IdExtractor`. This enables ownership-based access:
39//!
40//! ```text
41//! Rule: /api/boat/{id}/** role=USER, id={id} -> allow
42//! User: id="boat-123", roles=USER
43//! Path: /api/boat/boat-123/size -> ALLOWED (id matches)
44//! Path: /api/boat/boat-456/size -> DENIED (id doesn't match)
45//! ```
46
47use http::Request;
48use std::sync::Arc;
49
50/// Result of role extraction.
51#[derive(Debug, Clone)]
52pub enum RoleExtractionResult {
53 /// Roles were successfully extracted (u32 bitmask).
54 Roles(u32),
55 /// No role could be extracted (user is anonymous/guest).
56 Anonymous,
57 /// An error occurred during extraction.
58 Error(String),
59}
60
61impl RoleExtractionResult {
62 /// Get the roles bitmask, returning a default for anonymous users.
63 pub fn roles_or(&self, default: u32) -> u32 {
64 match self {
65 Self::Roles(roles) => *roles,
66 Self::Anonymous => default,
67 Self::Error(_) => default,
68 }
69 }
70
71 /// Get the roles, returning 0 (no roles) for anonymous users.
72 pub fn roles_or_none(&self) -> u32 {
73 self.roles_or(0)
74 }
75}
76
77/// Trait for extracting roles from HTTP requests.
78///
79/// Implement this trait to customize how roles are determined from incoming
80/// requests. This allows integration with various authentication systems.
81///
82/// Roles are represented as `u32` bitmasks, allowing multiple roles per user.
83///
84/// The trait is synchronous because role extraction typically involves
85/// reading headers or request extensions, which doesn't require async.
86///
87/// # Example
88/// ```
89/// use axum_acl::{RoleExtractor, RoleExtractionResult};
90/// use http::Request;
91///
92/// const ROLE_ADMIN: u32 = 0b001;
93/// const ROLE_USER: u32 = 0b010;
94///
95/// /// Extract roles from a custom header as a bitmask.
96/// struct CustomRolesExtractor;
97///
98/// impl<B> RoleExtractor<B> for CustomRolesExtractor {
99/// fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
100/// match request.headers().get("X-Roles") {
101/// Some(value) => {
102/// match value.to_str() {
103/// Ok(s) => {
104/// // Parse comma-separated role names to bitmask
105/// let mut mask = 0u32;
106/// for role in s.split(',') {
107/// match role.trim() {
108/// "admin" => mask |= ROLE_ADMIN,
109/// "user" => mask |= ROLE_USER,
110/// _ => {}
111/// }
112/// }
113/// RoleExtractionResult::Roles(mask)
114/// }
115/// Err(_) => RoleExtractionResult::Anonymous,
116/// }
117/// }
118/// None => RoleExtractionResult::Anonymous,
119/// }
120/// }
121/// }
122/// ```
123pub trait RoleExtractor<B>: Send + Sync {
124 /// Extract the roles bitmask from an HTTP request.
125 fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult;
126}
127
128// Implement for Arc<T> where T: RoleExtractor
129impl<B, T: RoleExtractor<B>> RoleExtractor<B> for Arc<T> {
130 fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
131 (**self).extract_roles(request)
132 }
133}
134
135// Implement for Box<T> where T: RoleExtractor
136impl<B, T: RoleExtractor<B> + ?Sized> RoleExtractor<B> for Box<T> {
137 fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
138 (**self).extract_roles(request)
139 }
140}
141
142/// Extract roles bitmask from an HTTP header.
143///
144/// The header value is parsed as a u32 bitmask directly, or you can use
145/// a custom parser function to convert header values to bitmasks.
146///
147/// # Example
148/// ```
149/// use axum_acl::HeaderRoleExtractor;
150///
151/// // Extract roles bitmask directly from X-Roles header (as decimal or hex)
152/// let extractor = HeaderRoleExtractor::new("X-Roles");
153///
154/// // With a custom default roles bitmask for missing headers
155/// let extractor = HeaderRoleExtractor::new("X-Roles")
156/// .with_default_roles(0b100); // guest role
157/// ```
158#[derive(Debug, Clone)]
159pub struct HeaderRoleExtractor {
160 header_name: String,
161 default_roles: u32,
162}
163
164impl HeaderRoleExtractor {
165 /// Create a new header role extractor.
166 pub fn new(header_name: impl Into<String>) -> Self {
167 Self {
168 header_name: header_name.into(),
169 default_roles: 0,
170 }
171 }
172
173 /// Set default roles bitmask to use when the header is missing.
174 pub fn with_default_roles(mut self, roles: u32) -> Self {
175 self.default_roles = roles;
176 self
177 }
178}
179
180impl<B> RoleExtractor<B> for HeaderRoleExtractor {
181 fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
182 match request.headers().get(&self.header_name) {
183 Some(value) => match value.to_str() {
184 Ok(s) if !s.is_empty() => {
185 // Try parsing as decimal first, then hex (with 0x prefix)
186 let trimmed = s.trim();
187 if let Ok(roles) = trimmed.parse::<u32>() {
188 RoleExtractionResult::Roles(roles)
189 } else if let Some(hex) = trimmed.strip_prefix("0x") {
190 u32::from_str_radix(hex, 16)
191 .map(RoleExtractionResult::Roles)
192 .unwrap_or_else(|_| {
193 if self.default_roles != 0 {
194 RoleExtractionResult::Roles(self.default_roles)
195 } else {
196 RoleExtractionResult::Anonymous
197 }
198 })
199 } else if self.default_roles != 0 {
200 RoleExtractionResult::Roles(self.default_roles)
201 } else {
202 RoleExtractionResult::Anonymous
203 }
204 }
205 _ => {
206 if self.default_roles != 0 {
207 RoleExtractionResult::Roles(self.default_roles)
208 } else {
209 RoleExtractionResult::Anonymous
210 }
211 }
212 },
213 None => {
214 if self.default_roles != 0 {
215 RoleExtractionResult::Roles(self.default_roles)
216 } else {
217 RoleExtractionResult::Anonymous
218 }
219 }
220 }
221 }
222}
223
224/// Extract roles from a request extension.
225///
226/// This extractor looks for roles that were set by a previous middleware
227/// (e.g., an authentication middleware) as a request extension.
228///
229/// # Example
230/// ```
231/// use axum_acl::ExtensionRoleExtractor;
232///
233/// // The authentication middleware should insert a Roles struct into extensions
234/// #[derive(Clone)]
235/// struct UserRoles(u32);
236///
237/// let extractor = ExtensionRoleExtractor::<UserRoles>::new(|roles| roles.0);
238/// ```
239pub struct ExtensionRoleExtractor<T> {
240 extract_fn: Box<dyn Fn(&T) -> u32 + Send + Sync>,
241}
242
243impl<T> ExtensionRoleExtractor<T> {
244 /// Create a new extension role extractor.
245 ///
246 /// The `extract_fn` converts the extension type to a roles bitmask.
247 pub fn new<F>(extract_fn: F) -> Self
248 where
249 F: Fn(&T) -> u32 + Send + Sync + 'static,
250 {
251 Self {
252 extract_fn: Box::new(extract_fn),
253 }
254 }
255}
256
257impl<T> std::fmt::Debug for ExtensionRoleExtractor<T> {
258 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259 f.debug_struct("ExtensionRoleExtractor")
260 .field("type", &std::any::type_name::<T>())
261 .finish()
262 }
263}
264
265impl<B, T: Clone + Send + Sync + 'static> RoleExtractor<B> for ExtensionRoleExtractor<T> {
266 fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
267 match request.extensions().get::<T>() {
268 Some(ext) => RoleExtractionResult::Roles((self.extract_fn)(ext)),
269 None => RoleExtractionResult::Anonymous,
270 }
271 }
272}
273
274/// A roles extractor that always returns fixed roles.
275///
276/// Useful for testing or for routes that should always use specific roles.
277#[derive(Debug, Clone)]
278pub struct FixedRoleExtractor {
279 roles: u32,
280}
281
282impl FixedRoleExtractor {
283 /// Create a new fixed roles extractor.
284 pub fn new(roles: u32) -> Self {
285 Self { roles }
286 }
287}
288
289impl<B> RoleExtractor<B> for FixedRoleExtractor {
290 fn extract_roles(&self, _request: &Request<B>) -> RoleExtractionResult {
291 RoleExtractionResult::Roles(self.roles)
292 }
293}
294
295/// A role extractor that always returns anonymous (no roles).
296#[derive(Debug, Clone, Default)]
297pub struct AnonymousRoleExtractor;
298
299impl AnonymousRoleExtractor {
300 /// Create a new anonymous role extractor.
301 pub fn new() -> Self {
302 Self
303 }
304}
305
306impl<B> RoleExtractor<B> for AnonymousRoleExtractor {
307 fn extract_roles(&self, _request: &Request<B>) -> RoleExtractionResult {
308 RoleExtractionResult::Anonymous
309 }
310}
311
312/// A composite extractor that tries multiple extractors in order.
313///
314/// Returns the first successful roles extraction, or anonymous if all fail.
315/// Roles from multiple extractors are NOT combined - only the first match is used.
316pub struct ChainedRoleExtractor<B> {
317 extractors: Vec<Box<dyn RoleExtractor<B>>>,
318}
319
320// ============================================================================
321// ID Extraction
322// ============================================================================
323
324/// Result of ID extraction.
325#[derive(Debug, Clone)]
326pub enum IdExtractionResult {
327 /// ID was successfully extracted.
328 Id(String),
329 /// No ID could be extracted (anonymous user).
330 Anonymous,
331 /// An error occurred during extraction.
332 Error(String),
333}
334
335impl IdExtractionResult {
336 /// Get the ID, returning a default for anonymous users.
337 pub fn id_or(&self, default: impl Into<String>) -> String {
338 match self {
339 Self::Id(id) => id.clone(),
340 Self::Anonymous => default.into(),
341 Self::Error(_) => default.into(),
342 }
343 }
344
345 /// Get the ID, returning "*" (wildcard) for anonymous users.
346 pub fn id_or_wildcard(&self) -> String {
347 self.id_or("*")
348 }
349}
350
351/// Trait for extracting user/resource ID from HTTP requests.
352///
353/// Implement this trait to customize how user IDs are determined from
354/// incoming requests. The ID is used for:
355/// - Matching against `{id}` path parameters
356/// - Direct ID matching in ACL rules
357///
358/// # Example: JWT User ID
359/// ```
360/// use axum_acl::{IdExtractor, IdExtractionResult};
361/// use http::Request;
362///
363/// struct JwtIdExtractor;
364///
365/// impl<B> IdExtractor<B> for JwtIdExtractor {
366/// fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
367/// // In practice, you'd decode the JWT and extract the user ID
368/// if let Some(auth) = request.headers().get("Authorization") {
369/// if let Ok(s) = auth.to_str() {
370/// // Simplified: extract user ID from token
371/// if s.starts_with("Bearer ") {
372/// return IdExtractionResult::Id("user-123".to_string());
373/// }
374/// }
375/// }
376/// IdExtractionResult::Anonymous
377/// }
378/// }
379/// ```
380///
381/// # Example: Path-based Resource ID
382/// ```
383/// use axum_acl::{IdExtractor, IdExtractionResult};
384/// use http::Request;
385///
386/// /// Extract resource ID from path like /api/boat/{id}/...
387/// struct PathIdExtractor {
388/// prefix: String, // e.g., "/api/boat/"
389/// }
390///
391/// impl<B> IdExtractor<B> for PathIdExtractor {
392/// fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
393/// let path = request.uri().path();
394/// if let Some(rest) = path.strip_prefix(&self.prefix) {
395/// // Get the next path segment as the ID
396/// if let Some(id) = rest.split('/').next() {
397/// if !id.is_empty() {
398/// return IdExtractionResult::Id(id.to_string());
399/// }
400/// }
401/// }
402/// IdExtractionResult::Anonymous
403/// }
404/// }
405/// ```
406pub trait IdExtractor<B>: Send + Sync {
407 /// Extract the user/resource ID from an HTTP request.
408 fn extract_id(&self, request: &Request<B>) -> IdExtractionResult;
409}
410
411// Implement for Arc<T> where T: IdExtractor
412impl<B, T: IdExtractor<B>> IdExtractor<B> for Arc<T> {
413 fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
414 (**self).extract_id(request)
415 }
416}
417
418// Implement for Box<T> where T: IdExtractor
419impl<B, T: IdExtractor<B> + ?Sized> IdExtractor<B> for Box<T> {
420 fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
421 (**self).extract_id(request)
422 }
423}
424
425/// Extract ID from an HTTP header.
426///
427/// # Example
428/// ```
429/// use axum_acl::HeaderIdExtractor;
430///
431/// // Extract user ID from X-User-Id header
432/// let extractor = HeaderIdExtractor::new("X-User-Id");
433/// ```
434#[derive(Debug, Clone)]
435pub struct HeaderIdExtractor {
436 header_name: String,
437}
438
439impl HeaderIdExtractor {
440 /// Create a new header ID extractor.
441 pub fn new(header_name: impl Into<String>) -> Self {
442 Self {
443 header_name: header_name.into(),
444 }
445 }
446}
447
448impl<B> IdExtractor<B> for HeaderIdExtractor {
449 fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
450 match request.headers().get(&self.header_name) {
451 Some(value) => match value.to_str() {
452 Ok(s) if !s.is_empty() => IdExtractionResult::Id(s.trim().to_string()),
453 _ => IdExtractionResult::Anonymous,
454 },
455 None => IdExtractionResult::Anonymous,
456 }
457 }
458}
459
460/// Extract ID from a request extension.
461///
462/// This extractor looks for an ID that was set by a previous middleware
463/// (e.g., an authentication middleware) as a request extension.
464///
465/// # Example
466/// ```
467/// use axum_acl::ExtensionIdExtractor;
468///
469/// // The authentication middleware should insert a User struct into extensions
470/// #[derive(Clone)]
471/// struct AuthenticatedUser {
472/// id: String,
473/// name: String,
474/// }
475///
476/// let extractor = ExtensionIdExtractor::<AuthenticatedUser>::new(|user| user.id.clone());
477/// ```
478pub struct ExtensionIdExtractor<T> {
479 extract_fn: Box<dyn Fn(&T) -> String + Send + Sync>,
480}
481
482impl<T> ExtensionIdExtractor<T> {
483 /// Create a new extension ID extractor.
484 ///
485 /// The `extract_fn` converts the extension type to an ID string.
486 pub fn new<F>(extract_fn: F) -> Self
487 where
488 F: Fn(&T) -> String + Send + Sync + 'static,
489 {
490 Self {
491 extract_fn: Box::new(extract_fn),
492 }
493 }
494}
495
496impl<T> std::fmt::Debug for ExtensionIdExtractor<T> {
497 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
498 f.debug_struct("ExtensionIdExtractor")
499 .field("type", &std::any::type_name::<T>())
500 .finish()
501 }
502}
503
504impl<B, T: Clone + Send + Sync + 'static> IdExtractor<B> for ExtensionIdExtractor<T> {
505 fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
506 match request.extensions().get::<T>() {
507 Some(ext) => IdExtractionResult::Id((self.extract_fn)(ext)),
508 None => IdExtractionResult::Anonymous,
509 }
510 }
511}
512
513/// An ID extractor that always returns a fixed ID.
514///
515/// Useful for testing.
516#[derive(Debug, Clone)]
517pub struct FixedIdExtractor {
518 id: String,
519}
520
521impl FixedIdExtractor {
522 /// Create a new fixed ID extractor.
523 pub fn new(id: impl Into<String>) -> Self {
524 Self { id: id.into() }
525 }
526}
527
528impl<B> IdExtractor<B> for FixedIdExtractor {
529 fn extract_id(&self, _request: &Request<B>) -> IdExtractionResult {
530 IdExtractionResult::Id(self.id.clone())
531 }
532}
533
534/// An ID extractor that always returns anonymous (no ID).
535#[derive(Debug, Clone, Default)]
536pub struct AnonymousIdExtractor;
537
538impl AnonymousIdExtractor {
539 /// Create a new anonymous ID extractor.
540 pub fn new() -> Self {
541 Self
542 }
543}
544
545impl<B> IdExtractor<B> for AnonymousIdExtractor {
546 fn extract_id(&self, _request: &Request<B>) -> IdExtractionResult {
547 IdExtractionResult::Anonymous
548 }
549}
550
551impl<B> ChainedRoleExtractor<B> {
552 /// Create a new chained role extractor.
553 pub fn new() -> Self {
554 Self {
555 extractors: Vec::new(),
556 }
557 }
558
559 /// Add an extractor to the chain.
560 pub fn push<E: RoleExtractor<B> + 'static>(mut self, extractor: E) -> Self {
561 self.extractors.push(Box::new(extractor));
562 self
563 }
564}
565
566impl<B> Default for ChainedRoleExtractor<B> {
567 fn default() -> Self {
568 Self::new()
569 }
570}
571
572impl<B> std::fmt::Debug for ChainedRoleExtractor<B> {
573 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
574 f.debug_struct("ChainedRoleExtractor")
575 .field("extractors_count", &self.extractors.len())
576 .finish()
577 }
578}
579
580impl<B> RoleExtractor<B> for ChainedRoleExtractor<B>
581where
582 B: Send + Sync,
583{
584 fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
585 for extractor in &self.extractors {
586 match extractor.extract_roles(request) {
587 RoleExtractionResult::Roles(roles) => return RoleExtractionResult::Roles(roles),
588 RoleExtractionResult::Error(e) => {
589 tracing::warn!(error = %e, "Role extractor failed, trying next");
590 }
591 RoleExtractionResult::Anonymous => continue,
592 }
593 }
594 RoleExtractionResult::Anonymous
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use http::Request;
602
603 #[test]
604 fn test_header_extractor_decimal() {
605 let extractor = HeaderRoleExtractor::new("X-Roles");
606
607 let req = Request::builder()
608 .header("X-Roles", "5") // 0b101 = roles 0 and 2
609 .body(())
610 .unwrap();
611
612 match extractor.extract_roles(&req) {
613 RoleExtractionResult::Roles(roles) => assert_eq!(roles, 5),
614 _ => panic!("Expected Roles"),
615 }
616 }
617
618 #[test]
619 fn test_header_extractor_hex() {
620 let extractor = HeaderRoleExtractor::new("X-Roles");
621
622 let req = Request::builder()
623 .header("X-Roles", "0x1F") // 0b11111 = roles 0-4
624 .body(())
625 .unwrap();
626
627 match extractor.extract_roles(&req) {
628 RoleExtractionResult::Roles(roles) => assert_eq!(roles, 0x1F),
629 _ => panic!("Expected Roles"),
630 }
631 }
632
633 #[test]
634 fn test_header_extractor_missing() {
635 let extractor = HeaderRoleExtractor::new("X-Roles");
636
637 let req = Request::builder().body(()).unwrap();
638
639 match extractor.extract_roles(&req) {
640 RoleExtractionResult::Anonymous => {}
641 _ => panic!("Expected Anonymous"),
642 }
643 }
644
645 #[test]
646 fn test_header_extractor_default() {
647 let extractor = HeaderRoleExtractor::new("X-Roles")
648 .with_default_roles(0b100); // guest role
649
650 let req = Request::builder().body(()).unwrap();
651
652 match extractor.extract_roles(&req) {
653 RoleExtractionResult::Roles(roles) => assert_eq!(roles, 0b100),
654 _ => panic!("Expected Roles"),
655 }
656 }
657
658 #[test]
659 fn test_fixed_extractor() {
660 let extractor = FixedRoleExtractor::new(0b11); // admin + user
661
662 let req = Request::builder().body(()).unwrap();
663
664 match extractor.extract_roles(&req) {
665 RoleExtractionResult::Roles(roles) => assert_eq!(roles, 0b11),
666 _ => panic!("Expected Roles"),
667 }
668 }
669}