1#[cfg(feature = "cidr")]
4use cidr::{Ipv4Cidr, Ipv4Inet, Ipv6Cidr, Ipv6Inet};
5#[cfg(feature = "ipnet")]
6use ipnet::{Ipv4Net, Ipv6Net};
7#[cfg(feature = "ipnetwork")]
8use ipnetwork::{Ipv4Network, Ipv6Network};
9use num_traits::{CheckedShr, PrimInt, Unsigned, Zero};
10
11pub trait Prefix: Sized + std::fmt::Debug {
13 type R: Unsigned + PrimInt + Zero + CheckedShr;
15
16 fn repr(&self) -> Self::R;
19
20 fn prefix_len(&self) -> u8;
22
23 fn from_repr_len(repr: Self::R, len: u8) -> Self;
25
26 fn mask(&self) -> Self::R {
29 self.repr() & mask_from_prefix_len(self.prefix_len())
30 }
31
32 fn zero() -> Self {
34 Self::from_repr_len(Self::R::zero(), 0)
35 }
36
37 fn longest_common_prefix(&self, other: &Self) -> Self {
39 let a = self.mask();
40 let b = other.mask();
41 let len = ((a ^ b).leading_zeros() as u8)
42 .min(self.prefix_len())
43 .min(other.prefix_len());
44 let repr = a & mask_from_prefix_len(len);
45 Self::from_repr_len(repr, len)
46 }
47
48 fn contains(&self, other: &Self) -> bool {
51 if self.prefix_len() > other.prefix_len() {
52 return false;
53 }
54 other.repr() & mask_from_prefix_len(self.prefix_len()) == self.mask()
55 }
56
57 fn is_bit_set(&self, bit: u8) -> bool {
60 let mask = (!Self::R::zero())
61 .checked_shr(bit as u32)
62 .unwrap_or_else(Self::R::zero)
63 ^ (!Self::R::zero())
64 .checked_shr(1u32 + bit as u32)
65 .unwrap_or_else(Self::R::zero);
66 mask & self.mask() != Self::R::zero()
67 }
68
69 fn eq(&self, other: &Self) -> bool {
71 self.mask() == other.mask() && self.prefix_len() == other.prefix_len()
72 }
73}
74
75pub(crate) fn mask_from_prefix_len<R>(len: u8) -> R
76where
77 R: PrimInt + Zero,
78{
79 if len as u32 == R::zero().count_zeros() {
80 !R::zero()
81 } else if len == 0 {
82 R::zero()
83 } else {
84 !((!R::zero()) >> len as usize)
85 }
86}
87
88#[cfg(feature = "ipnet")]
89impl Prefix for Ipv4Net {
90 type R = u32;
91
92 fn repr(&self) -> u32 {
93 self.addr().into()
94 }
95
96 fn prefix_len(&self) -> u8 {
97 self.prefix_len()
98 }
99
100 fn from_repr_len(repr: u32, len: u8) -> Self {
101 Ipv4Net::new(repr.into(), len).unwrap()
102 }
103
104 fn mask(&self) -> u32 {
105 self.network().into()
106 }
107
108 fn zero() -> Self {
109 Default::default()
110 }
111
112 fn longest_common_prefix(&self, other: &Self) -> Self {
113 let a = self.repr();
114 let b = other.repr();
115 let len = ((a ^ b).leading_zeros() as u8)
116 .min(self.prefix_len())
117 .min(other.prefix_len());
118 let repr = a & mask_from_prefix_len::<u32>(len);
119 Ipv4Net::new(repr.into(), len).unwrap()
120 }
121
122 fn contains(&self, other: &Self) -> bool {
123 self.contains(other)
124 }
125}
126
127#[cfg(feature = "ipnet")]
128impl Prefix for Ipv6Net {
129 type R = u128;
130
131 fn repr(&self) -> u128 {
132 self.addr().into()
133 }
134
135 fn prefix_len(&self) -> u8 {
136 self.prefix_len()
137 }
138
139 fn from_repr_len(repr: u128, len: u8) -> Self {
140 Ipv6Net::new(repr.into(), len).unwrap()
141 }
142
143 fn mask(&self) -> u128 {
144 self.network().into()
145 }
146
147 fn zero() -> Self {
148 Default::default()
149 }
150
151 fn longest_common_prefix(&self, other: &Self) -> Self {
152 let a = self.repr();
153 let b = other.repr();
154 let len = ((a ^ b).leading_zeros() as u8)
155 .min(self.prefix_len())
156 .min(other.prefix_len());
157 let repr = a & mask_from_prefix_len::<u128>(len);
158 Ipv6Net::new(repr.into(), len).unwrap()
159 }
160
161 fn contains(&self, other: &Self) -> bool {
162 self.contains(other)
163 }
164}
165
166#[cfg(feature = "ipnetwork")]
167impl Prefix for Ipv4Network {
168 type R = u32;
169
170 fn repr(&self) -> u32 {
171 self.ip().into()
172 }
173
174 fn prefix_len(&self) -> u8 {
175 self.prefix()
176 }
177
178 fn from_repr_len(repr: u32, len: u8) -> Self {
179 Ipv4Network::new(repr.into(), len).unwrap()
180 }
181
182 fn mask(&self) -> u32 {
183 self.network().into()
184 }
185}
186
187#[cfg(feature = "ipnetwork")]
188impl Prefix for Ipv6Network {
189 type R = u128;
190
191 fn repr(&self) -> u128 {
192 self.ip().into()
193 }
194
195 fn prefix_len(&self) -> u8 {
196 self.prefix()
197 }
198
199 fn from_repr_len(repr: u128, len: u8) -> Self {
200 Ipv6Network::new(repr.into(), len).unwrap()
201 }
202
203 fn mask(&self) -> u128 {
204 self.network().into()
205 }
206}
207
208#[cfg(feature = "cidr")]
209impl Prefix for Ipv4Cidr {
210 type R = u32;
211
212 fn repr(&self) -> Self::R {
213 self.first_address().into()
214 }
215
216 fn prefix_len(&self) -> u8 {
217 self.network_length()
218 }
219
220 fn from_repr_len(repr: Self::R, len: u8) -> Self {
221 let repr = repr & mask_from_prefix_len::<Self::R>(len);
222 Self::new(repr.into(), len).unwrap()
223 }
224
225 fn mask(&self) -> Self::R {
226 self.first_address().into()
227 }
228}
229
230#[cfg(feature = "cidr")]
231impl Prefix for Ipv6Cidr {
232 type R = u128;
233
234 fn repr(&self) -> Self::R {
235 self.first_address().into()
236 }
237
238 fn prefix_len(&self) -> u8 {
239 self.network_length()
240 }
241
242 fn from_repr_len(repr: Self::R, len: u8) -> Self {
243 let repr = repr & mask_from_prefix_len::<Self::R>(len);
244 Self::new(repr.into(), len).unwrap()
245 }
246
247 fn mask(&self) -> Self::R {
248 self.first_address().into()
249 }
250}
251
252#[cfg(feature = "cidr")]
253impl Prefix for Ipv4Inet {
254 type R = u32;
255
256 fn repr(&self) -> Self::R {
257 self.address().into()
258 }
259
260 fn prefix_len(&self) -> u8 {
261 self.network_length()
262 }
263
264 fn from_repr_len(repr: Self::R, len: u8) -> Self {
265 Self::new(repr.into(), len).unwrap()
266 }
267
268 fn mask(&self) -> Self::R {
269 self.network().first_address().into()
270 }
271}
272
273#[cfg(feature = "cidr")]
274impl Prefix for Ipv6Inet {
275 type R = u128;
276
277 fn repr(&self) -> Self::R {
278 self.address().into()
279 }
280
281 fn prefix_len(&self) -> u8 {
282 self.network_length()
283 }
284
285 fn from_repr_len(repr: Self::R, len: u8) -> Self {
286 Self::new(repr.into(), len).unwrap()
287 }
288
289 fn mask(&self) -> Self::R {
290 self.network().first_address().into()
291 }
292}
293
294impl<R> Prefix for (R, u8)
295where
296 R: Unsigned + PrimInt + Zero + CheckedShr + std::fmt::Debug,
297{
298 type R = R;
299
300 fn repr(&self) -> R {
301 self.0
302 }
303
304 fn prefix_len(&self) -> u8 {
305 self.1
306 }
307
308 fn from_repr_len(repr: R, len: u8) -> Self {
309 (repr, len)
310 }
311}
312
313#[cfg(test)]
314#[cfg(feature = "ipnet")]
315mod test {
316 use super::*;
317
318 macro_rules! pfx {
319 ($p:literal) => {
320 $p.parse::<Ipv4Net>().unwrap()
321 };
322 }
323
324 #[test]
325 fn mask_from_len() {
326 assert_eq!(mask_from_prefix_len::<u8>(3), 0b11100000);
327 assert_eq!(mask_from_prefix_len::<u8>(5), 0b11111000);
328 assert_eq!(mask_from_prefix_len::<u8>(8), 0b11111111);
329 assert_eq!(mask_from_prefix_len::<u8>(0), 0b00000000);
330
331 assert_eq!(mask_from_prefix_len::<u32>(0), 0x00000000);
332 assert_eq!(mask_from_prefix_len::<u32>(8), 0xff000000);
333 assert_eq!(mask_from_prefix_len::<u32>(16), 0xffff0000);
334 assert_eq!(mask_from_prefix_len::<u32>(24), 0xffffff00);
335 assert_eq!(mask_from_prefix_len::<u32>(32), 0xffffffff);
336 }
337
338 #[test]
339 fn prefix_mask() {
340 let addr = pfx!("10.1.0.0/8");
341 assert_eq!(Prefix::prefix_len(&addr), 8);
342 assert_eq!(Prefix::repr(&addr), (10 << 24) + (1 << 16));
343 assert_eq!(Prefix::mask(&addr), 10u32 << 24);
344 }
345
346 #[test]
347 fn contains() {
348 let larger = pfx!("10.128.0.0/9");
349 let smaller = pfx!("10.0.0.0/8");
350 let larger_c = pfx!("10.130.2.5/9");
351 let smaller_c = pfx!("10.25.2.8/8");
352 assert!(smaller.contains(&larger));
353 assert!(smaller.contains(&larger_c));
354 assert!(smaller_c.contains(&larger));
355 assert!(smaller_c.contains(&larger_c));
356 assert!(!larger.contains(&smaller));
357 assert!(!larger.contains(&smaller_c));
358 assert!(!larger_c.contains(&smaller));
359 assert!(!larger_c.contains(&smaller_c));
360 assert!(smaller.contains(&smaller));
361 assert!(smaller.contains(&smaller_c));
362 assert!(smaller_c.contains(&smaller));
363 assert!(smaller_c.contains(&smaller_c));
364 }
365
366 #[test]
367 fn longest_common_prefix() {
368 macro_rules! assert_lcp {
369 ($a:literal, $b:literal, $c:literal) => {
370 assert_eq!(pfx!($a).longest_common_prefix(&pfx!($b)), pfx!($c));
371 assert_eq!(pfx!($b).longest_common_prefix(&pfx!($a)), pfx!($c));
372 };
373 }
374 assert_lcp!("1.2.3.4/24", "1.3.3.4/24", "1.2.0.0/15");
375 assert_lcp!("1.2.3.4/24", "1.1.3.4/24", "1.0.0.0/14");
376 assert_lcp!("1.2.3.4/24", "1.2.3.4/30", "1.2.3.0/24");
377 }
378
379 #[test]
380 fn is_bit_set() {
381 assert!(pfx!("255.0.0.0/8").is_bit_set(0));
382 assert!(pfx!("255.0.0.0/8").is_bit_set(7));
383 assert!(!pfx!("255.0.0.0/8").is_bit_set(8));
384 assert!(!pfx!("255.255.0.0/8").is_bit_set(8));
385 }
386
387 #[generic_tests::define]
388 mod t {
389 use num_traits::NumCast;
390
391 use super::*;
392
393 fn new<P: Prefix>(repr: u32, len: u8) -> P {
394 let repr = <<P as Prefix>::R as NumCast>::from(repr).unwrap();
395 let num_zeros = <<P as Prefix>::R as Zero>::zero().count_zeros() as u8;
396 let len = len + (num_zeros - 32);
397 P::from_repr_len(repr, len)
398 }
399
400 #[test]
401 fn repr_len<P: Prefix>() {
402 for x in [0x01000000u32, 0x010f0000u32, 0xffff0000u32] {
403 let repr = <<P as Prefix>::R as NumCast>::from(x).unwrap();
404 let num_zeros = <<P as Prefix>::R as Zero>::zero().count_zeros() as u8;
405 let len = 16 + (num_zeros - 32);
406 let prefix = P::from_repr_len(repr, len);
407 assert!(prefix.repr() == repr);
408 assert!(prefix.prefix_len() == len);
409 }
410 }
411
412 #[test]
413 fn keep_host_addr<P: Prefix + 'static>() {
414 #[allow(unused_mut)]
415 #[allow(unused_assignments)]
416 let mut prefix_is_masked = false;
417 #[cfg(feature = "cidr")]
418 {
419 let p_id = std::any::TypeId::of::<P>();
420 prefix_is_masked = p_id == std::any::TypeId::of::<cidr::Ipv4Cidr>()
422 || p_id == std::any::TypeId::of::<cidr::Ipv6Cidr>();
423 }
424 let mask = 0xffff0000u32;
425 for mut x in [0x01001234u32, 0x010fabcdu32, 0xffff5678u32] {
426 let prefix: P = new(x, 16);
427 if prefix_is_masked {
428 x &= mask;
429 }
430 assert_eq!(<u32 as NumCast>::from(prefix.repr()), Some(x));
431 }
432 }
433
434 #[test]
435 fn mask<P: Prefix>() {
436 let mask = 0xffff0000u32;
437 for x in [0x01001234u32, 0x010fabcdu32, 0xffff5678u32] {
438 let prefix: P = new(x, 16);
439 assert_eq!(<u32 as NumCast>::from(prefix.mask()), Some(x & mask));
440 }
441 }
442
443 #[test]
444 fn zero<P: Prefix>() {
445 let prefix = P::from_repr_len(P::R::zero(), 0);
446 assert!(P::zero().eq(&prefix));
447 }
448
449 #[test]
450 fn longest_common_prefix<P: Prefix>() {
451 for ((a, al), (b, bl), (c, cl)) in [
452 ((0x01020304, 24), (0x01030304, 24), (0x01020000, 15)),
453 ((0x12345678, 24), (0x12345678, 16), (0x12340000, 16)),
454 ] {
455 let a: P = new(a, al);
456 let b: P = new(b, bl);
457 let c: P = new(c, cl);
458 let lcp = a.longest_common_prefix(&b);
459 assert!(lcp.repr() == c.repr());
460 assert!(lcp.prefix_len() == c.prefix_len());
461 }
462 }
463
464 #[test]
465 fn contains<P: Prefix>() {
466 assert!(new::<P>(0x01020000, 16).contains(&new(0x0102ffff, 24)));
467 assert!(new::<P>(0x01020304, 16).contains(&new(0x0102ffff, 24)));
468 assert!(new::<P>(0x01020304, 16).contains(&new(0x0102ffff, 16)));
469 assert!(!new::<P>(0x01020304, 24).contains(&new(0x0102ffff, 16)));
470 }
471
472 #[test]
473 fn is_bit_set<P: Prefix>() {
474 let x = 0x12345678u32;
475 let num_zeros = <<P as Prefix>::R as Zero>::zero().count_zeros() as u8;
476 let offset = num_zeros - 32;
477 let p: P = new(x, 16);
478 for i in 0..64 {
479 let j = i + offset;
480 if i >= 16 {
481 assert!(!p.is_bit_set(j))
482 } else {
483 let mask = 0x80000000u32 >> i;
484 assert_eq!(p.is_bit_set(j), x & mask != 0)
485 }
486 }
487 }
488
489 #[instantiate_tests(<Ipv4Net>)]
490 mod ipv4net {}
491
492 #[instantiate_tests(<Ipv6Net>)]
493 mod ipv6net {}
494
495 #[cfg(feature = "ipnetwork")]
496 #[instantiate_tests(<Ipv4Network>)]
497 mod ipv4network {}
498
499 #[cfg(feature = "ipnetwork")]
500 #[instantiate_tests(<Ipv6Network>)]
501 mod ipv6network {}
502
503 #[cfg(feature = "cidr")]
504 #[instantiate_tests(<Ipv4Cidr>)]
505 mod ipv4cidr {}
506
507 #[cfg(feature = "cidr")]
508 #[instantiate_tests(<Ipv4Inet>)]
509 mod ipv4inet {}
510
511 #[cfg(feature = "cidr")]
512 #[instantiate_tests(<Ipv6Cidr>)]
513 mod ipv6cidr {}
514
515 #[cfg(feature = "cidr")]
516 #[instantiate_tests(<Ipv6Inet>)]
517 mod ipv6inet {}
518
519 #[instantiate_tests(<(u32, u8)>)]
520 mod u32_u8 {}
521
522 #[instantiate_tests(<(u64, u8)>)]
523 mod u64_u8 {}
524 }
525}