use super::{build_target_graph, Label, TransducerLattice, BLANK};
use crate::semiring::Semiring;
use crate::wfst::{VectorWfst, Wfst};
#[derive(Debug, Clone)]
pub struct TransducerLossResult {
pub loss: f64,
pub gradients: TransducerGradients,
pub forward_scores: Vec<f64>,
pub backward_scores: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct TransducerGradients {
pub num_frames: usize,
pub num_positions: usize,
pub vocab_size: usize,
pub data: Vec<f64>,
}
impl TransducerGradients {
pub fn new(num_frames: usize, num_positions: usize, vocab_size: usize) -> Self {
let size = num_frames * num_positions * vocab_size;
Self {
num_frames,
num_positions,
vocab_size,
data: vec![0.0; size],
}
}
#[inline]
pub fn get(&self, t: usize, u: usize, label: Label) -> f64 {
let idx = (t * self.num_positions + u) * self.vocab_size + label as usize;
self.data[idx]
}
#[inline]
pub fn set(&mut self, t: usize, u: usize, label: Label, value: f64) {
let idx = (t * self.num_positions + u) * self.vocab_size + label as usize;
self.data[idx] = value;
}
#[inline]
pub fn add(&mut self, t: usize, u: usize, label: Label, value: f64) {
let idx = (t * self.num_positions + u) * self.vocab_size + label as usize;
self.data[idx] += value;
}
}
pub fn transducer_loss<W>(lattice: &TransducerLattice<W>, targets: &[Label]) -> TransducerLossResult
where
W: Semiring + From<f64> + Into<f64>,
{
let t_len = lattice.num_frames;
let u_len = targets.len() + 1;
let mut alpha = vec![vec![f64::NEG_INFINITY; u_len]; t_len + 1];
alpha[0][0] = 0.0;
for t in 0..t_len {
for u in 0..u_len {
if alpha[t][u] <= f64::NEG_INFINITY {
continue;
}
let blank_prob = lattice.get(t, u, BLANK);
let new_alpha = alpha[t][u] + blank_prob;
alpha[t + 1][u] = log_add(alpha[t + 1][u], new_alpha);
if u < targets.len() {
let label = targets[u];
let label_prob = lattice.get(t, u, label);
let new_alpha = alpha[t][u] + label_prob;
alpha[t + 1][u + 1] = log_add(alpha[t + 1][u + 1], new_alpha);
}
}
}
let total_log_prob = alpha[t_len][u_len - 1];
let mut beta = vec![vec![f64::NEG_INFINITY; u_len]; t_len + 1];
beta[t_len][u_len - 1] = 0.0;
for t in (0..t_len).rev() {
for u in (0..u_len).rev() {
if beta[t + 1][u] > f64::NEG_INFINITY {
let blank_prob = lattice.get(t, u, BLANK);
let new_beta = blank_prob + beta[t + 1][u];
beta[t][u] = log_add(beta[t][u], new_beta);
}
if u < targets.len() && beta[t + 1][u + 1] > f64::NEG_INFINITY {
let label = targets[u];
let label_prob = lattice.get(t, u, label);
let new_beta = label_prob + beta[t + 1][u + 1];
beta[t][u] = log_add(beta[t][u], new_beta);
}
}
}
let mut gradients = TransducerGradients::new(t_len, u_len, lattice.vocab_size);
for t in 0..t_len {
for u in 0..u_len {
if alpha[t][u] <= f64::NEG_INFINITY {
continue;
}
if beta[t + 1][u] > f64::NEG_INFINITY {
let blank_prob = lattice.get(t, u, BLANK);
let posterior = (alpha[t][u] + blank_prob + beta[t + 1][u] - total_log_prob).exp();
gradients.set(t, u, BLANK, -posterior);
}
if u < targets.len() && beta[t + 1][u + 1] > f64::NEG_INFINITY {
let label = targets[u];
let label_prob = lattice.get(t, u, label);
let posterior =
(alpha[t][u] + label_prob + beta[t + 1][u + 1] - total_log_prob).exp();
gradients.set(t, u, label, -posterior);
}
}
}
let loss = -total_log_prob;
let forward_scores: Vec<f64> = alpha.into_iter().flatten().collect();
let backward_scores: Vec<f64> = beta.into_iter().flatten().collect();
TransducerLossResult {
loss,
gradients,
forward_scores,
backward_scores,
}
}
pub fn transducer_loss_with_lm<W>(
lattice: &TransducerLattice<W>,
targets: &[Label],
lm: &VectorWfst<Label, W>,
lm_weight: f64,
) -> TransducerLossResult
where
W: Semiring + From<f64> + Into<f64> + Clone,
{
let _target_graph: VectorWfst<Label, W> = build_target_graph(targets);
let mut result = transducer_loss(lattice, targets);
let lm_score = compute_lm_score(lm, targets);
result.loss -= lm_weight * lm_score;
result
}
fn compute_lm_score<W>(lm: &VectorWfst<Label, W>, targets: &[Label]) -> f64
where
W: Semiring + Into<f64> + Clone,
{
let mut score = 0.0f64;
let mut state = lm.start();
for &label in targets {
let mut found = false;
for tr in lm.transitions(state) {
if tr.input == Some(label) {
let weight: f64 = tr.weight.clone().into();
score += weight;
state = tr.to;
found = true;
break;
}
}
if !found {
for tr in lm.transitions(state) {
if tr.input.is_none() {
let backoff_weight: f64 = tr.weight.clone().into();
score += backoff_weight;
state = tr.to;
for tr2 in lm.transitions(state) {
if tr2.input == Some(label) {
let weight: f64 = tr2.weight.clone().into();
score += weight;
state = tr2.to;
found = true;
break;
}
}
break;
}
}
}
if !found {
score += -10.0; }
}
if lm.is_final(state) {
let final_weight: f64 = lm.final_weight(state).into();
score += final_weight;
}
score
}
pub fn transducer_loss_batch<W>(
lattices: &[TransducerLattice<W>],
targets_batch: &[Vec<Label>],
) -> Vec<TransducerLossResult>
where
W: Semiring + From<f64> + Into<f64>,
{
lattices
.iter()
.zip(targets_batch.iter())
.map(|(lattice, targets)| transducer_loss(lattice, targets))
.collect()
}
#[inline]
fn log_add(a: f64, b: f64) -> f64 {
if a == f64::NEG_INFINITY {
b
} else if b == f64::NEG_INFINITY {
a
} else if a > b {
a + (1.0 + (b - a).exp()).ln()
} else {
b + (1.0 + (a - b).exp()).ln()
}
}
#[derive(Debug, Clone)]
pub struct TransducerLossConfig {
pub regularization: f64,
pub normalize_by_length: bool,
pub label_smoothing: f64,
pub lm_weight: f64,
}
impl Default for TransducerLossConfig {
fn default() -> Self {
Self {
regularization: 0.0,
normalize_by_length: true,
label_smoothing: 0.0,
lm_weight: 0.0,
}
}
}
pub fn factorized_transducer_loss<W>(
blank_logits: &[f64], vocab_logits: &[Vec<f64>], targets: &[Label],
) -> TransducerLossResult
where
W: Semiring + From<f64> + Into<f64>,
{
let t_len = blank_logits.len();
let u_len = targets.len() + 1;
let vocab_size = vocab_logits.first().map_or(1, |v| v.len()) + 1;
let mut lattice: TransducerLattice<W> = TransducerLattice::new(t_len, u_len, vocab_size);
for t in 0..t_len {
for u in 0..u_len {
lattice.set(t, u, BLANK, blank_logits[t]);
if u < vocab_logits.len() {
for (v, &log_prob) in vocab_logits[u].iter().enumerate() {
lattice.set(t, u, (v + 1) as Label, log_prob);
}
}
}
}
transducer_loss(&lattice, targets)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
#[test]
fn test_log_add() {
assert!((log_add(0.0, 0.0) - 0.693).abs() < 0.01); assert!((log_add(f64::NEG_INFINITY, 0.0) - 0.0).abs() < 0.001);
assert!((log_add(0.0, f64::NEG_INFINITY) - 0.0).abs() < 0.001);
}
#[test]
fn test_transducer_loss_simple() {
let mut lattice: TransducerLattice<TropicalWeight> = TransducerLattice::new(2, 2, 3);
lattice.set(0, 0, BLANK, -1.5); lattice.set(0, 0, 1, -2.0); lattice.set(0, 0, 2, -3.0); lattice.set(1, 0, BLANK, -1.2);
lattice.set(1, 0, 1, -1.8);
lattice.set(1, 1, BLANK, -1.0);
let targets = vec![1];
let result = transducer_loss(&lattice, &targets);
assert!(
result.loss > 0.0,
"Loss should be positive, got {}",
result.loss
);
assert!(result.loss.is_finite());
}
#[test]
fn test_transducer_gradients() {
let mut grads = TransducerGradients::new(2, 2, 3);
grads.set(0, 0, 1, 0.5);
assert!((grads.get(0, 0, 1) - 0.5).abs() < 1e-6);
grads.add(0, 0, 1, 0.3);
assert!((grads.get(0, 0, 1) - 0.8).abs() < 1e-6);
}
}