#![doc = include_str!("../README.md")]
#![warn(missing_docs)]
#![warn(missing_debug_implementations)]
use std::fmt;
use std::iter::FusedIterator;
#[inline]
unsafe fn utf8_str_unchecked(bytes: &[u8]) -> &str {
debug_assert!(std::str::from_utf8(bytes).is_ok());
unsafe { std::str::from_utf8_unchecked(bytes) }
}
#[inline]
unsafe fn utf8_string_unchecked(bytes: Vec<u8>) -> String {
debug_assert!(std::str::from_utf8(&bytes).is_ok());
unsafe { String::from_utf8_unchecked(bytes) }
}
#[inline]
fn utf8_char_len(b: u8) -> usize {
if b < 0xC0 {
1
} else if b < 0xE0 {
2
} else if b < 0xF0 {
3
} else {
4
}
}
#[derive(Clone)]
pub struct CommandTrieBuilder<T> {
root: BuilderNode<T>,
len: usize,
}
#[derive(Clone)]
struct BuilderNode<T> {
label: Box<[u8]>,
value: Option<T>,
children: Vec<BuilderNode<T>>,
}
impl<T> Default for CommandTrieBuilder<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> CommandTrieBuilder<T> {
#[must_use]
pub fn new() -> Self {
Self {
root: BuilderNode {
label: Box::from(&[][..]),
value: None,
children: Vec::new(),
},
len: 0,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.len
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn clear(&mut self) {
*self = Self::new();
}
pub fn insert(&mut self, key: &str, value: T) -> Option<T> {
let prev = self.root.insert(key.as_bytes(), value);
if prev.is_none() {
self.len += 1;
}
prev
}
pub fn remove(&mut self, key: &str) -> Option<T> {
let v = self.root.remove(key.as_bytes())?;
self.len -= 1;
Some(v)
}
#[must_use]
pub fn get(&self, key: &str) -> Option<&T> {
let mut node = &self.root;
let mut rem = key.as_bytes();
loop {
if rem.is_empty() {
return node.value.as_ref();
}
let child = node.find_child(rem)?;
if !rem.starts_with(&child.label) {
return None;
}
rem = &rem[child.label.len()..];
node = child;
}
}
#[must_use]
pub fn contains(&self, key: &str) -> bool {
self.get(key).is_some()
}
pub fn build(self) -> CommandTrie<T> {
let len = self.len;
let mut nodes: Vec<FrozenNode<T>> = Vec::new();
let mut labels: Vec<u8> = Vec::new();
let mut children: Vec<NodeId> = Vec::new();
let mut child_first_bytes: Vec<u8> = Vec::new();
build_visit(
self.root,
&mut nodes,
&mut labels,
&mut children,
&mut child_first_bytes,
);
CommandTrie {
nodes: nodes.into_boxed_slice(),
labels: labels.into_boxed_slice(),
children: children.into_boxed_slice(),
child_first_bytes: child_first_bytes.into_boxed_slice(),
len,
}
}
}
fn build_visit<T>(
node: BuilderNode<T>,
nodes: &mut Vec<FrozenNode<T>>,
labels: &mut Vec<u8>,
children: &mut Vec<NodeId>,
child_first_bytes: &mut Vec<u8>,
) -> NodeId {
let id = u16_or_panic(nodes.len());
let label_start = u16_or_panic(labels.len());
let label_len = u16_or_panic(node.label.len());
labels.extend_from_slice(&node.label);
nodes.push(FrozenNode {
label_start,
label_len,
children_start: 0,
children_len: 0,
value: node.value,
});
let n_children = u16_or_panic(node.children.len());
let children_start = u16_or_panic(children.len());
for _ in 0..n_children {
children.push(0);
child_first_bytes.push(0);
}
for (i, child) in node.children.into_iter().enumerate() {
let first = child.label[0];
let slot = children_start as usize + i;
let child_id = build_visit(child, nodes, labels, children, child_first_bytes);
children[slot] = child_id;
child_first_bytes[slot] = first;
}
let n = &mut nodes[id as usize];
n.children_start = children_start;
n.children_len = n_children;
id
}
#[inline]
fn u16_or_panic(n: usize) -> u16 {
u16::try_from(n).expect("command-trie size exceeds u16::MAX (see FrozenNode docs)")
}
impl<T> BuilderNode<T> {
fn find_child(&self, rem: &[u8]) -> Option<&BuilderNode<T>> {
let idx = self.child_index(rem).ok()?;
Some(&self.children[idx])
}
fn child_index(&self, rem: &[u8]) -> Result<usize, usize> {
let first = rem[0];
if first < 0x80 {
return self.children.binary_search_by_key(&first, |c| c.label[0]);
}
let needle_len = utf8_char_len(first).min(rem.len());
let needle = &rem[..needle_len];
self.children.binary_search_by(|c| {
let cn = utf8_char_len(c.label[0]).min(c.label.len());
c.label[..cn].cmp(needle)
})
}
fn insert(&mut self, rem: &[u8], value: T) -> Option<T> {
if rem.is_empty() {
return self.value.replace(value);
}
match self.child_index(rem) {
Err(at) => {
self.children.insert(
at,
BuilderNode {
label: Box::from(rem),
value: Some(value),
children: Vec::new(),
},
);
None
}
Ok(idx) => {
let child = &mut self.children[idx];
let common = lcp(&child.label, rem);
if common == child.label.len() {
return child.insert(&rem[common..], value);
}
let old_label = std::mem::replace(&mut child.label, Box::from(&rem[..common]));
let old_value = child.value.take();
let old_children = std::mem::take(&mut child.children);
let existing = BuilderNode {
label: Box::from(&old_label[common..]),
value: old_value,
children: old_children,
};
if common == rem.len() {
child.value = Some(value);
child.children = vec![existing];
} else {
let new_node = BuilderNode {
label: Box::from(&rem[common..]),
value: Some(value),
children: Vec::new(),
};
child.children = if existing.label[..].cmp(&new_node.label[..])
== std::cmp::Ordering::Less
{
vec![existing, new_node]
} else {
vec![new_node, existing]
};
}
None
}
}
}
fn remove(&mut self, rem: &[u8]) -> Option<T> {
if rem.is_empty() {
return self.value.take();
}
let idx = self.child_index(rem).ok()?;
if !rem.starts_with(&self.children[idx].label) {
return None;
}
let label_len = self.children[idx].label.len();
let removed = self.children[idx].remove(&rem[label_len..])?;
let child = &self.children[idx];
if child.value.is_none() {
if child.children.is_empty() {
self.children.remove(idx);
} else if child.children.len() == 1 {
let mut removed_child = self.children.remove(idx);
let mut grandchild = removed_child.children.pop().unwrap();
let mut merged =
Vec::with_capacity(removed_child.label.len() + grandchild.label.len());
merged.extend_from_slice(&removed_child.label);
merged.extend_from_slice(&grandchild.label);
grandchild.label = merged.into_boxed_slice();
self.children.insert(idx, grandchild);
}
}
Some(removed)
}
}
fn lcp(a: &[u8], b: &[u8]) -> usize {
let mut i = a.iter().zip(b.iter()).take_while(|(x, y)| x == y).count();
while i > 0 && i < a.len() && (a[i] & 0xC0) == 0x80 {
i -= 1;
}
i
}
impl<K: AsRef<str>, T> FromIterator<(K, T)> for CommandTrieBuilder<T> {
fn from_iter<I: IntoIterator<Item = (K, T)>>(iter: I) -> Self {
let mut t = Self::new();
t.extend(iter);
t
}
}
impl<K: AsRef<str>, T> Extend<(K, T)> for CommandTrieBuilder<T> {
fn extend<I: IntoIterator<Item = (K, T)>>(&mut self, iter: I) {
for (k, v) in iter {
self.insert(k.as_ref(), v);
}
}
}
impl<T: fmt::Debug> fmt::Debug for CommandTrieBuilder<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CommandTrieBuilder")
.field("len", &self.len)
.field("root", &self.root)
.finish()
}
}
impl<T: fmt::Debug> fmt::Debug for BuilderNode<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BuilderNode")
.field("label", &unsafe { utf8_str_unchecked(&self.label) })
.field("value", &self.value)
.field("children", &self.children)
.finish()
}
}
type NodeId = u16;
const ROOT: NodeId = 0;
#[derive(Clone)]
pub struct CommandTrie<T> {
nodes: Box<[FrozenNode<T>]>,
labels: Box<[u8]>,
children: Box<[NodeId]>,
child_first_bytes: Box<[u8]>,
len: usize,
}
#[derive(Clone)]
struct FrozenNode<T> {
label_start: u16,
label_len: u16,
children_start: u16,
children_len: u16,
value: Option<T>,
}
impl<T> CommandTrie<T> {
#[must_use]
pub fn len(&self) -> usize {
self.len
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
fn label_of(&self, id: NodeId) -> &[u8] {
unsafe {
let n = self.nodes.get_unchecked(id as usize);
let start = n.label_start as usize;
let end = start + n.label_len as usize;
self.labels.get_unchecked(start..end)
}
}
#[inline]
fn children_of(&self, id: NodeId) -> &[NodeId] {
unsafe {
let n = self.nodes.get_unchecked(id as usize);
let start = n.children_start as usize;
let end = start + n.children_len as usize;
self.children.get_unchecked(start..end)
}
}
#[inline]
fn value_of(&self, id: NodeId) -> Option<&T> {
unsafe { self.nodes.get_unchecked(id as usize).value.as_ref() }
}
#[inline]
fn find_child(&self, parent: NodeId, rem: &[u8]) -> Option<NodeId> {
unsafe {
let n = self.nodes.get_unchecked(parent as usize);
let start = n.children_start as usize;
let end = start + n.children_len as usize;
let first = *rem.get_unchecked(0);
let slab = self.child_first_bytes.get_unchecked(start..end);
let idx = slab.binary_search(&first).ok()?;
if first < 0x80 {
return Some(*self.children.get_unchecked(start + idx));
}
self.find_child_multibyte(start, slab, idx, first, rem)
}
}
#[cold]
#[inline(never)]
fn find_child_multibyte(
&self,
start: usize,
slab: &[u8],
idx: usize,
first: u8,
rem: &[u8],
) -> Option<NodeId> {
let clen = utf8_char_len(first);
debug_assert!(rem.len() >= clen);
unsafe {
let needle = rem.get_unchecked(..clen);
let mut lo = idx;
while lo > 0 && *slab.get_unchecked(lo - 1) == first {
lo -= 1;
}
let mut i = lo;
while i < slab.len() && *slab.get_unchecked(i) == first {
let child = *self.children.get_unchecked(start + i);
let lbl = self.label_of(child);
if lbl.len() >= clen && lbl.get_unchecked(..clen) == needle {
return Some(child);
}
i += 1;
}
None
}
}
#[must_use]
pub fn get(&self, key: &str) -> Option<&T> {
let mut node = ROOT;
let mut rem = key.as_bytes();
loop {
if rem.is_empty() {
return self.value_of(node);
}
let child = self.find_child(node, rem)?;
let lbl = self.label_of(child);
if !rem.starts_with(lbl) {
return None;
}
rem = &rem[lbl.len()..];
node = child;
}
}
#[must_use]
pub fn contains(&self, key: &str) -> bool {
self.get(key).is_some()
}
#[must_use]
pub fn longest_prefix_match<'a>(&self, input: &'a str) -> Option<(&'a str, &T)> {
let bytes = input.as_bytes();
let mut node = ROOT;
let mut consumed = 0usize;
let mut best: Option<(usize, &T)> = None;
loop {
if let Some(v) = self.value_of(node) {
best = Some((consumed, v));
}
let rem = &bytes[consumed..];
if rem.is_empty() {
break;
}
let Some(child) = self.find_child(node, rem) else {
break;
};
let lbl = self.label_of(child);
if !rem.starts_with(lbl) {
break;
}
consumed += lbl.len();
node = child;
}
best.map(|(n, v)| (&input[..n], v))
}
#[must_use]
pub fn contains_prefix(&self, prefix: &str) -> bool {
match self.descend_to_node(prefix.as_bytes()) {
Some(node) => self.value_of(node).is_some() || !self.children_of(node).is_empty(),
None => false,
}
}
fn descend_to_node(&self, mut rem: &[u8]) -> Option<NodeId> {
let mut node = ROOT;
while !rem.is_empty() {
let child = self.find_child(node, rem)?;
let lbl = self.label_of(child);
if rem.len() >= lbl.len() {
if !rem.starts_with(lbl) {
return None;
}
rem = &rem[lbl.len()..];
node = child;
} else {
if !lbl.starts_with(rem) {
return None;
}
node = child;
break;
}
}
Some(node)
}
fn descend_to_prefix(&self, mut rem: &[u8]) -> Option<(NodeId, Vec<u8>)> {
let mut node = ROOT;
let mut path: Vec<u8> = Vec::with_capacity(rem.len());
while !rem.is_empty() {
let child = self.find_child(node, rem)?;
let lbl = self.label_of(child);
if rem.len() >= lbl.len() {
if !rem.starts_with(lbl) {
return None;
}
path.extend_from_slice(lbl);
rem = &rem[lbl.len()..];
node = child;
} else {
if !lbl.starts_with(rem) {
return None;
}
path.extend_from_slice(lbl);
node = child;
break;
}
}
Some((node, path))
}
#[must_use]
pub fn iter(&self) -> Iter<'_, T> {
Iter::new(self, ROOT, Vec::new())
}
pub fn for_each(&self, mut f: impl FnMut(&str, &T)) {
let mut buf = Vec::new();
for_each_descendants(self, ROOT, &mut buf, &mut f);
}
#[must_use]
pub fn subtrie<'a>(&'a self, prefix: &str) -> Option<SubTrie<'a, T>> {
let (mut node, mut path) = self.descend_to_prefix(prefix.as_bytes())?;
if self.value_of(node).is_none() && self.children_of(node).is_empty() {
return None;
}
loop {
let kids = self.children_of(node);
if self.value_of(node).is_none() && kids.len() == 1 {
let child = kids[0];
path.extend_from_slice(self.label_of(child));
node = child;
} else {
break;
}
}
Some(SubTrie {
trie: self,
node,
query_len: prefix.len(),
common_prefix: path,
})
}
#[must_use]
pub fn completions<'a>(&'a self, prefix: &str) -> Vec<(String, &'a T)> {
match self.subtrie(prefix) {
Some(sub) => sub.into_iter().collect(),
None => Vec::new(),
}
}
#[must_use]
pub fn count_completions(&self, prefix: &str) -> usize {
match self.descend_to_node(prefix.as_bytes()) {
Some(node) => count_values(self, node),
None => 0,
}
}
#[must_use]
pub fn completion_prefix(&self, prefix: &str) -> Option<String> {
let mut rem = prefix.as_bytes();
let mut node = ROOT;
let mut buf: Vec<u8> = Vec::with_capacity(rem.len());
while !rem.is_empty() {
let child = self.find_child(node, rem)?;
let lbl = self.label_of(child);
if rem.len() >= lbl.len() {
if !rem.starts_with(lbl) {
return None;
}
buf.extend_from_slice(lbl);
rem = &rem[lbl.len()..];
node = child;
} else {
if !lbl.starts_with(rem) {
return None;
}
buf.extend_from_slice(lbl);
node = child;
break;
}
}
if self.value_of(node).is_none() && self.children_of(node).is_empty() {
return None;
}
while self.value_of(node).is_none() && self.children_of(node).len() == 1 {
let child = self.children_of(node)[0];
buf.extend_from_slice(self.label_of(child));
node = child;
}
Some(unsafe { utf8_string_unchecked(buf) })
}
pub fn for_each_completion(&self, prefix: &str, mut f: impl FnMut(&str, &T)) {
if let Some(sub) = self.subtrie(prefix) {
sub.for_each(&mut f);
}
}
}
impl<T: fmt::Debug> fmt::Debug for CommandTrie<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CommandTrie")
.field("len", &self.len)
.field("nodes", &self.nodes.len())
.field("labels_bytes", &self.labels.len())
.field("children_edges", &self.children.len())
.finish_non_exhaustive()
}
}
impl<T> Default for CommandTrie<T> {
fn default() -> Self {
CommandTrieBuilder::new().build()
}
}
impl<'a, T> IntoIterator for &'a CommandTrie<T> {
type Item = (String, &'a T);
type IntoIter = Iter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[derive(Clone)]
pub struct SubTrie<'a, T> {
trie: &'a CommandTrie<T>,
node: NodeId,
query_len: usize,
common_prefix: Vec<u8>,
}
impl<'a, T> SubTrie<'a, T> {
#[must_use]
pub fn common_prefix(&self) -> &str {
unsafe { utf8_str_unchecked(&self.common_prefix) }
}
#[must_use]
pub fn extension(&self) -> &str {
unsafe { utf8_str_unchecked(&self.common_prefix[self.query_len..]) }
}
#[must_use]
pub fn is_unique(&self) -> bool {
self.trie.value_of(self.node).is_some() && self.trie.children_of(self.node).is_empty()
}
#[must_use]
pub fn value(&self) -> Option<&'a T> {
self.trie.value_of(self.node)
}
#[must_use]
pub fn unique_value(&self) -> Option<&'a T> {
if self.is_unique() {
self.value()
} else {
None
}
}
#[must_use]
pub fn len(&self) -> usize {
count_values(self.trie, self.node)
}
#[must_use]
pub fn is_empty(&self) -> bool {
false
}
#[must_use]
pub fn iter(&self) -> Iter<'a, T> {
Iter::new(self.trie, self.node, self.common_prefix.clone())
}
pub fn for_each(&self, mut f: impl FnMut(&str, &T)) {
let mut buf = self.common_prefix.clone();
if let Some(v) = self.trie.value_of(self.node) {
f(unsafe { utf8_str_unchecked(&buf) }, v);
}
for &child in self.trie.children_of(self.node) {
for_each_descendants(self.trie, child, &mut buf, &mut f);
}
}
}
impl<'a, T> IntoIterator for &SubTrie<'a, T> {
type Item = (String, &'a T);
type IntoIter = Iter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, T> IntoIterator for SubTrie<'a, T> {
type Item = (String, &'a T);
type IntoIter = Iter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
Iter::new(self.trie, self.node, self.common_prefix)
}
}
impl<T> fmt::Debug for SubTrie<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SubTrie")
.field("common_prefix", &self.common_prefix())
.field("is_unique", &self.is_unique())
.finish()
}
}
pub struct Iter<'a, T> {
trie: &'a CommandTrie<T>,
stack: Vec<Frame>,
path: Vec<u8>,
pending_root: Option<NodeId>,
remaining: usize,
}
enum Frame {
Enter(NodeId),
Exit(u16),
}
impl<'a, T> Iter<'a, T> {
fn new(trie: &'a CommandTrie<T>, root: NodeId, initial_path: Vec<u8>) -> Self {
let mut stack = Vec::new();
let kids = trie.children_of(root);
for &child in kids.iter().rev() {
stack.push(Frame::Enter(child));
}
let pending_root = if trie.value_of(root).is_some() {
Some(root)
} else {
None
};
let remaining = count_values(trie, root);
Self {
trie,
stack,
path: initial_path,
pending_root,
remaining,
}
}
}
impl<'a, T> Iterator for Iter<'a, T> {
type Item = (String, &'a T);
fn next(&mut self) -> Option<Self::Item> {
if let Some(id) = self.pending_root.take() {
let v = self.trie.value_of(id).expect("pending_root has a value");
self.remaining -= 1;
return Some((unsafe { utf8_string_unchecked(self.path.clone()) }, v));
}
while let Some(frame) = self.stack.pop() {
match frame {
Frame::Exit(n) => {
let new_len = self.path.len() - n as usize;
self.path.truncate(new_len);
}
Frame::Enter(node) => {
let lbl = self.trie.label_of(node);
self.path.extend_from_slice(lbl);
self.stack
.push(Frame::Exit(u16::try_from(lbl.len()).unwrap()));
for &child in self.trie.children_of(node).iter().rev() {
self.stack.push(Frame::Enter(child));
}
if let Some(v) = self.trie.value_of(node) {
self.remaining -= 1;
return Some((unsafe { utf8_string_unchecked(self.path.clone()) }, v));
}
}
}
}
None
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
impl<T> ExactSizeIterator for Iter<'_, T> {
#[inline]
fn len(&self) -> usize {
self.remaining
}
}
impl<T> FusedIterator for Iter<'_, T> {}
impl<T> fmt::Debug for Iter<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Iter")
.field("remaining_frames", &self.stack.len())
.finish()
}
}
fn for_each_descendants<T>(
trie: &CommandTrie<T>,
node: NodeId,
buf: &mut Vec<u8>,
f: &mut impl FnMut(&str, &T),
) {
let prev = buf.len();
buf.extend_from_slice(trie.label_of(node));
if let Some(v) = trie.value_of(node) {
f(unsafe { utf8_str_unchecked(buf) }, v);
}
for &child in trie.children_of(node) {
for_each_descendants(trie, child, buf, f);
}
buf.truncate(prev);
}
fn count_values<T>(trie: &CommandTrie<T>, node: NodeId) -> usize {
let mut n = usize::from(trie.value_of(node).is_some());
for &child in trie.children_of(node) {
n += count_values(trie, child);
}
n
}
#[cfg(test)]
mod tests {
use super::*;
fn build_from<'a, I: IntoIterator<Item = (&'a str, i32)>>(items: I) -> CommandTrie<i32> {
let mut b = CommandTrieBuilder::new();
for (k, v) in items {
b.insert(k, v);
}
b.build()
}
#[test]
fn builder_insert_overwrite_remove() {
let mut b = CommandTrieBuilder::new();
assert_eq!(b.insert("commit", 1), None);
assert_eq!(b.insert("commit", 2), Some(1));
assert_eq!(b.get("commit"), Some(&2));
assert!(b.contains("commit"));
assert_eq!(b.remove("commit"), Some(2));
assert_eq!(b.remove("commit"), None);
assert!(b.is_empty());
}
#[test]
fn builder_remove_prunes_and_merges() {
let mut b = CommandTrieBuilder::new();
b.insert("command", 1);
b.insert("commit", 2);
b.insert("comm", 3);
assert_eq!(b.remove("comm"), Some(3));
assert_eq!(b.remove("commit"), Some(2));
assert_eq!(b.get("command"), Some(&1));
assert_eq!(b.len(), 1);
}
#[test]
fn builder_from_iter_and_extend() {
let b: CommandTrieBuilder<i32> = [("a", 1), ("ab", 2), ("abc", 3)].into_iter().collect();
assert_eq!(b.len(), 3);
assert_eq!(b.get("ab"), Some(&2));
}
#[test]
fn builder_get_diverges_mid_edge() {
let mut b = CommandTrieBuilder::new();
b.insert("command", 1);
assert_eq!(b.get("comx"), None);
assert_eq!(b.get("c"), None);
}
#[test]
fn frozen_get() {
let t = build_from([("commit", 1), ("command", 2)]);
assert_eq!(t.get("commit"), Some(&1));
assert_eq!(t.get("command"), Some(&2));
assert_eq!(t.get("comm"), None);
assert_eq!(t.get("commits"), None);
assert_eq!(t.get(""), None);
assert!(t.contains("commit"));
assert!(!t.contains("comm"));
}
#[test]
fn frozen_query_accepts_non_ascii_keys() {
let t = build_from([("commit", 1), ("command", 2), ("config", 3)]);
assert_eq!(t.get("café"), None);
assert!(!t.contains("café"));
assert_eq!(t.get("commité"), None);
assert_eq!(t.get("comméand"), None);
assert_eq!(t.get("🦀"), None);
assert_eq!(t.get("π"), None);
assert!(!t.contains_prefix("café"));
assert_eq!(t.count_completions("café"), 0);
assert!(t.completions("café").is_empty());
assert_eq!(t.completion_prefix("café"), None);
assert!(t.subtrie("café").is_none());
assert_eq!(t.longest_prefix_match("café"), None);
assert_eq!(t.longest_prefix_match("commit é"), Some(("commit", &1)),);
}
#[test]
fn frozen_empty_trie() {
let t = CommandTrieBuilder::<i32>::new().build();
assert_eq!(t.len(), 0);
assert!(t.is_empty());
assert_eq!(t.get(""), None);
assert_eq!(t.get("anything"), None);
assert!(!t.contains_prefix(""));
assert!(t.subtrie("").is_none());
assert_eq!(t.completion_prefix(""), None);
assert_eq!(t.count_completions(""), 0);
assert!(t.completions("").is_empty());
assert_eq!(t.iter().count(), 0);
}
#[test]
fn frozen_longest_prefix_match() {
let t = build_from([("git", 1), ("git-status", 2)]);
assert_eq!(
t.longest_prefix_match("git-status --short"),
Some(("git-status", &2))
);
assert_eq!(t.longest_prefix_match("git foo"), Some(("git", &1)));
assert_eq!(t.longest_prefix_match("git"), Some(("git", &1)));
assert_eq!(t.longest_prefix_match("gi"), None);
assert_eq!(t.longest_prefix_match("zzz"), None);
assert_eq!(t.longest_prefix_match(""), None);
let (matched, _) = t.longest_prefix_match("git-status xyz").unwrap();
assert_eq!(matched.len(), 10);
}
#[test]
fn frozen_contains_prefix() {
let t = build_from([("commit", 1)]);
assert!(t.contains_prefix(""));
assert!(t.contains_prefix("c"));
assert!(t.contains_prefix("comm"));
assert!(t.contains_prefix("commit"));
assert!(!t.contains_prefix("commits"));
assert!(!t.contains_prefix("d"));
}
#[test]
fn frozen_completions() {
let t = build_from([("commit", 1), ("command", 2), ("config", 3), ("clone", 4)]);
let mut got = t.completions("comm");
got.sort_by(|a, b| a.0.cmp(&b.0));
assert_eq!(
got,
vec![("command".to_string(), &2), ("commit".to_string(), &1)]
);
assert_eq!(t.completions("").len(), 4);
assert!(t.completions("xyz").is_empty());
}
#[test]
fn frozen_completions_prefix_ends_mid_edge() {
let t = build_from([("command", 1)]);
let got = t.completions("co");
assert_eq!(got, vec![("command".to_string(), &1)]);
}
#[test]
fn frozen_completion_prefix_extends_past_query() {
let t = build_from([("command", 1), ("commit", 2)]);
assert_eq!(t.completion_prefix("c").as_deref(), Some("comm"));
let t = build_from([("command", 1)]);
assert_eq!(t.completion_prefix("").as_deref(), Some("command"));
assert_eq!(t.completion_prefix("c").as_deref(), Some("command"));
assert_eq!(t.completion_prefix("commits"), None);
}
#[test]
fn frozen_subtrie_views() {
let t = build_from([("commit", 1), ("command", 2), ("config", 3)]);
let sub = t.subtrie("comm").unwrap();
assert_eq!(sub.common_prefix(), "comm");
assert_eq!(sub.len(), 2);
assert!(!sub.is_empty());
let mut via_iter: Vec<(String, i32)> = sub.iter().map(|(k, v)| (k, *v)).collect();
via_iter.sort();
let mut via_for_each: Vec<(String, i32)> = Vec::new();
sub.for_each(|k, v| via_for_each.push((k.to_string(), *v)));
via_for_each.sort();
assert_eq!(via_iter, via_for_each);
assert_eq!(
via_iter,
vec![("command".to_string(), 2), ("commit".to_string(), 1)]
);
let owned: Vec<_> = sub.into_iter().collect();
assert_eq!(owned.len(), 2);
}
#[test]
fn frozen_subtrie_on_exact_leaf() {
let t = build_from([("commit", 1), ("command", 2)]);
let sub = t.subtrie("commit").unwrap();
assert_eq!(sub.common_prefix(), "commit");
assert_eq!(sub.len(), 1);
}
#[test]
fn frozen_subtrie_extension_and_is_unique() {
let t = build_from([("commit", 1), ("command", 2), ("config", 3), ("clone", 4)]);
let sub = t.subtrie("c").unwrap();
assert_eq!(sub.common_prefix(), "c");
assert_eq!(sub.extension(), "");
assert!(!sub.is_unique());
let sub = t.subtrie("co").unwrap();
assert_eq!(sub.common_prefix(), "co");
assert_eq!(sub.extension(), "");
assert!(!sub.is_unique());
let sub = t.subtrie("comma").unwrap();
assert_eq!(sub.common_prefix(), "command");
assert_eq!(sub.extension(), "nd");
assert!(sub.is_unique());
let sub = t.subtrie("clone").unwrap();
assert_eq!(sub.extension(), "");
assert!(sub.is_unique());
let t2 = build_from([("git", 1), ("github", 2)]);
let sub = t2.subtrie("gi").unwrap();
assert_eq!(sub.common_prefix(), "git");
assert_eq!(sub.extension(), "t");
assert!(!sub.is_unique()); }
#[test]
fn frozen_subtrie_value_and_unique_value() {
let t = build_from([("commit", 1), ("command", 2), ("config", 3)]);
let sub = t.subtrie("com").unwrap();
assert_eq!(sub.value(), None);
assert_eq!(sub.unique_value(), None);
let sub = t.subtrie("commi").unwrap();
assert_eq!(sub.common_prefix(), "commit");
assert_eq!(sub.value(), Some(&1));
assert_eq!(sub.unique_value(), Some(&1));
let t2 = build_from([("git", 10), ("github", 20)]);
let sub = t2.subtrie("gi").unwrap();
assert_eq!(sub.common_prefix(), "git");
assert_eq!(sub.value(), Some(&10));
assert_eq!(sub.unique_value(), None);
}
#[test]
fn frozen_iter_alphabetical() {
let t = build_from([("commit", 1), ("command", 2), ("config", 3), ("clone", 4)]);
let got: Vec<_> = t.iter().map(|(k, v)| (k, *v)).collect();
assert_eq!(
got,
vec![
("clone".to_string(), 4),
("command".to_string(), 2),
("commit".to_string(), 1),
("config".to_string(), 3),
]
);
let mut n = 0;
for _ in &t {
n += 1;
}
assert_eq!(n, 4);
}
#[test]
fn frozen_for_each_no_alloc() {
let t = build_from([("a", 1), ("ab", 2), ("abc", 3)]);
let mut got: Vec<(String, i32)> = Vec::new();
t.for_each(|k, v| got.push((k.to_string(), *v)));
got.sort();
assert_eq!(
got,
vec![
("a".to_string(), 1),
("ab".to_string(), 2),
("abc".to_string(), 3),
]
);
}
#[test]
fn frozen_count_and_for_each_completion() {
let t = build_from([("commit", 1), ("command", 2), ("config", 3), ("clone", 4)]);
assert_eq!(t.count_completions("c"), 4);
assert_eq!(t.count_completions("comm"), 2);
assert_eq!(t.count_completions("commit"), 1);
assert_eq!(t.count_completions("z"), 0);
let mut got: Vec<String> = Vec::new();
t.for_each_completion("comm", |k, _| got.push(k.to_string()));
got.sort();
assert_eq!(got, vec!["command".to_string(), "commit".to_string()]);
}
#[test]
fn build_packs_into_four_allocations() {
let t = build_from([
("add", 0),
("alias", 1),
("branch", 2),
("checkout", 3),
("cherry-pick", 4),
("clean", 5),
("clone", 6),
("commit", 7),
("command", 8),
("config", 9),
]);
for (i, k) in [
"add",
"alias",
"branch",
"checkout",
"cherry-pick",
"clean",
"clone",
"commit",
"command",
"config",
]
.iter()
.enumerate()
{
assert_eq!(t.get(k), Some(&(i as i32)));
}
assert_eq!(t.child_first_bytes.len(), t.children.len());
for (i, &child) in t.children.iter().enumerate() {
let lbl0 = t.label_of(child)[0];
assert_eq!(t.child_first_bytes[i], lbl0);
}
for id in 0..t.nodes.len() as NodeId {
let kids = t.children_of(id);
for w in kids.windows(2) {
let a = t.label_of(w[0])[0];
let b = t.label_of(w[1])[0];
assert!(a < b, "siblings not sorted at node {id}: {a} >= {b}");
}
}
}
const _: fn() = || {
fn assert_send_sync<X: Send + Sync>() {}
assert_send_sync::<CommandTrieBuilder<i32>>();
assert_send_sync::<CommandTrie<i32>>();
assert_send_sync::<SubTrie<'static, i32>>();
assert_send_sync::<Iter<'static, i32>>();
};
#[test]
fn builder_default_and_clear() {
let mut b: CommandTrieBuilder<i32> = CommandTrieBuilder::default();
assert!(b.is_empty());
b.insert("a", 1);
b.insert("ab", 2);
assert_eq!(b.len(), 2);
b.clear();
assert!(b.is_empty());
assert_eq!(b.get("a"), None);
}
#[test]
fn trie_default_is_empty() {
let t: CommandTrie<i32> = CommandTrie::default();
assert!(t.is_empty());
assert_eq!(t.get("anything"), None);
assert_eq!(t.iter().count(), 0);
}
#[test]
fn debug_impls_render() {
let mut b = CommandTrieBuilder::new();
b.insert("commit", 1);
b.insert("command", 2);
let s = format!("{b:?}");
assert!(s.contains("CommandTrieBuilder"));
assert!(s.contains("BuilderNode"));
let t = b.build();
let s = format!("{t:?}");
assert!(s.contains("CommandTrie"));
assert!(s.contains("len"));
let sub = t.subtrie("comm").unwrap();
let s = format!("{sub:?}");
assert!(s.contains("SubTrie"));
assert!(s.contains("comm"));
let it = t.iter();
let s = format!("{it:?}");
assert!(s.contains("Iter"));
}
#[test]
fn subtrie_ref_into_iter() {
let t = build_from([("commit", 1), ("command", 2)]);
let sub = t.subtrie("comm").unwrap();
let from_ref: Vec<_> = (&sub).into_iter().collect();
assert_eq!(from_ref.len(), 2);
assert_eq!(sub.len(), 2);
}
#[test]
fn subtrie_for_each_emits_starting_node_value() {
let t = build_from([("git", 10), ("github", 20)]);
let sub = t.subtrie("git").unwrap();
let mut got: Vec<(String, i32)> = Vec::new();
sub.for_each(|k, v| got.push((k.to_string(), *v)));
got.sort();
assert_eq!(
got,
vec![("git".to_string(), 10), ("github".to_string(), 20)]
);
}
#[test]
fn insert_32k_entries_no_panic() {
const PREFIXES: &[&str] = &[
"git-",
"cargo-",
"docker-",
"kubectl-",
"npm-",
"pip-",
"rustup-",
"systemctl-",
"journalctl-",
"ip-",
"nmcli-",
"brew-",
];
const STEMS: &[&str] = &[
"list", "get", "set", "show", "describe", "create", "delete", "update", "apply",
"watch", "rollout", "exec", "logs", "status", "info", "config", "scale", "patch",
"expose", "annotate",
];
const BUCKETS: u32 = 134;
const N: u32 = PREFIXES.len() as u32 * STEMS.len() as u32 * BUCKETS;
const _: () = assert!(N >= 32_000, "test corpus must hit the documented ~32k cap");
fn key(n: u32, buf: &mut String) {
use std::fmt::Write;
buf.clear();
let p = (n as usize) % PREFIXES.len();
let s = ((n as usize) / PREFIXES.len()) % STEMS.len();
let bucket = (n as usize) / (PREFIXES.len() * STEMS.len());
buf.push_str(PREFIXES[p]);
buf.push_str(STEMS[s]);
buf.push('-');
write!(buf, "{bucket:03}").unwrap();
}
let mut b: CommandTrieBuilder<u32> = CommandTrieBuilder::new();
let mut buf = String::new();
for i in 0..N {
key(i, &mut buf);
b.insert(&buf, i);
}
assert_eq!(b.len(), N as usize);
let t = b.build();
assert_eq!(t.len(), N as usize);
for &i in &[0u32, 1, 11, 12, 239, 240, 1023, 12345, N / 2, N - 1] {
key(i, &mut buf);
assert_eq!(t.get(&buf), Some(&i), "lookup failed for {i}");
}
}
#[test]
fn iter_is_fused() {
fn assert_fused<I: FusedIterator>(_: &I) {}
let t = CommandTrieBuilder::<i32>::new().build();
let it = t.iter();
assert_fused(&it);
}
#[test]
fn iter_is_exact_size() {
fn assert_exact<I: ExactSizeIterator>(_: &I) {}
let t = build_from([("commit", 1), ("command", 2), ("config", 3), ("clone", 4)]);
let it = t.iter();
assert_exact(&it);
assert_eq!(it.len(), 4);
assert_eq!(it.size_hint(), (4, Some(4)));
let mut it = t.iter();
for expected in (0..4).rev() {
assert!(it.next().is_some());
assert_eq!(it.len(), expected);
assert_eq!(it.size_hint(), (expected, Some(expected)));
}
assert!(it.next().is_none());
assert_eq!(it.len(), 0);
let t2 = build_from([("git", 10), ("github", 20), ("gitlab", 30)]);
let sub = t2.subtrie("git").unwrap();
let it = sub.iter();
assert_eq!(it.len(), 3);
let collected: Vec<_> = sub.into_iter().collect();
assert_eq!(collected.len(), 3);
let empty: CommandTrie<i32> = CommandTrieBuilder::new().build();
assert_eq!(empty.iter().len(), 0);
}
#[test]
fn fuzz_against_btreemap() {
use std::collections::BTreeMap;
let mut state: u64 = 0x_dead_beef_cafe_f00d;
let mut rand = || {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
state
};
let keys = [
"", "a", "ab", "abc", "abd", "abcd", "abce", "b", "ba", "bar", "baz", "c", "co", "com",
"comm", "command", "commit", "config", "x", "xy", "xyz",
];
let probe_prefixes = [
"", "a", "ab", "abc", "abz", "b", "c", "comm", "com", "x", "z",
];
let mut builder: CommandTrieBuilder<i32> = CommandTrieBuilder::new();
let mut model: BTreeMap<String, i32> = BTreeMap::new();
for op in 0..500 {
let r = rand();
let key = keys[(r as usize) % keys.len()];
if (r >> 32) % 4 == 0 {
let b_prev = builder.remove(key);
let m_prev = model.remove(key);
assert_eq!(b_prev, m_prev, "remove({key:?}) at op {op}");
} else {
let v = (r >> 8) as i32;
let b_prev = builder.insert(key, v);
let m_prev = model.insert(key.to_string(), v);
assert_eq!(b_prev, m_prev, "insert({key:?}, {v}) at op {op}");
}
assert_eq!(builder.len(), model.len(), "len at op {op}");
let trie = builder.clone().build();
assert_eq!(trie.len(), model.len());
for k in &keys {
assert_eq!(trie.get(k), model.get(*k), "get({k:?}) at op {op}");
assert_eq!(trie.contains(k), model.contains_key(*k));
}
for pfx in &probe_prefixes {
let mut from_trie: Vec<String> =
trie.completions(pfx).into_iter().map(|(k, _)| k).collect();
from_trie.sort();
let mut from_model: Vec<String> = model
.keys()
.filter(|k| k.starts_with(pfx))
.cloned()
.collect();
from_model.sort();
assert_eq!(from_trie, from_model, "completions({pfx:?}) at op {op}");
assert_eq!(
trie.count_completions(pfx),
from_model.len(),
"count_completions({pfx:?}) at op {op}"
);
if let Some(cp) = trie.completion_prefix(pfx) {
assert!(cp.starts_with(pfx));
for k in &from_model {
assert!(k.starts_with(&cp));
}
} else {
assert!(from_model.is_empty());
}
}
let from_trie: Vec<(String, i32)> = trie.iter().map(|(k, v)| (k, *v)).collect();
let from_model: Vec<(String, i32)> =
model.iter().map(|(k, v)| (k.clone(), *v)).collect();
assert_eq!(from_trie, from_model, "iter at op {op}");
}
}
#[test]
fn utf8_basic_insert_get() {
let t = build_from([
("café", 1),
("über", 2),
("naïve", 3),
("naïveté", 4),
("🦀", 5),
("π", 6),
]);
assert_eq!(t.get("café"), Some(&1));
assert_eq!(t.get("über"), Some(&2));
assert_eq!(t.get("naïve"), Some(&3));
assert_eq!(t.get("naïveté"), Some(&4));
assert_eq!(t.get("🦀"), Some(&5));
assert_eq!(t.get("π"), Some(&6));
assert_eq!(t.get("cafe"), None);
assert_eq!(t.get("naïv"), None);
}
#[test]
fn utf8_shared_first_byte_siblings() {
let t = build_from([("éa", 1), ("êb", 2), ("èc", 3), ("ad", 4)]);
assert_eq!(t.get("éa"), Some(&1));
assert_eq!(t.get("êb"), Some(&2));
assert_eq!(t.get("èc"), Some(&3));
assert_eq!(t.get("ad"), Some(&4));
assert_eq!(t.get("éb"), None);
assert_eq!(t.get("ê"), None);
assert!(t.contains_prefix("é"));
assert!(t.contains_prefix("ê"));
assert!(!t.contains_prefix("ë"));
}
#[test]
fn utf8_split_at_shared_codepoint() {
let t = build_from([("éa", 1), ("éb", 2)]);
assert_eq!(t.get("éa"), Some(&1));
assert_eq!(t.get("éb"), Some(&2));
let sub = t.subtrie("é").expect("prefix 'é' should exist");
assert_eq!(sub.common_prefix(), "é");
assert_eq!(sub.len(), 2);
}
#[test]
fn utf8_sort_order_matches_btreemap() {
use std::collections::BTreeMap;
let pairs: Vec<(&str, i32)> = vec![
("apple", 1),
("café", 2),
("cab", 3),
("über", 4),
("ünder", 5),
("naïve", 6),
("naive", 7),
("🦀rust", 8),
("🦀", 9),
("π", 10),
("zoo", 11),
];
let model: BTreeMap<&str, i32> = pairs.iter().copied().collect();
let t = build_from(pairs.iter().copied());
let from_trie: Vec<(String, i32)> = t.iter().map(|(k, v)| (k, *v)).collect();
let from_model: Vec<(String, i32)> =
model.iter().map(|(k, v)| (k.to_string(), *v)).collect();
assert_eq!(from_trie, from_model);
}
#[test]
fn utf8_completion_prefix_extends_through_char() {
let t = build_from([("naïveté", 1), ("zzz", 2)]);
assert_eq!(t.completion_prefix("n").as_deref(), Some("naïveté"));
assert_eq!(t.completion_prefix("naï").as_deref(), Some("naïveté"));
}
#[test]
fn utf8_longest_prefix_match() {
let t = build_from([("café", 1), ("ca", 2)]);
assert_eq!(t.longest_prefix_match("café au lait"), Some(("café", &1)));
assert_eq!(t.longest_prefix_match("cab"), Some(("ca", &2)));
assert_eq!(t.longest_prefix_match("caf"), Some(("ca", &2)));
}
#[test]
fn utf8_remove_and_reinsert() {
let mut b = CommandTrieBuilder::new();
b.insert("café", 1);
b.insert("naïve", 2);
assert_eq!(b.remove("café"), Some(1));
assert_eq!(b.get("café"), None);
assert_eq!(b.get("naïve"), Some(&2));
b.insert("café", 11);
let t = b.build();
assert_eq!(t.get("café"), Some(&11));
assert_eq!(t.get("naïve"), Some(&2));
}
#[test]
fn utf8_iter_roundtrip_emoji_heavy() {
let keys = ["🦀", "🦀rust", "🦀🦀", "🔥", "🔥fire", "ascii"];
let mut b = CommandTrieBuilder::new();
for (i, k) in keys.iter().enumerate() {
b.insert(k, i as i32);
}
let t = b.build();
for (i, k) in keys.iter().enumerate() {
assert_eq!(t.get(k), Some(&(i as i32)), "lookup {k}");
}
let collected: Vec<String> = t.iter().map(|(k, _)| k).collect();
let mut expected: Vec<String> = keys.iter().map(|s| s.to_string()).collect();
expected.sort();
assert_eq!(collected, expected);
}
#[test]
fn utf8_three_byte_char_paths() {
let mut b = CommandTrieBuilder::new();
b.insert("中", 1);
b.insert("中a", 2);
b.insert("中b", 3);
b.insert("間", 4);
let t = b.build();
assert_eq!(t.get("中"), Some(&1));
assert_eq!(t.get("中a"), Some(&2));
assert_eq!(t.get("中b"), Some(&3));
assert_eq!(t.get("間"), Some(&4));
assert_eq!(t.get("中c"), None);
assert_eq!(t.longest_prefix_match("中ax"), Some(("中a", &2)),);
}
#[test]
fn utf8_lcp_backs_off_into_codepoint_boundary() {
let mut b = CommandTrieBuilder::new();
b.insert("Xé", 1);
b.insert("Xè", 2);
let t = b.build();
assert_eq!(t.get("Xé"), Some(&1));
assert_eq!(t.get("Xè"), Some(&2));
let keys: Vec<String> = t.iter().map(|(k, _)| k).collect();
assert_eq!(keys, vec!["Xè".to_string(), "Xé".to_string()]);
}
}