Skip to main content

ndarray_glm/response/
linear.rs

1//! Functions for solving linear regression
2
3#[cfg(feature = "stats")]
4use crate::response::Response;
5use crate::{
6    error::{RegressionError, RegressionResult},
7    glm::{DispersionType, Glm},
8    link::Link,
9    num::Float,
10    response::Yval,
11};
12use num_traits::ToPrimitive;
13#[cfg(feature = "stats")]
14use statrs::distribution::Normal;
15use std::marker::PhantomData;
16
17/// Linear regression with constant variance (Ordinary least squares).
18pub struct Linear<L = link::Id>
19where
20    L: Link<Linear<L>>,
21{
22    _link: PhantomData<L>,
23}
24
25/// Allow all floating point types in the linear model.
26impl<Y, L> Yval<Linear<L>> for Y
27where
28    Y: Float + ToPrimitive + ToString,
29    L: Link<Linear<L>>,
30{
31    fn into_float<F: Float>(self) -> RegressionResult<F, F> {
32        F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string()))
33    }
34}
35
36#[cfg(feature = "stats")]
37impl<L> Response for Linear<L>
38where
39    L: Link<Linear<L>>,
40{
41    type DistributionType = Normal;
42
43    fn get_distribution(mu: f64, phi: f64) -> Self::DistributionType {
44        let sigma = phi.sqrt();
45        // We might want to return the error, but we should be able to assume that neither are NaN
46        // and sigma > 0.
47        // TODO: We should probably return an error instead of unwrap()-ing each of these
48        // distributions, because a sigma of zero is possible (e.g. in an underspecified model).
49        // The statrs errors aren't unified so we can't implement a simple #[from] for our error
50        // enum and will need to map_err in each implementation.
51        Normal::new(mu, sigma).unwrap()
52    }
53}
54
55impl<L> Glm for Linear<L>
56where
57    L: Link<Linear<L>>,
58{
59    type Link = L;
60    const DISPERSED: DispersionType = DispersionType::FreeDispersion;
61
62    /// Logarithm of the partition function in terms of the natural parameter,
63    /// which is mu for OLS.
64    fn log_partition<F: Float>(nat_par: F) -> F {
65        let half = F::from(0.5).unwrap();
66        half * nat_par * nat_par
67    }
68
69    /// variance is not a function of the mean in OLS regression.
70    fn variance<F: Float>(_mean: F) -> F {
71        F::one()
72    }
73
74    /// The saturated model likelihood is 0.5*y^2 for each observation. Note
75    /// that if a sum of squares were used for the log-likelihood, this would be
76    /// zero.
77    fn log_like_sat<F: Float>(y: F) -> F {
78        // Only for linear regression does this identity hold.
79        Self::log_partition(y)
80    }
81}
82
83pub mod link {
84    //! Link functions for linear regression.
85    use super::*;
86    use crate::link::{Canonical, Link};
87
88    /// The identity link function, which is canonical for linear regression.
89    pub struct Id;
90    /// The identity is the canonical link function.
91    impl Canonical for Id {}
92    impl Link<Linear> for Id {
93        #[inline]
94        fn func<F: Float>(y: F) -> F {
95            y
96        }
97        #[inline]
98        fn func_inv<F: Float>(lin_pred: F) -> F {
99            lin_pred
100        }
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::Linear;
107    use crate::{error::RegressionResult, model::ModelBuilder};
108    use approx::assert_abs_diff_eq;
109    use ndarray::array;
110
111    #[test]
112    fn lin_reg() -> RegressionResult<(), f64> {
113        let beta = array![0.3, 1.2, -0.5];
114        let data_x = array![[-0.1, 0.2], [0.7, 0.5], [3.2, 0.1]];
115        // let data_x = array![[-0.1, 0.1], [0.7, -0.7], [3.2, -3.2]];
116        let data_y = array![
117            beta[0] + beta[1] * data_x[[0, 0]] + beta[2] * data_x[[0, 1]],
118            beta[0] + beta[1] * data_x[[1, 0]] + beta[2] * data_x[[1, 1]],
119            beta[0] + beta[1] * data_x[[2, 0]] + beta[2] * data_x[[2, 1]],
120        ];
121        let model = ModelBuilder::<Linear>::data(&data_y, &data_x).build()?;
122        let fit = model.fit_options().max_iter(10).fit()?;
123        dbg!(fit.n_iter);
124        // This is failing within the default tolerance
125        assert_abs_diff_eq!(beta, fit.result, epsilon = 64.0 * f64::EPSILON);
126        let lr: f64 = fit.lr_test();
127        dbg!(&lr);
128        dbg!(&lr.sqrt());
129        Ok(())
130    }
131}