use crate::semiring::Semiring;
use crate::transducer::{DenseFsa, Label};
use crate::wfst::{MutableWfst, StateId, VectorWfst, WeightedTransition, Wfst};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct PrunedCompositionConfig {
pub beam: f64,
pub max_states: usize,
pub max_arcs: usize,
pub output_beam: Option<f64>,
pub compute_gradients: bool,
pub min_arc_posterior: f64,
}
impl Default for PrunedCompositionConfig {
fn default() -> Self {
Self {
beam: 10.0,
max_states: 10000,
max_arcs: 50000,
output_beam: None,
compute_gradients: true,
min_arc_posterior: 1e-10,
}
}
}
#[derive(Debug)]
pub struct PrunedComposition<W: Semiring> {
pub wfst: VectorWfst<Label, W>,
pub state_map: HashMap<StateId, (usize, StateId)>,
pub forward_scores: Vec<f64>,
pub backward_scores: Option<Vec<f64>>,
pub arc_info: Vec<ArcInfo>,
pub stats: PruningStats,
}
#[derive(Debug, Clone)]
pub struct ArcInfo {
pub from_state: StateId,
pub to_state: StateId,
pub time: usize,
pub label: Label,
pub acoustic_score: f64,
pub lm_score: f64,
pub arc_score: f64,
}
#[derive(Debug, Clone, Default)]
pub struct PruningStats {
pub states_before: usize,
pub states_after: usize,
pub arcs_before: usize,
pub arcs_after: usize,
pub avg_beam_utilization: f64,
}
pub fn pruned_compose<W>(
dense: &DenseFsa<W>,
sparse: &VectorWfst<Label, W>,
config: &PrunedCompositionConfig,
) -> PrunedComposition<W>
where
W: Semiring + From<f64> + Into<f64> + Clone,
{
let num_frames = dense.num_frames;
let mut fst: VectorWfst<Label, W> = VectorWfst::new();
let mut state_map: HashMap<(usize, StateId), StateId> = HashMap::new();
let mut reverse_map: HashMap<StateId, (usize, StateId)> = HashMap::new();
let mut arc_info: Vec<ArcInfo> = Vec::new();
let mut forward_scores: Vec<f64> = Vec::new();
let mut stats = PruningStats::default();
let get_or_create_state = |map: &mut HashMap<(usize, StateId), StateId>,
rev_map: &mut HashMap<StateId, (usize, StateId)>,
scores: &mut Vec<f64>,
fst: &mut VectorWfst<Label, W>,
t: usize,
s: StateId|
-> StateId {
*map.entry((t, s)).or_insert_with(|| {
let id = fst.add_state();
scores.push(f64::NEG_INFINITY);
rev_map.insert(id, (t, s));
id
})
};
let sparse_start = sparse.start();
let start_state = get_or_create_state(
&mut state_map,
&mut reverse_map,
&mut forward_scores,
&mut fst,
0,
sparse_start,
);
forward_scores[start_state as usize] = 0.0;
fst.set_start(start_state);
let mut best_scores: Vec<f64> = vec![f64::NEG_INFINITY; num_frames + 1];
best_scores[0] = 0.0;
for t in 0..num_frames {
let frame_scores = dense.frame_scores(t);
let beam_threshold = best_scores[t] - config.beam;
let active_states: Vec<(StateId, StateId, f64)> = state_map
.iter()
.filter(|((time, _), _)| *time == t)
.map(|((_, sparse_s), &composed_s)| {
let score = forward_scores[composed_s as usize];
(*sparse_s, composed_s, score)
})
.filter(|(_, _, score)| *score >= beam_threshold)
.collect();
stats.states_before += active_states.len();
for (sparse_state, composed_from, from_score) in active_states {
for tr in sparse.transitions(sparse_state) {
let label = match tr.input {
Some(l) => l,
None => continue, };
let acoustic_score = if (label as usize) < frame_scores.len() {
frame_scores[label as usize] as f64
} else {
continue; };
if acoustic_score <= f64::NEG_INFINITY {
continue;
}
let lm_score: f64 = tr.weight.clone().into();
let arc_score = acoustic_score - lm_score;
let new_score = from_score + arc_score;
if new_score < best_scores[t + 1] - config.beam {
continue;
}
if new_score > best_scores[t + 1] {
best_scores[t + 1] = new_score;
}
let composed_to = get_or_create_state(
&mut state_map,
&mut reverse_map,
&mut forward_scores,
&mut fst,
t + 1,
tr.to,
);
let old_score = forward_scores[composed_to as usize];
forward_scores[composed_to as usize] = log_add(old_score, new_score);
fst.add_transition(WeightedTransition {
from: composed_from,
input: Some(label),
output: tr.output,
to: composed_to,
weight: W::from(-arc_score), });
stats.arcs_before += 1;
if config.compute_gradients {
arc_info.push(ArcInfo {
from_state: composed_from,
to_state: composed_to,
time: t,
label,
acoustic_score,
lm_score,
arc_score,
});
}
}
}
stats.states_after = state_map.len();
stats.arcs_after = arc_info.len();
}
for (&(t, sparse_s), &composed_s) in &state_map {
if t == num_frames && sparse.is_final(sparse_s) {
let final_weight: f64 = sparse.final_weight(sparse_s).into();
fst.set_final(composed_s, W::from(final_weight));
}
}
if stats.states_before > 0 {
stats.avg_beam_utilization = stats.states_after as f64 / stats.states_before as f64;
}
PrunedComposition {
wfst: fst,
state_map: reverse_map,
forward_scores,
backward_scores: None,
arc_info,
stats,
}
}
impl<W: Semiring + From<f64> + Into<f64> + Clone> PrunedComposition<W> {
pub fn forward_score(&self) -> f64 {
let mut total = f64::NEG_INFINITY;
for state in 0..self.wfst.num_states() {
let state_id = state as StateId;
if self.wfst.is_final(state_id) {
let final_weight: f64 = self.wfst.final_weight(state_id).into();
let state_score = self.forward_scores[state];
total = log_add(total, state_score - final_weight);
}
}
total
}
pub fn compute_backward(&mut self) {
let num_states = self.wfst.num_states();
let mut backward = vec![f64::NEG_INFINITY; num_states];
for state in 0..num_states {
let state_id = state as StateId;
if self.wfst.is_final(state_id) {
let final_weight: f64 = self.wfst.final_weight(state_id).into();
backward[state] = -final_weight;
}
}
for state in (0..num_states).rev() {
let state_id = state as StateId;
for tr in self.wfst.transitions(state_id) {
let next_state = tr.to as usize;
if backward[next_state] > f64::NEG_INFINITY {
let weight: f64 = tr.weight.clone().into();
let new_backward = -weight + backward[next_state];
backward[state] = log_add(backward[state], new_backward);
}
}
}
self.backward_scores = Some(backward);
}
pub fn backward(&mut self, output_grad: f64) -> DenseGradient {
if self.backward_scores.is_none() {
self.compute_backward();
}
let backward = self.backward_scores.as_ref().expect("backward computed");
let total_log_prob = self.forward_score();
let num_frames = self.arc_info.iter().map(|a| a.time + 1).max().unwrap_or(0);
let vocab_size = self
.arc_info
.iter()
.map(|a| a.label as usize + 1)
.max()
.unwrap_or(0);
let mut gradients = DenseGradient::new(num_frames, vocab_size);
for arc in &self.arc_info {
let from_score = self.forward_scores[arc.from_state as usize];
let to_backward = backward[arc.to_state as usize];
let arc_posterior = (from_score + arc.arc_score + to_backward - total_log_prob).exp();
gradients.add(arc.time, arc.label as usize, output_grad * arc_posterior);
}
gradients
}
}
#[derive(Debug, Clone)]
pub struct DenseGradient {
pub num_frames: usize,
pub vocab_size: usize,
pub data: Vec<f64>,
}
impl DenseGradient {
pub fn new(num_frames: usize, vocab_size: usize) -> Self {
Self {
num_frames,
vocab_size,
data: vec![0.0; num_frames * vocab_size],
}
}
#[inline]
pub fn get(&self, t: usize, v: usize) -> f64 {
self.data[t * self.vocab_size + v]
}
#[inline]
pub fn set(&mut self, t: usize, v: usize, value: f64) {
self.data[t * self.vocab_size + v] = value;
}
#[inline]
pub fn add(&mut self, t: usize, v: usize, value: f64) {
self.data[t * self.vocab_size + v] += value;
}
}
#[inline]
fn log_add(a: f64, b: f64) -> f64 {
if a == f64::NEG_INFINITY {
b
} else if b == f64::NEG_INFINITY {
a
} else if a > b {
a + (1.0 + (b - a).exp()).ln()
} else {
b + (1.0 + (a - b).exp()).ln()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pruned_composition_config() {
let config = PrunedCompositionConfig::default();
assert_eq!(config.beam, 10.0);
assert_eq!(config.max_states, 10000);
}
#[test]
fn test_dense_gradient() {
let mut grad = DenseGradient::new(10, 100);
grad.set(0, 50, 0.5);
assert!((grad.get(0, 50) - 0.5).abs() < 1e-10);
grad.add(0, 50, 0.3);
assert!((grad.get(0, 50) - 0.8).abs() < 1e-10);
}
#[test]
fn test_pruning_stats() {
let stats = PruningStats {
states_before: 1000,
states_after: 100,
arcs_before: 5000,
arcs_after: 500,
avg_beam_utilization: 0.1,
};
assert_eq!(stats.states_after, 100);
}
}