use core::iter;
use crate::trie;
use alloc::vec::{IntoIter, Vec};
pub use crate::trie::Nibble;
mod tests;
#[derive(Debug, Clone)]
pub struct Config<K, P> {
pub key_before: K,
pub or_equal: bool,
pub prefix: P,
pub no_branch_search: bool,
}
pub fn start_branch_search(
config: Config<impl Iterator<Item = Nibble>, impl Iterator<Item = Nibble>>,
) -> NextKey {
NextKey {
prefix: config.prefix.collect(),
key_before: config.key_before.collect(),
or_equal: config.or_equal,
inner: NextKeyInner::FirstFound {
no_branch_search: config.no_branch_search,
},
}
}
pub enum BranchSearch {
NextKey(NextKey),
Found {
branch_trie_node_key: Option<BranchTrieNodeKeyIter>,
},
}
pub struct BranchTrieNodeKeyIter {
inner: IntoIter<Nibble>,
}
impl Iterator for BranchTrieNodeKeyIter {
type Item = Nibble;
fn next(&mut self) -> Option<Self::Item> {
self.inner.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl ExactSizeIterator for BranchTrieNodeKeyIter {}
pub struct NextKey {
inner: NextKeyInner,
key_before: Vec<Nibble>,
prefix: Vec<Nibble>,
or_equal: bool,
}
enum NextKeyInner {
FirstFound {
no_branch_search: bool,
},
FurtherRound {
current_found_branch: Vec<Nibble>,
},
}
impl NextKey {
pub fn key_before(&self) -> impl Iterator<Item = u8> {
trie::nibbles_to_bytes_suffix_extend(match &self.inner {
NextKeyInner::FirstFound { .. } => either::Left(self.key_before.iter().copied()),
NextKeyInner::FurtherRound {
current_found_branch,
} => {
let num_f_nibbles_to_pop = current_found_branch
.iter()
.rev()
.take_while(|n| **n == Nibble::max())
.count();
debug_assert!(num_f_nibbles_to_pop < current_found_branch.len());
let len = current_found_branch.len();
let extra_nibble = current_found_branch[len - num_f_nibbles_to_pop - 1]
.checked_add(1)
.unwrap_or_else(|| unreachable!());
either::Right(
current_found_branch
.iter()
.take(len - num_f_nibbles_to_pop - 1)
.copied()
.chain(iter::once(extra_nibble)),
)
}
})
}
pub fn or_equal(&self) -> bool {
match self.inner {
NextKeyInner::FirstFound { .. } => {
self.or_equal || (self.key_before.len() % 2 != 0)
}
NextKeyInner::FurtherRound { .. } => true,
}
}
pub fn prefix(&self) -> impl Iterator<Item = u8> {
trie::nibbles_to_bytes_truncate(self.prefix.iter().copied())
}
pub fn inject(
mut self,
storage_trie_node_key: Option<impl Iterator<Item = u8>>,
) -> BranchSearch {
match (self.inner, storage_trie_node_key) {
(NextKeyInner::FirstFound { .. }, None) => BranchSearch::Found {
branch_trie_node_key: None,
},
(
NextKeyInner::FirstFound {
no_branch_search: true,
..
},
Some(storage_trie_node_key),
) => {
let storage_trie_node_key =
trie::bytes_to_nibbles(storage_trie_node_key).collect::<Vec<_>>();
debug_assert!(storage_trie_node_key >= self.key_before);
if !storage_trie_node_key.starts_with(&self.prefix) {
return BranchSearch::Found {
branch_trie_node_key: None,
};
}
BranchSearch::Found {
branch_trie_node_key: Some(BranchTrieNodeKeyIter {
inner: storage_trie_node_key.into_iter(),
}),
}
}
(
NextKeyInner::FirstFound {
no_branch_search: false,
..
},
Some(storage_trie_node_key),
) => {
let storage_trie_node_key =
trie::bytes_to_nibbles(storage_trie_node_key).collect::<Vec<_>>();
debug_assert!(storage_trie_node_key >= self.key_before);
if !storage_trie_node_key.starts_with(&self.prefix) {
return BranchSearch::Found {
branch_trie_node_key: None,
};
}
if storage_trie_node_key.is_empty()
|| storage_trie_node_key.iter().all(|n| *n == Nibble::max())
{
return BranchSearch::Found {
branch_trie_node_key: Some(BranchTrieNodeKeyIter {
inner: storage_trie_node_key.into_iter(),
}),
};
}
self.inner = NextKeyInner::FurtherRound {
current_found_branch: storage_trie_node_key,
};
BranchSearch::NextKey(self)
}
(
NextKeyInner::FurtherRound {
mut current_found_branch,
},
Some(storage_trie_node_key),
) => {
let storage_trie_node_key = trie::bytes_to_nibbles(storage_trie_node_key);
let num_common = storage_trie_node_key
.zip(current_found_branch.iter())
.take_while(|(a, b)| a == *b)
.count();
debug_assert!(num_common < current_found_branch.len());
if !current_found_branch[..num_common].starts_with(&self.prefix)
|| ¤t_found_branch[..num_common] < &self.key_before
|| (!self.or_equal && current_found_branch[..num_common] == self.key_before)
{
return BranchSearch::Found {
branch_trie_node_key: Some(BranchTrieNodeKeyIter {
inner: current_found_branch.into_iter(),
}),
};
}
current_found_branch.truncate(num_common);
if current_found_branch.is_empty()
|| current_found_branch.iter().all(|n| *n == Nibble::max())
{
return BranchSearch::Found {
branch_trie_node_key: Some(BranchTrieNodeKeyIter {
inner: current_found_branch.into_iter(),
}),
};
}
self.inner = NextKeyInner::FurtherRound {
current_found_branch,
};
BranchSearch::NextKey(self)
}
(
NextKeyInner::FurtherRound {
current_found_branch,
},
None,
) => BranchSearch::Found {
branch_trie_node_key: Some(BranchTrieNodeKeyIter {
inner: current_found_branch.into_iter(),
}),
},
}
}
}