1#![allow(clippy::excessive_precision)]
2
3use super::Reduction;
4use burn::config::Config;
5use burn::module::Module;
6use burn::tensor::{Int, Tensor, backend::Backend};
7use burn_core as burn;
8
9#[derive(Config, Debug)]
11pub struct CTCLossConfig {
12 #[config(default = 0)]
14 pub blank: usize,
15 #[config(default = false)]
17 pub zero_infinity: bool,
18}
19
20impl CTCLossConfig {
21 pub fn init(&self) -> CTCLoss {
23 CTCLoss {
24 blank: self.blank,
25 zero_infinity: self.zero_infinity,
26 }
27 }
28}
29
30#[derive(Module, Clone, Debug)]
78pub struct CTCLoss {
79 blank: usize,
80 zero_infinity: bool,
81}
82
83impl CTCLoss {
84 pub fn forward<B: Backend>(
108 &self,
109 log_probs: Tensor<B, 3>,
110 targets: Tensor<B, 2, Int>,
111 input_lengths: Tensor<B, 1, Int>,
112 target_lengths: Tensor<B, 1, Int>,
113 ) -> Tensor<B, 1> {
114 let [max_input_length, batch_size, num_classes] = log_probs.dims();
115 let max_target_len = targets.dims()[1];
116 let input_lengths_len = input_lengths.dims()[0];
117 let target_lengths_len = target_lengths.dims()[0];
118 self.assertions(
119 batch_size,
120 num_classes,
121 targets.clone(),
122 input_lengths_len,
123 target_lengths_len,
124 );
125 self.length_assertions(
126 input_lengths.clone(),
127 target_lengths.clone(),
128 max_target_len,
129 max_input_length,
130 );
131
132 let mut loss = burn::tensor::module::ctc_loss(
133 log_probs,
134 targets,
135 input_lengths,
136 target_lengths,
137 self.blank,
138 );
139
140 if self.zero_infinity {
141 let inf_mask = loss.clone().is_inf();
142 loss = loss.clone().mask_where(inf_mask, loss.clone().zeros_like());
143 }
144
145 loss
146 }
147
148 pub fn forward_with_reduction<B: Backend>(
180 &self,
181 log_probs: Tensor<B, 3>,
182 targets: Tensor<B, 2, Int>,
183 input_lengths: Tensor<B, 1, Int>,
184 target_lengths: Tensor<B, 1, Int>,
185 reduction: Reduction,
186 ) -> Tensor<B, 1> {
187 let ctc_loss_tensor =
188 self.forward(log_probs, targets, input_lengths, target_lengths.clone());
189
190 match reduction {
191 Reduction::Auto | Reduction::Mean => {
192 let target_lengths_float = target_lengths.float();
195 ctc_loss_tensor.div(target_lengths_float).mean()
196 }
197 Reduction::Sum => ctc_loss_tensor.sum(),
198 other => panic!("{other:?} reduction is not supported"),
199 }
200 }
201
202 #[allow(unused_variables)]
213 fn length_assertions<B: Backend>(
214 &self,
215 input_lengths: Tensor<B, 1, Int>,
216 target_lengths: Tensor<B, 1, Int>,
217 max_target_len: usize,
218 max_input_length: usize,
219 ) {
220 #[cfg(debug_assertions)]
221 {
222 let target_lengths_data = target_lengths.into_data();
223 let input_lengths_data = input_lengths.into_data();
224 let target_iter = target_lengths_data.iter::<i64>();
225 let input_iter = input_lengths_data.iter::<i64>();
226 for (i, (tl, il)) in target_iter.zip(input_iter).enumerate() {
227 assert!(tl >= 0, "target_lengths[{i}] = {tl} must be non-negative");
228 assert!(
229 tl as usize <= max_target_len,
230 "target_lengths[{i}] = {tl} exceeds the targets tensor width {max_target_len}"
231 );
232 assert!(
233 il >= tl,
234 "input_lengths[{i}] = {il} must be >= target_lengths[{i}] = {tl} \
235 (no valid CTC alignment otherwise)"
236 );
237 assert!(
238 il as usize <= max_input_length,
239 "input_lengths[{i}] = {il} exceeds the log_probs time dimension \
240 {max_input_length}"
241 );
242 }
243 }
244 }
245
246 fn assertions<B: Backend>(
247 &self,
248 batch_size: usize,
249 num_classes: usize,
250 targets: Tensor<B, 2, Int>,
251 input_lengths_len: usize,
252 target_lengths_len: usize,
253 ) {
254 assert!(
255 self.blank < num_classes,
256 "blank index {} must be less than num_classes {}",
257 self.blank,
258 num_classes
259 );
260 assert_eq!(
261 targets.dims()[0],
262 batch_size,
263 "targets batch dimension {} must equal batch_size {}",
264 targets.dims()[0],
265 batch_size
266 );
267 assert_eq!(
268 input_lengths_len, batch_size,
269 "input_lengths length {} must equal batch_size {}",
270 input_lengths_len, batch_size
271 );
272 assert_eq!(
273 target_lengths_len, batch_size,
274 "target_lengths length {} must equal batch_size {}",
275 target_lengths_len, batch_size
276 );
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use burn_flex::{Flex, FlexDevice};
284
285 type TestBackend = Flex;
286
287 fn assert_approx_equal(actual: &[f32], expected: &[f32], tol: f32) {
288 assert_eq!(
289 actual.len(),
290 expected.len(),
291 "Length mismatch: actual {} vs expected {}",
292 actual.len(),
293 expected.len()
294 );
295 for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
296 assert!(
297 (a - e).abs() < tol,
298 "Mismatch at index {}: expected {:.6}, got {:.6} (diff: {:.6})",
299 i,
300 e,
301 a,
302 (a - e).abs()
303 );
304 }
305 }
306
307 #[test]
312 #[should_panic(expected = "blank index")]
313 fn test_ctc_loss_panics_invalid_blank_index() {
314 let device = FlexDevice;
315 let ctc = CTCLossConfig::new().with_blank(5).init();
317
318 let log_probs = Tensor::<TestBackend, 3>::zeros([2, 1, 3], &device);
319 let targets = Tensor::<TestBackend, 2, Int>::from_data([[1]], &device);
320 let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2], &device);
321 let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1], &device);
322
323 ctc.forward(log_probs, targets, input_lengths, target_lengths);
324 }
325
326 #[test]
327 #[should_panic(expected = "must equal batch_size")]
328 fn test_ctc_loss_panics_mismatched_batch_size() {
329 let device = FlexDevice;
330 let ctc = CTCLossConfig::new().init();
331
332 let log_probs = Tensor::<TestBackend, 3>::zeros([2, 2, 3], &device);
334 let targets = Tensor::<TestBackend, 2, Int>::from_data([[1]], &device);
336 let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2, 2], &device);
337 let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1, 1], &device);
338
339 ctc.forward(log_probs, targets, input_lengths, target_lengths);
340 }
341
342 #[test]
343 #[should_panic(expected = "input_lengths length")]
344 fn test_ctc_loss_panics_input_lengths_mismatch() {
345 let device = FlexDevice;
346 let ctc = CTCLossConfig::new().init();
347
348 let log_probs = Tensor::<TestBackend, 3>::zeros([2, 2, 3], &device);
350 let targets = Tensor::<TestBackend, 2, Int>::from_data([[1], [2]], &device);
351
352 let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2], &device);
354 let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1, 1], &device);
355
356 ctc.forward(log_probs, targets, input_lengths, target_lengths);
357 }
358
359 #[test]
360 #[should_panic(expected = "target_lengths length")]
361 fn test_ctc_loss_panics_target_lengths_mismatch() {
362 let device = FlexDevice;
363 let ctc = CTCLossConfig::new().init();
364
365 let log_probs = Tensor::<TestBackend, 3>::zeros([2, 2, 3], &device);
367 let targets = Tensor::<TestBackend, 2, Int>::from_data([[1], [2]], &device);
368 let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2, 2], &device);
369
370 let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1], &device);
372
373 ctc.forward(log_probs, targets, input_lengths, target_lengths);
374 }
375
376 #[test]
381 fn test_ctc_loss_repeated_labels_minimum_input_length() {
382 let device = FlexDevice;
388 let ctc = CTCLossConfig::new().init();
389
390 let log_probs = Tensor::<TestBackend, 3>::full([3, 1, 2], 0.5_f32.ln(), &device);
391 let targets = Tensor::<TestBackend, 2, Int>::from_data([[1_i32, 1]], &device);
392 let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([3_i32], &device);
393 let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i32], &device);
394
395 let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);
396 let loss_data = loss.into_data().to_vec::<f32>().unwrap();
397 let expected = 3.0 * 2.0_f32.ln();
398 assert_approx_equal(&loss_data, &[expected], 1e-3);
399 }
400
401 #[test]
402 fn test_ctc_loss_custom_blank_uniform() {
403 let device = FlexDevice;
410 let ctc = CTCLossConfig::new().with_blank(2).init();
411
412 let log_probs = Tensor::<TestBackend, 3>::full([3, 1, 3], (1.0_f32 / 3.0).ln(), &device);
413 let targets = Tensor::<TestBackend, 2, Int>::from_data([[0_i32, 1]], &device);
414 let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([3_i32], &device);
415 let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i32], &device);
416
417 let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);
418 let loss_data = loss.into_data().to_vec::<f32>().unwrap();
419 let expected = -(5.0_f32 / 27.0).ln();
420 assert_approx_equal(&loss_data, &[expected], 1e-3);
421 }
422
423 #[test]
428 fn test_ctc_loss_zero_infinity_produces_inf_when_disabled() {
429 let device = FlexDevice;
432 let ctc = CTCLossConfig::new().with_zero_infinity(false).init();
433
434 let log_probs = Tensor::<TestBackend, 3>::full([2, 1, 3], (1.0_f32 / 3.0).ln(), &device);
435 let targets = Tensor::<TestBackend, 2, Int>::from_data([[1_i32, 1]], &device);
436 let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i32], &device);
437 let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i32], &device);
438
439 let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);
440 let loss_data = loss.into_data().to_vec::<f32>().unwrap();
441 assert!(
442 loss_data[0].is_infinite() && loss_data[0] > 0.0,
443 "Expected +inf, got {}",
444 loss_data[0]
445 );
446 }
447
448 #[test]
449 fn test_ctc_loss_zero_infinity_masks_inf_when_enabled() {
450 let device = FlexDevice;
452 let ctc = CTCLossConfig::new().with_zero_infinity(true).init();
453
454 let log_probs = Tensor::<TestBackend, 3>::full([2, 1, 3], (1.0_f32 / 3.0).ln(), &device);
455 let targets = Tensor::<TestBackend, 2, Int>::from_data([[1_i32, 1]], &device);
456 let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i32], &device);
457 let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i32], &device);
458
459 let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);
460 let loss_data = loss.into_data().to_vec::<f32>().unwrap();
461 assert_approx_equal(&loss_data, &[0.0], 1e-6);
462 }
463
464 #[test]
465 fn test_ctc_loss_zero_infinity_does_not_affect_finite_loss() {
466 let device = FlexDevice;
468 let ctc = CTCLossConfig::new().with_zero_infinity(true).init();
469
470 let log_probs = Tensor::<TestBackend, 3>::full([2, 1, 2], 0.5_f32.ln(), &device);
471 let targets = Tensor::<TestBackend, 2, Int>::from_data([[1_i32]], &device);
472 let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i32], &device);
473 let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1_i32], &device);
474
475 let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);
476 let loss_data = loss.into_data().to_vec::<f32>().unwrap();
477 let expected = -(0.75_f32).ln();
478 assert_approx_equal(&loss_data, &[expected], 1e-3);
479 }
480}
481
482#[cfg(test)]
483mod pytorch_comparison_tests {
484 use super::*;
485 use burn::tensor::activation::log_softmax;
486 use burn_autodiff::Autodiff;
487 use burn_core::tensor::TensorData;
488 use burn_flex::{Flex, FlexDevice};
489
490 type InnerBackend = Flex;
491 type TestBackend = Autodiff<InnerBackend>;
492
493 fn assert_approx_equal(actual: &[f32], expected: &[f32], tol: f32) {
494 assert_eq!(
495 actual.len(),
496 expected.len(),
497 "Length mismatch: actual {} vs expected {}",
498 actual.len(),
499 expected.len()
500 );
501 for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
502 assert!(
503 (a - e).abs() < tol,
504 "Mismatch at index {}: expected {:.6}, got {:.6} (diff: {:.6})",
505 i,
506 e,
507 a,
508 (a - e).abs()
509 );
510 }
511 }
512
513 fn generate_logits(
515 t_size: usize,
516 n_size: usize,
517 c_size: usize,
518 device: &FlexDevice,
519 ) -> Tensor<TestBackend, 3> {
520 let mut data = Vec::with_capacity(t_size * n_size * c_size);
521 for t in 0..t_size {
522 for n in 0..n_size {
523 for c in 0..c_size {
524 data.push(((t * 7 + n * 13 + c * 3) as f32 * 0.1).sin());
525 }
526 }
527 }
528 Tensor::<TestBackend, 3>::from_data(TensorData::new(data, [t_size, n_size, c_size]), device)
529 }
530
531 #[allow(clippy::too_many_arguments)]
545 fn run_comparison(
546 label: &str,
547 t_size: usize,
548 n_size: usize,
549 c_size: usize,
550 targets_flat: Vec<i64>,
551 target_shape: [usize; 2],
552 input_lengths: Vec<i64>,
553 target_lengths: Vec<i64>,
554 blank: usize,
555 expected_losses: &[f32],
556 expected_grad_flat: &[f32],
557 loss_tol: f32,
558 grad_tol: f32,
559 ) {
560 let device = FlexDevice;
561 let ctc = CTCLossConfig::new().with_blank(blank).init();
562
563 let logits = generate_logits(t_size, n_size, c_size, &device).require_grad();
564 let log_probs = log_softmax(logits.clone(), 2);
565
566 let targets = Tensor::<TestBackend, 2, Int>::from_data(
567 TensorData::new(targets_flat, target_shape),
568 &device,
569 );
570 let input_lengths = Tensor::<TestBackend, 1, Int>::from_data(
571 TensorData::new(input_lengths, [n_size]),
572 &device,
573 );
574 let target_lengths = Tensor::<TestBackend, 1, Int>::from_data(
575 TensorData::new(target_lengths, [n_size]),
576 &device,
577 );
578
579 let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);
580 let loss_data = loss.clone().into_data().to_vec::<f32>().unwrap();
581
582 println!("=== {} ===", label);
583 println!(" Loss: {:?}", loss_data);
584 assert_approx_equal(&loss_data, expected_losses, loss_tol);
585
586 let loss_sum = loss.sum();
587 let grads = loss_sum.backward();
588 let logits_grad = logits.grad(&grads).unwrap();
589 let grad_data = logits_grad.into_data().to_vec::<f32>().unwrap();
590 assert_approx_equal(&grad_data, expected_grad_flat, grad_tol);
591 }
592
593 #[test]
594 fn test_ctc_loss_uniform_input_lengths() {
595 let expected_losses = [3.5236570835113525_f32, 3.495313882827759, 4.262677192687988];
598 let expected_grad_flat = [
599 -0.1679008007_f32,
600 -0.4595540464,
601 0.2795598209,
602 0.3478950262,
603 -0.3913056254,
604 -0.0832268298,
605 0.2535884976,
606 0.2209439576,
607 -0.0502742566,
608 0.2766197622,
609 0.2054125518,
610 -0.4317580462,
611 -0.0544800088,
612 -0.3144550920,
613 0.0847885981,
614 0.2841464877,
615 -0.1844545156,
616 -0.2063435912,
617 0.2222184092,
618 0.1685796976,
619 0.0278018005,
620 0.2657383382,
621 -0.0336986706,
622 -0.2598414719,
623 -0.0482986756,
624 -0.0098767160,
625 -0.1533526182,
626 0.2115280181,
627 -0.1380317956,
628 -0.2198686600,
629 0.2042596638,
630 0.1536407918,
631 0.0534787849,
632 0.1819230020,
633 -0.2805589139,
634 0.0451571345,
635 -0.0895631388,
636 0.1996460557,
637 -0.2741115987,
638 0.1640286744,
639 -0.2200077325,
640 -0.1693530381,
641 0.2101601064,
642 0.1792006642,
643 0.0398471877,
644 -0.1131042913,
645 -0.2363226712,
646 0.3095797896,
647 -0.2163617164,
648 0.2740726173,
649 -0.2124865055,
650 0.1547756046,
651 -0.4312027395,
652 -0.0446923785,
653 0.2330704331,
654 0.2428246588,
655 -0.0050083841,
656 -0.6256869435,
657 0.2689785957,
658 0.3617166877,
659 ];
660 run_comparison(
661 "T=5, N=3, C=4 (uniform input lengths)",
662 5,
663 3,
664 4,
665 vec![1, 2, 0, 1, 0, 0, 3, 2, 1],
666 [3, 3],
667 vec![5, 5, 5],
668 vec![2, 1, 3],
669 0,
670 &expected_losses,
671 &expected_grad_flat,
672 1e-3,
673 1e-3,
674 );
675 }
676
677 #[test]
678 fn test_ctc_loss_repeated_labels() {
679 let expected_losses = [
682 8.84203052520752_f32,
683 9.023029327392578,
684 9.398024559020996,
685 9.008068084716797,
686 ];
687 let expected_grad_flat = [
688 -0.2766432464,
689 -0.5202965736,
690 0.1523768753,
691 0.1896236390,
692 0.2200277001,
693 0.2349116206,
694 -0.1854365915,
695 0.2031330466,
696 -0.4260218740,
697 0.1678018719,
698 0.1360142529,
699 0.1045092493,
700 -0.6603536606,
701 0.2278252542,
702 0.1691786796,
703 0.1262856424,
704 0.0972681716,
705 0.0397959016,
706 -0.0894432291,
707 -0.5457318425,
708 0.1490373611,
709 0.1462858170,
710 0.1569476575,
711 0.1829041988,
712 -0.2842915654,
713 -0.4220107496,
714 0.1822281033,
715 0.1889107376,
716 0.1791101843,
717 0.1560532600,
718 -0.1155678406,
719 0.2295538932,
720 -0.2645366490,
721 -0.0288553704,
722 0.1027252972,
723 0.0766806602,
724 -0.5448347330,
725 0.2031028718,
726 0.1589304954,
727 0.1322451383,
728 0.1189499870,
729 -0.0683937520,
730 -0.0873993114,
731 -0.3051757514,
732 -0.2355299890,
733 0.1586059481,
734 0.2018169016,
735 0.2676822543,
736 -0.3225219846,
737 -0.2611543834,
738 0.1922984123,
739 0.1632783115,
740 0.1297036558,
741 0.0983960181,
742 -0.1507159024,
743 0.2256962359,
744 -0.1040333956,
745 -0.1514528394,
746 0.0985243544,
747 0.0819815546,
748 -0.2940836251,
749 0.1586865336,
750 0.1468491107,
751 0.1485087872,
752 0.1639631987,
753 -0.3239239752,
754 -0.0767390430,
755 -0.0434846729,
756 -0.4023587406,
757 -0.0052628326,
758 0.2273432612,
759 0.3005020022,
760 -0.2598774135,
761 -0.2188862711,
762 0.1678501070,
763 0.1352078766,
764 0.1002781317,
765 0.0754275694,
766 -0.1502914876,
767 0.1930875033,
768 -0.0709601715,
769 -0.2219523191,
770 0.1243555173,
771 0.1257609427,
772 -0.0574148744,
773 0.1152269915,
774 0.1307857931,
775 0.1599020809,
776 0.2068412602,
777 -0.5553412437,
778 -0.0536844917,
779 0.0758557543,
780 -0.2106334567,
781 -0.2509877980,
782 0.1757438034,
783 0.2637061775,
784 -0.1759711355,
785 -0.2431350052,
786 0.1071053818,
787 0.1259848624,
788 0.1004033238,
789 0.0856125653,
790 -0.1173698306,
791 0.1213828772,
792 -0.1768893301,
793 -0.2070008069,
794 0.1709136516,
795 0.2089634240,
796 0.0153109450,
797 0.0967332721,
798 0.1268781722,
799 0.1706230640,
800 0.2291058898,
801 -0.6386513710,
802 -0.0536664203,
803 0.1378114969,
804 0.0360041447,
805 -0.2989685237,
806 -0.0084722806,
807 0.1872915775,
808 -0.1523490399,
809 -0.2111770809,
810 -0.0390694551,
811 0.1366800815,
812 0.1302325875,
813 0.1356829405,
814 -0.0982905105,
815 -0.0127884001,
816 -0.3586881459,
817 -0.0259541404,
818 0.2114149332,
819 0.2843062580,
820 -0.0324133746,
821 0.1084750593,
822 0.1447229236,
823 0.1862253845,
824 0.2259712219,
825 -0.6329812407,
826 -0.1173689738,
827 0.1914442331,
828 0.1654772907,
829 -0.1376858056,
830 -0.2194855511,
831 0.1176188141,
832 -0.1529908478,
833 -0.0606661662,
834 -0.3384291232,
835 0.1524862647,
836 0.1777049750,
837 0.2218948901,
838 -0.0923086405,
839 -0.2855934799,
840 -0.3215619624,
841 0.1726681292,
842 0.2303666323,
843 0.2964293361,
844 -0.2508065701,
845 0.1479703039,
846 0.1753441393,
847 0.1917535067,
848 0.1919818372,
849 -0.4562432170,
850 -0.2350299209,
851 0.2257601619,
852 0.1863904297,
853 0.0388212129,
854 -0.2966264784,
855 0.0806845874,
856 -0.1992894858,
857 0.1068909168,
858 -0.5761897564,
859 0.1624972969,
860 0.2155302167,
861 0.2905607820,
862 -0.1168124676,
863 -0.6870660186,
864 0.1488010883,
865 0.1881926507,
866 0.2230074406,
867 0.2438773215,
868 -0.5771554708,
869 0.1980127096,
870 0.1924194694,
871 0.1714663208,
872 0.1415647417,
873 -0.1263078004,
874 -0.3408652246,
875 0.2292248607,
876 0.1707807332,
877 0.1269564927,
878 -0.2634142637,
879 0.0773174241,
880 ];
881 run_comparison(
882 "T=8, N=4, C=6 (repeated labels)",
883 8,
884 4,
885 6,
886 vec![1, 1, 2, 0, 2, 3, 2, 1, 5, 0, 0, 0, 1, 2, 3, 4],
887 [4, 4],
888 vec![8, 8, 8, 8],
889 vec![3, 4, 1, 4],
890 0,
891 &expected_losses,
892 &expected_grad_flat,
893 1e-3,
894 1e-3,
895 );
896 }
897
898 #[test]
899 fn test_ctc_loss_long_sequence() {
900 let expected_losses = [12.629399299621582, 12.298524856567383];
903 let expected_grad_flat = [
904 -0.2570972741,
905 -0.6013792753,
906 0.1061997041,
907 0.1321590245,
908 0.1533492655,
909 0.1637226790,
910 0.1598964781,
911 0.1431493312,
912 -0.2540431321,
913 0.1788398325,
914 -0.4038805366,
915 0.1477340311,
916 0.1197479516,
917 0.0920107216,
918 0.0686140805,
919 0.0509770736,
920 -0.1364373565,
921 -0.3724762201,
922 0.1489177048,
923 -0.0966964588,
924 0.1463697106,
925 0.1275274903,
926 0.1033692732,
927 0.0794258416,
928 -0.1771971881,
929 0.2073454857,
930 -0.3109439015,
931 0.1249521226,
932 -0.0101635465,
933 0.0692621097,
934 0.0533472970,
935 0.0433975980,
936 -0.1398337185,
937 -0.0874802172,
938 0.1705365479,
939 -0.2174201906,
940 0.1150254831,
941 0.0460043959,
942 0.0647982135,
943 0.0483694859,
944 -0.2332949787,
945 0.1969220787,
946 -0.1270586401,
947 0.1098557115,
948 -0.1364655048,
949 0.0715296715,
950 0.0553609394,
951 0.0631506816,
952 -0.2169117928,
953 0.0929956511,
954 0.1624538749,
955 -0.2009791434,
956 0.0904926360,
957 -0.0248185843,
958 0.0532633252,
959 0.0435040221,
960 -0.2313277274,
961 0.1497355998,
962 -0.0024202778,
963 0.1029939279,
964 -0.2776987851,
965 0.0963881761,
966 0.0351882279,
967 0.1271408647,
968 -0.2590557337,
969 0.1577988416,
970 0.1429322213,
971 -0.1401246637,
972 0.0866033062,
973 -0.1151762009,
974 0.0683368817,
975 0.0586853735,
976 -0.1322475076,
977 0.0806737095,
978 0.0528722852,
979 0.0920089707,
980 -0.3037962914,
981 0.1280544847,
982 -0.1391123086,
983 0.2215466499,
984 -0.1918463260,
985 0.1376975775,
986 0.1160097718,
987 -0.0549413785,
988 0.0970225409,
989 -0.2708687484,
990 0.1147320047,
991 0.0521945432,
992 -0.0504456684,
993 -0.0012221609,
994 0.0644332916,
995 0.0818370953,
996 -0.1036835983,
997 0.1512031406,
998 -0.4072600305,
999 0.2651379406,
1000 -0.0681083873,
1001 0.0860663429,
1002 0.0810486302,
1003 0.0434282124,
1004 0.1056238264,
1005 -0.2994530201,
1006 0.1729898751,
1007 -0.1215954795,
1008 -0.0481944978,
1009 -0.1697723418,
1010 0.0725984722,
1011 0.0692019314,
1012 0.0859903544,
1013 0.1680216491,
1014 -0.4071443677,
1015 0.2292988002,
1016 -0.0205532499,
1017 0.0566616580,
1018 0.0326749459,
1019 0.0861379728,
1020 0.1142501161,
1021 -0.0448331088,
1022 0.2054910213,
1023 -0.4298293889,
1024 -0.0647637174,
1025 -0.4240962267,
1026 0.1013666242,
1027 -0.0110451467,
1028 0.1519176364,
1029 0.1661346704,
1030 -0.0719586164,
1031 0.1524447650,
1032 -0.0496110357,
1033 0.0562372655,
1034 -0.1889088154,
1035 0.1013496071,
1036 0.1339637935,
1037 0.1694275290,
1038 0.2007708699,
1039 -0.4232292175,
1040 -0.0401752405,
1041 -0.2951072752,
1042 0.1443216652,
1043 -0.2857291698,
1044 0.1489982456,
1045 0.1327733696,
1046 0.1096193567,
1047 0.0852990299,
1048 -0.0413062274,
1049 0.0820900649,
1050 -0.7903561592,
1051 0.1329460591,
1052 0.1535883099,
1053 0.1631743014,
1054 0.1585651338,
1055 0.1412984729,
1056 -0.1033771932,
1057 0.1799504310,
1058 0.1697744429,
1059 -0.5749052763,
1060 0.1189445183,
1061 0.0911802500,
1062 0.0679325759,
1063 0.0505003072,
1064 ];
1065 run_comparison(
1066 "T=10, N=2, C=8",
1067 10,
1068 2,
1069 8,
1070 vec![1, 3, 5, 7, 2, 2, 4, 6, 1, 3],
1071 [2, 5],
1072 vec![10, 10],
1073 vec![5, 5],
1074 0,
1075 &expected_losses,
1076 &expected_grad_flat,
1077 1e-3,
1078 1e-3,
1079 );
1080 }
1081
1082 #[test]
1083 fn test_ctc_loss_mixed_input_lengths() {
1084 let expected_losses = [10.595505714416504, 6.8078508377075195, 7.705057144165039];
1087 let expected_grad_flat = [
1088 -0.4790987670,
1089 -0.2554937005,
1090 0.1991624236,
1091 0.2478453964,
1092 0.2875846624,
1093 -0.3495813310,
1094 0.2268397957,
1095 0.2150714993,
1096 -0.2442178279,
1097 0.1518878639,
1098 -0.2764556706,
1099 0.2474014312,
1100 -0.2137086987,
1101 0.1371368915,
1102 0.1056260392,
1103 -0.2729502618,
1104 -0.3609606028,
1105 0.2159237266,
1106 0.2238420397,
1107 0.1941450834,
1108 -0.2953839302,
1109 0.1920599341,
1110 0.1974952668,
1111 -0.2054278404,
1112 0.1112565696,
1113 -0.1719199270,
1114 0.2299505472,
1115 -0.2864859998,
1116 0.1497263014,
1117 0.0787290633,
1118 -0.2035763413,
1119 -0.3042884767,
1120 0.2126964629,
1121 0.1810975969,
1122 0.1140707731,
1123 -0.2759391963,
1124 0.0975771844,
1125 0.1823379993,
1126 -0.1112988219,
1127 0.1073228419,
1128 -0.1336459517,
1129 0.1869296581,
1130 -0.1996247321,
1131 0.1846873760,
1132 -0.0383463502,
1133 -0.2254105806,
1134 -0.1834360659,
1135 0.1925925612,
1136 0.1462381780,
1137 0.0700158924,
1138 -0.2259973884,
1139 -0.0393539183,
1140 0.1802661419,
1141 -0.0571591072,
1142 0.1422442794,
1143 -0.0609069727,
1144 0.1089282706,
1145 -0.0313654318,
1146 0.2186669111,
1147 -0.2353227735,
1148 -0.2840364873,
1149 -0.0632198900,
1150 0.1755636632,
1151 0.1377806067,
1152 0.0339120962,
1153 -0.1904856712,
1154 -0.2139032930,
1155 0.1827126741,
1156 0.0056131603,
1157 0.2160631120,
1158 -0.0243270602,
1159 -0.0070458520,
1160 0.1070247591,
1161 0.2239368409,
1162 -0.2995886803,
1163 -0.2955487072,
1164 0.0309870224,
1165 0.1654911339,
1166 0.1581364125,
1167 -0.0590658709,
1168 -0.2191396207,
1169 -0.3791662455,
1170 0.1803640425,
1171 0.1225430891,
1172 0.2953987718,
1173 -0.0436352938,
1174 -0.1575258970,
1175 0.1785279512,
1176 0.1756918877,
1177 -0.1530586481,
1178 -0.1834939867,
1179 0.0909025446,
1180 0.1423641294,
1181 0.1959712654,
1182 -0.2457439601,
1183 -0.3619639874,
1184 -0.3929221630,
1185 0.1820438206,
1186 0.2454170734,
1187 0.3274252713,
1188 -0.0628800318,
1189 -0.2567180395,
1190 0.2112283260,
1191 0.0507859327,
1192 0.0575838275,
1193 -0.0587697029,
1194 0.1174769849,
1195 0.0783569664,
1196 0.2290501744,
1197 -0.3661144078,
1198 0.0000000000,
1199 0.0000000000,
1200 0.0000000000,
1201 0.0000000000,
1202 0.0000000000,
1203 -0.0725664943,
1204 -0.1532069892,
1205 0.2162397504,
1206 -0.1248963475,
1207 0.1344300956,
1208 -0.0362483934,
1209 0.1295878887,
1210 -0.0502482466,
1211 0.2470482886,
1212 -0.2901395261,
1213 0.0000000000,
1214 0.0000000000,
1215 0.0000000000,
1216 0.0000000000,
1217 0.0000000000,
1218 -0.1349253207,
1219 0.0867646411,
1220 0.1998746395,
1221 -0.2658679783,
1222 0.1141540110,
1223 -0.0705668628,
1224 0.1519546807,
1225 -0.2509805560,
1226 0.2475892603,
1227 -0.0779965296,
1228 0.0000000000,
1229 0.0000000000,
1230 0.0000000000,
1231 0.0000000000,
1232 0.0000000000,
1233 -0.2338010073,
1234 0.2471641302,
1235 0.1834627241,
1236 -0.3026831448,
1237 0.1058573127,
1238 -0.1155209392,
1239 0.1921830922,
1240 -0.4129956067,
1241 0.2229512781,
1242 0.1133821756,
1243 0.0000000000,
1244 0.0000000000,
1245 0.0000000000,
1246 0.0000000000,
1247 0.0000000000,
1248 0.0000000000,
1249 0.0000000000,
1250 0.0000000000,
1251 0.0000000000,
1252 0.0000000000,
1253 -0.2636392713,
1254 0.2323469073,
1255 -0.2913427949,
1256 0.1800564528,
1257 0.1425786912,
1258 0.0000000000,
1259 0.0000000000,
1260 0.0000000000,
1261 0.0000000000,
1262 0.0000000000,
1263 0.0000000000,
1264 0.0000000000,
1265 0.0000000000,
1266 0.0000000000,
1267 0.0000000000,
1268 ];
1269 run_comparison(
1270 "T=12, N=3, C=5 (mixed input lengths)",
1271 12,
1272 3,
1273 5,
1274 vec![1, 4, 2, 0, 3, 1, 0, 0, 2, 4, 1, 3],
1275 [3, 4],
1276 vec![12, 7, 10],
1277 vec![3, 2, 4],
1278 0,
1279 &expected_losses,
1280 &expected_grad_flat,
1281 1e-3,
1282 1e-3,
1283 );
1284 }
1285
1286 #[test]
1287 fn test_ctc_loss_sum_reduction() {
1288 let device = FlexDevice;
1290 let ctc = CTCLossConfig::new().init();
1291
1292 let logits = generate_logits(5, 3, 4, &device).require_grad();
1293 let log_probs = log_softmax(logits.clone(), 2);
1294 let targets = Tensor::<TestBackend, 2, Int>::from_data(
1295 TensorData::new(vec![1_i32, 2, 0, 1, 0, 0, 3, 2, 1], [3, 3]),
1296 &device,
1297 );
1298 let il = Tensor::<TestBackend, 1, Int>::from_data([5_i32, 5, 5], &device);
1299 let tl = Tensor::<TestBackend, 1, Int>::from_data([2_i32, 1, 3], &device);
1300
1301 let loss = ctc.forward_with_reduction(log_probs, targets, il, tl, Reduction::Sum);
1302 let loss_data = loss.clone().into_data().to_vec::<f32>().unwrap();
1303
1304 let expected_sum = 11.2816486359_f32; assert_approx_equal(&loss_data, &[expected_sum], 1e-3);
1306
1307 let grads = loss.backward();
1308 let logits_grad = logits.grad(&grads).unwrap();
1309 let grad_data = logits_grad.into_data().to_vec::<f32>().unwrap();
1310 let expected_grad = [
1312 -0.1679008007_f32,
1313 -0.4595540464,
1314 0.2795598209,
1315 0.3478950262,
1316 -0.3913056254,
1317 -0.0832268298,
1318 0.2535884976,
1319 0.2209439576,
1320 -0.0502742566,
1321 0.2766197622,
1322 0.2054125518,
1323 -0.4317580462,
1324 -0.0544800088,
1325 -0.3144550920,
1326 0.0847885981,
1327 0.2841464877,
1328 -0.1844545156,
1329 -0.2063435912,
1330 0.2222184092,
1331 0.1685796976,
1332 0.0278018005,
1333 0.2657383382,
1334 -0.0336986706,
1335 -0.2598414719,
1336 -0.0482986756,
1337 -0.0098767160,
1338 -0.1533526182,
1339 0.2115280181,
1340 -0.1380317956,
1341 -0.2198686600,
1342 0.2042596638,
1343 0.1536407918,
1344 0.0534787849,
1345 0.1819230020,
1346 -0.2805589139,
1347 0.0451571345,
1348 -0.0895631388,
1349 0.1996460557,
1350 -0.2741115987,
1351 0.1640286744,
1352 -0.2200077325,
1353 -0.1693530381,
1354 0.2101601064,
1355 0.1792006642,
1356 0.0398471877,
1357 -0.1131042913,
1358 -0.2363226712,
1359 0.3095797896,
1360 -0.2163617164,
1361 0.2740726173,
1362 -0.2124865055,
1363 0.1547756046,
1364 -0.4312027395,
1365 -0.0446923785,
1366 0.2330704331,
1367 0.2428246588,
1368 -0.0050083841,
1369 -0.6256869435,
1370 0.2689785957,
1371 0.3617166877,
1372 ];
1373 assert_approx_equal(&grad_data, &expected_grad, 1e-3);
1374 }
1375
1376 #[test]
1377 fn test_ctc_loss_mean_reduction() {
1378 let device = FlexDevice;
1379 let ctc = CTCLossConfig::new().init();
1380
1381 let logits = generate_logits(5, 3, 4, &device).require_grad();
1382 let log_probs = log_softmax(logits.clone(), 2);
1383 let targets = Tensor::<TestBackend, 2, Int>::from_data(
1384 TensorData::new(vec![1_i32, 2, 0, 1, 0, 0, 3, 2, 1], [3, 3]),
1385 &device,
1386 );
1387 let il = Tensor::<TestBackend, 1, Int>::from_data([5_i32, 5, 5], &device);
1388 let tl = Tensor::<TestBackend, 1, Int>::from_data([2_i32, 1, 3], &device);
1389
1390 let loss = ctc.forward_with_reduction(log_probs, targets, il, tl, Reduction::Mean);
1391 let loss_data = loss.clone().into_data().to_vec::<f32>().unwrap();
1392
1393 let expected_mean = 2.2260115147_f32; assert_approx_equal(&loss_data, &[expected_mean], 1e-3);
1395
1396 let grads = loss.backward();
1397 let logits_grad = logits.grad(&grads).unwrap();
1398 let grad_data = logits_grad.into_data().to_vec::<f32>().unwrap();
1399 let expected_grad = [
1401 -0.0279834662_f32,
1402 -0.0765923411,
1403 0.0465933047,
1404 0.0579825081,
1405 -0.1304352134,
1406 -0.0277422778,
1407 0.0845294967,
1408 0.0736479908,
1409 -0.0055860290,
1410 0.0307355281,
1411 0.0228236169,
1412 -0.0479731150,
1413 -0.0090800021,
1414 -0.0524091832,
1415 0.0141314333,
1416 0.0473577492,
1417 -0.0614848398,
1418 -0.0687812045,
1419 0.0740728080,
1420 0.0561932363,
1421 0.0030890885,
1422 0.0295264814,
1423 -0.0037442972,
1424 -0.0288712755,
1425 -0.0080497796,
1426 -0.0016461194,
1427 -0.0255587716,
1428 0.0352546684,
1429 -0.0460105985,
1430 -0.0732895583,
1431 0.0680865571,
1432 0.0512135960,
1433 0.0059420872,
1434 0.0202136654,
1435 -0.0311732125,
1436 0.0050174589,
1437 -0.0149271907,
1438 0.0332743451,
1439 -0.0456852652,
1440 0.0273381118,
1441 -0.0733359158,
1442 -0.0564510152,
1443 0.0700533763,
1444 0.0597335547,
1445 0.0044274656,
1446 -0.0125671430,
1447 -0.0262580756,
1448 0.0343977548,
1449 -0.0360602848,
1450 0.0456787720,
1451 -0.0354144201,
1452 0.0257959347,
1453 -0.1437342465,
1454 -0.0148974592,
1455 0.0776901469,
1456 0.0809415579,
1457 -0.0005564869,
1458 -0.0695207715,
1459 0.0298865121,
1460 0.0401907414,
1461 ];
1462 assert_approx_equal(&grad_data, &expected_grad, 1e-3);
1463 }
1464}