1use super::Reduction;
2use alloc::vec;
3use burn::config::Config;
4use burn::module::Module;
5use burn::tensor::{Bool, Int, Tensor, backend::Backend, s};
6use burn_core as burn;
7use core::f32;
8
9#[derive(Config, Debug)]
11pub struct RNNTLossConfig {
12 #[config(default = 0)]
14 pub blank: usize,
15 #[config(default = true)]
18 pub logits: bool,
19}
20
21impl RNNTLossConfig {
22 pub fn init(&self) -> RNNTLoss {
24 RNNTLoss {
25 blank: self.blank,
26 logits: self.logits,
27 }
28 }
29}
30
31#[derive(Module, Clone, Debug)]
46pub struct RNNTLoss {
47 blank: usize,
48 logits: bool,
49}
50
51impl RNNTLoss {
52 pub fn forward<B: Backend>(
59 &self,
60 logits: Tensor<B, 4>,
61 targets: Tensor<B, 2, Int>,
62 logit_lengths: Tensor<B, 1, Int>,
63 target_lengths: Tensor<B, 1, Int>,
64 ) -> Tensor<B, 1> {
65 let device = logits.device();
66 let [b, max_t, max_up1, v] = logits.dims();
67 let max_u = max_up1 - 1;
68
69 self.check_inputs(b, v, &targets, &logit_lengths, &target_lengths, max_u);
70
71 let log_probs = if self.logits {
72 let vocab_dim = 3; burn::tensor::activation::log_softmax(logits, vocab_dim)
74 } else {
75 logits
76 };
77
78 let (lpb, lpl) = self.extract_log_probs(log_probs, targets);
79 let u_mask = self.create_u_mask(&target_lengths, b, max_up1, &device);
80 let neg_inf = Tensor::<B, 2>::full([b, max_up1], f32::NEG_INFINITY, &device);
81
82 let mut alpha = self.init_alpha(&lpl, b, max_up1, &device);
84 alpha = neg_inf.clone().mask_where(u_mask.clone(), alpha);
85
86 let logit_lengths_exp = logit_lengths.clone().reshape([b, 1]).expand([b, max_up1]);
87
88 for t in 1..max_t {
89 let new = self.step_alpha(&alpha, &lpb, &lpl, t);
90 let new = neg_inf.clone().mask_where(u_mask.clone(), new);
91
92 let valid = logit_lengths_exp.clone().greater_elem(t as i64);
94 alpha = alpha.mask_where(valid, new);
95 }
96
97 self.gather_loss(alpha, &lpb, logit_lengths, target_lengths, b)
98 }
99
100 pub fn forward_with_reduction<B: Backend>(
102 &self,
103 logits: Tensor<B, 4>,
104 targets: Tensor<B, 2, Int>,
105 logit_lengths: Tensor<B, 1, Int>,
106 target_lengths: Tensor<B, 1, Int>,
107 reduction: Reduction,
108 ) -> Tensor<B, 1> {
109 let loss = self.forward(logits, targets, logit_lengths, target_lengths);
110 match reduction {
111 Reduction::Auto | Reduction::Mean => loss.mean(),
112 Reduction::Sum => loss.sum(),
113 other => panic!("{other:?} reduction is not supported"),
114 }
115 }
116
117 fn extract_log_probs<B: Backend>(
120 &self,
121 log_probs: Tensor<B, 4>,
122 targets: Tensor<B, 2, Int>,
123 ) -> (Tensor<B, 3>, Tensor<B, 3>) {
124 let [b, max_t, max_up1, v] = log_probs.dims();
125 let max_u = max_up1 - 1;
126 let vocab_dim = 3;
127
128 let lpb = log_probs
130 .clone()
131 .slice_dim(vocab_dim, self.blank)
132 .squeeze_dim::<3>(vocab_dim);
133
134 let tgt = targets
136 .reshape([b, 1, max_u, 1])
137 .expand([b, max_t, max_u, 1]);
138 let lpl = log_probs
139 .slice(s![.., .., 0..max_u, 0..v])
140 .gather(vocab_dim, tgt)
141 .squeeze_dim::<3>(vocab_dim);
142
143 (lpb, lpl)
144 }
145
146 fn init_alpha<B: Backend>(
148 &self,
149 lpl: &Tensor<B, 3>,
150 b: usize,
151 max_up1: usize,
152 device: &B::Device,
153 ) -> Tensor<B, 2> {
154 let lpl_0 = lpl.clone().slice(s![.., 0..1, ..]).squeeze_dim::<2>(1);
156 let zero_col = Tensor::<B, 2>::zeros([b, 1], device);
157 let prefix = Tensor::cat(vec![zero_col, lpl_0.slice(s![.., 0..(max_up1 - 1)])], 1);
158
159 prefix.cumsum(1)
160 }
161
162 fn create_u_mask<B: Backend>(
164 &self,
165 target_lengths: &Tensor<B, 1, Int>,
166 b: usize,
167 max_up1: usize,
168 device: &B::Device,
169 ) -> Tensor<B, 2, Bool> {
170 let indices = Tensor::<B, 1, Int>::arange(0..max_up1 as i64, device)
171 .reshape([1, max_up1])
172 .expand([b, max_up1]);
173 let lengths = target_lengths.clone().reshape([b, 1]).expand([b, max_up1]);
174 indices.lower_equal(lengths)
175 }
176
177 fn step_alpha<B: Backend>(
184 &self,
185 alpha: &Tensor<B, 2>,
186 lpb: &Tensor<B, 3>,
187 lpl: &Tensor<B, 3>,
188 t: usize,
189 ) -> Tensor<B, 2> {
190 let [b, max_up1] = alpha.dims();
191 let device = alpha.device();
192
193 let blank_prob = lpb
195 .clone()
196 .slice(s![.., (t - 1)..t, ..])
197 .squeeze_dim::<2>(1);
198 let from_blank = alpha.clone().add(blank_prob);
199
200 let mut new = Tensor::<B, 2>::full([b, max_up1], f32::NEG_INFINITY, &device);
201 new = new.slice_assign(s![.., 0..1], from_blank.clone().slice(s![.., 0..1]));
202
203 let label_prob = lpl
205 .clone()
206 .slice(s![.., t..(t + 1), ..])
207 .squeeze_dim::<2>(1);
208
209 for u in 1..max_up1 {
210 let via_blank = from_blank.clone().slice(s![.., u..(u + 1)]);
211 let via_label = new
212 .clone()
213 .slice(s![.., (u - 1)..u])
214 .add(label_prob.clone().slice(s![.., (u - 1)..u]));
215 new = new.slice_assign(s![.., u..(u + 1)], self.log_sum_exp(via_blank, via_label));
216 }
217 new
218 }
219
220 fn gather_loss<B: Backend>(
222 &self,
223 alpha: Tensor<B, 2>,
224 lpb: &Tensor<B, 3>,
225 logit_lengths: Tensor<B, 1, Int>,
226 target_lengths: Tensor<B, 1, Int>,
227 b: usize,
228 ) -> Tensor<B, 1> {
229 let device = alpha.device();
230 let u_idx = target_lengths;
234 let int_dtype = u_idx.dtype();
235 let t_idx = logit_lengths.sub_scalar(1).cast(int_dtype);
236 let b_idx = Tensor::<B, 1, Int>::arange(0..b as i64, (&device, int_dtype));
237
238 let alpha_tu: Tensor<B, 1> =
239 alpha.gather_nd(Tensor::stack::<2>(vec![b_idx.clone(), u_idx.clone()], 1));
240 let lpb_tu: Tensor<B, 1> = lpb
241 .clone()
242 .gather_nd(Tensor::stack::<2>(vec![b_idx, t_idx, u_idx], 1));
243
244 alpha_tu.add(lpb_tu).neg()
245 }
246
247 fn check_inputs<B: Backend>(
248 &self,
249 b: usize,
250 v: usize,
251 targets: &Tensor<B, 2, Int>,
252 logit_lengths: &Tensor<B, 1, Int>,
253 target_lengths: &Tensor<B, 1, Int>,
254 max_u: usize,
255 ) {
256 assert!(
257 self.blank < v,
258 "blank index {} must be less than vocab_size {}",
259 self.blank,
260 v
261 );
262 assert_eq!(
263 targets.dims()[0],
264 b,
265 "targets batch dimension {} must equal batch_size {}",
266 targets.dims()[0],
267 b
268 );
269 assert_eq!(
270 targets.dims()[1],
271 max_u,
272 "targets length dimension {} must equal max_target_len (max_u) {}",
273 targets.dims()[1],
274 max_u
275 );
276 assert_eq!(
277 logit_lengths.dims()[0],
278 b,
279 "logit_lengths length {} must equal batch_size {}",
280 logit_lengths.dims()[0],
281 b
282 );
283 assert_eq!(
284 target_lengths.dims()[0],
285 b,
286 "target_lengths length {} must equal batch_size {}",
287 target_lengths.dims()[0],
288 b
289 );
290 }
291
292 fn log_sum_exp<const D: usize, B: Backend>(
294 &self,
295 a: Tensor<B, D>,
296 b: Tensor<B, D>,
297 ) -> Tensor<B, D> {
298 let a_inf = a.clone().equal_elem(f32::NEG_INFINITY);
299 let b_inf = b.clone().equal_elem(f32::NEG_INFINITY);
300
301 let a_safe = a.clone().mask_fill(a_inf.clone(), 0.0);
303 let b_safe = b.clone().mask_fill(b_inf.clone(), 0.0);
304
305 let max = a_safe.clone().max_pair(b_safe.clone());
307 let result = max.add(a_safe.sub(b_safe).abs().neg().exp().add_scalar(1.0).log());
308
309 let result = result.mask_where(a_inf, b);
311 result.mask_where(b_inf, a)
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use burn::tensor::{TensorData, Tolerance};
319 use burn_flex::{Flex, FlexDevice};
320
321 type B = Flex;
322 const NUM_LABELS: usize = 2; #[test]
325 fn config_defaults() {
326 let cfg = RNNTLossConfig::new();
327 assert_eq!(cfg.blank, 0);
328 assert!(cfg.logits);
329 }
330
331 #[test]
332 #[should_panic(expected = "blank index")]
333 fn panics_on_invalid_blank() {
334 let dev = FlexDevice;
335 let rnnt = RNNTLossConfig::new().with_blank(5).init();
336 rnnt.forward(
337 Tensor::<B, 4>::zeros([1, 2, 2, 3], &dev),
338 Tensor::<B, 2, Int>::from_data([[1_i32]], &dev),
339 Tensor::<B, 1, Int>::from_data([2], &dev),
340 Tensor::<B, 1, Int>::from_data([1], &dev),
341 );
342 }
343
344 #[test]
345 #[should_panic(expected = "must equal batch_size")]
346 fn panics_on_batch_mismatch() {
347 let dev = FlexDevice;
348 let rnnt = RNNTLossConfig::new().init();
349 rnnt.forward(
350 Tensor::<B, 4>::zeros([2, 3, 2, 3], &dev),
351 Tensor::<B, 2, Int>::from_data([[1_i32]], &dev),
352 Tensor::<B, 1, Int>::from_data([3, 3], &dev),
353 Tensor::<B, 1, Int>::from_data([1, 1], &dev),
354 );
355 }
356
357 #[test]
358 #[should_panic(expected = "logit_lengths length")]
359 fn panics_on_logit_lengths_mismatch() {
360 let dev = FlexDevice;
361 let rnnt = RNNTLossConfig::new().init();
362 rnnt.forward(
363 Tensor::<B, 4>::zeros([2, 3, 2, 3], &dev),
364 Tensor::<B, 2, Int>::from_data([[1_i32], [2]], &dev),
365 Tensor::<B, 1, Int>::from_data([3], &dev),
366 Tensor::<B, 1, Int>::from_data([1, 1], &dev),
367 );
368 }
369
370 #[test]
371 #[should_panic(expected = "target_lengths length")]
372 fn panics_on_target_lengths_mismatch() {
373 let dev = FlexDevice;
374 let rnnt = RNNTLossConfig::new().init();
375 rnnt.forward(
376 Tensor::<B, 4>::zeros([2, 3, 2, 3], &dev),
377 Tensor::<B, 2, Int>::from_data([[1_i32], [2]], &dev),
378 Tensor::<B, 1, Int>::from_data([3, 3], &dev),
379 Tensor::<B, 1, Int>::from_data([1], &dev),
380 );
381 }
382
383 #[test]
384 fn single_token_uniform_probs() {
385 let dev = FlexDevice;
391 let rnnt = RNNTLossConfig::new().with_logits(false).init();
392 let time_steps = 2;
393 let target_len = 1;
394 let v = NUM_LABELS as f32;
395 let log_uniform = (1.0 / v).ln();
396
397 let loss = rnnt.forward(
398 Tensor::<B, 4>::full(
399 [1, time_steps, target_len + 1, NUM_LABELS],
400 log_uniform,
401 &dev,
402 ),
403 Tensor::<B, 2, Int>::from_data([[1_i32]], &dev),
404 Tensor::<B, 1, Int>::from_data([time_steps as i64], &dev),
405 Tensor::<B, 1, Int>::from_data([target_len as i64], &dev),
406 );
407 let num_paths = time_steps as f32;
409 let emissions_per_path = (time_steps + target_len) as f32;
410 let total_prob = num_paths * v.powf(-emissions_per_path);
411 let expected_loss = -total_prob.ln();
412 loss.into_data().assert_approx_eq::<f32>(
413 &TensorData::from([expected_loss]),
414 Tolerance::absolute(1e-4),
415 );
416 }
417
418 #[test]
419 fn empty_target() {
420 let dev = FlexDevice;
426 let rnnt = RNNTLossConfig::new().with_logits(false).init();
427 let time_steps = 3;
428 let target_len = 0;
429 let v = NUM_LABELS as f32;
430 let log_uniform = (1.0 / v).ln();
431
432 let loss = rnnt.forward(
433 Tensor::<B, 4>::full([1, time_steps, 2, NUM_LABELS], log_uniform, &dev),
434 Tensor::<B, 2, Int>::from_data([[1_i32]], &dev),
435 Tensor::<B, 1, Int>::from_data([time_steps as i64], &dev),
436 Tensor::<B, 1, Int>::from_data([target_len as i64], &dev),
437 );
438 let expected_loss = -v.powf(-((time_steps + target_len) as f32)).ln();
440 loss.into_data().assert_approx_eq::<f32>(
441 &TensorData::from([expected_loss]),
442 Tolerance::absolute(1e-4),
443 );
444 }
445
446 #[test]
447 fn logits_equivalence() {
448 let dev = FlexDevice;
451 let [bs, time_steps, up1, vocab] = [1, 2, 3, 4];
452 let num_elements = bs * time_steps * up1 * vocab;
453 let target_len = up1 - 1;
454
455 let data: Vec<f32> = (0..num_elements).map(|i| (i as f32 * 0.3).sin()).collect();
456 let logits = Tensor::<B, 4>::from_data(
457 burn_core::tensor::TensorData::new(data, [bs, time_steps, up1, vocab]),
458 &dev,
459 );
460 let targets = Tensor::<B, 2, Int>::from_data([[1_i32, 2]], &dev);
461 let logit_lengths = Tensor::<B, 1, Int>::from_data([time_steps as i64], &dev);
462 let target_lengths = Tensor::<B, 1, Int>::from_data([target_len as i64], &dev);
463
464 let vocab_dim = 3;
465 let fused = RNNTLossConfig::new().with_logits(true).init().forward(
466 logits.clone(),
467 targets.clone(),
468 logit_lengths.clone(),
469 target_lengths.clone(),
470 );
471
472 let log_probs = burn::tensor::activation::log_softmax(logits, vocab_dim);
473 let manual = RNNTLossConfig::new().with_logits(false).init().forward(
474 log_probs,
475 targets,
476 logit_lengths,
477 target_lengths,
478 );
479
480 fused
481 .into_data()
482 .assert_approx_eq::<f32>(&manual.into_data(), Tolerance::absolute(1e-4));
483 }
484}
485
486#[cfg(test)]
491#[allow(clippy::identity_op, clippy::too_many_arguments)]
492mod pytorch_comparison_tests {
493 use super::*;
494 use burn::tensor::{TensorData, Tolerance};
495 use burn_autodiff::Autodiff;
496 use burn_flex::{Flex, FlexDevice};
497
498 type B = Autodiff<Flex>;
499 fn tol() -> Tolerance<f32> {
500 Tolerance::absolute(1e-3)
501 }
502
503 fn make_logits(bs: usize, t: usize, u: usize, v: usize, dev: &FlexDevice) -> Tensor<B, 4> {
506 let mut data = Vec::with_capacity(bs * t * u * v);
507 for bi in 0..bs {
508 for ti in 0..t {
509 for ui in 0..u {
510 for vi in 0..v {
511 let idx = bi * 11 + ti * 7 + ui * 13 + vi * 3;
512 data.push((idx as f32 * 0.1).sin());
513 }
514 }
515 }
516 }
517 Tensor::from_data(TensorData::new(data, [bs, t, u, v]), dev)
518 }
519
520 fn check_vocab_grad_sums(grad: &[f32], bs: usize, t: usize, up1: usize, v: usize) {
524 for bi in 0..bs {
525 for ti in 0..t {
526 for ui in 0..up1 {
527 let base = ((bi * t + ti) * up1 + ui) * v;
528 let sum: f32 = (0..v).map(|vi| grad[base + vi]).sum();
529 TensorData::from([sum])
530 .assert_approx_eq::<f32>(&TensorData::from([0.0f32]), tol());
531 }
532 }
533 }
534 }
535
536 fn grad_at(
538 grad: &[f32],
539 b: usize,
540 t: usize,
541 u: usize,
542 max_t: usize,
543 up1: usize,
544 v: usize,
545 ) -> &[f32] {
546 let base = ((b * max_t + t) * up1 + u) * v;
547 &grad[base..base + v]
548 }
549
550 fn assert_grad(
552 grad: &[f32],
553 b: usize,
554 t: usize,
555 u: usize,
556 max_t: usize,
557 up1: usize,
558 v: usize,
559 expected: &[f32],
560 ) {
561 TensorData::from(grad_at(grad, b, t, u, max_t, up1, v))
562 .assert_approx_eq::<f32>(&TensorData::from(expected), tol());
563 }
564
565 #[test]
566 fn basic_b1() {
567 let dev = FlexDevice;
569 let rnnt = RNNTLossConfig::new().init();
570 let logits = make_logits(1, 4, 3, 3, &dev).require_grad();
571
572 let loss = rnnt.forward(
573 logits.clone(),
574 Tensor::<B, 2, Int>::from_data([[1_i32, 2]], &dev),
575 Tensor::<B, 1, Int>::from_data([4_i32], &dev),
576 Tensor::<B, 1, Int>::from_data([2_i32], &dev),
577 );
578 loss.clone()
579 .into_data()
580 .assert_approx_eq::<f32>(&TensorData::from([4.4491f32]), tol());
581
582 let grads = loss.sum().backward();
583 let grad = logits
584 .grad(&grads)
585 .unwrap()
586 .into_data()
587 .to_vec::<f32>()
588 .unwrap();
589
590 assert_grad(&grad, 0, 0, 0, 4, 3, 3, &[-0.2041, -0.2246, 0.4287]);
592 assert_grad(&grad, 0, 2, 0, 4, 3, 3, &[0.0079, -0.0640, 0.0561]);
593 assert_grad(&grad, 0, 3, 2, 4, 3, 3, &[-0.6899, 0.3231, 0.3667]);
594 check_vocab_grad_sums(&grad, 1, 4, 3, 3);
595 }
596
597 #[test]
598 fn batched_b2() {
599 let dev = FlexDevice;
601 let rnnt = RNNTLossConfig::new().init();
602 let logits = make_logits(2, 5, 4, 4, &dev).require_grad();
603
604 let loss = rnnt.forward(
605 logits.clone(),
606 Tensor::<B, 2, Int>::from_data(
607 TensorData::new(vec![1_i32, 2, 3, 2, 1, 3], [2, 3]),
608 &dev,
609 ),
610 Tensor::<B, 1, Int>::from_data([5_i32, 5], &dev),
611 Tensor::<B, 1, Int>::from_data([3_i32, 3], &dev),
612 );
613 loss.clone()
614 .into_data()
615 .assert_approx_eq::<f32>(&TensorData::from([7.9356f32, 7.2033]), tol());
616
617 let grads = loss.sum().backward();
618 let grad = logits
619 .grad(&grads)
620 .unwrap()
621 .into_data()
622 .to_vec::<f32>()
623 .unwrap();
624
625 assert_grad(&grad, 0, 0, 0, 5, 4, 4, &[-0.3161, -0.3113, 0.2796, 0.3479]);
627 assert_grad(&grad, 1, 0, 0, 5, 4, 4, &[-0.2766, 0.2602, -0.2248, 0.2411]);
628 assert_grad(&grad, 0, 4, 3, 5, 4, 4, &[-0.8216, 0.2296, 0.2786, 0.3133]);
629 assert_grad(&grad, 1, 4, 3, 5, 4, 4, &[-0.7185, 0.2735, 0.2437, 0.2012]);
630 check_vocab_grad_sums(&grad, 2, 5, 4, 4);
631 }
632
633 #[test]
634 fn variable_lengths_b3() {
635 let dev = FlexDevice;
639 let rnnt = RNNTLossConfig::new().init();
640 let logits = make_logits(3, 6, 4, 5, &dev).require_grad();
641
642 let loss = rnnt.forward(
643 logits.clone(),
644 Tensor::<B, 2, Int>::from_data(
645 TensorData::new(vec![1_i32, 2, 3, 4, 1, 0, 2, 0, 0], [3, 3]),
646 &dev,
647 ),
648 Tensor::<B, 1, Int>::from_data([6_i32, 4, 5], &dev),
649 Tensor::<B, 1, Int>::from_data([3_i32, 2, 1], &dev),
650 );
651 loss.clone()
652 .into_data()
653 .assert_approx_eq::<f32>(&TensorData::from([10.7458f32, 8.0196, 8.3316]), tol());
654
655 let grads = loss.sum().backward();
656 let grad = logits
657 .grad(&grads)
658 .unwrap()
659 .into_data()
660 .to_vec::<f32>()
661 .unwrap();
662 let stride = 4 * 5; let zeros = vec![0.0f32; 5];
664
665 assert_grad(
667 &grad,
668 0,
669 0,
670 0,
671 6,
672 4,
673 5,
674 &[-0.4232, -0.3114, 0.1992, 0.2478, 0.2876],
675 );
676 assert_grad(
677 &grad,
678 0,
679 5,
680 3,
681 6,
682 4,
683 5,
684 &[-0.8016, 0.2170, 0.2172, 0.1991, 0.1683],
685 );
686
687 assert_grad(
689 &grad,
690 1,
691 0,
692 0,
693 6,
694 4,
695 5,
696 &[-0.2502, 0.2160, 0.2173, 0.2002, -0.3833],
697 );
698 let sample1_t4_start = 1 * 6 * stride + 4 * stride;
699 for i in 0..(2 * stride) {
700 assert!(
702 grad[sample1_t4_start + i].abs() < 1e-3,
703 "sample 1, t>=4: grad[{}] = {} (expected 0)",
704 i,
705 grad[sample1_t4_start + i]
706 );
707 }
708
709 for ti in 0..4 {
711 assert_grad(&grad, 1, ti, 3, 6, 4, 5, &zeros);
712 }
713
714 let sample2_t5_start = 2 * 6 * stride + 5 * stride;
716 for i in 0..stride {
717 assert!(
718 grad[sample2_t5_start + i].abs() < 1e-3,
719 "sample 2, t=5: grad[{}] = {} (expected 0)",
720 i,
721 grad[sample2_t5_start + i]
722 );
723 }
724
725 check_vocab_grad_sums(&grad, 3, 6, 4, 5);
726 }
727
728 #[test]
729 fn sum_reduction() {
730 let dev = FlexDevice;
731 let rnnt = RNNTLossConfig::new().init();
732 let logits = make_logits(2, 5, 4, 4, &dev).require_grad();
733 let tgt = Tensor::<B, 2, Int>::from_data(
734 TensorData::new(vec![1_i32, 2, 3, 2, 1, 3], [2, 3]),
735 &dev,
736 );
737 let il = Tensor::<B, 1, Int>::from_data([5_i32, 5], &dev);
738 let tl = Tensor::<B, 1, Int>::from_data([3_i32, 3], &dev);
739
740 let loss = rnnt.forward_with_reduction(logits.clone(), tgt, il, tl, Reduction::Sum);
741 loss.clone()
743 .into_data()
744 .assert_approx_eq::<f32>(&TensorData::from([15.1389f32]), tol());
745
746 let grads = loss.backward();
747 let g = logits
748 .grad(&grads)
749 .unwrap()
750 .into_data()
751 .to_vec::<f32>()
752 .unwrap();
753 TensorData::from(&g[..4]).assert_approx_eq::<f32>(
754 &TensorData::from([-0.3161f32, -0.3113, 0.2796, 0.3479]),
755 tol(),
756 );
757 }
758
759 #[test]
760 fn mean_reduction() {
761 let dev = FlexDevice;
762 let rnnt = RNNTLossConfig::new().init();
763 let logits = make_logits(2, 5, 4, 4, &dev).require_grad();
764 let tgt = Tensor::<B, 2, Int>::from_data(
765 TensorData::new(vec![1_i32, 2, 3, 2, 1, 3], [2, 3]),
766 &dev,
767 );
768 let il = Tensor::<B, 1, Int>::from_data([5_i32, 5], &dev);
769 let tl = Tensor::<B, 1, Int>::from_data([3_i32, 3], &dev);
770
771 let loss = rnnt.forward_with_reduction(logits.clone(), tgt, il, tl, Reduction::Mean);
772 loss.clone()
774 .into_data()
775 .assert_approx_eq::<f32>(&TensorData::from([7.5694f32]), tol());
776
777 let grads = loss.backward();
779 let g = logits
780 .grad(&grads)
781 .unwrap()
782 .into_data()
783 .to_vec::<f32>()
784 .unwrap();
785 TensorData::from(&g[..4]).assert_approx_eq::<f32>(
786 &TensorData::from([-0.1581f32, -0.1557, 0.1398, 0.1739]),
787 tol(),
788 );
789 }
790}