Skip to main content

burn_nn/loss/
lp_loss.rs

1use super::Reduction;
2use burn::config::Config;
3use burn::module::Module;
4use burn::tensor::{Tensor, backend::Backend};
5use burn_core as burn;
6
7/// Configuration for the [Lp Loss](LpLoss) module.
8///
9/// # Example
10///
11/// ```ignore
12/// use burn_nn::loss::{LpLossConfig, Reduction};
13///
14/// // Create L1 loss (MAE when using mean reduction)
15/// let l1_loss = LpLossConfig::l1();
16///
17/// // Create L2 loss (MSE when using mean reduction)
18/// let l2_loss = LpLossConfig::l2();
19///
20/// // Create custom Lp loss with p=3
21/// let l3_loss = LpLossConfig::new(3.0).init();
22/// ```
23#[derive(Config, Debug)]
24pub struct LpLossConfig {
25    /// The exponent `p` determining the type of error measurement.
26    ///
27    /// Common values:
28    /// - `p = 1.0`: L1 loss (MAE with mean reduction) - robust to outliers
29    /// - `p = 2.0`: L2 loss (MSE with mean reduction) - standard choice, differentiable everywhere
30    /// - `p > 2.0`: Increasingly sensitive to large errors (outliers)
31    /// - `0 < p < 1`: More robust to outliers than L1 (quasi-norm)
32    pub p: f64,
33}
34
35impl LpLossConfig {
36    /// Initializes a [Lp Loss](LpLoss) module.
37    ///
38    /// # Panics
39    ///
40    /// Panics if `p <= 0`.
41    pub fn init(&self) -> LpLoss {
42        self.assertions();
43        LpLoss { p: self.p }
44    }
45
46    /// Creates L1 loss (p=1).
47    ///
48    /// When used with `Reduction::Mean`, this computes Mean Absolute Error (MAE).
49    /// When used with `Reduction::Sum`, this computes Sum of Absolute Errors (SAE).
50    pub fn l1() -> LpLoss {
51        LpLoss { p: 1.0 }
52    }
53
54    /// Creates L2 loss (p=2).
55    ///
56    /// When used with `Reduction::Mean`, this computes Mean Squared Error (MSE).
57    /// When used with `Reduction::Sum`, this computes Sum of Squared Errors (SSE).
58    pub fn l2() -> LpLoss {
59        LpLoss { p: 2.0 }
60    }
61
62    fn assertions(&self) {
63        assert!(self.p > 0.0, "The order of the norm p must be positive.")
64    }
65}
66
67/// Computes the Lp Loss between predictions and targets.
68///
69/// This loss function computes the element-wise p-th power of absolute errors,
70/// then reduces them via mean or sum.
71///
72/// # Mathematical Definition
73///
74/// For predictions `ŷ` and targets `y`, the element-wise loss is:
75///
76/// ```text
77/// Lᵢ = |ŷᵢ - yᵢ|ᵖ
78/// ```
79///
80/// With mean reduction (default), the final loss is:
81///
82/// ```text
83/// L = (1/n) × Σᵢ |ŷᵢ - yᵢ|ᵖ
84/// ```
85///
86/// # Notes
87///
88/// - This implementation computes `|error|^p`, **not** the Lp norm `(Σ|error|^p)^(1/p)`.
89/// - The `p = 1` case uses an optimized `abs()` operation.
90/// - The `p = 2` case uses an optimized computation `error * error` instead of `powf`.
91///
92/// # Example
93///
94/// ```ignore
95/// use burn_nn::loss::{LpLossConfig, Reduction};
96/// use burn::tensor::Tensor;
97///
98/// // Create L2 loss
99/// let l2_loss = LpLossConfig::l2();
100///
101/// let predictions: Tensor<Backend, 2> = /* model output */;
102/// let targets: Tensor<Backend, 2> = /* ground truth */;
103///
104/// // Compute loss with mean reduction (MSE)
105/// let mse = l2_loss.forward(predictions.clone(), targets.clone(), Reduction::Mean);
106///
107/// // Compute loss with sum reduction (SSE)
108/// let sse = l2_loss.forward(predictions.clone(), targets.clone(), Reduction::Sum);
109///
110/// // Compute loss with no reduction
111/// let unreduced_l2_loss = l2_loss.forward_no_reduction(predictions, targets);
112/// ```
113#[derive(Module, Clone, Debug)]
114pub struct LpLoss {
115    /// The order of the norm (e.g., 1 for L1, 2 for L2).
116    /// Equivalently, the exponent `p` for computing `|error|^p`.
117    pub p: f64,
118}
119
120impl LpLoss {
121    /// Computes the element-wise loss `|error|^p` with reduction.
122    ///
123    /// # Arguments
124    ///
125    /// * `predictions` - The model's predicted values.
126    /// * `targets` - The ground truth target values.
127    /// * `reduction` - Specifies how to reduce the element-wise losses:
128    ///   - `Reduction::Mean` or `Reduction::Auto`: Returns the mean of all element-wise losses.
129    ///   - `Reduction::Sum`: Returns the sum of all element-wise losses.
130    ///
131    /// # Returns
132    ///
133    /// A scalar tensor containing the reduced loss value.
134    ///
135    /// # Shapes
136    ///
137    /// - predictions: `[...dims]` - Any shape
138    /// - targets: `[...dims]` - Must match predictions shape
139    /// - output: `[1]` - Scalar loss value
140    pub fn forward<const D: usize, B: Backend>(
141        &self,
142        predictions: Tensor<B, D>,
143        targets: Tensor<B, D>,
144        reduction: Reduction,
145    ) -> Tensor<B, 1> {
146        let unreduced_loss = self.forward_no_reduction(predictions, targets);
147
148        match reduction {
149            Reduction::Mean | Reduction::Auto => unreduced_loss.mean(),
150            Reduction::Sum => unreduced_loss.sum(),
151            other => panic!("{other:?} reduction is not supported"),
152        }
153    }
154
155    /// Computes the element-wise loss `|error|^p` without reduction.
156    ///
157    /// # Arguments
158    ///
159    /// * `predictions` - The model's predicted values.
160    /// * `targets` - The ground truth target values.
161    ///
162    /// # Returns
163    ///
164    /// A tensor of the same shape as the inputs, containing `|prediction - target|^p`
165    /// for each element.
166    ///
167    /// # Shapes
168    ///
169    /// - predictions: `[...dims]` - Any shape
170    /// - targets: `[...dims]` - Must match predictions shape
171    /// - output: `[...dims]` - Same shape as inputs
172    pub fn forward_no_reduction<const D: usize, B: Backend>(
173        &self,
174        predictions: Tensor<B, D>,
175        targets: Tensor<B, D>,
176    ) -> Tensor<B, D> {
177        let error = predictions.sub(targets);
178
179        // Use simplified/optimized expressions for common cases (p = 1, p = 2)
180        if self.p == 1.0 {
181            // L1 loss
182            error.abs()
183        } else if self.p == 2.0 {
184            // L2 loss
185            error.clone().mul(error)
186        } else {
187            error.abs().powf_scalar(self.p)
188        }
189    }
190
191    /// Computes the element-wise loss `|error|^p` with reduction over specified dimensions.
192    ///
193    /// Calculates element-wise `|predictions - targets|^p`, then takes the mean
194    /// over the specified dimensions. Useful for per-sample or per-channel losses (e.g., when
195    /// working with images).
196    ///
197    /// Dimensions can be provided in any order. They are sorted internally and
198    /// reduced from highest to lowest to ensure indices remain valid.
199    ///
200    /// # Arguments
201    ///
202    /// * `predictions` - The model's predicted values.
203    /// * `targets` - The ground truth target values.
204    /// * `dims` - Dimensions to reduce over.
205    ///
206    /// # Returns
207    ///
208    /// A tensor with the specified dimensions reduced to size 1.
209    ///
210    /// # Example
211    ///
212    /// ```ignore
213    /// // Image tensor: [batch, C, H, W]
214    /// let l2_loss = LpLossConfig::l2();
215    ///
216    /// // Per-image MSE for PSNR: reduce over C, H, W → [batch, 1, 1, 1]
217    /// let mse_per_image = l2_loss.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
218    /// ```
219    pub fn forward_reduce_dims<const D: usize, B: Backend>(
220        &self,
221        predictions: Tensor<B, D>,
222        targets: Tensor<B, D>,
223        dims: &[usize],
224    ) -> Tensor<B, D> {
225        let error = self.forward_no_reduction(predictions, targets);
226
227        // Sort the dimensions to ascending order
228        let mut sorted_dims = dims.to_vec();
229        sorted_dims.sort();
230
231        // Reduce over specified dimensions
232        error.mean_dims(sorted_dims.as_slice())
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use crate::TestBackend;
240    use burn::tensor::TensorData;
241    use burn::tensor::{Tolerance, ops::FloatElem};
242    type FT = FloatElem<TestBackend>;
243
244    #[test]
245    fn test_lp_loss_l1_constructor() {
246        let loss_func_l1 = LpLossConfig::l1();
247        let loss_func_p1 = LpLossConfig::new(1.0).init();
248        assert_eq!(loss_func_l1.p, 1.0);
249        assert_eq!(loss_func_l1.p, loss_func_p1.p);
250    }
251
252    #[test]
253    fn test_lp_loss_l2_constructor() {
254        let loss_func_l2 = LpLossConfig::l2();
255        let loss_func_p2 = LpLossConfig::new(2.0).init();
256        assert_eq!(loss_func_l2.p, 2.0);
257        assert_eq!(loss_func_l2.p, loss_func_p2.p);
258    }
259
260    #[test]
261    fn test_lp_loss_l1() {
262        let device = Default::default();
263        let predictions = Tensor::<TestBackend, 2>::from_data(
264            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
265            &device,
266        );
267
268        let targets = Tensor::<TestBackend, 2>::from_data(
269            TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
270            &device,
271        );
272
273        let loss_func = LpLossConfig::l1();
274        let loss_no_reduction =
275            loss_func.forward_no_reduction(predictions.clone(), targets.clone());
276        let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
277        let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
278
279        let expected = TensorData::from([[1.0, 1.0], [0.0, 2.0]]);
280        loss_no_reduction.into_data().assert_eq(&expected, false);
281
282        let expected = TensorData::from([1.0]);
283        loss_auto.into_data().assert_eq(&expected, false);
284
285        let expected = TensorData::from([4.0]);
286        loss_sum.into_data().assert_eq(&expected, false);
287    }
288
289    #[test]
290    fn test_lp_loss_l2() {
291        let device = Default::default();
292        let predictions = Tensor::<TestBackend, 2>::from_data(
293            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
294            &device,
295        );
296
297        let targets = Tensor::<TestBackend, 2>::from_data(
298            TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
299            &device,
300        );
301
302        let loss_func = LpLossConfig::l2();
303        let loss_no_reduction =
304            loss_func.forward_no_reduction(predictions.clone(), targets.clone());
305        let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
306        let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
307
308        let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]);
309        loss_no_reduction.into_data().assert_eq(&expected, false);
310
311        let expected = TensorData::from([1.5]);
312        loss_auto.into_data().assert_eq(&expected, false);
313
314        let expected = TensorData::from([6.0]);
315        loss_sum.into_data().assert_eq(&expected, false);
316    }
317
318    #[test]
319    fn test_lp_loss_p_half() {
320        // L0.5 quasi-norm: more robust to outliers than L1
321        let device = Default::default();
322        let predictions = Tensor::<TestBackend, 2>::from_data(
323            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
324            &device,
325        );
326
327        let targets = Tensor::<TestBackend, 2>::from_data(
328            TensorData::from([[2.0, 1.0], [3.0, 0.0]]),
329            &device,
330        );
331
332        let loss_func = LpLossConfig::new(0.5).init();
333        let loss_no_reduction =
334            loss_func.forward_no_reduction(predictions.clone(), targets.clone());
335        let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
336        let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
337
338        // |1-2|^0.5 = 1, |2-1|^0.5 = 1, |3-3|^0.5 = 0, |4-0|^0.5 = 2
339        let expected = TensorData::from([[1.0, 1.0], [0.0, 2.0]]);
340        loss_no_reduction.into_data().assert_eq(&expected, false);
341
342        let expected = TensorData::from([1.0]);
343        loss_auto.into_data().assert_eq(&expected, false);
344
345        let expected = TensorData::from([4.0]);
346        loss_sum.into_data().assert_eq(&expected, false);
347    }
348
349    #[test]
350    fn test_lp_loss_p3() {
351        // L3 norm: more sensitive to outliers than L2
352        let device = Default::default();
353        let predictions = Tensor::<TestBackend, 2>::from_data(
354            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
355            &device,
356        );
357
358        let targets = Tensor::<TestBackend, 2>::from_data(
359            TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
360            &device,
361        );
362
363        let loss_func = LpLossConfig::new(3.0).init();
364        let loss_no_reduction =
365            loss_func.forward_no_reduction(predictions.clone(), targets.clone());
366        let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
367        let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
368
369        // |1-2|^3 = 1, |2-1|^3 = 1, |3-3|^3 = 0, |4-2|^3 = 8
370        let expected = TensorData::from([[1.0, 1.0], [0.0, 8.0]]);
371        loss_no_reduction.into_data().assert_eq(&expected, false);
372
373        let expected = TensorData::from([2.5]);
374        loss_auto.into_data().assert_eq(&expected, false);
375
376        let expected = TensorData::from([10.0]);
377        loss_sum.into_data().assert_eq(&expected, false);
378    }
379
380    #[test]
381    fn test_lp_loss_zero_error() {
382        // Test when predictions exactly match targets
383        let device = Default::default();
384        let predictions = Tensor::<TestBackend, 2>::from_data(
385            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
386            &device,
387        );
388
389        let targets = predictions.clone();
390
391        let loss_func_l1 = LpLossConfig::l1();
392        let loss_func_l2 = LpLossConfig::l2();
393
394        let l1_loss = loss_func_l1.forward(predictions.clone(), targets.clone(), Reduction::Auto);
395        let l2_loss = loss_func_l2.forward(predictions, targets, Reduction::Auto);
396
397        let expected = TensorData::from([0.0]);
398        l1_loss.into_data().assert_eq(&expected, false);
399        l2_loss.into_data().assert_eq(&expected, false);
400    }
401
402    #[test]
403    fn test_lp_loss_negative_errors() {
404        // Test that negative errors are handled correctly (absolute value)
405        let device = Default::default();
406        let predictions =
407            Tensor::<TestBackend, 1>::from_data(TensorData::from([1.0, 2.0, 3.0]), &device);
408        let targets =
409            Tensor::<TestBackend, 1>::from_data(TensorData::from([3.0, 4.0, 5.0]), &device);
410        let loss_func_l1 = LpLossConfig::l1();
411        let loss_func_p1 = LpLossConfig::new(1.0).init();
412
413        let loss_no_reduction_l1 =
414            loss_func_l1.forward_no_reduction(predictions.clone(), targets.clone());
415        let loss_no_reduction_p1 = loss_func_p1.forward_no_reduction(predictions, targets);
416
417        // All errors are negative: 1-3=-2, 2-4=-2, 3-5=-2, but |error| = 2
418        let expected = TensorData::from([2.0, 2.0, 2.0]);
419        loss_no_reduction_l1.into_data().assert_eq(&expected, false);
420        loss_no_reduction_p1.into_data().assert_eq(&expected, false);
421    }
422
423    #[test]
424    fn test_lp_loss_3d_tensor() {
425        let device = Default::default();
426        let predictions = Tensor::<TestBackend, 3>::from_data(
427            TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
428            &device,
429        );
430        let targets = Tensor::<TestBackend, 3>::from_data(
431            TensorData::from([[[0.0, 2.0], [3.0, 5.0]], [[4.0, 6.0], [7.0, 10.0]]]),
432            &device,
433        );
434        let loss_func_l2 = LpLossConfig::l2();
435        let loss_func_p2 = LpLossConfig::new(2.0).init();
436
437        let loss_l2 = loss_func_l2.forward(predictions.clone(), targets.clone(), Reduction::Auto);
438        let loss_p2 = loss_func_p2.forward(predictions, targets, Reduction::Auto);
439
440        // Errors: 1, 0, 0, -1, 1, 0, 0, -2
441        // Squared: 1, 0, 0, 1, 1, 0, 0, 4
442        // Mean: 7/8 = 0.875
443        let expected = TensorData::from([0.875]);
444        loss_l2.into_data().assert_eq(&expected, false);
445        loss_p2.into_data().assert_eq(&expected, false);
446    }
447
448    #[test]
449    #[should_panic(expected = "The order of the norm p must be positive.")]
450    fn test_lp_loss_negative_p_panics() {
451        let _ = LpLossConfig::new(-1.0).init();
452    }
453
454    #[test]
455    #[should_panic(expected = "The order of the norm p must be positive.")]
456    fn test_lp_loss_zero_p_panics() {
457        let _ = LpLossConfig::new(0.0).init();
458    }
459
460    #[test]
461    fn test_lp_loss_fractional_p() {
462        // Test p = 1.5
463        let device = Default::default();
464        let predictions =
465            Tensor::<TestBackend, 1>::from_data(TensorData::from([0.0, 4.0]), &device);
466
467        let targets = Tensor::<TestBackend, 1>::from_data(TensorData::from([1.0, 0.0]), &device);
468
469        let loss_func = LpLossConfig::new(1.5).init();
470        let loss_no_reduction = loss_func.forward_no_reduction(predictions, targets);
471
472        // |0-1|^1.5 = 1, |4-0|^1.5 = 8
473        let expected = TensorData::from([1.0, 8.0]);
474        loss_no_reduction.into_data().assert_eq(&expected, false);
475    }
476
477    #[test]
478    fn test_forward_reduce_dims_single_dim() {
479        let device = Default::default();
480        // Shape: [2, 3]
481        let predictions = Tensor::<TestBackend, 2>::from_data(
482            TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
483            &device,
484        );
485        let targets = Tensor::<TestBackend, 2>::from_data(
486            TensorData::from([[0.0, 2.0, 6.0], [1.0, 5.0, 6.0]]),
487            &device,
488        );
489        let loss_func_l2 = LpLossConfig::l2();
490        let loss_func_p2 = LpLossConfig::new(2.0).init();
491
492        // Reduce over dim 1 -> should give [2, 1] shape
493        let loss_l2 = loss_func_l2.forward_reduce_dims(predictions.clone(), targets.clone(), &[1]);
494        let loss_p2 = loss_func_p2.forward_reduce_dims(predictions, targets, &[1]);
495
496        // Errors row 0: [1, 0, -3] -> squared: [1, 0, 9] -> mean: 10/3
497        // Errors row 1: [3, 0, 0] -> squared: [9, 0, 0] -> mean: 3
498        let expected = TensorData::from([[10.0 / 3.0], [3.0]]);
499        loss_l2
500            .into_data()
501            .assert_approx_eq::<FT>(&expected, Tolerance::default());
502        loss_p2
503            .into_data()
504            .assert_approx_eq::<FT>(&expected, Tolerance::default());
505    }
506
507    #[test]
508    fn test_forward_reduce_dims_first_dim() {
509        let device = Default::default();
510        // Shape: [2, 3]
511        let predictions = Tensor::<TestBackend, 2>::from_data(
512            TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
513            &device,
514        );
515        let targets = Tensor::<TestBackend, 2>::from_data(
516            TensorData::from([[0.0, 2.0, 6.0], [1.0, 5.0, 6.0]]),
517            &device,
518        );
519        let loss_func = LpLossConfig::l2();
520
521        // Reduce over dim 0 -> should give [1, 3] shape
522        let loss = loss_func.forward_reduce_dims(predictions, targets, &[0]);
523
524        // Squared errors: [[1, 0, 9], [9, 0, 0]]
525        // Mean over dim 0: [5, 0, 4.5]
526        let expected = TensorData::from([[5.0, 0.0, 4.5]]);
527        loss.into_data()
528            .assert_approx_eq::<FT>(&expected, Tolerance::default());
529    }
530
531    #[test]
532    fn test_forward_reduce_dims_multiple_dims() {
533        let device = Default::default();
534        // Shape: [2, 2, 2]
535        let predictions = Tensor::<TestBackend, 3>::from_data(
536            TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
537            &device,
538        );
539        let targets = Tensor::<TestBackend, 3>::from_data(
540            TensorData::from([[[0.0, 2.0], [3.0, 6.0]], [[4.0, 6.0], [7.0, 10.0]]]),
541            &device,
542        );
543        let loss_func = LpLossConfig::l2();
544
545        // Reduce over dims 1 and 2 -> should give [2, 1, 1] shape
546        let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2]);
547
548        // Batch 0 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 5/4 = 1.25
549        // Batch 1 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 5/4 = 1.25
550        let expected = TensorData::from([[[1.25]], [[1.25]]]);
551        loss.into_data()
552            .assert_approx_eq::<FT>(&expected, Tolerance::default());
553    }
554
555    #[test]
556    fn test_forward_reduce_dims_all_dims() {
557        let device = Default::default();
558        // Shape: [2, 2]
559        let predictions = Tensor::<TestBackend, 2>::from_data(
560            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
561            &device,
562        );
563        let targets = Tensor::<TestBackend, 2>::from_data(
564            TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
565            &device,
566        );
567        let loss_func = LpLossConfig::l2();
568
569        // Reduce over all dims -> should give [1, 1] shape
570        let loss = loss_func.forward_reduce_dims(predictions, targets, &[0, 1]);
571
572        // Errors: [[-1, 1], [0, 2]] -> squared: [[1, 1], [0, 4]] -> mean: 1.5
573        let expected = TensorData::from([[1.5]]);
574        loss.into_data()
575            .assert_approx_eq::<FT>(&expected, Tolerance::default());
576    }
577
578    #[test]
579    fn test_forward_reduce_dims_image_batch() {
580        // Simulate per-image loss for [batch, C, H, W] tensor (common use case for PSNR)
581        let device = Default::default();
582        // Shape: [2, 1, 2, 2] (batch=2, C=1, H=2, W=2)
583        let predictions = Tensor::<TestBackend, 4>::from_data(
584            TensorData::from([
585                [[[1.0, 2.0], [3.0, 4.0]]], // Image 1
586                [[[5.0, 6.0], [7.0, 8.0]]], // Image 2
587            ]),
588            &device,
589        );
590        let targets = Tensor::<TestBackend, 4>::from_data(
591            TensorData::from([
592                [[[0.0, 2.0], [3.0, 6.0]]], // Target 1
593                [[[5.0, 5.0], [7.0, 7.0]]], // Target 2
594            ]),
595            &device,
596        );
597        let loss_func = LpLossConfig::l2();
598
599        // Reduce over C, H, W (dims 1, 2, 3) to get per-image MSE
600        let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
601
602        // Image 1 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 1.25
603        // Image 2 errors: [[0, 1], [0, 1]] -> squared: [[0, 1], [0, 1]] -> mean: 0.5
604        let expected = TensorData::from([[[[1.25]]], [[[0.5]]]]);
605        loss.into_data()
606            .assert_approx_eq::<FT>(&expected, Tolerance::default());
607    }
608
609    #[test]
610    fn test_forward_reduce_dims_with_p1() {
611        let device = Default::default();
612        // Shape: [2, 3]
613        let predictions = Tensor::<TestBackend, 2>::from_data(
614            TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
615            &device,
616        );
617        let targets = Tensor::<TestBackend, 2>::from_data(
618            TensorData::from([[0.0, 5.0, 3.0], [1.0, 5.0, 9.0]]),
619            &device,
620        );
621        let loss_func = LpLossConfig::l1();
622
623        // Reduce over dim 1 -> should give [2, 1] shape
624        let loss = loss_func.forward_reduce_dims(predictions, targets, &[1]);
625
626        // Abs errors row 0: [1, 3, 0] -> mean: 4/3
627        // Abs errors row 1: [3, 0, 3] -> mean: 2
628        let expected = TensorData::from([[4.0 / 3.0], [2.0]]);
629        loss.into_data()
630            .assert_approx_eq::<FT>(&expected, Tolerance::default());
631    }
632
633    #[test]
634    fn test_forward_reduce_dims_empty_dims() {
635        // Reducing over no dimensions should return the unreduced loss
636        let device = Default::default();
637        let predictions = Tensor::<TestBackend, 2>::from_data(
638            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
639            &device,
640        );
641        let targets = Tensor::<TestBackend, 2>::from_data(
642            TensorData::from([[0.0, 2.0], [3.0, 6.0]]),
643            &device,
644        );
645        let loss_func = LpLossConfig::l2();
646        let loss_reduce_dims =
647            loss_func.forward_reduce_dims(predictions.clone(), targets.clone(), &[]);
648        let loss_no_reduction = loss_func.forward_no_reduction(predictions, targets);
649
650        // Should be equivalent
651        loss_reduce_dims
652            .into_data()
653            .assert_eq(&loss_no_reduction.into_data(), true);
654    }
655
656    #[test]
657    fn test_forward_reduce_dims_zero_error() {
658        let device = Default::default();
659        // Shape: [2, 2, 2]
660        let predictions = Tensor::<TestBackend, 3>::from_data(
661            TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
662            &device,
663        );
664        let targets = predictions.clone();
665        let loss_func = LpLossConfig::l2();
666        let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2]);
667
668        // All zeros, reduced to shape: [2, 1, 1]
669        let expected = TensorData::from([[[0.0]], [[0.0]]]);
670        loss.into_data().assert_eq(&expected, false);
671    }
672}