use crate::semiring::Semiring;
use crate::wfst::{MutableWfst, StateId, VectorWfst, WeightedTransition, Wfst};
#[derive(Debug, Clone)]
pub struct LfMmiConfig {
pub leaky_hmm_coefficient: f64,
pub l2_regularize: f64,
pub xent_regularize: f64,
pub use_chain_topology: bool,
pub subsampling_factor: usize,
}
impl Default for LfMmiConfig {
fn default() -> Self {
Self {
leaky_hmm_coefficient: 0.1,
l2_regularize: 0.0001,
xent_regularize: 0.1,
use_chain_topology: true,
subsampling_factor: 3,
}
}
}
#[derive(Debug, Clone)]
pub struct LfMmiResult {
pub loss: f64,
pub numerator_log_prob: f64,
pub denominator_log_prob: f64,
pub xent_loss: f64,
pub gradients: LfMmiGradients,
}
#[derive(Debug, Clone)]
pub struct LfMmiGradients {
pub num_frames: usize,
pub num_pdfs: usize,
pub data: Vec<f64>,
}
impl LfMmiGradients {
pub fn new(num_frames: usize, num_pdfs: usize) -> Self {
Self {
num_frames,
num_pdfs,
data: vec![0.0; num_frames * num_pdfs],
}
}
#[inline]
pub fn get(&self, t: usize, pdf: usize) -> f64 {
self.data[t * self.num_pdfs + pdf]
}
#[inline]
pub fn set(&mut self, t: usize, pdf: usize, value: f64) {
self.data[t * self.num_pdfs + pdf] = value;
}
#[inline]
pub fn add(&mut self, t: usize, pdf: usize, value: f64) {
self.data[t * self.num_pdfs + pdf] += value;
}
}
pub fn lfmmi_loss<W>(
acoustic_scores: &[Vec<f64>],
numerator_graph: &VectorWfst<u32, W>,
denominator_graph: &VectorWfst<u32, W>,
config: &LfMmiConfig,
) -> LfMmiResult
where
W: Semiring + From<f64> + Into<f64> + Clone,
{
let num_frames = acoustic_scores.len();
let num_pdfs = acoustic_scores.first().map_or(0, |v| v.len());
let (num_log_prob, num_posteriors) =
compute_graph_score(acoustic_scores, numerator_graph, config);
let (den_log_prob, den_posteriors) =
compute_graph_score(acoustic_scores, denominator_graph, config);
let mmi_loss = -(num_log_prob - den_log_prob);
let mut gradients = LfMmiGradients::new(num_frames, num_pdfs);
for t in 0..num_frames {
for pdf in 0..num_pdfs {
let grad = den_posteriors.get(t, pdf) - num_posteriors.get(t, pdf);
gradients.set(t, pdf, grad);
}
}
let xent_loss = if config.xent_regularize > 0.0 {
compute_xent_loss(acoustic_scores, &num_posteriors, num_frames, num_pdfs)
} else {
0.0
};
let l2_loss = if config.l2_regularize > 0.0 {
compute_l2_loss(acoustic_scores, config.l2_regularize)
} else {
0.0
};
let total_loss = mmi_loss + config.xent_regularize * xent_loss + l2_loss;
LfMmiResult {
loss: total_loss,
numerator_log_prob: num_log_prob,
denominator_log_prob: den_log_prob,
xent_loss,
gradients,
}
}
fn compute_graph_score<W>(
acoustic_scores: &[Vec<f64>],
graph: &VectorWfst<u32, W>,
config: &LfMmiConfig,
) -> (f64, LfMmiGradients)
where
W: Semiring + From<f64> + Into<f64> + Clone,
{
let num_frames = acoustic_scores.len();
let num_pdfs = acoustic_scores.first().map_or(0, |v| v.len());
let num_states = graph.num_states();
let mut alpha = vec![vec![f64::NEG_INFINITY; num_states]; num_frames + 1];
alpha[0][graph.start() as usize] = 0.0;
let mut frame_posteriors = LfMmiGradients::new(num_frames, num_pdfs);
for t in 0..num_frames {
for s in 0..num_states {
if alpha[t][s] <= f64::NEG_INFINITY {
continue;
}
let state = s as StateId;
for tr in graph.transitions(state) {
let pdf = tr.input.unwrap_or(0) as usize;
if pdf >= num_pdfs {
continue;
}
let acoustic = acoustic_scores[t][pdf];
let transition_weight: f64 = tr.weight.clone().into();
let arc_score = acoustic - transition_weight;
let leaky_score = if config.leaky_hmm_coefficient > 0.0 {
log_add(arc_score, config.leaky_hmm_coefficient.ln())
} else {
arc_score
};
let new_alpha = alpha[t][s] + leaky_score;
let next_state = tr.to as usize;
alpha[t + 1][next_state] = log_add(alpha[t + 1][next_state], new_alpha);
}
}
}
let mut beta = vec![vec![f64::NEG_INFINITY; num_states]; num_frames + 1];
for s in 0..num_states {
let state = s as StateId;
if graph.is_final(state) {
let final_weight: f64 = graph.final_weight(state).into();
beta[num_frames][s] = -final_weight; }
}
for t in (0..num_frames).rev() {
for s in 0..num_states {
let state = s as StateId;
for tr in graph.transitions(state) {
let pdf = tr.input.unwrap_or(0) as usize;
if pdf >= num_pdfs {
continue;
}
let next_state = tr.to as usize;
if beta[t + 1][next_state] <= f64::NEG_INFINITY {
continue;
}
let acoustic = acoustic_scores[t][pdf];
let transition_weight: f64 = tr.weight.clone().into();
let arc_score = acoustic - transition_weight;
let new_beta = arc_score + beta[t + 1][next_state];
beta[t][s] = log_add(beta[t][s], new_beta);
}
}
}
let total_log_prob = alpha[num_frames]
.iter()
.enumerate()
.filter(|(s, _)| graph.is_final(*s as StateId))
.map(|(s, &a)| {
let final_weight: f64 = graph.final_weight(s as StateId).into();
a - final_weight
})
.fold(f64::NEG_INFINITY, log_add);
for t in 0..num_frames {
for s in 0..num_states {
if alpha[t][s] <= f64::NEG_INFINITY {
continue;
}
let state = s as StateId;
for tr in graph.transitions(state) {
let pdf = tr.input.unwrap_or(0) as usize;
if pdf >= num_pdfs {
continue;
}
let next_state = tr.to as usize;
if beta[t + 1][next_state] <= f64::NEG_INFINITY {
continue;
}
let acoustic = acoustic_scores[t][pdf];
let transition_weight: f64 = tr.weight.clone().into();
let arc_score = acoustic - transition_weight;
let posterior =
(alpha[t][s] + arc_score + beta[t + 1][next_state] - total_log_prob).exp();
frame_posteriors.add(t, pdf, posterior);
}
}
}
(total_log_prob, frame_posteriors)
}
fn compute_xent_loss(
acoustic_scores: &[Vec<f64>],
posteriors: &LfMmiGradients,
num_frames: usize,
num_pdfs: usize,
) -> f64 {
let mut loss = 0.0;
for t in 0..num_frames {
for pdf in 0..num_pdfs {
let posterior = posteriors.get(t, pdf);
if posterior > 1e-10 {
let log_prob = acoustic_scores[t][pdf];
loss -= posterior * log_prob;
}
}
}
loss / num_frames as f64
}
fn compute_l2_loss(acoustic_scores: &[Vec<f64>], l2_weight: f64) -> f64 {
let mut loss = 0.0;
for frame in acoustic_scores {
for &score in frame {
loss += score * score;
}
}
0.5 * l2_weight * loss
}
pub fn build_numerator_graph<W>(
transcript: &[u32],
_pdf_to_phone: &[u32],
hmm_topo: &HmmTopology,
) -> VectorWfst<u32, W>
where
W: Semiring + From<f64>,
{
let mut fst: VectorWfst<u32, W> = VectorWfst::new();
let mut current_state = fst.add_state();
fst.set_start(current_state);
for &phone in transcript {
let num_hmm_states = hmm_topo.num_states_for_phone(phone);
for hmm_state in 0..num_hmm_states {
let pdf = hmm_topo.pdf_for_state(phone, hmm_state);
let next_state = fst.add_state();
fst.add_transition(WeightedTransition {
from: current_state,
input: Some(pdf),
output: Some(pdf),
to: current_state,
weight: W::from(hmm_topo.self_loop_prob(phone, hmm_state).ln()),
});
fst.add_transition(WeightedTransition {
from: current_state,
input: Some(pdf),
output: Some(pdf),
to: next_state,
weight: W::from(hmm_topo.forward_prob(phone, hmm_state).ln()),
});
current_state = next_state;
}
}
fst.set_final(current_state, W::one());
fst
}
pub fn build_denominator_graph<W>(
num_phones: usize,
hmm_topo: &HmmTopology,
phone_lm: Option<&VectorWfst<u32, W>>,
) -> VectorWfst<u32, W>
where
W: Semiring + From<f64> + Clone,
{
if let Some(lm) = phone_lm {
return lm.clone();
}
let mut fst: VectorWfst<u32, W> = VectorWfst::new();
let state = fst.add_state();
fst.set_start(state);
fst.set_final(state, W::one());
for phone in 0..num_phones as u32 {
let num_hmm_states = hmm_topo.num_states_for_phone(phone);
for hmm_state in 0..num_hmm_states {
let pdf = hmm_topo.pdf_for_state(phone, hmm_state);
fst.add_transition(WeightedTransition {
from: state,
input: Some(pdf),
output: Some(pdf),
to: state,
weight: W::from(0.0), });
}
}
fst
}
#[derive(Debug, Clone)]
pub struct HmmTopology {
pub states_per_phone: usize,
pub self_loop_prob: f64,
pub forward_prob: f64,
pub num_phones: usize,
}
impl Default for HmmTopology {
fn default() -> Self {
Self {
states_per_phone: 3,
self_loop_prob: 0.5,
forward_prob: 0.5,
num_phones: 0,
}
}
}
impl HmmTopology {
pub fn new(num_phones: usize, states_per_phone: usize) -> Self {
Self {
states_per_phone,
self_loop_prob: 0.5,
forward_prob: 0.5,
num_phones,
}
}
pub fn num_states_for_phone(&self, _phone: u32) -> usize {
self.states_per_phone
}
pub fn pdf_for_state(&self, phone: u32, hmm_state: usize) -> u32 {
phone * self.states_per_phone as u32 + hmm_state as u32
}
pub fn self_loop_prob(&self, _phone: u32, _hmm_state: usize) -> f64 {
self.self_loop_prob
}
pub fn forward_prob(&self, _phone: u32, _hmm_state: usize) -> f64 {
self.forward_prob
}
}
#[inline]
fn log_add(a: f64, b: f64) -> f64 {
if a == f64::NEG_INFINITY {
b
} else if b == f64::NEG_INFINITY {
a
} else if a > b {
a + (1.0 + (b - a).exp()).ln()
} else {
b + (1.0 + (a - b).exp()).ln()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
#[test]
fn test_hmm_topology() {
let topo = HmmTopology::new(40, 3);
assert_eq!(topo.num_states_for_phone(0), 3);
assert_eq!(topo.pdf_for_state(0, 0), 0);
assert_eq!(topo.pdf_for_state(0, 1), 1);
assert_eq!(topo.pdf_for_state(1, 0), 3);
}
#[test]
fn test_denominator_graph() {
let topo = HmmTopology::new(10, 3);
let graph: VectorWfst<u32, TropicalWeight> = build_denominator_graph(10, &topo, None);
assert_eq!(graph.num_states(), 1);
assert!(graph.is_final(0));
assert_eq!(graph.transitions(0).len(), 30); }
#[test]
fn test_lfmmi_gradients() {
let mut grads = LfMmiGradients::new(10, 100);
grads.set(0, 50, 0.5);
assert!((grads.get(0, 50) - 0.5).abs() < 1e-10);
grads.add(0, 50, 0.3);
assert!((grads.get(0, 50) - 0.8).abs() < 1e-10);
}
}