use crate::bloom_filter::BloomFilter;
use crate::node_signature::NodeSignature;
use crate::value::DictionaryValue;
use crate::CharUnit;
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use std::collections::HashMap;
#[derive(Clone, Debug)]
#[cfg_attr(
feature = "serialization",
derive(serde::Serialize, serde::Deserialize)
)]
#[cfg_attr(
all(feature = "serialization", not(feature = "persistent-artrie")),
serde(bound(serialize = "U: serde::Serialize, V: serde::Serialize")),
serde(bound(deserialize = "U: serde::Deserialize<'de>, V: serde::Deserialize<'de>"))
)]
#[cfg_attr(
all(feature = "serialization", feature = "persistent-artrie"),
serde(bound(serialize = "U: serde::Serialize, V: serde::Serialize")),
serde(bound(deserialize = "U: serde::de::DeserializeOwned, V: serde::de::DeserializeOwned"))
)]
pub struct DawgNode<U: CharUnit, V: DictionaryValue> {
pub(crate) edges: SmallVec<[(U, usize); 4]>,
pub(crate) is_final: bool,
pub(crate) ref_count: usize,
pub(crate) value: Option<V>,
}
impl<U: CharUnit, V: DictionaryValue> DawgNode<U, V> {
pub fn new(is_final: bool) -> Self {
DawgNode {
edges: SmallVec::new(),
is_final,
ref_count: 0,
value: None,
}
}
pub fn new_with_value(is_final: bool, value: Option<V>) -> Self {
DawgNode {
edges: SmallVec::new(),
is_final,
ref_count: 0,
value,
}
}
}
#[derive(Debug)]
#[cfg_attr(
feature = "serialization",
derive(serde::Serialize, serde::Deserialize)
)]
#[cfg_attr(
all(feature = "serialization", not(feature = "persistent-artrie")),
serde(bound(serialize = "U: serde::Serialize, V: serde::Serialize")),
serde(bound(deserialize = "U: serde::Deserialize<'de>, V: serde::Deserialize<'de>"))
)]
#[cfg_attr(
all(feature = "serialization", feature = "persistent-artrie"),
serde(bound(serialize = "U: serde::Serialize, V: serde::Serialize")),
serde(bound(deserialize = "U: serde::de::DeserializeOwned, V: serde::de::DeserializeOwned"))
)]
pub struct DawgCore<U: CharUnit, V: DictionaryValue> {
pub(crate) nodes: Vec<DawgNode<U, V>>,
pub(crate) term_count: usize,
pub(crate) needs_compaction: bool,
#[cfg_attr(feature = "serialization", serde(skip))]
pub(crate) suffix_cache: FxHashMap<u64, usize>,
#[cfg_attr(feature = "serialization", serde(skip))]
pub(crate) last_minimized_node_count: usize,
#[cfg_attr(feature = "serialization", serde(skip))]
pub(crate) auto_minimize_threshold: f32,
#[cfg_attr(feature = "serialization", serde(skip))]
pub(crate) bloom_filter: Option<BloomFilter>,
}
impl<U: CharUnit, V: DictionaryValue> DawgCore<U, V> {
pub fn new() -> Self {
Self::with_config(f32::INFINITY, None)
}
pub fn with_auto_minimize_threshold(threshold: f32) -> Self {
Self::with_config(threshold, None)
}
pub fn with_config(auto_minimize_threshold: f32, bloom_filter_capacity: Option<usize>) -> Self {
let nodes = vec![DawgNode::new(false)]; let bloom_filter = bloom_filter_capacity.map(BloomFilter::new);
DawgCore {
nodes,
term_count: 0,
needs_compaction: false,
suffix_cache: FxHashMap::default(),
last_minimized_node_count: 1,
auto_minimize_threshold,
bloom_filter,
}
}
#[inline]
pub fn term_count(&self) -> usize {
self.term_count
}
#[inline]
pub fn node_count(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn needs_compaction(&self) -> bool {
self.needs_compaction
}
pub fn insert_units(&mut self, units: &[U]) -> bool {
let mut node_idx = 0;
let mut path: Vec<(usize, U, usize)> = Vec::new();
for &unit in units {
if let Some(&child_idx) = self.nodes[node_idx]
.edges
.iter()
.find(|(u, _)| *u == unit)
.map(|(_, idx)| idx)
{
path.push((node_idx, unit, child_idx));
node_idx = child_idx;
} else {
break;
}
}
let path_len = path.len();
if path_len == units.len() && self.nodes[node_idx].is_final {
return false; }
let unique_path = self.make_path_unique(&path);
node_idx = unique_path.last().map(|(_, _, child)| *child).unwrap_or(0);
for i in path_len..units.len() {
let unit = units[i];
let new_idx = self.nodes.len();
let is_final = i == units.len() - 1;
let mut new_node = DawgNode::new(is_final);
new_node.ref_count = 1;
self.nodes.push(new_node);
self.insert_edge_sorted(node_idx, unit, new_idx);
node_idx = new_idx;
}
if path_len == units.len() {
self.nodes[node_idx].is_final = true;
}
self.term_count += 1;
self.recompute_ref_counts();
self.check_and_auto_minimize();
true
}
pub fn insert_units_with_value(&mut self, units: &[U], value: V) -> bool {
let mut node_idx = 0;
let mut path: Vec<(usize, U, usize)> = Vec::new();
for &unit in units {
if let Some(&child_idx) = self.nodes[node_idx]
.edges
.iter()
.find(|(u, _)| *u == unit)
.map(|(_, idx)| idx)
{
path.push((node_idx, unit, child_idx));
node_idx = child_idx;
} else {
break;
}
}
let path_len = path.len();
if path_len == units.len() {
let unique_path = self.make_path_unique(&path);
node_idx = unique_path.last().map(|(_, _, child)| *child).unwrap_or(0);
if self.nodes[node_idx].is_final {
self.nodes[node_idx].value = Some(value);
self.recompute_ref_counts();
return false;
} else {
self.nodes[node_idx].is_final = true;
self.nodes[node_idx].value = Some(value);
self.term_count += 1;
self.recompute_ref_counts();
self.check_and_auto_minimize();
return true;
}
}
let unique_path = self.make_path_unique(&path);
node_idx = unique_path.last().map(|(_, _, child)| *child).unwrap_or(0);
for i in path_len..units.len() {
let unit = units[i];
let new_idx = self.nodes.len();
let is_final = i == units.len() - 1;
let mut new_node = if is_final {
DawgNode::new_with_value(true, Some(value.clone()))
} else {
DawgNode::new(false)
};
new_node.ref_count = 1;
self.nodes.push(new_node);
self.insert_edge_sorted(node_idx, unit, new_idx);
node_idx = new_idx;
}
self.term_count += 1;
self.recompute_ref_counts();
self.check_and_auto_minimize();
true
}
pub fn contains_units(&self, units: &[U]) -> bool {
let mut node_idx = 0;
for &unit in units {
match self.nodes[node_idx]
.edges
.iter()
.find(|(u, _)| *u == unit)
.map(|(_, idx)| idx)
{
Some(&child_idx) => node_idx = child_idx,
None => return false,
}
}
self.nodes[node_idx].is_final
}
pub fn get_value_for_units(&self, units: &[U]) -> Option<V> {
let mut node_idx = 0;
for &unit in units {
match self.nodes[node_idx]
.edges
.iter()
.find(|(u, _)| *u == unit)
.map(|(_, idx)| idx)
{
Some(&child_idx) => node_idx = child_idx,
None => return None,
}
}
if self.nodes[node_idx].is_final {
self.nodes[node_idx].value.clone()
} else {
None
}
}
pub fn remove_units(&mut self, units: &[U]) -> bool {
let mut node_idx = 0;
let mut path: Vec<(usize, U, usize)> = Vec::new();
for &unit in units {
if let Some(&child_idx) = self.nodes[node_idx]
.edges
.iter()
.find(|(u, _)| *u == unit)
.map(|(_, idx)| idx)
{
path.push((node_idx, unit, child_idx));
node_idx = child_idx;
} else {
return false; }
}
if !self.nodes[node_idx].is_final {
return false;
}
let unique_path = self.make_path_unique(&path);
node_idx = unique_path.last().map(|(_, _, child)| *child).unwrap_or(0);
self.nodes[node_idx].is_final = false;
self.nodes[node_idx].value = None;
self.term_count -= 1;
for (parent_idx, label, child_idx) in unique_path.iter().rev() {
let child = &self.nodes[*child_idx];
if !child.is_final && child.edges.is_empty() {
self.nodes[*parent_idx].edges.retain(|(u, _)| *u != *label);
} else {
break;
}
}
self.suffix_cache.clear();
self.needs_compaction = true;
self.recompute_ref_counts();
true
}
pub fn update_or_insert_units<F>(&mut self, units: &[U], default_value: V, update_fn: F) -> bool
where
F: FnOnce(&mut V),
{
let mut node_idx = 0;
let mut path: Vec<(usize, U, usize)> = Vec::new();
for &unit in units {
if let Some(&child_idx) = self.nodes[node_idx]
.edges
.iter()
.find(|(u, _)| *u == unit)
.map(|(_, idx)| idx)
{
path.push((node_idx, unit, child_idx));
node_idx = child_idx;
} else {
break;
}
}
let path_len = path.len();
if path_len == units.len() {
let unique_path = self.make_path_unique(&path);
node_idx = unique_path.last().map(|(_, _, child)| *child).unwrap_or(0);
if self.nodes[node_idx].is_final {
if let Some(ref mut existing_value) = self.nodes[node_idx].value {
update_fn(existing_value);
} else {
self.nodes[node_idx].value = Some(default_value);
}
self.recompute_ref_counts();
return false;
} else {
self.nodes[node_idx].is_final = true;
self.nodes[node_idx].value = Some(default_value);
self.term_count += 1;
self.recompute_ref_counts();
self.check_and_auto_minimize();
return true;
}
}
let unique_path = self.make_path_unique(&path);
node_idx = unique_path.last().map(|(_, _, child)| *child).unwrap_or(0);
for i in path_len..units.len() {
let unit = units[i];
let new_idx = self.nodes.len();
let is_final = i == units.len() - 1;
let mut new_node = if is_final {
DawgNode::new_with_value(true, Some(default_value.clone()))
} else {
DawgNode::new(false)
};
new_node.ref_count = 1;
self.nodes.push(new_node);
self.insert_edge_sorted(node_idx, unit, new_idx);
node_idx = new_idx;
}
self.term_count += 1;
self.recompute_ref_counts();
self.check_and_auto_minimize();
true
}
pub fn bloom_insert(&mut self, term: &str) {
if let Some(ref mut bloom) = self.bloom_filter {
bloom.insert(term);
}
}
#[inline]
pub fn bloom_might_contain(&self, term: &str) -> bool {
match &self.bloom_filter {
Some(bloom) => bloom.might_contain(term),
None => true,
}
}
pub fn compact(&mut self) -> usize {
let entries = self.extract_all_entries();
let old_node_count = self.nodes.len();
let auto_minimize_threshold = self.auto_minimize_threshold;
let bloom_capacity = self.bloom_filter.as_ref().map(|b| b.capacity() / 10);
self.nodes = vec![DawgNode::new(false)];
self.term_count = 0;
self.needs_compaction = false;
self.suffix_cache.clear();
self.last_minimized_node_count = 1;
self.auto_minimize_threshold = auto_minimize_threshold;
self.bloom_filter = bloom_capacity.map(BloomFilter::new);
let mut sorted_entries = entries;
sorted_entries.sort_by(|(left, _), (right, _)| left.cmp(right));
for (term, value) in &sorted_entries {
self.insert_direct_with_value(term, value.clone());
if let Some(ref mut bloom) = self.bloom_filter {
let term_str = U::to_string(term);
bloom.insert(&term_str);
}
}
let minimized = self.minimize_incremental();
old_node_count.saturating_sub(self.nodes.len()) + minimized
}
pub fn minimize_incremental(&mut self) -> usize {
let initial_count = self.nodes.len();
let signatures = self.compute_signatures();
let mut sig_to_canonical: HashMap<NodeSignature, Vec<usize>> = HashMap::new();
let mut node_mapping: Vec<usize> = (0..self.nodes.len()).collect();
for node_idx in (0..self.nodes.len()).rev() {
let sig = &signatures[node_idx];
if let Some(canonical_candidates) = sig_to_canonical.get(sig) {
let mut found_match = false;
for &canonical_idx in canonical_candidates {
if node_mapping[canonical_idx] != canonical_idx {
continue;
}
if self.nodes_structurally_equal(node_idx, canonical_idx, &node_mapping) {
node_mapping[node_idx] = canonical_idx;
found_match = true;
break;
}
}
if !found_match {
sig_to_canonical
.get_mut(sig)
.expect("sig exists")
.push(node_idx);
node_mapping[node_idx] = node_idx;
}
} else {
sig_to_canonical.insert(*sig, vec![node_idx]);
node_mapping[node_idx] = node_idx;
}
}
for node in &mut self.nodes {
for (_, target_idx) in &mut node.edges {
*target_idx = node_mapping[*target_idx];
}
}
let reachable = self.find_reachable_nodes();
if reachable.len() < self.nodes.len() {
self.compact_with_reachable(&reachable);
}
self.suffix_cache.clear();
self.needs_compaction = false;
self.last_minimized_node_count = self.nodes.len();
self.recompute_ref_counts();
initial_count.saturating_sub(self.nodes.len())
}
#[inline]
pub(crate) fn insert_edge_sorted(&mut self, node_idx: usize, label: U, target_idx: usize) {
let edges = &mut self.nodes[node_idx].edges;
match edges.binary_search_by_key(&label, |(l, _)| *l) {
Ok(pos) => {
edges[pos] = (label, target_idx);
}
Err(pos) => {
edges.insert(pos, (label, target_idx));
}
}
}
fn make_path_unique(&mut self, path: &[(usize, U, usize)]) -> Vec<(usize, U, usize)> {
let mut unique_path = Vec::with_capacity(path.len());
let mut parent_idx = 0;
for (_, label, _) in path {
let child_idx = self.nodes[parent_idx]
.edges
.iter()
.find(|(edge_label, _)| edge_label == label)
.map(|(_, target)| *target)
.expect("path labels must exist while making path unique");
let unique_child = self.ensure_unique_child(parent_idx, *label, child_idx);
unique_path.push((parent_idx, *label, unique_child));
parent_idx = unique_child;
}
unique_path
}
fn ensure_unique_child(&mut self, parent_idx: usize, label: U, child_idx: usize) -> usize {
if self.nodes[child_idx].ref_count <= 1 {
return child_idx;
}
let new_idx = self.nodes.len();
let mut cloned = self.nodes[child_idx].clone();
cloned.ref_count = 1;
let cloned_child_targets: Vec<usize> =
cloned.edges.iter().map(|(_, target)| *target).collect();
self.nodes.push(cloned);
for target in cloned_child_targets {
self.nodes[target].ref_count += 1;
}
self.nodes[child_idx].ref_count -= 1;
let edge_pos = self.nodes[parent_idx]
.edges
.iter()
.position(|(edge_label, target)| *edge_label == label && *target == child_idx)
.expect("path edge must exist while cloning shared DAWG node");
self.nodes[parent_idx].edges[edge_pos].1 = new_idx;
new_idx
}
pub(crate) fn recompute_ref_counts(&mut self) {
for node in &mut self.nodes {
node.ref_count = 0;
}
if self.nodes.is_empty() {
return;
}
self.nodes[0].ref_count = 1;
let targets: Vec<usize> = self
.nodes
.iter()
.flat_map(|node| node.edges.iter().map(|(_, target)| *target))
.collect();
for target in targets {
if let Some(node) = self.nodes.get_mut(target) {
node.ref_count += 1;
}
}
}
pub(crate) fn check_and_auto_minimize(&mut self) {
let current_nodes = self.nodes.len();
let threshold_nodes =
(self.last_minimized_node_count as f32 * self.auto_minimize_threshold) as usize;
if current_nodes > threshold_nodes && !self.auto_minimize_threshold.is_infinite() {
self.minimize_incremental();
}
}
pub(crate) fn nodes_structurally_equal(
&self,
idx1: usize,
idx2: usize,
node_mapping: &[usize],
) -> bool {
let node1 = &self.nodes[idx1];
let node2 = &self.nodes[idx2];
if node1.is_final != node2.is_final {
return false;
}
if node1.value.is_some() || node2.value.is_some() {
return false;
}
if node1.edges.len() != node2.edges.len() {
return false;
}
for i in 0..node1.edges.len() {
let (label1, target1) = node1.edges[i];
let (label2, target2) = node2.edges[i];
if label1 != label2 {
return false;
}
if node_mapping[target1] != node_mapping[target2] {
return false;
}
}
true
}
pub(crate) fn compute_signatures(&self) -> Vec<NodeSignature> {
let mut signatures = vec![NodeSignature::zero(); self.nodes.len()];
let mut visited = vec![false; self.nodes.len()];
self.compute_signatures_dfs(0, &mut signatures, &mut visited);
signatures
}
pub(crate) fn compute_signatures_dfs(
&self,
node_idx: usize,
signatures: &mut [NodeSignature],
visited: &mut [bool],
) {
if visited[node_idx] {
return;
}
visited[node_idx] = true;
let node = &self.nodes[node_idx];
for (_, child_idx) in &node.edges {
self.compute_signatures_dfs(*child_idx, signatures, visited);
}
let edge_iter = node
.edges
.iter()
.map(|(label, child_idx)| (*label, signatures[*child_idx]));
signatures[node_idx] = NodeSignature::compute(node.is_final, edge_iter);
}
pub(crate) fn find_reachable_nodes(&self) -> Vec<usize> {
let mut reachable = Vec::new();
let mut visited = vec![false; self.nodes.len()];
self.find_reachable_dfs(0, &mut visited);
for (idx, &is_reachable) in visited.iter().enumerate() {
if is_reachable {
reachable.push(idx);
}
}
reachable
}
pub(crate) fn find_reachable_dfs(&self, node_idx: usize, visited: &mut [bool]) {
if visited[node_idx] {
return;
}
visited[node_idx] = true;
for (_, child_idx) in &self.nodes[node_idx].edges {
self.find_reachable_dfs(*child_idx, visited);
}
}
pub(crate) fn compact_with_reachable(&mut self, reachable: &[usize]) {
let mut old_to_new = vec![usize::MAX; self.nodes.len()];
for (new_idx, &old_idx) in reachable.iter().enumerate() {
old_to_new[old_idx] = new_idx;
}
let new_nodes: Vec<DawgNode<U, V>> = reachable
.iter()
.map(|&old_idx| {
let mut node = self.nodes[old_idx].clone();
for (_, target) in &mut node.edges {
*target = old_to_new[*target];
}
node
})
.collect();
self.nodes = new_nodes;
self.recompute_ref_counts();
}
pub(crate) fn extract_all_entries(&self) -> Vec<(Vec<U>, Option<V>)> {
let mut entries = Vec::new();
let mut current_term = Vec::new();
self.dfs_collect_entries(0, &mut current_term, &mut entries);
entries
}
fn dfs_collect_entries(
&self,
node_idx: usize,
current_term: &mut Vec<U>,
entries: &mut Vec<(Vec<U>, Option<V>)>,
) {
let node = &self.nodes[node_idx];
if node.is_final {
entries.push((current_term.clone(), node.value.clone()));
}
for (unit, child_idx) in &node.edges {
current_term.push(*unit);
self.dfs_collect_entries(*child_idx, current_term, entries);
current_term.pop();
}
}
pub(crate) fn insert_direct_with_value(&mut self, units: &[U], value: Option<V>) {
let mut node_idx = 0;
for &unit in units {
if let Some(&child_idx) = self.nodes[node_idx]
.edges
.iter()
.find(|(u, _)| *u == unit)
.map(|(_, idx)| idx)
{
node_idx = child_idx;
} else {
let new_idx = self.nodes.len();
self.nodes.push(DawgNode::new(false));
self.nodes[node_idx].edges.push((unit, new_idx));
node_idx = new_idx;
}
}
if !self.nodes[node_idx].is_final {
self.term_count += 1;
}
self.nodes[node_idx].is_final = true;
self.nodes[node_idx].value = value;
}
}
impl<U: CharUnit, V: DictionaryValue> Default for DawgCore<U, V> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dawg_core_insert_bytes() {
let mut core: DawgCore<u8, ()> = DawgCore::new();
assert!(core.insert_units(b"hello"));
assert!(core.insert_units(b"world"));
assert!(!core.insert_units(b"hello"));
assert_eq!(core.term_count(), 2);
assert!(core.contains_units(b"hello"));
assert!(core.contains_units(b"world"));
assert!(!core.contains_units(b"foo"));
}
#[test]
fn test_dawg_core_insert_chars() {
let mut core: DawgCore<char, ()> = DawgCore::new();
let hello: Vec<char> = "hello".chars().collect();
let world: Vec<char> = "world".chars().collect();
let cafe: Vec<char> = "café".chars().collect();
assert!(core.insert_units(&hello));
assert!(core.insert_units(&world));
assert!(core.insert_units(&cafe));
assert!(!core.insert_units(&hello));
assert_eq!(core.term_count(), 3);
assert!(core.contains_units(&hello));
assert!(core.contains_units(&cafe));
}
#[test]
fn test_dawg_core_with_values() {
let mut core: DawgCore<u8, u32> = DawgCore::new();
assert!(core.insert_units_with_value(b"key1", 42));
assert!(core.insert_units_with_value(b"key2", 100));
assert!(!core.insert_units_with_value(b"key1", 999));
assert_eq!(core.get_value_for_units(b"key1"), Some(999));
assert_eq!(core.get_value_for_units(b"key2"), Some(100));
assert_eq!(core.get_value_for_units(b"unknown"), None);
}
#[test]
fn test_dawg_core_remove() {
let mut core: DawgCore<u8, ()> = DawgCore::new();
core.insert_units(b"test");
core.insert_units(b"testing");
core.insert_units(b"tested");
assert!(core.remove_units(b"testing"));
assert_eq!(core.term_count(), 2);
assert!(!core.remove_units(b"testing")); assert!(core.contains_units(b"test"));
assert!(!core.contains_units(b"testing"));
}
#[test]
fn test_dawg_core_minimize() {
let mut core: DawgCore<u8, ()> = DawgCore::new();
core.insert_units(b"zebra");
core.insert_units(b"apple");
core.insert_units(b"banana");
core.insert_units(b"apricot");
let nodes_before = core.node_count();
let merged = core.minimize_incremental();
let nodes_after = core.node_count();
assert_eq!(nodes_after, nodes_before - merged);
assert!(core.contains_units(b"zebra"));
assert!(core.contains_units(b"apple"));
assert!(core.contains_units(b"banana"));
assert!(core.contains_units(b"apricot"));
}
#[test]
fn test_dawg_core_compact() {
let mut core: DawgCore<u8, ()> = DawgCore::new();
core.insert_units(b"test");
core.insert_units(b"testing");
core.insert_units(b"tested");
core.remove_units(b"testing");
let _removed = core.compact();
assert_eq!(core.term_count(), 2);
assert!(core.contains_units(b"test"));
assert!(core.contains_units(b"tested"));
assert!(!core.contains_units(b"testing"));
}
}