nomt_core/
trie_pos.rs

1use crate::{
2    page::DEPTH,
3    page_id::{ChildPageIndex, PageId, ROOT_PAGE_ID},
4    trie::KeyPath,
5};
6use alloc::fmt;
7use bitvec::prelude::*;
8
9/// Encapsulates logic for moving around in paged storage for a binary trie.
10#[derive(Clone)]
11#[cfg_attr(
12    feature = "borsh",
13    derive(borsh::BorshDeserialize, borsh::BorshSerialize)
14)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16pub struct TriePosition {
17    // The bits after depth are irrelevant.
18    path: [u8; 32],
19    depth: u16,
20    node_index: usize,
21}
22
23impl PartialEq for TriePosition {
24    fn eq(&self, other: &Self) -> bool {
25        self.path() == other.path()
26    }
27}
28
29impl Eq for TriePosition {}
30
31impl TriePosition {
32    /// Create a new `TriePosition` at the root.
33    pub fn new() -> Self {
34        TriePosition {
35            path: [0; 32],
36            depth: 0,
37            node_index: 0,
38        }
39    }
40
41    /// Create a new `TriePosition` based on the first `depth` bits of `path`.
42    ///
43    /// Panics if depth is zero.
44    pub fn from_path_and_depth(path: KeyPath, depth: u16) -> Self {
45        assert_ne!(depth, 0, "depth must be non-zero");
46        assert!(depth <= 256);
47        let page_path = last_page_path(&path, depth);
48        TriePosition {
49            path,
50            depth,
51            node_index: node_index(&page_path),
52        }
53    }
54
55    /// Create a new `TriePosition` based on a bitslice.
56    pub fn from_bitslice(slice: &BitSlice<u8, Msb0>) -> Self {
57        assert!(slice.len() <= 256);
58
59        let mut path = [0; 32];
60        path.view_bits_mut::<Msb0>()[..slice.len()].copy_from_bitslice(slice);
61        Self::from_path_and_depth(path, slice.len() as u16)
62    }
63
64    /// Parse a `TriePosition` from a bit string.
65    #[cfg(test)]
66    pub fn from_str(s: &str) -> Self {
67        let mut bitvec = BitVec::<u8, Msb0>::new();
68        if s.len() > 256 {
69            panic!("bit string too long");
70        }
71        for ch in s.chars() {
72            match ch {
73                '0' => bitvec.push(false),
74                '1' => bitvec.push(true),
75                _ => panic!("invalid character in bit string"),
76            }
77        }
78        let node_index = node_index(&bitvec);
79        let depth = bitvec.len() as u16;
80        bitvec.resize(256, false);
81        // Unwrap: resized to 256 bit, or 32 bytes, above.
82        let path = bitvec.as_raw_slice().try_into().unwrap();
83        Self {
84            path,
85            depth,
86            node_index,
87        }
88    }
89
90    /// Whether the position is at the root.
91    pub fn is_root(&self) -> bool {
92        self.depth == 0
93    }
94
95    /// Get the current `depth` of the position.
96    pub fn depth(&self) -> u16 {
97        self.depth
98    }
99
100    /// Get the path to the current position.
101    pub fn path(&self) -> &BitSlice<u8, Msb0> {
102        &self.path.view_bits::<Msb0>()[..self.depth as usize]
103    }
104
105    /// Get the raw key at the current position.
106    ///
107    /// Note that if you have called `up`, this might have bits beyond `depth` which are set.
108    pub fn raw_path(&self) -> [u8; 32] {
109        self.path
110    }
111
112    /// Move the position down by 1, towards either the left or right child.
113    ///
114    /// Panics on depth out of range.
115    pub fn down(&mut self, bit: bool) {
116        assert_ne!(self.depth, 256, "can't descend past 256 bits");
117        if self.depth as usize % DEPTH == 0 {
118            self.node_index = bit as usize;
119        } else {
120            let children = self.child_node_indices();
121            self.node_index = if bit {
122                children.right()
123            } else {
124                children.left()
125            };
126        }
127        self.path
128            .view_bits_mut::<Msb0>()
129            .set(self.depth as usize, bit);
130        self.depth += 1;
131    }
132
133    /// Move the position up by `d` bits.
134    ///
135    /// Panics if `d` is greater than the current depth.
136    pub fn up(&mut self, d: u16) {
137        let prev_depth = self.depth;
138        let Some(new_depth) = self.depth.checked_sub(d) else {
139            panic!("can't move up by {} bits from depth {}", d, prev_depth)
140        };
141        if new_depth == 0 {
142            *self = TriePosition::new();
143            return;
144        }
145
146        self.depth = new_depth;
147        let prev_page_depth = (prev_depth as usize + DEPTH - 1) / DEPTH;
148        let new_page_depth = (self.depth as usize + DEPTH - 1) / DEPTH;
149        if prev_page_depth == new_page_depth {
150            for _ in 0..d {
151                self.node_index = parent_node_index(self.node_index);
152            }
153        } else {
154            let path = last_page_path(&self.path, self.depth);
155            self.node_index = node_index(path);
156        }
157    }
158
159    /// Move the position to the sibling node.
160    ///
161    /// Panic if at the root.
162    pub fn sibling(&mut self) {
163        assert_ne!(self.depth, 0, "can't move to sibling of root node");
164        let bits = self.path.view_bits_mut::<Msb0>();
165        let i = self.depth as usize - 1;
166        bits.set(i, !bits[i]);
167        self.node_index = sibling_index(self.node_index);
168    }
169
170    /// Peek at the last bit of the path.
171    ///
172    /// Panics if at the root.
173    pub fn peek_last_bit(&self) -> bool {
174        assert_ne!(self.depth, 0, "can't peek at root node");
175        let this_bit_idx = self.depth as usize - 1;
176        // unwrap: depth != 0 above
177        let bit = *self.path.view_bits::<Msb0>().get(this_bit_idx).unwrap();
178        bit
179    }
180
181    /// Get the page ID this position lands in. Returns `None` at the root.
182    pub fn page_id(&self) -> Option<PageId> {
183        if self.is_root() {
184            return None;
185        }
186
187        let mut page_id = ROOT_PAGE_ID;
188        for (i, chunk) in self.path().chunks_exact(DEPTH).enumerate() {
189            if (i + 1) * DEPTH == self.depth as usize {
190                return Some(page_id);
191            }
192
193            // UNWRAP: 6 bits never overflows child page index
194            let child_index = ChildPageIndex::new(chunk.load_be::<u8>()).unwrap();
195
196            // UNWRAP: trie position never overflows page tree.
197            page_id = page_id.child_page_id(child_index).unwrap();
198        }
199
200        Some(page_id)
201    }
202
203    /// Get the child page index, relative to the current page,
204    /// where the children of the current node are stored.
205    ///
206    /// Panics if the position is not in the last layer of the page.
207    pub fn child_page_index(&self) -> ChildPageIndex {
208        assert!(self.node_index >= 62);
209        ChildPageIndex::new(bottom_node_index(self.node_index)).unwrap()
210    }
211
212    /// Get the child page index, relative to the current page,
213    /// where the children of the sibling node are stored.
214    ///
215    /// Panics if the position is not in the last layer of the page.
216    pub fn sibling_child_page_index(&self) -> ChildPageIndex {
217        ChildPageIndex::new(bottom_node_index(sibling_index(self.node_index))).unwrap()
218    }
219
220    /// Transform a bit-path to the index in a page corresponding to the child node indices.
221    ///
222    /// Panics if the node is not at a depth in the range 1..=5
223    pub fn child_node_indices(&self) -> ChildNodeIndices {
224        let depth = self.depth_in_page();
225        if depth == 0 || depth > DEPTH - 1 {
226            panic!("{depth} out of bounds 1..={}", DEPTH - 1);
227        }
228        let left = self.node_index * 2 + 2;
229        ChildNodeIndices(left)
230    }
231
232    /// Get the index of the sibling node within a page.
233    pub fn sibling_index(&self) -> usize {
234        sibling_index(self.node_index)
235    }
236
237    /// Get the index of the current node within a page.
238    pub fn node_index(&self) -> usize {
239        self.node_index
240    }
241
242    /// Get the number of bits traversed in the current page.
243    ///
244    /// Note that every page has traversed at least 1 bit, therefore the return value would be
245    /// between 1 and `DEPTH`, with the exception of the root node, which has traversed 0 bits.
246    pub fn depth_in_page(&self) -> usize {
247        if self.depth == 0 {
248            0
249        } else {
250            self.depth as usize - ((self.depth as usize - 1) / DEPTH) * DEPTH
251        }
252    }
253
254    /// Fast path for checking whether this is in the first layer in the page.
255    pub fn is_first_layer_in_page(&self) -> bool {
256        self.node_index & !1 == 0
257    }
258
259    /// Get the number of shared bits between this position and `other`.
260    ///
261    /// This is essentially the depth of a hypothetical internal node which both positions would
262    /// descend from.
263    pub fn shared_depth(&self, other: &Self) -> usize {
264        crate::update::shared_bits(self.path(), other.path())
265    }
266
267    /// Whether the sub-trie indicated by this position would contain
268    /// a given key-path.
269    pub fn subtrie_contains(&self, path: &crate::trie::KeyPath) -> bool {
270        path.view_bits::<Msb0>()
271            .starts_with(&self.path.view_bits::<Msb0>()[..self.depth as usize])
272    }
273}
274
275// extract the relevant portion of the key path to the last page. panics on empty path.
276fn last_page_path(path: &[u8; 32], depth: u16) -> &BitSlice<u8, Msb0> {
277    if depth == 0 {
278        panic!();
279    }
280    let prev_page_end = ((depth as usize - 1) / DEPTH) * DEPTH;
281    &path.view_bits::<Msb0>()[prev_page_end..depth as usize]
282}
283
284// Transform a bit-path to an index in a page.
285//
286// The expected length of the page path is between 1 and `DEPTH`, inclusive. A length of 0 returns
287// 0 and all bits beyond `DEPTH` are ignored.
288fn node_index(page_path: &BitSlice<u8, Msb0>) -> usize {
289    let depth = core::cmp::min(DEPTH, page_path.len());
290
291    if depth == 0 {
292        0
293    } else {
294        // each node is stored at (2^depth - 2) + as_uint(path)
295        (1 << depth) - 2 + page_path[..depth].load_be::<usize>()
296    }
297}
298
299fn bottom_node_index(node_index: usize) -> u8 {
300    node_index as u8 - 62
301}
302
303/// Given a node index, get the index of the sibling.
304fn sibling_index(node_index: usize) -> usize {
305    if node_index % 2 == 0 {
306        node_index + 1
307    } else {
308        node_index - 1
309    }
310}
311
312// Transform a node index to the index where the parent node is stored
313// Id does not check for an overflow of the maximum valid node index
314// and panics if the provided node_index is one of the first two
315// nodes in a page, thus node_index 0 or 1
316fn parent_node_index(node_index: usize) -> usize {
317    (node_index - 2) / 2
318}
319
320impl fmt::Debug for TriePosition {
321    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322        if self.depth == 0 {
323            write!(f, "TriePosition(root)")
324        } else {
325            write!(f, "TriePosition({})", self.path(),)
326        }
327    }
328}
329
330/// A helper type representing two child node indices within a page.
331#[derive(Debug, Clone, Copy)]
332pub struct ChildNodeIndices(usize);
333
334impl ChildNodeIndices {
335    /// Create from a left child index.
336    pub fn from_left(left: usize) -> Self {
337        ChildNodeIndices(left)
338    }
339
340    /// Whether these are at the top of a page.
341    pub fn in_next_page(&self) -> bool {
342        self.0 == 0
343    }
344
345    /// Get the index of the left child.
346    pub fn left(&self) -> usize {
347        self.0
348    }
349    /// Get the index of the right child.
350    pub fn right(&self) -> usize {
351        self.0 + 1
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::TriePosition;
358
359    #[test]
360    fn path_can_go_deeper_255_bit() {
361        let mut p = TriePosition::from_str(
362            "1010101010101010101010101010101010101010101010101010101010101010\
363            1010101010101010101010101010101010101010101010101010101010101010\
364            1010101010101010101010101010101010101010101010101010101010101010\
365            101010101010101010101010101010101010101010101010101010101010101",
366        );
367        assert_eq!(p.depth as usize, 255);
368        p.down(false);
369    }
370}