use crate::{
nibbles::Nibbles,
partial_trie::{Node, PartialTrie, WrappedNode},
utils::TrieSegment,
};
#[derive(Clone, Debug, Hash)]
pub struct TriePathIter<N: PartialTrie> {
curr_node: WrappedNode<N>,
curr_key: Nibbles,
terminated: bool,
always_include_final_node_if_possible: bool,
}
impl<T: PartialTrie> Iterator for TriePathIter<T> {
type Item = TrieSegment;
fn next(&mut self) -> Option<Self::Item> {
if self.terminated {
return None;
}
match self.curr_node.as_ref() {
Node::Empty => {
self.terminated = true;
Some(TrieSegment::Empty)
}
Node::Hash(_) => {
self.terminated = true;
Some(TrieSegment::Hash)
}
Node::Branch { children, .. } => {
if self.curr_key.is_empty() {
self.terminated = true;
return None;
}
let nib = self.curr_key.pop_next_nibble_front();
self.curr_node = children[nib as usize].clone();
Some(TrieSegment::Branch(nib))
}
Node::Extension { nibbles, child } => {
match self
.curr_key
.nibbles_are_identical_up_to_smallest_count(nibbles)
{
false => {
self.terminated = true;
self.always_include_final_node_if_possible
.then_some(TrieSegment::Extension(*nibbles))
}
true => {
pop_nibbles_clamped(&mut self.curr_key, nibbles.count);
let res = Some(TrieSegment::Extension(*nibbles));
self.curr_node = child.clone();
res
}
}
}
Node::Leaf { nibbles, .. } => {
self.terminated = true;
match self.curr_key == *nibbles || self.always_include_final_node_if_possible {
false => None,
true => Some(TrieSegment::Leaf(*nibbles)),
}
}
}
}
}
fn pop_nibbles_clamped(nibbles: &mut Nibbles, n: usize) -> Nibbles {
let n_nibs_to_pop = nibbles.count.min(n);
nibbles.pop_nibbles_front(n_nibs_to_pop)
}
pub fn path_for_query<K, T: PartialTrie>(
trie: &Node<T>,
k: K,
always_include_final_node_if_possible: bool,
) -> TriePathIter<T>
where
K: Into<Nibbles>,
{
TriePathIter {
curr_node: trie.clone().into(),
curr_key: k.into(),
terminated: false,
always_include_final_node_if_possible,
}
}
#[cfg(test)]
mod test {
use std::str::FromStr;
use super::path_for_query;
use crate::{
nibbles::Nibbles,
testing_utils::{common_setup, handmade_trie_1},
trie_ops::TrieOpResult,
utils::TrieSegment,
};
#[test]
fn query_iter_works_no_last_node() -> TrieOpResult<()> {
common_setup();
let (trie, ks) = handmade_trie_1()?;
let res = vec![
vec![
TrieSegment::Branch(1),
TrieSegment::Branch(2),
TrieSegment::Leaf(0x34.into()),
],
vec![
TrieSegment::Branch(1),
TrieSegment::Branch(3),
TrieSegment::Extension(0x24.into()),
],
vec![
TrieSegment::Branch(1),
TrieSegment::Branch(3),
TrieSegment::Extension(0x24.into()),
TrieSegment::Branch(0),
TrieSegment::Leaf(Nibbles::from_str("0x0005").unwrap()),
],
vec![
TrieSegment::Branch(2),
TrieSegment::Extension(Nibbles::from_str("0x00").unwrap()),
TrieSegment::Branch(0x1),
TrieSegment::Leaf(Nibbles::default()),
],
vec![
TrieSegment::Branch(2),
TrieSegment::Extension(Nibbles::from_str("0x00").unwrap()),
TrieSegment::Branch(0x2),
TrieSegment::Leaf(Nibbles::default()),
],
];
for (q, expected) in ks.into_iter().zip(res.into_iter()) {
let res: Vec<_> = path_for_query(&trie.node, q, false).collect();
assert_eq!(res, expected)
}
Ok(())
}
#[test]
fn query_iter_works_with_last_node() -> TrieOpResult<()> {
common_setup();
let (trie, _) = handmade_trie_1()?;
let extension_expected = vec![
TrieSegment::Branch(1),
TrieSegment::Branch(3),
TrieSegment::Extension(0x24.into()),
];
assert_eq!(
path_for_query(&trie, 0x13, true).collect::<Vec<_>>(),
extension_expected
);
assert_eq!(
path_for_query(&trie, 0x132, true).collect::<Vec<_>>(),
extension_expected
);
assert_eq!(
path_for_query(&trie, 0x1324, true).collect::<Vec<_>>(),
extension_expected
);
let mut leaf_expected = extension_expected.clone();
leaf_expected.push(TrieSegment::Branch(0));
leaf_expected.push(TrieSegment::Leaf(Nibbles::from_str("0x0005").unwrap()));
assert_eq!(
path_for_query(&trie, 0x13240, true).collect::<Vec<_>>(),
leaf_expected
);
assert_eq!(
path_for_query(&trie, 0x132400, true).collect::<Vec<_>>(),
leaf_expected
);
Ok(())
}
}