Skip to main content

prefix_trie/trieview/
intersection.rs

1//! Intersection set-operation view.
2//!
3//! [`IntersectionView`] yields every prefix present in **both** the left and right views.
4//! Both `data_bitmap()` and `child_bitmap()` are the AND of the two sides' bitmaps, so
5//! only prefixes/children present in both views are visited.
6//!
7//! `IntersectionView::new` returns `None` for structurally disjoint sub-tries, so every
8//! live `IntersectionView` has aligned, overlapping views.
9
10use std::marker::PhantomData;
11
12use crate::{prefix::mask_from_prefix_len, AsView, Prefix};
13
14use super::{TrieView, ViewIter};
15
16/// An immutable view over the intersection of two [`TrieView`]s.
17///
18/// Returned as `Option<IntersectionView<'_, L, R>>` by [`TrieView::intersection`].
19/// `None` means the two sub-tries are disjoint (no common prefixes possible).
20///
21/// A live `IntersectionView` can be iterated directly (implements [`IntoIterator`])
22/// or composed with further set operations before iterating.
23#[derive(Clone)]
24pub struct IntersectionView<'a, L, R> {
25    left: L,
26    right: R,
27    _phantom: PhantomData<&'a ()>,
28}
29
30impl<'a, L, R> IntersectionView<'a, L, R>
31where
32    L: TrieView<'a>,
33    R: TrieView<'a, P = L::P>,
34{
35    /// Construct an `IntersectionView`, aligning the two views to the same depth.
36    ///
37    /// Returns `None` when:
38    /// - The key prefixes diverge at the shallower prefix_len (disjoint sub-tries), or
39    /// - The deeper side has no matching sub-trie at the shallower side's key.
40    pub(crate) fn new(left: L, right: R) -> Option<Self> {
41        let (left, right) = align(left, right)?;
42        Some(Self {
43            left,
44            right,
45            _phantom: PhantomData,
46        })
47    }
48}
49
50/// Align two views to the same depth by navigating the shallower one toward the deeper one.
51///
52/// Returns `None` if the key prefixes diverge (disjoint sub-tries).
53fn align<'a, L, R>(left: L, right: R) -> Option<(L, R)>
54where
55    L: TrieView<'a>,
56    R: TrieView<'a, P = L::P>,
57{
58    // Check key agreement at the shallower prefix_len.
59    let min_prefix_len = left.prefix_len().min(right.prefix_len());
60    let mask = mask_from_prefix_len(min_prefix_len as u8);
61    if left.key() & mask != right.key() & mask {
62        return None; // diverging keys -> disjoint sub-tries
63    }
64
65    // Navigate the shallower side toward the deeper one.
66    if left.depth() < right.depth() {
67        let left = left.navigate_to(right.key(), right.prefix_len())?;
68        Some((left, right))
69    } else if right.depth() < left.depth() {
70        let right = right.navigate_to(left.key(), left.prefix_len())?;
71        Some((left, right))
72    } else if left.prefix_len() < right.prefix_len() {
73        let left = left.navigate_to(right.key(), right.prefix_len())?;
74        Some((left, right))
75    } else if right.prefix_len() < left.prefix_len() {
76        let right = right.navigate_to(left.key(), left.prefix_len())?;
77        Some((left, right))
78    } else {
79        Some((left, right))
80    }
81}
82
83impl<'a, L, R> TrieView<'a> for IntersectionView<'a, L, R>
84where
85    L: TrieView<'a>,
86    R: TrieView<'a, P = L::P>,
87{
88    type P = L::P;
89    type T = (L::T, R::T);
90
91    #[inline]
92    fn depth(&self) -> u32 {
93        self.left.depth()
94    }
95
96    #[inline]
97    fn key(&self) -> <L::P as Prefix>::R {
98        self.left.key()
99    }
100
101    #[inline]
102    fn prefix_len(&self) -> u32 {
103        self.left.prefix_len()
104    }
105
106    /// Only bits present in **both** data bitmaps.
107    #[inline]
108    fn data_bitmap(&self) -> u32 {
109        self.left.data_bitmap() & self.right.data_bitmap()
110    }
111
112    /// Only children present in **both** child bitmaps.
113    #[inline]
114    fn child_bitmap(&self) -> u32 {
115        self.left.child_bitmap() & self.right.child_bitmap()
116    }
117
118    #[inline]
119    unsafe fn get_data(&mut self, data_bit: u32) -> (L::T, R::T) {
120        // SAFETY: caller guarantees data_bit is set in data_bitmap() and called at most once.
121        unsafe { (self.left.get_data(data_bit), self.right.get_data(data_bit)) }
122    }
123
124    #[inline]
125    unsafe fn get_child(&mut self, child_bit: u32) -> Self {
126        // SAFETY: caller guarantees child_bit is set in child_bitmap() and called at most once.
127        unsafe {
128            Self {
129                left: self.left.get_child(child_bit),
130                right: self.right.get_child(child_bit),
131                _phantom: PhantomData,
132            }
133        }
134    }
135
136    #[inline]
137    unsafe fn reposition(&mut self, key: <L::P as Prefix>::R, prefix_len: u32) {
138        // SAFETY: caller ensures non-overlapping use with existing cursors.
139        unsafe {
140            self.left.reposition(key, prefix_len);
141            self.right.reposition(key, prefix_len);
142        }
143    }
144}
145
146impl<'a, L, R> IntoIterator for IntersectionView<'a, L, R>
147where
148    L: TrieView<'a>,
149    R: TrieView<'a, P = L::P>,
150{
151    type Item = (L::P, (L::T, R::T));
152    type IntoIter = ViewIter<'a, IntersectionView<'a, L, R>>;
153
154    fn into_iter(self) -> Self::IntoIter {
155        self.iter()
156    }
157}
158
159impl<'a, L, R> AsView<'a> for IntersectionView<'a, L, R>
160where
161    L: TrieView<'a>,
162    R: TrieView<'a, P = L::P>,
163{
164    type P = L::P;
165    type View = Self;
166
167    fn view(self) -> Self {
168        self
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use crate::{
175        Prefix,
176        {
177            trieview::{AsView, TrieView},
178            PrefixMap,
179        },
180    };
181
182    type P = (u32, u8);
183
184    fn p(repr: u32, len: u8) -> P {
185        P::from_repr_len(repr, len)
186    }
187
188    fn map_from(entries: &[(u32, u8, i32)]) -> PrefixMap<P, i32> {
189        let mut m = PrefixMap::new();
190        for &(repr, len, val) in entries {
191            m.insert(p(repr, len), val);
192        }
193        m
194    }
195
196    #[test]
197    fn intersection_basic() {
198        let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0b000000, 8, 9)]);
199        let b = map_from(&[
200            (0x0a000000, 8, 10),
201            (0x0a010000, 16, 20),
202            (0x0c000000, 8, 99),
203        ]);
204        let got: Vec<_> = a
205            .view()
206            .intersection(b.view())
207            .unwrap()
208            .into_iter()
209            .map(|(p, (l, r))| (p, (*l, *r)))
210            .collect();
211        assert_eq!(
212            got,
213            vec![(p(0x0a000000, 8), (1, 10)), (p(0x0a010000, 16), (2, 20))]
214        );
215    }
216
217    #[test]
218    fn intersection_no_common_entries() {
219        // 10.0.0.0/8 and 11.0.0.0/8 share no prefixes. Both root views cover the entire
220        // trie, so the intersection is Some; is_non_empty() may be true (shared child
221        // subtrees exist at the bitmap level), but iterating yields nothing.
222        let a = map_from(&[(0x0a000000, 8, 1)]);
223        let b = map_from(&[(0x0b000000, 8, 2)]);
224        let isect = a.view().intersection(b.view()).unwrap();
225        assert!(isect.into_iter().next().is_none());
226    }
227
228    #[test]
229    fn intersection_disjoint_subviews_is_none() {
230        // Viewing at sub-prefixes that are structurally disjoint -> None.
231        let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
232        let b = map_from(&[(0x0b000000, 8, 10), (0x0b010000, 16, 20)]);
233        let va = a.view_at(&p(0x0a000000, 8)).unwrap();
234        let vb = b.view_at(&p(0x0b000000, 8)).unwrap();
235        assert!(va.intersection(vb).is_none());
236    }
237
238    #[test]
239    fn intersection_into_iter_for_loop() {
240        let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
241        let b = map_from(&[(0x0a000000, 8, 10), (0x0a010000, 16, 20)]);
242        let mut count = 0;
243        if let Some(isect) = a.view().intersection(b.view()) {
244            for (_prefix, (_l, _r)) in isect {
245                count += 1;
246            }
247        }
248        assert_eq!(count, 2);
249    }
250
251    #[test]
252    fn intersection_composed() {
253        // (a ∩ b) ∩ c -> tests that IntersectionView itself implements TrieView
254        // and can be fed into another intersection.
255        let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0b000000, 8, 3)]);
256        let b = map_from(&[
257            (0x0a000000, 8, 10),
258            (0x0a010000, 16, 20),
259            (0x0c000000, 8, 30),
260        ]);
261        let c = map_from(&[(0x0a000000, 8, 100), (0x0b000000, 8, 200)]);
262
263        // a ∩ b gives {10.0.0.0/8, 10.1.0.0/16}; ∩ c keeps only {10.0.0.0/8}
264        let ab = a.view().intersection(b.view()).unwrap();
265        let got: Vec<_> = ab
266            .intersection(c.view())
267            .unwrap()
268            .into_iter()
269            .map(|(p, ((l, _m), r))| (p, (*l, *r)))
270            .collect();
271        assert_eq!(got, vec![(p(0x0a000000, 8), (1, 100))]);
272    }
273
274    #[test]
275    fn intersection_find_then_iter() {
276        // Build maps with many entries across two sub-tries; intersect then find a sub-prefix.
277        let a = map_from(&[
278            (0x0a000000, 8, 1),
279            (0x0a010000, 16, 2),
280            (0x0a010100, 24, 3),
281            (0x0a020000, 16, 4),
282            (0x0b000000, 8, 5),
283        ]);
284        let b = map_from(&[
285            (0x0a000000, 8, 10),
286            (0x0a010000, 16, 20),
287            (0x0a010100, 24, 30),
288            (0x0a030000, 16, 40),
289            (0x0c000000, 8, 50),
290        ]);
291
292        // Intersect: common = {10.0.0.0/8, 10.1.0.0/16, 10.1.1.0/24}
293        let isect = a.view().intersection(b.view()).unwrap();
294
295        // find(10.1.0.0/16) on the intersection -> should yield 10.1.0.0/16 and 10.1.1.0/24
296        let sub: Vec<_> = isect
297            .find(&p(0x0a010000, 16))
298            .unwrap()
299            .into_iter()
300            .map(|(p, (l, r))| (p, (*l, *r)))
301            .collect();
302        assert_eq!(
303            sub,
304            vec![(p(0x0a010000, 16), (2, 20)), (p(0x0a010100, 24), (3, 30))]
305        );
306    }
307
308    #[test]
309    fn intersection_find_exact_and_value() {
310        let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0a010100, 24, 3)]);
311        let b = map_from(&[
312            (0x0a000000, 8, 10),
313            (0x0a010000, 16, 20),
314            (0x0a020000, 16, 40), // not in a
315        ]);
316
317        let isect = a.view().intersection(b.view()).unwrap();
318
319        // find_exact on a prefix present in both
320        let v = isect.clone().find_exact(&p(0x0a010000, 16)).unwrap();
321        let (l, r) = v.value().unwrap();
322        assert_eq!((*l, *r), (2, 20));
323
324        // find_exact on a prefix present only in a (not in b) -> None
325        assert!(isect.find_exact(&p(0x0a010100, 24)).is_none());
326    }
327
328    #[test]
329    fn intersection_mut_find_lpm_value_does_not_require_clone() {
330        let mut a = map_from(&[(0x0a000000, 8, 1), (0x0a010100, 24, 3)]);
331        let b = map_from(&[(0x0a000000, 8, 10), (0x0a010100, 24, 30)]);
332
333        let got = (&mut a)
334            .view()
335            .intersection(b.view())
336            .unwrap()
337            .find_lpm_value(&p(0x0a010180, 25))
338            .map(|(prefix, (left, right))| {
339                *left += *right;
340                (prefix, *left, *right)
341            });
342
343        assert_eq!(got, Some((p(0x0a010100, 24), 33, 30)));
344        assert_eq!(a.get(&p(0x0a010100, 24)), Some(&33));
345    }
346
347    // -- iter_from on intersection views ----------------------------------------
348
349    #[test]
350    fn intersection_iter_from_inclusive() {
351        let a = map_from(&[
352            (0x0a000000, 8, 1),
353            (0x0a010000, 16, 2),
354            (0x0a020000, 16, 3),
355            (0x0a030000, 16, 4),
356        ]);
357        let b = map_from(&[
358            (0x0a000000, 8, 10),
359            (0x0a020000, 16, 30),
360            (0x0a030000, 16, 40),
361        ]);
362
363        // Intersection: 10/8, 10.2/16, 10.3/16
364        let isect = a.view().intersection(b.view()).unwrap();
365        let from: Vec<_> = isect
366            .iter_from(&p(0x0a020000, 16), true)
367            .map(|(p, (l, r))| (p, (*l, *r)))
368            .collect();
369        assert_eq!(
370            from,
371            vec![(p(0x0a020000, 16), (3, 30)), (p(0x0a030000, 16), (4, 40))]
372        );
373    }
374
375    #[test]
376    fn intersection_iter_from_exclusive() {
377        let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0a020000, 16, 3)]);
378        let b = map_from(&[
379            (0x0a000000, 8, 10),
380            (0x0a010000, 16, 20),
381            (0x0a020000, 16, 30),
382        ]);
383
384        let isect = a.view().intersection(b.view()).unwrap();
385        let from: Vec<_> = isect
386            .iter_from(&p(0x0a000000, 8), false)
387            .map(|(p, (l, r))| (p, (*l, *r)))
388            .collect();
389        assert_eq!(
390            from,
391            vec![(p(0x0a010000, 16), (2, 20)), (p(0x0a020000, 16), (3, 30))]
392        );
393    }
394
395    #[test]
396    fn intersection_iter_from_subview() {
397        let a = map_from(&[
398            (0x0a000000, 8, 1), // excluded by sub-view
399            (0x0a020000, 16, 2),
400            (0x0a030000, 16, 3),
401            (0x0b000000, 8, 4), // excluded by sub-view
402        ]);
403        let b = map_from(&[
404            (0x0a000000, 8, 10), // excluded by sub-view
405            (0x0a020000, 16, 20),
406            (0x0a030000, 16, 30),
407        ]);
408
409        // Sub-view at 10.2.0.0/15 covers 10.2–10.3, excludes 10/8, 11/8
410        // Intersection: 10.2/16, 10.3/16
411        let isect = a
412            .view_at(&p(0x0a020000, 15))
413            .unwrap()
414            .intersection(b.view_at(&p(0x0a020000, 15)).unwrap())
415            .unwrap();
416
417        let all: Vec<_> = isect
418            .clone()
419            .iter()
420            .map(|(p, (l, r))| (p, (*l, *r)))
421            .collect();
422        assert_eq!(
423            all,
424            vec![(p(0x0a020000, 16), (2, 20)), (p(0x0a030000, 16), (3, 30))]
425        );
426
427        // iter_from exclusive at 10.2/16 → only 10.3/16
428        let from: Vec<_> = isect
429            .iter_from(&p(0x0a020000, 16), false)
430            .map(|(p, (l, r))| (p, (*l, *r)))
431            .collect();
432        assert_eq!(from, vec![(p(0x0a030000, 16), (3, 30))]);
433    }
434}