Skip to main content

so_models/regression/
linear.rs

1//! Linear regression models
2//!
3//! This module implements various linear regression models:
4//! - Ordinary Least Squares (OLS)
5//! - Ridge Regression (L2 regularization)
6//! - Lasso Regression (L1 regularization)
7//! - Elastic Net (L1 + L2 regularization)
8
9#![allow(non_snake_case)] // Allow mathematical notation (X, y, etc.)
10
11use ndarray::{Array1, Array2};
12use serde::{Deserialize, Serialize};
13use so_core::data::DataFrame;
14use so_core::error::{Error, Result};
15use so_core::formula::Formula;
16use so_linalg;
17
18// ============================================================================
19// Model Results
20// ============================================================================
21
22/// Results from fitting a linear regression model
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct LinearRegressionResults {
25    /// Estimated coefficients (including intercept if present)
26    pub coefficients: Array1<f64>,
27    /// Standard errors of coefficients
28    pub std_errors: Option<Array1<f64>>,
29    /// t-statistics for coefficients
30    pub t_values: Option<Array1<f64>>,
31    /// p-values for coefficients
32    pub p_values: Option<Array1<f64>>,
33    /// Residual sum of squares
34    pub rss: f64,
35    /// Total sum of squares
36    pub tss: f64,
37    /// R-squared (coefficient of determination)
38    pub r_squared: f64,
39    /// Adjusted R-squared
40    pub r_squared_adj: f64,
41    /// Residual standard error
42    pub sigma: f64,
43    /// Degrees of freedom
44    pub df_residual: usize,
45    /// Degrees of freedom of model
46    pub df_model: usize,
47    /// F-statistic
48    pub f_statistic: Option<f64>,
49    /// F-statistic p-value
50    pub f_p_value: Option<f64>,
51    /// Akaike Information Criterion
52    pub aic: Option<f64>,
53    /// Bayesian Information Criterion
54    pub bic: Option<f64>,
55    /// Log-likelihood
56    pub log_likelihood: Option<f64>,
57}
58
59impl LinearRegressionResults {
60    /// Create a summary string similar to R's summary()
61    pub fn summary(&self, feature_names: &[String]) -> String {
62        let n_coef = self.coefficients.len();
63        let intercept_included = feature_names.first().is_some_and(|n| n == "(Intercept)");
64
65        let mut summary = String::new();
66        summary.push_str("Linear Regression Results\n");
67        summary.push_str("========================\n");
68        summary.push_str(&format!(
69            "R-squared: {:.4}, Adjusted R-squared: {:.4}\n",
70            self.r_squared, self.r_squared_adj
71        ));
72        summary.push_str(&format!(
73            "F-statistic: {:.2}, p-value: {:.4e}\n",
74            self.f_statistic.unwrap_or(f64::NAN),
75            self.f_p_value.unwrap_or(f64::NAN)
76        ));
77        summary.push_str(&format!(
78            "Residual Std. Error: {:.4} (df = {})\n",
79            self.sigma, self.df_residual
80        ));
81
82        summary.push_str("\nCoefficients:\n");
83        summary.push_str("              Estimate Std. Error t value Pr(>|t|)\n");
84
85        for i in 0..n_coef {
86            let name = if i == 0 && intercept_included {
87                "(Intercept)".to_string()
88            } else if intercept_included {
89                feature_names
90                    .get(i)
91                    .cloned()
92                    .unwrap_or_else(|| format!("x{}", i))
93            } else {
94                feature_names
95                    .get(i)
96                    .cloned()
97                    .unwrap_or_else(|| format!("x{}", i))
98            };
99
100            let coef = self.coefficients[i];
101            let se = self.std_errors.as_ref().map_or(f64::NAN, |se| se[i]);
102            let t = self.t_values.as_ref().map_or(f64::NAN, |t| t[i]);
103            let p = self.p_values.as_ref().map_or(f64::NAN, |p| p[i]);
104
105            let significance = if p < 0.001 {
106                "***"
107            } else if p < 0.01 {
108                "**"
109            } else if p < 0.05 {
110                "*"
111            } else if p < 0.1 {
112                "."
113            } else {
114                ""
115            };
116
117            summary.push_str(&format!(
118                "{:15} {:8.4} {:8.4} {:7.3} {:8.4} {}\n",
119                name, coef, se, t, p, significance
120            ));
121        }
122
123        summary.push_str("\nSignif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n");
124        summary
125    }
126
127    /// Predict using fitted model
128    pub fn predict(&self, X: &Array2<f64>) -> Array1<f64> {
129        X.dot(&self.coefficients)
130    }
131}
132
133// ============================================================================
134// OLS Regression
135// ============================================================================
136
137/// Ordinary Least Squares regression
138#[derive(Debug, Clone)]
139pub struct OLS {
140    /// Whether to include intercept
141    pub intercept: bool,
142}
143
144impl OLS {
145    /// Create a new OLS model
146    pub fn new() -> Self {
147        Self { intercept: true }
148    }
149
150    /// Create OLS model without intercept
151    pub fn no_intercept() -> Self {
152        Self { intercept: false }
153    }
154
155    /// Fit model using formula and DataFrame
156    pub fn fit_formula(
157        &self,
158        formula: &Formula,
159        df: &DataFrame,
160    ) -> Result<LinearRegressionResults> {
161        let X = formula.build_matrix(df)?;
162        let y = formula
163            .response_vector(df)?
164            .ok_or_else(|| Error::Message("Formula must include response variable".to_string()))?;
165
166        self.fit(&X, &y)
167    }
168
169    /// Fit model with design matrix X and response y
170    pub fn fit(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<LinearRegressionResults> {
171        let n = X.shape()[0];
172        let p = X.shape()[1];
173
174        if n <= p {
175            return Err(Error::Message(
176                "Not enough observations for OLS estimation".to_string(),
177            ));
178        }
179
180        if y.len() != n {
181            return Err(Error::DimensionMismatch(format!(
182                "X has {} rows, y has {} elements",
183                n,
184                y.len()
185            )));
186        }
187
188        // Compute coefficients using normal equations or QR decomposition
189        let coefficients = self.solve_normal_equations(X, y)?;
190
191        // Compute residuals and statistics
192        let y_hat = X.dot(&coefficients);
193        let residuals = y - &y_hat;
194        let rss = residuals.dot(&residuals);
195
196        let y_mean = y.mean().unwrap_or(0.0);
197        let tss = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>();
198
199        let df_residual = n - p;
200        let df_model = p - if self.intercept { 1 } else { 0 };
201        let sigma = (rss / df_residual as f64).sqrt();
202
203        let r_squared = if tss == 0.0 { 1.0 } else { 1.0 - rss / tss };
204        let r_squared_adj = 1.0 - (1.0 - r_squared) * (n as f64 - 1.0) / (df_residual as f64);
205
206        // Compute standard errors if we have enough data
207        let (std_errors, t_values, p_values) = if n > p + 1 {
208            self.compute_inference(X, &residuals, sigma, df_residual, &coefficients)
209        } else {
210            (None, None, None)
211        };
212
213        // Compute F-statistic
214        let (f_statistic, f_p_value) = if self.intercept && df_model > 0 && df_residual > 0 {
215            // Handle edge case where r_squared is exactly 1.0
216            if (1.0 - r_squared).abs() < f64::EPSILON {
217                // Perfect fit, F-statistic is infinite, p-value is 0
218                (Some(f64::INFINITY), Some(0.0))
219            } else if r_squared.abs() < f64::EPSILON {
220                // No relationship, F-statistic is 0
221                (Some(0.0), Some(1.0))
222            } else {
223                let f_stat =
224                    (r_squared / df_model as f64) / ((1.0 - r_squared) / df_residual as f64);
225
226                // Check for invalid F-statistic
227                if f_stat.is_nan() || f_stat.is_infinite() {
228                    (Some(f_stat), None)
229                } else {
230                    let x =
231                        df_model as f64 * f_stat / (df_residual as f64 + df_model as f64 * f_stat);
232                    // Ensure x is in [0, 1] for beta_reg
233                    let x_clamped = x.clamp(0.0, 1.0);
234                    // Handle edge case where x_clamped might be exactly 0 or 1 due to floating point
235                    let x_safe = if x_clamped <= 0.0 {
236                        f64::MIN_POSITIVE
237                    } else if x_clamped >= 1.0 {
238                        1.0 - f64::EPSILON
239                    } else {
240                        x_clamped
241                    };
242
243                    // beta_reg returns f64, may panic if x_safe is not in (0, 1)
244                    // x_safe is guaranteed to be in (0, 1) by construction
245                    let beta_val = statrs::function::beta::beta_reg(
246                        df_model as f64 / 2.0,
247                        df_residual as f64 / 2.0,
248                        x_safe,
249                    );
250                    (Some(f_stat), Some(1.0 - beta_val))
251                }
252            }
253        } else {
254            (None, None)
255        };
256
257        // Compute information criteria
258        let (aic, bic, log_likelihood) = self.compute_information_criteria(n, rss, p);
259
260        Ok(LinearRegressionResults {
261            coefficients,
262            std_errors,
263            t_values,
264            p_values,
265            rss,
266            tss,
267            r_squared,
268            r_squared_adj,
269            sigma,
270            df_residual,
271            df_model,
272            f_statistic,
273            f_p_value,
274            aic,
275            bic,
276            log_likelihood,
277        })
278    }
279
280    fn solve_normal_equations(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<Array1<f64>> {
281        // Solve normal equations: (X'X)β = X'y
282        let xtx = X.t().dot(X);
283        let xty = X.t().dot(y);
284
285        // Use so-linalg solver
286        so_linalg::solve(&xtx, &xty).map_err(|e| {
287            Error::LinearAlgebraError(format!("Solving normal equations failed: {}", e))
288        })
289    }
290
291    fn compute_inference(
292        &self,
293        X: &Array2<f64>,
294        _residuals: &Array1<f64>,
295        sigma: f64,
296        df_residual: usize,
297        coefficients: &Array1<f64>,
298    ) -> (
299        Option<Array1<f64>>,
300        Option<Array1<f64>>,
301        Option<Array1<f64>>,
302    ) {
303        let n = X.shape()[0];
304        let p = X.shape()[1];
305
306        if n <= p + 1 {
307            return (None, None, None);
308        }
309
310        // Compute covariance matrix: sigma^2 * (X'X)^-1
311        let xtx = X.t().dot(X);
312        let xtx_inv = match so_linalg::inv(&xtx) {
313            Ok(inv) => inv,
314            Err(_) => return (None, None, None),
315        };
316
317        let cov_matrix = &xtx_inv * sigma.powi(2);
318
319        // Standard errors are sqrt of diagonal elements
320        let std_errors: Array1<f64> = (0..p).map(|i| cov_matrix[(i, i)].sqrt()).collect();
321
322        // t-values = coefficients / std_errors
323        let t_values: Array1<f64> = coefficients
324            .iter()
325            .zip(std_errors.iter())
326            .map(|(&coef, &se)| coef / se)
327            .collect();
328
329        // p-values from t-distribution
330        let p_values: Array1<f64> = t_values
331            .iter()
332            .map(|&t: &f64| {
333                let t_abs = t.abs();
334                2.0 * (1.0
335                    - statrs::function::gamma::gamma_ur(
336                        df_residual as f64 / 2.0,
337                        df_residual as f64 / (df_residual as f64 + t_abs * t_abs),
338                    ))
339            })
340            .collect();
341
342        (Some(std_errors), Some(t_values), Some(p_values))
343    }
344
345    fn compute_information_criteria(
346        &self,
347        n: usize,
348        rss: f64,
349        p: usize,
350    ) -> (Option<f64>, Option<f64>, Option<f64>) {
351        if n <= p {
352            return (None, None, None);
353        }
354
355        let log_likelihood = -0.5 * n as f64 * (2.0 * std::f64::consts::PI * rss / n as f64).ln();
356        let aic = -2.0 * log_likelihood + 2.0 * p as f64;
357        let bic = -2.0 * log_likelihood + (n as f64).ln() * p as f64;
358
359        (Some(aic), Some(bic), Some(log_likelihood))
360    }
361}
362
363// ============================================================================
364// Ridge Regression
365// ============================================================================
366
367/// Ridge regression (L2 regularization)
368#[derive(Debug, Clone)]
369pub struct Ridge {
370    /// Regularization parameter (lambda)
371    pub alpha: f64,
372    /// Whether to include intercept
373    pub intercept: bool,
374    /// Whether to standardize features
375    pub standardize: bool,
376}
377
378impl Ridge {
379    /// Create a new Ridge regression model
380    pub fn new(alpha: f64) -> Self {
381        Self {
382            alpha,
383            intercept: true,
384            standardize: true,
385        }
386    }
387
388    /// Fit model with design matrix X and response y
389    pub fn fit(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<LinearRegressionResults> {
390        let n = X.shape()[0];
391        let p = X.shape()[1];
392
393        if n <= p && self.alpha == 0.0 {
394            return Err(Error::Message(
395                "Not enough observations for OLS estimation".to_string(),
396            ));
397        }
398
399        // Standardize features if requested
400        let (X_std, x_mean, x_std) = if self.standardize {
401            self.standardize_features(X, self.intercept)
402        } else {
403            (X.clone(), Array1::zeros(p), Array1::ones(p))
404        };
405
406        // Center response
407        let y_mean = y.mean().unwrap_or(0.0);
408        let y_centered = y - y_mean;
409
410        // Solve ridge regression: (X'X + alpha*I)^-1 X'y
411        let xtx = X_std.t().dot(&X_std);
412        let mut xtx_regularized = xtx.clone();
413
414        // Add regularization to diagonal (skip intercept if present)
415        let start_idx = if self.intercept { 1 } else { 0 };
416        for i in start_idx..p {
417            xtx_regularized[(i, i)] += self.alpha;
418        }
419
420        let xty = X_std.t().dot(&y_centered);
421        let coefficients_std = match so_linalg::solve(&xtx_regularized, &xty) {
422            Ok(coef) => coef,
423            Err(e) => {
424                return Err(Error::LinearAlgebraError(format!(
425                    "Ridge regression solve failed: {}",
426                    e
427                )));
428            }
429        };
430
431        // Unstandardize coefficients
432        let coefficients =
433            self.unstandardize_coefficients(&coefficients_std, &x_mean, &x_std, y_mean);
434
435        // For simplicity, we'll return basic results without inference
436        // (Ridge doesn't have straightforward standard errors)
437        let y_hat = X.dot(&coefficients);
438        let residuals = y - &y_hat;
439        let rss = residuals.dot(&residuals);
440        let tss = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>();
441        let r_squared = if tss == 0.0 { 1.0 } else { 1.0 - rss / tss };
442
443        Ok(LinearRegressionResults {
444            coefficients,
445            std_errors: None,
446            t_values: None,
447            p_values: None,
448            rss,
449            tss,
450            r_squared,
451            r_squared_adj: r_squared, // Simplified
452            sigma: (rss / (n - p) as f64).sqrt(),
453            df_residual: n - p,
454            df_model: p - if self.intercept { 1 } else { 0 },
455            f_statistic: None,
456            f_p_value: None,
457            aic: None,
458            bic: None,
459            log_likelihood: None,
460        })
461    }
462
463    fn standardize_features(
464        &self,
465        X: &Array2<f64>,
466        skip_intercept: bool,
467    ) -> (Array2<f64>, Array1<f64>, Array1<f64>) {
468        let n = X.shape()[0] as f64;
469        let p = X.shape()[1];
470
471        let mut x_mean = Array1::zeros(p);
472        let mut x_std = Array1::ones(p);
473        let mut X_std = X.clone();
474
475        let start_idx = if skip_intercept { 1 } else { 0 };
476
477        for j in start_idx..p {
478            let col = X.column(j);
479            let mean = col.mean().unwrap_or(0.0);
480            let variance = col.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (n - 1.0);
481            let std = variance.sqrt();
482
483            x_mean[j] = mean;
484            x_std[j] = if std == 0.0 { 1.0 } else { std };
485
486            // Standardize column
487            for i in 0..X_std.shape()[0] {
488                X_std[(i, j)] = (X_std[(i, j)] - mean) / x_std[j];
489            }
490        }
491
492        (X_std, x_mean, x_std)
493    }
494
495    fn unstandardize_coefficients(
496        &self,
497        coefficients_std: &Array1<f64>,
498        x_mean: &Array1<f64>,
499        x_std: &Array1<f64>,
500        y_mean: f64,
501    ) -> Array1<f64> {
502        let p = coefficients_std.len();
503        let mut coefficients = coefficients_std.clone();
504
505        // Adjust intercept
506        if self.intercept {
507            let mut intercept_adjustment = 0.0;
508            for j in 1..p {
509                intercept_adjustment += coefficients[j] * x_mean[j] / x_std[j];
510            }
511            // coefficients[0] is the intercept from standardized data with centered y
512            // Original intercept = y_mean + coefficients[0] - intercept_adjustment
513            coefficients[0] = y_mean + coefficients[0] - intercept_adjustment;
514        }
515
516        // Unstandardize other coefficients
517        for j in 1..p {
518            coefficients[j] /= x_std[j];
519        }
520
521        coefficients
522    }
523}
524
525// ============================================================================
526// Model Builder Interface
527// ============================================================================
528
529/// Builder for linear regression models with chainable API
530pub struct LinearModelBuilder<'a> {
531    formula: Option<&'a Formula>,
532    df: Option<&'a DataFrame>,
533    X: Option<Array2<f64>>,
534    y: Option<Array1<f64>>,
535    intercept: bool,
536    standardize: bool,
537}
538
539impl<'a> LinearModelBuilder<'a> {
540    /// Start building a model from formula and DataFrame
541    pub fn formula(formula: &'a Formula, df: &'a DataFrame) -> Self {
542        Self {
543            formula: Some(formula),
544            df: Some(df),
545            X: None,
546            y: None,
547            intercept: true,
548            standardize: true,
549        }
550    }
551
552    /// Start building a model from design matrix and response
553    pub fn matrix(X: Array2<f64>, y: Array1<f64>) -> Self {
554        Self {
555            formula: None,
556            df: None,
557            X: Some(X),
558            y: Some(y),
559            intercept: true,
560            standardize: true,
561        }
562    }
563
564    /// Disable intercept
565    pub fn no_intercept(mut self) -> Self {
566        self.intercept = false;
567        self
568    }
569
570    /// Disable feature standardization
571    pub fn no_standardize(mut self) -> Self {
572        self.standardize = false;
573        self
574    }
575
576    /// Fit OLS model
577    pub fn ols(self) -> Result<LinearRegressionResults> {
578        let (X, y) = self.prepare_data()?;
579
580        let mut model = OLS::new();
581        if !self.intercept {
582            model = OLS::no_intercept();
583        }
584
585        model.fit(&X, &y)
586    }
587
588    /// Fit Ridge regression model
589    pub fn ridge(self, alpha: f64) -> Result<LinearRegressionResults> {
590        let (X, y) = self.prepare_data()?;
591
592        let mut model = Ridge::new(alpha);
593        model.intercept = self.intercept;
594        model.standardize = self.standardize;
595
596        model.fit(&X, &y)
597    }
598
599    fn prepare_data(&self) -> Result<(Array2<f64>, Array1<f64>)> {
600        if let (Some(formula), Some(df)) = (self.formula, self.df) {
601            let X = formula.build_matrix(df)?;
602            let y = formula.response_vector(df)?.ok_or_else(|| {
603                Error::Message("Formula must include response variable".to_string())
604            })?;
605            Ok((X, y))
606        } else if let (Some(X), Some(y)) = (&self.X, &self.y) {
607            Ok((X.clone(), y.clone()))
608        } else {
609            Err(Error::Message(
610                "No data provided for model fitting".to_string(),
611            ))
612        }
613    }
614}
615
616// ============================================================================
617// Tests
618// ============================================================================
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623    use ndarray::{arr1, arr2};
624    use so_core::data::{DataFrame, Series};
625    use so_core::formula::Formula;
626    use std::collections::HashMap;
627
628    #[test]
629    fn test_ols_basic() {
630        // Simple linear relationship: y = 2 + 3*x
631        let X = arr2(&[[1.0, 1.0], [1.0, 2.0], [1.0, 3.0], [1.0, 4.0], [1.0, 5.0]]);
632
633        let y = arr1(&[5.0, 8.0, 11.0, 14.0, 17.0]); // 2 + 3*x
634
635        let model = OLS::new();
636        let results = model.fit(&X, &y).unwrap();
637
638        // Coefficients should be close to [2, 3]
639        assert!((results.coefficients[0] - 2.0).abs() < 0.001);
640        assert!((results.coefficients[1] - 3.0).abs() < 0.001);
641        assert!(results.r_squared > 0.99);
642    }
643
644    #[test]
645    fn test_ridge_basic() {
646        // Same data as OLS test
647        let X = arr2(&[[1.0, 1.0], [1.0, 2.0], [1.0, 3.0], [1.0, 4.0], [1.0, 5.0]]);
648
649        let y = arr1(&[5.0, 8.0, 11.0, 14.0, 17.0]);
650
651        let model = Ridge::new(0.1); // Small regularization
652        let results = model.fit(&X, &y).unwrap();
653
654        // Debug print coefficients
655        println!("Ridge coefficients: {:?}", results.coefficients);
656        println!("Expected: [2.0, 3.0]");
657        println!(
658            "Difference: [{:.4}, {:.4}]",
659            results.coefficients[0] - 2.0,
660            results.coefficients[1] - 3.0
661        );
662
663        // With small alpha and standardization, results should be similar to OLS
664        // Note: standardization affects regularization, so tolerance needs to be larger
665        let intercept_diff = (results.coefficients[0] - 2.0).abs();
666        let slope_diff = (results.coefficients[1] - 3.0).abs();
667        println!(
668            "Differences: intercept={:.4}, slope={:.4}",
669            intercept_diff, slope_diff
670        );
671
672        assert!(
673            intercept_diff < 0.25,
674            "Intercept coefficient {} not close to 2.0 (diff={})",
675            results.coefficients[0],
676            intercept_diff
677        );
678        assert!(
679            slope_diff < 0.25,
680            "Slope coefficient {} not close to 3.0 (diff={})",
681            results.coefficients[1],
682            slope_diff
683        );
684    }
685
686    #[test]
687    fn test_model_builder() {
688        let mut columns = HashMap::new();
689        columns.insert("y".to_string(), Series::new("y", arr1(&[1.0, 2.0, 3.0])));
690        columns.insert("x".to_string(), Series::new("x", arr1(&[1.0, 2.0, 3.0])));
691
692        let df = DataFrame::from_series(columns).unwrap();
693        let formula = Formula::parse("y ~ x").unwrap();
694
695        let results = LinearModelBuilder::formula(&formula, &df).ols().unwrap();
696
697        assert!(results.coefficients.len() == 2); // intercept + x
698        assert!(results.r_squared >= 0.0 && results.r_squared <= 1.0);
699    }
700}