1#![allow(non_snake_case)] use ndarray::{Array1, Array2};
6use so_linalg;
7use statrs::distribution::{ContinuousCDF, Normal};
8
9use so_core::data::DataFrame;
10use so_core::error::{Error, Result};
11use so_core::formula::Formula;
12
13use super::family::{Family, Link, is_valid_link};
14use super::results::GLMResults;
15
16fn add_intercept(X: &Array2<f64>) -> Array2<f64> {
18 let (n_samples, n_features) = X.dim();
19 let mut X_with_intercept = Array2::zeros((n_samples, n_features + 1));
20
21 for i in 0..n_samples {
23 X_with_intercept[(i, 0)] = 1.0;
24 }
25
26 for i in 0..n_samples {
28 for j in 0..n_features {
29 X_with_intercept[(i, j + 1)] = X[(i, j)];
30 }
31 }
32
33 X_with_intercept
34}
35
36#[derive(Debug, Clone)]
38pub struct GLMModelBuilder {
39 family: Family,
40 link: Option<Link>,
41 intercept: bool,
42 max_iter: usize,
43 tol: f64,
44 scale: Option<f64>,
45}
46
47impl Default for GLMModelBuilder {
48 fn default() -> Self {
49 Self {
50 family: Family::Gaussian,
51 link: None,
52 intercept: true,
53 max_iter: 100,
54 tol: 1e-6,
55 scale: None,
56 }
57 }
58}
59
60impl GLMModelBuilder {
61 pub fn new() -> Self {
63 Self::default()
64 }
65
66 pub fn family(mut self, family: Family) -> Self {
68 self.family = family;
69 self
70 }
71
72 pub fn link(mut self, link: Link) -> Self {
74 self.link = Some(link);
75 self
76 }
77
78 pub fn intercept(mut self, intercept: bool) -> Self {
80 self.intercept = intercept;
81 self
82 }
83
84 pub fn max_iter(mut self, max_iter: usize) -> Self {
86 self.max_iter = max_iter;
87 self
88 }
89
90 pub fn tol(mut self, tol: f64) -> Self {
92 self.tol = tol;
93 self
94 }
95
96 pub fn scale(mut self, scale: f64) -> Self {
98 self.scale = Some(scale);
99 self
100 }
101
102 pub fn build(self) -> GLM {
104 let link = self.link.unwrap_or_else(|| self.family.default_link());
105
106 if !is_valid_link(self.family, link) {
107 panic!(
108 "Invalid link-function combination: {} with {}",
109 self.family.name(),
110 link.name()
111 );
112 }
113
114 GLM {
115 family: self.family,
116 link,
117 intercept: self.intercept,
118 max_iter: self.max_iter,
119 tol: self.tol,
120 scale: self.scale,
121 }
122 }
123}
124
125#[derive(Debug, Clone)]
127pub struct GLM {
128 family: Family,
129 link: Link,
130 intercept: bool,
131 max_iter: usize,
132 tol: f64,
133 scale: Option<f64>,
134}
135
136impl GLM {
137 pub fn new() -> GLMModelBuilder {
139 GLMModelBuilder::new()
140 }
141
142 pub fn fit(&self, formula: &str, data: &DataFrame) -> Result<GLMResults> {
144 let formula = Formula::parse(formula)
145 .map_err(|e| Error::FormulaError(format!("Formula parse error: {}", e)))?;
146
147 let response_var = formula
149 .response
150 .as_ref()
151 .and_then(|term| match term {
152 so_core::formula::Term::Variable(name) => Some(name.clone()),
153 _ => None,
154 })
155 .ok_or_else(|| Error::DataError("Response must be a simple variable".to_string()))?;
156
157 let y_series = data.column(&response_var).ok_or_else(|| {
158 Error::DataError(format!(
159 "Response variable '{}' not found in data",
160 response_var
161 ))
162 })?;
163
164 let y = y_series.data().to_owned();
165
166 self.family.validate_response(&y)?;
168
169 let X = formula
171 .build_matrix(data)
172 .map_err(|e| Error::DataError(format!("Design matrix error: {}", e)))?;
173
174 let X = if self.intercept {
176 add_intercept(&X)
177 } else {
178 X.clone()
179 };
180
181 self.fit_irls(&X, &y)
183 }
184
185 fn fit_irls(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<GLMResults> {
187 let n = X.nrows();
188 let p = X.ncols();
189
190 if n <= p {
191 return Err(Error::DataError(format!(
192 "Not enough observations (n={}) for p={} parameters",
193 n, p
194 )));
195 }
196
197 let mut mu = self.family.initialize(y);
199 let mut eta = mu.mapv(|mu_val| self.link.link(mu_val));
200 let mut beta = Array1::zeros(p);
201
202 let mut weights = Array1::zeros(n);
204 let mut working_response = Array1::zeros(n);
205 let mut converged = false;
206 let mut iteration = 0;
207 let mut deviance_old = f64::INFINITY;
208 let mut last_X_weighted = Array2::zeros((n, p));
209 let mut last_XtWX = Array2::zeros((p, p));
210
211 while iteration < self.max_iter {
213 for i in 0..n {
215 let mu_i = mu[i];
216 let eta_i = eta[i];
217
218 let variance = self.family.variance(mu_i);
220 let derivative = self.link.derivative(eta_i);
221 weights[i] = 1.0 / (variance * derivative.powi(2)).max(1e-10);
222
223 working_response[i] = eta_i + (y[i] - mu_i) * derivative;
225 }
226
227 let X_weighted = X.clone() * weights.mapv(|w| w.sqrt()).insert_axis(ndarray::Axis(1));
229 let z_weighted = &working_response * weights.mapv(|w| w.sqrt());
230
231 let XtWX = X_weighted.t().dot(&X_weighted);
232 let XtWz = X_weighted.t().dot(&z_weighted);
233
234 let beta_new = so_linalg::solve(&XtWX, &XtWz)
235 .map_err(|e| Error::LinearAlgebraError(format!("IRLS solve failed: {}", e)))?;
236
237 last_X_weighted = X_weighted;
239 last_XtWX = XtWX;
240
241 eta = X.dot(&beta_new);
243 mu = eta.mapv(|eta_val| self.link.inverse_link(eta_val));
244
245 let deviance = self.family.deviance(y, &mu);
247 let beta_diff = (&beta_new - &beta).mapv(|x| x.abs()).sum();
248
249 if (deviance_old - deviance).abs() < self.tol && beta_diff < self.tol {
250 converged = true;
251 break;
252 }
253
254 beta = beta_new;
255 deviance_old = deviance;
256 iteration += 1;
257 }
258
259 if !converged {
260 return Err(Error::Message(format!(
261 "IRLS did not converge after {} iterations",
262 self.max_iter
263 )));
264 }
265
266 let fitted = mu.clone();
268 let residuals = y - &fitted;
269
270 let pearson_residuals: Array1<f64> = y
272 .iter()
273 .zip(fitted.iter())
274 .map(|(&y_val, &mu_val)| {
275 let variance = self.family.variance(mu_val);
276 if variance > 0.0 {
277 (y_val - mu_val) / variance.sqrt()
278 } else {
279 0.0
280 }
281 })
282 .collect();
283
284 let hat_matrix_diag = self.compute_leverage(&last_X_weighted);
286
287 let scale = match self.scale {
289 Some(s) => s,
290 None => self.family.estimate_dispersion(y, &fitted, n, p),
291 };
292
293 let cov_matrix = self.compute_covariance(&last_XtWX, scale);
295 let std_errors: Array1<f64> = (0..p).map(|i| cov_matrix[(i, i)].sqrt()).collect();
296
297 let (z_values, p_values) = self.compute_inference(&beta, &std_errors, n - p);
299
300 let null_deviance = self.compute_null_deviance(y);
302 let residual_deviance = deviance_old;
303 let df_null = if self.intercept { n - 1 } else { n };
304 let df_residual = n - p;
305 let aic = self.compute_aic(y, &fitted, p);
306 let bic = self.compute_bic(y, &fitted, p, n);
307
308 Ok(GLMResults {
309 coefficients: beta,
310 std_errors,
311 z_values,
312 p_values,
313 fitted_values: fitted,
314 residuals,
315 pearson_residuals,
316 hat_matrix_diag,
317 scale,
318 deviance: residual_deviance,
319 null_deviance,
320 df_residual,
321 df_null,
322 aic,
323 bic,
324 converged,
325 iterations: iteration,
326 family: self.family,
327 link: self.link,
328 intercept: self.intercept,
329 n_obs: n,
330 n_params: p,
331 })
332 }
333
334 fn compute_leverage(&self, X_weighted: &Array2<f64>) -> Array1<f64> {
336 let n = X_weighted.nrows();
337 let p = X_weighted.ncols();
338
339 if n <= p {
340 return Array1::zeros(n);
341 }
342
343 let xtx = X_weighted.t().dot(X_weighted);
344 match so_linalg::inv(&xtx) {
345 Ok(xtx_inv) => {
346 let mut leverage = Array1::zeros(n);
347 for i in 0..n {
348 let xi = X_weighted.row(i);
349 leverage[i] = xi.dot(&xtx_inv.dot(&xi.t()));
350 }
351 leverage
352 }
353 Err(_) => Array1::zeros(n),
354 }
355 }
356
357 fn compute_covariance(&self, XtWX: &Array2<f64>, scale: f64) -> Array2<f64> {
359 match so_linalg::inv(XtWX) {
360 Ok(cov) => &cov * scale,
361 Err(_) => Array2::zeros((XtWX.nrows(), XtWX.ncols())),
362 }
363 }
364
365 fn compute_inference(
367 &self,
368 coefficients: &Array1<f64>,
369 std_errors: &Array1<f64>,
370 _df_residual: usize,
371 ) -> (Array1<f64>, Array1<f64>) {
372 let n_coef = coefficients.len();
373 let mut z_values = Array1::zeros(n_coef);
374 let mut p_values = Array1::zeros(n_coef);
375
376 for i in 0..n_coef {
377 let se = std_errors[i];
378 if se > 0.0 {
379 z_values[i] = coefficients[i] / se;
380
381 let z_abs = z_values[i].abs();
383 p_values[i] = 2.0 * (1.0 - Normal::new(0.0, 1.0).unwrap().cdf(z_abs));
384 } else {
385 z_values[i] = f64::NAN;
386 p_values[i] = f64::NAN;
387 }
388 }
389
390 (z_values, p_values)
391 }
392
393 fn compute_null_deviance(&self, y: &Array1<f64>) -> f64 {
395 let n = y.len();
396 let mu_null = if self.intercept {
397 let y_mean = y.mean().unwrap_or(0.0);
399 let eta_mean = self.link.link(y_mean);
400 Array1::from_elem(n, self.link.inverse_link(eta_mean))
401 } else {
402 Array1::zeros(n)
403 };
404
405 self.family.deviance(y, &mu_null)
406 }
407
408 fn compute_aic(&self, y: &Array1<f64>, fitted: &Array1<f64>, n_params: usize) -> f64 {
410 let deviance = self.family.deviance(y, fitted);
411 2.0 * n_params as f64 + deviance
412 }
413
414 fn compute_bic(
416 &self,
417 y: &Array1<f64>,
418 fitted: &Array1<f64>,
419 n_params: usize,
420 n_obs: usize,
421 ) -> f64 {
422 let deviance = self.family.deviance(y, fitted);
423 n_params as f64 * (n_obs as f64).ln() + deviance
424 }
425
426 pub fn predict(&self, results: &GLMResults, X: &Array2<f64>) -> Array1<f64> {
428 let X_with_intercept = if self.intercept {
429 add_intercept(X)
430 } else {
431 X.clone()
432 };
433
434 let linear_predictor = X_with_intercept.dot(&results.coefficients);
435 linear_predictor.mapv(|eta| self.link.inverse_link(eta))
436 }
437}