use smallvec::SmallVec;
use super::lazy::{LazyState, LazyWfstWrapper, StateSource};
use super::{StateId, WeightedTransition, Wfst};
use crate::semiring::Semiring;
#[derive(Clone)]
pub struct UnionSource<L, W, T1, T2>
where
W: Semiring,
T1: Wfst<L, W>,
T2: Wfst<L, W>,
{
fst1: T1,
fst2: T2,
n1: usize,
_phantom: std::marker::PhantomData<(L, W)>,
}
impl<L, W, T1, T2> UnionSource<L, W, T1, T2>
where
W: Semiring,
L: Clone + Send + Sync,
T1: Wfst<L, W>,
T2: Wfst<L, W>,
{
pub fn new(fst1: T1, fst2: T2) -> Self {
let n1 = fst1.num_states();
Self {
fst1,
fst2,
n1,
_phantom: std::marker::PhantomData,
}
}
#[inline]
fn decode_state(&self, state: StateId) -> (u8, StateId) {
if state == 0 {
(0, 0) } else if (state as usize) <= self.n1 {
(1, state - 1) } else {
(2, state - 1 - self.n1 as StateId) }
}
#[inline]
fn encode_fst1(&self, state: StateId) -> StateId {
state + 1
}
#[inline]
fn encode_fst2(&self, state: StateId) -> StateId {
state + 1 + self.n1 as StateId
}
}
impl<L, W, T1, T2> StateSource<L, W> for UnionSource<L, W, T1, T2>
where
W: Semiring,
L: Clone + Send + Sync,
T1: Wfst<L, W>,
T2: Wfst<L, W>,
{
fn compute_state(&self, state: StateId) -> LazyState<L, W> {
let (fst_idx, original) = self.decode_state(state);
match fst_idx {
0 => {
let mut transitions = SmallVec::new();
let start1 = self.fst1.start();
let start2 = self.fst2.start();
if start1 != super::NO_STATE {
transitions.push(WeightedTransition::epsilon(
state,
self.encode_fst1(start1),
W::one(),
));
}
if start2 != super::NO_STATE {
transitions.push(WeightedTransition::epsilon(
state,
self.encode_fst2(start2),
W::one(),
));
}
LazyState::non_final(transitions)
}
1 => {
let is_final = self.fst1.is_final(original);
let final_weight = self.fst1.final_weight(original);
let transitions: SmallVec<[WeightedTransition<L, W>; 4]> = self
.fst1
.transitions(original)
.iter()
.map(|t| WeightedTransition {
from: state,
input: t.input.clone(),
output: t.output.clone(),
to: self.encode_fst1(t.to),
weight: t.weight,
})
.collect();
if is_final {
LazyState::final_state(final_weight, transitions)
} else {
LazyState::non_final(transitions)
}
}
2 => {
let is_final = self.fst2.is_final(original);
let final_weight = self.fst2.final_weight(original);
let transitions: SmallVec<[WeightedTransition<L, W>; 4]> = self
.fst2
.transitions(original)
.iter()
.map(|t| WeightedTransition {
from: state,
input: t.input.clone(),
output: t.output.clone(),
to: self.encode_fst2(t.to),
weight: t.weight,
})
.collect();
if is_final {
LazyState::final_state(final_weight, transitions)
} else {
LazyState::non_final(transitions)
}
}
_ => unreachable!(),
}
}
fn start(&self) -> StateId {
0 }
fn num_states_hint(&self) -> Option<usize> {
Some(1 + self.n1 + self.fst2.num_states())
}
}
pub type UnionWfst<L, W, T1, T2> = LazyWfstWrapper<UnionSource<L, W, T1, T2>, L, W>;
pub fn union<L, W, T1, T2>(fst1: &T1, fst2: &T2) -> UnionWfst<L, W, T1, T2>
where
W: Semiring,
L: Clone + Send + Sync,
T1: Wfst<L, W>,
T2: Wfst<L, W>,
{
let source = UnionSource::new(fst1.clone(), fst2.clone());
LazyWfstWrapper::new(source)
}
#[derive(Clone)]
pub struct ConcatSource<L, W, T1, T2>
where
W: Semiring,
T1: Wfst<L, W>,
T2: Wfst<L, W>,
{
fst1: T1,
fst2: T2,
n1: usize,
_phantom: std::marker::PhantomData<(L, W)>,
}
impl<L, W, T1, T2> ConcatSource<L, W, T1, T2>
where
W: Semiring,
L: Clone + Send + Sync,
T1: Wfst<L, W>,
T2: Wfst<L, W>,
{
pub fn new(fst1: T1, fst2: T2) -> Self {
let n1 = fst1.num_states();
Self {
fst1,
fst2,
n1,
_phantom: std::marker::PhantomData,
}
}
#[inline]
fn is_fst1_state(&self, state: StateId) -> bool {
(state as usize) < self.n1
}
#[inline]
fn decode_state(&self, state: StateId) -> (bool, StateId) {
if self.is_fst1_state(state) {
(true, state)
} else {
(false, state - self.n1 as StateId)
}
}
#[inline]
fn encode_fst2(&self, state: StateId) -> StateId {
state + self.n1 as StateId
}
}
impl<L, W, T1, T2> StateSource<L, W> for ConcatSource<L, W, T1, T2>
where
W: Semiring,
L: Clone + Send + Sync,
T1: Wfst<L, W>,
T2: Wfst<L, W>,
{
fn compute_state(&self, state: StateId) -> LazyState<L, W> {
let (is_fst1, original) = self.decode_state(state);
if is_fst1 {
let is_final_in_fst1 = self.fst1.is_final(original);
let final_weight_fst1 = self.fst1.final_weight(original);
let mut transitions: SmallVec<[WeightedTransition<L, W>; 4]> = self
.fst1
.transitions(original)
.iter()
.map(|t| WeightedTransition {
from: state,
input: t.input.clone(),
output: t.output.clone(),
to: t.to, weight: t.weight,
})
.collect();
if is_final_in_fst1 {
let start2 = self.fst2.start();
if start2 != super::NO_STATE {
transitions.push(WeightedTransition::epsilon(
state,
self.encode_fst2(start2),
final_weight_fst1,
));
}
}
LazyState::non_final(transitions)
} else {
let is_final = self.fst2.is_final(original);
let final_weight = self.fst2.final_weight(original);
let transitions: SmallVec<[WeightedTransition<L, W>; 4]> = self
.fst2
.transitions(original)
.iter()
.map(|t| WeightedTransition {
from: state,
input: t.input.clone(),
output: t.output.clone(),
to: self.encode_fst2(t.to),
weight: t.weight,
})
.collect();
if is_final {
LazyState::final_state(final_weight, transitions)
} else {
LazyState::non_final(transitions)
}
}
}
fn start(&self) -> StateId {
self.fst1.start()
}
fn num_states_hint(&self) -> Option<usize> {
Some(self.n1 + self.fst2.num_states())
}
}
pub type ConcatWfst<L, W, T1, T2> = LazyWfstWrapper<ConcatSource<L, W, T1, T2>, L, W>;
pub fn concat<L, W, T1, T2>(fst1: &T1, fst2: &T2) -> ConcatWfst<L, W, T1, T2>
where
W: Semiring,
L: Clone + Send + Sync,
T1: Wfst<L, W>,
T2: Wfst<L, W>,
{
let source = ConcatSource::new(fst1.clone(), fst2.clone());
LazyWfstWrapper::new(source)
}
#[derive(Clone)]
pub struct ClosureSource<L, W, T>
where
W: Semiring,
T: Wfst<L, W>,
{
fst: T,
n: usize,
_phantom: std::marker::PhantomData<(L, W)>,
}
impl<L, W, T> ClosureSource<L, W, T>
where
W: Semiring,
L: Clone + Send + Sync,
T: Wfst<L, W>,
{
pub fn new(fst: T) -> Self {
let n = fst.num_states();
Self {
fst,
n,
_phantom: std::marker::PhantomData,
}
}
#[inline]
fn is_super_start(&self, state: StateId) -> bool {
state == 0
}
#[inline]
fn decode_state(&self, state: StateId) -> StateId {
state - 1
}
#[inline]
fn encode_state(&self, state: StateId) -> StateId {
state + 1
}
}
impl<L, W, T> StateSource<L, W> for ClosureSource<L, W, T>
where
W: Semiring,
L: Clone + Send + Sync,
T: Wfst<L, W>,
{
fn compute_state(&self, state: StateId) -> LazyState<L, W> {
if self.is_super_start(state) {
let mut transitions = SmallVec::new();
let fst_start = self.fst.start();
if fst_start != super::NO_STATE {
transitions.push(WeightedTransition::epsilon(
state,
self.encode_state(fst_start),
W::one(),
));
}
LazyState::final_state(W::one(), transitions)
} else {
let original = self.decode_state(state);
let is_final = self.fst.is_final(original);
let final_weight = self.fst.final_weight(original);
let mut transitions: SmallVec<[WeightedTransition<L, W>; 4]> = self
.fst
.transitions(original)
.iter()
.map(|t| WeightedTransition {
from: state,
input: t.input.clone(),
output: t.output.clone(),
to: self.encode_state(t.to),
weight: t.weight,
})
.collect();
if is_final {
let fst_start = self.fst.start();
if fst_start != super::NO_STATE {
transitions.push(WeightedTransition::epsilon(
state,
self.encode_state(fst_start),
final_weight,
));
}
}
if is_final {
LazyState::final_state(final_weight, transitions)
} else {
LazyState::non_final(transitions)
}
}
}
fn start(&self) -> StateId {
0 }
fn num_states_hint(&self) -> Option<usize> {
Some(1 + self.n)
}
}
pub type ClosureWfst<L, W, T> = LazyWfstWrapper<ClosureSource<L, W, T>, L, W>;
pub fn closure<L, W, T>(fst: &T) -> ClosureWfst<L, W, T>
where
W: Semiring,
L: Clone + Send + Sync,
T: Wfst<L, W>,
{
let source = ClosureSource::new(fst.clone());
LazyWfstWrapper::new(source)
}
pub fn closure_plus<L, W, T>(fst: &T) -> ConcatWfst<L, W, T, ClosureWfst<L, W, T>>
where
W: Semiring,
L: Clone + Send + Sync,
T: Wfst<L, W>,
{
concat(fst, &closure(fst))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
use crate::wfst::{LazyWfst, VectorWfst, VectorWfstBuilder};
fn make_single_arc_fst(label: char) -> VectorWfst<char, TropicalWeight> {
VectorWfstBuilder::new()
.add_states(2)
.start(0)
.arc(0, Some(label), Some(label), 1, TropicalWeight::one())
.final_state(1, TropicalWeight::one())
.build()
}
#[test]
fn test_union_basic() {
let fst_a = make_single_arc_fst('a');
let fst_b = make_single_arc_fst('b');
let mut u = union(&fst_a, &fst_b);
assert_eq!(u.start(), 0);
let start_trans = u.transitions_lazy(0);
assert_eq!(start_trans.len(), 2);
assert!(start_trans[0].is_epsilon());
assert!(start_trans[1].is_epsilon());
assert_eq!(u.num_states(), 5); }
#[test]
fn test_union_final_states() {
let fst_a = make_single_arc_fst('a');
let fst_b = make_single_arc_fst('b');
let mut u = union(&fst_a, &fst_b);
for i in 0..5 {
u.expand(i);
}
assert!(u.is_final(2));
assert!(u.is_final(4));
assert!(!u.is_final(0));
}
#[test]
fn test_concat_basic() {
let fst_a = make_single_arc_fst('a');
let fst_b = make_single_arc_fst('b');
let mut c = concat(&fst_a, &fst_b);
assert_eq!(c.start(), 0);
let s0_trans = c.transitions_lazy(0);
assert_eq!(s0_trans.len(), 1);
assert_eq!(s0_trans[0].input, Some('a'));
let s1_trans = c.transitions_lazy(1);
assert!(s1_trans.iter().any(|t| t.is_epsilon()));
}
#[test]
fn test_concat_final_states() {
let fst_a = make_single_arc_fst('a');
let fst_b = make_single_arc_fst('b');
let mut c = concat(&fst_a, &fst_b);
for i in 0..4 {
c.expand(i);
}
assert!(!c.is_final(0)); assert!(!c.is_final(1)); assert!(!c.is_final(2)); assert!(c.is_final(3)); }
#[test]
fn test_closure_basic() {
let fst_a = make_single_arc_fst('a');
let mut k = closure(&fst_a);
assert_eq!(k.start(), 0);
k.expand(0);
assert!(k.is_final(0));
let s0_trans = k.transitions_lazy(0);
assert_eq!(s0_trans.len(), 1);
assert!(s0_trans[0].is_epsilon());
}
#[test]
fn test_closure_loop_back() {
let fst_a = make_single_arc_fst('a');
let mut k = closure(&fst_a);
let s2_trans = k.transitions_lazy(2);
assert!(s2_trans.iter().any(|t| t.is_epsilon() && t.to == 1));
}
#[test]
fn test_closure_plus() {
let fst_a = make_single_arc_fst('a');
let mut kp = closure_plus(&fst_a);
assert_eq!(kp.start(), 0);
kp.expand(0);
assert!(!kp.is_final(0));
}
#[test]
fn test_empty_fst_union() {
let empty: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let fst_a = make_single_arc_fst('a');
let mut u = union(&empty, &fst_a);
let start_trans = u.transitions_lazy(0);
assert_eq!(start_trans.len(), 1);
}
#[test]
fn test_union_commutativity_structure() {
let fst_a = make_single_arc_fst('a');
let fst_b = make_single_arc_fst('b');
let mut u1 = union(&fst_a, &fst_b);
let mut u2 = union(&fst_b, &fst_a);
assert_eq!(u1.num_states(), u2.num_states());
let u1_trans = u1.transitions_lazy(0);
let u2_trans = u2.transitions_lazy(0);
assert_eq!(u1_trans.len(), u2_trans.len());
let u1_finals: Vec<_> = (0..u1.num_states() as StateId)
.filter(|&s| {
u1.expand(s);
u1.is_final(s)
})
.collect();
let u2_finals: Vec<_> = (0..u2.num_states() as StateId)
.filter(|&s| {
u2.expand(s);
u2.is_final(s)
})
.collect();
assert_eq!(u1_finals.len(), u2_finals.len());
}
#[test]
fn test_union_associativity_states() {
let fst_a = make_single_arc_fst('a');
let fst_b = make_single_arc_fst('b');
let fst_c = make_single_arc_fst('c');
fn count_finals<S>(fst: &mut LazyWfstWrapper<S, char, TropicalWeight>) -> usize
where
S: StateSource<char, TropicalWeight> + Send + Sync,
{
let n = fst.num_states();
(0..n as StateId)
.filter(|&s| {
fst.expand(s);
fst.is_final(s)
})
.count()
}
let mut u12 = union(&fst_a, &fst_b);
for s in 0..u12.num_states() as StateId {
u12.expand(s);
}
let mut u12_3 = union(&u12, &fst_c);
let mut u23 = union(&fst_b, &fst_c);
for s in 0..u23.num_states() as StateId {
u23.expand(s);
}
let mut u1_23 = union(&fst_a, &u23);
assert_eq!(count_finals(&mut u12_3), 3);
assert_eq!(count_finals(&mut u1_23), 3);
}
#[test]
fn test_concat_associativity_path_length() {
let fst_a = make_single_arc_fst('a');
let fst_b = make_single_arc_fst('b');
let fst_c = make_single_arc_fst('c');
fn count_arcs<S>(fst: &mut LazyWfstWrapper<S, char, TropicalWeight>) -> usize
where
S: StateSource<char, TropicalWeight> + Send + Sync,
{
let n = fst.num_states();
(0..n as StateId)
.flat_map(|s| fst.transitions_lazy(s).to_vec())
.filter(|t| !t.is_epsilon())
.count()
}
let mut c12 = concat(&fst_a, &fst_b);
for s in 0..c12.num_states() as StateId {
c12.expand(s);
}
let mut c12_3 = concat(&c12, &fst_c);
let mut c23 = concat(&fst_b, &fst_c);
for s in 0..c23.num_states() as StateId {
c23.expand(s);
}
let mut c1_23 = concat(&fst_a, &c23);
assert_eq!(count_arcs(&mut c12_3), 3);
assert_eq!(count_arcs(&mut c1_23), 3);
}
#[test]
fn test_closure_idempotence() {
let fst_a = make_single_arc_fst('a');
let mut k = closure(&fst_a);
for s in 0..k.num_states() as StateId {
k.expand(s);
}
let mut kk = closure(&k);
k.expand(0);
kk.expand(0);
assert!(k.is_final(0));
assert!(kk.is_final(0));
}
#[test]
fn test_union_identity() {
let fst_a = make_single_arc_fst('a');
let empty: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let mut u = union(&fst_a, &empty);
let start_trans = u.transitions_lazy(0);
assert_eq!(start_trans.len(), 1);
u.expand(2); assert!(u.is_final(2));
}
#[test]
fn test_concat_with_closure_distributivity() {
let fst_a = make_single_arc_fst('a');
let mut k = closure(&fst_a);
for s in 0..k.num_states() as StateId {
k.expand(s);
}
let mut c = concat(&fst_a, &k);
c.expand(0);
assert!(!c.is_final(0));
let n = c.num_states();
let has_final = (0..n as StateId).any(|s| {
c.expand(s);
c.is_final(s)
});
assert!(has_final);
}
mod property_tests {
use super::*;
use crate::test_utils::arb_tropical_wfst;
use crate::wfst::NO_STATE;
use proptest::prelude::*;
proptest! {
#[test]
fn union_state_count(
fst1 in arb_tropical_wfst(5, 2),
fst2 in arb_tropical_wfst(5, 2)
) {
let u = union(&fst1, &fst2);
let expected = 1 + fst1.num_states() + fst2.num_states();
prop_assert_eq!(u.num_states(), expected);
}
#[test]
fn union_identity_with_empty(
fst in arb_tropical_wfst(5, 2)
) {
let empty: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let mut u = union(&fst, &empty);
if fst.start() != NO_STATE {
let trans = u.transitions_lazy(0);
prop_assert_eq!(trans.len(), 1, "Union with empty should have 1 epsilon");
}
}
#[test]
fn concat_state_count(
fst1 in arb_tropical_wfst(5, 2),
fst2 in arb_tropical_wfst(5, 2)
) {
let c = concat(&fst1, &fst2);
let expected = fst1.num_states() + fst2.num_states();
prop_assert_eq!(c.num_states(), expected);
}
#[test]
fn closure_state_count(
fst in arb_tropical_wfst(5, 2)
) {
let k = closure(&fst);
let expected = 1 + fst.num_states();
prop_assert_eq!(k.num_states(), expected);
}
#[test]
fn closure_accepts_empty(
fst in arb_tropical_wfst(5, 2)
) {
let mut k = closure(&fst);
k.expand(0);
prop_assert!(k.is_final(0), "Closure super-start should be final");
}
#[test]
fn closure_plus_start_not_final(
fst in arb_tropical_wfst(5, 2)
) {
if fst.num_states() == 0 || fst.start() == NO_STATE {
return Ok(());
}
let mut kp = closure_plus(&fst);
kp.expand(0);
prop_assert!(
!kp.is_final(0),
"Closure+ start should never be final (concat makes fst1 non-final)"
);
}
#[test]
fn union_preserves_finals(
fst1 in arb_tropical_wfst(5, 2),
fst2 in arb_tropical_wfst(5, 2)
) {
let mut u = union(&fst1, &fst2);
let finals1: usize = (0..fst1.num_states() as StateId)
.filter(|&s| fst1.is_final(s))
.count();
let finals2: usize = (0..fst2.num_states() as StateId)
.filter(|&s| fst2.is_final(s))
.count();
let union_finals: usize = (0..u.num_states() as StateId)
.filter(|&s| {
u.expand(s);
u.is_final(s)
})
.count();
prop_assert_eq!(union_finals, finals1 + finals2);
}
}
}
}