use std::sync::Arc;
use crate::semiring::LogWeight;
pub type HmmStateId = u32;
pub type UnitId = u32;
pub type TransitionLogProb = f32;
#[derive(Clone, Debug)]
pub struct TransitionMatrix {
num_states: usize,
transitions: Vec<Vec<(HmmStateId, TransitionLogProb)>>,
initial_probs: Vec<TransitionLogProb>,
is_final: Vec<bool>,
}
impl TransitionMatrix {
pub fn new(num_states: usize) -> Self {
Self {
num_states,
transitions: vec![Vec::new(); num_states],
initial_probs: vec![f32::NEG_INFINITY; num_states],
is_final: vec![false; num_states],
}
}
pub fn left_to_right(num_states: usize, self_loop_prob: f32) -> Self {
let mut tm = Self::new(num_states);
tm.initial_probs[0] = 0.0;
let forward_prob = 1.0 - self_loop_prob;
let log_self = self_loop_prob.ln();
let log_forward = forward_prob.ln();
for i in 0..num_states {
tm.add_transition(i as HmmStateId, i as HmmStateId, log_self);
if i + 1 < num_states {
tm.add_transition(i as HmmStateId, (i + 1) as HmmStateId, log_forward);
}
}
tm.is_final[num_states - 1] = true;
tm
}
pub fn bakis(num_states: usize, self_prob: f32, forward_prob: f32) -> Self {
let mut tm = Self::new(num_states);
tm.initial_probs[0] = 0.0;
let skip_prob = 1.0 - self_prob - forward_prob;
let log_self = self_prob.ln();
let log_forward = forward_prob.ln();
let log_skip = skip_prob.ln();
for i in 0..num_states {
tm.add_transition(i as HmmStateId, i as HmmStateId, log_self);
if i + 1 < num_states {
tm.add_transition(i as HmmStateId, (i + 1) as HmmStateId, log_forward);
}
if i + 2 < num_states {
tm.add_transition(i as HmmStateId, (i + 2) as HmmStateId, log_skip);
}
}
if num_states > 0 {
tm.is_final[num_states - 1] = true;
}
if num_states > 1 {
tm.is_final[num_states - 2] = true;
}
tm
}
pub fn add_transition(
&mut self,
from: HmmStateId,
to: HmmStateId,
log_prob: TransitionLogProb,
) {
if (from as usize) < self.num_states && (to as usize) < self.num_states {
self.transitions[from as usize].push((to, log_prob));
}
}
pub fn set_initial(&mut self, state: HmmStateId, log_prob: TransitionLogProb) {
if (state as usize) < self.num_states {
self.initial_probs[state as usize] = log_prob;
}
}
pub fn set_final(&mut self, state: HmmStateId, is_final: bool) {
if (state as usize) < self.num_states {
self.is_final[state as usize] = is_final;
}
}
pub fn num_states(&self) -> usize {
self.num_states
}
pub fn transitions_from(&self, state: HmmStateId) -> &[(HmmStateId, TransitionLogProb)] {
if (state as usize) < self.num_states {
&self.transitions[state as usize]
} else {
&[]
}
}
pub fn initial_prob(&self, state: HmmStateId) -> TransitionLogProb {
if (state as usize) < self.num_states {
self.initial_probs[state as usize]
} else {
f32::NEG_INFINITY
}
}
pub fn is_final(&self, state: HmmStateId) -> bool {
if (state as usize) < self.num_states {
self.is_final[state as usize]
} else {
false
}
}
pub fn initial_states(&self) -> Vec<HmmStateId> {
self.initial_probs
.iter()
.enumerate()
.filter(|(_, &p)| p > f32::NEG_INFINITY)
.map(|(i, _)| i as HmmStateId)
.collect()
}
pub fn final_states(&self) -> Vec<HmmStateId> {
self.is_final
.iter()
.enumerate()
.filter(|(_, &f)| f)
.map(|(i, _)| i as HmmStateId)
.collect()
}
}
pub trait AcousticModel: Send + Sync {
fn feature_dim(&self) -> usize;
fn num_units(&self) -> usize;
fn forward(&self, frames: &[Vec<f32>]) -> Vec<Vec<f32>>;
fn forward_sequence(&self, frames: &[Vec<f32>]) -> Vec<Vec<f32>> {
frames
.iter()
.map(|f| self.forward(std::slice::from_ref(f))[0].clone())
.collect()
}
fn transition_matrix(&self) -> Option<&TransitionMatrix> {
None
}
fn blank_id(&self) -> Option<UnitId> {
None
}
fn unit_name(&self, unit: UnitId) -> Option<String> {
let _ = unit;
None
}
}
#[derive(Clone, Debug)]
pub struct FusionConfig {
pub acoustic_weight: f64,
pub lm_weight: f64,
pub word_insertion_penalty: f64,
pub blank_penalty: f64,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
acoustic_weight: 1.0,
lm_weight: 0.5,
word_insertion_penalty: 0.0,
blank_penalty: 0.0,
}
}
}
#[derive(Clone)]
pub struct AcousticLanguageModel<A: AcousticModel, L> {
acoustic: Arc<A>,
language: Arc<L>,
config: FusionConfig,
}
impl<A: AcousticModel, L> AcousticLanguageModel<A, L> {
pub fn new(acoustic: Arc<A>, language: Arc<L>, config: FusionConfig) -> Self {
Self {
acoustic,
language,
config,
}
}
pub fn acoustic(&self) -> &A {
&self.acoustic
}
pub fn language(&self) -> &L {
&self.language
}
pub fn config(&self) -> &FusionConfig {
&self.config
}
pub fn config_mut(&mut self) -> &mut FusionConfig {
&mut self.config
}
pub fn acoustic_forward(&self, frames: &[Vec<f32>]) -> Vec<Vec<f32>> {
self.acoustic.forward(frames)
}
pub fn weight_acoustic(&self, log_prob: f64) -> f64 {
self.config.acoustic_weight * log_prob
}
pub fn weight_lm(&self, log_prob: f64) -> f64 {
self.config.lm_weight * log_prob
}
pub fn combine_scores(&self, acoustic_log_prob: f64, lm_log_prob: f64) -> f64 {
self.weight_acoustic(acoustic_log_prob) + self.weight_lm(lm_log_prob)
}
pub fn to_log_weight(&self, acoustic_log_prob: f64, lm_log_prob: f64) -> LogWeight {
let combined = self.combine_scores(acoustic_log_prob, lm_log_prob);
LogWeight::new(-combined)
}
}
#[derive(Clone, Debug)]
pub struct FramePosterior {
pub frame_idx: usize,
pub log_probs: Vec<f32>,
pub top_k_units: Option<Vec<UnitId>>,
}
impl FramePosterior {
pub fn new(frame_idx: usize, log_probs: Vec<f32>) -> Self {
Self {
frame_idx,
log_probs,
top_k_units: None,
}
}
pub fn best_unit(&self) -> Option<UnitId> {
self.log_probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as UnitId)
}
pub fn log_prob(&self, unit: UnitId) -> f32 {
self.log_probs
.get(unit as usize)
.copied()
.unwrap_or(f32::NEG_INFINITY)
}
pub fn compute_top_k(&mut self, k: usize) {
let mut indexed: Vec<(usize, f32)> = self.log_probs.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
self.top_k_units = Some(
indexed
.into_iter()
.take(k)
.map(|(i, _)| i as UnitId)
.collect(),
);
}
}
#[derive(Clone, Debug)]
pub struct PosteriorSequence {
pub frames: Vec<FramePosterior>,
pub num_units: usize,
}
impl PosteriorSequence {
pub fn from_raw(posteriors: Vec<Vec<f32>>) -> Self {
let num_units = posteriors.first().map(|f| f.len()).unwrap_or(0);
let frames = posteriors
.into_iter()
.enumerate()
.map(|(i, probs)| FramePosterior::new(i, probs))
.collect();
Self { frames, num_units }
}
pub fn len(&self) -> usize {
self.frames.len()
}
pub fn is_empty(&self) -> bool {
self.frames.is_empty()
}
pub fn frame(&self, idx: usize) -> Option<&FramePosterior> {
self.frames.get(idx)
}
pub fn greedy_path(&self) -> Vec<UnitId> {
self.frames.iter().filter_map(|f| f.best_unit()).collect()
}
pub fn compute_all_top_k(&mut self, k: usize) {
for frame in &mut self.frames {
frame.compute_top_k(k);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transition_matrix_left_to_right() {
let tm = TransitionMatrix::left_to_right(3, 0.5);
assert_eq!(tm.num_states(), 3);
assert_eq!(tm.initial_states(), vec![0]);
assert_eq!(tm.final_states(), vec![2]);
let trans = tm.transitions_from(0);
assert_eq!(trans.len(), 2); }
#[test]
fn test_transition_matrix_bakis() {
let tm = TransitionMatrix::bakis(4, 0.3, 0.4);
assert_eq!(tm.num_states(), 4);
let trans = tm.transitions_from(0);
assert_eq!(trans.len(), 3);
}
#[test]
fn test_frame_posterior() {
let mut fp = FramePosterior::new(0, vec![-1.0, -0.5, -2.0, -0.1]);
assert_eq!(fp.best_unit(), Some(3)); assert!((fp.log_prob(1) - (-0.5)).abs() < 1e-6);
fp.compute_top_k(2);
assert!(fp.top_k_units.is_some());
assert_eq!(
fp.top_k_units
.as_ref()
.expect("acoustic/mod.rs: required value was None/Err")
.len(),
2
);
assert_eq!(
fp.top_k_units
.as_ref()
.expect("acoustic/mod.rs: required value was None/Err")[0],
3
); assert_eq!(
fp.top_k_units
.as_ref()
.expect("acoustic/mod.rs: required value was None/Err")[1],
1
); }
#[test]
fn test_posterior_sequence() {
let posteriors = vec![
vec![-1.0, -0.5, -2.0],
vec![-0.1, -0.8, -1.5],
vec![-0.3, -0.2, -0.9],
];
let seq = PosteriorSequence::from_raw(posteriors);
assert_eq!(seq.len(), 3);
assert_eq!(seq.num_units, 3);
let greedy = seq.greedy_path();
assert_eq!(greedy, vec![1, 0, 1]); }
#[test]
fn test_fusion_config_default() {
let config = FusionConfig::default();
assert!((config.acoustic_weight - 1.0).abs() < 1e-6);
assert!((config.lm_weight - 0.5).abs() < 1e-6);
assert!((config.word_insertion_penalty - 0.0).abs() < 1e-6);
}
struct MockAcousticModel {
feature_dim: usize,
num_units: usize,
}
impl AcousticModel for MockAcousticModel {
fn feature_dim(&self) -> usize {
self.feature_dim
}
fn num_units(&self) -> usize {
self.num_units
}
fn forward(&self, frames: &[Vec<f32>]) -> Vec<Vec<f32>> {
let log_prob = (-(self.num_units as f32)).ln();
frames
.iter()
.map(|_| vec![log_prob; self.num_units])
.collect()
}
}
#[test]
fn test_acoustic_model_trait() {
let model = MockAcousticModel {
feature_dim: 40,
num_units: 100,
};
assert_eq!(model.feature_dim(), 40);
assert_eq!(model.num_units(), 100);
let frames = vec![vec![0.0f32; 40]; 5];
let posteriors = model.forward(&frames);
assert_eq!(posteriors.len(), 5);
assert_eq!(posteriors[0].len(), 100);
}
#[test]
fn test_acoustic_language_model() {
let acoustic = Arc::new(MockAcousticModel {
feature_dim: 40,
num_units: 100,
});
let language: Arc<()> = Arc::new(());
let config = FusionConfig {
acoustic_weight: 1.0,
lm_weight: 0.5,
..Default::default()
};
let alm = AcousticLanguageModel::new(acoustic, language, config);
let am_score = -2.0; let lm_score = -1.0;
let combined = alm.combine_scores(am_score, lm_score);
assert!((combined - (-2.5)).abs() < 1e-6);
let lw = alm.to_log_weight(am_score, lm_score);
assert!((lw.value() - 2.5).abs() < 1e-6);
}
}