use crate::double_array_trie_zipper::DoubleArrayTrieZipper;
use crate::iterator::{DictionaryIterator, DictionaryTermIterator};
use crate::value::DictionaryValue;
use crate::{Dictionary, DictionaryNode, MappedDictionary, MappedDictionaryNode};
use std::sync::Arc;
#[cfg(feature = "serialization")]
#[allow(unused_imports)]
use crate::serialization::serde_helpers::{
deserialize_arc_vec, deserialize_arc_vec_vec, serialize_arc_vec, serialize_arc_vec_vec,
};
pub(crate) type DATShared<V = ()> = crate::dat_core::DATCoreShared<u8, V>;
#[cfg_attr(
feature = "serialization",
derive(serde::Serialize, serde::Deserialize)
)]
#[cfg_attr(
all(feature = "serialization", not(feature = "persistent-artrie")),
serde(bound(serialize = "V: serde::Serialize")),
serde(bound(deserialize = "V: serde::Deserialize<'de>"))
)]
#[cfg_attr(
all(feature = "serialization", feature = "persistent-artrie"),
serde(bound = "")
)]
#[derive(Clone, Debug)]
pub struct DoubleArrayTrie<V: DictionaryValue = ()> {
pub(crate) shared: DATShared<V>,
#[allow(dead_code)]
#[cfg_attr(
feature = "serialization",
serde(
serialize_with = "serialize_arc_vec",
deserialize_with = "deserialize_arc_vec"
)
)]
free_list: Arc<Vec<usize>>,
term_count: usize,
#[allow(dead_code)]
rebuild_threshold: f64,
}
pub struct DoubleArrayTrieBuilder<V: DictionaryValue = ()> {
base: Vec<i32>,
check: Vec<i32>,
is_final: Vec<bool>,
values: Vec<Option<V>>,
free_list: Vec<usize>,
term_count: usize,
rebuild_threshold: f64,
}
impl<V: DictionaryValue> DoubleArrayTrieBuilder<V> {
pub fn new() -> Self {
let base = vec![-1, 0]; let check = vec![-1, -1]; let is_final = vec![false, false];
let values = vec![None, None];
Self {
base,
check,
is_final,
values,
free_list: Vec::new(),
term_count: 0,
rebuild_threshold: 0.2, }
}
pub fn with_rebuild_threshold(mut self, threshold: f64) -> Self {
self.rebuild_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn insert(&mut self, term: &str) -> bool {
self.insert_with_value(term, None)
}
pub fn insert_with_value(&mut self, term: &str, value: Option<V>) -> bool {
if term.is_empty() {
while self.is_final.len() <= 1 {
self.is_final.push(false);
}
while self.values.len() <= 1 {
self.values.push(None);
}
if self.is_final[1] {
if value.is_some() {
self.values[1] = value;
}
return false; }
self.is_final[1] = true;
self.values[1] = value;
self.term_count += 1;
return true;
}
let bytes = term.as_bytes();
let mut state = 1;
for &byte in bytes {
if let Some(next) = self.transition(state, byte) {
state = next;
} else {
state = self.add_transition(state, byte);
}
}
if state < self.is_final.len() && self.is_final[state] {
if value.is_some() && state < self.values.len() {
self.values[state] = value;
}
false } else {
while state >= self.is_final.len() {
self.is_final.push(false);
}
while state >= self.values.len() {
self.values.push(None);
}
self.is_final[state] = true;
self.values[state] = value;
self.term_count += 1;
true
}
}
fn transition(&self, state: usize, byte: u8) -> Option<usize> {
if state >= self.base.len() {
return None;
}
let base = self.base[state];
if base < 0 {
return None; }
let next = (base as usize).wrapping_add(byte as usize);
if next < self.check.len() && self.check[next] == state as i32 {
Some(next)
} else {
None
}
}
fn add_transition(&mut self, state: usize, byte: u8) -> usize {
while state >= self.base.len() {
self.base.push(-1);
self.check.push(-1);
self.is_final.push(false);
self.values.push(None);
}
let next_state = if self.base[state] < 0 {
let start = (state * 31) % 1000 + byte as usize;
let base = self.find_free_base(start, &[byte]);
self.base[state] = base;
(base as usize).wrapping_add(byte as usize)
} else {
(self.base[state] as usize).wrapping_add(byte as usize)
};
while next_state >= self.check.len() {
self.base.push(-1);
self.check.push(-1);
self.is_final.push(false);
self.values.push(None);
}
if self.check[next_state] >= 0 {
let mut all_bytes = Vec::new();
let old_base = self.base[state];
for b in 0u8..=255 {
let child = (old_base as usize).wrapping_add(b as usize);
if child < self.check.len() && self.check[child] == state as i32 {
all_bytes.push(b);
}
}
all_bytes.push(byte);
let new_base = self.find_free_base(next_state + 1, &all_bytes);
for &b in &all_bytes {
if b == byte {
continue; }
let old_child = (old_base as usize).wrapping_add(b as usize);
let new_child = (new_base as usize).wrapping_add(b as usize);
while new_child >= self.check.len() {
self.base.push(-1);
self.check.push(-1);
self.is_final.push(false);
self.values.push(None);
}
self.check[new_child] = state as i32; self.base[new_child] = self.base[old_child];
self.is_final[new_child] = self.is_final[old_child];
if old_child < self.values.len() {
while new_child >= self.values.len() {
self.values.push(None);
}
self.values[new_child] = self.values[old_child].clone();
}
if self.base[old_child] >= 0 {
let child_base = self.base[old_child] as usize;
for gc_byte in 0u8..=255 {
let grandchild = child_base + (gc_byte as usize);
if grandchild < self.check.len()
&& self.check[grandchild] == old_child as i32
{
self.check[grandchild] = new_child as i32;
}
}
}
self.check[old_child] = -1;
self.base[old_child] = -1;
self.is_final[old_child] = false;
if old_child < self.values.len() {
self.values[old_child] = None;
}
}
self.base[state] = new_base;
let new_next = (new_base as usize).wrapping_add(byte as usize);
while new_next >= self.check.len() {
self.base.push(-1);
self.check.push(-1);
self.is_final.push(false);
self.values.push(None);
}
self.check[new_next] = state as i32;
new_next
} else {
self.check[next_state] = state as i32;
next_state
}
}
fn find_free_base(&self, start: usize, bytes: &[u8]) -> i32 {
if bytes.is_empty() {
return 0;
}
let start_base = start as i32;
for base in start_base..start_base + 10000 {
let mut all_free = true;
for &byte in bytes {
let next = base + (byte as i32);
if next < 0 {
all_free = false;
break;
}
let next_usize = next as usize;
if next_usize < self.check.len() && self.check[next_usize] >= 0 {
all_free = false;
break;
}
}
if all_free {
return base;
}
}
start_base + 10000
}
pub fn build(self) -> DoubleArrayTrie<V> {
let mut edges = vec![Vec::new(); self.base.len()];
for (state, base_entry) in self.base.iter().enumerate() {
if *base_entry >= 0 {
let base = *base_entry as usize;
for byte in 0u8..=255 {
let next = base + (byte as usize);
if next < self.check.len() && self.check[next] == state as i32 {
edges[state].push(byte);
}
}
}
}
DoubleArrayTrie {
shared: DATShared {
base: Arc::new(self.base),
check: Arc::new(self.check),
is_final: Arc::new(self.is_final),
edges: Arc::new(edges),
values: Arc::new(self.values),
},
free_list: Arc::new(self.free_list),
term_count: self.term_count,
rebuild_threshold: self.rebuild_threshold,
}
}
}
impl<V: DictionaryValue> Default for DoubleArrayTrieBuilder<V> {
fn default() -> Self {
Self::new()
}
}
impl<V: DictionaryValue> DoubleArrayTrie<V> {
pub fn new() -> Self {
DoubleArrayTrieBuilder::new().build()
}
pub fn from_terms_with_values<I, S>(terms: I) -> Self
where
I: IntoIterator<Item = (S, V)>,
S: AsRef<str>,
{
let mut term_value_pairs: Vec<(String, V)> = terms
.into_iter()
.map(|(s, v)| (s.as_ref().to_string(), v))
.collect();
term_value_pairs.sort_by(|a, b| a.0.cmp(&b.0));
term_value_pairs.dedup_by(|a, b| {
if a.0 == b.0 {
std::mem::swap(&mut a.1, &mut b.1);
true
} else {
false
}
});
let mut builder = DoubleArrayTrieBuilder::new();
for (term, value) in term_value_pairs {
builder.insert_with_value(&term, Some(value));
}
builder.build()
}
pub fn get_value(&self, term: &str) -> Option<V> {
self.shared.term_value(term)
}
pub fn len(&self) -> Option<usize> {
Some(self.term_count)
}
pub fn is_empty(&self) -> bool {
self.term_count == 0
}
pub fn contains(&self, term: &str) -> bool {
self.shared.contains_term(term)
}
pub fn state_count(&self) -> usize {
self.shared.base.len()
}
pub fn memory_bytes(&self) -> usize {
let state_count = self.state_count();
let edges_bytes: usize = self.shared.edges.iter().map(|e| e.len()).sum();
state_count * 4 + state_count * 4 + (state_count + 7) / 8 + edges_bytes
}
pub fn iter_terms(&self) -> DictionaryTermIterator<DoubleArrayTrieZipper<V>> {
let zipper = DoubleArrayTrieZipper::new_from_dict(self);
DictionaryTermIterator::new(zipper)
}
pub fn iter_bytes(&self) -> DictionaryIterator<DoubleArrayTrieZipper<V>> {
let zipper = DoubleArrayTrieZipper::new_from_dict(self);
DictionaryIterator::new(zipper)
}
pub fn iter(&self) -> impl Iterator<Item = (String, V)> + '_ {
self.iter_bytes()
.map(|(bytes, value)| (String::from_utf8_lossy(&bytes).into_owned(), value))
}
}
impl<V: DictionaryValue> IntoIterator for &DoubleArrayTrie<V> {
type Item = (Vec<u8>, V);
type IntoIter = DictionaryIterator<DoubleArrayTrieZipper<V>>;
fn into_iter(self) -> Self::IntoIter {
self.iter_bytes()
}
}
impl<V: DictionaryValue> Default for DoubleArrayTrie<V> {
fn default() -> Self {
Self::new()
}
}
impl DoubleArrayTrie<()> {
pub fn from_terms<I, S>(terms: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let mut sorted_terms: Vec<String> =
terms.into_iter().map(|s| s.as_ref().to_string()).collect();
sorted_terms.sort();
sorted_terms.dedup();
let mut builder = DoubleArrayTrieBuilder::new();
for term in sorted_terms {
builder.insert(&term);
}
builder.build()
}
}
#[derive(Clone)]
pub struct DoubleArrayTrieNode<V: DictionaryValue = ()> {
state: usize,
shared: DATShared<V>,
}
impl<V: DictionaryValue> DictionaryNode for DoubleArrayTrieNode<V> {
type Unit = u8;
fn is_final(&self) -> bool {
self.state < self.shared.is_final.len() && self.shared.is_final[self.state]
}
fn transition(&self, label: u8) -> Option<Self> {
if self.state >= self.shared.base.len() {
return None;
}
let base = self.shared.base[self.state];
if base < 0 {
return None; }
let next = (base as usize).wrapping_add(label as usize);
if next < self.shared.check.len() && self.shared.check[next] == self.state as i32 {
Some(DoubleArrayTrieNode {
state: next,
shared: self.shared.clone(), })
} else {
None
}
}
fn edges(&self) -> Box<dyn Iterator<Item = (u8, Self)> + '_> {
let state = self.state;
if state >= self.shared.edges.len() {
return Box::new(std::iter::empty());
}
let base = self.shared.base[state];
if base < 0 {
return Box::new(std::iter::empty());
}
let edges: Vec<(u8, Self)> = self.shared.edges[state]
.iter()
.map(|&byte| {
let next = (base as usize) + (byte as usize);
(
byte,
DoubleArrayTrieNode {
state: next,
shared: self.shared.clone(), },
)
})
.collect();
Box::new(edges.into_iter())
}
fn edge_count(&self) -> Option<usize> {
if self.state < self.shared.edges.len() {
Some(self.shared.edges[self.state].len())
} else {
Some(0)
}
}
}
impl<V: DictionaryValue> Dictionary for DoubleArrayTrie<V> {
type Node = DoubleArrayTrieNode<V>;
fn root(&self) -> Self::Node {
DoubleArrayTrieNode {
state: 1, shared: self.shared.clone(),
}
}
fn len(&self) -> Option<usize> {
Some(self.term_count)
}
fn contains(&self, term: &str) -> bool {
self.contains(term)
}
}
impl<V: DictionaryValue> MappedDictionaryNode for DoubleArrayTrieNode<V> {
type Value = V;
fn value(&self) -> Option<Self::Value> {
if self.state < self.shared.values.len() {
self.shared.values[self.state].clone()
} else {
None
}
}
}
impl<V: DictionaryValue> MappedDictionary for DoubleArrayTrie<V> {
type Value = V;
fn get_value(&self, term: &str) -> Option<Self::Value> {
Self::get_value(self, term)
}
fn contains_with_value<F>(&self, term: &str, predicate: F) -> bool
where
F: Fn(&Self::Value) -> bool,
{
match self.get_value(term) {
Some(ref value) => predicate(value),
None => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_dat() {
let dat: DoubleArrayTrie<()> = DoubleArrayTrie::new();
assert_eq!(dat.len(), Some(0));
assert!(dat.is_empty());
}
#[test]
fn test_single_term() {
let dat = DoubleArrayTrie::from_terms(vec!["test"]);
assert_eq!(dat.len(), Some(1));
assert!(dat.contains("test"));
assert!(!dat.contains("testing"));
assert!(!dat.contains("tes"));
}
#[test]
fn test_multiple_terms() {
let dat = DoubleArrayTrie::from_terms(vec!["test", "testing", "tested", "tester"]);
assert_eq!(dat.len(), Some(4));
assert!(dat.contains("test"));
assert!(dat.contains("testing"));
assert!(dat.contains("tested"));
assert!(dat.contains("tester"));
assert!(!dat.contains("tes"));
assert!(!dat.contains("tests"));
}
#[test]
fn test_prefix_sharing() {
let dat = DoubleArrayTrie::from_terms(vec!["test", "best", "rest"]);
assert_eq!(dat.len(), Some(3));
assert!(dat.contains("test"));
assert!(dat.contains("best"));
assert!(dat.contains("rest"));
}
#[test]
fn test_memory_efficiency() {
let dat =
DoubleArrayTrie::from_terms(vec!["band", "banana", "bandana", "can", "cane", "candy"]);
let memory = dat.memory_bytes();
let state_count = dat.state_count();
println!("DAT memory: {} bytes for {} states", memory, state_count);
println!(
" Approximately {} bytes/state",
memory / state_count.max(1)
);
assert!(memory < state_count * 12);
}
#[test]
fn test_dictionary_trait() {
let dat = DoubleArrayTrie::from_terms(vec!["test", "testing"]);
let root = dat.root();
assert!(!root.is_final());
let t_node = root.transition(b't').expect("Should have 't' edge");
assert!(!t_node.is_final());
let e_node = t_node.transition(b'e').expect("Should have 'e' edge");
assert!(!e_node.is_final());
let s_node = e_node.transition(b's').expect("Should have 's' edge");
assert!(!s_node.is_final());
let final_node = s_node.transition(b't').expect("Should have 't' edge");
assert!(final_node.is_final()); }
#[test]
fn test_edge_iteration() {
let dat = DoubleArrayTrie::from_terms(vec!["ab", "ac", "ad"]);
let root = dat.root();
let a_node = root.transition(b'a').expect("Should have 'a' edge");
let edges: Vec<u8> = a_node.edges().map(|(label, _)| label).collect();
assert!(edges.contains(&b'b'));
assert!(edges.contains(&b'c'));
assert!(edges.contains(&b'd'));
assert_eq!(edges.len(), 3);
}
#[test]
fn test_incremental_construction() {
let mut builder: DoubleArrayTrieBuilder<()> = DoubleArrayTrieBuilder::new();
assert!(builder.insert("hello"));
assert!(builder.insert("world"));
assert!(builder.insert("test"));
assert!(!builder.insert("test"));
let dat = builder.build();
assert_eq!(dat.len(), Some(3));
assert!(dat.contains("hello"));
assert!(dat.contains("world"));
assert!(dat.contains("test"));
}
#[test]
fn test_mapped_dictionary_with_values() {
let terms = vec![("apple", 1), ("application", 2), ("apply", 3)];
let dict = DoubleArrayTrie::from_terms_with_values(terms);
assert_eq!(dict.get_value("apple"), Some(1));
assert_eq!(dict.get_value("application"), Some(2));
assert_eq!(dict.get_value("apply"), Some(3));
assert_eq!(dict.get_value("apricot"), None);
}
#[test]
fn test_mapped_dictionary_contains_with_value() {
let dict = DoubleArrayTrie::from_terms_with_values(vec![("test", 42), ("testing", 100)]);
assert!(dict.contains_with_value("test", |v| *v == 42));
assert!(dict.contains_with_value("testing", |v| *v > 50));
assert!(!dict.contains_with_value("test", |v| *v > 50));
assert!(!dict.contains_with_value("missing", |v| *v == 42));
}
#[test]
fn test_mapped_dictionary_node_value() {
use crate::MappedDictionaryNode;
let dict = DoubleArrayTrie::from_terms_with_values(vec![("cat", 1), ("catch", 2)]);
let root = dict.root();
let c = root.transition(b'c').unwrap();
let a = c.transition(b'a').unwrap();
let t = a.transition(b't').unwrap();
assert!(t.is_final());
assert_eq!(t.value(), Some(1));
let c2 = t.transition(b'c').unwrap();
let h = c2.transition(b'h').unwrap();
assert!(h.is_final());
assert_eq!(h.value(), Some(2));
}
#[test]
fn test_backward_compatibility_without_values() {
let dict: DoubleArrayTrie = DoubleArrayTrie::from_terms(vec!["test", "testing"]);
assert!(dict.contains("test"));
assert_eq!(dict.len(), Some(2));
assert_eq!(dict.get_value("test"), None);
}
#[test]
fn test_builder_with_values() {
let mut builder: DoubleArrayTrieBuilder<i32> = DoubleArrayTrieBuilder::new();
builder.insert_with_value("hello", Some(10));
builder.insert_with_value("world", Some(20));
builder.insert_with_value("test", Some(30));
let dat = builder.build();
assert_eq!(dat.len(), Some(3));
assert_eq!(dat.get_value("hello"), Some(10));
assert_eq!(dat.get_value("world"), Some(20));
assert_eq!(dat.get_value("test"), Some(30));
}
#[test]
fn test_empty_string_with_value() {
let mut builder: DoubleArrayTrieBuilder<i32> = DoubleArrayTrieBuilder::new();
builder.insert_with_value("", Some(42));
let dat = builder.build();
assert_eq!(dat.get_value(""), Some(42));
}
#[test]
fn test_duplicate_update_value() {
let mut builder: DoubleArrayTrieBuilder<i32> = DoubleArrayTrieBuilder::new();
assert!(builder.insert_with_value("test", Some(10)));
assert!(!builder.insert_with_value("test", Some(20)));
let dat = builder.build();
assert_eq!(dat.len(), Some(1));
assert_eq!(dat.get_value("test"), Some(20)); }
#[test]
fn test_string_values() {
let dict = DoubleArrayTrie::from_terms_with_values(vec![
("hello", "greeting".to_string()),
("world", "noun".to_string()),
("test", "verb".to_string()),
]);
assert_eq!(dict.get_value("hello"), Some("greeting".to_string()));
assert_eq!(dict.get_value("world"), Some("noun".to_string()));
assert_eq!(dict.get_value("test"), Some("verb".to_string()));
}
}