ndarray_glm/response/
inverse_gaussian.rs1use crate::{
4 error::{RegressionError, RegressionResult},
5 glm::{DispersionType, Glm},
6 link::Link,
7 num::Float,
8 response::Yval,
9};
10use ndarray::Array1;
11use std::marker::PhantomData;
12
13pub struct InvGaussian<L = link::NegRecSq>
15where
16 L: Link<InvGaussian<L>>,
17{
18 _link: PhantomData<L>,
19}
20
21impl<L> Yval<InvGaussian<L>> for f32
23where
24 L: Link<InvGaussian<L>>,
25{
26 fn into_float<F: Float>(self) -> RegressionResult<F, F> {
27 if self <= 0. {
28 return Err(RegressionError::InvalidY(self.to_string()));
29 }
30 F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string()))
31 }
32}
33impl<L> Yval<InvGaussian<L>> for f64
34where
35 L: Link<InvGaussian<L>>,
36{
37 fn into_float<F: Float>(self) -> RegressionResult<F, F> {
38 if self <= 0. {
39 return Err(RegressionError::InvalidY(self.to_string()));
40 }
41 F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string()))
42 }
43}
44
45impl<L> Glm for InvGaussian<L>
50where
51 L: Link<InvGaussian<L>>,
52{
53 type Link = L;
54 const DISPERSED: DispersionType = DispersionType::FreeDispersion;
55
56 fn log_partition<F: Float>(nat_par: F) -> F {
63 -num_traits::Float::sqrt(-F::two() * nat_par)
64 }
65
66 fn variance<F: Float>(mean: F) -> F {
69 mean * mean * mean
70 }
71
72 fn log_like_sat<F: Float>(y: F) -> F {
75 F::half() * num_traits::Float::recip(y)
76 }
77}
78
79pub mod link {
80 use super::*;
82 use crate::link::{Canonical, Link, Transform};
83 use crate::num::Float;
84
85 pub struct NegRecSq {}
88 impl Canonical for NegRecSq {}
89 impl Link<InvGaussian<NegRecSq>> for NegRecSq {
90 fn func<F: Float>(y: F) -> F {
91 -num_traits::Float::recip(F::two() * y * y)
92 }
93 fn func_inv<F: Float>(lin_pred: F) -> F {
94 num_traits::Float::recip(num_traits::Float::sqrt(-F::two() * lin_pred))
95 }
96 }
97
98 pub struct Log {}
101 impl Link<InvGaussian<Log>> for Log {
102 fn func<F: Float>(y: F) -> F {
103 num_traits::Float::ln(y)
104 }
105 fn func_inv<F: Float>(lin_pred: F) -> F {
106 num_traits::Float::exp(lin_pred)
107 }
108 }
109 impl Transform for Log {
110 fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F> {
111 lin_pred.mapv(|x| -F::half() * num_traits::Float::exp(-F::two() * x))
113 }
114 fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F> {
115 lin_pred.mapv(|x| num_traits::Float::exp(-F::two() * x))
117 }
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::{error::RegressionResult, model::ModelBuilder};
125 use approx::assert_abs_diff_eq;
126 use ndarray::array;
127
128 #[test]
134 fn ig_ex() -> RegressionResult<(), f64> {
135 let beta = array![-0.125, 0.09375];
138 let data_x = array![[0.], [0.], [1.0], [1.0], [1.0]];
139 let data_y = array![1.0, 3.0, 2.0, 4.0, 6.0];
140 let model = ModelBuilder::<InvGaussian>::data(&data_y, &data_x).build()?;
141 let fit = model.fit()?;
142 assert_abs_diff_eq!(beta, fit.result, epsilon = 0.5 * f32::EPSILON as f64);
143 let _cov = fit.covariance()?;
144 Ok(())
145 }
146
147 #[test]
151 fn ig_log_link_ex() -> RegressionResult<(), f64> {
152 let ln2 = f64::ln(2.);
155 let beta = array![ln2, ln2];
156 let data_x = array![[0.], [0.], [1.0], [1.0], [1.0]];
157 let data_y = array![1.0, 3.0, 2.0, 4.0, 6.0];
158 let model = ModelBuilder::<InvGaussian<link::Log>>::data(&data_y, &data_x).build()?;
159 let fit = model.fit()?;
160 assert_abs_diff_eq!(beta, fit.result, epsilon = 0.5 * f32::EPSILON as f64);
161 let _cov = fit.covariance()?;
162 Ok(())
163 }
164
165 #[test]
166 fn neg_rec_sq_closure() {
168 use super::link::NegRecSq;
169 use crate::link::TestLink;
170 let x = array![-360., -12., -5., -1.0, -1e-4];
172 NegRecSq::check_closure(&x);
173 let y = array![1e-5, 0.25, 0.8, 2.5, 10., 256.];
174 NegRecSq::check_closure_y(&y);
175 }
176
177 #[test]
179 fn log_closure() {
180 use crate::link::TestLink;
181 use link::Log;
182 let mu_test_vals = array![1e-8, 0.01, 0.1, 0.3, 0.9, 1.8, 4.2, 148.];
183 Log::check_closure_y(&mu_test_vals);
184 let lin_test_vals = array![1e-8, 0.002, 0.5, 2.4, 15., 120.];
185 Log::check_closure(&lin_test_vals);
186 }
187
188 #[test]
189 fn log_nat_par() {
190 use crate::link::TestLink;
191 use link::Log;
192 let lin_test_vals = array![-10., -2., -0.5, 0.0, 0.5, 2., 10.];
194 Log::check_nat_par::<InvGaussian<link::NegRecSq>>(&lin_test_vals);
195 Log::check_nat_par_d(&lin_test_vals);
196 }
197}