Skip to main content

microsandbox_network/policy/
engine.rs

1//! Policy evaluation engine.
2//!
3//! Evaluates rules against parsed packet headers using first-match-wins semantics.
4//! Domain-based rules are resolved via the DNS pin set.
5
6use std::{
7    collections::{HashMap, HashSet},
8    net::IpAddr,
9    sync::{Arc, RwLock},
10};
11
12use crate::packet::{IpProtocol, ParsedFrame};
13
14use super::{
15    destination::{matches_cidr, matches_group},
16    types::{Action, Destination, Direction, NetworkPolicy, Protocol},
17};
18
19//--------------------------------------------------------------------------------------------------
20// Types
21//--------------------------------------------------------------------------------------------------
22
23/// Maps resolved IP addresses back to domain names.
24///
25/// Populated by the DNS interceptor when it resolves A/AAAA records.
26/// Used by the policy engine to match domain-based rules against destination IPs.
27pub struct DnsPinSet {
28    /// IP → set of domain names that resolved to it.
29    ip_to_domains: HashMap<IpAddr, HashSet<String>>,
30}
31
32/// Policy evaluation engine.
33///
34/// Evaluates `NetworkPolicy` rules against parsed frames, using first-match-wins
35/// semantics. Domain-based rules check the `DnsPinSet` to see if the destination
36/// IP was resolved from a matching domain.
37pub struct PolicyEngine {
38    policy: NetworkPolicy,
39    pin_set: Arc<RwLock<DnsPinSet>>,
40}
41
42//--------------------------------------------------------------------------------------------------
43// Methods
44//--------------------------------------------------------------------------------------------------
45
46impl DnsPinSet {
47    /// Creates an empty pin set.
48    pub fn new() -> Self {
49        Self {
50            ip_to_domains: HashMap::new(),
51        }
52    }
53
54    /// Records that `domain` resolved to `ip`.
55    pub fn pin(&mut self, domain: &str, ip: IpAddr) {
56        self.ip_to_domains
57            .entry(ip)
58            .or_default()
59            .insert(domain.to_lowercase());
60    }
61
62    /// Returns the set of domains that resolved to `ip`, if any.
63    pub fn lookup(&self, ip: IpAddr) -> Option<&HashSet<String>> {
64        self.ip_to_domains.get(&ip)
65    }
66
67    /// Removes all entries for an IP.
68    pub fn remove_ip(&mut self, ip: &IpAddr) {
69        self.ip_to_domains.remove(ip);
70    }
71}
72
73impl Default for DnsPinSet {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl PolicyEngine {
80    /// Creates a new policy engine with the given policy and pin set.
81    pub fn new(policy: NetworkPolicy, pin_set: Arc<RwLock<DnsPinSet>>) -> Self {
82        Self { policy, pin_set }
83    }
84
85    /// Evaluates a parsed frame against the policy.
86    ///
87    /// Returns the action to take (Allow or Deny). Uses first-match-wins:
88    /// the first rule whose direction, destination, protocol, and ports all
89    /// match determines the action. If no rule matches, the default action
90    /// is returned.
91    pub fn evaluate(&self, frame: &ParsedFrame<'_>, direction: Direction) -> Action {
92        let dst_ip = match frame.dst_ip() {
93            Some(ip) => ip,
94            None => return self.policy.default_action,
95        };
96
97        let protocol = frame.protocol();
98        let dst_port = frame.dst_port();
99
100        for rule in &self.policy.rules {
101            if rule.direction != direction {
102                continue;
103            }
104
105            if !self.matches_destination(&rule.destination, dst_ip) {
106                continue;
107            }
108
109            if let Some(ref rule_proto) = rule.protocol
110                && !matches_protocol(rule_proto, protocol)
111            {
112                continue;
113            }
114
115            if let Some(ref port_range) = rule.ports {
116                match dst_port {
117                    Some(port) if port_range.contains(port) => {}
118                    _ => continue,
119                }
120            }
121
122            return rule.action;
123        }
124
125        self.policy.default_action
126    }
127
128    /// Checks if a destination IP matches a rule's destination spec.
129    fn matches_destination(&self, destination: &Destination, ip: IpAddr) -> bool {
130        match destination {
131            Destination::Any => true,
132            Destination::Cidr(network) => matches_cidr(network, ip),
133            Destination::Group(group) => matches_group(*group, ip),
134            Destination::Domain(domain) => self.ip_matches_domain(ip, domain),
135            Destination::DomainSuffix(suffix) => self.ip_matches_domain_suffix(ip, suffix),
136        }
137    }
138
139    /// Checks if an IP was resolved from the given domain via intercepted DNS.
140    fn ip_matches_domain(&self, ip: IpAddr, domain: &str) -> bool {
141        let pin_set = match self.pin_set.read() {
142            Ok(ps) => ps,
143            Err(_) => return false,
144        };
145        match pin_set.lookup(ip) {
146            Some(domains) => domains.contains(&domain.to_lowercase()),
147            None => false,
148        }
149    }
150
151    /// Checks if an IP was resolved from a domain matching the suffix.
152    fn ip_matches_domain_suffix(&self, ip: IpAddr, suffix: &str) -> bool {
153        let pin_set = match self.pin_set.read() {
154            Ok(ps) => ps,
155            Err(_) => return false,
156        };
157        let suffix_lower = suffix.to_lowercase();
158        match pin_set.lookup(ip) {
159            Some(domains) => domains.iter().any(|d| d.ends_with(&suffix_lower)),
160            None => false,
161        }
162    }
163}
164
165//--------------------------------------------------------------------------------------------------
166// Functions
167//--------------------------------------------------------------------------------------------------
168
169/// Checks if a parsed protocol matches a rule's protocol filter.
170fn matches_protocol(rule_proto: &Protocol, frame_proto: Option<IpProtocol>) -> bool {
171    let frame_proto = match frame_proto {
172        Some(p) => p,
173        None => return false,
174    };
175
176    matches!(
177        (rule_proto, frame_proto),
178        (Protocol::Tcp, IpProtocol::Tcp)
179            | (Protocol::Udp, IpProtocol::Udp)
180            | (Protocol::Icmpv4, IpProtocol::Icmpv4)
181            | (Protocol::Icmpv6, IpProtocol::Icmpv6)
182    )
183}
184
185//--------------------------------------------------------------------------------------------------
186// Tests
187//--------------------------------------------------------------------------------------------------
188
189#[cfg(test)]
190mod tests {
191    use std::net::Ipv4Addr;
192
193    use super::*;
194    use crate::policy::{DestinationGroup, PortRange, Rule};
195
196    fn build_udp_frame(dst_ip: [u8; 4], dst_port: u16) -> Vec<u8> {
197        use etherparse::PacketBuilder;
198        let builder = PacketBuilder::ethernet2(
199            [0x02, 0x00, 0x00, 0x00, 0x00, 0x01],
200            [0x02, 0x00, 0x00, 0x00, 0x00, 0x02],
201        )
202        .ipv4([10, 0, 0, 1], dst_ip, 64)
203        .udp(50000, dst_port);
204        let mut buf = Vec::new();
205        builder.write(&mut buf, &[]).unwrap();
206        buf
207    }
208
209    fn build_tcp_frame(dst_ip: [u8; 4], dst_port: u16) -> Vec<u8> {
210        use etherparse::PacketBuilder;
211        let builder = PacketBuilder::ethernet2(
212            [0x02, 0x00, 0x00, 0x00, 0x00, 0x01],
213            [0x02, 0x00, 0x00, 0x00, 0x00, 0x02],
214        )
215        .ipv4([10, 0, 0, 1], dst_ip, 64)
216        .tcp(50000, dst_port, 0, 65535);
217        let mut buf = Vec::new();
218        builder.write(&mut buf, &[]).unwrap();
219        buf
220    }
221
222    fn make_engine(policy: NetworkPolicy) -> PolicyEngine {
223        PolicyEngine::new(policy, Arc::new(RwLock::new(DnsPinSet::new())))
224    }
225
226    #[test]
227    fn test_allow_all() {
228        let engine = make_engine(NetworkPolicy::allow_all());
229        let frame_data = build_tcp_frame([93, 184, 216, 34], 443);
230        let frame = ParsedFrame::parse(&frame_data).unwrap();
231        assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Allow);
232    }
233
234    #[test]
235    fn test_deny_all() {
236        let engine = make_engine(NetworkPolicy::none());
237        let frame_data = build_tcp_frame([93, 184, 216, 34], 443);
238        let frame = ParsedFrame::parse(&frame_data).unwrap();
239        assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Deny);
240    }
241
242    #[test]
243    fn test_deny_private_networks() {
244        let policy = NetworkPolicy {
245            default_action: Action::Allow,
246            rules: vec![Rule::deny_outbound(Destination::Group(
247                DestinationGroup::Private,
248            ))],
249        };
250        let engine = make_engine(policy);
251
252        // Private → denied.
253        let frame_data = build_tcp_frame([10, 0, 0, 1], 80);
254        let frame = ParsedFrame::parse(&frame_data).unwrap();
255        assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Deny);
256
257        // Public → allowed (default).
258        let frame_data = build_tcp_frame([93, 184, 216, 34], 443);
259        let frame = ParsedFrame::parse(&frame_data).unwrap();
260        assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Allow);
261    }
262
263    #[test]
264    fn test_cidr_rule() {
265        let policy = NetworkPolicy {
266            default_action: Action::Deny,
267            rules: vec![Rule::allow_outbound(Destination::Cidr(
268                "93.184.216.0/24".parse().unwrap(),
269            ))],
270        };
271        let engine = make_engine(policy);
272
273        let frame_data = build_tcp_frame([93, 184, 216, 34], 443);
274        let frame = ParsedFrame::parse(&frame_data).unwrap();
275        assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Allow);
276
277        let frame_data = build_tcp_frame([8, 8, 8, 8], 53);
278        let frame = ParsedFrame::parse(&frame_data).unwrap();
279        assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Deny);
280    }
281
282    #[test]
283    fn test_port_range() {
284        let policy = NetworkPolicy {
285            default_action: Action::Deny,
286            rules: vec![Rule {
287                direction: Direction::Outbound,
288                destination: Destination::Any,
289                protocol: Some(Protocol::Tcp),
290                ports: Some(PortRange::range(80, 443)),
291                action: Action::Allow,
292            }],
293        };
294        let engine = make_engine(policy);
295
296        let frame_data = build_tcp_frame([8, 8, 8, 8], 443);
297        let frame = ParsedFrame::parse(&frame_data).unwrap();
298        assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Allow);
299
300        let frame_data = build_tcp_frame([8, 8, 8, 8], 22);
301        let frame = ParsedFrame::parse(&frame_data).unwrap();
302        assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Deny);
303    }
304
305    #[test]
306    fn test_protocol_filter() {
307        let policy = NetworkPolicy {
308            default_action: Action::Deny,
309            rules: vec![Rule {
310                direction: Direction::Outbound,
311                destination: Destination::Any,
312                protocol: Some(Protocol::Tcp),
313                ports: None,
314                action: Action::Allow,
315            }],
316        };
317        let engine = make_engine(policy);
318
319        // TCP → allowed.
320        let frame_data = build_tcp_frame([8, 8, 8, 8], 443);
321        let frame = ParsedFrame::parse(&frame_data).unwrap();
322        assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Allow);
323
324        // UDP → denied (protocol mismatch).
325        let frame_data = build_udp_frame([8, 8, 8, 8], 53);
326        let frame = ParsedFrame::parse(&frame_data).unwrap();
327        assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Deny);
328    }
329
330    #[test]
331    fn test_direction_filter() {
332        let policy = NetworkPolicy {
333            default_action: Action::Allow,
334            rules: vec![Rule::deny_outbound(Destination::Group(
335                DestinationGroup::Loopback,
336            ))],
337        };
338        let engine = make_engine(policy);
339
340        let frame_data = build_tcp_frame([127, 0, 0, 1], 80);
341        let frame = ParsedFrame::parse(&frame_data).unwrap();
342
343        // Outbound → denied.
344        assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Deny);
345
346        // Inbound → allowed (rule is outbound-only).
347        assert_eq!(engine.evaluate(&frame, Direction::Inbound), Action::Allow);
348    }
349
350    #[test]
351    fn test_domain_rule_with_pin_set() {
352        let pin_set = Arc::new(RwLock::new(DnsPinSet::new()));
353        pin_set
354            .write()
355            .unwrap()
356            .pin("example.com", IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)));
357
358        let policy = NetworkPolicy {
359            default_action: Action::Deny,
360            rules: vec![Rule::allow_outbound(Destination::Domain(
361                "example.com".to_string(),
362            ))],
363        };
364        let engine = PolicyEngine::new(policy, pin_set);
365
366        // Pinned IP → allowed.
367        let frame_data = build_tcp_frame([93, 184, 216, 34], 443);
368        let frame = ParsedFrame::parse(&frame_data).unwrap();
369        assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Allow);
370
371        // Unpinned IP → denied.
372        let frame_data = build_tcp_frame([8, 8, 8, 8], 443);
373        let frame = ParsedFrame::parse(&frame_data).unwrap();
374        assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Deny);
375    }
376
377    #[test]
378    fn test_first_match_wins() {
379        let policy = NetworkPolicy {
380            default_action: Action::Deny,
381            rules: vec![
382                // First rule: allow port 443 to anywhere.
383                Rule {
384                    direction: Direction::Outbound,
385                    destination: Destination::Any,
386                    protocol: Some(Protocol::Tcp),
387                    ports: Some(PortRange::single(443)),
388                    action: Action::Allow,
389                },
390                // Second rule: deny all TCP (should not be reached for port 443).
391                Rule {
392                    direction: Direction::Outbound,
393                    destination: Destination::Any,
394                    protocol: Some(Protocol::Tcp),
395                    ports: None,
396                    action: Action::Deny,
397                },
398            ],
399        };
400        let engine = make_engine(policy);
401
402        let frame_data = build_tcp_frame([8, 8, 8, 8], 443);
403        let frame = ParsedFrame::parse(&frame_data).unwrap();
404        assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Allow);
405    }
406}