use std::collections::{BTreeMap, HashMap, VecDeque};
use std::fmt::Debug;
use std::hash::Hash;
use crate::semiring::{DivisibleSemiring, Semiring, TotallyOrderedSemiring};
use crate::wfst::{MutableWfst, StateId, WeightedTransition, Wfst, NO_STATE};
#[derive(Clone, Debug)]
pub struct DeterminizeConfig {
pub max_states: Option<usize>,
pub remove_epsilon_first: bool,
pub connect_after: bool,
}
impl Default for DeterminizeConfig {
fn default() -> Self {
Self {
max_states: Some(1_000_000),
remove_epsilon_first: true,
connect_after: true,
}
}
}
impl DeterminizeConfig {
pub fn standard() -> Self {
Self::default()
}
pub fn unlimited() -> Self {
Self {
max_states: None,
remove_epsilon_first: true,
connect_after: true,
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum DeterminizeError {
NoStartState,
StateLimitExceeded {
limit: usize,
},
NotDeterminizable {
reason: String,
},
}
impl std::fmt::Display for DeterminizeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DeterminizeError::NoStartState => write!(f, "WFST has no start state"),
DeterminizeError::StateLimitExceeded { limit } => {
write!(f, "Determinization exceeded {} state limit", limit)
}
DeterminizeError::NotDeterminizable { reason } => {
write!(f, "WFST is not determinizable: {}", reason)
}
}
}
}
impl std::error::Error for DeterminizeError {}
type WeightedSubset<W> = BTreeMap<StateId, W>;
fn subset_key<W: Semiring + Clone>(subset: &WeightedSubset<W>) -> Vec<(StateId, W)> {
subset.iter().map(|(&s, w)| (s, w.clone())).collect()
}
fn min_weight<W: TotallyOrderedSemiring + Clone>(subset: &WeightedSubset<W>) -> W {
subset
.values()
.cloned()
.min_by(|a, b| a.total_cmp(b))
.unwrap_or_else(W::zero)
}
pub fn determinize<L, W, F>(fst: &F, config: DeterminizeConfig) -> Result<F, DeterminizeError>
where
L: Clone + Eq + Hash + Ord + Debug,
W: DivisibleSemiring + TotallyOrderedSemiring + Clone + Debug + Hash + Eq,
F: MutableWfst<L, W> + Wfst<L, W> + Default,
{
let n = fst.num_states();
if n == 0 {
return Ok(F::default());
}
let start = fst.start();
if start == NO_STATE {
return Err(DeterminizeError::NoStartState);
}
let mut result = F::default();
let mut subset_to_state: HashMap<Vec<(StateId, W)>, StateId> = HashMap::new();
let mut queue: VecDeque<(StateId, WeightedSubset<W>)> = VecDeque::new();
let mut initial_subset: WeightedSubset<W> = BTreeMap::new();
initial_subset.insert(start, W::one());
let initial_state = result.add_state();
result.set_start(initial_state);
let initial_key = subset_key(&initial_subset);
subset_to_state.insert(initial_key, initial_state);
queue.push_back((initial_state, initial_subset));
while let Some((output_state, subset)) = queue.pop_front() {
if let Some(limit) = config.max_states {
if result.num_states() > limit {
return Err(DeterminizeError::StateLimitExceeded { limit });
}
}
let mut final_weight = W::zero();
for (&state, residual) in &subset {
if fst.is_final(state) {
let fw = fst.final_weight(state);
final_weight = final_weight.plus(&residual.times(&fw));
}
}
if !final_weight.is_zero() {
result.set_final(output_state, final_weight);
}
let mut label_to_targets: HashMap<Option<L>, Vec<(StateId, W, Option<L>)>> = HashMap::new();
for (&state, residual) in &subset {
for trans in fst.transitions(state) {
let combined = residual.times(&trans.weight);
label_to_targets
.entry(trans.input.clone())
.or_default()
.push((trans.to, combined, trans.output.clone()));
}
}
for (input_label, targets) in label_to_targets {
if input_label.is_none() {
continue;
}
let mut target_subset: WeightedSubset<W> = BTreeMap::new();
let mut output_label: Option<L> = None;
for (target_state, weight, out) in &targets {
target_subset
.entry(*target_state)
.and_modify(|w| *w = w.plus(weight))
.or_insert_with(|| weight.clone());
if output_label.is_none() {
output_label = out.clone();
}
}
if target_subset.is_empty() {
continue;
}
let min_w = min_weight(&target_subset);
let mut normalized_subset: WeightedSubset<W> = BTreeMap::new();
for (&state, weight) in &target_subset {
if let Some(normalized) = weight.divide(&min_w) {
normalized_subset.insert(state, normalized);
} else {
normalized_subset.insert(state, weight.clone());
}
}
let normalized_key = subset_key(&normalized_subset);
let target_output_state = if let Some(&existing) = subset_to_state.get(&normalized_key)
{
existing
} else {
let new_state = result.add_state();
subset_to_state.insert(normalized_key, new_state);
queue.push_back((new_state, normalized_subset));
new_state
};
let trans = WeightedTransition {
from: output_state,
to: target_output_state,
input: input_label,
output: output_label,
weight: min_w,
};
result.add_transition(trans);
}
}
if config.connect_after {
use crate::algorithms::{connect, ConnectConfig};
connect(&mut result, ConnectConfig::trim());
}
Ok(result)
}
pub fn is_deterministic<L, W, F>(fst: &F) -> bool
where
L: Clone + Eq + Hash,
W: Semiring,
F: Wfst<L, W>,
{
let n = fst.num_states();
if n == 0 {
return true;
}
let start = fst.start();
if start == NO_STATE {
return true; }
for state in 0..n {
let state_id = state as StateId;
let mut seen_labels: std::collections::HashSet<Option<&L>> =
std::collections::HashSet::new();
for trans in fst.transitions(state_id) {
if trans.input.is_none() {
return false;
}
if !seen_labels.insert(trans.input.as_ref()) {
return false;
}
}
}
true
}
pub fn non_determinism_degree<L, W, F>(fst: &F) -> usize
where
L: Clone + Eq + Hash,
W: Semiring,
F: Wfst<L, W>,
{
let n = fst.num_states();
if n == 0 {
return 0;
}
let mut max_degree = 0;
for state in 0..n {
let state_id = state as StateId;
let mut label_counts: HashMap<Option<&L>, usize> = HashMap::new();
for trans in fst.transitions(state_id) {
*label_counts.entry(trans.input.as_ref()).or_insert(0) += 1;
}
if let Some(&count) = label_counts.values().max() {
max_degree = max_degree.max(count);
}
}
max_degree
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
use crate::wfst::{VectorWfst, VectorWfstBuilder};
mod property_tests {
use super::*;
use crate::test_utils::arb_deterministic_wfst_tropical;
use proptest::prelude::*;
proptest! {
#[test]
fn determinize_produces_deterministic(
fst in arb_deterministic_wfst_tropical(8, 3)
) {
let result = determinize(&fst, DeterminizeConfig::standard());
if let Ok(det_fst) = result {
prop_assert!(
is_deterministic(&det_fst),
"Determinized FST should be deterministic"
);
}
}
#[test]
fn determinize_already_deterministic(
fst in arb_deterministic_wfst_tropical(8, 3)
) {
if fst.num_states() == 0 {
return Ok(());
}
prop_assert!(is_deterministic(&fst), "Test FST should be deterministic");
let result = determinize(&fst, DeterminizeConfig::standard());
if let Ok(det_fst) = result {
prop_assert!(
det_fst.num_states() <= fst.num_states() + 2,
"Determinizing deterministic FST grew from {} to {} states",
fst.num_states(),
det_fst.num_states()
);
}
}
#[test]
fn determinize_idempotent(
fst in arb_deterministic_wfst_tropical(6, 2)
) {
if fst.num_states() == 0 {
return Ok(());
}
let det1 = determinize(&fst, DeterminizeConfig::standard());
if let Ok(det1_fst) = det1 {
let det2 = determinize(&det1_fst, DeterminizeConfig::standard());
if let Ok(det2_fst) = det2 {
prop_assert!(is_deterministic(&det1_fst));
prop_assert!(is_deterministic(&det2_fst));
prop_assert!(
det2_fst.num_states() <= det1_fst.num_states() + 1,
"det(det(F)) has {} states, det(F) has {}",
det2_fst.num_states(),
det1_fst.num_states()
);
}
}
}
#[test]
fn non_determinism_degree_deterministic(
fst in arb_deterministic_wfst_tropical(8, 3)
) {
if fst.num_states() == 0 {
return Ok(());
}
let degree = non_determinism_degree(&fst);
prop_assert!(
degree <= 1,
"Deterministic FST should have degree 0 or 1, got {}",
degree
);
}
}
}
fn build_deterministic_fst() -> VectorWfst<char, TropicalWeight> {
VectorWfstBuilder::new()
.add_states(3)
.start(0)
.arc(0, Some('a'), Some('a'), 1, TropicalWeight::new(1.0))
.arc(1, Some('b'), Some('b'), 2, TropicalWeight::new(2.0))
.final_state(2, TropicalWeight::one())
.build()
}
fn build_non_deterministic_fst() -> VectorWfst<char, TropicalWeight> {
let mut fst = VectorWfst::new();
fst.add_states(4);
fst.set_start(0);
fst.add_arc(0, Some('a'), Some('a'), 1, TropicalWeight::new(1.0));
fst.add_arc(0, Some('a'), Some('a'), 2, TropicalWeight::new(2.0));
fst.add_arc(1, Some('b'), Some('b'), 3, TropicalWeight::new(1.0));
fst.add_arc(2, Some('c'), Some('c'), 3, TropicalWeight::new(1.0));
fst.set_final(3, TropicalWeight::one());
fst
}
fn build_diamond_non_det() -> VectorWfst<char, TropicalWeight> {
let mut fst = VectorWfst::new();
fst.add_states(4);
fst.set_start(0);
fst.add_arc(0, Some('a'), Some('a'), 1, TropicalWeight::new(1.0));
fst.add_arc(0, Some('a'), Some('a'), 2, TropicalWeight::new(2.0));
fst.add_arc(1, Some('b'), Some('b'), 3, TropicalWeight::new(1.0));
fst.add_arc(2, Some('b'), Some('b'), 3, TropicalWeight::new(1.0));
fst.set_final(3, TropicalWeight::one());
fst
}
#[test]
fn test_is_deterministic_true() {
let fst = build_deterministic_fst();
assert!(is_deterministic(&fst));
}
#[test]
fn test_is_deterministic_false() {
let fst = build_non_deterministic_fst();
assert!(!is_deterministic(&fst));
}
#[test]
fn test_non_determinism_degree() {
let det_fst = build_deterministic_fst();
assert_eq!(non_determinism_degree(&det_fst), 1);
let nondet_fst = build_non_deterministic_fst();
assert_eq!(non_determinism_degree(&nondet_fst), 2);
}
#[test]
fn test_determinize_already_deterministic() {
let fst = build_deterministic_fst();
let result = determinize(&fst, DeterminizeConfig::standard())
.expect("algorithms/determinize.rs: required value was None/Err");
assert!(is_deterministic(&result));
assert!(result.num_states() <= 3);
}
#[test]
fn test_determinize_simple_non_det() {
let fst = build_non_deterministic_fst();
assert!(!is_deterministic(&fst));
let result = determinize(&fst, DeterminizeConfig::standard())
.expect("algorithms/determinize.rs: required value was None/Err");
assert!(is_deterministic(&result));
}
#[test]
fn test_determinize_diamond() {
let fst = build_diamond_non_det();
assert!(!is_deterministic(&fst));
let result = determinize(&fst, DeterminizeConfig::standard())
.expect("algorithms/determinize.rs: required value was None/Err");
assert!(is_deterministic(&result));
assert!(result.num_states() <= fst.num_states());
}
#[test]
fn test_determinize_empty() {
let fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let result = determinize(&fst, DeterminizeConfig::standard())
.expect("algorithms/determinize.rs: required value was None/Err");
assert_eq!(result.num_states(), 0);
}
#[test]
fn test_determinize_weight_preservation() {
let mut fst = VectorWfst::new();
fst.add_states(3);
fst.set_start(0);
fst.add_arc(0, Some('a'), Some('a'), 1, TropicalWeight::new(1.0));
fst.add_arc(0, Some('a'), Some('a'), 2, TropicalWeight::new(3.0));
fst.set_final(1, TropicalWeight::one()); fst.set_final(2, TropicalWeight::one());
let result = determinize(&fst, DeterminizeConfig::standard())
.expect("algorithms/determinize.rs: required value was None/Err");
assert!(is_deterministic(&result));
assert_eq!(result.num_states(), 2);
let start = result.start();
let trans: Vec<_> = result.transitions(start).to_vec();
assert_eq!(trans.len(), 1);
assert_eq!(trans[0].weight.value(), 1.0);
}
#[test]
fn test_determinize_state_limit() {
let fst = build_non_deterministic_fst();
let config = DeterminizeConfig {
max_states: Some(1), remove_epsilon_first: false,
connect_after: false,
};
let result = determinize(&fst, config);
assert!(matches!(
result,
Err(DeterminizeError::StateLimitExceeded { .. })
));
}
}