use std::collections::{HashSet, VecDeque};
use crate::semiring::Semiring;
use crate::wfst::{MutableWfst, StateId, Wfst, NO_STATE};
#[derive(Clone, Debug, Default)]
pub struct ConnectConfig {
pub keep_non_coaccessible: bool,
pub keep_non_accessible: bool,
}
impl ConnectConfig {
pub fn trim() -> Self {
Self::default()
}
pub fn accessible_only() -> Self {
Self {
keep_non_coaccessible: true,
keep_non_accessible: false,
}
}
pub fn coaccessible_only() -> Self {
Self {
keep_non_coaccessible: false,
keep_non_accessible: true,
}
}
}
pub fn connect<L, W, F>(fst: &mut F, config: ConnectConfig) -> usize
where
L: Clone,
W: Semiring,
F: MutableWfst<L, W> + Wfst<L, W>,
{
let n = fst.num_states();
if n == 0 {
return 0;
}
let start = fst.start();
if start == NO_STATE {
let removed = n;
for state in 0..n {
fst.clear_transitions(state as StateId);
fst.set_final(state as StateId, W::zero());
}
return removed;
}
let accessible = compute_accessible(fst);
let coaccessible = compute_coaccessible(fst);
let mut keep: HashSet<StateId> = HashSet::new();
for state in 0..n {
let state_id = state as StateId;
let is_accessible = accessible.contains(&state_id);
let is_coaccessible = coaccessible.contains(&state_id);
let should_keep = match (is_accessible, is_coaccessible) {
(true, true) => true,
(true, false) => config.keep_non_coaccessible,
(false, true) => config.keep_non_accessible,
(false, false) => false,
};
if should_keep {
keep.insert(state_id);
}
}
let removed = n - keep.len();
if removed == 0 {
return 0;
}
for state in 0..n {
let state_id = state as StateId;
if !keep.contains(&state_id) {
fst.clear_transitions(state_id);
fst.set_final(state_id, W::zero());
} else {
let transitions: Vec<_> = fst
.transitions(state_id)
.iter()
.filter(|t| keep.contains(&t.to))
.cloned()
.collect();
fst.clear_transitions(state_id);
for trans in transitions {
fst.add_transition(trans);
}
}
}
removed
}
pub fn compute_accessible<L, W, F>(fst: &F) -> HashSet<StateId>
where
L: Clone,
W: Semiring,
F: Wfst<L, W>,
{
let mut accessible = HashSet::new();
let start = fst.start();
if start == NO_STATE {
return accessible;
}
let mut queue = VecDeque::new();
queue.push_back(start);
accessible.insert(start);
while let Some(state) = queue.pop_front() {
for trans in fst.transitions(state) {
if !accessible.contains(&trans.to) {
accessible.insert(trans.to);
queue.push_back(trans.to);
}
}
}
accessible
}
pub fn compute_coaccessible<L, W, F>(fst: &F) -> HashSet<StateId>
where
L: Clone,
W: Semiring,
F: Wfst<L, W>,
{
let n = fst.num_states();
let mut reverse: Vec<Vec<StateId>> = vec![Vec::new(); n];
for state in 0..n {
let state_id = state as StateId;
for trans in fst.transitions(state_id) {
reverse[trans.to as usize].push(state_id);
}
}
let mut coaccessible = HashSet::new();
let mut queue = VecDeque::new();
for state in 0..n {
let state_id = state as StateId;
if fst.is_final(state_id) {
coaccessible.insert(state_id);
queue.push_back(state_id);
}
}
while let Some(state) = queue.pop_front() {
for &predecessor in &reverse[state as usize] {
if !coaccessible.contains(&predecessor) {
coaccessible.insert(predecessor);
queue.push_back(predecessor);
}
}
}
coaccessible
}
pub fn is_connected<L, W, F>(fst: &F) -> bool
where
L: Clone,
W: Semiring,
F: Wfst<L, W>,
{
let n = fst.num_states();
if n == 0 {
return true;
}
let accessible = compute_accessible(fst);
let coaccessible = compute_coaccessible(fst);
for state in 0..n {
let state_id = state as StateId;
if !accessible.contains(&state_id) || !coaccessible.contains(&state_id) {
return false;
}
}
true
}
pub fn count_useful_states<L, W, F>(fst: &F) -> usize
where
L: Clone,
W: Semiring,
F: Wfst<L, W>,
{
let accessible = compute_accessible(fst);
let coaccessible = compute_coaccessible(fst);
accessible.intersection(&coaccessible).count()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
use crate::wfst::{MutableWfst, VectorWfst, VectorWfstBuilder};
mod property_tests {
use super::*;
use crate::test_utils::arb_tropical_wfst;
use proptest::prelude::*;
proptest! {
#[test]
fn connect_reduces_or_maintains_states(
fst in arb_tropical_wfst(10, 3)
) {
let original_states = fst.num_states();
let useful_before = count_useful_states(&fst);
let mut connected_fst = fst.clone();
let removed = connect(&mut connected_fst, ConnectConfig::trim());
prop_assert!(
removed <= original_states,
"Removed {} states from {} total",
removed,
original_states
);
let useful_after = count_useful_states(&connected_fst);
prop_assert!(
useful_after <= useful_before,
"Useful count increased from {} to {}",
useful_before,
useful_after
);
}
#[test]
fn connect_all_useful(
fst in arb_tropical_wfst(8, 3)
) {
let mut connected_fst = fst.clone();
connect(&mut connected_fst, ConnectConfig::trim());
let accessible = compute_accessible(&connected_fst);
let coaccessible = compute_coaccessible(&connected_fst);
for state in 0..connected_fst.num_states() {
let state_id = state as StateId;
let has_transitions = !connected_fst.transitions(state_id).is_empty()
|| connected_fst.is_final(state_id);
if has_transitions && accessible.contains(&state_id) {
prop_assert!(
coaccessible.contains(&state_id),
"State {} is accessible but not coaccessible",
state_id
);
}
}
}
#[test]
fn connect_idempotent(
fst in arb_tropical_wfst(8, 3)
) {
let useful_before = count_useful_states(&fst);
if useful_before == 0 {
return Ok(());
}
let mut fst1 = fst.clone();
let _removed1 = connect(&mut fst1, ConnectConfig::trim());
let useful_after_first = count_useful_states(&fst1);
if useful_after_first == 0 {
return Ok(());
}
let mut fst2 = fst1.clone();
let _removed2 = connect(&mut fst2, ConnectConfig::trim());
let useful_after_second = count_useful_states(&fst2);
prop_assert_eq!(
useful_after_first,
useful_after_second,
"Useful count changed from {} to {} after second connect",
useful_after_first,
useful_after_second
);
}
#[test]
fn accessible_coaccessible_consistent(
fst in arb_tropical_wfst(6, 2)
) {
let accessible = compute_accessible(&fst);
let coaccessible = compute_coaccessible(&fst);
let useful = count_useful_states(&fst);
let intersection_count = accessible.intersection(&coaccessible).count();
prop_assert_eq!(
useful,
intersection_count,
"count_useful_states {} != intersection count {}",
useful,
intersection_count
);
}
#[test]
fn is_connected_after_connect(
fst in arb_tropical_wfst(6, 2)
) {
let mut connected_fst = fst.clone();
connect(&mut connected_fst, ConnectConfig::trim());
if count_useful_states(&connected_fst) > 0 {
let accessible = compute_accessible(&connected_fst);
let coaccessible = compute_coaccessible(&connected_fst);
for state in accessible.iter() {
if !connected_fst.transitions(*state).is_empty()
|| connected_fst.is_final(*state)
{
prop_assert!(
coaccessible.contains(state),
"Accessible state {} is not coaccessible after connect",
state
);
}
}
}
}
}
}
fn build_connected_fst() -> VectorWfst<char, TropicalWeight> {
VectorWfstBuilder::new()
.add_states(3)
.start(0)
.arc(0, Some('a'), Some('a'), 1, TropicalWeight::one())
.arc(1, Some('b'), Some('b'), 2, TropicalWeight::one())
.final_state(2, TropicalWeight::one())
.build()
}
fn build_with_unreachable() -> VectorWfst<char, TropicalWeight> {
let mut fst = VectorWfst::new();
fst.add_states(4);
fst.set_start(0);
fst.add_arc(0, Some('a'), Some('a'), 1, TropicalWeight::one());
fst.add_arc(1, Some('b'), Some('b'), 2, TropicalWeight::one());
fst.set_final(2, TropicalWeight::one());
fst.add_arc(3, Some('c'), Some('c'), 2, TropicalWeight::one());
fst
}
fn build_with_dead_end() -> VectorWfst<char, TropicalWeight> {
let mut fst = VectorWfst::new();
fst.add_states(4);
fst.set_start(0);
fst.add_arc(0, Some('a'), Some('a'), 1, TropicalWeight::one());
fst.add_arc(1, Some('b'), Some('b'), 2, TropicalWeight::one());
fst.set_final(2, TropicalWeight::one());
fst.add_arc(0, Some('x'), Some('x'), 3, TropicalWeight::one());
fst
}
#[test]
fn test_connect_empty() {
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let removed = connect(&mut fst, ConnectConfig::trim());
assert_eq!(removed, 0);
}
#[test]
fn test_connect_already_connected() {
let mut fst = build_connected_fst();
assert!(is_connected(&fst));
let removed = connect(&mut fst, ConnectConfig::trim());
assert_eq!(removed, 0);
assert!(is_connected(&fst));
}
#[test]
fn test_connect_removes_unreachable() {
let mut fst = build_with_unreachable();
assert!(!is_connected(&fst));
assert_eq!(count_useful_states(&fst), 3);
let removed = connect(&mut fst, ConnectConfig::trim());
assert_eq!(removed, 1);
assert!(fst.transitions(3).is_empty());
}
#[test]
fn test_connect_removes_dead_end() {
let mut fst = build_with_dead_end();
assert!(!is_connected(&fst));
assert_eq!(count_useful_states(&fst), 3);
let removed = connect(&mut fst, ConnectConfig::trim());
assert_eq!(removed, 1);
let trans_from_0: Vec<_> = fst.transitions(0).iter().map(|t| t.to).collect();
assert!(!trans_from_0.contains(&3));
}
#[test]
fn test_compute_accessible() {
let fst = build_with_unreachable();
let accessible = compute_accessible(&fst);
assert!(accessible.contains(&0));
assert!(accessible.contains(&1));
assert!(accessible.contains(&2));
assert!(!accessible.contains(&3)); }
#[test]
fn test_compute_coaccessible() {
let fst = build_with_dead_end();
let coaccessible = compute_coaccessible(&fst);
assert!(coaccessible.contains(&0));
assert!(coaccessible.contains(&1));
assert!(coaccessible.contains(&2));
assert!(!coaccessible.contains(&3)); }
#[test]
fn test_is_connected() {
let connected = build_connected_fst();
assert!(is_connected(&connected));
let with_unreachable = build_with_unreachable();
assert!(!is_connected(&with_unreachable));
let with_dead_end = build_with_dead_end();
assert!(!is_connected(&with_dead_end));
}
#[test]
fn test_count_useful_states() {
let connected = build_connected_fst();
assert_eq!(count_useful_states(&connected), 3);
let with_unreachable = build_with_unreachable();
assert_eq!(count_useful_states(&with_unreachable), 3);
let with_dead_end = build_with_dead_end();
assert_eq!(count_useful_states(&with_dead_end), 3);
}
#[test]
fn test_connect_config_accessible_only() {
let mut fst = build_with_dead_end();
let removed = connect(&mut fst, ConnectConfig::accessible_only());
assert_eq!(removed, 0);
assert!(fst.transitions(3).is_empty() || fst.transitions(0).iter().any(|t| t.to == 3));
}
#[test]
fn test_connect_config_coaccessible_only() {
let mut fst = build_with_unreachable();
let removed = connect(&mut fst, ConnectConfig::coaccessible_only());
assert_eq!(removed, 0); }
}