Skip to main content

prefix_trie/trieview/
union.rs

1//! Union set-operation view.
2//!
3//! [`UnionView`] yields every prefix present in **either** the left or right view.
4//! The two views are **not** aligned at construction: if one is shallower (smaller `depth`),
5//! it leads the traversal on its own, and the deeper side is incorporated only once the
6//! traversal descends to its depth.
7//!
8//! - Same depth:   `data_bitmap = L | R`,  `child_bitmap = L | R`
9//! - L shallower:  `data_bitmap = L only`, `child_bitmap = L | (1 << toward_R)`
10//! - R shallower:  mirror of above
11
12use std::marker::PhantomData;
13
14use num_traits::PrimInt;
15
16use crate::{
17    prefix::mask_from_prefix_len,
18    Prefix,
19    {
20        node::{child_bit as node_child_bit, extend_repr},
21        table::K,
22        AsView,
23    },
24};
25
26use super::{TrieView, ViewIter};
27
28/// The item type yielded by iterating a [`UnionView`].
29///
30/// Indicates whether a prefix is present in only the left view, only the right view,
31/// or in both.
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum UnionItem<L, R> {
34    /// Present only in the left view.
35    Left(L),
36    /// Present only in the right view.
37    Right(R),
38    /// Present in both views.
39    Both(L, R),
40}
41
42impl<L, R> UnionItem<L, R> {
43    /// Get a reference to the left element (if present).
44    pub fn left(&self) -> Option<&L> {
45        match self {
46            UnionItem::Left(l) | UnionItem::Both(l, _) => Some(l),
47            UnionItem::Right(_) => None,
48        }
49    }
50
51    /// Get a reference to the right element (if present).
52    pub fn right(&self) -> Option<&R> {
53        match self {
54            UnionItem::Right(r) | UnionItem::Both(_, r) => Some(r),
55            UnionItem::Left(_) => None,
56        }
57    }
58
59    /// Get a reference to both elements if both are set.
60    pub fn both(&self) -> Option<(&L, &R)> {
61        match self {
62            UnionItem::Both(l, r) => Some((l, r)),
63            _ => None,
64        }
65    }
66
67    /// Extract the left element (and dropping the right if present.)
68    pub fn into_left(self) -> Option<L> {
69        match self {
70            UnionItem::Left(l) | UnionItem::Both(l, _) => Some(l),
71            UnionItem::Right(_) => None,
72        }
73    }
74
75    /// Extract the right element (and dropping the left if present.)
76    pub fn into_right(self) -> Option<R> {
77        match self {
78            UnionItem::Right(r) | UnionItem::Both(_, r) => Some(r),
79            UnionItem::Left(_) => None,
80        }
81    }
82
83    /// Extract both elements, each if they are present.
84    pub fn into_both(self) -> (Option<L>, Option<R>) {
85        match self {
86            UnionItem::Left(l) => (Some(l), None),
87            UnionItem::Right(r) => (None, Some(r)),
88            UnionItem::Both(l, r) => (Some(l), Some(r)),
89        }
90    }
91}
92
93/// An immutable view over the union of two [`TrieView`]s.
94///
95/// Returned by [`TrieView::union`]. The two views are **not** aligned at
96/// construction: if one is shallower (smaller `depth`) than the other, it
97/// continues iterating on its own, incorporating the deeper view only once the
98/// traversal reaches the deeper view's multi-bit node depth.
99///
100/// Yields every prefix present in **either** sub-trie in lexicographic order,
101/// with [`UnionItem::Left`], [`UnionItem::Right`], or [`UnionItem::Both`] indicating
102/// membership.
103#[derive(Clone)]
104pub struct UnionView<'a, L, R>
105where
106    L: TrieView<'a>,
107    R: TrieView<'a, P = L::P>,
108{
109    left: Option<L>,
110    right: Option<R>,
111    depth: u32,
112    key: <<L as TrieView<'a>>::P as Prefix>::R,
113    prefix_len: u32,
114    _phantom: PhantomData<&'a ()>,
115}
116
117impl<'a, L, R> UnionView<'a, L, R>
118where
119    L: TrieView<'a>,
120    R: TrieView<'a, P = L::P>,
121{
122    /// Construct a `UnionView` from two views without aligning their depths.
123    pub(crate) fn new(left: L, right: R) -> Self {
124        let (key, prefix_len) = common_prefix::<L::P>(
125            left.key(),
126            left.prefix_len(),
127            right.key(),
128            right.prefix_len(),
129        );
130        let depth = (prefix_len / K) * K;
131        Self {
132            left: Some(left),
133            right: Some(right),
134            depth,
135            key,
136            prefix_len,
137            _phantom: PhantomData,
138        }
139    }
140}
141
142impl<'a, L, R> TrieView<'a> for UnionView<'a, L, R>
143where
144    L: TrieView<'a>,
145    R: TrieView<'a, P = L::P>,
146{
147    type P = L::P;
148    type T = UnionItem<L::T, R::T>;
149
150    #[inline]
151    fn depth(&self) -> u32 {
152        self.depth
153    }
154
155    #[inline]
156    fn key(&self) -> <L::P as Prefix>::R {
157        self.key
158    }
159
160    #[inline]
161    fn prefix_len(&self) -> u32 {
162        self.prefix_len
163    }
164
165    #[inline]
166    fn data_bitmap(&self) -> u32 {
167        side_data_bitmap(&self.left, self.depth) | side_data_bitmap(&self.right, self.depth)
168    }
169
170    #[inline]
171    fn child_bitmap(&self) -> u32 {
172        side_child_bitmap(&self.left, self.depth) | side_child_bitmap(&self.right, self.depth)
173    }
174
175    #[inline]
176    unsafe fn get_data(&mut self, data_bit: u32) -> UnionItem<L::T, R::T> {
177        match (self.left.as_mut(), self.right.as_mut()) {
178            (Some(l), Some(r)) => {
179                if l.depth() == self.depth && r.depth() == self.depth {
180                    let in_l = (l.data_bitmap() >> data_bit) & 1 == 1;
181                    let in_r = (r.data_bitmap() >> data_bit) & 1 == 1;
182                    match (in_l, in_r) {
183                        (true, true) => UnionItem::Both(l.get_data(data_bit), r.get_data(data_bit)),
184                        (true, false) => UnionItem::Left(l.get_data(data_bit)),
185                        (false, true) => UnionItem::Right(r.get_data(data_bit)),
186                        (false, false) => unreachable!("get_data on bit absent from data_bitmap"),
187                    }
188                } else if l.depth() == self.depth {
189                    UnionItem::Left(l.get_data(data_bit))
190                } else if r.depth() == self.depth {
191                    UnionItem::Right(r.get_data(data_bit))
192                } else {
193                    unreachable!("get_data on virtual UnionView root")
194                }
195            }
196            (Some(l), None) => UnionItem::Left(l.get_data(data_bit)),
197            (None, Some(r)) => UnionItem::Right(r.get_data(data_bit)),
198            (None, None) => unreachable!("get_data on empty UnionView"),
199        }
200    }
201
202    unsafe fn get_child(&mut self, child_bit: u32) -> Self {
203        let new_depth = self.depth + K;
204        let new_key = extend_repr(self.key, self.depth, child_bit);
205        UnionView {
206            left: take_child(&mut self.left, self.depth, child_bit),
207            right: take_child(&mut self.right, self.depth, child_bit),
208            depth: new_depth,
209            key: new_key,
210            prefix_len: new_depth,
211            _phantom: PhantomData,
212        }
213    }
214
215    unsafe fn reposition(&mut self, key: <L::P as Prefix>::R, prefix_len: u32) {
216        reposition_side(&mut self.left, self.depth, key, prefix_len);
217        reposition_side(&mut self.right, self.depth, key, prefix_len);
218        self.key = key;
219        self.prefix_len = prefix_len;
220    }
221}
222
223fn common_prefix<P: Prefix>(
224    left_key: P::R,
225    left_len: u32,
226    right_key: P::R,
227    right_len: u32,
228) -> (P::R, u32) {
229    let max_len = left_len.min(right_len);
230    let diff = (left_key & mask_from_prefix_len(max_len as u8))
231        ^ (right_key & mask_from_prefix_len(max_len as u8));
232    let len = diff.leading_zeros().min(max_len);
233    (left_key & mask_from_prefix_len(len as u8), len)
234}
235
236fn paths_overlap<P: Prefix>(
237    left_key: P::R,
238    left_len: u32,
239    right_key: P::R,
240    right_len: u32,
241) -> bool {
242    let min_len = left_len.min(right_len);
243    let mask = mask_from_prefix_len(min_len as u8);
244    left_key & mask == right_key & mask
245}
246
247fn side_data_bitmap<'a, V: TrieView<'a>>(side: &Option<V>, depth: u32) -> u32 {
248    match side {
249        Some(view) if view.depth() == depth => view.data_bitmap(),
250        _ => 0,
251    }
252}
253
254fn side_child_bitmap<'a, V: TrieView<'a>>(side: &Option<V>, depth: u32) -> u32 {
255    match side {
256        Some(view) if view.depth() == depth => view.child_bitmap(),
257        Some(view) if view.depth() > depth => 1 << node_child_bit(depth, view.key()),
258        _ => 0,
259    }
260}
261
262unsafe fn take_child<'a, V: TrieView<'a>>(
263    side: &mut Option<V>,
264    depth: u32,
265    child_bit: u32,
266) -> Option<V> {
267    let view = side.as_mut()?;
268    if view.depth() == depth {
269        if (view.child_bitmap() >> child_bit) & 1 == 1 {
270            Some(view.get_child(child_bit))
271        } else {
272            None
273        }
274    } else if view.depth() > depth {
275        if child_bit == node_child_bit(depth, view.key()) {
276            side.take()
277        } else {
278            None
279        }
280    } else {
281        None
282    }
283}
284
285unsafe fn reposition_side<'a, V: TrieView<'a>>(
286    side: &mut Option<V>,
287    union_depth: u32,
288    key: <<V as TrieView<'a>>::P as Prefix>::R,
289    prefix_len: u32,
290) {
291    let Some(view) = side.as_mut() else {
292        return;
293    };
294    if !paths_overlap::<V::P>(view.key(), view.prefix_len(), key, prefix_len) {
295        *side = None;
296    } else if view.depth() == union_depth && prefix_len >= view.prefix_len() {
297        view.reposition(key, prefix_len);
298    }
299}
300
301impl<'a, L, R> IntoIterator for UnionView<'a, L, R>
302where
303    L: TrieView<'a>,
304    R: TrieView<'a, P = L::P>,
305{
306    type Item = (L::P, UnionItem<L::T, R::T>);
307    type IntoIter = ViewIter<'a, UnionView<'a, L, R>>;
308
309    fn into_iter(self) -> Self::IntoIter {
310        self.iter()
311    }
312}
313
314impl<'a, L, R> AsView<'a> for UnionView<'a, L, R>
315where
316    L: TrieView<'a>,
317    R: TrieView<'a, P = L::P>,
318{
319    type P = L::P;
320    type View = Self;
321
322    fn view(self) -> Self {
323        self
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use crate::{
330        Prefix,
331        {
332            trieview::{AsView, TrieView},
333            PrefixMap,
334        },
335    };
336
337    use super::UnionItem;
338
339    type P = (u32, u8);
340
341    fn p(repr: u32, len: u8) -> P {
342        P::from_repr_len(repr, len)
343    }
344
345    fn map_from(entries: &[(u32, u8, i32)]) -> PrefixMap<P, i32> {
346        let mut m = PrefixMap::new();
347        for &(repr, len, val) in entries {
348            m.insert(p(repr, len), val);
349        }
350        m
351    }
352
353    fn collect_union<'a>(
354        iter: impl Iterator<Item = (P, UnionItem<&'a i32, &'a i32>)>,
355    ) -> Vec<(P, Option<i32>, Option<i32>)> {
356        iter.map(|(p, item)| match item {
357            UnionItem::Left(l) => (p, Some(*l), None),
358            UnionItem::Right(r) => (p, None, Some(*r)),
359            UnionItem::Both(l, r) => (p, Some(*l), Some(*r)),
360        })
361        .collect()
362    }
363
364    // -- Same-depth cases ------------------------------------------------------
365
366    #[test]
367    fn union_disjoint() {
368        let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
369        let b = map_from(&[(0x0b000000, 8, 10), (0x0b010000, 16, 20)]);
370        let got = collect_union(a.view().union(b.view()).into_iter());
371        assert_eq!(
372            got,
373            vec![
374                (p(0x0a000000, 8), Some(1), None),
375                (p(0x0a010000, 16), Some(2), None),
376                (p(0x0b000000, 8), None, Some(10)),
377                (p(0x0b010000, 16), None, Some(20)),
378            ]
379        );
380    }
381
382    #[test]
383    fn union_overlapping() {
384        let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
385        let b = map_from(&[(0x0a000000, 8, 10), (0x0b000000, 8, 20)]);
386        let got = collect_union(a.view().union(b.view()).into_iter());
387        assert_eq!(
388            got,
389            vec![
390                (p(0x0a000000, 8), Some(1), Some(10)),
391                (p(0x0a010000, 16), Some(2), None),
392                (p(0x0b000000, 8), None, Some(20)),
393            ]
394        );
395    }
396
397    #[test]
398    fn union_identical() {
399        let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
400        let b = map_from(&[(0x0a000000, 8, 10), (0x0a010000, 16, 20)]);
401        let got = collect_union(a.view().union(b.view()).into_iter());
402        assert_eq!(
403            got,
404            vec![
405                (p(0x0a000000, 8), Some(1), Some(10)),
406                (p(0x0a010000, 16), Some(2), Some(20)),
407            ]
408        );
409    }
410
411    #[test]
412    fn union_one_empty() {
413        let a = map_from(&[(0x0a000000, 8, 1)]);
414        let b: PrefixMap<P, i32> = PrefixMap::new();
415        let got = collect_union(a.view().union(b.view()).into_iter());
416        assert_eq!(got, vec![(p(0x0a000000, 8), Some(1), None)]);
417    }
418
419    #[test]
420    fn union_both_empty() {
421        let a: PrefixMap<P, i32> = PrefixMap::new();
422        let b: PrefixMap<P, i32> = PrefixMap::new();
423        assert!(a.view().union(b.view()).into_iter().next().is_none());
424    }
425
426    #[test]
427    fn union_into_iter_for_loop() {
428        let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
429        let b = map_from(&[(0x0a000000, 8, 10), (0x0b000000, 8, 20)]);
430        let mut count = 0;
431        for (_p, _item) in a.view().union(b.view()) {
432            count += 1;
433        }
434        assert_eq!(count, 3);
435    }
436
437    /// Larger same-depth test: many entries spread across the address space,
438    /// covering Left-only, Right-only, and Both across multiple subtries and levels.
439    #[test]
440    fn union_large_same_depth() {
441        let a = map_from(&[
442            (0x01000000, 8, 1),
443            (0x0a000000, 8, 10),
444            (0x0a010000, 16, 11),
445            (0x0a020000, 16, 12),
446            (0x0a010100, 24, 13),
447            (0x64000000, 8, 100),
448            (0x64010000, 16, 101),
449            (0xc0a80000, 16, 200), // 192.168.0.0/16
450        ]);
451        let b = map_from(&[
452            (0x0a000000, 8, 20),  // overlaps a
453            (0x0a010000, 16, 21), // overlaps a
454            (0x0a030000, 16, 22), // new in b
455            (0x0b000000, 8, 30),
456            (0x0b010000, 16, 31),
457            (0x64000000, 8, 110),  // overlaps a
458            (0xc0a80100, 24, 210), // 192.168.1.0/24 -> new in b
459        ]);
460        let got = collect_union(a.view().union(b.view()).into_iter());
461        assert_eq!(
462            got,
463            vec![
464                (p(0x01000000, 8), Some(1), None),        // Left
465                (p(0x0a000000, 8), Some(10), Some(20)),   // Both
466                (p(0x0a010000, 16), Some(11), Some(21)),  // Both
467                (p(0x0a010100, 24), Some(13), None),      // Left
468                (p(0x0a020000, 16), Some(12), None),      // Left
469                (p(0x0a030000, 16), None, Some(22)),      // Right
470                (p(0x0b000000, 8), None, Some(30)),       // Right
471                (p(0x0b010000, 16), None, Some(31)),      // Right
472                (p(0x64000000, 8), Some(100), Some(110)), // Both
473                (p(0x64010000, 16), Some(101), None),     // Left
474                (p(0xc0a80000, 16), Some(200), None),     // Left
475                (p(0xc0a80100, 24), None, Some(210)),     // Right
476            ]
477        );
478    }
479
480    #[test]
481    fn union_large_same_depth_view_at() {
482        let a = map_from(&[
483            (0x01000000, 8, 1),
484            (0x0a000000, 8, 10),
485            (0x0a010000, 16, 11),
486            (0x0a020000, 16, 12),
487            (0x0a010100, 24, 13),
488            (0x64000000, 8, 100),
489            (0x64010000, 16, 101),
490            (0xc0a80000, 16, 200), // 192.168.0.0/16
491        ]);
492        let b = map_from(&[
493            (0x0a000000, 8, 20),  // overlaps a
494            (0x0a010000, 16, 21), // overlaps a
495            (0x0a030000, 16, 22), // new in b
496            (0x0b000000, 8, 30),
497            (0x0b010000, 16, 31),
498            (0x64000000, 8, 110),  // overlaps a
499            (0xc0a80100, 24, 210), // 192.168.1.0/24 -> new in b
500        ]);
501        let got = collect_union(
502            a.view_at(&p(0x00000000, 1))
503                .unwrap()
504                .union(b.view_at(&p(0x00000000, 1)).unwrap())
505                .into_iter(),
506        );
507        let want = vec![
508            (p(0x01000000, 8), Some(1), None),        // Left
509            (p(0x0a000000, 8), Some(10), Some(20)),   // Both
510            (p(0x0a010000, 16), Some(11), Some(21)),  // Both
511            (p(0x0a010100, 24), Some(13), None),      // Left
512            (p(0x0a020000, 16), Some(12), None),      // Left
513            (p(0x0a030000, 16), None, Some(22)),      // Right
514            (p(0x0b000000, 8), None, Some(30)),       // Right
515            (p(0x0b010000, 16), None, Some(31)),      // Right
516            (p(0x64000000, 8), Some(100), Some(110)), // Both
517            (p(0x64010000, 16), Some(101), None),     // Left
518        ];
519        assert_eq!(got, want);
520    }
521
522    #[test]
523    fn union_large_different_depth() {
524        let a = map_from(&[
525            (0x01000000, 8, 1),
526            (0x0a000000, 8, 10),
527            (0x0a010000, 16, 11),
528            (0x0a020000, 16, 12),
529            (0x0a010100, 24, 13),
530            (0x64000000, 8, 100),
531            (0x64010000, 16, 101),
532            (0xc0a80000, 16, 200), // 192.168.0.0/16
533        ]);
534        let b = map_from(&[
535            (0x0a000000, 8, 20),  // overlaps a
536            (0x0a010000, 16, 21), // overlaps a
537            (0x0a030000, 16, 22), // new in b
538            (0x0b000000, 8, 30),
539            (0x0b010000, 16, 31),
540            (0x64000000, 8, 110),  // overlaps a
541            (0xc0a80100, 24, 210), // 192.168.1.0/24 -> new in b
542        ]);
543        let got = collect_union(
544            a.view_at(&p(0x00000000, 1))
545                .unwrap()
546                .union(b.view_at(&p(0x0a000000, 8)).unwrap())
547                .into_iter(),
548        );
549        let want = vec![
550            (p(0x01000000, 8), Some(1), None),       // Left
551            (p(0x0a000000, 8), Some(10), Some(20)),  // Both
552            (p(0x0a010000, 16), Some(11), Some(21)), // Both
553            (p(0x0a010100, 24), Some(13), None),     // Left
554            (p(0x0a020000, 16), Some(12), None),     // Left
555            (p(0x0a030000, 16), None, Some(22)),     // Right
556            (p(0x64000000, 8), Some(100), None),     // Both
557            (p(0x64010000, 16), Some(101), None),    // Left
558        ];
559        assert_eq!(got, want);
560    }
561
562    // -- find / find_lpm on a union --------------------------------------------
563
564    #[test]
565    fn union_find_then_iter() {
566        let a = map_from(&[
567            (0x0a000000, 8, 1),
568            (0x0a010000, 16, 2),
569            (0x0a010100, 24, 3),
570            (0x0b000000, 8, 4),
571        ]);
572        let b = map_from(&[
573            (0x0a000000, 8, 10),
574            (0x0a010000, 16, 20),
575            (0x0a020000, 16, 30),
576            (0x0c000000, 8, 40),
577        ]);
578        // find 10.1.0.0/16 on the union -> a has {10.1/16, 10.1.1/24}, b has {10.1/16}
579        let got = collect_union(
580            a.view()
581                .union(b.view())
582                .find(&p(0x0a010000, 16))
583                .unwrap()
584                .into_iter(),
585        );
586        assert_eq!(
587            got,
588            vec![
589                (p(0x0a010000, 16), Some(2), Some(20)),
590                (p(0x0a010100, 24), Some(3), None),
591            ]
592        );
593    }
594
595    #[test]
596    fn union_find_exact_and_value() {
597        let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
598        let b = map_from(&[(0x0a000000, 8, 10), (0x0b000000, 8, 20)]);
599        let u = a.view().union(b.view());
600
601        // find_exact on a prefix present in both
602        let v = u.clone().find_exact(&p(0x0a000000, 8)).unwrap();
603        assert!(matches!(v.value().unwrap(), UnionItem::Both(l, r) if *l == 1 && *r == 10));
604
605        // find_exact on a prefix present only in a
606        let v2 = u.clone().find_exact(&p(0x0a010000, 16)).unwrap();
607        assert!(matches!(v2.value().unwrap(), UnionItem::Left(l) if *l == 2));
608
609        // find_exact on a prefix present only in b
610        let v3 = u.clone().find_exact(&p(0x0b000000, 8)).unwrap();
611        assert!(matches!(v3.value().unwrap(), UnionItem::Right(r) if *r == 20));
612
613        // find_exact on a prefix in neither
614        assert!(u.find_exact(&p(0x0c000000, 8)).is_none());
615    }
616
617    #[test]
618    fn union_find_lpm_value_keys_values() {
619        let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
620        let b = map_from(&[(0x0a010100, 24, 30), (0x0b000000, 8, 40)]);
621        let u = a.view().union(b.view());
622
623        let lpm = u.clone().find_lpm(&p(0x0a010180, 25)).unwrap();
624        assert_eq!(lpm.prefix(), p(0x0a010100, 24));
625        assert!(matches!(lpm.value().unwrap(), UnionItem::Right(r) if *r == 30));
626
627        let got = u
628            .clone()
629            .find_lpm_value(&p(0x0a010180, 25))
630            .map(|(prefix, value)| (prefix, value.into_both()));
631        assert!(matches!(
632            got,
633            Some((prefix, (None, Some(r)))) if prefix == p(0x0a010100, 24) && *r == 30
634        ));
635
636        assert_eq!(
637            u.clone().keys().collect::<Vec<_>>(),
638            vec![
639                p(0x0a000000, 8),
640                p(0x0a010000, 16),
641                p(0x0a010100, 24),
642                p(0x0b000000, 8),
643            ]
644        );
645        assert_eq!(u.values().count(), 4);
646    }
647
648    #[test]
649    fn union_mut_find_lpm_value_does_not_require_clone() {
650        let mut a = map_from(&[(0x0a000000, 8, 1), (0x0a010100, 24, 3)]);
651        let b = map_from(&[(0x0a000000, 8, 10), (0x0a010000, 16, 20)]);
652
653        let got = (&mut a)
654            .view()
655            .union(b.view())
656            .find_lpm_value(&p(0x0a010180, 25))
657            .map(|(prefix, item)| match item {
658                UnionItem::Left(l) => {
659                    *l += 10;
660                    (prefix, *l)
661                }
662                other => panic!("expected left-only LPM, got {other:?}"),
663            });
664
665        assert_eq!(got, Some((p(0x0a010100, 24), 13)));
666        assert_eq!(a.get(&p(0x0a010100, 24)), Some(&13));
667    }
668
669    // -- Depth-difference cases ------------------------------------------------
670    //
671    // With K=5, nodes live at depths 0, 5, 10, 15, 20, 25, 30.
672    // `view_at(&p(addr, len))` lands in the node at depth = floor(len / K) * K.
673    //   e.g. len=8  -> depth 5    (5 ≤ 8 < 10)
674    //        len=16 -> depth 15   (15 ≤ 16 < 20)
675    //
676    // Tests cover:
677    //   (a) "going toward"    : child_bit == toward_deeper: deeper view is carried along
678    //   (b) "not going toward": child_bit != toward_deeper: deeper view is dropped
679    //   (c) shallower has NO child in the direction of the deeper view
680
681    /// L deeper (depth 5), R shallower (depth 0).
682    /// R leads at depth 0; the child toward 10.x carries L along and merges at depth 5.
683    /// Children of R toward 9.x and 11.x are NOT toward L -> only R data there.
684    #[test]
685    fn union_l_deeper_r_shallower_going_toward_and_not() {
686        let a = map_from(&[
687            (0x0a000000, 8, 1),  // 10.0.0.0/8
688            (0x0a010000, 16, 2), // 10.1.0.0/16
689            (0x0a020000, 16, 3), // 10.2.0.0/16
690        ]);
691        let b = map_from(&[
692            (0x09000000, 8, 90),  // 9.0.0.0/8  -> NOT toward a_sub
693            (0x0a000000, 8, 10),  // 10.0.0.0/8 -> toward a_sub; merged
694            (0x0a010000, 16, 20), // 10.1.0.0/16
695            (0x0b000000, 8, 30),  // 11.0.0.0/8 -> NOT toward a_sub
696        ]);
697        let a_sub = a.view_at(&p(0x0a000000, 8)).unwrap(); // depth 5
698        let b_root = b.view(); // depth 0
699
700        let got = collect_union(a_sub.union(b_root).into_iter());
701        assert_eq!(
702            got,
703            vec![
704                (p(0x09000000, 8), None, Some(90)), // not toward a_sub -> b only
705                (p(0x0a000000, 8), Some(1), Some(10)), // toward a_sub, merged -> Both
706                (p(0x0a010000, 16), Some(2), Some(20)), // merged
707                (p(0x0a020000, 16), Some(3), None), // a only
708                (p(0x0b000000, 8), None, Some(30)), // not toward a_sub -> b only
709            ]
710        );
711    }
712
713    /// Mirror: L shallower (depth 0), R deeper (depth 5).
714    /// L leads; child toward 10.x carries R; children toward 9.x and 11.x drop R.
715    #[test]
716    fn union_r_deeper_l_shallower_going_toward_and_not() {
717        let a = map_from(&[
718            (0x09000000, 8, 90),
719            (0x0a000000, 8, 10),
720            (0x0a010000, 16, 20),
721            (0x0b000000, 8, 30),
722        ]);
723        let b = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0a020000, 16, 3)]);
724        let a_root = a.view();
725        let b_sub = b.view_at(&p(0x0a000000, 8)).unwrap(); // depth 5
726
727        let got = collect_union(a_root.union(b_sub).into_iter());
728        assert_eq!(
729            got,
730            vec![
731                (p(0x09000000, 8), Some(90), None), // not toward b_sub -> a only
732                (p(0x0a000000, 8), Some(10), Some(1)), // toward b_sub -> Both
733                (p(0x0a010000, 16), Some(20), Some(2)),
734                (p(0x0a020000, 16), None, Some(3)), // b only
735                (p(0x0b000000, 8), Some(30), None), // not toward b_sub -> a only
736            ]
737        );
738    }
739
740    /// The shallower side has **no** child in the direction of the deeper side.
741    /// child_bitmap() forcibly adds the bit toward the deeper view (via `| (1 << toward)`),
742    /// so the deeper view's entries still appear even though the shallower has no child there.
743    #[test]
744    fn union_shallower_has_no_child_toward_deeper() {
745        // a_root has entries only in 9.x and 11.x -> no 10.x child at all.
746        // b_sub is positioned at the 10.x subtrie (depth 5).
747        let a = map_from(&[
748            (0x09000000, 8, 1), // 9.0.0.0/8
749            (0x0b000000, 8, 2), // 11.0.0.0/8
750        ]);
751        let b = map_from(&[
752            (0x0a000000, 8, 10),  // 10.0.0.0/8
753            (0x0a010000, 16, 20), // 10.1.0.0/16
754            (0x0a010100, 24, 30), // 10.1.1.0/24
755        ]);
756        let a_root = a.view();
757        let b_sub = b.view_at(&p(0x0a000000, 8)).unwrap(); // depth 5
758
759        let got = collect_union(a_root.union(b_sub).into_iter());
760        assert_eq!(
761            got,
762            vec![
763                (p(0x09000000, 8), Some(1), None),   // a only
764                (p(0x0a000000, 8), None, Some(10)),  // b only -> a had no child here
765                (p(0x0a010000, 16), None, Some(20)), // b only
766                (p(0x0a010100, 24), None, Some(30)), // b only
767                (p(0x0b000000, 8), Some(2), None),   // a only
768            ]
769        );
770    }
771
772    /// L has entries in multiple children; only one child is "toward" R.
773    /// Tests that going-toward carries R, not-going-toward drops R.
774    #[test]
775    fn union_shallower_multiple_children_only_one_toward_deeper() {
776        // a_root: entries in 10.x, 11.x, 12.x.
777        // b_sub: in 10.x subtrie (depth 5).
778        //   child toward 10.x -> merge; children toward 11.x and 12.x -> a only.
779        let a = map_from(&[
780            (0x0a000000, 8, 1),
781            (0x0a010000, 16, 2),
782            (0x0b000000, 8, 3),
783            (0x0c000000, 8, 4),
784        ]);
785        let b = map_from(&[
786            (0x0a000000, 8, 10),
787            (0x0a020000, 16, 20), // 10.2/16 -> only in b
788        ]);
789        let a_root = a.view();
790        let b_sub = b.view_at(&p(0x0a000000, 8)).unwrap(); // depth 5
791
792        let got = collect_union(a_root.union(b_sub).into_iter());
793        assert_eq!(
794            got,
795            vec![
796                (p(0x0a000000, 8), Some(1), Some(10)), // toward b_sub -> merged
797                (p(0x0a010000, 16), Some(2), None),    // a only (b has nothing here)
798                (p(0x0a020000, 16), None, Some(20)),   // b only
799                (p(0x0b000000, 8), Some(3), None),     // NOT toward b_sub -> a only
800                (p(0x0c000000, 8), Some(4), None),     // NOT toward b_sub -> a only
801            ]
802        );
803    }
804
805    /// Three levels of depth difference (depth 15 vs depth 0 with K=5).
806    /// Requires three successive get_child calls to align the two sides.
807    /// Entries in b's intermediate nodes (depth 5) appear as Right before alignment.
808    #[test]
809    fn union_multi_level_depth_difference() {
810        // a_sub at depth 15 (since K=5 and 15 ≤ 16 < 20).
811        // b_root at depth 0.
812        // Hops needed: depth 0 -> 5 -> 10 -> 15 (three hops).
813        let a = map_from(&[
814            (0x0a010000, 16, 1), // 10.1.0.0/16
815            (0x0a010100, 24, 2), // 10.1.1.0/24
816            (0x0a010200, 24, 3), // 10.1.2.0/24
817        ]);
818        let b = map_from(&[
819            (0x0a000000, 8, 10),  // 10.0.0.0/8  -> in depth-5 node, appears before alignment
820            (0x0a010000, 16, 20), // 10.1.0.0/16 -> aligned at depth 15
821            (0x0b000000, 8, 30),  // 11.0.0.0/8
822        ]);
823        let a_sub = a.view_at(&p(0x0a010000, 16)).unwrap(); // depth 15
824        let b_root = b.view(); // depth 0
825
826        let got = collect_union(a_sub.union(b_root).into_iter());
827        assert_eq!(
828            got,
829            vec![
830                (p(0x0a000000, 8), None, Some(10)),     // b only (in depth-5 node)
831                (p(0x0a010000, 16), Some(1), Some(20)), // both  (aligned at depth 15)
832                (p(0x0a010100, 24), Some(2), None),     // a only
833                (p(0x0a010200, 24), Some(3), None),     // a only
834                (p(0x0b000000, 8), None, Some(30)),     // b only
835            ]
836        );
837    }
838
839    // -- Composition -----------------------------------------------------------
840
841    #[test]
842    fn union_composed_with_intersection() {
843        // (a ∪ b) ∩ c -> UnionView implements TrieView so it composes.
844        let a = map_from(&[(0x0a000000, 8, 1), (0x0b000000, 8, 2)]);
845        let b = map_from(&[(0x0a000000, 8, 10), (0x0c000000, 8, 20)]);
846        let c = map_from(&[(0x0a000000, 8, 100)]);
847
848        let got: Vec<_> = a
849            .view()
850            .union(&b)
851            .intersection(&c)
852            .unwrap()
853            .into_iter()
854            .map(|(p, (u, r))| (p, u, *r))
855            .collect();
856        assert_eq!(got.len(), 1);
857        assert_eq!(got[0].0, p(0x0a000000, 8));
858        assert_eq!(got[0].2, 100);
859        assert!(matches!(got[0].1, UnionItem::Both(l, r) if *l == 1 && *r == 10));
860    }
861
862    #[test]
863    fn union_composed_union_of_unions() {
864        // (a ∪ b) ∪ c -> UnionView<UnionView<..>, ..> works end-to-end.
865        let a = map_from(&[(0x0a000000, 8, 1)]);
866        let b = map_from(&[(0x0b000000, 8, 2)]);
867        let c = map_from(&[(0x0c000000, 8, 3), (0x0a000000, 8, 10)]);
868
869        let count = a.view().union(b.view()).union(c.view()).into_iter().count();
870        // 10/8 (in a and c -> Both in ab∪c), 11/8 (b only), 12/8 (c only)
871        assert_eq!(count, 3);
872    }
873
874    // -- iter_from on union views -----------------------------------------------
875
876    #[test]
877    fn union_iter_from_inclusive() {
878        let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0a020000, 16, 3)]);
879        let b = map_from(&[(0x0a010000, 16, 20), (0x0a030000, 16, 40)]);
880
881        // Full union: 10/8(L), 10.1/16(B), 10.2/16(L), 10.3/16(R)
882        let u = a.view().union(b.view());
883        let from: Vec<_> = u
884            .iter_from(&p(0x0a010000, 16), true)
885            .map(|(p, _)| p)
886            .collect();
887        assert_eq!(
888            from,
889            vec![p(0x0a010000, 16), p(0x0a020000, 16), p(0x0a030000, 16)]
890        );
891    }
892
893    #[test]
894    fn union_iter_from_exclusive() {
895        let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0a020000, 16, 3)]);
896        let b = map_from(&[(0x0a010000, 16, 20), (0x0a030000, 16, 40)]);
897
898        let u = a.view().union(b.view());
899        let from: Vec<_> = u
900            .iter_from(&p(0x0a010000, 16), false)
901            .map(|(p, _)| p)
902            .collect();
903        assert_eq!(from, vec![p(0x0a020000, 16), p(0x0a030000, 16)]);
904    }
905
906    #[test]
907    fn union_iter_from_subview() {
908        let a = map_from(&[
909            (0x0a000000, 8, 1), // excluded by sub-view
910            (0x0a020000, 16, 2),
911            (0x0a030000, 16, 3),
912            (0x0b000000, 8, 4), // excluded by sub-view
913        ]);
914        let b = map_from(&[
915            (0x0a020000, 16, 20),
916            (0x0a030000, 16, 40),
917            (0x0b000000, 8, 50), // excluded by sub-view
918        ]);
919
920        // Sub-view at 10.2.0.0/15 covers 10.2.x.x–10.3.x.x, excludes 10/8, 11/8
921        let u = a
922            .view_at(&p(0x0a020000, 15))
923            .unwrap()
924            .union(b.view_at(&p(0x0a020000, 15)).unwrap());
925
926        // Full union of the sub-views: 10.2/16(B), 10.3/16(B)
927        let all: Vec<_> = u.clone().iter().map(|(p, _)| p).collect();
928        assert_eq!(all, vec![p(0x0a020000, 16), p(0x0a030000, 16)]);
929
930        // iter_from exclusive from 10.2/16 → only 10.3/16
931        let from: Vec<_> = u
932            .iter_from(&p(0x0a020000, 16), false)
933            .map(|(p, _)| p)
934            .collect();
935        assert_eq!(from, vec![p(0x0a030000, 16)]);
936    }
937}