1use crate::rule::{AclAction, AclRuleFilter, EndpointPattern, RequestContext};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12#[derive(Debug, Clone)]
35pub struct AclTable {
36 pub(crate) exact_rules: HashMap<String, Vec<AclRuleFilter>>,
38 pub(crate) pattern_rules: Vec<(EndpointPattern, Vec<AclRuleFilter>)>,
40 pub(crate) default_action: AclAction,
42}
43
44impl Default for AclTable {
45 fn default() -> Self {
46 Self {
47 exact_rules: HashMap::new(),
48 pattern_rules: Vec::new(),
49 default_action: AclAction::Deny,
50 }
51 }
52}
53
54impl AclTable {
55 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub fn builder() -> AclTableBuilder {
62 AclTableBuilder::new()
63 }
64
65 pub fn exact_rules(&self) -> &HashMap<String, Vec<AclRuleFilter>> {
67 &self.exact_rules
68 }
69
70 pub fn pattern_rules(&self) -> &[(EndpointPattern, Vec<AclRuleFilter>)] {
72 &self.pattern_rules
73 }
74
75 pub fn default_action(&self) -> AclAction {
77 self.default_action.clone()
78 }
79
80 pub fn evaluate(&self, path: &str, ctx: &RequestContext) -> AclAction {
105 self.evaluate_with_match(path, ctx).0
106 }
107
108 pub fn evaluate_with_match(&self, path: &str, ctx: &RequestContext) -> (AclAction, Option<(String, usize)>) {
113 if let Some(filters) = self.exact_rules.get(path) {
115 for (idx, filter) in filters.iter().enumerate() {
116 if filter.matches(ctx) {
117 tracing::debug!(
118 endpoint = path,
119 filter_index = idx,
120 filter_description = ?filter.description,
121 roles = ctx.roles,
122 id = ctx.id,
123 ip = %ctx.ip,
124 action = ?filter.action,
125 "ACL exact match"
126 );
127 return (filter.action.clone(), Some((path.to_string(), idx)));
128 }
129 }
130 }
131
132 for (pattern, filters) in &self.pattern_rules {
134 if pattern.matches(path) {
135 for (idx, filter) in filters.iter().enumerate() {
136 if filter.matches(ctx) {
137 tracing::debug!(
138 endpoint = ?pattern,
139 filter_index = idx,
140 filter_description = ?filter.description,
141 roles = ctx.roles,
142 id = ctx.id,
143 ip = %ctx.ip,
144 action = ?filter.action,
145 "ACL pattern match"
146 );
147 return (filter.action.clone(), Some((format!("{:?}", pattern), idx)));
148 }
149 }
150 }
151 }
152
153 tracing::debug!(
154 path = path,
155 roles = ctx.roles,
156 id = ctx.id,
157 ip = %ctx.ip,
158 action = ?self.default_action,
159 "No ACL rule matched, using default action"
160 );
161 (self.default_action.clone(), None)
162 }
163
164 pub fn is_allowed(&self, path: &str, ctx: &RequestContext) -> bool {
166 self.evaluate(path, ctx) == AclAction::Allow
167 }
168}
169
170#[derive(Debug, Default)]
172pub struct AclTableBuilder {
173 exact_rules: HashMap<String, Vec<AclRuleFilter>>,
174 pattern_rules: Vec<(EndpointPattern, Vec<AclRuleFilter>)>,
175 default_action: AclAction,
176}
177
178impl AclTableBuilder {
179 pub fn new() -> Self {
181 Self::default()
182 }
183
184 pub fn default_action(mut self, action: AclAction) -> Self {
188 self.default_action = action;
189 self
190 }
191
192 pub fn add_exact(mut self, endpoint: impl Into<String>, filter: AclRuleFilter) -> Self {
194 let endpoint = endpoint.into();
195 self.exact_rules
196 .entry(endpoint)
197 .or_default()
198 .push(filter);
199 self
200 }
201
202 pub fn add_exact_filters(
204 mut self,
205 endpoint: impl Into<String>,
206 filters: impl IntoIterator<Item = AclRuleFilter>,
207 ) -> Self {
208 let endpoint = endpoint.into();
209 self.exact_rules
210 .entry(endpoint)
211 .or_default()
212 .extend(filters);
213 self
214 }
215
216 pub fn add_prefix(self, prefix: impl Into<String>, filter: AclRuleFilter) -> Self {
218 let pattern = EndpointPattern::Prefix(prefix.into());
219 self.add_pattern(pattern, filter)
220 }
221
222 pub fn add_glob(self, glob: impl Into<String>, filter: AclRuleFilter) -> Self {
224 let pattern = EndpointPattern::Glob(glob.into());
225 self.add_pattern(pattern, filter)
226 }
227
228 pub fn add_any(self, filter: AclRuleFilter) -> Self {
230 self.add_pattern(EndpointPattern::Any, filter)
231 }
232
233 pub fn add_pattern(mut self, pattern: EndpointPattern, filter: AclRuleFilter) -> Self {
235 for (existing_pattern, filters) in &mut self.pattern_rules {
237 let is_match = match (existing_pattern, &pattern) {
238 (EndpointPattern::Any, EndpointPattern::Any) => true,
239 (EndpointPattern::Prefix(a), EndpointPattern::Prefix(b)) => a == b,
240 (EndpointPattern::Glob(a), EndpointPattern::Glob(b)) => a == b,
241 (EndpointPattern::Exact(a), EndpointPattern::Exact(b)) => a == b,
242 _ => false,
243 };
244 if is_match {
245 filters.push(filter);
246 return self;
247 }
248 }
249 self.pattern_rules.push((pattern, vec![filter]));
251 self
252 }
253
254 pub fn build(self) -> AclTable {
256 AclTable {
257 exact_rules: self.exact_rules,
258 pattern_rules: self.pattern_rules,
259 default_action: self.default_action,
260 }
261 }
262
263 pub fn build_shared(self) -> Arc<AclTable> {
265 Arc::new(self.build())
266 }
267}
268
269#[derive(Debug, Clone)]
271pub struct RuleEntry {
272 pub pattern: EndpointPattern,
274 pub filter: AclRuleFilter,
276}
277
278impl RuleEntry {
279 pub fn new(pattern: EndpointPattern, filter: AclRuleFilter) -> Self {
281 Self { pattern, filter }
282 }
283
284 pub fn exact(endpoint: impl Into<String>, filter: AclRuleFilter) -> Self {
286 Self::new(EndpointPattern::Exact(endpoint.into()), filter)
287 }
288
289 pub fn prefix(prefix: impl Into<String>, filter: AclRuleFilter) -> Self {
291 Self::new(EndpointPattern::Prefix(prefix.into()), filter)
292 }
293
294 pub fn glob(glob: impl Into<String>, filter: AclRuleFilter) -> Self {
296 Self::new(EndpointPattern::Glob(glob.into()), filter)
297 }
298
299 pub fn any(filter: AclRuleFilter) -> Self {
301 Self::new(EndpointPattern::Any, filter)
302 }
303}
304
305pub trait AclRuleProvider: Send + Sync {
332 type Error: std::error::Error + Send + Sync + 'static;
334
335 fn load_rules(&self) -> Result<Vec<RuleEntry>, Self::Error>;
337}
338
339#[derive(Debug, Clone)]
341pub struct StaticRuleProvider {
342 rules: Vec<RuleEntry>,
343}
344
345impl StaticRuleProvider {
346 pub fn new(rules: Vec<RuleEntry>) -> Self {
348 Self { rules }
349 }
350}
351
352impl AclRuleProvider for StaticRuleProvider {
353 type Error = std::convert::Infallible;
354
355 fn load_rules(&self) -> Result<Vec<RuleEntry>, Self::Error> {
356 Ok(self.rules.clone())
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use std::net::IpAddr;
364
365 const ROLE_ADMIN: u32 = 0b001;
366 const ROLE_USER: u32 = 0b010;
367 const ROLE_GUEST: u32 = 0b100;
368
369 #[test]
370 fn test_table_evaluation() {
371 let table = AclTable::builder()
372 .default_action(AclAction::Deny)
373 .add_any(AclRuleFilter::new()
375 .role_mask(ROLE_ADMIN)
376 .action(AclAction::Allow))
377 .add_prefix("/api/", AclRuleFilter::new()
379 .role_mask(ROLE_USER)
380 .action(AclAction::Allow))
381 .build();
382
383 let ip: IpAddr = "127.0.0.1".parse().unwrap();
384
385 let admin_ctx = RequestContext::new(ROLE_ADMIN, ip, "admin1");
387 assert!(table.is_allowed("/admin/dashboard", &admin_ctx));
388 assert!(table.is_allowed("/api/users", &admin_ctx));
389
390 let user_ctx = RequestContext::new(ROLE_USER, ip, "user1");
392 assert!(table.is_allowed("/api/users", &user_ctx));
393 assert!(!table.is_allowed("/admin/dashboard", &user_ctx));
394
395 let guest_ctx = RequestContext::new(ROLE_GUEST, ip, "guest1");
397 assert!(!table.is_allowed("/api/users", &guest_ctx));
398 }
399
400 #[test]
401 fn test_exact_before_pattern() {
402 let table = AclTable::builder()
404 .default_action(AclAction::Deny)
405 .add_exact("/public", AclRuleFilter::new()
407 .role_mask(u32::MAX)
408 .action(AclAction::Allow))
409 .add_any(AclRuleFilter::new()
411 .role_mask(u32::MAX)
412 .action(AclAction::Deny))
413 .build();
414
415 let ip: IpAddr = "127.0.0.1".parse().unwrap();
416 let ctx = RequestContext::new(0b1, ip, "anyone");
417
418 assert!(table.is_allowed("/public", &ctx));
419 assert!(!table.is_allowed("/private", &ctx));
420 }
421
422 #[test]
423 fn test_role_bitmask() {
424 let table = AclTable::builder()
425 .default_action(AclAction::Deny)
426 .add_exact("/shared", AclRuleFilter::new()
427 .role_mask(ROLE_ADMIN | ROLE_USER) .action(AclAction::Allow))
429 .build();
430
431 let ip: IpAddr = "127.0.0.1".parse().unwrap();
432
433 assert!(table.is_allowed("/shared", &RequestContext::new(ROLE_ADMIN, ip, "a")));
435 assert!(table.is_allowed("/shared", &RequestContext::new(ROLE_USER, ip, "u")));
437 assert!(!table.is_allowed("/shared", &RequestContext::new(ROLE_GUEST, ip, "g")));
439 assert!(table.is_allowed("/shared", &RequestContext::new(ROLE_ADMIN | ROLE_USER, ip, "au")));
441 }
442}