burn_core/nn/loss/
huber.rs

1use crate as burn;
2
3use crate::module::{Content, DisplaySettings, ModuleDisplay};
4use crate::tensor::Tensor;
5use crate::tensor::backend::Backend;
6use crate::{config::Config, module::Module};
7
8use super::Reduction;
9
10/// Configuration to create a [Huber loss](HuberLoss).
11#[derive(Config, Debug)]
12pub struct HuberLossConfig {
13    /// The bound where the Huber loss function changes from quadratic to linear behaviour.
14    pub delta: f32,
15}
16
17impl HuberLossConfig {
18    /// Initialize [Huber loss](HuberLoss).
19    pub fn init(&self) -> HuberLoss {
20        self.assertions();
21        HuberLoss {
22            delta: self.delta,
23            lin_bias: self.delta * self.delta * 0.5,
24        }
25    }
26
27    fn assertions(&self) {
28        assert!(
29            self.delta >= 0., // This also tests for normality
30            "Delta for Huber loss must be a non-negative number."
31        );
32    }
33}
34
35/// Calculate the Huber loss between the inputs and the target.
36///
37/// The loss for each element of the residuals `r = targets - predictions` is given by
38///
39/// ```text
40/// L(r) = 0.5 * r^2                  if |r| <= d
41/// L(r) = 0.5 * d^2 + d * (|r| - d)  if |r| >  d
42/// ```
43///
44/// where `d` is the configured `delta`. In particular, this is equal to the
45/// [L2 Loss](super::MseLoss) for residuals with magnitude smaller than `delta`,
46/// but behaves linearly instead of quadratically for large residuals.
47///
48/// This loss function is less sensitive to outliers than the mean squared error loss.
49///
50/// See also: <https://en.wikipedia.org/wiki/Huber_loss>
51#[derive(Module, Debug, Clone)]
52#[module(custom_display)]
53pub struct HuberLoss {
54    /// The bound where the Huber loss function changes from quadratic to linear behaviour.
55    pub delta: f32,
56    /// Precomputed value for the linear bias.
57    pub lin_bias: f32, // delta * delta * 0.5 precomputed
58}
59
60impl ModuleDisplay for HuberLoss {
61    fn custom_settings(&self) -> Option<DisplaySettings> {
62        DisplaySettings::new()
63            .with_new_line_after_attribute(false)
64            .optional()
65    }
66
67    fn custom_content(&self, content: Content) -> Option<Content> {
68        content
69            .add("delta", &self.delta)
70            .add("lin_bias", &self.lin_bias)
71            .optional()
72    }
73}
74
75impl HuberLoss {
76    /// Compute the loss element-wise for the predictions and targets, then reduce
77    /// to a single loss value.
78    ///
79    /// `Reduction::Auto` behaves as `Reduction::Mean`.
80    ///
81    /// # Shapes
82    ///
83    /// - predictions: \[...dims\]
84    /// - targets: \[...dims\]
85    /// - output: \[1\]
86    pub fn forward<const D: usize, B: Backend>(
87        &self,
88        predictions: Tensor<B, D>,
89        targets: Tensor<B, D>,
90        reduction: Reduction,
91    ) -> Tensor<B, 1> {
92        let loss = self.forward_no_reduction(predictions, targets);
93        match reduction {
94            Reduction::Mean | Reduction::Auto => loss.mean(),
95            Reduction::Sum => loss.sum(),
96        }
97    }
98    /// Compute the loss element-wise for the predictions and targets.
99    ///
100    /// # Shapes
101    ///
102    /// - predictions: [...dims]
103    /// - targets: [...dims]
104    /// - output: [...dims]
105    pub fn forward_no_reduction<const D: usize, B: Backend>(
106        &self,
107        predictions: Tensor<B, D>,
108        targets: Tensor<B, D>,
109    ) -> Tensor<B, D> {
110        let residuals = targets - predictions;
111        self.forward_residuals(residuals)
112    }
113    /// Compute the loss element-wise for the given residuals.
114    ///
115    /// # Shapes
116    ///
117    /// - residuals: [...dims]
118    /// - output: [...dims]
119    pub fn forward_residuals<const D: usize, B: Backend>(
120        &self,
121        residuals: Tensor<B, D>,
122    ) -> Tensor<B, D> {
123        let is_large = residuals.clone().abs().greater_elem(self.delta);
124        // We are interested in `sign(r)` when `abs(r) > self.delta`. Note that the
125        // `sign()` function, in general, suffers from a jump at 0.
126        // Instead the following tensor implements `delta * sign(r)` for values outside
127        // the bound:
128        let softsign = residuals.clone().clamp(-self.delta, self.delta);
129
130        // 0.5 * d^2 + d * (|r| - d) =
131        // d * |r| - 0.5 * d^2
132        // Moreover |r| = sign(r) * r
133        let outside = softsign.mul(residuals.clone()).sub_scalar(self.lin_bias);
134
135        let inside = residuals.powi_scalar(2).mul_scalar(0.5);
136        inside.mask_where(is_large, outside)
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::TestBackend;
144    use crate::tensor::TensorData;
145    type TestTensor<const D: usize> = Tensor<TestBackend, D>;
146    use burn_tensor::{Tolerance, ops::FloatElem};
147    type FT = FloatElem<TestBackend>;
148
149    #[test]
150    fn test_huber_loss() {
151        let predict = TensorData::from([-2., -0.5, 0., 0.3, 1.]);
152        let targets = TensorData::from([0., 0., 0., 0., 0.]);
153
154        let device = Default::default();
155
156        let predict = TestTensor::<1>::from_data(predict, &device);
157        let targets = TestTensor::<1>::from_data(targets, &device);
158
159        let huber = HuberLossConfig::new(0.5).init();
160
161        let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum);
162        let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto);
163        let loss_no_reduction = huber.forward_no_reduction(predict, targets);
164
165        let expected = TensorData::from([0.875, 0.125, 0., 0.045, 0.375]);
166        loss_no_reduction
167            .into_data()
168            .assert_approx_eq::<FT>(&expected, Tolerance::default());
169
170        let expected = TensorData::from([0.284]);
171        loss.into_data()
172            .assert_approx_eq::<FT>(&expected, Tolerance::default());
173
174        let expected = TensorData::from([1.42]);
175        loss_sum
176            .into_data()
177            .assert_approx_eq::<FT>(&expected, Tolerance::default());
178    }
179
180    #[cfg(feature = "std")]
181    #[test]
182    fn test_huber_ad_loss() {
183        type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
184
185        let predict = TensorData::from([-2., -0.5, 0., 0.3, 1.]);
186        let targets = TensorData::from([0., 0., 0., 0., 0.]);
187
188        let device = Default::default();
189        let predict = TestAutodiffTensor::from_data(predict, &device).require_grad();
190        let targets = TestAutodiffTensor::from_data(targets, &device);
191
192        let loss = HuberLossConfig::new(0.5).init();
193        let loss = loss.forward_no_reduction(predict.clone(), targets);
194
195        let grads = loss.backward();
196        let grads_predict = predict.grad(&grads).unwrap();
197
198        let expected = TensorData::from([-0.5, -0.5, 0., 0.3, 0.5]);
199        grads_predict
200            .to_data()
201            .assert_approx_eq::<FT>(&expected, Tolerance::default());
202    }
203
204    #[test]
205    fn display() {
206        let config = HuberLossConfig::new(0.5);
207        let loss = config.init();
208
209        assert_eq!(
210            alloc::format!("{loss}"),
211            "HuberLoss {delta: 0.5, lin_bias: 0.125}"
212        );
213    }
214}