Skip to main content

irontide_session/
ip_filter.rs

1//! IP and port filtering using sorted interval maps.
2//!
3//! Provides [`IpFilter`] for blocking peer connections by IP address range,
4//! and [`PortFilter`] for blocking by port range. Supports eMule `.dat` and
5//! P2P plaintext blocklist file formats.
6
7use std::collections::BTreeMap;
8use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
9
10use crate::rate_limiter::is_local_network;
11
12// ── Helper traits ────────────────────────────────────────────────────
13
14/// Types that have a minimum and maximum value.
15trait Bounded {
16    fn min_value() -> Self;
17    #[allow(dead_code)]
18    fn max_value() -> Self;
19}
20
21/// Types that can produce a successor (saturating).
22trait Successor {
23    fn successor(self) -> Self;
24}
25
26impl Bounded for Ipv4Addr {
27    fn min_value() -> Self {
28        Ipv4Addr::new(0, 0, 0, 0)
29    }
30    fn max_value() -> Self {
31        Ipv4Addr::new(255, 255, 255, 255)
32    }
33}
34
35impl Successor for Ipv4Addr {
36    fn successor(self) -> Self {
37        let n: u32 = self.into();
38        Ipv4Addr::from(n.saturating_add(1))
39    }
40}
41
42impl Bounded for Ipv6Addr {
43    fn min_value() -> Self {
44        Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)
45    }
46    fn max_value() -> Self {
47        Ipv6Addr::new(
48            0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
49        )
50    }
51}
52
53impl Successor for Ipv6Addr {
54    fn successor(self) -> Self {
55        let n: u128 = self.into();
56        Ipv6Addr::from(n.saturating_add(1))
57    }
58}
59
60impl Bounded for u16 {
61    fn min_value() -> Self {
62        0
63    }
64    fn max_value() -> Self {
65        u16::MAX
66    }
67}
68
69impl Successor for u16 {
70    fn successor(self) -> Self {
71        self.saturating_add(1)
72    }
73}
74
75// ── IntervalMap ──────────────────────────────────────────────────────
76
77/// A sorted interval map where each entry means "from this key onward, flags
78/// are this value". The entire key space defaults to flags=0 (allowed).
79///
80/// `add_rule` applies last-applied-wins semantics for overlapping ranges.
81#[derive(Debug, Clone)]
82struct IntervalMap<K: Ord + Clone + Bounded + Successor> {
83    /// Sorted breakpoints: from this key onward, flags are the stored value.
84    map: BTreeMap<K, u32>,
85}
86
87impl<K: Ord + Clone + Bounded + Successor> IntervalMap<K> {
88    fn new() -> Self {
89        let mut map = BTreeMap::new();
90        // Entire space starts at 0 (allowed)
91        map.insert(K::min_value(), 0);
92        Self { map }
93    }
94
95    /// Set flags for the range `[first, last]`.
96    fn add_rule(&mut self, first: K, last: K, flags: u32) {
97        if first > last {
98            return;
99        }
100
101        // Save the flags that were in effect at `last.successor()` before we modify anything,
102        // so we can restore them after the range.
103        let after_key = last.clone().successor();
104        let flags_after = self.access(&after_key);
105
106        // Save the flags that were in effect just before `first`.
107        // We need this in case first == K::min_value().
108        // Actually we just need to set the breakpoint at `first` to `flags`.
109
110        // Remove all breakpoints strictly between first (exclusive) and after_key (exclusive)
111        // We collect keys to remove to avoid borrowing issues
112        let keys_to_remove: Vec<K> = self
113            .map
114            .range(first.clone()..after_key.clone())
115            .map(|(k, _)| k.clone())
116            .collect();
117        for k in keys_to_remove {
118            self.map.remove(&k);
119        }
120
121        // Set the start of our range
122        self.map.insert(first, flags);
123
124        // Restore the flags after our range (only if after_key is still in bounds)
125        if after_key > last {
126            self.map.insert(after_key, flags_after);
127        }
128
129        self.minimize();
130    }
131
132    /// Look up flags for a key. O(log n).
133    fn access(&self, key: &K) -> u32 {
134        self.map
135            .range(..=key.clone())
136            .next_back()
137            .map(|(_, &v)| v)
138            .unwrap_or(0)
139    }
140
141    /// Remove consecutive entries with the same flags.
142    fn minimize(&mut self) {
143        let mut prev_flags: Option<u32> = None;
144        let mut to_remove = Vec::new();
145
146        for (k, &flags) in &self.map {
147            if prev_flags == Some(flags) {
148                to_remove.push(k.clone());
149            }
150            prev_flags = Some(flags);
151        }
152
153        for k in to_remove {
154            self.map.remove(&k);
155        }
156    }
157
158    /// Number of breakpoints in the map.
159    fn num_ranges(&self) -> usize {
160        // Count segments with non-zero flags
161        let mut count = 0;
162        for &flags in self.map.values() {
163            if flags != 0 {
164                count += 1;
165            }
166        }
167        count
168    }
169
170    fn is_empty(&self) -> bool {
171        self.num_ranges() == 0
172    }
173}
174
175// ── IpFilter ─────────────────────────────────────────────────────────
176
177/// IP address filter supporting both IPv4 and IPv6 ranges.
178///
179/// Flags: 0 = allowed, non-zero = blocked.
180/// Local/private network addresses are always exempt from filtering.
181#[derive(Debug, Clone)]
182pub struct IpFilter {
183    v4: IntervalMap<Ipv4Addr>,
184    v6: IntervalMap<Ipv6Addr>,
185}
186
187impl IpFilter {
188    /// Create a new filter that allows everything.
189    pub fn new() -> Self {
190        Self {
191            v4: IntervalMap::new(),
192            v6: IntervalMap::new(),
193        }
194    }
195
196    /// Add a rule blocking (or allowing) a range of IP addresses.
197    ///
198    /// Both endpoints must be the same address family (both v4 or both v6).
199    /// Mixed families are silently ignored.
200    pub fn add_rule(&mut self, first: IpAddr, last: IpAddr, flags: u32) {
201        match (first, last) {
202            (IpAddr::V4(f), IpAddr::V4(l)) => self.v4.add_rule(f, l, flags),
203            (IpAddr::V6(f), IpAddr::V6(l)) => self.v6.add_rule(f, l, flags),
204            _ => {} // mixed families: ignore
205        }
206    }
207
208    /// Return the flags for an address. 0 = allowed.
209    pub fn access(&self, addr: IpAddr) -> u32 {
210        match addr {
211            IpAddr::V4(ip) => self.v4.access(&ip),
212            IpAddr::V6(ip) => self.v6.access(&ip),
213        }
214    }
215
216    /// Check if an address is blocked by the filter.
217    ///
218    /// Local/private network addresses (RFC 1918, loopback, link-local) are
219    /// always exempt and return `false` even if they fall within a blocked range.
220    pub fn is_blocked(&self, addr: IpAddr) -> bool {
221        if is_local_network(addr) {
222            return false;
223        }
224        self.access(addr) != 0
225    }
226
227    /// Total number of non-zero-flag ranges across both address families.
228    pub fn num_ranges(&self) -> usize {
229        self.v4.num_ranges() + self.v6.num_ranges()
230    }
231
232    /// True if no rules have been added.
233    pub fn is_empty(&self) -> bool {
234        self.v4.is_empty() && self.v6.is_empty()
235    }
236}
237
238impl Default for IpFilter {
239    fn default() -> Self {
240        Self::new()
241    }
242}
243
244// ── PortFilter ───────────────────────────────────────────────────────
245
246/// Port range filter.
247///
248/// Flags: 0 = allowed, non-zero = blocked.
249#[derive(Debug, Clone)]
250pub struct PortFilter {
251    ports: IntervalMap<u16>,
252}
253
254impl PortFilter {
255    /// Create a new filter that allows all ports.
256    pub fn new() -> Self {
257        Self {
258            ports: IntervalMap::new(),
259        }
260    }
261
262    /// Add a rule for a port range.
263    pub fn add_rule(&mut self, first: u16, last: u16, flags: u32) {
264        self.ports.add_rule(first, last, flags);
265    }
266
267    /// Return the flags for a port. 0 = allowed.
268    pub fn access(&self, port: u16) -> u32 {
269        self.ports.access(&port)
270    }
271
272    /// Check if a port is blocked.
273    pub fn is_blocked(&self, port: u16) -> bool {
274        self.access(port) != 0
275    }
276}
277
278impl Default for PortFilter {
279    fn default() -> Self {
280        Self::new()
281    }
282}
283
284// ── File Parsers ─────────────────────────────────────────────────────
285
286/// Errors from parsing IP filter files.
287#[derive(Debug, thiserror::Error)]
288pub enum IpFilterError {
289    /// An IP address could not be parsed.
290    #[error("invalid IP address on line {line}: {message}")]
291    InvalidAddress {
292        /// One-based line number in the filter file.
293        line: usize,
294        /// Parse error description.
295        message: String,
296    },
297
298    /// A line could not be parsed (wrong number of fields, etc.).
299    #[error("malformed line {line}: {message}")]
300    MalformedLine {
301        /// One-based line number in the filter file.
302        line: usize,
303        /// Description of the formatting problem.
304        message: String,
305    },
306}
307
308/// Parse an eMule `.dat` format blocklist.
309///
310/// Format: `first_ip - last_ip , level , description`
311/// Lines starting with `#` are comments.
312pub fn parse_dat(input: &str) -> Result<IpFilter, IpFilterError> {
313    let mut filter = IpFilter::new();
314
315    for (line_num, line) in input.lines().enumerate() {
316        let line = line.trim();
317        if line.is_empty() || line.starts_with('#') {
318            continue;
319        }
320
321        // Split on comma to get: "first_ip - last_ip", "level", "description"
322        let parts: Vec<&str> = line.splitn(3, ',').collect();
323        if parts.len() < 2 {
324            return Err(IpFilterError::MalformedLine {
325                line: line_num + 1,
326                message: "expected 'first_ip - last_ip , level , description'".into(),
327            });
328        }
329
330        // Parse IP range
331        let ip_range = parts[0].trim();
332        let ips: Vec<&str> = ip_range.splitn(2, '-').collect();
333        if ips.len() != 2 {
334            return Err(IpFilterError::MalformedLine {
335                line: line_num + 1,
336                message: "expected 'first_ip - last_ip'".into(),
337            });
338        }
339
340        let first: IpAddr = ips[0]
341            .trim()
342            .parse()
343            .map_err(
344                |e: std::net::AddrParseError| IpFilterError::InvalidAddress {
345                    line: line_num + 1,
346                    message: e.to_string(),
347                },
348            )?;
349
350        let last: IpAddr = ips[1]
351            .trim()
352            .parse()
353            .map_err(
354                |e: std::net::AddrParseError| IpFilterError::InvalidAddress {
355                    line: line_num + 1,
356                    message: e.to_string(),
357                },
358            )?;
359
360        // Parse level (flags)
361        let level: u32 = parts[1]
362            .trim()
363            .parse()
364            .map_err(|_| IpFilterError::MalformedLine {
365                line: line_num + 1,
366                message: "invalid level (expected integer)".into(),
367            })?;
368
369        filter.add_rule(first, last, level);
370    }
371
372    Ok(filter)
373}
374
375/// Parse a P2P plaintext format blocklist.
376///
377/// Format: `description:first_ip-last_ip`
378/// Lines starting with `#` are comments.
379pub fn parse_p2p(input: &str) -> Result<IpFilter, IpFilterError> {
380    let mut filter = IpFilter::new();
381
382    for (line_num, line) in input.lines().enumerate() {
383        let line = line.trim();
384        if line.is_empty() || line.starts_with('#') {
385            continue;
386        }
387
388        // Split on last ':' to separate description from IP range
389        let colon_pos = line
390            .rfind(':')
391            .ok_or_else(|| IpFilterError::MalformedLine {
392                line: line_num + 1,
393                message: "expected 'description:first_ip-last_ip'".into(),
394            })?;
395
396        let ip_range = &line[colon_pos + 1..];
397        let ips: Vec<&str> = ip_range.splitn(2, '-').collect();
398        if ips.len() != 2 {
399            return Err(IpFilterError::MalformedLine {
400                line: line_num + 1,
401                message: "expected 'first_ip-last_ip' after ':'".into(),
402            });
403        }
404
405        let first: IpAddr = ips[0]
406            .trim()
407            .parse()
408            .map_err(
409                |e: std::net::AddrParseError| IpFilterError::InvalidAddress {
410                    line: line_num + 1,
411                    message: e.to_string(),
412                },
413            )?;
414
415        let last: IpAddr = ips[1]
416            .trim()
417            .parse()
418            .map_err(
419                |e: std::net::AddrParseError| IpFilterError::InvalidAddress {
420                    line: line_num + 1,
421                    message: e.to_string(),
422                },
423            )?;
424
425        // P2P format always blocks (flags=1)
426        filter.add_rule(first, last, 1);
427    }
428
429    Ok(filter)
430}
431
432// ── Tests ────────────────────────────────────────────────────────────
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    // Test 1: IntervalMap: empty returns allowed for any key
439    #[test]
440    fn interval_map_empty_returns_zero() {
441        let map: IntervalMap<Ipv4Addr> = IntervalMap::new();
442        assert_eq!(map.access(&Ipv4Addr::new(0, 0, 0, 0)), 0);
443        assert_eq!(map.access(&Ipv4Addr::new(192, 168, 1, 1)), 0);
444        assert_eq!(map.access(&Ipv4Addr::new(255, 255, 255, 255)), 0);
445    }
446
447    // Test 2: IntervalMap: single range add + lookup inside/outside
448    #[test]
449    fn interval_map_single_range() {
450        let mut map: IntervalMap<Ipv4Addr> = IntervalMap::new();
451        map.add_rule(Ipv4Addr::new(10, 0, 0, 0), Ipv4Addr::new(10, 0, 0, 255), 1);
452
453        // Inside range
454        assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 0)), 1);
455        assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 128)), 1);
456        assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 255)), 1);
457
458        // Outside range
459        assert_eq!(map.access(&Ipv4Addr::new(9, 255, 255, 255)), 0);
460        assert_eq!(map.access(&Ipv4Addr::new(10, 0, 1, 0)), 0);
461        assert_eq!(map.access(&Ipv4Addr::new(192, 168, 1, 1)), 0);
462    }
463
464    // Test 3: IntervalMap: overlapping ranges — last-applied-wins
465    #[test]
466    fn interval_map_overlapping_last_wins() {
467        let mut map: IntervalMap<Ipv4Addr> = IntervalMap::new();
468        // Block 10.0.0.0 - 10.0.0.255 with flags=1
469        map.add_rule(Ipv4Addr::new(10, 0, 0, 0), Ipv4Addr::new(10, 0, 0, 255), 1);
470        // Allow 10.0.0.100 - 10.0.0.200 with flags=0 (override)
471        map.add_rule(
472            Ipv4Addr::new(10, 0, 0, 100),
473            Ipv4Addr::new(10, 0, 0, 200),
474            0,
475        );
476
477        assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 50)), 1); // still blocked
478        assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 100)), 0); // allowed (override)
479        assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 150)), 0); // allowed (override)
480        assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 200)), 0); // allowed (override)
481        assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 201)), 1); // blocked again
482        assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 255)), 1); // blocked
483    }
484
485    // Test 4: IpFilter IPv4: block /24, verify access inside/outside
486    #[test]
487    fn ip_filter_v4_block_range() {
488        let mut filter = IpFilter::new();
489        filter.add_rule(
490            IpAddr::V4(Ipv4Addr::new(203, 0, 113, 0)),
491            IpAddr::V4(Ipv4Addr::new(203, 0, 113, 255)),
492            1,
493        );
494
495        // Inside blocked range (public IPs, not local)
496        assert!(filter.is_blocked("203.0.113.0".parse().unwrap()));
497        assert!(filter.is_blocked("203.0.113.128".parse().unwrap()));
498        assert!(filter.is_blocked("203.0.113.255".parse().unwrap()));
499
500        // Outside
501        assert!(!filter.is_blocked("203.0.112.255".parse().unwrap()));
502        assert!(!filter.is_blocked("203.0.114.0".parse().unwrap()));
503        assert!(!filter.is_blocked("8.8.8.8".parse().unwrap()));
504    }
505
506    // Test 5: IpFilter IPv6: block range, verify access
507    #[test]
508    fn ip_filter_v6_block_range() {
509        let mut filter = IpFilter::new();
510        filter.add_rule(
511            IpAddr::V6("2001:db8::0".parse().unwrap()),
512            IpAddr::V6("2001:db8::ffff".parse().unwrap()),
513            1,
514        );
515
516        assert!(filter.is_blocked("2001:db8::1".parse().unwrap()));
517        assert!(filter.is_blocked("2001:db8::ff".parse().unwrap()));
518        assert!(!filter.is_blocked("2001:db9::1".parse().unwrap()));
519    }
520
521    // Test 6: Local network exemption: blocked range doesn't affect RFC 1918/loopback
522    #[test]
523    fn ip_filter_local_network_exempt() {
524        let mut filter = IpFilter::new();
525        // Block everything
526        filter.add_rule(
527            IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
528            IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)),
529            1,
530        );
531        filter.add_rule(
532            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
533            IpAddr::V6(Ipv6Addr::new(
534                0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
535            )),
536            1,
537        );
538
539        // Local IPs are exempt
540        assert!(!filter.is_blocked("127.0.0.1".parse().unwrap()));
541        assert!(!filter.is_blocked("192.168.1.1".parse().unwrap()));
542        assert!(!filter.is_blocked("10.0.0.1".parse().unwrap()));
543        assert!(!filter.is_blocked("172.16.0.1".parse().unwrap()));
544        assert!(!filter.is_blocked("::1".parse().unwrap()));
545
546        // But the raw access() still shows blocked
547        assert_eq!(filter.access("127.0.0.1".parse().unwrap()), 1);
548
549        // Public IPs are blocked
550        assert!(filter.is_blocked("8.8.8.8".parse().unwrap()));
551        assert!(filter.is_blocked("2001:db8::1".parse().unwrap()));
552    }
553
554    // Test 7: Export: minimized non-overlapping ranges
555    #[test]
556    fn ip_filter_num_ranges() {
557        let mut filter = IpFilter::new();
558        assert_eq!(filter.num_ranges(), 0);
559        assert!(filter.is_empty());
560
561        filter.add_rule(
562            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)),
563            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 255)),
564            1,
565        );
566        assert_eq!(filter.num_ranges(), 1);
567        assert!(!filter.is_empty());
568
569        filter.add_rule(
570            IpAddr::V4(Ipv4Addr::new(172, 16, 0, 0)),
571            IpAddr::V4(Ipv4Addr::new(172, 16, 255, 255)),
572            1,
573        );
574        assert_eq!(filter.num_ranges(), 2);
575
576        // Adding overlapping range that allows part of first range
577        filter.add_rule(
578            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)),
579            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 255)),
580            0,
581        );
582        // First range is now allowed, so num_ranges drops
583        assert_eq!(filter.num_ranges(), 1);
584    }
585
586    // Test 8: parse_dat: valid lines, comments, malformed line error
587    #[test]
588    fn parse_dat_valid() {
589        let input = "\
590# This is a comment
591203.0.113.0 - 203.0.113.255 , 128 , Test range
592198.51.100.0 - 198.51.100.255 , 1 , Another range
593";
594        let filter = parse_dat(input).unwrap();
595        assert!(filter.is_blocked("203.0.113.50".parse().unwrap()));
596        assert!(filter.is_blocked("198.51.100.1".parse().unwrap()));
597        assert!(!filter.is_blocked("8.8.8.8".parse().unwrap()));
598    }
599
600    #[test]
601    fn parse_dat_malformed() {
602        let input = "this is not a valid line";
603        let err = parse_dat(input).unwrap_err();
604        assert!(matches!(err, IpFilterError::MalformedLine { line: 1, .. }));
605    }
606
607    // Test 9: parse_p2p: valid lines, comments, invalid IP error
608    #[test]
609    fn parse_p2p_valid() {
610        let input = "\
611# P2P blocklist
612Some Bad Range:203.0.113.0-203.0.113.255
613Another Range:198.51.100.0-198.51.100.255
614";
615        let filter = parse_p2p(input).unwrap();
616        assert!(filter.is_blocked("203.0.113.50".parse().unwrap()));
617        assert!(filter.is_blocked("198.51.100.1".parse().unwrap()));
618        assert!(!filter.is_blocked("8.8.8.8".parse().unwrap()));
619    }
620
621    #[test]
622    fn parse_p2p_invalid_ip() {
623        let input = "Bad Range:999.999.999.999-203.0.113.255";
624        let err = parse_p2p(input).unwrap_err();
625        assert!(matches!(err, IpFilterError::InvalidAddress { line: 1, .. }));
626    }
627
628    // Test 10: PortFilter: block port range, verify access
629    #[test]
630    fn port_filter_block_range() {
631        let mut filter = PortFilter::new();
632        filter.add_rule(6881, 6889, 1);
633
634        assert!(filter.is_blocked(6881));
635        assert!(filter.is_blocked(6885));
636        assert!(filter.is_blocked(6889));
637        assert!(!filter.is_blocked(6880));
638        assert!(!filter.is_blocked(6890));
639        assert!(!filter.is_blocked(80));
640    }
641}