use std::collections::HashSet as StdHashSet;
use std::error::Error;
use std::fs::File;
use std::io::{BufRead, BufReader};
use hashbrown::HashMap;
use smallvec::SmallVec;
use super::builder::{BuilderCore, BuilderError, IntoWord};
use super::char_trait::DawgChar;
use super::children::DawgNode;
use super::node_arena::NodeArena;
struct MutationState<C: DawgChar + 'static> {
register: HashMap<&'static DawgNode<'static, C>, usize>,
free_list: Vec<*mut DawgNode<'static, C>>,
}
pub struct OwnedDawg<C: DawgChar + 'static> {
arena: NodeArena<DawgNode<'static, C>>,
root: *const DawgNode<'static, C>,
mutation_state: Option<MutationState<C>>,
}
unsafe impl<C: DawgChar + 'static> Send for OwnedDawg<C> {}
unsafe impl<C: DawgChar + 'static> Sync for OwnedDawg<C> {}
impl<C: DawgChar + 'static> OwnedDawg<C> {
pub fn root(&self) -> &DawgNode<'_, C> {
unsafe { &*self.root }
}
pub fn node_count(&self) -> usize {
self.arena.len()
}
}
impl<C: DawgChar + 'static> MutationState<C> {
fn initialize(root: *const DawgNode<'static, C>) -> Self {
let root_ref = unsafe { &*root };
let mut register: HashMap<&'static DawgNode<'static, C>, usize> = HashMap::new();
let mut visited = StdHashSet::new();
let mut stack = vec![root_ref];
while let Some(node) = stack.pop() {
let ptr = node as *const DawgNode<'static, C>;
if !visited.insert(ptr) {
continue;
}
register.entry(node).or_insert(0);
for (_, child) in node.children() {
*register.entry(child).or_insert(0) += 1;
stack.push(child);
}
}
*register.entry(root_ref).or_insert(0) += 1;
MutationState {
register,
free_list: Vec::new(),
}
}
}
impl<C: DawgChar + 'static> OwnedDawg<C> {
fn ensure_mutation_state(&mut self) {
if self.mutation_state.is_none() {
self.mutation_state = Some(MutationState::initialize(self.root));
}
}
fn alloc_node(
arena: &NodeArena<DawgNode<'static, C>>,
free_list: &mut Vec<*mut DawgNode<'static, C>>,
mut node: DawgNode<'static, C>,
) -> &'static DawgNode<'static, C> {
node.set_canonical();
if let Some(slot) = free_list.pop() {
unsafe {
std::ptr::write(slot, node);
&*slot
}
} else {
let arena_ref: &'static NodeArena<DawgNode<'static, C>> =
unsafe { &*(arena as *const NodeArena<DawgNode<'static, C>>) };
arena_ref.alloc(node)
}
}
fn canonicalize_node(
arena: &NodeArena<DawgNode<'static, C>>,
state: &mut MutationState<C>,
node: DawgNode<'static, C>,
) -> &'static DawgNode<'static, C> {
if let Some((&existing, _)) = state.register.get_key_value(&node) {
existing
} else {
let child_refs: SmallVec<[&'static DawgNode<'static, C>; 8]> =
(0..node.child_count())
.map(|i| {
let (_, child) = node.children_ref().get(i).unwrap();
child
})
.collect();
let allocated = Self::alloc_node(arena, &mut state.free_list, node);
state.register.insert(allocated, 0);
for child in child_refs {
*state.register.get_mut(child).expect("child not in register") += 1;
}
allocated
}
}
fn decrement_refcount_cascade(state: &mut MutationState<C>, node_ptr: *const DawgNode<'static, C>) {
let node = unsafe { &*node_ptr };
let rc = state.register.get_mut(node).expect("node not in register");
*rc -= 1;
if *rc == 0 {
let children: SmallVec<[*const DawgNode<'static, C>; 8]> = node
.children()
.map(|(_, child)| child as *const _)
.collect();
state.register.remove(node);
for child_ptr in children {
Self::decrement_refcount_cascade(state, child_ptr);
}
let slot = node_ptr as *mut DawgNode<'static, C>;
unsafe {
std::ptr::write(slot, DawgNode::new(false));
}
state.free_list.push(slot);
}
}
pub fn add_word(&mut self, word: impl IntoWord<C>) -> bool {
let word = word.collect_word();
if word.is_empty() {
return false;
}
self.ensure_mutation_state();
let root = unsafe { &*self.root };
let mut path_nodes: SmallVec<[*const DawgNode<'static, C>; 32]> = SmallVec::new();
path_nodes.push(self.root);
let mut current = root;
let mut prefix_len = 0;
for &ch in word.iter() {
if let Some(child) = current.get(ch) {
path_nodes.push(child as *const _);
current = child;
prefix_len += 1;
} else {
break;
}
}
if prefix_len == word.len() && current.is_word() {
return false; }
let Self {
arena,
mutation_state,
..
} = self;
let state = mutation_state.as_mut().unwrap();
let mut updated_child: &'static DawgNode<'static, C>;
if prefix_len == word.len() {
let terminal = unsafe { &*path_nodes[prefix_len] };
let new_terminal =
DawgNode::with_children(true, terminal.children_ref().clone());
updated_child = Self::canonicalize_node(arena, state, new_terminal);
} else {
let leaf = DawgNode::new(true);
updated_child = Self::canonicalize_node(arena, state, leaf);
for i in (prefix_len + 1..word.len()).rev() {
let mut intermediate = DawgNode::new(false);
intermediate.insert(word[i], updated_child);
updated_child = Self::canonicalize_node(arena, state, intermediate);
}
}
let start_level = if prefix_len < word.len() {
prefix_len
} else {
prefix_len - 1
};
for level in (0..=start_level).rev() {
let old_node = unsafe { &*path_nodes[level] };
let ch = word[level];
let new_children = if level == prefix_len && prefix_len < word.len() {
old_node.children_ref().with_added_child(ch, updated_child)
} else {
old_node.children_ref().with_replaced_child(ch, updated_child)
};
let new_node = DawgNode::with_children(old_node.is_word(), new_children);
updated_child = Self::canonicalize_node(arena, state, new_node);
}
let old_root = self.root;
self.root = updated_child as *const _;
let state = self.mutation_state.as_mut().unwrap();
*state.register.get_mut(updated_child).expect("new root not in register") += 1;
Self::decrement_refcount_cascade(state, old_root);
true
}
pub fn remove_word(&mut self, word: impl IntoWord<C>) -> bool {
let word = word.collect_word();
if word.is_empty() {
return false;
}
self.ensure_mutation_state();
let root = unsafe { &*self.root };
let mut path_nodes: SmallVec<[*const DawgNode<'static, C>; 32]> = SmallVec::new();
path_nodes.push(self.root);
let mut current = root;
for &ch in word.iter() {
if let Some(child) = current.get(ch) {
path_nodes.push(child as *const _);
current = child;
} else {
return false; }
}
if !current.is_word() {
return false; }
let Self {
arena,
mutation_state,
..
} = self;
let state = mutation_state.as_mut().unwrap();
let terminal = unsafe { &*path_nodes[word.len()] };
let new_terminal = DawgNode::with_children(false, terminal.children_ref().clone());
let mut updated_child: Option<&'static DawgNode<'static, C>> =
if new_terminal.child_count() == 0 {
None } else {
Some(Self::canonicalize_node(arena, state, new_terminal))
};
for level in (0..word.len()).rev() {
let old_node = unsafe { &*path_nodes[level] };
let ch = word[level];
let new_node = match updated_child {
None => {
let new_children = old_node.children_ref().without_child(ch);
DawgNode::with_children(old_node.is_word(), new_children)
}
Some(child) => {
let new_children =
old_node.children_ref().with_replaced_child(ch, child);
DawgNode::with_children(old_node.is_word(), new_children)
}
};
if new_node.child_count() == 0 && !new_node.is_word() {
updated_child = None; } else {
updated_child = Some(Self::canonicalize_node(arena, state, new_node));
}
}
let old_root = self.root;
let new_root_ref = match updated_child {
Some(node) => node,
None => {
let empty = DawgNode::new(false);
Self::canonicalize_node(arena, state, empty)
}
};
self.root = new_root_ref as *const _;
let state = self.mutation_state.as_mut().unwrap();
*state.register.get_mut(new_root_ref).expect("new root not in register") += 1;
Self::decrement_refcount_cascade(state, old_root);
true
}
pub fn contains(&self, word: impl IntoWord<C>) -> bool {
let word = word.collect_word();
let root = self.root();
let mut current = root;
for &ch in word.iter() {
match current.get(ch) {
Some(child) => current = child,
None => return false,
}
}
current.is_word()
}
}
impl<C: DawgChar + 'static> std::fmt::Debug for OwnedDawg<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OwnedDawg")
.field("node_count", &self.node_count())
.finish()
}
}
unsafe fn make_builder<C: DawgChar + 'static>(
arena: &NodeArena<DawgNode<'static, C>>,
) -> BuilderCore<'static, C, NodeArena<DawgNode<'static, C>>> {
let arena_ref: &'static NodeArena<DawgNode<'static, C>> =
&*(arena as *const NodeArena<DawgNode<'static, C>>);
BuilderCore::new(arena_ref)
}
pub fn build_owned_dawg<C, W>(
words: impl IntoIterator<Item = W>,
) -> Result<OwnedDawg<C>, BuilderError<C>>
where
C: DawgChar + 'static,
W: IntoWord<C>,
{
let arena: NodeArena<DawgNode<'static, C>> = NodeArena::new();
let root = unsafe {
let mut builder = make_builder(&arena);
for word in words {
builder.add_word(word)?;
}
builder.build() as *const DawgNode<'static, C>
};
Ok(OwnedDawg {
arena,
root,
mutation_state: None,
})
}
pub fn build_owned_dawg_from_file(
filename: &str,
) -> Result<OwnedDawg<char>, Box<dyn Error>> {
let arena: NodeArena<DawgNode<'static, char>> = NodeArena::new();
let root = unsafe {
let mut builder = make_builder(&arena);
let file = File::open(filename)?;
let mut reader = BufReader::new(file);
let mut buf = String::with_capacity(80);
loop {
let bytes_read = reader.read_line(&mut buf);
match bytes_read {
Ok(0) => break,
Err(e) => return Err(e.into()),
_ => {}
}
let word = buf.trim_end();
if !word.is_empty() {
builder.add_word(word)?;
}
buf.clear();
}
builder.build() as *const DawgNode<'static, char>
};
Ok(OwnedDawg {
arena,
root,
mutation_state: None,
})
}
#[cfg(test)]
mod test {
use super::*;
fn is_word(root: &DawgNode<'_, char>, word: &str) -> bool {
word.chars()
.try_fold(root, |n, ch| n.get(ch))
.is_some_and(|n| n.is_word())
}
#[test]
fn basic_word_lookup() {
let dawg = build_owned_dawg(["BAKE", "CAKE", "FAKE", "LAKE"]).unwrap();
let root = dawg.root();
assert!(is_word(root, "BAKE"));
assert!(is_word(root, "CAKE"));
assert!(!is_word(root, "MAKE"));
assert!(!is_word(root, "BAK"));
}
#[test]
fn sorted_input_required() {
let res = build_owned_dawg(["ZULU", "ALFA"]);
assert!(res.is_err());
}
#[test]
fn suffix_sharing() {
let dawg = build_owned_dawg([
"ASUFFIX",
"BSUFFIX",
"CDESUFFIX",
"FFFFFFFSUFFIX",
"INBETWEEN",
"JSUFFIX",
"XXSUFFIX",
])
.unwrap();
let root = dawg.root();
let suffix_node = root.get('A').unwrap().get('S').unwrap();
for prefix_char in ['B', 'J'] {
let node = root.get(prefix_char).unwrap().get('S').unwrap();
assert!(std::ptr::addr_eq(node, suffix_node));
}
}
#[test]
fn generic_u8() {
let dawg: OwnedDawg<u8> =
build_owned_dawg([vec![1, 2, 3], vec![1, 2, 4], vec![2, 3, 4]]).unwrap();
let root = dawg.root();
assert!(root
.get(1)
.and_then(|n| n.get(2))
.and_then(|n| n.get(3))
.is_some_and(|n| n.is_word()));
assert!(root
.get(1)
.and_then(|n| n.get(2))
.and_then(|n| n.get(5))
.is_none());
}
#[test]
fn node_count() {
let dawg = build_owned_dawg(["ABC", "ABD"]).unwrap();
assert!(dawg.node_count() > 0);
}
#[test]
fn all_words_from_file() {
use std::fs::File;
use std::io::{BufRead, BufReader};
let dict_filename = "../dict-sv.txt";
let file = File::open(dict_filename).unwrap();
let words: Vec<String> = BufReader::new(file)
.lines()
.map(|l| l.unwrap())
.filter(|w| !w.is_empty() && !w.starts_with('#'))
.collect();
let dawg = build_owned_dawg(&words).unwrap();
let root = dawg.root();
for word in &words {
assert!(is_word(root, word), "{}", word);
}
assert!(!is_word(root, "URSINN"));
assert!(!is_word(root, "ÅTMINSTON"));
}
#[test]
fn owned_dawg_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<OwnedDawg<char>>();
}
#[test]
fn contains_basic() {
let dawg = build_owned_dawg(["BAKE", "CAKE", "FAKE"]).unwrap();
assert!(dawg.contains("BAKE"));
assert!(dawg.contains("CAKE"));
assert!(!dawg.contains("MAKE"));
assert!(!dawg.contains("BAK"));
}
#[test]
fn add_word_to_empty() {
let mut dawg = build_owned_dawg::<char, &str>([]).unwrap();
assert!(dawg.add_word("HELLO"));
assert!(dawg.contains("HELLO"));
assert!(!dawg.contains("HELL"));
}
#[test]
fn add_word_returns_false_for_duplicate() {
let mut dawg = build_owned_dawg(["BAKE", "CAKE"]).unwrap();
assert!(!dawg.add_word("BAKE"));
assert!(!dawg.add_word("CAKE"));
}
#[test]
fn add_word_preserves_existing() {
let mut dawg = build_owned_dawg(["BAKE", "CAKE"]).unwrap();
dawg.add_word("FAKE");
assert!(dawg.contains("BAKE"));
assert!(dawg.contains("CAKE"));
assert!(dawg.contains("FAKE"));
}
#[test]
fn add_prefix_of_existing() {
let mut dawg = build_owned_dawg(["CART"]).unwrap();
assert!(dawg.add_word("CAR"));
assert!(dawg.contains("CAR"));
assert!(dawg.contains("CART"));
}
#[test]
fn add_extension_of_existing() {
let mut dawg = build_owned_dawg(["CAR"]).unwrap();
assert!(dawg.add_word("CART"));
assert!(dawg.contains("CAR"));
assert!(dawg.contains("CART"));
}
#[test]
fn add_multiple_words() {
let mut dawg = build_owned_dawg::<char, &str>([]).unwrap();
for word in ["FAKE", "CAKE", "BAKE", "LAKE", "MAKE"] {
assert!(dawg.add_word(word));
}
for word in ["FAKE", "CAKE", "BAKE", "LAKE", "MAKE"] {
assert!(dawg.contains(word));
}
assert!(!dawg.contains("SAKE"));
}
#[test]
fn remove_word_basic() {
let mut dawg = build_owned_dawg(["BAKE", "CAKE", "FAKE"]).unwrap();
assert!(dawg.remove_word("BAKE"));
assert!(!dawg.contains("BAKE"));
assert!(dawg.contains("CAKE"));
assert!(dawg.contains("FAKE"));
}
#[test]
fn remove_word_returns_false_for_missing() {
let mut dawg = build_owned_dawg(["BAKE", "CAKE"]).unwrap();
assert!(!dawg.remove_word("FAKE"));
assert!(!dawg.remove_word("BAK"));
}
#[test]
fn remove_word_returns_false_for_already_removed() {
let mut dawg = build_owned_dawg(["BAKE", "CAKE"]).unwrap();
assert!(dawg.remove_word("BAKE"));
assert!(!dawg.remove_word("BAKE"));
}
#[test]
fn remove_preserves_existing() {
let mut dawg = build_owned_dawg(["BAKE", "CAKE", "FAKE"]).unwrap();
dawg.remove_word("CAKE");
assert!(dawg.contains("BAKE"));
assert!(!dawg.contains("CAKE"));
assert!(dawg.contains("FAKE"));
}
#[test]
fn remove_last_word() {
let mut dawg = build_owned_dawg(["HELLO"]).unwrap();
assert!(dawg.remove_word("HELLO"));
assert!(!dawg.contains("HELLO"));
assert!(!dawg.root().is_word());
assert_eq!(dawg.root().child_count(), 0);
}
#[test]
fn remove_all_words() {
let mut dawg = build_owned_dawg(["BAKE", "CAKE", "FAKE"]).unwrap();
assert!(dawg.remove_word("BAKE"));
assert!(dawg.remove_word("CAKE"));
assert!(dawg.remove_word("FAKE"));
assert!(!dawg.contains("BAKE"));
assert!(!dawg.contains("CAKE"));
assert!(!dawg.contains("FAKE"));
assert_eq!(dawg.root().child_count(), 0);
}
#[test]
fn remove_prefix_keeps_extension() {
let mut dawg = build_owned_dawg(["CAR", "CART"]).unwrap();
assert!(dawg.remove_word("CAR"));
assert!(!dawg.contains("CAR"));
assert!(dawg.contains("CART"));
}
#[test]
fn remove_extension_keeps_prefix() {
let mut dawg = build_owned_dawg(["CAR", "CART"]).unwrap();
assert!(dawg.remove_word("CART"));
assert!(dawg.contains("CAR"));
assert!(!dawg.contains("CART"));
}
#[test]
fn interleaved_add_remove() {
let mut dawg = build_owned_dawg(["BAKE", "CAKE"]).unwrap();
dawg.add_word("FAKE");
dawg.remove_word("BAKE");
dawg.add_word("LAKE");
dawg.add_word("MAKE");
dawg.remove_word("CAKE");
assert!(!dawg.contains("BAKE"));
assert!(!dawg.contains("CAKE"));
assert!(dawg.contains("FAKE"));
assert!(dawg.contains("LAKE"));
assert!(dawg.contains("MAKE"));
}
#[test]
fn add_maintains_suffix_sharing() {
let mut dawg = build_owned_dawg(["BAKE", "CAKE"]).unwrap();
dawg.add_word("FAKE");
let root = dawg.root();
let bake_a = root.get('B').unwrap().get('A').unwrap();
let cake_a = root.get('C').unwrap().get('A').unwrap();
let fake_a = root.get('F').unwrap().get('A').unwrap();
assert!(std::ptr::addr_eq(bake_a, cake_a));
assert!(std::ptr::addr_eq(cake_a, fake_a));
}
#[test]
fn remove_does_not_break_sharing() {
let mut dawg = build_owned_dawg(["BAKE", "CAKE", "FAKE"]).unwrap();
dawg.remove_word("FAKE");
let root = dawg.root();
let bake_a = root.get('B').unwrap().get('A').unwrap();
let cake_a = root.get('C').unwrap().get('A').unwrap();
assert!(std::ptr::addr_eq(bake_a, cake_a));
}
#[test]
fn add_word_to_prebuilt_dawg_shares_suffix() {
let mut dawg = build_owned_dawg([
"ASUFFIX",
"BSUFFIX",
])
.unwrap();
dawg.add_word("CSUFFIX");
let root = dawg.root();
let a_suffix = root.get('A').unwrap().get('S').unwrap();
let b_suffix = root.get('B').unwrap().get('S').unwrap();
let c_suffix = root.get('C').unwrap().get('S').unwrap();
assert!(std::ptr::addr_eq(a_suffix, b_suffix));
assert!(std::ptr::addr_eq(b_suffix, c_suffix));
}
#[test]
fn minimality_matches_fresh_build() {
let mut dawg_inc = build_owned_dawg::<char, &str>([]).unwrap();
let words = ["BAKE", "BAKED", "CAKE", "CAKED", "FAKE", "FAKED"];
for word in words {
dawg_inc.add_word(word);
}
let dawg_fresh = build_owned_dawg(words).unwrap();
let inc_state = dawg_inc.mutation_state.as_ref().unwrap();
let fresh_register = {
let mut dawg = dawg_fresh;
dawg.ensure_mutation_state();
dawg.mutation_state.unwrap().register.len()
};
assert_eq!(inc_state.register.len(), fresh_register);
}
#[test]
fn add_then_remove_returns_to_empty() {
let mut dawg = build_owned_dawg::<char, &str>([]).unwrap();
dawg.add_word("HELLO");
dawg.add_word("WORLD");
dawg.remove_word("HELLO");
dawg.remove_word("WORLD");
assert!(!dawg.contains("HELLO"));
assert!(!dawg.contains("WORLD"));
assert_eq!(dawg.root().child_count(), 0);
}
#[test]
fn generic_u8_add_remove() {
let mut dawg: OwnedDawg<u8> =
build_owned_dawg([vec![1, 2, 3], vec![1, 2, 4]]).unwrap();
dawg.add_word(vec![2, 3, 4]);
assert!(dawg.contains([2u8, 3, 4].as_slice()));
assert!(dawg.contains([1u8, 2, 3].as_slice()));
dawg.remove_word([1u8, 2, 3].as_slice());
assert!(!dawg.contains([1u8, 2, 3].as_slice()));
assert!(dawg.contains([1u8, 2, 4].as_slice()));
}
#[test]
fn remove_populates_free_list() {
let mut dawg = build_owned_dawg(["HELLO"]).unwrap();
dawg.remove_word("HELLO");
let state = dawg.mutation_state.as_ref().unwrap();
assert!(
!state.free_list.is_empty(),
"free-list should have entries after removing the only word"
);
}
#[test]
fn add_after_remove_consumes_free_list() {
let mut dawg = build_owned_dawg(["HELLO"]).unwrap();
dawg.remove_word("HELLO");
let free_before = dawg.mutation_state.as_ref().unwrap().free_list.len();
assert!(free_before > 0);
dawg.add_word("WORLD");
let free_after = dawg.mutation_state.as_ref().unwrap().free_list.len();
assert!(
free_after < free_before,
"free-list should shrink after add (was {free_before}, now {free_after})"
);
assert!(dawg.contains("WORLD"));
}
#[test]
fn arena_does_not_grow_when_free_list_has_nodes() {
let mut dawg = build_owned_dawg(["ABCDEFGH", "XY"]).unwrap();
dawg.remove_word("ABCDEFGH");
let free_count = dawg.mutation_state.as_ref().unwrap().free_list.len();
let arena_before = dawg.node_count();
assert!(
free_count >= 3,
"need at least 3 free slots, got {free_count}"
);
dawg.add_word("ZW");
assert!(dawg.contains("ZW"));
assert!(dawg.contains("XY"));
let arena_after = dawg.node_count();
assert_eq!(
arena_before, arena_after,
"arena should not grow when free-list provides all nodes \
(before={arena_before}, after={arena_after})"
);
}
#[test]
fn free_list_reuse_produces_correct_dawg() {
let mut dawg = build_owned_dawg(["ALPHA", "BRAVO", "CHARLIE"]).unwrap();
dawg.remove_word("ALPHA");
dawg.remove_word("BRAVO");
dawg.remove_word("CHARLIE");
let free_after_remove = dawg.mutation_state.as_ref().unwrap().free_list.len();
assert!(free_after_remove > 0);
dawg.add_word("DELTA");
dawg.add_word("ECHO");
dawg.add_word("FOXTROT");
assert!(dawg.contains("DELTA"));
assert!(dawg.contains("ECHO"));
assert!(dawg.contains("FOXTROT"));
assert!(!dawg.contains("ALPHA"));
assert!(!dawg.contains("BRAVO"));
assert!(!dawg.contains("CHARLIE"));
let free_after_add = dawg.mutation_state.as_ref().unwrap().free_list.len();
assert!(
free_after_add < free_after_remove,
"free-list should shrink after adding words (was {free_after_remove}, now {free_after_add})"
);
}
#[test]
fn repeated_add_remove_cycles_reuse_nodes() {
let mut dawg = build_owned_dawg::<char, &str>([]).unwrap();
for _ in 0..5 {
dawg.add_word("TESTING");
dawg.remove_word("TESTING");
}
let arena_after_cycles = dawg.node_count();
for _ in 0..5 {
dawg.add_word("TESTING");
dawg.remove_word("TESTING");
}
let arena_after_more_cycles = dawg.node_count();
assert_eq!(
arena_after_cycles, arena_after_more_cycles,
"arena should not grow when repeatedly adding/removing the same word \
(after 5 cycles: {arena_after_cycles}, after 10: {arena_after_more_cycles})"
);
}
}