use std::slice::Iter;
pub struct Trie<T> {
root: TrieNode<T>,
}
struct TrieNode<T> {
branches: Vec<(u8, TrieNode<T>)>,
values: Vec<T>,
}
impl<T> TrieNode<T> {
fn allocated_size(&self) -> usize {
let mut result = self.branches.capacity() * std::mem::size_of::<(u8, TrieNode<T>)>();
for (_, branch) in self.branches.iter() {
result += branch.allocated_size();
}
result += self.values.capacity() * std::mem::size_of::<T>();
result
}
fn num_children(&self) -> usize {
let mut result = self.branches.len();
for (_, branch) in self.branches.iter() {
result += branch.num_children();
}
result
}
fn num_values(&self) -> usize {
let mut result = self.values.len();
for (_, branch) in self.branches.iter() {
result += branch.num_values();
}
result
}
}
impl<T> Default for Trie<T> {
fn default() -> Self {
Trie::new()
}
}
impl<T> Trie<T> {
pub fn new() -> Self {
Trie {
root: TrieNode {
branches: Vec::new(),
values: Vec::new(),
},
}
}
fn insert_checked<P>(&mut self, key: &str, value: T, predicate: P)
where
P: Fn(&T, &TrieNode<T>) -> bool,
{
let mut node = &mut self.root;
for &ch in key.as_bytes().iter() {
if let Some(idx) = node.branches.iter().position(|branch| branch.0 == ch) {
node = unsafe { &mut node.branches.get_unchecked_mut(idx).1 };
} else {
node.branches.push((
ch,
TrieNode {
branches: Vec::with_capacity(1),
values: Vec::new(),
},
));
let last_index = node.branches.len() - 1;
node = unsafe { &mut node.branches.get_unchecked_mut(last_index).1 };
}
}
if predicate(&value, node) {
node.values.push(value);
}
}
pub fn insert(&mut self, key: &str, value: T) {
self.insert_checked(key, value, |_, _| true)
}
pub fn query_values(&self, query: impl AsRef<str>) -> Option<&Vec<T>> {
let mut node = &self.root;
for &ch in query.as_ref().as_bytes().iter() {
node = match node.branches.iter().find(|branch| branch.0 == ch) {
Some((_, node)) => node,
None => return None,
};
}
if node.values.is_empty() {
None
} else {
Some(&node.values)
}
}
pub fn query(&'_ self, query: impl AsRef<str>) -> OptionalIter<'_, T> {
if let Some(values) = self.query_values(query) {
OptionalIter::Found(values.iter())
} else {
OptionalIter::Empty
}
}
pub fn prefix_query(&'_ self, prefix: &str) -> PrefixIter<'_, T> {
let mut node = &self.root;
for &ch in prefix.as_bytes().iter() {
node = match node.branches.iter().find(|branch| branch.0 == ch) {
Some((_, node)) => node,
None => {
return PrefixIter {
current_iter: None,
stack: vec![],
};
}
};
}
PrefixIter {
current_iter: Some(node.values.iter()),
stack: vec![(0, node)],
}
}
pub fn allocated_size(&self) -> usize {
self.root.allocated_size()
}
pub fn num_entries(&self) -> usize {
self.root.num_values()
}
pub fn num_nodes(&self) -> usize {
self.root.num_children()
}
pub fn scan<C>(&self, callback: &mut C)
where
C: FnMut(&str, &T),
{
let mut prefix = Vec::new();
Trie::scan_node(&mut prefix, &self.root, callback);
}
fn scan_node<C>(prefix: &mut Vec<u8>, node: &TrieNode<T>, callback: &mut C)
where
C: FnMut(&str, &T),
{
if let Ok(prefix_str) = std::str::from_utf8(prefix.as_slice()) {
for value in node.values.iter() {
callback(prefix_str, value);
}
}
for (ch, child) in &node.branches {
prefix.push(*ch);
Trie::scan_node(prefix, child, callback);
let _ = prefix.pop();
}
}
}
impl<T: PartialEq> Trie<T> {
pub fn insert_unique(&mut self, key: &str, value: T) {
self.insert_checked(key, value, |new_value, node| {
!node.values.contains(new_value)
});
}
}
pub enum OptionalIter<'a, T> {
Found(Iter<'a, T>),
Empty,
}
impl<'a, T> Iterator for OptionalIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
match self {
OptionalIter::Found(iter) => iter.next(),
_ => None,
}
}
}
pub struct PrefixIter<'a, T> {
current_iter: Option<Iter<'a, T>>,
stack: Vec<(usize, &'a TrieNode<T>)>,
}
impl<'a, T> Iterator for PrefixIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if let Some(current) = &mut self.current_iter {
if let Some(value) = current.next() {
return Some(value);
} else {
self.current_iter = None;
}
}
if self.stack.is_empty() {
return None;
}
while let Some((idx, node)) = self.stack.last_mut() {
if let Some(child) = node.branches.get(*idx) {
*idx += 1;
self.stack.push((0, &child.1));
if !child.1.values.is_empty() {
self.current_iter = Some(child.1.values.iter());
break;
}
} else {
let _ = self.stack.pop();
}
}
self.next()
}
}
#[cfg(test)]
mod tests {
use crate::idb::trie::Trie;
#[test]
fn insert_and_query_works() {
let mut trie = Trie::new();
trie.insert("abc", 42);
trie.insert("abc", 0);
trie.insert("test", 1);
let mut iter = trie.query("abc");
assert_eq!(iter.next().unwrap(), &42);
assert_eq!(iter.next().unwrap(), &0);
assert_eq!(trie.query_values("abc").unwrap()[0], 42);
assert_eq!(trie.query_values("abc").unwrap()[1], 0);
assert_eq!(trie.query("test").next().unwrap(), &1);
assert_eq!(trie.query_values("test").unwrap()[0], 1);
assert_eq!(trie.query("fail").next().is_none(), true);
assert_eq!(trie.query_values("fail").is_none(), true);
}
#[test]
fn prefix_search_works() {
let mut trie = Trie::new();
trie.insert("abc", 1);
trie.insert("abc", 2);
trie.insert("abcd", 3);
trie.insert("abcde", 4);
trie.insert("abcdf", 5);
trie.insert("abd", 6);
let results: Vec<&i32> = trie.prefix_query("").collect();
assert_eq!(results, vec![&1, &2, &3, &4, &5, &6]);
let results: Vec<&i32> = trie.prefix_query("a").collect();
assert_eq!(results, vec![&1, &2, &3, &4, &5, &6]);
let results: Vec<&i32> = trie.prefix_query("ab").collect();
assert_eq!(results, vec![&1, &2, &3, &4, &5, &6]);
let results: Vec<&i32> = trie.prefix_query("abc").collect();
assert_eq!(results, vec![&1, &2, &3, &4, &5]);
let results: Vec<&i32> = trie.prefix_query("abcd").collect();
assert_eq!(results, vec![&3, &4, &5]);
let results: Vec<&i32> = trie.prefix_query("abcdf").collect();
assert_eq!(results, vec![&5]);
let results: Vec<&i32> = trie.prefix_query("unknown").collect();
assert_eq!(results.is_empty(), true);
}
#[test]
fn scan_works() {
let mut trie = Trie::new();
trie.insert("A", 1);
trie.insert("B", 2);
trie.insert("AB", 4);
trie.insert("ABC", 9);
let mut buffer = String::new();
trie.scan(&mut |prefix, item| buffer.push_str(format!("{}{}", prefix, item).as_str()));
assert_eq!(buffer, "A1AB4ABC9B2");
}
}