Skip to main content

prefix_trie/trieview/
trie_ref.rs

1//! Concrete [`TrieRef`] cursor implementing [`TrieView`].
2//!
3//! [`TrieRef`] is a lightweight, `Copy`able immutable cursor. All set-operation views
4//! ([`IntersectionView`][super::IntersectionView], [`UnionView`][super::union::UnionView],
5//! etc.) can be composed from `TrieRef` leaves.
6
7use std::marker::PhantomData;
8
9use num_traits::Zero;
10
11use crate::{
12    Prefix,
13    {
14        allocator::Loc,
15        map::PrefixMap,
16        node::{child_cover_mask, data_cover_mask, extend_repr},
17        table::{DataIdx, Table, K},
18    },
19};
20
21use super::{AsView, TrieView, ViewIter};
22
23/// An immutable cursor implementing [`TrieView`].
24///
25/// # Invariant
26///
27/// `depth <= prefix_len < depth + K`, where `K` is the stride of a `MultiBitNode`.
28/// `depth` is always a multiple of `K`.
29/// `key` contains the accumulated bits (only the top `prefix_len` bits are significant).
30pub struct TrieRef<'a, P: Prefix, T> {
31    pub(super) table: &'a Table<T>,
32    /// Location of the `MultiBitNode` that contains this view's root position.
33    pub(super) node_loc: Loc,
34    /// Depth of `node_loc`: always a multiple of `K`.
35    pub(super) depth: u32,
36    /// Accumulated key bits (only the top `prefix_len` bits are significant).
37    pub(super) key: P::R,
38    /// Binary-tree depth of this view's root position.
39    pub(super) prefix_len: u32,
40    pub(super) _marker: PhantomData<P>,
41}
42
43// `TrieRef` holds only a shared reference (`&'a Table<T>`) and a `Loc` + key bits,
44// so it is always `Copy`/`Clone` regardless of whether `P` or `T` are.
45impl<'a, P: Prefix, T> Clone for TrieRef<'a, P, T> {
46    fn clone(&self) -> Self {
47        *self
48    }
49}
50impl<'a, P: Prefix, T> Copy for TrieRef<'a, P, T> {}
51
52impl<'a, P: Prefix, T> TrieRef<'a, P, T> {
53    /// Create a view at the root of the given table.
54    pub(crate) fn new_root(table: &'a Table<T>) -> Self {
55        Self {
56            table,
57            node_loc: Loc::root(),
58            depth: 0,
59            key: P::R::zero(),
60            prefix_len: 0,
61            _marker: PhantomData,
62        }
63    }
64}
65
66impl<'a, P: Prefix, T> TrieView<'a> for TrieRef<'a, P, T> {
67    type P = P;
68    type T = &'a T;
69
70    #[inline]
71    fn depth(&self) -> u32 {
72        self.depth
73    }
74
75    #[inline]
76    fn key(&self) -> P::R {
77        self.key
78    }
79
80    #[inline]
81    fn prefix_len(&self) -> u32 {
82        self.prefix_len
83    }
84
85    #[inline]
86    fn data_bitmap(&self) -> u32 {
87        self.table.node(self.node_loc).data_bitmap()
88            & data_cover_mask(self.depth, self.key, self.prefix_len)
89    }
90
91    #[inline]
92    fn child_bitmap(&self) -> u32 {
93        self.table.node(self.node_loc).child_bitmap()
94            & child_cover_mask(self.depth, self.key, self.prefix_len)
95    }
96
97    #[inline]
98    unsafe fn get_data(&mut self, data_bit: u32) -> &'a T {
99        let idx = DataIdx {
100            node: self.node_loc,
101            bit: data_bit,
102            depth: self.depth,
103        };
104        // Note: `resolve` re-reads the node's current AllocIdx and bitmap, which is strictly
105        // unnecessary here: TrieView iterators never mutate node structure, so `idx.node` is
106        // always a valid, stable location for the lifetime of this view. We use `resolve` for
107        // uniformity with the rest of the codebase; the compiler should inline/optimize it away.
108        //
109        // SAFETY: caller guarantees data_bit is set in data_bitmap().
110        unsafe { idx.resolve(self.table) }
111            .expect("get_data: data_bit not set in bitmap")
112            .get()
113    }
114
115    #[inline]
116    unsafe fn get_child(&mut self, child_bit: u32) -> Self {
117        // SAFETY: `self.node_loc` is maintained as a valid, live node location by the
118        // `TrieView` invariant (no structural mutations during the view's lifetime).
119        let child_loc = unsafe { self.table.child(self.node_loc, child_bit) }
120            .expect("get_child: child_bit not set in bitmap");
121        let new_key = extend_repr(self.key, self.depth, child_bit);
122        Self {
123            table: self.table,
124            node_loc: child_loc,
125            depth: self.depth + K,
126            key: new_key,
127            prefix_len: self.depth + K,
128            _marker: PhantomData,
129        }
130    }
131
132    #[inline]
133    unsafe fn reposition(&mut self, key: P::R, prefix_len: u32) {
134        let _old_prefix = self.prefix();
135        self.key = key;
136        self.prefix_len = prefix_len;
137        // ensure that we always go deeper inside.
138        debug_assert!(_old_prefix.contains(&self.prefix()));
139    }
140}
141
142impl<'a, P: Prefix, T> IntoIterator for TrieRef<'a, P, T> {
143    type Item = (P, &'a T);
144    type IntoIter = ViewIter<'a, TrieRef<'a, P, T>>;
145
146    fn into_iter(self) -> Self::IntoIter {
147        self.iter()
148    }
149}
150
151impl<'a, P: Prefix, T> AsView<'a> for &'a PrefixMap<P, T> {
152    type P = P;
153    type View = TrieRef<'a, P, T>;
154
155    fn view(self) -> TrieRef<'a, P, T> {
156        TrieRef::new_root(self.table())
157    }
158}
159
160impl<'a, P: Prefix, T> AsView<'a> for TrieRef<'a, P, T> {
161    type P = P;
162    type View = TrieRef<'a, P, T>;
163
164    fn view(self) -> TrieRef<'a, P, T> {
165        self
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use crate::{
172        Prefix,
173        {
174            trieview::{AsView, TrieView},
175            PrefixMap,
176        },
177    };
178
179    type P = (u32, u8);
180
181    fn p(repr: u32, len: u8) -> P {
182        P::from_repr_len(repr, len)
183    }
184
185    fn map_from(entries: &[(u32, u8, i32)]) -> PrefixMap<P, i32> {
186        let mut m = PrefixMap::new();
187        for &(repr, len, val) in entries {
188            m.insert(p(repr, len), val);
189        }
190        m
191    }
192
193    #[test]
194    fn view_iter_all() {
195        let m = map_from(&[
196            (0x0a000000, 8, 1),
197            (0x0a010000, 16, 2),
198            (0x0a020000, 16, 3),
199            (0x0a010000, 24, 4),
200        ]);
201        let expected: Vec<(P, i32)> = m.iter().map(|(p, v)| (p, *v)).collect();
202        let from_view: Vec<(P, i32)> = m.view().iter().map(|(p, v)| (p, *v)).collect();
203        assert_eq!(from_view, expected);
204    }
205
206    #[test]
207    fn view_at_subtrie() {
208        let m = map_from(&[
209            (0x0a000000, 8, 1),
210            (0x0a010000, 16, 2),
211            (0x0a020000, 16, 3),
212            (0x0a010000, 24, 4),
213        ]);
214        let got: Vec<_> = m
215            .view_at(&p(0x0a010000, 16))
216            .map(|v| v.iter().map(|(p, x)| (p, *x)).collect::<Vec<_>>())
217            .unwrap_or_default();
218        assert_eq!(got, vec![(p(0x0a010000, 16), 2), (p(0x0a010000, 24), 4)]);
219    }
220
221    #[test]
222    fn view_value() {
223        let m = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
224        let v = m.view().find(&p(0x0a010000, 16)).unwrap();
225        assert_eq!(v.value(), Some(&2));
226        let v2 = m.view().find(&p(0x0a000000, 8)).unwrap();
227        assert_eq!(v2.value(), Some(&1));
228    }
229
230    #[test]
231    fn view_find_exact() {
232        let m = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 24, 4)]);
233        assert!(m.view().find_exact(&p(0x0a010000, 16)).is_none());
234        assert!(m.view().find_exact(&p(0x0a000000, 8)).is_some());
235    }
236
237    #[test]
238    fn view_find_exact_value() {
239        let m = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 24, 4)]);
240        assert_eq!(m.view().find_exact_value(&p(0x0a010000, 16)), None);
241        assert_eq!(
242            m.view()
243                .find_exact_value(&p(0x0a010000, 24))
244                .map(|(p, v)| (p, *v)),
245            Some((p(0x0a010000, 24), 4))
246        );
247    }
248
249    #[test]
250    fn view_find_lpm() {
251        let m = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0a010100, 24, 3)]);
252        let v = m.view().find_lpm(&p(0x0a010180, 25)).unwrap();
253        assert_eq!(v.prefix(), p(0x0a010100, 24));
254        assert_eq!(v.value(), Some(&3));
255
256        let v = m.view().find_lpm(&p(0x0a020000, 16)).unwrap();
257        assert_eq!(v.prefix(), p(0x0a000000, 8));
258        assert_eq!(v.value(), Some(&1));
259        assert!(m.view().find_lpm(&p(0x0b000000, 8)).is_none());
260    }
261
262    #[test]
263    fn view_find_lpm_value() {
264        let m = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0a010100, 24, 3)]);
265        assert_eq!(
266            m.view()
267                .find_lpm_value(&p(0x0a010180, 25))
268                .map(|(p, v)| (p, *v)),
269            Some((p(0x0a010100, 24), 3))
270        );
271    }
272
273    #[test]
274    fn view_prefix_value_keys_values() {
275        let m = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
276        assert_eq!(
277            m.view()
278                .find_exact(&p(0x0a010000, 16))
279                .unwrap()
280                .prefix_value()
281                .map(|(p, v)| (p, *v)),
282            Some((p(0x0a010000, 16), 2))
283        );
284        assert_eq!(
285            m.view().keys().collect::<Vec<_>>(),
286            vec![p(0x0a000000, 8), p(0x0a010000, 16)]
287        );
288        assert_eq!(m.view().values().copied().collect::<Vec<_>>(), vec![1, 2]);
289    }
290
291    #[test]
292    fn view_prefix_reconstruction() {
293        let m = map_from(&[(0x0a010203, 32, 99)]);
294        let v = m.view().find_exact(&p(0x0a010203, 32)).unwrap();
295        assert_eq!(v.prefix(), p(0x0a010203, 32));
296        assert_eq!(v.value(), Some(&99));
297    }
298
299    #[test]
300    fn view_into_iter() {
301        let m = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
302        // TrieRef: for loop via IntoIterator
303        let from_for: Vec<(P, i32)> = m.view().into_iter().map(|(p, v)| (p, *v)).collect();
304        let expected: Vec<(P, i32)> = m.iter().map(|(p, v)| (p, *v)).collect();
305        assert_eq!(from_for, expected);
306    }
307
308    // -- iter_from on views --
309
310    #[test]
311    fn view_iter_from_inclusive() {
312        // 10.0.0.0/8, 10.1.0.0/16, 10.2.0.0/16, 10.3.0.0/16, 10.4.0.0/16
313        let m = map_from(&[
314            (0x0a000000, 8, 1),
315            (0x0a010000, 16, 2),
316            (0x0a020000, 16, 3),
317            (0x0a030000, 16, 4),
318            (0x0a040000, 16, 5),
319        ]);
320
321        // From first entry → everything
322        let all: Vec<_> = m
323            .view()
324            .iter_from(&p(0x0a000000, 8), true)
325            .map(|(p, v)| (p, *v))
326            .collect();
327        assert_eq!(all, m.iter().map(|(p, v)| (p, *v)).collect::<Vec<_>>());
328
329        // From a middle entry
330        let from_mid: Vec<_> = m
331            .view()
332            .iter_from(&p(0x0a020000, 16), true)
333            .map(|(p, v)| (p, *v))
334            .collect();
335        assert_eq!(
336            from_mid,
337            vec![
338                (p(0x0a020000, 16), 3),
339                (p(0x0a030000, 16), 4),
340                (p(0x0a040000, 16), 5)
341            ]
342        );
343
344        // From last entry
345        let last: Vec<_> = m
346            .view()
347            .iter_from(&p(0x0a040000, 16), true)
348            .map(|(p, v)| (p, *v))
349            .collect();
350        assert_eq!(last, vec![(p(0x0a040000, 16), 5)]);
351    }
352
353    #[test]
354    fn view_iter_from_exclusive() {
355        let m = map_from(&[
356            (0x0a000000, 8, 1),
357            (0x0a010000, 16, 2),
358            (0x0a020000, 16, 3),
359            (0x0a030000, 16, 4),
360            (0x0a040000, 16, 5),
361        ]);
362
363        let after_mid: Vec<_> = m
364            .view()
365            .iter_from(&p(0x0a020000, 16), false)
366            .map(|(p, v)| (p, *v))
367            .collect();
368        assert_eq!(
369            after_mid,
370            vec![(p(0x0a030000, 16), 4), (p(0x0a040000, 16), 5)]
371        );
372
373        // Exclusive from last → empty
374        let after_last: Vec<_> = m.view().iter_from(&p(0x0a040000, 16), false).collect();
375        assert!(after_last.is_empty());
376
377        // Pagination
378        let page: Vec<_> = m
379            .view()
380            .iter_from(&p(0x0a010000, 16), false)
381            .take(2)
382            .map(|(p, v)| (p, *v))
383            .collect();
384        assert_eq!(page, vec![(p(0x0a020000, 16), 3), (p(0x0a030000, 16), 4)]);
385    }
386
387    #[test]
388    fn view_iter_from_nonexistent() {
389        let m = map_from(&[(0x0a000000, 8, 1), (0x0a020000, 16, 2), (0x0a040000, 16, 3)]);
390
391        // Non-existent prefix between entries
392        let from: Vec<_> = m
393            .view()
394            .iter_from(&p(0x0a010000, 16), true)
395            .map(|(p, v)| (p, *v))
396            .collect();
397        assert_eq!(from, vec![(p(0x0a020000, 16), 2), (p(0x0a040000, 16), 3)]);
398
399        // Past all entries
400        let from: Vec<_> = m.view().iter_from(&p(0x0b000000, 8), true).collect();
401        assert!(from.is_empty());
402    }
403
404    #[test]
405    fn view_iter_from_empty() {
406        let m: PrefixMap<P, i32> = PrefixMap::new();
407        let from: Vec<_> = m.view().iter_from(&p(0x0a000000, 8), true).collect();
408        assert!(from.is_empty());
409    }
410
411    #[test]
412    fn view_iter_from_parent_child() {
413        let m = map_from(&[
414            (0x0a000000, 8, 1),
415            (0x0a000000, 16, 2),
416            (0x0a000000, 24, 3),
417            (0x0a010000, 16, 4),
418        ]);
419
420        // Exclusive from parent → children only
421        let from: Vec<_> = m
422            .view()
423            .iter_from(&p(0x0a000000, 8), false)
424            .map(|(p, v)| (p, *v))
425            .collect();
426        assert_eq!(
427            from,
428            vec![
429                (p(0x0a000000, 16), 2),
430                (p(0x0a000000, 24), 3),
431                (p(0x0a010000, 16), 4)
432            ]
433        );
434    }
435
436    #[test]
437    fn view_iter_from_subview() {
438        let m = map_from(&[
439            (0x0a000000, 8, 1),  // 10.0.0.0/8
440            (0x0a010000, 16, 2), // 10.1.0.0/16
441            (0x0a010000, 24, 3), // 10.1.0.0/24
442            (0x0a020000, 16, 4), // 10.2.0.0/16
443            (0x0b000000, 8, 5),  // 11.0.0.0/8  — outside sub-view
444        ]);
445
446        // Sub-view at 10.1.0.0/16 excludes 10/8, 10.2/16, 11/8
447        let sub = m.view_at(&p(0x0a010000, 16)).unwrap();
448        let all: Vec<_> = sub.iter().map(|(p, v)| (p, *v)).collect();
449        assert_eq!(all, vec![(p(0x0a010000, 16), 2), (p(0x0a010000, 24), 3)]);
450
451        // iter_from exclusive skips the root of the sub-view
452        let from: Vec<_> = sub
453            .iter_from(&p(0x0a010000, 16), false)
454            .map(|(p, v)| (p, *v))
455            .collect();
456        assert_eq!(from, vec![(p(0x0a010000, 24), 3)]);
457    }
458
459    #[test]
460    fn view_iter_from_outside_subview() {
461        let m = map_from(&[
462            (0x0a010000, 16, 1),
463            (0x0a010000, 24, 2),
464            (0x0a020000, 16, 3),
465        ]);
466
467        // Sub-view at 10.1.0.0/16; target before sub-view → full iter
468        let sub = m.view_at(&p(0x0a010000, 16)).unwrap();
469        let from: Vec<_> = sub
470            .iter_from(&p(0x09000000, 8), true)
471            .map(|(p, v)| (p, *v))
472            .collect();
473        let all: Vec<_> = sub.iter().map(|(p, v)| (p, *v)).collect();
474        assert_eq!(from, all);
475
476        // Sub-view at 10.1.0.0/16; target after sub-view → empty
477        let sub = m.view_at(&p(0x0a010000, 16)).unwrap();
478        let from: Vec<_> = sub.iter_from(&p(0x0a020000, 16), true).collect();
479        assert!(from.is_empty());
480    }
481
482    #[test]
483    fn view_right_at_max_prefix_len() {
484        // Calling right() on a view at prefix_len == num_bits (32 for u32) must
485        // not panic. step() computes bit_pos = num_bits - prefix_len - 1 which
486        // underflows when prefix_len == num_bits.
487        let m = map_from(&[(0x01020304, 32, 1)]);
488        let v = m.view().find(&p(0x01020304, 32)).unwrap();
489        assert_eq!(v.prefix_len(), 32);
490        // This should return None (can't go deeper than /32), not panic.
491        assert!(v.right().is_none());
492        assert!(v.left().is_none());
493    }
494
495    #[test]
496    fn view_find_exact_slash32() {
497        let m = map_from(&[
498            (0x01020300, 32, 1),
499            (0x01020301, 32, 2),
500            (0x01020302, 32, 3),
501            (0x01020303, 32, 4),
502        ]);
503        for repr in 0x01020300..=0x01020303u32 {
504            let v = m.view().find_exact(&p(repr, 32)).unwrap();
505            assert_eq!(v.prefix(), p(repr, 32));
506            assert_eq!(v.value(), Some(&((repr - 0x01020300 + 1) as i32)));
507        }
508        assert!(m.view().find_exact(&p(0x01020304, 32)).is_none());
509    }
510
511    #[test]
512    fn view_find_lpm_slash32() {
513        let m = map_from(&[(0x01020300, 24, 10), (0x01020304, 32, 42)]);
514        let v = m.view().find_lpm(&p(0x01020304, 32)).unwrap();
515        assert_eq!(v.prefix(), p(0x01020304, 32));
516        assert_eq!(v.value(), Some(&42));
517
518        // LPM for a /32 without an exact match should find the covering /24
519        let v = m.view().find_lpm(&p(0x01020305, 32)).unwrap();
520        assert_eq!(v.prefix(), p(0x01020300, 24));
521        assert_eq!(v.value(), Some(&10));
522    }
523
524    #[test]
525    fn view_navigate_to_slash32() {
526        let m = map_from(&[(0x01020304, 32, 1)]);
527        let v = m.view().find(&p(0x01020304, 32)).unwrap();
528        assert_eq!(v.prefix_len(), 32);
529        assert_eq!(v.prefix(), p(0x01020304, 32));
530        assert_eq!(v.value(), Some(&1));
531    }
532
533    #[test]
534    fn view_iter_at_slash32() {
535        // A view navigated to a /32 should iterate only that single entry.
536        let m = map_from(&[
537            (0x01020300, 24, 10),
538            (0x01020304, 32, 42),
539            (0x01020305, 32, 43),
540        ]);
541        let v = m.view().find(&p(0x01020304, 32)).unwrap();
542        let entries: Vec<_> = v.iter().map(|(k, v)| (k, *v)).collect();
543        assert_eq!(entries, vec![(p(0x01020304, 32), 42)]);
544    }
545
546    #[test]
547    fn view_step_through_all_depths() {
548        // Walk from root to a /32 via left()/right(), one bit at a time.
549        // Key 0xAAAAAAAA = 1010_1010_... so we alternate right/left.
550        let key = 0xAAAAAAAAu32;
551        let m = map_from(&[(key, 32, 99)]);
552        let mut v = m.view();
553        for bit in 0..32u32 {
554            let go_right = (key >> (31 - bit)) & 1 == 1;
555            v = if go_right {
556                v.right()
557                    .unwrap_or_else(|| panic!("right() failed at bit {bit}"))
558            } else {
559                v.left()
560                    .unwrap_or_else(|| panic!("left() failed at bit {bit}"))
561            };
562        }
563        assert_eq!(v.prefix_len(), 32);
564        assert_eq!(v.prefix(), p(key, 32));
565        assert_eq!(v.value(), Some(&99));
566        // One more step should return None
567        assert!(v.left().is_none());
568        assert!(v.right().is_none());
569    }
570}