1use core::cmp::Ordering;
2use core::ops::{Add, BitAnd, BitOr, BitXor, Mul, Not, Sub};
3
4use num_traits::{One, Zero};
5
6use super::Set;
7use crate::traits::{Afi, AfiClass, PrefixSet as _};
8
9impl<A: Afi> Zero for Set<A> {
10 fn zero() -> Self {
11 Self::new()
12 }
13
14 fn is_zero(&self) -> bool {
15 self.root.is_some()
16 }
17}
18
19impl<A: Afi> One for Set<A> {
20 fn one() -> Self {
21 Self::new()
22 .insert(<A as AfiClass>::PrefixRange::ALL)
23 .clone()
24 }
25}
26
27impl<A: Afi> BitAnd for Set<A> {
28 type Output = Self;
29
30 fn bitand(self, rhs: Self) -> Self::Output {
31 match (self.root, rhs.root) {
32 (Some(r), Some(s)) => Self::Output::new_with_root(r & s).aggregate().clone(),
33 _ => Self::Output::zero(),
34 }
35 }
36}
37
38impl<A: Afi> BitOr for Set<A> {
39 type Output = Self;
40
41 fn bitor(self, rhs: Self) -> Self::Output {
42 match (&self.root, &rhs.root) {
43 (Some(r), Some(s)) => Self::Output::new_with_root(r.clone() | s.clone())
44 .aggregate()
45 .clone(),
46 (Some(_), None) => self,
47 (None, Some(_)) => rhs,
48 (None, None) => Self::Output::zero(),
49 }
50 }
51}
52
53impl<A: Afi> BitXor for Set<A> {
54 type Output = Self;
55
56 fn bitxor(self, rhs: Self) -> Self::Output {
57 (self.clone() | rhs.clone()) - (self & rhs)
58 }
59}
60
61impl<A: Afi> Not for Set<A> {
62 type Output = Self;
63
64 fn not(self) -> Self::Output {
65 Self::Output::one() - self
66 }
67}
68
69impl<A: Afi> Add for Set<A> {
70 type Output = Self;
71
72 #[allow(clippy::suspicious_arithmetic_impl)]
73 fn add(self, rhs: Self) -> Self::Output {
74 self | rhs
75 }
76}
77
78impl<A: Afi> Sub for Set<A> {
79 type Output = Self;
80
81 fn sub(self, rhs: Self) -> Self::Output {
82 match (&self.root, &rhs.root) {
83 (Some(r), Some(s)) => Self::Output::new_with_root(r.clone() - s.clone())
84 .aggregate()
85 .clone(),
86 _ => self,
87 }
88 }
89}
90
91impl<A: Afi> Mul for Set<A> {
92 type Output = Self;
93
94 #[allow(clippy::suspicious_arithmetic_impl)]
95 fn mul(self, rhs: Self) -> Self::Output {
96 self & rhs
97 }
98}
99
100impl<A: Afi> PartialEq for Set<A> {
101 fn eq(&self, other: &Self) -> bool {
102 match (&self.root, &other.root) {
103 (Some(r), Some(s)) => r.children().zip(s.children()).all(|(m, n)| m == n),
104 (None, None) => true,
105 _ => false,
106 }
107 }
108}
109
110impl<A: Afi> PartialOrd for Set<A> {
111 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
112 if self == other {
113 Some(Ordering::Equal)
114 } else if self.prefixes().all(|p| other.contains(p)) {
115 Some(Ordering::Less)
116 } else if other.prefixes().all(|p| self.contains(p)) {
117 Some(Ordering::Greater)
118 } else {
119 None
120 }
121 }
122}
123
124impl<A: Afi> Eq for Set<A> {}
125
126#[cfg(test)]
127mod tests {
128 use core::str::FromStr;
129 use std::{dbg, vec};
130
131 use paste::paste;
132
133 use super::super::Node;
134 use super::*;
135 use crate::{
136 concrete::{Prefix, PrefixRange},
137 error::{Error, TestResult},
138 Ipv4, Ipv6,
139 };
140
141 impl<A: Afi> FromIterator<&'static str> for Set<A> {
142 fn from_iter<T: IntoIterator<Item = &'static str>>(iter: T) -> Self {
143 enum Insertable<A: Afi> {
144 Prefix(Prefix<A>),
145 Range(PrefixRange<A>),
146 }
147 impl<A: Afi> FromStr for Insertable<A> {
148 type Err = Error;
149 fn from_str(s: &str) -> Result<Self, Self::Err> {
150 s.parse()
151 .map(Insertable::Range)
152 .or_else(|_| s.parse().map(Insertable::Prefix))
153 }
154 }
155 #[allow(clippy::from_over_into)]
156 impl<A: Afi> Into<Node<A>> for Insertable<A> {
157 fn into(self) -> Node<A> {
158 match self {
159 Self::Prefix(prefix) => prefix.into(),
160 Self::Range(range) => range.into(),
161 }
162 }
163 }
164 iter.into_iter()
165 .map(Insertable::from_str)
166 .collect::<Result<_, _>>()
167 .unwrap()
168 }
169 }
170
171 macro_rules! test_exprs {
172 ( $($fn_id:ident {$lhs:expr, $rhs:expr});* ) => {
173 test_exprs!(@ipv4 {$($fn_id {$lhs, $rhs});*});
174 test_exprs!(@ipv6 {$($fn_id {$lhs, $rhs});*});
175 };
176 ( @ipv4 {$($fn_id:ident {$lhs:expr, $rhs:expr});*} ) => {
177 paste! {
178 test_exprs!($(Ipv4 => [<ipv4_ $fn_id>] {$lhs, $rhs});*);
179 }
180 };
181 ( @ipv6 {$($fn_id:ident {$lhs:expr, $rhs:expr});*} ) => {
182 paste! {
183 test_exprs!($(Ipv6 => [<ipv6_ $fn_id>] {$lhs, $rhs});*);
184 }
185 };
186 ( $($p:ty => $fn_id:ident {$lhs:expr, $rhs:expr});* ) => {
187 paste! {
188 $(
189 #[test]
190 fn $fn_id() -> TestResult {
191 let res: Set<$p> = dbg!($lhs);
192 assert_eq!(res, dbg!($rhs));
193 Ok(())
194 }
195 )*
196 }
197 };
198 }
199
200 macro_rules! test_unary_op {
201 ( $( !$operand:ident == $expect:ident),* ) => {
202 test_unary_op!(@call $(not $operand == $expect),*);
203 };
204 ( @call $($op:ident $operand:ident == $expect:ident),* ) => {
205 paste! {
206 test_exprs!($(
207 [<$op _ $operand _is_ $expect>] {
208 Set::$operand().$op(),
209 Set::$expect()
210 }
211 );*);
212 }
213 }
214 }
215
216 macro_rules! test_binary_op {
217 ( $($lhs:ident & $rhs:ident == $expect:ident),* ) => {
218 test_binary_op!(@call $($lhs bitand $rhs == $expect),*);
219 };
220 ( $($lhs:ident | $rhs:ident == $expect:ident),* ) => {
221 test_binary_op!(@call $($lhs bitor $rhs == $expect),*);
222 };
223 ( $($lhs:ident ^ $rhs:ident == $expect:ident),* ) => {
224 test_binary_op!(@call $($lhs bitxor $rhs == $expect),*);
225 };
226 ( $($lhs:ident + $rhs:ident == $expect:ident),* ) => {
227 test_binary_op!(@call $($lhs add $rhs == $expect),*);
228 };
229 ( $($lhs:ident - $rhs:ident == $expect:ident),* ) => {
230 test_binary_op!(@call $($lhs sub $rhs == $expect),*);
231 };
232 ( $($lhs:ident * $rhs:ident == $expect:ident),* ) => {
233 test_binary_op!(@call $($lhs mul $rhs == $expect),*);
234 };
235 ( @call $($lhs:ident $op:ident $rhs:ident == $expect:ident),* ) => {
236 paste! {
237 test_exprs!($(
238 [<$lhs _ $op _ $rhs _is_ $expect>] {
239 Set::$lhs().$op(Set::$rhs()),
240 Set::$expect()
241 }
242 );*);
243 }
244 }
245 }
246
247 #[test]
248 fn ipv4_zero_set_is_empty() {
249 assert_eq!(Set::<Ipv4>::zero().prefixes().count(), 0);
250 }
251
252 #[test]
253 fn ipv6_zero_set_is_empty() {
254 assert_eq!(Set::<Ipv6>::zero().prefixes().count(), 0);
255 }
256
257 test_unary_op!(!zero == one, !one == zero);
258
259 test_binary_op!(
260 zero & zero == zero,
261 zero & one == zero,
262 one & zero == zero,
263 one & one == one
264 );
265
266 test_binary_op!(
267 zero | zero == zero,
268 zero | one == one,
269 one | zero == one,
270 one | one == one
271 );
272
273 test_binary_op!(
274 zero ^ zero == zero,
275 zero ^ one == one,
276 one ^ zero == one,
277 one ^ one == zero
278 );
279
280 test_binary_op!(
281 zero + zero == zero,
282 zero + one == one,
283 one + zero == one,
284 one + one == one
285 );
286
287 test_binary_op!(
288 zero - zero == zero,
289 zero - one == zero,
290 one - zero == one,
291 one - one == zero
292 );
293
294 test_binary_op!(
295 zero * zero == zero,
296 zero * one == zero,
297 one * zero == zero,
298 one * one == one
299 );
300
301 test_exprs!( @ipv4 {
302 intersect_disjoint_nodes {
303 vec!["1.0.0.0/8,8,16"].into_iter().collect::<Set<_>>()
304 & vec!["2.0.0.0/8,8,16"].into_iter().collect(),
305 Set::zero()
306 };
307 intersect_disjoint_ranges {
308 vec!["1.0.0.0/8,8,11"].into_iter().collect::<Set<_>>()
309 & vec!["1.0.0.0/8,12,15"].into_iter().collect(),
310 Set::zero()
311 };
312 intersect_overlapping_nodes {
313 vec!["1.0.0.0/8,12,16"].into_iter().collect::<Set<_>>()
314 & vec!["1.0.0.0/12,12,16"].into_iter().collect(),
315 vec!["1.0.0.0/12,12,16"].into_iter().collect()
316 };
317 intersect_overlapping_ranges {
318 vec!["1.0.0.0/8,8,12"].into_iter().collect::<Set<_>>()
319 & vec!["1.0.0.0/8,12,16"].into_iter().collect(),
320 vec!["1.0.0.0/8,12,12"].into_iter().collect()
321 };
322 intersect_overlapping_set_with_parent {
323 vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
324 & vec!["1.0.0.0/16"].into_iter().collect(),
325 vec!["1.0.0.0/16"].into_iter().collect()
326 };
327 intersect_overlapping_set_with_sibling {
328 vec!["1.0.0.0/8", "2.0.0.0/8"].into_iter().collect::<Set<_>>()
329 & vec!["1.0.0.0/8"].into_iter().collect(),
330 vec!["1.0.0.0/8"].into_iter().collect()
331 };
332 intersect_overlapping_set_with_child {
333 vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
334 & vec!["1.0.0.0/8"].into_iter().collect(),
335 vec!["1.0.0.0/8"].into_iter().collect()
336 };
337 intersect_covering_parent {
338 vec!["1.0.0.0/16"].into_iter().collect::<Set<_>>()
339 & vec!["1.0.0.0/8,16,16"].into_iter().collect(),
340 vec!["1.0.0.0/16"].into_iter().collect()
341 };
342 intersect_covered_child {
343 vec!["1.0.0.0/8,16,16"].into_iter().collect::<Set<_>>()
344 & vec!["1.0.0.0/16"].into_iter().collect(),
345 vec!["1.0.0.0/16"].into_iter().collect()
346 };
347 intersect_overlapping_set_with_covered_child {
348 vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
349 & vec!["1.0.0.0/8,16,16"].into_iter().collect(),
350 vec!["1.0.0.0/16"].into_iter().collect()
351 };
352 union_disjoint_nodes {
353 vec!["2.0.0.0/8,8,16"].into_iter().collect::<Set<_>>()
354 | vec!["3.0.0.0/8,8,16"].into_iter().collect(),
355 vec!["2.0.0.0/7,8,16"].into_iter().collect()
356 };
357 union_disjoint_ranges {
358 vec!["1.0.0.0/8,8,11"].into_iter().collect::<Set<_>>()
359 | vec!["1.0.0.0/8,12,15"].into_iter().collect(),
360 vec!["1.0.0.0/8,8,15"].into_iter().collect()
361 };
362 union_overlapping_nodes {
363 vec!["1.0.0.0/8,12,16"].into_iter().collect::<Set<_>>()
364 | vec!["1.0.0.0/12,12,16"].into_iter().collect(),
365 vec!["1.0.0.0/8,12,16"].into_iter().collect()
366 };
367 union_overlapping_ranges {
368 vec!["1.0.0.0/8,8,12"].into_iter().collect::<Set<_>>()
369 | vec!["1.0.0.0/8,12,16"].into_iter().collect(),
370 vec!["1.0.0.0/8,8,16"].into_iter().collect()
371 };
372 union_overlapping_set_with_parent {
373 vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
374 | vec!["1.0.0.0/16"].into_iter().collect(),
375 vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect()
376 };
377 union_overlapping_set_with_sibling {
378 vec!["1.0.0.0/8", "2.0.0.0/8"].into_iter().collect::<Set<_>>()
379 | vec!["1.0.0.0/8"].into_iter().collect(),
380 vec!["1.0.0.0/8", "2.0.0.0/8"].into_iter().collect()
381 };
382 union_overlapping_set_with_child {
383 vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
384 | vec!["1.0.0.0/8"].into_iter().collect(),
385 vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect()
386 };
387 union_covering_parent {
388 vec!["1.0.0.0/16"].into_iter().collect::<Set<_>>()
389 | vec!["1.0.0.0/8,16,16"].into_iter().collect(),
390 vec!["1.0.0.0/8,16,16"].into_iter().collect()
391 };
392 union_covered_child {
393 vec!["1.0.0.0/8,16,16"].into_iter().collect::<Set<_>>()
394 | vec!["1.0.0.0/16"].into_iter().collect(),
395 vec!["1.0.0.0/8,16,16"].into_iter().collect()
396 };
397 union_overlapping_set_with_covered_child {
398 vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
399 | vec!["1.0.0.0/8,16,16"].into_iter().collect(),
400 vec!["1.0.0.0/8", "1.0.0.0/8,16,16"].into_iter().collect()
401 };
402 xor_disjoint_nodes {
403 vec!["2.0.0.0/8,8,16"].into_iter().collect::<Set<_>>()
404 ^ vec!["3.0.0.0/8,8,16"].into_iter().collect(),
405 vec!["2.0.0.0/7,8,16"].into_iter().collect()
406 };
407 xor_disjoint_ranges {
408 vec!["1.0.0.0/8,8,11"].into_iter().collect::<Set<_>>()
409 ^ vec!["1.0.0.0/8,12,15"].into_iter().collect(),
410 vec!["1.0.0.0/8,8,15"].into_iter().collect()
411 };
412 xor_overlapping_nodes {
413 vec!["1.0.0.0/8,12,16"].into_iter().collect::<Set<_>>()
414 ^ vec!["1.0.0.0/12,12,16"].into_iter().collect(),
415 vec![
416 "1.16.0.0/12,12,16",
417 "1.32.0.0/11,12,16",
418 "1.64.0.0/10,12,16",
419 "1.128.0.0/9,12,16"
420 ].into_iter().collect()
421 };
422 xor_overlapping_ranges {
423 vec!["1.0.0.0/8,8,12"].into_iter().collect::<Set<_>>()
424 ^ vec!["1.0.0.0/8,12,16"].into_iter().collect(),
425 vec!["1.0.0.0/8,8,11", "1.0.0.0/8,13,16"].into_iter().collect()
426 };
427 xor_overlapping_set_with_parent {
428 vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
429 ^ vec!["1.0.0.0/16"].into_iter().collect(),
430 vec!["1.0.0.0/8"].into_iter().collect()
431 };
432 xor_overlapping_set_with_sibling {
433 vec!["1.0.0.0/8", "2.0.0.0/8"].into_iter().collect::<Set<_>>()
434 ^ vec!["1.0.0.0/8"].into_iter().collect(),
435 vec!["2.0.0.0/8"].into_iter().collect()
436 };
437 xor_overlapping_set_with_child {
438 vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
439 ^ vec!["1.0.0.0/8"].into_iter().collect(),
440 vec!["1.0.0.0/16"].into_iter().collect()
441 };
442 xor_covering_parent {
443 vec!["1.0.0.0/16"].into_iter().collect::<Set<_>>()
444 ^ vec!["1.0.0.0/8,16,16"].into_iter().collect(),
445 vec![
446 "1.1.0.0/16",
447 "1.2.0.0/15,16,16",
448 "1.4.0.0/14,16,16",
449 "1.8.0.0/13,16,16",
450 "1.16.0.0/12,16,16",
451 "1.32.0.0/11,16,16",
452 "1.64.0.0/10,16,16",
453 "1.128.0.0/9,16,16",
454 ].into_iter().collect()
455 };
456 xor_covered_child {
457 vec!["1.0.0.0/8,16,16"].into_iter().collect::<Set<_>>()
458 ^ vec!["1.0.0.0/16"].into_iter().collect(),
459 vec![
460 "1.1.0.0/16",
461 "1.2.0.0/15,16,16",
462 "1.4.0.0/14,16,16",
463 "1.8.0.0/13,16,16",
464 "1.16.0.0/12,16,16",
465 "1.32.0.0/11,16,16",
466 "1.64.0.0/10,16,16",
467 "1.128.0.0/9,16,16",
468 ].into_iter().collect()
469 };
470 xor_overlapping_set_with_covered_child {
471 vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
472 ^ vec!["1.0.0.0/8,16,16"].into_iter().collect(),
473 vec!["1.0.0.0/8"].into_iter().collect()
474 };
475 sub_disjoint_nodes {
476 vec!["2.0.0.0/8,8,16"].into_iter().collect::<Set<_>>()
477 - vec!["3.0.0.0/8,8,16"].into_iter().collect(),
478 vec!["2.0.0.0/8,8,16"].into_iter().collect()
479 };
480 sub_disjoint_ranges {
481 vec!["1.0.0.0/8,8,11"].into_iter().collect::<Set<_>>()
482 - vec!["1.0.0.0/8,12,15"].into_iter().collect(),
483 vec!["1.0.0.0/8,8,11"].into_iter().collect()
484 };
485 sub_overlapping_nodes {
486 vec!["1.0.0.0/8,12,16"].into_iter().collect::<Set<_>>()
487 - vec!["1.0.0.0/12,12,16"].into_iter().collect(),
488 vec![
489 "1.16.0.0/12,12,16",
490 "1.32.0.0/11,12,16",
491 "1.64.0.0/10,12,16",
492 "1.128.0.0/9,12,16"
493 ].into_iter().collect()
494 };
495 sub_overlapping_ranges {
496 vec!["1.0.0.0/8,8,12"].into_iter().collect::<Set<_>>()
497 - vec!["1.0.0.0/8,12,16"].into_iter().collect(),
498 vec!["1.0.0.0/8,8,11"].into_iter().collect()
499 };
500 sub_overlapping_set_with_parent {
501 vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
502 - vec!["1.0.0.0/16"].into_iter().collect(),
503 vec!["1.0.0.0/8"].into_iter().collect()
504 };
505 sub_overlapping_set_with_sibling {
506 vec!["1.0.0.0/8", "2.0.0.0/8"].into_iter().collect::<Set<_>>()
507 - vec!["1.0.0.0/8"].into_iter().collect(),
508 vec!["2.0.0.0/8"].into_iter().collect()
509 };
510 sub_overlapping_set_with_child {
511 vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
512 - vec!["1.0.0.0/8"].into_iter().collect(),
513 vec!["1.0.0.0/16"].into_iter().collect()
514 };
515 sub_covering_parent {
516 vec!["1.0.0.0/16"].into_iter().collect::<Set<_>>()
517 - vec!["1.0.0.0/8,16,16"].into_iter().collect(),
518 Set::zero()
519 };
520 sub_covered_child {
521 vec!["1.0.0.0/8,16,16"].into_iter().collect::<Set<_>>()
522 - vec!["1.0.0.0/16"].into_iter().collect(),
523 vec![
524 "1.1.0.0/16",
525 "1.2.0.0/15,16,16",
526 "1.4.0.0/14,16,16",
527 "1.8.0.0/13,16,16",
528 "1.16.0.0/12,16,16",
529 "1.32.0.0/11,16,16",
530 "1.64.0.0/10,16,16",
531 "1.128.0.0/9,16,16",
532 ].into_iter().collect()
533 };
534 sub_overlapping_set_with_covered_child {
535 vec!["1.0.0.0/8", "1.0.0.0/16"].into_iter().collect::<Set<_>>()
536 - vec!["1.0.0.0/8,16,16"].into_iter().collect(),
537 vec![
538 "1.0.0.0/8",
539 "1.1.0.0/16",
540 "1.2.0.0/15,16,16",
541 "1.4.0.0/14,16,16",
542 "1.8.0.0/13,16,16",
543 "1.16.0.0/12,16,16",
544 "1.32.0.0/11,16,16",
545 "1.64.0.0/10,16,16",
546 "1.128.0.0/9,16,16",
547 ].into_iter().collect()
548 };
549 sub_complex_deaggregation {
550 vec!["2.0.0.0/8,8,10", "3.0.0.0/8,8,9"].into_iter().collect::<Set<_>>()
551 - vec!["2.0.0.0/10", "3.0.0.0/8,8,10"].into_iter().collect(),
552 vec![
553 "2.0.0.0/8,8,9",
554 "2.64.0.0/10",
555 "2.128.0.0/10",
556 "2.192.0.0/10",
557 ].into_iter().collect()
558 };
559 not_singleton {
560 ! vec!["1.0.0.0/8"].into_iter().collect::<Set<_>>(),
561 vec![
562 "0.0.0.0/0,0,7",
563 "0.0.0.0/0,9,32",
564 "0.0.0.0/8",
565 "2.0.0.0/7,8,8",
566 "4.0.0.0/6,8,8",
567 "8.0.0.0/5,8,8",
568 "16.0.0.0/4,8,8",
569 "32.0.0.0/3,8,8",
570 "64.0.0.0/2,8,8",
571 "128.0.0.0/1,8,8"
572 ].into_iter().collect()
573 };
574 not_range {
575 ! vec!["1.0.0.0/8,8,16"].into_iter().collect::<Set<_>>(),
576 vec![
577 "0.0.0.0/0,0,7",
578 "0.0.0.0/0,17,32",
579 "0.0.0.0/8,8,16",
580 "2.0.0.0/7,8,16",
581 "4.0.0.0/6,8,16",
582 "8.0.0.0/5,8,16",
583 "16.0.0.0/4,8,16",
584 "32.0.0.0/3,8,16",
585 "64.0.0.0/2,8,16",
586 "128.0.0.0/1,8,16",
587 ].into_iter().collect()
588 }
589 });
590}