1use anyhow::format_err;
33use anyhow::Error as AnyhowError;
34use pdk_core::log::{debug, warn};
35use thiserror::Error;
36
37use crate::model::address_parser::{parse_address, AddressType};
38use crate::model::network_address::Address::Unknown;
39
40mod model;
41
42#[derive(Debug, Copy, Clone)]
44pub enum FilterType {
45 Allow,
47 Block,
49}
50
51#[derive(Debug, Clone)]
53pub struct IpFilter {
54 ips: Vec<AddressType>,
56 filter_type: FilterType,
58}
59
60#[derive(Debug, Error)]
62#[non_exhaustive]
63pub enum IpFilterError {
64 #[error("Invalid IP: {0}")]
66 InvalidIp(String),
67}
68
69pub(crate) fn parse_ips<B: AsRef<str>>(ip_list: &[B]) -> Result<Vec<AddressType>, IpFilterError> {
71 debug!("Parsing {} IP addresses/ranges", ip_list.len());
72 let (parsed, errors): (Vec<_>, Vec<_>) = ip_list
73 .iter()
74 .map(|ip| parse_ip(ip.as_ref()))
75 .partition(Result::is_ok);
76 if !errors.is_empty() {
77 let concatenated_bad_ips = errors
78 .into_iter()
79 .map(|result| result.err().unwrap().to_string())
80 .reduce(|err1, err2| format!("{err1} {err2}"))
81 .unwrap_or_default();
82 warn!("Failed to parse IPs: {concatenated_bad_ips}");
83 return Err(IpFilterError::InvalidIp(concatenated_bad_ips));
84 }
85
86 debug!("Successfully parsed {} IP addresses/ranges", parsed.len());
87 Ok(parsed
88 .into_iter()
89 .map(|result| result.unwrap())
90 .collect::<Vec<AddressType>>())
91}
92
93fn parse_ip(ip: &str) -> Result<AddressType, AnyhowError> {
94 let parsed = parse_address(ip);
95 if parsed == Unknown {
96 Err(format_err!("{ip}"))
97 } else {
98 Ok(parsed)
99 }
100}
101
102impl IpFilter {
103 fn new<B: AsRef<str>>(ip_list: &[B], filter_type: FilterType) -> Result<Self, IpFilterError> {
105 debug!(
106 "Creating IP filter with {} addresses, type: {:?}",
107 ip_list.len(),
108 filter_type
109 );
110 let parsed_ips = parse_ips(ip_list)?;
111 Ok(IpFilter {
112 ips: parsed_ips,
113 filter_type,
114 })
115 }
116
117 pub fn allow<B: AsRef<str>>(ips: &[B]) -> Result<Self, IpFilterError> {
120 Self::new(ips, FilterType::Allow)
121 }
122
123 pub fn block<B: AsRef<str>>(ips: &[B]) -> Result<Self, IpFilterError> {
126 Self::new(ips, FilterType::Block)
127 }
128
129 pub fn is_allowed(&self, ip: &str) -> bool {
131 let parsed_ip = parse_ip(ip);
132 if parsed_ip.is_err() {
133 warn!("Failed to parse IP address: {ip}");
134 return false;
135 }
136 let parsed_ip = parsed_ip.unwrap();
137 let ip_in_list = self.ips.iter().any(|ip| ip.contains(&parsed_ip));
138 let allowed = match self.filter_type {
139 FilterType::Allow => ip_in_list,
140 FilterType::Block => !ip_in_list,
141 };
142 debug!(
143 "IP {} check result: allowed={}, filter_type={:?}, in_list={}",
144 ip, allowed, self.filter_type, ip_in_list
145 );
146 allowed
147 }
148}
149
150#[allow(non_snake_case)]
151#[cfg(test)]
152mod ip_filter_tests {
153 use super::{parse_ips, IpFilter};
154
155 #[test]
156 fn test_allow_with_valid_ips() {
157 let ips = vec!["192.168.1.1", "10.0.0.2"];
158 let filter = IpFilter::allow(&ips).expect("Should create allow filter");
159 assert!(filter.is_allowed("192.168.1.1"));
160 assert!(filter.is_allowed("10.0.0.2"));
161 assert!(!filter.is_allowed("127.0.0.1"));
162 }
163
164 #[test]
165 fn test_allow_with_invalid_ip() {
166 let ips = vec!["192.168.1.1", "bad_ip"];
167 let result = IpFilter::allow(&ips);
168 assert!(result.is_err());
169 }
170
171 #[test]
172 fn test_block_with_valid_ips() {
173 let ips = vec!["10.10.10.10"];
174 let filter = IpFilter::block(&ips).expect("Should create block filter");
175 assert!(!filter.is_allowed("10.10.10.10"));
176 assert!(filter.is_allowed("8.8.8.8"));
177 }
178
179 #[test]
180 fn test_block_with_invalid_ip() {
181 let ips = vec!["not_an_ip", "192.0.2.6"];
182 let result = IpFilter::block(&ips);
183 assert!(result.is_err());
184 }
185
186 #[test]
187 fn given_invalid_ip__when_creating_filter_with_ip_list__then_invalid_ip_prevents_creation_of_valid_ips(
188 ) {
189 let ips: Vec<String> = ["192.0.0.1", "invalid_ip", "8.8.8.8"]
190 .iter()
191 .map(|&s| String::from(s))
192 .collect();
193
194 let parsed_ips = parse_ips(&ips);
195
196 assert!(parsed_ips.is_err())
197 }
198
199 #[test]
200 fn given_valid_ips__when_parsing__then_returns_parsed_list() {
201 let ips = vec!["192.168.1.1", "10.0.0.1", "::1"];
202 let result = parse_ips(&ips);
203 assert!(result.is_ok());
204 assert_eq!(result.unwrap().len(), 3);
205 }
206
207 #[test]
208 fn given_empty_list__when_parsing__then_returns_empty_list() {
209 let ips: Vec<String> = vec![];
210 let result = parse_ips(&ips);
211 assert!(result.is_ok());
212 assert!(result.unwrap().is_empty());
213 }
214
215 #[test]
216 fn given_cidr_ranges__when_parsing__then_returns_parsed_list() {
217 let ips = vec!["192.168.0.0/24", "10.0.0.0/8", "2001:db8::/32"];
218 let result = parse_ips(&ips);
219 assert!(result.is_ok());
220 assert_eq!(result.unwrap().len(), 3);
221 }
222
223 mod ipv4 {
224 use crate::IpFilter;
225
226 const ALLOWED_IP: &str = "192.0.0.2";
227 const BLOCKED_IP: &str = "192.0.0.1";
228
229 #[test]
230 fn given_valid_ipv4__when_creating_blocking_filter__then_ip_gets_blocked() {
231 let filter = IpFilter::block(&[BLOCKED_IP]).unwrap();
232 assert!(!filter.is_allowed(BLOCKED_IP));
233 }
234
235 #[test]
236 fn given_valid_ipv4__when_creating_blocking_filter__then_other_valid_ips_doesnt_get_blocked(
237 ) {
238 let filter = IpFilter::block(&[BLOCKED_IP]).unwrap();
239 assert!(filter.is_allowed(ALLOWED_IP));
240 }
241
242 #[test]
243 fn given_empty_blocking_filter__then_all_ips_allowed() {
244 let filter = IpFilter::block(&[] as &[&str]).unwrap();
245 assert!(filter.is_allowed(BLOCKED_IP));
246 }
247
248 #[test]
249 fn given_valid_ipv4__when_creating_allowing_filter__then_ip_is_allowed() {
250 let filter = IpFilter::allow(&[ALLOWED_IP]).unwrap();
251 assert!(filter.is_allowed(ALLOWED_IP));
252 }
253
254 #[test]
255 fn given_valid_ipv4__when_creating_allowing_filter__then_all_other_ips_are_blocked() {
256 let filter = IpFilter::allow(&[ALLOWED_IP]).unwrap();
257 assert!(!filter.is_allowed(BLOCKED_IP));
258 }
259
260 #[test]
261 fn given_empty_allow_filter__then_no_ip_is_allowed() {
262 let filter = IpFilter::allow(&[] as &[&str]).unwrap();
263 assert!(!filter.is_allowed(ALLOWED_IP));
264 }
265 }
266
267 mod ipv6 {
268 use crate::IpFilter;
269
270 const ALLOWED_IP: &str = "2001:db8:0:0:0:0:A:0";
271 const BLOCKED_IP: &str = "2001:db8:0:0:0:0:A:A";
272
273 #[test]
274 fn given_valid_ipv6__when_creating_blocking_filter__then_ip_gets_blocked() {
275 let filter = IpFilter::block(&[BLOCKED_IP]).unwrap();
276 assert!(!filter.is_allowed(BLOCKED_IP));
277 }
278
279 #[test]
280 fn given_valid_ipv6__when_creating_blocking_filter__then_other_valid_ips_dont_get_blocked()
281 {
282 let filter = IpFilter::block(&[BLOCKED_IP]).unwrap();
283 assert!(filter.is_allowed(ALLOWED_IP));
284 }
285
286 #[test]
287 fn given_empty_blocking_filter__then_all_ips_allowed() {
288 let filter = IpFilter::block(&[] as &[&str]).unwrap();
289 assert!(filter.is_allowed(BLOCKED_IP));
290 }
291
292 #[test]
293 fn given_valid_ipv6__when_creating_allowing_filter__then_ip_is_allowed() {
294 let filter = IpFilter::allow(&[ALLOWED_IP]).unwrap();
295 assert!(filter.is_allowed(ALLOWED_IP));
296 }
297
298 #[test]
299 fn given_valid_ipv6__when_creating_allowing_filter__then_all_other_ips_are_blocked() {
300 let filter = IpFilter::allow(&[ALLOWED_IP]).unwrap();
301 assert!(!filter.is_allowed(BLOCKED_IP));
302 }
303
304 #[test]
305 fn given_empty_allow_filter__then_no_ip_is_allowed() {
306 let filter = IpFilter::allow(&[] as &[&str]).unwrap();
307 assert!(!filter.is_allowed(ALLOWED_IP));
308 }
309 }
310
311 mod cidr_tests {
312 use crate::IpFilter;
313
314 #[test]
315 fn given_block_filter__when_filtering_31_bit_mask_ipv4__then_two_addresses_blocked() {
316 let filter = IpFilter::block(&["192.168.0.0/31"]).unwrap();
317
318 assert!(!filter.is_allowed("192.168.0.0"));
319 assert!(!filter.is_allowed("192.168.0.1"));
320 assert!(filter.is_allowed("192.168.0.2"));
321 }
322
323 #[test]
324 fn given_block_filter__when_filtering_128_bit_mask_ipv6__then_one_addresses_blocked() {
325 let filter = IpFilter::block(&["2001:db8:0:0:0:0:A:A/128"]).unwrap();
326
327 assert!(!filter.is_allowed("2001:db8:0:0:0:0:A:A"));
328 assert!(filter.is_allowed("2001:db8:0:0:0:0:A:B"));
329 assert!(filter.is_allowed("2001:db8:0:0:0:0:A:9"));
330 }
331 }
332}