use super::Reduction;
use burn::config::Config;
use burn::module::Module;
use burn::tensor::{Tensor, backend::Backend};
use burn_core as burn;
#[derive(Config, Debug)]
pub struct LpLossConfig {
pub p: f64,
}
impl LpLossConfig {
pub fn init(&self) -> LpLoss {
self.assertions();
LpLoss { p: self.p }
}
pub fn l1() -> LpLoss {
LpLoss { p: 1.0 }
}
pub fn l2() -> LpLoss {
LpLoss { p: 2.0 }
}
fn assertions(&self) {
assert!(self.p > 0.0, "The order of the norm p must be positive.")
}
}
#[derive(Module, Clone, Debug)]
pub struct LpLoss {
pub p: f64,
}
impl LpLoss {
pub fn forward<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
reduction: Reduction,
) -> Tensor<B, 1> {
let unreduced_loss = self.forward_no_reduction(predictions, targets);
match reduction {
Reduction::Mean | Reduction::Auto => unreduced_loss.mean(),
Reduction::Sum => unreduced_loss.sum(),
other => panic!("{other:?} reduction is not supported"),
}
}
pub fn forward_no_reduction<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
) -> Tensor<B, D> {
let error = predictions.sub(targets);
if self.p == 1.0 {
error.abs()
} else if self.p == 2.0 {
error.clone().mul(error)
} else {
error.abs().powf_scalar(self.p)
}
}
pub fn forward_reduce_dims<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
dims: &[usize],
) -> Tensor<B, D> {
let error = self.forward_no_reduction(predictions, targets);
let mut sorted_dims = dims.to_vec();
sorted_dims.sort();
error.mean_dims(sorted_dims.as_slice())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_lp_loss_l1_constructor() {
let loss_func_l1 = LpLossConfig::l1();
let loss_func_p1 = LpLossConfig::new(1.0).init();
assert_eq!(loss_func_l1.p, 1.0);
assert_eq!(loss_func_l1.p, loss_func_p1.p);
}
#[test]
fn test_lp_loss_l2_constructor() {
let loss_func_l2 = LpLossConfig::l2();
let loss_func_p2 = LpLossConfig::new(2.0).init();
assert_eq!(loss_func_l2.p, 2.0);
assert_eq!(loss_func_l2.p, loss_func_p2.p);
}
#[test]
fn test_lp_loss_l1() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
&device,
);
let loss_func = LpLossConfig::l1();
let loss_no_reduction =
loss_func.forward_no_reduction(predictions.clone(), targets.clone());
let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
let expected = TensorData::from([[1.0, 1.0], [0.0, 2.0]]);
loss_no_reduction.into_data().assert_eq(&expected, false);
let expected = TensorData::from([1.0]);
loss_auto.into_data().assert_eq(&expected, false);
let expected = TensorData::from([4.0]);
loss_sum.into_data().assert_eq(&expected, false);
}
#[test]
fn test_lp_loss_l2() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
&device,
);
let loss_func = LpLossConfig::l2();
let loss_no_reduction =
loss_func.forward_no_reduction(predictions.clone(), targets.clone());
let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]);
loss_no_reduction.into_data().assert_eq(&expected, false);
let expected = TensorData::from([1.5]);
loss_auto.into_data().assert_eq(&expected, false);
let expected = TensorData::from([6.0]);
loss_sum.into_data().assert_eq(&expected, false);
}
#[test]
fn test_lp_loss_p_half() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[2.0, 1.0], [3.0, 0.0]]),
&device,
);
let loss_func = LpLossConfig::new(0.5).init();
let loss_no_reduction =
loss_func.forward_no_reduction(predictions.clone(), targets.clone());
let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
let expected = TensorData::from([[1.0, 1.0], [0.0, 2.0]]);
loss_no_reduction.into_data().assert_eq(&expected, false);
let expected = TensorData::from([1.0]);
loss_auto.into_data().assert_eq(&expected, false);
let expected = TensorData::from([4.0]);
loss_sum.into_data().assert_eq(&expected, false);
}
#[test]
fn test_lp_loss_p3() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
&device,
);
let loss_func = LpLossConfig::new(3.0).init();
let loss_no_reduction =
loss_func.forward_no_reduction(predictions.clone(), targets.clone());
let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);
let expected = TensorData::from([[1.0, 1.0], [0.0, 8.0]]);
loss_no_reduction.into_data().assert_eq(&expected, false);
let expected = TensorData::from([2.5]);
loss_auto.into_data().assert_eq(&expected, false);
let expected = TensorData::from([10.0]);
loss_sum.into_data().assert_eq(&expected, false);
}
#[test]
fn test_lp_loss_zero_error() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = predictions.clone();
let loss_func_l1 = LpLossConfig::l1();
let loss_func_l2 = LpLossConfig::l2();
let l1_loss = loss_func_l1.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let l2_loss = loss_func_l2.forward(predictions, targets, Reduction::Auto);
let expected = TensorData::from([0.0]);
l1_loss.into_data().assert_eq(&expected, false);
l2_loss.into_data().assert_eq(&expected, false);
}
#[test]
fn test_lp_loss_negative_errors() {
let device = Default::default();
let predictions =
Tensor::<TestBackend, 1>::from_data(TensorData::from([1.0, 2.0, 3.0]), &device);
let targets =
Tensor::<TestBackend, 1>::from_data(TensorData::from([3.0, 4.0, 5.0]), &device);
let loss_func_l1 = LpLossConfig::l1();
let loss_func_p1 = LpLossConfig::new(1.0).init();
let loss_no_reduction_l1 =
loss_func_l1.forward_no_reduction(predictions.clone(), targets.clone());
let loss_no_reduction_p1 = loss_func_p1.forward_no_reduction(predictions, targets);
let expected = TensorData::from([2.0, 2.0, 2.0]);
loss_no_reduction_l1.into_data().assert_eq(&expected, false);
loss_no_reduction_p1.into_data().assert_eq(&expected, false);
}
#[test]
fn test_lp_loss_3d_tensor() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
&device,
);
let targets = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[0.0, 2.0], [3.0, 5.0]], [[4.0, 6.0], [7.0, 10.0]]]),
&device,
);
let loss_func_l2 = LpLossConfig::l2();
let loss_func_p2 = LpLossConfig::new(2.0).init();
let loss_l2 = loss_func_l2.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let loss_p2 = loss_func_p2.forward(predictions, targets, Reduction::Auto);
let expected = TensorData::from([0.875]);
loss_l2.into_data().assert_eq(&expected, false);
loss_p2.into_data().assert_eq(&expected, false);
}
#[test]
#[should_panic(expected = "The order of the norm p must be positive.")]
fn test_lp_loss_negative_p_panics() {
let _ = LpLossConfig::new(-1.0).init();
}
#[test]
#[should_panic(expected = "The order of the norm p must be positive.")]
fn test_lp_loss_zero_p_panics() {
let _ = LpLossConfig::new(0.0).init();
}
#[test]
fn test_lp_loss_fractional_p() {
let device = Default::default();
let predictions =
Tensor::<TestBackend, 1>::from_data(TensorData::from([0.0, 4.0]), &device);
let targets = Tensor::<TestBackend, 1>::from_data(TensorData::from([1.0, 0.0]), &device);
let loss_func = LpLossConfig::new(1.5).init();
let loss_no_reduction = loss_func.forward_no_reduction(predictions, targets);
let expected = TensorData::from([1.0, 8.0]);
loss_no_reduction.into_data().assert_eq(&expected, false);
}
#[test]
fn test_forward_reduce_dims_single_dim() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.0, 2.0, 6.0], [1.0, 5.0, 6.0]]),
&device,
);
let loss_func_l2 = LpLossConfig::l2();
let loss_func_p2 = LpLossConfig::new(2.0).init();
let loss_l2 = loss_func_l2.forward_reduce_dims(predictions.clone(), targets.clone(), &[1]);
let loss_p2 = loss_func_p2.forward_reduce_dims(predictions, targets, &[1]);
let expected = TensorData::from([[10.0 / 3.0], [3.0]]);
loss_l2
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
loss_p2
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_forward_reduce_dims_first_dim() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.0, 2.0, 6.0], [1.0, 5.0, 6.0]]),
&device,
);
let loss_func = LpLossConfig::l2();
let loss = loss_func.forward_reduce_dims(predictions, targets, &[0]);
let expected = TensorData::from([[5.0, 0.0, 4.5]]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_forward_reduce_dims_multiple_dims() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
&device,
);
let targets = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[0.0, 2.0], [3.0, 6.0]], [[4.0, 6.0], [7.0, 10.0]]]),
&device,
);
let loss_func = LpLossConfig::l2();
let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2]);
let expected = TensorData::from([[[1.25]], [[1.25]]]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_forward_reduce_dims_all_dims() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
&device,
);
let loss_func = LpLossConfig::l2();
let loss = loss_func.forward_reduce_dims(predictions, targets, &[0, 1]);
let expected = TensorData::from([[1.5]]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_forward_reduce_dims_image_batch() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 4>::from_data(
TensorData::from([
[[[1.0, 2.0], [3.0, 4.0]]], [[[5.0, 6.0], [7.0, 8.0]]], ]),
&device,
);
let targets = Tensor::<TestBackend, 4>::from_data(
TensorData::from([
[[[0.0, 2.0], [3.0, 6.0]]], [[[5.0, 5.0], [7.0, 7.0]]], ]),
&device,
);
let loss_func = LpLossConfig::l2();
let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
let expected = TensorData::from([[[[1.25]]], [[[0.5]]]]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_forward_reduce_dims_with_p1() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.0, 5.0, 3.0], [1.0, 5.0, 9.0]]),
&device,
);
let loss_func = LpLossConfig::l1();
let loss = loss_func.forward_reduce_dims(predictions, targets, &[1]);
let expected = TensorData::from([[4.0 / 3.0], [2.0]]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_forward_reduce_dims_empty_dims() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.0, 2.0], [3.0, 6.0]]),
&device,
);
let loss_func = LpLossConfig::l2();
let loss_reduce_dims =
loss_func.forward_reduce_dims(predictions.clone(), targets.clone(), &[]);
let loss_no_reduction = loss_func.forward_no_reduction(predictions, targets);
loss_reduce_dims
.into_data()
.assert_eq(&loss_no_reduction.into_data(), true);
}
#[test]
fn test_forward_reduce_dims_zero_error() {
let device = Default::default();
let predictions = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
&device,
);
let targets = predictions.clone();
let loss_func = LpLossConfig::l2();
let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2]);
let expected = TensorData::from([[[0.0]], [[0.0]]]);
loss.into_data().assert_eq(&expected, false);
}
}