Skip to main content

so_models/glm/
model.rs

1//! Generalized Linear Model implementation using IRLS (Iteratively Reweighted Least Squares)
2
3#![allow(non_snake_case)] // Allow mathematical notation (X, W, etc.)
4
5use ndarray::{Array1, Array2};
6use so_linalg;
7use statrs::distribution::{ContinuousCDF, Normal};
8
9use so_core::data::DataFrame;
10use so_core::error::{Error, Result};
11use so_core::formula::Formula;
12
13use super::family::{Family, Link, is_valid_link};
14use super::results::GLMResults;
15
16/// Add intercept column (column of ones) to design matrix
17fn add_intercept(X: &Array2<f64>) -> Array2<f64> {
18    let (n_samples, n_features) = X.dim();
19    let mut X_with_intercept = Array2::zeros((n_samples, n_features + 1));
20
21    // First column is intercept (ones)
22    for i in 0..n_samples {
23        X_with_intercept[(i, 0)] = 1.0;
24    }
25
26    // Copy original features
27    for i in 0..n_samples {
28        for j in 0..n_features {
29            X_with_intercept[(i, j + 1)] = X[(i, j)];
30        }
31    }
32
33    X_with_intercept
34}
35
36/// GLM model configuration and builder
37#[derive(Debug, Clone)]
38pub struct GLMModelBuilder {
39    family: Family,
40    link: Option<Link>,
41    intercept: bool,
42    max_iter: usize,
43    tol: f64,
44    scale: Option<f64>,
45}
46
47impl Default for GLMModelBuilder {
48    fn default() -> Self {
49        Self {
50            family: Family::Gaussian,
51            link: None,
52            intercept: true,
53            max_iter: 100,
54            tol: 1e-6,
55            scale: None,
56        }
57    }
58}
59
60impl GLMModelBuilder {
61    /// Create a new GLM builder with default settings
62    pub fn new() -> Self {
63        Self::default()
64    }
65
66    /// Set the distribution family
67    pub fn family(mut self, family: Family) -> Self {
68        self.family = family;
69        self
70    }
71
72    /// Set the link function (if None, uses family's default)
73    pub fn link(mut self, link: Link) -> Self {
74        self.link = Some(link);
75        self
76    }
77
78    /// Set whether to include an intercept
79    pub fn intercept(mut self, intercept: bool) -> Self {
80        self.intercept = intercept;
81        self
82    }
83
84    /// Set maximum number of IRLS iterations
85    pub fn max_iter(mut self, max_iter: usize) -> Self {
86        self.max_iter = max_iter;
87        self
88    }
89
90    /// Set convergence tolerance
91    pub fn tol(mut self, tol: f64) -> Self {
92        self.tol = tol;
93        self
94    }
95
96    /// Set fixed scale parameter (dispersion)
97    pub fn scale(mut self, scale: f64) -> Self {
98        self.scale = Some(scale);
99        self
100    }
101
102    /// Build the GLM model with current configuration
103    pub fn build(self) -> GLM {
104        let link = self.link.unwrap_or_else(|| self.family.default_link());
105
106        if !is_valid_link(self.family, link) {
107            panic!(
108                "Invalid link-function combination: {} with {}",
109                self.family.name(),
110                link.name()
111            );
112        }
113
114        GLM {
115            family: self.family,
116            link,
117            intercept: self.intercept,
118            max_iter: self.max_iter,
119            tol: self.tol,
120            scale: self.scale,
121        }
122    }
123}
124
125/// Generalized Linear Model
126#[derive(Debug, Clone)]
127pub struct GLM {
128    family: Family,
129    link: Link,
130    intercept: bool,
131    max_iter: usize,
132    tol: f64,
133    scale: Option<f64>,
134}
135
136impl GLM {
137    /// Create a new GLM builder
138    pub fn new() -> GLMModelBuilder {
139        GLMModelBuilder::new()
140    }
141
142    /// Fit the GLM using formula and data
143    pub fn fit(&self, formula: &str, data: &DataFrame) -> Result<GLMResults> {
144        let formula = Formula::parse(formula)
145            .map_err(|e| Error::FormulaError(format!("Formula parse error: {}", e)))?;
146
147        // Extract response and predictors
148        let response_var = formula
149            .response
150            .as_ref()
151            .and_then(|term| match term {
152                so_core::formula::Term::Variable(name) => Some(name.clone()),
153                _ => None,
154            })
155            .ok_or_else(|| Error::DataError("Response must be a simple variable".to_string()))?;
156
157        let y_series = data.column(&response_var).ok_or_else(|| {
158            Error::DataError(format!(
159                "Response variable '{}' not found in data",
160                response_var
161            ))
162        })?;
163
164        let y = y_series.data().to_owned();
165
166        // Validate response values for the family
167        self.family.validate_response(&y)?;
168
169        // Build design matrix
170        let X = formula
171            .build_matrix(data)
172            .map_err(|e| Error::DataError(format!("Design matrix error: {}", e)))?;
173
174        // Add intercept if requested
175        let X = if self.intercept {
176            add_intercept(&X)
177        } else {
178            X.clone()
179        };
180
181        // Fit using IRLS
182        self.fit_irls(&X, &y)
183    }
184
185    /// Fit using IRLS algorithm
186    fn fit_irls(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<GLMResults> {
187        let n = X.nrows();
188        let p = X.ncols();
189
190        if n <= p {
191            return Err(Error::DataError(format!(
192                "Not enough observations (n={}) for p={} parameters",
193                n, p
194            )));
195        }
196
197        // Initialize parameters
198        let mut mu = self.family.initialize(y);
199        let mut eta = mu.mapv(|mu_val| self.link.link(mu_val));
200        let mut beta = Array1::zeros(p);
201
202        // Working variables for IRLS
203        let mut weights = Array1::zeros(n);
204        let mut working_response = Array1::zeros(n);
205        let mut converged = false;
206        let mut iteration = 0;
207        let mut deviance_old = f64::INFINITY;
208        let mut last_X_weighted = Array2::zeros((n, p));
209        let mut last_XtWX = Array2::zeros((p, p));
210
211        // IRLS iterations
212        while iteration < self.max_iter {
213            // Compute working weights and response
214            for i in 0..n {
215                let mu_i = mu[i];
216                let eta_i = eta[i];
217
218                // Weight: w = 1 / (V(μ) * (g'(μ))^2)
219                let variance = self.family.variance(mu_i);
220                let derivative = self.link.derivative(eta_i);
221                weights[i] = 1.0 / (variance * derivative.powi(2)).max(1e-10);
222
223                // Working response: z = η + (y - μ) * g'(μ)
224                working_response[i] = eta_i + (y[i] - mu_i) * derivative;
225            }
226
227            // Solve weighted least squares: β = (X'WX)^-1 X'Wz
228            let X_weighted = X.clone() * weights.mapv(|w| w.sqrt()).insert_axis(ndarray::Axis(1));
229            let z_weighted = &working_response * weights.mapv(|w| w.sqrt());
230
231            let XtWX = X_weighted.t().dot(&X_weighted);
232            let XtWz = X_weighted.t().dot(&z_weighted);
233
234            let beta_new = so_linalg::solve(&XtWX, &XtWz)
235                .map_err(|e| Error::LinearAlgebraError(format!("IRLS solve failed: {}", e)))?;
236
237            // Save the last weighted matrices for standard error computation
238            last_X_weighted = X_weighted;
239            last_XtWX = XtWX;
240
241            // Update parameters
242            eta = X.dot(&beta_new);
243            mu = eta.mapv(|eta_val| self.link.inverse_link(eta_val));
244
245            // Check convergence
246            let deviance = self.family.deviance(y, &mu);
247            let beta_diff = (&beta_new - &beta).mapv(|x| x.abs()).sum();
248
249            if (deviance_old - deviance).abs() < self.tol && beta_diff < self.tol {
250                converged = true;
251                break;
252            }
253
254            beta = beta_new;
255            deviance_old = deviance;
256            iteration += 1;
257        }
258
259        if !converged {
260            return Err(Error::Message(format!(
261                "IRLS did not converge after {} iterations",
262                self.max_iter
263            )));
264        }
265
266        // Compute final predictions and residuals
267        let fitted = mu.clone();
268        let residuals = y - &fitted;
269
270        // Compute Pearson residuals
271        let pearson_residuals: Array1<f64> = y
272            .iter()
273            .zip(fitted.iter())
274            .map(|(&y_val, &mu_val)| {
275                let variance = self.family.variance(mu_val);
276                if variance > 0.0 {
277                    (y_val - mu_val) / variance.sqrt()
278                } else {
279                    0.0
280                }
281            })
282            .collect();
283
284        // Compute leverage and Cook's distance (simplified)
285        let hat_matrix_diag = self.compute_leverage(&last_X_weighted);
286
287        // Estimate or use provided scale
288        let scale = match self.scale {
289            Some(s) => s,
290            None => self.family.estimate_dispersion(y, &fitted, n, p),
291        };
292
293        // Compute standard errors
294        let cov_matrix = self.compute_covariance(&last_XtWX, scale);
295        let std_errors: Array1<f64> = (0..p).map(|i| cov_matrix[(i, i)].sqrt()).collect();
296
297        // Compute z/t-values and p-values
298        let (z_values, p_values) = self.compute_inference(&beta, &std_errors, n - p);
299
300        // Compute model statistics
301        let null_deviance = self.compute_null_deviance(y);
302        let residual_deviance = deviance_old;
303        let df_null = if self.intercept { n - 1 } else { n };
304        let df_residual = n - p;
305        let aic = self.compute_aic(y, &fitted, p);
306        let bic = self.compute_bic(y, &fitted, p, n);
307
308        Ok(GLMResults {
309            coefficients: beta,
310            std_errors,
311            z_values,
312            p_values,
313            fitted_values: fitted,
314            residuals,
315            pearson_residuals,
316            hat_matrix_diag,
317            scale,
318            deviance: residual_deviance,
319            null_deviance,
320            df_residual,
321            df_null,
322            aic,
323            bic,
324            converged,
325            iterations: iteration,
326            family: self.family,
327            link: self.link,
328            intercept: self.intercept,
329            n_obs: n,
330            n_params: p,
331        })
332    }
333
334    /// Compute leverage (diagonal of hat matrix)
335    fn compute_leverage(&self, X_weighted: &Array2<f64>) -> Array1<f64> {
336        let n = X_weighted.nrows();
337        let p = X_weighted.ncols();
338
339        if n <= p {
340            return Array1::zeros(n);
341        }
342
343        let xtx = X_weighted.t().dot(X_weighted);
344        match so_linalg::inv(&xtx) {
345            Ok(xtx_inv) => {
346                let mut leverage = Array1::zeros(n);
347                for i in 0..n {
348                    let xi = X_weighted.row(i);
349                    leverage[i] = xi.dot(&xtx_inv.dot(&xi.t()));
350                }
351                leverage
352            }
353            Err(_) => Array1::zeros(n),
354        }
355    }
356
357    /// Compute covariance matrix of coefficients
358    fn compute_covariance(&self, XtWX: &Array2<f64>, scale: f64) -> Array2<f64> {
359        match so_linalg::inv(XtWX) {
360            Ok(cov) => &cov * scale,
361            Err(_) => Array2::zeros((XtWX.nrows(), XtWX.ncols())),
362        }
363    }
364
365    /// Compute z/t-values and p-values for coefficients
366    fn compute_inference(
367        &self,
368        coefficients: &Array1<f64>,
369        std_errors: &Array1<f64>,
370        _df_residual: usize,
371    ) -> (Array1<f64>, Array1<f64>) {
372        let n_coef = coefficients.len();
373        let mut z_values = Array1::zeros(n_coef);
374        let mut p_values = Array1::zeros(n_coef);
375
376        for i in 0..n_coef {
377            let se = std_errors[i];
378            if se > 0.0 {
379                z_values[i] = coefficients[i] / se;
380
381                // Use normal distribution for p-values in GLM
382                let z_abs = z_values[i].abs();
383                p_values[i] = 2.0 * (1.0 - Normal::new(0.0, 1.0).unwrap().cdf(z_abs));
384            } else {
385                z_values[i] = f64::NAN;
386                p_values[i] = f64::NAN;
387            }
388        }
389
390        (z_values, p_values)
391    }
392
393    /// Compute null deviance (intercept-only model)
394    fn compute_null_deviance(&self, y: &Array1<f64>) -> f64 {
395        let n = y.len();
396        let mu_null = if self.intercept {
397            // Compute mean response (on link scale, then transform)
398            let y_mean = y.mean().unwrap_or(0.0);
399            let eta_mean = self.link.link(y_mean);
400            Array1::from_elem(n, self.link.inverse_link(eta_mean))
401        } else {
402            Array1::zeros(n)
403        };
404
405        self.family.deviance(y, &mu_null)
406    }
407
408    /// Compute Akaike Information Criterion
409    fn compute_aic(&self, y: &Array1<f64>, fitted: &Array1<f64>, n_params: usize) -> f64 {
410        let deviance = self.family.deviance(y, fitted);
411        2.0 * n_params as f64 + deviance
412    }
413
414    /// Compute Bayesian Information Criterion
415    fn compute_bic(
416        &self,
417        y: &Array1<f64>,
418        fitted: &Array1<f64>,
419        n_params: usize,
420        n_obs: usize,
421    ) -> f64 {
422        let deviance = self.family.deviance(y, fitted);
423        n_params as f64 * (n_obs as f64).ln() + deviance
424    }
425
426    /// Predict on new data
427    pub fn predict(&self, results: &GLMResults, X: &Array2<f64>) -> Array1<f64> {
428        let X_with_intercept = if self.intercept {
429            add_intercept(X)
430        } else {
431            X.clone()
432        };
433
434        let linear_predictor = X_with_intercept.dot(&results.coefficients);
435        linear_predictor.mapv(|eta| self.link.inverse_link(eta))
436    }
437}