prefix_trie/
prefix.rs

1//! Description of the generic type `Prefix`.
2
3#[cfg(feature = "cidr")]
4use cidr::{Ipv4Cidr, Ipv4Inet, Ipv6Cidr, Ipv6Inet};
5#[cfg(feature = "ipnet")]
6use ipnet::{Ipv4Net, Ipv6Net};
7#[cfg(feature = "ipnetwork")]
8use ipnetwork::{Ipv4Network, Ipv6Network};
9use num_traits::{CheckedShr, PrimInt, Unsigned, Zero};
10
11/// Trait for defining prefixes.
12pub trait Prefix: Sized + std::fmt::Debug {
13    /// How can the prefix be represented. This must be one of `u8`, `u16`, `u32`, `u64`, or `u128`.
14    type R: Unsigned + PrimInt + Zero + CheckedShr;
15
16    /// Get raw representation of the address, ignoring the prefix length. This function must return
17    /// the representation with the mask already applied.
18    fn repr(&self) -> Self::R;
19
20    /// Prefix length
21    fn prefix_len(&self) -> u8;
22
23    /// Create a new prefix from the representation and the prefix pength.
24    fn from_repr_len(repr: Self::R, len: u8) -> Self;
25
26    /// mask `self.repr()` using `self.len()`. If you can guarantee that `repr` is already masked,
27    /// them simply re-implement this function for your type.
28    fn mask(&self) -> Self::R {
29        self.repr() & mask_from_prefix_len(self.prefix_len())
30    }
31
32    /// Create a prefix that matches everything
33    fn zero() -> Self {
34        Self::from_repr_len(Self::R::zero(), 0)
35    }
36
37    /// longest common prefix
38    fn longest_common_prefix(&self, other: &Self) -> Self {
39        let a = self.mask();
40        let b = other.mask();
41        let len = ((a ^ b).leading_zeros() as u8)
42            .min(self.prefix_len())
43            .min(other.prefix_len());
44        let repr = a & mask_from_prefix_len(len);
45        Self::from_repr_len(repr, len)
46    }
47
48    /// Check if `self` contains `other` in its prefix range. This function also returns `True` if
49    /// `self` is identical to `other`.
50    fn contains(&self, other: &Self) -> bool {
51        if self.prefix_len() > other.prefix_len() {
52            return false;
53        }
54        other.repr() & mask_from_prefix_len(self.prefix_len()) == self.mask()
55    }
56
57    /// Check if a specific bit is set (counted from the left, where 0 is the first bit from the
58    /// left).
59    fn is_bit_set(&self, bit: u8) -> bool {
60        let mask = (!Self::R::zero())
61            .checked_shr(bit as u32)
62            .unwrap_or_else(Self::R::zero)
63            ^ (!Self::R::zero())
64                .checked_shr(1u32 + bit as u32)
65                .unwrap_or_else(Self::R::zero);
66        mask & self.mask() != Self::R::zero()
67    }
68
69    /// Compare two prefixes together
70    fn eq(&self, other: &Self) -> bool {
71        self.mask() == other.mask() && self.prefix_len() == other.prefix_len()
72    }
73}
74
75pub(crate) fn mask_from_prefix_len<R>(len: u8) -> R
76where
77    R: PrimInt + Zero,
78{
79    if len as u32 == R::zero().count_zeros() {
80        !R::zero()
81    } else if len == 0 {
82        R::zero()
83    } else {
84        !((!R::zero()) >> len as usize)
85    }
86}
87
88#[cfg(feature = "ipnet")]
89impl Prefix for Ipv4Net {
90    type R = u32;
91
92    fn repr(&self) -> u32 {
93        self.addr().into()
94    }
95
96    fn prefix_len(&self) -> u8 {
97        self.prefix_len()
98    }
99
100    fn from_repr_len(repr: u32, len: u8) -> Self {
101        Ipv4Net::new(repr.into(), len).unwrap()
102    }
103
104    fn mask(&self) -> u32 {
105        self.network().into()
106    }
107
108    fn zero() -> Self {
109        Default::default()
110    }
111
112    fn longest_common_prefix(&self, other: &Self) -> Self {
113        let a = self.repr();
114        let b = other.repr();
115        let len = ((a ^ b).leading_zeros() as u8)
116            .min(self.prefix_len())
117            .min(other.prefix_len());
118        let repr = a & mask_from_prefix_len::<u32>(len);
119        Ipv4Net::new(repr.into(), len).unwrap()
120    }
121
122    fn contains(&self, other: &Self) -> bool {
123        self.contains(other)
124    }
125}
126
127#[cfg(feature = "ipnet")]
128impl Prefix for Ipv6Net {
129    type R = u128;
130
131    fn repr(&self) -> u128 {
132        self.addr().into()
133    }
134
135    fn prefix_len(&self) -> u8 {
136        self.prefix_len()
137    }
138
139    fn from_repr_len(repr: u128, len: u8) -> Self {
140        Ipv6Net::new(repr.into(), len).unwrap()
141    }
142
143    fn mask(&self) -> u128 {
144        self.network().into()
145    }
146
147    fn zero() -> Self {
148        Default::default()
149    }
150
151    fn longest_common_prefix(&self, other: &Self) -> Self {
152        let a = self.repr();
153        let b = other.repr();
154        let len = ((a ^ b).leading_zeros() as u8)
155            .min(self.prefix_len())
156            .min(other.prefix_len());
157        let repr = a & mask_from_prefix_len::<u128>(len);
158        Ipv6Net::new(repr.into(), len).unwrap()
159    }
160
161    fn contains(&self, other: &Self) -> bool {
162        self.contains(other)
163    }
164}
165
166#[cfg(feature = "ipnetwork")]
167impl Prefix for Ipv4Network {
168    type R = u32;
169
170    fn repr(&self) -> u32 {
171        self.ip().into()
172    }
173
174    fn prefix_len(&self) -> u8 {
175        self.prefix()
176    }
177
178    fn from_repr_len(repr: u32, len: u8) -> Self {
179        Ipv4Network::new(repr.into(), len).unwrap()
180    }
181
182    fn mask(&self) -> u32 {
183        self.network().into()
184    }
185}
186
187#[cfg(feature = "ipnetwork")]
188impl Prefix for Ipv6Network {
189    type R = u128;
190
191    fn repr(&self) -> u128 {
192        self.ip().into()
193    }
194
195    fn prefix_len(&self) -> u8 {
196        self.prefix()
197    }
198
199    fn from_repr_len(repr: u128, len: u8) -> Self {
200        Ipv6Network::new(repr.into(), len).unwrap()
201    }
202
203    fn mask(&self) -> u128 {
204        self.network().into()
205    }
206}
207
208#[cfg(feature = "cidr")]
209impl Prefix for Ipv4Cidr {
210    type R = u32;
211
212    fn repr(&self) -> Self::R {
213        self.first_address().into()
214    }
215
216    fn prefix_len(&self) -> u8 {
217        self.network_length()
218    }
219
220    fn from_repr_len(repr: Self::R, len: u8) -> Self {
221        let repr = repr & mask_from_prefix_len::<Self::R>(len);
222        Self::new(repr.into(), len).unwrap()
223    }
224
225    fn mask(&self) -> Self::R {
226        self.first_address().into()
227    }
228}
229
230#[cfg(feature = "cidr")]
231impl Prefix for Ipv6Cidr {
232    type R = u128;
233
234    fn repr(&self) -> Self::R {
235        self.first_address().into()
236    }
237
238    fn prefix_len(&self) -> u8 {
239        self.network_length()
240    }
241
242    fn from_repr_len(repr: Self::R, len: u8) -> Self {
243        let repr = repr & mask_from_prefix_len::<Self::R>(len);
244        Self::new(repr.into(), len).unwrap()
245    }
246
247    fn mask(&self) -> Self::R {
248        self.first_address().into()
249    }
250}
251
252#[cfg(feature = "cidr")]
253impl Prefix for Ipv4Inet {
254    type R = u32;
255
256    fn repr(&self) -> Self::R {
257        self.address().into()
258    }
259
260    fn prefix_len(&self) -> u8 {
261        self.network_length()
262    }
263
264    fn from_repr_len(repr: Self::R, len: u8) -> Self {
265        Self::new(repr.into(), len).unwrap()
266    }
267
268    fn mask(&self) -> Self::R {
269        self.network().first_address().into()
270    }
271}
272
273#[cfg(feature = "cidr")]
274impl Prefix for Ipv6Inet {
275    type R = u128;
276
277    fn repr(&self) -> Self::R {
278        self.address().into()
279    }
280
281    fn prefix_len(&self) -> u8 {
282        self.network_length()
283    }
284
285    fn from_repr_len(repr: Self::R, len: u8) -> Self {
286        Self::new(repr.into(), len).unwrap()
287    }
288
289    fn mask(&self) -> Self::R {
290        self.network().first_address().into()
291    }
292}
293
294impl<R> Prefix for (R, u8)
295where
296    R: Unsigned + PrimInt + Zero + CheckedShr + std::fmt::Debug,
297{
298    type R = R;
299
300    fn repr(&self) -> R {
301        self.0
302    }
303
304    fn prefix_len(&self) -> u8 {
305        self.1
306    }
307
308    fn from_repr_len(repr: R, len: u8) -> Self {
309        (repr, len)
310    }
311}
312
313#[cfg(test)]
314#[cfg(feature = "ipnet")]
315mod test {
316    use super::*;
317
318    macro_rules! pfx {
319        ($p:literal) => {
320            $p.parse::<Ipv4Net>().unwrap()
321        };
322    }
323
324    #[test]
325    fn mask_from_len() {
326        assert_eq!(mask_from_prefix_len::<u8>(3), 0b11100000);
327        assert_eq!(mask_from_prefix_len::<u8>(5), 0b11111000);
328        assert_eq!(mask_from_prefix_len::<u8>(8), 0b11111111);
329        assert_eq!(mask_from_prefix_len::<u8>(0), 0b00000000);
330
331        assert_eq!(mask_from_prefix_len::<u32>(0), 0x00000000);
332        assert_eq!(mask_from_prefix_len::<u32>(8), 0xff000000);
333        assert_eq!(mask_from_prefix_len::<u32>(16), 0xffff0000);
334        assert_eq!(mask_from_prefix_len::<u32>(24), 0xffffff00);
335        assert_eq!(mask_from_prefix_len::<u32>(32), 0xffffffff);
336    }
337
338    #[test]
339    fn prefix_mask() {
340        let addr = pfx!("10.1.0.0/8");
341        assert_eq!(Prefix::prefix_len(&addr), 8);
342        assert_eq!(Prefix::repr(&addr), (10 << 24) + (1 << 16));
343        assert_eq!(Prefix::mask(&addr), 10u32 << 24);
344    }
345
346    #[test]
347    fn contains() {
348        let larger = pfx!("10.128.0.0/9");
349        let smaller = pfx!("10.0.0.0/8");
350        let larger_c = pfx!("10.130.2.5/9");
351        let smaller_c = pfx!("10.25.2.8/8");
352        assert!(smaller.contains(&larger));
353        assert!(smaller.contains(&larger_c));
354        assert!(smaller_c.contains(&larger));
355        assert!(smaller_c.contains(&larger_c));
356        assert!(!larger.contains(&smaller));
357        assert!(!larger.contains(&smaller_c));
358        assert!(!larger_c.contains(&smaller));
359        assert!(!larger_c.contains(&smaller_c));
360        assert!(smaller.contains(&smaller));
361        assert!(smaller.contains(&smaller_c));
362        assert!(smaller_c.contains(&smaller));
363        assert!(smaller_c.contains(&smaller_c));
364    }
365
366    #[test]
367    fn longest_common_prefix() {
368        macro_rules! assert_lcp {
369            ($a:literal, $b:literal, $c:literal) => {
370                assert_eq!(pfx!($a).longest_common_prefix(&pfx!($b)), pfx!($c));
371                assert_eq!(pfx!($b).longest_common_prefix(&pfx!($a)), pfx!($c));
372            };
373        }
374        assert_lcp!("1.2.3.4/24", "1.3.3.4/24", "1.2.0.0/15");
375        assert_lcp!("1.2.3.4/24", "1.1.3.4/24", "1.0.0.0/14");
376        assert_lcp!("1.2.3.4/24", "1.2.3.4/30", "1.2.3.0/24");
377    }
378
379    #[test]
380    fn is_bit_set() {
381        assert!(pfx!("255.0.0.0/8").is_bit_set(0));
382        assert!(pfx!("255.0.0.0/8").is_bit_set(7));
383        assert!(!pfx!("255.0.0.0/8").is_bit_set(8));
384        assert!(!pfx!("255.255.0.0/8").is_bit_set(8));
385    }
386
387    #[generic_tests::define]
388    mod t {
389        use num_traits::NumCast;
390
391        use super::*;
392
393        fn new<P: Prefix>(repr: u32, len: u8) -> P {
394            let repr = <<P as Prefix>::R as NumCast>::from(repr).unwrap();
395            let num_zeros = <<P as Prefix>::R as Zero>::zero().count_zeros() as u8;
396            let len = len + (num_zeros - 32);
397            P::from_repr_len(repr, len)
398        }
399
400        #[test]
401        fn repr_len<P: Prefix>() {
402            for x in [0x01000000u32, 0x010f0000u32, 0xffff0000u32] {
403                let repr = <<P as Prefix>::R as NumCast>::from(x).unwrap();
404                let num_zeros = <<P as Prefix>::R as Zero>::zero().count_zeros() as u8;
405                let len = 16 + (num_zeros - 32);
406                let prefix = P::from_repr_len(repr, len);
407                assert!(prefix.repr() == repr);
408                assert!(prefix.prefix_len() == len);
409            }
410        }
411
412        #[test]
413        fn keep_host_addr<P: Prefix + 'static>() {
414            #[allow(unused_mut)]
415            #[allow(unused_assignments)]
416            let mut prefix_is_masked = false;
417            #[cfg(feature = "cidr")]
418            {
419                let p_id = std::any::TypeId::of::<P>();
420                // Ipv4Cidr and Ipv6Cidr addresses are always masked.
421                prefix_is_masked = p_id == std::any::TypeId::of::<cidr::Ipv4Cidr>()
422                    || p_id == std::any::TypeId::of::<cidr::Ipv6Cidr>();
423            }
424            let mask = 0xffff0000u32;
425            for mut x in [0x01001234u32, 0x010fabcdu32, 0xffff5678u32] {
426                let prefix: P = new(x, 16);
427                if prefix_is_masked {
428                    x &= mask;
429                }
430                assert_eq!(<u32 as NumCast>::from(prefix.repr()), Some(x));
431            }
432        }
433
434        #[test]
435        fn mask<P: Prefix>() {
436            let mask = 0xffff0000u32;
437            for x in [0x01001234u32, 0x010fabcdu32, 0xffff5678u32] {
438                let prefix: P = new(x, 16);
439                assert_eq!(<u32 as NumCast>::from(prefix.mask()), Some(x & mask));
440            }
441        }
442
443        #[test]
444        fn zero<P: Prefix>() {
445            let prefix = P::from_repr_len(P::R::zero(), 0);
446            assert!(P::zero().eq(&prefix));
447        }
448
449        #[test]
450        fn longest_common_prefix<P: Prefix>() {
451            for ((a, al), (b, bl), (c, cl)) in [
452                ((0x01020304, 24), (0x01030304, 24), (0x01020000, 15)),
453                ((0x12345678, 24), (0x12345678, 16), (0x12340000, 16)),
454            ] {
455                let a: P = new(a, al);
456                let b: P = new(b, bl);
457                let c: P = new(c, cl);
458                let lcp = a.longest_common_prefix(&b);
459                assert!(lcp.repr() == c.repr());
460                assert!(lcp.prefix_len() == c.prefix_len());
461            }
462        }
463
464        #[test]
465        fn contains<P: Prefix>() {
466            assert!(new::<P>(0x01020000, 16).contains(&new(0x0102ffff, 24)));
467            assert!(new::<P>(0x01020304, 16).contains(&new(0x0102ffff, 24)));
468            assert!(new::<P>(0x01020304, 16).contains(&new(0x0102ffff, 16)));
469            assert!(!new::<P>(0x01020304, 24).contains(&new(0x0102ffff, 16)));
470        }
471
472        #[test]
473        fn is_bit_set<P: Prefix>() {
474            let x = 0x12345678u32;
475            let num_zeros = <<P as Prefix>::R as Zero>::zero().count_zeros() as u8;
476            let offset = num_zeros - 32;
477            let p: P = new(x, 16);
478            for i in 0..64 {
479                let j = i + offset;
480                if i >= 16 {
481                    assert!(!p.is_bit_set(j))
482                } else {
483                    let mask = 0x80000000u32 >> i;
484                    assert_eq!(p.is_bit_set(j), x & mask != 0)
485                }
486            }
487        }
488
489        #[instantiate_tests(<Ipv4Net>)]
490        mod ipv4net {}
491
492        #[instantiate_tests(<Ipv6Net>)]
493        mod ipv6net {}
494
495        #[cfg(feature = "ipnetwork")]
496        #[instantiate_tests(<Ipv4Network>)]
497        mod ipv4network {}
498
499        #[cfg(feature = "ipnetwork")]
500        #[instantiate_tests(<Ipv6Network>)]
501        mod ipv6network {}
502
503        #[cfg(feature = "cidr")]
504        #[instantiate_tests(<Ipv4Cidr>)]
505        mod ipv4cidr {}
506
507        #[cfg(feature = "cidr")]
508        #[instantiate_tests(<Ipv4Inet>)]
509        mod ipv4inet {}
510
511        #[cfg(feature = "cidr")]
512        #[instantiate_tests(<Ipv6Cidr>)]
513        mod ipv6cidr {}
514
515        #[cfg(feature = "cidr")]
516        #[instantiate_tests(<Ipv6Inet>)]
517        mod ipv6inet {}
518
519        #[instantiate_tests(<(u32, u8)>)]
520        mod u32_u8 {}
521
522        #[instantiate_tests(<(u64, u8)>)]
523        mod u64_u8 {}
524    }
525}