1use crate::{
2 page::DEPTH,
3 page_id::{ChildPageIndex, PageId, ROOT_PAGE_ID},
4 trie::KeyPath,
5};
6use alloc::fmt;
7use bitvec::prelude::*;
8
9#[derive(Clone)]
11#[cfg_attr(
12 feature = "borsh",
13 derive(borsh::BorshDeserialize, borsh::BorshSerialize)
14)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16pub struct TriePosition {
17 path: [u8; 32],
19 depth: u16,
20 node_index: usize,
21}
22
23impl PartialEq for TriePosition {
24 fn eq(&self, other: &Self) -> bool {
25 self.path() == other.path()
26 }
27}
28
29impl Eq for TriePosition {}
30
31impl TriePosition {
32 pub fn new() -> Self {
34 TriePosition {
35 path: [0; 32],
36 depth: 0,
37 node_index: 0,
38 }
39 }
40
41 pub fn from_path_and_depth(path: KeyPath, depth: u16) -> Self {
45 assert_ne!(depth, 0, "depth must be non-zero");
46 assert!(depth <= 256);
47 let page_path = last_page_path(&path, depth);
48 TriePosition {
49 path,
50 depth,
51 node_index: node_index(&page_path),
52 }
53 }
54
55 pub fn from_bitslice(slice: &BitSlice<u8, Msb0>) -> Self {
57 assert!(slice.len() <= 256);
58
59 let mut path = [0; 32];
60 path.view_bits_mut::<Msb0>()[..slice.len()].copy_from_bitslice(slice);
61 Self::from_path_and_depth(path, slice.len() as u16)
62 }
63
64 #[cfg(test)]
66 pub fn from_str(s: &str) -> Self {
67 let mut bitvec = BitVec::<u8, Msb0>::new();
68 if s.len() > 256 {
69 panic!("bit string too long");
70 }
71 for ch in s.chars() {
72 match ch {
73 '0' => bitvec.push(false),
74 '1' => bitvec.push(true),
75 _ => panic!("invalid character in bit string"),
76 }
77 }
78 let node_index = node_index(&bitvec);
79 let depth = bitvec.len() as u16;
80 bitvec.resize(256, false);
81 let path = bitvec.as_raw_slice().try_into().unwrap();
83 Self {
84 path,
85 depth,
86 node_index,
87 }
88 }
89
90 pub fn is_root(&self) -> bool {
92 self.depth == 0
93 }
94
95 pub fn depth(&self) -> u16 {
97 self.depth
98 }
99
100 pub fn path(&self) -> &BitSlice<u8, Msb0> {
102 &self.path.view_bits::<Msb0>()[..self.depth as usize]
103 }
104
105 pub fn raw_path(&self) -> [u8; 32] {
109 self.path
110 }
111
112 pub fn down(&mut self, bit: bool) {
116 assert_ne!(self.depth, 256, "can't descend past 256 bits");
117 if self.depth as usize % DEPTH == 0 {
118 self.node_index = bit as usize;
119 } else {
120 let children = self.child_node_indices();
121 self.node_index = if bit {
122 children.right()
123 } else {
124 children.left()
125 };
126 }
127 self.path
128 .view_bits_mut::<Msb0>()
129 .set(self.depth as usize, bit);
130 self.depth += 1;
131 }
132
133 pub fn up(&mut self, d: u16) {
137 let prev_depth = self.depth;
138 let Some(new_depth) = self.depth.checked_sub(d) else {
139 panic!("can't move up by {} bits from depth {}", d, prev_depth)
140 };
141 if new_depth == 0 {
142 *self = TriePosition::new();
143 return;
144 }
145
146 self.depth = new_depth;
147 let prev_page_depth = (prev_depth as usize + DEPTH - 1) / DEPTH;
148 let new_page_depth = (self.depth as usize + DEPTH - 1) / DEPTH;
149 if prev_page_depth == new_page_depth {
150 for _ in 0..d {
151 self.node_index = parent_node_index(self.node_index);
152 }
153 } else {
154 let path = last_page_path(&self.path, self.depth);
155 self.node_index = node_index(path);
156 }
157 }
158
159 pub fn sibling(&mut self) {
163 assert_ne!(self.depth, 0, "can't move to sibling of root node");
164 let bits = self.path.view_bits_mut::<Msb0>();
165 let i = self.depth as usize - 1;
166 bits.set(i, !bits[i]);
167 self.node_index = sibling_index(self.node_index);
168 }
169
170 pub fn peek_last_bit(&self) -> bool {
174 assert_ne!(self.depth, 0, "can't peek at root node");
175 let this_bit_idx = self.depth as usize - 1;
176 let bit = *self.path.view_bits::<Msb0>().get(this_bit_idx).unwrap();
178 bit
179 }
180
181 pub fn page_id(&self) -> Option<PageId> {
183 if self.is_root() {
184 return None;
185 }
186
187 let mut page_id = ROOT_PAGE_ID;
188 for (i, chunk) in self.path().chunks_exact(DEPTH).enumerate() {
189 if (i + 1) * DEPTH == self.depth as usize {
190 return Some(page_id);
191 }
192
193 let child_index = ChildPageIndex::new(chunk.load_be::<u8>()).unwrap();
195
196 page_id = page_id.child_page_id(child_index).unwrap();
198 }
199
200 Some(page_id)
201 }
202
203 pub fn child_page_index(&self) -> ChildPageIndex {
208 assert!(self.node_index >= 62);
209 ChildPageIndex::new(bottom_node_index(self.node_index)).unwrap()
210 }
211
212 pub fn sibling_child_page_index(&self) -> ChildPageIndex {
217 ChildPageIndex::new(bottom_node_index(sibling_index(self.node_index))).unwrap()
218 }
219
220 pub fn child_node_indices(&self) -> ChildNodeIndices {
224 let depth = self.depth_in_page();
225 if depth == 0 || depth > DEPTH - 1 {
226 panic!("{depth} out of bounds 1..={}", DEPTH - 1);
227 }
228 let left = self.node_index * 2 + 2;
229 ChildNodeIndices(left)
230 }
231
232 pub fn sibling_index(&self) -> usize {
234 sibling_index(self.node_index)
235 }
236
237 pub fn node_index(&self) -> usize {
239 self.node_index
240 }
241
242 pub fn depth_in_page(&self) -> usize {
247 if self.depth == 0 {
248 0
249 } else {
250 self.depth as usize - ((self.depth as usize - 1) / DEPTH) * DEPTH
251 }
252 }
253
254 pub fn is_first_layer_in_page(&self) -> bool {
256 self.node_index & !1 == 0
257 }
258
259 pub fn shared_depth(&self, other: &Self) -> usize {
264 crate::update::shared_bits(self.path(), other.path())
265 }
266
267 pub fn subtrie_contains(&self, path: &crate::trie::KeyPath) -> bool {
270 path.view_bits::<Msb0>()
271 .starts_with(&self.path.view_bits::<Msb0>()[..self.depth as usize])
272 }
273}
274
275fn last_page_path(path: &[u8; 32], depth: u16) -> &BitSlice<u8, Msb0> {
277 if depth == 0 {
278 panic!();
279 }
280 let prev_page_end = ((depth as usize - 1) / DEPTH) * DEPTH;
281 &path.view_bits::<Msb0>()[prev_page_end..depth as usize]
282}
283
284fn node_index(page_path: &BitSlice<u8, Msb0>) -> usize {
289 let depth = core::cmp::min(DEPTH, page_path.len());
290
291 if depth == 0 {
292 0
293 } else {
294 (1 << depth) - 2 + page_path[..depth].load_be::<usize>()
296 }
297}
298
299fn bottom_node_index(node_index: usize) -> u8 {
300 node_index as u8 - 62
301}
302
303fn sibling_index(node_index: usize) -> usize {
305 if node_index % 2 == 0 {
306 node_index + 1
307 } else {
308 node_index - 1
309 }
310}
311
312fn parent_node_index(node_index: usize) -> usize {
317 (node_index - 2) / 2
318}
319
320impl fmt::Debug for TriePosition {
321 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322 if self.depth == 0 {
323 write!(f, "TriePosition(root)")
324 } else {
325 write!(f, "TriePosition({})", self.path(),)
326 }
327 }
328}
329
330#[derive(Debug, Clone, Copy)]
332pub struct ChildNodeIndices(usize);
333
334impl ChildNodeIndices {
335 pub fn from_left(left: usize) -> Self {
337 ChildNodeIndices(left)
338 }
339
340 pub fn in_next_page(&self) -> bool {
342 self.0 == 0
343 }
344
345 pub fn left(&self) -> usize {
347 self.0
348 }
349 pub fn right(&self) -> usize {
351 self.0 + 1
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::TriePosition;
358
359 #[test]
360 fn path_can_go_deeper_255_bit() {
361 let mut p = TriePosition::from_str(
362 "1010101010101010101010101010101010101010101010101010101010101010\
363 1010101010101010101010101010101010101010101010101010101010101010\
364 1010101010101010101010101010101010101010101010101010101010101010\
365 101010101010101010101010101010101010101010101010101010101010101",
366 );
367 assert_eq!(p.depth as usize, 255);
368 p.down(false);
369 }
370}