Skip to main content

burn_nn/loss/
ctc.rs

1#![allow(clippy::excessive_precision)]
2
3use super::Reduction;
4use burn::config::Config;
5use burn::module::Module;
6use burn::tensor::{Int, Tensor, backend::Backend};
7use burn_core as burn;
8
9/// Configuration for the [CTC Loss](CTCLoss) module.
10#[derive(Config, Debug)]
11pub struct CTCLossConfig {
12    /// The index number used to represent the blank label. Default value is `0`.
13    #[config(default = 0)]
14    pub blank: usize,
15    /// Whether to zero infinite losses and the associated gradients. Default value is `false`.
16    #[config(default = false)]
17    pub zero_infinity: bool,
18}
19
20impl CTCLossConfig {
21    /// Initialize a new [CTC Loss](CTCLoss) module
22    pub fn init(&self) -> CTCLoss {
23        CTCLoss {
24            blank: self.blank,
25            zero_infinity: self.zero_infinity,
26        }
27    }
28}
29
30/// Computes the Connectionist Temporal Classification (CTC) loss.
31///
32/// Calculates the loss between a continuous (unsegmented) time series and a target sequence.
33/// CTC sums over the probability of all possible alignments of the input to the target,
34/// producing a loss value that is differentiable with respect to each input node.
35///
36/// The input to this loss is expected to be **log-probabilities** (e.g,, via `log_softmax`),
37/// not raw logits.
38///
39/// # References
40///
41/// - [Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks](https://www.cs.toronto.edu/~graves/icml_2006.pdf)
42///
43/// # Example
44///
45/// ```rust,ignore
46/// use burn::tensor::{Tensor, Int};
47/// use burn::tensor::activation::log_softmax;
48/// use burn::nn::loss::{CTCLossConfig, CTCLoss};
49///
50/// let device = Default::default();
51///
52/// // Initialize CTC Loss with default configuration
53/// let ctc_loss = CTCLossConfig::new().init();
54///
55/// // Initialize CTC Loss with custom configuration
56/// let ctc_loss = CTCLossConfig::new()
57///     .with_blank(1)
58///     .with_zero_infinity(true)
59///     .init();
60///
61/// // Prepare inputs (Logits shape: [Time, Batch, Class])
62/// // In your actual code, the logits would be the output of your model
63/// let logits = Tensor::<B, 3>::ones([10, 2, 5], &device);
64/// let log_probs = log_softmax(logits, 2);
65///
66/// // Targets shape: [Batch, Max_Target_Len]
67/// // Note: Targets should not contain the blank index (1).
68/// let targets = Tensor::<B, 2, Int>::from_data([[0, 2], [3, 4]], &device);
69///
70/// // Lengths shape: [Batch]
71/// let input_lengths = Tensor::<B, 1, Int>::from_data([10, 8], &device);
72/// let target_lengths = Tensor::<B, 1, Int>::from_data([2, 2], &device);
73///
74/// // Compute loss
75/// let loss = ctc_loss.forward(log_probs, targets, input_lengths, target_lengths);
76/// ```
77#[derive(Module, Clone, Debug)]
78pub struct CTCLoss {
79    blank: usize,
80    zero_infinity: bool,
81}
82
83impl CTCLoss {
84    /// Computes the CTC loss for the input log-probabilities and targets with no reduction applied.
85    ///
86    /// # Arguments
87    ///
88    /// - `log_probs`: The log-probabilities of the outputs (e.g., from `log_softmax`).
89    /// - `targets`: A 2D tensor containing the target class indices. These indices should not
90    ///   include the blank index used in CTC loss. The targets are padded to the length of the longest sequence.
91    /// - `input_lengths`: A 1D tensor containing the actual length of the input sequence for each batch. This
92    ///   allows retrieving the actual sequence of log-probabilities from `log_probs` if the batch contains
93    ///   sequences of varying lengths.
94    /// - `target_lengths`: A 1D tensor containing the actual length of the target sequence for each target
95    ///   sequence in `targets`.
96    ///
97    /// # Returns
98    ///
99    /// - A 1D tensor of shape `[batch_size]` containing the loss for each sample.
100    ///
101    /// # Shapes
102    ///
103    /// - `log_probs`: `[time_steps, batch_size, num_classes]` where `num_classes` includes blank.
104    /// - `targets`: `[batch_size, max_target_length]`
105    /// - `input_lengths`: `[batch_size]`
106    /// - `target_lengths`: `[batch_size]`
107    pub fn forward<B: Backend>(
108        &self,
109        log_probs: Tensor<B, 3>,
110        targets: Tensor<B, 2, Int>,
111        input_lengths: Tensor<B, 1, Int>,
112        target_lengths: Tensor<B, 1, Int>,
113    ) -> Tensor<B, 1> {
114        let [max_input_length, batch_size, num_classes] = log_probs.dims();
115        let max_target_len = targets.dims()[1];
116        let input_lengths_len = input_lengths.dims()[0];
117        let target_lengths_len = target_lengths.dims()[0];
118        self.assertions(
119            batch_size,
120            num_classes,
121            targets.clone(),
122            input_lengths_len,
123            target_lengths_len,
124        );
125        self.length_assertions(
126            input_lengths.clone(),
127            target_lengths.clone(),
128            max_target_len,
129            max_input_length,
130        );
131
132        let mut loss = burn::tensor::module::ctc_loss(
133            log_probs,
134            targets,
135            input_lengths,
136            target_lengths,
137            self.blank,
138        );
139
140        if self.zero_infinity {
141            let inf_mask = loss.clone().is_inf();
142            loss = loss.clone().mask_where(inf_mask, loss.clone().zeros_like());
143        }
144
145        loss
146    }
147
148    /// Computes the CTC loss for the input log-probabilities and targets with reduction.
149    ///
150    /// # Arguments
151    ///
152    /// - `log_probs`: The log-probabilities of the outputs (e.g., from `log_softmax`).
153    /// - `targets`: A 2D tensor containing the target class indices. These indices should not
154    ///   include the blank index used in CTC loss. The targets are padded to the length of the longest sequence.
155    /// - `input_lengths`: A 1D tensor containing the actual length of the input sequence for each batch. This
156    ///   allows retrieving the actual sequence of log-probabilities from `log_probs` if the batch contains
157    ///   sequences of varying lengths.
158    /// - `target_lengths`: A 1D tensor containing the actual length of the target sequence for each target
159    ///   sequence in `targets`.
160    /// - `reduction`: The reduction stratey to apply to the loss tensor containing the CTC loss values for
161    ///   each sample (e.g., mean, sum). For the mean reduction strategy, the output losses will be divided
162    ///   by the target lengths and then the mean over the batch is taken. This follows PyTorch's behavior.
163    ///
164    /// # Returns
165    ///
166    /// - A 1D tensor of shape `[1]` containing the reduced loss value.
167    ///
168    /// # Shapes
169    ///
170    /// - `log_probs`: `[time_steps, batch_size, num_classes]` where `num_classes` includes blank.
171    /// - `targets`: `[batch_size, max_target_length]`
172    /// - `input_lengths`: `[batch_size]`
173    /// - `target_lengths`: `[batch_size]`
174    ///
175    /// # Panics
176    /// - If `reduction` is not one of `Reduction::Auto`, `Reduction::Mean`, and `Reduction::Sum`.
177    /// - If `blank` index is greater than or equal to `num_classes`.
178    /// - If the batch dimension of `log_probs`, `targets`, `input_lengths`, and `target_lengths` do not match.
179    pub fn forward_with_reduction<B: Backend>(
180        &self,
181        log_probs: Tensor<B, 3>,
182        targets: Tensor<B, 2, Int>,
183        input_lengths: Tensor<B, 1, Int>,
184        target_lengths: Tensor<B, 1, Int>,
185        reduction: Reduction,
186    ) -> Tensor<B, 1> {
187        let ctc_loss_tensor =
188            self.forward(log_probs, targets, input_lengths, target_lengths.clone());
189
190        match reduction {
191            Reduction::Auto | Reduction::Mean => {
192                // Following PyTorch's behavior where the output losses are divided
193                // by the target lengths and then the mean over the batch is taken
194                let target_lengths_float = target_lengths.float();
195                ctc_loss_tensor.div(target_lengths_float).mean()
196            }
197            Reduction::Sum => ctc_loss_tensor.sum(),
198            other => panic!("{other:?} reduction is not supported"),
199        }
200    }
201
202    /// Checks the per-element length invariants required by the alpha
203    /// recursion. These require reading the length tensors from the device,
204    /// so the checks are gated behind `cfg(debug_assertions)` to avoid the
205    /// device-to-host sync in release builds.
206    ///
207    /// Validated:
208    /// - `target_lengths[i] >= 0`
209    /// - `target_lengths[i] <= max_target_len`
210    /// - `input_lengths[i] >= target_lengths[i]`
211    /// - `input_lengths[i] <= max_input_length`
212    #[allow(unused_variables)]
213    fn length_assertions<B: Backend>(
214        &self,
215        input_lengths: Tensor<B, 1, Int>,
216        target_lengths: Tensor<B, 1, Int>,
217        max_target_len: usize,
218        max_input_length: usize,
219    ) {
220        #[cfg(debug_assertions)]
221        {
222            let target_lengths_data = target_lengths.into_data();
223            let input_lengths_data = input_lengths.into_data();
224            let target_iter = target_lengths_data.iter::<i64>();
225            let input_iter = input_lengths_data.iter::<i64>();
226            for (i, (tl, il)) in target_iter.zip(input_iter).enumerate() {
227                assert!(tl >= 0, "target_lengths[{i}] = {tl} must be non-negative");
228                assert!(
229                    tl as usize <= max_target_len,
230                    "target_lengths[{i}] = {tl} exceeds the targets tensor width {max_target_len}"
231                );
232                assert!(
233                    il >= tl,
234                    "input_lengths[{i}] = {il} must be >= target_lengths[{i}] = {tl} \
235                     (no valid CTC alignment otherwise)"
236                );
237                assert!(
238                    il as usize <= max_input_length,
239                    "input_lengths[{i}] = {il} exceeds the log_probs time dimension \
240                     {max_input_length}"
241                );
242            }
243        }
244    }
245
246    fn assertions<B: Backend>(
247        &self,
248        batch_size: usize,
249        num_classes: usize,
250        targets: Tensor<B, 2, Int>,
251        input_lengths_len: usize,
252        target_lengths_len: usize,
253    ) {
254        assert!(
255            self.blank < num_classes,
256            "blank index {} must be less than num_classes {}",
257            self.blank,
258            num_classes
259        );
260        assert_eq!(
261            targets.dims()[0],
262            batch_size,
263            "targets batch dimension {} must equal batch_size {}",
264            targets.dims()[0],
265            batch_size
266        );
267        assert_eq!(
268            input_lengths_len, batch_size,
269            "input_lengths length {} must equal batch_size {}",
270            input_lengths_len, batch_size
271        );
272        assert_eq!(
273            target_lengths_len, batch_size,
274            "target_lengths length {} must equal batch_size {}",
275            target_lengths_len, batch_size
276        );
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use burn_flex::{Flex, FlexDevice};
284
285    type TestBackend = Flex;
286
287    fn assert_approx_equal(actual: &[f32], expected: &[f32], tol: f32) {
288        assert_eq!(
289            actual.len(),
290            expected.len(),
291            "Length mismatch: actual {} vs expected {}",
292            actual.len(),
293            expected.len()
294        );
295        for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
296            assert!(
297                (a - e).abs() < tol,
298                "Mismatch at index {}: expected {:.6}, got {:.6} (diff: {:.6})",
299                i,
300                e,
301                a,
302                (a - e).abs()
303            );
304        }
305    }
306
307    // ---------------------------------------------------------------
308    // Assertions
309    // ---------------------------------------------------------------
310
311    #[test]
312    #[should_panic(expected = "blank index")]
313    fn test_ctc_loss_panics_invalid_blank_index() {
314        let device = FlexDevice;
315        // blank=5 is out of bounds for num_classes=3
316        let ctc = CTCLossConfig::new().with_blank(5).init();
317
318        let log_probs = Tensor::<TestBackend, 3>::zeros([2, 1, 3], &device);
319        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1]], &device);
320        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2], &device);
321        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1], &device);
322
323        ctc.forward(log_probs, targets, input_lengths, target_lengths);
324    }
325
326    #[test]
327    #[should_panic(expected = "must equal batch_size")]
328    fn test_ctc_loss_panics_mismatched_batch_size() {
329        let device = FlexDevice;
330        let ctc = CTCLossConfig::new().init();
331
332        // Logits batch size = 2
333        let log_probs = Tensor::<TestBackend, 3>::zeros([2, 2, 3], &device);
334        // Targets batch size = 1 (Mismatch)
335        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1]], &device);
336        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2, 2], &device);
337        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1, 1], &device);
338
339        ctc.forward(log_probs, targets, input_lengths, target_lengths);
340    }
341
342    #[test]
343    #[should_panic(expected = "input_lengths length")]
344    fn test_ctc_loss_panics_input_lengths_mismatch() {
345        let device = FlexDevice;
346        let ctc = CTCLossConfig::new().init();
347
348        // Logits batch size = 2
349        let log_probs = Tensor::<TestBackend, 3>::zeros([2, 2, 3], &device);
350        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1], [2]], &device);
351
352        // Input lengths size = 1 (Mismatch)
353        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2], &device);
354        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1, 1], &device);
355
356        ctc.forward(log_probs, targets, input_lengths, target_lengths);
357    }
358
359    #[test]
360    #[should_panic(expected = "target_lengths length")]
361    fn test_ctc_loss_panics_target_lengths_mismatch() {
362        let device = FlexDevice;
363        let ctc = CTCLossConfig::new().init();
364
365        // Logits batch size = 2
366        let log_probs = Tensor::<TestBackend, 3>::zeros([2, 2, 3], &device);
367        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1], [2]], &device);
368        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2, 2], &device);
369
370        // Target lengths size = 1 (Mismatch)
371        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1], &device);
372
373        ctc.forward(log_probs, targets, input_lengths, target_lengths);
374    }
375
376    // ---------------------------------------------------------------
377    // Edge Case & Config Tests
378    // ---------------------------------------------------------------
379
380    #[test]
381    fn test_ctc_loss_repeated_labels_minimum_input_length() {
382        // T=3, N=1, C=2, blank=0, target=[1, 1], uniform P = 1/2.
383        //
384        // The minimum T for target [1, 1] is 3: the only valid path is (1, 0, 1).
385        // prob = (1/2)^3 = 1/8
386        // Loss = -ln(1/8) = 3 * ln(2)
387        let device = FlexDevice;
388        let ctc = CTCLossConfig::new().init();
389
390        let log_probs = Tensor::<TestBackend, 3>::full([3, 1, 2], 0.5_f32.ln(), &device);
391        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1_i32, 1]], &device);
392        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([3_i32], &device);
393        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i32], &device);
394
395        let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);
396        let loss_data = loss.into_data().to_vec::<f32>().unwrap();
397        let expected = 3.0 * 2.0_f32.ln();
398        assert_approx_equal(&loss_data, &[expected], 1e-3);
399    }
400
401    #[test]
402    fn test_ctc_loss_custom_blank_uniform() {
403        // T=3, N=1, C=3, blank=2, target=[0, 1], uniform P = 1/3.
404        //
405        // Two distinct labels, 3 classes, 3 time steps, just with
406        // blank=2 instead of 0.
407        // 5 valid paths → total = 5/27
408        // Loss = -ln(5/27)
409        let device = FlexDevice;
410        let ctc = CTCLossConfig::new().with_blank(2).init();
411
412        let log_probs = Tensor::<TestBackend, 3>::full([3, 1, 3], (1.0_f32 / 3.0).ln(), &device);
413        let targets = Tensor::<TestBackend, 2, Int>::from_data([[0_i32, 1]], &device);
414        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([3_i32], &device);
415        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i32], &device);
416
417        let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);
418        let loss_data = loss.into_data().to_vec::<f32>().unwrap();
419        let expected = -(5.0_f32 / 27.0).ln();
420        assert_approx_equal(&loss_data, &[expected], 1e-3);
421    }
422
423    // ---------------------------------------------------------------
424    // zero_infinity tests
425    // ---------------------------------------------------------------
426
427    #[test]
428    fn test_ctc_loss_zero_infinity_produces_inf_when_disabled() {
429        // T=2, N=1, C=3, blank=0, target=[1, 1], input_length=2
430        // Target [1, 1] requires at least 3 time steps → no valid paths → loss = +inf
431        let device = FlexDevice;
432        let ctc = CTCLossConfig::new().with_zero_infinity(false).init();
433
434        let log_probs = Tensor::<TestBackend, 3>::full([2, 1, 3], (1.0_f32 / 3.0).ln(), &device);
435        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1_i32, 1]], &device);
436        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i32], &device);
437        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i32], &device);
438
439        let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);
440        let loss_data = loss.into_data().to_vec::<f32>().unwrap();
441        assert!(
442            loss_data[0].is_infinite() && loss_data[0] > 0.0,
443            "Expected +inf, got {}",
444            loss_data[0]
445        );
446    }
447
448    #[test]
449    fn test_ctc_loss_zero_infinity_masks_inf_when_enabled() {
450        // Same inputs as above, but zero_infinity=true → loss should be 0.0
451        let device = FlexDevice;
452        let ctc = CTCLossConfig::new().with_zero_infinity(true).init();
453
454        let log_probs = Tensor::<TestBackend, 3>::full([2, 1, 3], (1.0_f32 / 3.0).ln(), &device);
455        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1_i32, 1]], &device);
456        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i32], &device);
457        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i32], &device);
458
459        let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);
460        let loss_data = loss.into_data().to_vec::<f32>().unwrap();
461        assert_approx_equal(&loss_data, &[0.0], 1e-6);
462    }
463
464    #[test]
465    fn test_ctc_loss_zero_infinity_does_not_affect_finite_loss() {
466        // Verify that zero_infinity=true does not change a finite loss value.
467        let device = FlexDevice;
468        let ctc = CTCLossConfig::new().with_zero_infinity(true).init();
469
470        let log_probs = Tensor::<TestBackend, 3>::full([2, 1, 2], 0.5_f32.ln(), &device);
471        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1_i32]], &device);
472        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i32], &device);
473        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1_i32], &device);
474
475        let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);
476        let loss_data = loss.into_data().to_vec::<f32>().unwrap();
477        let expected = -(0.75_f32).ln();
478        assert_approx_equal(&loss_data, &[expected], 1e-3);
479    }
480}
481
482#[cfg(test)]
483mod pytorch_comparison_tests {
484    use super::*;
485    use burn::tensor::activation::log_softmax;
486    use burn_autodiff::Autodiff;
487    use burn_core::tensor::TensorData;
488    use burn_flex::{Flex, FlexDevice};
489
490    type InnerBackend = Flex;
491    type TestBackend = Autodiff<InnerBackend>;
492
493    fn assert_approx_equal(actual: &[f32], expected: &[f32], tol: f32) {
494        assert_eq!(
495            actual.len(),
496            expected.len(),
497            "Length mismatch: actual {} vs expected {}",
498            actual.len(),
499            expected.len()
500        );
501        for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
502            assert!(
503                (a - e).abs() < tol,
504                "Mismatch at index {}: expected {:.6}, got {:.6} (diff: {:.6})",
505                i,
506                e,
507                a,
508                (a - e).abs()
509            );
510        }
511    }
512
513    /// Deterministic logits: sin((t*7 + n*13 + c*3) * 0.1).
514    fn generate_logits(
515        t_size: usize,
516        n_size: usize,
517        c_size: usize,
518        device: &FlexDevice,
519    ) -> Tensor<TestBackend, 3> {
520        let mut data = Vec::with_capacity(t_size * n_size * c_size);
521        for t in 0..t_size {
522            for n in 0..n_size {
523                for c in 0..c_size {
524                    data.push(((t * 7 + n * 13 + c * 3) as f32 * 0.1).sin());
525                }
526            }
527        }
528        Tensor::<TestBackend, 3>::from_data(TensorData::new(data, [t_size, n_size, c_size]), device)
529    }
530
531    /// Runs a CTC forward + backward test and asserts against expected values from PyTorch.
532    ///
533    /// This helper performs the following steps:
534    /// 1. Generates deterministic logits using a sine-wave formula.
535    /// 2. Computes the CTC loss (forward pass).
536    /// 3. Asserts the computed loss matches `expected_losses`.
537    /// 4. Backpropagates the sum of the loss.
538    /// 5. Asserts the resulting gradients w.r.t. logits match `expected_grad_flat`.
539    ///
540    /// # Arguments
541    ///
542    /// - `expected_losses`: per-sample loss values from PyTorch (reduction='none').
543    /// - `expected_grad_flat`: flattened gradient of sum(loss) w.r.t. logits.
544    #[allow(clippy::too_many_arguments)]
545    fn run_comparison(
546        label: &str,
547        t_size: usize,
548        n_size: usize,
549        c_size: usize,
550        targets_flat: Vec<i64>,
551        target_shape: [usize; 2],
552        input_lengths: Vec<i64>,
553        target_lengths: Vec<i64>,
554        blank: usize,
555        expected_losses: &[f32],
556        expected_grad_flat: &[f32],
557        loss_tol: f32,
558        grad_tol: f32,
559    ) {
560        let device = FlexDevice;
561        let ctc = CTCLossConfig::new().with_blank(blank).init();
562
563        let logits = generate_logits(t_size, n_size, c_size, &device).require_grad();
564        let log_probs = log_softmax(logits.clone(), 2);
565
566        let targets = Tensor::<TestBackend, 2, Int>::from_data(
567            TensorData::new(targets_flat, target_shape),
568            &device,
569        );
570        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data(
571            TensorData::new(input_lengths, [n_size]),
572            &device,
573        );
574        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data(
575            TensorData::new(target_lengths, [n_size]),
576            &device,
577        );
578
579        let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);
580        let loss_data = loss.clone().into_data().to_vec::<f32>().unwrap();
581
582        println!("=== {} ===", label);
583        println!("  Loss: {:?}", loss_data);
584        assert_approx_equal(&loss_data, expected_losses, loss_tol);
585
586        let loss_sum = loss.sum();
587        let grads = loss_sum.backward();
588        let logits_grad = logits.grad(&grads).unwrap();
589        let grad_data = logits_grad.into_data().to_vec::<f32>().unwrap();
590        assert_approx_equal(&grad_data, expected_grad_flat, grad_tol);
591    }
592
593    #[test]
594    fn test_ctc_loss_uniform_input_lengths() {
595        // T=5, N=3, C=4, all input_lengths = 5
596        // Expected losses and gradient from PyTorch
597        let expected_losses = [3.5236570835113525_f32, 3.495313882827759, 4.262677192687988];
598        let expected_grad_flat = [
599            -0.1679008007_f32,
600            -0.4595540464,
601            0.2795598209,
602            0.3478950262,
603            -0.3913056254,
604            -0.0832268298,
605            0.2535884976,
606            0.2209439576,
607            -0.0502742566,
608            0.2766197622,
609            0.2054125518,
610            -0.4317580462,
611            -0.0544800088,
612            -0.3144550920,
613            0.0847885981,
614            0.2841464877,
615            -0.1844545156,
616            -0.2063435912,
617            0.2222184092,
618            0.1685796976,
619            0.0278018005,
620            0.2657383382,
621            -0.0336986706,
622            -0.2598414719,
623            -0.0482986756,
624            -0.0098767160,
625            -0.1533526182,
626            0.2115280181,
627            -0.1380317956,
628            -0.2198686600,
629            0.2042596638,
630            0.1536407918,
631            0.0534787849,
632            0.1819230020,
633            -0.2805589139,
634            0.0451571345,
635            -0.0895631388,
636            0.1996460557,
637            -0.2741115987,
638            0.1640286744,
639            -0.2200077325,
640            -0.1693530381,
641            0.2101601064,
642            0.1792006642,
643            0.0398471877,
644            -0.1131042913,
645            -0.2363226712,
646            0.3095797896,
647            -0.2163617164,
648            0.2740726173,
649            -0.2124865055,
650            0.1547756046,
651            -0.4312027395,
652            -0.0446923785,
653            0.2330704331,
654            0.2428246588,
655            -0.0050083841,
656            -0.6256869435,
657            0.2689785957,
658            0.3617166877,
659        ];
660        run_comparison(
661            "T=5, N=3, C=4 (uniform input lengths)",
662            5,
663            3,
664            4,
665            vec![1, 2, 0, 1, 0, 0, 3, 2, 1],
666            [3, 3],
667            vec![5, 5, 5],
668            vec![2, 1, 3],
669            0,
670            &expected_losses,
671            &expected_grad_flat,
672            1e-3,
673            1e-3,
674        );
675    }
676
677    #[test]
678    fn test_ctc_loss_repeated_labels() {
679        // T=8, N=4, C=6, includes consecutive repeated label [1,1,2]
680        // Expected losses and gradient from PyTorch
681        let expected_losses = [
682            8.84203052520752_f32,
683            9.023029327392578,
684            9.398024559020996,
685            9.008068084716797,
686        ];
687        let expected_grad_flat = [
688            -0.2766432464,
689            -0.5202965736,
690            0.1523768753,
691            0.1896236390,
692            0.2200277001,
693            0.2349116206,
694            -0.1854365915,
695            0.2031330466,
696            -0.4260218740,
697            0.1678018719,
698            0.1360142529,
699            0.1045092493,
700            -0.6603536606,
701            0.2278252542,
702            0.1691786796,
703            0.1262856424,
704            0.0972681716,
705            0.0397959016,
706            -0.0894432291,
707            -0.5457318425,
708            0.1490373611,
709            0.1462858170,
710            0.1569476575,
711            0.1829041988,
712            -0.2842915654,
713            -0.4220107496,
714            0.1822281033,
715            0.1889107376,
716            0.1791101843,
717            0.1560532600,
718            -0.1155678406,
719            0.2295538932,
720            -0.2645366490,
721            -0.0288553704,
722            0.1027252972,
723            0.0766806602,
724            -0.5448347330,
725            0.2031028718,
726            0.1589304954,
727            0.1322451383,
728            0.1189499870,
729            -0.0683937520,
730            -0.0873993114,
731            -0.3051757514,
732            -0.2355299890,
733            0.1586059481,
734            0.2018169016,
735            0.2676822543,
736            -0.3225219846,
737            -0.2611543834,
738            0.1922984123,
739            0.1632783115,
740            0.1297036558,
741            0.0983960181,
742            -0.1507159024,
743            0.2256962359,
744            -0.1040333956,
745            -0.1514528394,
746            0.0985243544,
747            0.0819815546,
748            -0.2940836251,
749            0.1586865336,
750            0.1468491107,
751            0.1485087872,
752            0.1639631987,
753            -0.3239239752,
754            -0.0767390430,
755            -0.0434846729,
756            -0.4023587406,
757            -0.0052628326,
758            0.2273432612,
759            0.3005020022,
760            -0.2598774135,
761            -0.2188862711,
762            0.1678501070,
763            0.1352078766,
764            0.1002781317,
765            0.0754275694,
766            -0.1502914876,
767            0.1930875033,
768            -0.0709601715,
769            -0.2219523191,
770            0.1243555173,
771            0.1257609427,
772            -0.0574148744,
773            0.1152269915,
774            0.1307857931,
775            0.1599020809,
776            0.2068412602,
777            -0.5553412437,
778            -0.0536844917,
779            0.0758557543,
780            -0.2106334567,
781            -0.2509877980,
782            0.1757438034,
783            0.2637061775,
784            -0.1759711355,
785            -0.2431350052,
786            0.1071053818,
787            0.1259848624,
788            0.1004033238,
789            0.0856125653,
790            -0.1173698306,
791            0.1213828772,
792            -0.1768893301,
793            -0.2070008069,
794            0.1709136516,
795            0.2089634240,
796            0.0153109450,
797            0.0967332721,
798            0.1268781722,
799            0.1706230640,
800            0.2291058898,
801            -0.6386513710,
802            -0.0536664203,
803            0.1378114969,
804            0.0360041447,
805            -0.2989685237,
806            -0.0084722806,
807            0.1872915775,
808            -0.1523490399,
809            -0.2111770809,
810            -0.0390694551,
811            0.1366800815,
812            0.1302325875,
813            0.1356829405,
814            -0.0982905105,
815            -0.0127884001,
816            -0.3586881459,
817            -0.0259541404,
818            0.2114149332,
819            0.2843062580,
820            -0.0324133746,
821            0.1084750593,
822            0.1447229236,
823            0.1862253845,
824            0.2259712219,
825            -0.6329812407,
826            -0.1173689738,
827            0.1914442331,
828            0.1654772907,
829            -0.1376858056,
830            -0.2194855511,
831            0.1176188141,
832            -0.1529908478,
833            -0.0606661662,
834            -0.3384291232,
835            0.1524862647,
836            0.1777049750,
837            0.2218948901,
838            -0.0923086405,
839            -0.2855934799,
840            -0.3215619624,
841            0.1726681292,
842            0.2303666323,
843            0.2964293361,
844            -0.2508065701,
845            0.1479703039,
846            0.1753441393,
847            0.1917535067,
848            0.1919818372,
849            -0.4562432170,
850            -0.2350299209,
851            0.2257601619,
852            0.1863904297,
853            0.0388212129,
854            -0.2966264784,
855            0.0806845874,
856            -0.1992894858,
857            0.1068909168,
858            -0.5761897564,
859            0.1624972969,
860            0.2155302167,
861            0.2905607820,
862            -0.1168124676,
863            -0.6870660186,
864            0.1488010883,
865            0.1881926507,
866            0.2230074406,
867            0.2438773215,
868            -0.5771554708,
869            0.1980127096,
870            0.1924194694,
871            0.1714663208,
872            0.1415647417,
873            -0.1263078004,
874            -0.3408652246,
875            0.2292248607,
876            0.1707807332,
877            0.1269564927,
878            -0.2634142637,
879            0.0773174241,
880        ];
881        run_comparison(
882            "T=8, N=4, C=6 (repeated labels)",
883            8,
884            4,
885            6,
886            vec![1, 1, 2, 0, 2, 3, 2, 1, 5, 0, 0, 0, 1, 2, 3, 4],
887            [4, 4],
888            vec![8, 8, 8, 8],
889            vec![3, 4, 1, 4],
890            0,
891            &expected_losses,
892            &expected_grad_flat,
893            1e-3,
894            1e-3,
895        );
896    }
897
898    #[test]
899    fn test_ctc_loss_long_sequence() {
900        // T=10, N=2, C=8
901        // Expected losses and gradient from PyTorch
902        let expected_losses = [12.629399299621582, 12.298524856567383];
903        let expected_grad_flat = [
904            -0.2570972741,
905            -0.6013792753,
906            0.1061997041,
907            0.1321590245,
908            0.1533492655,
909            0.1637226790,
910            0.1598964781,
911            0.1431493312,
912            -0.2540431321,
913            0.1788398325,
914            -0.4038805366,
915            0.1477340311,
916            0.1197479516,
917            0.0920107216,
918            0.0686140805,
919            0.0509770736,
920            -0.1364373565,
921            -0.3724762201,
922            0.1489177048,
923            -0.0966964588,
924            0.1463697106,
925            0.1275274903,
926            0.1033692732,
927            0.0794258416,
928            -0.1771971881,
929            0.2073454857,
930            -0.3109439015,
931            0.1249521226,
932            -0.0101635465,
933            0.0692621097,
934            0.0533472970,
935            0.0433975980,
936            -0.1398337185,
937            -0.0874802172,
938            0.1705365479,
939            -0.2174201906,
940            0.1150254831,
941            0.0460043959,
942            0.0647982135,
943            0.0483694859,
944            -0.2332949787,
945            0.1969220787,
946            -0.1270586401,
947            0.1098557115,
948            -0.1364655048,
949            0.0715296715,
950            0.0553609394,
951            0.0631506816,
952            -0.2169117928,
953            0.0929956511,
954            0.1624538749,
955            -0.2009791434,
956            0.0904926360,
957            -0.0248185843,
958            0.0532633252,
959            0.0435040221,
960            -0.2313277274,
961            0.1497355998,
962            -0.0024202778,
963            0.1029939279,
964            -0.2776987851,
965            0.0963881761,
966            0.0351882279,
967            0.1271408647,
968            -0.2590557337,
969            0.1577988416,
970            0.1429322213,
971            -0.1401246637,
972            0.0866033062,
973            -0.1151762009,
974            0.0683368817,
975            0.0586853735,
976            -0.1322475076,
977            0.0806737095,
978            0.0528722852,
979            0.0920089707,
980            -0.3037962914,
981            0.1280544847,
982            -0.1391123086,
983            0.2215466499,
984            -0.1918463260,
985            0.1376975775,
986            0.1160097718,
987            -0.0549413785,
988            0.0970225409,
989            -0.2708687484,
990            0.1147320047,
991            0.0521945432,
992            -0.0504456684,
993            -0.0012221609,
994            0.0644332916,
995            0.0818370953,
996            -0.1036835983,
997            0.1512031406,
998            -0.4072600305,
999            0.2651379406,
1000            -0.0681083873,
1001            0.0860663429,
1002            0.0810486302,
1003            0.0434282124,
1004            0.1056238264,
1005            -0.2994530201,
1006            0.1729898751,
1007            -0.1215954795,
1008            -0.0481944978,
1009            -0.1697723418,
1010            0.0725984722,
1011            0.0692019314,
1012            0.0859903544,
1013            0.1680216491,
1014            -0.4071443677,
1015            0.2292988002,
1016            -0.0205532499,
1017            0.0566616580,
1018            0.0326749459,
1019            0.0861379728,
1020            0.1142501161,
1021            -0.0448331088,
1022            0.2054910213,
1023            -0.4298293889,
1024            -0.0647637174,
1025            -0.4240962267,
1026            0.1013666242,
1027            -0.0110451467,
1028            0.1519176364,
1029            0.1661346704,
1030            -0.0719586164,
1031            0.1524447650,
1032            -0.0496110357,
1033            0.0562372655,
1034            -0.1889088154,
1035            0.1013496071,
1036            0.1339637935,
1037            0.1694275290,
1038            0.2007708699,
1039            -0.4232292175,
1040            -0.0401752405,
1041            -0.2951072752,
1042            0.1443216652,
1043            -0.2857291698,
1044            0.1489982456,
1045            0.1327733696,
1046            0.1096193567,
1047            0.0852990299,
1048            -0.0413062274,
1049            0.0820900649,
1050            -0.7903561592,
1051            0.1329460591,
1052            0.1535883099,
1053            0.1631743014,
1054            0.1585651338,
1055            0.1412984729,
1056            -0.1033771932,
1057            0.1799504310,
1058            0.1697744429,
1059            -0.5749052763,
1060            0.1189445183,
1061            0.0911802500,
1062            0.0679325759,
1063            0.0505003072,
1064        ];
1065        run_comparison(
1066            "T=10, N=2, C=8",
1067            10,
1068            2,
1069            8,
1070            vec![1, 3, 5, 7, 2, 2, 4, 6, 1, 3],
1071            [2, 5],
1072            vec![10, 10],
1073            vec![5, 5],
1074            0,
1075            &expected_losses,
1076            &expected_grad_flat,
1077            1e-3,
1078            1e-3,
1079        );
1080    }
1081
1082    #[test]
1083    fn test_ctc_loss_mixed_input_lengths() {
1084        // T=12, N=3, C=5, input_lengths=[12, 7, 10]
1085        // Expected losses and gradient from PyTorch
1086        let expected_losses = [10.595505714416504, 6.8078508377075195, 7.705057144165039];
1087        let expected_grad_flat = [
1088            -0.4790987670,
1089            -0.2554937005,
1090            0.1991624236,
1091            0.2478453964,
1092            0.2875846624,
1093            -0.3495813310,
1094            0.2268397957,
1095            0.2150714993,
1096            -0.2442178279,
1097            0.1518878639,
1098            -0.2764556706,
1099            0.2474014312,
1100            -0.2137086987,
1101            0.1371368915,
1102            0.1056260392,
1103            -0.2729502618,
1104            -0.3609606028,
1105            0.2159237266,
1106            0.2238420397,
1107            0.1941450834,
1108            -0.2953839302,
1109            0.1920599341,
1110            0.1974952668,
1111            -0.2054278404,
1112            0.1112565696,
1113            -0.1719199270,
1114            0.2299505472,
1115            -0.2864859998,
1116            0.1497263014,
1117            0.0787290633,
1118            -0.2035763413,
1119            -0.3042884767,
1120            0.2126964629,
1121            0.1810975969,
1122            0.1140707731,
1123            -0.2759391963,
1124            0.0975771844,
1125            0.1823379993,
1126            -0.1112988219,
1127            0.1073228419,
1128            -0.1336459517,
1129            0.1869296581,
1130            -0.1996247321,
1131            0.1846873760,
1132            -0.0383463502,
1133            -0.2254105806,
1134            -0.1834360659,
1135            0.1925925612,
1136            0.1462381780,
1137            0.0700158924,
1138            -0.2259973884,
1139            -0.0393539183,
1140            0.1802661419,
1141            -0.0571591072,
1142            0.1422442794,
1143            -0.0609069727,
1144            0.1089282706,
1145            -0.0313654318,
1146            0.2186669111,
1147            -0.2353227735,
1148            -0.2840364873,
1149            -0.0632198900,
1150            0.1755636632,
1151            0.1377806067,
1152            0.0339120962,
1153            -0.1904856712,
1154            -0.2139032930,
1155            0.1827126741,
1156            0.0056131603,
1157            0.2160631120,
1158            -0.0243270602,
1159            -0.0070458520,
1160            0.1070247591,
1161            0.2239368409,
1162            -0.2995886803,
1163            -0.2955487072,
1164            0.0309870224,
1165            0.1654911339,
1166            0.1581364125,
1167            -0.0590658709,
1168            -0.2191396207,
1169            -0.3791662455,
1170            0.1803640425,
1171            0.1225430891,
1172            0.2953987718,
1173            -0.0436352938,
1174            -0.1575258970,
1175            0.1785279512,
1176            0.1756918877,
1177            -0.1530586481,
1178            -0.1834939867,
1179            0.0909025446,
1180            0.1423641294,
1181            0.1959712654,
1182            -0.2457439601,
1183            -0.3619639874,
1184            -0.3929221630,
1185            0.1820438206,
1186            0.2454170734,
1187            0.3274252713,
1188            -0.0628800318,
1189            -0.2567180395,
1190            0.2112283260,
1191            0.0507859327,
1192            0.0575838275,
1193            -0.0587697029,
1194            0.1174769849,
1195            0.0783569664,
1196            0.2290501744,
1197            -0.3661144078,
1198            0.0000000000,
1199            0.0000000000,
1200            0.0000000000,
1201            0.0000000000,
1202            0.0000000000,
1203            -0.0725664943,
1204            -0.1532069892,
1205            0.2162397504,
1206            -0.1248963475,
1207            0.1344300956,
1208            -0.0362483934,
1209            0.1295878887,
1210            -0.0502482466,
1211            0.2470482886,
1212            -0.2901395261,
1213            0.0000000000,
1214            0.0000000000,
1215            0.0000000000,
1216            0.0000000000,
1217            0.0000000000,
1218            -0.1349253207,
1219            0.0867646411,
1220            0.1998746395,
1221            -0.2658679783,
1222            0.1141540110,
1223            -0.0705668628,
1224            0.1519546807,
1225            -0.2509805560,
1226            0.2475892603,
1227            -0.0779965296,
1228            0.0000000000,
1229            0.0000000000,
1230            0.0000000000,
1231            0.0000000000,
1232            0.0000000000,
1233            -0.2338010073,
1234            0.2471641302,
1235            0.1834627241,
1236            -0.3026831448,
1237            0.1058573127,
1238            -0.1155209392,
1239            0.1921830922,
1240            -0.4129956067,
1241            0.2229512781,
1242            0.1133821756,
1243            0.0000000000,
1244            0.0000000000,
1245            0.0000000000,
1246            0.0000000000,
1247            0.0000000000,
1248            0.0000000000,
1249            0.0000000000,
1250            0.0000000000,
1251            0.0000000000,
1252            0.0000000000,
1253            -0.2636392713,
1254            0.2323469073,
1255            -0.2913427949,
1256            0.1800564528,
1257            0.1425786912,
1258            0.0000000000,
1259            0.0000000000,
1260            0.0000000000,
1261            0.0000000000,
1262            0.0000000000,
1263            0.0000000000,
1264            0.0000000000,
1265            0.0000000000,
1266            0.0000000000,
1267            0.0000000000,
1268        ];
1269        run_comparison(
1270            "T=12, N=3, C=5 (mixed input lengths)",
1271            12,
1272            3,
1273            5,
1274            vec![1, 4, 2, 0, 3, 1, 0, 0, 2, 4, 1, 3],
1275            [3, 4],
1276            vec![12, 7, 10],
1277            vec![3, 2, 4],
1278            0,
1279            &expected_losses,
1280            &expected_grad_flat,
1281            1e-3,
1282            1e-3,
1283        );
1284    }
1285
1286    #[test]
1287    fn test_ctc_loss_sum_reduction() {
1288        // Same inputs as comparison_uniform_input_lengths, sum reduction
1289        let device = FlexDevice;
1290        let ctc = CTCLossConfig::new().init();
1291
1292        let logits = generate_logits(5, 3, 4, &device).require_grad();
1293        let log_probs = log_softmax(logits.clone(), 2);
1294        let targets = Tensor::<TestBackend, 2, Int>::from_data(
1295            TensorData::new(vec![1_i32, 2, 0, 1, 0, 0, 3, 2, 1], [3, 3]),
1296            &device,
1297        );
1298        let il = Tensor::<TestBackend, 1, Int>::from_data([5_i32, 5, 5], &device);
1299        let tl = Tensor::<TestBackend, 1, Int>::from_data([2_i32, 1, 3], &device);
1300
1301        let loss = ctc.forward_with_reduction(log_probs, targets, il, tl, Reduction::Sum);
1302        let loss_data = loss.clone().into_data().to_vec::<f32>().unwrap();
1303
1304        let expected_sum = 11.2816486359_f32; // Expected value from PyTorch
1305        assert_approx_equal(&loss_data, &[expected_sum], 1e-3);
1306
1307        let grads = loss.backward();
1308        let logits_grad = logits.grad(&grads).unwrap();
1309        let grad_data = logits_grad.into_data().to_vec::<f32>().unwrap();
1310        // Expected gradient from PyTorch
1311        let expected_grad = [
1312            -0.1679008007_f32,
1313            -0.4595540464,
1314            0.2795598209,
1315            0.3478950262,
1316            -0.3913056254,
1317            -0.0832268298,
1318            0.2535884976,
1319            0.2209439576,
1320            -0.0502742566,
1321            0.2766197622,
1322            0.2054125518,
1323            -0.4317580462,
1324            -0.0544800088,
1325            -0.3144550920,
1326            0.0847885981,
1327            0.2841464877,
1328            -0.1844545156,
1329            -0.2063435912,
1330            0.2222184092,
1331            0.1685796976,
1332            0.0278018005,
1333            0.2657383382,
1334            -0.0336986706,
1335            -0.2598414719,
1336            -0.0482986756,
1337            -0.0098767160,
1338            -0.1533526182,
1339            0.2115280181,
1340            -0.1380317956,
1341            -0.2198686600,
1342            0.2042596638,
1343            0.1536407918,
1344            0.0534787849,
1345            0.1819230020,
1346            -0.2805589139,
1347            0.0451571345,
1348            -0.0895631388,
1349            0.1996460557,
1350            -0.2741115987,
1351            0.1640286744,
1352            -0.2200077325,
1353            -0.1693530381,
1354            0.2101601064,
1355            0.1792006642,
1356            0.0398471877,
1357            -0.1131042913,
1358            -0.2363226712,
1359            0.3095797896,
1360            -0.2163617164,
1361            0.2740726173,
1362            -0.2124865055,
1363            0.1547756046,
1364            -0.4312027395,
1365            -0.0446923785,
1366            0.2330704331,
1367            0.2428246588,
1368            -0.0050083841,
1369            -0.6256869435,
1370            0.2689785957,
1371            0.3617166877,
1372        ];
1373        assert_approx_equal(&grad_data, &expected_grad, 1e-3);
1374    }
1375
1376    #[test]
1377    fn test_ctc_loss_mean_reduction() {
1378        let device = FlexDevice;
1379        let ctc = CTCLossConfig::new().init();
1380
1381        let logits = generate_logits(5, 3, 4, &device).require_grad();
1382        let log_probs = log_softmax(logits.clone(), 2);
1383        let targets = Tensor::<TestBackend, 2, Int>::from_data(
1384            TensorData::new(vec![1_i32, 2, 0, 1, 0, 0, 3, 2, 1], [3, 3]),
1385            &device,
1386        );
1387        let il = Tensor::<TestBackend, 1, Int>::from_data([5_i32, 5, 5], &device);
1388        let tl = Tensor::<TestBackend, 1, Int>::from_data([2_i32, 1, 3], &device);
1389
1390        let loss = ctc.forward_with_reduction(log_probs, targets, il, tl, Reduction::Mean);
1391        let loss_data = loss.clone().into_data().to_vec::<f32>().unwrap();
1392
1393        let expected_mean = 2.2260115147_f32; // Expected value from PyTorch
1394        assert_approx_equal(&loss_data, &[expected_mean], 1e-3);
1395
1396        let grads = loss.backward();
1397        let logits_grad = logits.grad(&grads).unwrap();
1398        let grad_data = logits_grad.into_data().to_vec::<f32>().unwrap();
1399        // Expected gradient from PyTorch
1400        let expected_grad = [
1401            -0.0279834662_f32,
1402            -0.0765923411,
1403            0.0465933047,
1404            0.0579825081,
1405            -0.1304352134,
1406            -0.0277422778,
1407            0.0845294967,
1408            0.0736479908,
1409            -0.0055860290,
1410            0.0307355281,
1411            0.0228236169,
1412            -0.0479731150,
1413            -0.0090800021,
1414            -0.0524091832,
1415            0.0141314333,
1416            0.0473577492,
1417            -0.0614848398,
1418            -0.0687812045,
1419            0.0740728080,
1420            0.0561932363,
1421            0.0030890885,
1422            0.0295264814,
1423            -0.0037442972,
1424            -0.0288712755,
1425            -0.0080497796,
1426            -0.0016461194,
1427            -0.0255587716,
1428            0.0352546684,
1429            -0.0460105985,
1430            -0.0732895583,
1431            0.0680865571,
1432            0.0512135960,
1433            0.0059420872,
1434            0.0202136654,
1435            -0.0311732125,
1436            0.0050174589,
1437            -0.0149271907,
1438            0.0332743451,
1439            -0.0456852652,
1440            0.0273381118,
1441            -0.0733359158,
1442            -0.0564510152,
1443            0.0700533763,
1444            0.0597335547,
1445            0.0044274656,
1446            -0.0125671430,
1447            -0.0262580756,
1448            0.0343977548,
1449            -0.0360602848,
1450            0.0456787720,
1451            -0.0354144201,
1452            0.0257959347,
1453            -0.1437342465,
1454            -0.0148974592,
1455            0.0776901469,
1456            0.0809415579,
1457            -0.0005564869,
1458            -0.0695207715,
1459            0.0298865121,
1460            0.0401907414,
1461        ];
1462        assert_approx_equal(&grad_data, &expected_grad, 1e-3);
1463    }
1464}