use smallvec::SmallVec;
use super::lazy::{LazyState, LazyWfstWrapper, StateSource};
use super::{MutableWfst, StateId, VectorWfst, WeightedTransition, Wfst, NO_STATE};
use crate::semiring::Semiring;
#[derive(Clone)]
pub struct InvertSource<L, W, T>
where
W: Semiring,
T: Wfst<L, W>,
{
fst: T,
_phantom: std::marker::PhantomData<(L, W)>,
}
impl<L, W, T> InvertSource<L, W, T>
where
W: Semiring,
L: Clone + Send + Sync,
T: Wfst<L, W>,
{
pub fn new(fst: T) -> Self {
Self {
fst,
_phantom: std::marker::PhantomData,
}
}
}
impl<L, W, T> StateSource<L, W> for InvertSource<L, W, T>
where
W: Semiring,
L: Clone + Send + Sync,
T: Wfst<L, W>,
{
fn compute_state(&self, state: StateId) -> LazyState<L, W> {
let is_final = self.fst.is_final(state);
let final_weight = self.fst.final_weight(state);
let transitions: SmallVec<[WeightedTransition<L, W>; 4]> = self
.fst
.transitions(state)
.iter()
.map(|t| WeightedTransition {
from: t.from,
input: t.output.clone(), output: t.input.clone(), to: t.to,
weight: t.weight,
})
.collect();
if is_final {
LazyState::final_state(final_weight, transitions)
} else {
LazyState::non_final(transitions)
}
}
fn start(&self) -> StateId {
self.fst.start()
}
fn num_states_hint(&self) -> Option<usize> {
Some(self.fst.num_states())
}
}
pub type InvertWfst<L, W, T> = LazyWfstWrapper<InvertSource<L, W, T>, L, W>;
pub fn invert<L, W, T>(fst: &T) -> InvertWfst<L, W, T>
where
W: Semiring,
L: Clone + Send + Sync,
T: Wfst<L, W>,
{
let source = InvertSource::new(fst.clone());
LazyWfstWrapper::new(source)
}
#[derive(Clone)]
pub struct ProjectSource<L, W, T, const INPUT: bool>
where
W: Semiring,
T: Wfst<L, W>,
{
fst: T,
_phantom: std::marker::PhantomData<(L, W)>,
}
impl<L, W, T, const INPUT: bool> ProjectSource<L, W, T, INPUT>
where
W: Semiring,
L: Clone + Send + Sync,
T: Wfst<L, W>,
{
pub fn new(fst: T) -> Self {
Self {
fst,
_phantom: std::marker::PhantomData,
}
}
}
impl<L, W, T, const INPUT: bool> StateSource<L, W> for ProjectSource<L, W, T, INPUT>
where
W: Semiring,
L: Clone + Send + Sync,
T: Wfst<L, W>,
{
fn compute_state(&self, state: StateId) -> LazyState<L, W> {
let is_final = self.fst.is_final(state);
let final_weight = self.fst.final_weight(state);
let transitions: SmallVec<[WeightedTransition<L, W>; 4]> = self
.fst
.transitions(state)
.iter()
.map(|t| {
if INPUT {
WeightedTransition {
from: t.from,
input: t.input.clone(),
output: t.input.clone(), to: t.to,
weight: t.weight,
}
} else {
WeightedTransition {
from: t.from,
input: t.output.clone(), output: t.output.clone(),
to: t.to,
weight: t.weight,
}
}
})
.collect();
if is_final {
LazyState::final_state(final_weight, transitions)
} else {
LazyState::non_final(transitions)
}
}
fn start(&self) -> StateId {
self.fst.start()
}
fn num_states_hint(&self) -> Option<usize> {
Some(self.fst.num_states())
}
}
pub type ProjectInputWfst<L, W, T> = LazyWfstWrapper<ProjectSource<L, W, T, true>, L, W>;
pub type ProjectOutputWfst<L, W, T> = LazyWfstWrapper<ProjectSource<L, W, T, false>, L, W>;
pub fn project_input<L, W, T>(fst: &T) -> ProjectInputWfst<L, W, T>
where
W: Semiring,
L: Clone + Send + Sync,
T: Wfst<L, W>,
{
let source = ProjectSource::<L, W, T, true>::new(fst.clone());
LazyWfstWrapper::new(source)
}
pub fn project_output<L, W, T>(fst: &T) -> ProjectOutputWfst<L, W, T>
where
W: Semiring,
L: Clone + Send + Sync,
T: Wfst<L, W>,
{
let source = ProjectSource::<L, W, T, false>::new(fst.clone());
LazyWfstWrapper::new(source)
}
pub fn reverse<L, W, T>(fst: &T) -> VectorWfst<L, W>
where
W: Semiring,
L: Clone + Send + Sync,
T: Wfst<L, W>,
{
let n = fst.num_states();
if n == 0 {
return VectorWfst::new();
}
let mut result: VectorWfst<L, W> = VectorWfst::with_capacity(n + 1);
result.add_state();
result.set_start(0);
for _ in 0..n {
result.add_state();
}
for orig_state in 0..n as StateId {
let reversed_state = orig_state + 1;
if fst.is_final(orig_state) {
let final_weight = fst.final_weight(orig_state);
result.add_epsilon(0, reversed_state, final_weight);
}
for t in fst.transitions(orig_state) {
let reversed_from = t.to + 1;
let reversed_to = orig_state + 1;
result.add_transition(WeightedTransition {
from: reversed_from,
input: t.input.clone(),
output: t.output.clone(),
to: reversed_to,
weight: t.weight,
});
}
}
let orig_start = fst.start();
if orig_start != NO_STATE {
result.set_final(orig_start + 1, W::one());
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
use crate::wfst::{LazyWfst, VectorWfstBuilder};
fn make_transducer() -> VectorWfst<char, TropicalWeight> {
VectorWfstBuilder::new()
.add_states(3)
.start(0)
.arc(0, Some('a'), Some('x'), 1, TropicalWeight::new(1.0))
.arc(1, Some('b'), Some('y'), 2, TropicalWeight::new(2.0))
.final_state(2, TropicalWeight::one())
.build()
}
#[test]
fn test_invert_basic() {
let fst = make_transducer();
let mut inv = invert(&fst);
let trans = inv.transitions_lazy(0);
assert_eq!(trans.len(), 1);
assert_eq!(trans[0].input, Some('x')); assert_eq!(trans[0].output, Some('a')); }
#[test]
fn test_invert_preserves_structure() {
let fst = make_transducer();
let mut inv = invert(&fst);
assert_eq!(inv.start(), fst.start());
assert_eq!(inv.num_states(), fst.num_states());
inv.expand(2);
assert!(inv.is_final(2));
}
#[test]
fn test_double_invert() {
let fst = make_transducer();
let mut inv1 = invert(&fst);
for state in 0..fst.num_states() as StateId {
inv1.expand(state);
}
let mut inv2 = invert(&inv1);
let orig_trans = fst.transitions(0);
let double_inv_trans = inv2.transitions_lazy(0);
assert_eq!(orig_trans[0].input, double_inv_trans[0].input);
assert_eq!(orig_trans[0].output, double_inv_trans[0].output);
}
#[test]
fn test_project_input() {
let fst = make_transducer();
let mut pin = project_input(&fst);
let trans = pin.transitions_lazy(0);
assert_eq!(trans[0].input, Some('a'));
assert_eq!(trans[0].output, Some('a')); }
#[test]
fn test_project_output() {
let fst = make_transducer();
let mut pout = project_output(&fst);
let trans = pout.transitions_lazy(0);
assert_eq!(trans[0].input, Some('x')); assert_eq!(trans[0].output, Some('x'));
}
#[test]
fn test_project_preserves_structure() {
let fst = make_transducer();
let mut pin = project_input(&fst);
assert_eq!(pin.start(), fst.start());
pin.expand(2);
assert!(pin.is_final(2));
}
#[test]
fn test_reverse_basic() {
let fst = make_transducer();
let rev = reverse(&fst);
assert_eq!(rev.start(), 0);
let s0_trans = rev.transitions(0);
assert_eq!(s0_trans.len(), 1);
assert!(s0_trans[0].is_epsilon());
assert_eq!(s0_trans[0].to, 3); }
#[test]
fn test_reverse_final_state() {
let fst = make_transducer();
let rev = reverse(&fst);
assert!(rev.is_final(1));
assert!(!rev.is_final(3));
}
#[test]
fn test_reverse_transition_direction() {
let fst = make_transducer();
let rev = reverse(&fst);
let s3_trans = rev.transitions(3);
assert_eq!(s3_trans.len(), 1);
assert_eq!(s3_trans[0].input, Some('b'));
assert_eq!(s3_trans[0].to, 2);
let s2_trans = rev.transitions(2);
assert_eq!(s2_trans.len(), 1);
assert_eq!(s2_trans[0].input, Some('a'));
assert_eq!(s2_trans[0].to, 1); }
#[test]
fn test_double_reverse() {
let fst = make_transducer();
let rev1 = reverse(&fst);
let rev2 = reverse(&rev1);
let original_arcs: usize = (0..fst.num_states())
.map(|s| fst.transitions(s as crate::wfst::StateId).len())
.sum();
let _ = rev1;
let final_arcs: usize = (0..rev2.num_states())
.map(|s| rev2.transitions(s as crate::wfst::StateId).len())
.sum();
assert!(
final_arcs >= original_arcs,
"double-reverse should preserve at least the original arc count ({} vs {})",
final_arcs,
original_arcs
);
}
#[test]
fn test_reverse_empty_fst() {
let fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let rev = reverse(&fst);
assert!(rev.is_empty());
}
#[test]
fn test_reverse_single_state_fst() {
let fst: VectorWfst<char, TropicalWeight> = VectorWfstBuilder::new()
.add_states(1)
.start(0)
.final_state(0, TropicalWeight::one())
.build();
let rev = reverse(&fst);
assert_eq!(rev.start(), 0);
assert!(rev.is_final(1));
}
#[test]
fn test_invert_epsilon_transitions() {
let fst: VectorWfst<char, TropicalWeight> = VectorWfstBuilder::new()
.add_states(2)
.start(0)
.epsilon(0, 1, TropicalWeight::one())
.final_state(1, TropicalWeight::one())
.build();
let mut inv = invert(&fst);
let trans = inv.transitions_lazy(0);
assert!(trans[0].is_epsilon());
}
#[test]
fn test_invert_involution() {
let fst = make_transducer();
let mut inv1 = invert(&fst);
for s in 0..fst.num_states() as StateId {
inv1.expand(s);
}
let mut inv2 = invert(&inv1);
for s in 0..fst.num_states() as StateId {
let orig = fst.transitions(s);
let double = inv2.transitions_lazy(s);
assert_eq!(
orig.len(),
double.len(),
"State {} transition count mismatch",
s
);
for (o, d) in orig.iter().zip(double.iter()) {
assert_eq!(o.input, d.input, "State {} input mismatch", s);
assert_eq!(o.output, d.output, "State {} output mismatch", s);
assert_eq!(o.to, d.to, "State {} destination mismatch", s);
}
}
}
#[test]
fn test_project_input_idempotence() {
let fst = make_transducer();
let mut p1 = project_input(&fst);
for s in 0..fst.num_states() as StateId {
p1.expand(s);
}
let mut p2 = project_input(&p1);
for s in 0..fst.num_states() as StateId {
let once = p1.transitions_lazy(s);
let twice = p2.transitions_lazy(s);
assert_eq!(once.len(), twice.len());
for (o, t) in once.iter().zip(twice.iter()) {
assert_eq!(o.input, t.input);
assert_eq!(o.output, t.output);
}
}
}
#[test]
fn test_project_output_idempotence() {
let fst = make_transducer();
let mut p1 = project_output(&fst);
for s in 0..fst.num_states() as StateId {
p1.expand(s);
}
let mut p2 = project_output(&p1);
for s in 0..fst.num_states() as StateId {
let once = p1.transitions_lazy(s);
let twice = p2.transitions_lazy(s);
assert_eq!(once.len(), twice.len());
for (o, t) in once.iter().zip(twice.iter()) {
assert_eq!(o.input, t.input);
assert_eq!(o.output, t.output);
}
}
}
#[test]
fn test_reverse_involution_structure() {
let fst = make_transducer();
let rev1 = reverse(&fst);
let rev2 = reverse(&rev1);
let count_arcs = |f: &VectorWfst<char, TropicalWeight>| {
(0..f.num_states() as StateId)
.flat_map(|s| f.transitions(s).to_vec())
.filter(|t| !t.is_epsilon())
.count()
};
assert_eq!(count_arcs(&fst), count_arcs(&rev2));
let collect_labels = |f: &VectorWfst<char, TropicalWeight>| {
let mut labels: Vec<_> = (0..f.num_states() as StateId)
.flat_map(|s| f.transitions(s).to_vec())
.filter(|t| !t.is_epsilon())
.map(|t| (t.input, t.output))
.collect();
labels.sort();
labels
};
assert_eq!(collect_labels(&fst), collect_labels(&rev2));
}
#[test]
fn test_invert_project_commutes() {
let fst = make_transducer();
let mut inv = invert(&fst);
for s in 0..fst.num_states() as StateId {
inv.expand(s);
}
let mut pinv = project_input(&inv);
let mut pout = project_output(&fst);
for s in 0..fst.num_states() as StateId {
pout.expand(s);
}
let mut invp = invert(&pout);
for s in 0..fst.num_states() as StateId {
let t1 = pinv.transitions_lazy(s);
let t2 = invp.transitions_lazy(s);
assert_eq!(t1.len(), t2.len());
for (a, b) in t1.iter().zip(t2.iter()) {
assert_eq!(a.input, b.input);
}
}
}
#[test]
fn test_invert_preserves_path_weight() {
let fst = make_transducer();
let mut inv = invert(&fst);
let t0 = inv.transitions_lazy(0);
assert_eq!(t0[0].weight, TropicalWeight::new(1.0));
let t1 = inv.transitions_lazy(1);
assert_eq!(t1[0].weight, TropicalWeight::new(2.0));
inv.expand(2);
assert_eq!(inv.final_weight(2), TropicalWeight::one());
}
#[test]
fn test_project_preserves_path_weight() {
let fst = make_transducer();
let mut pin = project_input(&fst);
let mut pout = project_output(&fst);
assert_eq!(pin.transitions_lazy(0)[0].weight, TropicalWeight::new(1.0));
assert_eq!(pout.transitions_lazy(0)[0].weight, TropicalWeight::new(1.0));
assert_eq!(pin.transitions_lazy(1)[0].weight, TropicalWeight::new(2.0));
assert_eq!(pout.transitions_lazy(1)[0].weight, TropicalWeight::new(2.0));
}
#[test]
fn test_reverse_preserves_total_weight() {
let fst = make_transducer();
let rev = reverse(&fst);
let sum_weights = |f: &VectorWfst<char, TropicalWeight>| {
(0..f.num_states() as StateId)
.flat_map(|s| f.transitions(s).to_vec())
.filter(|t| !t.is_epsilon())
.map(|t| t.weight.0.into_inner())
.sum::<f64>()
};
assert!((sum_weights(&fst) - sum_weights(&rev)).abs() < 1e-10);
}
mod property_tests {
use super::*;
use crate::test_utils::arb_tropical_wfst;
use proptest::prelude::*;
proptest! {
#[test]
fn invert_preserves_states(
fst in arb_tropical_wfst(6, 2)
) {
let inv = invert(&fst);
prop_assert_eq!(inv.num_states(), fst.num_states());
}
#[test]
fn invert_is_involution(
fst in arb_tropical_wfst(5, 2)
) {
if fst.num_states() == 0 {
return Ok(());
}
let mut inv1 = invert(&fst);
for s in 0..fst.num_states() as StateId {
inv1.expand(s);
}
let mut inv2 = invert(&inv1);
for s in 0..fst.num_states() as StateId {
let orig = fst.transitions(s);
let double = inv2.transitions_lazy(s);
prop_assert_eq!(orig.len(), double.len(), "State {} arc count", s);
for (o, d) in orig.iter().zip(double.iter()) {
prop_assert_eq!(o.input, d.input, "State {} input label", s);
prop_assert_eq!(o.output, d.output, "State {} output label", s);
}
}
}
#[test]
fn project_input_preserves_states(
fst in arb_tropical_wfst(6, 2)
) {
let pin = project_input(&fst);
prop_assert_eq!(pin.num_states(), fst.num_states());
}
#[test]
fn project_output_preserves_states(
fst in arb_tropical_wfst(6, 2)
) {
let pout = project_output(&fst);
prop_assert_eq!(pout.num_states(), fst.num_states());
}
#[test]
fn project_input_idempotent(
fst in arb_tropical_wfst(5, 2)
) {
if fst.num_states() == 0 {
return Ok(());
}
let mut p1 = project_input(&fst);
for s in 0..fst.num_states() as StateId {
p1.expand(s);
}
let mut p2 = project_input(&p1);
for s in 0..fst.num_states() as StateId {
let t1 = p1.transitions_lazy(s);
let t2 = p2.transitions_lazy(s);
prop_assert_eq!(t1.len(), t2.len());
for (a, b) in t1.iter().zip(t2.iter()) {
prop_assert_eq!(a.input, b.input);
prop_assert_eq!(a.output, b.output);
}
}
}
#[test]
fn project_output_idempotent(
fst in arb_tropical_wfst(5, 2)
) {
if fst.num_states() == 0 {
return Ok(());
}
let mut p1 = project_output(&fst);
for s in 0..fst.num_states() as StateId {
p1.expand(s);
}
let mut p2 = project_output(&p1);
for s in 0..fst.num_states() as StateId {
let t1 = p1.transitions_lazy(s);
let t2 = p2.transitions_lazy(s);
prop_assert_eq!(t1.len(), t2.len());
for (a, b) in t1.iter().zip(t2.iter()) {
prop_assert_eq!(a.input, b.input);
prop_assert_eq!(a.output, b.output);
}
}
}
#[test]
fn reverse_state_count(
fst in arb_tropical_wfst(6, 2)
) {
if fst.num_states() == 0 {
let rev = reverse(&fst);
prop_assert!(rev.is_empty());
return Ok(());
}
let rev = reverse(&fst);
prop_assert_eq!(rev.num_states(), fst.num_states() + 1);
}
#[test]
fn reverse_preserves_arc_count(
fst in arb_tropical_wfst(6, 2)
) {
let rev = reverse(&fst);
let count_non_eps = |f: &VectorWfst<char, TropicalWeight>| {
(0..f.num_states() as StateId)
.flat_map(|s| f.transitions(s).to_vec())
.filter(|t| !t.is_epsilon())
.count()
};
prop_assert_eq!(count_non_eps(&fst), count_non_eps(&rev));
}
#[test]
fn reverse_double_arc_count(
fst in arb_tropical_wfst(5, 2)
) {
let rev1 = reverse(&fst);
let rev2 = reverse(&rev1);
let count_non_eps = |f: &VectorWfst<char, TropicalWeight>| {
(0..f.num_states() as StateId)
.flat_map(|s| f.transitions(s).to_vec())
.filter(|t| !t.is_epsilon())
.count()
};
prop_assert_eq!(count_non_eps(&fst), count_non_eps(&rev2));
}
#[test]
fn invert_preserves_weights(
fst in arb_tropical_wfst(5, 2)
) {
let mut inv = invert(&fst);
for s in 0..fst.num_states() as StateId {
let orig = fst.transitions(s);
let inverted = inv.transitions_lazy(s);
for (o, i) in orig.iter().zip(inverted.iter()) {
prop_assert!(
o.weight.approx_eq(&i.weight, 1e-10),
"Weight mismatch: {:?} vs {:?}", o.weight, i.weight
);
}
}
}
}
}
}