axum_acl/
config.rs

1//! TOML configuration support for ACL rules.
2//!
3//! This module provides structures for loading ACL rules from TOML configuration,
4//! either compiled-in at build time or read from a file at runtime.
5//!
6//! # Example TOML Format
7//!
8//! ```toml
9//! [settings]
10//! default_action = "deny"
11//!
12//! [[rules]]
13//! role = "admin"
14//! endpoint = "*"
15//! action = "allow"
16//! description = "Admins have full access"
17//!
18//! [[rules]]
19//! role = "user"
20//! endpoint = "/api/**"
21//! time = { start = 9, end = 17, days = [0,1,2,3,4] }
22//! action = "allow"
23//!
24//! [[rules]]
25//! role = "*"
26//! endpoint = "/blocked/**"
27//! action = { type = "error", code = 403, message = "Access forbidden" }
28//! ```
29//!
30//! # Usage
31//!
32//! ## Compile-time embedded config
33//!
34//! ```ignore
35//! use axum_acl::AclTable;
36//!
37//! // Embed at compile time
38//! const ACL_CONFIG: &str = include_str!("../acl.toml");
39//!
40//! let table = AclTable::from_toml(ACL_CONFIG).unwrap();
41//! ```
42//!
43//! ## Runtime file loading
44//!
45//! ```ignore
46//! use axum_acl::AclTable;
47//!
48//! let table = AclTable::from_toml_file("config/acl.toml").unwrap();
49//! ```
50
51use crate::rule::{AclAction, AclRuleFilter, EndpointPattern, IpMatcher, TimeWindow};
52use crate::table::AclTable;
53use serde::{Deserialize, Serialize};
54use std::path::Path;
55
56/// Configuration file structure.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct AclConfig {
59    /// Global settings.
60    #[serde(default)]
61    pub settings: ConfigSettings,
62    /// List of ACL rules.
63    #[serde(default)]
64    pub rules: Vec<RuleConfig>,
65}
66
67/// Global configuration settings.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ConfigSettings {
70    /// Default action when no rules match.
71    #[serde(default = "default_action")]
72    pub default_action: ActionConfig,
73}
74
75fn default_action() -> ActionConfig {
76    ActionConfig::Simple(SimpleAction::Deny)
77}
78
79impl Default for ConfigSettings {
80    fn default() -> Self {
81        Self {
82            default_action: default_action(),
83        }
84    }
85}
86
87/// A single rule configuration.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct RuleConfig {
90    /// Role bitmask. Use 0xFFFFFFFF for all roles, or specific bits like 0b11.
91    /// Can be decimal (e.g., 3), hex (e.g., "0x3"), or binary string.
92    #[serde(default = "default_role_mask")]
93    pub role_mask: RoleMaskConfig,
94
95    /// ID to match. Use "*" for any ID.
96    #[serde(default = "default_id")]
97    pub id: String,
98
99    /// Endpoint pattern to match.
100    /// - "*" or "any" for all endpoints
101    /// - "/path/" (trailing slash) for prefix match
102    /// - "/path/**" for glob match
103    /// - "/path" for exact match
104    #[serde(default = "default_endpoint")]
105    pub endpoint: String,
106
107    /// HTTP methods to match (optional). Empty means all methods.
108    #[serde(default)]
109    pub methods: Vec<String>,
110
111    /// Time window configuration (optional).
112    #[serde(default)]
113    pub time: Option<TimeConfig>,
114
115    /// IP address/CIDR to match (optional). "*" or omitted means any IP.
116    #[serde(default)]
117    pub ip: Option<String>,
118
119    /// Action to take when rule matches.
120    pub action: ActionConfig,
121
122    /// Optional description for logging/debugging.
123    #[serde(default)]
124    pub description: Option<String>,
125
126    /// Priority (lower = higher priority). Rules are sorted by priority.
127    #[serde(default = "default_priority")]
128    pub priority: i32,
129}
130
131/// Role mask configuration - can be a number or string.
132#[derive(Debug, Clone, Serialize, Deserialize)]
133#[serde(untagged)]
134pub enum RoleMaskConfig {
135    /// Numeric role mask.
136    Number(u32),
137    /// String role mask (can be hex like "0xFF" or "*" for all).
138    String(String),
139}
140
141impl RoleMaskConfig {
142    /// Convert to u32 bitmask.
143    pub fn to_mask(&self) -> u32 {
144        match self {
145            RoleMaskConfig::Number(n) => *n,
146            RoleMaskConfig::String(s) => {
147                let s = s.trim();
148                if s == "*" || s.eq_ignore_ascii_case("all") {
149                    u32::MAX
150                } else if let Some(hex) = s.strip_prefix("0x") {
151                    u32::from_str_radix(hex, 16).unwrap_or(u32::MAX)
152                } else if let Some(bin) = s.strip_prefix("0b") {
153                    u32::from_str_radix(bin, 2).unwrap_or(u32::MAX)
154                } else {
155                    s.parse().unwrap_or(u32::MAX)
156                }
157            }
158        }
159    }
160}
161
162fn default_role_mask() -> RoleMaskConfig {
163    RoleMaskConfig::Number(u32::MAX)
164}
165
166fn default_id() -> String {
167    "*".to_string()
168}
169
170fn default_endpoint() -> String {
171    "*".to_string()
172}
173
174fn default_priority() -> i32 {
175    100
176}
177
178/// Time window configuration.
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct TimeConfig {
181    /// Start hour (0-23).
182    #[serde(default)]
183    pub start: Option<u32>,
184    /// End hour (0-23).
185    #[serde(default)]
186    pub end: Option<u32>,
187    /// Days of week (0=Monday, 6=Sunday). Empty means all days.
188    #[serde(default)]
189    pub days: Vec<u32>,
190}
191
192/// Action configuration - can be simple string or complex object.
193#[derive(Debug, Clone, Serialize, Deserialize)]
194#[serde(untagged)]
195pub enum ActionConfig {
196    /// Simple action: "allow", "deny"
197    Simple(SimpleAction),
198    /// Complex action with parameters
199    Complex(ComplexAction),
200}
201
202/// Simple action types.
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
204#[serde(rename_all = "lowercase")]
205pub enum SimpleAction {
206    /// Allow the request.
207    Allow,
208    /// Deny with 403 Forbidden.
209    Deny,
210    /// Block (same as deny, alias).
211    Block,
212}
213
214/// Complex action with additional parameters.
215#[derive(Debug, Clone, Serialize, Deserialize)]
216#[serde(tag = "type", rename_all = "lowercase")]
217pub enum ComplexAction {
218    /// Allow the request.
219    Allow,
220    /// Deny with 403.
221    Deny,
222    /// Block (alias for deny).
223    Block,
224    /// Return a custom error response.
225    Error {
226        /// HTTP status code.
227        #[serde(default = "default_error_code")]
228        code: u16,
229        /// Error message body.
230        #[serde(default)]
231        message: Option<String>,
232    },
233    /// Reroute to a different path.
234    Reroute {
235        /// Target path to reroute to.
236        target: String,
237        /// Whether to preserve the original path as a header.
238        #[serde(default)]
239        preserve_path: bool,
240    },
241    /// Rate limit the request.
242    RateLimit {
243        /// Maximum requests per window.
244        max_requests: u32,
245        /// Window duration in seconds.
246        window_secs: u64,
247    },
248    /// Log and allow (for monitoring).
249    Log {
250        /// Log level: "trace", "debug", "info", "warn", "error"
251        #[serde(default = "default_log_level")]
252        level: String,
253        /// Custom log message.
254        #[serde(default)]
255        message: Option<String>,
256    },
257}
258
259fn default_error_code() -> u16 {
260    403
261}
262
263fn default_log_level() -> String {
264    "info".to_string()
265}
266
267/// Error type for configuration parsing.
268#[derive(Debug, thiserror::Error)]
269pub enum ConfigError {
270    /// TOML parsing error.
271    #[error("Failed to parse TOML: {0}")]
272    TomlParse(#[from] toml::de::Error),
273
274    /// File I/O error.
275    #[error("Failed to read config file: {0}")]
276    FileRead(#[from] std::io::Error),
277
278    /// Invalid configuration.
279    #[error("Invalid configuration: {0}")]
280    Invalid(String),
281
282    /// Invalid IP pattern.
283    #[error("Invalid IP pattern '{0}': {1}")]
284    InvalidIp(String, String),
285
286    /// Invalid action configuration.
287    #[error("Invalid action configuration: {0}")]
288    InvalidAction(String),
289}
290
291impl AclConfig {
292    /// Parse configuration from a TOML string.
293    ///
294    /// # Example
295    /// ```
296    /// use axum_acl::TomlConfig;
297    ///
298    /// let toml = r#"
299    /// [settings]
300    /// default_action = "deny"
301    ///
302    /// [[rules]]
303    /// role = "admin"
304    /// endpoint = "*"
305    /// action = "allow"
306    /// "#;
307    ///
308    /// let config = TomlConfig::from_toml(toml).unwrap();
309    /// ```
310    pub fn from_toml(toml_str: &str) -> Result<Self, ConfigError> {
311        let config: AclConfig = toml::from_str(toml_str)?;
312        config.validate()?;
313        Ok(config)
314    }
315
316    /// Load configuration from a TOML file.
317    ///
318    /// # Example
319    /// ```ignore
320    /// use axum_acl::AclConfig;
321    ///
322    /// let config = AclConfig::from_file("config/acl.toml").unwrap();
323    /// ```
324    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
325        let contents = std::fs::read_to_string(path)?;
326        Self::from_toml(&contents)
327    }
328
329    /// Validate the configuration.
330    fn validate(&self) -> Result<(), ConfigError> {
331        for (i, rule) in self.rules.iter().enumerate() {
332            // Validate IP pattern if provided
333            if let Some(ref ip) = rule.ip {
334                if ip != "*" && !ip.eq_ignore_ascii_case("any") {
335                    IpMatcher::parse(ip)
336                        .map_err(|e| ConfigError::InvalidIp(ip.clone(), e))?;
337                }
338            }
339
340            // Validate time config
341            if let Some(ref time) = rule.time {
342                if let Some(start) = time.start {
343                    if start > 23 {
344                        return Err(ConfigError::Invalid(format!(
345                            "Rule {}: start hour {} is invalid (must be 0-23)",
346                            i, start
347                        )));
348                    }
349                }
350                if let Some(end) = time.end {
351                    if end > 23 {
352                        return Err(ConfigError::Invalid(format!(
353                            "Rule {}: end hour {} is invalid (must be 0-23)",
354                            i, end
355                        )));
356                    }
357                }
358                for &day in &time.days {
359                    if day > 6 {
360                        return Err(ConfigError::Invalid(format!(
361                            "Rule {}: day {} is invalid (must be 0-6)",
362                            i, day
363                        )));
364                    }
365                }
366            }
367        }
368        Ok(())
369    }
370
371    /// Convert configuration to an AclTable.
372    ///
373    /// # Example
374    /// ```
375    /// use axum_acl::TomlConfig;
376    ///
377    /// let toml = r#"
378    /// [[rules]]
379    /// role_mask = 1
380    /// endpoint = "*"
381    /// action = "allow"
382    /// "#;
383    ///
384    /// let config = TomlConfig::from_toml(toml).unwrap();
385    /// let table = config.into_table();
386    /// ```
387    pub fn into_table(self) -> AclTable {
388        let default_action = action_config_to_action(&self.settings.default_action);
389
390        // Sort rules by priority (lower = higher priority)
391        let mut rules: Vec<(i32, RuleConfig)> = self
392            .rules
393            .into_iter()
394            .map(|r| (r.priority, r))
395            .collect();
396        rules.sort_by_key(|(p, _)| *p);
397
398        // Build the table using the builder
399        let mut builder = AclTable::builder().default_action(default_action);
400
401        for (_, rule_config) in rules {
402            let endpoint = EndpointPattern::parse(&rule_config.endpoint);
403            let filter = rule_config_to_filter(rule_config);
404
405            // Add to appropriate collection based on endpoint type
406            match endpoint {
407                EndpointPattern::Exact(path) => {
408                    builder = builder.add_exact(path, filter);
409                }
410                pattern => {
411                    builder = builder.add_pattern(pattern, filter);
412                }
413            }
414        }
415
416        builder.build()
417    }
418}
419
420/// Convert ActionConfig to AclAction.
421fn action_config_to_action(config: &ActionConfig) -> AclAction {
422    match config {
423        ActionConfig::Simple(simple) => match simple {
424            SimpleAction::Allow => AclAction::Allow,
425            SimpleAction::Deny | SimpleAction::Block => AclAction::Deny,
426        },
427        ActionConfig::Complex(complex) => match complex {
428            ComplexAction::Allow => AclAction::Allow,
429            ComplexAction::Deny | ComplexAction::Block => AclAction::Deny,
430            ComplexAction::Error { code, message } => AclAction::Error {
431                code: *code,
432                message: message.clone(),
433            },
434            ComplexAction::Reroute {
435                target,
436                preserve_path,
437            } => AclAction::Reroute {
438                target: target.clone(),
439                preserve_path: *preserve_path,
440            },
441            ComplexAction::RateLimit {
442                max_requests,
443                window_secs,
444            } => AclAction::RateLimit {
445                max_requests: *max_requests,
446                window_secs: *window_secs,
447            },
448            ComplexAction::Log { level, message } => AclAction::Log {
449                level: level.clone(),
450                message: message.clone(),
451            },
452        },
453    }
454}
455
456/// Convert RuleConfig to AclRuleFilter.
457fn rule_config_to_filter(config: RuleConfig) -> AclRuleFilter {
458    let time = config.time.map(|t| {
459        if t.start.is_none() && t.end.is_none() && t.days.is_empty() {
460            TimeWindow::any()
461        } else {
462            TimeWindow {
463                start: t.start.and_then(|h| chrono::NaiveTime::from_hms_opt(h, 0, 0)),
464                end: t.end.and_then(|h| chrono::NaiveTime::from_hms_opt(h, 0, 0)),
465                days: t.days,
466            }
467        }
468    }).unwrap_or_else(TimeWindow::any);
469
470    let ip = config
471        .ip
472        .map(|s| IpMatcher::parse(&s).unwrap_or(IpMatcher::Any))
473        .unwrap_or(IpMatcher::Any);
474
475    let action = action_config_to_action(&config.action);
476
477    let mut filter = AclRuleFilter::new()
478        .id(config.id)
479        .role_mask(config.role_mask.to_mask())
480        .time(time)
481        .ip(ip)
482        .action(action);
483
484    if let Some(desc) = config.description {
485        filter = filter.description(desc);
486    }
487
488    filter
489}
490
491impl AclTable {
492    /// Create an AclTable from a TOML configuration string.
493    ///
494    /// This is the recommended way to load embedded configuration.
495    ///
496    /// # Example
497    /// ```
498    /// use axum_acl::AclTable;
499    ///
500    /// // Compile-time embedded config
501    /// const CONFIG: &str = r#"
502    /// [settings]
503    /// default_action = "deny"
504    ///
505    /// [[rules]]
506    /// role_mask = 1
507    /// endpoint = "*"
508    /// action = "allow"
509    ///
510    /// [[rules]]
511    /// role_mask = 2
512    /// endpoint = "/api/**"
513    /// action = "allow"
514    /// "#;
515    ///
516    /// let table = AclTable::from_toml(CONFIG).unwrap();
517    /// ```
518    pub fn from_toml(toml_str: &str) -> Result<Self, ConfigError> {
519        let config = AclConfig::from_toml(toml_str)?;
520        Ok(config.into_table())
521    }
522
523    /// Create an AclTable from a TOML configuration file.
524    ///
525    /// # Example
526    /// ```ignore
527    /// use axum_acl::AclTable;
528    ///
529    /// let table = AclTable::from_toml_file("config/acl.toml").unwrap();
530    /// ```
531    pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
532        let config = AclConfig::from_file(path)?;
533        Ok(config.into_table())
534    }
535}
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540    use std::net::IpAddr;
541    use crate::rule::RequestContext;
542
543    #[test]
544    fn test_parse_simple_config() {
545        let toml = r#"
546[settings]
547default_action = "deny"
548
549[[rules]]
550role_mask = 1
551endpoint = "*"
552action = "allow"
553description = "Admin access"
554
555[[rules]]
556role_mask = 2
557endpoint = "/api/**"
558action = "allow"
559"#;
560
561        let config = AclConfig::from_toml(toml).unwrap();
562        assert_eq!(config.rules.len(), 2);
563
564        let table = config.into_table();
565        // Check the table has pattern rules (since endpoints use * and **)
566        assert!(!table.pattern_rules().is_empty());
567    }
568
569    #[test]
570    fn test_parse_complex_actions() {
571        let toml = r#"
572[[rules]]
573endpoint = "/error"
574action = { type = "error", code = 418, message = "I'm a teapot" }
575
576[[rules]]
577endpoint = "/redirect"
578action = { type = "reroute", target = "/new-path", preserve_path = true }
579"#;
580
581        let config = AclConfig::from_toml(toml).unwrap();
582        assert_eq!(config.rules.len(), 2);
583
584        let table = config.into_table();
585        // Check error action is returned for /error path
586        let ip: IpAddr = "127.0.0.1".parse().unwrap();
587        let ctx = RequestContext::new(u32::MAX, ip, "*");
588        let action = table.evaluate("/error", &ctx);
589        match action {
590            AclAction::Error { code, message } => {
591                assert_eq!(code, 418);
592                assert_eq!(message.as_deref(), Some("I'm a teapot"));
593            }
594            _ => panic!("Expected Error action"),
595        }
596    }
597
598    #[test]
599    fn test_parse_time_config() {
600        let toml = r#"
601[[rules]]
602role_mask = 2
603endpoint = "/api/**"
604time = { start = 9, end = 17, days = [0, 1, 2, 3, 4] }
605action = "allow"
606"#;
607
608        let config = AclConfig::from_toml(toml).unwrap();
609        let table = config.into_table();
610
611        // The table should have pattern rules with time config
612        assert!(!table.pattern_rules().is_empty());
613    }
614
615    #[test]
616    fn test_parse_ip_config() {
617        let toml = r#"
618[[rules]]
619endpoint = "/internal/**"
620ip = "192.168.1.0/24"
621action = "allow"
622"#;
623
624        let config = AclConfig::from_toml(toml).unwrap();
625        let table = config.into_table();
626
627        // Check the filter has correct IP matcher
628        let (_, filters) = &table.pattern_rules()[0];
629        match &filters[0].ip {
630            IpMatcher::Network(_) => {}
631            _ => panic!("Expected Network IP matcher"),
632        }
633    }
634
635    #[test]
636    fn test_role_mask_formats() {
637        let toml = r#"
638[[rules]]
639role_mask = 3
640endpoint = "/decimal"
641action = "allow"
642
643[[rules]]
644role_mask = "0xFF"
645endpoint = "/hex"
646action = "allow"
647
648[[rules]]
649role_mask = "*"
650endpoint = "/all"
651action = "allow"
652"#;
653
654        let config = AclConfig::from_toml(toml).unwrap();
655        assert_eq!(config.rules[0].role_mask.to_mask(), 3);
656        assert_eq!(config.rules[1].role_mask.to_mask(), 0xFF);
657        assert_eq!(config.rules[2].role_mask.to_mask(), u32::MAX);
658    }
659}