use crate::semiring::Semiring;
use crate::wfst::{StateId, VectorWfst};
use std::fmt::Debug;
pub type Label = u32;
pub type FrameIndex = usize;
pub const BLANK: Label = 0;
pub trait AcousticEncoder: Send + Sync + Debug {
fn output_dim(&self) -> usize;
fn output_length(&self, input_length: usize) -> usize;
fn get_frame(&self, encoder_out: &EncoderOutput, t: FrameIndex) -> &[f32];
}
#[derive(Debug, Clone)]
pub struct EncoderOutput {
pub data: Vec<f32>,
pub num_frames: usize,
pub dim: usize,
}
impl EncoderOutput {
pub fn new(data: Vec<f32>, num_frames: usize, dim: usize) -> Self {
debug_assert_eq!(data.len(), num_frames * dim);
Self {
data,
num_frames,
dim,
}
}
#[inline]
pub fn frame(&self, t: FrameIndex) -> &[f32] {
let start = t * self.dim;
&self.data[start..start + self.dim]
}
#[inline]
pub fn len(&self) -> usize {
self.num_frames
}
#[inline]
pub fn is_empty(&self) -> bool {
self.num_frames == 0
}
}
pub trait AutoregressivePredictor: Send + Sync + Debug {
fn output_dim(&self) -> usize;
fn initial_state(&self) -> PredictorState;
fn step(&self, state: &PredictorState, token: Label) -> (PredictorState, Vec<f32>);
fn get_output(&self, predictor_out: &PredictorOutput, u: usize) -> &[f32];
}
#[derive(Debug, Clone, Default)]
pub struct PredictorState {
pub hidden: Vec<f32>,
pub cell: Vec<f32>,
pub num_tokens: usize,
}
impl PredictorState {
pub fn new(hidden: Vec<f32>, cell: Vec<f32>) -> Self {
Self {
hidden,
cell,
num_tokens: 0,
}
}
pub fn zeros(dim: usize) -> Self {
Self {
hidden: vec![0.0; dim],
cell: vec![0.0; dim],
num_tokens: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct PredictorOutput {
pub data: Vec<f32>,
pub num_positions: usize,
pub dim: usize,
}
impl PredictorOutput {
pub fn new(data: Vec<f32>, num_positions: usize, dim: usize) -> Self {
debug_assert_eq!(data.len(), num_positions * dim);
Self {
data,
num_positions,
dim,
}
}
#[inline]
pub fn position(&self, u: usize) -> &[f32] {
let start = u * self.dim;
&self.data[start..start + self.dim]
}
#[inline]
pub fn len(&self) -> usize {
self.num_positions
}
#[inline]
pub fn is_empty(&self) -> bool {
self.num_positions == 0
}
}
pub trait JointNetwork: Send + Sync + Debug {
fn vocab_size(&self) -> usize;
fn forward(&self, encoder_frame: &[f32], predictor_output: &[f32]) -> Vec<f32>;
fn forward_batch(
&self,
encoder_frames: &[&[f32]],
predictor_outputs: &[&[f32]],
) -> Vec<Vec<f32>> {
encoder_frames
.iter()
.zip(predictor_outputs.iter())
.map(|(enc, pred)| self.forward(enc, pred))
.collect()
}
}
pub trait NeuralTransducer: Send + Sync + Debug {
type Encoder: AcousticEncoder;
type Predictor: AutoregressivePredictor;
type Joiner: JointNetwork;
fn encoder(&self) -> &Self::Encoder;
fn predictor(&self) -> &Self::Predictor;
fn joiner(&self) -> &Self::Joiner;
fn vocab_size(&self) -> usize {
self.joiner().vocab_size()
}
fn build_lattice<W: Semiring + From<f64>>(
&self,
encoder_out: &EncoderOutput,
predictor_out: &PredictorOutput,
) -> TransducerLattice<W>;
}
#[derive(Debug, Clone)]
pub struct TransducerConfig {
pub beam_width: usize,
pub max_active: usize,
pub pruning_threshold: f32,
pub use_batch_joiner: bool,
pub max_symbols_per_frame: usize,
}
impl Default for TransducerConfig {
fn default() -> Self {
Self {
beam_width: 10,
max_active: 1000,
pruning_threshold: 10.0,
use_batch_joiner: true,
max_symbols_per_frame: 10,
}
}
}
#[derive(Debug, Clone)]
pub struct TransducerLattice<W: Semiring> {
pub num_frames: usize,
pub num_positions: usize,
pub vocab_size: usize,
pub log_probs: Vec<f64>,
_phantom: std::marker::PhantomData<W>,
}
impl<W: Semiring> TransducerLattice<W> {
pub fn new(num_frames: usize, num_positions: usize, vocab_size: usize) -> Self {
let size = num_frames * num_positions * vocab_size;
Self {
num_frames,
num_positions,
vocab_size,
log_probs: vec![f64::NEG_INFINITY; size],
_phantom: std::marker::PhantomData,
}
}
#[inline]
pub fn set(&mut self, t: usize, u: usize, label: Label, log_prob: f64) {
let idx = self.index(t, u, label as usize);
self.log_probs[idx] = log_prob;
}
#[inline]
pub fn get(&self, t: usize, u: usize, label: Label) -> f64 {
let idx = self.index(t, u, label as usize);
self.log_probs[idx]
}
#[inline]
fn index(&self, t: usize, u: usize, label: usize) -> usize {
(t * self.num_positions + u) * self.vocab_size + label
}
pub fn get_position(&self, t: usize, u: usize) -> &[f64] {
let start = (t * self.num_positions + u) * self.vocab_size;
&self.log_probs[start..start + self.vocab_size]
}
pub fn to_wfst(&self) -> VectorWfst<Label, W>
where
W: From<f64> + Clone,
{
use crate::wfst::{MutableWfst, WeightedTransition};
let mut fst: VectorWfst<Label, W> = VectorWfst::new();
let num_states = (self.num_frames + 1) * self.num_positions;
fst.add_states(num_states);
fst.set_start(0);
let final_state =
(self.num_frames * self.num_positions + (self.num_positions - 1)) as StateId;
fst.set_final(final_state, W::one());
for t in 0..self.num_frames {
for u in 0..self.num_positions {
let from_state = (t * self.num_positions + u) as StateId;
let blank_prob = self.get(t, u, BLANK);
if blank_prob > f64::NEG_INFINITY {
let to_state = ((t + 1) * self.num_positions + u) as StateId;
fst.add_transition(WeightedTransition {
from: from_state,
input: Some(BLANK),
output: Some(BLANK),
to: to_state,
weight: W::from(-blank_prob), });
}
if u + 1 < self.num_positions {
for label in 1..self.vocab_size as Label {
let label_prob = self.get(t, u, label);
if label_prob > f64::NEG_INFINITY {
let to_state = ((t + 1) * self.num_positions + u + 1) as StateId;
fst.add_transition(WeightedTransition {
from: from_state,
input: Some(label),
output: Some(label),
to: to_state,
weight: W::from(-label_prob),
});
}
}
}
}
}
fst
}
}
#[derive(Debug, Clone, Default)]
pub struct TransducerStats {
pub num_frames: usize,
pub num_positions: usize,
pub num_hypotheses: usize,
pub encoder_time_ms: f64,
pub predictor_time_ms: f64,
pub joiner_time_ms: f64,
pub search_time_ms: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
#[test]
fn test_encoder_output_creation() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let num_frames = 2;
let dim = 3;
let encoder_out = EncoderOutput::new(data.clone(), num_frames, dim);
assert_eq!(encoder_out.num_frames, 2);
assert_eq!(encoder_out.dim, 3);
assert_eq!(encoder_out.len(), 2);
assert!(!encoder_out.is_empty());
}
#[test]
fn test_encoder_output_frame_access() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let encoder_out = EncoderOutput::new(data, 3, 2);
assert_eq!(encoder_out.frame(0), &[1.0f32, 2.0]);
assert_eq!(encoder_out.frame(1), &[3.0f32, 4.0]);
assert_eq!(encoder_out.frame(2), &[5.0f32, 6.0]);
}
#[test]
fn test_encoder_output_empty() {
let encoder_out = EncoderOutput::new(vec![], 0, 4);
assert_eq!(encoder_out.len(), 0);
assert!(encoder_out.is_empty());
}
#[test]
fn test_encoder_output_single_frame() {
let data = vec![0.1f32, 0.2, 0.3, 0.4];
let encoder_out = EncoderOutput::new(data, 1, 4);
assert_eq!(encoder_out.len(), 1);
assert_eq!(encoder_out.frame(0), &[0.1f32, 0.2, 0.3, 0.4]);
}
#[test]
fn test_predictor_state_creation() {
let hidden = vec![0.1, 0.2, 0.3];
let cell = vec![0.4, 0.5, 0.6];
let state = PredictorState::new(hidden.clone(), cell.clone());
assert_eq!(state.hidden, hidden);
assert_eq!(state.cell, cell);
assert_eq!(state.num_tokens, 0);
}
#[test]
fn test_predictor_state_zeros() {
let state = PredictorState::zeros(4);
assert_eq!(state.hidden, vec![0.0; 4]);
assert_eq!(state.cell, vec![0.0; 4]);
assert_eq!(state.num_tokens, 0);
}
#[test]
fn test_predictor_state_default() {
let state = PredictorState::default();
assert!(state.hidden.is_empty());
assert!(state.cell.is_empty());
assert_eq!(state.num_tokens, 0);
}
#[test]
fn test_predictor_output_creation() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let num_positions = 4;
let dim = 2;
let predictor_out = PredictorOutput::new(data.clone(), num_positions, dim);
assert_eq!(predictor_out.num_positions, 4);
assert_eq!(predictor_out.dim, 2);
assert_eq!(predictor_out.len(), 4);
assert!(!predictor_out.is_empty());
}
#[test]
fn test_predictor_output_position_access() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let predictor_out = PredictorOutput::new(data, 3, 2);
assert_eq!(predictor_out.position(0), &[1.0, 2.0]);
assert_eq!(predictor_out.position(1), &[3.0, 4.0]);
assert_eq!(predictor_out.position(2), &[5.0, 6.0]);
}
#[test]
fn test_predictor_output_empty() {
let predictor_out = PredictorOutput::new(vec![], 0, 4);
assert_eq!(predictor_out.len(), 0);
assert!(predictor_out.is_empty());
}
#[test]
fn test_transducer_config_default() {
let config = TransducerConfig::default();
assert_eq!(config.beam_width, 10);
assert_eq!(config.max_active, 1000);
assert!((config.pruning_threshold - 10.0).abs() < f32::EPSILON);
assert!(config.use_batch_joiner);
assert_eq!(config.max_symbols_per_frame, 10);
}
#[test]
fn test_transducer_config_custom() {
let config = TransducerConfig {
beam_width: 20,
max_active: 500,
pruning_threshold: 5.0,
use_batch_joiner: false,
max_symbols_per_frame: 5,
};
assert_eq!(config.beam_width, 20);
assert_eq!(config.max_active, 500);
assert!(!config.use_batch_joiner);
}
#[test]
fn test_transducer_lattice_creation() {
let lattice: TransducerLattice<TropicalWeight> = TransducerLattice::new(5, 3, 10);
assert_eq!(lattice.num_frames, 5);
assert_eq!(lattice.num_positions, 3);
assert_eq!(lattice.vocab_size, 10);
assert_eq!(lattice.log_probs.len(), 5 * 3 * 10);
}
#[test]
fn test_transducer_lattice_set_get() {
let mut lattice: TransducerLattice<TropicalWeight> = TransducerLattice::new(3, 2, 5);
lattice.set(0, 0, 0, -1.0);
lattice.set(0, 0, 1, -2.0);
lattice.set(1, 1, 2, -3.0);
lattice.set(2, 0, BLANK, -0.5);
assert!((lattice.get(0, 0, 0) - (-1.0)).abs() < 1e-10);
assert!((lattice.get(0, 0, 1) - (-2.0)).abs() < 1e-10);
assert!((lattice.get(1, 1, 2) - (-3.0)).abs() < 1e-10);
assert!((lattice.get(2, 0, BLANK) - (-0.5)).abs() < 1e-10);
assert!(lattice.get(0, 1, 3) == f64::NEG_INFINITY);
}
#[test]
fn test_transducer_lattice_get_position() {
let mut lattice: TransducerLattice<TropicalWeight> = TransducerLattice::new(2, 2, 3);
lattice.set(0, 0, 0, -1.0);
lattice.set(0, 0, 1, -2.0);
lattice.set(0, 0, 2, -3.0);
let position_probs = lattice.get_position(0, 0);
assert_eq!(position_probs.len(), 3);
assert!((position_probs[0] - (-1.0)).abs() < 1e-10);
assert!((position_probs[1] - (-2.0)).abs() < 1e-10);
assert!((position_probs[2] - (-3.0)).abs() < 1e-10);
}
#[test]
fn test_transducer_lattice_to_wfst_structure() {
use crate::wfst::Wfst;
let mut lattice: TransducerLattice<TropicalWeight> = TransducerLattice::new(2, 2, 3);
lattice.set(0, 0, BLANK, -0.1); lattice.set(0, 0, 1, -0.5);
lattice.set(1, 0, BLANK, -0.2);
lattice.set(1, 1, BLANK, -0.3);
let wfst = lattice.to_wfst();
assert!(wfst.num_states() > 0);
let start_state = wfst.start();
assert!(
wfst.is_valid_state(start_state),
"Start state should be valid"
);
let has_final = (0..wfst.num_states()).any(|s| wfst.is_final(s as u32));
assert!(has_final, "WFST should have at least one final state");
}
#[test]
fn test_transducer_lattice_empty() {
let lattice: TransducerLattice<TropicalWeight> = TransducerLattice::new(0, 0, 0);
assert_eq!(lattice.num_frames, 0);
assert_eq!(lattice.num_positions, 0);
assert_eq!(lattice.vocab_size, 0);
assert!(lattice.log_probs.is_empty());
}
#[test]
fn test_transducer_lattice_indexing() {
let mut lattice2: TransducerLattice<TropicalWeight> = TransducerLattice::new(3, 4, 5);
lattice2.set(0, 0, 0, 1.0);
lattice2.set(0, 0, 4, 2.0);
lattice2.set(0, 1, 0, 3.0);
lattice2.set(1, 0, 0, 4.0);
lattice2.set(2, 3, 4, 5.0);
assert!((lattice2.get(0, 0, 0) - 1.0).abs() < 1e-10);
assert!((lattice2.get(0, 0, 4) - 2.0).abs() < 1e-10);
assert!((lattice2.get(0, 1, 0) - 3.0).abs() < 1e-10);
assert!((lattice2.get(1, 0, 0) - 4.0).abs() < 1e-10);
assert!((lattice2.get(2, 3, 4) - 5.0).abs() < 1e-10);
}
#[test]
fn test_transducer_stats_default() {
let stats = TransducerStats::default();
assert_eq!(stats.num_frames, 0);
assert_eq!(stats.num_positions, 0);
assert_eq!(stats.num_hypotheses, 0);
assert!((stats.encoder_time_ms - 0.0).abs() < f64::EPSILON);
assert!((stats.predictor_time_ms - 0.0).abs() < f64::EPSILON);
assert!((stats.joiner_time_ms - 0.0).abs() < f64::EPSILON);
assert!((stats.search_time_ms - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_transducer_stats_custom() {
let stats = TransducerStats {
num_frames: 100,
num_positions: 50,
num_hypotheses: 200,
encoder_time_ms: 10.5,
predictor_time_ms: 5.2,
joiner_time_ms: 15.8,
search_time_ms: 3.1,
};
assert_eq!(stats.num_frames, 100);
assert_eq!(stats.num_positions, 50);
assert_eq!(stats.num_hypotheses, 200);
assert!((stats.encoder_time_ms - 10.5).abs() < 1e-10);
}
#[test]
fn test_blank_constant() {
assert_eq!(BLANK, 0);
}
#[test]
fn test_label_type() {
let label: Label = 42;
assert_eq!(label, 42u32);
}
#[test]
fn test_frame_index_type() {
let frame: FrameIndex = 100;
assert_eq!(frame, 100usize);
}
#[test]
fn test_encoder_output_clone() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let encoder_out = EncoderOutput::new(data.clone(), 2, 2);
let cloned = encoder_out.clone();
assert_eq!(encoder_out.data, cloned.data);
assert_eq!(encoder_out.num_frames, cloned.num_frames);
assert_eq!(encoder_out.dim, cloned.dim);
}
#[test]
fn test_predictor_state_clone() {
let state = PredictorState::new(vec![1.0, 2.0], vec![3.0, 4.0]);
let cloned = state.clone();
assert_eq!(state.hidden, cloned.hidden);
assert_eq!(state.cell, cloned.cell);
assert_eq!(state.num_tokens, cloned.num_tokens);
}
#[test]
fn test_predictor_output_clone() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let predictor_out = PredictorOutput::new(data.clone(), 2, 2);
let cloned = predictor_out.clone();
assert_eq!(predictor_out.data, cloned.data);
assert_eq!(predictor_out.num_positions, cloned.num_positions);
assert_eq!(predictor_out.dim, cloned.dim);
}
#[test]
fn test_transducer_lattice_clone() {
let mut lattice: TransducerLattice<TropicalWeight> = TransducerLattice::new(2, 2, 3);
lattice.set(0, 0, 0, -1.0);
lattice.set(1, 1, 2, -2.0);
let cloned = lattice.clone();
assert_eq!(lattice.num_frames, cloned.num_frames);
assert_eq!(lattice.num_positions, cloned.num_positions);
assert_eq!(lattice.vocab_size, cloned.vocab_size);
assert!((lattice.get(0, 0, 0) - cloned.get(0, 0, 0)).abs() < 1e-10);
assert!((lattice.get(1, 1, 2) - cloned.get(1, 1, 2)).abs() < 1e-10);
}
}