1use chrono::{Datelike, NaiveTime, Utc};
13use ipnetwork::IpNetwork;
14use std::net::IpAddr;
15
16#[derive(Debug, Clone)]
18pub struct RequestContext<'a> {
19 pub roles: u32,
21 pub ip: IpAddr,
23 pub id: &'a str,
25}
26
27impl<'a> RequestContext<'a> {
28 pub fn new(roles: u32, ip: IpAddr, id: &'a str) -> Self {
30 Self { roles, ip, id }
31 }
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, Default)]
36pub enum AclAction {
37 #[default]
39 Allow,
40 Deny,
42 Error {
44 code: u16,
46 message: Option<String>,
48 },
49 Reroute {
51 target: String,
53 preserve_path: bool,
55 },
56 RateLimit {
58 max_requests: u32,
60 window_secs: u64,
62 },
63 Log {
65 level: String,
67 message: Option<String>,
69 },
70}
71
72impl AclAction {
73 pub fn deny() -> Self {
75 Self::Deny
76 }
77
78 pub fn allow() -> Self {
80 Self::Allow
81 }
82
83 pub fn error(code: u16, message: impl Into<Option<String>>) -> Self {
85 Self::Error {
86 code,
87 message: message.into(),
88 }
89 }
90
91 pub fn reroute(target: impl Into<String>) -> Self {
93 Self::Reroute {
94 target: target.into(),
95 preserve_path: false,
96 }
97 }
98
99 pub fn reroute_with_preserve(target: impl Into<String>) -> Self {
101 Self::Reroute {
102 target: target.into(),
103 preserve_path: true,
104 }
105 }
106
107 pub fn is_allow(&self) -> bool {
109 matches!(self, Self::Allow | Self::Log { .. })
110 }
111
112 pub fn is_deny(&self) -> bool {
114 matches!(self, Self::Deny | Self::Error { .. })
115 }
116}
117
118#[derive(Debug, Clone)]
123pub struct AclRuleFilter {
124 pub id: String,
126 pub role_mask: u32,
128 pub time: TimeWindow,
130 pub ip: IpMatcher,
132 pub action: AclAction,
134 pub description: Option<String>,
136}
137
138impl AclRuleFilter {
139 pub fn new() -> Self {
141 Self {
142 id: "*".to_string(),
143 role_mask: u32::MAX, time: TimeWindow::default(),
145 ip: IpMatcher::Any,
146 action: AclAction::Allow,
147 description: None,
148 }
149 }
150
151 pub fn id(mut self, id: impl Into<String>) -> Self {
153 self.id = id.into();
154 self
155 }
156
157 pub fn role_mask(mut self, mask: u32) -> Self {
159 self.role_mask = mask;
160 self
161 }
162
163 pub fn role(mut self, role_id: u8) -> Self {
165 self.role_mask = 1 << role_id;
166 self
167 }
168
169 pub fn add_role(mut self, role_id: u8) -> Self {
171 self.role_mask |= 1 << role_id;
172 self
173 }
174
175 pub fn time(mut self, window: TimeWindow) -> Self {
177 self.time = window;
178 self
179 }
180
181 pub fn ip(mut self, matcher: IpMatcher) -> Self {
183 self.ip = matcher;
184 self
185 }
186
187 pub fn action(mut self, action: AclAction) -> Self {
189 self.action = action;
190 self
191 }
192
193 pub fn description(mut self, desc: impl Into<String>) -> Self {
195 self.description = Some(desc.into());
196 self
197 }
198
199 #[inline]
203 pub fn matches(&self, ctx: &RequestContext) -> bool {
204 (self.id == "*" || self.id == ctx.id)
206 && (self.role_mask & ctx.roles) != 0
208 && self.ip.matches(&ctx.ip)
210 && self.time.matches_now()
212 }
213}
214
215impl Default for AclRuleFilter {
216 fn default() -> Self {
217 Self::new()
218 }
219}
220
221#[derive(Debug, Clone, Default)]
226pub struct TimeWindow {
227 pub start: Option<NaiveTime>,
229 pub end: Option<NaiveTime>,
231 pub days: Vec<u32>,
234}
235
236impl TimeWindow {
237 pub fn any() -> Self {
239 Self::default()
240 }
241
242 pub fn hours(start_hour: u32, end_hour: u32) -> Self {
252 Self {
253 start: Some(NaiveTime::from_hms_opt(start_hour, 0, 0).unwrap_or_default()),
254 end: Some(NaiveTime::from_hms_opt(end_hour, 0, 0).unwrap_or_default()),
255 days: Vec::new(),
256 }
257 }
258
259 pub fn hours_on_days(start_hour: u32, end_hour: u32, days: Vec<u32>) -> Self {
274 Self {
275 start: Some(NaiveTime::from_hms_opt(start_hour, 0, 0).unwrap_or_default()),
276 end: Some(NaiveTime::from_hms_opt(end_hour, 0, 0).unwrap_or_default()),
277 days,
278 }
279 }
280
281 pub fn matches_now(&self) -> bool {
283 let now = Utc::now();
284 let current_time = now.time();
285 let current_day = now.weekday().num_days_from_monday();
286
287 if !self.days.is_empty() && !self.days.contains(¤t_day) {
289 return false;
290 }
291
292 match (&self.start, &self.end) {
294 (Some(start), Some(end)) => {
295 if start <= end {
296 current_time >= *start && current_time <= *end
298 } else {
299 current_time >= *start || current_time <= *end
301 }
302 }
303 (Some(start), None) => current_time >= *start,
304 (None, Some(end)) => current_time <= *end,
305 (None, None) => true,
306 }
307 }
308}
309
310#[derive(Debug, Clone, Default)]
312pub enum IpMatcher {
313 #[default]
315 Any,
316 Single(IpAddr),
318 Network(IpNetwork),
320 List(Vec<IpMatcher>),
322}
323
324impl IpMatcher {
325 pub fn any() -> Self {
327 Self::Any
328 }
329
330 pub fn single(ip: IpAddr) -> Self {
340 Self::Single(ip)
341 }
342
343 pub fn cidr(network: IpNetwork) -> Self {
352 Self::Network(network)
353 }
354
355 pub fn parse(s: &str) -> Result<Self, String> {
371 let s = s.trim();
372 if s == "*" || s.eq_ignore_ascii_case("any") {
373 return Ok(Self::Any);
374 }
375
376 if s.contains('/') {
378 return s
379 .parse::<IpNetwork>()
380 .map(Self::Network)
381 .map_err(|e| format!("Invalid CIDR: {}", e));
382 }
383
384 s.parse::<IpAddr>()
386 .map(Self::Single)
387 .map_err(|e| format!("Invalid IP address: {}", e))
388 }
389
390 pub fn matches(&self, ip: &IpAddr) -> bool {
392 match self {
393 Self::Any => true,
394 Self::Single(addr) => addr == ip,
395 Self::Network(network) => network.contains(*ip),
396 Self::List(matchers) => matchers.iter().any(|m| m.matches(ip)),
397 }
398 }
399}
400
401#[derive(Debug, Clone, Default)]
407pub enum EndpointPattern {
408 #[default]
410 Any,
411 Exact(String),
413 Prefix(String),
415 Glob(String),
418}
419
420impl EndpointPattern {
421 pub fn any() -> Self {
423 Self::Any
424 }
425
426 pub fn exact(path: impl Into<String>) -> Self {
428 Self::Exact(path.into())
429 }
430
431 pub fn prefix(path: impl Into<String>) -> Self {
433 Self::Prefix(path.into())
434 }
435
436 pub fn glob(pattern: impl Into<String>) -> Self {
442 Self::Glob(pattern.into())
443 }
444
445 pub fn parse(s: &str) -> Self {
452 let s = s.trim();
453 if s == "*" || s.eq_ignore_ascii_case("any") {
454 return Self::Any;
455 }
456
457 if s.contains('*') {
458 return Self::Glob(s.to_string());
459 }
460
461 if s.ends_with('/') {
462 return Self::Prefix(s.to_string());
463 }
464
465 Self::Exact(s.to_string())
466 }
467
468 pub fn matches(&self, path: &str) -> bool {
470 self.matches_with_id(path, None)
471 }
472
473 pub fn matches_with_id(&self, path: &str, user_id: Option<&str>) -> bool {
494 match self {
495 Self::Any => true,
496 Self::Exact(p) => p == path,
497 Self::Prefix(prefix) => path.starts_with(prefix),
498 Self::Glob(pattern) => Self::glob_matches_with_id(pattern, path, user_id),
499 }
500 }
501
502 fn glob_matches_with_id(pattern: &str, path: &str, user_id: Option<&str>) -> bool {
503 let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
504 let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
505
506 Self::glob_match_parts_with_id(&pattern_parts, &path_parts, user_id)
507 }
508
509 fn glob_match_parts_with_id(pattern: &[&str], path: &[&str], user_id: Option<&str>) -> bool {
510 if pattern.is_empty() {
511 return path.is_empty();
512 }
513
514 let (first_pattern, rest_pattern) = (pattern[0], &pattern[1..]);
515
516 if first_pattern == "**" {
517 if rest_pattern.is_empty() {
519 return true;
520 }
521 for i in 0..=path.len() {
523 if Self::glob_match_parts_with_id(rest_pattern, &path[i..], user_id) {
524 return true;
525 }
526 }
527 false
528 } else if path.is_empty() {
529 false
530 } else {
531 let (first_path, rest_path) = (path[0], &path[1..]);
532
533 let segment_matches = if first_pattern == "{id}" {
535 match user_id {
537 Some(id) => first_path == id,
538 None => true, }
540 } else if first_pattern.starts_with('{') && first_pattern.ends_with('}') {
541 true
543 } else {
544 first_pattern == "*" || first_pattern == first_path
545 };
546
547 segment_matches && Self::glob_match_parts_with_id(rest_pattern, rest_path, user_id)
548 }
549 }
550
551 pub fn extract_id(&self, path: &str) -> Option<String> {
565 match self {
566 Self::Glob(pattern) => Self::extract_id_from_glob(pattern, path),
567 _ => None,
568 }
569 }
570
571 fn extract_id_from_glob(pattern: &str, path: &str) -> Option<String> {
572 let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
573 let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
574
575 Self::extract_id_from_parts(&pattern_parts, &path_parts)
576 }
577
578 fn extract_id_from_parts(pattern: &[&str], path: &[&str]) -> Option<String> {
579 if pattern.is_empty() || path.is_empty() {
580 return None;
581 }
582
583 for (i, &p) in pattern.iter().enumerate() {
584 if p == "{id}" {
585 if i < path.len() {
586 return Some(path[i].to_string());
587 }
588 return None;
589 }
590 if p == "**" {
591 continue;
593 }
594 if i >= path.len() {
595 return None;
596 }
597 if p != "*" && p != path[i] && !p.starts_with('{') {
599 return None;
600 }
601 }
602 None
603 }
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609
610 #[test]
611 fn test_ip_matcher_single() {
612 let ip: IpAddr = "192.168.1.1".parse().unwrap();
613 let matcher = IpMatcher::single(ip);
614 assert!(matcher.matches(&ip));
615 assert!(!matcher.matches(&"192.168.1.2".parse().unwrap()));
616 }
617
618 #[test]
619 fn test_ip_matcher_cidr() {
620 let matcher = IpMatcher::cidr("192.168.1.0/24".parse().unwrap());
621 assert!(matcher.matches(&"192.168.1.1".parse().unwrap()));
622 assert!(matcher.matches(&"192.168.1.255".parse().unwrap()));
623 assert!(!matcher.matches(&"192.168.2.1".parse().unwrap()));
624 }
625
626 #[test]
627 fn test_endpoint_exact() {
628 let pattern = EndpointPattern::exact("/api/users");
629 assert!(pattern.matches("/api/users"));
630 assert!(!pattern.matches("/api/users/"));
631 assert!(!pattern.matches("/api/users/1"));
632 }
633
634 #[test]
635 fn test_endpoint_prefix() {
636 let pattern = EndpointPattern::prefix("/api/");
637 assert!(pattern.matches("/api/users"));
638 assert!(pattern.matches("/api/users/1"));
639 assert!(!pattern.matches("/admin/users"));
640 }
641
642 #[test]
643 fn test_endpoint_glob() {
644 let pattern = EndpointPattern::glob("/api/*/users");
645 assert!(pattern.matches("/api/v1/users"));
646 assert!(pattern.matches("/api/v2/users"));
647 assert!(!pattern.matches("/api/v1/posts"));
648
649 let pattern = EndpointPattern::glob("/api/**");
650 assert!(pattern.matches("/api/users"));
651 assert!(pattern.matches("/api/v1/users/1"));
652 }
653
654 #[test]
655 fn test_endpoint_glob_with_id() {
656 let pattern = EndpointPattern::glob("/api/boat/{id}/details");
657
658 assert!(pattern.matches("/api/boat/boat-123/details"));
660 assert!(pattern.matches("/api/boat/anything/details"));
661
662 assert!(pattern.matches_with_id("/api/boat/boat-123/details", Some("boat-123")));
664
665 assert!(!pattern.matches_with_id("/api/boat/boat-456/details", Some("boat-123")));
667
668 let pattern = EndpointPattern::glob("/api/user/{id}/**");
670 assert!(pattern.matches_with_id("/api/user/user-1/profile", Some("user-1")));
671 assert!(pattern.matches_with_id("/api/user/user-1/boats/123", Some("user-1")));
672 assert!(!pattern.matches_with_id("/api/user/user-2/profile", Some("user-1")));
673 }
674
675 #[test]
676 fn test_extract_id_from_path() {
677 let pattern = EndpointPattern::glob("/api/boat/{id}/details");
678 assert_eq!(pattern.extract_id("/api/boat/boat-123/details"), Some("boat-123".to_string()));
679 assert_eq!(pattern.extract_id("/api/boat/xyz/details"), Some("xyz".to_string()));
680 assert_eq!(pattern.extract_id("/api/wrong/path"), None);
681
682 let pattern = EndpointPattern::glob("/users/{id}");
683 assert_eq!(pattern.extract_id("/users/123"), Some("123".to_string()));
684 assert_eq!(pattern.extract_id("/users/"), None);
685 }
686
687 #[test]
688 fn test_filter_matches() {
689 let filter = AclRuleFilter::new()
690 .role_mask(0b001) .ip(IpMatcher::any());
692
693 let ip: IpAddr = "10.0.0.1".parse().unwrap();
694
695 let ctx = RequestContext::new(0b001, ip, "*");
697 assert!(filter.matches(&ctx));
698
699 let ctx = RequestContext::new(0b010, ip, "*");
701 assert!(!filter.matches(&ctx));
702
703 let ctx = RequestContext::new(0b011, ip, "*");
705 assert!(filter.matches(&ctx));
706 }
707
708 #[test]
709 fn test_filter_id_match() {
710 let filter = AclRuleFilter::new()
711 .id("user123")
712 .role_mask(u32::MAX);
713
714 let ip: IpAddr = "10.0.0.1".parse().unwrap();
715
716 let ctx = RequestContext::new(0b1, ip, "user123");
718 assert!(filter.matches(&ctx));
719
720 let ctx = RequestContext::new(0b1, ip, "user456");
722 assert!(!filter.matches(&ctx));
723 }
724
725 #[test]
726 fn test_filter_wildcard_id() {
727 let filter = AclRuleFilter::new()
728 .id("*")
729 .role_mask(u32::MAX);
730
731 let ip: IpAddr = "10.0.0.1".parse().unwrap();
732
733 let ctx = RequestContext::new(0b1, ip, "anyone");
735 assert!(filter.matches(&ctx));
736 }
737}