burn_core/nn/loss/
huber.rs1use crate as burn;
2
3use crate::module::{Content, DisplaySettings, ModuleDisplay};
4use crate::tensor::backend::Backend;
5use crate::tensor::Tensor;
6use crate::{config::Config, module::Module};
7
8use super::Reduction;
9
10#[derive(Config, Debug)]
12pub struct HuberLossConfig {
13 pub delta: f32,
15}
16
17impl HuberLossConfig {
18 pub fn init(&self) -> HuberLoss {
20 self.assertions();
21 HuberLoss {
22 delta: self.delta,
23 lin_bias: self.delta * self.delta * 0.5,
24 }
25 }
26
27 fn assertions(&self) {
28 assert!(
29 self.delta >= 0., "Delta for Huber loss must be a non-negative number."
31 );
32 }
33}
34
35#[derive(Module, Debug, Clone)]
52#[module(custom_display)]
53pub struct HuberLoss {
54 pub delta: f32,
56 pub lin_bias: f32, }
59
60impl ModuleDisplay for HuberLoss {
61 fn custom_settings(&self) -> Option<DisplaySettings> {
62 DisplaySettings::new()
63 .with_new_line_after_attribute(false)
64 .optional()
65 }
66
67 fn custom_content(&self, content: Content) -> Option<Content> {
68 content
69 .add("delta", &self.delta)
70 .add("lin_bias", &self.lin_bias)
71 .optional()
72 }
73}
74
75impl HuberLoss {
76 pub fn forward<const D: usize, B: Backend>(
87 &self,
88 predictions: Tensor<B, D>,
89 targets: Tensor<B, D>,
90 reduction: Reduction,
91 ) -> Tensor<B, 1> {
92 let loss = self.forward_no_reduction(predictions, targets);
93 match reduction {
94 Reduction::Mean | Reduction::Auto => loss.mean(),
95 Reduction::Sum => loss.sum(),
96 }
97 }
98 pub fn forward_no_reduction<const D: usize, B: Backend>(
106 &self,
107 predictions: Tensor<B, D>,
108 targets: Tensor<B, D>,
109 ) -> Tensor<B, D> {
110 let residuals = targets - predictions;
111 self.forward_residuals(residuals)
112 }
113 pub fn forward_residuals<const D: usize, B: Backend>(
120 &self,
121 residuals: Tensor<B, D>,
122 ) -> Tensor<B, D> {
123 let is_large = residuals.clone().abs().greater_elem(self.delta);
124 let softsign = residuals.clone().clamp(-self.delta, self.delta);
129
130 let outside = softsign.mul(residuals.clone()).sub_scalar(self.lin_bias);
134
135 let inside = residuals.powf_scalar(2.).mul_scalar(0.5);
136 inside.mask_where(is_large, outside)
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::tensor::TensorData;
144 use crate::TestBackend;
145 type TestTensor<const D: usize> = Tensor<TestBackend, D>;
146
147 #[test]
148 fn test_huber_loss() {
149 let predict = TensorData::from([-2., -0.5, 0., 0.3, 1.]);
150 let targets = TensorData::from([0., 0., 0., 0., 0.]);
151
152 let device = Default::default();
153
154 let predict = TestTensor::<1>::from_data(predict, &device);
155 let targets = TestTensor::<1>::from_data(targets, &device);
156
157 let huber = HuberLossConfig::new(0.5).init();
158
159 let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum);
160 let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto);
161 let loss_no_reduction = huber.forward_no_reduction(predict, targets);
162
163 let expected = TensorData::from([0.875, 0.125, 0., 0.045, 0.375]);
164 loss_no_reduction.into_data().assert_approx_eq(&expected, 7);
165
166 let expected = TensorData::from([0.284]);
167 loss.into_data().assert_approx_eq(&expected, 7);
168
169 let expected = TensorData::from([1.42]);
170 loss_sum.into_data().assert_approx_eq(&expected, 5);
171 }
172
173 #[cfg(feature = "std")]
174 #[test]
175 fn test_huber_ad_loss() {
176 type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
177
178 let predict = TensorData::from([-2., -0.5, 0., 0.3, 1.]);
179 let targets = TensorData::from([0., 0., 0., 0., 0.]);
180
181 let device = Default::default();
182 let predict = TestAutodiffTensor::from_data(predict, &device).require_grad();
183 let targets = TestAutodiffTensor::from_data(targets, &device);
184
185 let loss = HuberLossConfig::new(0.5).init();
186 let loss = loss.forward_no_reduction(predict.clone(), targets);
187
188 let grads = loss.backward();
189 let grads_predict = predict.grad(&grads).unwrap();
190
191 let expected = TensorData::from([-0.5, -0.5, 0., 0.3, 0.5]);
192 grads_predict.to_data().assert_approx_eq(&expected, 3);
193 }
194
195 #[test]
196 fn display() {
197 let config = HuberLossConfig::new(0.5);
198 let loss = config.init();
199
200 assert_eq!(
201 alloc::format!("{}", loss),
202 "HuberLoss {delta: 0.5, lin_bias: 0.125}"
203 );
204 }
205}