use chrono::{Datelike, NaiveTime, Utc};
use http::Method;
use ipnetwork::IpNetwork;
use std::collections::HashMap;
use std::net::IpAddr;
#[derive(Debug, Clone)]
pub struct RequestContext<'a> {
pub roles: u32,
pub ip: IpAddr,
pub id: &'a str,
}
impl<'a> RequestContext<'a> {
pub fn new(roles: u32, ip: IpAddr, id: &'a str) -> Self {
Self { roles, ip, id }
}
}
#[derive(Debug, Clone)]
pub struct BitmaskAuth {
pub roles: u32,
pub id: String,
}
#[derive(Debug, Clone)]
pub struct RequestMeta {
pub method: Method,
pub path: String,
pub path_params: HashMap<String, String>,
pub ip: IpAddr,
}
pub trait RuleMatcher<A>: Send + Sync + std::fmt::Debug {
fn matches(&self, auth: &A, meta: &RequestMeta) -> bool;
fn action(&self) -> &AclAction;
fn description(&self) -> Option<&str> {
None
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum AclAction {
#[default]
Allow,
Deny,
Error {
code: u16,
message: Option<String>,
},
Reroute {
target: String,
preserve_path: bool,
},
RateLimit {
max_requests: u32,
window_secs: u64,
},
Log {
level: String,
message: Option<String>,
},
}
impl AclAction {
pub fn deny() -> Self {
Self::Deny
}
pub fn allow() -> Self {
Self::Allow
}
pub fn error(code: u16, message: impl Into<Option<String>>) -> Self {
Self::Error {
code,
message: message.into(),
}
}
pub fn reroute(target: impl Into<String>) -> Self {
Self::Reroute {
target: target.into(),
preserve_path: false,
}
}
pub fn reroute_with_preserve(target: impl Into<String>) -> Self {
Self::Reroute {
target: target.into(),
preserve_path: true,
}
}
pub fn is_allow(&self) -> bool {
matches!(self, Self::Allow | Self::Log { .. })
}
pub fn is_deny(&self) -> bool {
matches!(self, Self::Deny | Self::Error { .. })
}
}
#[derive(Debug, Clone)]
pub struct AclRuleFilter {
pub id: String,
pub role_mask: u32,
pub methods: Vec<Method>,
pub time: TimeWindow,
pub ip: IpMatcher,
pub action: AclAction,
pub description: Option<String>,
}
impl AclRuleFilter {
pub fn new() -> Self {
Self {
id: "*".to_string(),
role_mask: u32::MAX, methods: Vec::new(),
time: TimeWindow::default(),
ip: IpMatcher::Any,
action: AclAction::Allow,
description: None,
}
}
pub fn id(mut self, id: impl Into<String>) -> Self {
self.id = id.into();
self
}
pub fn role_mask(mut self, mask: u32) -> Self {
self.role_mask = mask;
self
}
pub fn role(mut self, role_id: u8) -> Self {
self.role_mask = 1 << role_id;
self
}
pub fn add_role(mut self, role_id: u8) -> Self {
self.role_mask |= 1 << role_id;
self
}
pub fn methods(mut self, methods: Vec<Method>) -> Self {
self.methods = methods;
self
}
pub fn method(mut self, method: Method) -> Self {
self.methods.push(method);
self
}
pub fn time(mut self, window: TimeWindow) -> Self {
self.time = window;
self
}
pub fn ip(mut self, matcher: IpMatcher) -> Self {
self.ip = matcher;
self
}
pub fn action(mut self, action: AclAction) -> Self {
self.action = action;
self
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
#[inline]
pub fn matches(&self, ctx: &RequestContext) -> bool {
(self.id == "*" || self.id == ctx.id)
&& (self.role_mask & ctx.roles) != 0
&& self.ip.matches(&ctx.ip)
&& self.time.matches_now()
}
}
impl Default for AclRuleFilter {
fn default() -> Self {
Self::new()
}
}
impl RuleMatcher<BitmaskAuth> for AclRuleFilter {
fn matches(&self, auth: &BitmaskAuth, meta: &RequestMeta) -> bool {
(self.methods.is_empty() || self.methods.contains(&meta.method))
&& (self.id == "*" || self.id == auth.id)
&& (self.role_mask & auth.roles) != 0
&& self.ip.matches(&meta.ip)
&& self.time.matches_now()
}
fn action(&self) -> &AclAction {
&self.action
}
fn description(&self) -> Option<&str> {
self.description.as_deref()
}
}
#[derive(Debug, Clone, Default)]
pub struct TimeWindow {
pub start: Option<NaiveTime>,
pub end: Option<NaiveTime>,
pub days: Vec<u32>,
}
impl TimeWindow {
pub fn any() -> Self {
Self::default()
}
pub fn hours(start_hour: u32, end_hour: u32) -> Self {
Self {
start: Some(NaiveTime::from_hms_opt(start_hour, 0, 0).unwrap_or_default()),
end: Some(NaiveTime::from_hms_opt(end_hour, 0, 0).unwrap_or_default()),
days: Vec::new(),
}
}
pub fn hours_on_days(start_hour: u32, end_hour: u32, days: Vec<u32>) -> Self {
Self {
start: Some(NaiveTime::from_hms_opt(start_hour, 0, 0).unwrap_or_default()),
end: Some(NaiveTime::from_hms_opt(end_hour, 0, 0).unwrap_or_default()),
days,
}
}
pub fn matches_now(&self) -> bool {
let now = Utc::now();
let current_time = now.time();
let current_day = now.weekday().num_days_from_monday();
if !self.days.is_empty() && !self.days.contains(¤t_day) {
return false;
}
match (&self.start, &self.end) {
(Some(start), Some(end)) => {
if start <= end {
current_time >= *start && current_time <= *end
} else {
current_time >= *start || current_time <= *end
}
}
(Some(start), None) => current_time >= *start,
(None, Some(end)) => current_time <= *end,
(None, None) => true,
}
}
}
#[derive(Debug, Clone, Default)]
pub enum IpMatcher {
#[default]
Any,
Single(IpAddr),
Network(IpNetwork),
List(Vec<IpMatcher>),
}
impl IpMatcher {
pub fn any() -> Self {
Self::Any
}
pub fn single(ip: IpAddr) -> Self {
Self::Single(ip)
}
pub fn cidr(network: IpNetwork) -> Self {
Self::Network(network)
}
pub fn parse(s: &str) -> Result<Self, String> {
let s = s.trim();
if s == "*" || s.eq_ignore_ascii_case("any") {
return Ok(Self::Any);
}
if s.contains('/') {
return s
.parse::<IpNetwork>()
.map(Self::Network)
.map_err(|e| format!("Invalid CIDR: {}", e));
}
s.parse::<IpAddr>()
.map(Self::Single)
.map_err(|e| format!("Invalid IP address: {}", e))
}
pub fn matches(&self, ip: &IpAddr) -> bool {
match self {
Self::Any => true,
Self::Single(addr) => addr == ip,
Self::Network(network) => network.contains(*ip),
Self::List(matchers) => matchers.iter().any(|m| m.matches(ip)),
}
}
}
#[derive(Debug, Clone, Default)]
pub enum EndpointPattern {
#[default]
Any,
Exact(String),
Prefix(String),
Glob(String),
}
impl EndpointPattern {
pub fn any() -> Self {
Self::Any
}
pub fn exact(path: impl Into<String>) -> Self {
Self::Exact(path.into())
}
pub fn prefix(path: impl Into<String>) -> Self {
Self::Prefix(path.into())
}
pub fn glob(pattern: impl Into<String>) -> Self {
Self::Glob(pattern.into())
}
pub fn parse(s: &str) -> Self {
let s = s.trim();
if s == "*" || s.eq_ignore_ascii_case("any") {
return Self::Any;
}
if s.contains('*') {
return Self::Glob(s.to_string());
}
if s.ends_with('/') {
return Self::Prefix(s.to_string());
}
Self::Exact(s.to_string())
}
pub fn matches(&self, path: &str) -> bool {
self.matches_with_id(path, None)
}
pub fn matches_with_id(&self, path: &str, user_id: Option<&str>) -> bool {
match self {
Self::Any => true,
Self::Exact(p) => p == path,
Self::Prefix(prefix) => path.starts_with(prefix),
Self::Glob(pattern) => Self::glob_matches_with_id(pattern, path, user_id),
}
}
fn glob_matches_with_id(pattern: &str, path: &str, user_id: Option<&str>) -> bool {
let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
Self::glob_match_parts_with_id(&pattern_parts, &path_parts, user_id)
}
fn glob_match_parts_with_id(pattern: &[&str], path: &[&str], user_id: Option<&str>) -> bool {
if pattern.is_empty() {
return path.is_empty();
}
let (first_pattern, rest_pattern) = (pattern[0], &pattern[1..]);
if first_pattern == "**" {
if rest_pattern.is_empty() {
return true;
}
for i in 0..=path.len() {
if Self::glob_match_parts_with_id(rest_pattern, &path[i..], user_id) {
return true;
}
}
false
} else if path.is_empty() {
false
} else {
let (first_path, rest_path) = (path[0], &path[1..]);
let segment_matches = if first_pattern == "{id}" {
match user_id {
Some(id) => first_path == id,
None => true, }
} else if first_pattern.starts_with('{') && first_pattern.ends_with('}') {
true
} else {
first_pattern == "*" || first_pattern == first_path
};
segment_matches && Self::glob_match_parts_with_id(rest_pattern, rest_path, user_id)
}
}
pub fn extract_id(&self, path: &str) -> Option<String> {
match self {
Self::Glob(pattern) => Self::extract_id_from_glob(pattern, path),
_ => None,
}
}
fn extract_id_from_glob(pattern: &str, path: &str) -> Option<String> {
let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
Self::extract_id_from_parts(&pattern_parts, &path_parts)
}
pub fn extract_named_params(&self, path: &str) -> HashMap<String, String> {
match self {
Self::Glob(pattern) => {
let pattern_parts: Vec<&str> =
pattern.split('/').filter(|s| !s.is_empty()).collect();
let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
let mut params = HashMap::new();
Self::collect_named_params(&pattern_parts, &path_parts, &mut params);
params
}
_ => HashMap::new(),
}
}
fn collect_named_params<'a>(
pattern: &[&str],
path: &[&'a str],
params: &mut HashMap<String, String>,
) {
let mut pi = 0;
let mut qi = 0;
while pi < pattern.len() && qi < path.len() {
let seg = pattern[pi];
if seg == "**" {
if pi + 1 >= pattern.len() {
return;
}
for skip in qi..=path.len() {
let mut trial = HashMap::new();
Self::collect_named_params(&pattern[pi + 1..], &path[skip..], &mut trial);
if !trial.is_empty() || (pi + 1 == pattern.len() - 1 && skip < path.len()) {
params.extend(trial);
return;
}
}
return;
}
if seg.starts_with('{') && seg.ends_with('}') {
let name = &seg[1..seg.len() - 1];
params.insert(name.to_string(), path[qi].to_string());
}
pi += 1;
qi += 1;
}
}
fn extract_id_from_parts(pattern: &[&str], path: &[&str]) -> Option<String> {
if pattern.is_empty() || path.is_empty() {
return None;
}
for (i, &p) in pattern.iter().enumerate() {
if p == "{id}" {
if i < path.len() {
return Some(path[i].to_string());
}
return None;
}
if p == "**" {
continue;
}
if i >= path.len() {
return None;
}
if p != "*" && p != path[i] && !p.starts_with('{') {
return None;
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ip_matcher_single() {
let ip: IpAddr = "192.168.1.1".parse().unwrap();
let matcher = IpMatcher::single(ip);
assert!(matcher.matches(&ip));
assert!(!matcher.matches(&"192.168.1.2".parse().unwrap()));
}
#[test]
fn test_ip_matcher_cidr() {
let matcher = IpMatcher::cidr("192.168.1.0/24".parse().unwrap());
assert!(matcher.matches(&"192.168.1.1".parse().unwrap()));
assert!(matcher.matches(&"192.168.1.255".parse().unwrap()));
assert!(!matcher.matches(&"192.168.2.1".parse().unwrap()));
}
#[test]
fn test_endpoint_exact() {
let pattern = EndpointPattern::exact("/api/users");
assert!(pattern.matches("/api/users"));
assert!(!pattern.matches("/api/users/"));
assert!(!pattern.matches("/api/users/1"));
}
#[test]
fn test_endpoint_prefix() {
let pattern = EndpointPattern::prefix("/api/");
assert!(pattern.matches("/api/users"));
assert!(pattern.matches("/api/users/1"));
assert!(!pattern.matches("/admin/users"));
}
#[test]
fn test_endpoint_glob() {
let pattern = EndpointPattern::glob("/api/*/users");
assert!(pattern.matches("/api/v1/users"));
assert!(pattern.matches("/api/v2/users"));
assert!(!pattern.matches("/api/v1/posts"));
let pattern = EndpointPattern::glob("/api/**");
assert!(pattern.matches("/api/users"));
assert!(pattern.matches("/api/v1/users/1"));
}
#[test]
fn test_endpoint_glob_with_id() {
let pattern = EndpointPattern::glob("/api/boat/{id}/details");
assert!(pattern.matches("/api/boat/boat-123/details"));
assert!(pattern.matches("/api/boat/anything/details"));
assert!(pattern.matches_with_id("/api/boat/boat-123/details", Some("boat-123")));
assert!(!pattern.matches_with_id("/api/boat/boat-456/details", Some("boat-123")));
let pattern = EndpointPattern::glob("/api/user/{id}/**");
assert!(pattern.matches_with_id("/api/user/user-1/profile", Some("user-1")));
assert!(pattern.matches_with_id("/api/user/user-1/boats/123", Some("user-1")));
assert!(!pattern.matches_with_id("/api/user/user-2/profile", Some("user-1")));
}
#[test]
fn test_extract_id_from_path() {
let pattern = EndpointPattern::glob("/api/boat/{id}/details");
assert_eq!(pattern.extract_id("/api/boat/boat-123/details"), Some("boat-123".to_string()));
assert_eq!(pattern.extract_id("/api/boat/xyz/details"), Some("xyz".to_string()));
assert_eq!(pattern.extract_id("/api/wrong/path"), None);
let pattern = EndpointPattern::glob("/users/{id}");
assert_eq!(pattern.extract_id("/users/123"), Some("123".to_string()));
assert_eq!(pattern.extract_id("/users/"), None);
}
#[test]
fn test_filter_matches() {
let filter = AclRuleFilter::new()
.role_mask(0b001) .ip(IpMatcher::any());
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let ctx = RequestContext::new(0b001, ip, "*");
assert!(filter.matches(&ctx));
let ctx = RequestContext::new(0b010, ip, "*");
assert!(!filter.matches(&ctx));
let ctx = RequestContext::new(0b011, ip, "*");
assert!(filter.matches(&ctx));
}
#[test]
fn test_filter_id_match() {
let filter = AclRuleFilter::new()
.id("user123")
.role_mask(u32::MAX);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let ctx = RequestContext::new(0b1, ip, "user123");
assert!(filter.matches(&ctx));
let ctx = RequestContext::new(0b1, ip, "user456");
assert!(!filter.matches(&ctx));
}
#[test]
fn test_filter_wildcard_id() {
let filter = AclRuleFilter::new()
.id("*")
.role_mask(u32::MAX);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let ctx = RequestContext::new(0b1, ip, "anyone");
assert!(filter.matches(&ctx));
}
#[test]
fn test_extract_named_params() {
let pattern = EndpointPattern::glob("/api/{resource}/{id}/details");
let params = pattern.extract_named_params("/api/boat/123/details");
assert_eq!(params.get("resource").map(|s| s.as_str()), Some("boat"));
assert_eq!(params.get("id").map(|s| s.as_str()), Some("123"));
let pattern = EndpointPattern::glob("/api/groups/{group_id}/factions/{faction_id}");
let params = pattern.extract_named_params("/api/groups/abc-123/factions/def-456");
assert_eq!(params.get("group_id").map(|s| s.as_str()), Some("abc-123"));
assert_eq!(params.get("faction_id").map(|s| s.as_str()), Some("def-456"));
let pattern = EndpointPattern::exact("/api/users");
let params = pattern.extract_named_params("/api/users");
assert!(params.is_empty());
}
#[test]
fn test_rule_matcher_bitmask_auth() {
let filter = AclRuleFilter::new()
.role_mask(0b001)
.action(AclAction::Allow);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let meta = RequestMeta {
method: Method::GET,
path: "/api/users".to_string(),
path_params: HashMap::new(),
ip,
};
let auth = BitmaskAuth { roles: 0b001, id: "*".to_string() };
assert!(RuleMatcher::matches(&filter, &auth, &meta));
let auth = BitmaskAuth { roles: 0b010, id: "*".to_string() };
assert!(!RuleMatcher::matches(&filter, &auth, &meta));
}
#[test]
fn test_rule_matcher_method_filtering() {
let filter = AclRuleFilter::new()
.role_mask(u32::MAX)
.method(Method::POST)
.action(AclAction::Allow);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let auth = BitmaskAuth { roles: 0b1, id: "*".to_string() };
let meta_post = RequestMeta {
method: Method::POST,
path: "/api/users".to_string(),
path_params: HashMap::new(),
ip,
};
assert!(RuleMatcher::matches(&filter, &auth, &meta_post));
let meta_get = RequestMeta {
method: Method::GET,
path: "/api/users".to_string(),
path_params: HashMap::new(),
ip,
};
assert!(!RuleMatcher::matches(&filter, &auth, &meta_get));
let filter_any = AclRuleFilter::new()
.role_mask(u32::MAX)
.action(AclAction::Allow);
assert!(RuleMatcher::matches(&filter_any, &auth, &meta_get));
assert!(RuleMatcher::matches(&filter_any, &auth, &meta_post));
}
}