ndarray_glm/response/
linear.rs

1//! Functions for solving linear regression
2
3use crate::{
4    error::{RegressionError, RegressionResult},
5    glm::{DispersionType, Glm},
6    link::Link,
7    num::Float,
8    response::Response,
9};
10use num_traits::ToPrimitive;
11use std::marker::PhantomData;
12
13/// Linear regression with constant variance (Ordinary least squares).
14pub struct Linear<L = link::Id>
15where
16    L: Link<Linear<L>>,
17{
18    _link: PhantomData<L>,
19}
20
21/// Allow all floating point types in the linear model.
22impl<Y, L> Response<Linear<L>> for Y
23where
24    Y: Float + ToPrimitive + ToString,
25    L: Link<Linear<L>>,
26{
27    fn into_float<F: Float>(self) -> RegressionResult<F> {
28        F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string()))
29    }
30}
31
32impl<L> Glm for Linear<L>
33where
34    L: Link<Linear<L>>,
35{
36    type Link = L;
37    const DISPERSED: DispersionType = DispersionType::FreeDispersion;
38
39    /// Logarithm of the partition function in terms of the natural parameter,
40    /// which is mu for OLS.
41    fn log_partition<F: Float>(nat_par: F) -> F {
42        let half = F::from(0.5).unwrap();
43        half * nat_par * nat_par
44    }
45
46    /// variance is not a function of the mean in OLS regression.
47    fn variance<F: Float>(_mean: F) -> F {
48        F::one()
49    }
50
51    /// The saturated model likelihood is 0.5*y^2 for each observation. Note
52    /// that if a sum of squares were used for the log-likelihood, this would be
53    /// zero.
54    fn log_like_sat<F: Float>(y: F) -> F {
55        // Only for linear regression does this identity hold.
56        Self::log_partition(y)
57    }
58}
59
60pub mod link {
61    //! Link functions for linear regression.
62    use super::*;
63    use crate::link::{Canonical, Link};
64
65    /// The identity link function, which is canonical for linear regression.
66    pub struct Id;
67    /// The identity is the canonical link function.
68    impl Canonical for Id {}
69    impl Link<Linear> for Id {
70        #[inline]
71        fn func<F: Float>(y: F) -> F {
72            y
73        }
74        #[inline]
75        fn func_inv<F: Float>(lin_pred: F) -> F {
76            lin_pred
77        }
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::Linear;
84    use crate::{error::RegressionResult, model::ModelBuilder};
85    use approx::assert_abs_diff_eq;
86    use ndarray::array;
87
88    #[test]
89    fn lin_reg() -> RegressionResult<()> {
90        let beta = array![0.3, 1.2, -0.5];
91        let data_x = array![[-0.1, 0.2], [0.7, 0.5], [3.2, 0.1]];
92        // let data_x = array![[-0.1, 0.1], [0.7, -0.7], [3.2, -3.2]];
93        let data_y = array![
94            beta[0] + beta[1] * data_x[[0, 0]] + beta[2] * data_x[[0, 1]],
95            beta[0] + beta[1] * data_x[[1, 0]] + beta[2] * data_x[[1, 1]],
96            beta[0] + beta[1] * data_x[[2, 0]] + beta[2] * data_x[[2, 1]],
97        ];
98        let model = ModelBuilder::<Linear>::data(&data_y, &data_x).build()?;
99        let fit = model.fit_options().max_iter(10).fit()?;
100        dbg!(fit.n_iter);
101        // This is failing within the default tolerance
102        assert_abs_diff_eq!(beta, fit.result, epsilon = 64.0 * f64::EPSILON);
103        let lr: f64 = fit.lr_test();
104        dbg!(&lr);
105        dbg!(&lr.sqrt());
106        Ok(())
107    }
108}