use burn_core as burn;
use core::f32::consts::PI;
use burn::tensor::cast::ToElement;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::{config::Config, module::Module};
use super::Reduction;
#[derive(Config, Debug)]
pub struct PoissonNllLossConfig {
#[config(default = true)]
pub log_input: bool,
#[config(default = false)]
pub full: bool,
#[config(default = 1e-8)]
pub eps: f64,
}
impl PoissonNllLossConfig {
pub fn init(&self) -> PoissonNllLoss {
self.assertions();
PoissonNllLoss {
log_input: self.log_input,
full: self.full,
eps: self.eps,
}
}
fn assertions(&self) {
assert!(
self.eps > 0.,
"eps for PoissonNllLoss must be a positive number."
);
}
}
#[derive(Module, Debug, Clone)]
#[module(custom_display)]
pub struct PoissonNllLoss {
pub log_input: bool,
pub full: bool,
pub eps: f64,
}
impl ModuleDisplay for PoissonNllLoss {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("log_input", &self.log_input)
.add("full", &self.full)
.add("eps", &self.eps)
.optional()
}
}
impl PoissonNllLoss {
pub fn forward<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
reduction: Reduction,
) -> Tensor<B, 1> {
let loss = self.forward_no_reduction(predictions, targets);
match reduction {
Reduction::Mean | Reduction::Auto => loss.mean(),
Reduction::Sum => loss.sum(),
}
}
pub fn forward_no_reduction<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
) -> Tensor<B, D> {
self.assertions(&predictions, &targets);
let mut loss;
if self.log_input {
loss = predictions.clone().exp() - targets.clone() * predictions;
} else {
loss = predictions.clone() - targets.clone() * (predictions + self.eps).log();
}
if self.full {
let log_stirling_term = targets.clone() * targets.clone().log() - targets.clone()
+ (targets.clone() * 2. * PI).log() * 0.5;
loss = loss
+ log_stirling_term
.mask_where(targets.clone().lower_equal_elem(1), targets.zeros_like());
}
loss
}
fn assertions<const D: usize, B: Backend>(
&self,
predictions: &Tensor<B, D>,
targets: &Tensor<B, D>,
) {
let predictions_dims = predictions.dims();
let targets_dims = targets.dims();
assert!(
predictions_dims == targets_dims,
"Shape of targets ({targets_dims:?}) should correspond to outer shape of predictions ({predictions_dims:?})."
);
assert!(
targets
.clone()
.greater_equal_elem(0.)
.all()
.into_scalar()
.to_bool(),
"All the values of `targets` must be non-negative."
);
if !self.log_input {
assert!(
predictions
.clone()
.greater_equal_elem(0.)
.all()
.into_scalar()
.to_bool(),
"When `log_input` is `false`, all the values of `predictions` must be non-negative."
);
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::approx_constant)]
use super::*;
use crate::TestBackend;
use burn::tensor::TensorData;
type TestTensor<const D: usize> = Tensor<TestBackend, D>;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
#[test]
fn test_poisson_nll_loss() {
let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
let device = Default::default();
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let poisson = PoissonNllLossConfig::new().init();
let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);
let expected = TensorData::from([1.0000, 1.0000, 100.0000, 2.7183, 7.3891, 14.0855]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
let expected = TensorData::from([21.0321]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
let expected = TensorData::from([126.1929]);
loss_sum
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_poisson_nll_loss_no_log_input() {
let predictions = TensorData::from([0.0, 0.5, 1.0, 1.0, 2.71828, 7.38905, 20.0855]);
let targets = TensorData::from([2., 3., 1., 4.5, 0., 0., 2.]);
let device = Default::default();
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let poisson = PoissonNllLossConfig::new().with_log_input(false).init();
let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone());
let expected = TensorData::from([36.84136, 2.579441, 1.0, 1.0, 2.71828, 7.38905, 14.0855]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
fn test_poisson_nll_loss_full() {
let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
let device = Default::default();
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let poisson = PoissonNllLossConfig::new().with_full(true).init();
let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);
let expected = TensorData::from([1.0000, 4.9393, 101.1678, 2.7183, 7.3891, 14.7373]);
loss_no_reduction
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
let expected = TensorData::from([21.9920]);
loss.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
let expected = TensorData::from([131.9518]);
loss_sum
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[cfg(feature = "std")]
#[test]
fn test_poisson_nll_loss_gradients() {
type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
let device = Default::default();
let predictions1 = TestAutodiffTensor::from_data(predictions, &device).require_grad();
let predictions2 = predictions1.clone();
let targets = TestAutodiffTensor::from_data(targets, &device);
let poisson = PoissonNllLossConfig::new().with_full(false).init();
let poisson_full = PoissonNllLossConfig::new().with_full(true).init();
let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum);
let loss_full_sum =
poisson_full.forward(predictions2.clone(), targets.clone(), Reduction::Sum);
let grads = loss_sum.backward();
let grads_full = loss_full_sum.backward();
let grads_predictions1 = predictions1.grad(&grads).unwrap();
let grads_predictions2 = predictions2.grad(&grads_full).unwrap();
let expected = TensorData::from([0.0000, -3.5000, -2.5000, 2.7183, 7.3891, 18.0855]);
grads_predictions1
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
grads_predictions2
.into_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
#[test]
#[should_panic = "eps for PoissonNllLoss must be a positive number."]
fn test_negative_eps() {
let _poisson = PoissonNllLossConfig::new().with_eps(0.).init();
}
#[test]
#[should_panic = "All the values of `targets` must be non-negative."]
fn test_targets_with_negative_values() {
let predictions = TensorData::from([0., 0., -40., 1., 2., 3., 4.]);
let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2., -0.42]);
let device = Default::default();
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let poisson = PoissonNllLossConfig::new().init();
let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
}
#[test]
#[should_panic = "Shape of targets"]
fn test_shape_tensors() {
let predictions = TensorData::from([0., 1., 2.]);
let targets = TensorData::from([0., 1.]);
let device = Default::default();
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let poisson = PoissonNllLossConfig::new().init();
let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
}
#[test]
#[should_panic = "When `log_input` is `false`, all the values of `predictions` must be non-negative."]
fn test_exp_predictions_non_negative() {
let predictions = TensorData::from([0.3, -0.1, 0.4]);
let targets = TensorData::from([0., 1., 0.]);
let device = Default::default();
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let poisson = PoissonNllLossConfig::new().with_log_input(false).init();
let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
}
#[test]
fn display() {
let config = PoissonNllLossConfig::new();
let loss = config.init();
assert_eq!(
alloc::format!("{loss}"),
"PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}"
);
}
}