Skip to main content

trojan_rules/
engine.rs

1//! Rule engine: compiles rule-sets and matches requests.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use arc_swap::ArcSwap;
7use tracing::debug;
8
9use crate::error::RulesError;
10#[cfg(feature = "geoip")]
11use crate::matcher::GeoipMatcher;
12use crate::matcher::{CidrMatcher, DomainMatcher, KeywordMatcher};
13use crate::rule::{Action, EngineRule, MatchContext, ParsedRule};
14
15/// Result of a lazy match attempt that may require DNS resolution.
16#[derive(Debug)]
17pub enum MatchDecision<'a> {
18    /// A rule matched without needing an IP address.
19    Matched(&'a Action),
20    /// An IP-based rule appeared before any match; resolve and retry.
21    NeedIp,
22}
23
24/// A compiled rule-set ready for matching.
25#[derive(Debug)]
26struct CompiledRuleSet {
27    domain_matcher: DomainMatcher,
28    keyword_matcher: Option<KeywordMatcher>,
29    cidr_matcher: CidrMatcher,
30    /// Destination ports to match.
31    dst_ports: Vec<u16>,
32    /// Source IP CIDR matcher.
33    src_cidr_matcher: CidrMatcher,
34}
35
36impl CompiledRuleSet {
37    /// Compile a list of parsed rules into optimized matchers.
38    fn compile(rules: Vec<ParsedRule>) -> Self {
39        let mut domain_matcher = DomainMatcher::new();
40        let mut keywords = Vec::new();
41        let mut v4_cidrs = Vec::new();
42        let mut v6_cidrs = Vec::new();
43        let mut dst_ports = Vec::new();
44        let mut src_v4_cidrs = Vec::new();
45        let mut src_v6_cidrs = Vec::new();
46
47        for rule in rules {
48            match rule {
49                ParsedRule::Domain(d) => domain_matcher.add_exact(&d),
50                ParsedRule::DomainSuffix(d) => domain_matcher.add_suffix(&d),
51                ParsedRule::DomainKeyword(k) => keywords.push(k),
52                ParsedRule::IpCidr(net) => match net {
53                    ipnet::IpNet::V4(v4) => v4_cidrs.push(v4),
54                    ipnet::IpNet::V6(v6) => v6_cidrs.push(v6),
55                },
56                ParsedRule::DstPort(port) => dst_ports.push(port),
57                ParsedRule::SrcIpCidr(net) => match net {
58                    ipnet::IpNet::V4(v4) => src_v4_cidrs.push(v4),
59                    ipnet::IpNet::V6(v6) => src_v6_cidrs.push(v6),
60                },
61            }
62        }
63
64        Self {
65            domain_matcher,
66            keyword_matcher: KeywordMatcher::new(keywords),
67            cidr_matcher: CidrMatcher::new(v4_cidrs, v6_cidrs),
68            dst_ports,
69            src_cidr_matcher: CidrMatcher::new(src_v4_cidrs, src_v6_cidrs),
70        }
71    }
72
73    /// Check if a request context matches this rule-set.
74    fn matches(&self, ctx: &MatchContext) -> bool {
75        // Try domain matches first
76        if let Some(domain) = ctx.domain {
77            if self.domain_matcher.matches(domain) {
78                return true;
79            }
80            if let Some(ref kw) = self.keyword_matcher
81                && kw.matches(domain)
82            {
83                return true;
84            }
85        }
86
87        // Try IP matches
88        if let Some(ip) = ctx.dest_ip
89            && self.cidr_matcher.contains(ip)
90        {
91            return true;
92        }
93
94        // Try port matches
95        if !self.dst_ports.is_empty() && self.dst_ports.contains(&ctx.dest_port) {
96            return true;
97        }
98
99        // Try source IP matches
100        if self.src_cidr_matcher.contains(ctx.src_ip) {
101            return true;
102        }
103
104        false
105    }
106}
107
108/// The rule engine: holds compiled rule-sets and an ordered list of rules.
109///
110/// Send + Sync, designed to be shared via `Arc<RuleEngine>`.
111pub struct RuleEngine {
112    compiled_sets: HashMap<String, CompiledRuleSet>,
113    rules: Vec<EngineRule>,
114    final_action: Action,
115    #[cfg(feature = "geoip")]
116    geoip: Option<Arc<GeoipMatcher>>,
117}
118
119impl RuleEngine {
120    /// Match a request against the rules and return the action to take.
121    ///
122    /// Rules are evaluated in order; the first match wins.
123    /// If no rule matches, the FINAL action is returned.
124    pub fn match_request(&self, ctx: &MatchContext) -> &Action {
125        for rule in &self.rules {
126            match rule {
127                EngineRule::RuleSet { name, action } => {
128                    if let Some(compiled) = self.compiled_sets.get(name)
129                        && compiled.matches(ctx)
130                    {
131                        return action;
132                    }
133                }
134                EngineRule::GeoIp { code, action } => {
135                    #[cfg(feature = "geoip")]
136                    if let Some(ref geoip) = self.geoip {
137                        // Only match on dest_ip; skip when no resolved IP is
138                        // available.  Falling back to src_ip would mis-route
139                        // domain requests based on the *client's* country.
140                        if let Some(ip) = ctx.dest_ip
141                            && geoip.matches(ip, code)
142                        {
143                            return action;
144                        }
145                    }
146                    #[cfg(not(feature = "geoip"))]
147                    {
148                        let _ = (code, action);
149                        // GEOIP matching requires the "geoip" feature. Skip.
150                    }
151                }
152                EngineRule::Inline { rule, action } => {
153                    if inline_matches(rule, ctx) {
154                        return action;
155                    }
156                }
157                EngineRule::Final { action } => {
158                    return action;
159                }
160            }
161        }
162        &self.final_action
163    }
164
165    /// Try to match without an IP; if an IP-based rule appears before any
166    /// match, returns `NeedIp` so callers can resolve and retry.
167    pub fn match_request_lazy_ip(&self, ctx: &MatchContext) -> MatchDecision<'_> {
168        for rule in &self.rules {
169            match rule {
170                EngineRule::RuleSet { name, action } => {
171                    if let Some(compiled) = self.compiled_sets.get(name) {
172                        // Domain matchers don't need IP.
173                        if let Some(domain) = ctx.domain
174                            && (compiled.domain_matcher.matches(domain)
175                                || compiled
176                                    .keyword_matcher
177                                    .as_ref()
178                                    .is_some_and(|kw| kw.matches(domain)))
179                        {
180                            return MatchDecision::Matched(action);
181                        }
182
183                        // DST-PORT and SRC-IP-CIDR don't need DNS.
184                        if !compiled.dst_ports.is_empty()
185                            && compiled.dst_ports.contains(&ctx.dest_port)
186                        {
187                            return MatchDecision::Matched(action);
188                        }
189                        if compiled.src_cidr_matcher.contains(ctx.src_ip) {
190                            return MatchDecision::Matched(action);
191                        }
192
193                        // Dest-IP CIDR requires resolved IP.
194                        if ctx.dest_ip.is_none() && !compiled.cidr_matcher.is_empty() {
195                            return MatchDecision::NeedIp;
196                        }
197                        if let Some(ip) = ctx.dest_ip
198                            && compiled.cidr_matcher.contains(ip)
199                        {
200                            return MatchDecision::Matched(action);
201                        }
202                    }
203                }
204                EngineRule::GeoIp { code, action } => {
205                    #[cfg(feature = "geoip")]
206                    if let Some(ref geoip) = self.geoip {
207                        if ctx.dest_ip.is_none() {
208                            return MatchDecision::NeedIp;
209                        }
210                        if let Some(ip) = ctx.dest_ip
211                            && geoip.matches(ip, code)
212                        {
213                            return MatchDecision::Matched(action);
214                        }
215                    }
216                    #[cfg(not(feature = "geoip"))]
217                    {
218                        let _ = (code, action);
219                        // GEOIP matching requires the "geoip" feature. Skip.
220                    }
221                }
222                EngineRule::Inline { rule, action } => match rule {
223                    ParsedRule::Domain(d) => {
224                        if ctx
225                            .domain
226                            .is_some_and(|domain| domain.eq_ignore_ascii_case(d))
227                        {
228                            return MatchDecision::Matched(action);
229                        }
230                    }
231                    ParsedRule::DomainSuffix(s) => {
232                        if ctx.domain.is_some_and(|domain| {
233                            let lower = domain.to_ascii_lowercase();
234                            let suffix_lower = s.to_ascii_lowercase();
235                            lower == suffix_lower || lower.ends_with(&format!(".{suffix_lower}"))
236                        }) {
237                            return MatchDecision::Matched(action);
238                        }
239                    }
240                    ParsedRule::DomainKeyword(k) => {
241                        if ctx.domain.is_some_and(|domain| {
242                            domain
243                                .to_ascii_lowercase()
244                                .contains(&k.to_ascii_lowercase())
245                        }) {
246                            return MatchDecision::Matched(action);
247                        }
248                    }
249                    ParsedRule::IpCidr(net) => {
250                        if ctx.dest_ip.is_none() {
251                            return MatchDecision::NeedIp;
252                        }
253                        if ctx.dest_ip.is_some_and(|ip| net.contains(&ip)) {
254                            return MatchDecision::Matched(action);
255                        }
256                    }
257                    ParsedRule::DstPort(port) => {
258                        if ctx.dest_port == *port {
259                            return MatchDecision::Matched(action);
260                        }
261                    }
262                    ParsedRule::SrcIpCidr(net) => {
263                        if net.contains(&ctx.src_ip) {
264                            return MatchDecision::Matched(action);
265                        }
266                    }
267                },
268                EngineRule::Final { action } => {
269                    return MatchDecision::Matched(action);
270                }
271            }
272        }
273        MatchDecision::Matched(&self.final_action)
274    }
275
276    /// Returns true if this engine has IP-based rules that may require DNS resolution.
277    pub fn has_ip_rules(&self) -> bool {
278        let has_cidr = self
279            .compiled_sets
280            .values()
281            .any(|cs| !cs.cidr_matcher.is_empty());
282        let has_geoip = {
283            #[cfg(feature = "geoip")]
284            {
285                self.geoip.is_some()
286                    && self
287                        .rules
288                        .iter()
289                        .any(|r| matches!(r, EngineRule::GeoIp { .. }))
290            }
291            #[cfg(not(feature = "geoip"))]
292            {
293                false
294            }
295        };
296        has_cidr || has_geoip
297    }
298
299    /// Number of compiled rule-sets.
300    pub fn rule_set_count(&self) -> usize {
301        self.compiled_sets.len()
302    }
303
304    /// Number of engine rules (including FINAL).
305    pub fn rule_count(&self) -> usize {
306        self.rules.len()
307    }
308}
309
310/// Check if a single inline rule matches the context.
311fn inline_matches(rule: &ParsedRule, ctx: &MatchContext) -> bool {
312    match rule {
313        ParsedRule::Domain(d) => ctx
314            .domain
315            .is_some_and(|domain| domain.eq_ignore_ascii_case(d)),
316        ParsedRule::DomainSuffix(s) => ctx.domain.is_some_and(|domain| {
317            let lower = domain.to_ascii_lowercase();
318            let suffix_lower = s.to_ascii_lowercase();
319            lower == suffix_lower || lower.ends_with(&format!(".{suffix_lower}"))
320        }),
321        ParsedRule::DomainKeyword(k) => ctx.domain.is_some_and(|domain| {
322            domain
323                .to_ascii_lowercase()
324                .contains(&k.to_ascii_lowercase())
325        }),
326        ParsedRule::IpCidr(net) => ctx.dest_ip.is_some_and(|ip| net.contains(&ip)),
327        ParsedRule::DstPort(port) => ctx.dest_port == *port,
328        ParsedRule::SrcIpCidr(net) => net.contains(&ctx.src_ip),
329    }
330}
331
332impl std::fmt::Debug for RuleEngine {
333    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334        f.debug_struct("RuleEngine")
335            .field("rule_sets", &self.compiled_sets.len())
336            .field("rules", &self.rules.len())
337            .field("final_action", &self.final_action)
338            .finish()
339    }
340}
341
342// ── Builder ──
343
344/// Builder for constructing a `RuleEngine`.
345#[derive(Debug)]
346pub struct RuleEngineBuilder {
347    rule_sets: HashMap<String, Vec<ParsedRule>>,
348    rules: Vec<EngineRule>,
349    final_action: Option<Action>,
350    #[cfg(feature = "geoip")]
351    geoip: Option<Arc<GeoipMatcher>>,
352}
353
354impl RuleEngineBuilder {
355    /// Create a new builder.
356    pub fn new() -> Self {
357        Self {
358            rule_sets: HashMap::new(),
359            rules: Vec::new(),
360            final_action: None,
361            #[cfg(feature = "geoip")]
362            geoip: None,
363        }
364    }
365
366    /// Set the GeoIP matcher for GEOIP rule matching.
367    #[cfg(feature = "geoip")]
368    pub fn set_geoip(&mut self, matcher: Arc<GeoipMatcher>) -> &mut Self {
369        self.geoip = Some(matcher);
370        self
371    }
372
373    /// Add a named rule-set (parsed rules).
374    pub fn add_rule_set(&mut self, name: impl Into<String>, rules: Vec<ParsedRule>) -> &mut Self {
375        self.rule_sets.insert(name.into(), rules);
376        self
377    }
378
379    /// Add a rule that references a named rule-set.
380    pub fn add_rule_set_rule(
381        &mut self,
382        rule_set_name: impl Into<String>,
383        action: Action,
384    ) -> &mut Self {
385        self.rules.push(EngineRule::RuleSet {
386            name: rule_set_name.into(),
387            action,
388        });
389        self
390    }
391
392    /// Add a GEOIP rule.
393    pub fn add_geoip_rule(&mut self, code: impl Into<String>, action: Action) -> &mut Self {
394        self.rules.push(EngineRule::GeoIp {
395            code: code.into(),
396            action,
397        });
398        self
399    }
400
401    /// Add an inline rule (single rule type + value).
402    pub fn add_inline_rule(&mut self, rule: ParsedRule, action: Action) -> &mut Self {
403        self.rules.push(EngineRule::Inline { rule, action });
404        self
405    }
406
407    /// Set the FINAL (catch-all) action.
408    pub fn set_final(&mut self, action: Action) -> &mut Self {
409        self.final_action = Some(action.clone());
410        self.rules.push(EngineRule::Final { action });
411        self
412    }
413
414    /// Build the rule engine.
415    pub fn build(self) -> Result<RuleEngine, RulesError> {
416        let final_action = self.final_action.ok_or(RulesError::NoFinalRule)?;
417
418        // Validate that all rule-set references exist
419        for rule in &self.rules {
420            if let EngineRule::RuleSet { name, .. } = rule
421                && !self.rule_sets.contains_key(name)
422            {
423                return Err(RulesError::UnknownRuleSet(name.clone()));
424            }
425        }
426
427        // Warn when GEOIP rules are present but cannot be evaluated
428        {
429            let has_geoip_rules = self
430                .rules
431                .iter()
432                .any(|r| matches!(r, EngineRule::GeoIp { .. }));
433            if has_geoip_rules {
434                #[cfg(feature = "geoip")]
435                if self.geoip.is_none() {
436                    tracing::warn!(
437                        "GEOIP rules are configured but no GeoIP database is loaded; they will never match"
438                    );
439                }
440                #[cfg(not(feature = "geoip"))]
441                tracing::warn!(
442                    "GEOIP rules are configured but the 'geoip' feature is not enabled; they will never match"
443                );
444            }
445        }
446
447        // Compile rule-sets
448        let compiled_sets: HashMap<String, CompiledRuleSet> = self
449            .rule_sets
450            .into_iter()
451            .map(|(name, rules)| {
452                let count = rules.len();
453                let compiled = CompiledRuleSet::compile(rules);
454                debug!(
455                    name = %name,
456                    rules = count,
457                    domains = compiled.domain_matcher.len(),
458                    keywords = compiled.keyword_matcher.as_ref().map_or(0, |k| k.len()),
459                    cidrs = compiled.cidr_matcher.len(),
460                    "compiled rule-set"
461                );
462                (name, compiled)
463            })
464            .collect();
465
466        Ok(RuleEngine {
467            compiled_sets,
468            rules: self.rules,
469            final_action,
470            #[cfg(feature = "geoip")]
471            geoip: self.geoip,
472        })
473    }
474}
475
476impl Default for RuleEngineBuilder {
477    fn default() -> Self {
478        Self::new()
479    }
480}
481
482// ── Hot-reloadable engine ──
483
484/// A hot-reloadable wrapper around `RuleEngine`.
485///
486/// Uses `ArcSwap` for lock-free reads and atomic replacement.
487/// All reads go through `arc_swap::Guard` which is wait-free.
488pub struct HotRuleEngine {
489    inner: ArcSwap<RuleEngine>,
490}
491
492impl HotRuleEngine {
493    /// Create a new hot-reloadable engine with the given initial engine.
494    pub fn new(engine: RuleEngine) -> Self {
495        Self {
496            inner: ArcSwap::new(Arc::new(engine)),
497        }
498    }
499
500    /// Match a request against the current rules.
501    ///
502    /// Returns an owned `Action` (cloned from the engine) so the caller
503    /// does not hold a borrow on the engine across await points.
504    pub fn match_request(&self, ctx: &MatchContext) -> Action {
505        let engine = self.inner.load();
506        engine.match_request(ctx).clone()
507    }
508
509    /// Returns true if the current engine has IP-based rules.
510    pub fn has_ip_rules(&self) -> bool {
511        self.inner.load().has_ip_rules()
512    }
513
514    /// Atomically replace the engine with a new one.
515    pub fn update(&self, engine: RuleEngine) {
516        self.inner.store(Arc::new(engine));
517    }
518
519    /// Lazy IP matching: returns `Some(action)` if a rule matched without DNS,
520    /// `None` if an IP-based rule appeared first and DNS resolution is needed.
521    ///
522    /// The caller should resolve DNS and then call `match_request()` with the
523    /// resolved IP when `None` is returned.
524    pub fn match_request_lazy_ip(&self, ctx: &MatchContext) -> Option<Action> {
525        let engine = self.inner.load();
526        match engine.match_request_lazy_ip(ctx) {
527            MatchDecision::Matched(action) => Some(action.clone()),
528            MatchDecision::NeedIp => None,
529        }
530    }
531
532    /// Number of compiled rule-sets in the current engine.
533    pub fn rule_set_count(&self) -> usize {
534        self.inner.load().rule_set_count()
535    }
536
537    /// Number of engine rules in the current engine.
538    pub fn rule_count(&self) -> usize {
539        self.inner.load().rule_count()
540    }
541}
542
543impl std::fmt::Debug for HotRuleEngine {
544    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
545        f.debug_struct("HotRuleEngine")
546            .field("inner", &*self.inner.load())
547            .finish()
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554    use std::net::{IpAddr, Ipv4Addr};
555
556    fn ctx_domain(domain: &str) -> MatchContext<'_> {
557        MatchContext {
558            domain: Some(domain),
559            dest_ip: None,
560            dest_port: 443,
561            src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
562        }
563    }
564
565    fn ctx_ip(ip: IpAddr) -> MatchContext<'static> {
566        MatchContext {
567            domain: None,
568            dest_ip: Some(ip),
569            dest_port: 443,
570            src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
571        }
572    }
573
574    #[test]
575    fn rule_set_domain_match() {
576        let mut builder = RuleEngineBuilder::new();
577        builder.add_rule_set(
578            "ads",
579            vec![
580                ParsedRule::DomainSuffix("ad.example.com".into()),
581                ParsedRule::DomainKeyword("ads".into()),
582            ],
583        );
584        builder.add_rule_set_rule("ads", Action::Reject);
585        builder.set_final(Action::Direct);
586
587        let engine = builder.build().unwrap();
588
589        assert_eq!(
590            engine.match_request(&ctx_domain("tracker.ad.example.com")),
591            &Action::Reject
592        );
593        assert_eq!(
594            engine.match_request(&ctx_domain("someads.com")),
595            &Action::Reject
596        );
597        assert_eq!(
598            engine.match_request(&ctx_domain("clean.example.com")),
599            &Action::Direct
600        );
601    }
602
603    #[test]
604    fn rule_set_ip_match() {
605        let mut builder = RuleEngineBuilder::new();
606        builder.add_rule_set(
607            "private",
608            vec![ParsedRule::IpCidr("192.168.0.0/16".parse().unwrap())],
609        );
610        builder.add_rule_set_rule("private", Action::Outbound("vpn".into()));
611        builder.set_final(Action::Direct);
612
613        let engine = builder.build().unwrap();
614
615        assert_eq!(
616            engine.match_request(&ctx_ip(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)))),
617            &Action::Outbound("vpn".into())
618        );
619        assert_eq!(
620            engine.match_request(&ctx_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))),
621            &Action::Direct
622        );
623    }
624
625    #[test]
626    fn rule_order_matters() {
627        let mut builder = RuleEngineBuilder::new();
628        builder.add_rule_set("block", vec![ParsedRule::Domain("example.com".into())]);
629        builder.add_rule_set(
630            "allow",
631            vec![ParsedRule::DomainSuffix("example.com".into())],
632        );
633        builder.add_rule_set_rule("block", Action::Reject);
634        builder.add_rule_set_rule("allow", Action::Direct);
635        builder.set_final(Action::Outbound("proxy".into()));
636
637        let engine = builder.build().unwrap();
638
639        // "example.com" matches "block" first → REJECT
640        assert_eq!(
641            engine.match_request(&ctx_domain("example.com")),
642            &Action::Reject
643        );
644        // "sub.example.com" doesn't match "block" (exact), matches "allow" (suffix) → DIRECT
645        assert_eq!(
646            engine.match_request(&ctx_domain("sub.example.com")),
647            &Action::Direct
648        );
649    }
650
651    #[test]
652    fn final_action_catch_all() {
653        let mut builder = RuleEngineBuilder::new();
654        builder.set_final(Action::Direct);
655        let engine = builder.build().unwrap();
656
657        assert_eq!(
658            engine.match_request(&ctx_domain("anything.com")),
659            &Action::Direct
660        );
661    }
662
663    #[test]
664    fn no_final_rule_error() {
665        let builder = RuleEngineBuilder::new();
666        builder.build().unwrap_err();
667    }
668
669    #[test]
670    fn unknown_rule_set_error() {
671        let mut builder = RuleEngineBuilder::new();
672        builder.add_rule_set_rule("nonexistent", Action::Reject);
673        builder.set_final(Action::Direct);
674        builder.build().unwrap_err();
675    }
676
677    #[test]
678    fn inline_rule_match() {
679        let mut builder = RuleEngineBuilder::new();
680        builder.add_inline_rule(ParsedRule::Domain("blocked.com".into()), Action::Reject);
681        builder.add_inline_rule(
682            ParsedRule::IpCidr("10.0.0.0/8".parse().unwrap()),
683            Action::Outbound("internal".into()),
684        );
685        builder.set_final(Action::Direct);
686
687        let engine = builder.build().unwrap();
688
689        assert_eq!(
690            engine.match_request(&ctx_domain("blocked.com")),
691            &Action::Reject
692        );
693        assert_eq!(
694            engine.match_request(&ctx_ip(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))),
695            &Action::Outbound("internal".into())
696        );
697        assert_eq!(
698            engine.match_request(&ctx_domain("allowed.com")),
699            &Action::Direct
700        );
701    }
702
703    #[test]
704    fn engine_send_sync() {
705        fn assert_send_sync<T: Send + Sync>() {}
706        assert_send_sync::<RuleEngine>();
707    }
708
709    #[test]
710    fn lazy_match_domain_before_ip_rule() {
711        let mut builder = RuleEngineBuilder::new();
712        builder.add_inline_rule(ParsedRule::Domain("example.com".into()), Action::Reject);
713        builder.add_inline_rule(
714            ParsedRule::IpCidr("10.0.0.0/8".parse().unwrap()),
715            Action::Direct,
716        );
717        builder.set_final(Action::Direct);
718        let engine = builder.build().unwrap();
719
720        let ctx = ctx_domain("example.com");
721        match engine.match_request_lazy_ip(&ctx) {
722            MatchDecision::Matched(action) => assert_eq!(action, &Action::Reject),
723            MatchDecision::NeedIp => panic!("should not require IP for domain match"),
724        }
725    }
726
727    #[test]
728    fn lazy_match_needs_ip_when_ip_rule_first() {
729        let mut builder = RuleEngineBuilder::new();
730        builder.add_inline_rule(
731            ParsedRule::IpCidr("10.0.0.0/8".parse().unwrap()),
732            Action::Reject,
733        );
734        builder.add_inline_rule(ParsedRule::Domain("example.com".into()), Action::Direct);
735        builder.set_final(Action::Direct);
736        let engine = builder.build().unwrap();
737
738        let ctx = ctx_domain("example.com");
739        match engine.match_request_lazy_ip(&ctx) {
740            MatchDecision::Matched(_) => panic!("should require IP before evaluating later rules"),
741            MatchDecision::NeedIp => {}
742        }
743    }
744
745    #[test]
746    fn geoip_skipped_when_dest_ip_none() {
747        // GEOIP rule should not match when dest_ip is None (domain-only request)
748        let mut builder = RuleEngineBuilder::new();
749        builder.add_geoip_rule("CN", Action::Reject);
750        builder.set_final(Action::Direct);
751        let engine = builder.build().unwrap();
752
753        // domain-only context: dest_ip is None
754        let ctx = ctx_domain("example.cn");
755        assert_eq!(engine.match_request(&ctx), &Action::Direct);
756    }
757
758    #[test]
759    fn inline_dst_port_match() {
760        let mut builder = RuleEngineBuilder::new();
761        builder.add_inline_rule(ParsedRule::DstPort(80), Action::Reject);
762        builder.set_final(Action::Direct);
763        let engine = builder.build().unwrap();
764
765        let ctx = MatchContext {
766            domain: Some("example.com"),
767            dest_ip: None,
768            dest_port: 80,
769            src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
770        };
771        assert_eq!(engine.match_request(&ctx), &Action::Reject);
772
773        let ctx_miss = MatchContext {
774            domain: Some("example.com"),
775            dest_ip: None,
776            dest_port: 443,
777            src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
778        };
779        assert_eq!(engine.match_request(&ctx_miss), &Action::Direct);
780    }
781
782    #[test]
783    fn inline_src_ip_cidr_match() {
784        let mut builder = RuleEngineBuilder::new();
785        builder.add_inline_rule(
786            ParsedRule::SrcIpCidr("10.0.0.0/8".parse().unwrap()),
787            Action::Reject,
788        );
789        builder.set_final(Action::Direct);
790        let engine = builder.build().unwrap();
791
792        let ctx_match = MatchContext {
793            domain: Some("example.com"),
794            dest_ip: None,
795            dest_port: 443,
796            src_ip: IpAddr::V4(Ipv4Addr::new(10, 1, 2, 3)),
797        };
798        assert_eq!(engine.match_request(&ctx_match), &Action::Reject);
799
800        let ctx_miss = MatchContext {
801            domain: Some("example.com"),
802            dest_ip: None,
803            dest_port: 443,
804            src_ip: IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
805        };
806        assert_eq!(engine.match_request(&ctx_miss), &Action::Direct);
807    }
808
809    #[test]
810    fn rule_set_with_dst_port_and_src_cidr() {
811        let mut builder = RuleEngineBuilder::new();
812        builder.add_rule_set(
813            "mixed",
814            vec![
815                ParsedRule::DomainSuffix("example.com".into()),
816                ParsedRule::DstPort(8080),
817                ParsedRule::SrcIpCidr("172.16.0.0/12".parse().unwrap()),
818            ],
819        );
820        builder.add_rule_set_rule("mixed", Action::Outbound("proxy".into()));
821        builder.set_final(Action::Direct);
822        let engine = builder.build().unwrap();
823
824        // Match via dst port
825        let ctx_port = MatchContext {
826            domain: Some("other.com"),
827            dest_ip: None,
828            dest_port: 8080,
829            src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
830        };
831        assert_eq!(
832            engine.match_request(&ctx_port),
833            &Action::Outbound("proxy".into())
834        );
835
836        // Match via src ip
837        let ctx_src = MatchContext {
838            domain: Some("other.com"),
839            dest_ip: None,
840            dest_port: 443,
841            src_ip: IpAddr::V4(Ipv4Addr::new(172, 16, 5, 1)),
842        };
843        assert_eq!(
844            engine.match_request(&ctx_src),
845            &Action::Outbound("proxy".into())
846        );
847
848        // No match
849        let ctx_miss = MatchContext {
850            domain: Some("other.com"),
851            dest_ip: None,
852            dest_port: 443,
853            src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
854        };
855        assert_eq!(engine.match_request(&ctx_miss), &Action::Direct);
856    }
857
858    #[test]
859    fn lazy_match_dst_port_no_dns_needed() {
860        // DST-PORT rules should not require DNS resolution
861        let mut builder = RuleEngineBuilder::new();
862        builder.add_inline_rule(ParsedRule::DstPort(80), Action::Reject);
863        builder.set_final(Action::Direct);
864        let engine = builder.build().unwrap();
865
866        let ctx = MatchContext {
867            domain: Some("example.com"),
868            dest_ip: None,
869            dest_port: 80,
870            src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
871        };
872        match engine.match_request_lazy_ip(&ctx) {
873            MatchDecision::Matched(action) => assert_eq!(action, &Action::Reject),
874            MatchDecision::NeedIp => panic!("DST-PORT should not require IP resolution"),
875        }
876    }
877
878    #[test]
879    fn lazy_match_src_ip_cidr_no_dns_needed() {
880        // SRC-IP-CIDR rules should not require DNS resolution
881        let mut builder = RuleEngineBuilder::new();
882        builder.add_inline_rule(
883            ParsedRule::SrcIpCidr("10.0.0.0/8".parse().unwrap()),
884            Action::Reject,
885        );
886        builder.set_final(Action::Direct);
887        let engine = builder.build().unwrap();
888
889        let ctx = MatchContext {
890            domain: Some("example.com"),
891            dest_ip: None,
892            dest_port: 443,
893            src_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
894        };
895        match engine.match_request_lazy_ip(&ctx) {
896            MatchDecision::Matched(action) => assert_eq!(action, &Action::Reject),
897            MatchDecision::NeedIp => panic!("SRC-IP-CIDR should not require IP resolution"),
898        }
899    }
900
901    #[test]
902    fn lazy_match_rule_set_port_before_cidr() {
903        // A rule-set containing both DST-PORT and IP-CIDR should match on
904        // port without returning NeedIp, even when dest_ip is None.
905        let mut builder = RuleEngineBuilder::new();
906        builder.add_rule_set(
907            "mixed",
908            vec![
909                ParsedRule::IpCidr("10.0.0.0/8".parse().unwrap()),
910                ParsedRule::DstPort(8080),
911            ],
912        );
913        builder.add_rule_set_rule("mixed", Action::Reject);
914        builder.set_final(Action::Direct);
915        let engine = builder.build().unwrap();
916
917        // Port 8080, no dest_ip → should match on port, not return NeedIp
918        let ctx = MatchContext {
919            domain: Some("example.com"),
920            dest_ip: None,
921            dest_port: 8080,
922            src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
923        };
924        match engine.match_request_lazy_ip(&ctx) {
925            MatchDecision::Matched(action) => assert_eq!(action, &Action::Reject),
926            MatchDecision::NeedIp => {
927                panic!("DST-PORT in rule-set should match before CIDR triggers NeedIp")
928            }
929        }
930
931        // Port 443 (no match), no dest_ip, cidr exists → should return NeedIp
932        let ctx_miss = MatchContext {
933            domain: Some("other.com"),
934            dest_ip: None,
935            dest_port: 443,
936            src_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
937        };
938        match engine.match_request_lazy_ip(&ctx_miss) {
939            MatchDecision::Matched(_) => {
940                panic!("should return NeedIp for unmatched port with CIDR")
941            }
942            MatchDecision::NeedIp => {}
943        }
944    }
945
946    #[test]
947    fn domain_suffix_leading_dot_in_rule_set() {
948        // Rule-sets compiled from parsers that produce leading-dot suffixes
949        // should still match correctly after normalization.
950        let mut builder = RuleEngineBuilder::new();
951        builder.add_rule_set(
952            "test",
953            vec![ParsedRule::DomainSuffix(".example.com".into())],
954        );
955        builder.add_rule_set_rule("test", Action::Reject);
956        builder.set_final(Action::Direct);
957        let engine = builder.build().unwrap();
958
959        assert_eq!(
960            engine.match_request(&ctx_domain("example.com")),
961            &Action::Reject
962        );
963        assert_eq!(
964            engine.match_request(&ctx_domain("sub.example.com")),
965            &Action::Reject
966        );
967        assert_eq!(
968            engine.match_request(&ctx_domain("other.com")),
969            &Action::Direct
970        );
971    }
972}