use std::collections::BTreeMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use crate::rate_limiter::is_local_network;
trait Bounded {
fn min_value() -> Self;
#[allow(dead_code)]
fn max_value() -> Self;
}
trait Successor {
fn successor(self) -> Self;
}
impl Bounded for Ipv4Addr {
fn min_value() -> Self {
Ipv4Addr::new(0, 0, 0, 0)
}
fn max_value() -> Self {
Ipv4Addr::new(255, 255, 255, 255)
}
}
impl Successor for Ipv4Addr {
fn successor(self) -> Self {
let n: u32 = self.into();
Ipv4Addr::from(n.saturating_add(1))
}
}
impl Bounded for Ipv6Addr {
fn min_value() -> Self {
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)
}
fn max_value() -> Self {
Ipv6Addr::new(
0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
)
}
}
impl Successor for Ipv6Addr {
fn successor(self) -> Self {
let n: u128 = self.into();
Ipv6Addr::from(n.saturating_add(1))
}
}
impl Bounded for u16 {
fn min_value() -> Self {
0
}
fn max_value() -> Self {
u16::MAX
}
}
impl Successor for u16 {
fn successor(self) -> Self {
self.saturating_add(1)
}
}
#[derive(Debug, Clone)]
struct IntervalMap<K: Ord + Clone + Bounded + Successor> {
map: BTreeMap<K, u32>,
}
impl<K: Ord + Clone + Bounded + Successor> IntervalMap<K> {
fn new() -> Self {
let mut map = BTreeMap::new();
map.insert(K::min_value(), 0);
Self { map }
}
fn add_rule(&mut self, first: K, last: K, flags: u32) {
if first > last {
return;
}
let after_key = last.clone().successor();
let flags_after = self.access(&after_key);
let keys_to_remove: Vec<K> = self
.map
.range(first.clone()..after_key.clone())
.map(|(k, _)| k.clone())
.collect();
for k in keys_to_remove {
self.map.remove(&k);
}
self.map.insert(first, flags);
if after_key > last {
self.map.insert(after_key, flags_after);
}
self.minimize();
}
fn access(&self, key: &K) -> u32 {
self.map
.range(..=key.clone())
.next_back()
.map(|(_, &v)| v)
.unwrap_or(0)
}
fn minimize(&mut self) {
let mut prev_flags: Option<u32> = None;
let mut to_remove = Vec::new();
for (k, &flags) in &self.map {
if prev_flags == Some(flags) {
to_remove.push(k.clone());
}
prev_flags = Some(flags);
}
for k in to_remove {
self.map.remove(&k);
}
}
fn num_ranges(&self) -> usize {
let mut count = 0;
for &flags in self.map.values() {
if flags != 0 {
count += 1;
}
}
count
}
fn is_empty(&self) -> bool {
self.num_ranges() == 0
}
}
#[derive(Debug, Clone)]
pub struct IpFilter {
v4: IntervalMap<Ipv4Addr>,
v6: IntervalMap<Ipv6Addr>,
}
impl IpFilter {
pub fn new() -> Self {
Self {
v4: IntervalMap::new(),
v6: IntervalMap::new(),
}
}
pub fn add_rule(&mut self, first: IpAddr, last: IpAddr, flags: u32) {
match (first, last) {
(IpAddr::V4(f), IpAddr::V4(l)) => self.v4.add_rule(f, l, flags),
(IpAddr::V6(f), IpAddr::V6(l)) => self.v6.add_rule(f, l, flags),
_ => {} }
}
pub fn access(&self, addr: IpAddr) -> u32 {
match addr {
IpAddr::V4(ip) => self.v4.access(&ip),
IpAddr::V6(ip) => self.v6.access(&ip),
}
}
pub fn is_blocked(&self, addr: IpAddr) -> bool {
if is_local_network(addr) {
return false;
}
self.access(addr) != 0
}
pub fn num_ranges(&self) -> usize {
self.v4.num_ranges() + self.v6.num_ranges()
}
pub fn is_empty(&self) -> bool {
self.v4.is_empty() && self.v6.is_empty()
}
}
impl Default for IpFilter {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct PortFilter {
ports: IntervalMap<u16>,
}
impl PortFilter {
pub fn new() -> Self {
Self {
ports: IntervalMap::new(),
}
}
pub fn add_rule(&mut self, first: u16, last: u16, flags: u32) {
self.ports.add_rule(first, last, flags);
}
pub fn access(&self, port: u16) -> u32 {
self.ports.access(&port)
}
pub fn is_blocked(&self, port: u16) -> bool {
self.access(port) != 0
}
}
impl Default for PortFilter {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, thiserror::Error)]
pub enum IpFilterError {
#[error("invalid IP address on line {line}: {message}")]
InvalidAddress {
line: usize,
message: String,
},
#[error("malformed line {line}: {message}")]
MalformedLine {
line: usize,
message: String,
},
}
pub fn parse_dat(input: &str) -> Result<IpFilter, IpFilterError> {
let mut filter = IpFilter::new();
for (line_num, line) in input.lines().enumerate() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let parts: Vec<&str> = line.splitn(3, ',').collect();
if parts.len() < 2 {
return Err(IpFilterError::MalformedLine {
line: line_num + 1,
message: "expected 'first_ip - last_ip , level , description'".into(),
});
}
let ip_range = parts[0].trim();
let ips: Vec<&str> = ip_range.splitn(2, '-').collect();
if ips.len() != 2 {
return Err(IpFilterError::MalformedLine {
line: line_num + 1,
message: "expected 'first_ip - last_ip'".into(),
});
}
let first: IpAddr = ips[0]
.trim()
.parse()
.map_err(
|e: std::net::AddrParseError| IpFilterError::InvalidAddress {
line: line_num + 1,
message: e.to_string(),
},
)?;
let last: IpAddr = ips[1]
.trim()
.parse()
.map_err(
|e: std::net::AddrParseError| IpFilterError::InvalidAddress {
line: line_num + 1,
message: e.to_string(),
},
)?;
let level: u32 = parts[1]
.trim()
.parse()
.map_err(|_| IpFilterError::MalformedLine {
line: line_num + 1,
message: "invalid level (expected integer)".into(),
})?;
filter.add_rule(first, last, level);
}
Ok(filter)
}
pub fn parse_p2p(input: &str) -> Result<IpFilter, IpFilterError> {
let mut filter = IpFilter::new();
for (line_num, line) in input.lines().enumerate() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let colon_pos = line
.rfind(':')
.ok_or_else(|| IpFilterError::MalformedLine {
line: line_num + 1,
message: "expected 'description:first_ip-last_ip'".into(),
})?;
let ip_range = &line[colon_pos + 1..];
let ips: Vec<&str> = ip_range.splitn(2, '-').collect();
if ips.len() != 2 {
return Err(IpFilterError::MalformedLine {
line: line_num + 1,
message: "expected 'first_ip-last_ip' after ':'".into(),
});
}
let first: IpAddr = ips[0]
.trim()
.parse()
.map_err(
|e: std::net::AddrParseError| IpFilterError::InvalidAddress {
line: line_num + 1,
message: e.to_string(),
},
)?;
let last: IpAddr = ips[1]
.trim()
.parse()
.map_err(
|e: std::net::AddrParseError| IpFilterError::InvalidAddress {
line: line_num + 1,
message: e.to_string(),
},
)?;
filter.add_rule(first, last, 1);
}
Ok(filter)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn interval_map_empty_returns_zero() {
let map: IntervalMap<Ipv4Addr> = IntervalMap::new();
assert_eq!(map.access(&Ipv4Addr::new(0, 0, 0, 0)), 0);
assert_eq!(map.access(&Ipv4Addr::new(192, 168, 1, 1)), 0);
assert_eq!(map.access(&Ipv4Addr::new(255, 255, 255, 255)), 0);
}
#[test]
fn interval_map_single_range() {
let mut map: IntervalMap<Ipv4Addr> = IntervalMap::new();
map.add_rule(Ipv4Addr::new(10, 0, 0, 0), Ipv4Addr::new(10, 0, 0, 255), 1);
assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 0)), 1);
assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 128)), 1);
assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 255)), 1);
assert_eq!(map.access(&Ipv4Addr::new(9, 255, 255, 255)), 0);
assert_eq!(map.access(&Ipv4Addr::new(10, 0, 1, 0)), 0);
assert_eq!(map.access(&Ipv4Addr::new(192, 168, 1, 1)), 0);
}
#[test]
fn interval_map_overlapping_last_wins() {
let mut map: IntervalMap<Ipv4Addr> = IntervalMap::new();
map.add_rule(Ipv4Addr::new(10, 0, 0, 0), Ipv4Addr::new(10, 0, 0, 255), 1);
map.add_rule(
Ipv4Addr::new(10, 0, 0, 100),
Ipv4Addr::new(10, 0, 0, 200),
0,
);
assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 50)), 1); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 100)), 0); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 150)), 0); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 200)), 0); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 201)), 1); assert_eq!(map.access(&Ipv4Addr::new(10, 0, 0, 255)), 1); }
#[test]
fn ip_filter_v4_block_range() {
let mut filter = IpFilter::new();
filter.add_rule(
IpAddr::V4(Ipv4Addr::new(203, 0, 113, 0)),
IpAddr::V4(Ipv4Addr::new(203, 0, 113, 255)),
1,
);
assert!(filter.is_blocked("203.0.113.0".parse().unwrap()));
assert!(filter.is_blocked("203.0.113.128".parse().unwrap()));
assert!(filter.is_blocked("203.0.113.255".parse().unwrap()));
assert!(!filter.is_blocked("203.0.112.255".parse().unwrap()));
assert!(!filter.is_blocked("203.0.114.0".parse().unwrap()));
assert!(!filter.is_blocked("8.8.8.8".parse().unwrap()));
}
#[test]
fn ip_filter_v6_block_range() {
let mut filter = IpFilter::new();
filter.add_rule(
IpAddr::V6("2001:db8::0".parse().unwrap()),
IpAddr::V6("2001:db8::ffff".parse().unwrap()),
1,
);
assert!(filter.is_blocked("2001:db8::1".parse().unwrap()));
assert!(filter.is_blocked("2001:db8::ff".parse().unwrap()));
assert!(!filter.is_blocked("2001:db9::1".parse().unwrap()));
}
#[test]
fn ip_filter_local_network_exempt() {
let mut filter = IpFilter::new();
filter.add_rule(
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)),
1,
);
filter.add_rule(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
IpAddr::V6(Ipv6Addr::new(
0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
)),
1,
);
assert!(!filter.is_blocked("127.0.0.1".parse().unwrap()));
assert!(!filter.is_blocked("192.168.1.1".parse().unwrap()));
assert!(!filter.is_blocked("10.0.0.1".parse().unwrap()));
assert!(!filter.is_blocked("172.16.0.1".parse().unwrap()));
assert!(!filter.is_blocked("::1".parse().unwrap()));
assert_eq!(filter.access("127.0.0.1".parse().unwrap()), 1);
assert!(filter.is_blocked("8.8.8.8".parse().unwrap()));
assert!(filter.is_blocked("2001:db8::1".parse().unwrap()));
}
#[test]
fn ip_filter_num_ranges() {
let mut filter = IpFilter::new();
assert_eq!(filter.num_ranges(), 0);
assert!(filter.is_empty());
filter.add_rule(
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 255)),
1,
);
assert_eq!(filter.num_ranges(), 1);
assert!(!filter.is_empty());
filter.add_rule(
IpAddr::V4(Ipv4Addr::new(172, 16, 0, 0)),
IpAddr::V4(Ipv4Addr::new(172, 16, 255, 255)),
1,
);
assert_eq!(filter.num_ranges(), 2);
filter.add_rule(
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 255)),
0,
);
assert_eq!(filter.num_ranges(), 1);
}
#[test]
fn parse_dat_valid() {
let input = "\
# This is a comment
203.0.113.0 - 203.0.113.255 , 128 , Test range
198.51.100.0 - 198.51.100.255 , 1 , Another range
";
let filter = parse_dat(input).unwrap();
assert!(filter.is_blocked("203.0.113.50".parse().unwrap()));
assert!(filter.is_blocked("198.51.100.1".parse().unwrap()));
assert!(!filter.is_blocked("8.8.8.8".parse().unwrap()));
}
#[test]
fn parse_dat_malformed() {
let input = "this is not a valid line";
let err = parse_dat(input).unwrap_err();
assert!(matches!(err, IpFilterError::MalformedLine { line: 1, .. }));
}
#[test]
fn parse_p2p_valid() {
let input = "\
# P2P blocklist
Some Bad Range:203.0.113.0-203.0.113.255
Another Range:198.51.100.0-198.51.100.255
";
let filter = parse_p2p(input).unwrap();
assert!(filter.is_blocked("203.0.113.50".parse().unwrap()));
assert!(filter.is_blocked("198.51.100.1".parse().unwrap()));
assert!(!filter.is_blocked("8.8.8.8".parse().unwrap()));
}
#[test]
fn parse_p2p_invalid_ip() {
let input = "Bad Range:999.999.999.999-203.0.113.255";
let err = parse_p2p(input).unwrap_err();
assert!(matches!(err, IpFilterError::InvalidAddress { line: 1, .. }));
}
#[test]
fn port_filter_block_range() {
let mut filter = PortFilter::new();
filter.add_rule(6881, 6889, 1);
assert!(filter.is_blocked(6881));
assert!(filter.is_blocked(6885));
assert!(filter.is_blocked(6889));
assert!(!filter.is_blocked(6880));
assert!(!filter.is_blocked(6890));
assert!(!filter.is_blocked(80));
}
}