use crate::error::{Result, ZiporaError};
use crate::fsa::cache::{FsaCache, FsaCacheConfig};
use crate::fsa::traits::{FiniteStateAutomaton, Trie, TrieStats, StatisticsProvider};
use crate::StateId;
use crate::memory::SecureMemoryPool;
use crate::succinct::rank_select::{RankSelectInterleaved256, RankSelectOps};
use crate::succinct::BitVector;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct DawgConfig {
pub use_rank_select: bool,
pub enable_cache: bool,
pub cache_config: FsaCacheConfig,
pub max_states: usize,
pub compressed_storage: bool,
}
impl Default for DawgConfig {
fn default() -> Self {
Self {
use_rank_select: true,
enable_cache: true,
cache_config: FsaCacheConfig::default(),
max_states: 1_000_000,
compressed_storage: true,
}
}
}
impl DawgConfig {
pub fn memory_efficient() -> Self {
Self {
cache_config: FsaCacheConfig::memory_efficient(),
max_states: 100_000,
compressed_storage: true,
..Default::default()
}
}
pub fn performance_optimized() -> Self {
Self {
cache_config: FsaCacheConfig::large(),
max_states: 10_000_000,
use_rank_select: true,
..Default::default()
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TerminalStrategy {
RankSelect,
BitVector,
Inline,
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct DawgState {
pub child_base: u32,
pub flags_and_parent: u32,
}
impl DawgState {
pub fn new(child_base: u32, parent: u32, is_terminal: bool, is_final: bool) -> Self {
let flags = if is_terminal { 0x80 } else { 0 } | if is_final { 0x40 } else { 0 };
let flags_and_parent = (parent & 0x00FFFFFF) | ((flags as u32) << 24);
Self {
child_base,
flags_and_parent,
}
}
pub fn parent(&self) -> u32 {
self.flags_and_parent & 0x00FFFFFF
}
#[inline]
pub fn is_terminal(&self) -> bool {
(self.flags_and_parent & 0x80000000) != 0
}
#[inline]
pub fn is_final(&self) -> bool {
(self.flags_and_parent & 0x40000000) != 0
}
pub fn flags(&self) -> u8 {
((self.flags_and_parent >> 24) & 0xFF) as u8
}
pub fn set_terminal(&mut self, is_terminal: bool) {
if is_terminal {
self.flags_and_parent |= 0x80000000;
} else {
self.flags_and_parent &= !0x80000000;
}
}
pub fn set_final(&mut self, is_final: bool) {
if is_final {
self.flags_and_parent |= 0x40000000;
} else {
self.flags_and_parent &= !0x40000000;
}
}
}
#[derive(Debug, Clone)]
pub struct TransitionTable {
pub sparse_table: HashMap<(u32, u8), u32>,
pub num_states: u32,
}
impl TransitionTable {
pub fn new(num_states: u32, _use_dense: bool) -> Self {
Self {
sparse_table: HashMap::new(),
num_states,
}
}
#[inline]
pub fn add_transition(&mut self, from_state: u32, symbol: u8, to_state: u32) -> Result<()> {
self.sparse_table.insert((from_state, symbol), to_state);
Ok(())
}
#[inline]
pub fn get_transition(&self, from_state: u32, symbol: u8) -> Option<u32> {
self.sparse_table.get(&(from_state, symbol)).copied()
}
pub fn get_outgoing_transitions(&self, from_state: u32) -> Vec<(u8, u32)> {
let mut transitions: Vec<(u8, u32)> = self.sparse_table.iter()
.filter_map(|(&(state, symbol), &target)| {
if state == from_state { Some((symbol, target)) } else { None }
})
.collect();
transitions.sort_unstable_by_key(|(symbol, _)| *symbol);
transitions
}
#[inline]
pub fn memory_usage(&self) -> usize {
self.sparse_table.len() * (std::mem::size_of::<(u32, u8)>() + std::mem::size_of::<u32>())
}
}
pub struct NestedTrieDawg {
config: DawgConfig,
states: Vec<DawgState>,
transitions: TransitionTable,
terminal_strategy: TerminalStrategy,
terminal_bits: Option<BitVector>,
terminal_rank_select: Option<RankSelectInterleaved256>,
cache: Option<FsaCache>,
root_state: u32,
num_keys: usize,
memory_pool: Option<Arc<SecureMemoryPool>>,
}
impl NestedTrieDawg {
pub fn new() -> Result<Self> {
Self::with_config(DawgConfig::default())
}
pub fn with_config(config: DawgConfig) -> Result<Self> {
let cache = if config.enable_cache {
Some(FsaCache::with_config(config.cache_config.clone())?)
} else {
None
};
let memory_pool = Some(SecureMemoryPool::new(
crate::memory::SecurePoolConfig::small_secure()
)?);
let transitions = TransitionTable::new(config.max_states as u32, false);
let terminal_strategy = if config.use_rank_select {
TerminalStrategy::RankSelect
} else {
TerminalStrategy::BitVector
};
Ok(Self {
config,
states: Vec::new(),
transitions,
terminal_strategy,
terminal_bits: None,
terminal_rank_select: None,
cache,
root_state: 0,
num_keys: 0,
memory_pool,
})
}
pub fn build_from_keys<I, K>(&mut self, keys: I) -> Result<()>
where
I: IntoIterator<Item = K>,
K: AsRef<[u8]>,
{
self.clear();
self.root_state = self.add_state(0, false, false)?;
for key in keys {
self.insert_key(key.as_ref())?;
}
self.convert_to_dawg()?;
if self.terminal_strategy == TerminalStrategy::RankSelect {
self.build_terminal_rank_select()?;
}
Ok(())
}
fn insert_key(&mut self, key: &[u8]) -> Result<()> {
let mut current_state = self.root_state;
for &symbol in key {
if let Some(next_state) = self.transitions.get_transition(current_state, symbol) {
current_state = next_state;
} else {
let new_state = self.add_state(current_state, false, false)?;
self.transitions.add_transition(current_state, symbol, new_state)?;
current_state = new_state;
}
}
if (current_state as usize) < self.states.len() {
self.states[current_state as usize].set_terminal(true);
self.num_keys += 1;
}
Ok(())
}
fn add_state(&mut self, parent: u32, is_terminal: bool, is_final: bool) -> Result<u32> {
let state_id = self.states.len() as u32;
if state_id >= self.config.max_states as u32 {
return Err(ZiporaError::invalid_data("Maximum states exceeded"));
}
let state = DawgState::new(0, parent, is_terminal, is_final);
self.states.push(state);
if let Some(ref mut cache) = self.cache {
cache.cache_state(parent, 0, is_terminal)?;
}
Ok(state_id)
}
fn convert_to_dawg(&mut self) -> Result<()> {
let mut state_signatures: HashMap<Vec<u8>, u32> = HashMap::new();
let mut state_mapping: HashMap<u32, u32> = HashMap::new();
for state_id in (0..self.states.len()).rev() {
let signature = self.compute_state_signature(state_id as u32)?;
if let Some(&equivalent_state) = state_signatures.get(&signature) {
state_mapping.insert(state_id as u32, equivalent_state);
} else {
state_signatures.insert(signature, state_id as u32);
state_mapping.insert(state_id as u32, state_id as u32);
}
}
self.remap_transitions(&state_mapping)?;
self.compact_states(&state_mapping)?;
Ok(())
}
fn compute_state_signature(&self, state_id: u32) -> Result<Vec<u8>> {
let mut signature = Vec::new();
signature.push(if self.states[state_id as usize].is_terminal() { 1 } else { 0 });
let mut transitions = self.transitions.get_outgoing_transitions(state_id);
transitions.sort_by_key(|(symbol, _)| *symbol);
for (symbol, target_state) in transitions {
signature.push(symbol);
signature.extend_from_slice(&target_state.to_le_bytes());
}
Ok(signature)
}
fn remap_transitions(&mut self, state_mapping: &HashMap<u32, u32>) -> Result<()> {
let old_transitions = std::mem::replace(
&mut self.transitions,
TransitionTable::new(self.states.len() as u32, false)
);
for (&(from_state, symbol), &to_state) in &old_transitions.sparse_table {
if let (Some(&new_from), Some(&new_to)) =
(state_mapping.get(&from_state), state_mapping.get(&to_state)) {
self.transitions.add_transition(new_from, symbol, new_to)?;
}
}
Ok(())
}
fn compact_states(&mut self, state_mapping: &HashMap<u32, u32>) -> Result<()> {
let mut new_states = Vec::new();
let mut compaction_mapping: HashMap<u32, u32> = HashMap::new();
let mut new_id = 0u32;
for old_id in 0..self.states.len() as u32 {
if let Some(&mapped_id) = state_mapping.get(&old_id) {
if mapped_id == old_id {
compaction_mapping.insert(old_id, new_id);
new_states.push(self.states[old_id as usize]);
new_id += 1;
}
}
}
let old_transitions = std::mem::replace(
&mut self.transitions,
TransitionTable::new(new_states.len() as u32, false)
);
for (&(from_state, symbol), &to_state) in &old_transitions.sparse_table {
if let (Some(&compact_from), Some(&compact_to)) =
(compaction_mapping.get(&from_state), compaction_mapping.get(&to_state)) {
self.transitions.add_transition(compact_from, symbol, compact_to)?;
}
}
self.root_state = compaction_mapping.get(&self.root_state).copied().unwrap_or(0);
self.states = new_states;
Ok(())
}
fn build_terminal_rank_select(&mut self) -> Result<()> {
let mut terminal_bits = BitVector::new();
for state in &self.states {
terminal_bits.push(state.is_terminal())?;
}
self.terminal_rank_select = Some(RankSelectInterleaved256::new(terminal_bits.clone())?);
self.terminal_bits = Some(terminal_bits);
Ok(())
}
pub fn clear(&mut self) {
self.states.clear();
self.transitions = TransitionTable::new(self.config.max_states as u32, false);
self.terminal_bits = None;
self.terminal_rank_select = None;
self.root_state = 0;
self.num_keys = 0;
if let Some(ref mut cache) = self.cache {
cache.clear();
}
}
#[inline]
pub fn state_to_word_id(&self, state: u32) -> Option<usize> {
let idx = state as usize;
if idx >= self.states.len() || !self.states[idx].is_terminal() {
return None;
}
if let Some(ref rs) = self.terminal_rank_select {
Some(rs.rank1(idx))
} else {
Some(self.states[..=idx].iter().filter(|s| s.is_terminal()).count() - 1)
}
}
#[inline]
pub fn total_states(&self) -> usize {
self.states.len()
}
pub fn statistics(&self) -> DawgStats {
let transition_memory = self.transitions.memory_usage();
let state_memory = self.states.len() * std::mem::size_of::<DawgState>();
let terminal_memory = self.terminal_bits.as_ref()
.map(|_bv| std::mem::size_of::<BitVector>())
.unwrap_or(0) +
self.terminal_rank_select.as_ref()
.map(|_rs| std::mem::size_of::<RankSelectInterleaved256>())
.unwrap_or(0);
DawgStats {
num_states: self.states.len(),
num_transitions: self.transitions.sparse_table.len(),
num_keys: self.num_keys,
memory_usage: state_memory + transition_memory + terminal_memory,
compression_ratio: if self.num_keys > 0 {
self.states.len() as f64 / self.num_keys as f64
} else {
0.0
},
cache_hit_ratio: self.cache.as_ref()
.map(|c| c.stats().hit_ratio())
.unwrap_or(0.0),
}
}
}
#[derive(Debug, Clone)]
pub struct DawgStats {
pub num_states: usize,
pub num_transitions: usize,
pub num_keys: usize,
pub memory_usage: usize,
pub compression_ratio: f64,
pub cache_hit_ratio: f64,
}
impl FiniteStateAutomaton for NestedTrieDawg {
fn root(&self) -> StateId {
self.root_state as StateId
}
fn is_final(&self, state: StateId) -> bool {
if (state as usize) < self.states.len() {
self.states[state as usize].is_terminal()
} else {
false
}
}
fn transition(&self, state: StateId, symbol: u8) -> Option<StateId> {
self.transitions.get_transition(state as u32, symbol).map(|s| s as StateId)
}
fn transitions(&self, state: StateId) -> Vec<(u8, StateId)> {
self.transitions.get_outgoing_transitions(state as u32)
.into_iter()
.map(|(symbol, target)| (symbol, target as StateId))
.collect()
}
}
impl Trie for NestedTrieDawg {
fn insert(&mut self, key: &[u8]) -> Result<StateId> {
let _old_len = self.num_keys;
self.insert_key(key)?;
let mut current_state = self.root_state;
for &symbol in key {
if let Some(next_state) = self.transitions.get_transition(current_state, symbol) {
current_state = next_state;
} else {
return Err(ZiporaError::invalid_data("Failed to find inserted key"));
}
}
Ok(current_state as StateId)
}
fn len(&self) -> usize {
self.num_keys
}
fn contains(&self, key: &[u8]) -> bool {
let mut current_state = self.root_state;
for &symbol in key {
if let Some(next_state) = self.transitions.get_transition(current_state, symbol) {
current_state = next_state;
} else {
return false;
}
}
(current_state as usize) < self.states.len() && self.states[current_state as usize].is_terminal()
}
fn is_empty(&self) -> bool {
self.num_keys == 0
}
}
impl StatisticsProvider for NestedTrieDawg {
fn stats(&self) -> TrieStats {
let dawg_stats = self.statistics();
TrieStats {
num_states: dawg_stats.num_states,
num_keys: dawg_stats.num_keys,
num_transitions: dawg_stats.num_transitions,
max_depth: 0, avg_depth: 0.0, memory_usage: dawg_stats.memory_usage,
bits_per_key: if dawg_stats.num_keys > 0 {
(dawg_stats.memory_usage * 8) as f64 / dawg_stats.num_keys as f64
} else {
0.0
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dawg_state_creation() {
let state = DawgState::new(100, 50, true, false);
assert_eq!(state.child_base, 100);
assert_eq!(state.parent(), 50);
assert!(state.is_terminal());
assert!(!state.is_final());
}
#[test]
fn test_dawg_state_flags() {
let mut state = DawgState::new(100, 50, false, false);
assert!(!state.is_terminal());
assert!(!state.is_final());
state.set_terminal(true);
assert!(state.is_terminal());
state.set_final(true);
assert!(state.is_final());
}
#[test]
fn test_transition_table_dense() {
let mut table = TransitionTable::new(10, true);
table.add_transition(0, b'a', 1).unwrap();
table.add_transition(0, b'b', 2).unwrap();
assert_eq!(table.get_transition(0, b'a'), Some(1));
assert_eq!(table.get_transition(0, b'b'), Some(2));
assert_eq!(table.get_transition(0, b'c'), None);
let transitions = table.get_outgoing_transitions(0);
assert_eq!(transitions.len(), 2);
assert!(transitions.contains(&(b'a', 1)));
assert!(transitions.contains(&(b'b', 2)));
}
#[test]
fn test_transition_table_sparse() {
let mut table = TransitionTable::new(10, false);
table.add_transition(0, b'a', 1).unwrap();
table.add_transition(0, b'b', 2).unwrap();
assert_eq!(table.get_transition(0, b'a'), Some(1));
assert_eq!(table.get_transition(0, b'b'), Some(2));
assert_eq!(table.get_transition(0, b'c'), None);
let transitions = table.get_outgoing_transitions(0);
assert_eq!(transitions.len(), 2);
assert!(transitions.contains(&(b'a', 1)));
assert!(transitions.contains(&(b'b', 2)));
}
#[test]
fn test_nested_trie_dawg_basic() {
let mut dawg = NestedTrieDawg::new().unwrap();
let keys = vec!["cat".as_bytes(), "car".as_bytes(), "card".as_bytes(), "care".as_bytes(), "careful".as_bytes()];
dawg.build_from_keys(keys).unwrap();
assert!(dawg.contains(b"cat"));
assert!(dawg.contains(b"car"));
assert!(dawg.contains(b"card"));
assert!(dawg.contains(b"care"));
assert!(dawg.contains(b"careful"));
assert!(!dawg.contains(b"dog"));
assert!(!dawg.contains(b"ca"));
let stats = dawg.statistics();
assert_eq!(stats.num_keys, 5);
assert!(stats.num_states > 0);
assert!(stats.memory_usage > 0);
}
#[test]
fn test_nested_trie_dawg_prefix_search() {
let mut dawg = NestedTrieDawg::new().unwrap();
let keys = vec!["computer".as_bytes(), "computation".as_bytes(), "compute".as_bytes(), "computing".as_bytes()];
dawg.build_from_keys(keys).unwrap();
assert!(dawg.contains(b"computer"));
assert!(dawg.contains(b"computation"));
assert!(dawg.contains(b"compute"));
assert!(dawg.contains(b"computing"));
}
#[test]
fn test_nested_trie_dawg_longest_prefix() {
let mut dawg = NestedTrieDawg::new().unwrap();
let keys = vec!["app".as_bytes(), "apple".as_bytes(), "application".as_bytes()];
dawg.build_from_keys(keys).unwrap();
assert_eq!(dawg.longest_prefix(b"app"), Some(3));
assert_eq!(dawg.longest_prefix(b"apple"), Some(5));
assert_eq!(dawg.longest_prefix(b"application"), Some(11));
assert_eq!(dawg.longest_prefix(b"applications"), Some(11));
assert_eq!(dawg.longest_prefix(b"ap"), None);
}
#[test]
fn test_dawg_compression() {
let mut dawg = NestedTrieDawg::new().unwrap();
let keys = vec![
"ending".as_bytes(), "reading".as_bytes(), "heading".as_bytes(), "leading".as_bytes(),
"sending".as_bytes(), "bending".as_bytes(), "pending".as_bytes(), "mending".as_bytes()
];
dawg.build_from_keys(keys).unwrap();
let stats = dawg.statistics();
assert_eq!(stats.num_keys, 8);
assert!(stats.num_states < stats.num_keys * 7); assert!(stats.compression_ratio < 8.0); }
#[test]
fn test_dawg_config_variants() {
let memory_config = DawgConfig::memory_efficient();
let performance_config = DawgConfig::performance_optimized();
assert!(memory_config.max_states < performance_config.max_states);
assert!(memory_config.compressed_storage);
}
#[test]
fn test_dawg_empty_and_clear() {
let mut dawg = NestedTrieDawg::new().unwrap();
assert!(dawg.is_empty());
assert_eq!(dawg.len(), 0);
dawg.build_from_keys(vec![b"test"]).unwrap();
assert!(!dawg.is_empty());
assert_eq!(dawg.len(), 1);
dawg.clear();
assert!(dawg.is_empty());
assert_eq!(dawg.len(), 0);
}
}