use std::fmt::Debug;
use super::{StateId, WeightedTransition, WfstState};
use crate::semiring::Semiring;
pub trait Wfst<L, W: Semiring>: Clone + Send + Sync {
fn start(&self) -> StateId;
fn is_final(&self, state: StateId) -> bool;
fn final_weight(&self, state: StateId) -> W;
fn transitions(&self, state: StateId) -> &[WeightedTransition<L, W>];
fn num_states(&self) -> usize;
#[inline]
fn is_valid_state(&self, state: StateId) -> bool {
(state as usize) < self.num_states()
}
#[inline]
fn num_transitions(&self, state: StateId) -> usize {
self.transitions(state).len()
}
fn total_transitions(&self) -> usize {
(0..self.num_states() as StateId)
.map(|s| self.num_transitions(s))
.sum()
}
#[inline]
fn is_empty(&self) -> bool {
self.num_states() == 0
}
fn state(&self, state: StateId) -> Option<WfstState<L, W>>
where
L: Clone,
{
if !self.is_valid_state(state) {
return None;
}
let mut s = if self.is_final(state) {
WfstState::final_state(state, self.final_weight(state))
} else {
WfstState::new(state)
};
s.transitions = self.transitions(state).iter().cloned().collect();
Some(s)
}
}
pub trait MutableWfst<L, W: Semiring>: Wfst<L, W> {
fn add_state(&mut self) -> StateId;
fn add_states(&mut self, count: usize) -> StateId {
let first = self.add_state();
for _ in 1..count {
self.add_state();
}
first
}
fn set_start(&mut self, state: StateId);
fn set_final(&mut self, state: StateId, weight: W);
fn clear_final(&mut self, state: StateId) {
self.set_final(state, W::zero());
}
fn add_transition(&mut self, transition: WeightedTransition<L, W>);
#[inline]
fn add_arc(
&mut self,
from: StateId,
input: Option<L>,
output: Option<L>,
to: StateId,
weight: W,
) {
self.add_transition(WeightedTransition::new(from, input, output, to, weight));
}
#[inline]
fn add_epsilon(&mut self, from: StateId, to: StateId, weight: W) {
self.add_transition(WeightedTransition::epsilon(from, to, weight));
}
fn reserve_states(&mut self, additional: usize);
fn reserve_transitions(&mut self, state: StateId, additional: usize);
fn clear_transitions(&mut self, state: StateId);
fn set_transitions(&mut self, state: StateId, transitions: Vec<WeightedTransition<L, W>>)
where
L: Clone,
{
self.clear_transitions(state);
for trans in transitions {
self.add_transition(trans);
}
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum CachePolicy {
#[default]
CacheAll,
Lru {
max_states: usize,
},
NoCache,
}
pub trait LazyWfst<L, W: Semiring>: Wfst<L, W> {
fn is_expanded(&self, state: StateId) -> bool;
fn expand(&mut self, state: StateId);
fn transitions_lazy(&mut self, state: StateId) -> &[WeightedTransition<L, W>];
fn cache_policy(&self) -> CachePolicy;
fn set_cache_policy(&mut self, policy: CachePolicy);
fn computed_states(&self) -> usize;
fn clear_cache(&mut self);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_policy_default() {
let policy = CachePolicy::default();
assert_eq!(policy, CachePolicy::CacheAll);
}
#[test]
fn test_cache_policy_lru() {
let policy = CachePolicy::Lru { max_states: 1000 };
if let CachePolicy::Lru { max_states } = policy {
assert_eq!(max_states, 1000);
} else {
panic!("Expected Lru policy");
}
}
}