use super::Reduction;
use alloc::vec;
use burn::config::Config;
use burn::module::Module;
use burn::tensor::{Bool, Int, Tensor, backend::Backend, s};
use burn_core as burn;
use core::f32;
#[derive(Config, Debug)]
pub struct RNNTLossConfig {
#[config(default = 0)]
pub blank: usize,
#[config(default = true)]
pub logits: bool,
}
impl RNNTLossConfig {
pub fn init(&self) -> RNNTLoss {
RNNTLoss {
blank: self.blank,
logits: self.logits,
}
}
}
#[derive(Module, Clone, Debug)]
pub struct RNNTLoss {
blank: usize,
logits: bool,
}
impl RNNTLoss {
pub fn forward<B: Backend>(
&self,
logits: Tensor<B, 4>,
targets: Tensor<B, 2, Int>,
logit_lengths: Tensor<B, 1, Int>,
target_lengths: Tensor<B, 1, Int>,
) -> Tensor<B, 1> {
let device = logits.device();
let [b, max_t, max_up1, v] = logits.dims();
let max_u = max_up1 - 1;
self.check_inputs(b, v, &targets, &logit_lengths, &target_lengths, max_u);
let log_probs = if self.logits {
let vocab_dim = 3; burn::tensor::activation::log_softmax(logits, vocab_dim)
} else {
logits
};
let (lpb, lpl) = self.extract_log_probs(log_probs, targets);
let u_mask = self.create_u_mask(&target_lengths, b, max_up1, &device);
let neg_inf = Tensor::<B, 2>::full([b, max_up1], f32::NEG_INFINITY, &device);
let mut alpha = self.init_alpha(&lpl, b, max_up1, &device);
alpha = neg_inf.clone().mask_where(u_mask.clone(), alpha);
let logit_lengths_exp = logit_lengths.clone().reshape([b, 1]).expand([b, max_up1]);
for t in 1..max_t {
let new = self.step_alpha(&alpha, &lpb, &lpl, t);
let new = neg_inf.clone().mask_where(u_mask.clone(), new);
let valid = logit_lengths_exp.clone().greater_elem(t as i64);
alpha = alpha.mask_where(valid, new);
}
self.gather_loss(alpha, &lpb, logit_lengths, target_lengths, b)
}
pub fn forward_with_reduction<B: Backend>(
&self,
logits: Tensor<B, 4>,
targets: Tensor<B, 2, Int>,
logit_lengths: Tensor<B, 1, Int>,
target_lengths: Tensor<B, 1, Int>,
reduction: Reduction,
) -> Tensor<B, 1> {
let loss = self.forward(logits, targets, logit_lengths, target_lengths);
match reduction {
Reduction::Auto | Reduction::Mean => loss.mean(),
Reduction::Sum => loss.sum(),
other => panic!("{other:?} reduction is not supported"),
}
}
fn extract_log_probs<B: Backend>(
&self,
log_probs: Tensor<B, 4>,
targets: Tensor<B, 2, Int>,
) -> (Tensor<B, 3>, Tensor<B, 3>) {
let [b, max_t, max_up1, v] = log_probs.dims();
let max_u = max_up1 - 1;
let vocab_dim = 3;
let lpb = log_probs
.clone()
.slice_dim(vocab_dim, self.blank)
.squeeze_dim::<3>(vocab_dim);
let tgt = targets
.reshape([b, 1, max_u, 1])
.expand([b, max_t, max_u, 1]);
let lpl = log_probs
.slice(s![.., .., 0..max_u, 0..v])
.gather(vocab_dim, tgt)
.squeeze_dim::<3>(vocab_dim);
(lpb, lpl)
}
fn init_alpha<B: Backend>(
&self,
lpl: &Tensor<B, 3>,
b: usize,
max_up1: usize,
device: &B::Device,
) -> Tensor<B, 2> {
let lpl_0 = lpl.clone().slice(s![.., 0..1, ..]).squeeze_dim::<2>(1);
let zero_col = Tensor::<B, 2>::zeros([b, 1], device);
let prefix = Tensor::cat(vec![zero_col, lpl_0.slice(s![.., 0..(max_up1 - 1)])], 1);
prefix.cumsum(1)
}
fn create_u_mask<B: Backend>(
&self,
target_lengths: &Tensor<B, 1, Int>,
b: usize,
max_up1: usize,
device: &B::Device,
) -> Tensor<B, 2, Bool> {
let indices = Tensor::<B, 1, Int>::arange(0..max_up1 as i64, device)
.reshape([1, max_up1])
.expand([b, max_up1]);
let lengths = target_lengths.clone().reshape([b, 1]).expand([b, max_up1]);
indices.lower_equal(lengths)
}
fn step_alpha<B: Backend>(
&self,
alpha: &Tensor<B, 2>,
lpb: &Tensor<B, 3>,
lpl: &Tensor<B, 3>,
t: usize,
) -> Tensor<B, 2> {
let [b, max_up1] = alpha.dims();
let device = alpha.device();
let blank_prob = lpb
.clone()
.slice(s![.., (t - 1)..t, ..])
.squeeze_dim::<2>(1);
let from_blank = alpha.clone().add(blank_prob);
let mut new = Tensor::<B, 2>::full([b, max_up1], f32::NEG_INFINITY, &device);
new = new.slice_assign(s![.., 0..1], from_blank.clone().slice(s![.., 0..1]));
let label_prob = lpl
.clone()
.slice(s![.., t..(t + 1), ..])
.squeeze_dim::<2>(1);
for u in 1..max_up1 {
let via_blank = from_blank.clone().slice(s![.., u..(u + 1)]);
let via_label = new
.clone()
.slice(s![.., (u - 1)..u])
.add(label_prob.clone().slice(s![.., (u - 1)..u]));
new = new.slice_assign(s![.., u..(u + 1)], self.log_sum_exp(via_blank, via_label));
}
new
}
fn gather_loss<B: Backend>(
&self,
alpha: Tensor<B, 2>,
lpb: &Tensor<B, 3>,
logit_lengths: Tensor<B, 1, Int>,
target_lengths: Tensor<B, 1, Int>,
b: usize,
) -> Tensor<B, 1> {
let device = alpha.device();
let u_idx = target_lengths;
let int_dtype = u_idx.dtype();
let t_idx = logit_lengths.sub_scalar(1).cast(int_dtype);
let b_idx = Tensor::<B, 1, Int>::arange(0..b as i64, (&device, int_dtype));
let alpha_tu: Tensor<B, 1> =
alpha.gather_nd(Tensor::stack::<2>(vec![b_idx.clone(), u_idx.clone()], 1));
let lpb_tu: Tensor<B, 1> = lpb
.clone()
.gather_nd(Tensor::stack::<2>(vec![b_idx, t_idx, u_idx], 1));
alpha_tu.add(lpb_tu).neg()
}
fn check_inputs<B: Backend>(
&self,
b: usize,
v: usize,
targets: &Tensor<B, 2, Int>,
logit_lengths: &Tensor<B, 1, Int>,
target_lengths: &Tensor<B, 1, Int>,
max_u: usize,
) {
assert!(
self.blank < v,
"blank index {} must be less than vocab_size {}",
self.blank,
v
);
assert_eq!(
targets.dims()[0],
b,
"targets batch dimension {} must equal batch_size {}",
targets.dims()[0],
b
);
assert_eq!(
targets.dims()[1],
max_u,
"targets length dimension {} must equal max_target_len (max_u) {}",
targets.dims()[1],
max_u
);
assert_eq!(
logit_lengths.dims()[0],
b,
"logit_lengths length {} must equal batch_size {}",
logit_lengths.dims()[0],
b
);
assert_eq!(
target_lengths.dims()[0],
b,
"target_lengths length {} must equal batch_size {}",
target_lengths.dims()[0],
b
);
}
fn log_sum_exp<const D: usize, B: Backend>(
&self,
a: Tensor<B, D>,
b: Tensor<B, D>,
) -> Tensor<B, D> {
let a_inf = a.clone().equal_elem(f32::NEG_INFINITY);
let b_inf = b.clone().equal_elem(f32::NEG_INFINITY);
let a_safe = a.clone().mask_fill(a_inf.clone(), 0.0);
let b_safe = b.clone().mask_fill(b_inf.clone(), 0.0);
let max = a_safe.clone().max_pair(b_safe.clone());
let result = max.add(a_safe.sub(b_safe).abs().neg().exp().add_scalar(1.0).log());
let result = result.mask_where(a_inf, b);
result.mask_where(b_inf, a)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::{TensorData, Tolerance};
use burn_flex::{Flex, FlexDevice};
type B = Flex;
const NUM_LABELS: usize = 2;
#[test]
fn config_defaults() {
let cfg = RNNTLossConfig::new();
assert_eq!(cfg.blank, 0);
assert!(cfg.logits);
}
#[test]
#[should_panic(expected = "blank index")]
fn panics_on_invalid_blank() {
let dev = FlexDevice;
let rnnt = RNNTLossConfig::new().with_blank(5).init();
rnnt.forward(
Tensor::<B, 4>::zeros([1, 2, 2, 3], &dev),
Tensor::<B, 2, Int>::from_data([[1_i32]], &dev),
Tensor::<B, 1, Int>::from_data([2], &dev),
Tensor::<B, 1, Int>::from_data([1], &dev),
);
}
#[test]
#[should_panic(expected = "must equal batch_size")]
fn panics_on_batch_mismatch() {
let dev = FlexDevice;
let rnnt = RNNTLossConfig::new().init();
rnnt.forward(
Tensor::<B, 4>::zeros([2, 3, 2, 3], &dev),
Tensor::<B, 2, Int>::from_data([[1_i32]], &dev),
Tensor::<B, 1, Int>::from_data([3, 3], &dev),
Tensor::<B, 1, Int>::from_data([1, 1], &dev),
);
}
#[test]
#[should_panic(expected = "logit_lengths length")]
fn panics_on_logit_lengths_mismatch() {
let dev = FlexDevice;
let rnnt = RNNTLossConfig::new().init();
rnnt.forward(
Tensor::<B, 4>::zeros([2, 3, 2, 3], &dev),
Tensor::<B, 2, Int>::from_data([[1_i32], [2]], &dev),
Tensor::<B, 1, Int>::from_data([3], &dev),
Tensor::<B, 1, Int>::from_data([1, 1], &dev),
);
}
#[test]
#[should_panic(expected = "target_lengths length")]
fn panics_on_target_lengths_mismatch() {
let dev = FlexDevice;
let rnnt = RNNTLossConfig::new().init();
rnnt.forward(
Tensor::<B, 4>::zeros([2, 3, 2, 3], &dev),
Tensor::<B, 2, Int>::from_data([[1_i32], [2]], &dev),
Tensor::<B, 1, Int>::from_data([3, 3], &dev),
Tensor::<B, 1, Int>::from_data([1], &dev),
);
}
#[test]
fn single_token_uniform_probs() {
let dev = FlexDevice;
let rnnt = RNNTLossConfig::new().with_logits(false).init();
let time_steps = 2;
let target_len = 1;
let v = NUM_LABELS as f32;
let log_uniform = (1.0 / v).ln();
let loss = rnnt.forward(
Tensor::<B, 4>::full(
[1, time_steps, target_len + 1, NUM_LABELS],
log_uniform,
&dev,
),
Tensor::<B, 2, Int>::from_data([[1_i32]], &dev),
Tensor::<B, 1, Int>::from_data([time_steps as i64], &dev),
Tensor::<B, 1, Int>::from_data([target_len as i64], &dev),
);
let num_paths = time_steps as f32;
let emissions_per_path = (time_steps + target_len) as f32;
let total_prob = num_paths * v.powf(-emissions_per_path);
let expected_loss = -total_prob.ln();
loss.into_data().assert_approx_eq::<f32>(
&TensorData::from([expected_loss]),
Tolerance::absolute(1e-4),
);
}
#[test]
fn empty_target() {
let dev = FlexDevice;
let rnnt = RNNTLossConfig::new().with_logits(false).init();
let time_steps = 3;
let target_len = 0;
let v = NUM_LABELS as f32;
let log_uniform = (1.0 / v).ln();
let loss = rnnt.forward(
Tensor::<B, 4>::full([1, time_steps, 2, NUM_LABELS], log_uniform, &dev),
Tensor::<B, 2, Int>::from_data([[1_i32]], &dev),
Tensor::<B, 1, Int>::from_data([time_steps as i64], &dev),
Tensor::<B, 1, Int>::from_data([target_len as i64], &dev),
);
let expected_loss = -v.powf(-((time_steps + target_len) as f32)).ln();
loss.into_data().assert_approx_eq::<f32>(
&TensorData::from([expected_loss]),
Tolerance::absolute(1e-4),
);
}
#[test]
fn logits_equivalence() {
let dev = FlexDevice;
let [bs, time_steps, up1, vocab] = [1, 2, 3, 4];
let num_elements = bs * time_steps * up1 * vocab;
let target_len = up1 - 1;
let data: Vec<f32> = (0..num_elements).map(|i| (i as f32 * 0.3).sin()).collect();
let logits = Tensor::<B, 4>::from_data(
burn_core::tensor::TensorData::new(data, [bs, time_steps, up1, vocab]),
&dev,
);
let targets = Tensor::<B, 2, Int>::from_data([[1_i32, 2]], &dev);
let logit_lengths = Tensor::<B, 1, Int>::from_data([time_steps as i64], &dev);
let target_lengths = Tensor::<B, 1, Int>::from_data([target_len as i64], &dev);
let vocab_dim = 3;
let fused = RNNTLossConfig::new().with_logits(true).init().forward(
logits.clone(),
targets.clone(),
logit_lengths.clone(),
target_lengths.clone(),
);
let log_probs = burn::tensor::activation::log_softmax(logits, vocab_dim);
let manual = RNNTLossConfig::new().with_logits(false).init().forward(
log_probs,
targets,
logit_lengths,
target_lengths,
);
fused
.into_data()
.assert_approx_eq::<f32>(&manual.into_data(), Tolerance::absolute(1e-4));
}
}
#[cfg(test)]
#[allow(clippy::identity_op, clippy::too_many_arguments)]
mod pytorch_comparison_tests {
use super::*;
use burn::tensor::{TensorData, Tolerance};
use burn_autodiff::Autodiff;
use burn_flex::{Flex, FlexDevice};
type B = Autodiff<Flex>;
fn tol() -> Tolerance<f32> {
Tolerance::absolute(1e-3)
}
fn make_logits(bs: usize, t: usize, u: usize, v: usize, dev: &FlexDevice) -> Tensor<B, 4> {
let mut data = Vec::with_capacity(bs * t * u * v);
for bi in 0..bs {
for ti in 0..t {
for ui in 0..u {
for vi in 0..v {
let idx = bi * 11 + ti * 7 + ui * 13 + vi * 3;
data.push((idx as f32 * 0.1).sin());
}
}
}
}
Tensor::from_data(TensorData::new(data, [bs, t, u, v]), dev)
}
fn check_vocab_grad_sums(grad: &[f32], bs: usize, t: usize, up1: usize, v: usize) {
for bi in 0..bs {
for ti in 0..t {
for ui in 0..up1 {
let base = ((bi * t + ti) * up1 + ui) * v;
let sum: f32 = (0..v).map(|vi| grad[base + vi]).sum();
TensorData::from([sum])
.assert_approx_eq::<f32>(&TensorData::from([0.0f32]), tol());
}
}
}
}
fn grad_at(
grad: &[f32],
b: usize,
t: usize,
u: usize,
max_t: usize,
up1: usize,
v: usize,
) -> &[f32] {
let base = ((b * max_t + t) * up1 + u) * v;
&grad[base..base + v]
}
fn assert_grad(
grad: &[f32],
b: usize,
t: usize,
u: usize,
max_t: usize,
up1: usize,
v: usize,
expected: &[f32],
) {
TensorData::from(grad_at(grad, b, t, u, max_t, up1, v))
.assert_approx_eq::<f32>(&TensorData::from(expected), tol());
}
#[test]
fn basic_b1() {
let dev = FlexDevice;
let rnnt = RNNTLossConfig::new().init();
let logits = make_logits(1, 4, 3, 3, &dev).require_grad();
let loss = rnnt.forward(
logits.clone(),
Tensor::<B, 2, Int>::from_data([[1_i32, 2]], &dev),
Tensor::<B, 1, Int>::from_data([4_i32], &dev),
Tensor::<B, 1, Int>::from_data([2_i32], &dev),
);
loss.clone()
.into_data()
.assert_approx_eq::<f32>(&TensorData::from([4.4491f32]), tol());
let grads = loss.sum().backward();
let grad = logits
.grad(&grads)
.unwrap()
.into_data()
.to_vec::<f32>()
.unwrap();
assert_grad(&grad, 0, 0, 0, 4, 3, 3, &[-0.2041, -0.2246, 0.4287]);
assert_grad(&grad, 0, 2, 0, 4, 3, 3, &[0.0079, -0.0640, 0.0561]);
assert_grad(&grad, 0, 3, 2, 4, 3, 3, &[-0.6899, 0.3231, 0.3667]);
check_vocab_grad_sums(&grad, 1, 4, 3, 3);
}
#[test]
fn batched_b2() {
let dev = FlexDevice;
let rnnt = RNNTLossConfig::new().init();
let logits = make_logits(2, 5, 4, 4, &dev).require_grad();
let loss = rnnt.forward(
logits.clone(),
Tensor::<B, 2, Int>::from_data(
TensorData::new(vec![1_i32, 2, 3, 2, 1, 3], [2, 3]),
&dev,
),
Tensor::<B, 1, Int>::from_data([5_i32, 5], &dev),
Tensor::<B, 1, Int>::from_data([3_i32, 3], &dev),
);
loss.clone()
.into_data()
.assert_approx_eq::<f32>(&TensorData::from([7.9356f32, 7.2033]), tol());
let grads = loss.sum().backward();
let grad = logits
.grad(&grads)
.unwrap()
.into_data()
.to_vec::<f32>()
.unwrap();
assert_grad(&grad, 0, 0, 0, 5, 4, 4, &[-0.3161, -0.3113, 0.2796, 0.3479]);
assert_grad(&grad, 1, 0, 0, 5, 4, 4, &[-0.2766, 0.2602, -0.2248, 0.2411]);
assert_grad(&grad, 0, 4, 3, 5, 4, 4, &[-0.8216, 0.2296, 0.2786, 0.3133]);
assert_grad(&grad, 1, 4, 3, 5, 4, 4, &[-0.7185, 0.2735, 0.2437, 0.2012]);
check_vocab_grad_sums(&grad, 2, 5, 4, 4);
}
#[test]
fn variable_lengths_b3() {
let dev = FlexDevice;
let rnnt = RNNTLossConfig::new().init();
let logits = make_logits(3, 6, 4, 5, &dev).require_grad();
let loss = rnnt.forward(
logits.clone(),
Tensor::<B, 2, Int>::from_data(
TensorData::new(vec![1_i32, 2, 3, 4, 1, 0, 2, 0, 0], [3, 3]),
&dev,
),
Tensor::<B, 1, Int>::from_data([6_i32, 4, 5], &dev),
Tensor::<B, 1, Int>::from_data([3_i32, 2, 1], &dev),
);
loss.clone()
.into_data()
.assert_approx_eq::<f32>(&TensorData::from([10.7458f32, 8.0196, 8.3316]), tol());
let grads = loss.sum().backward();
let grad = logits
.grad(&grads)
.unwrap()
.into_data()
.to_vec::<f32>()
.unwrap();
let stride = 4 * 5; let zeros = vec![0.0f32; 5];
assert_grad(
&grad,
0,
0,
0,
6,
4,
5,
&[-0.4232, -0.3114, 0.1992, 0.2478, 0.2876],
);
assert_grad(
&grad,
0,
5,
3,
6,
4,
5,
&[-0.8016, 0.2170, 0.2172, 0.1991, 0.1683],
);
assert_grad(
&grad,
1,
0,
0,
6,
4,
5,
&[-0.2502, 0.2160, 0.2173, 0.2002, -0.3833],
);
let sample1_t4_start = 1 * 6 * stride + 4 * stride;
for i in 0..(2 * stride) {
assert!(
grad[sample1_t4_start + i].abs() < 1e-3,
"sample 1, t>=4: grad[{}] = {} (expected 0)",
i,
grad[sample1_t4_start + i]
);
}
for ti in 0..4 {
assert_grad(&grad, 1, ti, 3, 6, 4, 5, &zeros);
}
let sample2_t5_start = 2 * 6 * stride + 5 * stride;
for i in 0..stride {
assert!(
grad[sample2_t5_start + i].abs() < 1e-3,
"sample 2, t=5: grad[{}] = {} (expected 0)",
i,
grad[sample2_t5_start + i]
);
}
check_vocab_grad_sums(&grad, 3, 6, 4, 5);
}
#[test]
fn sum_reduction() {
let dev = FlexDevice;
let rnnt = RNNTLossConfig::new().init();
let logits = make_logits(2, 5, 4, 4, &dev).require_grad();
let tgt = Tensor::<B, 2, Int>::from_data(
TensorData::new(vec![1_i32, 2, 3, 2, 1, 3], [2, 3]),
&dev,
);
let il = Tensor::<B, 1, Int>::from_data([5_i32, 5], &dev);
let tl = Tensor::<B, 1, Int>::from_data([3_i32, 3], &dev);
let loss = rnnt.forward_with_reduction(logits.clone(), tgt, il, tl, Reduction::Sum);
loss.clone()
.into_data()
.assert_approx_eq::<f32>(&TensorData::from([15.1389f32]), tol());
let grads = loss.backward();
let g = logits
.grad(&grads)
.unwrap()
.into_data()
.to_vec::<f32>()
.unwrap();
TensorData::from(&g[..4]).assert_approx_eq::<f32>(
&TensorData::from([-0.3161f32, -0.3113, 0.2796, 0.3479]),
tol(),
);
}
#[test]
fn mean_reduction() {
let dev = FlexDevice;
let rnnt = RNNTLossConfig::new().init();
let logits = make_logits(2, 5, 4, 4, &dev).require_grad();
let tgt = Tensor::<B, 2, Int>::from_data(
TensorData::new(vec![1_i32, 2, 3, 2, 1, 3], [2, 3]),
&dev,
);
let il = Tensor::<B, 1, Int>::from_data([5_i32, 5], &dev);
let tl = Tensor::<B, 1, Int>::from_data([3_i32, 3], &dev);
let loss = rnnt.forward_with_reduction(logits.clone(), tgt, il, tl, Reduction::Mean);
loss.clone()
.into_data()
.assert_approx_eq::<f32>(&TensorData::from([7.5694f32]), tol());
let grads = loss.backward();
let g = logits
.grad(&grads)
.unwrap()
.into_data()
.to_vec::<f32>()
.unwrap();
TensorData::from(&g[..4]).assert_approx_eq::<f32>(
&TensorData::from([-0.1581f32, -0.1557, 0.1398, 0.1739]),
tol(),
);
}
}