1use std::collections::HashMap;
4use std::sync::Arc;
5
6use arc_swap::ArcSwap;
7use tracing::debug;
8
9use crate::error::RulesError;
10#[cfg(feature = "geoip")]
11use crate::matcher::GeoipMatcher;
12use crate::matcher::{CidrMatcher, DomainMatcher, KeywordMatcher};
13use crate::rule::{Action, EngineRule, MatchContext, ParsedRule};
14
15#[derive(Debug)]
17pub enum MatchDecision<'a> {
18 Matched(&'a Action),
20 NeedIp,
22}
23
24#[derive(Debug)]
26struct CompiledRuleSet {
27 domain_matcher: DomainMatcher,
28 keyword_matcher: Option<KeywordMatcher>,
29 cidr_matcher: CidrMatcher,
30 dst_ports: Vec<u16>,
32 src_cidr_matcher: CidrMatcher,
34}
35
36impl CompiledRuleSet {
37 fn compile(rules: Vec<ParsedRule>) -> Self {
39 let mut domain_matcher = DomainMatcher::new();
40 let mut keywords = Vec::new();
41 let mut v4_cidrs = Vec::new();
42 let mut v6_cidrs = Vec::new();
43 let mut dst_ports = Vec::new();
44 let mut src_v4_cidrs = Vec::new();
45 let mut src_v6_cidrs = Vec::new();
46
47 for rule in rules {
48 match rule {
49 ParsedRule::Domain(d) => domain_matcher.add_exact(&d),
50 ParsedRule::DomainSuffix(d) => domain_matcher.add_suffix(&d),
51 ParsedRule::DomainKeyword(k) => keywords.push(k),
52 ParsedRule::IpCidr(net) => match net {
53 ipnet::IpNet::V4(v4) => v4_cidrs.push(v4),
54 ipnet::IpNet::V6(v6) => v6_cidrs.push(v6),
55 },
56 ParsedRule::DstPort(port) => dst_ports.push(port),
57 ParsedRule::SrcIpCidr(net) => match net {
58 ipnet::IpNet::V4(v4) => src_v4_cidrs.push(v4),
59 ipnet::IpNet::V6(v6) => src_v6_cidrs.push(v6),
60 },
61 }
62 }
63
64 Self {
65 domain_matcher,
66 keyword_matcher: KeywordMatcher::new(keywords),
67 cidr_matcher: CidrMatcher::new(v4_cidrs, v6_cidrs),
68 dst_ports,
69 src_cidr_matcher: CidrMatcher::new(src_v4_cidrs, src_v6_cidrs),
70 }
71 }
72
73 fn matches(&self, ctx: &MatchContext) -> bool {
75 if let Some(domain) = ctx.domain {
77 if self.domain_matcher.matches(domain) {
78 return true;
79 }
80 if let Some(ref kw) = self.keyword_matcher
81 && kw.matches(domain)
82 {
83 return true;
84 }
85 }
86
87 if let Some(ip) = ctx.dest_ip
89 && self.cidr_matcher.contains(ip)
90 {
91 return true;
92 }
93
94 if !self.dst_ports.is_empty() && self.dst_ports.contains(&ctx.dest_port) {
96 return true;
97 }
98
99 if self.src_cidr_matcher.contains(ctx.src_ip) {
101 return true;
102 }
103
104 false
105 }
106}
107
108pub struct RuleEngine {
112 compiled_sets: HashMap<String, CompiledRuleSet>,
113 rules: Vec<EngineRule>,
114 final_action: Action,
115 #[cfg(feature = "geoip")]
116 geoip: Option<Arc<GeoipMatcher>>,
117}
118
119impl RuleEngine {
120 pub fn match_request(&self, ctx: &MatchContext) -> &Action {
125 for rule in &self.rules {
126 match rule {
127 EngineRule::RuleSet { name, action } => {
128 if let Some(compiled) = self.compiled_sets.get(name)
129 && compiled.matches(ctx)
130 {
131 return action;
132 }
133 }
134 EngineRule::GeoIp { code, action } => {
135 #[cfg(feature = "geoip")]
136 if let Some(ref geoip) = self.geoip {
137 if let Some(ip) = ctx.dest_ip
141 && geoip.matches(ip, code)
142 {
143 return action;
144 }
145 }
146 #[cfg(not(feature = "geoip"))]
147 {
148 let _ = (code, action);
149 }
151 }
152 EngineRule::Inline { rule, action } => {
153 if inline_matches(rule, ctx) {
154 return action;
155 }
156 }
157 EngineRule::Final { action } => {
158 return action;
159 }
160 }
161 }
162 &self.final_action
163 }
164
165 pub fn match_request_lazy_ip(&self, ctx: &MatchContext) -> MatchDecision<'_> {
168 for rule in &self.rules {
169 match rule {
170 EngineRule::RuleSet { name, action } => {
171 if let Some(compiled) = self.compiled_sets.get(name) {
172 if let Some(domain) = ctx.domain
174 && (compiled.domain_matcher.matches(domain)
175 || compiled
176 .keyword_matcher
177 .as_ref()
178 .is_some_and(|kw| kw.matches(domain)))
179 {
180 return MatchDecision::Matched(action);
181 }
182
183 if !compiled.dst_ports.is_empty()
185 && compiled.dst_ports.contains(&ctx.dest_port)
186 {
187 return MatchDecision::Matched(action);
188 }
189 if compiled.src_cidr_matcher.contains(ctx.src_ip) {
190 return MatchDecision::Matched(action);
191 }
192
193 if ctx.dest_ip.is_none() && !compiled.cidr_matcher.is_empty() {
195 return MatchDecision::NeedIp;
196 }
197 if let Some(ip) = ctx.dest_ip
198 && compiled.cidr_matcher.contains(ip)
199 {
200 return MatchDecision::Matched(action);
201 }
202 }
203 }
204 EngineRule::GeoIp { code, action } => {
205 #[cfg(feature = "geoip")]
206 if let Some(ref geoip) = self.geoip {
207 if ctx.dest_ip.is_none() {
208 return MatchDecision::NeedIp;
209 }
210 if let Some(ip) = ctx.dest_ip
211 && geoip.matches(ip, code)
212 {
213 return MatchDecision::Matched(action);
214 }
215 }
216 #[cfg(not(feature = "geoip"))]
217 {
218 let _ = (code, action);
219 }
221 }
222 EngineRule::Inline { rule, action } => match rule {
223 ParsedRule::Domain(d) => {
224 if ctx
225 .domain
226 .is_some_and(|domain| domain.eq_ignore_ascii_case(d))
227 {
228 return MatchDecision::Matched(action);
229 }
230 }
231 ParsedRule::DomainSuffix(s) => {
232 if ctx.domain.is_some_and(|domain| {
233 let lower = domain.to_ascii_lowercase();
234 let suffix_lower = s.to_ascii_lowercase();
235 lower == suffix_lower || lower.ends_with(&format!(".{suffix_lower}"))
236 }) {
237 return MatchDecision::Matched(action);
238 }
239 }
240 ParsedRule::DomainKeyword(k) => {
241 if ctx.domain.is_some_and(|domain| {
242 domain
243 .to_ascii_lowercase()
244 .contains(&k.to_ascii_lowercase())
245 }) {
246 return MatchDecision::Matched(action);
247 }
248 }
249 ParsedRule::IpCidr(net) => {
250 if ctx.dest_ip.is_none() {
251 return MatchDecision::NeedIp;
252 }
253 if ctx.dest_ip.is_some_and(|ip| net.contains(&ip)) {
254 return MatchDecision::Matched(action);
255 }
256 }
257 ParsedRule::DstPort(port) => {
258 if ctx.dest_port == *port {
259 return MatchDecision::Matched(action);
260 }
261 }
262 ParsedRule::SrcIpCidr(net) => {
263 if net.contains(&ctx.src_ip) {
264 return MatchDecision::Matched(action);
265 }
266 }
267 },
268 EngineRule::Final { action } => {
269 return MatchDecision::Matched(action);
270 }
271 }
272 }
273 MatchDecision::Matched(&self.final_action)
274 }
275
276 pub fn has_ip_rules(&self) -> bool {
278 let has_cidr = self
279 .compiled_sets
280 .values()
281 .any(|cs| !cs.cidr_matcher.is_empty());
282 let has_geoip = {
283 #[cfg(feature = "geoip")]
284 {
285 self.geoip.is_some()
286 && self
287 .rules
288 .iter()
289 .any(|r| matches!(r, EngineRule::GeoIp { .. }))
290 }
291 #[cfg(not(feature = "geoip"))]
292 {
293 false
294 }
295 };
296 has_cidr || has_geoip
297 }
298
299 pub fn rule_set_count(&self) -> usize {
301 self.compiled_sets.len()
302 }
303
304 pub fn rule_count(&self) -> usize {
306 self.rules.len()
307 }
308}
309
310fn inline_matches(rule: &ParsedRule, ctx: &MatchContext) -> bool {
312 match rule {
313 ParsedRule::Domain(d) => ctx
314 .domain
315 .is_some_and(|domain| domain.eq_ignore_ascii_case(d)),
316 ParsedRule::DomainSuffix(s) => ctx.domain.is_some_and(|domain| {
317 let lower = domain.to_ascii_lowercase();
318 let suffix_lower = s.to_ascii_lowercase();
319 lower == suffix_lower || lower.ends_with(&format!(".{suffix_lower}"))
320 }),
321 ParsedRule::DomainKeyword(k) => ctx.domain.is_some_and(|domain| {
322 domain
323 .to_ascii_lowercase()
324 .contains(&k.to_ascii_lowercase())
325 }),
326 ParsedRule::IpCidr(net) => ctx.dest_ip.is_some_and(|ip| net.contains(&ip)),
327 ParsedRule::DstPort(port) => ctx.dest_port == *port,
328 ParsedRule::SrcIpCidr(net) => net.contains(&ctx.src_ip),
329 }
330}
331
332impl std::fmt::Debug for RuleEngine {
333 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334 f.debug_struct("RuleEngine")
335 .field("rule_sets", &self.compiled_sets.len())
336 .field("rules", &self.rules.len())
337 .field("final_action", &self.final_action)
338 .finish()
339 }
340}
341
342#[derive(Debug)]
346pub struct RuleEngineBuilder {
347 rule_sets: HashMap<String, Vec<ParsedRule>>,
348 rules: Vec<EngineRule>,
349 final_action: Option<Action>,
350 #[cfg(feature = "geoip")]
351 geoip: Option<Arc<GeoipMatcher>>,
352}
353
354impl RuleEngineBuilder {
355 pub fn new() -> Self {
357 Self {
358 rule_sets: HashMap::new(),
359 rules: Vec::new(),
360 final_action: None,
361 #[cfg(feature = "geoip")]
362 geoip: None,
363 }
364 }
365
366 #[cfg(feature = "geoip")]
368 pub fn set_geoip(&mut self, matcher: Arc<GeoipMatcher>) -> &mut Self {
369 self.geoip = Some(matcher);
370 self
371 }
372
373 pub fn add_rule_set(&mut self, name: impl Into<String>, rules: Vec<ParsedRule>) -> &mut Self {
375 self.rule_sets.insert(name.into(), rules);
376 self
377 }
378
379 pub fn add_rule_set_rule(
381 &mut self,
382 rule_set_name: impl Into<String>,
383 action: Action,
384 ) -> &mut Self {
385 self.rules.push(EngineRule::RuleSet {
386 name: rule_set_name.into(),
387 action,
388 });
389 self
390 }
391
392 pub fn add_geoip_rule(&mut self, code: impl Into<String>, action: Action) -> &mut Self {
394 self.rules.push(EngineRule::GeoIp {
395 code: code.into(),
396 action,
397 });
398 self
399 }
400
401 pub fn add_inline_rule(&mut self, rule: ParsedRule, action: Action) -> &mut Self {
403 self.rules.push(EngineRule::Inline { rule, action });
404 self
405 }
406
407 pub fn set_final(&mut self, action: Action) -> &mut Self {
409 self.final_action = Some(action.clone());
410 self.rules.push(EngineRule::Final { action });
411 self
412 }
413
414 pub fn build(self) -> Result<RuleEngine, RulesError> {
416 let final_action = self.final_action.ok_or(RulesError::NoFinalRule)?;
417
418 for rule in &self.rules {
420 if let EngineRule::RuleSet { name, .. } = rule
421 && !self.rule_sets.contains_key(name)
422 {
423 return Err(RulesError::UnknownRuleSet(name.clone()));
424 }
425 }
426
427 {
429 let has_geoip_rules = self
430 .rules
431 .iter()
432 .any(|r| matches!(r, EngineRule::GeoIp { .. }));
433 if has_geoip_rules {
434 #[cfg(feature = "geoip")]
435 if self.geoip.is_none() {
436 tracing::warn!(
437 "GEOIP rules are configured but no GeoIP database is loaded; they will never match"
438 );
439 }
440 #[cfg(not(feature = "geoip"))]
441 tracing::warn!(
442 "GEOIP rules are configured but the 'geoip' feature is not enabled; they will never match"
443 );
444 }
445 }
446
447 let compiled_sets: HashMap<String, CompiledRuleSet> = self
449 .rule_sets
450 .into_iter()
451 .map(|(name, rules)| {
452 let count = rules.len();
453 let compiled = CompiledRuleSet::compile(rules);
454 debug!(
455 name = %name,
456 rules = count,
457 domains = compiled.domain_matcher.len(),
458 keywords = compiled.keyword_matcher.as_ref().map_or(0, |k| k.len()),
459 cidrs = compiled.cidr_matcher.len(),
460 "compiled rule-set"
461 );
462 (name, compiled)
463 })
464 .collect();
465
466 Ok(RuleEngine {
467 compiled_sets,
468 rules: self.rules,
469 final_action,
470 #[cfg(feature = "geoip")]
471 geoip: self.geoip,
472 })
473 }
474}
475
476impl Default for RuleEngineBuilder {
477 fn default() -> Self {
478 Self::new()
479 }
480}
481
482pub struct HotRuleEngine {
489 inner: ArcSwap<RuleEngine>,
490}
491
492impl HotRuleEngine {
493 pub fn new(engine: RuleEngine) -> Self {
495 Self {
496 inner: ArcSwap::new(Arc::new(engine)),
497 }
498 }
499
500 pub fn match_request(&self, ctx: &MatchContext) -> Action {
505 let engine = self.inner.load();
506 engine.match_request(ctx).clone()
507 }
508
509 pub fn has_ip_rules(&self) -> bool {
511 self.inner.load().has_ip_rules()
512 }
513
514 pub fn update(&self, engine: RuleEngine) {
516 self.inner.store(Arc::new(engine));
517 }
518
519 pub fn match_request_lazy_ip(&self, ctx: &MatchContext) -> Option<Action> {
525 let engine = self.inner.load();
526 match engine.match_request_lazy_ip(ctx) {
527 MatchDecision::Matched(action) => Some(action.clone()),
528 MatchDecision::NeedIp => None,
529 }
530 }
531
532 pub fn rule_set_count(&self) -> usize {
534 self.inner.load().rule_set_count()
535 }
536
537 pub fn rule_count(&self) -> usize {
539 self.inner.load().rule_count()
540 }
541}
542
543impl std::fmt::Debug for HotRuleEngine {
544 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
545 f.debug_struct("HotRuleEngine")
546 .field("inner", &*self.inner.load())
547 .finish()
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554 use std::net::{IpAddr, Ipv4Addr};
555
556 fn ctx_domain(domain: &str) -> MatchContext<'_> {
557 MatchContext {
558 domain: Some(domain),
559 dest_ip: None,
560 dest_port: 443,
561 src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
562 }
563 }
564
565 fn ctx_ip(ip: IpAddr) -> MatchContext<'static> {
566 MatchContext {
567 domain: None,
568 dest_ip: Some(ip),
569 dest_port: 443,
570 src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
571 }
572 }
573
574 #[test]
575 fn rule_set_domain_match() {
576 let mut builder = RuleEngineBuilder::new();
577 builder.add_rule_set(
578 "ads",
579 vec![
580 ParsedRule::DomainSuffix("ad.example.com".into()),
581 ParsedRule::DomainKeyword("ads".into()),
582 ],
583 );
584 builder.add_rule_set_rule("ads", Action::Reject);
585 builder.set_final(Action::Direct);
586
587 let engine = builder.build().unwrap();
588
589 assert_eq!(
590 engine.match_request(&ctx_domain("tracker.ad.example.com")),
591 &Action::Reject
592 );
593 assert_eq!(
594 engine.match_request(&ctx_domain("someads.com")),
595 &Action::Reject
596 );
597 assert_eq!(
598 engine.match_request(&ctx_domain("clean.example.com")),
599 &Action::Direct
600 );
601 }
602
603 #[test]
604 fn rule_set_ip_match() {
605 let mut builder = RuleEngineBuilder::new();
606 builder.add_rule_set(
607 "private",
608 vec![ParsedRule::IpCidr("192.168.0.0/16".parse().unwrap())],
609 );
610 builder.add_rule_set_rule("private", Action::Outbound("vpn".into()));
611 builder.set_final(Action::Direct);
612
613 let engine = builder.build().unwrap();
614
615 assert_eq!(
616 engine.match_request(&ctx_ip(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)))),
617 &Action::Outbound("vpn".into())
618 );
619 assert_eq!(
620 engine.match_request(&ctx_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))),
621 &Action::Direct
622 );
623 }
624
625 #[test]
626 fn rule_order_matters() {
627 let mut builder = RuleEngineBuilder::new();
628 builder.add_rule_set("block", vec![ParsedRule::Domain("example.com".into())]);
629 builder.add_rule_set(
630 "allow",
631 vec![ParsedRule::DomainSuffix("example.com".into())],
632 );
633 builder.add_rule_set_rule("block", Action::Reject);
634 builder.add_rule_set_rule("allow", Action::Direct);
635 builder.set_final(Action::Outbound("proxy".into()));
636
637 let engine = builder.build().unwrap();
638
639 assert_eq!(
641 engine.match_request(&ctx_domain("example.com")),
642 &Action::Reject
643 );
644 assert_eq!(
646 engine.match_request(&ctx_domain("sub.example.com")),
647 &Action::Direct
648 );
649 }
650
651 #[test]
652 fn final_action_catch_all() {
653 let mut builder = RuleEngineBuilder::new();
654 builder.set_final(Action::Direct);
655 let engine = builder.build().unwrap();
656
657 assert_eq!(
658 engine.match_request(&ctx_domain("anything.com")),
659 &Action::Direct
660 );
661 }
662
663 #[test]
664 fn no_final_rule_error() {
665 let builder = RuleEngineBuilder::new();
666 builder.build().unwrap_err();
667 }
668
669 #[test]
670 fn unknown_rule_set_error() {
671 let mut builder = RuleEngineBuilder::new();
672 builder.add_rule_set_rule("nonexistent", Action::Reject);
673 builder.set_final(Action::Direct);
674 builder.build().unwrap_err();
675 }
676
677 #[test]
678 fn inline_rule_match() {
679 let mut builder = RuleEngineBuilder::new();
680 builder.add_inline_rule(ParsedRule::Domain("blocked.com".into()), Action::Reject);
681 builder.add_inline_rule(
682 ParsedRule::IpCidr("10.0.0.0/8".parse().unwrap()),
683 Action::Outbound("internal".into()),
684 );
685 builder.set_final(Action::Direct);
686
687 let engine = builder.build().unwrap();
688
689 assert_eq!(
690 engine.match_request(&ctx_domain("blocked.com")),
691 &Action::Reject
692 );
693 assert_eq!(
694 engine.match_request(&ctx_ip(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))),
695 &Action::Outbound("internal".into())
696 );
697 assert_eq!(
698 engine.match_request(&ctx_domain("allowed.com")),
699 &Action::Direct
700 );
701 }
702
703 #[test]
704 fn engine_send_sync() {
705 fn assert_send_sync<T: Send + Sync>() {}
706 assert_send_sync::<RuleEngine>();
707 }
708
709 #[test]
710 fn lazy_match_domain_before_ip_rule() {
711 let mut builder = RuleEngineBuilder::new();
712 builder.add_inline_rule(ParsedRule::Domain("example.com".into()), Action::Reject);
713 builder.add_inline_rule(
714 ParsedRule::IpCidr("10.0.0.0/8".parse().unwrap()),
715 Action::Direct,
716 );
717 builder.set_final(Action::Direct);
718 let engine = builder.build().unwrap();
719
720 let ctx = ctx_domain("example.com");
721 match engine.match_request_lazy_ip(&ctx) {
722 MatchDecision::Matched(action) => assert_eq!(action, &Action::Reject),
723 MatchDecision::NeedIp => panic!("should not require IP for domain match"),
724 }
725 }
726
727 #[test]
728 fn lazy_match_needs_ip_when_ip_rule_first() {
729 let mut builder = RuleEngineBuilder::new();
730 builder.add_inline_rule(
731 ParsedRule::IpCidr("10.0.0.0/8".parse().unwrap()),
732 Action::Reject,
733 );
734 builder.add_inline_rule(ParsedRule::Domain("example.com".into()), Action::Direct);
735 builder.set_final(Action::Direct);
736 let engine = builder.build().unwrap();
737
738 let ctx = ctx_domain("example.com");
739 match engine.match_request_lazy_ip(&ctx) {
740 MatchDecision::Matched(_) => panic!("should require IP before evaluating later rules"),
741 MatchDecision::NeedIp => {}
742 }
743 }
744
745 #[test]
746 fn geoip_skipped_when_dest_ip_none() {
747 let mut builder = RuleEngineBuilder::new();
749 builder.add_geoip_rule("CN", Action::Reject);
750 builder.set_final(Action::Direct);
751 let engine = builder.build().unwrap();
752
753 let ctx = ctx_domain("example.cn");
755 assert_eq!(engine.match_request(&ctx), &Action::Direct);
756 }
757
758 #[test]
759 fn inline_dst_port_match() {
760 let mut builder = RuleEngineBuilder::new();
761 builder.add_inline_rule(ParsedRule::DstPort(80), Action::Reject);
762 builder.set_final(Action::Direct);
763 let engine = builder.build().unwrap();
764
765 let ctx = MatchContext {
766 domain: Some("example.com"),
767 dest_ip: None,
768 dest_port: 80,
769 src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
770 };
771 assert_eq!(engine.match_request(&ctx), &Action::Reject);
772
773 let ctx_miss = MatchContext {
774 domain: Some("example.com"),
775 dest_ip: None,
776 dest_port: 443,
777 src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
778 };
779 assert_eq!(engine.match_request(&ctx_miss), &Action::Direct);
780 }
781
782 #[test]
783 fn inline_src_ip_cidr_match() {
784 let mut builder = RuleEngineBuilder::new();
785 builder.add_inline_rule(
786 ParsedRule::SrcIpCidr("10.0.0.0/8".parse().unwrap()),
787 Action::Reject,
788 );
789 builder.set_final(Action::Direct);
790 let engine = builder.build().unwrap();
791
792 let ctx_match = MatchContext {
793 domain: Some("example.com"),
794 dest_ip: None,
795 dest_port: 443,
796 src_ip: IpAddr::V4(Ipv4Addr::new(10, 1, 2, 3)),
797 };
798 assert_eq!(engine.match_request(&ctx_match), &Action::Reject);
799
800 let ctx_miss = MatchContext {
801 domain: Some("example.com"),
802 dest_ip: None,
803 dest_port: 443,
804 src_ip: IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
805 };
806 assert_eq!(engine.match_request(&ctx_miss), &Action::Direct);
807 }
808
809 #[test]
810 fn rule_set_with_dst_port_and_src_cidr() {
811 let mut builder = RuleEngineBuilder::new();
812 builder.add_rule_set(
813 "mixed",
814 vec![
815 ParsedRule::DomainSuffix("example.com".into()),
816 ParsedRule::DstPort(8080),
817 ParsedRule::SrcIpCidr("172.16.0.0/12".parse().unwrap()),
818 ],
819 );
820 builder.add_rule_set_rule("mixed", Action::Outbound("proxy".into()));
821 builder.set_final(Action::Direct);
822 let engine = builder.build().unwrap();
823
824 let ctx_port = MatchContext {
826 domain: Some("other.com"),
827 dest_ip: None,
828 dest_port: 8080,
829 src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
830 };
831 assert_eq!(
832 engine.match_request(&ctx_port),
833 &Action::Outbound("proxy".into())
834 );
835
836 let ctx_src = MatchContext {
838 domain: Some("other.com"),
839 dest_ip: None,
840 dest_port: 443,
841 src_ip: IpAddr::V4(Ipv4Addr::new(172, 16, 5, 1)),
842 };
843 assert_eq!(
844 engine.match_request(&ctx_src),
845 &Action::Outbound("proxy".into())
846 );
847
848 let ctx_miss = MatchContext {
850 domain: Some("other.com"),
851 dest_ip: None,
852 dest_port: 443,
853 src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
854 };
855 assert_eq!(engine.match_request(&ctx_miss), &Action::Direct);
856 }
857
858 #[test]
859 fn lazy_match_dst_port_no_dns_needed() {
860 let mut builder = RuleEngineBuilder::new();
862 builder.add_inline_rule(ParsedRule::DstPort(80), Action::Reject);
863 builder.set_final(Action::Direct);
864 let engine = builder.build().unwrap();
865
866 let ctx = MatchContext {
867 domain: Some("example.com"),
868 dest_ip: None,
869 dest_port: 80,
870 src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
871 };
872 match engine.match_request_lazy_ip(&ctx) {
873 MatchDecision::Matched(action) => assert_eq!(action, &Action::Reject),
874 MatchDecision::NeedIp => panic!("DST-PORT should not require IP resolution"),
875 }
876 }
877
878 #[test]
879 fn lazy_match_src_ip_cidr_no_dns_needed() {
880 let mut builder = RuleEngineBuilder::new();
882 builder.add_inline_rule(
883 ParsedRule::SrcIpCidr("10.0.0.0/8".parse().unwrap()),
884 Action::Reject,
885 );
886 builder.set_final(Action::Direct);
887 let engine = builder.build().unwrap();
888
889 let ctx = MatchContext {
890 domain: Some("example.com"),
891 dest_ip: None,
892 dest_port: 443,
893 src_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
894 };
895 match engine.match_request_lazy_ip(&ctx) {
896 MatchDecision::Matched(action) => assert_eq!(action, &Action::Reject),
897 MatchDecision::NeedIp => panic!("SRC-IP-CIDR should not require IP resolution"),
898 }
899 }
900
901 #[test]
902 fn lazy_match_rule_set_port_before_cidr() {
903 let mut builder = RuleEngineBuilder::new();
906 builder.add_rule_set(
907 "mixed",
908 vec![
909 ParsedRule::IpCidr("10.0.0.0/8".parse().unwrap()),
910 ParsedRule::DstPort(8080),
911 ],
912 );
913 builder.add_rule_set_rule("mixed", Action::Reject);
914 builder.set_final(Action::Direct);
915 let engine = builder.build().unwrap();
916
917 let ctx = MatchContext {
919 domain: Some("example.com"),
920 dest_ip: None,
921 dest_port: 8080,
922 src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
923 };
924 match engine.match_request_lazy_ip(&ctx) {
925 MatchDecision::Matched(action) => assert_eq!(action, &Action::Reject),
926 MatchDecision::NeedIp => {
927 panic!("DST-PORT in rule-set should match before CIDR triggers NeedIp")
928 }
929 }
930
931 let ctx_miss = MatchContext {
933 domain: Some("other.com"),
934 dest_ip: None,
935 dest_port: 443,
936 src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
937 };
938 match engine.match_request_lazy_ip(&ctx_miss) {
939 MatchDecision::Matched(_) => {
940 panic!("should return NeedIp for unmatched port with CIDR")
941 }
942 MatchDecision::NeedIp => {}
943 }
944 }
945
946 #[test]
947 fn domain_suffix_leading_dot_in_rule_set() {
948 let mut builder = RuleEngineBuilder::new();
951 builder.add_rule_set(
952 "test",
953 vec![ParsedRule::DomainSuffix(".example.com".into())],
954 );
955 builder.add_rule_set_rule("test", Action::Reject);
956 builder.set_final(Action::Direct);
957 let engine = builder.build().unwrap();
958
959 assert_eq!(
960 engine.match_request(&ctx_domain("example.com")),
961 &Action::Reject
962 );
963 assert_eq!(
964 engine.match_request(&ctx_domain("sub.example.com")),
965 &Action::Reject
966 );
967 assert_eq!(
968 engine.match_request(&ctx_domain("other.com")),
969 &Action::Direct
970 );
971 }
972}