use crate::error::{Result, ZiporaError};
#[repr(C)]
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct DaState {
child0: u32,
parent: u32,
}
const FREE_BIT: u32 = 0x8000_0000;
const VALUE_MASK: u32 = 0x7FFF_FFFF;
const NIL_STATE: u32 = 0x7FFF_FFFF;
const MAX_STATE: u32 = 0x7FFF_FFFE;
impl DaState {
#[inline(always)]
const fn new_free() -> Self {
Self {
child0: NIL_STATE, parent: NIL_STATE | FREE_BIT, }
}
#[inline(always)]
const fn new_root() -> Self {
Self {
child0: 0, parent: NIL_STATE, }
}
#[inline(always)]
fn child0(&self) -> u32 { self.child0 }
#[inline(always)]
fn parent(&self) -> u32 { self.parent & VALUE_MASK }
#[inline(always)]
fn is_free(&self) -> bool { (self.parent & FREE_BIT) != 0 }
#[inline(always)]
fn set_child0(&mut self, val: u32) {
self.child0 = val;
}
#[inline(always)]
fn set_parent(&mut self, val: u32) {
self.parent = val & VALUE_MASK; }
#[inline(always)]
fn set_free_linked(&mut self, next: u32, prev: u32) {
self.child0 = next; self.parent = FREE_BIT | prev; }
#[inline(always)]
fn set_free(&mut self) {
self.set_free_linked(NIL_STATE, NIL_STATE);
}
#[inline(always)]
fn free_next(&self) -> u32 { self.child0 }
#[inline(always)]
fn free_prev(&self) -> u32 { self.parent & VALUE_MASK }
}
#[repr(C)]
#[derive(Clone, Copy, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct NInfo {
sibling: u16, child: u16, }
const NINFO_NONE: u16 = 0;
const NINFO_TERM: u16 = 0x8000;
impl NInfo {
#[inline(always)]
fn is_term(&self) -> bool { (self.child & NINFO_TERM) != 0 }
#[inline(always)]
fn set_term(&mut self) { self.child |= NINFO_TERM; }
#[inline(always)]
fn clear_term(&mut self) { self.child &= !NINFO_TERM; }
#[inline(always)]
fn first_child(&self) -> u16 { self.child & !NINFO_TERM }
#[inline(always)]
fn set_first_child(&mut self, val: u16) {
self.child = (self.child & NINFO_TERM) | val;
}
}
#[inline(always)]
fn ninfo_to_label(v: u16) -> Option<u8> {
if v == 0 { None } else { Some((v - 1) as u8) }
}
#[inline(always)]
fn label_to_ninfo(label: u8) -> u16 {
label as u16 + 1
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DoubleArrayTrie {
states: Vec<DaState>,
ninfos: Vec<NInfo>,
num_keys: usize,
search_head: usize,
}
impl DoubleArrayTrie {
pub fn new() -> Self {
Self::with_capacity(256)
}
pub fn with_capacity(capacity: usize) -> Self {
let cap = capacity.max(256);
let mut states = Vec::with_capacity(cap);
states.push(DaState::new_root());
states.resize(cap, DaState::new_free());
let ninfos = vec![NInfo::default(); cap];
Self {
states,
ninfos,
num_keys: 0,
search_head: 1,
}
}
#[inline(always)]
pub fn len(&self) -> usize { self.num_keys }
#[inline(always)]
pub fn is_empty(&self) -> bool { self.num_keys == 0 }
#[inline]
pub fn total_states(&self) -> usize { self.states.len() }
#[inline]
pub fn mem_size(&self) -> usize {
self.states.len() * std::mem::size_of::<DaState>() +
self.ninfos.len() * std::mem::size_of::<NInfo>()
}
#[inline(always)]
pub fn is_term(&self, state: u32) -> bool {
(state as usize) < self.ninfos.len() && self.ninfos[state as usize].is_term()
}
#[inline(always)]
pub fn is_free(&self, state: u32) -> bool {
(state as usize) >= self.states.len() || self.states[state as usize].is_free()
}
#[inline(always)]
pub fn state_move(&self, curr: u32, ch: u8) -> u32 {
let base = self.states[curr as usize].child0();
let next = (base ^ ch as u32) as usize;
debug_assert!(next < self.states.len());
let next_state = unsafe { self.states.get_unchecked(next) };
if next_state.is_free() { return NIL_STATE; }
if next_state.parent() == curr {
next as u32
} else {
NIL_STATE
}
}
#[inline]
pub fn insert(&mut self, key: &[u8]) -> Result<bool> {
self.insert_with_relocate_cb(key, |_, _| {})
}
pub fn insert_with_relocate_cb(
&mut self,
key: &[u8],
mut on_relocate: impl FnMut(u32, u32),
) -> Result<bool> {
if key.is_empty() {
let was_new = !self.ninfos[0].is_term();
self.ninfos[0].set_term();
if was_new { self.num_keys += 1; }
return Ok(was_new);
}
let mut curr = 0u32;
for &ch in key {
let base = self.states[curr as usize].child0;
if base == 0 {
let new_base = self.find_free_base(&[ch])?;
self.set_base_padded(curr as usize, new_base);
let next = new_base ^ ch as u32;
self.ensure_capacity(next as usize + 1);
self.states[next as usize].child0 = 0;
self.states[next as usize].set_parent(curr);
self.add_child_link(curr as usize, ch);
curr = next;
} else {
let next = base ^ ch as u32;
self.ensure_capacity(next as usize + 1);
if !self.states[next as usize].is_free()
&& self.states[next as usize].parent() == curr
{
curr = next;
} else if self.states[next as usize].is_free() {
self.states[next as usize].child0 = 0;
self.states[next as usize].set_parent(curr);
self.add_child_link(curr as usize, ch);
curr = next;
} else {
let conflict_parent = self.states[next as usize].parent();
let can_consult = (conflict_parent as usize) < self.ninfos.len()
&& conflict_parent != curr
&& !self.is_ancestor(conflict_parent, curr);
if can_consult {
let curr_n = self.count_children(curr) + 1;
let conf_n = self.count_children(conflict_parent);
if curr_n > conf_n {
let old_base_cf = self.states[conflict_parent as usize].child0();
self.relocate_existing(conflict_parent)?;
let new_base_cf = self.states[conflict_parent as usize].child0();
Self::notify_relocated(
&self.ninfos, conflict_parent as usize,
old_base_cf, new_base_cf, &mut on_relocate,
);
self.states[next as usize].child0 = 0;
self.states[next as usize].set_parent(curr);
self.add_child_link(curr as usize, ch);
curr = next;
continue;
}
}
let old_base = self.states[curr as usize].child0();
let new_base = self.relocate(curr, ch)?;
Self::notify_relocated_excluding(
&self.ninfos, curr as usize,
old_base, new_base, ch, &mut on_relocate,
);
let next = new_base ^ ch as u32;
self.ensure_capacity(next as usize + 1);
self.states[next as usize].child0 = 0;
self.states[next as usize].set_parent(curr);
self.add_child_link(curr as usize, ch);
curr = next;
}
}
}
let was_new = !self.ninfos[curr as usize].is_term();
self.ninfos[curr as usize].set_term();
if was_new { self.num_keys += 1; }
Ok(was_new)
}
fn notify_relocated(
ninfos: &[NInfo], parent_pos: usize,
old_base: u32, new_base: u32,
on_relocate: &mut impl FnMut(u32, u32),
) {
let mut c = ninfos[parent_pos].first_child();
while c != NINFO_NONE {
let label = (c - 1) as u8;
on_relocate(old_base ^ label as u32, new_base ^ label as u32);
let child_pos = (new_base ^ label as u32) as usize;
c = if child_pos < ninfos.len() { ninfos[child_pos].sibling } else { NINFO_NONE };
}
}
fn notify_relocated_excluding(
ninfos: &[NInfo], parent_pos: usize,
old_base: u32, new_base: u32, exclude_ch: u8,
on_relocate: &mut impl FnMut(u32, u32),
) {
let mut c = ninfos[parent_pos].first_child();
while c != NINFO_NONE {
let label = (c - 1) as u8;
if label != exclude_ch {
on_relocate(old_base ^ label as u32, new_base ^ label as u32);
}
let child_pos = (new_base ^ label as u32) as usize;
c = if child_pos < ninfos.len() { ninfos[child_pos].sibling } else { NINFO_NONE };
}
}
#[inline]
pub fn contains(&self, key: &[u8]) -> bool {
let states = self.states.as_slice();
if key.is_empty() {
return self.ninfos[0].is_term();
}
let ninfos = self.ninfos.as_slice();
let mut curr = 0usize;
for &ch in key {
let base = states[curr].child0;
let next = (base ^ ch as u32) as usize;
debug_assert!(next < states.len(), "OOB: next={next}, len={}", states.len());
let next_state = unsafe { states.get_unchecked(next) };
if next_state.parent != curr as u32 { return false; }
curr = next;
}
ninfos[curr].is_term()
}
#[inline]
pub fn lookup_state(&self, key: &[u8]) -> Option<u32> {
let states = self.states.as_slice();
if key.is_empty() {
return if self.ninfos[0].is_term() { Some(0) } else { None };
}
let ninfos = self.ninfos.as_slice();
let mut curr = 0usize;
for &ch in key {
let base = states[curr].child0;
let next = (base ^ ch as u32) as usize;
debug_assert!(next < states.len());
let next_state = unsafe { states.get_unchecked(next) };
if next_state.parent != curr as u32 { return None; }
curr = next;
}
if ninfos[curr].is_term() { Some(curr as u32) } else { None }
}
pub fn remove(&mut self, key: &[u8]) -> bool {
if let Some(state) = self.lookup_state(key) {
if self.ninfos[state as usize].is_term() {
self.ninfos[state as usize].clear_term();
self.num_keys -= 1;
self.prune_dead_branch(state);
return true;
}
}
false
}
pub fn restore_key(&self, state: u32) -> Option<Vec<u8>> {
if state as usize >= self.states.len() { return None; }
if self.states[state as usize].is_free() { return None; }
let mut symbols = Vec::new();
let mut curr = state;
while curr != 0 {
let parent = self.states[curr as usize].parent();
let parent_base = self.states[parent as usize].child0();
let symbol = (curr ^ parent_base) as u8;
symbols.push(symbol);
curr = parent;
}
symbols.reverse();
Some(symbols)
}
fn prune_dead_branch(&mut self, state: u32) {
let mut curr = state;
while curr != 0 {
if self.ninfos[curr as usize].is_term() { break; }
if self.ninfos[curr as usize].first_child() != NINFO_NONE { break; }
let parent = self.states[curr as usize].parent();
if parent as usize >= self.states.len() { break; }
let parent_base = self.states[parent as usize].child0();
let label = (curr ^ parent_base) as u8;
self.remove_child_link(parent as usize, label);
self.ninfos[curr as usize] = NInfo::default();
self.states[curr as usize].set_free();
curr = parent;
}
}
fn remove_child_link(&mut self, parent_pos: usize, label: u8) {
let label_enc = label_to_ninfo(label);
let base = self.states[parent_pos].child0();
let first = self.ninfos[parent_pos].first_child();
if first == NINFO_NONE { return; }
if first == label_enc {
let child_pos = (base ^ label as u32) as usize;
let next = if child_pos < self.ninfos.len() {
self.ninfos[child_pos].sibling
} else {
NINFO_NONE
};
self.ninfos[parent_pos].set_first_child(next);
return;
}
let mut prev_enc = first;
loop {
let prev_label = (prev_enc - 1) as u8;
let prev_pos = (base ^ prev_label as u32) as usize;
if prev_pos >= self.ninfos.len() { break; }
let next_enc = self.ninfos[prev_pos].sibling;
if next_enc == label_enc {
let label_pos = (base ^ label as u32) as usize;
let after = if label_pos < self.ninfos.len() {
self.ninfos[label_pos].sibling
} else {
NINFO_NONE
};
self.ninfos[prev_pos].sibling = after;
return;
}
if next_enc == NINFO_NONE { break; }
prev_enc = next_enc;
}
}
pub fn keys(&self) -> Vec<Vec<u8>> {
let mut result = Vec::with_capacity(self.num_keys);
let mut path = Vec::new();
self.collect_keys(0, &mut path, &mut result);
result
}
pub fn keys_with_prefix(&self, prefix: &[u8]) -> Vec<Vec<u8>> {
let mut curr = 0u32;
for &ch in prefix {
let next = self.state_move(curr, ch);
if next == NIL_STATE { return Vec::new(); }
curr = next;
}
let mut result = Vec::new();
let mut path = prefix.to_vec();
self.collect_keys(curr, &mut path, &mut result);
result
}
#[inline]
pub fn for_each_child(&self, state: u32, mut f: impl FnMut(u8, u32)) {
let mut c = self.ninfos[state as usize].first_child();
if c == NINFO_NONE { return; }
let base = self.states[state as usize].child0();
while c != NINFO_NONE {
let label = (c - 1) as u8;
let child_pos = (base ^ label as u32) as usize;
if child_pos < self.states.len() && !self.states[child_pos].is_free() {
f(label, child_pos as u32);
}
c = if child_pos < self.ninfos.len() {
self.ninfos[child_pos].sibling
} else {
NINFO_NONE
};
}
}
#[inline]
fn get_children(&self, state: u32) -> Vec<(u8, u32)> {
let mut c = self.ninfos[state as usize].first_child();
if c == NINFO_NONE { return Vec::new(); }
let base = self.states[state as usize].child0();
let mut children = Vec::new();
while c != NINFO_NONE {
let label = (c - 1) as u8;
let child_pos = (base ^ label as u32) as usize;
if child_pos < self.states.len() && !self.states[child_pos].is_free() {
children.push((label, child_pos as u32));
}
c = if child_pos < self.ninfos.len() {
self.ninfos[child_pos].sibling
} else {
NINFO_NONE
};
}
children }
#[inline]
fn lower_bound_child(&self, state: u32, symbol: u8) -> Option<(u8, u32)> {
let mut c = self.ninfos[state as usize].first_child();
if c == NINFO_NONE { return None; }
let base = self.states[state as usize].child0();
while c != NINFO_NONE {
let label = (c - 1) as u8;
if label < symbol {
let child_pos = (base ^ label as u32) as usize;
c = if child_pos < self.ninfos.len() {
self.ninfos[child_pos].sibling
} else {
NINFO_NONE
};
continue;
}
let child_pos = (base ^ label as u32) as usize;
if child_pos < self.states.len() && !self.states[child_pos].is_free() {
return Some((label, child_pos as u32));
}
c = if child_pos < self.ninfos.len() {
self.ninfos[child_pos].sibling
} else {
NINFO_NONE
};
}
None
}
#[inline]
fn prev_child(&self, state: u32, symbol: u32) -> Option<(u8, u32)> {
let mut c = self.ninfos[state as usize].first_child();
if c == NINFO_NONE { return None; }
let base = self.states[state as usize].child0();
let mut result = None;
while c != NINFO_NONE {
let label = (c - 1) as u8;
let child_pos = (base ^ label as u32) as usize;
if (label as u32) < symbol && child_pos < self.states.len() && !self.states[child_pos].is_free() {
result = Some((label, child_pos as u32));
}
c = if child_pos < self.ninfos.len() {
self.ninfos[child_pos].sibling
} else {
NINFO_NONE
};
}
result
}
#[inline]
fn first_child(&self, state: u32) -> Option<(u8, u32)> {
let c = self.ninfos[state as usize].first_child();
if c == NINFO_NONE { return None; }
let base = self.states[state as usize].child0();
let label = (c - 1) as u8;
let child_pos = (base ^ label as u32) as usize;
if child_pos < self.states.len() && !self.states[child_pos].is_free() {
Some((label, child_pos as u32))
} else {
None
}
}
#[inline]
fn last_child(&self, state: u32) -> Option<(u8, u32)> {
let mut c = self.ninfos[state as usize].first_child();
if c == NINFO_NONE { return None; }
let base = self.states[state as usize].child0();
let mut result = None;
while c != NINFO_NONE {
let label = (c - 1) as u8;
let child_pos = (base ^ label as u32) as usize;
if child_pos < self.states.len() && !self.states[child_pos].is_free() {
result = Some((label, child_pos as u32));
}
c = if child_pos < self.ninfos.len() {
self.ninfos[child_pos].sibling
} else {
NINFO_NONE
};
}
result
}
pub fn for_each_key_with_prefix(&self, prefix: &[u8], mut f: impl FnMut(&[u8])) {
let mut curr = 0u32;
for &ch in prefix {
let next = self.state_move(curr, ch);
if next == NIL_STATE { return; }
curr = next;
}
let mut path = prefix.to_vec();
self.walk_keys(curr, &mut path, &mut f);
}
fn walk_keys(&self, state: u32, path: &mut Vec<u8>, f: &mut impl FnMut(&[u8])) {
if state as usize >= self.states.len() { return; }
if self.ninfos[state as usize].is_term() {
f(path);
}
let mut c = self.ninfos[state as usize].first_child();
if c == NINFO_NONE { return; }
let base = self.states[state as usize].child0();
while c != NINFO_NONE {
let label = (c - 1) as u8;
let child_pos = (base ^ label as u32) as usize;
if child_pos < self.states.len() && !self.states[child_pos].is_free() {
path.push(label);
self.walk_keys(child_pos as u32, path, f);
path.pop();
}
c = if child_pos < self.ninfos.len() {
self.ninfos[child_pos].sibling
} else {
NINFO_NONE
};
}
}
pub fn build_from_sorted(keys: &[&[u8]]) -> Result<Self> {
if keys.is_empty() { return Ok(Self::new()); }
let total_bytes: usize = keys.iter().map(|k| k.len()).sum();
let estimated_states = (total_bytes / 2).max(256);
let mut trie = Self::with_capacity(estimated_states * 3 / 2);
for &key in keys {
trie.insert(key)?;
}
trie.shrink_to_fit();
Ok(trie)
}
#[inline]
fn count_children(&self, state: u32) -> usize {
let mut c = self.ninfos[state as usize].first_child();
if c == NINFO_NONE { return 0; }
let base = self.states[state as usize].child0();
let mut count = 0;
while c != NINFO_NONE {
count += 1;
let label = (c - 1) as u8;
let pos = (base ^ label as u32) as usize;
c = if pos < self.ninfos.len() { self.ninfos[pos].sibling } else { NINFO_NONE };
}
count
}
fn consult_and_relocate(&mut self, curr: u32, ch: u8) -> Result<u32> {
let base = self.states[curr as usize].child0();
let conflict_pos = base ^ ch as u32;
let conflict_parent = self.states[conflict_pos as usize].parent();
let curr_children = self.count_children(curr);
let conflict_children = self.count_children(conflict_parent);
if curr_children < conflict_children {
self.relocate(curr, ch)
} else {
self.relocate(conflict_parent, ch)?;
Ok(self.states[curr as usize].child0())
}
}
pub fn shrink_to_fit(&mut self) {
let mut max_reachable = 0usize;
for (i, s) in self.states.iter().enumerate() {
if !s.is_free() {
if i < self.ninfos.len() && self.ninfos[i].first_child() != NINFO_NONE {
let base = s.child0();
max_reachable = max_reachable.max((base as usize | 0xFF) + 1);
}
}
}
let last_used = self.states.iter().rposition(|s| !s.is_free()).unwrap_or(0);
let new_len = (last_used + 1).max(max_reachable).min(self.states.len());
self.states.truncate(new_len);
self.states.shrink_to_fit();
self.ninfos.truncate(new_len);
self.ninfos.shrink_to_fit();
self.search_head = 1;
}
fn add_child_link(&mut self, parent_pos: usize, label: u8) {
let label_enc = label_to_ninfo(label);
let base = self.states[parent_pos].child0();
let first = self.ninfos[parent_pos].first_child();
if first == NINFO_NONE || label_enc < first {
let child_pos = (base ^ label as u32) as usize;
if child_pos < self.ninfos.len() {
self.ninfos[child_pos].sibling = first;
}
self.ninfos[parent_pos].set_first_child(label_enc);
} else if first == label_enc {
return; } else {
let mut prev_enc = first;
loop {
let prev_label = (prev_enc - 1) as u8;
let prev_pos = (base ^ prev_label as u32) as usize;
if prev_pos >= self.ninfos.len() { break; }
let next_enc = self.ninfos[prev_pos].sibling;
if next_enc == NINFO_NONE || label_enc < next_enc {
let child_pos = (base ^ label as u32) as usize;
if child_pos < self.ninfos.len() {
self.ninfos[child_pos].sibling = next_enc;
self.ninfos[prev_pos].sibling = label_enc;
}
break;
}
if next_enc == label_enc { break; } prev_enc = next_enc;
}
}
}
#[inline]
fn ensure_capacity(&mut self, required: usize) {
if required <= self.states.len() { return; }
let new_len = required.max(self.states.len() * 3 / 2).max(256);
self.states.resize(new_len, DaState::new_free());
self.ninfos.resize(new_len, NInfo::default());
}
#[inline]
fn set_base_padded(&mut self, state: usize, base: u32) {
self.states[state].set_child0(base);
self.ensure_capacity((base as usize | 0xFF) + 1);
}
fn find_free_base(&mut self, children: &[u8]) -> Result<u32> {
debug_assert!(!children.is_empty());
let ch0 = children[0] as u32;
let single = children.len() == 1;
while self.search_head < self.states.len()
&& !self.states[self.search_head].is_free()
{
self.search_head += 1;
}
let mut base = (self.search_head as u32) ^ ch0;
if base == 0 { base = 1; }
let mut attempts = 0u32;
loop {
if attempts > 1_000_000 || base > MAX_STATE {
return Err(ZiporaError::invalid_data("Double array: cannot find free base"));
}
attempts += 1;
self.ensure_capacity((base as usize | 0xFF) + 1);
let first_pos = (base ^ ch0) as usize;
if first_pos == 0 || !self.states[first_pos].is_free() {
base += 1;
continue;
}
if single {
self.search_head = first_pos;
return Ok(base);
}
let all_free = children[1..].iter().all(|&ch| {
let pos = (base ^ ch as u32) as usize;
pos > 0 && self.states[pos].is_free()
});
if all_free {
self.search_head = first_pos;
return Ok(base);
}
base += 1;
}
}
fn relocate(&mut self, state: u32, new_ch: u8) -> Result<u32> {
let old_base = self.states[state as usize].child0();
let mut children_symbols = Vec::new();
{
let mut c = self.ninfos[state as usize].first_child();
while c != NINFO_NONE {
let label = (c - 1) as u8;
let child_pos = (old_base ^ label as u32) as usize;
if child_pos < self.states.len() && !self.states[child_pos].is_free() {
if !children_symbols.contains(&label) {
children_symbols.push(label);
}
}
c = if child_pos < self.ninfos.len() {
self.ninfos[child_pos].sibling
} else {
NINFO_NONE
};
}
}
if !children_symbols.contains(&new_ch) {
children_symbols.push(new_ch);
}
children_symbols.sort_unstable();
let new_base = self.find_free_base(&children_symbols)?;
self.ninfos[state as usize].set_first_child(NINFO_NONE);
{
for &ch in &children_symbols {
if ch == new_ch { continue; }
let old_pos = old_base ^ ch as u32;
let new_pos = new_base ^ ch as u32;
if old_pos as usize >= self.states.len() { continue; }
if self.states[old_pos as usize].is_free() { continue; }
if self.states[old_pos as usize].parent() != state { continue; }
self.ensure_capacity(new_pos as usize + 1);
let old_state = self.states[old_pos as usize];
self.states[new_pos as usize].child0 = old_state.child0;
self.states[new_pos as usize].set_parent(state);
let old_ninfo = self.ninfos[old_pos as usize];
self.ninfos[new_pos as usize].child = old_ninfo.child;
self.ninfos[new_pos as usize].sibling = NINFO_NONE;
{
let mut gc = old_ninfo.first_child();
if gc != NINFO_NONE {
let child_base = old_state.child0();
while gc != NINFO_NONE {
let glabel = (gc - 1) as u8;
let gpos = (child_base ^ glabel as u32) as usize;
if gpos < self.states.len() && !self.states[gpos].is_free()
&& self.states[gpos].parent() == old_pos
{
self.states[gpos].set_parent(new_pos);
}
gc = if gpos < self.ninfos.len() {
self.ninfos[gpos].sibling
} else {
NINFO_NONE
};
}
}
}
self.ninfos[old_pos as usize] = NInfo::default();
self.states[old_pos as usize].set_free();
}
}
self.set_base_padded(state as usize, new_base);
for &ch in &children_symbols {
if ch != new_ch {
self.add_child_link(state as usize, ch);
}
}
Ok(new_base)
}
fn relocate_existing(&mut self, state: u32) -> Result<u32> {
let old_base = self.states[state as usize].child0();
let mut children_symbols: Vec<u8> = Vec::new();
let mut c = self.ninfos[state as usize].first_child();
while c != NINFO_NONE {
let label = (c - 1) as u8;
let child_pos = (old_base ^ label as u32) as usize;
if child_pos < self.states.len() && !self.states[child_pos].is_free() {
children_symbols.push(label);
}
c = if child_pos < self.ninfos.len() { self.ninfos[child_pos].sibling } else { NINFO_NONE };
}
if children_symbols.is_empty() { return Ok(old_base); }
children_symbols.sort_unstable();
let new_base = self.find_free_base(&children_symbols)?;
self.ninfos[state as usize].set_first_child(NINFO_NONE);
for &ch in &children_symbols {
let old_pos = old_base ^ ch as u32;
let new_pos = new_base ^ ch as u32;
if old_pos as usize >= self.states.len() { continue; }
if self.states[old_pos as usize].is_free() { continue; }
if self.states[old_pos as usize].parent() != state { continue; }
self.ensure_capacity(new_pos as usize + 1);
let old_state = self.states[old_pos as usize];
self.states[new_pos as usize].child0 = old_state.child0;
self.states[new_pos as usize].set_parent(state);
let old_ninfo = self.ninfos[old_pos as usize];
self.ninfos[new_pos as usize].child = old_ninfo.child;
self.ninfos[new_pos as usize].sibling = NINFO_NONE;
{
let mut gc = old_ninfo.first_child();
if gc != NINFO_NONE {
let child_base = old_state.child0();
while gc != NINFO_NONE {
let glabel = (gc - 1) as u8;
let gpos = (child_base ^ glabel as u32) as usize;
if gpos < self.states.len() && !self.states[gpos].is_free()
&& self.states[gpos].parent() == old_pos
{
self.states[gpos].set_parent(new_pos);
}
gc = if gpos < self.ninfos.len() { self.ninfos[gpos].sibling } else { NINFO_NONE };
}
}
}
self.ninfos[old_pos as usize] = NInfo::default();
self.states[old_pos as usize].set_free();
}
self.set_base_padded(state as usize, new_base);
for &ch in &children_symbols {
self.add_child_link(state as usize, ch);
}
Ok(new_base)
}
fn is_ancestor(&self, ancestor: u32, descendant: u32) -> bool {
let mut curr = descendant;
let mut depth = 0;
while curr != 0 && depth < 256 {
let parent = self.states[curr as usize].parent();
if parent == ancestor { return true; }
curr = parent;
depth += 1;
}
false
}
fn collect_keys(&self, state: u32, path: &mut Vec<u8>, keys: &mut Vec<Vec<u8>>) {
if state as usize >= self.states.len() { return; }
if self.ninfos[state as usize].is_term() {
keys.push(path.clone());
}
let mut c = self.ninfos[state as usize].first_child();
if c == NINFO_NONE { return; }
let base = self.states[state as usize].child0();
while c != NINFO_NONE {
let label = (c - 1) as u8;
let child_pos = (base ^ label as u32) as usize;
if child_pos < self.states.len() && !self.states[child_pos].is_free() {
path.push(label);
self.collect_keys(child_pos as u32, path, keys);
path.pop();
}
c = if child_pos < self.ninfos.len() {
self.ninfos[child_pos].sibling
} else {
NINFO_NONE
};
}
}
}
impl Default for DoubleArrayTrie {
fn default() -> Self { Self::new() }
}
impl std::fmt::Debug for DoubleArrayTrie {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DoubleArrayTrie")
.field("num_keys", &self.num_keys)
.field("total_states", &self.states.len())
.field("mem_size", &self.mem_size())
.finish()
}
}
pub struct DoubleArrayTrieCursor<'a> {
trie: &'a DoubleArrayTrie,
stack: Vec<(u32, u16)>,
current_key: Vec<u8>,
valid: bool,
}
impl<'a> DoubleArrayTrieCursor<'a> {
fn new(trie: &'a DoubleArrayTrie) -> Self {
Self {
trie,
stack: Vec::with_capacity(64),
current_key: Vec::with_capacity(64),
valid: false,
}
}
#[inline]
pub fn key(&self) -> &[u8] {
&self.current_key
}
#[inline]
pub fn is_valid(&self) -> bool {
self.valid
}
pub fn seek_begin(&mut self) -> bool {
self.stack.clear();
self.current_key.clear();
self.valid = false;
if self.trie.states.is_empty() { return false; }
if self.trie.ninfos[0].is_term() {
self.stack.push((0, 0));
self.valid = true;
return true;
}
self.stack.push((0, 0));
self.descend_to_next_terminal()
}
pub fn seek_end(&mut self) -> bool {
self.stack.clear();
self.current_key.clear();
self.valid = false;
if self.trie.states.is_empty() { return false; }
self.stack.push((0, 256));
self.descend_to_rightmost_terminal(0)
}
pub fn seek_lower_bound(&mut self, target: &[u8]) -> bool {
self.stack.clear();
self.current_key.clear();
self.valid = false;
if self.trie.states.is_empty() { return false; }
let mut curr = 0u32;
for &ch in target {
match self.trie.lower_bound_child(curr, ch) {
Some((found_ch, next_state)) if found_ch == ch => {
self.stack.push((curr, ch as u16 + 1));
self.current_key.push(ch);
curr = next_state;
}
Some((found_ch, _)) => {
self.stack.push((curr, found_ch as u16));
return self.advance_from_stack();
}
None => {
self.stack.push((curr, 256)); return self.advance_from_stack();
}
}
}
if self.trie.ninfos[curr as usize].is_term() {
self.stack.push((curr, 0));
self.valid = true;
return true;
}
self.stack.push((curr, 0));
self.descend_to_next_terminal()
}
pub fn next(&mut self) -> bool {
if !self.valid { return false; }
if let Some(&(state, _)) = self.stack.last() {
if self.trie.ninfos[state as usize].first_child() != NINFO_NONE {
let len = self.stack.len();
self.stack[len - 1] = (state, 0);
return self.descend_to_next_terminal();
}
}
self.advance_from_stack()
}
pub fn prev(&mut self) -> bool {
if !self.valid { return false; }
self.valid = false;
let saved_key = self.current_key.clone();
let mut state = match self.trie.lookup_state(&saved_key) {
Some(s) => s,
None => return false,
};
let mut depth = saved_key.len();
loop {
if state == 0 {
if depth == 0 {
self.current_key.clear();
self.stack.clear();
return false;
}
return false;
}
let parent = self.trie.states[state as usize].parent();
let parent_base = self.trie.states[parent as usize].child0();
let my_symbol = state ^ parent_base;
depth -= 1;
if let Some((ch, sibling)) = self.trie.prev_child(parent, my_symbol) {
self.current_key.truncate(depth);
self.current_key.push(ch);
self.stack.clear();
self.stack.push((sibling, 256));
if self.descend_to_rightmost_terminal(sibling) {
self.rebuild_stack_from_key();
return true;
}
self.current_key.truncate(depth);
}
if self.trie.ninfos[parent as usize].is_term() {
self.current_key.truncate(depth);
self.rebuild_stack_from_key();
self.valid = true;
return true;
}
state = parent;
}
}
fn rebuild_stack_from_key(&mut self) {
let key = self.current_key.clone();
self.stack.clear();
let mut curr = 0u32;
self.stack.push((0, 0));
for &ch in key.iter() {
let next = self.trie.state_move(curr, ch);
if next == NIL_STATE { break; }
let len = self.stack.len();
self.stack[len - 1].1 = ch as u16 + 1;
self.stack.push((next, 0));
curr = next;
}
}
fn descend_to_rightmost_terminal(&mut self, state: u32) -> bool {
let mut curr = state;
loop {
match self.trie.last_child(curr) {
Some((ch, next_state)) => {
let len = self.stack.len();
self.stack[len - 1].1 = ch as u16;
self.current_key.push(ch);
self.stack.push((next_state, 256));
curr = next_state;
}
None => {
if self.trie.ninfos[curr as usize].is_term() {
self.valid = true;
return true;
}
return false;
}
}
}
}
fn advance_from_stack(&mut self) -> bool {
self.valid = false;
while let Some(&mut (state, ref mut next_sym)) = self.stack.last_mut() {
if *next_sym > 255 {
self.stack.pop();
self.current_key.pop();
continue;
}
match self.trie.lower_bound_child(state, *next_sym as u8) {
Some((ch, next_state)) => {
*next_sym = ch as u16 + 1;
self.current_key.push(ch);
if self.trie.ninfos[next_state as usize].is_term() {
self.stack.push((next_state, 0));
self.valid = true;
return true;
}
self.stack.push((next_state, 0));
}
None => {
self.stack.pop();
self.current_key.pop();
}
}
}
false
}
fn descend_to_next_terminal(&mut self) -> bool {
loop {
let &(state, _) = match self.stack.last() {
Some(s) => s,
None => return false,
};
match self.trie.first_child(state) {
Some((ch, next_state)) => {
let len = self.stack.len();
self.stack[len - 1].1 = ch as u16 + 1;
self.current_key.push(ch);
if self.trie.ninfos[next_state as usize].is_term() {
self.stack.push((next_state, 0));
self.valid = true;
return true;
}
self.stack.push((next_state, 0));
}
None => {
return self.advance_from_stack();
}
}
}
}
}
impl DoubleArrayTrie {
#[inline]
pub fn cursor(&self) -> DoubleArrayTrieCursor<'_> {
DoubleArrayTrieCursor::new(self)
}
pub fn range<'a>(&'a self, from: &[u8], to: &[u8]) -> RangeIter<'a> {
let mut cursor = self.cursor();
let valid = cursor.seek_lower_bound(from);
RangeIter {
cursor,
upper_bound: to.to_vec(),
started: valid,
}
}
}
pub struct RangeIter<'a> {
cursor: DoubleArrayTrieCursor<'a>,
upper_bound: Vec<u8>,
started: bool,
}
impl<'a> Iterator for RangeIter<'a> {
type Item = Vec<u8>;
fn next(&mut self) -> Option<Self::Item> {
if !self.started {
return None;
}
if !self.cursor.is_valid() {
return None;
}
let key = self.cursor.key();
if key >= self.upper_bound.as_slice() {
return None;
}
let result = key.to_vec();
self.started = self.cursor.next();
Some(result)
}
}
pub trait MapValue: Copy + PartialEq {
const EMPTY: Self;
}
impl MapValue for i32 { const EMPTY: Self = i32::MIN; }
impl MapValue for u32 { const EMPTY: Self = u32::MAX; }
impl MapValue for i64 { const EMPTY: Self = i64::MIN; }
impl MapValue for u64 { const EMPTY: Self = u64::MAX; }
impl MapValue for usize { const EMPTY: Self = usize::MAX; }
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DoubleArrayTrieMap<V: MapValue> {
trie: DoubleArrayTrie,
values: Vec<V>,
}
impl<V: MapValue> DoubleArrayTrieMap<V> {
pub fn new() -> Self {
Self { trie: DoubleArrayTrie::new(), values: Vec::new() }
}
pub fn with_capacity(cap: usize) -> Self {
Self { trie: DoubleArrayTrie::with_capacity(cap), values: Vec::with_capacity(cap) }
}
pub fn insert(&mut self, key: &[u8], value: V) -> Result<Option<V>> {
let values = &mut self.values;
self.trie.insert_with_relocate_cb(key, |old_pos, new_pos| {
let old = old_pos as usize;
let new = new_pos as usize;
if new >= values.len() {
values.resize((new + 1).max(values.len() * 2), V::EMPTY);
}
if old < values.len() {
let v = std::mem::replace(&mut values[old], V::EMPTY);
values[new] = v;
}
})?;
let state = self.trie.lookup_state(key)
.ok_or_else(|| ZiporaError::invalid_state("insert succeeded but lookup failed"))?;
let idx = state as usize;
if idx >= self.values.len() {
let new_len = (idx + 1).max(self.values.len() * 2).max(256);
self.values.resize(new_len, V::EMPTY);
}
let prev = self.values[idx];
self.values[idx] = value;
Ok(if prev != V::EMPTY { Some(prev) } else { None })
}
#[inline]
pub fn get(&self, key: &[u8]) -> Option<V> {
let state = self.trie.lookup_state(key)?;
let idx = state as usize;
if idx < self.values.len() {
let v = unsafe { *self.values.get_unchecked(idx) };
if v != V::EMPTY { Some(v) } else { None }
} else {
None
}
}
#[inline]
pub fn contains(&self, key: &[u8]) -> bool { self.trie.contains(key) }
#[inline]
pub fn len(&self) -> usize { self.trie.len() }
#[inline]
pub fn is_empty(&self) -> bool { self.trie.is_empty() }
pub fn keys(&self) -> Vec<Vec<u8>> { self.trie.keys() }
pub fn keys_with_prefix(&self, prefix: &[u8]) -> Vec<Vec<u8>> {
self.trie.keys_with_prefix(prefix)
}
pub fn entries_with_prefix(&self, prefix: &[u8]) -> Vec<(Vec<u8>, V)> {
let mut results = Vec::new();
let mut curr = 0u32;
for &ch in prefix {
let next = self.trie.state_move(curr, ch);
if next == NIL_STATE { return results; }
curr = next;
}
let mut path = prefix.to_vec();
self.collect_entries(curr, &mut path, &mut results);
results
}
pub fn values_with_prefix(&self, prefix: &[u8]) -> Vec<V> {
self.entries_with_prefix(prefix).into_iter().map(|(_, v)| v).collect()
}
fn collect_entries(&self, state: u32, path: &mut Vec<u8>, entries: &mut Vec<(Vec<u8>, V)>) {
if state as usize >= self.trie.states.len() { return; }
if self.trie.ninfos[state as usize].is_term() {
if let Some(&val) = self.values.get(state as usize) {
if val != V::EMPTY {
entries.push((path.clone(), val));
}
}
}
let mut c = self.trie.ninfos[state as usize].first_child();
if c == NINFO_NONE { return; }
let base = self.trie.states[state as usize].child0();
while c != NINFO_NONE {
let label = (c - 1) as u8;
let child_pos = (base ^ label as u32) as usize;
if child_pos < self.trie.states.len() && !self.trie.states[child_pos].is_free() {
path.push(label);
self.collect_entries(child_pos as u32, path, entries);
path.pop();
}
c = if child_pos < self.trie.ninfos.len() {
self.trie.ninfos[child_pos].sibling
} else {
NINFO_NONE
};
}
}
pub fn for_each_value_with_prefix(&self, prefix: &[u8], mut f: impl FnMut(V)) {
let mut curr = 0u32;
for &ch in prefix {
let next = self.trie.state_move(curr, ch);
if next == NIL_STATE { return; }
curr = next;
}
self.walk_values_dfs(curr, &mut f);
}
fn walk_values_dfs(&self, state: u32, f: &mut impl FnMut(V)) {
if state as usize >= self.trie.states.len() { return; }
if self.trie.ninfos[state as usize].is_term() {
if let Some(&val) = self.values.get(state as usize) {
if val != V::EMPTY {
f(val);
}
}
}
let mut c = self.trie.ninfos[state as usize].first_child();
if c == NINFO_NONE { return; }
let base = self.trie.states[state as usize].child0();
while c != NINFO_NONE {
let label = (c - 1) as u8;
let child_pos = (base ^ label as u32) as usize;
if child_pos < self.trie.states.len() && !self.trie.states[child_pos].is_free() {
self.walk_values_dfs(child_pos as u32, f);
}
c = if child_pos < self.trie.ninfos.len() {
self.trie.ninfos[child_pos].sibling
} else {
NINFO_NONE
};
}
}
pub fn remove(&mut self, key: &[u8]) -> Option<V> {
let state = self.trie.lookup_state(key)?;
let idx = state as usize;
let prev = if idx < self.values.len() {
let v = self.values[idx];
if v != V::EMPTY { Some(v) } else { None }
} else {
None
};
self.trie.remove(key);
if idx < self.values.len() {
self.values[idx] = V::EMPTY;
}
prev
}
pub fn iter_prefix(&self, prefix: &[u8]) -> PrefixIterator<'_, V> {
let mut curr = 0u32;
for &ch in prefix {
let next = self.trie.state_move(curr, ch);
if next == NIL_STATE {
return PrefixIterator { trie: self, stack: Vec::new(), path: Vec::new() };
}
curr = next;
}
let state = curr as usize;
let first_child = if state < self.trie.ninfos.len() {
self.trie.ninfos[state].first_child()
} else {
NINFO_NONE
};
let path = prefix.to_vec();
let frame = PrefixFrame {
state: curr,
next_sibling: first_child,
checked_terminal: false,
depth: prefix.len(),
};
PrefixIterator { trie: self, stack: vec![frame], path }
}
pub fn iter_fuzzy(&self, query: &[u8], max_dist: usize) -> FuzzyIterator<'_, V> {
let row0: Vec<usize> = (0..=query.len()).collect();
let first_child = if !self.trie.states.is_empty() {
self.trie.ninfos[0].first_child()
} else {
NINFO_NONE
};
let frame = FuzzyFrame {
state: 0, next_sibling: first_child,
checked_terminal: false,
depth: 0,
};
FuzzyIterator {
trie: self,
stack: vec![frame],
path: Vec::new(),
query: query.to_vec(),
max_dist,
dp_columns: vec![row0],
}
}
}
struct PrefixFrame {
state: u32,
next_sibling: u16,
checked_terminal: bool,
depth: usize,
}
pub struct PrefixIterator<'a, V: MapValue> {
trie: &'a DoubleArrayTrieMap<V>,
stack: Vec<PrefixFrame>,
path: Vec<u8>,
}
impl<'a, V: MapValue> Iterator for PrefixIterator<'a, V> {
type Item = (Vec<u8>, V);
fn next(&mut self) -> Option<Self::Item> {
loop {
let frame = self.stack.last_mut()?;
let state = frame.state;
if !frame.checked_terminal {
frame.checked_terminal = true;
let state_idx = state as usize;
if state_idx < self.trie.trie.ninfos.len()
&& self.trie.trie.ninfos[state_idx].is_term()
&& let Some(&val) = self.trie.values.get(state_idx)
&& val != V::EMPTY
{
return Some((self.path[..frame.depth].to_vec(), val));
}
}
if frame.next_sibling == NINFO_NONE {
self.stack.pop();
continue;
}
let label = (frame.next_sibling - 1) as u8;
let base = self.trie.trie.states[state as usize].child0();
let child_pos = (base ^ label as u32) as usize;
frame.next_sibling = if child_pos < self.trie.trie.ninfos.len() {
self.trie.trie.ninfos[child_pos].sibling
} else {
NINFO_NONE
};
let parent_depth = frame.depth;
if child_pos >= self.trie.trie.states.len()
|| self.trie.trie.states[child_pos].is_free()
{
continue;
}
let child_depth = parent_depth + 1;
self.path.truncate(parent_depth);
self.path.push(label);
let first_child = if child_pos < self.trie.trie.ninfos.len() {
self.trie.trie.ninfos[child_pos].first_child()
} else {
NINFO_NONE
};
self.stack.push(PrefixFrame {
state: child_pos as u32,
next_sibling: first_child,
checked_terminal: false,
depth: child_depth,
});
}
}
}
struct FuzzyFrame {
state: u32,
next_sibling: u16,
checked_terminal: bool,
depth: usize,
}
pub struct FuzzyIterator<'a, V: MapValue> {
trie: &'a DoubleArrayTrieMap<V>,
stack: Vec<FuzzyFrame>,
path: Vec<u8>,
query: Vec<u8>,
max_dist: usize,
dp_columns: Vec<Vec<usize>>,
}
impl<'a, V: MapValue> FuzzyIterator<'a, V> {
fn compute_row(prev_row: &[usize], query: &[u8], c: u8) -> Vec<usize> {
let mut row = vec![0; query.len() + 1];
row[0] = prev_row[0] + 1; for j in 1..=query.len() {
let cost = if query[j - 1] == c { 0 } else { 1 };
row[j] = (prev_row[j] + 1) .min(row[j - 1] + 1) .min(prev_row[j - 1] + cost); }
row
}
}
impl<'a, V: MapValue> Iterator for FuzzyIterator<'a, V> {
type Item = (Vec<u8>, V);
fn next(&mut self) -> Option<Self::Item> {
loop {
let frame = self.stack.last_mut()?;
let state = frame.state;
let depth = frame.depth;
if !frame.checked_terminal {
frame.checked_terminal = true;
let state_idx = state as usize;
if depth < self.dp_columns.len()
&& self.dp_columns[depth][self.query.len()] <= self.max_dist
&& state_idx < self.trie.trie.ninfos.len()
&& self.trie.trie.ninfos[state_idx].is_term()
&& let Some(&val) = self.trie.values.get(state_idx)
&& val != V::EMPTY
{
return Some((self.path[..depth].to_vec(), val));
}
}
if frame.next_sibling == NINFO_NONE {
self.stack.pop();
if let Some(parent) = self.stack.last() {
self.dp_columns.truncate(parent.depth + 1);
} else {
self.dp_columns.truncate(1); }
continue;
}
let label = (frame.next_sibling - 1) as u8;
let base = self.trie.trie.states[state as usize].child0();
let child_pos = (base ^ label as u32) as usize;
frame.next_sibling = if child_pos < self.trie.trie.ninfos.len() {
self.trie.trie.ninfos[child_pos].sibling
} else {
NINFO_NONE
};
let parent_depth = frame.depth;
if child_pos >= self.trie.trie.states.len()
|| self.trie.trie.states[child_pos].is_free()
{
continue;
}
let prev_row = &self.dp_columns[parent_depth];
let new_row = Self::compute_row(prev_row, &self.query, label);
if *new_row.iter().min().unwrap_or(&usize::MAX) > self.max_dist {
continue;
}
let child_depth = parent_depth + 1;
self.path.truncate(parent_depth);
self.path.push(label);
self.dp_columns.truncate(child_depth);
self.dp_columns.push(new_row);
let first_child = if child_pos < self.trie.trie.ninfos.len() {
self.trie.trie.ninfos[child_pos].first_child()
} else {
NINFO_NONE
};
self.stack.push(FuzzyFrame {
state: child_pos as u32,
next_sibling: first_child,
checked_terminal: false,
depth: child_depth,
});
}
}
}
impl<V: MapValue> std::fmt::Debug for DoubleArrayTrieMap<V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DoubleArrayTrieMap")
.field("num_keys", &self.trie.len())
.field("total_states", &self.trie.total_states())
.field("mem_size", &self.trie.mem_size())
.finish()
}
}
impl<V: MapValue> Default for DoubleArrayTrieMap<V> {
fn default() -> Self { Self::new() }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_insert_contains() {
let mut t = DoubleArrayTrie::new();
assert!(t.insert(b"hello").unwrap());
assert!(t.insert(b"help").unwrap());
assert!(t.insert(b"world").unwrap());
assert_eq!(t.len(), 3);
assert!(t.contains(b"hello"));
assert!(t.contains(b"help"));
assert!(t.contains(b"world"));
assert!(!t.contains(b"hel"));
assert!(!t.contains(b"hell"));
assert!(!t.contains(b"worlds"));
}
#[test]
fn test_duplicate_insert() {
let mut t = DoubleArrayTrie::new();
assert!(t.insert(b"abc").unwrap());
assert!(!t.insert(b"abc").unwrap());
assert!(!t.insert(b"abc").unwrap());
assert_eq!(t.len(), 1);
}
#[test]
fn test_empty_key() {
let mut t = DoubleArrayTrie::new();
assert!(t.insert(b"").unwrap());
assert!(t.contains(b""));
assert!(t.insert(b"a").unwrap());
assert_eq!(t.len(), 2);
let mut keys = t.keys();
keys.sort();
assert_eq!(keys, vec![vec![], vec![b'a']]);
}
#[test]
fn test_remove() {
let mut t = DoubleArrayTrie::new();
t.insert(b"hello").unwrap();
t.insert(b"world").unwrap();
assert_eq!(t.len(), 2);
assert!(t.remove(b"hello"));
assert_eq!(t.len(), 1);
assert!(!t.contains(b"hello"));
assert!(t.contains(b"world"));
assert!(!t.remove(b"missing"));
}
#[test]
fn test_restore_key() {
let mut t = DoubleArrayTrie::new();
t.insert(b"hello").unwrap();
t.insert(b"world").unwrap();
let state = t.lookup_state(b"hello").unwrap();
assert_eq!(t.restore_key(state).unwrap(), b"hello");
let state2 = t.lookup_state(b"world").unwrap();
assert_eq!(t.restore_key(state2).unwrap(), b"world");
}
#[test]
fn test_keys() {
let mut t = DoubleArrayTrie::new();
t.insert(b"apple").unwrap();
t.insert(b"app").unwrap();
t.insert(b"banana").unwrap();
let mut keys = t.keys();
keys.sort();
assert_eq!(keys.len(), 3);
assert_eq!(keys[0], b"app");
assert_eq!(keys[1], b"apple");
assert_eq!(keys[2], b"banana");
}
#[test]
fn test_keys_with_prefix() {
let mut t = DoubleArrayTrie::new();
t.insert(b"").unwrap();
t.insert(b"a").unwrap();
t.insert(b"ab").unwrap();
t.insert(b"abc").unwrap();
t.insert(b"abd").unwrap();
t.insert(b"b").unwrap();
let all = t.keys_with_prefix(b"");
assert_eq!(all.len(), 6);
let a = t.keys_with_prefix(b"a");
assert_eq!(a.len(), 4);
let ab = t.keys_with_prefix(b"ab");
assert_eq!(ab.len(), 3);
let none = t.keys_with_prefix(b"xyz");
assert_eq!(none.len(), 0);
}
#[test]
fn test_many_inserts() {
let mut t = DoubleArrayTrie::new();
for i in 0..1000 {
t.insert(format!("key_{:04}", i).as_bytes()).unwrap();
}
assert_eq!(t.len(), 1000);
assert!(t.contains(b"key_0000"));
assert!(t.contains(b"key_0500"));
assert!(t.contains(b"key_0999"));
assert!(!t.contains(b"key_1000"));
}
#[test]
fn test_state_move() {
let mut t = DoubleArrayTrie::new();
t.insert(b"abc").unwrap();
let s1 = t.state_move(0, b'a');
assert_ne!(s1, NIL_STATE);
let s2 = t.state_move(s1, b'b');
assert_ne!(s2, NIL_STATE);
let s3 = t.state_move(s2, b'c');
assert_ne!(s3, NIL_STATE);
assert!(t.is_term(s3));
assert_eq!(t.state_move(0, b'z'), NIL_STATE);
}
#[test]
fn test_build_from_sorted() {
let keys: Vec<&[u8]> = vec![b"apple", b"application", b"apply", b"banana", b"band"];
let t = DoubleArrayTrie::build_from_sorted(&keys).unwrap();
assert_eq!(t.len(), 5);
for key in &keys {
assert!(t.contains(key), "missing: {:?}", std::str::from_utf8(key));
}
}
#[test]
fn test_for_each_child() {
let mut t = DoubleArrayTrie::new();
t.insert(b"ab").unwrap();
t.insert(b"ac").unwrap();
t.insert(b"ad").unwrap();
let mut root_children = Vec::new();
t.for_each_child(0, |ch, _| root_children.push(ch));
assert_eq!(root_children, vec![b'a']);
let a_state = t.state_move(0, b'a');
let mut a_children = Vec::new();
t.for_each_child(a_state, |ch, _| a_children.push(ch));
a_children.sort();
assert_eq!(a_children, vec![b'b', b'c', b'd']);
}
#[test]
fn test_da_trie_map() {
let mut map = DoubleArrayTrieMap::<u32>::new();
map.insert(b"hello", 42).unwrap();
map.insert(b"world", 100).unwrap();
map.insert(b"help", 7).unwrap();
assert_eq!(map.get(b"hello"), Some(42));
assert_eq!(map.get(b"world"), Some(100));
assert_eq!(map.get(b"help"), Some(7));
assert_eq!(map.get(b"missing"), None);
let prev = map.insert(b"hello", 99).unwrap();
assert_eq!(prev, Some(42));
assert_eq!(map.get(b"hello"), Some(99));
let removed = map.remove(b"world");
assert_eq!(removed, Some(100));
assert_eq!(map.get(b"world"), None);
assert_eq!(map.len(), 2);
}
#[test]
fn test_mem_size() {
let t = DoubleArrayTrie::new();
assert_eq!(t.mem_size(), 256 * 12);
assert_eq!(std::mem::size_of::<DaState>(), 8);
assert_eq!(std::mem::size_of::<NInfo>(), 4);
}
#[test]
fn test_shared_prefixes() {
let mut t = DoubleArrayTrie::new();
t.insert(b"test").unwrap();
t.insert(b"testing").unwrap();
t.insert(b"tested").unwrap();
t.insert(b"tester").unwrap();
t.insert(b"tests").unwrap();
t.insert(b"tea").unwrap();
t.insert(b"team").unwrap();
t.insert(b"tear").unwrap();
assert_eq!(t.len(), 8);
assert!(t.contains(b"test"));
assert!(t.contains(b"testing"));
assert!(t.contains(b"tea"));
assert!(t.contains(b"team"));
assert!(!t.contains(b"te")); assert!(!t.contains(b"testi")); }
#[test]
fn test_long_keys() {
let mut t = DoubleArrayTrie::new();
let long_key = "a".repeat(1000);
t.insert(long_key.as_bytes()).unwrap();
assert!(t.contains(long_key.as_bytes()));
assert_eq!(t.len(), 1);
let state = t.lookup_state(long_key.as_bytes()).unwrap();
let restored = t.restore_key(state).unwrap();
assert_eq!(restored, long_key.as_bytes());
}
#[test]
fn test_binary_keys() {
let mut t = DoubleArrayTrie::new();
t.insert(&[0x00, 0xFF, 0x80]).unwrap();
t.insert(&[0x00, 0xFF, 0x81]).unwrap();
t.insert(&[0xFF, 0x00, 0x01]).unwrap();
assert_eq!(t.len(), 3);
assert!(t.contains(&[0x00, 0xFF, 0x80]));
assert!(t.contains(&[0xFF, 0x00, 0x01]));
assert!(!t.contains(&[0x00, 0xFF]));
}
#[test]
fn test_relocation_stress() {
let mut t = DoubleArrayTrie::new();
for ch in 0u8..=127u8 {
let key = [ch];
t.insert(&key).unwrap();
}
assert_eq!(t.len(), 128);
for ch in 0u8..=127u8 {
assert!(t.contains(&[ch]), "missing single-byte key {}", ch);
}
}
#[test]
fn test_shrink_to_fit() {
let mut t = DoubleArrayTrie::with_capacity(10000);
assert!(t.total_states() >= 10000);
t.insert(b"hello").unwrap();
t.insert(b"world").unwrap();
t.shrink_to_fit();
assert!(t.total_states() < 1000);
assert!(t.contains(b"hello"));
assert!(t.contains(b"world"));
}
#[test]
fn test_small_capacity_no_oob() {
let t = DoubleArrayTrie::with_capacity(2);
assert!(t.total_states() >= 256, "minimum capacity must be 256");
assert!(!t.contains(b"hello")); assert!(!t.contains(b"\xff")); assert!(!t.contains(b"\x00"));
let mut t = DoubleArrayTrie::with_capacity(1);
t.insert(b"test").unwrap();
assert!(t.contains(b"test"));
assert!(!t.contains(b"other"));
let t = DoubleArrayTrie::with_capacity(0);
assert!(t.total_states() >= 256);
assert!(!t.contains(b"anything"));
}
#[test]
fn test_remove_and_reinsert() {
let mut t = DoubleArrayTrie::new();
t.insert(b"abc").unwrap();
assert!(t.contains(b"abc"));
t.remove(b"abc");
assert!(!t.contains(b"abc"));
assert_eq!(t.len(), 0);
assert!(t.insert(b"abc").unwrap());
assert!(t.contains(b"abc"));
assert_eq!(t.len(), 1);
}
#[test]
fn test_lookup_state_consistency() {
let mut t = DoubleArrayTrie::new();
let keys: Vec<&[u8]> = vec![b"alpha", b"beta", b"gamma", b"delta"];
for &key in &keys {
t.insert(key).unwrap();
}
let states: Vec<u32> = keys.iter()
.map(|k| t.lookup_state(k).unwrap())
.collect();
for i in 0..states.len() {
for j in (i + 1)..states.len() {
assert_ne!(states[i], states[j],
"states for {:?} and {:?} should differ",
std::str::from_utf8(keys[i]).unwrap(),
std::str::from_utf8(keys[j]).unwrap());
}
}
}
#[test]
fn test_performance_5000_terms() {
let terms: Vec<String> = (0..5000)
.map(|i| format!("term_{:06}_{}", i, ["alpha", "beta", "gamma", "delta"][i % 4]))
.collect();
let start = std::time::Instant::now();
let mut t = DoubleArrayTrie::new();
for term in &terms {
t.insert(term.as_bytes()).unwrap();
}
let insert_time = start.elapsed();
assert_eq!(t.len(), 5000);
let start = std::time::Instant::now();
for term in &terms {
assert!(t.contains(term.as_bytes()));
}
let lookup_time = start.elapsed();
let start = std::time::Instant::now();
for i in 0..5000 {
let miss = format!("miss_{:06}", i);
assert!(!t.contains(miss.as_bytes()));
}
let miss_time = start.elapsed();
#[cfg(not(debug_assertions))]
{
eprintln!("DoubleArrayTrie 5000 terms: insert={:?}, lookup_hit={:?}, lookup_miss={:?}",
insert_time, lookup_time, miss_time);
eprintln!("Memory: {} bytes ({} bytes/key), {} states",
t.mem_size(), t.mem_size() / 5000, t.total_states());
assert!(insert_time.as_millis() < 50,
"Insert too slow: {:?}", insert_time);
assert!(lookup_time.as_millis() < 10,
"Lookup too slow: {:?}", lookup_time);
}
}
#[test]
fn test_entries_with_prefix() {
let mut map = DoubleArrayTrieMap::<u32>::new();
map.insert(b"apple", 1).unwrap();
map.insert(b"app", 2).unwrap();
map.insert(b"application", 3).unwrap();
map.insert(b"banana", 4).unwrap();
let entries = map.entries_with_prefix(b"app");
assert_eq!(entries.len(), 3);
let values: Vec<u32> = entries.iter().map(|(_, v)| *v).collect();
assert!(values.contains(&1)); assert!(values.contains(&2)); assert!(values.contains(&3));
let banana_entries = map.entries_with_prefix(b"ban");
assert_eq!(banana_entries.len(), 1);
assert_eq!(banana_entries[0].1, 4);
let none = map.entries_with_prefix(b"xyz");
assert_eq!(none.len(), 0);
}
#[test]
fn test_values_with_prefix() {
let mut map = DoubleArrayTrieMap::<u32>::new();
map.insert(b"test_a", 10).unwrap();
map.insert(b"test_b", 20).unwrap();
map.insert(b"other", 30).unwrap();
let vals = map.values_with_prefix(b"test_");
assert_eq!(vals.len(), 2);
assert!(vals.contains(&10));
assert!(vals.contains(&20));
}
#[test]
fn test_for_each_key_with_prefix() {
let mut t = DoubleArrayTrie::new();
t.insert(b"hello").unwrap();
t.insert(b"help").unwrap();
t.insert(b"world").unwrap();
let mut found = Vec::new();
t.for_each_key_with_prefix(b"hel", |key| {
found.push(key.to_vec());
});
found.sort();
assert_eq!(found.len(), 2);
assert_eq!(found[0], b"hello");
assert_eq!(found[1], b"help");
let mut all = Vec::new();
t.for_each_key_with_prefix(b"", |key| all.push(key.to_vec()));
assert_eq!(all.len(), 3);
}
#[test]
fn test_prefix_iterator_matches_entries_with_prefix() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
let words: &[&[u8]] = &[b"apple", b"app", b"application", b"banana", b"band", b"bar"];
for (i, w) in words.iter().enumerate() {
trie.insert(w, i as i32 + 1).unwrap();
}
let prefixes: Vec<&[u8]> = vec![b"app", b"ban", b"b", b"apple", b"z", b""];
for prefix in &prefixes {
let mut lazy: Vec<_> = trie.iter_prefix(prefix).collect();
let mut eager = trie.entries_with_prefix(prefix);
lazy.sort_by(|a, b| a.0.cmp(&b.0));
eager.sort_by(|a, b| a.0.cmp(&b.0));
assert_eq!(lazy, eager, "mismatch for prefix {:?}", std::str::from_utf8(prefix));
}
}
#[test]
fn test_prefix_iterator_empty_trie() {
let trie = DoubleArrayTrieMap::<i32>::new();
assert_eq!(trie.iter_prefix(b"any").count(), 0);
}
#[test]
fn test_prefix_iterator_nonexistent_prefix() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
trie.insert(b"hello", 1).unwrap();
assert_eq!(trie.iter_prefix(b"xyz").count(), 0);
}
#[test]
fn test_prefix_iterator_empty_prefix_yields_all() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
trie.insert(b"a", 1).unwrap();
trie.insert(b"b", 2).unwrap();
trie.insert(b"c", 3).unwrap();
assert_eq!(trie.iter_prefix(b"").count(), 3);
}
#[test]
fn test_prefix_iterator_drop_early() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
for i in 0..100 {
let key = format!("key{:03}", i);
trie.insert(key.as_bytes(), i).unwrap();
}
let first5: Vec<_> = trie.iter_prefix(b"key").take(5).collect();
assert_eq!(first5.len(), 5);
}
#[test]
fn test_fuzzy_iterator_exact_match() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
trie.insert(b"cat", 1).unwrap();
trie.insert(b"car", 2).unwrap();
trie.insert(b"cap", 3).unwrap();
trie.insert(b"dog", 4).unwrap();
let results: Vec<_> = trie.iter_fuzzy(b"cat", 0).collect();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, b"cat");
assert_eq!(results[0].1, 1);
}
#[test]
fn test_fuzzy_iterator_distance_1() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
trie.insert(b"cat", 1).unwrap();
trie.insert(b"car", 2).unwrap();
trie.insert(b"cap", 3).unwrap();
trie.insert(b"bat", 4).unwrap();
trie.insert(b"dog", 5).unwrap();
let mut results: Vec<_> = trie.iter_fuzzy(b"cat", 1).collect();
results.sort_by(|a, b| a.0.cmp(&b.0));
let keys: Vec<&[u8]> = results.iter().map(|(k, _)| k.as_slice()).collect();
assert!(keys.contains(&b"cat".as_slice()));
assert!(keys.contains(&b"car".as_slice()));
assert!(keys.contains(&b"cap".as_slice()));
assert!(keys.contains(&b"bat".as_slice()));
assert!(!keys.contains(&b"dog".as_slice())); }
#[test]
fn test_fuzzy_iterator_empty_query() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
trie.insert(b"a", 1).unwrap();
trie.insert(b"ab", 2).unwrap();
trie.insert(b"abc", 3).unwrap();
let results: Vec<_> = trie.iter_fuzzy(b"", 1).collect();
assert!(results.iter().any(|(k, _)| k == b"a"));
assert!(!results.iter().any(|(k, _)| k == b"abc")); }
#[test]
fn test_fuzzy_iterator_empty_trie() {
let trie = DoubleArrayTrieMap::<i32>::new();
assert_eq!(trie.iter_fuzzy(b"test", 2).count(), 0);
}
#[test]
fn test_fuzzy_iterator_all_within_distance() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
trie.insert(b"cat", 1).unwrap();
let results: Vec<_> = trie.iter_fuzzy(b"xyz", 10).collect();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, b"cat");
}
#[test]
fn test_prefix_iterator_prefix_longer_than_keys() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
trie.insert(b"app", 1).unwrap();
trie.insert(b"apple", 2).unwrap();
assert_eq!(trie.iter_prefix(b"application123").count(), 0);
}
#[test]
fn test_prefix_iterator_single_key_trie() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
trie.insert(b"hello", 42).unwrap();
let results: Vec<_> = trie.iter_prefix(b"hello").collect();
assert_eq!(results.len(), 1);
assert_eq!(results[0], (b"hello".to_vec(), 42));
let results: Vec<_> = trie.iter_prefix(b"hel").collect();
assert_eq!(results.len(), 1);
assert_eq!(results[0], (b"hello".to_vec(), 42));
assert_eq!(trie.iter_prefix(b"world").count(), 0);
}
#[test]
fn test_prefix_iterator_deeply_nested() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
trie.insert(b"a", 1).unwrap();
trie.insert(b"ab", 2).unwrap();
trie.insert(b"abc", 3).unwrap();
trie.insert(b"abcd", 4).unwrap();
trie.insert(b"abcde", 5).unwrap();
assert_eq!(trie.iter_prefix(b"").count(), 5);
assert_eq!(trie.iter_prefix(b"a").count(), 5);
assert_eq!(trie.iter_prefix(b"ab").count(), 4);
assert_eq!(trie.iter_prefix(b"abc").count(), 3);
assert_eq!(trie.iter_prefix(b"abcd").count(), 2);
assert_eq!(trie.iter_prefix(b"abcde").count(), 1);
assert_eq!(trie.iter_prefix(b"abcdef").count(), 0);
}
#[test]
fn test_prefix_iterator_large_trie() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
for i in 0..1000 {
let key = format!("key{:04}", i);
trie.insert(key.as_bytes(), i).unwrap();
}
for prefix in [b"key" as &[u8], b"key0", b"key00", b"key000", b"key0001", b""] {
let mut lazy: Vec<_> = trie.iter_prefix(prefix).collect();
let mut eager = trie.entries_with_prefix(prefix);
lazy.sort_by(|a, b| a.0.cmp(&b.0));
eager.sort_by(|a, b| a.0.cmp(&b.0));
assert_eq!(lazy, eager, "mismatch for prefix {:?}", std::str::from_utf8(prefix));
}
}
#[test]
fn test_prefix_iterator_binary_keys() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
trie.insert(&[0x00, 0x01, 0x02], 1).unwrap();
trie.insert(&[0x00, 0x01, 0xFF], 2).unwrap();
trie.insert(&[0x00, 0xFF], 3).unwrap();
trie.insert(&[0xFF, 0x00], 4).unwrap();
let results: Vec<_> = trie.iter_prefix(&[0x00]).collect();
assert_eq!(results.len(), 3);
let results2: Vec<_> = trie.iter_prefix(&[0x00, 0x01]).collect();
assert_eq!(results2.len(), 2);
}
#[test]
fn test_fuzzy_iterator_edit_distance_verification() {
fn edit_distance(a: &[u8], b: &[u8]) -> usize {
let m = a.len();
let n = b.len();
let mut dp = vec![vec![0usize; n + 1]; m + 1];
for i in 0..=m { dp[i][0] = i; }
for j in 0..=n { dp[0][j] = j; }
for i in 1..=m {
for j in 1..=n {
let cost = if a[i-1] == b[j-1] { 0 } else { 1 };
dp[i][j] = (dp[i-1][j] + 1)
.min(dp[i][j-1] + 1)
.min(dp[i-1][j-1] + cost);
}
}
dp[m][n]
}
let mut trie = DoubleArrayTrieMap::<i32>::new();
let words = [b"cat" as &[u8], b"car", b"cap", b"bat", b"hat", b"cart", b"ca", b"c", b"cats", b"dog"];
for (i, w) in words.iter().enumerate() {
trie.insert(w, i as i32).unwrap();
}
let query = b"cat";
for max_dist in 0..=3 {
let results: Vec<_> = trie.iter_fuzzy(query, max_dist).collect();
for (key, _) in &results {
let dist = edit_distance(key, query);
assert!(dist <= max_dist,
"key {:?} has distance {} from {:?}, exceeds max_dist {}",
std::str::from_utf8(key), dist, std::str::from_utf8(query), max_dist);
}
for w in &words {
let dist = edit_distance(w, query);
if dist <= max_dist {
assert!(results.iter().any(|(k, _)| k == *w),
"key {:?} at distance {} missing from results (max_dist={})",
std::str::from_utf8(w), dist, max_dist);
}
}
}
}
#[test]
fn test_fuzzy_iterator_insertion_deletion_substitution() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
trie.insert(b"cat", 1).unwrap(); trie.insert(b"at", 2).unwrap(); trie.insert(b"ct", 3).unwrap(); trie.insert(b"ca", 4).unwrap(); trie.insert(b"cats", 5).unwrap(); trie.insert(b"scat", 6).unwrap(); trie.insert(b"caat", 7).unwrap(); trie.insert(b"bat", 8).unwrap(); trie.insert(b"cot", 9).unwrap(); trie.insert(b"cab", 10).unwrap();
let results: Vec<_> = trie.iter_fuzzy(b"cat", 1).collect();
let keys: Vec<Vec<u8>> = results.iter().map(|(k, _)| k.clone()).collect();
assert!(keys.contains(&b"cat".to_vec()), "exact match");
assert!(keys.contains(&b"at".to_vec()), "deletion of c");
assert!(keys.contains(&b"ct".to_vec()), "deletion of a");
assert!(keys.contains(&b"ca".to_vec()), "deletion of t");
assert!(keys.contains(&b"cats".to_vec()), "insertion of s");
assert!(keys.contains(&b"bat".to_vec()), "substitution c->b");
assert!(keys.contains(&b"cot".to_vec()), "substitution a->o");
assert!(keys.contains(&b"cab".to_vec()), "substitution t->b");
}
#[test]
fn test_fuzzy_iterator_root_terminal() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
trie.insert(b"", 1).unwrap();
trie.insert(b"a", 2).unwrap();
trie.insert(b"ab", 3).unwrap();
let results: Vec<_> = trie.iter_fuzzy(b"", 0).collect();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, b"".to_vec());
let results: Vec<_> = trie.iter_fuzzy(b"", 1).collect();
assert!(results.iter().any(|(k, _)| k.is_empty()));
assert!(results.iter().any(|(k, _)| k == b"a"));
assert!(!results.iter().any(|(k, _)| k == b"ab")); }
#[test]
fn test_fuzzy_iterator_long_query_short_keys() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
trie.insert(b"a", 1).unwrap();
trie.insert(b"ab", 2).unwrap();
let results: Vec<_> = trie.iter_fuzzy(b"abcdefgh", 1).collect();
assert_eq!(results.len(), 0);
let results: Vec<_> = trie.iter_fuzzy(b"abcdefgh", 7).collect();
assert!(results.iter().any(|(k, _)| k == b"a"));
}
#[test]
fn test_fuzzy_iterator_binary_keys() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
trie.insert(&[0x00, 0x01], 1).unwrap();
trie.insert(&[0x00, 0x02], 2).unwrap();
trie.insert(&[0xFF, 0xFE], 3).unwrap();
let results: Vec<_> = trie.iter_fuzzy(&[0x00, 0x01], 1).collect();
assert!(results.iter().any(|(k, _)| k == &[0x00, 0x01]));
assert!(results.iter().any(|(k, _)| k == &[0x00, 0x02]));
assert!(!results.iter().any(|(k, _)| k == &[0xFF, 0xFE])); }
#[test]
fn test_fuzzy_iterator_pruning_completeness() {
let mut trie = DoubleArrayTrieMap::<i32>::new();
let words: Vec<String> = (0..100).map(|i| format!("word{:02}", i)).collect();
for (i, w) in words.iter().enumerate() {
trie.insert(w.as_bytes(), i as i32).unwrap();
}
fn edit_distance(a: &[u8], b: &[u8]) -> usize {
let m = a.len();
let n = b.len();
let mut dp = vec![vec![0usize; n + 1]; m + 1];
for i in 0..=m { dp[i][0] = i; }
for j in 0..=n { dp[0][j] = j; }
for i in 1..=m {
for j in 1..=n {
let cost = if a[i-1] == b[j-1] { 0 } else { 1 };
dp[i][j] = (dp[i-1][j] + 1)
.min(dp[i][j-1] + 1)
.min(dp[i-1][j-1] + cost);
}
}
dp[m][n]
}
let query = b"word50";
for max_dist in 0..=2 {
let results: Vec<_> = trie.iter_fuzzy(query, max_dist).collect();
for (key, _) in &results {
let d = edit_distance(key, query);
assert!(d <= max_dist, "spurious result {:?} at distance {}",
std::str::from_utf8(key), d);
}
for w in &words {
let d = edit_distance(w.as_bytes(), query);
if d <= max_dist {
assert!(results.iter().any(|(k, _)| k == w.as_bytes()),
"missing {:?} at distance {} (max_dist={})", w, d, max_dist);
}
}
}
}
#[test]
fn test_cursor_seek_begin_end() {
let mut t = DoubleArrayTrie::new();
t.insert(b"apple").unwrap();
t.insert(b"banana").unwrap();
t.insert(b"cherry").unwrap();
let mut c = t.cursor();
assert!(c.seek_begin());
assert_eq!(c.key(), b"apple");
assert!(c.seek_end());
assert_eq!(c.key(), b"cherry");
}
#[test]
fn test_cursor_next_prev() {
let mut t = DoubleArrayTrie::new();
for w in &["apple", "banana", "cherry", "date", "elderberry"] {
t.insert(w.as_bytes()).unwrap();
}
let mut c = t.cursor();
c.seek_begin();
assert_eq!(c.key(), b"apple");
assert!(c.next());
assert_eq!(c.key(), b"banana");
assert!(c.next());
assert_eq!(c.key(), b"cherry");
assert!(c.next());
assert_eq!(c.key(), b"date");
assert!(c.next());
assert_eq!(c.key(), b"elderberry");
assert!(!c.next());
c.seek_end();
assert_eq!(c.key(), b"elderberry");
assert!(c.prev());
assert_eq!(c.key(), b"date");
assert!(c.prev());
assert_eq!(c.key(), b"cherry");
assert!(c.prev());
assert_eq!(c.key(), b"banana");
assert!(c.prev());
assert_eq!(c.key(), b"apple");
assert!(!c.prev()); }
#[test]
fn test_cursor_seek_lower_bound() {
let mut t = DoubleArrayTrie::new();
for w in &["apple", "banana", "cherry", "date", "elderberry"] {
t.insert(w.as_bytes()).unwrap();
}
let mut c = t.cursor();
assert!(c.seek_lower_bound(b"cherry"));
assert_eq!(c.key(), b"cherry");
assert!(c.seek_lower_bound(b"c"));
assert_eq!(c.key(), b"cherry");
assert!(c.seek_lower_bound(b"a"));
assert_eq!(c.key(), b"apple");
assert!(!c.seek_lower_bound(b"z"));
assert!(c.seek_lower_bound(b"cat"));
assert_eq!(c.key(), b"cherry");
}
#[test]
fn test_range() {
let mut t = DoubleArrayTrie::new();
for w in &["apple", "banana", "cherry", "date", "elderberry", "fig"] {
t.insert(w.as_bytes()).unwrap();
}
let range: Vec<Vec<u8>> = t.range(b"b", b"e").collect();
assert_eq!(range.len(), 3);
assert_eq!(range[0], b"banana");
assert_eq!(range[1], b"cherry");
assert_eq!(range[2], b"date");
let all: Vec<Vec<u8>> = t.range(b"a", b"z").collect();
assert_eq!(all.len(), 6);
let empty: Vec<Vec<u8>> = t.range(b"d", b"d").collect();
assert_eq!(empty.len(), 0);
let mid: Vec<Vec<u8>> = t.range(b"cherry", b"elderberry").collect();
assert_eq!(mid.len(), 2);
assert_eq!(mid[0], b"cherry");
assert_eq!(mid[1], b"date");
}
#[test]
fn test_cursor_empty_trie() {
let t = DoubleArrayTrie::new();
let mut c = t.cursor();
assert!(!c.seek_begin());
assert!(!c.seek_end());
assert!(!c.seek_lower_bound(b"anything"));
}
#[test]
fn test_cursor_single_key() {
let mut t = DoubleArrayTrie::new();
t.insert(b"only").unwrap();
let mut c = t.cursor();
assert!(c.seek_begin());
assert_eq!(c.key(), b"only");
assert!(!c.next());
assert!(c.seek_end());
assert_eq!(c.key(), b"only");
assert!(!c.prev());
}
#[test]
fn test_cursor_with_empty_key() {
let mut t = DoubleArrayTrie::new();
t.insert(b"").unwrap();
t.insert(b"a").unwrap();
t.insert(b"b").unwrap();
let mut c = t.cursor();
assert!(c.seek_begin());
assert_eq!(c.key(), b"");
assert!(c.next());
assert_eq!(c.key(), b"a");
assert!(c.next());
assert_eq!(c.key(), b"b");
}
#[test]
fn test_cursor_full_traversal_matches_keys() {
let mut t = DoubleArrayTrie::new();
let words = ["alpha", "beta", "gamma", "delta", "epsilon",
"zeta", "eta", "theta", "iota", "kappa"];
for w in &words {
t.insert(w.as_bytes()).unwrap();
}
let mut cursor_keys = Vec::new();
let mut c = t.cursor();
if c.seek_begin() {
cursor_keys.push(c.key().to_vec());
while c.next() {
cursor_keys.push(c.key().to_vec());
}
}
let mut trie_keys = t.keys();
trie_keys.sort();
assert_eq!(cursor_keys, trie_keys,
"Cursor traversal must match sorted keys()");
}
#[test]
fn test_range_empty_bounds() {
let mut t = DoubleArrayTrie::new();
for w in &["a", "b", "c", "d"] {
t.insert(w.as_bytes()).unwrap();
}
let r: Vec<_> = t.range(b"z", b"a").collect();
assert_eq!(r.len(), 0);
let r: Vec<_> = t.range(b"", b"\xff").collect();
assert_eq!(r.len(), 4);
}
#[test]
fn test_cursor_seek_lower_bound_exact_last() {
let mut t = DoubleArrayTrie::new();
t.insert(b"aaa").unwrap();
t.insert(b"zzz").unwrap();
let mut c = t.cursor();
assert!(c.seek_lower_bound(b"zzz"));
assert_eq!(c.key(), b"zzz");
assert!(!c.next());
assert!(!c.seek_lower_bound(b"zzzz"));
}
#[test]
fn test_cursor_prev_from_begin() {
let mut t = DoubleArrayTrie::new();
t.insert(b"first").unwrap();
t.insert(b"second").unwrap();
let mut c = t.cursor();
c.seek_begin();
assert_eq!(c.key(), b"first");
assert!(!c.prev()); }
#[test]
fn test_cursor_interleaved_next_prev() {
let mut t = DoubleArrayTrie::new();
for w in &["a", "b", "c", "d", "e"] {
t.insert(w.as_bytes()).unwrap();
}
let mut c = t.cursor();
c.seek_begin();
assert_eq!(c.key(), b"a");
c.next();
assert_eq!(c.key(), b"b");
c.next();
assert_eq!(c.key(), b"c");
c.prev();
assert_eq!(c.key(), b"b");
c.next();
assert_eq!(c.key(), b"c");
c.next();
assert_eq!(c.key(), b"d");
}
#[test]
fn test_cursor_many_keys_sorted() {
let mut t = DoubleArrayTrie::new();
for i in 0..200u32 {
t.insert(format!("k{:04}", i).as_bytes()).unwrap();
}
let mut c = t.cursor();
let mut keys = Vec::new();
if c.seek_begin() {
keys.push(c.key().to_vec());
while c.next() { keys.push(c.key().to_vec()); }
}
assert_eq!(keys.len(), 200);
for i in 1..keys.len() {
assert!(keys[i - 1] < keys[i], "Not sorted at {}: {:?} >= {:?}",
i, String::from_utf8_lossy(&keys[i-1]), String::from_utf8_lossy(&keys[i]));
}
let mut c = t.cursor();
let mut rkeys = Vec::new();
if c.seek_end() {
rkeys.push(c.key().to_vec());
while c.prev() { rkeys.push(c.key().to_vec()); }
}
assert_eq!(rkeys.len(), 200);
rkeys.reverse();
assert_eq!(keys, rkeys, "Forward and backward traversals must match");
}
#[test]
fn test_range_single_element() {
let mut t = DoubleArrayTrie::new();
t.insert(b"hello").unwrap();
t.insert(b"world").unwrap();
let r: Vec<_> = t.range(b"hello", b"world").collect();
assert_eq!(r.len(), 1);
assert_eq!(r[0], b"hello");
}
#[test]
fn test_seek_lower_bound_between_shared_prefix() {
let mut t = DoubleArrayTrie::new();
t.insert(b"abc").unwrap();
t.insert(b"abd").unwrap();
t.insert(b"abe").unwrap();
let mut c = t.cursor();
assert!(c.seek_lower_bound(b"abd"));
assert_eq!(c.key(), b"abd");
assert!(c.seek_lower_bound(b"abcc"));
assert_eq!(c.key(), b"abd");
assert!(c.seek_lower_bound(b"ab"));
assert_eq!(c.key(), b"abc");
}
#[test]
fn test_keys_with_prefix_1000_terms() {
let mut t = DoubleArrayTrie::new();
for i in 0..1000u32 {
t.insert(format!("term_{:04}", i).as_bytes()).unwrap();
}
assert_eq!(t.len(), 1000);
let results = t.keys_with_prefix(b"term_00");
assert_eq!(results.len(), 100,
"keys_with_prefix('term_00') should return 100 (term_0000..term_0099), got {}",
results.len());
let results2 = t.keys_with_prefix(b"term_01");
assert_eq!(results2.len(), 100,
"keys_with_prefix('term_01') should return 100, got {}",
results2.len());
let all = t.keys_with_prefix(b"term_");
assert_eq!(all.len(), 1000,
"keys_with_prefix('term_') should return 1000, got {}",
all.len());
let all_keys = t.keys();
assert_eq!(all_keys.len(), 1000,
"keys() should return 1000, got {}",
all_keys.len());
}
#[test]
fn test_map_values_basic() {
let mut m = DoubleArrayTrieMap::<i32>::new();
let prev = m.insert(b"hello", 42).unwrap();
assert_eq!(prev, None);
assert_eq!(m.get(b"hello"), Some(42));
assert_eq!(m.get(b"world"), None);
m.insert(b"world", 100).unwrap();
assert_eq!(m.get(b"world"), Some(100));
let prev = m.insert(b"hello", 99).unwrap();
assert_eq!(prev, Some(42));
assert_eq!(m.get(b"hello"), Some(99));
}
#[test]
fn test_map_values_many() {
let mut m = DoubleArrayTrieMap::<i32>::new();
for i in 0..500i32 {
m.insert(format!("key_{:04}", i).as_bytes(), i).unwrap();
}
for i in 0..500i32 {
assert_eq!(m.get(format!("key_{:04}", i).as_bytes()), Some(i),
"value mismatch for key_{:04}", i);
}
assert_eq!(m.len(), 500);
}
#[test]
fn test_map_values_with_contains() {
let mut m = DoubleArrayTrieMap::<i32>::new();
m.insert(b"abc", 1).unwrap();
m.insert(b"abd", 2).unwrap();
assert!(m.contains(b"abc"));
assert!(m.contains(b"abd"));
assert!(!m.contains(b"ab"));
assert_eq!(m.get(b"abc"), Some(1));
assert_eq!(m.get(b"abd"), Some(2));
}
#[test]
fn test_consult_many_inserts() {
let mut t = DoubleArrayTrie::new();
for i in 0..1000u32 {
t.insert(format!("term_{:04}", i).as_bytes()).unwrap();
}
assert_eq!(t.len(), 1000);
for i in 0..1000u32 {
assert!(t.contains(format!("term_{:04}", i).as_bytes()),
"missing term_{:04}", i);
}
}
#[test]
fn test_map_prefix_value_iteration() {
let mut m = DoubleArrayTrieMap::<i32>::new();
m.insert(b"app", 1).unwrap();
m.insert(b"apple", 2).unwrap();
m.insert(b"application", 3).unwrap();
m.insert(b"banana", 4).unwrap();
let mut vals: Vec<i32> = m.values_with_prefix(b"app");
vals.sort();
assert_eq!(vals, vec![1, 2, 3]);
let mut all_vals: Vec<i32> = m.values_with_prefix(b"");
all_vals.sort();
assert_eq!(all_vals, vec![1, 2, 3, 4]);
let none: Vec<i32> = m.values_with_prefix(b"xyz");
assert!(none.is_empty());
}
#[test]
fn test_map_for_each_value_with_prefix() {
let mut m = DoubleArrayTrieMap::<i32>::new();
m.insert(b"app", 1).unwrap();
m.insert(b"apple", 2).unwrap();
m.insert(b"application", 3).unwrap();
m.insert(b"banana", 4).unwrap();
let mut callback_vals = Vec::new();
m.for_each_value_with_prefix(b"app", |v| callback_vals.push(v));
callback_vals.sort();
assert_eq!(callback_vals, vec![1, 2, 3]);
let mut all = Vec::new();
m.for_each_value_with_prefix(b"", |v| all.push(v));
all.sort();
assert_eq!(all, vec![1, 2, 3, 4]);
let mut none = Vec::new();
m.for_each_value_with_prefix(b"xyz", |v| none.push(v));
assert!(none.is_empty());
}
#[test]
fn test_map_value_performance() {
let mut m = DoubleArrayTrieMap::<i32>::new();
for i in 0..5000i32 {
m.insert(format!("term_{:06}", i).as_bytes(), i).unwrap();
}
assert_eq!(m.len(), 5000);
for i in 0..5000i32 {
assert_eq!(m.get(format!("term_{:06}", i).as_bytes()), Some(i));
}
let prefix_vals = m.values_with_prefix(b"term_001");
assert_eq!(prefix_vals.len(), 1000, "prefix 'term_001' should yield 1000 values, got {}", prefix_vals.len());
}
}
#[cfg(test)]
mod prefix_regression_tests {
use super::*;
#[test]
fn test_1000_terms_prefix() {
let mut t = DoubleArrayTrie::new();
for i in 0..1000u32 {
let term = format!("term_{:04}", i);
let inserted = t.insert(term.as_bytes()).unwrap();
assert!(inserted || !inserted, "insert returned for term_{:04}", i);
}
assert_eq!(t.len(), 1000, "expected 1000 keys, got {}", t.len());
let mut missing = Vec::new();
for i in 0..1000u32 {
let term = format!("term_{:04}", i);
if !t.contains(term.as_bytes()) {
missing.push(i);
}
}
assert!(missing.is_empty(), "missing {} terms: {:?}", missing.len(), &missing[..missing.len().min(20)]);
let result = t.keys_with_prefix(b"term_00");
assert_eq!(result.len(), 100, "prefix 'term_00' returned {} (expected 100)", result.len());
}
}
#[cfg(test)]
mod map_prefix_regression_tests {
use super::*;
#[test]
fn test_map_1000_terms_prefix_fresh_trie() {
let mut entries: Vec<(String, u32)> = (0..1000u32)
.map(|i| (format!("term_{:04}", i), i))
.collect();
entries.sort_unstable_by(|a, b| a.0.cmp(&b.0));
let mut trie = DoubleArrayTrieMap::with_capacity(entries.len());
for (term, id) in &entries {
trie.insert(term.as_bytes(), *id).expect("insert failed");
}
assert_eq!(trie.len(), 1000);
for i in 0..1000u32 {
let term = format!("term_{:04}", i);
assert_eq!(trie.get(term.as_bytes()), Some(i), "get failed for {}", term);
}
let result = trie.values_with_prefix(b"term_00");
assert_eq!(result.len(), 100, "values_with_prefix 'term_00' returned {} (expected 100)", result.len());
}
#[test]
fn test_all_256_single_byte_keys() {
let mut t = DoubleArrayTrie::new();
for b in 0u8..=255 {
t.insert(&[b]).unwrap();
}
assert_eq!(t.len(), 256);
for b in 0u8..=255 {
assert!(t.contains(&[b]), "missing single-byte key 0x{:02x}", b);
}
}
#[test]
fn test_insert_after_shrink() {
let mut t = DoubleArrayTrie::new();
t.insert(b"hello").unwrap();
t.insert(b"world").unwrap();
t.shrink_to_fit();
t.insert(b"foo").unwrap();
t.insert(b"bar").unwrap();
assert_eq!(t.len(), 4);
assert!(t.contains(b"hello"));
assert!(t.contains(b"world"));
assert!(t.contains(b"foo"));
assert!(t.contains(b"bar"));
}
#[test]
fn test_map_values_empty_key() {
let mut m = DoubleArrayTrieMap::<i32>::new();
m.insert(b"", 42).unwrap();
assert_eq!(m.get(b""), Some(42));
m.insert(b"a", 1).unwrap();
assert_eq!(m.get(b""), Some(42));
assert_eq!(m.get(b"a"), Some(1));
}
#[test]
fn test_cursor_after_remove() {
let mut t = DoubleArrayTrie::new();
t.insert(b"a").unwrap();
t.insert(b"b").unwrap();
t.insert(b"c").unwrap();
t.remove(b"b");
let mut c = t.cursor();
c.seek_begin();
assert_eq!(c.key(), b"a");
assert!(c.next());
assert_eq!(c.key(), b"c");
assert!(!c.next());
}
#[test]
fn test_keys_with_prefix_empty_trie() {
let t = DoubleArrayTrie::new();
assert_eq!(t.keys_with_prefix(b"").len(), 0);
assert_eq!(t.keys_with_prefix(b"anything").len(), 0);
assert_eq!(t.keys().len(), 0);
}
#[test]
fn test_map_empty_key() {
let mut map = DoubleArrayTrieMap::<u32>::new();
map.insert(b"", 99).unwrap();
assert_eq!(map.get(b""), Some(99));
assert_eq!(map.len(), 1);
map.insert(b"x", 1).unwrap();
assert_eq!(map.get(b""), Some(99));
assert_eq!(map.get(b"x"), Some(1));
}
#[test]
fn test_remove_all_then_reinsert() {
let mut t = DoubleArrayTrie::new();
for i in 0..50u32 {
t.insert(format!("k{}", i).as_bytes()).unwrap();
}
assert_eq!(t.len(), 50);
for i in 0..50u32 {
assert!(t.remove(format!("k{}", i).as_bytes()));
}
assert_eq!(t.len(), 0);
assert!(t.is_empty());
for i in 0..50u32 {
t.insert(format!("k{}", i).as_bytes()).unwrap();
}
assert_eq!(t.len(), 50);
for i in 0..50u32 {
assert!(t.contains(format!("k{}", i).as_bytes()));
}
}
#[test]
fn test_range_after_remove() {
let mut t = DoubleArrayTrie::new();
t.insert(b"a").unwrap();
t.insert(b"b").unwrap();
t.insert(b"c").unwrap();
t.insert(b"d").unwrap();
t.remove(b"b");
t.remove(b"c");
let range: Vec<Vec<u8>> = t.range(b"a", b"z").collect();
assert_eq!(range.len(), 2);
assert_eq!(range[0], b"a");
assert_eq!(range[1], b"d");
}
#[test]
fn test_map_values_empty_prefix_empty() {
let m = DoubleArrayTrieMap::<i32>::new();
let vals = m.values_with_prefix(b"");
assert_eq!(vals.len(), 0);
}
}