use anyhow::format_err;
use anyhow::Error as AnyhowError;
use pdk_core::log::{debug, warn};
use thiserror::Error;
use crate::model::address_parser::{parse_address, AddressType};
use crate::model::network_address::Address::Unknown;
mod model;
#[derive(Debug, Copy, Clone)]
pub enum FilterType {
Allow,
Block,
}
#[derive(Debug, Clone)]
pub struct IpFilter {
ips: Vec<AddressType>,
filter_type: FilterType,
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum IpFilterError {
#[error("Invalid IP: {0}")]
InvalidIp(String),
}
pub(crate) fn parse_ips<B: AsRef<str>>(ip_list: &[B]) -> Result<Vec<AddressType>, IpFilterError> {
debug!("Parsing {} IP addresses/ranges", ip_list.len());
let (parsed, errors): (Vec<_>, Vec<_>) = ip_list
.iter()
.map(|ip| parse_ip(ip.as_ref()))
.partition(Result::is_ok);
if !errors.is_empty() {
let concatenated_bad_ips = errors
.into_iter()
.map(|result| result.err().unwrap().to_string())
.reduce(|err1, err2| format!("{err1} {err2}"))
.unwrap_or_default();
warn!("Failed to parse IPs: {concatenated_bad_ips}");
return Err(IpFilterError::InvalidIp(concatenated_bad_ips));
}
debug!("Successfully parsed {} IP addresses/ranges", parsed.len());
Ok(parsed
.into_iter()
.map(|result| result.unwrap())
.collect::<Vec<AddressType>>())
}
fn parse_ip(ip: &str) -> Result<AddressType, AnyhowError> {
let parsed = parse_address(ip);
if parsed == Unknown {
Err(format_err!("{ip}"))
} else {
Ok(parsed)
}
}
impl IpFilter {
fn new<B: AsRef<str>>(ip_list: &[B], filter_type: FilterType) -> Result<Self, IpFilterError> {
debug!(
"Creating IP filter with {} addresses, type: {:?}",
ip_list.len(),
filter_type
);
let parsed_ips = parse_ips(ip_list)?;
Ok(IpFilter {
ips: parsed_ips,
filter_type,
})
}
pub fn allow<B: AsRef<str>>(ips: &[B]) -> Result<Self, IpFilterError> {
Self::new(ips, FilterType::Allow)
}
pub fn block<B: AsRef<str>>(ips: &[B]) -> Result<Self, IpFilterError> {
Self::new(ips, FilterType::Block)
}
pub fn is_allowed(&self, ip: &str) -> bool {
let parsed_ip = parse_ip(ip);
if parsed_ip.is_err() {
warn!("Failed to parse IP address: {ip}");
return false;
}
let parsed_ip = parsed_ip.unwrap();
let ip_in_list = self.ips.iter().any(|ip| ip.contains(&parsed_ip));
let allowed = match self.filter_type {
FilterType::Allow => ip_in_list,
FilterType::Block => !ip_in_list,
};
debug!(
"IP {} check result: allowed={}, filter_type={:?}, in_list={}",
ip, allowed, self.filter_type, ip_in_list
);
allowed
}
}
#[allow(non_snake_case)]
#[cfg(test)]
mod ip_filter_tests {
use super::{parse_ips, IpFilter};
#[test]
fn test_allow_with_valid_ips() {
let ips = vec!["192.168.1.1", "10.0.0.2"];
let filter = IpFilter::allow(&ips).expect("Should create allow filter");
assert!(filter.is_allowed("192.168.1.1"));
assert!(filter.is_allowed("10.0.0.2"));
assert!(!filter.is_allowed("127.0.0.1"));
}
#[test]
fn test_allow_with_invalid_ip() {
let ips = vec!["192.168.1.1", "bad_ip"];
let result = IpFilter::allow(&ips);
assert!(result.is_err());
}
#[test]
fn test_block_with_valid_ips() {
let ips = vec!["10.10.10.10"];
let filter = IpFilter::block(&ips).expect("Should create block filter");
assert!(!filter.is_allowed("10.10.10.10"));
assert!(filter.is_allowed("8.8.8.8"));
}
#[test]
fn test_block_with_invalid_ip() {
let ips = vec!["not_an_ip", "192.0.2.6"];
let result = IpFilter::block(&ips);
assert!(result.is_err());
}
#[test]
fn given_invalid_ip__when_creating_filter_with_ip_list__then_invalid_ip_prevents_creation_of_valid_ips(
) {
let ips: Vec<String> = ["192.0.0.1", "invalid_ip", "8.8.8.8"]
.iter()
.map(|&s| String::from(s))
.collect();
let parsed_ips = parse_ips(&ips);
assert!(parsed_ips.is_err())
}
#[test]
fn given_valid_ips__when_parsing__then_returns_parsed_list() {
let ips = vec!["192.168.1.1", "10.0.0.1", "::1"];
let result = parse_ips(&ips);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 3);
}
#[test]
fn given_empty_list__when_parsing__then_returns_empty_list() {
let ips: Vec<String> = vec![];
let result = parse_ips(&ips);
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn given_cidr_ranges__when_parsing__then_returns_parsed_list() {
let ips = vec!["192.168.0.0/24", "10.0.0.0/8", "2001:db8::/32"];
let result = parse_ips(&ips);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 3);
}
mod ipv4 {
use crate::IpFilter;
const ALLOWED_IP: &str = "192.0.0.2";
const BLOCKED_IP: &str = "192.0.0.1";
#[test]
fn given_valid_ipv4__when_creating_blocking_filter__then_ip_gets_blocked() {
let filter = IpFilter::block(&[BLOCKED_IP]).unwrap();
assert!(!filter.is_allowed(BLOCKED_IP));
}
#[test]
fn given_valid_ipv4__when_creating_blocking_filter__then_other_valid_ips_doesnt_get_blocked(
) {
let filter = IpFilter::block(&[BLOCKED_IP]).unwrap();
assert!(filter.is_allowed(ALLOWED_IP));
}
#[test]
fn given_empty_blocking_filter__then_all_ips_allowed() {
let filter = IpFilter::block(&[] as &[&str]).unwrap();
assert!(filter.is_allowed(BLOCKED_IP));
}
#[test]
fn given_valid_ipv4__when_creating_allowing_filter__then_ip_is_allowed() {
let filter = IpFilter::allow(&[ALLOWED_IP]).unwrap();
assert!(filter.is_allowed(ALLOWED_IP));
}
#[test]
fn given_valid_ipv4__when_creating_allowing_filter__then_all_other_ips_are_blocked() {
let filter = IpFilter::allow(&[ALLOWED_IP]).unwrap();
assert!(!filter.is_allowed(BLOCKED_IP));
}
#[test]
fn given_empty_allow_filter__then_no_ip_is_allowed() {
let filter = IpFilter::allow(&[] as &[&str]).unwrap();
assert!(!filter.is_allowed(ALLOWED_IP));
}
}
mod ipv6 {
use crate::IpFilter;
const ALLOWED_IP: &str = "2001:db8:0:0:0:0:A:0";
const BLOCKED_IP: &str = "2001:db8:0:0:0:0:A:A";
#[test]
fn given_valid_ipv6__when_creating_blocking_filter__then_ip_gets_blocked() {
let filter = IpFilter::block(&[BLOCKED_IP]).unwrap();
assert!(!filter.is_allowed(BLOCKED_IP));
}
#[test]
fn given_valid_ipv6__when_creating_blocking_filter__then_other_valid_ips_dont_get_blocked()
{
let filter = IpFilter::block(&[BLOCKED_IP]).unwrap();
assert!(filter.is_allowed(ALLOWED_IP));
}
#[test]
fn given_empty_blocking_filter__then_all_ips_allowed() {
let filter = IpFilter::block(&[] as &[&str]).unwrap();
assert!(filter.is_allowed(BLOCKED_IP));
}
#[test]
fn given_valid_ipv6__when_creating_allowing_filter__then_ip_is_allowed() {
let filter = IpFilter::allow(&[ALLOWED_IP]).unwrap();
assert!(filter.is_allowed(ALLOWED_IP));
}
#[test]
fn given_valid_ipv6__when_creating_allowing_filter__then_all_other_ips_are_blocked() {
let filter = IpFilter::allow(&[ALLOWED_IP]).unwrap();
assert!(!filter.is_allowed(BLOCKED_IP));
}
#[test]
fn given_empty_allow_filter__then_no_ip_is_allowed() {
let filter = IpFilter::allow(&[] as &[&str]).unwrap();
assert!(!filter.is_allowed(ALLOWED_IP));
}
}
mod cidr_tests {
use crate::IpFilter;
#[test]
fn given_block_filter__when_filtering_31_bit_mask_ipv4__then_two_addresses_blocked() {
let filter = IpFilter::block(&["192.168.0.0/31"]).unwrap();
assert!(!filter.is_allowed("192.168.0.0"));
assert!(!filter.is_allowed("192.168.0.1"));
assert!(filter.is_allowed("192.168.0.2"));
}
#[test]
fn given_block_filter__when_filtering_128_bit_mask_ipv6__then_one_addresses_blocked() {
let filter = IpFilter::block(&["2001:db8:0:0:0:0:A:A/128"]).unwrap();
assert!(!filter.is_allowed("2001:db8:0:0:0:0:A:A"));
assert!(filter.is_allowed("2001:db8:0:0:0:0:A:B"));
assert!(filter.is_allowed("2001:db8:0:0:0:0:A:9"));
}
}
}