Skip to main content

axum_acl/
config.rs

1//! TOML configuration support for ACL rules.
2//!
3//! This module provides structures for loading ACL rules from TOML configuration,
4//! either compiled-in at build time or read from a file at runtime.
5//!
6//! # Example TOML Format
7//!
8//! ```toml
9//! [settings]
10//! default_action = "deny"
11//!
12//! [[rules]]
13//! role = "admin"
14//! endpoint = "*"
15//! action = "allow"
16//! description = "Admins have full access"
17//!
18//! [[rules]]
19//! role = "user"
20//! endpoint = "/api/**"
21//! time = { start = 9, end = 17, days = [0,1,2,3,4] }
22//! action = "allow"
23//!
24//! [[rules]]
25//! role = "*"
26//! endpoint = "/blocked/**"
27//! action = { type = "error", code = 403, message = "Access forbidden" }
28//! ```
29//!
30//! # Usage
31//!
32//! ## Compile-time embedded config
33//!
34//! ```ignore
35//! use axum_acl::AclTable;
36//!
37//! // Embed at compile time
38//! const ACL_CONFIG: &str = include_str!("../acl.toml");
39//!
40//! let table = AclTable::from_toml(ACL_CONFIG).unwrap();
41//! ```
42//!
43//! ## Runtime file loading
44//!
45//! ```ignore
46//! use axum_acl::AclTable;
47//!
48//! let table = AclTable::from_toml_file("config/acl.toml").unwrap();
49//! ```
50
51use crate::rule::{AclAction, AclRuleFilter, EndpointPattern, IpMatcher, RuleMatcher, TimeWindow};
52use crate::table::{AclRule, AclTable};
53use serde::{Deserialize, Serialize};
54use std::path::Path;
55use std::sync::Arc;
56
57/// Configuration file structure.
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct AclConfig {
60    /// Global settings.
61    #[serde(default)]
62    pub settings: ConfigSettings,
63    /// List of ACL rules.
64    #[serde(default)]
65    pub rules: Vec<RuleConfig>,
66}
67
68/// Global configuration settings.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct ConfigSettings {
71    /// Default action when no rules match.
72    #[serde(default = "default_action")]
73    pub default_action: ActionConfig,
74}
75
76fn default_action() -> ActionConfig {
77    ActionConfig::Simple(SimpleAction::Deny)
78}
79
80impl Default for ConfigSettings {
81    fn default() -> Self {
82        Self {
83            default_action: default_action(),
84        }
85    }
86}
87
88/// A single rule configuration.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct RuleConfig {
91    /// Role bitmask. Use 0xFFFFFFFF for all roles, or specific bits like 0b11.
92    /// Can be decimal (e.g., 3), hex (e.g., "0x3"), or binary string.
93    #[serde(default = "default_role_mask")]
94    pub role_mask: RoleMaskConfig,
95
96    /// ID to match. Use "*" for any ID.
97    #[serde(default = "default_id")]
98    pub id: String,
99
100    /// Endpoint pattern to match.
101    /// - "*" or "any" for all endpoints
102    /// - "/path/" (trailing slash) for prefix match
103    /// - "/path/**" for glob match
104    /// - "/path" for exact match
105    #[serde(default = "default_endpoint")]
106    pub endpoint: String,
107
108    /// HTTP methods to match (optional). Empty means all methods.
109    #[serde(default)]
110    pub methods: Vec<String>,
111
112    /// Time window configuration (optional).
113    #[serde(default)]
114    pub time: Option<TimeConfig>,
115
116    /// IP address/CIDR to match (optional). "*" or omitted means any IP.
117    #[serde(default)]
118    pub ip: Option<String>,
119
120    /// Action to take when rule matches.
121    pub action: ActionConfig,
122
123    /// Optional description for logging/debugging.
124    #[serde(default)]
125    pub description: Option<String>,
126
127    /// Priority (lower = higher priority). Rules are sorted by priority.
128    #[serde(default = "default_priority")]
129    pub priority: i32,
130
131    /// Custom matcher configuration (for use with `MatcherRegistry`).
132    /// When present and a registry is provided, this is used instead of
133    /// `role_mask`/`id`/`ip`/`time` for building the matcher.
134    #[serde(default)]
135    pub matcher: Option<MatcherConfig>,
136}
137
138/// Role mask configuration - can be a number or string.
139#[derive(Debug, Clone, Serialize, Deserialize)]
140#[serde(untagged)]
141pub enum RoleMaskConfig {
142    /// Numeric role mask.
143    Number(u32),
144    /// String role mask (can be hex like "0xFF" or "*" for all).
145    String(String),
146}
147
148impl RoleMaskConfig {
149    /// Convert to u32 bitmask.
150    pub fn to_mask(&self) -> u32 {
151        match self {
152            RoleMaskConfig::Number(n) => *n,
153            RoleMaskConfig::String(s) => {
154                let s = s.trim();
155                if s == "*" || s.eq_ignore_ascii_case("all") {
156                    u32::MAX
157                } else if let Some(hex) = s.strip_prefix("0x") {
158                    u32::from_str_radix(hex, 16).unwrap_or(u32::MAX)
159                } else if let Some(bin) = s.strip_prefix("0b") {
160                    u32::from_str_radix(bin, 2).unwrap_or(u32::MAX)
161                } else {
162                    s.parse().unwrap_or(u32::MAX)
163                }
164            }
165        }
166    }
167}
168
169fn default_role_mask() -> RoleMaskConfig {
170    RoleMaskConfig::Number(u32::MAX)
171}
172
173fn default_id() -> String {
174    "*".to_string()
175}
176
177fn default_endpoint() -> String {
178    "*".to_string()
179}
180
181fn default_priority() -> i32 {
182    100
183}
184
185/// Time window configuration.
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct TimeConfig {
188    /// Start hour (0-23).
189    #[serde(default)]
190    pub start: Option<u32>,
191    /// End hour (0-23).
192    #[serde(default)]
193    pub end: Option<u32>,
194    /// Days of week (0=Monday, 6=Sunday). Empty means all days.
195    #[serde(default)]
196    pub days: Vec<u32>,
197}
198
199/// Action configuration - can be simple string or complex object.
200#[derive(Debug, Clone, Serialize, Deserialize)]
201#[serde(untagged)]
202pub enum ActionConfig {
203    /// Simple action: "allow", "deny"
204    Simple(SimpleAction),
205    /// Complex action with parameters
206    Complex(ComplexAction),
207}
208
209/// Simple action types.
210#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
211#[serde(rename_all = "lowercase")]
212pub enum SimpleAction {
213    /// Allow the request.
214    Allow,
215    /// Deny with 403 Forbidden.
216    Deny,
217    /// Block (same as deny, alias).
218    Block,
219}
220
221/// Complex action with additional parameters.
222#[derive(Debug, Clone, Serialize, Deserialize)]
223#[serde(tag = "type", rename_all = "lowercase")]
224pub enum ComplexAction {
225    /// Allow the request.
226    Allow,
227    /// Deny with 403.
228    Deny,
229    /// Block (alias for deny).
230    Block,
231    /// Return a custom error response.
232    Error {
233        /// HTTP status code.
234        #[serde(default = "default_error_code")]
235        code: u16,
236        /// Error message body.
237        #[serde(default)]
238        message: Option<String>,
239    },
240    /// Reroute to a different path.
241    Reroute {
242        /// Target path to reroute to.
243        target: String,
244        /// Whether to preserve the original path as a header.
245        #[serde(default)]
246        preserve_path: bool,
247    },
248    /// Rate limit the request.
249    RateLimit {
250        /// Maximum requests per window.
251        max_requests: u32,
252        /// Window duration in seconds.
253        window_secs: u64,
254    },
255    /// Log and allow (for monitoring).
256    Log {
257        /// Log level: "trace", "debug", "info", "warn", "error"
258        #[serde(default = "default_log_level")]
259        level: String,
260        /// Custom log message.
261        #[serde(default)]
262        message: Option<String>,
263    },
264}
265
266/// Configuration for a custom matcher (used with `MatcherRegistry`).
267#[derive(Debug, Clone, Serialize, Deserialize)]
268#[serde(untagged)]
269pub enum MatcherConfig {
270    /// A named matcher reference (e.g., `"authenticated"`).
271    Named(String),
272    /// A parameterized matcher (e.g., `{ type = "scoped_role", param = "id" }`).
273    Parameterized {
274        /// The matcher type name.
275        #[serde(rename = "type")]
276        matcher_type: String,
277        /// Optional parameter name (e.g., path parameter to bind to).
278        #[serde(default)]
279        param: Option<String>,
280    },
281}
282
283/// Registry that resolves `MatcherConfig` values into `RuleMatcher<A>` instances.
284pub trait MatcherRegistry<A>: Send + Sync {
285    /// Resolve a matcher config into a boxed `RuleMatcher`.
286    fn resolve(&self, config: &MatcherConfig) -> Result<Arc<dyn RuleMatcher<A>>, ConfigError>;
287}
288
289fn default_error_code() -> u16 {
290    403
291}
292
293fn default_log_level() -> String {
294    "info".to_string()
295}
296
297/// Error type for configuration parsing.
298#[derive(Debug, thiserror::Error)]
299pub enum ConfigError {
300    /// TOML parsing error.
301    #[error("Failed to parse TOML: {0}")]
302    TomlParse(#[from] toml::de::Error),
303
304    /// File I/O error.
305    #[error("Failed to read config file: {0}")]
306    FileRead(#[from] std::io::Error),
307
308    /// Invalid configuration.
309    #[error("Invalid configuration: {0}")]
310    Invalid(String),
311
312    /// Invalid IP pattern.
313    #[error("Invalid IP pattern '{0}': {1}")]
314    InvalidIp(String, String),
315
316    /// Invalid action configuration.
317    #[error("Invalid action configuration: {0}")]
318    InvalidAction(String),
319}
320
321impl AclConfig {
322    /// Parse configuration from a TOML string.
323    ///
324    /// # Example
325    /// ```
326    /// use axum_acl::TomlConfig;
327    ///
328    /// let toml = r#"
329    /// [settings]
330    /// default_action = "deny"
331    ///
332    /// [[rules]]
333    /// role = "admin"
334    /// endpoint = "*"
335    /// action = "allow"
336    /// "#;
337    ///
338    /// let config = TomlConfig::from_toml(toml).unwrap();
339    /// ```
340    pub fn from_toml(toml_str: &str) -> Result<Self, ConfigError> {
341        let config: AclConfig = toml::from_str(toml_str)?;
342        config.validate()?;
343        Ok(config)
344    }
345
346    /// Load configuration from a TOML file.
347    ///
348    /// # Example
349    /// ```ignore
350    /// use axum_acl::AclConfig;
351    ///
352    /// let config = AclConfig::from_file("config/acl.toml").unwrap();
353    /// ```
354    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
355        let contents = std::fs::read_to_string(path)?;
356        Self::from_toml(&contents)
357    }
358
359    /// Validate the configuration.
360    fn validate(&self) -> Result<(), ConfigError> {
361        for (i, rule) in self.rules.iter().enumerate() {
362            // Validate IP pattern if provided
363            if let Some(ref ip) = rule.ip {
364                if ip != "*" && !ip.eq_ignore_ascii_case("any") {
365                    IpMatcher::parse(ip)
366                        .map_err(|e| ConfigError::InvalidIp(ip.clone(), e))?;
367                }
368            }
369
370            // Validate time config
371            if let Some(ref time) = rule.time {
372                if let Some(start) = time.start {
373                    if start > 23 {
374                        return Err(ConfigError::Invalid(format!(
375                            "Rule {}: start hour {} is invalid (must be 0-23)",
376                            i, start
377                        )));
378                    }
379                }
380                if let Some(end) = time.end {
381                    if end > 23 {
382                        return Err(ConfigError::Invalid(format!(
383                            "Rule {}: end hour {} is invalid (must be 0-23)",
384                            i, end
385                        )));
386                    }
387                }
388                for &day in &time.days {
389                    if day > 6 {
390                        return Err(ConfigError::Invalid(format!(
391                            "Rule {}: day {} is invalid (must be 0-6)",
392                            i, day
393                        )));
394                    }
395                }
396            }
397        }
398        Ok(())
399    }
400
401    /// Convert configuration to an AclTable.
402    ///
403    /// # Example
404    /// ```
405    /// use axum_acl::TomlConfig;
406    ///
407    /// let toml = r#"
408    /// [[rules]]
409    /// role_mask = 1
410    /// endpoint = "*"
411    /// action = "allow"
412    /// "#;
413    ///
414    /// let config = TomlConfig::from_toml(toml).unwrap();
415    /// let table = config.into_table();
416    /// ```
417    pub fn into_table(self) -> AclTable {
418        let default_action = action_config_to_action(&self.settings.default_action);
419
420        // Sort rules by priority (lower = higher priority)
421        let mut rules: Vec<(i32, RuleConfig)> = self
422            .rules
423            .into_iter()
424            .map(|r| (r.priority, r))
425            .collect();
426        rules.sort_by_key(|(p, _)| *p);
427
428        // Build the table using the builder
429        let mut builder = AclTable::builder().default_action(default_action);
430
431        for (_, rule_config) in rules {
432            let endpoint = EndpointPattern::parse(&rule_config.endpoint);
433            let filter = rule_config_to_filter(rule_config);
434
435            // Add to appropriate collection based on endpoint type
436            match endpoint {
437                EndpointPattern::Exact(path) => {
438                    builder = builder.add_exact(path, filter);
439                }
440                pattern => {
441                    builder = builder.add_pattern(pattern, filter);
442                }
443            }
444        }
445
446        builder.build()
447    }
448}
449
450/// Convert ActionConfig to AclAction.
451fn action_config_to_action(config: &ActionConfig) -> AclAction {
452    match config {
453        ActionConfig::Simple(simple) => match simple {
454            SimpleAction::Allow => AclAction::Allow,
455            SimpleAction::Deny | SimpleAction::Block => AclAction::Deny,
456        },
457        ActionConfig::Complex(complex) => match complex {
458            ComplexAction::Allow => AclAction::Allow,
459            ComplexAction::Deny | ComplexAction::Block => AclAction::Deny,
460            ComplexAction::Error { code, message } => AclAction::Error {
461                code: *code,
462                message: message.clone(),
463            },
464            ComplexAction::Reroute {
465                target,
466                preserve_path,
467            } => AclAction::Reroute {
468                target: target.clone(),
469                preserve_path: *preserve_path,
470            },
471            ComplexAction::RateLimit {
472                max_requests,
473                window_secs,
474            } => AclAction::RateLimit {
475                max_requests: *max_requests,
476                window_secs: *window_secs,
477            },
478            ComplexAction::Log { level, message } => AclAction::Log {
479                level: level.clone(),
480                message: message.clone(),
481            },
482        },
483    }
484}
485
486/// Convert RuleConfig to AclRuleFilter.
487fn rule_config_to_filter(config: RuleConfig) -> AclRuleFilter {
488    let time = config.time.map(|t| {
489        if t.start.is_none() && t.end.is_none() && t.days.is_empty() {
490            TimeWindow::any()
491        } else {
492            TimeWindow {
493                start: t.start.and_then(|h| chrono::NaiveTime::from_hms_opt(h, 0, 0)),
494                end: t.end.and_then(|h| chrono::NaiveTime::from_hms_opt(h, 0, 0)),
495                days: t.days,
496            }
497        }
498    }).unwrap_or_else(TimeWindow::any);
499
500    let ip = config
501        .ip
502        .map(|s| IpMatcher::parse(&s).unwrap_or(IpMatcher::Any))
503        .unwrap_or(IpMatcher::Any);
504
505    let action = action_config_to_action(&config.action);
506
507    let methods: Vec<http::Method> = config
508        .methods
509        .iter()
510        .filter_map(|m| m.parse::<http::Method>().ok())
511        .collect();
512
513    let mut filter = AclRuleFilter::new()
514        .id(config.id)
515        .role_mask(config.role_mask.to_mask())
516        .methods(methods)
517        .time(time)
518        .ip(ip)
519        .action(action);
520
521    if let Some(desc) = config.description {
522        filter = filter.description(desc);
523    }
524
525    filter
526}
527
528impl AclConfig {
529    /// Convert configuration into a generic `AclTable<A>` using a `MatcherRegistry`.
530    ///
531    /// Rules with a `matcher` field are resolved via the registry.
532    /// Rules without a `matcher` field are skipped (they only work with the
533    /// bitmask-based `into_table()` method).
534    pub fn into_generic_table<A: Send + Sync + 'static>(
535        self,
536        registry: &dyn MatcherRegistry<A>,
537    ) -> Result<AclTable<A>, ConfigError> {
538        let default_action = action_config_to_action(&self.settings.default_action);
539
540        let mut rules: Vec<(i32, RuleConfig)> = self
541            .rules
542            .into_iter()
543            .map(|r| (r.priority, r))
544            .collect();
545        rules.sort_by_key(|(p, _)| *p);
546
547        let mut exact_rules: std::collections::HashMap<String, Vec<AclRule<A>>> =
548            std::collections::HashMap::new();
549        let mut pattern_rules: Vec<(EndpointPattern, Vec<AclRule<A>>)> = Vec::new();
550
551        for (_, rule_config) in rules {
552            let Some(ref matcher_config) = rule_config.matcher else {
553                continue;
554            };
555
556            let matcher = registry.resolve(matcher_config)?;
557
558            let methods: Vec<http::Method> = rule_config
559                .methods
560                .iter()
561                .filter_map(|m| m.parse::<http::Method>().ok())
562                .collect();
563
564            let action = action_config_to_action(&rule_config.action);
565
566            let rule = if methods.is_empty() {
567                AclRule::from_matcher(matcher)
568            } else {
569                AclRule::from_matcher_with_methods(matcher, methods, action)
570            };
571
572            let endpoint = EndpointPattern::parse(&rule_config.endpoint);
573            match endpoint {
574                EndpointPattern::Exact(path) => {
575                    exact_rules.entry(path).or_default().push(rule);
576                }
577                pattern => {
578                    let mut found = false;
579                    for (existing, rules) in &mut pattern_rules {
580                        let is_match = match (existing, &pattern) {
581                            (EndpointPattern::Any, EndpointPattern::Any) => true,
582                            (EndpointPattern::Prefix(a), EndpointPattern::Prefix(b)) => a == b,
583                            (EndpointPattern::Glob(a), EndpointPattern::Glob(b)) => a == b,
584                            (EndpointPattern::Exact(a), EndpointPattern::Exact(b)) => a == b,
585                            _ => false,
586                        };
587                        if is_match {
588                            rules.push(rule.clone());
589                            found = true;
590                            break;
591                        }
592                    }
593                    if !found {
594                        pattern_rules.push((pattern, vec![rule]));
595                    }
596                }
597            }
598        }
599
600        Ok(AclTable {
601            exact_rules,
602            pattern_rules,
603            default_action,
604        })
605    }
606}
607
608impl AclTable {
609    /// Create an AclTable from a TOML configuration string.
610    ///
611    /// This is the recommended way to load embedded configuration.
612    ///
613    /// # Example
614    /// ```
615    /// use axum_acl::AclTable;
616    ///
617    /// // Compile-time embedded config
618    /// const CONFIG: &str = r#"
619    /// [settings]
620    /// default_action = "deny"
621    ///
622    /// [[rules]]
623    /// role_mask = 1
624    /// endpoint = "*"
625    /// action = "allow"
626    ///
627    /// [[rules]]
628    /// role_mask = 2
629    /// endpoint = "/api/**"
630    /// action = "allow"
631    /// "#;
632    ///
633    /// let table = AclTable::from_toml(CONFIG).unwrap();
634    /// ```
635    pub fn from_toml(toml_str: &str) -> Result<Self, ConfigError> {
636        let config = AclConfig::from_toml(toml_str)?;
637        Ok(config.into_table())
638    }
639
640    /// Create an AclTable from a TOML configuration file.
641    ///
642    /// # Example
643    /// ```ignore
644    /// use axum_acl::AclTable;
645    ///
646    /// let table = AclTable::from_toml_file("config/acl.toml").unwrap();
647    /// ```
648    pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
649        let config = AclConfig::from_file(path)?;
650        Ok(config.into_table())
651    }
652}
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657    use std::net::IpAddr;
658    use crate::rule::RequestContext;
659
660    #[test]
661    fn test_parse_simple_config() {
662        let toml = r#"
663[settings]
664default_action = "deny"
665
666[[rules]]
667role_mask = 1
668endpoint = "*"
669action = "allow"
670description = "Admin access"
671
672[[rules]]
673role_mask = 2
674endpoint = "/api/**"
675action = "allow"
676"#;
677
678        let config = AclConfig::from_toml(toml).unwrap();
679        assert_eq!(config.rules.len(), 2);
680
681        let table = config.into_table();
682        // Check the table has pattern rules (since endpoints use * and **)
683        assert!(!table.pattern_rules().is_empty());
684    }
685
686    #[test]
687    fn test_parse_complex_actions() {
688        let toml = r#"
689[[rules]]
690endpoint = "/error"
691action = { type = "error", code = 418, message = "I'm a teapot" }
692
693[[rules]]
694endpoint = "/redirect"
695action = { type = "reroute", target = "/new-path", preserve_path = true }
696"#;
697
698        let config = AclConfig::from_toml(toml).unwrap();
699        assert_eq!(config.rules.len(), 2);
700
701        let table = config.into_table();
702        // Check error action is returned for /error path
703        let ip: IpAddr = "127.0.0.1".parse().unwrap();
704        let ctx = RequestContext::new(u32::MAX, ip, "*");
705        let action = table.evaluate("/error", &ctx);
706        match action {
707            AclAction::Error { code, message } => {
708                assert_eq!(code, 418);
709                assert_eq!(message.as_deref(), Some("I'm a teapot"));
710            }
711            _ => panic!("Expected Error action"),
712        }
713    }
714
715    #[test]
716    fn test_parse_time_config() {
717        let toml = r#"
718[[rules]]
719role_mask = 2
720endpoint = "/api/**"
721time = { start = 9, end = 17, days = [0, 1, 2, 3, 4] }
722action = "allow"
723"#;
724
725        let config = AclConfig::from_toml(toml).unwrap();
726        let table = config.into_table();
727
728        // The table should have pattern rules with time config
729        assert!(!table.pattern_rules().is_empty());
730    }
731
732    #[test]
733    fn test_parse_ip_config() {
734        let toml = r#"
735[[rules]]
736endpoint = "/internal/**"
737ip = "192.168.1.0/24"
738action = "allow"
739"#;
740
741        let config = AclConfig::from_toml(toml).unwrap();
742        let table = config.into_table();
743
744        // Verify IP filtering works via evaluation
745        let internal_ip: IpAddr = "192.168.1.50".parse().unwrap();
746        let external_ip: IpAddr = "10.0.0.1".parse().unwrap();
747        let ctx_internal = RequestContext::new(u32::MAX, internal_ip, "*");
748        let ctx_external = RequestContext::new(u32::MAX, external_ip, "*");
749
750        assert_eq!(table.evaluate("/internal/foo", &ctx_internal), AclAction::Allow);
751        assert_eq!(table.evaluate("/internal/foo", &ctx_external), AclAction::Deny);
752    }
753
754    #[test]
755    fn test_role_mask_formats() {
756        let toml = r#"
757[[rules]]
758role_mask = 3
759endpoint = "/decimal"
760action = "allow"
761
762[[rules]]
763role_mask = "0xFF"
764endpoint = "/hex"
765action = "allow"
766
767[[rules]]
768role_mask = "*"
769endpoint = "/all"
770action = "allow"
771"#;
772
773        let config = AclConfig::from_toml(toml).unwrap();
774        assert_eq!(config.rules[0].role_mask.to_mask(), 3);
775        assert_eq!(config.rules[1].role_mask.to_mask(), 0xFF);
776        assert_eq!(config.rules[2].role_mask.to_mask(), u32::MAX);
777    }
778}