1use std::collections::BTreeMap;
8use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
9
10use crate::rate_limiter::is_local_network;
11
12trait Bounded {
16 fn min_value() -> Self;
17 #[allow(dead_code)]
18 fn max_value() -> Self;
19}
20
21trait Successor {
23 fn successor(self) -> Self;
24}
25
26impl Bounded for Ipv4Addr {
27 fn min_value() -> Self {
28 Ipv4Addr::new(0, 0, 0, 0)
29 }
30 fn max_value() -> Self {
31 Ipv4Addr::new(255, 255, 255, 255)
32 }
33}
34
35impl Successor for Ipv4Addr {
36 fn successor(self) -> Self {
37 let n: u32 = self.into();
38 Ipv4Addr::from(n.saturating_add(1))
39 }
40}
41
42impl Bounded for Ipv6Addr {
43 fn min_value() -> Self {
44 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)
45 }
46 fn max_value() -> Self {
47 Ipv6Addr::new(
48 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
49 )
50 }
51}
52
53impl Successor for Ipv6Addr {
54 fn successor(self) -> Self {
55 let n: u128 = self.into();
56 Ipv6Addr::from(n.saturating_add(1))
57 }
58}
59
60impl Bounded for u16 {
61 fn min_value() -> Self {
62 0
63 }
64 fn max_value() -> Self {
65 u16::MAX
66 }
67}
68
69impl Successor for u16 {
70 fn successor(self) -> Self {
71 self.saturating_add(1)
72 }
73}
74
75#[derive(Debug, Clone)]
82struct IntervalMap<K: Ord + Clone + Bounded + Successor> {
83 map: BTreeMap<K, u32>,
85}
86
87impl<K: Ord + Clone + Bounded + Successor> IntervalMap<K> {
88 fn new() -> Self {
89 let mut map = BTreeMap::new();
90 map.insert(K::min_value(), 0);
92 Self { map }
93 }
94
95 fn add_rule(&mut self, first: K, last: K, flags: u32) {
97 if first > last {
98 return;
99 }
100
101 let after_key = last.clone().successor();
104 let flags_after = self.access(&after_key);
105
106 let keys_to_remove: Vec<K> = self
113 .map
114 .range(first.clone()..after_key.clone())
115 .map(|(k, _)| k.clone())
116 .collect();
117 for k in keys_to_remove {
118 self.map.remove(&k);
119 }
120
121 self.map.insert(first, flags);
123
124 if after_key > last {
126 self.map.insert(after_key, flags_after);
127 }
128
129 self.minimize();
130 }
131
132 fn access(&self, key: &K) -> u32 {
134 self.map
135 .range(..=key.clone())
136 .next_back()
137 .map(|(_, &v)| v)
138 .unwrap_or(0)
139 }
140
141 fn minimize(&mut self) {
143 let mut prev_flags: Option<u32> = None;
144 let mut to_remove = Vec::new();
145
146 for (k, &flags) in &self.map {
147 if prev_flags == Some(flags) {
148 to_remove.push(k.clone());
149 }
150 prev_flags = Some(flags);
151 }
152
153 for k in to_remove {
154 self.map.remove(&k);
155 }
156 }
157
158 fn num_ranges(&self) -> usize {
160 let mut count = 0;
162 for &flags in self.map.values() {
163 if flags != 0 {
164 count += 1;
165 }
166 }
167 count
168 }
169
170 fn is_empty(&self) -> bool {
171 self.num_ranges() == 0
172 }
173}
174
175#[derive(Debug, Clone)]
182pub struct IpFilter {
183 v4: IntervalMap<Ipv4Addr>,
184 v6: IntervalMap<Ipv6Addr>,
185}
186
187impl IpFilter {
188 pub fn new() -> Self {
190 Self {
191 v4: IntervalMap::new(),
192 v6: IntervalMap::new(),
193 }
194 }
195
196 pub fn add_rule(&mut self, first: IpAddr, last: IpAddr, flags: u32) {
201 match (first, last) {
202 (IpAddr::V4(f), IpAddr::V4(l)) => self.v4.add_rule(f, l, flags),
203 (IpAddr::V6(f), IpAddr::V6(l)) => self.v6.add_rule(f, l, flags),
204 _ => {} }
206 }
207
208 pub fn access(&self, addr: IpAddr) -> u32 {
210 match addr {
211 IpAddr::V4(ip) => self.v4.access(&ip),
212 IpAddr::V6(ip) => self.v6.access(&ip),
213 }
214 }
215
216 pub fn is_blocked(&self, addr: IpAddr) -> bool {
221 if is_local_network(addr) {
222 return false;
223 }
224 self.access(addr) != 0
225 }
226
227 pub fn num_ranges(&self) -> usize {
229 self.v4.num_ranges() + self.v6.num_ranges()
230 }
231
232 pub fn is_empty(&self) -> bool {
234 self.v4.is_empty() && self.v6.is_empty()
235 }
236}
237
238impl Default for IpFilter {
239 fn default() -> Self {
240 Self::new()
241 }
242}
243
244#[derive(Debug, Clone)]
250pub struct PortFilter {
251 ports: IntervalMap<u16>,
252}
253
254impl PortFilter {
255 pub fn new() -> Self {
257 Self {
258 ports: IntervalMap::new(),
259 }
260 }
261
262 pub fn add_rule(&mut self, first: u16, last: u16, flags: u32) {
264 self.ports.add_rule(first, last, flags);
265 }
266
267 pub fn access(&self, port: u16) -> u32 {
269 self.ports.access(&port)
270 }
271
272 pub fn is_blocked(&self, port: u16) -> bool {
274 self.access(port) != 0
275 }
276}
277
278impl Default for PortFilter {
279 fn default() -> Self {
280 Self::new()
281 }
282}
283
284#[derive(Debug, thiserror::Error)]
288pub enum IpFilterError {
289 #[error("invalid IP address on line {line}: {message}")]
291 InvalidAddress {
292 line: usize,
294 message: String,
296 },
297
298 #[error("malformed line {line}: {message}")]
300 MalformedLine {
301 line: usize,
303 message: String,
305 },
306}
307
308pub fn parse_dat(input: &str) -> Result<IpFilter, IpFilterError> {
313 let mut filter = IpFilter::new();
314
315 for (line_num, line) in input.lines().enumerate() {
316 let line = line.trim();
317 if line.is_empty() || line.starts_with('#') {
318 continue;
319 }
320
321 let parts: Vec<&str> = line.splitn(3, ',').collect();
323 if parts.len() < 2 {
324 return Err(IpFilterError::MalformedLine {
325 line: line_num + 1,
326 message: "expected 'first_ip - last_ip , level , description'".into(),
327 });
328 }
329
330 let ip_range = parts[0].trim();
332 let ips: Vec<&str> = ip_range.splitn(2, '-').collect();
333 if ips.len() != 2 {
334 return Err(IpFilterError::MalformedLine {
335 line: line_num + 1,
336 message: "expected 'first_ip - last_ip'".into(),
337 });
338 }
339
340 let first: IpAddr = ips[0]
341 .trim()
342 .parse()
343 .map_err(
344 |e: std::net::AddrParseError| IpFilterError::InvalidAddress {
345 line: line_num + 1,
346 message: e.to_string(),
347 },
348 )?;
349
350 let last: IpAddr = ips[1]
351 .trim()
352 .parse()
353 .map_err(
354 |e: std::net::AddrParseError| IpFilterError::InvalidAddress {
355 line: line_num + 1,
356 message: e.to_string(),
357 },
358 )?;
359
360 let level: u32 = parts[1]
362 .trim()
363 .parse()
364 .map_err(|_| IpFilterError::MalformedLine {
365 line: line_num + 1,
366 message: "invalid level (expected integer)".into(),
367 })?;
368
369 filter.add_rule(first, last, level);
370 }
371
372 Ok(filter)
373}
374
375pub fn parse_p2p(input: &str) -> Result<IpFilter, IpFilterError> {
380 let mut filter = IpFilter::new();
381
382 for (line_num, line) in input.lines().enumerate() {
383 let line = line.trim();
384 if line.is_empty() || line.starts_with('#') {
385 continue;
386 }
387
388 let colon_pos = line
390 .rfind(':')
391 .ok_or_else(|| IpFilterError::MalformedLine {
392 line: line_num + 1,
393 message: "expected 'description:first_ip-last_ip'".into(),
394 })?;
395
396 let ip_range = &line[colon_pos + 1..];
397 let ips: Vec<&str> = ip_range.splitn(2, '-').collect();
398 if ips.len() != 2 {
399 return Err(IpFilterError::MalformedLine {
400 line: line_num + 1,
401 message: "expected 'first_ip-last_ip' after ':'".into(),
402 });
403 }
404
405 let first: IpAddr = ips[0]
406 .trim()
407 .parse()
408 .map_err(
409 |e: std::net::AddrParseError| IpFilterError::InvalidAddress {
410 line: line_num + 1,
411 message: e.to_string(),
412 },
413 )?;
414
415 let last: IpAddr = ips[1]
416 .trim()
417 .parse()
418 .map_err(
419 |e: std::net::AddrParseError| IpFilterError::InvalidAddress {
420 line: line_num + 1,
421 message: e.to_string(),
422 },
423 )?;
424
425 filter.add_rule(first, last, 1);
427 }
428
429 Ok(filter)
430}
431
432#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
440 fn interval_map_empty_returns_zero() {
441 let map: IntervalMap<Ipv4Addr> = IntervalMap::new();
442 assert_eq!(map.access(&Ipv4Addr::new(0, 0, 0, 0)), 0);
443 assert_eq!(map.access(&Ipv4Addr::new(192, 168, 1, 1)), 0);
444 assert_eq!(map.access(&Ipv4Addr::new(255, 255, 255, 255)), 0);
445 }
446
447 #[test]
449 fn interval_map_single_range() {
450 let mut map: IntervalMap<Ipv4Addr> = IntervalMap::new();
451 map.add_rule(Ipv4Addr::new(10, 0, 0, 0), Ipv4Addr::new(10, 0, 0, 255), 1);
452
453 assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 0)), 1);
455 assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 128)), 1);
456 assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 255)), 1);
457
458 assert_eq!(map.access(&Ipv4Addr::new(9, 255, 255, 255)), 0);
460 assert_eq!(map.access(&Ipv4Addr::new(10, 0, 1, 0)), 0);
461 assert_eq!(map.access(&Ipv4Addr::new(192, 168, 1, 1)), 0);
462 }
463
464 #[test]
466 fn interval_map_overlapping_last_wins() {
467 let mut map: IntervalMap<Ipv4Addr> = IntervalMap::new();
468 map.add_rule(Ipv4Addr::new(10, 0, 0, 0), Ipv4Addr::new(10, 0, 0, 255), 1);
470 map.add_rule(
472 Ipv4Addr::new(10, 0, 0, 100),
473 Ipv4Addr::new(10, 0, 0, 200),
474 0,
475 );
476
477 assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 50)), 1); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 100)), 0); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 150)), 0); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 200)), 0); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 201)), 1); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 255)), 1); }
484
485 #[test]
487 fn ip_filter_v4_block_range() {
488 let mut filter = IpFilter::new();
489 filter.add_rule(
490 IpAddr::V4(Ipv4Addr::new(203, 0, 113, 0)),
491 IpAddr::V4(Ipv4Addr::new(203, 0, 113, 255)),
492 1,
493 );
494
495 assert!(filter.is_blocked("203.0.113.0".parse().unwrap()));
497 assert!(filter.is_blocked("203.0.113.128".parse().unwrap()));
498 assert!(filter.is_blocked("203.0.113.255".parse().unwrap()));
499
500 assert!(!filter.is_blocked("203.0.112.255".parse().unwrap()));
502 assert!(!filter.is_blocked("203.0.114.0".parse().unwrap()));
503 assert!(!filter.is_blocked("8.8.8.8".parse().unwrap()));
504 }
505
506 #[test]
508 fn ip_filter_v6_block_range() {
509 let mut filter = IpFilter::new();
510 filter.add_rule(
511 IpAddr::V6("2001:db8::0".parse().unwrap()),
512 IpAddr::V6("2001:db8::ffff".parse().unwrap()),
513 1,
514 );
515
516 assert!(filter.is_blocked("2001:db8::1".parse().unwrap()));
517 assert!(filter.is_blocked("2001:db8::ff".parse().unwrap()));
518 assert!(!filter.is_blocked("2001:db9::1".parse().unwrap()));
519 }
520
521 #[test]
523 fn ip_filter_local_network_exempt() {
524 let mut filter = IpFilter::new();
525 filter.add_rule(
527 IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
528 IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)),
529 1,
530 );
531 filter.add_rule(
532 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
533 IpAddr::V6(Ipv6Addr::new(
534 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
535 )),
536 1,
537 );
538
539 assert!(!filter.is_blocked("127.0.0.1".parse().unwrap()));
541 assert!(!filter.is_blocked("192.168.1.1".parse().unwrap()));
542 assert!(!filter.is_blocked("10.0.0.1".parse().unwrap()));
543 assert!(!filter.is_blocked("172.16.0.1".parse().unwrap()));
544 assert!(!filter.is_blocked("::1".parse().unwrap()));
545
546 assert_eq!(filter.access("127.0.0.1".parse().unwrap()), 1);
548
549 assert!(filter.is_blocked("8.8.8.8".parse().unwrap()));
551 assert!(filter.is_blocked("2001:db8::1".parse().unwrap()));
552 }
553
554 #[test]
556 fn ip_filter_num_ranges() {
557 let mut filter = IpFilter::new();
558 assert_eq!(filter.num_ranges(), 0);
559 assert!(filter.is_empty());
560
561 filter.add_rule(
562 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)),
563 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 255)),
564 1,
565 );
566 assert_eq!(filter.num_ranges(), 1);
567 assert!(!filter.is_empty());
568
569 filter.add_rule(
570 IpAddr::V4(Ipv4Addr::new(172, 16, 0, 0)),
571 IpAddr::V4(Ipv4Addr::new(172, 16, 255, 255)),
572 1,
573 );
574 assert_eq!(filter.num_ranges(), 2);
575
576 filter.add_rule(
578 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)),
579 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 255)),
580 0,
581 );
582 assert_eq!(filter.num_ranges(), 1);
584 }
585
586 #[test]
588 fn parse_dat_valid() {
589 let input = "\
590# This is a comment
591203.0.113.0 - 203.0.113.255 , 128 , Test range
592198.51.100.0 - 198.51.100.255 , 1 , Another range
593";
594 let filter = parse_dat(input).unwrap();
595 assert!(filter.is_blocked("203.0.113.50".parse().unwrap()));
596 assert!(filter.is_blocked("198.51.100.1".parse().unwrap()));
597 assert!(!filter.is_blocked("8.8.8.8".parse().unwrap()));
598 }
599
600 #[test]
601 fn parse_dat_malformed() {
602 let input = "this is not a valid line";
603 let err = parse_dat(input).unwrap_err();
604 assert!(matches!(err, IpFilterError::MalformedLine { line: 1, .. }));
605 }
606
607 #[test]
609 fn parse_p2p_valid() {
610 let input = "\
611# P2P blocklist
612Some Bad Range:203.0.113.0-203.0.113.255
613Another Range:198.51.100.0-198.51.100.255
614";
615 let filter = parse_p2p(input).unwrap();
616 assert!(filter.is_blocked("203.0.113.50".parse().unwrap()));
617 assert!(filter.is_blocked("198.51.100.1".parse().unwrap()));
618 assert!(!filter.is_blocked("8.8.8.8".parse().unwrap()));
619 }
620
621 #[test]
622 fn parse_p2p_invalid_ip() {
623 let input = "Bad Range:999.999.999.999-203.0.113.255";
624 let err = parse_p2p(input).unwrap_err();
625 assert!(matches!(err, IpFilterError::InvalidAddress { line: 1, .. }));
626 }
627
628 #[test]
630 fn port_filter_block_range() {
631 let mut filter = PortFilter::new();
632 filter.add_rule(6881, 6889, 1);
633
634 assert!(filter.is_blocked(6881));
635 assert!(filter.is_blocked(6885));
636 assert!(filter.is_blocked(6889));
637 assert!(!filter.is_blocked(6880));
638 assert!(!filter.is_blocked(6890));
639 assert!(!filter.is_blocked(80));
640 }
641}