Skip to main content

ndarray_glm/response/
exponential.rs

1//! Exponential response in that y is drawn from the exponential distribution.
2
3#[cfg(feature = "stats")]
4use crate::response::Response;
5use crate::{
6    error::{RegressionError, RegressionResult},
7    glm::{DispersionType, Glm},
8    link::Link,
9    num::Float,
10    response::Yval,
11};
12use ndarray::Array1;
13#[cfg(feature = "stats")]
14use statrs::distribution::Exp;
15use std::marker::PhantomData;
16
17/// Exponential regression
18pub struct Exponential<L = link::NegRec>
19where
20    L: Link<Exponential<L>>,
21{
22    _link: PhantomData<L>,
23}
24
25// Allow floats for the domain. We can't use num_traits::Float because of the
26// possibility of conflicting implementations upstream, so manually implement
27// for f32 and f64. Note that for exponential regression, y=0 is invalid.
28impl<L> Yval<Exponential<L>> for f32
29where
30    L: Link<Exponential<L>>,
31{
32    fn into_float<F: Float>(self) -> RegressionResult<F, F> {
33        if self <= 0. {
34            return Err(RegressionError::InvalidY(self.to_string()));
35        }
36        F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string()))
37    }
38}
39impl<L> Yval<Exponential<L>> for f64
40where
41    L: Link<Exponential<L>>,
42{
43    fn into_float<F: Float>(self) -> RegressionResult<F, F> {
44        if self <= 0. {
45            return Err(RegressionError::InvalidY(self.to_string()));
46        }
47        F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string()))
48    }
49}
50
51#[cfg(feature = "stats")]
52impl<L> Response for Exponential<L>
53where
54    L: Link<Exponential<L>>,
55{
56    type DistributionType = Exp;
57
58    fn get_distribution(mu: f64, _phi: f64) -> Self::DistributionType {
59        // NOTE: Negative mu is a realistic concern for exponential regression, since the canonical
60        // link function does not prevent them. Without complicating the return type, either with
61        // dynamic dispatch or an enum between Exp and Dirac that would have to forward every
62        // Distribution<f64> and ContinuousCDF<f64> calls, the simplest way to ensure μ > 0 is
63        // to clamp at the lowest positive value (~2e-308).
64        // Exp::new(rate) where rate = 1/mu, since statrs parameterizes by rate (mean = 1/rate).
65        Exp::new(mu.max(f64::MIN_POSITIVE).recip()).unwrap()
66    }
67}
68
69/// Implementation of GLM functionality for exponential regression.
70impl<L> Glm for Exponential<L>
71where
72    L: Link<Exponential<L>>,
73{
74    type Link = L;
75    const DISPERSED: DispersionType = DispersionType::NoDispersion;
76
77    /// The log-partition function $`A(\eta)`$ for the exponential family, expressed in terms
78    /// of the canonical natural parameter $`\eta = -1/\mu`$:
79    ///
80    /// ```math
81    /// A(\eta) = -\ln(-\eta)
82    /// ```
83    fn log_partition<F: Float>(nat_par: F) -> F {
84        -num_traits::Float::ln(-nat_par)
85    }
86
87    /// The variance function $`V(\mu) = \mu^2`$, equal to $`A''(\eta)`$ evaluated at
88    /// $`\eta = -1/\mu`$.
89    fn variance<F: Float>(mean: F) -> F {
90        mean * mean
91    }
92
93    /// The saturated likelihood is -1 - log(y). This shows part of why exponential regression
94    /// can't deal with y=0.
95    fn log_like_sat<F: Float>(y: F) -> F {
96        -(F::one() + num_traits::Float::ln(y))
97    }
98}
99
100pub mod link {
101    //! Link functions for exponential regression
102    use super::*;
103    use crate::link::{Canonical, Link, Transform};
104    use crate::num::Float;
105
106    /// The canonical link function for exponential regression is the negative reciprocal
107    /// $`\eta = -1/mu`$. This fails to prevent negative predicted y-values.
108    pub struct NegRec {}
109    impl Canonical for NegRec {}
110    impl Link<Exponential<NegRec>> for NegRec {
111        fn func<F: Float>(y: F) -> F {
112            -num_traits::Float::recip(y)
113        }
114        fn func_inv<F: Float>(lin_pred: F) -> F {
115            -num_traits::Float::recip(lin_pred)
116        }
117    }
118
119    /// The log link $`g(\mu) = \log(\mu)`$ avoids linear predictors that give negative
120    /// expectations.
121    pub struct Log {}
122    impl Link<Exponential<Log>> for Log {
123        fn func<F: Float>(y: F) -> F {
124            num_traits::Float::ln(y)
125        }
126        fn func_inv<F: Float>(lin_pred: F) -> F {
127            num_traits::Float::exp(lin_pred)
128        }
129    }
130    impl Transform for Log {
131        fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F> {
132            lin_pred.mapv(|x| -num_traits::Float::exp(-x))
133        }
134        fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F> {
135            lin_pred.mapv(|x| num_traits::Float::exp(-x))
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::{error::RegressionResult, model::ModelBuilder};
144    use approx::assert_abs_diff_eq;
145    use ndarray::array;
146
147    /// A simple test where the correct value for the data is known exactly.
148    ///
149    /// With the canonical NegRec link, the MLE satisfies β₀ = -1/ȳ₀ and β₀+β₁ = -1/ȳ₁,
150    /// where ȳ₀ and ȳ₁ are the within-group sample means. Choosing group means 2 and 4
151    /// gives β = [-0.5, 0.25], both exactly representable in f64.
152    #[test]
153    fn exp_ex() -> RegressionResult<(), f64> {
154        // Group 0 (x=0): y ∈ {1, 3}, ȳ₀ = 2  → β₀       = -1/2 = -0.5
155        // Group 1 (x=1): y ∈ {2, 4, 6}, ȳ₁ = 4 → β₀ + β₁ = -1/4, β₁ = 0.25
156        let beta = array![-0.5, 0.25];
157        let data_x = array![[0.], [0.], [1.0], [1.0], [1.0]];
158        let data_y = array![1.0, 3.0, 2.0, 4.0, 6.0];
159        let model = ModelBuilder::<Exponential>::data(&data_y, &data_x).build()?;
160        let fit = model.fit()?;
161        assert_abs_diff_eq!(beta, fit.result, epsilon = 0.5 * f32::EPSILON as f64);
162        let _cov = fit.covariance()?;
163        Ok(())
164    }
165
166    /// Analogous test using the Log link. With g(μ) = log(μ), the MLE satisfies
167    /// β₀ = log(ȳ₀) and β₁ = log(ȳ₁/ȳ₀). Same group data as exp_ex gives
168    /// ȳ₀=2, ȳ₁=4, so β = [ln 2, ln 2].
169    #[test]
170    fn exp_log_link_ex() -> RegressionResult<(), f64> {
171        // Group 0 (x=0): y ∈ {1, 3}, ȳ₀ = 2 → β₀      = ln(2)
172        // Group 1 (x=1): y ∈ {2, 4, 6}, ȳ₁ = 4 → β₀+β₁ = ln(4), β₁ = ln(2)
173        let ln2 = f64::ln(2.);
174        let beta = array![ln2, ln2];
175        let data_x = array![[0.], [0.], [1.0], [1.0], [1.0]];
176        let data_y = array![1.0, 3.0, 2.0, 4.0, 6.0];
177        let model = ModelBuilder::<Exponential<link::Log>>::data(&data_y, &data_x).build()?;
178        let fit = model.fit()?;
179        assert_abs_diff_eq!(beta, fit.result, epsilon = 0.5 * f32::EPSILON as f64);
180        let _cov = fit.covariance()?;
181        Ok(())
182    }
183
184    #[test]
185    // Confirm inverse reciprocal closure.
186    fn neg_rec_closure() {
187        use super::link::NegRec;
188        use crate::link::TestLink;
189        // Note that the positive values aren't good linear predictor values, but they should be
190        // closed under the canonical transformation anyway.
191        let x = array![-360., -12., -5., -1.0, -0.002, 0., 0.5, 20.];
192        NegRec::check_closure(&x);
193        let y = array![1e-5, 0.25, 0.8, 2.5, 10., 256.];
194        NegRec::check_closure_y(&y);
195    }
196
197    // verify closure for the log link.
198    #[test]
199    fn log_closure() {
200        use crate::link::TestLink;
201        use link::Log;
202        let mu_test_vals = array![1e-8, 0.01, 0.1, 0.3, 0.9, 1.8, 4.2, 148.];
203        Log::check_closure_y(&mu_test_vals);
204        let lin_test_vals = array![1e-8, 0.002, 0.5, 2.4, 15., 120.];
205        Log::check_closure(&lin_test_vals);
206    }
207
208    #[test]
209    fn log_nat_par() {
210        use crate::link::TestLink;
211        use link::Log;
212        // nat_param(ω) = -exp(-ω) = g_0(g^{-1}(ω)) = NegRec(exp(ω)) = -1/exp(ω)
213        let lin_test_vals = array![-10., -2., -0.5, 0.0, 0.5, 2., 10.];
214        Log::check_nat_par::<Exponential<link::NegRec>>(&lin_test_vals);
215        Log::check_nat_par_d(&lin_test_vals);
216    }
217}