ndarray_glm/response/
logistic.rs1use crate::{
4 error::{RegressionError, RegressionResult},
5 glm::{DispersionType, Glm},
6 link::Link,
7 math::prod_log,
8 num::Float,
9 response::Response,
10};
11use ndarray::Array1;
12use std::marker::PhantomData;
13
14pub struct Logistic<L = link::Logit>
16where
17 L: Link<Logistic<L>>,
18{
19 _link: PhantomData<L>,
20}
21
22impl<L> Response<Logistic<L>> for bool
24where
25 L: Link<Logistic<L>>,
26{
27 fn into_float<F: Float>(self) -> RegressionResult<F> {
28 Ok(if self { F::one() } else { F::zero() })
29 }
30}
31impl<L> Response<Logistic<L>> for f32
35where
36 L: Link<Logistic<L>>,
37{
38 fn into_float<F: Float>(self) -> RegressionResult<F> {
39 if !(0.0..=1.0).contains(&self) {
40 return Err(RegressionError::InvalidY(self.to_string()));
41 }
42 F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string()))
43 }
44}
45impl<L> Response<Logistic<L>> for f64
46where
47 L: Link<Logistic<L>>,
48{
49 fn into_float<F: Float>(self) -> RegressionResult<F> {
50 if !(0.0..=1.0).contains(&self) {
51 return Err(RegressionError::InvalidY(self.to_string()));
52 }
53 F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string()))
54 }
55}
56
57impl<L> Glm for Logistic<L>
59where
60 L: Link<Logistic<L>>,
61{
62 type Link = L;
63 const DISPERSED: DispersionType = DispersionType::NoDispersion;
64
65 fn log_partition<F: Float>(nat_par: F) -> F {
68 num_traits::Float::exp(nat_par).ln_1p()
69 }
70
71 fn variance<F: Float>(mean: F) -> F {
73 mean * (F::one() - mean)
74 }
75
76 fn log_like_natural<F>(y: F, logit_p: F) -> F
79 where
80 F: Float,
81 {
82 let (yt, xt) = if logit_p < F::zero() {
83 (y, logit_p)
84 } else {
85 (F::one() - y, -logit_p)
86 };
87 yt * xt - num_traits::Float::exp(xt).ln_1p()
88 }
89
90 fn log_like_sat<F: Float>(y: F) -> F {
93 prod_log(y) + prod_log(F::one() - y)
94 }
95}
96
97pub mod link {
98 use super::*;
100 use crate::link::{Canonical, Link, Transform};
101 use crate::num::Float;
102
103 pub struct Logit {}
106 impl Canonical for Logit {}
107 impl Link<Logistic<Logit>> for Logit {
108 fn func<F: Float>(y: F) -> F {
109 num_traits::Float::ln(y / (F::one() - y))
110 }
111 fn func_inv<F: Float>(lin_pred: F) -> F {
112 (F::one() + num_traits::Float::exp(-lin_pred)).recip()
113 }
114 }
115
116 pub struct Cloglog {}
120 impl Link<Logistic<Cloglog>> for Cloglog {
121 fn func<F: Float>(y: F) -> F {
122 num_traits::Float::ln(-F::ln_1p(-y))
123 }
124 fn func_inv<F: Float>(lin_pred: F) -> F {
126 -F::exp_m1(-num_traits::Float::exp(lin_pred))
127 }
128 }
129 impl Transform for Cloglog {
130 fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F> {
131 lin_pred.mapv(|x| num_traits::Float::ln(num_traits::Float::exp(x).exp_m1()))
132 }
133 fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F> {
134 let neg_exp_lin = -lin_pred.mapv(num_traits::Float::exp);
135 &neg_exp_lin / &neg_exp_lin.mapv(F::exp_m1)
136 }
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::{error::RegressionResult, model::ModelBuilder};
144 use approx::assert_abs_diff_eq;
145 use ndarray::array;
146
147 #[test]
149 fn log_reg() -> RegressionResult<()> {
150 let beta = array![0., 1.0];
151 let ln2 = f64::ln(2.);
152 let data_x = array![[0.], [0.], [ln2], [ln2], [ln2]];
153 let data_y = array![true, false, true, true, false];
154 let model = ModelBuilder::<Logistic>::data(&data_y, &data_x).build()?;
155 let fit = model.fit()?;
156 assert_abs_diff_eq!(beta, fit.result, epsilon = 0.5 * f32::EPSILON as f64);
161 Ok(())
163 }
164
165 #[test]
167 fn cloglog_closure() {
168 use link::Cloglog;
169 let mu_test_vals = array![1e-8, 0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99, 0.9999999];
170 assert_abs_diff_eq!(
171 mu_test_vals,
172 mu_test_vals.mapv(|mu| Cloglog::func_inv(Cloglog::func(mu)))
173 );
174 let lin_test_vals = array![-10., -2., -0.1, 0.0, 0.1, 1., 2.];
175 assert_abs_diff_eq!(
176 lin_test_vals,
177 lin_test_vals.mapv(|lin| Cloglog::func(Cloglog::func_inv(lin))),
178 epsilon = 1e-3 * f32::EPSILON as f64
179 );
180 }
181}