use crate::semiring::Semiring;
use crate::wfst::{MutableWfst, StateId, VectorWfst};
pub type CtcLabel = u32;
pub const BLANK: CtcLabel = 0;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct CtcTopologyInfo {
pub num_states: usize,
pub num_arcs: usize,
pub vocab_size: usize,
pub selfless: bool,
}
#[derive(Clone, Debug)]
pub struct CtcTopology<W: Semiring> {
fst: VectorWfst<CtcLabel, W>,
info: CtcTopologyInfo,
}
impl<W: Semiring> CtcTopology<W> {
#[inline]
pub fn fst(&self) -> &VectorWfst<CtcLabel, W> {
&self.fst
}
#[inline]
pub fn fst_mut(&mut self) -> &mut VectorWfst<CtcLabel, W> {
&mut self.fst
}
#[inline]
pub fn into_fst(self) -> VectorWfst<CtcLabel, W> {
self.fst
}
#[inline]
pub fn info(&self) -> CtcTopologyInfo {
self.info
}
#[inline]
pub fn vocab_size(&self) -> usize {
self.info.vocab_size
}
}
pub fn correct_ctc<W: Semiring>(vocab_size: usize) -> CtcTopology<W> {
assert!(vocab_size >= 1, "vocab_size must be at least 1 (for blank)");
let num_arcs = vocab_size * vocab_size;
let mut fst = VectorWfst::with_capacity(vocab_size);
for _ in 0..vocab_size {
fst.add_state();
}
fst.set_start(0);
for s in 0..vocab_size as StateId {
fst.set_final(s, W::one());
}
for s in 0..vocab_size as StateId {
fst.reserve_transitions(s, vocab_size);
}
for from in 0..vocab_size as StateId {
for label in 0..vocab_size as CtcLabel {
let to = label as StateId;
let output = if label == BLANK { None } else { Some(label) };
fst.add_arc(from, Some(label), output, to, W::one());
}
}
CtcTopology {
fst,
info: CtcTopologyInfo {
num_states: vocab_size,
num_arcs,
vocab_size,
selfless: false,
},
}
}
pub fn compact_ctc<W: Semiring>(vocab_size: usize) -> CtcTopology<W> {
assert!(vocab_size >= 1, "vocab_size must be at least 1 (for blank)");
let num_arcs = 3 * vocab_size - 2;
let mut fst = VectorWfst::with_capacity(vocab_size);
for _ in 0..vocab_size {
fst.add_state();
}
fst.set_start(0);
for s in 0..vocab_size as StateId {
fst.set_final(s, W::one());
}
fst.reserve_transitions(0, vocab_size); for s in 1..vocab_size as StateId {
fst.reserve_transitions(s, 2); }
for label in 0..vocab_size as CtcLabel {
let to = label as StateId;
let output = if label == BLANK { None } else { Some(label) };
fst.add_arc(0, Some(label), output, to, W::one());
}
for s in 1..vocab_size as StateId {
let label = s as CtcLabel;
fst.add_arc(s, Some(label), Some(label), s, W::one());
fst.add_epsilon(s, 0, W::one());
}
CtcTopology {
fst,
info: CtcTopologyInfo {
num_states: vocab_size,
num_arcs,
vocab_size,
selfless: false,
},
}
}
pub fn minimal_ctc<W: Semiring>(vocab_size: usize) -> CtcTopology<W> {
assert!(vocab_size >= 1, "vocab_size must be at least 1 (for blank)");
let mut fst = VectorWfst::with_capacity(1);
let state = fst.add_state();
fst.set_start(state);
fst.set_final(state, W::one());
fst.reserve_transitions(state, vocab_size);
for label in 0..vocab_size as CtcLabel {
let output = if label == BLANK { None } else { Some(label) };
fst.add_arc(state, Some(label), output, state, W::one());
}
CtcTopology {
fst,
info: CtcTopologyInfo {
num_states: 1,
num_arcs: vocab_size,
vocab_size,
selfless: true, },
}
}
pub fn selfless_correct_ctc<W: Semiring>(vocab_size: usize) -> CtcTopology<W> {
assert!(vocab_size >= 1, "vocab_size must be at least 1 (for blank)");
let num_arcs = vocab_size * vocab_size - (vocab_size - 1);
let mut fst = VectorWfst::with_capacity(vocab_size);
for _ in 0..vocab_size {
fst.add_state();
}
fst.set_start(0);
for s in 0..vocab_size as StateId {
fst.set_final(s, W::one());
}
for s in 0..vocab_size as StateId {
let num_trans = if s == 0 { vocab_size } else { vocab_size - 1 };
fst.reserve_transitions(s, num_trans);
}
for from in 0..vocab_size as StateId {
for label in 0..vocab_size as CtcLabel {
let to = label as StateId;
if from != 0 && from == to {
continue;
}
let output = if label == BLANK { None } else { Some(label) };
fst.add_arc(from, Some(label), output, to, W::one());
}
}
CtcTopology {
fst,
info: CtcTopologyInfo {
num_states: vocab_size,
num_arcs,
vocab_size,
selfless: true,
},
}
}
pub fn selfless_compact_ctc<W: Semiring>(vocab_size: usize) -> CtcTopology<W> {
assert!(vocab_size >= 1, "vocab_size must be at least 1 (for blank)");
let num_arcs = 2 * vocab_size - 1;
let mut fst = VectorWfst::with_capacity(vocab_size);
for _ in 0..vocab_size {
fst.add_state();
}
fst.set_start(0);
for s in 0..vocab_size as StateId {
fst.set_final(s, W::one());
}
fst.reserve_transitions(0, vocab_size); for s in 1..vocab_size as StateId {
fst.reserve_transitions(s, 1); }
for label in 0..vocab_size as CtcLabel {
let to = label as StateId;
let output = if label == BLANK { None } else { Some(label) };
fst.add_arc(0, Some(label), output, to, W::one());
}
for s in 1..vocab_size as StateId {
fst.add_epsilon(s, 0, W::one());
}
CtcTopology {
fst,
info: CtcTopologyInfo {
num_states: vocab_size,
num_arcs,
vocab_size,
selfless: true,
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::{LogWeight, TropicalWeight};
use crate::wfst::Wfst;
#[test]
fn test_correct_ctc_structure() {
let ctc = correct_ctc::<LogWeight>(5);
let fst = ctc.fst();
assert_eq!(fst.num_states(), 5);
assert_eq!(fst.start(), 0);
for s in 0..5 {
assert!(fst.is_final(s));
}
assert_eq!(fst.total_transitions(), 25);
for s in 0..5 {
assert_eq!(fst.transitions(s).len(), 5);
}
}
#[test]
fn test_correct_ctc_blank_epsilon() {
let ctc = correct_ctc::<LogWeight>(3);
let fst = ctc.fst();
for s in 0..3 {
let blank_arc = fst
.transitions(s)
.iter()
.find(|t| t.input == Some(0))
.expect("Should have blank arc");
assert_eq!(blank_arc.output, None); assert_eq!(blank_arc.to, 0); }
}
#[test]
fn test_compact_ctc_structure() {
let ctc = compact_ctc::<LogWeight>(5);
let fst = ctc.fst();
assert_eq!(fst.num_states(), 5);
assert_eq!(fst.total_transitions(), 13);
assert_eq!(fst.transitions(0).len(), 5);
for s in 1..5 {
assert_eq!(fst.transitions(s).len(), 2);
}
}
#[test]
fn test_compact_ctc_back_off() {
let ctc = compact_ctc::<LogWeight>(4);
let fst = ctc.fst();
for s in 1..4 {
let eps_arc = fst
.transitions(s)
.iter()
.find(|t| t.is_epsilon())
.expect("Should have epsilon arc");
assert_eq!(eps_arc.to, 0); }
}
#[test]
fn test_minimal_ctc_structure() {
let ctc = minimal_ctc::<LogWeight>(10);
let fst = ctc.fst();
assert_eq!(fst.num_states(), 1);
assert_eq!(fst.total_transitions(), 10);
assert_eq!(fst.start(), 0);
assert!(fst.is_final(0));
for t in fst.transitions(0) {
assert_eq!(t.to, 0);
}
}
#[test]
fn test_selfless_correct_ctc_no_self_loops() {
let ctc = selfless_correct_ctc::<LogWeight>(4);
let fst = ctc.fst();
for s in 1..4 {
for t in fst.transitions(s) {
assert!(
t.to != s || t.input == Some(0),
"State {} should not have non-blank self-loop",
s
);
}
}
let blank_self = fst
.transitions(0)
.iter()
.find(|t| t.input == Some(0) && t.to == 0);
assert!(blank_self.is_some());
}
#[test]
fn test_selfless_compact_ctc_no_self_loops() {
let ctc = selfless_compact_ctc::<LogWeight>(4);
let fst = ctc.fst();
for s in 1..4 {
assert_eq!(fst.transitions(s).len(), 1);
let t = &fst.transitions(s)[0];
assert!(t.is_epsilon());
assert_eq!(t.to, 0);
}
}
#[test]
fn test_topology_arc_counts() {
for n in [5, 10, 50, 100] {
let correct = correct_ctc::<TropicalWeight>(n);
let compact = compact_ctc::<TropicalWeight>(n);
let minimal = minimal_ctc::<TropicalWeight>(n);
let selfless_c = selfless_correct_ctc::<TropicalWeight>(n);
let selfless_k = selfless_compact_ctc::<TropicalWeight>(n);
assert_eq!(correct.info().num_arcs, n * n);
assert_eq!(compact.info().num_arcs, 3 * n - 2);
assert_eq!(minimal.info().num_arcs, n);
assert_eq!(selfless_c.info().num_arcs, n * n - (n - 1));
assert_eq!(selfless_k.info().num_arcs, 2 * n - 1);
assert_eq!(correct.fst().total_transitions(), correct.info().num_arcs);
assert_eq!(compact.fst().total_transitions(), compact.info().num_arcs);
assert_eq!(minimal.fst().total_transitions(), minimal.info().num_arcs);
assert_eq!(
selfless_c.fst().total_transitions(),
selfless_c.info().num_arcs
);
assert_eq!(
selfless_k.fst().total_transitions(),
selfless_k.info().num_arcs
);
}
}
#[test]
fn test_large_vocabulary() {
let correct = correct_ctc::<LogWeight>(1000);
let compact = compact_ctc::<LogWeight>(1000);
let minimal = minimal_ctc::<LogWeight>(1000);
assert_eq!(correct.info().num_arcs, 1_000_000); assert_eq!(compact.info().num_arcs, 2998); assert_eq!(minimal.info().num_arcs, 1000);
assert!(correct.info().num_arcs / compact.info().num_arcs > 300);
assert_eq!(correct.info().num_arcs / minimal.info().num_arcs, 1000);
}
#[test]
fn test_info_consistency() {
let ctc = correct_ctc::<LogWeight>(10);
let info = ctc.info();
assert_eq!(info.num_states, ctc.fst().num_states());
assert_eq!(info.num_arcs, ctc.fst().total_transitions());
assert_eq!(info.vocab_size, 10);
assert!(!info.selfless);
let selfless = selfless_correct_ctc::<LogWeight>(10);
assert!(selfless.info().selfless);
}
#[test]
#[should_panic(expected = "vocab_size must be at least 1")]
fn test_empty_vocabulary_panics() {
let _ = correct_ctc::<LogWeight>(0);
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use crate::semiring::{LogWeight, TropicalWeight};
use crate::wfst::Wfst;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn correct_ctc_size(n in 1usize..50) {
let ctc = correct_ctc::<LogWeight>(n);
prop_assert_eq!(ctc.info().num_states, n);
prop_assert_eq!(ctc.info().num_arcs, n * n);
prop_assert_eq!(ctc.fst().num_states(), n);
prop_assert_eq!(ctc.fst().total_transitions(), n * n);
}
#[test]
fn compact_ctc_size(n in 1usize..50) {
let ctc = compact_ctc::<LogWeight>(n);
prop_assert_eq!(ctc.info().num_states, n);
prop_assert_eq!(ctc.info().num_arcs, 3 * n - 2);
prop_assert_eq!(ctc.fst().num_states(), n);
prop_assert_eq!(ctc.fst().total_transitions(), 3 * n - 2);
}
#[test]
fn minimal_ctc_size(n in 1usize..100) {
let ctc = minimal_ctc::<LogWeight>(n);
prop_assert_eq!(ctc.info().num_states, 1);
prop_assert_eq!(ctc.info().num_arcs, n);
prop_assert_eq!(ctc.fst().num_states(), 1);
prop_assert_eq!(ctc.fst().total_transitions(), n);
}
#[test]
fn selfless_correct_ctc_size(n in 1usize..50) {
let ctc = selfless_correct_ctc::<LogWeight>(n);
prop_assert_eq!(ctc.info().num_states, n);
let expected_arcs = n * n - (n - 1);
prop_assert_eq!(ctc.info().num_arcs, expected_arcs);
prop_assert_eq!(ctc.fst().total_transitions(), expected_arcs);
}
#[test]
fn selfless_compact_ctc_size(n in 1usize..50) {
let ctc = selfless_compact_ctc::<LogWeight>(n);
prop_assert_eq!(ctc.info().num_states, n);
prop_assert_eq!(ctc.info().num_arcs, 2 * n - 1);
prop_assert_eq!(ctc.fst().total_transitions(), 2 * n - 1);
}
#[test]
fn correct_ctc_start_final(n in 1usize..20) {
let ctc = correct_ctc::<LogWeight>(n);
let fst = ctc.fst();
prop_assert_eq!(fst.start(), 0);
for s in 0..n as StateId {
prop_assert!(fst.is_final(s), "State {} should be final", s);
}
}
#[test]
fn compact_ctc_start_final(n in 1usize..20) {
let ctc = compact_ctc::<LogWeight>(n);
let fst = ctc.fst();
prop_assert_eq!(fst.start(), 0);
for s in 0..n as StateId {
prop_assert!(fst.is_final(s), "State {} should be final", s);
}
}
#[test]
fn minimal_ctc_start_final(n in 1usize..100) {
let ctc = minimal_ctc::<LogWeight>(n);
let fst = ctc.fst();
prop_assert_eq!(fst.start(), 0);
prop_assert!(fst.is_final(0));
}
#[test]
fn correct_ctc_blank_epsilon(n in 2usize..20) {
let ctc = correct_ctc::<LogWeight>(n);
let fst = ctc.fst();
for s in 0..n as StateId {
for t in fst.transitions(s) {
if t.input == Some(BLANK) {
prop_assert_eq!(t.output, None, "Blank should output epsilon");
} else {
prop_assert_eq!(t.output, t.input, "Non-blank should output itself");
}
}
}
}
#[test]
fn compact_ctc_blank_epsilon(n in 2usize..20) {
let ctc = compact_ctc::<LogWeight>(n);
let fst = ctc.fst();
for t in fst.transitions(0) {
if t.input == Some(BLANK) {
prop_assert_eq!(t.output, None, "Blank should output epsilon");
} else if t.input.is_some() {
prop_assert_eq!(t.output, t.input, "Non-blank should output itself");
}
}
}
#[test]
fn minimal_ctc_blank_epsilon(n in 1usize..50) {
let ctc = minimal_ctc::<LogWeight>(n);
let fst = ctc.fst();
for t in fst.transitions(0) {
if t.input == Some(BLANK) {
prop_assert_eq!(t.output, None, "Blank should output epsilon");
} else {
prop_assert_eq!(t.output, t.input, "Non-blank should output itself");
}
}
}
#[test]
fn selfless_correct_no_self_loops(n in 2usize..20) {
let ctc = selfless_correct_ctc::<LogWeight>(n);
let fst = ctc.fst();
for s in 1..n as StateId {
for t in fst.transitions(s) {
if t.to == s {
prop_assert_eq!(t.input, Some(BLANK),
"State {} has non-blank self-loop with label {:?}", s, t.input);
}
}
}
prop_assert!(ctc.info().selfless);
}
#[test]
fn selfless_compact_no_self_loops(n in 2usize..20) {
let ctc = selfless_compact_ctc::<LogWeight>(n);
let fst = ctc.fst();
for s in 1..n as StateId {
for t in fst.transitions(s) {
prop_assert!(t.to != s || t.is_epsilon(),
"State {} has non-epsilon self-loop", s);
}
}
prop_assert!(ctc.info().selfless);
}
#[test]
fn minimal_ctc_selfless(n in 1usize..50) {
let ctc = minimal_ctc::<LogWeight>(n);
prop_assert!(ctc.info().selfless);
}
#[test]
fn standard_topologies_not_selfless(n in 2usize..20) {
let correct = correct_ctc::<LogWeight>(n);
let compact = compact_ctc::<LogWeight>(n);
prop_assert!(!correct.info().selfless);
prop_assert!(!compact.info().selfless);
}
#[test]
fn selfless_correct_arc_difference(n in 2usize..50) {
let correct = correct_ctc::<LogWeight>(n);
let selfless = selfless_correct_ctc::<LogWeight>(n);
let diff = correct.info().num_arcs - selfless.info().num_arcs;
prop_assert_eq!(diff, n - 1, "Should have N-1 fewer arcs");
}
#[test]
fn selfless_compact_arc_difference(n in 2usize..50) {
let compact = compact_ctc::<LogWeight>(n);
let selfless = selfless_compact_ctc::<LogWeight>(n);
let diff = compact.info().num_arcs - selfless.info().num_arcs;
prop_assert_eq!(diff, n - 1, "Should have N-1 fewer arcs");
}
#[test]
fn compact_ctc_back_off(n in 2usize..20) {
let ctc = compact_ctc::<LogWeight>(n);
let fst = ctc.fst();
for s in 1..n as StateId {
let has_eps_to_blank = fst.transitions(s)
.iter()
.any(|t| t.is_epsilon() && t.to == 0);
prop_assert!(has_eps_to_blank,
"State {} should have epsilon transition to blank", s);
}
}
#[test]
fn selfless_compact_ctc_only_back_off(n in 2usize..20) {
let ctc = selfless_compact_ctc::<LogWeight>(n);
let fst = ctc.fst();
for s in 1..n as StateId {
prop_assert_eq!(fst.transitions(s).len(), 1,
"Non-blank state {} should have exactly 1 transition", s);
let t = &fst.transitions(s)[0];
prop_assert!(t.is_epsilon(), "Should be epsilon transition");
prop_assert_eq!(t.to, 0, "Should go to blank state");
}
}
#[test]
fn correct_ctc_complete_graph(n in 2usize..15) {
let ctc = correct_ctc::<LogWeight>(n);
let fst = ctc.fst();
for from in 0..n as StateId {
let destinations: std::collections::HashSet<_> = fst.transitions(from)
.iter()
.map(|t| t.to)
.collect();
for to in 0..n as StateId {
prop_assert!(destinations.contains(&to),
"State {} should have transition to state {}", from, to);
}
}
}
#[test]
fn correct_ctc_outdegree(n in 1usize..20) {
let ctc = correct_ctc::<LogWeight>(n);
let fst = ctc.fst();
for s in 0..n as StateId {
prop_assert_eq!(fst.transitions(s).len(), n,
"State {} should have {} transitions", s, n);
}
}
#[test]
fn minimal_ctc_all_self_loops(n in 1usize..50) {
let ctc = minimal_ctc::<LogWeight>(n);
let fst = ctc.fst();
for t in fst.transitions(0) {
prop_assert_eq!(t.to, 0, "All transitions should go to state 0");
}
}
#[test]
fn minimal_ctc_all_labels(n in 1usize..50) {
let ctc = minimal_ctc::<LogWeight>(n);
let fst = ctc.fst();
let labels: std::collections::HashSet<_> = fst.transitions(0)
.iter()
.filter_map(|t| t.input)
.collect();
for label in 0..n as CtcLabel {
prop_assert!(labels.contains(&label), "Should have arc for label {}", label);
}
}
#[test]
fn vocab_size_matches(n in 1usize..100) {
prop_assert_eq!(correct_ctc::<LogWeight>(n).vocab_size(), n);
prop_assert_eq!(compact_ctc::<LogWeight>(n).vocab_size(), n);
prop_assert_eq!(minimal_ctc::<LogWeight>(n).vocab_size(), n);
prop_assert_eq!(selfless_correct_ctc::<LogWeight>(n).vocab_size(), n);
prop_assert_eq!(selfless_compact_ctc::<LogWeight>(n).vocab_size(), n);
}
#[test]
fn info_vocab_size_consistent(n in 1usize..50) {
let ctc = correct_ctc::<LogWeight>(n);
prop_assert_eq!(ctc.info().vocab_size, ctc.vocab_size());
}
#[test]
fn all_unit_weights(n in 1usize..20) {
let ctc = correct_ctc::<LogWeight>(n);
let fst = ctc.fst();
for s in 0..n as StateId {
for t in fst.transitions(s) {
prop_assert_eq!(t.weight, LogWeight::one(),
"Transition weight should be one");
}
}
}
#[test]
fn all_final_weights_one(n in 1usize..20) {
let ctc = correct_ctc::<LogWeight>(n);
let fst = ctc.fst();
for s in 0..n as StateId {
if fst.is_final(s) {
let w = fst.final_weight(s);
prop_assert_eq!(w, LogWeight::one(),
"Final weight should be one");
}
}
}
#[test]
fn minimal_smaller_than_compact(n in 2usize..50) {
let compact = compact_ctc::<LogWeight>(n);
let minimal = minimal_ctc::<LogWeight>(n);
prop_assert!(minimal.info().num_arcs < compact.info().num_arcs,
"Minimal ({}) should be smaller than Compact ({})",
minimal.info().num_arcs, compact.info().num_arcs);
}
#[test]
fn compact_smaller_than_correct(n in 3usize..50) {
let correct = correct_ctc::<LogWeight>(n);
let compact = compact_ctc::<LogWeight>(n);
prop_assert!(compact.info().num_arcs < correct.info().num_arcs,
"Compact ({}) should be smaller than Correct ({})",
compact.info().num_arcs, correct.info().num_arcs);
}
#[test]
fn topology_size_ordering(n in 4usize..30) {
let correct = correct_ctc::<LogWeight>(n);
let selfless_c = selfless_correct_ctc::<LogWeight>(n);
let compact = compact_ctc::<LogWeight>(n);
let selfless_k = selfless_compact_ctc::<LogWeight>(n);
let minimal = minimal_ctc::<LogWeight>(n);
prop_assert!(minimal.info().num_arcs < selfless_k.info().num_arcs);
prop_assert!(selfless_k.info().num_arcs < compact.info().num_arcs);
prop_assert!(compact.info().num_arcs < selfless_c.info().num_arcs);
prop_assert!(selfless_c.info().num_arcs < correct.info().num_arcs);
}
#[test]
fn works_with_tropical(n in 1usize..20) {
let ctc = correct_ctc::<TropicalWeight>(n);
prop_assert_eq!(ctc.info().num_states, n);
prop_assert_eq!(ctc.fst().num_states(), n);
}
#[test]
fn into_fst_preserves_structure(n in 1usize..20) {
let ctc = correct_ctc::<LogWeight>(n);
let info = ctc.info();
let fst = ctc.into_fst();
prop_assert_eq!(fst.num_states(), info.num_states);
prop_assert_eq!(fst.total_transitions(), info.num_arcs);
}
#[test]
fn fst_mut_allows_modification(n in 2usize..10) {
let mut ctc = correct_ctc::<LogWeight>(n);
let original_arcs = ctc.fst().total_transitions();
ctc.fst_mut().add_arc(0, Some(0), None, 0, LogWeight::new(1.0));
prop_assert_eq!(ctc.fst().total_transitions(), original_arcs + 1);
}
}
}