use crate::StateId;
use crate::error::Result;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
pub trait FiniteStateAutomaton {
fn root(&self) -> StateId;
fn is_final(&self, state: StateId) -> bool;
fn transition(&self, state: StateId, symbol: u8) -> Option<StateId>;
fn transitions(&self, state: StateId) -> Vec<(u8, StateId)>;
fn accepts(&self, input: &[u8]) -> bool {
let mut state = self.root();
for &symbol in input {
match self.transition(state, symbol) {
Some(next_state) => state = next_state,
None => return false,
}
}
self.is_final(state)
}
fn longest_prefix(&self, input: &[u8]) -> Option<usize> {
let mut state = self.root();
let mut last_final = None;
for (i, &symbol) in input.iter().enumerate() {
if self.is_final(state) {
last_final = Some(i);
}
match self.transition(state, symbol) {
Some(next_state) => state = next_state,
None => return last_final,
}
}
if self.is_final(state) {
Some(input.len())
} else {
last_final
}
}
}
pub trait PrefixIterable: FiniteStateAutomaton {
fn iter_prefix(&self, prefix: &[u8]) -> Box<dyn Iterator<Item = Vec<u8>> + '_>;
fn iter_all(&self) -> Box<dyn Iterator<Item = Vec<u8>> + '_> {
self.iter_prefix(&[])
}
}
pub trait Trie: FiniteStateAutomaton {
fn insert(&mut self, key: &[u8]) -> Result<StateId>;
fn lookup(&self, key: &[u8]) -> Option<StateId> {
let mut state = self.root();
for &symbol in key {
state = self.transition(state, symbol)?;
}
if self.is_final(state) {
Some(state)
} else {
None
}
}
fn contains(&self, key: &[u8]) -> bool {
self.lookup(key).is_some()
}
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct TrieStats {
pub num_states: usize,
pub num_keys: usize,
pub num_transitions: usize,
pub max_depth: usize,
pub avg_depth: f64,
pub memory_usage: usize,
pub bits_per_key: f64,
}
impl TrieStats {
pub fn new() -> Self {
Self::default()
}
pub fn calculate_bits_per_key(&mut self) {
if self.num_keys > 0 {
self.bits_per_key = (self.memory_usage * 8) as f64 / self.num_keys as f64;
}
}
pub fn calculate_avg_depth(&mut self, total_depth: usize) {
if self.num_keys > 0 {
self.avg_depth = total_depth as f64 / self.num_keys as f64;
}
}
}
pub trait StatisticsProvider {
fn stats(&self) -> TrieStats;
fn memory_usage(&self) -> usize {
self.stats().memory_usage
}
fn bits_per_key(&self) -> f64 {
self.stats().bits_per_key
}
}
#[derive(Debug, Clone)]
pub enum FsaError {
InvalidState(StateId),
InvalidSymbol(u8),
ConstructionFailed(String),
NotSupported(String),
}
impl std::fmt::Display for FsaError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FsaError::InvalidState(state) => write!(f, "Invalid state ID: {}", state),
FsaError::InvalidSymbol(symbol) => write!(f, "Invalid symbol: {}", symbol),
FsaError::ConstructionFailed(msg) => write!(f, "Trie construction failed: {}", msg),
FsaError::NotSupported(op) => write!(f, "Operation not supported: {}", op),
}
}
}
impl std::error::Error for FsaError {}
#[cfg(test)]
mod tests {
use super::*;
struct MockTrie {
keys: std::collections::HashSet<Vec<u8>>,
}
impl MockTrie {
fn new() -> Self {
Self {
keys: std::collections::HashSet::new(),
}
}
}
impl FiniteStateAutomaton for MockTrie {
fn root(&self) -> StateId {
0
}
fn is_final(&self, _state: StateId) -> bool {
true }
fn transition(&self, _state: StateId, _symbol: u8) -> Option<StateId> {
Some(1) }
fn transitions(&self, _state: StateId) -> Vec<(u8, StateId)> {
Vec::new()
}
}
impl Trie for MockTrie {
fn insert(&mut self, key: &[u8]) -> Result<StateId> {
self.keys.insert(key.to_vec());
Ok(1)
}
fn lookup(&self, key: &[u8]) -> Option<StateId> {
if self.keys.contains(key) {
Some(1)
} else {
None
}
}
fn len(&self) -> usize {
self.keys.len()
}
}
#[test]
fn test_trie_basic_operations() {
let mut trie = MockTrie::new();
assert!(trie.is_empty());
trie.insert(b"hello").unwrap();
trie.insert(b"world").unwrap();
assert_eq!(trie.len(), 2);
assert!(!trie.is_empty());
assert!(trie.contains(b"hello"));
assert!(trie.contains(b"world"));
assert!(!trie.contains(b"foo"));
}
#[test]
fn test_fsa_accepts() {
let trie = MockTrie::new();
assert!(trie.accepts(b"anything"));
}
#[test]
fn test_trie_stats() {
let mut stats = TrieStats::new();
stats.num_keys = 100;
stats.memory_usage = 1024;
stats.calculate_bits_per_key();
assert!((stats.bits_per_key - 81.92).abs() < 0.01);
stats.calculate_avg_depth(500);
assert!((stats.avg_depth - 5.0).abs() < 0.01);
}
#[test]
fn test_fsa_error_display() {
let error = FsaError::InvalidState(42);
assert_eq!(error.to_string(), "Invalid state ID: 42");
let error = FsaError::InvalidSymbol(65);
assert_eq!(error.to_string(), "Invalid symbol: 65");
let error = FsaError::ConstructionFailed("test".to_string());
assert_eq!(error.to_string(), "Trie construction failed: test");
let error = FsaError::NotSupported("test op".to_string());
assert_eq!(error.to_string(), "Operation not supported: test op");
}
}