use crate::{
page::DEPTH,
page_id::{ChildPageIndex, PageId, ROOT_PAGE_ID},
trie::KeyPath,
};
use alloc::fmt;
use bitvec::prelude::*;
#[derive(Clone)]
#[cfg_attr(
feature = "borsh",
derive(borsh::BorshDeserialize, borsh::BorshSerialize)
)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TriePosition {
path: [u8; 32],
depth: u16,
node_index: usize,
}
impl PartialEq for TriePosition {
fn eq(&self, other: &Self) -> bool {
self.path() == other.path()
}
}
impl Eq for TriePosition {}
impl TriePosition {
pub fn new() -> Self {
TriePosition {
path: [0; 32],
depth: 0,
node_index: 0,
}
}
pub fn from_path_and_depth(path: KeyPath, depth: u16) -> Self {
assert_ne!(depth, 0, "depth must be non-zero");
assert!(depth <= 256);
let page_path = last_page_path(&path, depth);
TriePosition {
path,
depth,
node_index: node_index(&page_path),
}
}
pub fn from_bitslice(slice: &BitSlice<u8, Msb0>) -> Self {
assert!(slice.len() <= 256);
let mut path = [0; 32];
path.view_bits_mut::<Msb0>()[..slice.len()].copy_from_bitslice(slice);
Self::from_path_and_depth(path, slice.len() as u16)
}
#[cfg(test)]
pub fn from_str(s: &str) -> Self {
let mut bitvec = BitVec::<u8, Msb0>::new();
if s.len() > 256 {
panic!("bit string too long");
}
for ch in s.chars() {
match ch {
'0' => bitvec.push(false),
'1' => bitvec.push(true),
_ => panic!("invalid character in bit string"),
}
}
let node_index = node_index(&bitvec);
let depth = bitvec.len() as u16;
bitvec.resize(256, false);
let path = bitvec.as_raw_slice().try_into().unwrap();
Self {
path,
depth,
node_index,
}
}
pub fn is_root(&self) -> bool {
self.depth == 0
}
pub fn depth(&self) -> u16 {
self.depth
}
pub fn path(&self) -> &BitSlice<u8, Msb0> {
&self.path.view_bits::<Msb0>()[..self.depth as usize]
}
pub fn raw_path(&self) -> [u8; 32] {
self.path
}
pub fn down(&mut self, bit: bool) {
assert_ne!(self.depth, 256, "can't descend past 256 bits");
if self.depth as usize % DEPTH == 0 {
self.node_index = bit as usize;
} else {
let children = self.child_node_indices();
self.node_index = if bit {
children.right()
} else {
children.left()
};
}
self.path
.view_bits_mut::<Msb0>()
.set(self.depth as usize, bit);
self.depth += 1;
}
pub fn up(&mut self, d: u16) {
let prev_depth = self.depth;
let Some(new_depth) = self.depth.checked_sub(d) else {
panic!("can't move up by {} bits from depth {}", d, prev_depth)
};
if new_depth == 0 {
*self = TriePosition::new();
return;
}
self.depth = new_depth;
let prev_page_depth = (prev_depth as usize + DEPTH - 1) / DEPTH;
let new_page_depth = (self.depth as usize + DEPTH - 1) / DEPTH;
if prev_page_depth == new_page_depth {
for _ in 0..d {
self.node_index = parent_node_index(self.node_index);
}
} else {
let path = last_page_path(&self.path, self.depth);
self.node_index = node_index(path);
}
}
pub fn sibling(&mut self) {
assert_ne!(self.depth, 0, "can't move to sibling of root node");
let bits = self.path.view_bits_mut::<Msb0>();
let i = self.depth as usize - 1;
bits.set(i, !bits[i]);
self.node_index = sibling_index(self.node_index);
}
pub fn peek_last_bit(&self) -> bool {
assert_ne!(self.depth, 0, "can't peek at root node");
let this_bit_idx = self.depth as usize - 1;
let bit = *self.path.view_bits::<Msb0>().get(this_bit_idx).unwrap();
bit
}
pub fn page_id(&self) -> Option<PageId> {
if self.is_root() {
return None;
}
let mut page_id = ROOT_PAGE_ID;
for (i, chunk) in self.path().chunks_exact(DEPTH).enumerate() {
if (i + 1) * DEPTH == self.depth as usize {
return Some(page_id);
}
let child_index = ChildPageIndex::new(chunk.load_be::<u8>()).unwrap();
page_id = page_id.child_page_id(child_index).unwrap();
}
Some(page_id)
}
pub fn child_page_index(&self) -> ChildPageIndex {
assert!(self.node_index >= 62);
ChildPageIndex::new(bottom_node_index(self.node_index)).unwrap()
}
pub fn sibling_child_page_index(&self) -> ChildPageIndex {
ChildPageIndex::new(bottom_node_index(sibling_index(self.node_index))).unwrap()
}
pub fn child_node_indices(&self) -> ChildNodeIndices {
let depth = self.depth_in_page();
if depth == 0 || depth > DEPTH - 1 {
panic!("{depth} out of bounds 1..={}", DEPTH - 1);
}
let left = self.node_index * 2 + 2;
ChildNodeIndices(left)
}
pub fn sibling_index(&self) -> usize {
sibling_index(self.node_index)
}
pub fn node_index(&self) -> usize {
self.node_index
}
pub fn depth_in_page(&self) -> usize {
if self.depth == 0 {
0
} else {
self.depth as usize - ((self.depth as usize - 1) / DEPTH) * DEPTH
}
}
pub fn is_first_layer_in_page(&self) -> bool {
self.node_index & !1 == 0
}
pub fn shared_depth(&self, other: &Self) -> usize {
crate::update::shared_bits(self.path(), other.path())
}
pub fn subtrie_contains(&self, path: &crate::trie::KeyPath) -> bool {
path.view_bits::<Msb0>()
.starts_with(&self.path.view_bits::<Msb0>()[..self.depth as usize])
}
}
fn last_page_path(path: &[u8; 32], depth: u16) -> &BitSlice<u8, Msb0> {
if depth == 0 {
panic!();
}
let prev_page_end = ((depth as usize - 1) / DEPTH) * DEPTH;
&path.view_bits::<Msb0>()[prev_page_end..depth as usize]
}
fn node_index(page_path: &BitSlice<u8, Msb0>) -> usize {
let depth = core::cmp::min(DEPTH, page_path.len());
if depth == 0 {
0
} else {
(1 << depth) - 2 + page_path[..depth].load_be::<usize>()
}
}
fn bottom_node_index(node_index: usize) -> u8 {
node_index as u8 - 62
}
fn sibling_index(node_index: usize) -> usize {
if node_index % 2 == 0 {
node_index + 1
} else {
node_index - 1
}
}
fn parent_node_index(node_index: usize) -> usize {
(node_index - 2) / 2
}
impl fmt::Debug for TriePosition {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.depth == 0 {
write!(f, "TriePosition(root)")
} else {
write!(f, "TriePosition({})", self.path(),)
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChildNodeIndices(usize);
impl ChildNodeIndices {
pub fn from_left(left: usize) -> Self {
ChildNodeIndices(left)
}
pub fn in_next_page(&self) -> bool {
self.0 == 0
}
pub fn left(&self) -> usize {
self.0
}
pub fn right(&self) -> usize {
self.0 + 1
}
}
#[cfg(test)]
mod tests {
use super::TriePosition;
#[test]
fn path_can_go_deeper_255_bit() {
let mut p = TriePosition::from_str(
"1010101010101010101010101010101010101010101010101010101010101010\
1010101010101010101010101010101010101010101010101010101010101010\
1010101010101010101010101010101010101010101010101010101010101010\
101010101010101010101010101010101010101010101010101010101010101",
);
assert_eq!(p.depth as usize, 255);
p.down(false);
}
}