use linfa::Float;
use ndarray::Zip;
use ndarray::{Array1, ArrayView1};
use crate::error::{LinearError, Result};
#[derive(Debug, Clone, PartialEq)]
pub struct TweedieDistribution<F: Float> {
power: F,
lower_bound: F,
inclusive: bool,
}
impl<F: Float> TweedieDistribution<F> {
pub fn new(power: F) -> Result<Self, F> {
let dist = match power {
power if power <= F::zero() => Self {
power,
lower_bound: F::neg_infinity(),
inclusive: false,
},
power if (power > F::zero() && power < F::one()) => {
return Err(LinearError::InvalidTweediePower(power));
}
power if (F::one()..F::cast(2.0)).contains(&power) => Self {
power,
lower_bound: F::zero(),
inclusive: true,
},
power if power >= F::cast(2.0) => Self {
power,
lower_bound: F::zero(),
inclusive: false,
},
_ => unreachable!(),
};
Ok(dist)
}
pub fn in_range(&self, y: &ArrayView1<F>) -> bool {
if self.inclusive {
return y.iter().all(|&x| x >= self.lower_bound);
}
y.iter().all(|&x| x > self.lower_bound)
}
fn unit_variance(&self, ypred: ArrayView1<F>) -> Array1<F> {
ypred.mapv(|x| x.powf(self.power))
}
fn unit_deviance(&self, y: ArrayView1<F>, ypred: ArrayView1<F>) -> Result<Array1<F>, F> {
match self.power {
power if power.is_negative() => {
let mut left = y.mapv(|x| x.max(F::zero()));
left.mapv_inplace(|x| {
x.powf(F::cast(2.) - self.power)
/ ((F::one() - self.power) * (F::cast(2.) - self.power))
});
let middle =
&y * &ypred.mapv(|x| x.powf(F::cast(1.) - self.power) / (F::cast(1.) - power));
let right =
ypred.mapv(|x| x.powf(F::cast(2.) - self.power) / (F::cast(2.) - self.power));
Ok((left - middle + right).mapv(|x| F::cast(2.) * x))
}
power if power == F::zero() => Ok((&y - &ypred).mapv(|x| x * x)),
power if power < F::one() => Err(LinearError::InvalidTweediePower(power)),
power if (power - F::one()).abs() < F::cast(1e-6) => {
let mut div = &y / &ypred;
Zip::from(&mut div).and(y).for_each(|y, &x| {
if x == F::zero() {
*y = F::zero();
} else {
*y = F::cast(2.) * (x * y.ln());
}
});
Ok(div - y + ypred)
}
power if (power - F::cast(2.)).abs() < F::cast(1e-6) => {
let mut temp = (&ypred / &y).mapv(|x| x.ln()) + (&y / &ypred);
temp.mapv_inplace(|x| x - F::one());
Ok(temp.mapv(|x| F::cast(2.) * x))
}
power => {
let left = y.mapv(|x| {
x.powf(F::cast(2.) - power) / ((F::one() - power) * (F::cast(2.) - power))
});
let middle = &y * &ypred.mapv(|x| x.powf(F::one() - power) / (F::one() - power));
let right = ypred.mapv(|x| x.powf(F::cast(2.) - power) / (F::cast(2.) - power));
Ok((left - middle + right).mapv(|x| F::cast(2.) * x))
}
}
}
fn unit_deviance_derivative(&self, y: ArrayView1<F>, ypred: ArrayView1<F>) -> Array1<F> {
((&y - &ypred) / &self.unit_variance(ypred)).mapv(|x| F::cast(-2.) * x)
}
pub fn deviance(&self, y: ArrayView1<F>, ypred: ArrayView1<F>) -> Result<F, F> {
Ok(self.unit_deviance(y, ypred)?.sum())
}
pub fn deviance_derivative(&self, y: ArrayView1<F>, ypred: ArrayView1<F>) -> Array1<F> {
self.unit_deviance_derivative(y, ypred)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn autotraits() {
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
has_autotraits::<TweedieDistribution<f64>>();
}
#[test]
fn test_distribution_error() {
let tweedie = TweedieDistribution::new(0.2);
assert!(tweedie.is_err());
}
macro_rules! test_bounds {
($($name:ident: ($dist:expr, $input:expr, $expected:expr),)*) => {
$(
#[test]
#[allow(clippy::bool_assert_comparison)]
fn $name() {
let output = $dist.in_range(&$input.view());
assert_eq!(output, $expected);
}
)*
};
}
test_bounds! {
test_bounds_normal: (TweedieDistribution::new(0.).unwrap(), array![-1., 0., 1.], true),
test_bounds_poisson1: (TweedieDistribution::new(1.).unwrap(), array![-1., 0., 1.], false),
test_bounds_poisson2: (TweedieDistribution::new(1.).unwrap(), array![0., 1., 2.], true),
test_bounds_tweedie1: (TweedieDistribution::new(1.5).unwrap(), array![-1., 0., 1.], false),
test_bounds_tweedie2: (TweedieDistribution::new(1.5).unwrap(), array![0., 1., 4.], true),
test_bounds_gamma1: (TweedieDistribution::new(2.).unwrap(), array![-1., 0., 1.], false),
test_bounds_gamma2: (TweedieDistribution::new(2.).unwrap(), array![0., 1., 2.], false),
test_bounds_gamma3: (TweedieDistribution::new(2.).unwrap(), array![1., 2., 3.], true),
test_bounds_inverse_gaussian: (TweedieDistribution::new(3.).unwrap(), array![-1., 0., 1.], false),
test_bounds_tweedie3: (TweedieDistribution::new(3.5).unwrap(), array![-1., 0., 1.], false),
}
macro_rules! test_deviance {
($($name:ident: ($dist:expr, $input:expr),)*) => {
$(
#[test]
fn $name() {
let output = $dist.deviance($input.view(), $input.view()).unwrap();
assert_abs_diff_eq!(output, 0.0, epsilon=1e-9);
}
)*
}
}
test_deviance! {
test_deviance_normal: (TweedieDistribution::new(0.).unwrap(), array![-1.5, -0.1, 0.1, 2.5]),
test_deviance_poisson: (TweedieDistribution::new(1.).unwrap(), array![0.1, 1.5]),
test_deviance_gamma: (TweedieDistribution::new(2.).unwrap(), array![0.1, 1.5]),
test_deviance_inverse_gaussian: (TweedieDistribution::new(3.).unwrap(), array![0.1, 1.5]),
test_deviance_tweedie1: (TweedieDistribution::new(-2.5).unwrap(), array![0.1, 1.5]),
test_deviance_tweedie2: (TweedieDistribution::new(-1.).unwrap(), array![0.1, 1.5]),
test_deviance_tweedie3: (TweedieDistribution::new(1.5).unwrap(), array![0.1, 1.5]),
test_deviance_tweedie4: (TweedieDistribution::new(2.5).unwrap(), array![0.1, 1.5]),
test_deviance_tweedie5: (TweedieDistribution::new(-4.).unwrap(), array![0.1, 1.5]),
}
macro_rules! test_deviance_derivative {
($($name:ident: {dist: $dist:expr, y: $y:expr, ypred: $ypred:expr, expected: $expected:expr,},)*) => {
$(
#[test]
fn $name() {
let output = $dist.deviance_derivative($y.view(), $ypred.view());
println!("{:?}", $expected);
println!("{:?}", output);
assert_abs_diff_eq!(output, $expected, epsilon=1e-6);
}
)*
};
}
test_deviance_derivative! {
test_derivative_normal: {
dist: TweedieDistribution::new(0.).unwrap(),
y: array![
0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
1.28521452, 1.35710428, 0.77688304
],
ypred: array![
1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
2.11783437, 2.13526103, 1.64689519
],
expected: array![
1.58345008, 1.05778984, 1.13608912, 1.85119328, 0.14207212, 0.1742586, 0.04043679,
1.66523969, 1.5563135, 1.7400243
],
},
test_derivative_poisson: {
dist: TweedieDistribution::new(1.).unwrap(),
y: array![
0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
1.28521452, 1.35710428, 0.77688304
],
ypred: array![
1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
2.11783437, 2.13526103, 1.64689519
],
expected: array![
0.91318817, 0.64596835, 0.72628385, 0.99317135, 0.15996728, 0.15469509, 0.047503,
0.78629364, 0.72886335, 1.05654829
],
},
test_derivative_gamma: {
dist: TweedieDistribution::new(2.).unwrap(),
y: array![
0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
1.28521452, 1.35710428, 0.77688304
],
ypred: array![
1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
2.11783437, 2.13526103, 1.64689519
],
expected: array![
0.52664283, 0.39447827, 0.46430181, 0.53283973, 0.18011648, 0.13732793, 0.05580401,
0.37127249, 0.34134625, 0.64153949
],
},
test_derivative_inverse_gaussian: {
dist: TweedieDistribution::new(3.).unwrap(),
y: array![
0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
1.28521452, 1.35710428, 0.77688304
],
ypred: array![
1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
2.11783437, 2.13526103, 1.64689519
],
expected: array![
0.30371908, 0.24089896, 0.29682082, 0.28587029, 0.20280364, 0.12191052, 0.06555559,
0.17530761, 0.1598616, 0.38954482
],
},
test_derivative_tweedie1: {
dist: TweedieDistribution::new(-2.5).unwrap(),
y: array![
0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
1.28521452, 1.35710428, 0.77688304
],
ypred: array![
1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
2.11783437, 2.13526103, 1.64689519
],
expected: array![
6.26923606,
3.62969199,
3.47678178,
8.78052969,
0.10560953,
0.23468666,
0.02703435,
10.86942904,
10.36870504,
6.05647896
],
},
test_derivative_tweedie2: {
dist: TweedieDistribution::new(-1.).unwrap(),
y: array![
0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
1.28521452, 1.35710428, 0.77688304
],
ypred: array![
1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
2.11783437, 2.13526103, 1.64689519
],
expected: array![
2.74567086, 1.73215816, 1.77712679, 3.45047865, 0.12617885, 0.1962962, 0.03442171,
3.52670184, 3.32313557, 2.86563764
],
},
test_derivative_tweedie3: {
dist: TweedieDistribution::new(1.5).unwrap(),
y: array![
0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
1.28521452, 1.35710428, 0.77688304
],
ypred: array![
1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
2.11783437, 2.13526103, 1.64689519
],
expected: array![
0.69348684, 0.50479746, 0.58070208, 0.72746214, 0.16974317, 0.14575307, 0.05148648,
0.54030473, 0.49879331, 0.8232967
],
},
test_derivative_tweedie4: {
dist: TweedieDistribution::new(2.5).unwrap(),
y: array![
0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
1.28521452, 1.35710428, 0.77688304
],
ypred: array![
1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
2.11783437, 2.13526103, 1.64689519
],
expected: array![
0.39993934, 0.3082684, 0.37123368, 0.39028586, 0.19112372, 0.1293898, 0.06048359,
0.25512133, 0.23359829, 0.49990837
],
},
test_derivative_tweedie5: {
dist: TweedieDistribution::new(-4.).unwrap(),
y: array![
0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
1.28521452, 1.35710428, 0.77688304
],
ypred: array![
1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
2.11783437, 2.13526103, 1.64689519
],
expected: array![
1.43146513e+01,
7.60592435e+00,
6.80199725e+00,
2.23440599e+01,
8.83933634e-02,
2.80585306e-01,
2.12324135e-02,
3.34999932e+01,
3.23519886e+01,
1.28002707e+01
],
},
}
}