1use burn_core as burn;
2use core::f32::consts::PI;
3
4use burn::tensor::cast::ToElement;
5
6use burn::module::{Content, DisplaySettings, ModuleDisplay};
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9use burn::{config::Config, module::Module};
10
11use super::Reduction;
12
13#[derive(Config, Debug)]
19pub struct PoissonNllLossConfig {
20 #[config(default = true)]
31 pub log_input: bool,
32 #[config(default = false)]
39 pub full: bool,
40 #[config(default = 1e-8)]
45 pub eps: f64,
46}
47
48impl PoissonNllLossConfig {
49 pub fn init(&self) -> PoissonNllLoss {
54 self.assertions();
55 PoissonNllLoss {
56 log_input: self.log_input,
57 full: self.full,
58 eps: self.eps,
59 }
60 }
61
62 fn assertions(&self) {
67 assert!(
68 self.eps > 0.,
69 "eps for PoissonNllLoss must be a positive number."
70 );
71 }
72}
73
74#[derive(Module, Debug, Clone)]
88#[module(custom_display)]
89pub struct PoissonNllLoss {
90 pub log_input: bool,
92 pub full: bool,
94 pub eps: f64,
96}
97
98impl ModuleDisplay for PoissonNllLoss {
99 fn custom_settings(&self) -> Option<DisplaySettings> {
100 DisplaySettings::new()
101 .with_new_line_after_attribute(false)
102 .optional()
103 }
104
105 fn custom_content(&self, content: Content) -> Option<Content> {
106 content
107 .add("log_input", &self.log_input)
108 .add("full", &self.full)
109 .add("eps", &self.eps)
110 .optional()
111 }
112}
113
114impl PoissonNllLoss {
115 pub fn forward<const D: usize, B: Backend>(
133 &self,
134 predictions: Tensor<B, D>,
135 targets: Tensor<B, D>,
136 reduction: Reduction,
137 ) -> Tensor<B, 1> {
138 let loss = self.forward_no_reduction(predictions, targets);
139 match reduction {
140 Reduction::Mean | Reduction::Auto => loss.mean(),
141 Reduction::Sum => loss.sum(),
142 other => panic!("{other:?} reduction is not supported"),
143 }
144 }
145
146 pub fn forward_no_reduction<const D: usize, B: Backend>(
162 &self,
163 predictions: Tensor<B, D>,
164 targets: Tensor<B, D>,
165 ) -> Tensor<B, D> {
166 self.assertions(&predictions, &targets);
167 let mut loss;
168 if self.log_input {
169 loss = predictions.clone().exp() - targets.clone() * predictions;
170 } else {
171 loss = predictions.clone() - targets.clone() * (predictions + self.eps).log();
172 }
173 if self.full {
174 let log_stirling_term = targets.clone() * targets.clone().log() - targets.clone()
175 + (targets.clone() * 2. * PI).log() * 0.5;
176 loss = loss
177 + log_stirling_term
178 .mask_where(targets.clone().lower_equal_elem(1), targets.zeros_like());
179 }
180 loss
181 }
182
183 fn assertions<const D: usize, B: Backend>(
190 &self,
191 predictions: &Tensor<B, D>,
192 targets: &Tensor<B, D>,
193 ) {
194 let predictions_dims = predictions.dims();
195 let targets_dims = targets.dims();
196 assert!(
197 predictions_dims == targets_dims,
198 "Shape of targets ({targets_dims:?}) should correspond to outer shape of predictions ({predictions_dims:?})."
199 );
200 assert!(
201 targets
202 .clone()
203 .greater_equal_elem(0.)
204 .all()
205 .into_scalar()
206 .to_bool(),
207 "All the values of `targets` must be non-negative."
208 );
209 if !self.log_input {
210 assert!(
211 predictions
212 .clone()
213 .greater_equal_elem(0.)
214 .all()
215 .into_scalar()
216 .to_bool(),
217 "When `log_input` is `false`, all the values of `predictions` must be non-negative."
218 );
219 }
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 #![allow(clippy::approx_constant)]
226
227 use super::*;
228 use crate::TestBackend;
229 use burn::tensor::TensorData;
230 type TestTensor<const D: usize> = Tensor<TestBackend, D>;
231 use burn::tensor::{Tolerance, ops::FloatElem};
232 type FT = FloatElem<TestBackend>;
233
234 #[test]
235 fn test_poisson_nll_loss() {
236 let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
237 let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
238
239 let device = Default::default();
240
241 let predictions = TestTensor::<1>::from_data(predictions, &device);
242 let targets = TestTensor::<1>::from_data(targets, &device);
243
244 let poisson = PoissonNllLossConfig::new().init();
245
246 let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
247 let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
248 let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);
249
250 let expected = TensorData::from([1.0000, 1.0000, 100.0000, 2.7183, 7.3891, 14.0855]);
251 loss_no_reduction
252 .into_data()
253 .assert_approx_eq::<FT>(&expected, Tolerance::default());
254
255 let expected = TensorData::from([21.0321]);
256 loss.into_data()
257 .assert_approx_eq::<FT>(&expected, Tolerance::default());
258
259 let expected = TensorData::from([126.1929]);
260 loss_sum
261 .into_data()
262 .assert_approx_eq::<FT>(&expected, Tolerance::default());
263 }
264
265 #[test]
266 fn test_poisson_nll_loss_no_log_input() {
267 let predictions = TensorData::from([0.0, 0.5, 1.0, 1.0, 2.71828, 7.38905, 20.0855]);
268 let targets = TensorData::from([2., 3., 1., 4.5, 0., 0., 2.]);
269
270 let device = Default::default();
271
272 let predictions = TestTensor::<1>::from_data(predictions, &device);
273 let targets = TestTensor::<1>::from_data(targets, &device);
274
275 let poisson = PoissonNllLossConfig::new().with_log_input(false).init();
276
277 let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone());
278
279 let expected = TensorData::from([36.84136, 2.579441, 1.0, 1.0, 2.71828, 7.38905, 14.0855]);
280 loss_no_reduction
281 .into_data()
282 .assert_approx_eq::<FT>(&expected, Tolerance::default());
283 }
284
285 #[test]
286 fn test_poisson_nll_loss_full() {
287 let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
288 let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
289
290 let device = Default::default();
291
292 let predictions = TestTensor::<1>::from_data(predictions, &device);
293 let targets = TestTensor::<1>::from_data(targets, &device);
294
295 let poisson = PoissonNllLossConfig::new().with_full(true).init();
296
297 let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
298 let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
299 let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);
300
301 let expected = TensorData::from([1.0000, 4.9393, 101.1678, 2.7183, 7.3891, 14.7373]);
302 loss_no_reduction
303 .into_data()
304 .assert_approx_eq::<FT>(&expected, Tolerance::default());
305
306 let expected = TensorData::from([21.9920]);
307 loss.into_data()
308 .assert_approx_eq::<FT>(&expected, Tolerance::default());
309
310 let expected = TensorData::from([131.9518]);
311 loss_sum
312 .into_data()
313 .assert_approx_eq::<FT>(&expected, Tolerance::default());
314 }
315
316 #[cfg(feature = "std")]
317 #[test]
318 fn test_poisson_nll_loss_gradients() {
319 type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
320
321 let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
322 let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
323
324 let device = Default::default();
325
326 let predictions1 = TestAutodiffTensor::from_data(predictions, &device).require_grad();
327 let predictions2 = predictions1.clone();
328 let targets = TestAutodiffTensor::from_data(targets, &device);
329
330 let poisson = PoissonNllLossConfig::new().with_full(false).init();
331 let poisson_full = PoissonNllLossConfig::new().with_full(true).init();
332
333 let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum);
334 let loss_full_sum =
335 poisson_full.forward(predictions2.clone(), targets.clone(), Reduction::Sum);
336
337 let grads = loss_sum.backward();
338 let grads_full = loss_full_sum.backward();
339
340 let grads_predictions1 = predictions1.grad(&grads).unwrap();
341 let grads_predictions2 = predictions2.grad(&grads_full).unwrap();
342
343 let expected = TensorData::from([0.0000, -3.5000, -2.5000, 2.7183, 7.3891, 18.0855]);
344
345 grads_predictions1
346 .into_data()
347 .assert_approx_eq::<FT>(&expected, Tolerance::default());
348 grads_predictions2
349 .into_data()
350 .assert_approx_eq::<FT>(&expected, Tolerance::default());
351 }
352
353 #[test]
354 #[should_panic = "eps for PoissonNllLoss must be a positive number."]
355 fn test_negative_eps() {
356 let _poisson = PoissonNllLossConfig::new().with_eps(0.).init();
357 }
358
359 #[test]
360 #[should_panic = "All the values of `targets` must be non-negative."]
361 fn test_targets_with_negative_values() {
362 let predictions = TensorData::from([0., 0., -40., 1., 2., 3., 4.]);
363 let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2., -0.42]);
364
365 let device = Default::default();
366
367 let predictions = TestTensor::<1>::from_data(predictions, &device);
368 let targets = TestTensor::<1>::from_data(targets, &device);
369
370 let poisson = PoissonNllLossConfig::new().init();
371
372 let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
373 }
374
375 #[test]
376 #[should_panic = "Shape of targets"]
377 fn test_shape_tensors() {
378 let predictions = TensorData::from([0., 1., 2.]);
379 let targets = TensorData::from([0., 1.]);
380
381 let device = Default::default();
382
383 let predictions = TestTensor::<1>::from_data(predictions, &device);
384 let targets = TestTensor::<1>::from_data(targets, &device);
385
386 let poisson = PoissonNllLossConfig::new().init();
387
388 let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
389 }
390
391 #[test]
392 #[should_panic = "When `log_input` is `false`, all the values of `predictions` must be non-negative."]
393 fn test_exp_predictions_non_negative() {
394 let predictions = TensorData::from([0.3, -0.1, 0.4]);
395 let targets = TensorData::from([0., 1., 0.]);
396
397 let device = Default::default();
398
399 let predictions = TestTensor::<1>::from_data(predictions, &device);
400 let targets = TestTensor::<1>::from_data(targets, &device);
401
402 let poisson = PoissonNllLossConfig::new().with_log_input(false).init();
403
404 let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
405 }
406
407 #[test]
408 fn display() {
409 let config = PoissonNllLossConfig::new();
410 let loss = config.init();
411
412 assert_eq!(
413 alloc::format!("{loss}"),
414 "PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}"
415 );
416 }
417}