Skip to main content

burn_nn/loss/
poisson.rs

1use burn_core as burn;
2use core::f32::consts::PI;
3
4use burn::tensor::cast::ToElement;
5
6use burn::module::{Content, DisplaySettings, ModuleDisplay};
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9use burn::{config::Config, module::Module};
10
11use super::Reduction;
12
13/// Configuration for creating a [PoissonNllLoss](PoissonNllLoss) instance.
14///
15/// This configuration allows customization of the Poisson Negative Log Likelihood (NLL) loss
16/// behavior, such as whether the input is in log-space, whether to include the Stirling
17/// approximation term, and a small epsilon value to avoid numerical instability.
18#[derive(Config, Debug)]
19pub struct PoissonNllLossConfig {
20    /// If `true`, the predictions are expected to be in log-space.
21    ///
22    /// When `log_input` is `true`, the loss is computed as:
23    /// ```text
24    /// L(predictions, target) = exp(predictions) - target * predictions
25    /// ```
26    /// When `log_input` is `false`, the loss is computed as:
27    /// ```text
28    /// L(predictions, target) = predictions - target * log(predictions + eps)
29    /// ```
30    #[config(default = true)]
31    pub log_input: bool,
32    /// Whether to compute the full loss, including the Stirling approximation term.
33    ///
34    /// When `full` is `true`, the Stirling approximation term is added to the loss:
35    /// ```text
36    /// target * log(target) - target + 0.5 * log(2 * PI * target)
37    /// ```
38    #[config(default = false)]
39    pub full: bool,
40    /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`.
41    ///
42    /// This epsilon value is added to the predictions to ensure numerical stability
43    /// when computing the logarithm.
44    #[config(default = 1e-8)]
45    pub eps: f64,
46}
47
48impl PoissonNllLossConfig {
49    /// Initializes a [PoissonNllLoss](PoissonNllLoss) instance with the current configuration.
50    ///
51    /// # Panics
52    /// - Panics if `eps` is not a positive number.
53    pub fn init(&self) -> PoissonNllLoss {
54        self.assertions();
55        PoissonNllLoss {
56            log_input: self.log_input,
57            full: self.full,
58            eps: self.eps,
59        }
60    }
61
62    /// Validates the configuration parameters.
63    ///
64    /// # Panics
65    /// - Panics if `eps` is not a positive number.
66    fn assertions(&self) {
67        assert!(
68            self.eps > 0.,
69            "eps for PoissonNllLoss must be a positive number."
70        );
71    }
72}
73
74/// Negative Log Likelihood (NLL) loss with a Poisson distribution assumption for the target.
75///
76/// This loss function is used when the target values are assumed to follow a Poisson distribution.
77/// The loss is defined as:
78/// ```text
79/// target ~ Poisson(input)
80/// L(predictions, target) = predictions - target * log(predictions) + log(target!)
81/// ```
82/// The last term (`log(target!)`) can be omitted or approximated using Stirling's formula.
83/// The approximation is applied for `target > 1`, while for `target <= 1`, zeros are added to the loss.
84///
85/// For more details, see:
86/// <https://en.wikipedia.org/wiki/Poisson_regression#Maximum_likelihood-based_parameter_estimation>
87#[derive(Module, Debug, Clone)]
88#[module(custom_display)]
89pub struct PoissonNllLoss {
90    /// If `true`, the predictions are expected to be in log-space.
91    pub log_input: bool,
92    /// Whether to compute the full loss, including the Stirling approximation term.
93    pub full: bool,
94    /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`.
95    pub eps: f64,
96}
97
98impl ModuleDisplay for PoissonNllLoss {
99    fn custom_settings(&self) -> Option<DisplaySettings> {
100        DisplaySettings::new()
101            .with_new_line_after_attribute(false)
102            .optional()
103    }
104
105    fn custom_content(&self, content: Content) -> Option<Content> {
106        content
107            .add("log_input", &self.log_input)
108            .add("full", &self.full)
109            .add("eps", &self.eps)
110            .optional()
111    }
112}
113
114impl PoissonNllLoss {
115    /// Computes the loss element-wise for the given predictions and targets, then reduces
116    /// the result to a single loss value.
117    ///
118    /// # Arguments
119    /// - `predictions`: The predicted values.
120    /// - `targets`: The target values.
121    /// - `reduction`: The reduction method to apply. `Reduction::Auto` behaves as `Reduction::Mean`.
122    ///
123    /// # Shapes
124    /// - `predictions`: `[...dims]`
125    /// - `targets`: `[...dims]`
126    /// - `output`: `[1]`
127    ///
128    /// # Panics
129    /// - Panics if the shapes of `predictions` and `targets` do not match.
130    /// - Panics if any target value is negative.
131    /// - Panics if `log_input` is `false` and any prediction value is negative.
132    pub fn forward<const D: usize, B: Backend>(
133        &self,
134        predictions: Tensor<B, D>,
135        targets: Tensor<B, D>,
136        reduction: Reduction,
137    ) -> Tensor<B, 1> {
138        let loss = self.forward_no_reduction(predictions, targets);
139        match reduction {
140            Reduction::Mean | Reduction::Auto => loss.mean(),
141            Reduction::Sum => loss.sum(),
142            other => panic!("{other:?} reduction is not supported"),
143        }
144    }
145
146    /// Computes the loss element-wise for the given predictions and targets without reduction.
147    ///
148    /// # Arguments
149    /// - `predictions`: The predicted values.
150    /// - `targets`: The target values.
151    ///
152    /// # Shapes
153    /// - `predictions`: `[...dims]`
154    /// - `targets`: `[...dims]`
155    /// - `output`: `[...dims]`
156    ///
157    /// # Panics
158    /// - Panics if the shapes of `predictions` and `targets` do not match.
159    /// - Panics if any target value is negative.
160    /// - Panics if `log_input` is `false` and any prediction value is negative.
161    pub fn forward_no_reduction<const D: usize, B: Backend>(
162        &self,
163        predictions: Tensor<B, D>,
164        targets: Tensor<B, D>,
165    ) -> Tensor<B, D> {
166        self.assertions(&predictions, &targets);
167        let mut loss;
168        if self.log_input {
169            loss = predictions.clone().exp() - targets.clone() * predictions;
170        } else {
171            loss = predictions.clone() - targets.clone() * (predictions + self.eps).log();
172        }
173        if self.full {
174            let log_stirling_term = targets.clone() * targets.clone().log() - targets.clone()
175                + (targets.clone() * 2. * PI).log() * 0.5;
176            loss = loss
177                + log_stirling_term
178                    .mask_where(targets.clone().lower_equal_elem(1), targets.zeros_like());
179        }
180        loss
181    }
182
183    /// Validates the input tensors for the loss computation.
184    ///
185    /// # Panics
186    /// - Panics if the shapes of `predictions` and `targets` do not match.
187    /// - Panics if any target value is negative.
188    /// - Panics if `log_input` is `false` and any prediction value is negative.
189    fn assertions<const D: usize, B: Backend>(
190        &self,
191        predictions: &Tensor<B, D>,
192        targets: &Tensor<B, D>,
193    ) {
194        let predictions_dims = predictions.dims();
195        let targets_dims = targets.dims();
196        assert!(
197            predictions_dims == targets_dims,
198            "Shape of targets ({targets_dims:?}) should correspond to outer shape of predictions ({predictions_dims:?})."
199        );
200        assert!(
201            targets
202                .clone()
203                .greater_equal_elem(0.)
204                .all()
205                .into_scalar()
206                .to_bool(),
207            "All the values of `targets` must be non-negative."
208        );
209        if !self.log_input {
210            assert!(
211                predictions
212                    .clone()
213                    .greater_equal_elem(0.)
214                    .all()
215                    .into_scalar()
216                    .to_bool(),
217                "When `log_input` is `false`, all the values of `predictions` must be non-negative."
218            );
219        }
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    #![allow(clippy::approx_constant)]
226
227    use super::*;
228    use crate::TestBackend;
229    use burn::tensor::TensorData;
230    type TestTensor<const D: usize> = Tensor<TestBackend, D>;
231    use burn::tensor::{Tolerance, ops::FloatElem};
232    type FT = FloatElem<TestBackend>;
233
234    #[test]
235    fn test_poisson_nll_loss() {
236        let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
237        let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
238
239        let device = Default::default();
240
241        let predictions = TestTensor::<1>::from_data(predictions, &device);
242        let targets = TestTensor::<1>::from_data(targets, &device);
243
244        let poisson = PoissonNllLossConfig::new().init();
245
246        let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
247        let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
248        let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);
249
250        let expected = TensorData::from([1.0000, 1.0000, 100.0000, 2.7183, 7.3891, 14.0855]);
251        loss_no_reduction
252            .into_data()
253            .assert_approx_eq::<FT>(&expected, Tolerance::default());
254
255        let expected = TensorData::from([21.0321]);
256        loss.into_data()
257            .assert_approx_eq::<FT>(&expected, Tolerance::default());
258
259        let expected = TensorData::from([126.1929]);
260        loss_sum
261            .into_data()
262            .assert_approx_eq::<FT>(&expected, Tolerance::default());
263    }
264
265    #[test]
266    fn test_poisson_nll_loss_no_log_input() {
267        let predictions = TensorData::from([0.0, 0.5, 1.0, 1.0, 2.71828, 7.38905, 20.0855]);
268        let targets = TensorData::from([2., 3., 1., 4.5, 0., 0., 2.]);
269
270        let device = Default::default();
271
272        let predictions = TestTensor::<1>::from_data(predictions, &device);
273        let targets = TestTensor::<1>::from_data(targets, &device);
274
275        let poisson = PoissonNllLossConfig::new().with_log_input(false).init();
276
277        let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone());
278
279        let expected = TensorData::from([36.84136, 2.579441, 1.0, 1.0, 2.71828, 7.38905, 14.0855]);
280        loss_no_reduction
281            .into_data()
282            .assert_approx_eq::<FT>(&expected, Tolerance::default());
283    }
284
285    #[test]
286    fn test_poisson_nll_loss_full() {
287        let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
288        let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
289
290        let device = Default::default();
291
292        let predictions = TestTensor::<1>::from_data(predictions, &device);
293        let targets = TestTensor::<1>::from_data(targets, &device);
294
295        let poisson = PoissonNllLossConfig::new().with_full(true).init();
296
297        let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
298        let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
299        let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);
300
301        let expected = TensorData::from([1.0000, 4.9393, 101.1678, 2.7183, 7.3891, 14.7373]);
302        loss_no_reduction
303            .into_data()
304            .assert_approx_eq::<FT>(&expected, Tolerance::default());
305
306        let expected = TensorData::from([21.9920]);
307        loss.into_data()
308            .assert_approx_eq::<FT>(&expected, Tolerance::default());
309
310        let expected = TensorData::from([131.9518]);
311        loss_sum
312            .into_data()
313            .assert_approx_eq::<FT>(&expected, Tolerance::default());
314    }
315
316    #[cfg(feature = "std")]
317    #[test]
318    fn test_poisson_nll_loss_gradients() {
319        type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
320
321        let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
322        let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
323
324        let device = Default::default();
325
326        let predictions1 = TestAutodiffTensor::from_data(predictions, &device).require_grad();
327        let predictions2 = predictions1.clone();
328        let targets = TestAutodiffTensor::from_data(targets, &device);
329
330        let poisson = PoissonNllLossConfig::new().with_full(false).init();
331        let poisson_full = PoissonNllLossConfig::new().with_full(true).init();
332
333        let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum);
334        let loss_full_sum =
335            poisson_full.forward(predictions2.clone(), targets.clone(), Reduction::Sum);
336
337        let grads = loss_sum.backward();
338        let grads_full = loss_full_sum.backward();
339
340        let grads_predictions1 = predictions1.grad(&grads).unwrap();
341        let grads_predictions2 = predictions2.grad(&grads_full).unwrap();
342
343        let expected = TensorData::from([0.0000, -3.5000, -2.5000, 2.7183, 7.3891, 18.0855]);
344
345        grads_predictions1
346            .into_data()
347            .assert_approx_eq::<FT>(&expected, Tolerance::default());
348        grads_predictions2
349            .into_data()
350            .assert_approx_eq::<FT>(&expected, Tolerance::default());
351    }
352
353    #[test]
354    #[should_panic = "eps for PoissonNllLoss must be a positive number."]
355    fn test_negative_eps() {
356        let _poisson = PoissonNllLossConfig::new().with_eps(0.).init();
357    }
358
359    #[test]
360    #[should_panic = "All the values of `targets` must be non-negative."]
361    fn test_targets_with_negative_values() {
362        let predictions = TensorData::from([0., 0., -40., 1., 2., 3., 4.]);
363        let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2., -0.42]);
364
365        let device = Default::default();
366
367        let predictions = TestTensor::<1>::from_data(predictions, &device);
368        let targets = TestTensor::<1>::from_data(targets, &device);
369
370        let poisson = PoissonNllLossConfig::new().init();
371
372        let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
373    }
374
375    #[test]
376    #[should_panic = "Shape of targets"]
377    fn test_shape_tensors() {
378        let predictions = TensorData::from([0., 1., 2.]);
379        let targets = TensorData::from([0., 1.]);
380
381        let device = Default::default();
382
383        let predictions = TestTensor::<1>::from_data(predictions, &device);
384        let targets = TestTensor::<1>::from_data(targets, &device);
385
386        let poisson = PoissonNllLossConfig::new().init();
387
388        let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
389    }
390
391    #[test]
392    #[should_panic = "When `log_input` is `false`, all the values of `predictions` must be non-negative."]
393    fn test_exp_predictions_non_negative() {
394        let predictions = TensorData::from([0.3, -0.1, 0.4]);
395        let targets = TensorData::from([0., 1., 0.]);
396
397        let device = Default::default();
398
399        let predictions = TestTensor::<1>::from_data(predictions, &device);
400        let targets = TestTensor::<1>::from_data(targets, &device);
401
402        let poisson = PoissonNllLossConfig::new().with_log_input(false).init();
403
404        let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
405    }
406
407    #[test]
408    fn display() {
409        let config = PoissonNllLossConfig::new();
410        let loss = config.init();
411
412        assert_eq!(
413            alloc::format!("{loss}"),
414            "PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}"
415        );
416    }
417}