Skip to main content

prefix_trie/trieview/
trie_ref_mut.rs

1//! Concrete [`TrieRefMut`] cursor implementing [`TrieView`].
2
3use std::marker::PhantomData;
4
5use num_traits::Zero;
6
7use crate::{
8    Prefix,
9    {
10        allocator::{Loc, RawPtr},
11        node::{child_cover_mask, data_cover_mask, extend_repr},
12        table::{DataIdx, Table, K},
13        AsView, PrefixMap,
14    },
15};
16
17use super::{TrieView, ViewIter};
18
19/// A mutable cursor implementing [`TrieView`].
20///
21/// # Invariant
22///
23/// `depth <= prefix_len < depth + K`, where `K` is the stride of a `MultiBitNode`.
24/// `depth` is always a multiple of `K`.
25/// `key` contains the accumulated bits (only the top `prefix_len` bits are significant).
26pub struct TrieRefMut<'a, P: Prefix, T> {
27    pub(super) table: &'a Table<T>,
28    /// a raw pointer into the data of the table used to access mutable references to data
29    pub(super) raw: RawPtr<T>,
30    /// Location of the `MultiBitNode` that contains this view's root position.
31    pub(super) node_loc: Loc,
32    /// Depth of `node_loc`: always a multiple of `K`.
33    pub(super) depth: u32,
34    /// Accumulated key bits (only the top `prefix_len` bits are significant).
35    pub(super) key: P::R,
36    /// Binary-tree depth of this view's root position.
37    pub(super) prefix_len: u32,
38    pub(super) _marker: PhantomData<P>,
39}
40
41impl<'a, P: Prefix, T> TrieRefMut<'a, P, T> {
42    /// Create a view at the root of the given table.
43    pub(crate) fn new_root(table: &'a Table<T>, raw: RawPtr<T>) -> Self {
44        Self {
45            table,
46            raw,
47            node_loc: Loc::root(),
48            depth: 0,
49            key: P::R::zero(),
50            prefix_len: 0,
51            _marker: PhantomData,
52        }
53    }
54}
55
56impl<'a, P: Prefix, T> TrieView<'a> for TrieRefMut<'a, P, T> {
57    type P = P;
58    type T = &'a mut T;
59
60    #[inline]
61    fn depth(&self) -> u32 {
62        self.depth
63    }
64
65    #[inline]
66    fn key(&self) -> P::R {
67        self.key
68    }
69
70    #[inline]
71    fn prefix_len(&self) -> u32 {
72        self.prefix_len
73    }
74
75    #[inline]
76    fn data_bitmap(&self) -> u32 {
77        self.table.node(self.node_loc).data_bitmap()
78            & data_cover_mask(self.depth, self.key, self.prefix_len)
79    }
80
81    #[inline]
82    fn child_bitmap(&self) -> u32 {
83        self.table.node(self.node_loc).child_bitmap()
84            & child_cover_mask(self.depth, self.key, self.prefix_len)
85    }
86
87    #[inline]
88    unsafe fn get_data(&mut self, data_bit: u32) -> &'a mut T {
89        let idx = DataIdx {
90            node: self.node_loc,
91            bit: data_bit,
92            depth: self.depth,
93        };
94        // Note: `resolve` re-reads the node's current AllocIdx and bitmap, which is strictly
95        // unnecessary here: TrieView iterators never mutate node structure, so `idx.node` is
96        // always a valid, stable location for the lifetime of this view. We use `resolve` for
97        // uniformity with the rest of the codebase; the compiler should inline/optimize it away.
98        //
99        // SAFETY: caller guarantees data_bit is set in data_bitmap().
100        let r = unsafe { idx.resolve(self.table) }.expect("get_data: data_bit not set in bitmap");
101        // SAFETY: TrieRefMut was created from `&'a mut PrefixMap`, and `raw` was obtained from
102        // that same map's table. The tree is acyclic so each data slot is visited at most once
103        // per iteration, guaranteeing no two live `&mut T` references to the same slot.
104        unsafe { r.unsafe_get_mut(&mut self.raw) }
105    }
106
107    #[inline]
108    unsafe fn get_child(&mut self, child_bit: u32) -> Self {
109        // SAFETY: `self.node_loc` is valid (no structural mutations during view lifetime).
110        // The returned child shares `self.raw`; disjoint access is guaranteed by the TrieView
111        // contract (different child bits → disjoint subtrees in the acyclic tree).
112        let child_loc = unsafe { self.table.child(self.node_loc, child_bit) }
113            .expect("get_child: child_bit not set in bitmap");
114        let new_key = extend_repr(self.key, self.depth, child_bit);
115        Self {
116            table: self.table,
117            raw: self.raw,
118            node_loc: child_loc,
119            depth: self.depth + K,
120            key: new_key,
121            prefix_len: self.depth + K,
122            _marker: PhantomData,
123        }
124    }
125
126    #[inline]
127    unsafe fn reposition(&mut self, key: P::R, prefix_len: u32) {
128        let _old_prefix = self.prefix();
129        self.key = key;
130        self.prefix_len = prefix_len;
131        // ensure that we always go deeper inside.
132        debug_assert!(_old_prefix.contains(&self.prefix()));
133    }
134}
135
136impl<'a, P: Prefix, T> IntoIterator for TrieRefMut<'a, P, T> {
137    type Item = (P, &'a mut T);
138    type IntoIter = ViewIter<'a, TrieRefMut<'a, P, T>>;
139
140    fn into_iter(self) -> Self::IntoIter {
141        self.iter()
142    }
143}
144
145impl<'a, P: Prefix, T> AsView<'a> for &'a mut PrefixMap<P, T> {
146    type P = P;
147    type View = TrieRefMut<'a, P, T>;
148
149    fn view(self) -> TrieRefMut<'a, P, T> {
150        let raw = self.table_mut().raw_cells();
151        TrieRefMut::new_root(self.table(), raw)
152    }
153}
154
155impl<'a, P: Prefix, T> AsView<'a> for TrieRefMut<'a, P, T> {
156    type P = P;
157    type View = TrieRefMut<'a, P, T>;
158
159    fn view(self) -> TrieRefMut<'a, P, T> {
160        self
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use crate::{
167        Prefix,
168        {
169            trieview::{AsView, TrieView},
170            PrefixMap,
171        },
172    };
173
174    type P = (u32, u8);
175
176    fn p(repr: u32, len: u8) -> P {
177        P::from_repr_len(repr, len)
178    }
179
180    fn map_from(entries: &[(u32, u8, i32)]) -> PrefixMap<P, i32> {
181        let mut m = PrefixMap::new();
182        for &(repr, len, val) in entries {
183            m.insert(p(repr, len), val);
184        }
185        m
186    }
187
188    #[test]
189    fn view_mut_iter_all() {
190        let mut m = map_from(&[
191            (0x0a000000, 8, 1),
192            (0x0a010000, 16, 2),
193            (0x0a020000, 16, 3),
194            (0x0a010000, 24, 4),
195        ]);
196        let expected: Vec<(P, i32)> = m.iter_mut().map(|(p, v)| (p, *v)).collect();
197        let from_view: Vec<(P, i32)> = (&mut m).view().iter().map(|(p, v)| (p, *v)).collect();
198        assert_eq!(from_view, expected);
199    }
200
201    #[test]
202    fn view_mut_at_subtrie() {
203        let m = map_from(&[
204            (0x0a000000, 8, 1),
205            (0x0a010000, 16, 2),
206            (0x0a020000, 16, 3),
207            (0x0a010000, 24, 4),
208        ]);
209        let got: Vec<_> = m
210            .view_at(&p(0x0a010000, 16))
211            .map(|v| v.iter().map(|(p, x)| (p, *x)).collect::<Vec<_>>())
212            .unwrap_or_default();
213        assert_eq!(got, vec![(p(0x0a010000, 16), 2), (p(0x0a010000, 24), 4)]);
214    }
215
216    #[test]
217    fn view_mut_value() {
218        let mut m = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
219        let v = (&mut m).view().find(&p(0x0a010000, 16)).unwrap();
220        assert_eq!(v.value(), Some(&mut 2));
221        let v2 = (&mut m).view().find(&p(0x0a000000, 8)).unwrap();
222        assert_eq!(v2.value(), Some(&mut 1));
223    }
224
225    #[test]
226    fn view_mut_find_exact() {
227        let mut m = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 24, 4)]);
228        assert!((&mut m).view().find_exact(&p(0x0a010000, 16)).is_none());
229        assert!((&mut m).view().find_exact(&p(0x0a000000, 8)).is_some());
230    }
231
232    #[test]
233    fn view_mut_find_exact_value() {
234        let mut m = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 24, 4)]);
235        assert_eq!((&mut m).view().find_exact_value(&p(0x0a010000, 16)), None);
236        let got = (&mut m)
237            .view()
238            .find_exact_value(&p(0x0a010000, 24))
239            .map(|(p, v)| {
240                *v += 10;
241                (p, *v)
242            });
243        assert_eq!(got, Some((p(0x0a010000, 24), 14)));
244        assert_eq!(m.get(&p(0x0a010000, 24)), Some(&14));
245    }
246
247    #[test]
248    fn view_mut_find_lpm_value() {
249        let mut m = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0a010100, 24, 3)]);
250        let got = (&mut m)
251            .view()
252            .find_lpm_value(&p(0x0a010180, 25))
253            .map(|(p, v)| {
254                *v += 10;
255                (p, *v)
256            });
257        assert_eq!(got, Some((p(0x0a010100, 24), 13)));
258        assert_eq!(m.get(&p(0x0a010100, 24)), Some(&13));
259    }
260
261    #[test]
262    fn view_mut_prefix_value_keys_values() {
263        let mut m = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
264        let got = (&mut m)
265            .view()
266            .find_exact(&p(0x0a010000, 16))
267            .unwrap()
268            .prefix_value()
269            .map(|(p, v)| {
270                *v += 10;
271                (p, *v)
272            });
273        assert_eq!(got, Some((p(0x0a010000, 16), 12)));
274        assert_eq!(m.get(&p(0x0a010000, 16)), Some(&12));
275        assert_eq!(
276            m.view().keys().collect::<Vec<_>>(),
277            vec![p(0x0a000000, 8), p(0x0a010000, 16)]
278        );
279        assert_eq!(m.view().values().copied().collect::<Vec<_>>(), vec![1, 12]);
280    }
281
282    #[test]
283    fn view_mut_prefix_reconstruction() {
284        let mut m = map_from(&[(0x0a010203, 32, 99)]);
285        let v = (&mut m).view().find_exact(&p(0x0a010203, 32)).unwrap();
286        assert_eq!(v.prefix(), p(0x0a010203, 32));
287        assert_eq!(v.value(), Some(&mut 99));
288    }
289
290    #[test]
291    fn view_mut_into_iter() {
292        let mut m = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
293        // TrieRef: for loop via IntoIterator
294        let from_for: Vec<(P, i32)> = (&mut m).view().into_iter().map(|(p, v)| (p, *v)).collect();
295        let expected: Vec<(P, i32)> = m.iter().map(|(p, v)| (p, *v)).collect();
296        assert_eq!(from_for, expected);
297    }
298}