axum_acl/
table.rs

1//! ACL table for storing and evaluating rules.
2//!
3//! The [`AclTable`] is the central data structure that holds all ACL rules
4//! and provides methods to evaluate requests against them.
5//!
6//! Uses a HashMap for O(1) endpoint lookup, with filters for role/time/ip/id matching.
7
8use crate::rule::{AclAction, AclRuleFilter, EndpointPattern, RequestContext};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// A table containing ACL rules for evaluation.
13///
14/// Uses a 5-tuple system: (endpoint, role, time, ip, id)
15/// - Endpoint is used as HashMap key for O(1) lookup
16/// - Role, time, ip, id are filters applied after endpoint match
17///
18/// # Example
19/// ```
20/// use axum_acl::{AclTable, AclRuleFilter, AclAction};
21///
22/// let table = AclTable::builder()
23///     .default_action(AclAction::Deny)
24///     // Exact endpoint match
25///     .add_exact("/api/users", AclRuleFilter::new()
26///         .role_mask(0b11)  // roles 0 and 1
27///         .action(AclAction::Allow))
28///     // Prefix match for /admin/*
29///     .add_prefix("/admin/", AclRuleFilter::new()
30///         .role_mask(0b1)   // role 0 only (admin)
31///         .action(AclAction::Allow))
32///     .build();
33/// ```
34#[derive(Debug, Clone)]
35pub struct AclTable {
36    /// O(1) lookup for exact endpoint matches.
37    pub(crate) exact_rules: HashMap<String, Vec<AclRuleFilter>>,
38    /// Fallback for prefix/glob/any patterns (checked in order).
39    pub(crate) pattern_rules: Vec<(EndpointPattern, Vec<AclRuleFilter>)>,
40    /// Default action when no rules match.
41    pub(crate) default_action: AclAction,
42}
43
44impl Default for AclTable {
45    fn default() -> Self {
46        Self {
47            exact_rules: HashMap::new(),
48            pattern_rules: Vec::new(),
49            default_action: AclAction::Deny,
50        }
51    }
52}
53
54impl AclTable {
55    /// Create a new empty ACL table with deny as default action.
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    /// Create a builder for constructing an ACL table.
61    pub fn builder() -> AclTableBuilder {
62        AclTableBuilder::new()
63    }
64
65    /// Get the exact rules map.
66    pub fn exact_rules(&self) -> &HashMap<String, Vec<AclRuleFilter>> {
67        &self.exact_rules
68    }
69
70    /// Get the pattern rules.
71    pub fn pattern_rules(&self) -> &[(EndpointPattern, Vec<AclRuleFilter>)] {
72        &self.pattern_rules
73    }
74
75    /// Get the default action when no rules match.
76    pub fn default_action(&self) -> AclAction {
77        self.default_action.clone()
78    }
79
80    /// Evaluate the ACL rules for a given request context.
81    ///
82    /// Lookup order:
83    /// 1. Exact endpoint match in HashMap (O(1))
84    /// 2. Pattern rules (prefix/glob/any) in order
85    ///
86    /// For each endpoint match, filters are checked: id → roles → ip → time
87    ///
88    /// # Example
89    /// ```
90    /// use axum_acl::{AclTable, AclRuleFilter, AclAction, RequestContext};
91    /// use std::net::IpAddr;
92    ///
93    /// let table = AclTable::builder()
94    ///     .add_exact("/api/users", AclRuleFilter::new()
95    ///         .role_mask(0b11)
96    ///         .action(AclAction::Allow))
97    ///     .build();
98    ///
99    /// let ip: IpAddr = "127.0.0.1".parse().unwrap();
100    /// let ctx = RequestContext::new(0b01, ip, "user123");
101    /// let action = table.evaluate("/api/users", &ctx);
102    /// assert_eq!(action, AclAction::Allow);
103    /// ```
104    pub fn evaluate(&self, path: &str, ctx: &RequestContext) -> AclAction {
105        self.evaluate_with_match(path, ctx).0
106    }
107
108    /// Evaluate the ACL rules and return both the action and match info.
109    ///
110    /// Returns `(action, Some((endpoint, filter_index)))` if matched,
111    /// or `(default_action, None)` if no rules matched.
112    pub fn evaluate_with_match(&self, path: &str, ctx: &RequestContext) -> (AclAction, Option<(String, usize)>) {
113        // 1. Try exact endpoint match first (O(1))
114        if let Some(filters) = self.exact_rules.get(path) {
115            for (idx, filter) in filters.iter().enumerate() {
116                if filter.matches(ctx) {
117                    tracing::debug!(
118                        endpoint = path,
119                        filter_index = idx,
120                        filter_description = ?filter.description,
121                        roles = ctx.roles,
122                        id = ctx.id,
123                        ip = %ctx.ip,
124                        action = ?filter.action,
125                        "ACL exact match"
126                    );
127                    return (filter.action.clone(), Some((path.to_string(), idx)));
128                }
129            }
130        }
131
132        // 2. Try pattern rules (prefix/glob/any)
133        for (pattern, filters) in &self.pattern_rules {
134            if pattern.matches(path) {
135                for (idx, filter) in filters.iter().enumerate() {
136                    if filter.matches(ctx) {
137                        tracing::debug!(
138                            endpoint = ?pattern,
139                            filter_index = idx,
140                            filter_description = ?filter.description,
141                            roles = ctx.roles,
142                            id = ctx.id,
143                            ip = %ctx.ip,
144                            action = ?filter.action,
145                            "ACL pattern match"
146                        );
147                        return (filter.action.clone(), Some((format!("{:?}", pattern), idx)));
148                    }
149                }
150            }
151        }
152
153        tracing::debug!(
154            path = path,
155            roles = ctx.roles,
156            id = ctx.id,
157            ip = %ctx.ip,
158            action = ?self.default_action,
159            "No ACL rule matched, using default action"
160        );
161        (self.default_action.clone(), None)
162    }
163
164    /// Check if access is allowed for the given context.
165    pub fn is_allowed(&self, path: &str, ctx: &RequestContext) -> bool {
166        self.evaluate(path, ctx) == AclAction::Allow
167    }
168}
169
170/// Builder for constructing an [`AclTable`].
171#[derive(Debug, Default)]
172pub struct AclTableBuilder {
173    exact_rules: HashMap<String, Vec<AclRuleFilter>>,
174    pattern_rules: Vec<(EndpointPattern, Vec<AclRuleFilter>)>,
175    default_action: AclAction,
176}
177
178impl AclTableBuilder {
179    /// Create a new builder.
180    pub fn new() -> Self {
181        Self::default()
182    }
183
184    /// Set the default action when no rules match.
185    ///
186    /// The default is `AclAction::Deny`.
187    pub fn default_action(mut self, action: AclAction) -> Self {
188        self.default_action = action;
189        self
190    }
191
192    /// Add a filter for an exact endpoint match (O(1) lookup).
193    pub fn add_exact(mut self, endpoint: impl Into<String>, filter: AclRuleFilter) -> Self {
194        let endpoint = endpoint.into();
195        self.exact_rules
196            .entry(endpoint)
197            .or_default()
198            .push(filter);
199        self
200    }
201
202    /// Add multiple filters for an exact endpoint.
203    pub fn add_exact_filters(
204        mut self,
205        endpoint: impl Into<String>,
206        filters: impl IntoIterator<Item = AclRuleFilter>,
207    ) -> Self {
208        let endpoint = endpoint.into();
209        self.exact_rules
210            .entry(endpoint)
211            .or_default()
212            .extend(filters);
213        self
214    }
215
216    /// Add a filter for a prefix endpoint match.
217    pub fn add_prefix(self, prefix: impl Into<String>, filter: AclRuleFilter) -> Self {
218        let pattern = EndpointPattern::Prefix(prefix.into());
219        self.add_pattern(pattern, filter)
220    }
221
222    /// Add a filter for a glob endpoint match.
223    pub fn add_glob(self, glob: impl Into<String>, filter: AclRuleFilter) -> Self {
224        let pattern = EndpointPattern::Glob(glob.into());
225        self.add_pattern(pattern, filter)
226    }
227
228    /// Add a filter that matches any endpoint.
229    pub fn add_any(self, filter: AclRuleFilter) -> Self {
230        self.add_pattern(EndpointPattern::Any, filter)
231    }
232
233    /// Add a filter for a custom endpoint pattern.
234    pub fn add_pattern(mut self, pattern: EndpointPattern, filter: AclRuleFilter) -> Self {
235        // Check if this pattern already exists
236        for (existing_pattern, filters) in &mut self.pattern_rules {
237            let is_match = match (existing_pattern, &pattern) {
238                (EndpointPattern::Any, EndpointPattern::Any) => true,
239                (EndpointPattern::Prefix(a), EndpointPattern::Prefix(b)) => a == b,
240                (EndpointPattern::Glob(a), EndpointPattern::Glob(b)) => a == b,
241                (EndpointPattern::Exact(a), EndpointPattern::Exact(b)) => a == b,
242                _ => false,
243            };
244            if is_match {
245                filters.push(filter);
246                return self;
247            }
248        }
249        // New pattern
250        self.pattern_rules.push((pattern, vec![filter]));
251        self
252    }
253
254    /// Build the ACL table.
255    pub fn build(self) -> AclTable {
256        AclTable {
257            exact_rules: self.exact_rules,
258            pattern_rules: self.pattern_rules,
259            default_action: self.default_action,
260        }
261    }
262
263    /// Build the ACL table wrapped in an Arc for sharing.
264    pub fn build_shared(self) -> Arc<AclTable> {
265        Arc::new(self.build())
266    }
267}
268
269/// Rule entry for providers: endpoint pattern + filter.
270#[derive(Debug, Clone)]
271pub struct RuleEntry {
272    /// The endpoint pattern.
273    pub pattern: EndpointPattern,
274    /// The filter for this endpoint.
275    pub filter: AclRuleFilter,
276}
277
278impl RuleEntry {
279    /// Create a new rule entry.
280    pub fn new(pattern: EndpointPattern, filter: AclRuleFilter) -> Self {
281        Self { pattern, filter }
282    }
283
284    /// Create an exact endpoint rule.
285    pub fn exact(endpoint: impl Into<String>, filter: AclRuleFilter) -> Self {
286        Self::new(EndpointPattern::Exact(endpoint.into()), filter)
287    }
288
289    /// Create a prefix endpoint rule.
290    pub fn prefix(prefix: impl Into<String>, filter: AclRuleFilter) -> Self {
291        Self::new(EndpointPattern::Prefix(prefix.into()), filter)
292    }
293
294    /// Create a glob endpoint rule.
295    pub fn glob(glob: impl Into<String>, filter: AclRuleFilter) -> Self {
296        Self::new(EndpointPattern::Glob(glob.into()), filter)
297    }
298
299    /// Create an any endpoint rule.
300    pub fn any(filter: AclRuleFilter) -> Self {
301        Self::new(EndpointPattern::Any, filter)
302    }
303}
304
305/// Trait for types that can provide ACL rules.
306///
307/// Implement this trait to load rules from external sources like databases,
308/// configuration files, or remote services.
309///
310/// # Example
311/// ```
312/// use axum_acl::{AclRuleProvider, RuleEntry, AclRuleFilter, AclAction, EndpointPattern};
313///
314/// struct ConfigRuleProvider {
315///     config_path: String,
316/// }
317///
318/// impl AclRuleProvider for ConfigRuleProvider {
319///     type Error = std::io::Error;
320///
321///     fn load_rules(&self) -> Result<Vec<RuleEntry>, Self::Error> {
322///         // Load rules from config file
323///         Ok(vec![
324///             RuleEntry::any(AclRuleFilter::new()
325///                 .role_mask(0b1)  // admin role
326///                 .action(AclAction::Allow))
327///         ])
328///     }
329/// }
330/// ```
331pub trait AclRuleProvider: Send + Sync {
332    /// Error type for rule loading failures.
333    type Error: std::error::Error + Send + Sync + 'static;
334
335    /// Load rules from the provider.
336    fn load_rules(&self) -> Result<Vec<RuleEntry>, Self::Error>;
337}
338
339/// A simple rule provider that returns a static list of rules.
340#[derive(Debug, Clone)]
341pub struct StaticRuleProvider {
342    rules: Vec<RuleEntry>,
343}
344
345impl StaticRuleProvider {
346    /// Create a new static rule provider.
347    pub fn new(rules: Vec<RuleEntry>) -> Self {
348        Self { rules }
349    }
350}
351
352impl AclRuleProvider for StaticRuleProvider {
353    type Error = std::convert::Infallible;
354
355    fn load_rules(&self) -> Result<Vec<RuleEntry>, Self::Error> {
356        Ok(self.rules.clone())
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use std::net::IpAddr;
364
365    const ROLE_ADMIN: u32 = 0b001;
366    const ROLE_USER: u32 = 0b010;
367    const ROLE_GUEST: u32 = 0b100;
368
369    #[test]
370    fn test_table_evaluation() {
371        let table = AclTable::builder()
372            .default_action(AclAction::Deny)
373            // Admin can access anything
374            .add_any(AclRuleFilter::new()
375                .role_mask(ROLE_ADMIN)
376                .action(AclAction::Allow))
377            // User can access /api/
378            .add_prefix("/api/", AclRuleFilter::new()
379                .role_mask(ROLE_USER)
380                .action(AclAction::Allow))
381            .build();
382
383        let ip: IpAddr = "127.0.0.1".parse().unwrap();
384
385        // Admin can access anything
386        let admin_ctx = RequestContext::new(ROLE_ADMIN, ip, "admin1");
387        assert!(table.is_allowed("/admin/dashboard", &admin_ctx));
388        assert!(table.is_allowed("/api/users", &admin_ctx));
389
390        // User can only access /api/
391        let user_ctx = RequestContext::new(ROLE_USER, ip, "user1");
392        assert!(table.is_allowed("/api/users", &user_ctx));
393        assert!(!table.is_allowed("/admin/dashboard", &user_ctx));
394
395        // Guest is denied (default action)
396        let guest_ctx = RequestContext::new(ROLE_GUEST, ip, "guest1");
397        assert!(!table.is_allowed("/api/users", &guest_ctx));
398    }
399
400    #[test]
401    fn test_exact_before_pattern() {
402        // Exact match takes priority over patterns
403        let table = AclTable::builder()
404            .default_action(AclAction::Deny)
405            // Exact match for /public
406            .add_exact("/public", AclRuleFilter::new()
407                .role_mask(u32::MAX)
408                .action(AclAction::Allow))
409            // Deny everything else
410            .add_any(AclRuleFilter::new()
411                .role_mask(u32::MAX)
412                .action(AclAction::Deny))
413            .build();
414
415        let ip: IpAddr = "127.0.0.1".parse().unwrap();
416        let ctx = RequestContext::new(0b1, ip, "anyone");
417
418        assert!(table.is_allowed("/public", &ctx));
419        assert!(!table.is_allowed("/private", &ctx));
420    }
421
422    #[test]
423    fn test_role_bitmask() {
424        let table = AclTable::builder()
425            .default_action(AclAction::Deny)
426            .add_exact("/shared", AclRuleFilter::new()
427                .role_mask(ROLE_ADMIN | ROLE_USER)  // admin OR user
428                .action(AclAction::Allow))
429            .build();
430
431        let ip: IpAddr = "127.0.0.1".parse().unwrap();
432
433        // Admin can access
434        assert!(table.is_allowed("/shared", &RequestContext::new(ROLE_ADMIN, ip, "a")));
435        // User can access
436        assert!(table.is_allowed("/shared", &RequestContext::new(ROLE_USER, ip, "u")));
437        // Guest cannot
438        assert!(!table.is_allowed("/shared", &RequestContext::new(ROLE_GUEST, ip, "g")));
439        // User+Admin can access (has overlap)
440        assert!(table.is_allowed("/shared", &RequestContext::new(ROLE_ADMIN | ROLE_USER, ip, "au")));
441    }
442}