1#![allow(non_snake_case)] use ndarray::Array1;
6use serde::{Deserialize, Serialize};
7use so_core::error::Result;
8use statrs::distribution::{Continuous, Normal};
9
10#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
12pub enum Family {
13 Gaussian,
15 Binomial,
17 Poisson,
19 Gamma,
21 InverseGaussian,
23}
24
25impl Family {
26 pub fn default_link(&self) -> Link {
28 match self {
29 Family::Gaussian => Link::Identity,
30 Family::Binomial => Link::Logit,
31 Family::Poisson => Link::Log,
32 Family::Gamma => Link::Inverse,
33 Family::InverseGaussian => Link::InverseSquare,
34 }
35 }
36
37 pub fn variance(&self, mu: f64) -> f64 {
39 match self {
40 Family::Gaussian => 1.0,
41 Family::Binomial => mu * (1.0 - mu),
42 Family::Poisson => mu,
43 Family::Gamma => mu.powi(2),
44 Family::InverseGaussian => mu.powi(3),
45 }
46 }
47
48 pub fn unit_deviance(&self, y: f64, mu: f64) -> f64 {
50 match self {
51 Family::Gaussian => (y - mu).powi(2),
52 Family::Binomial => {
53 if y == 0.0 {
54 2.0 * (1.0 - mu).ln().max(-100.0)
55 } else if y == 1.0 {
56 2.0 * mu.ln().max(-100.0)
57 } else {
58 2.0 * (y * (y / mu).ln().max(-100.0)
60 + (1.0 - y) * ((1.0 - y) / (1.0 - mu)).ln().max(-100.0))
61 }
62 }
63 Family::Poisson => {
64 if mu == 0.0 {
65 if y == 0.0 { 0.0 } else { 2.0 * y }
66 } else {
67 2.0 * (y * (y / mu).ln().max(-100.0) - (y - mu))
68 }
69 }
70 Family::Gamma => 2.0 * ((y - mu) / mu - (y / mu).ln()),
71 Family::InverseGaussian => (y - mu).powi(2) / (mu.powi(2) * y),
72 }
73 }
74
75 pub fn deviance(&self, y: &Array1<f64>, mu: &Array1<f64>) -> f64 {
77 y.iter()
78 .zip(mu.iter())
79 .map(|(&y_val, &mu_val)| self.unit_deviance(y_val, mu_val))
80 .sum()
81 }
82
83 pub fn initialize(&self, y: &Array1<f64>) -> Array1<f64> {
85 match self {
86 Family::Gaussian => y.clone(),
87 Family::Binomial => {
88 y.mapv(|y_val| {
90 let clipped = y_val.max(0.0001).min(0.9999);
91 (clipped / (1.0 - clipped)).ln()
92 })
93 }
94 Family::Poisson => {
95 y.mapv(|y_val| (y_val + 0.5).ln())
97 }
98 Family::Gamma => {
99 y.mapv(|y_val| y_val.max(1e-8).ln())
101 }
102 Family::InverseGaussian => {
103 y.mapv(|y_val| y_val.max(1e-8).ln())
105 }
106 }
107 }
108
109 pub fn validate_response(&self, y: &Array1<f64>) -> Result<()> {
111 match self {
112 Family::Gaussian => Ok(()), Family::Binomial => {
114 for &val in y {
116 if !(0.0..=1.0).contains(&val) {
117 return Err(so_core::error::Error::DataError(format!(
118 "Binomial response must be in [0, 1], got {}",
119 val
120 )));
121 }
122 }
123 Ok(())
124 }
125 Family::Poisson => {
126 for &val in y {
128 if val < 0.0 {
129 return Err(so_core::error::Error::DataError(format!(
130 "Poisson response must be non-negative, got {}",
131 val
132 )));
133 }
134 }
135 Ok(())
136 }
137 Family::Gamma | Family::InverseGaussian => {
138 for &val in y {
140 if val <= 0.0 {
141 return Err(so_core::error::Error::DataError(format!(
142 "{} response must be positive, got {}",
143 match self {
144 Family::Gamma => "Gamma",
145 Family::InverseGaussian => "Inverse Gaussian",
146 _ => unreachable!(),
147 },
148 val
149 )));
150 }
151 }
152 Ok(())
153 }
154 }
155 }
156
157 pub fn name(&self) -> &'static str {
159 match self {
160 Family::Gaussian => "Gaussian",
161 Family::Binomial => "Binomial",
162 Family::Poisson => "Poisson",
163 Family::Gamma => "Gamma",
164 Family::InverseGaussian => "Inverse Gaussian",
165 }
166 }
167
168 pub fn estimate_dispersion(
170 &self,
171 y: &Array1<f64>,
172 mu: &Array1<f64>,
173 n: usize,
174 p: usize,
175 ) -> f64 {
176 let pearson_residuals: f64 = y
177 .iter()
178 .zip(mu.iter())
179 .map(|(&y_val, &mu_val)| {
180 let variance = self.variance(mu_val);
181 if variance > 0.0 {
182 (y_val - mu_val).powi(2) / variance
183 } else {
184 0.0
185 }
186 })
187 .sum();
188
189 pearson_residuals / (n - p) as f64
190 }
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
195pub enum Link {
196 Identity,
198 Logit,
200 Probit,
202 Cloglog,
204 Log,
206 Inverse,
208 InverseSquare,
210 Sqrt,
212}
213
214impl Link {
215 pub fn link(&self, mu: f64) -> f64 {
217 match self {
218 Link::Identity => mu,
219 Link::Logit => (mu / (1.0 - mu)).ln(),
220 Link::Probit => {
221 if mu <= 0.0 || mu >= 1.0 {
223 f64::NAN
224 } else {
225 statrs::function::erf::erf_inv(2.0 * mu - 1.0) * 2.0f64.sqrt()
226 }
227 }
228 Link::Cloglog => (-(1.0 - mu).ln()).ln(),
229 Link::Log => mu.ln(),
230 Link::Inverse => 1.0 / mu,
231 Link::InverseSquare => 1.0 / mu.powi(2),
232 Link::Sqrt => mu.sqrt(),
233 }
234 }
235
236 pub fn inverse_link(&self, eta: f64) -> f64 {
238 match self {
239 Link::Identity => eta,
240 Link::Logit => 1.0 / (1.0 + (-eta).exp()),
241 Link::Probit => 0.5 * (1.0 + statrs::function::erf::erf(eta / 2.0f64.sqrt())),
242 Link::Cloglog => 1.0 - (-eta.exp()).exp(),
243 Link::Log => eta.exp(),
244 Link::Inverse => 1.0 / eta,
245 Link::InverseSquare => 1.0 / eta.sqrt(),
246 Link::Sqrt => eta.powi(2),
247 }
248 }
249
250 pub fn derivative(&self, eta: f64) -> f64 {
252 match self {
253 Link::Identity => 1.0,
254 Link::Logit => {
255 let mu = self.inverse_link(eta);
256 mu * (1.0 - mu)
257 }
258 Link::Probit => {
259 Normal::new(0.0, 1.0).unwrap().pdf(eta)
261 }
262 Link::Cloglog => {
263 let mu = self.inverse_link(eta);
264 (1.0 - mu) * (-(1.0 - mu).ln())
265 }
266 Link::Log => eta.exp(), Link::Inverse => -1.0 / eta.powi(2),
268 Link::InverseSquare => -0.5 / eta.powf(-1.5),
269 Link::Sqrt => 2.0 * eta,
270 }
271 }
272
273 pub fn name(&self) -> &'static str {
275 match self {
276 Link::Identity => "identity",
277 Link::Logit => "logit",
278 Link::Probit => "probit",
279 Link::Cloglog => "cloglog",
280 Link::Log => "log",
281 Link::Inverse => "inverse",
282 Link::InverseSquare => "inverse square",
283 Link::Sqrt => "sqrt",
284 }
285 }
286}
287
288pub fn is_valid_link(family: Family, link: Link) -> bool {
290 match family {
291 Family::Gaussian => matches!(link, Link::Identity | Link::Log | Link::Inverse),
292 Family::Binomial => matches!(link, Link::Logit | Link::Probit | Link::Cloglog | Link::Log),
293 Family::Poisson => matches!(link, Link::Log | Link::Identity | Link::Sqrt),
294 Family::Gamma => matches!(link, Link::Inverse | Link::Log | Link::Identity),
295 Family::InverseGaussian => matches!(
296 link,
297 Link::InverseSquare | Link::Inverse | Link::Log | Link::Identity
298 ),
299 }
300}