1use super::Reduction;
2use burn::config::Config;
3use burn::module::Module;
4use burn::tensor::{Tensor, backend::Backend};
5use burn_core as burn;
6
7#[derive(Config, Debug)]
24pub struct LpLossConfig {
25 pub p: f64,
33}
34
35impl LpLossConfig {
36 pub fn init(&self) -> LpLoss {
42 self.assertions();
43 LpLoss { p: self.p }
44 }
45
46 pub fn l1() -> LpLoss {
51 LpLoss { p: 1.0 }
52 }
53
54 pub fn l2() -> LpLoss {
59 LpLoss { p: 2.0 }
60 }
61
62 fn assertions(&self) {
63 assert!(self.p > 0.0, "The order of the norm p must be positive.")
64 }
65}
66
67#[derive(Module, Clone, Debug)]
114pub struct LpLoss {
115 pub p: f64,
118}
119
120impl LpLoss {
121 pub fn forward<const D: usize, B: Backend>(
141 &self,
142 predictions: Tensor<B, D>,
143 targets: Tensor<B, D>,
144 reduction: Reduction,
145 ) -> Tensor<B, 1> {
146 let unreduced_loss = self.forward_no_reduction(predictions, targets);
147
148 match reduction {
149 Reduction::Mean | Reduction::Auto => unreduced_loss.mean(),
150 Reduction::Sum => unreduced_loss.sum(),
151 other => panic!("{other:?} reduction is not supported"),
152 }
153 }
154
155 pub fn forward_no_reduction<const D: usize, B: Backend>(
173 &self,
174 predictions: Tensor<B, D>,
175 targets: Tensor<B, D>,
176 ) -> Tensor<B, D> {
177 let error = predictions.sub(targets);
178
179 if self.p == 1.0 {
181 error.abs()
183 } else if self.p == 2.0 {
184 error.clone().mul(error)
186 } else {
187 error.abs().powf_scalar(self.p)
188 }
189 }
190
191 pub fn forward_reduce_dims<const D: usize, B: Backend>(
220 &self,
221 predictions: Tensor<B, D>,
222 targets: Tensor<B, D>,
223 dims: &[usize],
224 ) -> Tensor<B, D> {
225 let error = self.forward_no_reduction(predictions, targets);
226
227 let mut sorted_dims = dims.to_vec();
229 sorted_dims.sort();
230
231 error.mean_dims(sorted_dims.as_slice())
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use crate::TestBackend;
240 use burn::tensor::TensorData;
241 use burn::tensor::{Tolerance, ops::FloatElem};
242 type FT = FloatElem<TestBackend>;
243
244 #[test]
245 fn test_lp_loss_l1_constructor() {
246 let loss_func_l1 = LpLossConfig::l1();
247 let loss_func_p1 = LpLossConfig::new(1.0).init();
248 assert_eq!(loss_func_l1.p, 1.0);
249 assert_eq!(loss_func_l1.p, loss_func_p1.p);
250 }
251
252 #[test]
253 fn test_lp_loss_l2_constructor() {
254 let loss_func_l2 = LpLossConfig::l2();
255 let loss_func_p2 = LpLossConfig::new(2.0).init();
256 assert_eq!(loss_func_l2.p, 2.0);
257 assert_eq!(loss_func_l2.p, loss_func_p2.p);
258 }
259
260 #[test]
261 fn test_lp_loss_l1() {
262 let device = Default::default();
263 let predictions = Tensor::<TestBackend, 2>::from_data(
264 TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
265 &device,
266 );
267
268 let targets = Tensor::<TestBackend, 2>::from_data(
269 TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
270 &device,
271 );
272
273 let loss_func = LpLossConfig::l1();
274 let loss_no_reduction =
275 loss_func.forward_no_reduction(predictions.clone(), targets.clone());
276 let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
277 let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
278
279 let expected = TensorData::from([[1.0, 1.0], [0.0, 2.0]]);
280 loss_no_reduction.into_data().assert_eq(&expected, false);
281
282 let expected = TensorData::from([1.0]);
283 loss_auto.into_data().assert_eq(&expected, false);
284
285 let expected = TensorData::from([4.0]);
286 loss_sum.into_data().assert_eq(&expected, false);
287 }
288
289 #[test]
290 fn test_lp_loss_l2() {
291 let device = Default::default();
292 let predictions = Tensor::<TestBackend, 2>::from_data(
293 TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
294 &device,
295 );
296
297 let targets = Tensor::<TestBackend, 2>::from_data(
298 TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
299 &device,
300 );
301
302 let loss_func = LpLossConfig::l2();
303 let loss_no_reduction =
304 loss_func.forward_no_reduction(predictions.clone(), targets.clone());
305 let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
306 let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
307
308 let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]);
309 loss_no_reduction.into_data().assert_eq(&expected, false);
310
311 let expected = TensorData::from([1.5]);
312 loss_auto.into_data().assert_eq(&expected, false);
313
314 let expected = TensorData::from([6.0]);
315 loss_sum.into_data().assert_eq(&expected, false);
316 }
317
318 #[test]
319 fn test_lp_loss_p_half() {
320 let device = Default::default();
322 let predictions = Tensor::<TestBackend, 2>::from_data(
323 TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
324 &device,
325 );
326
327 let targets = Tensor::<TestBackend, 2>::from_data(
328 TensorData::from([[2.0, 1.0], [3.0, 0.0]]),
329 &device,
330 );
331
332 let loss_func = LpLossConfig::new(0.5).init();
333 let loss_no_reduction =
334 loss_func.forward_no_reduction(predictions.clone(), targets.clone());
335 let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
336 let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
337
338 let expected = TensorData::from([[1.0, 1.0], [0.0, 2.0]]);
340 loss_no_reduction.into_data().assert_eq(&expected, false);
341
342 let expected = TensorData::from([1.0]);
343 loss_auto.into_data().assert_eq(&expected, false);
344
345 let expected = TensorData::from([4.0]);
346 loss_sum.into_data().assert_eq(&expected, false);
347 }
348
349 #[test]
350 fn test_lp_loss_p3() {
351 let device = Default::default();
353 let predictions = Tensor::<TestBackend, 2>::from_data(
354 TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
355 &device,
356 );
357
358 let targets = Tensor::<TestBackend, 2>::from_data(
359 TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
360 &device,
361 );
362
363 let loss_func = LpLossConfig::new(3.0).init();
364 let loss_no_reduction =
365 loss_func.forward_no_reduction(predictions.clone(), targets.clone());
366 let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
367 let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
368
369 let expected = TensorData::from([[1.0, 1.0], [0.0, 8.0]]);
371 loss_no_reduction.into_data().assert_eq(&expected, false);
372
373 let expected = TensorData::from([2.5]);
374 loss_auto.into_data().assert_eq(&expected, false);
375
376 let expected = TensorData::from([10.0]);
377 loss_sum.into_data().assert_eq(&expected, false);
378 }
379
380 #[test]
381 fn test_lp_loss_zero_error() {
382 let device = Default::default();
384 let predictions = Tensor::<TestBackend, 2>::from_data(
385 TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
386 &device,
387 );
388
389 let targets = predictions.clone();
390
391 let loss_func_l1 = LpLossConfig::l1();
392 let loss_func_l2 = LpLossConfig::l2();
393
394 let l1_loss = loss_func_l1.forward(predictions.clone(), targets.clone(), Reduction::Auto);
395 let l2_loss = loss_func_l2.forward(predictions, targets, Reduction::Auto);
396
397 let expected = TensorData::from([0.0]);
398 l1_loss.into_data().assert_eq(&expected, false);
399 l2_loss.into_data().assert_eq(&expected, false);
400 }
401
402 #[test]
403 fn test_lp_loss_negative_errors() {
404 let device = Default::default();
406 let predictions =
407 Tensor::<TestBackend, 1>::from_data(TensorData::from([1.0, 2.0, 3.0]), &device);
408 let targets =
409 Tensor::<TestBackend, 1>::from_data(TensorData::from([3.0, 4.0, 5.0]), &device);
410 let loss_func_l1 = LpLossConfig::l1();
411 let loss_func_p1 = LpLossConfig::new(1.0).init();
412
413 let loss_no_reduction_l1 =
414 loss_func_l1.forward_no_reduction(predictions.clone(), targets.clone());
415 let loss_no_reduction_p1 = loss_func_p1.forward_no_reduction(predictions, targets);
416
417 let expected = TensorData::from([2.0, 2.0, 2.0]);
419 loss_no_reduction_l1.into_data().assert_eq(&expected, false);
420 loss_no_reduction_p1.into_data().assert_eq(&expected, false);
421 }
422
423 #[test]
424 fn test_lp_loss_3d_tensor() {
425 let device = Default::default();
426 let predictions = Tensor::<TestBackend, 3>::from_data(
427 TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
428 &device,
429 );
430 let targets = Tensor::<TestBackend, 3>::from_data(
431 TensorData::from([[[0.0, 2.0], [3.0, 5.0]], [[4.0, 6.0], [7.0, 10.0]]]),
432 &device,
433 );
434 let loss_func_l2 = LpLossConfig::l2();
435 let loss_func_p2 = LpLossConfig::new(2.0).init();
436
437 let loss_l2 = loss_func_l2.forward(predictions.clone(), targets.clone(), Reduction::Auto);
438 let loss_p2 = loss_func_p2.forward(predictions, targets, Reduction::Auto);
439
440 let expected = TensorData::from([0.875]);
444 loss_l2.into_data().assert_eq(&expected, false);
445 loss_p2.into_data().assert_eq(&expected, false);
446 }
447
448 #[test]
449 #[should_panic(expected = "The order of the norm p must be positive.")]
450 fn test_lp_loss_negative_p_panics() {
451 let _ = LpLossConfig::new(-1.0).init();
452 }
453
454 #[test]
455 #[should_panic(expected = "The order of the norm p must be positive.")]
456 fn test_lp_loss_zero_p_panics() {
457 let _ = LpLossConfig::new(0.0).init();
458 }
459
460 #[test]
461 fn test_lp_loss_fractional_p() {
462 let device = Default::default();
464 let predictions =
465 Tensor::<TestBackend, 1>::from_data(TensorData::from([0.0, 4.0]), &device);
466
467 let targets = Tensor::<TestBackend, 1>::from_data(TensorData::from([1.0, 0.0]), &device);
468
469 let loss_func = LpLossConfig::new(1.5).init();
470 let loss_no_reduction = loss_func.forward_no_reduction(predictions, targets);
471
472 let expected = TensorData::from([1.0, 8.0]);
474 loss_no_reduction.into_data().assert_eq(&expected, false);
475 }
476
477 #[test]
478 fn test_forward_reduce_dims_single_dim() {
479 let device = Default::default();
480 let predictions = Tensor::<TestBackend, 2>::from_data(
482 TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
483 &device,
484 );
485 let targets = Tensor::<TestBackend, 2>::from_data(
486 TensorData::from([[0.0, 2.0, 6.0], [1.0, 5.0, 6.0]]),
487 &device,
488 );
489 let loss_func_l2 = LpLossConfig::l2();
490 let loss_func_p2 = LpLossConfig::new(2.0).init();
491
492 let loss_l2 = loss_func_l2.forward_reduce_dims(predictions.clone(), targets.clone(), &[1]);
494 let loss_p2 = loss_func_p2.forward_reduce_dims(predictions, targets, &[1]);
495
496 let expected = TensorData::from([[10.0 / 3.0], [3.0]]);
499 loss_l2
500 .into_data()
501 .assert_approx_eq::<FT>(&expected, Tolerance::default());
502 loss_p2
503 .into_data()
504 .assert_approx_eq::<FT>(&expected, Tolerance::default());
505 }
506
507 #[test]
508 fn test_forward_reduce_dims_first_dim() {
509 let device = Default::default();
510 let predictions = Tensor::<TestBackend, 2>::from_data(
512 TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
513 &device,
514 );
515 let targets = Tensor::<TestBackend, 2>::from_data(
516 TensorData::from([[0.0, 2.0, 6.0], [1.0, 5.0, 6.0]]),
517 &device,
518 );
519 let loss_func = LpLossConfig::l2();
520
521 let loss = loss_func.forward_reduce_dims(predictions, targets, &[0]);
523
524 let expected = TensorData::from([[5.0, 0.0, 4.5]]);
527 loss.into_data()
528 .assert_approx_eq::<FT>(&expected, Tolerance::default());
529 }
530
531 #[test]
532 fn test_forward_reduce_dims_multiple_dims() {
533 let device = Default::default();
534 let predictions = Tensor::<TestBackend, 3>::from_data(
536 TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
537 &device,
538 );
539 let targets = Tensor::<TestBackend, 3>::from_data(
540 TensorData::from([[[0.0, 2.0], [3.0, 6.0]], [[4.0, 6.0], [7.0, 10.0]]]),
541 &device,
542 );
543 let loss_func = LpLossConfig::l2();
544
545 let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2]);
547
548 let expected = TensorData::from([[[1.25]], [[1.25]]]);
551 loss.into_data()
552 .assert_approx_eq::<FT>(&expected, Tolerance::default());
553 }
554
555 #[test]
556 fn test_forward_reduce_dims_all_dims() {
557 let device = Default::default();
558 let predictions = Tensor::<TestBackend, 2>::from_data(
560 TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
561 &device,
562 );
563 let targets = Tensor::<TestBackend, 2>::from_data(
564 TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
565 &device,
566 );
567 let loss_func = LpLossConfig::l2();
568
569 let loss = loss_func.forward_reduce_dims(predictions, targets, &[0, 1]);
571
572 let expected = TensorData::from([[1.5]]);
574 loss.into_data()
575 .assert_approx_eq::<FT>(&expected, Tolerance::default());
576 }
577
578 #[test]
579 fn test_forward_reduce_dims_image_batch() {
580 let device = Default::default();
582 let predictions = Tensor::<TestBackend, 4>::from_data(
584 TensorData::from([
585 [[[1.0, 2.0], [3.0, 4.0]]], [[[5.0, 6.0], [7.0, 8.0]]], ]),
588 &device,
589 );
590 let targets = Tensor::<TestBackend, 4>::from_data(
591 TensorData::from([
592 [[[0.0, 2.0], [3.0, 6.0]]], [[[5.0, 5.0], [7.0, 7.0]]], ]),
595 &device,
596 );
597 let loss_func = LpLossConfig::l2();
598
599 let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
601
602 let expected = TensorData::from([[[[1.25]]], [[[0.5]]]]);
605 loss.into_data()
606 .assert_approx_eq::<FT>(&expected, Tolerance::default());
607 }
608
609 #[test]
610 fn test_forward_reduce_dims_with_p1() {
611 let device = Default::default();
612 let predictions = Tensor::<TestBackend, 2>::from_data(
614 TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
615 &device,
616 );
617 let targets = Tensor::<TestBackend, 2>::from_data(
618 TensorData::from([[0.0, 5.0, 3.0], [1.0, 5.0, 9.0]]),
619 &device,
620 );
621 let loss_func = LpLossConfig::l1();
622
623 let loss = loss_func.forward_reduce_dims(predictions, targets, &[1]);
625
626 let expected = TensorData::from([[4.0 / 3.0], [2.0]]);
629 loss.into_data()
630 .assert_approx_eq::<FT>(&expected, Tolerance::default());
631 }
632
633 #[test]
634 fn test_forward_reduce_dims_empty_dims() {
635 let device = Default::default();
637 let predictions = Tensor::<TestBackend, 2>::from_data(
638 TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
639 &device,
640 );
641 let targets = Tensor::<TestBackend, 2>::from_data(
642 TensorData::from([[0.0, 2.0], [3.0, 6.0]]),
643 &device,
644 );
645 let loss_func = LpLossConfig::l2();
646 let loss_reduce_dims =
647 loss_func.forward_reduce_dims(predictions.clone(), targets.clone(), &[]);
648 let loss_no_reduction = loss_func.forward_no_reduction(predictions, targets);
649
650 loss_reduce_dims
652 .into_data()
653 .assert_eq(&loss_no_reduction.into_data(), true);
654 }
655
656 #[test]
657 fn test_forward_reduce_dims_zero_error() {
658 let device = Default::default();
659 let predictions = Tensor::<TestBackend, 3>::from_data(
661 TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
662 &device,
663 );
664 let targets = predictions.clone();
665 let loss_func = LpLossConfig::l2();
666 let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2]);
667
668 let expected = TensorData::from([[[0.0]], [[0.0]]]);
670 loss.into_data().assert_eq(&expected, false);
671 }
672}