ndarray_glm/response/
binomial.rs1#[cfg(feature = "stats")]
3use crate::response::Response;
4use crate::{
5 error::{RegressionError, RegressionResult},
6 glm::{DispersionType, Glm},
7 math::prod_log,
8 num::Float,
9 response::Yval,
10};
11#[cfg(feature = "stats")]
12use statrs::distribution::Binomial as BinDist;
13
14type BinDom = u16;
16
17pub struct Binomial<const N: BinDom>;
21
22impl<const N: BinDom> Yval<Binomial<N>> for BinDom {
23 fn into_float<F: Float>(self) -> RegressionResult<F, F> {
24 F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string()))
25 }
26}
27
28#[cfg(feature = "stats")]
29impl<const N: BinDom> Response for Binomial<N> {
30 type DistributionType = BinDist;
31
32 fn get_distribution(mu: f64, _phi: f64) -> Self::DistributionType {
33 use num_traits::ToPrimitive;
34
35 let p = mu / N.to_f64().unwrap();
36 BinDist::new(p, N.into()).unwrap()
37 }
38}
39
40impl<const N: BinDom> Glm for Binomial<N> {
41 type Link = link::Logit;
43 const DISPERSED: DispersionType = DispersionType::NoDispersion;
44
45 fn log_partition<F: Float>(nat_par: F) -> F {
48 let n: F = F::from(N).unwrap();
49 n * num_traits::Float::exp(nat_par).ln_1p()
50 }
51
52 fn variance<F: Float>(mean: F) -> F {
53 let n_float: F = F::from(N).unwrap();
54 mean * (n_float - mean) / n_float
55 }
56
57 fn log_like_sat<F: Float>(y: F) -> F {
58 let n: F = F::from(N).unwrap();
59 prod_log(y) + prod_log(n - y) - prod_log(n)
60 }
61}
62
63pub mod link {
64 use super::*;
65 use crate::link::{Canonical, Link};
66 use num_traits::Float;
67
68 pub struct Logit {}
69 impl Canonical for Logit {}
70 impl<const N: BinDom> Link<Binomial<N>> for Logit {
71 fn func<F: Float>(y: F) -> F {
72 let n_float: F = F::from(N).unwrap();
73 Float::ln(y / (n_float - y))
74 }
75 fn func_inv<F: Float>(lin_pred: F) -> F {
76 let n_float: F = F::from(N).unwrap();
77 n_float / (F::one() + (-lin_pred).exp())
78 }
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::Binomial;
85 use crate::{error::RegressionResult, model::ModelBuilder};
86 use approx::assert_abs_diff_eq;
87 use ndarray::array;
88
89 #[test]
90 fn bin_reg() -> RegressionResult<(), f64> {
91 const N: u16 = 12;
92 let ln2 = f64::ln(2.);
93 let beta = array![0., 1.];
94 let data_x = array![[0.], [0.], [ln2], [ln2], [ln2]];
95 let data_y = array![5, 7, 9, 6, 9];
97 let model = ModelBuilder::<Binomial<N>>::data(&data_y, &data_x).build()?;
98 let fit = model.fit()?;
99 dbg!(&fit.result);
100 dbg!(&fit.n_iter);
101 assert_abs_diff_eq!(beta, fit.result, epsilon = 0.05 * f32::EPSILON as f64);
102 Ok(())
103 }
104}