ndarray_glm/response/
logistic.rs1#[cfg(feature = "stats")]
4use crate::response::Response;
5use crate::{
6 error::{RegressionError, RegressionResult},
7 glm::{DispersionType, Glm},
8 link::Link,
9 math::prod_log,
10 num::Float,
11 response::Yval,
12};
13use ndarray::Array1;
14#[cfg(feature = "stats")]
15use statrs::distribution::Bernoulli;
16use std::marker::PhantomData;
17
18pub struct Logistic<L = link::Logit>
20where
21 L: Link<Logistic<L>>,
22{
23 _link: PhantomData<L>,
24}
25
26impl<L> Yval<Logistic<L>> for bool
28where
29 L: Link<Logistic<L>>,
30{
31 fn into_float<F: Float>(self) -> RegressionResult<F, F> {
32 Ok(if self { F::one() } else { F::zero() })
33 }
34}
35impl<L> Yval<Logistic<L>> for f32
39where
40 L: Link<Logistic<L>>,
41{
42 fn into_float<F: Float>(self) -> RegressionResult<F, F> {
43 if !(0.0..=1.0).contains(&self) {
44 return Err(RegressionError::InvalidY(self.to_string()));
45 }
46 F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string()))
47 }
48}
49impl<L> Yval<Logistic<L>> for f64
50where
51 L: Link<Logistic<L>>,
52{
53 fn into_float<F: Float>(self) -> RegressionResult<F, F> {
54 if !(0.0..=1.0).contains(&self) {
55 return Err(RegressionError::InvalidY(self.to_string()));
56 }
57 F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string()))
58 }
59}
60
61#[cfg(feature = "stats")]
62impl<L> Response for Logistic<L>
63where
64 L: Link<Logistic<L>>,
65{
66 type DistributionType = Bernoulli;
67
68 fn get_distribution(mu: f64, _phi: f64) -> Self::DistributionType {
69 Bernoulli::new(mu).unwrap()
70 }
71}
72
73impl<L> Glm for Logistic<L>
75where
76 L: Link<Logistic<L>>,
77{
78 type Link = L;
79 const DISPERSED: DispersionType = DispersionType::NoDispersion;
80
81 fn log_partition<F: Float>(nat_par: F) -> F {
84 num_traits::Float::exp(nat_par).ln_1p()
85 }
86
87 fn variance<F: Float>(mean: F) -> F {
89 mean * (F::one() - mean)
90 }
91
92 fn log_like_natural<F>(y: F, logit_p: F) -> F
95 where
96 F: Float,
97 {
98 let (yt, xt) = if logit_p < F::zero() {
99 (y, logit_p)
100 } else {
101 (F::one() - y, -logit_p)
102 };
103 yt * xt - num_traits::Float::exp(xt).ln_1p()
104 }
105
106 fn log_like_sat<F: Float>(y: F) -> F {
109 prod_log(y) + prod_log(F::one() - y)
110 }
111}
112
113pub mod link {
114 use super::*;
116 use crate::link::{Canonical, Link, Transform};
117 use crate::num::Float;
118
119 pub struct Logit {}
122 impl Canonical for Logit {}
123 impl Link<Logistic<Logit>> for Logit {
124 fn func<F: Float>(y: F) -> F {
125 num_traits::Float::ln(y / (F::one() - y))
126 }
127 fn func_inv<F: Float>(lin_pred: F) -> F {
128 (F::one() + num_traits::Float::exp(-lin_pred)).recip()
129 }
130 }
131
132 pub struct Cloglog {}
136 impl Link<Logistic<Cloglog>> for Cloglog {
137 fn func<F: Float>(y: F) -> F {
138 num_traits::Float::ln(-F::ln_1p(-y))
139 }
140 fn func_inv<F: Float>(lin_pred: F) -> F {
142 -F::exp_m1(-num_traits::Float::exp(lin_pred))
143 }
144 }
145 impl Transform for Cloglog {
146 fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F> {
147 lin_pred.mapv(|x| num_traits::Float::ln(num_traits::Float::exp(x).exp_m1()))
148 }
149 fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F> {
150 let neg_exp_lin = -lin_pred.mapv(num_traits::Float::exp);
151 &neg_exp_lin / &neg_exp_lin.mapv(F::exp_m1)
152 }
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::{error::RegressionResult, model::ModelBuilder};
160 use approx::assert_abs_diff_eq;
161 use ndarray::array;
162
163 #[test]
165 fn log_reg() -> RegressionResult<(), f64> {
166 let beta = array![0., 1.0];
167 let ln2 = f64::ln(2.);
168 let data_x = array![[0.], [0.], [ln2], [ln2], [ln2]];
169 let data_y = array![true, false, true, true, false];
170 let model = ModelBuilder::<Logistic>::data(&data_y, &data_x).build()?;
171 let fit = model.fit()?;
172 assert_abs_diff_eq!(beta, fit.result, epsilon = 0.5 * f32::EPSILON as f64);
177 Ok(())
179 }
180
181 #[test]
183 fn cloglog_closure() {
184 use link::Cloglog;
185 let mu_test_vals = array![1e-8, 0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99, 0.9999999];
186 assert_abs_diff_eq!(
187 mu_test_vals,
188 mu_test_vals.mapv(|mu| Cloglog::func_inv(Cloglog::func(mu)))
189 );
190 let lin_test_vals = array![-10., -2., -0.1, 0.0, 0.1, 1., 2.];
191 assert_abs_diff_eq!(
192 lin_test_vals,
193 lin_test_vals.mapv(|lin| Cloglog::func(Cloglog::func_inv(lin))),
194 epsilon = 1e-3 * f32::EPSILON as f64
195 );
196 }
197}