1use std::net::IpAddr;
9
10use ipnet::IpNet;
11use parking_lot::RwLock;
12use serde::{Deserialize, Serialize};
13
14use crate::WafDecision;
15
16#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
18#[serde(rename_all = "snake_case")]
19pub enum IpFilterMode {
20 Allow,
22 Deny,
24 #[default]
26 Off,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct IpFilterConfig {
32 #[serde(default)]
33 pub mode: IpFilterMode,
34 #[serde(default)]
36 pub allow_list: Vec<String>,
37 #[serde(default)]
39 pub deny_list: Vec<String>,
40}
41
42impl Default for IpFilterConfig {
43 fn default() -> Self {
44 Self {
45 mode: IpFilterMode::Off,
46 allow_list: Vec::new(),
47 deny_list: Vec::new(),
48 }
49 }
50}
51
52struct ParsedLists {
53 allow: Vec<IpNet>,
54 deny: Vec<IpNet>,
55 mode: IpFilterMode,
56}
57
58pub struct IpFilter {
60 inner: RwLock<ParsedLists>,
61}
62
63fn parse_networks(raw: &[String]) -> Vec<IpNet> {
64 raw.iter()
65 .filter_map(|s| {
66 let trimmed = s.trim();
67 if trimmed.is_empty() {
68 return None;
69 }
70 trimmed
72 .parse::<IpNet>()
73 .ok()
74 .or_else(|| {
75 trimmed.parse::<IpAddr>().ok().map(|addr| match addr {
76 IpAddr::V4(v4) => IpNet::V4(ipnet::Ipv4Net::from(v4)),
77 IpAddr::V6(v6) => IpNet::V6(ipnet::Ipv6Net::from(v6)),
78 })
79 })
80 .or_else(|| {
81 tracing::warn!(input = trimmed, "invalid IP/CIDR in WAF ip_filter config");
82 None
83 })
84 })
85 .collect()
86}
87
88fn ip_in_list(ip: IpAddr, nets: &[IpNet]) -> bool {
89 nets.iter().any(|net| net.contains(&ip))
90}
91
92impl IpFilter {
93 pub fn new(config: IpFilterConfig) -> Self {
94 let parsed = ParsedLists {
95 allow: parse_networks(&config.allow_list),
96 deny: parse_networks(&config.deny_list),
97 mode: config.mode,
98 };
99 Self {
100 inner: RwLock::new(parsed),
101 }
102 }
103
104 pub fn reload(&self, config: IpFilterConfig) {
106 let mut guard = self.inner.write();
107 guard.allow = parse_networks(&config.allow_list);
108 guard.deny = parse_networks(&config.deny_list);
109 guard.mode = config.mode;
110 }
111
112 pub fn check(&self, ip: IpAddr) -> Option<WafDecision> {
114 let guard = self.inner.read();
115
116 match guard.mode {
117 IpFilterMode::Off => None,
118 IpFilterMode::Allow => {
119 if ip_in_list(ip, &guard.allow) {
121 None
122 } else {
123 Some(WafDecision::Block {
124 status: 403,
125 reason: format!("IP {ip} not in allow list"),
126 rule: "ip_filter_allow".into(),
127 })
128 }
129 }
130 IpFilterMode::Deny => {
131 if ip_in_list(ip, &guard.allow) {
133 return None;
134 }
135 if ip_in_list(ip, &guard.deny) {
136 Some(WafDecision::Block {
137 status: 403,
138 reason: format!("IP {ip} is in deny list"),
139 rule: "ip_filter_deny".into(),
140 })
141 } else {
142 None
143 }
144 }
145 }
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn off_mode_allows_everything() {
155 let filter = IpFilter::new(IpFilterConfig::default());
156 assert!(filter.check("1.2.3.4".parse().unwrap()).is_none());
157 assert!(filter.check("::1".parse().unwrap()).is_none());
158 }
159
160 #[test]
161 fn deny_mode_blocks_listed_ip() {
162 let config = IpFilterConfig {
163 mode: IpFilterMode::Deny,
164 allow_list: vec![],
165 deny_list: vec!["10.0.0.1".into()],
166 };
167 let filter = IpFilter::new(config);
168 assert!(filter.check("10.0.0.1".parse().unwrap()).is_some());
169 assert!(filter.check("10.0.0.2".parse().unwrap()).is_none());
170 }
171
172 #[test]
173 fn deny_mode_blocks_cidr_range() {
174 let config = IpFilterConfig {
175 mode: IpFilterMode::Deny,
176 allow_list: vec![],
177 deny_list: vec!["192.168.1.0/24".into()],
178 };
179 let filter = IpFilter::new(config);
180 assert!(filter.check("192.168.1.50".parse().unwrap()).is_some());
181 assert!(filter.check("192.168.1.255".parse().unwrap()).is_some());
182 assert!(filter.check("192.168.2.1".parse().unwrap()).is_none());
183 }
184
185 #[test]
186 fn deny_mode_allow_overrides_deny() {
187 let config = IpFilterConfig {
188 mode: IpFilterMode::Deny,
189 allow_list: vec!["10.0.0.5".into()],
190 deny_list: vec!["10.0.0.0/24".into()],
191 };
192 let filter = IpFilter::new(config);
193 assert!(filter.check("10.0.0.5".parse().unwrap()).is_none());
195 assert!(filter.check("10.0.0.6".parse().unwrap()).is_some());
197 }
198
199 #[test]
200 fn allow_mode_blocks_unlisted() {
201 let config = IpFilterConfig {
202 mode: IpFilterMode::Allow,
203 allow_list: vec!["203.0.113.0/24".into()],
204 deny_list: vec![],
205 };
206 let filter = IpFilter::new(config);
207 assert!(filter.check("203.0.113.10".parse().unwrap()).is_none());
208 assert!(filter.check("198.51.100.1".parse().unwrap()).is_some());
209 }
210
211 #[test]
212 fn ipv6_single_ip() {
213 let config = IpFilterConfig {
214 mode: IpFilterMode::Deny,
215 allow_list: vec![],
216 deny_list: vec!["2001:db8::1".into()],
217 };
218 let filter = IpFilter::new(config);
219 assert!(filter.check("2001:db8::1".parse().unwrap()).is_some());
220 assert!(filter.check("2001:db8::2".parse().unwrap()).is_none());
221 }
222
223 #[test]
224 fn ipv6_cidr() {
225 let config = IpFilterConfig {
226 mode: IpFilterMode::Deny,
227 allow_list: vec![],
228 deny_list: vec!["fd00::/8".into()],
229 };
230 let filter = IpFilter::new(config);
231 assert!(filter.check("fd12:3456:789a::1".parse().unwrap()).is_some());
232 assert!(filter.check("2001:db8::1".parse().unwrap()).is_none());
233 }
234
235 #[test]
236 fn empty_lists_deny_mode_allows_all() {
237 let config = IpFilterConfig {
238 mode: IpFilterMode::Deny,
239 allow_list: vec![],
240 deny_list: vec![],
241 };
242 let filter = IpFilter::new(config);
243 assert!(filter.check("1.2.3.4".parse().unwrap()).is_none());
244 }
245
246 #[test]
247 fn empty_allow_list_blocks_all() {
248 let config = IpFilterConfig {
249 mode: IpFilterMode::Allow,
250 allow_list: vec![],
251 deny_list: vec![],
252 };
253 let filter = IpFilter::new(config);
254 assert!(filter.check("1.2.3.4".parse().unwrap()).is_some());
255 }
256
257 #[test]
258 fn hot_reload() {
259 let config = IpFilterConfig {
260 mode: IpFilterMode::Deny,
261 allow_list: vec![],
262 deny_list: vec!["10.0.0.1".into()],
263 };
264 let filter = IpFilter::new(config);
265 assert!(filter.check("10.0.0.1".parse().unwrap()).is_some());
266
267 filter.reload(IpFilterConfig {
269 mode: IpFilterMode::Off,
270 allow_list: vec![],
271 deny_list: vec!["10.0.0.1".into()],
272 });
273 assert!(filter.check("10.0.0.1".parse().unwrap()).is_none());
274 }
275
276 #[test]
277 fn invalid_entries_are_skipped() {
278 let config = IpFilterConfig {
279 mode: IpFilterMode::Deny,
280 allow_list: vec![],
281 deny_list: vec!["not-an-ip".into(), "10.0.0.1".into()],
282 };
283 let filter = IpFilter::new(config);
284 assert!(filter.check("10.0.0.1".parse().unwrap()).is_some());
286 assert!(filter.check("10.0.0.2".parse().unwrap()).is_none());
287 }
288
289 #[test]
290 fn multiple_cidrs_in_deny() {
291 let config = IpFilterConfig {
292 mode: IpFilterMode::Deny,
293 allow_list: vec![],
294 deny_list: vec!["10.0.0.0/8".into(), "172.16.0.0/12".into()],
295 };
296 let filter = IpFilter::new(config);
297 assert!(filter.check("10.1.2.3".parse().unwrap()).is_some());
298 assert!(filter.check("172.20.1.1".parse().unwrap()).is_some());
299 assert!(filter.check("8.8.8.8".parse().unwrap()).is_none());
300 }
301}