use super::{
AutoregressivePredictor, EncoderOutput, JointNetwork, Label, PredictorState, TransducerConfig,
TransducerStats, BLANK,
};
use crate::semiring::Semiring;
use crate::wfst::{StateId, VectorWfst, Wfst};
use std::cmp::Ordering;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct Hypothesis {
pub labels: Vec<Label>,
pub score: f32,
pub predictor_state: PredictorState,
pub lm_state: Option<StateId>,
timestep: usize,
}
impl Hypothesis {
pub fn initial(predictor_state: PredictorState) -> Self {
Self {
labels: Vec::new(),
score: 0.0,
predictor_state,
lm_state: None,
timestep: 0,
}
}
pub fn initial_with_lm(predictor_state: PredictorState, lm_start: StateId) -> Self {
Self {
labels: Vec::new(),
score: 0.0,
predictor_state,
lm_state: Some(lm_start),
timestep: 0,
}
}
pub fn extend(
&self,
label: Label,
score_delta: f32,
new_predictor_state: PredictorState,
) -> Self {
let mut new_labels = self.labels.clone();
if label != BLANK {
new_labels.push(label);
}
Self {
labels: new_labels,
score: self.score + score_delta,
predictor_state: new_predictor_state,
lm_state: self.lm_state,
timestep: self.timestep + 1,
}
}
pub fn extend_with_lm(
&self,
label: Label,
score_delta: f32,
new_predictor_state: PredictorState,
new_lm_state: StateId,
) -> Self {
let mut new_labels = self.labels.clone();
if label != BLANK {
new_labels.push(label);
}
Self {
labels: new_labels,
score: self.score + score_delta,
predictor_state: new_predictor_state,
lm_state: Some(new_lm_state),
timestep: self.timestep + 1,
}
}
}
impl PartialEq for Hypothesis {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}
impl Eq for Hypothesis {}
impl PartialOrd for Hypothesis {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Hypothesis {
fn cmp(&self, other: &Self) -> Ordering {
self.score
.partial_cmp(&other.score)
.unwrap_or(Ordering::Equal)
.reverse()
}
}
#[derive(Debug)]
pub struct TransducerDecoder<P: AutoregressivePredictor, J: JointNetwork> {
predictor: P,
joiner: J,
config: TransducerConfig,
}
impl<P: AutoregressivePredictor, J: JointNetwork> TransducerDecoder<P, J> {
pub fn new(predictor: P, joiner: J, config: TransducerConfig) -> Self {
Self {
predictor,
joiner,
config,
}
}
pub fn greedy_decode(&self, encoder_out: &EncoderOutput) -> DecodingResult {
let mut labels = Vec::new();
let mut score = 0.0f32;
let mut predictor_state = self.predictor.initial_state();
let mut predictor_out = vec![0.0f32; self.predictor.output_dim()];
let (new_state, initial_out) = self.predictor.step(&predictor_state, 0); predictor_state = new_state;
predictor_out.copy_from_slice(&initial_out);
for t in 0..encoder_out.num_frames {
let enc_frame = encoder_out.frame(t);
let mut symbols_this_frame = 0;
loop {
let log_probs = self.joiner.forward(enc_frame, &predictor_out);
let (best_label, best_prob) = log_probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
.map(|(i, &p)| (i as Label, p))
.expect("log_probs must not be empty");
score += best_prob;
if best_label == BLANK {
break;
}
labels.push(best_label);
let (new_state, new_out) = self.predictor.step(&predictor_state, best_label);
predictor_state = new_state;
predictor_out.copy_from_slice(&new_out);
symbols_this_frame += 1;
if symbols_this_frame >= self.config.max_symbols_per_frame {
break;
}
}
}
DecodingResult {
labels,
score,
stats: TransducerStats::default(),
}
}
pub fn beam_decode(&self, encoder_out: &EncoderOutput) -> Vec<DecodingResult> {
let beam_width = self.config.beam_width;
let mut hypotheses: Vec<Hypothesis> =
vec![Hypothesis::initial(self.predictor.initial_state())];
let mut predictor_cache: HashMap<Vec<Label>, Vec<f32>> = HashMap::new();
let (_, initial_out) = self.predictor.step(&self.predictor.initial_state(), 0);
predictor_cache.insert(Vec::new(), initial_out);
for t in 0..encoder_out.num_frames {
let enc_frame = encoder_out.frame(t);
let mut new_hypotheses: Vec<Hypothesis> = Vec::new();
for hyp in &hypotheses {
let predictor_out =
predictor_cache
.entry(hyp.labels.clone())
.or_insert_with(|| {
let (_, out) = if hyp.labels.is_empty() {
self.predictor.step(&self.predictor.initial_state(), 0)
} else {
let last_label = *hyp
.labels
.last()
.expect("transducer/decoding.rs: required value was None/Err");
self.predictor.step(&hyp.predictor_state, last_label)
};
out
});
let log_probs = self.joiner.forward(enc_frame, predictor_out);
for (label, &log_prob) in log_probs.iter().enumerate() {
let label = label as Label;
if label == BLANK {
let new_hyp = hyp.extend(BLANK, log_prob, hyp.predictor_state.clone());
new_hypotheses.push(new_hyp);
} else {
let (new_state, _) = self.predictor.step(&hyp.predictor_state, label);
let new_hyp = hyp.extend(label, log_prob, new_state);
new_hypotheses.push(new_hyp);
}
}
}
new_hypotheses.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
new_hypotheses.truncate(beam_width);
hypotheses = merge_hypotheses(new_hypotheses);
}
hypotheses
.into_iter()
.map(|hyp| DecodingResult {
labels: hyp.labels,
score: hyp.score,
stats: TransducerStats::default(),
})
.collect()
}
pub fn beam_decode_with_lm<W>(
&self,
encoder_out: &EncoderOutput,
lm: &VectorWfst<Label, W>,
lm_weight: f32,
) -> Vec<DecodingResult>
where
W: Semiring + Into<f32> + Clone,
{
let beam_width = self.config.beam_width;
let lm_start = lm.start();
let mut hypotheses: Vec<Hypothesis> = vec![Hypothesis::initial_with_lm(
self.predictor.initial_state(),
lm_start,
)];
for t in 0..encoder_out.num_frames {
let enc_frame = encoder_out.frame(t);
let mut new_hypotheses: Vec<Hypothesis> = Vec::new();
for hyp in &hypotheses {
let (_, predictor_out) = if hyp.labels.is_empty() {
self.predictor.step(&self.predictor.initial_state(), 0)
} else {
let last_label = *hyp
.labels
.last()
.expect("transducer/decoding.rs: required value was None/Err");
self.predictor.step(&hyp.predictor_state, last_label)
};
let log_probs = self.joiner.forward(enc_frame, &predictor_out);
let lm_state = hyp.lm_state.expect("LM state must exist");
let blank_prob = log_probs[BLANK as usize];
let new_hyp = hyp.extend(BLANK, blank_prob, hyp.predictor_state.clone());
new_hypotheses.push(new_hyp);
for tr in lm.transitions(lm_state) {
let label = match tr.input {
Some(l) => l,
None => continue, };
if label == 0 || label as usize >= log_probs.len() {
continue;
}
let acoustic_prob = log_probs[label as usize];
let lm_prob: f32 = tr.weight.clone().into();
let combined_prob = acoustic_prob + lm_weight * lm_prob;
let (new_pred_state, _) = self.predictor.step(&hyp.predictor_state, label);
let new_hyp = hyp.extend_with_lm(label, combined_prob, new_pred_state, tr.to);
new_hypotheses.push(new_hyp);
}
}
new_hypotheses.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
new_hypotheses.truncate(beam_width);
hypotheses = merge_hypotheses(new_hypotheses);
}
for hyp in &mut hypotheses {
if let Some(lm_state) = hyp.lm_state {
if lm.is_final(lm_state) {
let final_weight: f32 = lm.final_weight(lm_state).into();
hyp.score += lm_weight * final_weight;
}
}
}
hypotheses
.into_iter()
.map(|hyp| DecodingResult {
labels: hyp.labels,
score: hyp.score,
stats: TransducerStats::default(),
})
.collect()
}
}
fn merge_hypotheses(hypotheses: Vec<Hypothesis>) -> Vec<Hypothesis> {
let mut merged: HashMap<Vec<Label>, Hypothesis> = HashMap::new();
for hyp in hypotheses {
merged
.entry(hyp.labels.clone())
.and_modify(|existing| {
if hyp.score > existing.score {
*existing = hyp.clone();
}
})
.or_insert(hyp);
}
merged.into_values().collect()
}
#[derive(Debug, Clone)]
pub struct DecodingResult {
pub labels: Vec<Label>,
pub score: f32,
pub stats: TransducerStats,
}
#[derive(Debug)]
pub struct StreamingTransducerDecoder<P: AutoregressivePredictor, J: JointNetwork> {
predictor: P,
joiner: J,
config: TransducerConfig,
hypotheses: Vec<Hypothesis>,
frames_processed: usize,
finalized: Vec<Label>,
}
impl<P: AutoregressivePredictor, J: JointNetwork> StreamingTransducerDecoder<P, J> {
pub fn new(predictor: P, joiner: J, config: TransducerConfig) -> Self {
let initial_hyp = Hypothesis::initial(predictor.initial_state());
Self {
predictor,
joiner,
config,
hypotheses: vec![initial_hyp],
frames_processed: 0,
finalized: Vec::new(),
}
}
pub fn process_frame(&mut self, enc_frame: &[f32]) -> Vec<Label> {
let mut new_labels = Vec::new();
let beam_width = self.config.beam_width;
let mut new_hypotheses: Vec<Hypothesis> = Vec::new();
for hyp in &self.hypotheses {
let (_, predictor_out) = if hyp.labels.is_empty() {
self.predictor.step(&self.predictor.initial_state(), 0)
} else {
let last_label = *hyp
.labels
.last()
.expect("transducer/decoding.rs: required value was None/Err");
self.predictor.step(&hyp.predictor_state, last_label)
};
let log_probs = self.joiner.forward(enc_frame, &predictor_out);
for (label, &log_prob) in log_probs.iter().enumerate() {
let label = label as Label;
if label == BLANK {
let new_hyp = hyp.extend(BLANK, log_prob, hyp.predictor_state.clone());
new_hypotheses.push(new_hyp);
} else {
let (new_state, _) = self.predictor.step(&hyp.predictor_state, label);
let new_hyp = hyp.extend(label, log_prob, new_state);
new_hypotheses.push(new_hyp);
}
}
}
new_hypotheses.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
new_hypotheses.truncate(beam_width);
self.hypotheses = merge_hypotheses(new_hypotheses);
if !self.hypotheses.is_empty() {
let first_labels = &self.hypotheses[0].labels;
let prefix_len = self
.hypotheses
.iter()
.skip(1)
.fold(first_labels.len(), |acc, h| {
common_prefix_len(first_labels, &h.labels).min(acc)
});
if prefix_len > self.finalized.len() {
new_labels = first_labels[self.finalized.len()..prefix_len].to_vec();
self.finalized.extend_from_slice(&new_labels);
}
}
self.frames_processed += 1;
new_labels
}
pub fn finalize(&self) -> DecodingResult {
if let Some(best) = self
.hypotheses
.iter()
.max_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(Ordering::Equal))
{
DecodingResult {
labels: best.labels.clone(),
score: best.score,
stats: TransducerStats {
num_frames: self.frames_processed,
..Default::default()
},
}
} else {
DecodingResult {
labels: self.finalized.clone(),
score: 0.0,
stats: TransducerStats::default(),
}
}
}
pub fn reset(&mut self) {
self.hypotheses = vec![Hypothesis::initial(self.predictor.initial_state())];
self.frames_processed = 0;
self.finalized.clear();
}
}
fn common_prefix_len(a: &[Label], b: &[Label]) -> usize {
a.iter().zip(b.iter()).take_while(|(x, y)| x == y).count()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hypothesis_ordering() {
let h1 = Hypothesis {
labels: vec![],
score: -1.0,
predictor_state: PredictorState::default(),
lm_state: None,
timestep: 0,
};
let h2 = Hypothesis {
labels: vec![],
score: -2.0,
predictor_state: PredictorState::default(),
lm_state: None,
timestep: 0,
};
assert!(h1 < h2); }
#[test]
fn test_common_prefix_len() {
assert_eq!(common_prefix_len(&[1, 2, 3], &[1, 2, 4]), 2);
assert_eq!(common_prefix_len(&[1, 2, 3], &[1, 2, 3]), 3);
assert_eq!(common_prefix_len(&[1, 2, 3], &[4, 5, 6]), 0);
assert_eq!(common_prefix_len(&[], &[1, 2, 3]), 0);
}
}