use std::{
borrow::Borrow, cmp::Ordering, error::Error, fmt::Debug, iter::Peekable, marker::PhantomData, ops::{Deref, DerefMut, Index, IndexMut}, slice::GetDisjointMutError, vec::IntoIter
};
use list::{DisjointMutIndices, Node, NodeIndex, Slots};
use crate::list::SlotsIterMut;
mod list;
pub type StrTrie<V> = Trie<char, V>;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub struct Trie<K, V> {
node: Slots<K, V>,
size: usize,
}
#[derive(Debug)]
struct WalkFailure<'a, I> {
root: NodeIndex,
remainder: &'a mut I,
}
#[derive(Debug)]
struct WalkTrajectory {
path: Vec<NodeIndex>,
}
impl<K, V> Default for Trie<K, V> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct WalkCtx {
pending: Vec<Vec<NodeIndex>>,
}
struct WalkIter<'a, K, V> {
trie: &'a Trie<K, V>,
context: WalkCtx,
}
impl WalkCtx {
pub fn drive<K, V>(&mut self, trie: &Trie<K, V>) -> Option<(NodeIndex, Vec<NodeIndex>)> {
while let Some(pending) = self.pending.pop() {
let current = pending.last().expect("Pending list was empty.");
for &child in trie.node[*current].subkeys().rev() {
let mut new_path = pending.clone();
new_path.push(child);
self.pending.push(new_path);
}
if trie.node[*current].value().is_some() {
return Some((*current, pending));
}
}
None
}
}
impl<K, V> Iterator for WalkIter<'_, K, V> {
type Item = (NodeIndex, Vec<NodeIndex>);
fn next(&mut self) -> Option<Self::Item> {
self.context.drive(self.trie)
}
}
impl<K, V> Trie<K, V> {
fn collect_path_keys<'a, J>(&'a self, nodes: &[NodeIndex]) -> J
where
J: FromIterator<&'a K>,
{
nodes
.iter()
.map(|f| self.node[*f].key())
.filter_map(<Option<K>>::as_ref)
.collect()
}
fn internal_walk_with_index<'b, I, B>(
&self,
remainder: &'b mut Peekable<I>,
) -> Result<NodeIndex, WalkFailure<'b, Peekable<I>>>
where
I: Iterator<Item = B>,
K: Ord,
B: Borrow<K>,
{
self.internal_walk_with_fn(remainder, |_| {})
}
fn internal_walk_with_path<'b, I, B>(
&self,
remainder: &'b mut Peekable<I>,
) -> Result<WalkTrajectory, WalkFailure<'b, Peekable<I>>>
where
I: Iterator<Item = B>,
K: Ord,
B: Borrow<K>,
{
self.internal_walk_with_path_fn(remainder, |_| {})
}
fn internal_walk_with_path_fn<'b, I, F, B>(
&self,
remainder: &'b mut Peekable<I>,
mut functor: F,
) -> Result<WalkTrajectory, WalkFailure<'b, Peekable<I>>>
where
I: Iterator<Item = B>,
K: Ord,
B: Borrow<K>,
F: FnMut(NodeIndex),
{
let mut trajectory_path = vec![];
self.internal_walk_with_fn(remainder, |node| {
trajectory_path.push(node);
functor(node);
})?;
Ok(WalkTrajectory {
path: trajectory_path,
})
}
fn fold_in_key_optionally<'a>(&'a self, node: NodeIndex, array: &mut Vec<&'a K>) {
if let Some(inner) = self.node[node].key() {
array.push(inner);
}
}
fn internal_walk_collect_key<'a, I, J, B>(
&self,
remainder: &'a mut Peekable<I>,
) -> Result<(J, NodeIndex), WalkFailure<'a, Peekable<I>>>
where
I: Iterator<Item = B>,
K: Ord,
B: Borrow<K>,
for<'b> J: FromIterator<&'b K>,
{
let mut collector = vec![];
let index = self.internal_walk_with_fn(remainder, |nk| {
self.fold_in_key_optionally(nk, &mut collector)
})?;
Ok((collector.into_iter().collect::<J>(), index))
}
fn perform_walk(&self, root: NodeIndex) -> WalkCtx {
WalkCtx {
pending: vec![vec![root]],
}
}
fn internal_walk_with_fn<'b, I, F, B>(
&self,
remainder: &'b mut Peekable<I>,
visitor_fn: F,
) -> Result<NodeIndex, WalkFailure<'b, Peekable<I>>>
where
I: Iterator<Item = B>,
B: Borrow<K>,
K: Ord,
F: FnMut(NodeIndex),
{
self.internal_walk_with_fn_cmp(remainder, visitor_fn, |a, b| a.cmp(b.borrow()))
}
fn internal_walk_with_fn_cmp<'a, F, C, A, B>(
&self,
remainder: &'a mut Peekable<A>,
mut visitor_fn: F,
mut cmp_fn: C,
) -> Result<NodeIndex, WalkFailure<'a, Peekable<A>>>
where
A: Iterator<Item = B>,
F: FnMut(NodeIndex),
for<'b> C: FnMut(&'b K, &'b B) -> Ordering,
{
visitor_fn(NodeIndex::ROOT);
let mut end = &NodeIndex::ROOT;
loop {
let Some(current) = remainder.peek() else {
break;
};
visitor_fn(*end);
if let Some(slot) = self.node[*end].get_with(&self.node, |k| cmp_fn(k, current)) {
end = slot;
} else {
return Err(WalkFailure {
root: *end,
remainder,
});
}
remainder.next();
}
visitor_fn(*end);
Ok(*end)
}
}
impl<K, V> Trie<K, V> {
pub fn new() -> Self {
Self::with_capacity(0)
}
pub fn capacity(&self) -> usize {
self.node.capacity()
}
pub fn with_capacity(slots: usize) -> Self {
Self {
node: Slots::with_capacity(slots),
size: 0,
}
}
fn lookup_key<I, B>(&self, key: I) -> Option<NodeIndex>
where
I: IntoIterator<Item = B>,
K: Ord,
B: Borrow<K>,
{
self.internal_walk_with_index(&mut key.into_iter().peekable())
.ok()
}
pub fn is_prefix<I, B>(&self, key: I) -> bool
where
I: IntoIterator<Item = B>,
K: Ord,
B: Borrow<K>,
{
self.internal_walk_with_index(&mut key.into_iter().peekable())
.is_ok()
}
pub fn search_with_score_fn<F, B>(&self, mut distance: F) -> Option<Vec<&K>>
where
K: Ord,
F: FnMut(&[&K]) -> B,
B: Ord,
{
let full_walk = WalkIter {
context: self.perform_walk(NodeIndex::ROOT),
trie: self,
};
let mut best_score = None;
let mut best_candidate = None;
for (_, node_index) in full_walk {
let collected = self.collect_path_keys::<Vec<_>>(&node_index);
let score = distance(&collected);
if best_score.is_none() || *best_score.as_ref().expect("Best score was none despite us asserting it was not.") < score {
best_score = Some(score);
best_candidate = Some(collected);
}
}
best_candidate
}
pub fn completions<'a, I, B, J>(&'a self, key: I) -> CompletionIter<'a, K, V, J>
where
I: IntoIterator<Item = B>,
K: Ord,
J: FromIterator<B>,
B: Borrow<K>,
{
let mut collector = vec![];
match self.internal_walk_with_fn(key.into_iter().peekable().by_ref(), |nk| {
self.fold_in_key_optionally(nk, &mut collector)
}) {
Ok(inner) => CompletionIter {
beginning: collector,
inner: self.perform_walk(inner),
trie: self,
_transform: PhantomData,
},
Err(_) => {
CompletionIter {
beginning: vec![],
_transform: PhantomData,
inner: WalkCtx { pending: vec![] },
trie: self,
}
}
}
}
pub fn postfix_search<'a, I, B, J>(&'a self, key: I) -> PostfixIter<'a, K, V, J>
where
I: IntoIterator<Item = B>,
K: Ord,
J: FromIterator<K>,
B: Borrow<K>,
{
let mut collector = vec![];
match self.internal_walk_with_fn(key.into_iter().peekable().by_ref(), |nk| {
self.fold_in_key_optionally(nk, &mut collector)
}) {
Ok(inner) => PostfixIter {
inner: self.perform_walk(inner),
trie: self,
_transform: PhantomData,
},
Err(_) => {
PostfixIter {
_transform: PhantomData,
inner: WalkCtx { pending: vec![] },
trie: self,
}
}
}
}
pub fn values(&self) -> ValueIterRef<'_, K, V> {
ValueIterRef {
values: self.node.slot_iter(),
}
}
pub fn values_mut(&mut self) -> ValueIterMut<'_, K, V> {
ValueIterMut {
values: self.node.slot_iter_mut(),
}
}
pub fn drain<J>(&mut self) -> Drain<'_, K, V, J>
where
for<'b> J: FromIterator<&'b K>,
{
let keys = self.keys::<J>().collect::<Vec<_>>();
Drain {
key_iter: keys.into_iter(),
inner: self.node.drain_slots(),
}
}
pub fn into_values(mut self) -> ValueIter<V> {
let values = self
.node
.drain_slots()
.flatten()
.filter_map(Node::into_value)
.collect::<Vec<V>>();
ValueIter {
inner: values.into_iter(),
}
}
pub fn longest_prefix_entry<I, J, B>(&self, key: I) -> Option<(J, &V)>
where
I: IntoIterator<Item = B>,
B: Borrow<K>,
for<'b> J: FromIterator<&'b K>,
K: Ord,
{
let mut last_productive = NodeIndex::ROOT;
let mut position = 0;
let mut last_productive_position = 0;
let mut collector = vec![];
self.find_longest_prefix_fn(key, |nk, k| {
if self.node[nk].value().is_some() {
last_productive = nk;
last_productive_position = position;
}
collector.push(k);
position += 1;
})?;
let value = self.node[last_productive].value().as_ref()?;
let key: J = collector[0..last_productive_position + 1]
.iter()
.copied()
.collect();
Some((key, value))
}
pub fn longest_prefix<I, B>(&self, key: I) -> Option<&V>
where
I: IntoIterator<Item = B>,
K: Ord,
B: Borrow<K>,
{
let mut last_productive = NodeIndex::ROOT;
self.find_longest_prefix_fn(key, |nk, _| {
if self.node[nk].value().is_some() {
last_productive = nk;
}
})?;
self.node[last_productive].value().as_ref()
}
fn find_longest_prefix_fn<'a, I, F, B>(&'a self, key: I, mut functor: F) -> Option<NodeIndex>
where
I: IntoIterator<Item = B>,
K: Ord,
B: Borrow<K>,
F: FnMut(NodeIndex, &'a K),
{
match self.internal_walk_with_fn(key.into_iter().peekable().by_ref(), |nk| {
if let Some(inner) = self.node[nk].key() {
functor(nk, inner);
}
}) {
Ok(end) => Some(end),
Err(WalkFailure { root, .. }) => {
Some(root)
}
}
}
pub fn keys<'a, J>(&'a self) -> KeyIter<'a, K, V, J>
where
J: FromIterator<&'a K>,
{
KeyIter {
inner: WalkIter {
context: self.perform_walk(NodeIndex::ROOT),
trie: self,
},
_type: PhantomData,
}
}
pub fn into_entries<J>(self) -> IntoEntryIter<K, V, J>
where
for<'a> J: FromIterator<&'a K>,
{
IntoEntryIter {
inner: self.perform_walk(NodeIndex::ROOT),
trie: self,
_type: PhantomData,
}
}
pub fn entries<'a, J>(&'a self) -> EntryIterRef<'a, K, V, J>
where
J: FromIterator<&'a K>,
{
EntryIterRef {
inner: WalkIter {
context: self.perform_walk(NodeIndex::ROOT),
trie: self,
},
_type: PhantomData,
}
}
pub fn entries_mut<J>(&mut self) -> EntryIterMut<'_, K, V, J>
where
for<'a> J: FromIterator<&'a K>,
{
EntryIterMut {
inner: self.perform_walk(NodeIndex::ROOT),
trie: self,
_type: PhantomData,
}
}
pub fn get<I, B>(&self, key: I) -> Option<&V>
where
I: IntoIterator<Item = B>,
K: Ord,
B: Borrow<K>,
{
self.node[self.lookup_key(key)?].value().as_ref()
}
fn try_get_key_map<I, B, const N: usize>(
&mut self,
keys: [I; N],
) -> Result<[NodeIndex; N], TryGetKeyMapError>
where
I: IntoIterator<Item = B>,
K: Ord,
B: Borrow<K>,
{
let mut array = [NodeIndex::ROOT; N];
for (i, k) in keys.into_iter().enumerate() {
let candidate = self.lookup_key(k);
if let Some(candidate) = candidate {
if array.contains(&candidate) {
return Err(TryGetKeyMapError::Duplicate(candidate.position()));
}
}
array[i] = candidate.ok_or(TryGetKeyMapError::KeyDoesNotExist)?;
}
Ok(array)
}
pub fn get_disjoint_mut<I, B, const N: usize>(
&mut self,
keys: [I; N],
) -> DisjointMutIndices<'_, K, V, N>
where
I: IntoIterator<Item = B>,
K: Ord,
B: Borrow<K>,
{
self.try_get_disjoint_mut(keys)
.expect("Keys were overlapping.")
}
pub fn try_get_disjoint_mut<I, B, const N: usize>(
&mut self,
keys: [I; N],
) -> Result<DisjointMutIndices<K, V, N>, GetDisjointMutError>
where
I: IntoIterator<Item = B>,
K: Ord,
B: Borrow<K>,
{
let array = self
.try_get_key_map(keys)
.map_err(|_| GetDisjointMutError::OverlappingIndices)?;
self.node.get_disjoint_mut(array)
}
pub fn get_mut<I, B>(&mut self, key: I) -> Option<&mut V>
where
I: IntoIterator<Item = B>,
K: Ord,
B: Borrow<K>,
{
let index = self.lookup_key(key)?;
self.node[index].value_mut().as_mut()
}
pub fn len(&self) -> usize {
self.size
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&mut self) {
self.node.clear();
self.size = 0;
}
pub fn get_key_value<I, B, J>(&self, key: I) -> Option<(J, &V)>
where
I: IntoIterator<Item = B>,
for<'a> J: FromIterator<&'a K>,
K: Ord,
B: Borrow<K>,
{
let (key, value) = self
.internal_walk_collect_key(&mut key.into_iter().peekable())
.ok()?;
Some((key, self.node[value].value().as_ref()
.expect("Walk terminated on a node that did not have a value, i.e., was not a proper key termination.")
))
}
pub fn get_key_value_mut<I, B, J>(&mut self, key: I) -> Option<(J, &mut V)>
where
I: IntoIterator<Item = B>,
for<'a> J: FromIterator<&'a K>,
K: Ord,
B: Borrow<K>,
{
let (key, value) = self
.internal_walk_collect_key(&mut key.into_iter().peekable())
.ok()?;
Some((key, self.node[value].value_mut().as_mut().expect("Iteration erroneously terminated on non-leaf node.")))
}
pub fn remove_entry<I, B, J>(&mut self, key: I) -> Option<(J, V)>
where
I: IntoIterator<Item = B>,
J: FromIterator<K>,
K: Ord + Clone,
B: Borrow<K>,
{
let mut path = vec![];
let trajectory = self
.internal_walk_with_path_fn(key.into_iter().peekable().by_ref(), |nk| {
self.fold_in_key_optionally(nk, &mut path)
})
.ok()?;
let traj_path = trajectory.path.last().copied();
if traj_path.is_some() {
let reconstruction = path.into_iter().cloned().collect::<J>();
let value = self.remove_post_walk(&trajectory.path)?;
Some((reconstruction, value))
} else {
self.remove_post_walk(&trajectory.path)?;
None
}
}
pub fn reserve(&mut self, space: usize) {
self.node.reserve(space);
}
pub fn remove<I, B>(&mut self, master: I) -> Option<V>
where
I: IntoIterator<Item = B>,
K: Ord,
B: Borrow<K>,
{
let trajectory = self
.internal_walk_with_path(master.into_iter().peekable().by_ref())
.ok()?;
self.remove_post_walk(&trajectory.path)
}
fn detach_node(&mut self, source: NodeIndex) {
let keys = self.node[source].subkeys().copied().collect::<Vec<_>>();
for i in keys {
self.detach_node(i);
}
self.node.remove(source);
}
fn remove_post_walk(&mut self, path: &[NodeIndex]) -> Option<V>
where
K: Ord,
{
let internal_value = self.node[*path.last()?].value_mut().take();
if internal_value.is_some() {
self.size -= 1;
}
let mut sub_index: isize = (path.len() - 1) as isize;
while sub_index >= 0 {
let sh_index = path[sub_index as usize];
if self.node[sh_index].sub_key_len() == 0 && self.node[sh_index].value().is_none() {
self.detach_node(sh_index);
if sub_index > 0 {
let above = path[(sub_index - 1) as usize];
self.node[above].remove_subkey(sh_index);
}
} else {
break;
}
sub_index -= 1;
}
internal_value
}
pub fn contains_key<I>(&self, master: I) -> bool
where
I: Iterator<Item = K>,
K: Ord,
{
self.get(master).is_some()
}
pub fn contains_value(&self, value: &V) -> bool
where
V: Eq,
{
self.node
.iter()
.any(|(_, v)| v.value().as_ref().is_some_and(|inner| inner == value))
}
pub fn insert<I>(&mut self, master: I, value: V) -> Option<V>
where
I: IntoIterator<Item = K>,
K: Ord,
{
match self.internal_walk_with_index(master.into_iter().peekable().by_ref()) {
Ok(v) => {
let current = self.node[v].value_mut().take();
if current.is_none() {
self.size += 1;
}
*self.node[v].value_mut() = Some(value);
current
}
Err(WalkFailure {
mut root,
remainder,
}) => {
let mut nk = root;
for item in remainder {
nk = self.node.insert(Node::keyed(item));
self.node.insert_subkey(root, nk);
root = nk;
}
*self.node[nk].value_mut() = Some(value);
self.size += 1;
None
}
}
}
pub fn shrink_to_fit(&mut self) {
self.node.defragment();
self.node.shrink_to_fit();
}
}
#[derive(Debug)]
pub struct CompletionIter<'a, K, V, J> {
trie: &'a Trie<K, V>,
beginning: Vec<&'a K>,
inner: WalkCtx,
_transform: PhantomData<J>,
}
#[derive(Debug)]
pub struct PostfixIter<'a, K, V, J> {
trie: &'a Trie<K, V>,
inner: WalkCtx,
_transform: PhantomData<J>,
}
impl<I, K, V> Index<I> for Trie<K, V>
where
I: IntoIterator<Item = K>,
K: Ord,
{
type Output = V;
fn index(&self, index: I) -> &Self::Output {
self.get(index).as_ref()
.expect("Could not find index in Trie!")
}
}
impl<'a, K, V, J> Iterator for CompletionIter<'a, K, V, J>
where
J: FromIterator<&'a K>,
{
type Item = (J, &'a V);
fn next(&mut self) -> Option<Self::Item> {
let (current, path) = self.inner.drive(self.trie)?;
let key = assemble_completion(self.trie, &self.beginning, path);
Some((key, self.trie.node[current].value().as_ref().expect("Iteration erroneously terminated on non-leaf node.")))
}
}
impl<'a, K, V, J> Iterator for PostfixIter<'a, K, V, J>
where
J: FromIterator<&'a K>,
{
type Item = (J, &'a V);
fn next(&mut self) -> Option<Self::Item> {
let (current, path) = self.inner.drive(self.trie)?;
let key = path
.iter()
.map(|f| self.trie.node[*f].key())
.filter_map(|f| f.as_ref())
.skip(1);
Some((
key.collect(),
self.trie.node[current].value().as_ref().expect("Walk terminated on a non-leaf node. This means that the internal structure of the Trie is corrupted or there is an error in the WalkCtx implementation."),
))
}
}
fn assemble_completion<'a, K, V, J>(
trie: &'a Trie<K, V>,
beginning: &Vec<&'a K>,
walked: Vec<NodeIndex>,
) -> J
where
J: FromIterator<&'a K>,
{
let tail_end = walked
.into_iter()
.map(|nk| trie.node[nk].key())
.filter_map(<Option<K>>::as_ref);
beginning
.iter()
.rev()
.skip(1)
.rev()
.copied()
.chain(tail_end)
.collect::<J>()
}
impl<K, V, I> IndexMut<I> for Trie<K, V>
where
K: Ord,
I: IntoIterator<Item = K>,
{
fn index_mut(&mut self, index: I) -> &mut Self::Output {
self.get_mut(index).expect("Invalid trie index")
}
}
pub struct EntryIterMut<'a, K, V, J> {
trie: &'a mut Trie<K, V>,
inner: WalkCtx,
_type: PhantomData<J>,
}
pub struct IntoEntryIter<K, V, J> {
trie: Trie<K, V>,
inner: WalkCtx,
_type: PhantomData<J>,
}
pub struct EntryIterRef<'a, K, V, J> {
inner: WalkIter<'a, K, V>,
_type: PhantomData<J>,
}
pub struct KeyIter<'a, K, V, J> {
inner: WalkIter<'a, K, V>,
_type: PhantomData<J>,
}
pub struct ValueIter<V> {
inner: IntoIter<V>,
}
pub struct ValueIterRef<'a, K, V> {
values: std::slice::Iter<'a, Option<Node<K, V>>>,
}
pub struct ValueIterMut<'a, K, V> {
values: SlotsIterMut<'a, K, V>,
}
impl<V> Trie<char, V> {
pub fn insert_str(&mut self, key: &str, value: V) -> Option<V> {
self.insert(key.chars(), value)
}
pub fn remove_str(&mut self, key: &str) -> Option<V> {
self.remove(key.chars())
}
pub fn contains_key_str(&self, key: &str) -> bool {
self.contains_key(key.chars())
}
pub fn is_prefix_str(&self, key: &str) -> bool {
self.is_prefix(key.chars())
}
pub fn completions_str(&self, key: &str) -> CompletionIter<'_, char, V, String> {
self.completions(key.chars())
}
pub fn postfix_search_str(&self, key: &str) -> PostfixIter<'_, char, V, String> {
self.postfix_search(key.chars())
}
pub fn longest_prefix_entry_str(&self, key: &str) -> Option<(String, &V)> {
self.longest_prefix_entry(key.chars())
}
pub fn longest_prefix_str(&self, key: &str) -> Option<&V> {
self.longest_prefix(key.chars())
}
pub fn keys_str(&self) -> KeyIter<'_, char, V, String> {
self.keys()
}
pub fn into_entries_str(self) -> IntoEntryIter<char, V, String> {
self.into_entries()
}
pub fn entries_str(&self) -> EntryIterRef<'_, char, V, String> {
self.entries()
}
pub fn entries_mut_str(&mut self) -> EntryIterMut<'_, char, V, String> {
self.entries_mut()
}
pub fn get_str(&self, key: &str) -> Option<&V> {
self.get(key.chars())
}
pub fn get_mut_str(&mut self, key: &str) -> Option<&mut V> {
self.get_mut(key.chars())
}
pub fn get_disjoint_mut_str<const N: usize>(
&mut self,
key: [&str; N],
) -> DisjointMutIndices<'_, char, V, N> {
let modded = core::array::from_fn(|i| key[i].chars());
self.get_disjoint_mut(modded)
}
pub fn try_get_disjoint_mut_str<const N: usize>(
&mut self,
key: [&str; N],
) -> Result<DisjointMutIndices<'_, char, V, N>, GetDisjointMutError> {
let modded = core::array::from_fn(|i| key[i].chars());
self.try_get_disjoint_mut(modded)
}
pub fn get_key_value_str(&self, key: &str) -> Option<(String, &V)> {
self.get_key_value(key.chars())
}
pub fn get_key_value_mut_str(&mut self, key: &str) -> Option<(String, &mut V)> {
self.get_key_value_mut(key.chars())
}
pub fn remove_entry_str(&mut self, key: &str) -> Option<(String, V)> {
self.remove_entry(key.chars())
}
}
impl<K, V, J> Iterator for IntoEntryIter<K, V, J>
where
for<'b> J: FromIterator<&'b K>,
{
type Item = (J, V);
fn next(&mut self) -> Option<Self::Item> {
let (current, path) = { self.inner.drive(&self.trie)? };
let calced = self.trie.collect_path_keys::<J>(&path);
Some((calced, self.trie.node[current].value_mut().take().expect("Iteration terminated erroneously on a non-leaf node.")))
}
}
impl<'a, K, V, J> Iterator for EntryIterMut<'a, K, V, J>
where
for<'b> J: FromIterator<&'b K>,
{
type Item = (J, ValueSlot<'a, K, V>);
fn next(&mut self) -> Option<Self::Item> {
let (current, path) = { self.inner.drive(self.trie)? };
let calced = self.trie.collect_path_keys::<J>(&path);
let candidate = unsafe { &mut *(&mut self.trie.node[current] as *mut Node<K, V>) };
Some((calced, ValueSlot { node: candidate }))
}
}
pub struct ValueSlot<'a, K, V> {
node: &'a mut Node<K, V>,
}
impl<K, V> Deref for ValueSlot<'_, K, V> {
type Target = V;
fn deref(&self) -> &Self::Target {
self.node.value().as_ref().expect("Tried to read node value but was None.")
}
}
impl<K, V> DerefMut for ValueSlot<'_, K, V> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.node.value_mut_unchecked()
}
}
impl<K, V> PartialEq for Trie<K, V>
where
K: PartialEq,
V: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.node.eq(&other.node)
}
}
impl<'a, K, V, J> Iterator for EntryIterRef<'a, K, V, J>
where
J: FromIterator<&'a K>,
{
type Item = (J, &'a V);
fn next(&mut self) -> Option<Self::Item> {
let (curr, path) = self.inner.next()?;
Some((
self.inner.trie.collect_path_keys(&path),
self.inner.trie.node[curr].value().as_ref()
.expect("Walk terminated on non-leaf node.")
))
}
}
impl<'a, K, V, J> Iterator for KeyIter<'a, K, V, J>
where
J: FromIterator<&'a K>,
{
type Item = J;
fn next(&mut self) -> Option<Self::Item> {
let (_, path) = self.inner.next()?;
Some(self.inner.trie.collect_path_keys(&path))
}
}
impl<V> Iterator for ValueIter<V> {
type Item = V;
fn next(&mut self) -> Option<Self::Item> {
self.inner.next()
}
}
impl<'a, K, V> Iterator for ValueIterMut<'a, K, V> {
type Item = &'a mut V;
fn next(&mut self) -> Option<Self::Item> {
let mut current = None;
while current.is_none() {
let node = self.values.next()?;
current = node.value_mut().as_mut();
}
current
}
}
fn map_node_to_value<K, V>(option: Option<&Node<K, V>>) -> Option<&V> {
let val = option?;
val.value().as_ref()
}
impl<'a, K, V> Iterator for ValueIterRef<'a, K, V> {
type Item = &'a V;
fn next(&mut self) -> Option<Self::Item> {
let mut current = None;
while current.is_none() {
let node = self.values.next()?;
current = map_node_to_value(node.as_ref());
}
current
}
}
impl<KP, K, V> Extend<(KP, V)> for Trie<K, V>
where
K: Ord,
KP: IntoIterator<Item = K>,
{
#[inline]
fn extend<T: IntoIterator<Item = (KP, V)>>(&mut self, iter: T) {
for (key, value) in iter {
self.insert(key, value);
}
}
}
impl<'a, KP, K, V> Extend<(KP, &'a V)> for Trie<K, V>
where
K: Ord,
V: Clone,
KP: IntoIterator<Item = K>,
{
#[inline]
fn extend<T: IntoIterator<Item = (KP, &'a V)>>(&mut self, iter: T) {
for (key, value) in iter {
self.insert(key, value.clone());
}
}
}
impl<KP, K, V> FromIterator<(KP, V)> for Trie<K, V>
where
K: Ord,
KP: IntoIterator<Item = K>,
{
fn from_iter<T: IntoIterator<Item = (KP, V)>>(iter: T) -> Self {
let mut map = Trie::new();
map.extend(iter);
map
}
}
impl<KP, K, V, const N: usize> From<[(KP, V); N]> for Trie<K, V>
where
K: Ord,
KP: IntoIterator<Item = K>,
{
fn from(arr: [(KP, V); N]) -> Self {
Self::from_iter(arr)
}
}
pub struct Drain<'a, K, V, J> {
key_iter: std::vec::IntoIter<J>,
inner: std::vec::Drain<'a, Option<Node<K, V>>>,
}
impl<K, V, J> Iterator for Drain<'_, K, V, J> {
type Item = (J, V);
fn next(&mut self) -> Option<Self::Item> {
let mut current = None;
while current.is_none() {
if let Some(node) = self.inner.next()? {
current = node.into_value();
}
}
Some((self.key_iter.next()?, current.expect("The current iterating node was supposedly filled, but turned out to be None.")))
}
}
pub enum TryGetKeyMapError {
KeyDoesNotExist,
Duplicate(usize)
}
impl std::fmt::Display for TryGetKeyMapError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::KeyDoesNotExist => f.write_str("KeyDoesNotExist"),
Self::Duplicate(i) => write!(f, "Duplicate({i})")
}
}
}
impl std::fmt::Debug for TryGetKeyMapError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
<Self as std::fmt::Display>::fmt(self, f)
}
}
impl Error for TryGetKeyMapError {
fn description(&self) -> &str {
match self {
Self::Duplicate(_) => "There was a duplicate key, they were not disjoint.",
Self::KeyDoesNotExist => "Failed to get one of the keys from the slotmap"
}
}
}
#[cfg(feature = "arbitrary")]
mod arbitrary_trie {
use core::f32;
use arbitrary::{Arbitrary, Unstructured};
use crate::Trie;
fn random_norm_float(u: &mut Unstructured<'_>) -> arbitrary::Result<f32> {
Ok(f32::from_bits(u32::arbitrary(u)?).abs() / f32::MAX)
}
impl<'a, K, V> Arbitrary<'a> for Trie<K, V>
where
K: Arbitrary<'a> + Ord + Clone,
V: Arbitrary<'a>,
{
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let mut trie = Trie::<K, V>::new();
let trie_length = u16::arbitrary(u)? as usize;
let mut buffer = Vec::new();
for _ in 0..trie_length {
let vector: Vec<K> = Vec::arbitrary(u)?;
let value = V::arbitrary(u)?;
buffer.push(Some(vector.clone()));
trie.insert(vector.into_iter(), value);
}
let thres = random_norm_float(u)?;
for i in 0..buffer.len() {
if random_norm_float(u)? < thres {
let key = buffer.get_mut(i)
.ok_or(arbitrary::Error::IncorrectFormat)?
.take()
.ok_or(arbitrary::Error::IncorrectFormat)?;
trie.remove(key.into_iter());
}
}
Ok(trie)
}
}
}
#[cfg(test)]
mod tests {
use std::{cmp::Reverse, time::Duration};
use super::Trie;
#[test]
#[cfg(feature = "arbitrary")]
pub fn test_arbitrary_crate() {
use arbitrary::{Arbitrary, Unstructured};
use rand::Rng;
let val: [u8; 12] = rand::rng().random();
let trie: Trie<char, usize> = Trie::arbitrary(&mut Unstructured::new(&val)).unwrap();
let trie2: Trie<char, usize> = Trie::arbitrary(&mut Unstructured::new(&val)).unwrap();
assert_eq!(trie, trie2);
}
#[test]
pub fn test_root_len_mod() {
let mut trie = Trie::<char, usize>::new();
assert_eq!(trie.len(), 0);
trie.insert_str("", 3);
assert_eq!(trie.len(), 1);
trie.insert_str("", 4);
assert_eq!(trie.len(), 1);
trie.remove_str("");
assert_eq!(trie.len(), 0);
}
#[test]
pub fn test_size_duplicates() {
let mut trie = Trie::<char, usize>::from([
("hello".chars(), 4),
("hey".chars(), 3),
("hamburger".chars(), 10),
]);
assert_eq!(trie.len(), 3);
trie.insert("hello".chars(), 4);
assert_eq!(trie.len(), 3);
trie.insert("heman".chars(), 10);
assert_eq!(trie.len(), 4);
trie.remove("heman".chars());
assert_eq!(trie.len(), 3);
trie.remove("heman".chars());
assert_eq!(trie.len(), 3);
trie.remove("albert".chars());
assert_eq!(trie.len(), 3);
}
#[test]
#[cfg(all(feature = "arbitrary", feature = "serde"))]
pub fn arbtest_arb_impl() {
use arbitrary::Arbitrary;
arbtest::arbtest(|u| {
let standard = Trie::<char, usize>::arbitrary(u)?;
let reversed: Trie<char, usize> =
serde_json::from_slice(&serde_json::to_vec(&standard).unwrap()).unwrap();
assert_eq!(standard, reversed);
Ok(())
})
.budget(Duration::from_secs(5));
}
#[test]
pub fn test_delete_root_vector() {
let mut root = Trie::<char, usize>::from([("hello".chars(), 4)]);
root.insert([], 5);
root.remove::<[char; 0], char>([]);
root.remove::<[char; 0], char>([]);
}
#[test]
#[cfg(all(feature = "rkyv", feature = "serde"))]
pub fn arbtest_arb_impl_rkyv() {
use arbitrary::Arbitrary;
use rkyv::rancor::Error;
arbtest::arbtest(|u| {
let standard = Trie::<char, u64>::arbitrary(u)?;
let archived = rkyv::to_bytes::<Error>(&standard).unwrap();
let archived: Trie<char, u64> = rkyv::from_bytes::<_, Error>(&archived).unwrap();
assert_eq!(standard, archived);
Ok(())
})
.budget(Duration::from_secs(5));
}
#[test]
#[cfg(feature = "serde")]
pub fn test_serde_serialize_map() {
let tree = Trie::<char, usize>::from([("hello".chars(), 4), ("hey".chars(), 5)]);
let val = serde_json::to_vec(&tree).unwrap();
let tree2: Trie<char, usize> = serde_json::from_slice(&val).unwrap();
assert_eq!(tree, tree2);
}
#[test]
pub fn trie_get_kv_properly() {
let trie: Trie<char, i32> = Trie::from([(['h', 'e', 'l', 'l', 'o'], 4)]);
assert_eq!(
trie.get_key_value::<_, _, String>("hello".chars()),
Some(("hello".to_string(), &4))
);
}
#[test]
pub fn test_entry_mut() {
let mut tree = Trie::<char, usize>::from([("hello".chars(), 4), ("bye".chars(), 3)]);
let mut key_iter = tree.entries_mut::<String>();
assert_eq!(*key_iter.next().unwrap().1, 3);
assert_eq!(*key_iter.next().unwrap().1, 4);
assert!(key_iter.next().is_none());
let mut key_iter = tree.entries_mut::<String>();
*key_iter.next().unwrap().1 = 5;
let mut key_iter = tree.entries_mut::<String>();
assert_eq!(*key_iter.next().unwrap().1, 5);
assert_eq!(*key_iter.next().unwrap().1, 4);
assert!(key_iter.next().is_none());
}
#[test]
pub fn trie_from_tuples() {
let trie: Trie<char, i32> = Trie::from([("hello".chars(), 4)]);
assert_eq!(trie.get("hello".chars()), Some(&4));
}
#[test]
pub fn trie_disjoint_mut() {
let mut tree = Trie::<char, usize>::new();
tree.insert("hello".chars(), 3);
tree.insert("world".chars(), 4);
let mut keys = tree.get_disjoint_mut(["hello".chars(), "world".chars()]);
*keys[0].as_mut().unwrap() = 4;
*keys[1].as_mut().unwrap() = 2;
assert_eq!(tree.get("hello".chars()), Some(&4));
assert_eq!(tree.get("world".chars()), Some(&2));
}
#[test]
pub fn test_trie_equality_simple() {
let trie_1 = Trie::from([("hello".chars(), 3), ("good".chars(), 2)]);
let trie_2 = Trie::from([("hello".chars(), 3), ("good".chars(), 2)]);
let trie_3 = Trie::from([("hello".chars(), 3), ("good".chars(), 3)]);
assert_eq!(trie_1, trie_2);
assert_ne!(trie_1, trie_3);
}
#[test]
pub fn test_into_values() {
let trie: Trie<char, usize> = Trie::from([("".chars(), 1), ("tra".chars(), 2)]);
let mut values = trie.into_values();
assert_eq!(values.next(), Some(1));
assert_eq!(values.next(), Some(2));
}
#[test]
pub fn test_removal_subword() {
let mut trie_1 = Trie::from([("hello".chars(), 3), ("good".chars(), 2)]);
trie_1.insert("go".chars(), 8);
assert_eq!(trie_1.remove("go".chars()), Some(8));
assert!(trie_1.contains_key("good".chars()));
}
#[test]
pub fn test_trie_equality_complex() {
let mut trie_1 = Trie::from([("hello".chars(), 3), ("good".chars(), 2)]);
let mut trie_2 = Trie::from([("hello".chars(), 3), ("good".chars(), 2)]);
trie_2.remove("hello".chars());
trie_2.insert("hey".chars(), 12);
trie_2.insert("hello".chars(), 3);
trie_1.remove("good".chars());
trie_1.insert("go".chars(), 8);
trie_1.insert("hey".chars(), 12);
assert!(trie_1.contains_key("hey".chars()));
trie_1.insert("good".chars(), 2);
trie_1.remove("go".chars());
assert!(trie_1.contains_key("good".chars()));
assert_eq!(trie_1, trie_2);
}
#[test]
pub fn trie_test_defragment() {
let mut trie = Trie::<char, usize>::new();
trie.insert(['a'], 1);
trie.insert(['d'], 2);
trie.insert(['b'], 3);
trie.insert(['b', 'c'], 4);
assert_eq!(trie.get(['a']), Some(&1));
assert_eq!(trie.get(['d']), Some(&2));
assert_eq!(trie.get(['b']), Some(&3));
assert_eq!(trie.get(['b', 'c']), Some(&4));
trie.remove(['d']);
assert_eq!(trie.get(['a']), Some(&1));
assert_eq!(trie.get(['d']), None);
assert_eq!(trie.get(['b']), Some(&3));
assert_eq!(trie.get(['b', 'c']), Some(&4));
trie.shrink_to_fit();
assert_eq!(trie.get(['a']), Some(&1));
assert_eq!(trie.get(['d']), None);
assert_eq!(trie.get(['b']), Some(&3));
assert_eq!(trie.get(['b', 'c']), Some(&4));
}
#[test]
pub fn trie_into_values() {
let mut tree = Trie::<char, usize>::new();
tree.insert("hello".chars(), 1);
tree.insert("bye".chars(), 2);
let mut values = tree.into_values();
assert_eq!(values.next(), Some(1));
assert_eq!(values.next(), Some(2));
assert_eq!(values.next(), None);
}
#[test]
pub fn path_traversal_test() {
let mut tree = Trie::<char, &str>::new();
tree.insert(['t', 'e', 's', 't'], "sample_1");
tree.insert("tea".chars(), "Sample_2");
}
#[test]
pub fn basic_trie_insert() {
let mut tree: Trie<char, &str> = Trie::new();
tree.insert("test".chars(), "sample_1");
tree.insert("testttt".chars(), "sample_2");
assert!(tree.contains_key("test".chars()));
assert_eq!(*tree.get("test".chars()).unwrap(), "sample_1");
}
#[test]
pub fn basic_trie_insert_multi_keys() {
let mut tree: Trie<char, &str> = Trie::new();
tree.insert("test".chars(), "sample_1");
println!("INSERTING TEA...");
tree.insert("tea".chars(), "sample_3");
for (key, value) in tree.node.iter() {
println!("({key:?}) -> {value:?}");
}
println!("\n\n\nStarting Tea...");
assert!(tree.contains_key("tea".chars()));
assert_eq!(*tree.get("test".chars()).unwrap(), "sample_1");
assert_eq!(*tree.get("tea".chars()).unwrap(), "sample_3");
}
#[test]
pub fn basic_trie_deletion() {
let mut tree: Trie<char, &str> = Trie::new();
tree.insert("test".chars(), "sample_1");
assert_eq!(tree.len(), 1);
assert!(tree.contains_key("test".chars()));
println!("Helo");
tree.remove("test".chars());
println!("Deleetign");
assert!(!tree.contains_key("test".chars()));
assert_eq!(tree.len(), 0);
}
#[test]
pub fn trie_deletion_multikey() {
let mut tree: Trie<char, &str> = Trie::new();
tree.insert("test".chars(), "sample_1");
tree.insert("tea".chars(), "sample_2");
assert!(tree.contains_key("test".chars()));
assert!(tree.contains_key("tea".chars()));
tree.remove("tea".chars());
assert!(!tree.contains_key("tea".chars()));
assert!(tree.contains_key("test".chars()));
}
#[test]
pub fn test_arbitrary_insert() {
let mut tree: Trie<char, String> = Trie::new();
arbtest::arbtest(|u| {
let key: String = u.arbitrary::<[char; 32]>().unwrap().iter().collect();
let value: String = u.arbitrary::<[char; 32]>().unwrap().iter().collect();
tree.insert(key.chars(), value);
assert!(tree.contains_key(key.chars()));
Ok(())
});
}
#[test]
pub fn completions_test() {
let mut tree = Trie::<char, ()>::new();
tree.insert("hello".chars(), ());
tree.insert("hey".chars(), ());
tree.insert("james".chars(), ());
let mut values = tree
.completions::<_, _, String>("he".chars())
.into_iter()
.collect::<Vec<_>>();
values.sort();
assert_eq!(values.len(), 2);
assert_eq!(values[0].0, "hello");
assert_eq!(values[1].0, "hey");
assert_eq!(
tree.completions::<_, char, String>([])
.map(|(a, _)| a)
.collect::<Vec<_>>(),
vec![
String::from("hello"),
String::from("hey"),
String::from("james")
]
);
assert_eq!(
tree.completions::<_, _, String>(['h', 'e', 'l', 'l', 'o'])
.map(|(a, _)| a)
.collect::<Vec<_>>(),
vec![String::from("hello")]
);
}
#[test]
pub fn test_arbitrary_deletion() {
let mut tree: Trie<char, String> = Trie::new();
arbtest::arbtest(|u| {
let key: String = u.arbitrary::<[char; 4]>().unwrap().iter().collect();
let value: String = u.arbitrary::<[char; 4]>().unwrap().iter().collect();
assert_eq!(
tree.contains_key(key.chars()),
tree.remove(key.chars()).is_some()
);
tree.insert(key.chars(), value);
assert_eq!(
tree.contains_key(key.chars()),
tree.remove(key.chars()).is_some()
);
Ok(())
})
.budget(Duration::from_secs(3));
}
#[test]
pub fn test_longest_prefix() {
let mut trie = Trie::<char, usize>::new();
trie.insert("he".chars(), 1);
trie.insert("hel".chars(), 2);
trie.insert("hello".chars(), 3);
assert_eq!(
trie.longest_prefix_entry("hello".chars()),
Some(("hello".to_string(), &3))
);
assert_eq!(
trie.longest_prefix_entry("hellothere".chars()),
Some(("hello".to_string(), &3))
);
}
#[test]
pub fn test_entry_iter_mut_soundness() {
let mut trie = Trie::from([("hello".chars(), 42), ("hey".chars(), 55)]);
let wow = trie.entries_mut::<String>();
for (_, mut val) in wow {
*val = 23;
}
for (_, val) in trie.entries::<String>() {
assert_eq!(*val, 23);
}
}
#[test]
pub fn test_longest_prefix_routing_table() {
let mut trie = Trie::<bool, MatchRule>::new();
fn convert_to_bits(val: u8) -> [bool; 8] {
let mut result = [false; 8];
for i in 0..8 {
result[7 - i] = (val & (1 << i)) != 0;
}
result
}
#[derive(PartialEq, Eq, Debug)]
enum MatchRule {
Forward(u8),
Delete,
}
trie.insert([true, false, true], MatchRule::Forward(3));
trie.insert([true, false, true, false], MatchRule::Forward(4));
trie.insert(
[true, false, true, false, false, false, true],
MatchRule::Delete,
);
assert_eq!(
trie.longest_prefix(convert_to_bits(0b10100000u8)),
Some(&MatchRule::Forward(4))
);
assert_eq!(
trie.longest_prefix(convert_to_bits(0b10110000u8)),
Some(&MatchRule::Forward(3))
);
assert_eq!(
trie.longest_prefix(convert_to_bits(0b10100010u8)),
Some(&MatchRule::Delete)
);
}
#[test]
pub fn test_scoring_fn() {
let mut trie = Trie::<char, ()>::new();
trie.insert("hello".chars(), ());
trie.insert("he".chars(), ());
trie.insert("hema".chars(), ());
assert_eq!(
trie.search_with_score_fn(|f| f.len()),
Some(vec![&'h', &'e', &'l', &'l', &'o'])
);
assert_eq!(
trie.search_with_score_fn(|f| Reverse(f.len())),
Some(vec![&'h', &'e'])
);
}
#[test]
pub fn test_scoring_fn_levinstein() {
let mut trie = Trie::<char, ()>::new();
trie.insert("hello".chars(), ());
trie.insert("he".chars(), ());
trie.insert("hema".chars(), ());
trie.insert("racecar".chars(), ());
trie.insert("nissan".chars(), ());
let lev = |target| {
let lev_result = trie.search_with_score_fn(|f| {
let string = f.iter().copied().collect::<String>();
let distance = edit_distance::edit_distance(target, &string);
Reverse(distance)
});
lev_result.map(|f| f.into_iter().copied().collect::<String>())
};
assert_eq!(lev("hello"), Some("hello".to_string()));
assert_eq!(lev("niwan"), Some("nissan".to_string()));
}
}