ndarray_glm/response/
linear.rs1use crate::{
4 error::{RegressionError, RegressionResult},
5 glm::{DispersionType, Glm},
6 link::Link,
7 num::Float,
8 response::Response,
9};
10use num_traits::ToPrimitive;
11use std::marker::PhantomData;
12
13pub struct Linear<L = link::Id>
15where
16 L: Link<Linear<L>>,
17{
18 _link: PhantomData<L>,
19}
20
21impl<Y, L> Response<Linear<L>> for Y
23where
24 Y: Float + ToPrimitive + ToString,
25 L: Link<Linear<L>>,
26{
27 fn into_float<F: Float>(self) -> RegressionResult<F> {
28 F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string()))
29 }
30}
31
32impl<L> Glm for Linear<L>
33where
34 L: Link<Linear<L>>,
35{
36 type Link = L;
37 const DISPERSED: DispersionType = DispersionType::FreeDispersion;
38
39 fn log_partition<F: Float>(nat_par: F) -> F {
42 let half = F::from(0.5).unwrap();
43 half * nat_par * nat_par
44 }
45
46 fn variance<F: Float>(_mean: F) -> F {
48 F::one()
49 }
50
51 fn log_like_sat<F: Float>(y: F) -> F {
55 Self::log_partition(y)
57 }
58}
59
60pub mod link {
61 use super::*;
63 use crate::link::{Canonical, Link};
64
65 pub struct Id;
67 impl Canonical for Id {}
69 impl Link<Linear> for Id {
70 #[inline]
71 fn func<F: Float>(y: F) -> F {
72 y
73 }
74 #[inline]
75 fn func_inv<F: Float>(lin_pred: F) -> F {
76 lin_pred
77 }
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::Linear;
84 use crate::{error::RegressionResult, model::ModelBuilder};
85 use approx::assert_abs_diff_eq;
86 use ndarray::array;
87
88 #[test]
89 fn lin_reg() -> RegressionResult<()> {
90 let beta = array![0.3, 1.2, -0.5];
91 let data_x = array![[-0.1, 0.2], [0.7, 0.5], [3.2, 0.1]];
92 let data_y = array![
94 beta[0] + beta[1] * data_x[[0, 0]] + beta[2] * data_x[[0, 1]],
95 beta[0] + beta[1] * data_x[[1, 0]] + beta[2] * data_x[[1, 1]],
96 beta[0] + beta[1] * data_x[[2, 0]] + beta[2] * data_x[[2, 1]],
97 ];
98 let model = ModelBuilder::<Linear>::data(&data_y, &data_x).build()?;
99 let fit = model.fit_options().max_iter(10).fit()?;
100 dbg!(fit.n_iter);
101 assert_abs_diff_eq!(beta, fit.result, epsilon = 64.0 * f64::EPSILON);
103 let lr: f64 = fit.lr_test();
104 dbg!(&lr);
105 dbg!(&lr.sqrt());
106 Ok(())
107 }
108}