use crate::rule::{
AclAction, AclRuleFilter, BitmaskAuth, EndpointPattern, RequestContext, RequestMeta,
RuleMatcher,
};
use std::collections::HashMap;
use std::sync::Arc;
pub struct AclRule<A> {
pub(crate) matcher: Arc<dyn RuleMatcher<A>>,
}
impl<A> Clone for AclRule<A> {
fn clone(&self) -> Self {
Self {
matcher: self.matcher.clone(),
}
}
}
impl<A> std::fmt::Debug for AclRule<A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AclRule")
.field("matcher", &self.matcher)
.finish()
}
}
impl<A> AclRule<A> {
pub fn from_matcher(matcher: Arc<dyn RuleMatcher<A>>) -> Self {
Self { matcher }
}
pub fn from_matcher_with_methods(
matcher: Arc<dyn RuleMatcher<A>>,
methods: Vec<http::Method>,
action: AclAction,
) -> Self
where
A: Send + Sync + 'static,
{
Self {
matcher: Arc::new(MethodFilterMatcher {
inner: matcher,
methods,
action,
}),
}
}
pub fn action(&self) -> &AclAction {
self.matcher.action()
}
pub fn description(&self) -> Option<&str> {
self.matcher.description()
}
}
struct MethodFilterMatcher<A> {
inner: Arc<dyn RuleMatcher<A>>,
methods: Vec<http::Method>,
action: AclAction,
}
impl<A> std::fmt::Debug for MethodFilterMatcher<A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MethodFilterMatcher")
.field("methods", &self.methods)
.field("inner", &self.inner)
.finish()
}
}
impl<A> RuleMatcher<A> for MethodFilterMatcher<A>
where
A: Send + Sync,
{
fn matches(&self, auth: &A, meta: &RequestMeta) -> bool {
(self.methods.is_empty() || self.methods.contains(&meta.method))
&& self.inner.matches(auth, meta)
}
fn action(&self) -> &AclAction {
&self.action
}
fn description(&self) -> Option<&str> {
self.inner.description()
}
}
pub struct AclTable<A = BitmaskAuth> {
pub(crate) exact_rules: HashMap<String, Vec<AclRule<A>>>,
pub(crate) pattern_rules: Vec<(EndpointPattern, Vec<AclRule<A>>)>,
pub(crate) default_action: AclAction,
}
impl<A> Clone for AclTable<A> {
fn clone(&self) -> Self {
Self {
exact_rules: self.exact_rules.clone(),
pattern_rules: self.pattern_rules.clone(),
default_action: self.default_action.clone(),
}
}
}
impl<A> std::fmt::Debug for AclTable<A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AclTable")
.field("exact_rules_count", &self.exact_rules.len())
.field("pattern_rules_count", &self.pattern_rules.len())
.field("default_action", &self.default_action)
.finish()
}
}
impl<A> Default for AclTable<A> {
fn default() -> Self {
Self {
exact_rules: HashMap::new(),
pattern_rules: Vec::new(),
default_action: AclAction::Deny,
}
}
}
impl<A> AclTable<A> {
pub fn exact_rules(&self) -> &HashMap<String, Vec<AclRule<A>>> {
&self.exact_rules
}
pub fn pattern_rules(&self) -> &[(EndpointPattern, Vec<AclRule<A>>)] {
&self.pattern_rules
}
pub fn default_action(&self) -> AclAction {
self.default_action.clone()
}
pub fn evaluate_request(&self, auth: &A, meta: &RequestMeta) -> AclAction {
self.evaluate_request_with_match(auth, meta).0
}
pub fn evaluate_request_with_match(
&self,
auth: &A,
meta: &RequestMeta,
) -> (AclAction, Option<(String, usize)>) {
if let Some(rules) = self.exact_rules.get(&meta.path) {
for (idx, rule) in rules.iter().enumerate() {
if rule.matcher.matches(auth, meta) {
tracing::debug!(
endpoint = %meta.path,
filter_index = idx,
filter_description = ?rule.description(),
ip = %meta.ip,
action = ?rule.action(),
"ACL exact match"
);
return (rule.action().clone(), Some((meta.path.clone(), idx)));
}
}
}
for (pattern, rules) in &self.pattern_rules {
if pattern.matches(&meta.path) {
let mut meta_with_params = meta.clone();
meta_with_params.path_params = pattern.extract_named_params(&meta.path);
for (idx, rule) in rules.iter().enumerate() {
if rule.matcher.matches(auth, &meta_with_params) {
tracing::debug!(
endpoint = ?pattern,
filter_index = idx,
filter_description = ?rule.description(),
ip = %meta.ip,
action = ?rule.action(),
"ACL pattern match"
);
return (
rule.action().clone(),
Some((format!("{:?}", pattern), idx)),
);
}
}
}
}
tracing::debug!(
path = %meta.path,
ip = %meta.ip,
action = ?self.default_action,
"No ACL rule matched, using default action"
);
(self.default_action.clone(), None)
}
pub fn is_request_allowed(&self, auth: &A, meta: &RequestMeta) -> bool {
self.evaluate_request(auth, meta) == AclAction::Allow
}
}
impl AclTable<BitmaskAuth> {
pub fn new() -> Self {
Self::default()
}
pub fn builder() -> AclTableBuilder<BitmaskAuth> {
AclTableBuilder::new()
}
pub fn evaluate(&self, path: &str, ctx: &RequestContext) -> AclAction {
self.evaluate_with_match(path, ctx).0
}
pub fn evaluate_with_match(
&self,
path: &str,
ctx: &RequestContext,
) -> (AclAction, Option<(String, usize)>) {
let meta = RequestMeta {
method: http::Method::GET,
path: path.to_string(),
path_params: HashMap::new(),
ip: ctx.ip,
};
let auth = BitmaskAuth {
roles: ctx.roles,
id: ctx.id.to_string(),
};
self.evaluate_request_with_match(&auth, &meta)
}
pub fn is_allowed(&self, path: &str, ctx: &RequestContext) -> bool {
self.evaluate(path, ctx) == AclAction::Allow
}
}
pub struct AclTableBuilder<A = BitmaskAuth> {
exact_rules: HashMap<String, Vec<AclRule<A>>>,
pattern_rules: Vec<(EndpointPattern, Vec<AclRule<A>>)>,
default_action: AclAction,
}
impl<A> Default for AclTableBuilder<A> {
fn default() -> Self {
Self {
exact_rules: HashMap::new(),
pattern_rules: Vec::new(),
default_action: AclAction::Deny,
}
}
}
impl<A> std::fmt::Debug for AclTableBuilder<A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AclTableBuilder")
.field("default_action", &self.default_action)
.finish()
}
}
impl<A: 'static> AclTableBuilder<A> {
pub fn new() -> Self {
Self::default()
}
pub fn default_action(mut self, action: AclAction) -> Self {
self.default_action = action;
self
}
pub fn add_exact_matcher(
mut self,
endpoint: impl Into<String>,
matcher: impl RuleMatcher<A> + 'static,
) -> Self {
let rule = AclRule {
matcher: Arc::new(matcher),
};
self.exact_rules.entry(endpoint.into()).or_default().push(rule);
self
}
pub fn add_pattern_matcher(
mut self,
pattern: EndpointPattern,
matcher: impl RuleMatcher<A> + 'static,
) -> Self {
let rule = AclRule {
matcher: Arc::new(matcher),
};
for (existing_pattern, rules) in &mut self.pattern_rules {
let is_match = match (existing_pattern, &pattern) {
(EndpointPattern::Any, EndpointPattern::Any) => true,
(EndpointPattern::Prefix(a), EndpointPattern::Prefix(b)) => a == b,
(EndpointPattern::Glob(a), EndpointPattern::Glob(b)) => a == b,
(EndpointPattern::Exact(a), EndpointPattern::Exact(b)) => a == b,
_ => false,
};
if is_match {
rules.push(rule);
return self;
}
}
self.pattern_rules.push((pattern, vec![rule]));
self
}
pub fn add_any_matcher(self, matcher: impl RuleMatcher<A> + 'static) -> Self {
self.add_pattern_matcher(EndpointPattern::Any, matcher)
}
pub fn build(self) -> AclTable<A> {
AclTable {
exact_rules: self.exact_rules,
pattern_rules: self.pattern_rules,
default_action: self.default_action,
}
}
pub fn build_shared(self) -> Arc<AclTable<A>> {
Arc::new(self.build())
}
}
impl AclTableBuilder<BitmaskAuth> {
pub fn add_exact(self, endpoint: impl Into<String>, filter: AclRuleFilter) -> Self {
self.add_exact_matcher(endpoint, filter)
}
pub fn add_exact_filters(
mut self,
endpoint: impl Into<String>,
filters: impl IntoIterator<Item = AclRuleFilter>,
) -> Self {
let endpoint = endpoint.into();
let rules: Vec<AclRule<BitmaskAuth>> = filters
.into_iter()
.map(|f| AclRule {
matcher: Arc::new(f),
})
.collect();
self.exact_rules.entry(endpoint).or_default().extend(rules);
self
}
pub fn add_prefix(self, prefix: impl Into<String>, filter: AclRuleFilter) -> Self {
let pattern = EndpointPattern::Prefix(prefix.into());
self.add_pattern_matcher(pattern, filter)
}
pub fn add_glob(self, glob: impl Into<String>, filter: AclRuleFilter) -> Self {
let pattern = EndpointPattern::Glob(glob.into());
self.add_pattern_matcher(pattern, filter)
}
pub fn add_any(self, filter: AclRuleFilter) -> Self {
self.add_pattern_matcher(EndpointPattern::Any, filter)
}
pub fn add_pattern(self, pattern: EndpointPattern, filter: AclRuleFilter) -> Self {
self.add_pattern_matcher(pattern, filter)
}
}
#[derive(Debug, Clone)]
pub struct RuleEntry {
pub pattern: EndpointPattern,
pub filter: AclRuleFilter,
}
impl RuleEntry {
pub fn new(pattern: EndpointPattern, filter: AclRuleFilter) -> Self {
Self { pattern, filter }
}
pub fn exact(endpoint: impl Into<String>, filter: AclRuleFilter) -> Self {
Self::new(EndpointPattern::Exact(endpoint.into()), filter)
}
pub fn prefix(prefix: impl Into<String>, filter: AclRuleFilter) -> Self {
Self::new(EndpointPattern::Prefix(prefix.into()), filter)
}
pub fn glob(glob: impl Into<String>, filter: AclRuleFilter) -> Self {
Self::new(EndpointPattern::Glob(glob.into()), filter)
}
pub fn any(filter: AclRuleFilter) -> Self {
Self::new(EndpointPattern::Any, filter)
}
}
pub trait AclRuleProvider: Send + Sync {
type Error: std::error::Error + Send + Sync + 'static;
fn load_rules(&self) -> Result<Vec<RuleEntry>, Self::Error>;
}
#[derive(Debug, Clone)]
pub struct StaticRuleProvider {
rules: Vec<RuleEntry>,
}
impl StaticRuleProvider {
pub fn new(rules: Vec<RuleEntry>) -> Self {
Self { rules }
}
}
impl AclRuleProvider for StaticRuleProvider {
type Error = std::convert::Infallible;
fn load_rules(&self) -> Result<Vec<RuleEntry>, Self::Error> {
Ok(self.rules.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::IpAddr;
const ROLE_ADMIN: u32 = 0b001;
const ROLE_USER: u32 = 0b010;
const ROLE_GUEST: u32 = 0b100;
#[test]
fn test_table_evaluation() {
let table = AclTable::builder()
.default_action(AclAction::Deny)
.add_any(
AclRuleFilter::new()
.role_mask(ROLE_ADMIN)
.action(AclAction::Allow),
)
.add_prefix(
"/api/",
AclRuleFilter::new()
.role_mask(ROLE_USER)
.action(AclAction::Allow),
)
.build();
let ip: IpAddr = "127.0.0.1".parse().unwrap();
let admin_ctx = RequestContext::new(ROLE_ADMIN, ip, "admin1");
assert!(table.is_allowed("/admin/dashboard", &admin_ctx));
assert!(table.is_allowed("/api/users", &admin_ctx));
let user_ctx = RequestContext::new(ROLE_USER, ip, "user1");
assert!(table.is_allowed("/api/users", &user_ctx));
assert!(!table.is_allowed("/admin/dashboard", &user_ctx));
let guest_ctx = RequestContext::new(ROLE_GUEST, ip, "guest1");
assert!(!table.is_allowed("/api/users", &guest_ctx));
}
#[test]
fn test_exact_before_pattern() {
let table = AclTable::builder()
.default_action(AclAction::Deny)
.add_exact(
"/public",
AclRuleFilter::new()
.role_mask(u32::MAX)
.action(AclAction::Allow),
)
.add_any(
AclRuleFilter::new()
.role_mask(u32::MAX)
.action(AclAction::Deny),
)
.build();
let ip: IpAddr = "127.0.0.1".parse().unwrap();
let ctx = RequestContext::new(0b1, ip, "anyone");
assert!(table.is_allowed("/public", &ctx));
assert!(!table.is_allowed("/private", &ctx));
}
#[test]
fn test_role_bitmask() {
let table = AclTable::builder()
.default_action(AclAction::Deny)
.add_exact(
"/shared",
AclRuleFilter::new()
.role_mask(ROLE_ADMIN | ROLE_USER) .action(AclAction::Allow),
)
.build();
let ip: IpAddr = "127.0.0.1".parse().unwrap();
assert!(table.is_allowed("/shared", &RequestContext::new(ROLE_ADMIN, ip, "a")));
assert!(table.is_allowed("/shared", &RequestContext::new(ROLE_USER, ip, "u")));
assert!(!table.is_allowed("/shared", &RequestContext::new(ROLE_GUEST, ip, "g")));
assert!(table.is_allowed(
"/shared",
&RequestContext::new(ROLE_ADMIN | ROLE_USER, ip, "au")
));
}
#[test]
fn test_generic_table_custom_auth() {
#[derive(Debug, Clone)]
struct CustomAuth {
role: String,
}
#[derive(Debug)]
struct RequireRole {
role: String,
action: AclAction,
}
impl RuleMatcher<CustomAuth> for RequireRole {
fn matches(&self, auth: &CustomAuth, _meta: &RequestMeta) -> bool {
auth.role == self.role
}
fn action(&self) -> &AclAction {
&self.action
}
}
let table: AclTable<CustomAuth> = AclTableBuilder::new()
.default_action(AclAction::Deny)
.add_exact_matcher(
"/admin",
RequireRole {
role: "admin".to_string(),
action: AclAction::Allow,
},
)
.build();
let ip: IpAddr = "127.0.0.1".parse().unwrap();
let meta = RequestMeta {
method: http::Method::GET,
path: "/admin".to_string(),
path_params: HashMap::new(),
ip,
};
let admin = CustomAuth {
role: "admin".to_string(),
};
assert!(table.is_request_allowed(&admin, &meta));
let user = CustomAuth {
role: "user".to_string(),
};
assert!(!table.is_request_allowed(&user, &meta));
}
}