use smallvec::SmallVec;
use super::{StateId, WeightedTransition};
use crate::semiring::Semiring;
#[derive(Clone, Debug)]
pub struct WfstState<L, W: Semiring> {
pub id: StateId,
pub is_final: bool,
pub final_weight: W,
pub transitions: SmallVec<[WeightedTransition<L, W>; 4]>,
}
impl<L, W: Semiring> WfstState<L, W> {
#[inline]
pub fn new(id: StateId) -> Self {
Self {
id,
is_final: false,
final_weight: W::zero(),
transitions: SmallVec::new(),
}
}
#[inline]
pub fn final_state(id: StateId, weight: W) -> Self {
Self {
id,
is_final: true,
final_weight: weight,
transitions: SmallVec::new(),
}
}
#[inline]
pub fn add_transition(&mut self, transition: WeightedTransition<L, W>) {
debug_assert_eq!(
transition.from, self.id,
"Transition source must match state ID"
);
self.transitions.push(transition);
}
#[inline]
pub fn add_arc(&mut self, input: Option<L>, output: Option<L>, to: StateId, weight: W) {
self.transitions
.push(WeightedTransition::new(self.id, input, output, to, weight));
}
#[inline]
pub fn set_final(&mut self, weight: W) {
self.is_final = true;
self.final_weight = weight;
}
#[inline]
pub fn clear_final(&mut self) {
self.is_final = false;
self.final_weight = W::zero();
}
#[inline]
pub fn num_transitions(&self) -> usize {
self.transitions.len()
}
#[inline]
pub fn has_transitions(&self) -> bool {
!self.transitions.is_empty()
}
#[inline]
pub fn iter_transitions(&self) -> impl Iterator<Item = &WeightedTransition<L, W>> {
self.transitions.iter()
}
#[inline]
pub fn reserve_transitions(&mut self, additional: usize) {
self.transitions.reserve(additional);
}
}
impl<L: Clone, W: Semiring> WfstState<L, W> {
pub fn transitions_by_input<'a>(
&'a self,
input: &'a Option<L>,
) -> impl Iterator<Item = &'a WeightedTransition<L, W>>
where
L: PartialEq,
{
self.transitions.iter().filter(move |t| &t.input == input)
}
pub fn epsilon_transitions(&self) -> impl Iterator<Item = &WeightedTransition<L, W>> {
self.transitions.iter().filter(|t| t.is_epsilon_input())
}
pub fn labeled_transitions(&self) -> impl Iterator<Item = &WeightedTransition<L, W>> {
self.transitions.iter().filter(|t| !t.is_epsilon_input())
}
}
impl<L, W: Semiring> Default for WfstState<L, W> {
fn default() -> Self {
Self::new(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
#[test]
fn test_state_creation() {
let state: WfstState<char, TropicalWeight> = WfstState::new(0);
assert_eq!(state.id, 0);
assert!(!state.is_final);
assert!(state.final_weight.is_zero());
assert!(state.transitions.is_empty());
}
#[test]
fn test_final_state() {
let state: WfstState<char, TropicalWeight> =
WfstState::final_state(1, TropicalWeight::new(0.5));
assert_eq!(state.id, 1);
assert!(state.is_final);
assert_eq!(state.final_weight.value(), 0.5);
}
#[test]
fn test_add_transitions() {
let mut state: WfstState<char, TropicalWeight> = WfstState::new(0);
state.add_arc(Some('a'), Some('b'), 1, TropicalWeight::new(1.0));
state.add_arc(Some('c'), Some('d'), 2, TropicalWeight::new(2.0));
assert_eq!(state.num_transitions(), 2);
assert!(state.has_transitions());
}
#[test]
fn test_transition_filtering() {
let mut state: WfstState<char, TropicalWeight> = WfstState::new(0);
state.add_arc(Some('a'), Some('a'), 1, TropicalWeight::one());
state.add_arc(None, None, 2, TropicalWeight::one()); state.add_arc(Some('b'), Some('b'), 3, TropicalWeight::one());
assert_eq!(state.epsilon_transitions().count(), 1);
assert_eq!(state.labeled_transitions().count(), 2);
}
}