use std::collections::{HashMap, VecDeque};
use std::hash::Hash;
use crate::semiring::Semiring;
use crate::wfst::{MutableWfst, StateId, VectorWfst, WeightedTransition, Wfst, NO_STATE};
#[derive(Debug, Clone)]
pub struct SubsequentialTransducer<L, W>
where
L: Clone + Eq + Hash,
W: Semiring,
{
wfst: VectorWfst<L, W>,
final_outputs: HashMap<StateId, Vec<L>>,
piece_id: usize,
}
impl<L, W> SubsequentialTransducer<L, W>
where
L: Clone + Eq + Hash + Send + Sync + 'static,
W: Semiring,
{
pub fn from_wfst(wfst: VectorWfst<L, W>) -> Option<Self> {
if !Self::is_subsequential(&wfst) {
return None;
}
let final_outputs = HashMap::new();
Some(Self {
wfst,
final_outputs,
piece_id: 0,
})
}
fn is_subsequential(wfst: &VectorWfst<L, W>) -> bool {
for state_id in 0..wfst.num_states() as StateId {
let mut seen_inputs: HashMap<Option<L>, bool> = HashMap::new();
for trans in wfst.transitions(state_id) {
if seen_inputs.contains_key(&trans.input) {
return false; }
seen_inputs.insert(trans.input.clone(), true);
}
}
true
}
pub fn apply(&self, input: &[L]) -> Option<(Vec<L>, W)> {
let start = self.wfst.start();
if start == NO_STATE {
return None;
}
let mut state = start;
let mut output = Vec::new();
let mut weight = W::one();
for symbol in input {
let mut found = false;
for trans in self.wfst.transitions(state) {
if trans.input.as_ref() == Some(symbol) {
if let Some(out) = &trans.output {
output.push(out.clone());
}
weight = weight.times(&trans.weight);
state = trans.to;
found = true;
break;
}
}
if !found {
return None; }
}
if self.wfst.is_final(state) {
let final_weight = self.wfst.final_weight(state);
weight = weight.times(&final_weight);
if let Some(final_out) = self.final_outputs.get(&state) {
output.extend(final_out.iter().cloned());
}
Some((output, weight))
} else {
None
}
}
pub fn wfst(&self) -> &VectorWfst<L, W> {
&self.wfst
}
pub fn piece_id(&self) -> usize {
self.piece_id
}
pub fn set_final_output(&mut self, state: StateId, output: Vec<L>) {
self.final_outputs.insert(state, output);
}
}
#[derive(Debug, Clone)]
pub struct PiecewiseSubsequential<L, W>
where
L: Clone + Eq + Hash,
W: Semiring,
{
pieces: Vec<SubsequentialTransducer<L, W>>,
stats: DecompositionStats,
}
#[derive(Debug, Clone, Default)]
pub struct DecompositionStats {
pub num_pieces: usize,
pub total_states: usize,
pub total_transitions: usize,
pub is_exact: bool,
}
impl<L, W> PiecewiseSubsequential<L, W>
where
L: Clone + Eq + Hash + Send + Sync + 'static,
W: Semiring,
{
pub fn new(pieces: Vec<SubsequentialTransducer<L, W>>) -> Self {
let stats = DecompositionStats {
num_pieces: pieces.len(),
total_states: pieces.iter().map(|p| p.wfst.num_states()).sum(),
total_transitions: pieces
.iter()
.map(|p| {
(0..p.wfst.num_states() as StateId)
.map(|s| p.wfst.transitions(s).len())
.sum::<usize>()
})
.sum(),
is_exact: true, };
Self { pieces, stats }
}
pub fn decompose(wfst: &VectorWfst<L, W>) -> Self
where
W: Clone,
{
if SubsequentialTransducer::is_subsequential(wfst) {
let piece = SubsequentialTransducer {
wfst: wfst.clone(),
final_outputs: HashMap::new(),
piece_id: 0,
};
return Self::new(vec![piece]);
}
let pieces = Self::build_pieces(wfst);
let stats = DecompositionStats {
num_pieces: pieces.len(),
total_states: pieces.iter().map(|p| p.wfst.num_states()).sum(),
total_transitions: pieces
.iter()
.map(|p| {
(0..p.wfst.num_states() as StateId)
.map(|s| p.wfst.transitions(s).len())
.sum::<usize>()
})
.sum(),
is_exact: true,
};
Self { pieces, stats }
}
fn build_pieces(wfst: &VectorWfst<L, W>) -> Vec<SubsequentialTransducer<L, W>>
where
W: Clone,
{
let start = wfst.start();
if start == NO_STATE {
return vec![];
}
let ambiguity_points = Self::find_ambiguity_points(wfst);
if ambiguity_points.is_empty() {
let piece = SubsequentialTransducer {
wfst: wfst.clone(),
final_outputs: HashMap::new(),
piece_id: 0,
};
return vec![piece];
}
let mut pieces = Vec::new();
let mut visited_paths: HashMap<Vec<(StateId, Option<L>)>, bool> = HashMap::new();
let mut queue = VecDeque::new();
queue.push_back((start, Vec::new(), 0usize));
while let Some((state, path, piece_idx)) = queue.pop_front() {
if visited_paths.contains_key(&path) {
continue;
}
visited_paths.insert(path.clone(), true);
let transitions: Vec<_> = wfst.transitions(state).iter().collect();
let mut by_input: HashMap<Option<&L>, Vec<&WeightedTransition<L, W>>> = HashMap::new();
for trans in &transitions {
by_input
.entry(trans.input.as_ref())
.or_default()
.push(trans);
}
for (input, trans_list) in by_input {
if trans_list.len() > 1 {
for (i, trans) in trans_list.iter().enumerate() {
let new_piece_idx = if i == 0 { piece_idx } else { pieces.len() + i };
let mut new_path = path.clone();
new_path.push((state, input.cloned()));
queue.push_back((trans.to, new_path, new_piece_idx));
}
} else if let Some(trans) = trans_list.first() {
let mut new_path = path.clone();
new_path.push((state, input.cloned()));
queue.push_back((trans.to, new_path, piece_idx));
}
}
}
if pieces.is_empty() {
let piece = SubsequentialTransducer {
wfst: wfst.clone(),
final_outputs: HashMap::new(),
piece_id: 0,
};
pieces.push(piece);
}
let max_ambiguity = ambiguity_points
.iter()
.map(|(_, count)| *count)
.max()
.unwrap_or(1);
let mut result = Vec::new();
for i in 0..max_ambiguity {
let mut piece_wfst = VectorWfst::new();
let piece_start = piece_wfst.add_state();
piece_wfst.set_start(piece_start);
Self::copy_with_disambiguation(wfst, &mut piece_wfst, &ambiguity_points, i);
result.push(SubsequentialTransducer {
wfst: piece_wfst,
final_outputs: HashMap::new(),
piece_id: i,
});
}
if result.is_empty() {
let piece = SubsequentialTransducer {
wfst: wfst.clone(),
final_outputs: HashMap::new(),
piece_id: 0,
};
result.push(piece);
}
result
}
fn find_ambiguity_points(wfst: &VectorWfst<L, W>) -> Vec<(StateId, usize)> {
let mut ambiguous = Vec::new();
for state_id in 0..wfst.num_states() as StateId {
let mut input_counts: HashMap<Option<&L>, usize> = HashMap::new();
for trans in wfst.transitions(state_id) {
*input_counts.entry(trans.input.as_ref()).or_insert(0) += 1;
}
let max_count = input_counts.values().max().copied().unwrap_or(1);
if max_count > 1 {
ambiguous.push((state_id, max_count));
}
}
ambiguous
}
fn copy_with_disambiguation(
source: &VectorWfst<L, W>,
target: &mut VectorWfst<L, W>,
ambiguity_points: &[(StateId, usize)],
alternative_idx: usize,
) where
W: Clone,
{
let start = source.start();
if start == NO_STATE {
return;
}
let mut state_map: HashMap<StateId, StateId> = HashMap::new();
let target_start = if target.start() == NO_STATE {
let s = target.add_state();
target.set_start(s);
s
} else {
target.start()
};
state_map.insert(start, target_start);
let mut queue = VecDeque::new();
queue.push_back(start);
let ambiguous_states: HashMap<StateId, usize> = ambiguity_points.iter().copied().collect();
while let Some(src_state) = queue.pop_front() {
let tgt_state = *state_map.get(&src_state).expect("State should be mapped");
if source.is_final(src_state) {
target.set_final(tgt_state, source.final_weight(src_state));
}
let mut by_input: HashMap<Option<&L>, Vec<&WeightedTransition<L, W>>> = HashMap::new();
for trans in source.transitions(src_state) {
by_input
.entry(trans.input.as_ref())
.or_default()
.push(trans);
}
for (_, trans_list) in by_input {
let trans = if trans_list.len() > 1 && ambiguous_states.contains_key(&src_state) {
trans_list.get(alternative_idx % trans_list.len())
} else {
trans_list.first()
};
if let Some(trans) = trans {
let to_state = *state_map.entry(trans.to).or_insert_with(|| {
let new_state = target.add_state();
queue.push_back(trans.to);
new_state
});
target.add_transition(WeightedTransition::new(
tgt_state,
trans.input.clone(),
trans.output.clone(),
to_state,
trans.weight.clone(),
));
}
}
}
}
pub fn apply(&self, input: &[L]) -> Vec<(Vec<L>, W)> {
let mut results = Vec::new();
for piece in &self.pieces {
if let Some(result) = piece.apply(input) {
results.push(result);
}
}
results
}
pub fn apply_unique(&self, input: &[L]) -> Vec<(Vec<L>, W)>
where
L: Ord,
{
let mut results = self.apply(input);
results.sort_by(|a, b| a.0.cmp(&b.0));
results.dedup_by(|a, b| a.0 == b.0);
results
}
pub fn num_pieces(&self) -> usize {
self.pieces.len()
}
pub fn stats(&self) -> &DecompositionStats {
&self.stats
}
pub fn pieces(&self) -> &[SubsequentialTransducer<L, W>] {
&self.pieces
}
pub fn pieces_mut(&mut self) -> &mut Vec<SubsequentialTransducer<L, W>> {
&mut self.pieces
}
pub fn is_trivial(&self) -> bool {
self.pieces.len() == 1
}
pub fn degree(&self) -> usize {
self.pieces.len()
}
}
#[derive(Debug, Clone)]
pub struct PiecewiseBuilder<L, W>
where
L: Clone + Eq + Hash,
W: Semiring,
{
pieces: Vec<SubsequentialTransducer<L, W>>,
}
impl<L, W> PiecewiseBuilder<L, W>
where
L: Clone + Eq + Hash + Send + Sync + 'static,
W: Semiring,
{
pub fn new() -> Self {
Self { pieces: Vec::new() }
}
pub fn add_piece(mut self, piece: SubsequentialTransducer<L, W>) -> Self {
self.pieces.push(piece);
self
}
pub fn add_wfst(mut self, wfst: VectorWfst<L, W>) -> Option<Self> {
let piece = SubsequentialTransducer::from_wfst(wfst)?;
self.pieces.push(piece);
Some(self)
}
pub fn build(self) -> PiecewiseSubsequential<L, W> {
PiecewiseSubsequential::new(self.pieces)
}
}
impl<L, W> Default for PiecewiseBuilder<L, W>
where
L: Clone + Eq + Hash + Send + Sync + 'static,
W: Semiring,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
fn make_simple_fst() -> VectorWfst<char, TropicalWeight> {
let mut fst = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
let s2 = fst.add_state();
fst.set_start(s0);
fst.set_final(s2, TropicalWeight::one());
fst.add_transition(WeightedTransition::new(
s0,
Some('a'),
Some('A'),
s1,
TropicalWeight::one(),
));
fst.add_transition(WeightedTransition::new(
s1,
Some('b'),
Some('B'),
s2,
TropicalWeight::one(),
));
fst
}
fn make_ambiguous_fst() -> VectorWfst<char, TropicalWeight> {
let mut fst = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
let s2 = fst.add_state();
let s3 = fst.add_state();
fst.set_start(s0);
fst.set_final(s2, TropicalWeight::one());
fst.set_final(s3, TropicalWeight::one());
fst.add_transition(WeightedTransition::new(
s0,
Some('a'),
Some('X'),
s1,
TropicalWeight::new(1.0),
));
fst.add_transition(WeightedTransition::new(
s0,
Some('a'),
Some('Y'),
s1,
TropicalWeight::new(2.0),
));
fst.add_transition(WeightedTransition::new(
s1,
Some('b'),
Some('B'),
s2,
TropicalWeight::one(),
));
fst.add_transition(WeightedTransition::new(
s1,
Some('c'),
Some('C'),
s3,
TropicalWeight::one(),
));
fst
}
#[test]
fn test_subsequential_check() {
let simple = make_simple_fst();
assert!(SubsequentialTransducer::<char, TropicalWeight>::is_subsequential(&simple));
let ambiguous = make_ambiguous_fst();
assert!(!SubsequentialTransducer::<char, TropicalWeight>::is_subsequential(&ambiguous));
}
#[test]
fn test_subsequential_from_wfst() {
let simple = make_simple_fst();
let subseq = SubsequentialTransducer::from_wfst(simple);
assert!(subseq.is_some());
let ambiguous = make_ambiguous_fst();
let subseq = SubsequentialTransducer::from_wfst(ambiguous);
assert!(subseq.is_none());
}
#[test]
fn test_subsequential_apply() {
let fst = make_simple_fst();
let subseq = SubsequentialTransducer::from_wfst(fst).expect("Should be subsequential");
let result = subseq.apply(&['a', 'b']);
assert!(result.is_some());
let (output, _weight) = result.expect("subsequential/mod.rs: required value was None/Err");
assert_eq!(output, vec!['A', 'B']);
}
#[test]
fn test_subsequential_apply_not_accepted() {
let fst = make_simple_fst();
let subseq = SubsequentialTransducer::from_wfst(fst).expect("Should be subsequential");
let result = subseq.apply(&['a']);
assert!(result.is_none());
let result = subseq.apply(&['x']);
assert!(result.is_none());
}
#[test]
fn test_decompose_subsequential() {
let fst = make_simple_fst();
let piecewise = PiecewiseSubsequential::decompose(&fst);
assert!(piecewise.is_trivial());
assert_eq!(piecewise.num_pieces(), 1);
}
#[test]
fn test_decompose_ambiguous() {
let fst = make_ambiguous_fst();
let piecewise = PiecewiseSubsequential::decompose(&fst);
assert!(piecewise.num_pieces() >= 1);
}
#[test]
fn test_piecewise_apply() {
let fst = make_simple_fst();
let piecewise = PiecewiseSubsequential::decompose(&fst);
let results = piecewise.apply(&['a', 'b']);
assert!(!results.is_empty());
assert_eq!(results[0].0, vec!['A', 'B']);
}
#[test]
fn test_piecewise_stats() {
let fst = make_simple_fst();
let piecewise = PiecewiseSubsequential::decompose(&fst);
let stats = piecewise.stats();
assert_eq!(stats.num_pieces, 1);
assert!(stats.total_states > 0);
}
#[test]
fn test_builder() {
let fst = make_simple_fst();
let piece = SubsequentialTransducer::from_wfst(fst).expect("Should be subsequential");
let piecewise = PiecewiseBuilder::new().add_piece(piece).build();
assert_eq!(piecewise.num_pieces(), 1);
}
#[test]
fn test_builder_add_wfst() {
let fst = make_simple_fst();
let builder = PiecewiseBuilder::<char, TropicalWeight>::new().add_wfst(fst);
assert!(builder.is_some());
let piecewise = builder
.expect("subsequential/mod.rs: required value was None/Err")
.build();
assert_eq!(piecewise.num_pieces(), 1);
}
#[test]
fn test_builder_reject_ambiguous() {
let fst = make_ambiguous_fst();
let builder = PiecewiseBuilder::<char, TropicalWeight>::new().add_wfst(fst);
assert!(builder.is_none()); }
#[test]
fn test_degree() {
let fst = make_simple_fst();
let piecewise = PiecewiseSubsequential::decompose(&fst);
assert_eq!(piecewise.degree(), 1);
}
#[test]
fn test_piece_id() {
let fst = make_simple_fst();
let piece = SubsequentialTransducer::from_wfst(fst).expect("Should be subsequential");
assert_eq!(piece.piece_id(), 0);
}
#[test]
fn test_empty_fst() {
let fst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let piecewise = PiecewiseSubsequential::decompose(&fst);
let results = piecewise.apply(&['a']);
assert!(results.is_empty());
}
#[test]
fn test_set_final_output() {
let fst = make_simple_fst();
let mut subseq = SubsequentialTransducer::from_wfst(fst).expect("Should be subsequential");
subseq.set_final_output(2, vec!['!']);
let result = subseq.apply(&['a', 'b']);
assert!(result.is_some());
let (output, _) = result.expect("subsequential/mod.rs: required value was None/Err");
assert_eq!(output, vec!['A', 'B', '!']);
}
#[test]
fn test_apply_unique() {
let fst = make_simple_fst();
let piecewise = PiecewiseSubsequential::decompose(&fst);
let results = piecewise.apply_unique(&['a', 'b']);
assert!(!results.is_empty());
let mut seen = std::collections::HashSet::new();
for (output, _) in &results {
assert!(seen.insert(output.clone()), "Duplicate output found");
}
}
#[test]
fn test_find_ambiguity_points() {
let fst = make_ambiguous_fst();
let ambiguous = PiecewiseSubsequential::<char, TropicalWeight>::find_ambiguity_points(&fst);
assert!(!ambiguous.is_empty());
assert!(ambiguous
.iter()
.any(|(state, count)| *state == 0 && *count == 2));
}
#[test]
fn test_decomposition_stats() {
let fst = make_ambiguous_fst();
let piecewise = PiecewiseSubsequential::decompose(&fst);
let stats = piecewise.stats();
assert!(stats.num_pieces >= 1);
assert!(stats.total_states > 0);
assert!(stats.total_transitions > 0);
}
}