use super::JointNetwork;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct FeedForwardJoiner {
pub vocab_size: usize,
pub hidden_dim: usize,
pub w_enc: Vec<f32>,
pub w_pred: Vec<f32>,
pub b_hidden: Vec<f32>,
pub w_out: Vec<f32>,
pub b_out: Vec<f32>,
pub enc_dim: usize,
pub pred_dim: usize,
}
impl FeedForwardJoiner {
pub fn new(vocab_size: usize, enc_dim: usize, pred_dim: usize, hidden_dim: usize) -> Self {
Self {
vocab_size,
hidden_dim,
w_enc: vec![0.0; hidden_dim * enc_dim],
w_pred: vec![0.0; hidden_dim * pred_dim],
b_hidden: vec![0.0; hidden_dim],
w_out: vec![0.0; vocab_size * hidden_dim],
b_out: vec![0.0; vocab_size],
enc_dim,
pred_dim,
}
}
pub fn from_weights(
vocab_size: usize,
enc_dim: usize,
pred_dim: usize,
hidden_dim: usize,
w_enc: Vec<f32>,
w_pred: Vec<f32>,
b_hidden: Vec<f32>,
w_out: Vec<f32>,
b_out: Vec<f32>,
) -> Self {
debug_assert_eq!(w_enc.len(), hidden_dim * enc_dim);
debug_assert_eq!(w_pred.len(), hidden_dim * pred_dim);
debug_assert_eq!(b_hidden.len(), hidden_dim);
debug_assert_eq!(w_out.len(), vocab_size * hidden_dim);
debug_assert_eq!(b_out.len(), vocab_size);
Self {
vocab_size,
hidden_dim,
w_enc,
w_pred,
b_hidden,
w_out,
b_out,
enc_dim,
pred_dim,
}
}
}
impl JointNetwork for FeedForwardJoiner {
fn vocab_size(&self) -> usize {
self.vocab_size
}
fn forward(&self, encoder_frame: &[f32], predictor_output: &[f32]) -> Vec<f32> {
let mut hidden = self.b_hidden.clone();
for (i, h) in hidden.iter_mut().enumerate() {
for (j, &enc) in encoder_frame.iter().enumerate() {
*h += self.w_enc[i * self.enc_dim + j] * enc;
}
}
for (i, h) in hidden.iter_mut().enumerate() {
for (j, &pred) in predictor_output.iter().enumerate() {
*h += self.w_pred[i * self.pred_dim + j] * pred;
}
}
for h in &mut hidden {
*h = h.tanh();
}
let mut logits = self.b_out.clone();
for (i, logit) in logits.iter_mut().enumerate() {
for (j, &h) in hidden.iter().enumerate() {
*logit += self.w_out[i * self.hidden_dim + j] * h;
}
}
log_softmax(&logits)
}
fn forward_batch(
&self,
encoder_frames: &[&[f32]],
predictor_outputs: &[&[f32]],
) -> Vec<Vec<f32>> {
encoder_frames
.iter()
.zip(predictor_outputs.iter())
.map(|(enc, pred)| self.forward(enc, pred))
.collect()
}
}
#[derive(Debug, Clone)]
pub struct FactorizedJoiner {
pub vocab_size: usize,
pub enc_dim: usize,
pub pred_dim: usize,
pub w_blank: Vec<f32>,
pub b_blank: f32,
pub w_vocab: Vec<f32>,
pub b_vocab: Vec<f32>,
}
impl FactorizedJoiner {
pub fn new(vocab_size: usize, enc_dim: usize, pred_dim: usize) -> Self {
Self {
vocab_size,
enc_dim,
pred_dim,
w_blank: vec![0.0; enc_dim],
b_blank: 0.0,
w_vocab: vec![0.0; (vocab_size - 1) * pred_dim],
b_vocab: vec![0.0; vocab_size - 1],
}
}
pub fn from_weights(
vocab_size: usize,
enc_dim: usize,
pred_dim: usize,
w_blank: Vec<f32>,
b_blank: f32,
w_vocab: Vec<f32>,
b_vocab: Vec<f32>,
) -> Self {
debug_assert_eq!(w_blank.len(), enc_dim);
debug_assert_eq!(w_vocab.len(), (vocab_size - 1) * pred_dim);
debug_assert_eq!(b_vocab.len(), vocab_size - 1);
Self {
vocab_size,
enc_dim,
pred_dim,
w_blank,
b_blank,
w_vocab,
b_vocab,
}
}
fn blank_prob(&self, encoder_frame: &[f32]) -> f32 {
let mut logit = self.b_blank;
for (w, &enc) in self.w_blank.iter().zip(encoder_frame.iter()) {
logit += w * enc;
}
sigmoid(logit)
}
fn vocab_log_probs(&self, predictor_output: &[f32]) -> Vec<f32> {
let mut logits = self.b_vocab.clone();
for (i, logit) in logits.iter_mut().enumerate() {
for (j, &pred) in predictor_output.iter().enumerate() {
*logit += self.w_vocab[i * self.pred_dim + j] * pred;
}
}
log_softmax(&logits)
}
}
impl JointNetwork for FactorizedJoiner {
fn vocab_size(&self) -> usize {
self.vocab_size
}
fn forward(&self, encoder_frame: &[f32], predictor_output: &[f32]) -> Vec<f32> {
let blank_p = self.blank_prob(encoder_frame);
let vocab_log_probs = self.vocab_log_probs(predictor_output);
let mut result = Vec::with_capacity(self.vocab_size);
result.push(blank_p.ln());
let non_blank_log = (1.0 - blank_p).ln();
for lp in vocab_log_probs {
result.push(non_blank_log + lp);
}
result
}
}
#[derive(Debug, Clone)]
pub struct AdditiveJoiner {
pub vocab_size: usize,
}
impl AdditiveJoiner {
pub fn new(vocab_size: usize) -> Self {
Self { vocab_size }
}
}
impl JointNetwork for AdditiveJoiner {
fn vocab_size(&self) -> usize {
self.vocab_size
}
fn forward(&self, encoder_frame: &[f32], predictor_output: &[f32]) -> Vec<f32> {
let logits: Vec<f32> = encoder_frame
.iter()
.zip(predictor_output.iter())
.map(|(e, p)| e + p)
.collect();
log_softmax(&logits)
}
}
fn log_softmax(logits: &[f32]) -> Vec<f32> {
if logits.is_empty() {
return Vec::new();
}
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let log_sum_exp: f32 = logits
.iter()
.map(|&x| (x - max_logit).exp())
.sum::<f32>()
.ln()
+ max_logit;
logits.iter().map(|&x| x - log_sum_exp).collect()
}
#[inline]
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_softmax() {
let logits = vec![1.0, 2.0, 3.0];
let result = log_softmax(&logits);
let sum: f32 = result.iter().map(|x| x.exp()).sum();
assert!((sum - 1.0).abs() < 1e-5);
assert!(result[2] > result[1]);
assert!(result[1] > result[0]);
}
#[test]
fn test_sigmoid() {
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
assert!(sigmoid(10.0) > 0.99);
assert!(sigmoid(-10.0) < 0.01);
}
#[test]
fn test_feedforward_joiner() {
let joiner = FeedForwardJoiner::new(10, 256, 256, 128);
let enc = vec![0.1; 256];
let pred = vec![0.2; 256];
let result = joiner.forward(&enc, &pred);
assert_eq!(result.len(), 10);
let sum: f32 = result.iter().map(|x| x.exp()).sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn test_factorized_joiner() {
let joiner = FactorizedJoiner::new(10, 256, 256);
let enc = vec![0.1; 256];
let pred = vec![0.2; 256];
let result = joiner.forward(&enc, &pred);
assert_eq!(result.len(), 10);
for &lp in &result {
assert!(lp <= 0.0);
assert!(lp.is_finite());
}
}
}