ndarray_glm/response/
poisson.rs

1//! Model for Poisson regression
2
3use crate::{
4    error::{RegressionError, RegressionResult},
5    glm::{DispersionType, Glm},
6    link::Link,
7    math::prod_log,
8    num::Float,
9    response::Response,
10};
11use num_traits::{ToPrimitive, Unsigned};
12use std::marker::PhantomData;
13
14/// Poisson regression over an unsigned integer type.
15pub struct Poisson<L = link::Log>
16where
17    L: Link<Poisson<L>>,
18{
19    _link: PhantomData<L>,
20}
21
22/// Poisson variables can be any unsigned integer.
23impl<U, L> Response<Poisson<L>> for U
24where
25    U: Unsigned + ToPrimitive + ToString + Copy,
26    L: Link<Poisson<L>>,
27{
28    fn into_float<F: Float>(self) -> RegressionResult<F> {
29        F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string()))
30    }
31}
32// TODO: A floating point response for Poisson might also be do-able.
33
34impl<L> Glm for Poisson<L>
35where
36    L: Link<Poisson<L>>,
37{
38    type Link = L;
39    const DISPERSED: DispersionType = DispersionType::NoDispersion;
40
41    /// The logarithm of the partition function for Poisson is the exponential of the natural
42    /// parameter, which is the logarithm of the mean.
43    fn log_partition<F: Float>(nat_par: F) -> F {
44        num_traits::Float::exp(nat_par)
45    }
46
47    /// The variance of a Poisson variable is equal to its mean.
48    fn variance<F: Float>(mean: F) -> F {
49        mean
50    }
51
52    /// The saturation likelihood of the Poisson distribution is non-trivial.
53    /// It is equal to y * (log(y) - 1).
54    fn log_like_sat<F: Float>(y: F) -> F {
55        prod_log(y) - y
56    }
57}
58
59pub mod link {
60    //! Link functions for Poisson regression
61    use super::Poisson;
62    use crate::{
63        link::{Canonical, Link},
64        num::Float,
65    };
66
67    /// The canonical link function of the Poisson response is the logarithm.
68    pub struct Log {}
69    impl Canonical for Log {}
70    impl Link<Poisson<Log>> for Log {
71        fn func<F: Float>(y: F) -> F {
72            num_traits::Float::ln(y)
73        }
74        fn func_inv<F: Float>(lin_pred: F) -> F {
75            num_traits::Float::exp(lin_pred)
76        }
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use crate::{error::RegressionResult, model::ModelBuilder};
84    use approx::assert_abs_diff_eq;
85    use ndarray::{array, Array1};
86
87    #[test]
88    fn poisson_reg() -> RegressionResult<()> {
89        let ln2 = f64::ln(2.);
90        let beta = array![0., ln2, -ln2];
91        let data_x = array![[1., 0.], [1., 1.], [0., 1.], [0., 1.]];
92        let data_y: Array1<u32> = array![2, 1, 0, 1];
93        let model = ModelBuilder::<Poisson>::data(&data_y, &data_x).build()?;
94        let fit = model.fit_options().max_iter(10).fit()?;
95        dbg!(fit.n_iter);
96        assert_abs_diff_eq!(beta, fit.result, epsilon = f32::EPSILON as f64);
97        Ok(())
98    }
99}