use crate::error::{Result, ZiporaError};
#[repr(C)]
#[derive(Clone, Copy, Debug)]
struct DaState {
child0: u32,
parent: u32,
}
const TERM_BIT: u32 = 0x8000_0000;
const FREE_BIT: u32 = 0x8000_0000;
const VALUE_MASK: u32 = 0x7FFF_FFFF;
const NIL_STATE: u32 = 0x7FFF_FFFF;
const MAX_STATE: u32 = 0x7FFF_FFFE;
impl DaState {
#[inline(always)]
const fn new_free() -> Self {
Self {
child0: NIL_STATE, parent: NIL_STATE | FREE_BIT, }
}
#[inline(always)]
const fn new_root() -> Self {
Self {
child0: NIL_STATE, parent: 0, }
}
#[inline(always)]
fn child0(&self) -> u32 { self.child0 & VALUE_MASK }
#[inline(always)]
fn parent(&self) -> u32 { self.parent & VALUE_MASK }
#[inline(always)]
fn is_term(&self) -> bool { (self.child0 & TERM_BIT) != 0 }
#[inline(always)]
fn is_free(&self) -> bool { (self.parent & FREE_BIT) != 0 }
#[inline(always)]
fn set_term_bit(&mut self) { self.child0 |= TERM_BIT; }
#[inline(always)]
fn clear_term_bit(&mut self) { self.child0 &= !TERM_BIT; }
#[inline(always)]
fn set_child0(&mut self, val: u32) {
self.child0 = (self.child0 & TERM_BIT) | (val & VALUE_MASK);
}
#[inline(always)]
fn set_parent(&mut self, val: u32) {
self.parent = val & VALUE_MASK; }
#[inline(always)]
fn set_free(&mut self) {
self.child0 = NIL_STATE;
self.parent = NIL_STATE | FREE_BIT;
}
}
pub struct DoubleArrayTrie {
states: Vec<DaState>,
num_keys: usize,
search_head: usize,
}
impl DoubleArrayTrie {
pub fn new() -> Self {
Self::with_capacity(256)
}
pub fn with_capacity(capacity: usize) -> Self {
let cap = capacity.max(2); let mut states = Vec::with_capacity(cap);
states.push(DaState::new_root());
states.resize(cap, DaState::new_free());
Self {
states,
num_keys: 0,
search_head: 1,
}
}
#[inline(always)]
pub fn len(&self) -> usize { self.num_keys }
#[inline(always)]
pub fn is_empty(&self) -> bool { self.num_keys == 0 }
#[inline]
pub fn total_states(&self) -> usize { self.states.len() }
#[inline]
pub fn mem_size(&self) -> usize {
self.states.len() * std::mem::size_of::<DaState>()
}
#[inline(always)]
pub fn is_term(&self, state: u32) -> bool {
(state as usize) < self.states.len() && self.states[state as usize].is_term()
}
#[inline(always)]
pub fn is_free(&self, state: u32) -> bool {
(state as usize) >= self.states.len() || self.states[state as usize].is_free()
}
#[inline(always)]
pub fn state_move(&self, curr: u32, ch: u8) -> u32 {
let base = self.states[curr as usize].child0();
if base == NIL_STATE { return NIL_STATE; }
let next = base as usize + ch as usize;
if next >= self.states.len() { return NIL_STATE; }
if self.states[next].is_free() { return NIL_STATE; }
if self.states[next].parent() == curr {
next as u32
} else {
NIL_STATE
}
}
pub fn insert(&mut self, key: &[u8]) -> Result<bool> {
if key.is_empty() {
let was_new = !self.states[0].is_term();
self.states[0].set_term_bit();
if was_new { self.num_keys += 1; }
return Ok(was_new);
}
let mut curr = 0u32;
for &ch in key {
let base = self.states[curr as usize].child0();
if base == NIL_STATE {
let new_base = self.find_free_base(&[ch])?;
self.states[curr as usize].set_child0(new_base);
let next = new_base + ch as u32;
self.ensure_capacity(next as usize + 1);
self.states[next as usize].set_parent(curr);
curr = next;
} else {
let next = base + ch as u32;
self.ensure_capacity(next as usize + 1);
if !self.states[next as usize].is_free()
&& self.states[next as usize].parent() == curr
{
curr = next;
} else if self.states[next as usize].is_free() {
self.states[next as usize].set_parent(curr);
curr = next;
} else {
let new_base = self.relocate(curr, ch)?;
let next = new_base + ch as u32;
self.ensure_capacity(next as usize + 1);
self.states[next as usize].set_parent(curr);
curr = next;
}
}
}
let was_new = !self.states[curr as usize].is_term();
self.states[curr as usize].set_term_bit();
if was_new { self.num_keys += 1; }
Ok(was_new)
}
#[inline]
pub fn contains(&self, key: &[u8]) -> bool {
let states = self.states.as_slice();
let len = states.len();
if key.is_empty() {
return states[0].is_term();
}
let mut curr = 0usize;
for &ch in key {
let base = states[curr].child0();
if base == NIL_STATE { return false; }
let next = base as usize + ch as usize;
if next >= len { return false; }
if states[next].parent != curr as u32 { return false; }
curr = next;
}
states[curr].is_term()
}
#[inline]
pub fn lookup_state(&self, key: &[u8]) -> Option<u32> {
let states = self.states.as_slice();
let len = states.len();
if key.is_empty() {
return if states[0].is_term() { Some(0) } else { None };
}
let mut curr = 0usize;
for &ch in key {
let base = states[curr].child0();
if base == NIL_STATE { return None; }
let next = base as usize + ch as usize;
if next >= len { return None; }
if states[next].parent != curr as u32 { return None; }
curr = next;
}
if states[curr].is_term() { Some(curr as u32) } else { None }
}
pub fn remove(&mut self, key: &[u8]) -> bool {
if let Some(state) = self.lookup_state(key) {
if self.states[state as usize].is_term() {
self.states[state as usize].clear_term_bit();
self.num_keys -= 1;
return true;
}
}
false
}
pub fn restore_key(&self, state: u32) -> Option<Vec<u8>> {
if state as usize >= self.states.len() { return None; }
if self.states[state as usize].is_free() { return None; }
let mut symbols = Vec::new();
let mut curr = state;
while curr != 0 {
let parent = self.states[curr as usize].parent();
let parent_base = self.states[parent as usize].child0();
if curr < parent_base { return None; }
let symbol = (curr - parent_base) as u8;
symbols.push(symbol);
curr = parent;
}
symbols.reverse();
Some(symbols)
}
pub fn keys(&self) -> Vec<Vec<u8>> {
let mut result = Vec::with_capacity(self.num_keys);
let mut path = Vec::new();
self.collect_keys(0, &mut path, &mut result);
result
}
pub fn keys_with_prefix(&self, prefix: &[u8]) -> Vec<Vec<u8>> {
let mut curr = 0u32;
for &ch in prefix {
let next = self.state_move(curr, ch);
if next == NIL_STATE { return Vec::new(); }
curr = next;
}
let mut result = Vec::new();
let mut path = prefix.to_vec();
self.collect_keys(curr, &mut path, &mut result);
result
}
#[inline]
pub fn for_each_child(&self, state: u32, mut f: impl FnMut(u8, u32)) {
let base = self.states[state as usize].child0();
if base == NIL_STATE { return; }
for ch in 0u16..=255u16 {
let next = base as usize + ch as usize;
if next >= self.states.len() { break; }
if !self.states[next].is_free() && self.states[next].parent() == state {
f(ch as u8, next as u32);
}
}
}
pub fn build_from_sorted(keys: &[&[u8]]) -> Result<Self> {
if keys.is_empty() { return Ok(Self::new()); }
let total_bytes: usize = keys.iter().map(|k| k.len()).sum();
let estimated_states = (total_bytes / 2).max(256);
let mut trie = Self::with_capacity(estimated_states * 3 / 2);
for &key in keys {
trie.insert(key)?;
}
trie.shrink_to_fit();
Ok(trie)
}
pub fn shrink_to_fit(&mut self) {
let last_used = self.states.iter().rposition(|s| !s.is_free()).unwrap_or(0);
let new_len = (last_used + 257).min(self.states.len());
self.states.truncate(new_len);
self.states.shrink_to_fit();
}
#[inline]
fn ensure_capacity(&mut self, required: usize) {
if required <= self.states.len() { return; }
let new_len = required.max(self.states.len() * 3 / 2).max(256);
self.states.resize(new_len, DaState::new_free());
}
fn find_free_base(&mut self, children: &[u8]) -> Result<u32> {
debug_assert!(!children.is_empty());
let min_ch = *children.iter().min().unwrap() as u32;
let max_ch = *children.iter().max().unwrap() as u32;
let single_child = children.len() == 1;
let mut base = if self.search_head as u32 > min_ch {
self.search_head as u32 - min_ch
} else {
1
};
let mut attempts = 0u32;
loop {
if attempts > 1_000_000 || base > MAX_STATE {
return Err(ZiporaError::invalid_data("Double array: cannot find free base"));
}
attempts += 1;
let max_pos = base + max_ch;
self.ensure_capacity(max_pos as usize + 1);
if single_child {
let pos = (base + min_ch) as usize;
if pos > 0 && self.states[pos].is_free() {
if base as usize > self.search_head {
self.search_head += ((base as usize - self.search_head) >> 4).max(1);
}
return Ok(base);
}
base += 1;
continue;
}
let all_free = children.iter().all(|&ch| {
let pos = (base + ch as u32) as usize;
pos > 0 && self.states[pos].is_free()
});
if all_free {
if base as usize > self.search_head {
self.search_head += ((base as usize - self.search_head) >> 4).max(1);
}
return Ok(base);
}
base += 1;
}
}
fn relocate(&mut self, state: u32, new_ch: u8) -> Result<u32> {
let old_base = self.states[state as usize].child0();
let mut children_symbols = Vec::new();
if old_base != NIL_STATE {
for ch in 0u16..=255u16 {
let pos = old_base as usize + ch as usize;
if pos >= self.states.len() { break; }
if !self.states[pos].is_free() && self.states[pos].parent() == state {
children_symbols.push(ch as u8);
}
}
}
if !children_symbols.contains(&new_ch) {
children_symbols.push(new_ch);
}
children_symbols.sort_unstable();
let new_base = self.find_free_base(&children_symbols)?;
if old_base != NIL_STATE {
for &ch in &children_symbols {
if ch == new_ch { continue; } let old_pos = old_base + ch as u32;
let new_pos = new_base + ch as u32;
if old_pos as usize >= self.states.len() { continue; }
if self.states[old_pos as usize].is_free() { continue; }
if self.states[old_pos as usize].parent() != state { continue; }
self.ensure_capacity(new_pos as usize + 1);
let old_state = self.states[old_pos as usize];
self.states[new_pos as usize].child0 = old_state.child0;
self.states[new_pos as usize].set_parent(state);
let child_base = old_state.child0();
if child_base != NIL_STATE {
for gch in 0u16..=255u16 {
let gpos = child_base as usize + gch as usize;
if gpos >= self.states.len() { break; }
if !self.states[gpos].is_free()
&& self.states[gpos].parent() == old_pos
{
self.states[gpos].set_parent(new_pos);
}
}
}
self.states[old_pos as usize].set_free();
}
}
self.states[state as usize].set_child0(new_base);
Ok(new_base)
}
fn collect_keys(&self, state: u32, path: &mut Vec<u8>, keys: &mut Vec<Vec<u8>>) {
if state as usize >= self.states.len() { return; }
if self.states[state as usize].is_term() {
keys.push(path.clone());
}
let base = self.states[state as usize].child0();
if base == NIL_STATE { return; }
for ch in 0u16..=255u16 {
let next = base as usize + ch as usize;
if next >= self.states.len() { break; }
if !self.states[next].is_free() && self.states[next].parent() == state {
path.push(ch as u8);
self.collect_keys(next as u32, path, keys);
path.pop();
}
}
}
}
impl Default for DoubleArrayTrie {
fn default() -> Self { Self::new() }
}
impl std::fmt::Debug for DoubleArrayTrie {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DoubleArrayTrie")
.field("num_keys", &self.num_keys)
.field("total_states", &self.states.len())
.field("mem_size", &self.mem_size())
.finish()
}
}
pub struct DoubleArrayTrieMap<V: Copy> {
trie: DoubleArrayTrie,
values: Vec<Option<V>>,
}
impl<V: Copy> DoubleArrayTrieMap<V> {
pub fn new() -> Self {
Self { trie: DoubleArrayTrie::new(), values: Vec::new() }
}
pub fn with_capacity(cap: usize) -> Self {
Self { trie: DoubleArrayTrie::with_capacity(cap), values: Vec::with_capacity(cap) }
}
pub fn insert(&mut self, key: &[u8], value: V) -> Result<Option<V>> {
self.trie.insert(key)?;
let state = self.trie.lookup_state(key)
.ok_or_else(|| ZiporaError::invalid_state("insert succeeded but lookup failed"))?;
let idx = state as usize;
if idx >= self.values.len() {
self.values.resize(idx + 1, None);
}
let prev = self.values[idx];
self.values[idx] = Some(value);
Ok(prev)
}
#[inline]
pub fn get(&self, key: &[u8]) -> Option<V> {
let state = self.trie.lookup_state(key)?;
self.values.get(state as usize).and_then(|v| *v)
}
#[inline]
pub fn contains(&self, key: &[u8]) -> bool { self.trie.contains(key) }
#[inline]
pub fn len(&self) -> usize { self.trie.len() }
#[inline]
pub fn is_empty(&self) -> bool { self.trie.is_empty() }
pub fn remove(&mut self, key: &[u8]) -> Option<V> {
let state = self.trie.lookup_state(key)?;
let prev = self.values.get(state as usize).and_then(|v| *v);
self.trie.remove(key);
if let Some(slot) = self.values.get_mut(state as usize) {
*slot = None;
}
prev
}
}
impl<V: Copy> Default for DoubleArrayTrieMap<V> {
fn default() -> Self { Self::new() }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_insert_contains() {
let mut t = DoubleArrayTrie::new();
assert!(t.insert(b"hello").unwrap());
assert!(t.insert(b"help").unwrap());
assert!(t.insert(b"world").unwrap());
assert_eq!(t.len(), 3);
assert!(t.contains(b"hello"));
assert!(t.contains(b"help"));
assert!(t.contains(b"world"));
assert!(!t.contains(b"hel"));
assert!(!t.contains(b"hell"));
assert!(!t.contains(b"worlds"));
}
#[test]
fn test_duplicate_insert() {
let mut t = DoubleArrayTrie::new();
assert!(t.insert(b"abc").unwrap());
assert!(!t.insert(b"abc").unwrap());
assert!(!t.insert(b"abc").unwrap());
assert_eq!(t.len(), 1);
}
#[test]
fn test_empty_key() {
let mut t = DoubleArrayTrie::new();
assert!(t.insert(b"").unwrap());
assert!(t.contains(b""));
assert!(t.insert(b"a").unwrap());
assert_eq!(t.len(), 2);
let mut keys = t.keys();
keys.sort();
assert_eq!(keys, vec![vec![], vec![b'a']]);
}
#[test]
fn test_remove() {
let mut t = DoubleArrayTrie::new();
t.insert(b"hello").unwrap();
t.insert(b"world").unwrap();
assert_eq!(t.len(), 2);
assert!(t.remove(b"hello"));
assert_eq!(t.len(), 1);
assert!(!t.contains(b"hello"));
assert!(t.contains(b"world"));
assert!(!t.remove(b"missing"));
}
#[test]
fn test_restore_key() {
let mut t = DoubleArrayTrie::new();
t.insert(b"hello").unwrap();
t.insert(b"world").unwrap();
let state = t.lookup_state(b"hello").unwrap();
assert_eq!(t.restore_key(state).unwrap(), b"hello");
let state2 = t.lookup_state(b"world").unwrap();
assert_eq!(t.restore_key(state2).unwrap(), b"world");
}
#[test]
fn test_keys() {
let mut t = DoubleArrayTrie::new();
t.insert(b"apple").unwrap();
t.insert(b"app").unwrap();
t.insert(b"banana").unwrap();
let mut keys = t.keys();
keys.sort();
assert_eq!(keys.len(), 3);
assert_eq!(keys[0], b"app");
assert_eq!(keys[1], b"apple");
assert_eq!(keys[2], b"banana");
}
#[test]
fn test_keys_with_prefix() {
let mut t = DoubleArrayTrie::new();
t.insert(b"").unwrap();
t.insert(b"a").unwrap();
t.insert(b"ab").unwrap();
t.insert(b"abc").unwrap();
t.insert(b"abd").unwrap();
t.insert(b"b").unwrap();
let all = t.keys_with_prefix(b"");
assert_eq!(all.len(), 6);
let a = t.keys_with_prefix(b"a");
assert_eq!(a.len(), 4);
let ab = t.keys_with_prefix(b"ab");
assert_eq!(ab.len(), 3);
let none = t.keys_with_prefix(b"xyz");
assert_eq!(none.len(), 0);
}
#[test]
fn test_many_inserts() {
let mut t = DoubleArrayTrie::new();
for i in 0..1000 {
t.insert(format!("key_{:04}", i).as_bytes()).unwrap();
}
assert_eq!(t.len(), 1000);
assert!(t.contains(b"key_0000"));
assert!(t.contains(b"key_0500"));
assert!(t.contains(b"key_0999"));
assert!(!t.contains(b"key_1000"));
}
#[test]
fn test_state_move() {
let mut t = DoubleArrayTrie::new();
t.insert(b"abc").unwrap();
let s1 = t.state_move(0, b'a');
assert_ne!(s1, NIL_STATE);
let s2 = t.state_move(s1, b'b');
assert_ne!(s2, NIL_STATE);
let s3 = t.state_move(s2, b'c');
assert_ne!(s3, NIL_STATE);
assert!(t.is_term(s3));
assert_eq!(t.state_move(0, b'z'), NIL_STATE);
}
#[test]
fn test_build_from_sorted() {
let keys: Vec<&[u8]> = vec![b"apple", b"application", b"apply", b"banana", b"band"];
let t = DoubleArrayTrie::build_from_sorted(&keys).unwrap();
assert_eq!(t.len(), 5);
for key in &keys {
assert!(t.contains(key), "missing: {:?}", std::str::from_utf8(key));
}
}
#[test]
fn test_for_each_child() {
let mut t = DoubleArrayTrie::new();
t.insert(b"ab").unwrap();
t.insert(b"ac").unwrap();
t.insert(b"ad").unwrap();
let mut root_children = Vec::new();
t.for_each_child(0, |ch, _| root_children.push(ch));
assert_eq!(root_children, vec![b'a']);
let a_state = t.state_move(0, b'a');
let mut a_children = Vec::new();
t.for_each_child(a_state, |ch, _| a_children.push(ch));
a_children.sort();
assert_eq!(a_children, vec![b'b', b'c', b'd']);
}
#[test]
fn test_da_trie_map() {
let mut map = DoubleArrayTrieMap::<u32>::new();
map.insert(b"hello", 42).unwrap();
map.insert(b"world", 100).unwrap();
map.insert(b"help", 7).unwrap();
assert_eq!(map.get(b"hello"), Some(42));
assert_eq!(map.get(b"world"), Some(100));
assert_eq!(map.get(b"help"), Some(7));
assert_eq!(map.get(b"missing"), None);
let prev = map.insert(b"hello", 99).unwrap();
assert_eq!(prev, Some(42));
assert_eq!(map.get(b"hello"), Some(99));
let removed = map.remove(b"world");
assert_eq!(removed, Some(100));
assert_eq!(map.get(b"world"), None);
assert_eq!(map.len(), 2);
}
#[test]
fn test_mem_size() {
let t = DoubleArrayTrie::new();
assert_eq!(t.mem_size(), 256 * 8);
assert_eq!(std::mem::size_of::<DaState>(), 8);
}
#[test]
fn test_shared_prefixes() {
let mut t = DoubleArrayTrie::new();
t.insert(b"test").unwrap();
t.insert(b"testing").unwrap();
t.insert(b"tested").unwrap();
t.insert(b"tester").unwrap();
t.insert(b"tests").unwrap();
t.insert(b"tea").unwrap();
t.insert(b"team").unwrap();
t.insert(b"tear").unwrap();
assert_eq!(t.len(), 8);
assert!(t.contains(b"test"));
assert!(t.contains(b"testing"));
assert!(t.contains(b"tea"));
assert!(t.contains(b"team"));
assert!(!t.contains(b"te")); assert!(!t.contains(b"testi")); }
#[test]
fn test_long_keys() {
let mut t = DoubleArrayTrie::new();
let long_key = "a".repeat(1000);
t.insert(long_key.as_bytes()).unwrap();
assert!(t.contains(long_key.as_bytes()));
assert_eq!(t.len(), 1);
let state = t.lookup_state(long_key.as_bytes()).unwrap();
let restored = t.restore_key(state).unwrap();
assert_eq!(restored, long_key.as_bytes());
}
#[test]
fn test_binary_keys() {
let mut t = DoubleArrayTrie::new();
t.insert(&[0x00, 0xFF, 0x80]).unwrap();
t.insert(&[0x00, 0xFF, 0x81]).unwrap();
t.insert(&[0xFF, 0x00, 0x01]).unwrap();
assert_eq!(t.len(), 3);
assert!(t.contains(&[0x00, 0xFF, 0x80]));
assert!(t.contains(&[0xFF, 0x00, 0x01]));
assert!(!t.contains(&[0x00, 0xFF]));
}
#[test]
fn test_relocation_stress() {
let mut t = DoubleArrayTrie::new();
for ch in 0u8..=127u8 {
let key = [ch];
t.insert(&key).unwrap();
}
assert_eq!(t.len(), 128);
for ch in 0u8..=127u8 {
assert!(t.contains(&[ch]), "missing single-byte key {}", ch);
}
}
#[test]
fn test_shrink_to_fit() {
let mut t = DoubleArrayTrie::with_capacity(10000);
assert!(t.total_states() >= 10000);
t.insert(b"hello").unwrap();
t.insert(b"world").unwrap();
t.shrink_to_fit();
assert!(t.total_states() < 1000);
assert!(t.contains(b"hello"));
assert!(t.contains(b"world"));
}
#[test]
fn test_remove_and_reinsert() {
let mut t = DoubleArrayTrie::new();
t.insert(b"abc").unwrap();
assert!(t.contains(b"abc"));
t.remove(b"abc");
assert!(!t.contains(b"abc"));
assert_eq!(t.len(), 0);
assert!(t.insert(b"abc").unwrap());
assert!(t.contains(b"abc"));
assert_eq!(t.len(), 1);
}
#[test]
fn test_lookup_state_consistency() {
let mut t = DoubleArrayTrie::new();
let keys: Vec<&[u8]> = vec![b"alpha", b"beta", b"gamma", b"delta"];
for &key in &keys {
t.insert(key).unwrap();
}
let states: Vec<u32> = keys.iter()
.map(|k| t.lookup_state(k).unwrap())
.collect();
for i in 0..states.len() {
for j in (i + 1)..states.len() {
assert_ne!(states[i], states[j],
"states for {:?} and {:?} should differ",
std::str::from_utf8(keys[i]).unwrap(),
std::str::from_utf8(keys[j]).unwrap());
}
}
}
#[test]
fn test_performance_5000_terms() {
let terms: Vec<String> = (0..5000)
.map(|i| format!("term_{:06}_{}", i, ["alpha", "beta", "gamma", "delta"][i % 4]))
.collect();
let start = std::time::Instant::now();
let mut t = DoubleArrayTrie::new();
for term in &terms {
t.insert(term.as_bytes()).unwrap();
}
let insert_time = start.elapsed();
assert_eq!(t.len(), 5000);
let start = std::time::Instant::now();
for term in &terms {
assert!(t.contains(term.as_bytes()));
}
let lookup_time = start.elapsed();
let start = std::time::Instant::now();
for i in 0..5000 {
let miss = format!("miss_{:06}", i);
assert!(!t.contains(miss.as_bytes()));
}
let miss_time = start.elapsed();
#[cfg(not(debug_assertions))]
{
eprintln!("DoubleArrayTrie 5000 terms: insert={:?}, lookup_hit={:?}, lookup_miss={:?}",
insert_time, lookup_time, miss_time);
eprintln!("Memory: {} bytes ({} bytes/key), {} states",
t.mem_size(), t.mem_size() / 5000, t.total_states());
assert!(insert_time.as_millis() < 50,
"Insert too slow: {:?}", insert_time);
assert!(lookup_time.as_millis() < 10,
"Lookup too slow: {:?}", lookup_time);
}
}
}