use std::collections::{HashMap, HashSet};
use crate::semiring::Semiring;
use crate::wfst::{MutableWfst, StateId, VectorWfst, Wfst, NO_STATE};
pub type ChainId = u32;
#[derive(Clone, Debug)]
pub struct Chain<L: Clone, W: Semiring> {
pub id: ChainId,
pub states: Vec<StateId>,
pub input_labels: Vec<Option<L>>,
pub output_labels: Vec<Option<L>>,
pub weight: W,
}
impl<L: Clone, W: Semiring + Clone> Chain<L, W> {
pub fn new(id: ChainId) -> Self {
Self {
id,
states: Vec::new(),
input_labels: Vec::new(),
output_labels: Vec::new(),
weight: W::one(),
}
}
pub fn len(&self) -> usize {
self.input_labels.len()
}
pub fn is_empty(&self) -> bool {
self.input_labels.is_empty()
}
pub fn start_state(&self) -> Option<StateId> {
self.states.first().copied()
}
pub fn end_state(&self) -> Option<StateId> {
self.states.last().copied()
}
}
#[derive(Clone, Debug)]
pub struct ChainFactorConfig {
pub min_chain_length: usize,
pub factor_epsilon_chains: bool,
pub max_chains: Option<usize>,
}
impl Default for ChainFactorConfig {
fn default() -> Self {
Self {
min_chain_length: 2,
factor_epsilon_chains: true,
max_chains: None,
}
}
}
#[derive(Clone, Debug)]
pub struct ChainFactorResult<L: Clone, W: Semiring> {
pub fst: VectorWfst<L, W>,
pub chains: HashMap<ChainId, Chain<L, W>>,
pub stats: ChainFactorStats,
}
#[derive(Clone, Debug, Default)]
pub struct ChainFactorStats {
pub chains_found: usize,
pub chains_factored: usize,
pub states_removed: usize,
pub transitions_removed: usize,
pub total_gain: i64,
}
pub fn find_chains<L, W>(fst: &VectorWfst<L, W>) -> Vec<(StateId, StateId)>
where
L: Clone + Eq + std::hash::Hash + Send + Sync,
W: Semiring + Clone,
{
let num_states = fst.num_states();
if num_states == 0 {
return Vec::new();
}
let mut in_degree = vec![0usize; num_states];
let mut out_degree = vec![0usize; num_states];
for state in 0..num_states as StateId {
let arcs = fst.transitions(state);
out_degree[state as usize] = arcs.len();
for arc in arcs {
if (arc.to as usize) < num_states {
in_degree[arc.to as usize] += 1;
}
}
}
let chain_candidates: HashSet<StateId> = (0..num_states as StateId)
.filter(|&s| {
let is_start = fst.start() == s;
let is_final = fst.is_final(s);
in_degree[s as usize] == 1 && out_degree[s as usize] == 1 && !is_start && !is_final
})
.collect();
let mut chains = Vec::new();
let mut visited = HashSet::new();
for start in 0..num_states as StateId {
if chain_candidates.contains(&start) {
continue;
}
for arc in fst.transitions(start) {
let mut current = arc.to;
if !chain_candidates.contains(¤t) || visited.contains(¤t) {
continue;
}
let chain_start = current;
while chain_candidates.contains(¤t) && !visited.contains(¤t) {
visited.insert(current);
let arcs = fst.transitions(current);
if arcs.len() == 1 {
current = arcs[0].to;
} else {
break;
}
}
let chain_end = current;
if chain_start != chain_end {
chains.push((start, chain_end));
}
}
}
chains
}
pub fn compute_chain_gain<L, W>(chain: &Chain<L, W>) -> i64
where
L: Clone,
W: Semiring,
{
let input_len = chain.input_labels.iter().filter(|l| l.is_some()).count();
let output_len = chain.output_labels.iter().filter(|l| l.is_some()).count();
(input_len as i64) - (output_len as i64) - 1
}
pub fn chain_factor<L, W>(
fst: &VectorWfst<L, W>,
config: &ChainFactorConfig,
) -> ChainFactorResult<L, W>
where
L: Clone + Eq + std::hash::Hash + Default + Send + Sync,
W: Semiring + Clone,
{
let mut stats = ChainFactorStats::default();
let mut chains: HashMap<ChainId, Chain<L, W>> = HashMap::new();
let mut next_chain_id: ChainId = 0;
let chain_endpoints = find_chains(fst);
stats.chains_found = chain_endpoints.len();
if chain_endpoints.is_empty() || fst.num_states() == 0 {
return ChainFactorResult {
fst: clone_fst(fst),
chains,
stats,
};
}
let mut chain_states_to_remove: HashSet<StateId> = HashSet::new();
let mut chain_replacements: Vec<(StateId, StateId, Chain<L, W>)> = Vec::new();
for (chain_entry, chain_exit) in &chain_endpoints {
if let Some(chain) = extract_chain(fst, *chain_entry, *chain_exit, next_chain_id) {
if chain.len() < config.min_chain_length {
continue;
}
if !config.factor_epsilon_chains {
let has_epsilon = chain.input_labels.iter().any(|l| l.is_none())
|| chain.output_labels.iter().any(|l| l.is_none());
if has_epsilon {
continue;
}
}
let gain = compute_chain_gain(&chain);
if gain <= 0 {
continue;
}
if let Some(max) = config.max_chains {
if chains.len() >= max {
break;
}
}
for &state in chain
.states
.iter()
.skip(1)
.take(chain.states.len().saturating_sub(2))
{
chain_states_to_remove.insert(state);
}
stats.chains_factored += 1;
stats.total_gain += gain;
stats.states_removed += chain.states.len().saturating_sub(2);
stats.transitions_removed += chain.len().saturating_sub(1);
chain_replacements.push((*chain_entry, *chain_exit, chain.clone()));
chains.insert(next_chain_id, chain);
next_chain_id += 1;
}
}
let result_fst = build_factored_fst(fst, &chain_states_to_remove, &chain_replacements);
ChainFactorResult {
fst: result_fst,
chains,
stats,
}
}
fn extract_chain<L, W>(
fst: &VectorWfst<L, W>,
entry: StateId,
exit: StateId,
chain_id: ChainId,
) -> Option<Chain<L, W>>
where
L: Clone + Send + Sync,
W: Semiring + Clone,
{
let mut chain = Chain::new(chain_id);
chain.states.push(entry);
let mut current = entry;
let mut accumulated_weight = W::one();
while current != exit {
let arcs = fst.transitions(current);
let next_arc = arcs.iter().find(|arc| {
arc.to != current
});
match next_arc {
Some(arc) => {
chain.states.push(arc.to);
chain.input_labels.push(arc.input.clone());
chain.output_labels.push(arc.output.clone());
accumulated_weight = accumulated_weight.times(&arc.weight);
current = arc.to;
}
None => return None, }
if chain.states.len() > fst.num_states() {
return None;
}
}
chain.weight = accumulated_weight;
Some(chain)
}
fn build_factored_fst<L, W>(
fst: &VectorWfst<L, W>,
states_to_remove: &HashSet<StateId>,
chain_replacements: &[(StateId, StateId, Chain<L, W>)],
) -> VectorWfst<L, W>
where
L: Clone + Default + Send + Sync,
W: Semiring + Clone,
{
if states_to_remove.is_empty() && chain_replacements.is_empty() {
return clone_fst(fst);
}
let mut result: VectorWfst<L, W> = VectorWfst::new();
let mut state_map: HashMap<StateId, StateId> = HashMap::new();
for old_id in 0..fst.num_states() as StateId {
if !states_to_remove.contains(&old_id) {
let new_id = result.add_state();
state_map.insert(old_id, new_id);
}
}
let start = fst.start();
if start != NO_STATE {
if let Some(&new_start) = state_map.get(&start) {
result.set_start(new_start);
}
}
let chain_arcs: HashSet<(StateId, StateId)> = chain_replacements
.iter()
.flat_map(|(_entry, _exit, chain)| {
chain
.states
.windows(2)
.map(|w| (w[0], w[1]))
.collect::<Vec<_>>()
})
.collect();
for old_source in 0..fst.num_states() as StateId {
if states_to_remove.contains(&old_source) {
continue;
}
let new_source = match state_map.get(&old_source) {
Some(&id) => id,
None => continue,
};
for arc in fst.transitions(old_source) {
if states_to_remove.contains(&arc.to) {
continue;
}
if chain_arcs.contains(&(old_source, arc.to)) {
continue;
}
if let Some(&new_target) = state_map.get(&arc.to) {
result.add_arc(
new_source,
arc.input.clone(),
arc.output.clone(),
new_target,
arc.weight.clone(),
);
}
}
if fst.is_final(old_source) {
let weight = fst.final_weight(old_source);
result.set_final(new_source, weight.clone());
}
}
for (entry, exit, chain) in chain_replacements {
if let (Some(&new_entry), Some(&new_exit)) = (state_map.get(entry), state_map.get(exit)) {
let input = chain.input_labels.first().cloned().flatten();
let output = chain.output_labels.first().cloned().flatten();
result.add_arc(new_entry, input, output, new_exit, chain.weight.clone());
}
}
result
}
fn clone_fst<L, W>(fst: &VectorWfst<L, W>) -> VectorWfst<L, W>
where
L: Clone + Send + Sync,
W: Semiring + Clone,
{
let mut result: VectorWfst<L, W> = VectorWfst::new();
for _ in 0..fst.num_states() {
result.add_state();
}
let start = fst.start();
if start != NO_STATE {
result.set_start(start);
}
for state in 0..fst.num_states() as StateId {
for arc in fst.transitions(state) {
result.add_arc(
state,
arc.input.clone(),
arc.output.clone(),
arc.to,
arc.weight.clone(),
);
}
if fst.is_final(state) {
let weight = fst.final_weight(state);
result.set_final(state, weight.clone());
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::LogWeight;
#[test]
fn test_chain_config_default() {
let config = ChainFactorConfig::default();
assert_eq!(config.min_chain_length, 2);
assert!(config.factor_epsilon_chains);
assert!(config.max_chains.is_none());
}
#[test]
fn test_empty_chain() {
let chain = Chain::<u32, LogWeight>::new(0);
assert!(chain.is_empty());
assert_eq!(chain.len(), 0);
assert!(chain.start_state().is_none());
assert!(chain.end_state().is_none());
}
#[test]
fn test_chain_with_states() {
let mut chain = Chain::<u32, LogWeight>::new(1);
chain.states = vec![0, 1, 2];
chain.input_labels = vec![Some(10), Some(11)];
chain.output_labels = vec![Some(20), Some(21)];
assert!(!chain.is_empty());
assert_eq!(chain.len(), 2);
assert_eq!(chain.start_state(), Some(0));
assert_eq!(chain.end_state(), Some(2));
}
#[test]
fn test_compute_chain_gain() {
let mut chain = Chain::<u32, LogWeight>::new(0);
chain.input_labels = vec![Some(1), Some(2), Some(3)]; chain.output_labels = vec![Some(10)];
assert_eq!(compute_chain_gain(&chain), 1);
}
#[test]
fn test_compute_chain_gain_negative() {
let mut chain = Chain::<u32, LogWeight>::new(0);
chain.input_labels = vec![Some(1)]; chain.output_labels = vec![Some(10), Some(20), Some(30)];
assert_eq!(compute_chain_gain(&chain), -3);
}
#[test]
fn test_find_chains_empty_fst() {
let fst = VectorWfst::<u32, LogWeight>::new();
let chains = find_chains(&fst);
assert!(chains.is_empty());
}
#[test]
fn test_chain_factor_empty_fst() {
let fst = VectorWfst::<u32, LogWeight>::new();
let config = ChainFactorConfig::default();
let result = chain_factor(&fst, &config);
assert_eq!(result.stats.chains_found, 0);
assert!(result.chains.is_empty());
}
#[test]
fn test_chain_factor_simple_fst() {
let mut fst: VectorWfst<u32, LogWeight> = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
let s2 = fst.add_state();
let s3 = fst.add_state();
fst.set_start(s0);
fst.set_final(s3, LogWeight::one());
fst.add_arc(s0, Some(1), Some(1), s1, LogWeight::one());
fst.add_arc(s1, Some(2), Some(2), s2, LogWeight::one());
fst.add_arc(s2, Some(3), Some(3), s3, LogWeight::one());
let config = ChainFactorConfig::default();
let result = chain_factor(&fst, &config);
assert!(result.stats.chains_found > 0, "expected at least one chain");
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use crate::semiring::LogWeight;
use crate::wfst::Wfst;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn default_config_min_length(_seed in any::<u64>()) {
let config = ChainFactorConfig::default();
prop_assert_eq!(config.min_chain_length, 2);
}
#[test]
fn default_config_epsilon(_seed in any::<u64>()) {
let config = ChainFactorConfig::default();
prop_assert!(config.factor_epsilon_chains);
}
#[test]
fn default_config_no_max(_seed in any::<u64>()) {
let config = ChainFactorConfig::default();
prop_assert!(config.max_chains.is_none());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn new_chain_empty(id in 0u32..1000) {
let chain = Chain::<u32, LogWeight>::new(id);
prop_assert!(chain.is_empty());
prop_assert_eq!(chain.len(), 0);
prop_assert_eq!(chain.id, id);
}
#[test]
fn new_chain_weight(id in 0u32..1000) {
let chain = Chain::<u32, LogWeight>::new(id);
prop_assert_eq!(chain.weight, LogWeight::one());
}
#[test]
fn chain_length_is_input_labels(
id in 0u32..100,
num_labels in 0usize..10
) {
let mut chain = Chain::<u32, LogWeight>::new(id);
chain.input_labels = (0..num_labels).map(|i| Some(i as u32)).collect();
prop_assert_eq!(chain.len(), num_labels);
}
#[test]
fn chain_is_empty_no_labels(id in 0u32..100) {
let chain = Chain::<u32, LogWeight>::new(id);
prop_assert!(chain.is_empty());
}
#[test]
fn chain_not_empty_with_labels(id in 0u32..100, num_labels in 1usize..10) {
let mut chain = Chain::<u32, LogWeight>::new(id);
chain.input_labels = (0..num_labels).map(|i| Some(i as u32)).collect();
prop_assert!(!chain.is_empty());
}
#[test]
fn chain_start_state(states in prop::collection::vec(0u32..100, 1..5)) {
let mut chain = Chain::<u32, LogWeight>::new(0);
chain.states = states.clone();
prop_assert_eq!(chain.start_state(), Some(states[0]));
}
#[test]
fn chain_end_state(states in prop::collection::vec(0u32..100, 1..5)) {
let mut chain = Chain::<u32, LogWeight>::new(0);
chain.states = states.clone();
prop_assert_eq!(chain.end_state(), Some(*states.last().expect("asr/factoring.rs: required value was None/Err")));
}
#[test]
fn chain_empty_states_none(id in 0u32..100) {
let chain = Chain::<u32, LogWeight>::new(id);
prop_assert!(chain.start_state().is_none());
prop_assert!(chain.end_state().is_none());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn default_stats_zeros(_seed in any::<u64>()) {
let stats = ChainFactorStats::default();
prop_assert_eq!(stats.chains_found, 0);
prop_assert_eq!(stats.chains_factored, 0);
prop_assert_eq!(stats.states_removed, 0);
prop_assert_eq!(stats.transitions_removed, 0);
prop_assert_eq!(stats.total_gain, 0);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn gain_formula(
num_inputs in 0usize..10,
num_outputs in 0usize..10
) {
let mut chain = Chain::<u32, LogWeight>::new(0);
chain.input_labels = (0..num_inputs).map(|i| Some(i as u32)).collect();
chain.output_labels = (0..num_outputs).map(|i| Some(i as u32)).collect();
let gain = compute_chain_gain(&chain);
let expected = (num_inputs as i64) - (num_outputs as i64) - 1;
prop_assert_eq!(gain, expected);
}
#[test]
fn empty_chain_gain(_seed in any::<u64>()) {
let chain = Chain::<u32, LogWeight>::new(0);
prop_assert_eq!(compute_chain_gain(&chain), -1);
}
#[test]
fn positive_gain_condition(extra in 2usize..10) {
let mut chain = Chain::<u32, LogWeight>::new(0);
chain.input_labels = (0..extra).map(|i| Some(i as u32)).collect();
chain.output_labels = vec![];
let gain = compute_chain_gain(&chain);
prop_assert!(gain > 0);
}
#[test]
fn negative_gain_condition(extra in 2usize..10) {
let mut chain = Chain::<u32, LogWeight>::new(0);
chain.input_labels = vec![];
chain.output_labels = (0..extra).map(|i| Some(i as u32)).collect();
let gain = compute_chain_gain(&chain);
prop_assert!(gain < 0);
}
#[test]
fn none_labels_not_counted(num_some in 0usize..5, num_none in 0usize..5) {
let mut chain = Chain::<u32, LogWeight>::new(0);
let mut inputs: Vec<Option<u32>> = (0..num_some).map(|i| Some(i as u32)).collect();
inputs.extend((0..num_none).map(|_| None));
chain.input_labels = inputs;
let gain = compute_chain_gain(&chain);
let expected = (num_some as i64) - 1;
prop_assert_eq!(gain, expected);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(30))]
#[test]
fn empty_fst_no_chains(_seed in any::<u64>()) {
let fst = VectorWfst::<u32, LogWeight>::new();
let chains = find_chains(&fst);
prop_assert!(chains.is_empty());
}
#[test]
fn single_state_no_chains(_seed in any::<u64>()) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
let s = fst.add_state();
fst.set_start(s);
fst.set_final(s, LogWeight::one());
let chains = find_chains(&fst);
prop_assert!(chains.is_empty());
}
#[test]
fn all_final_limited_chains(num_states in 2usize..5) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
let states: Vec<_> = (0..num_states).map(|_| {
let s = fst.add_state();
fst.set_final(s, LogWeight::one());
s
}).collect();
fst.set_start(states[0]);
for i in 0..states.len() - 1 {
fst.add_arc(states[i], Some(i as u32), Some(i as u32), states[i + 1], LogWeight::one());
}
let chains = find_chains(&fst);
prop_assert!(chains.len() <= num_states);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn factor_empty_fst(_seed in any::<u64>()) {
let fst = VectorWfst::<u32, LogWeight>::new();
let config = ChainFactorConfig::default();
let result = chain_factor(&fst, &config);
prop_assert_eq!(result.stats.chains_found, 0);
prop_assert!(result.chains.is_empty());
}
#[test]
fn factor_preserves_structure(num_states in 1usize..5) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
let states: Vec<_> = (0..num_states).map(|_| fst.add_state()).collect();
if !states.is_empty() {
fst.set_start(states[0]);
fst.set_final(*states.last().expect("asr/factoring.rs: required value was None/Err"), LogWeight::one());
for i in 0..states.len() - 1 {
fst.add_arc(states[i], Some(i as u32), Some(i as u32), states[i + 1], LogWeight::one());
}
}
let config = ChainFactorConfig::default();
let result = chain_factor(&fst, &config);
prop_assert_eq!(result.fst.num_states(), fst.num_states());
}
#[test]
fn factor_preserves_start(_seed in any::<u64>()) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
let s = fst.add_state();
fst.set_start(s);
fst.set_final(s, LogWeight::one());
let config = ChainFactorConfig::default();
let result = chain_factor(&fst, &config);
prop_assert!(result.fst.start() != NO_STATE);
}
#[test]
fn factor_stats_bounded(num_states in 0usize..5) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
for _ in 0..num_states {
fst.add_state();
}
if num_states > 0 {
fst.set_start(0);
}
let config = ChainFactorConfig::default();
let result = chain_factor(&fst, &config);
prop_assert!(result.stats.chains_found <= num_states);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn clone_preserves_states(num_states in 0usize..10) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
for _ in 0..num_states {
fst.add_state();
}
let cloned = clone_fst(&fst);
prop_assert_eq!(cloned.num_states(), num_states);
}
#[test]
fn clone_preserves_start(num_states in 1usize..10, start_idx in 0usize..10) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
for _ in 0..num_states {
fst.add_state();
}
let start = (start_idx % num_states) as StateId;
fst.set_start(start);
let cloned = clone_fst(&fst);
prop_assert_eq!(cloned.start(), start);
}
#[test]
fn clone_preserves_finals(num_states in 1usize..5) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
for i in 0..num_states {
let s = fst.add_state();
if i % 2 == 0 {
fst.set_final(s, LogWeight::new(i as f64));
}
}
let cloned = clone_fst(&fst);
for i in 0..num_states as StateId {
prop_assert_eq!(cloned.is_final(i), fst.is_final(i));
}
}
#[test]
fn clone_preserves_arcs(num_states in 2usize..5) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
let states: Vec<_> = (0..num_states).map(|_| fst.add_state()).collect();
fst.set_start(states[0]);
for i in 0..states.len() - 1 {
fst.add_arc(states[i], Some(i as u32), Some(i as u32), states[i + 1], LogWeight::one());
}
let cloned = clone_fst(&fst);
let count_arcs = |f: &VectorWfst<u32, LogWeight>| -> usize {
(0..f.num_states() as StateId)
.map(|s| f.transitions(s).len())
.sum()
};
prop_assert_eq!(count_arcs(&cloned), count_arcs(&fst));
}
}
}