1use core::f32::consts::PI;
2
3use burn_tensor::cast::ToElement;
4
5use crate as burn;
6use crate::module::{Content, DisplaySettings, ModuleDisplay};
7use crate::tensor::Tensor;
8use crate::tensor::backend::Backend;
9use crate::{config::Config, module::Module};
10
11use super::Reduction;
12
13#[derive(Config, Debug)]
19pub struct PoissonNllLossConfig {
20 #[config(default = true)]
31 pub log_input: bool,
32 #[config(default = false)]
39 pub full: bool,
40 #[config(default = 1e-8)]
45 pub eps: f64,
46}
47
48impl PoissonNllLossConfig {
49 pub fn init(&self) -> PoissonNllLoss {
54 self.assertions();
55 PoissonNllLoss {
56 log_input: self.log_input,
57 full: self.full,
58 eps: self.eps,
59 }
60 }
61
62 fn assertions(&self) {
67 assert!(
68 self.eps > 0.,
69 "eps for PoissonNllLoss must be a positive number."
70 );
71 }
72}
73
74#[derive(Module, Debug, Clone)]
88#[module(custom_display)]
89pub struct PoissonNllLoss {
90 pub log_input: bool,
92 pub full: bool,
94 pub eps: f64,
96}
97
98impl ModuleDisplay for PoissonNllLoss {
99 fn custom_settings(&self) -> Option<DisplaySettings> {
100 DisplaySettings::new()
101 .with_new_line_after_attribute(false)
102 .optional()
103 }
104
105 fn custom_content(&self, content: Content) -> Option<Content> {
106 content
107 .add("log_input", &self.log_input)
108 .add("full", &self.full)
109 .add("eps", &self.eps)
110 .optional()
111 }
112}
113
114impl PoissonNllLoss {
115 pub fn forward<const D: usize, B: Backend>(
133 &self,
134 predictions: Tensor<B, D>,
135 targets: Tensor<B, D>,
136 reduction: Reduction,
137 ) -> Tensor<B, 1> {
138 let loss = self.forward_no_reduction(predictions, targets);
139 match reduction {
140 Reduction::Mean | Reduction::Auto => loss.mean(),
141 Reduction::Sum => loss.sum(),
142 }
143 }
144
145 pub fn forward_no_reduction<const D: usize, B: Backend>(
161 &self,
162 predictions: Tensor<B, D>,
163 targets: Tensor<B, D>,
164 ) -> Tensor<B, D> {
165 self.assertions(&predictions, &targets);
166 let mut loss;
167 if self.log_input {
168 loss = predictions.clone().exp() - targets.clone() * predictions;
169 } else {
170 loss = predictions.clone() - targets.clone() * (predictions + self.eps).log();
171 }
172 if self.full {
173 let log_stirling_term = targets.clone() * targets.clone().log() - targets.clone()
174 + (targets.clone() * 2. * PI).log() * 0.5;
175 loss = loss
176 + log_stirling_term
177 .mask_where(targets.clone().lower_equal_elem(1), targets.zeros_like());
178 }
179 loss
180 }
181
182 fn assertions<const D: usize, B: Backend>(
189 &self,
190 predictions: &Tensor<B, D>,
191 targets: &Tensor<B, D>,
192 ) {
193 let predictions_dims = predictions.dims();
194 let targets_dims = targets.dims();
195 assert!(
196 predictions_dims == targets_dims,
197 "Shape of targets ({:?}) should correspond to outer shape of predictions ({:?}).",
198 targets_dims,
199 predictions_dims
200 );
201 assert!(
202 targets
203 .clone()
204 .greater_equal_elem(0.)
205 .all()
206 .into_scalar()
207 .to_bool(),
208 "All the values of `targets` must be non-negative."
209 );
210 if !self.log_input {
211 assert!(
212 predictions
213 .clone()
214 .greater_equal_elem(0.)
215 .all()
216 .into_scalar()
217 .to_bool(),
218 "When `log_input` is `false`, all the values of `predictions` must be non-negative."
219 );
220 }
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 #![allow(clippy::approx_constant)]
227
228 use super::*;
229 use crate::TestBackend;
230 use crate::tensor::TensorData;
231 type TestTensor<const D: usize> = Tensor<TestBackend, D>;
232 use burn_tensor::{Tolerance, ops::FloatElem};
233 type FT = FloatElem<TestBackend>;
234
235 #[test]
236 fn test_poisson_nll_loss() {
237 let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
238 let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
239
240 let device = Default::default();
241
242 let predictions = TestTensor::<1>::from_data(predictions, &device);
243 let targets = TestTensor::<1>::from_data(targets, &device);
244
245 let poisson = PoissonNllLossConfig::new().init();
246
247 let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
248 let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
249 let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);
250
251 let expected = TensorData::from([1.0000, 1.0000, 100.0000, 2.7183, 7.3891, 14.0855]);
252 loss_no_reduction
253 .into_data()
254 .assert_approx_eq::<FT>(&expected, Tolerance::default());
255
256 let expected = TensorData::from([21.0321]);
257 loss.into_data()
258 .assert_approx_eq::<FT>(&expected, Tolerance::default());
259
260 let expected = TensorData::from([126.1929]);
261 loss_sum
262 .into_data()
263 .assert_approx_eq::<FT>(&expected, Tolerance::default());
264 }
265
266 #[test]
267 fn test_poisson_nll_loss_no_log_input() {
268 let predictions = TensorData::from([0.0, 0.5, 1.0, 1.0, 2.71828, 7.38905, 20.0855]);
269 let targets = TensorData::from([2., 3., 1., 4.5, 0., 0., 2.]);
270
271 let device = Default::default();
272
273 let predictions = TestTensor::<1>::from_data(predictions, &device);
274 let targets = TestTensor::<1>::from_data(targets, &device);
275
276 let poisson = PoissonNllLossConfig::new().with_log_input(false).init();
277
278 let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone());
279
280 let expected = TensorData::from([36.84136, 2.579441, 1.0, 1.0, 2.71828, 7.38905, 14.0855]);
281 loss_no_reduction
282 .into_data()
283 .assert_approx_eq::<FT>(&expected, Tolerance::default());
284 }
285
286 #[test]
287 fn test_poisson_nll_loss_full() {
288 let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
289 let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
290
291 let device = Default::default();
292
293 let predictions = TestTensor::<1>::from_data(predictions, &device);
294 let targets = TestTensor::<1>::from_data(targets, &device);
295
296 let poisson = PoissonNllLossConfig::new().with_full(true).init();
297
298 let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
299 let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
300 let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);
301
302 let expected = TensorData::from([1.0000, 4.9393, 101.1678, 2.7183, 7.3891, 14.7373]);
303 loss_no_reduction
304 .into_data()
305 .assert_approx_eq::<FT>(&expected, Tolerance::default());
306
307 let expected = TensorData::from([21.9920]);
308 loss.into_data()
309 .assert_approx_eq::<FT>(&expected, Tolerance::default());
310
311 let expected = TensorData::from([131.9518]);
312 loss_sum
313 .into_data()
314 .assert_approx_eq::<FT>(&expected, Tolerance::default());
315 }
316
317 #[cfg(feature = "std")]
318 #[test]
319 fn test_poisson_nll_loss_gradients() {
320 type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
321
322 let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
323 let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
324
325 let device = Default::default();
326
327 let predictions1 = TestAutodiffTensor::from_data(predictions, &device).require_grad();
328 let predictions2 = predictions1.clone();
329 let targets = TestAutodiffTensor::from_data(targets, &device);
330
331 let poisson = PoissonNllLossConfig::new().with_full(false).init();
332 let poisson_full = PoissonNllLossConfig::new().with_full(true).init();
333
334 let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum);
335 let loss_full_sum =
336 poisson_full.forward(predictions2.clone(), targets.clone(), Reduction::Sum);
337
338 let grads = loss_sum.backward();
339 let grads_full = loss_full_sum.backward();
340
341 let grads_predictions1 = predictions1.grad(&grads).unwrap();
342 let grads_predictions2 = predictions2.grad(&grads_full).unwrap();
343
344 let expected = TensorData::from([0.0000, -3.5000, -2.5000, 2.7183, 7.3891, 18.0855]);
345
346 grads_predictions1
347 .into_data()
348 .assert_approx_eq::<FT>(&expected, Tolerance::default());
349 grads_predictions2
350 .into_data()
351 .assert_approx_eq::<FT>(&expected, Tolerance::default());
352 }
353
354 #[test]
355 #[should_panic = "eps for PoissonNllLoss must be a positive number."]
356 fn test_negative_eps() {
357 let _poisson = PoissonNllLossConfig::new().with_eps(0.).init();
358 }
359
360 #[test]
361 #[should_panic = "All the values of `targets` must be non-negative."]
362 fn test_targets_with_negative_values() {
363 let predictions = TensorData::from([0., 0., -40., 1., 2., 3., 4.]);
364 let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2., -0.42]);
365
366 let device = Default::default();
367
368 let predictions = TestTensor::<1>::from_data(predictions, &device);
369 let targets = TestTensor::<1>::from_data(targets, &device);
370
371 let poisson = PoissonNllLossConfig::new().init();
372
373 let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
374 }
375
376 #[test]
377 #[should_panic = "Shape of targets"]
378 fn test_shape_tensors() {
379 let predictions = TensorData::from([0., 1., 2.]);
380 let targets = TensorData::from([0., 1.]);
381
382 let device = Default::default();
383
384 let predictions = TestTensor::<1>::from_data(predictions, &device);
385 let targets = TestTensor::<1>::from_data(targets, &device);
386
387 let poisson = PoissonNllLossConfig::new().init();
388
389 let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
390 }
391
392 #[test]
393 #[should_panic = "When `log_input` is `false`, all the values of `predictions` must be non-negative."]
394 fn test_exp_predictions_non_negative() {
395 let predictions = TensorData::from([0.3, -0.1, 0.4]);
396 let targets = TensorData::from([0., 1., 0.]);
397
398 let device = Default::default();
399
400 let predictions = TestTensor::<1>::from_data(predictions, &device);
401 let targets = TestTensor::<1>::from_data(targets, &device);
402
403 let poisson = PoissonNllLossConfig::new().with_log_input(false).init();
404
405 let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
406 }
407
408 #[test]
409 fn display() {
410 let config = PoissonNllLossConfig::new();
411 let loss = config.init();
412
413 assert_eq!(
414 alloc::format!("{}", loss),
415 "PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}"
416 );
417 }
418}