use std::net::IpAddr;
use ipnet::IpNet;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use crate::WafDecision;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum IpFilterMode {
Allow,
Deny,
#[default]
Off,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IpFilterConfig {
#[serde(default)]
pub mode: IpFilterMode,
#[serde(default)]
pub allow_list: Vec<String>,
#[serde(default)]
pub deny_list: Vec<String>,
}
impl Default for IpFilterConfig {
fn default() -> Self {
Self {
mode: IpFilterMode::Off,
allow_list: Vec::new(),
deny_list: Vec::new(),
}
}
}
struct ParsedLists {
allow: Vec<IpNet>,
deny: Vec<IpNet>,
mode: IpFilterMode,
}
pub struct IpFilter {
inner: RwLock<ParsedLists>,
}
fn parse_networks(raw: &[String]) -> Vec<IpNet> {
raw.iter()
.filter_map(|s| {
let trimmed = s.trim();
if trimmed.is_empty() {
return None;
}
trimmed
.parse::<IpNet>()
.ok()
.or_else(|| {
trimmed.parse::<IpAddr>().ok().map(|addr| match addr {
IpAddr::V4(v4) => IpNet::V4(ipnet::Ipv4Net::from(v4)),
IpAddr::V6(v6) => IpNet::V6(ipnet::Ipv6Net::from(v6)),
})
})
.or_else(|| {
tracing::warn!(input = trimmed, "invalid IP/CIDR in WAF ip_filter config");
None
})
})
.collect()
}
fn ip_in_list(ip: IpAddr, nets: &[IpNet]) -> bool {
nets.iter().any(|net| net.contains(&ip))
}
impl IpFilter {
pub fn new(config: IpFilterConfig) -> Self {
let parsed = ParsedLists {
allow: parse_networks(&config.allow_list),
deny: parse_networks(&config.deny_list),
mode: config.mode,
};
Self {
inner: RwLock::new(parsed),
}
}
pub fn reload(&self, config: IpFilterConfig) {
let mut guard = self.inner.write();
guard.allow = parse_networks(&config.allow_list);
guard.deny = parse_networks(&config.deny_list);
guard.mode = config.mode;
}
pub fn check(&self, ip: IpAddr) -> Option<WafDecision> {
let guard = self.inner.read();
match guard.mode {
IpFilterMode::Off => None,
IpFilterMode::Allow => {
if ip_in_list(ip, &guard.allow) {
None
} else {
Some(WafDecision::Block {
status: 403,
reason: format!("IP {ip} not in allow list"),
rule: "ip_filter_allow".into(),
})
}
}
IpFilterMode::Deny => {
if ip_in_list(ip, &guard.allow) {
return None;
}
if ip_in_list(ip, &guard.deny) {
Some(WafDecision::Block {
status: 403,
reason: format!("IP {ip} is in deny list"),
rule: "ip_filter_deny".into(),
})
} else {
None
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn off_mode_allows_everything() {
let filter = IpFilter::new(IpFilterConfig::default());
assert!(filter.check("1.2.3.4".parse().unwrap()).is_none());
assert!(filter.check("::1".parse().unwrap()).is_none());
}
#[test]
fn deny_mode_blocks_listed_ip() {
let config = IpFilterConfig {
mode: IpFilterMode::Deny,
allow_list: vec![],
deny_list: vec!["10.0.0.1".into()],
};
let filter = IpFilter::new(config);
assert!(filter.check("10.0.0.1".parse().unwrap()).is_some());
assert!(filter.check("10.0.0.2".parse().unwrap()).is_none());
}
#[test]
fn deny_mode_blocks_cidr_range() {
let config = IpFilterConfig {
mode: IpFilterMode::Deny,
allow_list: vec![],
deny_list: vec!["192.168.1.0/24".into()],
};
let filter = IpFilter::new(config);
assert!(filter.check("192.168.1.50".parse().unwrap()).is_some());
assert!(filter.check("192.168.1.255".parse().unwrap()).is_some());
assert!(filter.check("192.168.2.1".parse().unwrap()).is_none());
}
#[test]
fn deny_mode_allow_overrides_deny() {
let config = IpFilterConfig {
mode: IpFilterMode::Deny,
allow_list: vec!["10.0.0.5".into()],
deny_list: vec!["10.0.0.0/24".into()],
};
let filter = IpFilter::new(config);
assert!(filter.check("10.0.0.5".parse().unwrap()).is_none());
assert!(filter.check("10.0.0.6".parse().unwrap()).is_some());
}
#[test]
fn allow_mode_blocks_unlisted() {
let config = IpFilterConfig {
mode: IpFilterMode::Allow,
allow_list: vec!["203.0.113.0/24".into()],
deny_list: vec![],
};
let filter = IpFilter::new(config);
assert!(filter.check("203.0.113.10".parse().unwrap()).is_none());
assert!(filter.check("198.51.100.1".parse().unwrap()).is_some());
}
#[test]
fn ipv6_single_ip() {
let config = IpFilterConfig {
mode: IpFilterMode::Deny,
allow_list: vec![],
deny_list: vec!["2001:db8::1".into()],
};
let filter = IpFilter::new(config);
assert!(filter.check("2001:db8::1".parse().unwrap()).is_some());
assert!(filter.check("2001:db8::2".parse().unwrap()).is_none());
}
#[test]
fn ipv6_cidr() {
let config = IpFilterConfig {
mode: IpFilterMode::Deny,
allow_list: vec![],
deny_list: vec!["fd00::/8".into()],
};
let filter = IpFilter::new(config);
assert!(filter.check("fd12:3456:789a::1".parse().unwrap()).is_some());
assert!(filter.check("2001:db8::1".parse().unwrap()).is_none());
}
#[test]
fn empty_lists_deny_mode_allows_all() {
let config = IpFilterConfig {
mode: IpFilterMode::Deny,
allow_list: vec![],
deny_list: vec![],
};
let filter = IpFilter::new(config);
assert!(filter.check("1.2.3.4".parse().unwrap()).is_none());
}
#[test]
fn empty_allow_list_blocks_all() {
let config = IpFilterConfig {
mode: IpFilterMode::Allow,
allow_list: vec![],
deny_list: vec![],
};
let filter = IpFilter::new(config);
assert!(filter.check("1.2.3.4".parse().unwrap()).is_some());
}
#[test]
fn hot_reload() {
let config = IpFilterConfig {
mode: IpFilterMode::Deny,
allow_list: vec![],
deny_list: vec!["10.0.0.1".into()],
};
let filter = IpFilter::new(config);
assert!(filter.check("10.0.0.1".parse().unwrap()).is_some());
filter.reload(IpFilterConfig {
mode: IpFilterMode::Off,
allow_list: vec![],
deny_list: vec!["10.0.0.1".into()],
});
assert!(filter.check("10.0.0.1".parse().unwrap()).is_none());
}
#[test]
fn invalid_entries_are_skipped() {
let config = IpFilterConfig {
mode: IpFilterMode::Deny,
allow_list: vec![],
deny_list: vec!["not-an-ip".into(), "10.0.0.1".into()],
};
let filter = IpFilter::new(config);
assert!(filter.check("10.0.0.1".parse().unwrap()).is_some());
assert!(filter.check("10.0.0.2".parse().unwrap()).is_none());
}
#[test]
fn multiple_cidrs_in_deny() {
let config = IpFilterConfig {
mode: IpFilterMode::Deny,
allow_list: vec![],
deny_list: vec!["10.0.0.0/8".into(), "172.16.0.0/12".into()],
};
let filter = IpFilter::new(config);
assert!(filter.check("10.1.2.3".parse().unwrap()).is_some());
assert!(filter.check("172.20.1.1".parse().unwrap()).is_some());
assert!(filter.check("8.8.8.8".parse().unwrap()).is_none());
}
}