use std::fmt::Debug;
use std::hash::Hash;
use std::marker::PhantomData;
use crate::algorithms::{determinize, minimize, DeterminizeConfig, MinimizeConfig};
use crate::composition::{compose, materialize};
use crate::semiring::{
DivisibleSemiring, NumericalWeight, QuantizableSemiring, Semiring, TotallyOrderedSemiring,
};
use crate::wfst::{MutableWfst, StateId, VectorWfst, Wfst, NO_STATE};
use super::context::PhoneId;
use super::ngram::{NgramTransducer, WordId};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum AuxiliarySymbol {
WordBoundary,
Disambiguation(u32),
Epsilon,
}
#[derive(Clone, Debug)]
pub struct LexiconEntry<W: Semiring> {
pub word: WordId,
pub phones: Vec<PhoneId>,
pub weight: W,
pub auxiliaries: Vec<AuxiliarySymbol>,
}
impl<W: Semiring> LexiconEntry<W> {
pub fn new(word: WordId, phones: Vec<PhoneId>, weight: W) -> Self {
Self {
word,
phones,
weight,
auxiliaries: Vec::new(),
}
}
pub fn with_auxiliaries(mut self, aux: Vec<AuxiliarySymbol>) -> Self {
self.auxiliaries = aux;
self
}
}
#[derive(Clone, Debug)]
pub struct CascadeConfig {
pub incremental_det: bool,
pub minimize: bool,
pub lazy: bool,
pub max_homophony: u32,
pub word_boundaries: bool,
}
impl Default for CascadeConfig {
fn default() -> Self {
Self {
incremental_det: true,
minimize: true,
lazy: false,
max_homophony: 10,
word_boundaries: true,
}
}
}
pub struct AsrCascade<W: Semiring> {
pub fst: VectorWfst<PhoneId, W>,
pub config: CascadeConfig,
pub stats: CascadeStats,
}
#[derive(Clone, Debug, Default)]
pub struct CascadeStats {
pub g_states: usize,
pub lg_states: usize,
pub det_lg_states: usize,
pub clg_states: usize,
pub det_clg_states: usize,
pub final_states: usize,
pub final_arcs: usize,
}
pub struct CascadeBuilder<W: Semiring> {
config: CascadeConfig,
grammar: Option<VectorWfst<WordId, W>>,
lexicon: Vec<LexiconEntry<W>>,
context: Option<VectorWfst<PhoneId, W>>,
hmm: Option<VectorWfst<PhoneId, W>>,
_weight: PhantomData<W>,
}
impl<W> CascadeBuilder<W>
where
W: Semiring + Clone,
{
pub fn new() -> Self {
Self {
config: CascadeConfig::default(),
grammar: None,
lexicon: Vec::new(),
context: None,
hmm: None,
_weight: PhantomData,
}
}
pub fn config(mut self, config: CascadeConfig) -> Self {
self.config = config;
self
}
pub fn grammar(mut self, grammar: VectorWfst<WordId, W>) -> Self {
self.grammar = Some(grammar);
self
}
pub fn grammar_from_ngram(self, ngram: NgramTransducer<W>) -> Self {
self.grammar(ngram.fst)
}
pub fn add_lexicon_entry(&mut self, entry: LexiconEntry<W>) {
self.lexicon.push(entry);
}
pub fn context_dependency(mut self, context: VectorWfst<PhoneId, W>) -> Self {
self.context = Some(context);
self
}
pub fn hmm(mut self, hmm: VectorWfst<PhoneId, W>) -> Self {
self.hmm = Some(hmm);
self
}
fn build_lexicon_for_composition(&self) -> VectorWfst<u32, W> {
let mut fst: VectorWfst<u32, W> = VectorWfst::new();
let start = fst.add_state();
fst.set_start(start);
fst.set_final(start, W::one());
for entry in &self.lexicon {
if entry.phones.is_empty() {
continue;
}
let mut current = start;
let next = fst.add_state();
fst.add_arc(
current,
Some(entry.phones[0] as u32), Some(entry.word as u32), next,
entry.weight.clone(),
);
current = next;
if entry.phones.len() > 2 {
for &phone in &entry.phones[1..entry.phones.len() - 1] {
let next = fst.add_state();
fst.add_arc(
current,
Some(phone as u32), None, next,
W::one(),
);
current = next;
}
}
if entry.phones.len() > 1 {
let last_phone = entry.phones[entry.phones.len() - 1];
fst.add_arc(
current,
Some(last_phone as u32), None, start,
W::one(),
);
} else {
fst.add_arc(
current,
None, None, start,
W::one(),
);
}
}
fst
}
pub fn build(self) -> AsrCascade<W> {
let mut stats = CascadeStats::default();
let l = self.build_lexicon_for_composition();
let g: VectorWfst<u32, W> = if let Some(grammar) = self.grammar {
grammar
} else {
let mut fst = VectorWfst::new();
let s = fst.add_state();
fst.set_start(s);
fst.set_final(s, W::one());
fst
};
stats.g_states = g.num_states();
let lg: VectorWfst<u32, W> = if g.num_states() > 1 {
let lazy = compose(l, g);
materialize(lazy)
} else {
l
};
stats.lg_states = lg.num_states();
let clg: VectorWfst<u32, W> = if let Some(context) = self.context {
let context_u32: VectorWfst<u32, W> = convert_to_u32(&context);
let lazy = compose(context_u32, lg);
materialize(lazy)
} else {
lg
};
stats.det_lg_states = clg.num_states();
stats.clg_states = clg.num_states();
let hclg: VectorWfst<u32, W> = if let Some(hmm) = self.hmm {
let hmm_u32: VectorWfst<u32, W> = convert_to_u32(&hmm);
let lazy = compose(hmm_u32, clg);
materialize(lazy)
} else {
clg
};
stats.det_clg_states = hclg.num_states();
let final_arcs: usize = (0..hclg.num_states() as StateId)
.map(|s| hclg.transitions(s).len())
.sum();
stats.final_states = hclg.num_states();
stats.final_arcs = final_arcs;
let result_fst: VectorWfst<PhoneId, W> = convert_label_type(&hclg);
AsrCascade {
fst: result_fst,
config: self.config,
stats,
}
}
}
fn convert_to_u32<W: Semiring + Clone>(fst: &VectorWfst<PhoneId, W>) -> VectorWfst<u32, W> {
let mut result: VectorWfst<u32, W> = VectorWfst::new();
for _ in 0..fst.num_states() {
result.add_state();
}
let start = fst.start();
if start != NO_STATE {
result.set_start(start);
}
for state in 0..fst.num_states() as StateId {
if fst.is_final(state) {
result.set_final(state, fst.final_weight(state));
}
}
for state in 0..fst.num_states() as StateId {
for arc in fst.transitions(state) {
result.add_arc(
arc.from,
arc.input, arc.output, arc.to,
arc.weight.clone(),
);
}
}
result
}
fn convert_label_type<W: Semiring + Clone>(fst: &VectorWfst<u32, W>) -> VectorWfst<PhoneId, W> {
let mut result: VectorWfst<PhoneId, W> = VectorWfst::new();
for _ in 0..fst.num_states() {
result.add_state();
}
let start = fst.start();
if start != NO_STATE {
result.set_start(start);
}
for state in 0..fst.num_states() as StateId {
if fst.is_final(state) {
result.set_final(state, fst.final_weight(state));
}
}
for state in 0..fst.num_states() as StateId {
for arc in fst.transitions(state) {
result.add_arc(
arc.from,
arc.input, arc.output, arc.to,
arc.weight.clone(),
);
}
}
result
}
impl<W> CascadeBuilder<W>
where
W: DivisibleSemiring
+ NumericalWeight
+ TotallyOrderedSemiring
+ QuantizableSemiring
+ PartialOrd
+ Clone
+ Debug
+ Hash
+ Eq,
{
pub fn build_optimized(self) -> AsrCascade<W> {
let mut stats = CascadeStats::default();
let l = self.build_lexicon_for_composition();
let g: VectorWfst<u32, W> = if let Some(grammar) = self.grammar {
grammar
} else {
let mut fst = VectorWfst::new();
let s = fst.add_state();
fst.set_start(s);
fst.set_final(s, W::one());
fst
};
stats.g_states = g.num_states();
let lg: VectorWfst<u32, W> = if g.num_states() > 1 {
let lazy = compose(l, g);
materialize(lazy)
} else {
l
};
stats.lg_states = lg.num_states();
let det_lg: VectorWfst<u32, W> = if self.config.incremental_det {
determinize(&lg, DeterminizeConfig::standard()).unwrap_or(lg)
} else {
lg
};
stats.det_lg_states = det_lg.num_states();
let clg: VectorWfst<u32, W> = if let Some(context) = self.context {
let context_u32: VectorWfst<u32, W> = convert_to_u32(&context);
let lazy = compose(context_u32, det_lg);
materialize(lazy)
} else {
det_lg
};
stats.clg_states = clg.num_states();
let det_clg: VectorWfst<u32, W> = if self.config.incremental_det {
determinize(&clg, DeterminizeConfig::standard()).unwrap_or(clg)
} else {
clg
};
stats.det_clg_states = det_clg.num_states();
let hclg: VectorWfst<u32, W> = if let Some(hmm) = self.hmm {
let hmm_u32: VectorWfst<u32, W> = convert_to_u32(&hmm);
let lazy = compose(hmm_u32, det_clg);
materialize(lazy)
} else {
det_clg
};
let minimized: VectorWfst<u32, W> = if self.config.minimize {
minimize(&hclg, MinimizeConfig::default()).unwrap_or(hclg)
} else {
hclg
};
let final_arcs: usize = (0..minimized.num_states() as StateId)
.map(|s| minimized.transitions(s).len())
.sum();
stats.final_states = minimized.num_states();
stats.final_arcs = final_arcs;
let result_fst: VectorWfst<PhoneId, W> = convert_label_type(&minimized);
AsrCascade {
fst: result_fst,
config: self.config,
stats,
}
}
}
impl<W: Semiring + Clone> Default for CascadeBuilder<W> {
fn default() -> Self {
Self::new()
}
}
impl<W: Semiring> AsrCascade<W> {
pub fn as_fst(&self) -> &VectorWfst<PhoneId, W> {
&self.fst
}
pub fn statistics(&self) -> &CascadeStats {
&self.stats
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::LogWeight;
use crate::wfst::NO_STATE;
#[test]
fn test_lexicon_entry() {
let entry = LexiconEntry::new(
1, vec![10, 11, 12], LogWeight::new(0.5),
);
assert_eq!(entry.word, 1);
assert_eq!(entry.phones.len(), 3);
}
#[test]
fn test_cascade_config_default() {
let config = CascadeConfig::default();
assert!(config.incremental_det);
assert!(config.minimize);
assert!(!config.lazy);
}
#[test]
fn test_cascade_builder_minimal() {
let builder = CascadeBuilder::<LogWeight>::new();
let cascade = builder.build();
assert!(cascade.fst.num_states() > 0);
assert!(cascade.fst.start() != NO_STATE);
}
#[test]
fn test_cascade_builder_with_lexicon() {
let mut builder = CascadeBuilder::<LogWeight>::new();
builder.add_lexicon_entry(LexiconEntry::new(
1, vec![10, 11, 12], LogWeight::new(0.0),
));
builder.add_lexicon_entry(LexiconEntry::new(
2, vec![20, 21, 22, 23], LogWeight::new(0.0),
));
let cascade = builder.build();
assert!(cascade.fst.num_states() > 1);
}
#[test]
fn test_cascade_stats() {
let mut builder = CascadeBuilder::<LogWeight>::new();
builder.add_lexicon_entry(LexiconEntry::new(1, vec![10, 11], LogWeight::new(0.0)));
let cascade = builder.build();
assert!(cascade.stats.final_states > 0);
}
#[test]
fn test_auxiliary_symbols() {
let entry = LexiconEntry::new(1, vec![10], LogWeight::new(0.0)).with_auxiliaries(vec![
AuxiliarySymbol::WordBoundary,
AuxiliarySymbol::Disambiguation(0),
]);
assert_eq!(entry.auxiliaries.len(), 2);
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use crate::semiring::LogWeight;
use crate::wfst::{Wfst, NO_STATE};
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn aux_symbol_reflexive(idx in 0u32..100) {
let sym = AuxiliarySymbol::Disambiguation(idx);
prop_assert_eq!(sym, sym);
}
#[test]
fn different_disambiguation(a in 0u32..100, b in 100u32..200) {
prop_assert_ne!(
AuxiliarySymbol::Disambiguation(a),
AuxiliarySymbol::Disambiguation(b)
);
}
#[test]
fn word_boundary_ne_epsilon(_seed in any::<u64>()) {
prop_assert_ne!(AuxiliarySymbol::WordBoundary, AuxiliarySymbol::Epsilon);
}
#[test]
fn word_boundary_ne_disambiguation(idx in 0u32..100) {
prop_assert_ne!(
AuxiliarySymbol::WordBoundary,
AuxiliarySymbol::Disambiguation(idx)
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn lexicon_preserves_word(word_id in 0u32..1000) {
let entry = LexiconEntry::new(
word_id,
vec![1, 2, 3],
LogWeight::new(1.0),
);
prop_assert_eq!(entry.word, word_id);
}
#[test]
fn lexicon_preserves_phones(phones in prop::collection::vec(0u32..100, 1..10)) {
let entry = LexiconEntry::new(
1,
phones.clone(),
LogWeight::new(1.0),
);
prop_assert_eq!(entry.phones, phones);
}
#[test]
fn lexicon_empty_auxiliaries(word_id in 0u32..100) {
let entry = LexiconEntry::new(
word_id,
vec![1],
LogWeight::new(1.0),
);
prop_assert!(entry.auxiliaries.is_empty());
}
#[test]
fn with_auxiliaries_sets(num_aux in 0usize..5) {
let aux: Vec<_> = (0..num_aux)
.map(|i| AuxiliarySymbol::Disambiguation(i as u32))
.collect();
let entry = LexiconEntry::new(1, vec![1], LogWeight::new(1.0))
.with_auxiliaries(aux.clone());
prop_assert_eq!(entry.auxiliaries.len(), num_aux);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn default_config_det(_seed in any::<u64>()) {
let config = CascadeConfig::default();
prop_assert!(config.incremental_det);
}
#[test]
fn default_config_min(_seed in any::<u64>()) {
let config = CascadeConfig::default();
prop_assert!(config.minimize);
}
#[test]
fn default_config_not_lazy(_seed in any::<u64>()) {
let config = CascadeConfig::default();
prop_assert!(!config.lazy);
}
#[test]
fn default_config_word_boundaries(_seed in any::<u64>()) {
let config = CascadeConfig::default();
prop_assert!(config.word_boundaries);
}
#[test]
fn default_config_homophony(_seed in any::<u64>()) {
let config = CascadeConfig::default();
prop_assert_eq!(config.max_homophony, 10);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn default_cascade_stats(_seed in any::<u64>()) {
let stats = CascadeStats::default();
prop_assert_eq!(stats.g_states, 0);
prop_assert_eq!(stats.lg_states, 0);
prop_assert_eq!(stats.det_lg_states, 0);
prop_assert_eq!(stats.clg_states, 0);
prop_assert_eq!(stats.det_clg_states, 0);
prop_assert_eq!(stats.final_states, 0);
prop_assert_eq!(stats.final_arcs, 0);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(25))]
#[test]
fn empty_builder_valid(_seed in any::<u64>()) {
let builder = CascadeBuilder::<LogWeight>::new();
let cascade = builder.build();
prop_assert!(cascade.fst.num_states() > 0);
prop_assert!(cascade.fst.start() != NO_STATE);
}
#[test]
fn builder_with_entries(num_entries in 1usize..5) {
let mut builder = CascadeBuilder::<LogWeight>::new();
for i in 0..num_entries {
builder.add_lexicon_entry(LexiconEntry::new(
i as u32,
vec![(i * 10) as u32, (i * 10 + 1) as u32],
LogWeight::new(0.0),
));
}
let cascade = builder.build();
prop_assert!(cascade.fst.num_states() >= 1);
}
#[test]
fn builder_config_sets(lazy in any::<bool>(), minimize in any::<bool>()) {
let config = CascadeConfig {
lazy,
minimize,
..Default::default()
};
let builder = CascadeBuilder::<LogWeight>::new().config(config);
let cascade = builder.build();
prop_assert_eq!(cascade.config.lazy, lazy);
prop_assert_eq!(cascade.config.minimize, minimize);
}
#[test]
fn builder_grammar(_seed in any::<u64>()) {
let mut g = VectorWfst::<WordId, LogWeight>::new();
let s = g.add_state();
g.set_start(s);
g.set_final(s, LogWeight::one());
let builder = CascadeBuilder::<LogWeight>::new().grammar(g);
let cascade = builder.build();
prop_assert!(cascade.fst.num_states() >= 1);
}
#[test]
fn builder_context(_seed in any::<u64>()) {
let mut c = VectorWfst::<PhoneId, LogWeight>::new();
let s = c.add_state();
c.set_start(s);
c.set_final(s, LogWeight::one());
let builder = CascadeBuilder::<LogWeight>::new().context_dependency(c);
let cascade = builder.build();
prop_assert!(cascade.fst.num_states() >= 1);
}
#[test]
fn builder_hmm(_seed in any::<u64>()) {
let mut h = VectorWfst::<PhoneId, LogWeight>::new();
let s = h.add_state();
h.set_start(s);
h.set_final(s, LogWeight::one());
let builder = CascadeBuilder::<LogWeight>::new().hmm(h);
let cascade = builder.build();
prop_assert!(cascade.fst.num_states() >= 1);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn as_fst_returns_fst(_seed in any::<u64>()) {
let builder = CascadeBuilder::<LogWeight>::new();
let cascade = builder.build();
let fst = cascade.as_fst();
prop_assert_eq!(fst.num_states(), cascade.fst.num_states());
}
#[test]
fn statistics_returns_stats(_seed in any::<u64>()) {
let builder = CascadeBuilder::<LogWeight>::new();
let cascade = builder.build();
let stats = cascade.statistics();
prop_assert_eq!(stats.final_states, cascade.stats.final_states);
}
#[test]
fn stats_match_fst(_seed in any::<u64>()) {
let builder = CascadeBuilder::<LogWeight>::new();
let cascade = builder.build();
prop_assert_eq!(cascade.stats.final_states, cascade.fst.num_states());
}
#[test]
fn stats_arcs_match(_seed in any::<u64>()) {
let mut builder = CascadeBuilder::<LogWeight>::new();
builder.add_lexicon_entry(LexiconEntry::new(
1,
vec![10, 11, 12],
LogWeight::new(0.0),
));
let cascade = builder.build();
let actual_arcs: usize = (0..cascade.fst.num_states() as StateId)
.map(|s| cascade.fst.transitions(s).len())
.sum();
prop_assert_eq!(cascade.stats.final_arcs, actual_arcs);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn empty_lexicon_minimal(_seed in any::<u64>()) {
let builder = CascadeBuilder::<LogWeight>::new();
let cascade = builder.build();
prop_assert!(cascade.fst.num_states() >= 1);
}
#[test]
fn single_word_lexicon(
word_id in 0u32..100,
phones in prop::collection::vec(0u32..50, 1..5)
) {
let mut builder = CascadeBuilder::<LogWeight>::new();
builder.add_lexicon_entry(LexiconEntry::new(
word_id,
phones.clone(),
LogWeight::new(0.0),
));
let cascade = builder.build();
prop_assert!(cascade.fst.num_states() >= 1);
}
#[test]
fn homophone_lexicon(
word1 in 0u32..50,
word2 in 50u32..100
) {
let phones = vec![10, 11, 12];
let mut builder = CascadeBuilder::<LogWeight>::new();
builder.add_lexicon_entry(LexiconEntry::new(
word1,
phones.clone(),
LogWeight::new(0.0),
));
builder.add_lexicon_entry(LexiconEntry::new(
word2,
phones.clone(),
LogWeight::new(0.0),
));
let cascade = builder.build();
prop_assert!(cascade.fst.num_states() >= 1);
}
}
}