use super::traits::{MutableWfst, Wfst};
use super::{StateId, WeightedTransition, WfstState, NO_STATE};
use crate::semiring::Semiring;
#[derive(Clone, Debug)]
pub struct VectorWfst<L, W: Semiring> {
states: Vec<WfstState<L, W>>,
start: StateId,
}
impl<L, W: Semiring> VectorWfst<L, W> {
#[inline]
pub fn new() -> Self {
Self {
states: Vec::new(),
start: NO_STATE,
}
}
#[inline]
pub fn with_capacity(num_states: usize) -> Self {
Self {
states: Vec::with_capacity(num_states),
start: NO_STATE,
}
}
#[inline]
pub fn state_mut(&mut self, state: StateId) -> Option<&mut WfstState<L, W>> {
self.states.get_mut(state as usize)
}
pub fn sort_transitions<F>(&mut self, compare: F)
where
F: Fn(&WeightedTransition<L, W>, &WeightedTransition<L, W>) -> std::cmp::Ordering + Copy,
{
for state in &mut self.states {
state.transitions.sort_by(compare);
}
}
pub fn final_states(&self) -> impl Iterator<Item = StateId> + '_ {
self.states.iter().filter(|s| s.is_final).map(|s| s.id)
}
pub fn shrink_to_fit(&mut self) {
self.states.shrink_to_fit();
for state in &mut self.states {
state.transitions.shrink_to_fit();
}
}
}
impl<L, W: Semiring> Default for VectorWfst<L, W> {
fn default() -> Self {
Self::new()
}
}
impl<L: Clone + Send + Sync, W: Semiring> Wfst<L, W> for VectorWfst<L, W> {
#[inline]
fn start(&self) -> StateId {
self.start
}
#[inline]
fn is_final(&self, state: StateId) -> bool {
self.states
.get(state as usize)
.map(|s| s.is_final)
.unwrap_or(false)
}
#[inline]
fn final_weight(&self, state: StateId) -> W {
self.states
.get(state as usize)
.map(|s| s.final_weight)
.unwrap_or_else(W::zero)
}
#[inline]
fn transitions(&self, state: StateId) -> &[WeightedTransition<L, W>] {
self.states
.get(state as usize)
.map(|s| s.transitions.as_slice())
.unwrap_or(&[])
}
#[inline]
fn num_states(&self) -> usize {
self.states.len()
}
}
impl<L: Clone + Send + Sync, W: Semiring> MutableWfst<L, W> for VectorWfst<L, W> {
fn add_state(&mut self) -> StateId {
let id = self.states.len() as StateId;
self.states.push(WfstState::new(id));
id
}
#[inline]
fn set_start(&mut self, state: StateId) {
debug_assert!(
(state as usize) < self.states.len(),
"Invalid start state: {}",
state
);
self.start = state;
}
fn set_final(&mut self, state: StateId, weight: W) {
if let Some(s) = self.states.get_mut(state as usize) {
if weight.is_zero() {
s.is_final = false;
s.final_weight = W::zero();
} else {
s.is_final = true;
s.final_weight = weight;
}
}
}
fn add_transition(&mut self, transition: WeightedTransition<L, W>) {
if let Some(s) = self.states.get_mut(transition.from as usize) {
s.transitions.push(transition);
}
}
fn reserve_states(&mut self, additional: usize) {
self.states.reserve(additional);
}
fn reserve_transitions(&mut self, state: StateId, additional: usize) {
if let Some(s) = self.states.get_mut(state as usize) {
s.transitions.reserve(additional);
}
}
fn clear_transitions(&mut self, state: StateId) {
if let Some(s) = self.states.get_mut(state as usize) {
s.transitions.clear();
}
}
}
#[derive(Clone, Debug)]
pub struct VectorWfstBuilder<L, W: Semiring> {
fst: VectorWfst<L, W>,
}
impl<L: Clone + Send + Sync, W: Semiring> VectorWfstBuilder<L, W> {
pub fn new() -> Self {
Self {
fst: VectorWfst::new(),
}
}
pub fn with_capacity(num_states: usize) -> Self {
Self {
fst: VectorWfst::with_capacity(num_states),
}
}
pub fn add_states(mut self, count: usize) -> Self {
self.fst.add_states(count);
self
}
pub fn start(mut self, state: StateId) -> Self {
self.fst.set_start(state);
self
}
pub fn final_state(mut self, state: StateId, weight: W) -> Self {
self.fst.set_final(state, weight);
self
}
pub fn arc(
mut self,
from: StateId,
input: Option<L>,
output: Option<L>,
to: StateId,
weight: W,
) -> Self {
self.fst.add_arc(from, input, output, to, weight);
self
}
pub fn epsilon(mut self, from: StateId, to: StateId, weight: W) -> Self {
self.fst.add_epsilon(from, to, weight);
self
}
pub fn build(self) -> VectorWfst<L, W> {
self.fst
}
}
impl<L: Clone + Send + Sync, W: Semiring> Default for VectorWfstBuilder<L, W> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
#[test]
fn test_empty_wfst() {
let fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
assert!(fst.is_empty());
assert_eq!(fst.num_states(), 0);
assert_eq!(fst.start(), NO_STATE);
}
#[test]
fn test_add_states() {
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
assert_eq!(s0, 0);
assert_eq!(s1, 1);
assert_eq!(fst.num_states(), 2);
}
#[test]
fn test_start_and_final() {
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, TropicalWeight::new(0.5));
assert_eq!(fst.start(), s0);
assert!(!fst.is_final(s0));
assert!(fst.is_final(s1));
assert_eq!(fst.final_weight(s1).value(), 0.5);
}
#[test]
fn test_transitions() {
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
let s2 = fst.add_state();
fst.add_arc(s0, Some('a'), Some('x'), s1, TropicalWeight::new(1.0));
fst.add_arc(s0, Some('b'), Some('y'), s2, TropicalWeight::new(2.0));
fst.add_epsilon(s1, s2, TropicalWeight::new(0.5));
assert_eq!(fst.transitions(s0).len(), 2);
assert_eq!(fst.transitions(s1).len(), 1);
assert_eq!(fst.transitions(s2).len(), 0);
}
#[test]
fn test_builder() {
let fst: VectorWfst<char, TropicalWeight> = VectorWfstBuilder::new()
.add_states(3)
.start(0)
.arc(0, Some('a'), Some('b'), 1, TropicalWeight::new(1.0))
.arc(1, Some('c'), Some('d'), 2, TropicalWeight::new(2.0))
.final_state(2, TropicalWeight::one())
.build();
assert_eq!(fst.num_states(), 3);
assert_eq!(fst.start(), 0);
assert!(fst.is_final(2));
assert_eq!(fst.transitions(0).len(), 1);
assert_eq!(fst.transitions(1).len(), 1);
}
#[test]
fn test_final_states_iterator() {
let fst: VectorWfst<char, TropicalWeight> = VectorWfstBuilder::new()
.add_states(4)
.start(0)
.final_state(1, TropicalWeight::one())
.final_state(3, TropicalWeight::one())
.build();
let finals: Vec<_> = fst.final_states().collect();
assert_eq!(finals, vec![1, 3]);
}
}