1use crate::rule::{AclAction, AclRuleFilter, EndpointPattern, IpMatcher, TimeWindow};
52use crate::table::AclTable;
53use serde::{Deserialize, Serialize};
54use std::path::Path;
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct AclConfig {
59 #[serde(default)]
61 pub settings: ConfigSettings,
62 #[serde(default)]
64 pub rules: Vec<RuleConfig>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ConfigSettings {
70 #[serde(default = "default_action")]
72 pub default_action: ActionConfig,
73}
74
75fn default_action() -> ActionConfig {
76 ActionConfig::Simple(SimpleAction::Deny)
77}
78
79impl Default for ConfigSettings {
80 fn default() -> Self {
81 Self {
82 default_action: default_action(),
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct RuleConfig {
90 #[serde(default = "default_role_mask")]
93 pub role_mask: RoleMaskConfig,
94
95 #[serde(default = "default_id")]
97 pub id: String,
98
99 #[serde(default = "default_endpoint")]
105 pub endpoint: String,
106
107 #[serde(default)]
109 pub methods: Vec<String>,
110
111 #[serde(default)]
113 pub time: Option<TimeConfig>,
114
115 #[serde(default)]
117 pub ip: Option<String>,
118
119 pub action: ActionConfig,
121
122 #[serde(default)]
124 pub description: Option<String>,
125
126 #[serde(default = "default_priority")]
128 pub priority: i32,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133#[serde(untagged)]
134pub enum RoleMaskConfig {
135 Number(u32),
137 String(String),
139}
140
141impl RoleMaskConfig {
142 pub fn to_mask(&self) -> u32 {
144 match self {
145 RoleMaskConfig::Number(n) => *n,
146 RoleMaskConfig::String(s) => {
147 let s = s.trim();
148 if s == "*" || s.eq_ignore_ascii_case("all") {
149 u32::MAX
150 } else if let Some(hex) = s.strip_prefix("0x") {
151 u32::from_str_radix(hex, 16).unwrap_or(u32::MAX)
152 } else if let Some(bin) = s.strip_prefix("0b") {
153 u32::from_str_radix(bin, 2).unwrap_or(u32::MAX)
154 } else {
155 s.parse().unwrap_or(u32::MAX)
156 }
157 }
158 }
159 }
160}
161
162fn default_role_mask() -> RoleMaskConfig {
163 RoleMaskConfig::Number(u32::MAX)
164}
165
166fn default_id() -> String {
167 "*".to_string()
168}
169
170fn default_endpoint() -> String {
171 "*".to_string()
172}
173
174fn default_priority() -> i32 {
175 100
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct TimeConfig {
181 #[serde(default)]
183 pub start: Option<u32>,
184 #[serde(default)]
186 pub end: Option<u32>,
187 #[serde(default)]
189 pub days: Vec<u32>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
194#[serde(untagged)]
195pub enum ActionConfig {
196 Simple(SimpleAction),
198 Complex(ComplexAction),
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
204#[serde(rename_all = "lowercase")]
205pub enum SimpleAction {
206 Allow,
208 Deny,
210 Block,
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
216#[serde(tag = "type", rename_all = "lowercase")]
217pub enum ComplexAction {
218 Allow,
220 Deny,
222 Block,
224 Error {
226 #[serde(default = "default_error_code")]
228 code: u16,
229 #[serde(default)]
231 message: Option<String>,
232 },
233 Reroute {
235 target: String,
237 #[serde(default)]
239 preserve_path: bool,
240 },
241 RateLimit {
243 max_requests: u32,
245 window_secs: u64,
247 },
248 Log {
250 #[serde(default = "default_log_level")]
252 level: String,
253 #[serde(default)]
255 message: Option<String>,
256 },
257}
258
259fn default_error_code() -> u16 {
260 403
261}
262
263fn default_log_level() -> String {
264 "info".to_string()
265}
266
267#[derive(Debug, thiserror::Error)]
269pub enum ConfigError {
270 #[error("Failed to parse TOML: {0}")]
272 TomlParse(#[from] toml::de::Error),
273
274 #[error("Failed to read config file: {0}")]
276 FileRead(#[from] std::io::Error),
277
278 #[error("Invalid configuration: {0}")]
280 Invalid(String),
281
282 #[error("Invalid IP pattern '{0}': {1}")]
284 InvalidIp(String, String),
285
286 #[error("Invalid action configuration: {0}")]
288 InvalidAction(String),
289}
290
291impl AclConfig {
292 pub fn from_toml(toml_str: &str) -> Result<Self, ConfigError> {
311 let config: AclConfig = toml::from_str(toml_str)?;
312 config.validate()?;
313 Ok(config)
314 }
315
316 pub fn from_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
325 let contents = std::fs::read_to_string(path)?;
326 Self::from_toml(&contents)
327 }
328
329 fn validate(&self) -> Result<(), ConfigError> {
331 for (i, rule) in self.rules.iter().enumerate() {
332 if let Some(ref ip) = rule.ip {
334 if ip != "*" && !ip.eq_ignore_ascii_case("any") {
335 IpMatcher::parse(ip)
336 .map_err(|e| ConfigError::InvalidIp(ip.clone(), e))?;
337 }
338 }
339
340 if let Some(ref time) = rule.time {
342 if let Some(start) = time.start {
343 if start > 23 {
344 return Err(ConfigError::Invalid(format!(
345 "Rule {}: start hour {} is invalid (must be 0-23)",
346 i, start
347 )));
348 }
349 }
350 if let Some(end) = time.end {
351 if end > 23 {
352 return Err(ConfigError::Invalid(format!(
353 "Rule {}: end hour {} is invalid (must be 0-23)",
354 i, end
355 )));
356 }
357 }
358 for &day in &time.days {
359 if day > 6 {
360 return Err(ConfigError::Invalid(format!(
361 "Rule {}: day {} is invalid (must be 0-6)",
362 i, day
363 )));
364 }
365 }
366 }
367 }
368 Ok(())
369 }
370
371 pub fn into_table(self) -> AclTable {
388 let default_action = action_config_to_action(&self.settings.default_action);
389
390 let mut rules: Vec<(i32, RuleConfig)> = self
392 .rules
393 .into_iter()
394 .map(|r| (r.priority, r))
395 .collect();
396 rules.sort_by_key(|(p, _)| *p);
397
398 let mut builder = AclTable::builder().default_action(default_action);
400
401 for (_, rule_config) in rules {
402 let endpoint = EndpointPattern::parse(&rule_config.endpoint);
403 let filter = rule_config_to_filter(rule_config);
404
405 match endpoint {
407 EndpointPattern::Exact(path) => {
408 builder = builder.add_exact(path, filter);
409 }
410 pattern => {
411 builder = builder.add_pattern(pattern, filter);
412 }
413 }
414 }
415
416 builder.build()
417 }
418}
419
420fn action_config_to_action(config: &ActionConfig) -> AclAction {
422 match config {
423 ActionConfig::Simple(simple) => match simple {
424 SimpleAction::Allow => AclAction::Allow,
425 SimpleAction::Deny | SimpleAction::Block => AclAction::Deny,
426 },
427 ActionConfig::Complex(complex) => match complex {
428 ComplexAction::Allow => AclAction::Allow,
429 ComplexAction::Deny | ComplexAction::Block => AclAction::Deny,
430 ComplexAction::Error { code, message } => AclAction::Error {
431 code: *code,
432 message: message.clone(),
433 },
434 ComplexAction::Reroute {
435 target,
436 preserve_path,
437 } => AclAction::Reroute {
438 target: target.clone(),
439 preserve_path: *preserve_path,
440 },
441 ComplexAction::RateLimit {
442 max_requests,
443 window_secs,
444 } => AclAction::RateLimit {
445 max_requests: *max_requests,
446 window_secs: *window_secs,
447 },
448 ComplexAction::Log { level, message } => AclAction::Log {
449 level: level.clone(),
450 message: message.clone(),
451 },
452 },
453 }
454}
455
456fn rule_config_to_filter(config: RuleConfig) -> AclRuleFilter {
458 let time = config.time.map(|t| {
459 if t.start.is_none() && t.end.is_none() && t.days.is_empty() {
460 TimeWindow::any()
461 } else {
462 TimeWindow {
463 start: t.start.and_then(|h| chrono::NaiveTime::from_hms_opt(h, 0, 0)),
464 end: t.end.and_then(|h| chrono::NaiveTime::from_hms_opt(h, 0, 0)),
465 days: t.days,
466 }
467 }
468 }).unwrap_or_else(TimeWindow::any);
469
470 let ip = config
471 .ip
472 .map(|s| IpMatcher::parse(&s).unwrap_or(IpMatcher::Any))
473 .unwrap_or(IpMatcher::Any);
474
475 let action = action_config_to_action(&config.action);
476
477 let mut filter = AclRuleFilter::new()
478 .id(config.id)
479 .role_mask(config.role_mask.to_mask())
480 .time(time)
481 .ip(ip)
482 .action(action);
483
484 if let Some(desc) = config.description {
485 filter = filter.description(desc);
486 }
487
488 filter
489}
490
491impl AclTable {
492 pub fn from_toml(toml_str: &str) -> Result<Self, ConfigError> {
519 let config = AclConfig::from_toml(toml_str)?;
520 Ok(config.into_table())
521 }
522
523 pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
532 let config = AclConfig::from_file(path)?;
533 Ok(config.into_table())
534 }
535}
536
537#[cfg(test)]
538mod tests {
539 use super::*;
540 use std::net::IpAddr;
541 use crate::rule::RequestContext;
542
543 #[test]
544 fn test_parse_simple_config() {
545 let toml = r#"
546[settings]
547default_action = "deny"
548
549[[rules]]
550role_mask = 1
551endpoint = "*"
552action = "allow"
553description = "Admin access"
554
555[[rules]]
556role_mask = 2
557endpoint = "/api/**"
558action = "allow"
559"#;
560
561 let config = AclConfig::from_toml(toml).unwrap();
562 assert_eq!(config.rules.len(), 2);
563
564 let table = config.into_table();
565 assert!(!table.pattern_rules().is_empty());
567 }
568
569 #[test]
570 fn test_parse_complex_actions() {
571 let toml = r#"
572[[rules]]
573endpoint = "/error"
574action = { type = "error", code = 418, message = "I'm a teapot" }
575
576[[rules]]
577endpoint = "/redirect"
578action = { type = "reroute", target = "/new-path", preserve_path = true }
579"#;
580
581 let config = AclConfig::from_toml(toml).unwrap();
582 assert_eq!(config.rules.len(), 2);
583
584 let table = config.into_table();
585 let ip: IpAddr = "127.0.0.1".parse().unwrap();
587 let ctx = RequestContext::new(u32::MAX, ip, "*");
588 let action = table.evaluate("/error", &ctx);
589 match action {
590 AclAction::Error { code, message } => {
591 assert_eq!(code, 418);
592 assert_eq!(message.as_deref(), Some("I'm a teapot"));
593 }
594 _ => panic!("Expected Error action"),
595 }
596 }
597
598 #[test]
599 fn test_parse_time_config() {
600 let toml = r#"
601[[rules]]
602role_mask = 2
603endpoint = "/api/**"
604time = { start = 9, end = 17, days = [0, 1, 2, 3, 4] }
605action = "allow"
606"#;
607
608 let config = AclConfig::from_toml(toml).unwrap();
609 let table = config.into_table();
610
611 assert!(!table.pattern_rules().is_empty());
613 }
614
615 #[test]
616 fn test_parse_ip_config() {
617 let toml = r#"
618[[rules]]
619endpoint = "/internal/**"
620ip = "192.168.1.0/24"
621action = "allow"
622"#;
623
624 let config = AclConfig::from_toml(toml).unwrap();
625 let table = config.into_table();
626
627 let (_, filters) = &table.pattern_rules()[0];
629 match &filters[0].ip {
630 IpMatcher::Network(_) => {}
631 _ => panic!("Expected Network IP matcher"),
632 }
633 }
634
635 #[test]
636 fn test_role_mask_formats() {
637 let toml = r#"
638[[rules]]
639role_mask = 3
640endpoint = "/decimal"
641action = "allow"
642
643[[rules]]
644role_mask = "0xFF"
645endpoint = "/hex"
646action = "allow"
647
648[[rules]]
649role_mask = "*"
650endpoint = "/all"
651action = "allow"
652"#;
653
654 let config = AclConfig::from_toml(toml).unwrap();
655 assert_eq!(config.rules[0].role_mask.to_mask(), 3);
656 assert_eq!(config.rules[1].role_mask.to_mask(), 0xFF);
657 assert_eq!(config.rules[2].role_mask.to_mask(), u32::MAX);
658 }
659}