Skip to main content

burn_nn/loss/
rnnt.rs

1use super::Reduction;
2use alloc::vec;
3use burn::config::Config;
4use burn::module::Module;
5use burn::tensor::{Bool, Int, Tensor, backend::Backend, s};
6use burn_core as burn;
7use core::f32;
8
9/// Configuration for [RNNTLoss](RNNTLoss).
10#[derive(Config, Debug)]
11pub struct RNNTLossConfig {
12    /// Index of the blank label in the vocabulary. Default: `0`.
13    #[config(default = 0)]
14    pub blank: usize,
15    /// Treat the inputs as logits, applying a log-softmax on the last dimension internally.
16    /// If `false`, the input must already be log-probabilities. Default: `true`.
17    #[config(default = true)]
18    pub logits: bool,
19}
20
21impl RNNTLossConfig {
22    /// Initializes a [RNNTLoss](RNNTLoss) module.
23    pub fn init(&self) -> RNNTLoss {
24        RNNTLoss {
25            blank: self.blank,
26            logits: self.logits,
27        }
28    }
29}
30
31/// RNN Transducer (RNNT) loss, as described in
32/// [Sequence Transduction with Recurrent Neural Networks](https://arxiv.org/abs/1211.3711).
33///
34/// Computes the negative log-likelihood over a 2D lattice of encoder time steps (T)
35/// and output labels (U), marginalizing over all valid alignments.
36///
37/// # Example
38///
39/// ```rust,ignore
40/// let rnnt = RNNTLossConfig::new().init();
41///
42/// // logits: [B, T, U+1, V] from the joiner network
43/// let loss = rnnt.forward(logits, targets, logit_lengths, target_lengths);
44/// ```
45#[derive(Module, Clone, Debug)]
46pub struct RNNTLoss {
47    blank: usize,
48    logits: bool,
49}
50
51impl RNNTLoss {
52    /// Computes per-sample RNNT loss (no reduction). Returns shape `[B]`.
53    ///
54    /// - `logits`: `[B, T, U+1, V]` — joiner output (raw logits or log-probs)
55    /// - `targets`: `[B, U]` — target label indices (must not contain blank)
56    /// - `logit_lengths`: `[B]` — actual encoder lengths per sample
57    /// - `target_lengths`: `[B]` — actual target lengths per sample
58    pub fn forward<B: Backend>(
59        &self,
60        logits: Tensor<B, 4>,
61        targets: Tensor<B, 2, Int>,
62        logit_lengths: Tensor<B, 1, Int>,
63        target_lengths: Tensor<B, 1, Int>,
64    ) -> Tensor<B, 1> {
65        let device = logits.device();
66        let [b, max_t, max_up1, v] = logits.dims();
67        let max_u = max_up1 - 1;
68
69        self.check_inputs(b, v, &targets, &logit_lengths, &target_lengths, max_u);
70
71        let log_probs = if self.logits {
72            let vocab_dim = 3; // last dim of [B, T, U+1, V]
73            burn::tensor::activation::log_softmax(logits, vocab_dim)
74        } else {
75            logits
76        };
77
78        let (lpb, lpl) = self.extract_log_probs(log_probs, targets);
79        let u_mask = self.create_u_mask(&target_lengths, b, max_up1, &device);
80        let neg_inf = Tensor::<B, 2>::full([b, max_up1], f32::NEG_INFINITY, &device);
81
82        // Forward pass: compute log_alpha across the (T, U) lattice
83        let mut alpha = self.init_alpha(&lpl, b, max_up1, &device);
84        alpha = neg_inf.clone().mask_where(u_mask.clone(), alpha);
85
86        let logit_lengths_exp = logit_lengths.clone().reshape([b, 1]).expand([b, max_up1]);
87
88        for t in 1..max_t {
89            let new = self.step_alpha(&alpha, &lpb, &lpl, t);
90            let new = neg_inf.clone().mask_where(u_mask.clone(), new);
91
92            // Only update alpha for samples where t < logit_lengths[b]
93            let valid = logit_lengths_exp.clone().greater_elem(t as i64);
94            alpha = alpha.mask_where(valid, new);
95        }
96
97        self.gather_loss(alpha, &lpb, logit_lengths, target_lengths, b)
98    }
99
100    /// Computes RNNT loss with the given reduction. Returns shape `[1]`.
101    pub fn forward_with_reduction<B: Backend>(
102        &self,
103        logits: Tensor<B, 4>,
104        targets: Tensor<B, 2, Int>,
105        logit_lengths: Tensor<B, 1, Int>,
106        target_lengths: Tensor<B, 1, Int>,
107        reduction: Reduction,
108    ) -> Tensor<B, 1> {
109        let loss = self.forward(logits, targets, logit_lengths, target_lengths);
110        match reduction {
111            Reduction::Auto | Reduction::Mean => loss.mean(),
112            Reduction::Sum => loss.sum(),
113            other => panic!("{other:?} reduction is not supported"),
114        }
115    }
116
117    /// Gathers `log_prob_blank[B, T, U+1]` and `log_prob_label[B, T, U]` from the full
118    /// log-probability tensor by indexing into the vocab dimension.
119    fn extract_log_probs<B: Backend>(
120        &self,
121        log_probs: Tensor<B, 4>,
122        targets: Tensor<B, 2, Int>,
123    ) -> (Tensor<B, 3>, Tensor<B, 3>) {
124        let [b, max_t, max_up1, v] = log_probs.dims();
125        let max_u = max_up1 - 1;
126        let vocab_dim = 3;
127
128        // Blank probabilities: slice log_probs in vocab dim using the blank index
129        let lpb = log_probs
130            .clone()
131            .slice_dim(vocab_dim, self.blank)
132            .squeeze_dim::<3>(vocab_dim);
133
134        // Label probabilities: gather target labels across vocab dim (only first U positions)
135        let tgt = targets
136            .reshape([b, 1, max_u, 1])
137            .expand([b, max_t, max_u, 1]);
138        let lpl = log_probs
139            .slice(s![.., .., 0..max_u, 0..v])
140            .gather(vocab_dim, tgt)
141            .squeeze_dim::<3>(vocab_dim);
142
143        (lpb, lpl)
144    }
145
146    /// Sets up log_alpha at t=0: `alpha(0,0) = 0`, then cumsum of label probs along u.
147    fn init_alpha<B: Backend>(
148        &self,
149        lpl: &Tensor<B, 3>,
150        b: usize,
151        max_up1: usize,
152        device: &B::Device,
153    ) -> Tensor<B, 2> {
154        // Label probs at t=0
155        let lpl_0 = lpl.clone().slice(s![.., 0..1, ..]).squeeze_dim::<2>(1);
156        let zero_col = Tensor::<B, 2>::zeros([b, 1], device);
157        let prefix = Tensor::cat(vec![zero_col, lpl_0.slice(s![.., 0..(max_up1 - 1)])], 1);
158
159        prefix.cumsum(1)
160    }
161
162    /// Boolean mask `[B, U+1]` that is true where `u <= target_lengths[b]`.
163    fn create_u_mask<B: Backend>(
164        &self,
165        target_lengths: &Tensor<B, 1, Int>,
166        b: usize,
167        max_up1: usize,
168        device: &B::Device,
169    ) -> Tensor<B, 2, Bool> {
170        let indices = Tensor::<B, 1, Int>::arange(0..max_up1 as i64, device)
171            .reshape([1, max_up1])
172            .expand([b, max_up1]);
173        let lengths = target_lengths.clone().reshape([b, 1]).expand([b, max_up1]);
174        indices.lower_equal(lengths)
175    }
176
177    /// One time step of the forward recurrence:
178    ///
179    ///   alpha(t, u) = logaddexp(
180    ///       alpha(t-1, u) + blank(t-1, u),
181    ///       alpha(t, u-1) + label(t, u-1),
182    ///   )
183    fn step_alpha<B: Backend>(
184        &self,
185        alpha: &Tensor<B, 2>,
186        lpb: &Tensor<B, 3>,
187        lpl: &Tensor<B, 3>,
188        t: usize,
189    ) -> Tensor<B, 2> {
190        let [b, max_up1] = alpha.dims();
191        let device = alpha.device();
192
193        // Blank transition: alpha(t-1, :) + blank_prob(t-1, :)
194        let blank_prob = lpb
195            .clone()
196            .slice(s![.., (t - 1)..t, ..])
197            .squeeze_dim::<2>(1);
198        let from_blank = alpha.clone().add(blank_prob);
199
200        let mut new = Tensor::<B, 2>::full([b, max_up1], f32::NEG_INFINITY, &device);
201        new = new.slice_assign(s![.., 0..1], from_blank.clone().slice(s![.., 0..1]));
202
203        // Label probs at time t
204        let label_prob = lpl
205            .clone()
206            .slice(s![.., t..(t + 1), ..])
207            .squeeze_dim::<2>(1);
208
209        for u in 1..max_up1 {
210            let via_blank = from_blank.clone().slice(s![.., u..(u + 1)]);
211            let via_label = new
212                .clone()
213                .slice(s![.., (u - 1)..u])
214                .add(label_prob.clone().slice(s![.., (u - 1)..u]));
215            new = new.slice_assign(s![.., u..(u + 1)], self.log_sum_exp(via_blank, via_label));
216        }
217        new
218    }
219
220    /// Extracts `-(alpha(T_b, U_b) + blank(T_b, U_b))` for each sample in the batch.
221    fn gather_loss<B: Backend>(
222        &self,
223        alpha: Tensor<B, 2>,
224        lpb: &Tensor<B, 3>,
225        logit_lengths: Tensor<B, 1, Int>,
226        target_lengths: Tensor<B, 1, Int>,
227        b: usize,
228    ) -> Tensor<B, 1> {
229        let device = alpha.device();
230        // Anchor the index dtype on `u_idx` so all three coordinate tensors share a
231        // common int dtype before stacking - `cat` panics on dtype mismatch and the
232        // caller's lengths may not use the device's default IntElem.
233        let u_idx = target_lengths;
234        let int_dtype = u_idx.dtype();
235        let t_idx = logit_lengths.sub_scalar(1).cast(int_dtype);
236        let b_idx = Tensor::<B, 1, Int>::arange(0..b as i64, (&device, int_dtype));
237
238        let alpha_tu: Tensor<B, 1> =
239            alpha.gather_nd(Tensor::stack::<2>(vec![b_idx.clone(), u_idx.clone()], 1));
240        let lpb_tu: Tensor<B, 1> = lpb
241            .clone()
242            .gather_nd(Tensor::stack::<2>(vec![b_idx, t_idx, u_idx], 1));
243
244        alpha_tu.add(lpb_tu).neg()
245    }
246
247    fn check_inputs<B: Backend>(
248        &self,
249        b: usize,
250        v: usize,
251        targets: &Tensor<B, 2, Int>,
252        logit_lengths: &Tensor<B, 1, Int>,
253        target_lengths: &Tensor<B, 1, Int>,
254        max_u: usize,
255    ) {
256        assert!(
257            self.blank < v,
258            "blank index {} must be less than vocab_size {}",
259            self.blank,
260            v
261        );
262        assert_eq!(
263            targets.dims()[0],
264            b,
265            "targets batch dimension {} must equal batch_size {}",
266            targets.dims()[0],
267            b
268        );
269        assert_eq!(
270            targets.dims()[1],
271            max_u,
272            "targets length dimension {} must equal max_target_len (max_u) {}",
273            targets.dims()[1],
274            max_u
275        );
276        assert_eq!(
277            logit_lengths.dims()[0],
278            b,
279            "logit_lengths length {} must equal batch_size {}",
280            logit_lengths.dims()[0],
281            b
282        );
283        assert_eq!(
284            target_lengths.dims()[0],
285            b,
286            "target_lengths length {} must equal batch_size {}",
287            target_lengths.dims()[0],
288            b
289        );
290    }
291
292    /// Numerically stable `log(exp(a) + exp(b))`, handling `-inf` inputs.
293    fn log_sum_exp<const D: usize, B: Backend>(
294        &self,
295        a: Tensor<B, D>,
296        b: Tensor<B, D>,
297    ) -> Tensor<B, D> {
298        let a_inf = a.clone().equal_elem(f32::NEG_INFINITY);
299        let b_inf = b.clone().equal_elem(f32::NEG_INFINITY);
300
301        // Replace -inf with 0 to prevent NaN in the subtraction (masked out below)
302        let a_safe = a.clone().mask_fill(a_inf.clone(), 0.0);
303        let b_safe = b.clone().mask_fill(b_inf.clone(), 0.0);
304
305        // log(exp(a) + exp(b)) = max(a,b) + log(1 + exp(-|a-b|))
306        let max = a_safe.clone().max_pair(b_safe.clone());
307        let result = max.add(a_safe.sub(b_safe).abs().neg().exp().add_scalar(1.0).log());
308
309        // If a=-inf, result is b; if b=-inf, result is a; if both -inf, stays -inf
310        let result = result.mask_where(a_inf, b);
311        result.mask_where(b_inf, a)
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use burn::tensor::{TensorData, Tolerance};
319    use burn_flex::{Flex, FlexDevice};
320
321    type B = Flex;
322    const NUM_LABELS: usize = 2; // vocab size for simple unit tests
323
324    #[test]
325    fn config_defaults() {
326        let cfg = RNNTLossConfig::new();
327        assert_eq!(cfg.blank, 0);
328        assert!(cfg.logits);
329    }
330
331    #[test]
332    #[should_panic(expected = "blank index")]
333    fn panics_on_invalid_blank() {
334        let dev = FlexDevice;
335        let rnnt = RNNTLossConfig::new().with_blank(5).init();
336        rnnt.forward(
337            Tensor::<B, 4>::zeros([1, 2, 2, 3], &dev),
338            Tensor::<B, 2, Int>::from_data([[1_i32]], &dev),
339            Tensor::<B, 1, Int>::from_data([2], &dev),
340            Tensor::<B, 1, Int>::from_data([1], &dev),
341        );
342    }
343
344    #[test]
345    #[should_panic(expected = "must equal batch_size")]
346    fn panics_on_batch_mismatch() {
347        let dev = FlexDevice;
348        let rnnt = RNNTLossConfig::new().init();
349        rnnt.forward(
350            Tensor::<B, 4>::zeros([2, 3, 2, 3], &dev),
351            Tensor::<B, 2, Int>::from_data([[1_i32]], &dev),
352            Tensor::<B, 1, Int>::from_data([3, 3], &dev),
353            Tensor::<B, 1, Int>::from_data([1, 1], &dev),
354        );
355    }
356
357    #[test]
358    #[should_panic(expected = "logit_lengths length")]
359    fn panics_on_logit_lengths_mismatch() {
360        let dev = FlexDevice;
361        let rnnt = RNNTLossConfig::new().init();
362        rnnt.forward(
363            Tensor::<B, 4>::zeros([2, 3, 2, 3], &dev),
364            Tensor::<B, 2, Int>::from_data([[1_i32], [2]], &dev),
365            Tensor::<B, 1, Int>::from_data([3], &dev),
366            Tensor::<B, 1, Int>::from_data([1, 1], &dev),
367        );
368    }
369
370    #[test]
371    #[should_panic(expected = "target_lengths length")]
372    fn panics_on_target_lengths_mismatch() {
373        let dev = FlexDevice;
374        let rnnt = RNNTLossConfig::new().init();
375        rnnt.forward(
376            Tensor::<B, 4>::zeros([2, 3, 2, 3], &dev),
377            Tensor::<B, 2, Int>::from_data([[1_i32], [2]], &dev),
378            Tensor::<B, 1, Int>::from_data([3, 3], &dev),
379            Tensor::<B, 1, Int>::from_data([1], &dev),
380        );
381    }
382
383    #[test]
384    fn single_token_uniform_probs() {
385        // B=1, T=2, U=1, V=2, uniform probs: P(blank) = P(label) = 1/V
386        //
387        // Two alignment paths (label emitted at t=0 or t=1), each with T+U emissions:
388        //   total_prob = T * (1/V)^(T+1) = 2 * (1/2)^3 = 1/4
389        //   loss = -ln(1/4) = 2*ln(2)
390        let dev = FlexDevice;
391        let rnnt = RNNTLossConfig::new().with_logits(false).init();
392        let time_steps = 2;
393        let target_len = 1;
394        let v = NUM_LABELS as f32;
395        let log_uniform = (1.0 / v).ln();
396
397        let loss = rnnt.forward(
398            Tensor::<B, 4>::full(
399                [1, time_steps, target_len + 1, NUM_LABELS],
400                log_uniform,
401                &dev,
402            ),
403            Tensor::<B, 2, Int>::from_data([[1_i32]], &dev),
404            Tensor::<B, 1, Int>::from_data([time_steps as i64], &dev),
405            Tensor::<B, 1, Int>::from_data([target_len as i64], &dev),
406        );
407        // Each path: T-1 blanks + U labels + 1 final blank = T + U emissions
408        let num_paths = time_steps as f32;
409        let emissions_per_path = (time_steps + target_len) as f32;
410        let total_prob = num_paths * v.powf(-emissions_per_path);
411        let expected_loss = -total_prob.ln();
412        loss.into_data().assert_approx_eq::<f32>(
413            &TensorData::from([expected_loss]),
414            Tolerance::absolute(1e-4),
415        );
416    }
417
418    #[test]
419    fn empty_target() {
420        // B=1, T=3, U=0, V=2, uniform probs: only the all-blanks path exists.
421        //
422        // Single path with T emissions (T-1 blanks + 1 final blank, all at u=0):
423        //   total_prob = (1/V)^T = (1/2)^3 = 1/8
424        //   loss = T*ln(V) = 3*ln(2)
425        let dev = FlexDevice;
426        let rnnt = RNNTLossConfig::new().with_logits(false).init();
427        let time_steps = 3;
428        let target_len = 0;
429        let v = NUM_LABELS as f32;
430        let log_uniform = (1.0 / v).ln();
431
432        let loss = rnnt.forward(
433            Tensor::<B, 4>::full([1, time_steps, 2, NUM_LABELS], log_uniform, &dev),
434            Tensor::<B, 2, Int>::from_data([[1_i32]], &dev),
435            Tensor::<B, 1, Int>::from_data([time_steps as i64], &dev),
436            Tensor::<B, 1, Int>::from_data([target_len as i64], &dev),
437        );
438        // T + U = T emissions total for U=0
439        let expected_loss = -v.powf(-((time_steps + target_len) as f32)).ln();
440        loss.into_data().assert_approx_eq::<f32>(
441            &TensorData::from([expected_loss]),
442            Tolerance::absolute(1e-4),
443        );
444    }
445
446    #[test]
447    fn logits_equivalence() {
448        // Verify that logits=true (internal log_softmax on raw logits)
449        // gives the same loss as logits=false with external log_softmax.
450        let dev = FlexDevice;
451        let [bs, time_steps, up1, vocab] = [1, 2, 3, 4];
452        let num_elements = bs * time_steps * up1 * vocab;
453        let target_len = up1 - 1;
454
455        let data: Vec<f32> = (0..num_elements).map(|i| (i as f32 * 0.3).sin()).collect();
456        let logits = Tensor::<B, 4>::from_data(
457            burn_core::tensor::TensorData::new(data, [bs, time_steps, up1, vocab]),
458            &dev,
459        );
460        let targets = Tensor::<B, 2, Int>::from_data([[1_i32, 2]], &dev);
461        let logit_lengths = Tensor::<B, 1, Int>::from_data([time_steps as i64], &dev);
462        let target_lengths = Tensor::<B, 1, Int>::from_data([target_len as i64], &dev);
463
464        let vocab_dim = 3;
465        let fused = RNNTLossConfig::new().with_logits(true).init().forward(
466            logits.clone(),
467            targets.clone(),
468            logit_lengths.clone(),
469            target_lengths.clone(),
470        );
471
472        let log_probs = burn::tensor::activation::log_softmax(logits, vocab_dim);
473        let manual = RNNTLossConfig::new().with_logits(false).init().forward(
474            log_probs,
475            targets,
476            logit_lengths,
477            target_lengths,
478        );
479
480        fused
481            .into_data()
482            .assert_approx_eq::<f32>(&manual.into_data(), Tolerance::absolute(1e-4));
483    }
484}
485
486/// Tests comparing forward loss and backward gradients against torchaudio.functional.rnnt_loss.
487///
488/// Logits are generated deterministically via sin((b*11+t*7+u*13+v*3)*0.1) so the same
489/// values can be reproduced in a Python script for cross-checking.
490#[cfg(test)]
491#[allow(clippy::identity_op, clippy::too_many_arguments)]
492mod pytorch_comparison_tests {
493    use super::*;
494    use burn::tensor::{TensorData, Tolerance};
495    use burn_autodiff::Autodiff;
496    use burn_flex::{Flex, FlexDevice};
497
498    type B = Autodiff<Flex>;
499    fn tol() -> Tolerance<f32> {
500        Tolerance::absolute(1e-3)
501    }
502
503    /// Deterministic logits matching the Python reference generator.
504    /// Uses coprime coefficients to avoid repeating patterns across dimensions.
505    fn make_logits(bs: usize, t: usize, u: usize, v: usize, dev: &FlexDevice) -> Tensor<B, 4> {
506        let mut data = Vec::with_capacity(bs * t * u * v);
507        for bi in 0..bs {
508            for ti in 0..t {
509                for ui in 0..u {
510                    for vi in 0..v {
511                        let idx = bi * 11 + ti * 7 + ui * 13 + vi * 3;
512                        data.push((idx as f32 * 0.1).sin());
513                    }
514                }
515            }
516        }
517        Tensor::from_data(TensorData::new(data, [bs, t, u, v]), dev)
518    }
519
520    /// Checks that gradients along the vocab dim sum to ~0 at every (b, t, u) position.
521    /// This must hold because log_softmax is applied on the last dim,
522    /// and the Jacobian of log_softmax has the property that each row sums to zero.
523    fn check_vocab_grad_sums(grad: &[f32], bs: usize, t: usize, up1: usize, v: usize) {
524        for bi in 0..bs {
525            for ti in 0..t {
526                for ui in 0..up1 {
527                    let base = ((bi * t + ti) * up1 + ui) * v;
528                    let sum: f32 = (0..v).map(|vi| grad[base + vi]).sum();
529                    TensorData::from([sum])
530                        .assert_approx_eq::<f32>(&TensorData::from([0.0f32]), tol());
531                }
532            }
533        }
534    }
535
536    /// Returns the V-sized gradient slice at position (b, t, u) in a flattened [B, T, U+1, V] grad.
537    fn grad_at(
538        grad: &[f32],
539        b: usize,
540        t: usize,
541        u: usize,
542        max_t: usize,
543        up1: usize,
544        v: usize,
545    ) -> &[f32] {
546        let base = ((b * max_t + t) * up1 + u) * v;
547        &grad[base..base + v]
548    }
549
550    /// Asserts that a gradient slice at position (b, t, u) matches expected values.
551    fn assert_grad(
552        grad: &[f32],
553        b: usize,
554        t: usize,
555        u: usize,
556        max_t: usize,
557        up1: usize,
558        v: usize,
559        expected: &[f32],
560    ) {
561        TensorData::from(grad_at(grad, b, t, u, max_t, up1, v))
562            .assert_approx_eq::<f32>(&TensorData::from(expected), tol());
563    }
564
565    #[test]
566    fn basic_b1() {
567        // B=1, T=4, U+1=3, V=3, targets=[1,2]
568        let dev = FlexDevice;
569        let rnnt = RNNTLossConfig::new().init();
570        let logits = make_logits(1, 4, 3, 3, &dev).require_grad();
571
572        let loss = rnnt.forward(
573            logits.clone(),
574            Tensor::<B, 2, Int>::from_data([[1_i32, 2]], &dev),
575            Tensor::<B, 1, Int>::from_data([4_i32], &dev),
576            Tensor::<B, 1, Int>::from_data([2_i32], &dev),
577        );
578        loss.clone()
579            .into_data()
580            .assert_approx_eq::<f32>(&TensorData::from([4.4491f32]), tol());
581
582        let grads = loss.sum().backward();
583        let grad = logits
584            .grad(&grads)
585            .unwrap()
586            .into_data()
587            .to_vec::<f32>()
588            .unwrap();
589
590        // Spot-check first, middle, and last (t, u) positions against torchaudio
591        assert_grad(&grad, 0, 0, 0, 4, 3, 3, &[-0.2041, -0.2246, 0.4287]);
592        assert_grad(&grad, 0, 2, 0, 4, 3, 3, &[0.0079, -0.0640, 0.0561]);
593        assert_grad(&grad, 0, 3, 2, 4, 3, 3, &[-0.6899, 0.3231, 0.3667]);
594        check_vocab_grad_sums(&grad, 1, 4, 3, 3);
595    }
596
597    #[test]
598    fn batched_b2() {
599        // B=2, T=5, U+1=4, V=4, targets=[[1,2,3],[2,1,3]]
600        let dev = FlexDevice;
601        let rnnt = RNNTLossConfig::new().init();
602        let logits = make_logits(2, 5, 4, 4, &dev).require_grad();
603
604        let loss = rnnt.forward(
605            logits.clone(),
606            Tensor::<B, 2, Int>::from_data(
607                TensorData::new(vec![1_i32, 2, 3, 2, 1, 3], [2, 3]),
608                &dev,
609            ),
610            Tensor::<B, 1, Int>::from_data([5_i32, 5], &dev),
611            Tensor::<B, 1, Int>::from_data([3_i32, 3], &dev),
612        );
613        loss.clone()
614            .into_data()
615            .assert_approx_eq::<f32>(&TensorData::from([7.9356f32, 7.2033]), tol());
616
617        let grads = loss.sum().backward();
618        let grad = logits
619            .grad(&grads)
620            .unwrap()
621            .into_data()
622            .to_vec::<f32>()
623            .unwrap();
624
625        // Spot-check: first position of each sample, and last position
626        assert_grad(&grad, 0, 0, 0, 5, 4, 4, &[-0.3161, -0.3113, 0.2796, 0.3479]);
627        assert_grad(&grad, 1, 0, 0, 5, 4, 4, &[-0.2766, 0.2602, -0.2248, 0.2411]);
628        assert_grad(&grad, 0, 4, 3, 5, 4, 4, &[-0.8216, 0.2296, 0.2786, 0.3133]);
629        assert_grad(&grad, 1, 4, 3, 5, 4, 4, &[-0.7185, 0.2735, 0.2437, 0.2012]);
630        check_vocab_grad_sums(&grad, 2, 5, 4, 4);
631    }
632
633    #[test]
634    fn variable_lengths_b3() {
635        // B=3, T=6, U+1=4, V=5
636        // logit_lengths=[6,4,5], target_lengths=[3,2,1]
637        // Tests that masking works correctly for variable-length sequences.
638        let dev = FlexDevice;
639        let rnnt = RNNTLossConfig::new().init();
640        let logits = make_logits(3, 6, 4, 5, &dev).require_grad();
641
642        let loss = rnnt.forward(
643            logits.clone(),
644            Tensor::<B, 2, Int>::from_data(
645                TensorData::new(vec![1_i32, 2, 3, 4, 1, 0, 2, 0, 0], [3, 3]),
646                &dev,
647            ),
648            Tensor::<B, 1, Int>::from_data([6_i32, 4, 5], &dev),
649            Tensor::<B, 1, Int>::from_data([3_i32, 2, 1], &dev),
650        );
651        loss.clone()
652            .into_data()
653            .assert_approx_eq::<f32>(&TensorData::from([10.7458f32, 8.0196, 8.3316]), tol());
654
655        let grads = loss.sum().backward();
656        let grad = logits
657            .grad(&grads)
658            .unwrap()
659            .into_data()
660            .to_vec::<f32>()
661            .unwrap();
662        let stride = 4 * 5; // U+1 * V per time step
663        let zeros = vec![0.0f32; 5];
664
665        // Sample 0 (full length=6): spot-check first and last active positions
666        assert_grad(
667            &grad,
668            0,
669            0,
670            0,
671            6,
672            4,
673            5,
674            &[-0.4232, -0.3114, 0.1992, 0.2478, 0.2876],
675        );
676        assert_grad(
677            &grad,
678            0,
679            5,
680            3,
681            6,
682            4,
683            5,
684            &[-0.8016, 0.2170, 0.2172, 0.1991, 0.1683],
685        );
686
687        // Sample 1 (logit_length=4): gradients beyond t=3 should be zero
688        assert_grad(
689            &grad,
690            1,
691            0,
692            0,
693            6,
694            4,
695            5,
696            &[-0.2502, 0.2160, 0.2173, 0.2002, -0.3833],
697        );
698        let sample1_t4_start = 1 * 6 * stride + 4 * stride;
699        for i in 0..(2 * stride) {
700            // t=4 and t=5 should all be zero
701            assert!(
702                grad[sample1_t4_start + i].abs() < 1e-3,
703                "sample 1, t>=4: grad[{}] = {} (expected 0)",
704                i,
705                grad[sample1_t4_start + i]
706            );
707        }
708
709        // Sample 1 (target_length=2): u=3 positions should be zero within active time steps
710        for ti in 0..4 {
711            assert_grad(&grad, 1, ti, 3, 6, 4, 5, &zeros);
712        }
713
714        // Sample 2 (logit_length=5): t=5 should be zero
715        let sample2_t5_start = 2 * 6 * stride + 5 * stride;
716        for i in 0..stride {
717            assert!(
718                grad[sample2_t5_start + i].abs() < 1e-3,
719                "sample 2, t=5: grad[{}] = {} (expected 0)",
720                i,
721                grad[sample2_t5_start + i]
722            );
723        }
724
725        check_vocab_grad_sums(&grad, 3, 6, 4, 5);
726    }
727
728    #[test]
729    fn sum_reduction() {
730        let dev = FlexDevice;
731        let rnnt = RNNTLossConfig::new().init();
732        let logits = make_logits(2, 5, 4, 4, &dev).require_grad();
733        let tgt = Tensor::<B, 2, Int>::from_data(
734            TensorData::new(vec![1_i32, 2, 3, 2, 1, 3], [2, 3]),
735            &dev,
736        );
737        let il = Tensor::<B, 1, Int>::from_data([5_i32, 5], &dev);
738        let tl = Tensor::<B, 1, Int>::from_data([3_i32, 3], &dev);
739
740        let loss = rnnt.forward_with_reduction(logits.clone(), tgt, il, tl, Reduction::Sum);
741        // 7.9356 + 7.2033 = 15.1389
742        loss.clone()
743            .into_data()
744            .assert_approx_eq::<f32>(&TensorData::from([15.1389f32]), tol());
745
746        let grads = loss.backward();
747        let g = logits
748            .grad(&grads)
749            .unwrap()
750            .into_data()
751            .to_vec::<f32>()
752            .unwrap();
753        TensorData::from(&g[..4]).assert_approx_eq::<f32>(
754            &TensorData::from([-0.3161f32, -0.3113, 0.2796, 0.3479]),
755            tol(),
756        );
757    }
758
759    #[test]
760    fn mean_reduction() {
761        let dev = FlexDevice;
762        let rnnt = RNNTLossConfig::new().init();
763        let logits = make_logits(2, 5, 4, 4, &dev).require_grad();
764        let tgt = Tensor::<B, 2, Int>::from_data(
765            TensorData::new(vec![1_i32, 2, 3, 2, 1, 3], [2, 3]),
766            &dev,
767        );
768        let il = Tensor::<B, 1, Int>::from_data([5_i32, 5], &dev);
769        let tl = Tensor::<B, 1, Int>::from_data([3_i32, 3], &dev);
770
771        let loss = rnnt.forward_with_reduction(logits.clone(), tgt, il, tl, Reduction::Mean);
772        // 15.1389 / 2 = 7.5694
773        loss.clone()
774            .into_data()
775            .assert_approx_eq::<f32>(&TensorData::from([7.5694f32]), tol());
776
777        // Gradients should be half the sum-reduction gradients (mean over batch of 2)
778        let grads = loss.backward();
779        let g = logits
780            .grad(&grads)
781            .unwrap()
782            .into_data()
783            .to_vec::<f32>()
784            .unwrap();
785        TensorData::from(&g[..4]).assert_approx_eq::<f32>(
786            &TensorData::from([-0.1581f32, -0.1557, 0.1398, 0.1739]),
787            tol(),
788        );
789    }
790}