ip/concrete/prefix/set/
ops.rs

1use core::cmp::Ordering;
2use core::ops::{Add, BitAnd, BitOr, BitXor, Mul, Not, Sub};
3
4use num_traits::{One, Zero};
5
6use super::Set;
7use crate::traits::{Afi, AfiClass, PrefixSet as _};
8
9impl<A: Afi> Zero for Set<A> {
10    fn zero() -> Self {
11        Self::new()
12    }
13
14    fn is_zero(&self) -> bool {
15        self.root.is_some()
16    }
17}
18
19impl<A: Afi> One for Set<A> {
20    fn one() -> Self {
21        Self::new()
22            .insert(<A as AfiClass>::PrefixRange::ALL)
23            .clone()
24    }
25}
26
27impl<A: Afi> BitAnd for Set<A> {
28    type Output = Self;
29
30    fn bitand(self, rhs: Self) -> Self::Output {
31        match (self.root, rhs.root) {
32            (Some(r), Some(s)) => Self::Output::new_with_root(r & s).aggregate().clone(),
33            _ => Self::Output::zero(),
34        }
35    }
36}
37
38impl<A: Afi> BitOr for Set<A> {
39    type Output = Self;
40
41    fn bitor(self, rhs: Self) -> Self::Output {
42        match (&self.root, &rhs.root) {
43            (Some(r), Some(s)) => Self::Output::new_with_root(r.clone() | s.clone())
44                .aggregate()
45                .clone(),
46            (Some(_), None) => self,
47            (None, Some(_)) => rhs,
48            (None, None) => Self::Output::zero(),
49        }
50    }
51}
52
53impl<A: Afi> BitXor for Set<A> {
54    type Output = Self;
55
56    fn bitxor(self, rhs: Self) -> Self::Output {
57        (self.clone() | rhs.clone()) - (self & rhs)
58    }
59}
60
61impl<A: Afi> Not for Set<A> {
62    type Output = Self;
63
64    fn not(self) -> Self::Output {
65        Self::Output::one() - self
66    }
67}
68
69impl<A: Afi> Add for Set<A> {
70    type Output = Self;
71
72    #[allow(clippy::suspicious_arithmetic_impl)]
73    fn add(self, rhs: Self) -> Self::Output {
74        self | rhs
75    }
76}
77
78impl<A: Afi> Sub for Set<A> {
79    type Output = Self;
80
81    fn sub(self, rhs: Self) -> Self::Output {
82        match (&self.root, &rhs.root) {
83            (Some(r), Some(s)) => Self::Output::new_with_root(r.clone() - s.clone())
84                .aggregate()
85                .clone(),
86            _ => self,
87        }
88    }
89}
90
91impl<A: Afi> Mul for Set<A> {
92    type Output = Self;
93
94    #[allow(clippy::suspicious_arithmetic_impl)]
95    fn mul(self, rhs: Self) -> Self::Output {
96        self & rhs
97    }
98}
99
100impl<A: Afi> PartialEq for Set<A> {
101    fn eq(&self, other: &Self) -> bool {
102        match (&self.root, &other.root) {
103            (Some(r), Some(s)) => r.children().zip(s.children()).all(|(m, n)| m == n),
104            (None, None) => true,
105            _ => false,
106        }
107    }
108}
109
110impl<A: Afi> PartialOrd for Set<A> {
111    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
112        if self == other {
113            Some(Ordering::Equal)
114        } else if self.prefixes().all(|p| other.contains(p)) {
115            Some(Ordering::Less)
116        } else if other.prefixes().all(|p| self.contains(p)) {
117            Some(Ordering::Greater)
118        } else {
119            None
120        }
121    }
122}
123
124impl<A: Afi> Eq for Set<A> {}
125
126#[cfg(test)]
127mod tests {
128    use core::str::FromStr;
129    use std::{dbg, vec};
130
131    use paste::paste;
132
133    use super::super::Node;
134    use super::*;
135    use crate::{
136        concrete::{Prefix, PrefixRange},
137        error::{Error, TestResult},
138        Ipv4, Ipv6,
139    };
140
141    impl<A: Afi> FromIterator<&'static str> for Set<A> {
142        fn from_iter<T: IntoIterator<Item = &'static str>>(iter: T) -> Self {
143            enum Insertable<A: Afi> {
144                Prefix(Prefix<A>),
145                Range(PrefixRange<A>),
146            }
147            impl<A: Afi> FromStr for Insertable<A> {
148                type Err = Error;
149                fn from_str(s: &str) -> Result<Self, Self::Err> {
150                    s.parse()
151                        .map(Insertable::Range)
152                        .or_else(|_| s.parse().map(Insertable::Prefix))
153                }
154            }
155            #[allow(clippy::from_over_into)]
156            impl<A: Afi> Into<Node<A>> for Insertable<A> {
157                fn into(self) -> Node<A> {
158                    match self {
159                        Self::Prefix(prefix) => prefix.into(),
160                        Self::Range(range) => range.into(),
161                    }
162                }
163            }
164            iter.into_iter()
165                .map(Insertable::from_str)
166                .collect::<Result<_, _>>()
167                .unwrap()
168        }
169    }
170
171    macro_rules! test_exprs {
172        ( $($fn_id:ident {$lhs:expr, $rhs:expr});* ) => {
173            test_exprs!(@ipv4 {$($fn_id {$lhs, $rhs});*});
174            test_exprs!(@ipv6 {$($fn_id {$lhs, $rhs});*});
175        };
176        ( @ipv4 {$($fn_id:ident {$lhs:expr, $rhs:expr});*} ) => {
177            paste! {
178                test_exprs!($(Ipv4 => [<ipv4_ $fn_id>] {$lhs, $rhs});*);
179            }
180        };
181        ( @ipv6 {$($fn_id:ident {$lhs:expr, $rhs:expr});*} ) => {
182            paste! {
183                test_exprs!($(Ipv6 => [<ipv6_ $fn_id>] {$lhs, $rhs});*);
184            }
185        };
186        ( $($p:ty => $fn_id:ident {$lhs:expr, $rhs:expr});* ) => {
187            paste! {
188                $(
189                    #[test]
190                    fn $fn_id() -> TestResult {
191                        let res: Set<$p> = dbg!($lhs);
192                        assert_eq!(res, dbg!($rhs));
193                        Ok(())
194                    }
195                )*
196            }
197        };
198    }
199
200    macro_rules! test_unary_op {
201        ( $( !$operand:ident == $expect:ident),* ) => {
202            test_unary_op!(@call $(not $operand == $expect),*);
203        };
204        ( @call $($op:ident $operand:ident == $expect:ident),* ) => {
205            paste! {
206                test_exprs!($(
207                    [<$op _ $operand _is_ $expect>] {
208                        Set::$operand().$op(),
209                        Set::$expect()
210                    }
211                );*);
212            }
213        }
214    }
215
216    macro_rules! test_binary_op {
217        ( $($lhs:ident & $rhs:ident == $expect:ident),* ) => {
218            test_binary_op!(@call $($lhs bitand $rhs == $expect),*);
219        };
220        ( $($lhs:ident | $rhs:ident == $expect:ident),* ) => {
221            test_binary_op!(@call $($lhs bitor $rhs == $expect),*);
222        };
223        ( $($lhs:ident ^ $rhs:ident == $expect:ident),* ) => {
224            test_binary_op!(@call $($lhs bitxor $rhs == $expect),*);
225        };
226        ( $($lhs:ident + $rhs:ident == $expect:ident),* ) => {
227            test_binary_op!(@call $($lhs add $rhs == $expect),*);
228        };
229        ( $($lhs:ident - $rhs:ident == $expect:ident),* ) => {
230            test_binary_op!(@call $($lhs sub $rhs == $expect),*);
231        };
232        ( $($lhs:ident * $rhs:ident == $expect:ident),* ) => {
233            test_binary_op!(@call $($lhs mul $rhs == $expect),*);
234        };
235        ( @call $($lhs:ident $op:ident $rhs:ident == $expect:ident),* ) => {
236            paste! {
237                test_exprs!($(
238                    [<$lhs _ $op _ $rhs _is_ $expect>] {
239                        Set::$lhs().$op(Set::$rhs()),
240                        Set::$expect()
241                    }
242                );*);
243            }
244        }
245    }
246
247    #[test]
248    fn ipv4_zero_set_is_empty() {
249        assert_eq!(Set::<Ipv4>::zero().prefixes().count(), 0);
250    }
251
252    #[test]
253    fn ipv6_zero_set_is_empty() {
254        assert_eq!(Set::<Ipv6>::zero().prefixes().count(), 0);
255    }
256
257    test_unary_op!(!zero == one, !one == zero);
258
259    test_binary_op!(
260        zero & zero == zero,
261        zero & one == zero,
262        one & zero == zero,
263        one & one == one
264    );
265
266    test_binary_op!(
267        zero | zero == zero,
268        zero | one == one,
269        one | zero == one,
270        one | one == one
271    );
272
273    test_binary_op!(
274        zero ^ zero == zero,
275        zero ^ one == one,
276        one ^ zero == one,
277        one ^ one == zero
278    );
279
280    test_binary_op!(
281        zero + zero == zero,
282        zero + one == one,
283        one + zero == one,
284        one + one == one
285    );
286
287    test_binary_op!(
288        zero - zero == zero,
289        zero - one == zero,
290        one - zero == one,
291        one - one == zero
292    );
293
294    test_binary_op!(
295        zero * zero == zero,
296        zero * one == zero,
297        one * zero == zero,
298        one * one == one
299    );
300
301    test_exprs!( @ipv4 {
302        intersect_disjoint_nodes {
303            vec!["1.0.0.0/8,8,16"].into_iter().collect::<Set<_>>()
304                & vec!["2.0.0.0/8,8,16"].into_iter().collect(),
305            Set::zero()
306        };
307        intersect_disjoint_ranges {
308            vec!["1.0.0.0/8,8,11"].into_iter().collect::<Set<_>>()
309                & vec!["1.0.0.0/8,12,15"].into_iter().collect(),
310            Set::zero()
311        };
312        intersect_overlapping_nodes {
313            vec!["1.0.0.0/8,12,16"].into_iter().collect::<Set<_>>()
314                & vec!["1.0.0.0/12,12,16"].into_iter().collect(),
315            vec!["1.0.0.0/12,12,16"].into_iter().collect()
316        };
317        intersect_overlapping_ranges {
318            vec!["1.0.0.0/8,8,12"].into_iter().collect::<Set<_>>()
319                & vec!["1.0.0.0/8,12,16"].into_iter().collect(),
320            vec!["1.0.0.0/8,12,12"].into_iter().collect()
321        };
322        intersect_overlapping_set_with_parent {
323            vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
324                & vec!["1.0.0.0/16"].into_iter().collect(),
325            vec!["1.0.0.0/16"].into_iter().collect()
326        };
327        intersect_overlapping_set_with_sibling {
328            vec!["1.0.0.0/8", "2.0.0.0/8"].into_iter().collect::<Set<_>>()
329                & vec!["1.0.0.0/8"].into_iter().collect(),
330            vec!["1.0.0.0/8"].into_iter().collect()
331        };
332        intersect_overlapping_set_with_child {
333            vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
334                & vec!["1.0.0.0/8"].into_iter().collect(),
335            vec!["1.0.0.0/8"].into_iter().collect()
336        };
337        intersect_covering_parent {
338            vec!["1.0.0.0/16"].into_iter().collect::<Set<_>>()
339                & vec!["1.0.0.0/8,16,16"].into_iter().collect(),
340            vec!["1.0.0.0/16"].into_iter().collect()
341        };
342        intersect_covered_child {
343            vec!["1.0.0.0/8,16,16"].into_iter().collect::<Set<_>>()
344                & vec!["1.0.0.0/16"].into_iter().collect(),
345            vec!["1.0.0.0/16"].into_iter().collect()
346        };
347        intersect_overlapping_set_with_covered_child {
348            vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
349                & vec!["1.0.0.0/8,16,16"].into_iter().collect(),
350            vec!["1.0.0.0/16"].into_iter().collect()
351        };
352        union_disjoint_nodes {
353            vec!["2.0.0.0/8,8,16"].into_iter().collect::<Set<_>>()
354                | vec!["3.0.0.0/8,8,16"].into_iter().collect(),
355            vec!["2.0.0.0/7,8,16"].into_iter().collect()
356        };
357        union_disjoint_ranges {
358            vec!["1.0.0.0/8,8,11"].into_iter().collect::<Set<_>>()
359                | vec!["1.0.0.0/8,12,15"].into_iter().collect(),
360            vec!["1.0.0.0/8,8,15"].into_iter().collect()
361        };
362        union_overlapping_nodes {
363            vec!["1.0.0.0/8,12,16"].into_iter().collect::<Set<_>>()
364                | vec!["1.0.0.0/12,12,16"].into_iter().collect(),
365            vec!["1.0.0.0/8,12,16"].into_iter().collect()
366        };
367        union_overlapping_ranges {
368            vec!["1.0.0.0/8,8,12"].into_iter().collect::<Set<_>>()
369                | vec!["1.0.0.0/8,12,16"].into_iter().collect(),
370            vec!["1.0.0.0/8,8,16"].into_iter().collect()
371        };
372        union_overlapping_set_with_parent {
373            vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
374                | vec!["1.0.0.0/16"].into_iter().collect(),
375            vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect()
376        };
377        union_overlapping_set_with_sibling {
378            vec!["1.0.0.0/8", "2.0.0.0/8"].into_iter().collect::<Set<_>>()
379                | vec!["1.0.0.0/8"].into_iter().collect(),
380            vec!["1.0.0.0/8", "2.0.0.0/8"].into_iter().collect()
381        };
382        union_overlapping_set_with_child {
383            vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
384                | vec!["1.0.0.0/8"].into_iter().collect(),
385            vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect()
386        };
387        union_covering_parent {
388            vec!["1.0.0.0/16"].into_iter().collect::<Set<_>>()
389                | vec!["1.0.0.0/8,16,16"].into_iter().collect(),
390            vec!["1.0.0.0/8,16,16"].into_iter().collect()
391        };
392        union_covered_child {
393            vec!["1.0.0.0/8,16,16"].into_iter().collect::<Set<_>>()
394                | vec!["1.0.0.0/16"].into_iter().collect(),
395            vec!["1.0.0.0/8,16,16"].into_iter().collect()
396        };
397        union_overlapping_set_with_covered_child {
398            vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
399                | vec!["1.0.0.0/8,16,16"].into_iter().collect(),
400            vec!["1.0.0.0/8", "1.0.0.0/8,16,16"].into_iter().collect()
401        };
402        xor_disjoint_nodes {
403            vec!["2.0.0.0/8,8,16"].into_iter().collect::<Set<_>>()
404                ^ vec!["3.0.0.0/8,8,16"].into_iter().collect(),
405            vec!["2.0.0.0/7,8,16"].into_iter().collect()
406        };
407        xor_disjoint_ranges {
408            vec!["1.0.0.0/8,8,11"].into_iter().collect::<Set<_>>()
409                ^ vec!["1.0.0.0/8,12,15"].into_iter().collect(),
410            vec!["1.0.0.0/8,8,15"].into_iter().collect()
411        };
412        xor_overlapping_nodes {
413            vec!["1.0.0.0/8,12,16"].into_iter().collect::<Set<_>>()
414                ^ vec!["1.0.0.0/12,12,16"].into_iter().collect(),
415            vec![
416                "1.16.0.0/12,12,16",
417                "1.32.0.0/11,12,16",
418                "1.64.0.0/10,12,16",
419                "1.128.0.0/9,12,16"
420            ].into_iter().collect()
421        };
422        xor_overlapping_ranges {
423            vec!["1.0.0.0/8,8,12"].into_iter().collect::<Set<_>>()
424                ^ vec!["1.0.0.0/8,12,16"].into_iter().collect(),
425            vec!["1.0.0.0/8,8,11", "1.0.0.0/8,13,16"].into_iter().collect()
426        };
427        xor_overlapping_set_with_parent {
428            vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
429                ^ vec!["1.0.0.0/16"].into_iter().collect(),
430            vec!["1.0.0.0/8"].into_iter().collect()
431        };
432        xor_overlapping_set_with_sibling {
433            vec!["1.0.0.0/8", "2.0.0.0/8"].into_iter().collect::<Set<_>>()
434                ^ vec!["1.0.0.0/8"].into_iter().collect(),
435            vec!["2.0.0.0/8"].into_iter().collect()
436        };
437        xor_overlapping_set_with_child {
438            vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
439                ^ vec!["1.0.0.0/8"].into_iter().collect(),
440            vec!["1.0.0.0/16"].into_iter().collect()
441        };
442        xor_covering_parent {
443            vec!["1.0.0.0/16"].into_iter().collect::<Set<_>>()
444                ^ vec!["1.0.0.0/8,16,16"].into_iter().collect(),
445            vec![
446                "1.1.0.0/16",
447                "1.2.0.0/15,16,16",
448                "1.4.0.0/14,16,16",
449                "1.8.0.0/13,16,16",
450                "1.16.0.0/12,16,16",
451                "1.32.0.0/11,16,16",
452                "1.64.0.0/10,16,16",
453                "1.128.0.0/9,16,16",
454            ].into_iter().collect()
455        };
456        xor_covered_child {
457            vec!["1.0.0.0/8,16,16"].into_iter().collect::<Set<_>>()
458                ^ vec!["1.0.0.0/16"].into_iter().collect(),
459            vec![
460                "1.1.0.0/16",
461                "1.2.0.0/15,16,16",
462                "1.4.0.0/14,16,16",
463                "1.8.0.0/13,16,16",
464                "1.16.0.0/12,16,16",
465                "1.32.0.0/11,16,16",
466                "1.64.0.0/10,16,16",
467                "1.128.0.0/9,16,16",
468            ].into_iter().collect()
469        };
470        xor_overlapping_set_with_covered_child {
471            vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
472                ^ vec!["1.0.0.0/8,16,16"].into_iter().collect(),
473            vec!["1.0.0.0/8"].into_iter().collect()
474        };
475        sub_disjoint_nodes {
476            vec!["2.0.0.0/8,8,16"].into_iter().collect::<Set<_>>()
477                - vec!["3.0.0.0/8,8,16"].into_iter().collect(),
478            vec!["2.0.0.0/8,8,16"].into_iter().collect()
479        };
480        sub_disjoint_ranges {
481            vec!["1.0.0.0/8,8,11"].into_iter().collect::<Set<_>>()
482                - vec!["1.0.0.0/8,12,15"].into_iter().collect(),
483            vec!["1.0.0.0/8,8,11"].into_iter().collect()
484        };
485        sub_overlapping_nodes {
486            vec!["1.0.0.0/8,12,16"].into_iter().collect::<Set<_>>()
487                - vec!["1.0.0.0/12,12,16"].into_iter().collect(),
488            vec![
489                "1.16.0.0/12,12,16",
490                "1.32.0.0/11,12,16",
491                "1.64.0.0/10,12,16",
492                "1.128.0.0/9,12,16"
493            ].into_iter().collect()
494        };
495        sub_overlapping_ranges {
496            vec!["1.0.0.0/8,8,12"].into_iter().collect::<Set<_>>()
497                - vec!["1.0.0.0/8,12,16"].into_iter().collect(),
498            vec!["1.0.0.0/8,8,11"].into_iter().collect()
499        };
500        sub_overlapping_set_with_parent {
501            vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
502                - vec!["1.0.0.0/16"].into_iter().collect(),
503            vec!["1.0.0.0/8"].into_iter().collect()
504        };
505        sub_overlapping_set_with_sibling {
506            vec!["1.0.0.0/8", "2.0.0.0/8"].into_iter().collect::<Set<_>>()
507                - vec!["1.0.0.0/8"].into_iter().collect(),
508            vec!["2.0.0.0/8"].into_iter().collect()
509        };
510        sub_overlapping_set_with_child {
511            vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
512                - vec!["1.0.0.0/8"].into_iter().collect(),
513            vec!["1.0.0.0/16"].into_iter().collect()
514        };
515        sub_covering_parent {
516            vec!["1.0.0.0/16"].into_iter().collect::<Set<_>>()
517                - vec!["1.0.0.0/8,16,16"].into_iter().collect(),
518            Set::zero()
519        };
520        sub_covered_child {
521            vec!["1.0.0.0/8,16,16"].into_iter().collect::<Set<_>>()
522                - vec!["1.0.0.0/16"].into_iter().collect(),
523            vec![
524                "1.1.0.0/16",
525                "1.2.0.0/15,16,16",
526                "1.4.0.0/14,16,16",
527                "1.8.0.0/13,16,16",
528                "1.16.0.0/12,16,16",
529                "1.32.0.0/11,16,16",
530                "1.64.0.0/10,16,16",
531                "1.128.0.0/9,16,16",
532            ].into_iter().collect()
533        };
534        sub_overlapping_set_with_covered_child {
535            vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
536                - vec!["1.0.0.0/8,16,16"].into_iter().collect(),
537            vec![
538                "1.0.0.0/8",
539                "1.1.0.0/16",
540                "1.2.0.0/15,16,16",
541                "1.4.0.0/14,16,16",
542                "1.8.0.0/13,16,16",
543                "1.16.0.0/12,16,16",
544                "1.32.0.0/11,16,16",
545                "1.64.0.0/10,16,16",
546                "1.128.0.0/9,16,16",
547            ].into_iter().collect()
548        };
549        sub_complex_deaggregation {
550            vec!["2.0.0.0/8,8,10", "3.0.0.0/8,8,9"].into_iter().collect::<Set<_>>()
551                - vec!["2.0.0.0/10", "3.0.0.0/8,8,10"].into_iter().collect(),
552            vec![
553                "2.0.0.0/8,8,9",
554                "2.64.0.0/10",
555                "2.128.0.0/10",
556                "2.192.0.0/10",
557            ].into_iter().collect()
558        };
559        not_singleton {
560            ! vec!["1.0.0.0/8"].into_iter().collect::<Set<_>>(),
561            vec![
562                "0.0.0.0/0,0,7",
563                "0.0.0.0/0,9,32",
564                "0.0.0.0/8",
565                "2.0.0.0/7,8,8",
566                "4.0.0.0/6,8,8",
567                "8.0.0.0/5,8,8",
568                "16.0.0.0/4,8,8",
569                "32.0.0.0/3,8,8",
570                "64.0.0.0/2,8,8",
571                "128.0.0.0/1,8,8"
572            ].into_iter().collect()
573        };
574        not_range {
575            ! vec!["1.0.0.0/8,8,16"].into_iter().collect::<Set<_>>(),
576            vec![
577                "0.0.0.0/0,0,7",
578                "0.0.0.0/0,17,32",
579                "0.0.0.0/8,8,16",
580                "2.0.0.0/7,8,16",
581                "4.0.0.0/6,8,16",
582                "8.0.0.0/5,8,16",
583                "16.0.0.0/4,8,16",
584                "32.0.0.0/3,8,16",
585                "64.0.0.0/2,8,16",
586                "128.0.0.0/1,8,16",
587            ].into_iter().collect()
588        }
589    });
590}