burn_core/nn/loss/
huber.rs1use crate as burn;
2
3use crate::module::{Content, DisplaySettings, ModuleDisplay};
4use crate::tensor::Tensor;
5use crate::tensor::backend::Backend;
6use crate::{config::Config, module::Module};
7
8use super::Reduction;
9
10#[derive(Config, Debug)]
12pub struct HuberLossConfig {
13 pub delta: f32,
15}
16
17impl HuberLossConfig {
18 pub fn init(&self) -> HuberLoss {
20 self.assertions();
21 HuberLoss {
22 delta: self.delta,
23 lin_bias: self.delta * self.delta * 0.5,
24 }
25 }
26
27 fn assertions(&self) {
28 assert!(
29 self.delta >= 0., "Delta for Huber loss must be a non-negative number."
31 );
32 }
33}
34
35#[derive(Module, Debug, Clone)]
52#[module(custom_display)]
53pub struct HuberLoss {
54 pub delta: f32,
56 pub lin_bias: f32, }
59
60impl ModuleDisplay for HuberLoss {
61 fn custom_settings(&self) -> Option<DisplaySettings> {
62 DisplaySettings::new()
63 .with_new_line_after_attribute(false)
64 .optional()
65 }
66
67 fn custom_content(&self, content: Content) -> Option<Content> {
68 content
69 .add("delta", &self.delta)
70 .add("lin_bias", &self.lin_bias)
71 .optional()
72 }
73}
74
75impl HuberLoss {
76 pub fn forward<const D: usize, B: Backend>(
87 &self,
88 predictions: Tensor<B, D>,
89 targets: Tensor<B, D>,
90 reduction: Reduction,
91 ) -> Tensor<B, 1> {
92 let loss = self.forward_no_reduction(predictions, targets);
93 match reduction {
94 Reduction::Mean | Reduction::Auto => loss.mean(),
95 Reduction::Sum => loss.sum(),
96 }
97 }
98 pub fn forward_no_reduction<const D: usize, B: Backend>(
106 &self,
107 predictions: Tensor<B, D>,
108 targets: Tensor<B, D>,
109 ) -> Tensor<B, D> {
110 let residuals = targets - predictions;
111 self.forward_residuals(residuals)
112 }
113 pub fn forward_residuals<const D: usize, B: Backend>(
120 &self,
121 residuals: Tensor<B, D>,
122 ) -> Tensor<B, D> {
123 let is_large = residuals.clone().abs().greater_elem(self.delta);
124 let softsign = residuals.clone().clamp(-self.delta, self.delta);
129
130 let outside = softsign.mul(residuals.clone()).sub_scalar(self.lin_bias);
134
135 let inside = residuals.powi_scalar(2).mul_scalar(0.5);
136 inside.mask_where(is_large, outside)
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::TestBackend;
144 use crate::tensor::TensorData;
145 type TestTensor<const D: usize> = Tensor<TestBackend, D>;
146 use burn_tensor::{Tolerance, ops::FloatElem};
147 type FT = FloatElem<TestBackend>;
148
149 #[test]
150 fn test_huber_loss() {
151 let predict = TensorData::from([-2., -0.5, 0., 0.3, 1.]);
152 let targets = TensorData::from([0., 0., 0., 0., 0.]);
153
154 let device = Default::default();
155
156 let predict = TestTensor::<1>::from_data(predict, &device);
157 let targets = TestTensor::<1>::from_data(targets, &device);
158
159 let huber = HuberLossConfig::new(0.5).init();
160
161 let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum);
162 let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto);
163 let loss_no_reduction = huber.forward_no_reduction(predict, targets);
164
165 let expected = TensorData::from([0.875, 0.125, 0., 0.045, 0.375]);
166 loss_no_reduction
167 .into_data()
168 .assert_approx_eq::<FT>(&expected, Tolerance::default());
169
170 let expected = TensorData::from([0.284]);
171 loss.into_data()
172 .assert_approx_eq::<FT>(&expected, Tolerance::default());
173
174 let expected = TensorData::from([1.42]);
175 loss_sum
176 .into_data()
177 .assert_approx_eq::<FT>(&expected, Tolerance::default());
178 }
179
180 #[cfg(feature = "std")]
181 #[test]
182 fn test_huber_ad_loss() {
183 type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
184
185 let predict = TensorData::from([-2., -0.5, 0., 0.3, 1.]);
186 let targets = TensorData::from([0., 0., 0., 0., 0.]);
187
188 let device = Default::default();
189 let predict = TestAutodiffTensor::from_data(predict, &device).require_grad();
190 let targets = TestAutodiffTensor::from_data(targets, &device);
191
192 let loss = HuberLossConfig::new(0.5).init();
193 let loss = loss.forward_no_reduction(predict.clone(), targets);
194
195 let grads = loss.backward();
196 let grads_predict = predict.grad(&grads).unwrap();
197
198 let expected = TensorData::from([-0.5, -0.5, 0., 0.3, 0.5]);
199 grads_predict
200 .to_data()
201 .assert_approx_eq::<FT>(&expected, Tolerance::default());
202 }
203
204 #[test]
205 fn display() {
206 let config = HuberLossConfig::new(0.5);
207 let loss = config.init();
208
209 assert_eq!(
210 alloc::format!("{loss}"),
211 "HuberLoss {delta: 0.5, lin_bias: 0.125}"
212 );
213 }
214}