1use std::{
7 collections::{HashMap, HashSet},
8 net::IpAddr,
9 sync::{Arc, RwLock},
10};
11
12use crate::packet::{IpProtocol, ParsedFrame};
13
14use super::{
15 destination::{matches_cidr, matches_group},
16 types::{Action, Destination, Direction, NetworkPolicy, Protocol},
17};
18
19pub struct DnsPinSet {
28 ip_to_domains: HashMap<IpAddr, HashSet<String>>,
30}
31
32pub struct PolicyEngine {
38 policy: NetworkPolicy,
39 pin_set: Arc<RwLock<DnsPinSet>>,
40}
41
42impl DnsPinSet {
47 pub fn new() -> Self {
49 Self {
50 ip_to_domains: HashMap::new(),
51 }
52 }
53
54 pub fn pin(&mut self, domain: &str, ip: IpAddr) {
56 self.ip_to_domains
57 .entry(ip)
58 .or_default()
59 .insert(domain.to_lowercase());
60 }
61
62 pub fn lookup(&self, ip: IpAddr) -> Option<&HashSet<String>> {
64 self.ip_to_domains.get(&ip)
65 }
66
67 pub fn remove_ip(&mut self, ip: &IpAddr) {
69 self.ip_to_domains.remove(ip);
70 }
71}
72
73impl Default for DnsPinSet {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl PolicyEngine {
80 pub fn new(policy: NetworkPolicy, pin_set: Arc<RwLock<DnsPinSet>>) -> Self {
82 Self { policy, pin_set }
83 }
84
85 pub fn evaluate(&self, frame: &ParsedFrame<'_>, direction: Direction) -> Action {
92 let dst_ip = match frame.dst_ip() {
93 Some(ip) => ip,
94 None => return self.policy.default_action,
95 };
96
97 let protocol = frame.protocol();
98 let dst_port = frame.dst_port();
99
100 for rule in &self.policy.rules {
101 if rule.direction != direction {
102 continue;
103 }
104
105 if !self.matches_destination(&rule.destination, dst_ip) {
106 continue;
107 }
108
109 if let Some(ref rule_proto) = rule.protocol
110 && !matches_protocol(rule_proto, protocol)
111 {
112 continue;
113 }
114
115 if let Some(ref port_range) = rule.ports {
116 match dst_port {
117 Some(port) if port_range.contains(port) => {}
118 _ => continue,
119 }
120 }
121
122 return rule.action;
123 }
124
125 self.policy.default_action
126 }
127
128 fn matches_destination(&self, destination: &Destination, ip: IpAddr) -> bool {
130 match destination {
131 Destination::Any => true,
132 Destination::Cidr(network) => matches_cidr(network, ip),
133 Destination::Group(group) => matches_group(*group, ip),
134 Destination::Domain(domain) => self.ip_matches_domain(ip, domain),
135 Destination::DomainSuffix(suffix) => self.ip_matches_domain_suffix(ip, suffix),
136 }
137 }
138
139 fn ip_matches_domain(&self, ip: IpAddr, domain: &str) -> bool {
141 let pin_set = match self.pin_set.read() {
142 Ok(ps) => ps,
143 Err(_) => return false,
144 };
145 match pin_set.lookup(ip) {
146 Some(domains) => domains.contains(&domain.to_lowercase()),
147 None => false,
148 }
149 }
150
151 fn ip_matches_domain_suffix(&self, ip: IpAddr, suffix: &str) -> bool {
153 let pin_set = match self.pin_set.read() {
154 Ok(ps) => ps,
155 Err(_) => return false,
156 };
157 let suffix_lower = suffix.to_lowercase();
158 match pin_set.lookup(ip) {
159 Some(domains) => domains.iter().any(|d| d.ends_with(&suffix_lower)),
160 None => false,
161 }
162 }
163}
164
165fn matches_protocol(rule_proto: &Protocol, frame_proto: Option<IpProtocol>) -> bool {
171 let frame_proto = match frame_proto {
172 Some(p) => p,
173 None => return false,
174 };
175
176 matches!(
177 (rule_proto, frame_proto),
178 (Protocol::Tcp, IpProtocol::Tcp)
179 | (Protocol::Udp, IpProtocol::Udp)
180 | (Protocol::Icmpv4, IpProtocol::Icmpv4)
181 | (Protocol::Icmpv6, IpProtocol::Icmpv6)
182 )
183}
184
185#[cfg(test)]
190mod tests {
191 use std::net::Ipv4Addr;
192
193 use super::*;
194 use crate::policy::{DestinationGroup, PortRange, Rule};
195
196 fn build_udp_frame(dst_ip: [u8; 4], dst_port: u16) -> Vec<u8> {
197 use etherparse::PacketBuilder;
198 let builder = PacketBuilder::ethernet2(
199 [0x02, 0x00, 0x00, 0x00, 0x00, 0x01],
200 [0x02, 0x00, 0x00, 0x00, 0x00, 0x02],
201 )
202 .ipv4([10, 0, 0, 1], dst_ip, 64)
203 .udp(50000, dst_port);
204 let mut buf = Vec::new();
205 builder.write(&mut buf, &[]).unwrap();
206 buf
207 }
208
209 fn build_tcp_frame(dst_ip: [u8; 4], dst_port: u16) -> Vec<u8> {
210 use etherparse::PacketBuilder;
211 let builder = PacketBuilder::ethernet2(
212 [0x02, 0x00, 0x00, 0x00, 0x00, 0x01],
213 [0x02, 0x00, 0x00, 0x00, 0x00, 0x02],
214 )
215 .ipv4([10, 0, 0, 1], dst_ip, 64)
216 .tcp(50000, dst_port, 0, 65535);
217 let mut buf = Vec::new();
218 builder.write(&mut buf, &[]).unwrap();
219 buf
220 }
221
222 fn make_engine(policy: NetworkPolicy) -> PolicyEngine {
223 PolicyEngine::new(policy, Arc::new(RwLock::new(DnsPinSet::new())))
224 }
225
226 #[test]
227 fn test_allow_all() {
228 let engine = make_engine(NetworkPolicy::allow_all());
229 let frame_data = build_tcp_frame([93, 184, 216, 34], 443);
230 let frame = ParsedFrame::parse(&frame_data).unwrap();
231 assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Allow);
232 }
233
234 #[test]
235 fn test_deny_all() {
236 let engine = make_engine(NetworkPolicy::none());
237 let frame_data = build_tcp_frame([93, 184, 216, 34], 443);
238 let frame = ParsedFrame::parse(&frame_data).unwrap();
239 assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Deny);
240 }
241
242 #[test]
243 fn test_deny_private_networks() {
244 let policy = NetworkPolicy {
245 default_action: Action::Allow,
246 rules: vec![Rule::deny_outbound(Destination::Group(
247 DestinationGroup::Private,
248 ))],
249 };
250 let engine = make_engine(policy);
251
252 let frame_data = build_tcp_frame([10, 0, 0, 1], 80);
254 let frame = ParsedFrame::parse(&frame_data).unwrap();
255 assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Deny);
256
257 let frame_data = build_tcp_frame([93, 184, 216, 34], 443);
259 let frame = ParsedFrame::parse(&frame_data).unwrap();
260 assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Allow);
261 }
262
263 #[test]
264 fn test_cidr_rule() {
265 let policy = NetworkPolicy {
266 default_action: Action::Deny,
267 rules: vec![Rule::allow_outbound(Destination::Cidr(
268 "93.184.216.0/24".parse().unwrap(),
269 ))],
270 };
271 let engine = make_engine(policy);
272
273 let frame_data = build_tcp_frame([93, 184, 216, 34], 443);
274 let frame = ParsedFrame::parse(&frame_data).unwrap();
275 assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Allow);
276
277 let frame_data = build_tcp_frame([8, 8, 8, 8], 53);
278 let frame = ParsedFrame::parse(&frame_data).unwrap();
279 assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Deny);
280 }
281
282 #[test]
283 fn test_port_range() {
284 let policy = NetworkPolicy {
285 default_action: Action::Deny,
286 rules: vec![Rule {
287 direction: Direction::Outbound,
288 destination: Destination::Any,
289 protocol: Some(Protocol::Tcp),
290 ports: Some(PortRange::range(80, 443)),
291 action: Action::Allow,
292 }],
293 };
294 let engine = make_engine(policy);
295
296 let frame_data = build_tcp_frame([8, 8, 8, 8], 443);
297 let frame = ParsedFrame::parse(&frame_data).unwrap();
298 assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Allow);
299
300 let frame_data = build_tcp_frame([8, 8, 8, 8], 22);
301 let frame = ParsedFrame::parse(&frame_data).unwrap();
302 assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Deny);
303 }
304
305 #[test]
306 fn test_protocol_filter() {
307 let policy = NetworkPolicy {
308 default_action: Action::Deny,
309 rules: vec![Rule {
310 direction: Direction::Outbound,
311 destination: Destination::Any,
312 protocol: Some(Protocol::Tcp),
313 ports: None,
314 action: Action::Allow,
315 }],
316 };
317 let engine = make_engine(policy);
318
319 let frame_data = build_tcp_frame([8, 8, 8, 8], 443);
321 let frame = ParsedFrame::parse(&frame_data).unwrap();
322 assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Allow);
323
324 let frame_data = build_udp_frame([8, 8, 8, 8], 53);
326 let frame = ParsedFrame::parse(&frame_data).unwrap();
327 assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Deny);
328 }
329
330 #[test]
331 fn test_direction_filter() {
332 let policy = NetworkPolicy {
333 default_action: Action::Allow,
334 rules: vec![Rule::deny_outbound(Destination::Group(
335 DestinationGroup::Loopback,
336 ))],
337 };
338 let engine = make_engine(policy);
339
340 let frame_data = build_tcp_frame([127, 0, 0, 1], 80);
341 let frame = ParsedFrame::parse(&frame_data).unwrap();
342
343 assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Deny);
345
346 assert_eq!(engine.evaluate(&frame, Direction::Inbound), Action::Allow);
348 }
349
350 #[test]
351 fn test_domain_rule_with_pin_set() {
352 let pin_set = Arc::new(RwLock::new(DnsPinSet::new()));
353 pin_set
354 .write()
355 .unwrap()
356 .pin("example.com", IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)));
357
358 let policy = NetworkPolicy {
359 default_action: Action::Deny,
360 rules: vec![Rule::allow_outbound(Destination::Domain(
361 "example.com".to_string(),
362 ))],
363 };
364 let engine = PolicyEngine::new(policy, pin_set);
365
366 let frame_data = build_tcp_frame([93, 184, 216, 34], 443);
368 let frame = ParsedFrame::parse(&frame_data).unwrap();
369 assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Allow);
370
371 let frame_data = build_tcp_frame([8, 8, 8, 8], 443);
373 let frame = ParsedFrame::parse(&frame_data).unwrap();
374 assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Deny);
375 }
376
377 #[test]
378 fn test_first_match_wins() {
379 let policy = NetworkPolicy {
380 default_action: Action::Deny,
381 rules: vec![
382 Rule {
384 direction: Direction::Outbound,
385 destination: Destination::Any,
386 protocol: Some(Protocol::Tcp),
387 ports: Some(PortRange::single(443)),
388 action: Action::Allow,
389 },
390 Rule {
392 direction: Direction::Outbound,
393 destination: Destination::Any,
394 protocol: Some(Protocol::Tcp),
395 ports: None,
396 action: Action::Deny,
397 },
398 ],
399 };
400 let engine = make_engine(policy);
401
402 let frame_data = build_tcp_frame([8, 8, 8, 8], 443);
403 let frame = ParsedFrame::parse(&frame_data).unwrap();
404 assert_eq!(engine.evaluate(&frame, Direction::Outbound), Action::Allow);
405 }
406}