use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::hash::Hash;
use crate::semiring::{DivisibleSemiring, QuantizableSemiring, Semiring};
use crate::wfst::{MutableWfst, StateId, WeightedTransition, Wfst, NO_STATE};
const MINIMIZE_EPSILON: f64 = 1e-10;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
struct QuantizedWeight {
quantized: i64,
}
impl QuantizedWeight {
fn from_weight<W: QuantizableSemiring>(weight: &W, epsilon: f64) -> Self {
Self {
quantized: weight.quantize(epsilon),
}
}
}
#[derive(Clone, Debug)]
pub struct MinimizeConfig {
pub push_weights: bool,
pub push_direction: crate::algorithms::PushDirection,
pub connect_first: bool,
pub weight_epsilon: f64,
}
impl Default for MinimizeConfig {
fn default() -> Self {
Self {
push_weights: true,
push_direction: crate::algorithms::PushDirection::Forward,
connect_first: true,
weight_epsilon: MINIMIZE_EPSILON,
}
}
}
impl MinimizeConfig {
pub fn standard() -> Self {
Self::default()
}
pub fn no_push() -> Self {
Self {
push_weights: false,
push_direction: crate::algorithms::PushDirection::Forward,
connect_first: true,
weight_epsilon: MINIMIZE_EPSILON,
}
}
pub fn with_epsilon(epsilon: f64) -> Self {
Self {
weight_epsilon: epsilon,
..Self::default()
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum MinimizeError {
NoStartState,
NotDeterministic,
PushError(String),
}
impl std::fmt::Display for MinimizeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MinimizeError::NoStartState => write!(f, "WFST has no start state"),
MinimizeError::NotDeterministic => {
write!(f, "WFST must be deterministic before minimization")
}
MinimizeError::PushError(msg) => write!(f, "Weight pushing failed: {}", msg),
}
}
}
impl std::error::Error for MinimizeError {}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct StateSignature<L: Ord + Hash> {
final_weight: Option<QuantizedWeight>,
transitions: Vec<(Option<L>, Option<L>, QuantizedWeight, usize)>,
}
impl<L: Ord + Hash + Clone> StateSignature<L> {
fn new() -> Self {
Self {
final_weight: None,
transitions: Vec::new(),
}
}
}
pub fn minimize<L, W, F>(fst: &F, config: MinimizeConfig) -> Result<F, MinimizeError>
where
L: Clone + Eq + Hash + Ord + Debug,
W: DivisibleSemiring + QuantizableSemiring + PartialOrd + Clone + Debug,
F: MutableWfst<L, W> + Wfst<L, W> + Default + Clone,
{
let n = fst.num_states();
if n == 0 {
return Ok(F::default());
}
let start = fst.start();
if start == NO_STATE {
return Err(MinimizeError::NoStartState);
}
if !super::determinize::is_deterministic(fst) {
return Err(MinimizeError::NotDeterministic);
}
let mut working = fst.clone();
if config.connect_first {
use crate::algorithms::{connect, ConnectConfig};
connect(&mut working, ConnectConfig::trim());
}
if config.push_weights {
use crate::algorithms::{push_weights, PushConfig, ShortestDistanceConfig};
let push_config = PushConfig {
direction: config.push_direction.clone(),
remove_non_coaccessible: false, distance_config: ShortestDistanceConfig::default(),
};
push_weights(&mut working, push_config)
.map_err(|e| MinimizeError::PushError(e.to_string()))?;
}
let partitions = compute_partitions(&working, config.weight_epsilon)?;
build_minimized(&working, &partitions)
}
fn compute_partitions<L, W, F>(fst: &F, epsilon: f64) -> Result<Vec<usize>, MinimizeError>
where
L: Clone + Eq + Hash + Ord + Debug,
W: QuantizableSemiring + Clone + Debug,
F: Wfst<L, W>,
{
let n = fst.num_states();
if n == 0 {
return Ok(Vec::new());
}
let mut partition: Vec<usize> = vec![0; n];
let mut num_partitions = 0;
let mut final_weight_to_partition: HashMap<Option<QuantizedWeight>, usize> = HashMap::new();
for state in 0..n {
let state_id = state as StateId;
let fw = if fst.is_final(state_id) {
Some(QuantizedWeight::from_weight(
&fst.final_weight(state_id),
epsilon,
))
} else {
None
};
if let Some(&p) = final_weight_to_partition.get(&fw) {
partition[state] = p;
} else {
let p = num_partitions;
num_partitions += 1;
final_weight_to_partition.insert(fw, p);
partition[state] = p;
}
}
let mut changed = true;
while changed {
changed = false;
let mut signature_to_partition: HashMap<StateSignature<L>, usize> = HashMap::new();
let mut new_partition: Vec<usize> = vec![0; n];
let mut new_num_partitions = 0;
for state in 0..n {
let state_id = state as StateId;
let mut sig = StateSignature::new();
if fst.is_final(state_id) {
sig.final_weight = Some(QuantizedWeight::from_weight(
&fst.final_weight(state_id),
epsilon,
));
}
let mut trans_sigs: Vec<(Option<L>, Option<L>, QuantizedWeight, usize)> = Vec::new();
for trans in fst.transitions(state_id) {
let target_partition = partition[trans.to as usize];
trans_sigs.push((
trans.input.clone(),
trans.output.clone(),
QuantizedWeight::from_weight(&trans.weight, epsilon),
target_partition,
));
}
trans_sigs.sort_by(|a, b| {
a.0.cmp(&b.0)
.then_with(|| a.1.cmp(&b.1))
.then_with(|| a.3.cmp(&b.3))
});
sig.transitions = trans_sigs;
if let Some(&p) = signature_to_partition.get(&sig) {
new_partition[state] = p;
} else {
let p = new_num_partitions;
new_num_partitions += 1;
signature_to_partition.insert(sig, p);
new_partition[state] = p;
}
}
if new_num_partitions > num_partitions {
changed = true;
partition = new_partition;
num_partitions = new_num_partitions;
}
}
Ok(partition)
}
fn build_minimized<L, W, F>(fst: &F, partitions: &[usize]) -> Result<F, MinimizeError>
where
L: Clone + Eq + Hash + Ord + Debug,
W: Semiring + Clone + Debug,
F: MutableWfst<L, W> + Wfst<L, W> + Default,
{
let n = fst.num_states();
if n == 0 {
return Ok(F::default());
}
let num_new_states = partitions.iter().max().map(|&m| m + 1).unwrap_or(0);
let mut partition_to_rep: HashMap<usize, StateId> = HashMap::new();
for state in 0..n {
let p = partitions[state];
partition_to_rep.entry(p).or_insert(state as StateId);
}
let mut result = F::default();
for _ in 0..num_new_states {
result.add_state();
}
let old_start = fst.start();
if old_start != NO_STATE {
let new_start = partitions[old_start as usize];
result.set_start(new_start as StateId);
}
let mut added_transitions: HashSet<(usize, Option<L>, Option<L>, usize)> = HashSet::new();
for (partition, &rep) in &partition_to_rep {
let new_state = *partition as StateId;
if fst.is_final(rep) {
result.set_final(new_state, fst.final_weight(rep));
}
for trans in fst.transitions(rep) {
let target_partition = partitions[trans.to as usize];
let key = (
*partition,
trans.input.clone(),
trans.output.clone(),
target_partition,
);
if !added_transitions.contains(&key) {
added_transitions.insert(key);
let new_trans = WeightedTransition {
from: new_state,
to: target_partition as StateId,
input: trans.input.clone(),
output: trans.output.clone(),
weight: trans.weight.clone(),
};
result.add_transition(new_trans);
}
}
}
Ok(result)
}
pub fn estimate_reduction<L, W, F>(fst: &F) -> usize
where
L: Clone + Eq + Hash + Ord + Debug,
W: QuantizableSemiring + Clone + Debug,
F: Wfst<L, W>,
{
estimate_reduction_with_epsilon(fst, MINIMIZE_EPSILON)
}
pub fn estimate_reduction_with_epsilon<L, W, F>(fst: &F, epsilon: f64) -> usize
where
L: Clone + Eq + Hash + Ord + Debug,
W: QuantizableSemiring + Clone + Debug,
F: Wfst<L, W>,
{
let n = fst.num_states();
if n == 0 {
return 0;
}
if let Ok(partitions) = compute_partitions(fst, epsilon) {
let num_new_states = partitions.iter().max().map(|&m| m + 1).unwrap_or(0);
n.saturating_sub(num_new_states)
} else {
0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
use crate::wfst::{VectorWfst, VectorWfstBuilder};
mod property_tests {
use super::*;
use crate::test_utils::arb_deterministic_wfst_tropical;
use proptest::prelude::*;
proptest! {
#[test]
fn minimize_reduces_or_maintains_states(
fst in arb_deterministic_wfst_tropical(8, 3)
) {
if fst.num_states() == 0 {
return Ok(());
}
let original_states = fst.num_states();
let result = minimize(&fst, MinimizeConfig::standard());
if let Ok(min_fst) = result {
prop_assert!(
min_fst.num_states() <= original_states,
"Minimization increased states from {} to {}",
original_states,
min_fst.num_states()
);
}
}
#[test]
fn minimize_idempotent(
fst in arb_deterministic_wfst_tropical(6, 2)
) {
if fst.num_states() == 0 {
return Ok(());
}
let result1 = minimize(&fst, MinimizeConfig::standard());
if let Ok(min1) = result1 {
let result2 = minimize(&min1, MinimizeConfig::standard());
if let Ok(min2) = result2 {
prop_assert_eq!(
min1.num_states(),
min2.num_states(),
"Minimization not idempotent: first pass {} states, second pass {} states",
min1.num_states(),
min2.num_states()
);
}
}
}
#[test]
fn minimize_preserves_determinism(
fst in arb_deterministic_wfst_tropical(8, 3)
) {
if fst.num_states() == 0 {
return Ok(());
}
let result = minimize(&fst, MinimizeConfig::standard());
if let Ok(min_fst) = result {
prop_assert!(
super::super::super::determinize::is_deterministic(&min_fst),
"Minimized FST should be deterministic"
);
}
}
#[test]
fn estimate_reduction_bounds(
fst in arb_deterministic_wfst_tropical(6, 2)
) {
if fst.num_states() <= 1 {
return Ok(());
}
let estimated = estimate_reduction(&fst);
let original_states = fst.num_states();
prop_assert!(
estimated <= original_states,
"Estimated reduction {} exceeds state count {}",
estimated,
original_states
);
let result = minimize(&fst, MinimizeConfig::standard());
prop_assert!(
result.is_ok() || matches!(result, Err(MinimizeError::PushError(_))),
"Minimize failed unexpectedly: {:?}",
result
);
}
}
}
fn build_redundant_fst() -> VectorWfst<char, TropicalWeight> {
let mut fst = VectorWfst::new();
fst.add_states(5);
fst.set_start(0);
fst.add_arc(0, Some('a'), Some('a'), 1, TropicalWeight::new(1.0));
fst.add_arc(0, Some('c'), Some('c'), 2, TropicalWeight::new(1.0));
fst.add_arc(1, Some('b'), Some('b'), 3, TropicalWeight::new(1.0));
fst.add_arc(2, Some('b'), Some('b'), 4, TropicalWeight::new(1.0));
fst.set_final(3, TropicalWeight::one());
fst.set_final(4, TropicalWeight::one());
fst
}
fn build_minimal_fst() -> VectorWfst<char, TropicalWeight> {
VectorWfstBuilder::new()
.add_states(3)
.start(0)
.arc(0, Some('a'), Some('a'), 1, TropicalWeight::new(1.0))
.arc(1, Some('b'), Some('b'), 2, TropicalWeight::new(2.0))
.final_state(2, TropicalWeight::one())
.build()
}
fn build_chain_with_equiv_states() -> VectorWfst<char, TropicalWeight> {
let mut fst = VectorWfst::new();
fst.add_states(5);
fst.set_start(0);
fst.add_arc(0, Some('a'), Some('a'), 1, TropicalWeight::new(1.0));
fst.add_arc(1, Some('b'), Some('b'), 2, TropicalWeight::new(1.0));
fst.add_arc(2, Some('c'), Some('c'), 3, TropicalWeight::new(1.0));
fst.add_arc(2, Some('d'), Some('d'), 4, TropicalWeight::new(1.0));
fst.set_final(3, TropicalWeight::one());
fst.set_final(4, TropicalWeight::one());
fst
}
#[test]
fn test_minimize_empty() {
let fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let result = minimize(&fst, MinimizeConfig::standard())
.expect("algorithms/minimize.rs: required value was None/Err");
assert_eq!(result.num_states(), 0);
}
#[test]
fn test_minimize_already_minimal() {
let fst = build_minimal_fst();
let result = minimize(&fst, MinimizeConfig::standard())
.expect("algorithms/minimize.rs: required value was None/Err");
assert!(result.num_states() <= fst.num_states());
}
#[test]
fn test_minimize_redundant() {
let fst = build_redundant_fst();
let initial_states = fst.num_states();
let result = minimize(&fst, MinimizeConfig::standard())
.expect("algorithms/minimize.rs: required value was None/Err");
assert!(
result.num_states() < initial_states,
"Expected fewer than {} states, got {}",
initial_states,
result.num_states()
);
}
#[test]
fn test_minimize_non_deterministic_fails() {
let mut fst = VectorWfst::new();
fst.add_states(3);
fst.set_start(0);
fst.add_arc(0, Some('a'), Some('a'), 1, TropicalWeight::new(1.0));
fst.add_arc(0, Some('a'), Some('a'), 2, TropicalWeight::new(2.0)); fst.set_final(1, TropicalWeight::one());
fst.set_final(2, TropicalWeight::one());
let result = minimize(&fst, MinimizeConfig::standard());
assert!(matches!(result, Err(MinimizeError::NotDeterministic)));
}
#[test]
fn test_minimize_chain_equiv() {
let fst = build_chain_with_equiv_states();
let initial_states = fst.num_states();
let result = minimize(&fst, MinimizeConfig::standard())
.expect("algorithms/minimize.rs: required value was None/Err");
assert!(result.num_states() > 0);
assert!(result.num_states() <= initial_states);
}
#[test]
fn test_estimate_reduction() {
let redundant = build_redundant_fst();
let reduction = estimate_reduction(&redundant);
assert!(reduction >= 1, "Expected reduction >= 1, got {}", reduction);
}
#[test]
fn test_minimize_preserves_determinism() {
let fst = build_redundant_fst();
assert!(super::super::determinize::is_deterministic(&fst));
let result = minimize(&fst, MinimizeConfig::standard())
.expect("algorithms/minimize.rs: required value was None/Err");
assert!(super::super::determinize::is_deterministic(&result));
}
#[test]
fn test_minimize_no_push_config() {
let fst = build_minimal_fst();
let result = minimize(&fst, MinimizeConfig::no_push())
.expect("algorithms/minimize.rs: required value was None/Err");
assert!(result.num_states() <= fst.num_states());
}
}