use crate::rule::{AclAction, AclRuleFilter, EndpointPattern, IpMatcher, RuleMatcher, TimeWindow};
use crate::table::{AclRule, AclTable};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AclConfig {
#[serde(default)]
pub settings: ConfigSettings,
#[serde(default)]
pub rules: Vec<RuleConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfigSettings {
#[serde(default = "default_action")]
pub default_action: ActionConfig,
}
fn default_action() -> ActionConfig {
ActionConfig::Simple(SimpleAction::Deny)
}
impl Default for ConfigSettings {
fn default() -> Self {
Self {
default_action: default_action(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuleConfig {
#[serde(default = "default_role_mask")]
pub role_mask: RoleMaskConfig,
#[serde(default = "default_id")]
pub id: String,
#[serde(default = "default_endpoint")]
pub endpoint: String,
#[serde(default)]
pub methods: Vec<String>,
#[serde(default)]
pub time: Option<TimeConfig>,
#[serde(default)]
pub ip: Option<String>,
pub action: ActionConfig,
#[serde(default)]
pub description: Option<String>,
#[serde(default = "default_priority")]
pub priority: i32,
#[serde(default)]
pub matcher: Option<MatcherConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum RoleMaskConfig {
Number(u32),
String(String),
}
impl RoleMaskConfig {
pub fn to_mask(&self) -> u32 {
match self {
RoleMaskConfig::Number(n) => *n,
RoleMaskConfig::String(s) => {
let s = s.trim();
if s == "*" || s.eq_ignore_ascii_case("all") {
u32::MAX
} else if let Some(hex) = s.strip_prefix("0x") {
u32::from_str_radix(hex, 16).unwrap_or(u32::MAX)
} else if let Some(bin) = s.strip_prefix("0b") {
u32::from_str_radix(bin, 2).unwrap_or(u32::MAX)
} else {
s.parse().unwrap_or(u32::MAX)
}
}
}
}
}
fn default_role_mask() -> RoleMaskConfig {
RoleMaskConfig::Number(u32::MAX)
}
fn default_id() -> String {
"*".to_string()
}
fn default_endpoint() -> String {
"*".to_string()
}
fn default_priority() -> i32 {
100
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeConfig {
#[serde(default)]
pub start: Option<u32>,
#[serde(default)]
pub end: Option<u32>,
#[serde(default)]
pub days: Vec<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ActionConfig {
Simple(SimpleAction),
Complex(ComplexAction),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SimpleAction {
Allow,
Deny,
Block,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ComplexAction {
Allow,
Deny,
Block,
Error {
#[serde(default = "default_error_code")]
code: u16,
#[serde(default)]
message: Option<String>,
},
Reroute {
target: String,
#[serde(default)]
preserve_path: bool,
},
RateLimit {
max_requests: u32,
window_secs: u64,
},
Log {
#[serde(default = "default_log_level")]
level: String,
#[serde(default)]
message: Option<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MatcherConfig {
Named(String),
Parameterized {
#[serde(rename = "type")]
matcher_type: String,
#[serde(default)]
param: Option<String>,
},
}
pub trait MatcherRegistry<A>: Send + Sync {
fn resolve(&self, config: &MatcherConfig) -> Result<Arc<dyn RuleMatcher<A>>, ConfigError>;
}
fn default_error_code() -> u16 {
403
}
fn default_log_level() -> String {
"info".to_string()
}
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("Failed to parse TOML: {0}")]
TomlParse(#[from] toml::de::Error),
#[error("Failed to read config file: {0}")]
FileRead(#[from] std::io::Error),
#[error("Invalid configuration: {0}")]
Invalid(String),
#[error("Invalid IP pattern '{0}': {1}")]
InvalidIp(String, String),
#[error("Invalid action configuration: {0}")]
InvalidAction(String),
}
impl AclConfig {
pub fn from_toml(toml_str: &str) -> Result<Self, ConfigError> {
let config: AclConfig = toml::from_str(toml_str)?;
config.validate()?;
Ok(config)
}
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
let contents = std::fs::read_to_string(path)?;
Self::from_toml(&contents)
}
fn validate(&self) -> Result<(), ConfigError> {
for (i, rule) in self.rules.iter().enumerate() {
if let Some(ref ip) = rule.ip {
if ip != "*" && !ip.eq_ignore_ascii_case("any") {
IpMatcher::parse(ip)
.map_err(|e| ConfigError::InvalidIp(ip.clone(), e))?;
}
}
if let Some(ref time) = rule.time {
if let Some(start) = time.start {
if start > 23 {
return Err(ConfigError::Invalid(format!(
"Rule {}: start hour {} is invalid (must be 0-23)",
i, start
)));
}
}
if let Some(end) = time.end {
if end > 23 {
return Err(ConfigError::Invalid(format!(
"Rule {}: end hour {} is invalid (must be 0-23)",
i, end
)));
}
}
for &day in &time.days {
if day > 6 {
return Err(ConfigError::Invalid(format!(
"Rule {}: day {} is invalid (must be 0-6)",
i, day
)));
}
}
}
}
Ok(())
}
pub fn into_table(self) -> AclTable {
let default_action = action_config_to_action(&self.settings.default_action);
let mut rules: Vec<(i32, RuleConfig)> = self
.rules
.into_iter()
.map(|r| (r.priority, r))
.collect();
rules.sort_by_key(|(p, _)| *p);
let mut builder = AclTable::builder().default_action(default_action);
for (_, rule_config) in rules {
let endpoint = EndpointPattern::parse(&rule_config.endpoint);
let filter = rule_config_to_filter(rule_config);
match endpoint {
EndpointPattern::Exact(path) => {
builder = builder.add_exact(path, filter);
}
pattern => {
builder = builder.add_pattern(pattern, filter);
}
}
}
builder.build()
}
}
fn action_config_to_action(config: &ActionConfig) -> AclAction {
match config {
ActionConfig::Simple(simple) => match simple {
SimpleAction::Allow => AclAction::Allow,
SimpleAction::Deny | SimpleAction::Block => AclAction::Deny,
},
ActionConfig::Complex(complex) => match complex {
ComplexAction::Allow => AclAction::Allow,
ComplexAction::Deny | ComplexAction::Block => AclAction::Deny,
ComplexAction::Error { code, message } => AclAction::Error {
code: *code,
message: message.clone(),
},
ComplexAction::Reroute {
target,
preserve_path,
} => AclAction::Reroute {
target: target.clone(),
preserve_path: *preserve_path,
},
ComplexAction::RateLimit {
max_requests,
window_secs,
} => AclAction::RateLimit {
max_requests: *max_requests,
window_secs: *window_secs,
},
ComplexAction::Log { level, message } => AclAction::Log {
level: level.clone(),
message: message.clone(),
},
},
}
}
fn rule_config_to_filter(config: RuleConfig) -> AclRuleFilter {
let time = config.time.map(|t| {
if t.start.is_none() && t.end.is_none() && t.days.is_empty() {
TimeWindow::any()
} else {
TimeWindow {
start: t.start.and_then(|h| chrono::NaiveTime::from_hms_opt(h, 0, 0)),
end: t.end.and_then(|h| chrono::NaiveTime::from_hms_opt(h, 0, 0)),
days: t.days,
}
}
}).unwrap_or_else(TimeWindow::any);
let ip = config
.ip
.map(|s| IpMatcher::parse(&s).unwrap_or(IpMatcher::Any))
.unwrap_or(IpMatcher::Any);
let action = action_config_to_action(&config.action);
let methods: Vec<http::Method> = config
.methods
.iter()
.filter_map(|m| m.parse::<http::Method>().ok())
.collect();
let mut filter = AclRuleFilter::new()
.id(config.id)
.role_mask(config.role_mask.to_mask())
.methods(methods)
.time(time)
.ip(ip)
.action(action);
if let Some(desc) = config.description {
filter = filter.description(desc);
}
filter
}
impl AclConfig {
pub fn into_generic_table<A: Send + Sync + 'static>(
self,
registry: &dyn MatcherRegistry<A>,
) -> Result<AclTable<A>, ConfigError> {
let default_action = action_config_to_action(&self.settings.default_action);
let mut rules: Vec<(i32, RuleConfig)> = self
.rules
.into_iter()
.map(|r| (r.priority, r))
.collect();
rules.sort_by_key(|(p, _)| *p);
let mut exact_rules: std::collections::HashMap<String, Vec<AclRule<A>>> =
std::collections::HashMap::new();
let mut pattern_rules: Vec<(EndpointPattern, Vec<AclRule<A>>)> = Vec::new();
for (_, rule_config) in rules {
let Some(ref matcher_config) = rule_config.matcher else {
continue;
};
let matcher = registry.resolve(matcher_config)?;
let methods: Vec<http::Method> = rule_config
.methods
.iter()
.filter_map(|m| m.parse::<http::Method>().ok())
.collect();
let action = action_config_to_action(&rule_config.action);
let rule = if methods.is_empty() {
AclRule::from_matcher(matcher)
} else {
AclRule::from_matcher_with_methods(matcher, methods, action)
};
let endpoint = EndpointPattern::parse(&rule_config.endpoint);
match endpoint {
EndpointPattern::Exact(path) => {
exact_rules.entry(path).or_default().push(rule);
}
pattern => {
let mut found = false;
for (existing, rules) in &mut pattern_rules {
let is_match = match (existing, &pattern) {
(EndpointPattern::Any, EndpointPattern::Any) => true,
(EndpointPattern::Prefix(a), EndpointPattern::Prefix(b)) => a == b,
(EndpointPattern::Glob(a), EndpointPattern::Glob(b)) => a == b,
(EndpointPattern::Exact(a), EndpointPattern::Exact(b)) => a == b,
_ => false,
};
if is_match {
rules.push(rule.clone());
found = true;
break;
}
}
if !found {
pattern_rules.push((pattern, vec![rule]));
}
}
}
}
Ok(AclTable {
exact_rules,
pattern_rules,
default_action,
})
}
}
impl AclTable {
pub fn from_toml(toml_str: &str) -> Result<Self, ConfigError> {
let config = AclConfig::from_toml(toml_str)?;
Ok(config.into_table())
}
pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
let config = AclConfig::from_file(path)?;
Ok(config.into_table())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::IpAddr;
use crate::rule::RequestContext;
#[test]
fn test_parse_simple_config() {
let toml = r#"
[settings]
default_action = "deny"
[[rules]]
role_mask = 1
endpoint = "*"
action = "allow"
description = "Admin access"
[[rules]]
role_mask = 2
endpoint = "/api/**"
action = "allow"
"#;
let config = AclConfig::from_toml(toml).unwrap();
assert_eq!(config.rules.len(), 2);
let table = config.into_table();
assert!(!table.pattern_rules().is_empty());
}
#[test]
fn test_parse_complex_actions() {
let toml = r#"
[[rules]]
endpoint = "/error"
action = { type = "error", code = 418, message = "I'm a teapot" }
[[rules]]
endpoint = "/redirect"
action = { type = "reroute", target = "/new-path", preserve_path = true }
"#;
let config = AclConfig::from_toml(toml).unwrap();
assert_eq!(config.rules.len(), 2);
let table = config.into_table();
let ip: IpAddr = "127.0.0.1".parse().unwrap();
let ctx = RequestContext::new(u32::MAX, ip, "*");
let action = table.evaluate("/error", &ctx);
match action {
AclAction::Error { code, message } => {
assert_eq!(code, 418);
assert_eq!(message.as_deref(), Some("I'm a teapot"));
}
_ => panic!("Expected Error action"),
}
}
#[test]
fn test_parse_time_config() {
let toml = r#"
[[rules]]
role_mask = 2
endpoint = "/api/**"
time = { start = 9, end = 17, days = [0, 1, 2, 3, 4] }
action = "allow"
"#;
let config = AclConfig::from_toml(toml).unwrap();
let table = config.into_table();
assert!(!table.pattern_rules().is_empty());
}
#[test]
fn test_parse_ip_config() {
let toml = r#"
[[rules]]
endpoint = "/internal/**"
ip = "192.168.1.0/24"
action = "allow"
"#;
let config = AclConfig::from_toml(toml).unwrap();
let table = config.into_table();
let internal_ip: IpAddr = "192.168.1.50".parse().unwrap();
let external_ip: IpAddr = "10.0.0.1".parse().unwrap();
let ctx_internal = RequestContext::new(u32::MAX, internal_ip, "*");
let ctx_external = RequestContext::new(u32::MAX, external_ip, "*");
assert_eq!(table.evaluate("/internal/foo", &ctx_internal), AclAction::Allow);
assert_eq!(table.evaluate("/internal/foo", &ctx_external), AclAction::Deny);
}
#[test]
fn test_role_mask_formats() {
let toml = r#"
[[rules]]
role_mask = 3
endpoint = "/decimal"
action = "allow"
[[rules]]
role_mask = "0xFF"
endpoint = "/hex"
action = "allow"
[[rules]]
role_mask = "*"
endpoint = "/all"
action = "allow"
"#;
let config = AclConfig::from_toml(toml).unwrap();
assert_eq!(config.rules[0].role_mask.to_mask(), 3);
assert_eq!(config.rules[1].role_mask.to_mask(), 0xFF);
assert_eq!(config.rules[2].role_mask.to_mask(), u32::MAX);
}
}