1use crate::rule::{AclAction, AclRuleFilter, EndpointPattern, IpMatcher, RuleMatcher, TimeWindow};
52use crate::table::{AclRule, AclTable};
53use serde::{Deserialize, Serialize};
54use std::path::Path;
55use std::sync::Arc;
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct AclConfig {
60 #[serde(default)]
62 pub settings: ConfigSettings,
63 #[serde(default)]
65 pub rules: Vec<RuleConfig>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct ConfigSettings {
71 #[serde(default = "default_action")]
73 pub default_action: ActionConfig,
74}
75
76fn default_action() -> ActionConfig {
77 ActionConfig::Simple(SimpleAction::Deny)
78}
79
80impl Default for ConfigSettings {
81 fn default() -> Self {
82 Self {
83 default_action: default_action(),
84 }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct RuleConfig {
91 #[serde(default = "default_role_mask")]
94 pub role_mask: RoleMaskConfig,
95
96 #[serde(default = "default_id")]
98 pub id: String,
99
100 #[serde(default = "default_endpoint")]
106 pub endpoint: String,
107
108 #[serde(default)]
110 pub methods: Vec<String>,
111
112 #[serde(default)]
114 pub time: Option<TimeConfig>,
115
116 #[serde(default)]
118 pub ip: Option<String>,
119
120 pub action: ActionConfig,
122
123 #[serde(default)]
125 pub description: Option<String>,
126
127 #[serde(default = "default_priority")]
129 pub priority: i32,
130
131 #[serde(default)]
135 pub matcher: Option<MatcherConfig>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140#[serde(untagged)]
141pub enum RoleMaskConfig {
142 Number(u32),
144 String(String),
146}
147
148impl RoleMaskConfig {
149 pub fn to_mask(&self) -> u32 {
151 match self {
152 RoleMaskConfig::Number(n) => *n,
153 RoleMaskConfig::String(s) => {
154 let s = s.trim();
155 if s == "*" || s.eq_ignore_ascii_case("all") {
156 u32::MAX
157 } else if let Some(hex) = s.strip_prefix("0x") {
158 u32::from_str_radix(hex, 16).unwrap_or(u32::MAX)
159 } else if let Some(bin) = s.strip_prefix("0b") {
160 u32::from_str_radix(bin, 2).unwrap_or(u32::MAX)
161 } else {
162 s.parse().unwrap_or(u32::MAX)
163 }
164 }
165 }
166 }
167}
168
169fn default_role_mask() -> RoleMaskConfig {
170 RoleMaskConfig::Number(u32::MAX)
171}
172
173fn default_id() -> String {
174 "*".to_string()
175}
176
177fn default_endpoint() -> String {
178 "*".to_string()
179}
180
181fn default_priority() -> i32 {
182 100
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct TimeConfig {
188 #[serde(default)]
190 pub start: Option<u32>,
191 #[serde(default)]
193 pub end: Option<u32>,
194 #[serde(default)]
196 pub days: Vec<u32>,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
201#[serde(untagged)]
202pub enum ActionConfig {
203 Simple(SimpleAction),
205 Complex(ComplexAction),
207}
208
209#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
211#[serde(rename_all = "lowercase")]
212pub enum SimpleAction {
213 Allow,
215 Deny,
217 Block,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
223#[serde(tag = "type", rename_all = "lowercase")]
224pub enum ComplexAction {
225 Allow,
227 Deny,
229 Block,
231 Error {
233 #[serde(default = "default_error_code")]
235 code: u16,
236 #[serde(default)]
238 message: Option<String>,
239 },
240 Reroute {
242 target: String,
244 #[serde(default)]
246 preserve_path: bool,
247 },
248 RateLimit {
250 max_requests: u32,
252 window_secs: u64,
254 },
255 Log {
257 #[serde(default = "default_log_level")]
259 level: String,
260 #[serde(default)]
262 message: Option<String>,
263 },
264}
265
266#[derive(Debug, Clone, Serialize, Deserialize)]
268#[serde(untagged)]
269pub enum MatcherConfig {
270 Named(String),
272 Parameterized {
274 #[serde(rename = "type")]
276 matcher_type: String,
277 #[serde(default)]
279 param: Option<String>,
280 },
281}
282
283pub trait MatcherRegistry<A>: Send + Sync {
285 fn resolve(&self, config: &MatcherConfig) -> Result<Arc<dyn RuleMatcher<A>>, ConfigError>;
287}
288
289fn default_error_code() -> u16 {
290 403
291}
292
293fn default_log_level() -> String {
294 "info".to_string()
295}
296
297#[derive(Debug, thiserror::Error)]
299pub enum ConfigError {
300 #[error("Failed to parse TOML: {0}")]
302 TomlParse(#[from] toml::de::Error),
303
304 #[error("Failed to read config file: {0}")]
306 FileRead(#[from] std::io::Error),
307
308 #[error("Invalid configuration: {0}")]
310 Invalid(String),
311
312 #[error("Invalid IP pattern '{0}': {1}")]
314 InvalidIp(String, String),
315
316 #[error("Invalid action configuration: {0}")]
318 InvalidAction(String),
319}
320
321impl AclConfig {
322 pub fn from_toml(toml_str: &str) -> Result<Self, ConfigError> {
341 let config: AclConfig = toml::from_str(toml_str)?;
342 config.validate()?;
343 Ok(config)
344 }
345
346 pub fn from_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
355 let contents = std::fs::read_to_string(path)?;
356 Self::from_toml(&contents)
357 }
358
359 fn validate(&self) -> Result<(), ConfigError> {
361 for (i, rule) in self.rules.iter().enumerate() {
362 if let Some(ref ip) = rule.ip {
364 if ip != "*" && !ip.eq_ignore_ascii_case("any") {
365 IpMatcher::parse(ip)
366 .map_err(|e| ConfigError::InvalidIp(ip.clone(), e))?;
367 }
368 }
369
370 if let Some(ref time) = rule.time {
372 if let Some(start) = time.start {
373 if start > 23 {
374 return Err(ConfigError::Invalid(format!(
375 "Rule {}: start hour {} is invalid (must be 0-23)",
376 i, start
377 )));
378 }
379 }
380 if let Some(end) = time.end {
381 if end > 23 {
382 return Err(ConfigError::Invalid(format!(
383 "Rule {}: end hour {} is invalid (must be 0-23)",
384 i, end
385 )));
386 }
387 }
388 for &day in &time.days {
389 if day > 6 {
390 return Err(ConfigError::Invalid(format!(
391 "Rule {}: day {} is invalid (must be 0-6)",
392 i, day
393 )));
394 }
395 }
396 }
397 }
398 Ok(())
399 }
400
401 pub fn into_table(self) -> AclTable {
418 let default_action = action_config_to_action(&self.settings.default_action);
419
420 let mut rules: Vec<(i32, RuleConfig)> = self
422 .rules
423 .into_iter()
424 .map(|r| (r.priority, r))
425 .collect();
426 rules.sort_by_key(|(p, _)| *p);
427
428 let mut builder = AclTable::builder().default_action(default_action);
430
431 for (_, rule_config) in rules {
432 let endpoint = EndpointPattern::parse(&rule_config.endpoint);
433 let filter = rule_config_to_filter(rule_config);
434
435 match endpoint {
437 EndpointPattern::Exact(path) => {
438 builder = builder.add_exact(path, filter);
439 }
440 pattern => {
441 builder = builder.add_pattern(pattern, filter);
442 }
443 }
444 }
445
446 builder.build()
447 }
448}
449
450fn action_config_to_action(config: &ActionConfig) -> AclAction {
452 match config {
453 ActionConfig::Simple(simple) => match simple {
454 SimpleAction::Allow => AclAction::Allow,
455 SimpleAction::Deny | SimpleAction::Block => AclAction::Deny,
456 },
457 ActionConfig::Complex(complex) => match complex {
458 ComplexAction::Allow => AclAction::Allow,
459 ComplexAction::Deny | ComplexAction::Block => AclAction::Deny,
460 ComplexAction::Error { code, message } => AclAction::Error {
461 code: *code,
462 message: message.clone(),
463 },
464 ComplexAction::Reroute {
465 target,
466 preserve_path,
467 } => AclAction::Reroute {
468 target: target.clone(),
469 preserve_path: *preserve_path,
470 },
471 ComplexAction::RateLimit {
472 max_requests,
473 window_secs,
474 } => AclAction::RateLimit {
475 max_requests: *max_requests,
476 window_secs: *window_secs,
477 },
478 ComplexAction::Log { level, message } => AclAction::Log {
479 level: level.clone(),
480 message: message.clone(),
481 },
482 },
483 }
484}
485
486fn rule_config_to_filter(config: RuleConfig) -> AclRuleFilter {
488 let time = config.time.map(|t| {
489 if t.start.is_none() && t.end.is_none() && t.days.is_empty() {
490 TimeWindow::any()
491 } else {
492 TimeWindow {
493 start: t.start.and_then(|h| chrono::NaiveTime::from_hms_opt(h, 0, 0)),
494 end: t.end.and_then(|h| chrono::NaiveTime::from_hms_opt(h, 0, 0)),
495 days: t.days,
496 }
497 }
498 }).unwrap_or_else(TimeWindow::any);
499
500 let ip = config
501 .ip
502 .map(|s| IpMatcher::parse(&s).unwrap_or(IpMatcher::Any))
503 .unwrap_or(IpMatcher::Any);
504
505 let action = action_config_to_action(&config.action);
506
507 let methods: Vec<http::Method> = config
508 .methods
509 .iter()
510 .filter_map(|m| m.parse::<http::Method>().ok())
511 .collect();
512
513 let mut filter = AclRuleFilter::new()
514 .id(config.id)
515 .role_mask(config.role_mask.to_mask())
516 .methods(methods)
517 .time(time)
518 .ip(ip)
519 .action(action);
520
521 if let Some(desc) = config.description {
522 filter = filter.description(desc);
523 }
524
525 filter
526}
527
528impl AclConfig {
529 pub fn into_generic_table<A: Send + Sync + 'static>(
535 self,
536 registry: &dyn MatcherRegistry<A>,
537 ) -> Result<AclTable<A>, ConfigError> {
538 let default_action = action_config_to_action(&self.settings.default_action);
539
540 let mut rules: Vec<(i32, RuleConfig)> = self
541 .rules
542 .into_iter()
543 .map(|r| (r.priority, r))
544 .collect();
545 rules.sort_by_key(|(p, _)| *p);
546
547 let mut exact_rules: std::collections::HashMap<String, Vec<AclRule<A>>> =
548 std::collections::HashMap::new();
549 let mut pattern_rules: Vec<(EndpointPattern, Vec<AclRule<A>>)> = Vec::new();
550
551 for (_, rule_config) in rules {
552 let Some(ref matcher_config) = rule_config.matcher else {
553 continue;
554 };
555
556 let matcher = registry.resolve(matcher_config)?;
557
558 let methods: Vec<http::Method> = rule_config
559 .methods
560 .iter()
561 .filter_map(|m| m.parse::<http::Method>().ok())
562 .collect();
563
564 let action = action_config_to_action(&rule_config.action);
565
566 let rule = if methods.is_empty() {
567 AclRule::from_matcher(matcher)
568 } else {
569 AclRule::from_matcher_with_methods(matcher, methods, action)
570 };
571
572 let endpoint = EndpointPattern::parse(&rule_config.endpoint);
573 match endpoint {
574 EndpointPattern::Exact(path) => {
575 exact_rules.entry(path).or_default().push(rule);
576 }
577 pattern => {
578 let mut found = false;
579 for (existing, rules) in &mut pattern_rules {
580 let is_match = match (existing, &pattern) {
581 (EndpointPattern::Any, EndpointPattern::Any) => true,
582 (EndpointPattern::Prefix(a), EndpointPattern::Prefix(b)) => a == b,
583 (EndpointPattern::Glob(a), EndpointPattern::Glob(b)) => a == b,
584 (EndpointPattern::Exact(a), EndpointPattern::Exact(b)) => a == b,
585 _ => false,
586 };
587 if is_match {
588 rules.push(rule.clone());
589 found = true;
590 break;
591 }
592 }
593 if !found {
594 pattern_rules.push((pattern, vec![rule]));
595 }
596 }
597 }
598 }
599
600 Ok(AclTable {
601 exact_rules,
602 pattern_rules,
603 default_action,
604 })
605 }
606}
607
608impl AclTable {
609 pub fn from_toml(toml_str: &str) -> Result<Self, ConfigError> {
636 let config = AclConfig::from_toml(toml_str)?;
637 Ok(config.into_table())
638 }
639
640 pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
649 let config = AclConfig::from_file(path)?;
650 Ok(config.into_table())
651 }
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657 use std::net::IpAddr;
658 use crate::rule::RequestContext;
659
660 #[test]
661 fn test_parse_simple_config() {
662 let toml = r#"
663[settings]
664default_action = "deny"
665
666[[rules]]
667role_mask = 1
668endpoint = "*"
669action = "allow"
670description = "Admin access"
671
672[[rules]]
673role_mask = 2
674endpoint = "/api/**"
675action = "allow"
676"#;
677
678 let config = AclConfig::from_toml(toml).unwrap();
679 assert_eq!(config.rules.len(), 2);
680
681 let table = config.into_table();
682 assert!(!table.pattern_rules().is_empty());
684 }
685
686 #[test]
687 fn test_parse_complex_actions() {
688 let toml = r#"
689[[rules]]
690endpoint = "/error"
691action = { type = "error", code = 418, message = "I'm a teapot" }
692
693[[rules]]
694endpoint = "/redirect"
695action = { type = "reroute", target = "/new-path", preserve_path = true }
696"#;
697
698 let config = AclConfig::from_toml(toml).unwrap();
699 assert_eq!(config.rules.len(), 2);
700
701 let table = config.into_table();
702 let ip: IpAddr = "127.0.0.1".parse().unwrap();
704 let ctx = RequestContext::new(u32::MAX, ip, "*");
705 let action = table.evaluate("/error", &ctx);
706 match action {
707 AclAction::Error { code, message } => {
708 assert_eq!(code, 418);
709 assert_eq!(message.as_deref(), Some("I'm a teapot"));
710 }
711 _ => panic!("Expected Error action"),
712 }
713 }
714
715 #[test]
716 fn test_parse_time_config() {
717 let toml = r#"
718[[rules]]
719role_mask = 2
720endpoint = "/api/**"
721time = { start = 9, end = 17, days = [0, 1, 2, 3, 4] }
722action = "allow"
723"#;
724
725 let config = AclConfig::from_toml(toml).unwrap();
726 let table = config.into_table();
727
728 assert!(!table.pattern_rules().is_empty());
730 }
731
732 #[test]
733 fn test_parse_ip_config() {
734 let toml = r#"
735[[rules]]
736endpoint = "/internal/**"
737ip = "192.168.1.0/24"
738action = "allow"
739"#;
740
741 let config = AclConfig::from_toml(toml).unwrap();
742 let table = config.into_table();
743
744 let internal_ip: IpAddr = "192.168.1.50".parse().unwrap();
746 let external_ip: IpAddr = "10.0.0.1".parse().unwrap();
747 let ctx_internal = RequestContext::new(u32::MAX, internal_ip, "*");
748 let ctx_external = RequestContext::new(u32::MAX, external_ip, "*");
749
750 assert_eq!(table.evaluate("/internal/foo", &ctx_internal), AclAction::Allow);
751 assert_eq!(table.evaluate("/internal/foo", &ctx_external), AclAction::Deny);
752 }
753
754 #[test]
755 fn test_role_mask_formats() {
756 let toml = r#"
757[[rules]]
758role_mask = 3
759endpoint = "/decimal"
760action = "allow"
761
762[[rules]]
763role_mask = "0xFF"
764endpoint = "/hex"
765action = "allow"
766
767[[rules]]
768role_mask = "*"
769endpoint = "/all"
770action = "allow"
771"#;
772
773 let config = AclConfig::from_toml(toml).unwrap();
774 assert_eq!(config.rules[0].role_mask.to_mask(), 3);
775 assert_eq!(config.rules[1].role_mask.to_mask(), 0xFF);
776 assert_eq!(config.rules[2].role_mask.to_mask(), u32::MAX);
777 }
778}