use super::Reduction;
use burn::config::Config;
use burn::module::Module;
use burn::tensor::{Tensor, backend::Backend};
use burn_core as burn;
#[derive(Config, Debug)]
pub struct SmoothL1LossConfig {
#[config(default = 1.0)]
pub beta: f32,
}
impl SmoothL1LossConfig {
pub fn init(&self) -> SmoothL1Loss {
self.assertions();
SmoothL1Loss { beta: self.beta }
}
fn assertions(&self) {
assert!(self.beta > 0.0, "The parameter beta must be positive.")
}
}
#[derive(Module, Clone, Debug)]
pub struct SmoothL1Loss {
pub beta: f32,
}
impl SmoothL1Loss {
pub fn forward<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
) -> Tensor<B, D> {
let error = predictions.sub(targets);
let abs_error = error.clone().abs();
let l1_loss = abs_error.clone().sub_scalar(0.5 * self.beta);
let l2_loss = error.square().mul_scalar(0.5).div_scalar(self.beta);
let l2_mask = abs_error.lower_elem(self.beta);
l1_loss.mask_where(l2_mask, l2_loss)
}
pub fn forward_with_reduction<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
reduction: Reduction,
) -> Tensor<B, 1> {
let unreduced_loss = self.forward(predictions, targets);
match reduction {
Reduction::Mean | Reduction::Auto => unreduced_loss.mean(),
Reduction::Sum => unreduced_loss.sum(),
other => panic!("{other:?} reduction is not supported"),
}
}
pub fn forward_reduce_dims<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
dims: &[usize],
) -> Tensor<B, D> {
let error = self.forward(predictions, targets);
let mut sorted_dims = dims.to_vec();
sorted_dims.sort();
error.mean_dims(sorted_dims.as_slice())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_smooth_l1_config_default_beta() {
let loss = SmoothL1LossConfig::new().init();
assert_eq!(loss.beta, 1.0);
}
#[test]
fn test_smooth_l1_config_custom_beta() {
let loss = SmoothL1LossConfig::new().with_beta(2.5).init();
assert_eq!(loss.beta, 2.5);
}
#[test]
#[should_panic(expected = "The parameter beta must be positive")]
fn test_smooth_l1_config_beta_zero_panics() {
SmoothL1LossConfig::new().with_beta(0.0).init();
}
#[test]
#[should_panic(expected = "The parameter beta must be positive")]
fn test_smooth_l1_config_beta_negative_panics() {
SmoothL1LossConfig::new().with_beta(-1.0).init();
}
#[test]
fn test_smooth_l1_forward_l2_region() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.5]]), &device);
let targets =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);
let output = loss.forward(predictions, targets);
let expected = TensorData::from([[0.0_f32, 0.125]]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_forward_l1_region() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 2.0]]), &device);
let targets =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);
let output = loss.forward(predictions, targets);
let expected = TensorData::from([[0.0_f32, 1.5]]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_forward_zero_error() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[1.0_f32, 2.0, 3.0]]), &device);
let targets = predictions.clone();
let output = loss.forward(predictions, targets);
let expected = TensorData::from([[0.0_f32, 0.0, 0.0]]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_forward_negative_errors() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions =
Tensor::<TestBackend, 1>::from_data(TensorData::from([-3.0_f32]), &device);
let targets = Tensor::<TestBackend, 1>::zeros([1], &device);
let output = loss.forward(predictions, targets);
let expected = TensorData::from([2.5_f32]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_forward_mixed_regions() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions =
Tensor::<TestBackend, 1>::from_data(TensorData::from([0.5_f32, 1.5, 3.0]), &device);
let targets = Tensor::<TestBackend, 1>::zeros([3], &device);
let output = loss.forward(predictions, targets);
let expected = TensorData::from([0.125_f32, 1.0, 2.5]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_custom_beta_values() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().with_beta(0.5).init();
let predictions =
Tensor::<TestBackend, 1>::from_data(TensorData::from([0.25_f32, 1.0]), &device);
let targets = Tensor::<TestBackend, 1>::zeros([2], &device);
let output = loss.forward(predictions, targets);
let expected = TensorData::from([0.0625_f32, 0.75]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_reduction_mean() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5_f32, 2.0]]), &device);
let targets =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);
let output = loss.forward_with_reduction(predictions, targets, Reduction::Mean);
let expected = TensorData::from([0.8125_f32]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_reduction_sum() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5_f32, 2.0]]), &device);
let targets =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);
let output = loss.forward_with_reduction(predictions, targets, Reduction::Sum);
let expected = TensorData::from([1.625_f32]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_reduction_auto_equals_mean() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions = Tensor::<TestBackend, 1>::from_data(TensorData::from([2.0_f32]), &device);
let targets = Tensor::<TestBackend, 1>::zeros([1], &device);
let mean_out =
loss.forward_with_reduction(predictions.clone(), targets.clone(), Reduction::Mean);
let auto_out = loss.forward_with_reduction(predictions, targets, Reduction::Auto);
mean_out.into_data().assert_eq(&auto_out.into_data(), false);
}
#[test]
fn test_smooth_l1_forward_reduce_dims_single_dim() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().with_beta(2.0).init();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.0_f32, 1.0, 4.0], [5.0_f32, 5.0, 5.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.0_f32, 0.0, 0.0], [5.0_f32, 5.0, 5.0]]),
&device,
);
let output = loss.forward_reduce_dims(predictions, targets, &[1]);
let expected = TensorData::from([[3.25_f32 / 3.0], [0.0]]); output
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_smooth_l1_forward_reduce_dims_image_batch() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions = Tensor::<TestBackend, 4>::from_data(
TensorData::from([
[[[0.5_f32, 2.0], [0.0, 3.0]]], [[[1.0_f32, 0.0], [0.5, 1.5]]], ]),
&device,
);
let targets = Tensor::<TestBackend, 4>::zeros([2, 1, 2, 2], &device);
let output = loss.forward_reduce_dims(predictions, targets, &[1, 2, 3]);
let expected = TensorData::from([[[[1.03125_f32]]], [[[0.40625_f32]]]]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn test_smooth_l1_forward_reduce_dims_unsorted() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[1.0_f32, 2.0], [3.0, 4.0]], [[5.0_f32, 6.0], [7.0, 8.0]]]),
&device,
);
let targets = Tensor::<TestBackend, 3>::zeros([2, 2, 2], &device);
let output = loss.forward_reduce_dims(predictions.clone(), targets.clone(), &[2, 1]);
let expected_output = loss.forward_reduce_dims(predictions, targets, &[1, 2]);
output
.into_data()
.assert_eq(&expected_output.into_data(), false);
}
#[test]
fn test_smooth_l1_forward_reduce_dims_empty_dims() {
let device = Default::default();
let loss = SmoothL1LossConfig::new().init();
let predictions = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.5_f32, 2.0], [0.0, 3.0]]),
&device,
);
let targets = Tensor::<TestBackend, 2>::zeros([2, 2], &device);
let loss_reduce_dims = loss.forward_reduce_dims(predictions.clone(), targets.clone(), &[]);
let loss_no_reduction = loss.forward(predictions, targets);
loss_reduce_dims
.into_data()
.assert_eq(&loss_no_reduction.into_data(), false);
}
}