1use std::collections::BTreeMap;
8use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
9
10use crate::rate_limiter::is_local_network;
11
12trait Bounded {
16 fn min_value() -> Self;
17 #[allow(dead_code)]
18 fn max_value() -> Self;
19}
20
21trait Successor {
23 fn successor(self) -> Self;
24}
25
26impl Bounded for Ipv4Addr {
27 fn min_value() -> Self {
28 Self::UNSPECIFIED
29 }
30 fn max_value() -> Self {
31 Self::BROADCAST
32 }
33}
34
35impl Successor for Ipv4Addr {
36 fn successor(self) -> Self {
37 let n: u32 = self.into();
38 Self::from(n.saturating_add(1))
39 }
40}
41
42impl Bounded for Ipv6Addr {
43 fn min_value() -> Self {
44 Self::UNSPECIFIED
45 }
46 fn max_value() -> Self {
47 Self::new(
48 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
49 )
50 }
51}
52
53impl Successor for Ipv6Addr {
54 fn successor(self) -> Self {
55 let n: u128 = self.into();
56 Self::from(n.saturating_add(1))
57 }
58}
59
60impl Bounded for u16 {
61 fn min_value() -> Self {
62 0
63 }
64 fn max_value() -> Self {
65 Self::MAX
66 }
67}
68
69impl Successor for u16 {
70 fn successor(self) -> Self {
71 self.saturating_add(1)
72 }
73}
74
75#[derive(Debug, Clone)]
82struct IntervalMap<K: Ord + Clone + Bounded + Successor> {
83 map: BTreeMap<K, u32>,
85}
86
87impl<K: Ord + Clone + Bounded + Successor> IntervalMap<K> {
88 fn new() -> Self {
89 let mut map = BTreeMap::new();
90 map.insert(K::min_value(), 0);
92 Self { map }
93 }
94
95 #[allow(
97 clippy::needless_pass_by_value,
98 reason = "K is consumed via .clone().successor() — taking &K just adds a clone at every call site"
99 )]
100 fn add_rule(&mut self, first: K, last: K, flags: u32) {
101 if first > last {
102 return;
103 }
104
105 let after_key = last.clone().successor();
108 let flags_after = self.access(&after_key);
109
110 let keys_to_remove: Vec<K> = self
117 .map
118 .range(first.clone()..after_key.clone())
119 .map(|(k, _)| k.clone())
120 .collect();
121 for k in keys_to_remove {
122 self.map.remove(&k);
123 }
124
125 self.map.insert(first, flags);
127
128 if after_key > last {
130 self.map.insert(after_key, flags_after);
131 }
132
133 self.minimize();
134 }
135
136 fn access(&self, key: &K) -> u32 {
138 self.map
139 .range(..=key.clone())
140 .next_back()
141 .map_or(0, |(_, &v)| v)
142 }
143
144 fn minimize(&mut self) {
146 let mut prev_flags: Option<u32> = None;
147 let mut to_remove = Vec::new();
148
149 for (k, &flags) in &self.map {
150 if prev_flags == Some(flags) {
151 to_remove.push(k.clone());
152 }
153 prev_flags = Some(flags);
154 }
155
156 for k in to_remove {
157 self.map.remove(&k);
158 }
159 }
160
161 fn num_ranges(&self) -> usize {
163 let mut count = 0;
165 for &flags in self.map.values() {
166 if flags != 0 {
167 count += 1;
168 }
169 }
170 count
171 }
172
173 fn is_empty(&self) -> bool {
174 self.num_ranges() == 0
175 }
176}
177
178#[derive(Debug, Clone)]
185pub struct IpFilter {
186 v4: IntervalMap<Ipv4Addr>,
187 v6: IntervalMap<Ipv6Addr>,
188 pub enabled: bool,
196}
197
198impl IpFilter {
199 #[must_use]
201 pub fn new() -> Self {
202 Self {
203 v4: IntervalMap::new(),
204 v6: IntervalMap::new(),
205 enabled: true,
206 }
207 }
208
209 pub fn add_rule(&mut self, first: IpAddr, last: IpAddr, flags: u32) {
214 match (first, last) {
215 (IpAddr::V4(f), IpAddr::V4(l)) => self.v4.add_rule(f, l, flags),
216 (IpAddr::V6(f), IpAddr::V6(l)) => self.v6.add_rule(f, l, flags),
217 _ => {} }
219 }
220
221 #[must_use]
223 pub fn access(&self, addr: IpAddr) -> u32 {
224 match addr {
225 IpAddr::V4(ip) => self.v4.access(&ip),
226 IpAddr::V6(ip) => self.v6.access(&ip),
227 }
228 }
229
230 #[must_use]
235 pub fn is_blocked(&self, addr: IpAddr) -> bool {
236 if !self.enabled {
237 return false;
238 }
239 if is_local_network(addr) {
240 return false;
241 }
242 self.access(addr) != 0
243 }
244
245 #[must_use]
247 pub fn num_ranges(&self) -> usize {
248 self.v4.num_ranges() + self.v6.num_ranges()
249 }
250
251 #[must_use]
253 pub fn is_empty(&self) -> bool {
254 self.v4.is_empty() && self.v6.is_empty()
255 }
256}
257
258impl Default for IpFilter {
259 fn default() -> Self {
260 Self::new()
261 }
262}
263
264#[derive(Debug, Clone)]
270pub struct PortFilter {
271 ports: IntervalMap<u16>,
272}
273
274impl PortFilter {
275 #[must_use]
277 pub fn new() -> Self {
278 Self {
279 ports: IntervalMap::new(),
280 }
281 }
282
283 pub fn add_rule(&mut self, first: u16, last: u16, flags: u32) {
285 self.ports.add_rule(first, last, flags);
286 }
287
288 #[must_use]
290 pub fn access(&self, port: u16) -> u32 {
291 self.ports.access(&port)
292 }
293
294 #[must_use]
296 pub fn is_blocked(&self, port: u16) -> bool {
297 self.access(port) != 0
298 }
299}
300
301impl Default for PortFilter {
302 fn default() -> Self {
303 Self::new()
304 }
305}
306
307#[derive(Debug, thiserror::Error)]
311pub enum IpFilterError {
312 #[error("invalid IP address on line {line}: {message}")]
314 InvalidAddress {
315 line: usize,
317 message: String,
319 },
320
321 #[error("malformed line {line}: {message}")]
323 MalformedLine {
324 line: usize,
326 message: String,
328 },
329}
330
331pub fn parse_dat(input: &str) -> Result<IpFilter, IpFilterError> {
340 let mut filter = IpFilter::new();
341
342 for (line_num, line) in input.lines().enumerate() {
343 let line = line.trim();
344 if line.is_empty() || line.starts_with('#') {
345 continue;
346 }
347
348 let parts: Vec<&str> = line.splitn(3, ',').collect();
350 if parts.len() < 2 {
351 return Err(IpFilterError::MalformedLine {
352 line: line_num + 1,
353 message: "expected 'first_ip - last_ip , level , description'".into(),
354 });
355 }
356
357 let ip_range = parts[0].trim();
359 let ips: Vec<&str> = ip_range.splitn(2, '-').collect();
360 if ips.len() != 2 {
361 return Err(IpFilterError::MalformedLine {
362 line: line_num + 1,
363 message: "expected 'first_ip - last_ip'".into(),
364 });
365 }
366
367 let first: IpAddr = ips[0]
368 .trim()
369 .parse()
370 .map_err(
371 |e: std::net::AddrParseError| IpFilterError::InvalidAddress {
372 line: line_num + 1,
373 message: e.to_string(),
374 },
375 )?;
376
377 let last: IpAddr = ips[1]
378 .trim()
379 .parse()
380 .map_err(
381 |e: std::net::AddrParseError| IpFilterError::InvalidAddress {
382 line: line_num + 1,
383 message: e.to_string(),
384 },
385 )?;
386
387 let level: u32 = parts[1]
389 .trim()
390 .parse()
391 .map_err(|_| IpFilterError::MalformedLine {
392 line: line_num + 1,
393 message: "invalid level (expected integer)".into(),
394 })?;
395
396 filter.add_rule(first, last, level);
397 }
398
399 Ok(filter)
400}
401
402pub fn parse_p2p(input: &str) -> Result<IpFilter, IpFilterError> {
411 let mut filter = IpFilter::new();
412
413 for (line_num, line) in input.lines().enumerate() {
414 let line = line.trim();
415 if line.is_empty() || line.starts_with('#') {
416 continue;
417 }
418
419 let colon_pos = line
421 .rfind(':')
422 .ok_or_else(|| IpFilterError::MalformedLine {
423 line: line_num + 1,
424 message: "expected 'description:first_ip-last_ip'".into(),
425 })?;
426
427 let ip_range = &line[colon_pos + 1..];
428 let ips: Vec<&str> = ip_range.splitn(2, '-').collect();
429 if ips.len() != 2 {
430 return Err(IpFilterError::MalformedLine {
431 line: line_num + 1,
432 message: "expected 'first_ip-last_ip' after ':'".into(),
433 });
434 }
435
436 let first: IpAddr = ips[0]
437 .trim()
438 .parse()
439 .map_err(
440 |e: std::net::AddrParseError| IpFilterError::InvalidAddress {
441 line: line_num + 1,
442 message: e.to_string(),
443 },
444 )?;
445
446 let last: IpAddr = ips[1]
447 .trim()
448 .parse()
449 .map_err(
450 |e: std::net::AddrParseError| IpFilterError::InvalidAddress {
451 line: line_num + 1,
452 message: e.to_string(),
453 },
454 )?;
455
456 filter.add_rule(first, last, 1);
458 }
459
460 Ok(filter)
461}
462
463#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
471 fn interval_map_empty_returns_zero() {
472 let map: IntervalMap<Ipv4Addr> = IntervalMap::new();
473 assert_eq!(map.access(&Ipv4Addr::UNSPECIFIED), 0);
474 assert_eq!(map.access(&Ipv4Addr::new(192, 168, 1, 1)), 0);
475 assert_eq!(map.access(&Ipv4Addr::BROADCAST), 0);
476 }
477
478 #[test]
480 fn interval_map_single_range() {
481 let mut map: IntervalMap<Ipv4Addr> = IntervalMap::new();
482 map.add_rule(Ipv4Addr::new(10, 0, 0, 0), Ipv4Addr::new(10, 0, 0, 255), 1);
483
484 assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 0)), 1);
486 assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 128)), 1);
487 assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 255)), 1);
488
489 assert_eq!(map.access(&Ipv4Addr::new(9, 255, 255, 255)), 0);
491 assert_eq!(map.access(&Ipv4Addr::new(10, 0, 1, 0)), 0);
492 assert_eq!(map.access(&Ipv4Addr::new(192, 168, 1, 1)), 0);
493 }
494
495 #[test]
497 fn interval_map_overlapping_last_wins() {
498 let mut map: IntervalMap<Ipv4Addr> = IntervalMap::new();
499 map.add_rule(Ipv4Addr::new(10, 0, 0, 0), Ipv4Addr::new(10, 0, 0, 255), 1);
501 map.add_rule(
503 Ipv4Addr::new(10, 0, 0, 100),
504 Ipv4Addr::new(10, 0, 0, 200),
505 0,
506 );
507
508 assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 50)), 1); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 100)), 0); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 150)), 0); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 200)), 0); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 201)), 1); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 255)), 1); }
515
516 #[test]
518 fn ip_filter_v4_block_range() {
519 let mut filter = IpFilter::new();
520 filter.add_rule(
521 IpAddr::V4(Ipv4Addr::new(203, 0, 113, 0)),
522 IpAddr::V4(Ipv4Addr::new(203, 0, 113, 255)),
523 1,
524 );
525
526 assert!(filter.is_blocked("203.0.113.0".parse().unwrap()));
528 assert!(filter.is_blocked("203.0.113.128".parse().unwrap()));
529 assert!(filter.is_blocked("203.0.113.255".parse().unwrap()));
530
531 assert!(!filter.is_blocked("203.0.112.255".parse().unwrap()));
533 assert!(!filter.is_blocked("203.0.114.0".parse().unwrap()));
534 assert!(!filter.is_blocked("8.8.8.8".parse().unwrap()));
535 }
536
537 #[test]
539 fn ip_filter_v6_block_range() {
540 let mut filter = IpFilter::new();
541 filter.add_rule(
542 IpAddr::V6("2001:db8::0".parse().unwrap()),
543 IpAddr::V6("2001:db8::ffff".parse().unwrap()),
544 1,
545 );
546
547 assert!(filter.is_blocked("2001:db8::1".parse().unwrap()));
548 assert!(filter.is_blocked("2001:db8::ff".parse().unwrap()));
549 assert!(!filter.is_blocked("2001:db9::1".parse().unwrap()));
550 }
551
552 #[test]
554 fn ip_filter_local_network_exempt() {
555 let mut filter = IpFilter::new();
556 filter.add_rule(
558 IpAddr::V4(Ipv4Addr::UNSPECIFIED),
559 IpAddr::V4(Ipv4Addr::BROADCAST),
560 1,
561 );
562 filter.add_rule(
563 IpAddr::V6(Ipv6Addr::UNSPECIFIED),
564 IpAddr::V6(Ipv6Addr::new(
565 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
566 )),
567 1,
568 );
569
570 assert!(!filter.is_blocked("127.0.0.1".parse().unwrap()));
572 assert!(!filter.is_blocked("192.168.1.1".parse().unwrap()));
573 assert!(!filter.is_blocked("10.0.0.1".parse().unwrap()));
574 assert!(!filter.is_blocked("172.16.0.1".parse().unwrap()));
575 assert!(!filter.is_blocked("::1".parse().unwrap()));
576
577 assert_eq!(filter.access("127.0.0.1".parse().unwrap()), 1);
579
580 assert!(filter.is_blocked("8.8.8.8".parse().unwrap()));
582 assert!(filter.is_blocked("2001:db8::1".parse().unwrap()));
583 }
584
585 #[test]
587 fn ip_filter_num_ranges() {
588 let mut filter = IpFilter::new();
589 assert_eq!(filter.num_ranges(), 0);
590 assert!(filter.is_empty());
591
592 filter.add_rule(
593 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)),
594 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 255)),
595 1,
596 );
597 assert_eq!(filter.num_ranges(), 1);
598 assert!(!filter.is_empty());
599
600 filter.add_rule(
601 IpAddr::V4(Ipv4Addr::new(172, 16, 0, 0)),
602 IpAddr::V4(Ipv4Addr::new(172, 16, 255, 255)),
603 1,
604 );
605 assert_eq!(filter.num_ranges(), 2);
606
607 filter.add_rule(
609 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)),
610 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 255)),
611 0,
612 );
613 assert_eq!(filter.num_ranges(), 1);
615 }
616
617 #[test]
619 fn parse_dat_valid() {
620 let input = "\
621# This is a comment
622203.0.113.0 - 203.0.113.255 , 128 , Test range
623198.51.100.0 - 198.51.100.255 , 1 , Another range
624";
625 let filter = parse_dat(input).unwrap();
626 assert!(filter.is_blocked("203.0.113.50".parse().unwrap()));
627 assert!(filter.is_blocked("198.51.100.1".parse().unwrap()));
628 assert!(!filter.is_blocked("8.8.8.8".parse().unwrap()));
629 }
630
631 #[test]
632 fn parse_dat_malformed() {
633 let input = "this is not a valid line";
634 let err = parse_dat(input).unwrap_err();
635 assert!(matches!(err, IpFilterError::MalformedLine { line: 1, .. }));
636 }
637
638 #[test]
640 fn parse_p2p_valid() {
641 let input = "\
642# P2P blocklist
643Some Bad Range:203.0.113.0-203.0.113.255
644Another Range:198.51.100.0-198.51.100.255
645";
646 let filter = parse_p2p(input).unwrap();
647 assert!(filter.is_blocked("203.0.113.50".parse().unwrap()));
648 assert!(filter.is_blocked("198.51.100.1".parse().unwrap()));
649 assert!(!filter.is_blocked("8.8.8.8".parse().unwrap()));
650 }
651
652 #[test]
653 fn parse_p2p_invalid_ip() {
654 let input = "Bad Range:999.999.999.999-203.0.113.255";
655 let err = parse_p2p(input).unwrap_err();
656 assert!(matches!(err, IpFilterError::InvalidAddress { line: 1, .. }));
657 }
658
659 #[test]
661 fn port_filter_block_range() {
662 let mut filter = PortFilter::new();
663 filter.add_rule(6881, 6889, 1);
664
665 assert!(filter.is_blocked(6881));
666 assert!(filter.is_blocked(6885));
667 assert!(filter.is_blocked(6889));
668 assert!(!filter.is_blocked(6880));
669 assert!(!filter.is_blocked(6890));
670 assert!(!filter.is_blocked(80));
671 }
672
673 #[test]
678 fn ip_filter_set_enabled_short_circuits_is_blocked() {
679 use std::str::FromStr;
680 let mut filter = IpFilter::new();
681 filter.add_rule(
682 IpAddr::from(Ipv4Addr::from_str("203.0.113.0").unwrap()),
683 IpAddr::from(Ipv4Addr::from_str("203.0.113.255").unwrap()),
684 1,
685 );
686 let blocked_ip = IpAddr::from(Ipv4Addr::from_str("203.0.113.42").unwrap());
687
688 assert!(filter.enabled);
690 assert!(filter.is_blocked(blocked_ip));
691
692 filter.enabled = false;
694 assert!(!filter.is_blocked(blocked_ip), "disabled filter must short-circuit even for IPs in blocked range");
695
696 filter.enabled = true;
698 assert!(filter.is_blocked(blocked_ip));
699 }
700}