Skip to main content

linreg_core/
core.rs

1//! Core OLS regression implementation.
2//!
3//! This module provides the main Ordinary Least Squares regression functionality
4//! that can be used directly in Rust code. Functions accept native Rust slices
5//! and return Result types for proper error handling.
6//!
7//! # Example
8//!
9//! ```
10//! # use linreg_core::core::ols_regression;
11//! # use linreg_core::Error;
12//! let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
13//! let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
14//! let x2 = vec![2.0, 3.0, 3.5, 4.0, 4.5, 5.0];
15//! let names = vec![
16//!     "Intercept".to_string(),
17//!     "X1".to_string(),
18//!     "X2".to_string(),
19//! ];
20//!
21//! let result = ols_regression(&y, &[x1, x2], &names)?;
22//! # Ok::<(), linreg_core::Error>(())
23//! ```
24
25use crate::distributions::{fisher_snedecor_cdf, student_t_cdf, student_t_inverse_cdf};
26use crate::error::{Error, Result};
27use crate::linalg::{vec_dot, vec_mean, vec_sub, Matrix};
28use crate::serialization::types::ModelType;
29// Import the macro from crate root (where #[macro_export] places it)
30use crate::impl_serialization;
31use serde::{Deserialize, Serialize};
32
33// ============================================================================
34// Numerical Constants
35// ============================================================================
36
37/// Minimum threshold for standardized residual denominator to avoid division by zero.
38/// When (1 - leverage) is very small, the observation has extremely high leverage
39/// and standardized residuals may be unreliable.
40const MIN_LEVERAGE_DENOM: f64 = 1e-10;
41
42// ============================================================================
43// Result Types
44// ============================================================================
45//
46// Structs containing the output of regression computations.
47
48/// Result of VIF (Variance Inflation Factor) calculation.
49///
50/// VIF measures how much the variance of an estimated regression coefficient
51/// increases due to multicollinearity among the predictors.
52///
53/// # Fields
54///
55/// * `variable` - Name of the predictor variable
56/// * `vif` - Variance Inflation Factor (VIF > 10 indicates high multicollinearity)
57/// * `rsquared` - R-squared from regressing this predictor on all others
58/// * `interpretation` - Human-readable interpretation of the VIF value
59///
60/// # Example
61///
62/// ```
63/// # use linreg_core::core::VifResult;
64/// let vif = VifResult {
65///     variable: "X1".to_string(),
66///     vif: 2.5,
67///     rsquared: 0.6,
68///     interpretation: "Low multicollinearity".to_string(),
69/// };
70/// assert_eq!(vif.variable, "X1");
71/// ```
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct VifResult {
74    /// Name of the predictor variable
75    pub variable: String,
76    /// Variance Inflation Factor (VIF > 10 indicates high multicollinearity)
77    pub vif: f64,
78    /// R-squared from regressing this predictor on all others
79    pub rsquared: f64,
80    /// Human-readable interpretation of the VIF value
81    pub interpretation: String,
82}
83
84/// Complete output from OLS regression.
85///
86/// Contains all coefficients, statistics, diagnostics, and residuals from
87/// an Ordinary Least Squares regression.
88///
89/// # Example
90///
91/// ```
92/// # use linreg_core::core::ols_regression;
93/// # use linreg_core::serialization::{ModelSave, ModelLoad};
94/// # use std::fs;
95/// let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
96/// let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
97/// let names = vec!["Intercept".to_string(), "X1".to_string()];
98///
99/// let result = ols_regression(&y, &[x1], &names).unwrap();
100/// assert!(result.r_squared > 0.0);
101/// assert_eq!(result.coefficients.len(), 2); // intercept + 1 predictor
102///
103/// // Save and load the model
104/// # let path = "test_model.json";
105/// # result.save(path).unwrap();
106/// # let loaded = linreg_core::core::RegressionOutput::load(path).unwrap();
107/// # assert_eq!(loaded.coefficients, result.coefficients);
108/// # fs::remove_file(path).ok();
109/// ```
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct RegressionOutput {
112    /// Regression coefficients (including intercept)
113    pub coefficients: Vec<f64>,
114    /// Standard errors of coefficients
115    pub std_errors: Vec<f64>,
116    /// t-statistics for coefficient significance tests
117    pub t_stats: Vec<f64>,
118    /// Two-tailed p-values for coefficients
119    pub p_values: Vec<f64>,
120    /// Lower bounds of 95% confidence intervals
121    pub conf_int_lower: Vec<f64>,
122    /// Upper bounds of 95% confidence intervals
123    pub conf_int_upper: Vec<f64>,
124    /// R-squared (coefficient of determination)
125    pub r_squared: f64,
126    /// Adjusted R-squared (accounts for number of predictors)
127    pub adj_r_squared: f64,
128    /// F-statistic for overall model significance
129    pub f_statistic: f64,
130    /// P-value for F-statistic
131    pub f_p_value: f64,
132    /// Mean squared error of residuals
133    pub mse: f64,
134    /// Root mean squared error (prediction error in original units)
135    pub rmse: f64,
136    /// Mean absolute error of residuals
137    pub mae: f64,
138    /// Standard error of the regression (residual standard deviation)
139    pub std_error: f64,
140    /// Raw residuals (observed - predicted)
141    pub residuals: Vec<f64>,
142    /// Standardized residuals (accounting for leverage)
143    pub standardized_residuals: Vec<f64>,
144    /// Fitted/predicted values
145    pub predictions: Vec<f64>,
146    /// Leverage values for each observation (diagonal of hat matrix)
147    pub leverage: Vec<f64>,
148    /// Variance Inflation Factors for detecting multicollinearity
149    pub vif: Vec<VifResult>,
150    /// Number of observations
151    pub n: usize,
152    /// Number of predictor variables (excluding intercept)
153    pub k: usize,
154    /// Degrees of freedom for residuals (n - k - 1)
155    pub df: usize,
156    /// Names of variables (including intercept)
157    pub variable_names: Vec<String>,
158    /// Log-likelihood of the model (useful for AIC/BIC calculation and model comparison)
159    pub log_likelihood: f64,
160    /// Akaike Information Criterion (lower = better model, penalizes complexity)
161    pub aic: f64,
162    /// Bayesian Information Criterion (lower = better model, penalizes complexity more heavily than AIC)
163    pub bic: f64,
164}
165
166// ============================================================================
167// Statistical Helper Functions
168// ============================================================================
169//
170// Utility functions for computing p-values, critical values, and leverage.
171
172/// Computes a two-tailed p-value from a t-statistic.
173///
174/// Uses the Student's t-distribution CDF to calculate the probability
175/// of observing a t-statistic as extreme as the one provided.
176///
177/// # Arguments
178///
179/// * `t` - The t-statistic value
180/// * `df` - Degrees of freedom
181///
182/// # Returns
183///
184/// The two-tailed p-value in the range [0, 1].
185///
186/// # Example
187///
188/// ```
189/// # use linreg_core::core::two_tailed_p_value;
190/// let p = two_tailed_p_value(2.0, 20.0);
191/// assert!(p > 0.0 && p < 0.1);
192/// ```
193pub fn two_tailed_p_value(t: f64, df: f64) -> f64 {
194    if t.abs() > 100.0 {
195        return 0.0;
196    }
197
198    let cdf = student_t_cdf(t, df);
199    if t >= 0.0 {
200        2.0 * (1.0 - cdf)
201    } else {
202        2.0 * cdf
203    }
204}
205
206/// Computes the critical t-value for a given significance level and degrees of freedom.
207///
208/// Returns the t-value such that the area under the t-distribution curve
209/// to the right of it equals alpha/2 (two-tailed test).
210///
211/// # Arguments
212///
213/// * `df` - Degrees of freedom
214/// * `alpha` - Significance level (typically 0.05 for 95% confidence)
215///
216/// # Returns
217///
218/// The critical t-value for the given significance level.
219///
220/// # Example
221///
222/// ```
223/// # use linreg_core::core::t_critical_quantile;
224/// let t_crit = t_critical_quantile(20.0, 0.05);
225/// assert!(t_crit > 2.0); // approximately 2.086 for df=20, alpha=0.05
226/// ```
227pub fn t_critical_quantile(df: f64, alpha: f64) -> f64 {
228    let p = 1.0 - alpha / 2.0;
229    student_t_inverse_cdf(p, df)
230}
231
232/// Computes a p-value from an F-statistic.
233///
234/// Uses the F-distribution CDF to calculate the probability of observing
235/// an F-statistic as extreme as the one provided.
236///
237/// # Arguments
238///
239/// * `f_stat` - The F-statistic value
240/// * `df1` - Numerator degrees of freedom
241/// * `df2` - Denominator degrees of freedom
242///
243/// # Returns
244///
245/// The p-value in the range [0, 1].
246///
247/// # Example
248///
249/// ```
250/// # use linreg_core::core::f_p_value;
251/// let p = f_p_value(5.0, 2.0, 20.0);
252/// assert!(p > 0.0 && p < 0.05);
253/// ```
254pub fn f_p_value(f_stat: f64, df1: f64, df2: f64) -> f64 {
255    if f_stat <= 0.0 {
256        return 1.0;
257    }
258    1.0 - fisher_snedecor_cdf(f_stat, df1, df2)
259}
260
261/// Computes leverage values from the design matrix and its inverse.
262///
263/// Leverage measures how far an observation's predictor values are from
264/// the center of the predictor space. High leverage points can have
265/// disproportionate influence on the regression results.
266///
267/// # Arguments
268///
269/// * `x` - Design matrix (including intercept column)
270/// * `xtx_inv` - Inverse of X'X matrix
271///
272/// # Example
273///
274/// ```
275/// # use linreg_core::core::compute_leverage;
276/// # use linreg_core::linalg::Matrix;
277/// // Design matrix with intercept: [[1, 1], [1, 2], [1, 3]]
278/// let x = Matrix::new(3, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0]);
279/// let xtx = x.transpose().matmul(&x);
280/// let xtx_inv = xtx.invert().unwrap();
281///
282/// let leverage = compute_leverage(&x, &xtx_inv);
283/// assert_eq!(leverage.len(), 3);
284/// // Leverage values should sum to the number of parameters (2)
285/// assert!((leverage.iter().sum::<f64>() - 2.0).abs() < 0.01);
286/// ```
287#[allow(clippy::needless_range_loop)]
288pub fn compute_leverage(x: &Matrix, xtx_inv: &Matrix) -> Vec<f64> {
289    let n = x.rows;
290    let mut leverage = vec![0.0; n];
291    for i in 0..n {
292        // x_row is (1, cols)
293        // temp = x_row * xtx_inv (1, cols)
294        // lev = temp * x_row^T (1, 1)
295
296        // Manual row extraction and multiplication
297        let mut row_vec = vec![0.0; x.cols];
298        for j in 0..x.cols {
299            row_vec[j] = x.get(i, j);
300        }
301
302        let mut temp_vec = vec![0.0; x.cols];
303        for c in 0..x.cols {
304            let mut sum = 0.0;
305            for k in 0..x.cols {
306                sum += row_vec[k] * xtx_inv.get(k, c);
307            }
308            temp_vec[c] = sum;
309        }
310
311        leverage[i] = vec_dot(&temp_vec, &row_vec);
312    }
313    leverage
314}
315
316// ============================================================================
317// Model Selection Criteria
318// ============================================================================
319//
320// Log-likelihood, AIC, and BIC for model comparison.
321
322/// Computes the log-likelihood of the OLS model.
323///
324/// For a linear regression with normally distributed errors, the log-likelihood is:
325///
326/// ```text
327/// log L = -n/2 * ln(2π) - n/2 * ln(SSR/n) - n/2
328///       = -n/2 * ln(2π*SSR/n) - n/2
329/// ```
330///
331/// where SSR is the sum of squared residuals and n is the number of observations.
332/// This matches R's `logLik.lm()` implementation.
333///
334/// # Arguments
335///
336/// * `n` - Number of observations
337/// * `mse` - Mean squared error (estimate of σ², but NOT directly used in formula)
338/// * `ssr` - Sum of squared residuals
339///
340/// # Example
341///
342/// ```
343/// # use linreg_core::core::log_likelihood;
344/// let ll = log_likelihood(100, 4.5, 450.0);
345/// assert!(ll < 0.0);  // Log-likelihood is negative for typical data
346/// ```
347pub fn log_likelihood(n: usize, _mse: f64, ssr: f64) -> f64 {
348    let n_f64 = n as f64;
349    let two_pi = 2.0 * std::f64::consts::PI;
350
351    // R's logLik.lm formula: -n/2 * log(2*pi*SSR/n) - n/2
352    // This is equivalent to: -n/2 * (log(2*pi) + log(SSR/n) + 1)
353    -0.5 * n_f64 * (two_pi * ssr / n_f64).ln() - n_f64 / 2.0
354}
355
356/// Computes the Akaike Information Criterion (AIC).
357///
358/// AIC = 2k - 2logL
359///
360/// where k is the number of estimated parameters and logL is the log-likelihood.
361/// Lower AIC indicates a better model, with a penalty for additional parameters.
362///
363/// Note: R's AIC for lm models counts k as (n_coef + 1) where n_coef is the
364/// number of coefficients (including intercept) and +1 is for the estimated
365/// variance parameter. This implementation follows that convention.
366///
367/// # Arguments
368///
369/// * `log_likelihood` - Log-likelihood of the model
370/// * `n_coef` - Number of coefficients (including intercept)
371///
372/// # Example
373///
374/// ```
375/// # use linreg_core::core::aic;
376/// let aic_value = aic(-150.5, 3);  // 3 coefficients
377/// ```
378pub fn aic(log_likelihood: f64, n_coef: usize) -> f64 {
379    // R's AIC for lm: 2k - 2*logL
380    // where k = n_coef + 1 (coefficients + variance parameter)
381    let k = n_coef + 1;
382    2.0 * k as f64 - 2.0 * log_likelihood
383}
384
385/// Computes the Bayesian Information Criterion (BIC).
386///
387/// BIC = k*ln(n) - 2logL
388///
389/// where k is the number of estimated parameters, n is the sample size, and
390/// logL is the log-likelihood. BIC penalizes model complexity more heavily
391/// than AIC for larger sample sizes.
392///
393/// Note: R's BIC for lm models counts k as (n_coef + 1) where n_coef is the
394/// number of coefficients (including intercept) and +1 is for the estimated
395/// variance parameter. This implementation follows that convention.
396///
397/// # Arguments
398///
399/// * `log_likelihood` - Log-likelihood of the model
400/// * `n_coef` - Number of coefficients (including intercept)
401/// * `n_obs` - Number of observations
402///
403/// # Example
404///
405/// ```
406/// # use linreg_core::core::bic;
407/// let bic_value = bic(-150.5, 3, 100);  // 3 coefficients, 100 obs
408/// ```
409pub fn bic(log_likelihood: f64, n_coef: usize, n_obs: usize) -> f64 {
410    // R's BIC for lm: k * ln(n) - 2*logL
411    // where k = n_coef + 1 (coefficients + variance parameter)
412    let k = n_coef + 1;
413    k as f64 * (n_obs as f64).ln() - 2.0 * log_likelihood
414}
415
416/// Computes AIC using Python/statsmodels convention.
417///
418/// AIC = 2k - 2logL
419///
420/// where k is the number of coefficients (NOT including variance parameter).
421/// This matches Python's statsmodels OLS.aic behavior.
422///
423/// # Arguments
424///
425/// * `log_likelihood` - Log-likelihood of the model
426/// * `n_coef` - Number of coefficients (including intercept)
427///
428/// # Example
429///
430/// ```
431/// # use linreg_core::core::aic_python;
432/// let aic_value = aic_python(-150.5, 3);  // 3 coefficients
433/// ```
434pub fn aic_python(log_likelihood: f64, n_coef: usize) -> f64 {
435    // Python's statsmodels AIC: 2k - 2*logL
436    // where k = n_coef (does NOT include variance parameter)
437    2.0 * n_coef as f64 - 2.0 * log_likelihood
438}
439
440/// Computes BIC using Python/statsmodels convention.
441///
442/// BIC = k*ln(n) - 2logL
443///
444/// where k is the number of coefficients (NOT including variance parameter).
445/// This matches Python's statsmodels OLS.bic behavior.
446///
447/// # Arguments
448///
449/// * `log_likelihood` - Log-likelihood of the model
450/// * `n_coef` - Number of coefficients (including intercept)
451/// * `n_obs` - Number of observations
452///
453/// # Example
454///
455/// ```
456/// # use linreg_core::core::bic_python;
457/// let bic_value = bic_python(-150.5, 3, 100);  // 3 coefficients, 100 obs
458/// ```
459pub fn bic_python(log_likelihood: f64, n_coef: usize, n_obs: usize) -> f64 {
460    // Python's statsmodels BIC: k * ln(n) - 2*logL
461    // where k = n_coef (does NOT include variance parameter)
462    n_coef as f64 * (n_obs as f64).ln() - 2.0 * log_likelihood
463}
464
465// ============================================================================
466// VIF Calculation
467// ============================================================================
468//
469// Variance Inflation Factor analysis for detecting multicollinearity.
470
471/// Calculates Variance Inflation Factors for all predictors.
472///
473/// VIF quantifies the severity of multicollinearity in a regression analysis.
474/// A VIF > 10 indicates high multicollinearity that may need to be addressed.
475///
476/// # Arguments
477///
478/// * `x_vars` - Predictor variables (each of length n)
479/// * `names` - Variable names (including intercept as first element)
480/// * `n` - Number of observations
481///
482/// # Returns
483///
484/// Vector of VIF results for each predictor (excluding intercept).
485///
486/// # Example
487///
488/// ```
489/// # use linreg_core::core::calculate_vif;
490/// let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
491/// let x2 = vec![2.0, 4.0, 6.0, 8.0, 10.0];
492/// let names = vec!["Intercept".to_string(), "X1".to_string(), "X2".to_string()];
493///
494/// let vif_results = calculate_vif(&[x1, x2], &names, 5);
495/// assert_eq!(vif_results.len(), 2);
496/// // Perfectly collinear variables will have VIF = infinity
497/// ```
498pub fn calculate_vif(x_vars: &[Vec<f64>], names: &[String], n: usize) -> Vec<VifResult> {
499    let k = x_vars.len();
500    if k <= 1 {
501        return vec![];
502    }
503
504    // Standardize predictors (Z-score)
505    let mut z_data = vec![0.0; n * k];
506
507    for (j, var) in x_vars.iter().enumerate() {
508        let mean = vec_mean(var);
509        let variance = var.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / ((n - 1) as f64);
510        let std_dev = variance.sqrt();
511
512        // Handle constant variables
513        if std_dev.abs() < 1e-10 {
514            return names
515                .iter()
516                .skip(1)
517                .map(|name| VifResult {
518                    variable: name.clone(),
519                    vif: f64::INFINITY,
520                    rsquared: 1.0,
521                    interpretation: "Constant variable (undefined correlation)".to_string(),
522                })
523                .collect();
524        }
525
526        for i in 0..n {
527            z_data[i * k + j] = (var[i] - mean) / std_dev;
528        }
529    }
530
531    let z = Matrix::new(n, k, z_data);
532
533    // Correlation Matrix R = (1/(n-1)) * Z^T * Z
534    let z_t = z.transpose();
535    let zt_z = z_t.matmul(&z);
536
537    // Scale by 1/(n-1)
538    let mut r_corr = zt_z; // Copy
539    let factor = 1.0 / ((n - 1) as f64);
540    for val in &mut r_corr.data {
541        *val *= factor;
542    }
543
544    // Invert R using QR on R_corr (since it's symmetric positive definite, Cholesky is better but QR works)
545    // Or just generic inversion. We implemented generic inversion for Upper Triangular.
546    // Let's use QR: A = QR => A^-1 = R^-1 Q^T
547    let (q_corr, r_corr_tri) = r_corr.qr();
548
549    let r_inv_opt = r_corr_tri.invert_upper_triangular();
550
551    let r_inv = match r_inv_opt {
552        Some(inv) => inv.matmul(&q_corr.transpose()),
553        None => {
554            return names
555                .iter()
556                .skip(1)
557                .map(|name| VifResult {
558                    variable: name.clone(),
559                    vif: f64::INFINITY,
560                    rsquared: 1.0,
561                    interpretation: "Perfect multicollinearity (singular matrix)".to_string(),
562                })
563                .collect();
564        },
565    };
566
567    // Extract diagonals
568    let mut results = vec![];
569    for j in 0..k {
570        let vif = r_inv.get(j, j);
571        let vif = if vif < 1.0 { 1.0 } else { vif };
572        let rsquared = 1.0 - 1.0 / vif;
573
574        let interpretation = if vif.is_infinite() {
575            "Perfect multicollinearity".to_string()
576        } else if vif > 10.0 {
577            "High multicollinearity - consider removing".to_string()
578        } else if vif > 5.0 {
579            "Moderate multicollinearity".to_string()
580        } else {
581            "Low multicollinearity".to_string()
582        };
583
584        results.push(VifResult {
585            variable: names[j + 1].clone(),
586            vif,
587            rsquared,
588            interpretation,
589        });
590    }
591
592    results
593}
594
595// ============================================================================
596// OLS Regression
597// ============================================================================
598//
599// Ordinary Least Squares regression implementation using QR decomposition.
600
601/// Performs Ordinary Least Squares regression using QR decomposition.
602///
603/// Uses a numerically stable QR decomposition approach to solve the normal
604/// equations. Returns comprehensive statistics including coefficients,
605/// standard errors, t-statistics, p-values, and diagnostic measures.
606///
607/// # Arguments
608///
609/// * `y` - Response variable (n observations)
610/// * `x_vars` - Predictor variables (each of length n)
611/// * `variable_names` - Names for variables (including intercept)
612///
613/// # Returns
614///
615/// A [`RegressionOutput`] containing all regression statistics and diagnostics.
616///
617/// # Errors
618///
619/// Returns [`Error::InsufficientData`] if n ≤ k + 1.
620/// Returns [`Error::SingularMatrix`] if perfect multicollinearity exists.
621/// Returns [`Error::InvalidInput`] if coefficients are NaN.
622///
623/// # Example
624///
625/// ```
626/// # use linreg_core::core::ols_regression;
627/// # use linreg_core::Error;
628/// let y = vec![2.5, 3.7, 4.2, 5.1, 6.3, 7.0];
629/// let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
630/// let x2 = vec![2.0, 4.0, 5.0, 4.0, 3.0, 2.0];
631/// let names = vec![
632///     "Intercept".to_string(),
633///     "Temperature".to_string(),
634///     "Pressure".to_string(),
635/// ];
636///
637/// let result = ols_regression(&y, &[x1, x2], &names)?;
638/// println!("R-squared: {}", result.r_squared);
639/// # Ok::<(), linreg_core::Error>(())
640/// ```
641#[allow(clippy::needless_range_loop)]
642pub fn ols_regression(
643    y: &[f64],
644    x_vars: &[Vec<f64>],
645    variable_names: &[String],
646) -> Result<RegressionOutput> {
647    let n = y.len();
648    let k = x_vars.len();
649    let p = k + 1;
650
651    // Validate inputs
652    if n <= k + 1 {
653        return Err(Error::InsufficientData {
654            required: k + 2,
655            available: n,
656        });
657    }
658
659    // Validate dimensions and finite values using shared helper
660    crate::diagnostics::validate_regression_data(y, x_vars)?;
661
662    // Prepare variable names
663    let mut names = variable_names.to_vec();
664    while names.len() <= k {
665        names.push(format!("X{}", names.len()));
666    }
667
668    // Create design matrix
669    let mut x_data = vec![0.0; n * p];
670    for (row, _yi) in y.iter().enumerate() {
671        x_data[row * p] = 1.0; // intercept
672        for (col, x_var) in x_vars.iter().enumerate() {
673            x_data[row * p + col + 1] = x_var[row];
674        }
675    }
676
677    let x_matrix = Matrix::new(n, p, x_data);
678
679    // QR Decomposition with column pivoting (LINPACK)
680    // This handles rank-deficient matrices gracefully
681    let qr_result = x_matrix.qr_linpack(None);
682    let rank = qr_result.rank;
683
684    // Solve for coefficients using the pivoted QR
685    // For rank-deficient cases, dropped coefficients are set to NAN
686    let beta = match x_matrix.qr_solve_linpack(&qr_result, y) {
687        Some(b) => b,
688        None => return Err(Error::SingularMatrix),
689    };
690
691    // Identify which coefficients are active (non-NAN) vs dropped
692    let active: Vec<bool> = beta.iter().map(|b| !b.is_nan()).collect();
693    let n_active = active.iter().filter(|&&a| a).count();
694
695    // For predictions, treat NAN coefficients as 0
696    let beta_for_pred: Vec<f64> = beta.iter().map(|&b| if b.is_nan() { 0.0 } else { b }).collect();
697    let predictions = x_matrix.mul_vec(&beta_for_pred);
698    let residuals = vec_sub(y, &predictions);
699
700    // Degrees of freedom based on rank 
701    let df = n - rank;
702
703    // Compute sums of squares
704    let y_mean = vec_mean(y);
705    let ss_total: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
706    let ss_residual: f64 = residuals.iter().map(|&r| r.powi(2)).sum();
707    let ss_regression = ss_total - ss_residual;
708
709    // R-squared and adjusted R-squared
710    // For rank-deficient cases, use rank-1 as the number of effective predictors
711    let k_effective = rank.saturating_sub(1); // rank minus intercept
712    let r_squared = if ss_total == 0.0 {
713        f64::NAN
714    } else {
715        1.0 - ss_residual / ss_total
716    };
717
718    let adj_r_squared = if ss_total == 0.0 || df == 0 {
719        f64::NAN
720    } else {
721        1.0 - (1.0 - r_squared) * ((n - 1) as f64 / df as f64)
722    };
723
724    // Mean squared error and standard error
725    let mse = if df > 0 { ss_residual / df as f64 } else { f64::NAN };
726    let std_error = mse.sqrt();
727
728    // Standard errors, t-stats, p-values, confidence intervals
729    // For active coefficients: compute from (X'X)^{-1} of the active subset
730    // For dropped coefficients: NAN (matching R's behavior)
731    let mut std_errors = vec![f64::NAN; p];
732    let mut t_stats = vec![f64::NAN; p];
733    let mut p_values = vec![f64::NAN; p];
734
735    // Build (X'X)^{-1} for the full model using the LINPACK R matrix
736    // Extract the rank×rank upper triangular R and invert it
737    let mut xtx_inv = Matrix::zeros(p, p);
738    if rank > 0 && df > 0 {
739        // Extract the rank×rank R sub-matrix (in pivoted order)
740        let mut r_sub = Matrix::zeros(rank, rank);
741        for i in 0..rank {
742            for j in i..rank {
743                r_sub.set(i, j, qr_result.qr.get(i, j));
744            }
745        }
746
747        // Invert the R sub-matrix
748        if let Some(r_inv_sub) = r_sub.invert_upper_triangular() {
749            // Compute (R^{-1})(R^{-1})^T in pivoted space
750            let xtx_inv_pivoted = r_inv_sub.matmul(&r_inv_sub.transpose());
751
752            // Map back to original column order via pivot permutation
753            for i in 0..rank {
754                let orig_i = qr_result.pivot[i] - 1; // Convert to 0-based
755                for j in 0..rank {
756                    let orig_j = qr_result.pivot[j] - 1;
757                    xtx_inv.set(orig_i, orig_j, xtx_inv_pivoted.get(i, j));
758                }
759            }
760
761            // Compute standard errors for active coefficients
762            for col in 0..p {
763                if active[col] {
764                    let se = (xtx_inv.get(col, col) * mse).sqrt();
765                    if !se.is_nan() {
766                        std_errors[col] = se;
767                        t_stats[col] = beta[col] / se;
768                        p_values[col] = two_tailed_p_value(t_stats[col], df as f64);
769                    }
770                }
771            }
772        }
773    }
774
775    // Confidence intervals
776    let alpha = 0.05;
777    let t_critical = if df > 0 { t_critical_quantile(df as f64, alpha) } else { f64::NAN };
778
779    let conf_int_lower: Vec<f64> = beta
780        .iter()
781        .zip(&std_errors)
782        .map(|(&b, &se)| if b.is_nan() || se.is_nan() { f64::NAN } else { b - t_critical * se })
783        .collect();
784    let conf_int_upper: Vec<f64> = beta
785        .iter()
786        .zip(&std_errors)
787        .map(|(&b, &se)| if b.is_nan() || se.is_nan() { f64::NAN } else { b + t_critical * se })
788        .collect();
789
790    // Leverage (using the rank-restricted (X'X)^{-1})
791    let leverage = compute_leverage(&x_matrix, &xtx_inv);
792
793    // Standardized residuals
794    let residuals_vec = residuals.clone();
795    let standardized_residuals: Vec<f64> = residuals_vec
796        .iter()
797        .zip(&leverage)
798        .map(|(&r, &h)| {
799            let factor = (1.0 - h).max(MIN_LEVERAGE_DENOM).sqrt();
800            let denom = std_error * factor;
801            if denom > MIN_LEVERAGE_DENOM {
802                r / denom
803            } else {
804                0.0
805            }
806        })
807        .collect();
808
809    // F-statistic (uses effective number of predictors)
810    let f_statistic = if k_effective > 0 && df > 0 {
811        (ss_regression / k_effective as f64) / mse
812    } else {
813        f64::NAN
814    };
815    let f_p_value = if f_statistic.is_nan() {
816        f64::NAN
817    } else {
818        f_p_value(f_statistic, k_effective as f64, df as f64)
819    };
820
821    // RMSE and MAE
822    let rmse = std_error;
823    let mae: f64 = residuals_vec.iter().map(|&r| r.abs()).sum::<f64>() / n as f64;
824
825    // VIF
826    let vif = calculate_vif(x_vars, &names, n);
827
828    // Model selection criteria (for model comparison)
829    let ll = log_likelihood(n, mse, ss_residual);
830    let n_coef = n_active;  // only active coefficients
831    let aic_val = aic(ll, n_coef);
832    let bic_val = bic(ll, n_coef, n);
833
834    Ok(RegressionOutput {
835        coefficients: beta,
836        std_errors,
837        t_stats,
838        p_values,
839        conf_int_lower,
840        conf_int_upper,
841        r_squared,
842        adj_r_squared,
843        f_statistic,
844        f_p_value,
845        mse,
846        rmse,
847        mae,
848        std_error,
849        residuals: residuals_vec,
850        standardized_residuals,
851        predictions,
852        leverage,
853        vif,
854        n,
855        k,
856        df,
857        variable_names: names,
858        log_likelihood: ll,
859        aic: aic_val,
860        bic: bic_val,
861    })
862}
863
864// ============================================================================
865// Model Serialization Traits
866// ============================================================================
867
868// Generate ModelSave and ModelLoad implementations using macro
869impl_serialization!(RegressionOutput, ModelType::OLS, "OLS");
870
871#[cfg(test)]
872mod tests {
873    use super::*;
874
875    #[test]
876    fn test_aic_bic_formulas_known_values() {
877        // Test formulas with simple known inputs
878        let ll = -100.0;
879        let n_coef = 3; // 3 coefficients (e.g., intercept + 2 predictors)
880        let n_obs = 100;
881
882        let aic_val = aic(ll, n_coef);
883        let bic_val = bic(ll, n_coef, n_obs);
884
885        // AIC = 2k - 2logL where k = n_coef + 1 (variance parameter)
886        // AIC = 2*4 - 2*(-100) = 8 + 200 = 208
887        assert!((aic_val - 208.0).abs() < 1e-10);
888
889        // BIC = k*ln(n) - 2logL where k = n_coef + 1
890        // BIC = 4*ln(100) - 2*(-100) = 4*4.605... + 200
891        let expected_bic = 4.0 * (100.0_f64).ln() + 200.0;
892        assert!((bic_val - expected_bic).abs() < 1e-10);
893    }
894
895    #[test]
896    fn test_bic_greater_than_aic_for_reasonable_n() {
897        // For n >= 8, ln(n) > 2, so BIC > AIC (both have -2logL term)
898        // BIC uses k*ln(n) while AIC uses 2k, so when ln(n) > 2, BIC > AIC
899        let ll = -50.0;
900        let n_coef = 2;
901
902        let aic_val = aic(ll, n_coef);
903        let bic_val = bic(ll, n_coef, 100); // n=100, ln(100) ≈ 4.6 > 2
904
905        assert!(bic_val > aic_val);
906    }
907
908    #[test]
909    fn test_log_likelihood_returns_finite() {
910        // Ensure log_likelihood returns finite values for valid inputs
911        let n = 10;
912        let mse = 4.0;
913        let ssr = mse * (n - 2) as f64;
914
915        let ll = log_likelihood(n, mse, ssr);
916        assert!(ll.is_finite());
917    }
918
919    #[test]
920    fn test_log_likelihood_increases_with_better_fit() {
921        // Lower SSR (better fit) should give higher log-likelihood
922        let n = 10;
923
924        // Worse fit (higher residuals)
925        let ll_worse = log_likelihood(n, 10.0, 80.0);
926
927        // Better fit (lower residuals)
928        let ll_better = log_likelihood(n, 2.0, 16.0);
929
930        assert!(ll_better > ll_worse);
931    }
932
933    #[test]
934    fn test_model_selection_criteria_present_in_output() {
935        // Basic sanity check that the fields are populated
936        let y = vec![2.0, 4.0, 5.0, 4.0, 5.0];
937        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
938        let names = vec!["Intercept".to_string(), "X1".to_string()];
939
940        let result = ols_regression(&y, &[x1], &names).unwrap();
941
942        // All three should be finite
943        assert!(result.log_likelihood.is_finite());
944        assert!(result.aic.is_finite());
945        assert!(result.bic.is_finite());
946
947        // AIC and BIC should be positive for typical cases
948        // (since log_likelihood is usually negative and bounded)
949        assert!(result.aic > 0.0);
950        assert!(result.bic > 0.0);
951    }
952
953    #[test]
954    fn test_regression_output_has_correct_dimensions() {
955        // Verify AIC/BIC use k = n_coef + 1 (coefficients + variance parameter)
956        let y = vec![3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0];
957        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
958        let x2 = vec![3.0, 2.0, 4.0, 3.0, 5.0, 4.0, 6.0, 5.0];
959        let names = vec!["Intercept".into(), "X1".into(), "X2".into()];
960
961        let result = ols_regression(&y, &[x1, x2], &names).unwrap();
962
963        // n_coef = 3 (intercept + 2 predictors)
964        // k = n_coef + 1 = 4 (including variance parameter, following R convention)
965        let n_coef = 3;
966        let k = n_coef + 1; // R's convention includes variance parameter
967
968        // Verify by recalculating AIC from log_likelihood
969        let expected_aic = 2.0 * k as f64 - 2.0 * result.log_likelihood;
970        assert!((result.aic - expected_aic).abs() < 1e-10);
971
972        // Verify by recalculating BIC from log_likelihood
973        let expected_bic = k as f64 * (result.n as f64).ln() - 2.0 * result.log_likelihood;
974        assert!((result.bic - expected_bic).abs() < 1e-10);
975    }
976
977    #[test]
978    fn test_aic_python_convention() {
979        // Python's statsmodels uses k = n_coef (no variance parameter)
980        let ll = -100.0;
981        let n_coef = 3;
982
983        let aic_py = aic_python(ll, n_coef);
984        // AIC = 2k - 2logL where k = n_coef (Python convention)
985        // AIC = 2*3 - 2*(-100) = 6 + 200 = 206
986        assert!((aic_py - 206.0).abs() < 1e-10);
987    }
988
989    #[test]
990    fn test_bic_python_convention() {
991        // Python's statsmodels uses k = n_coef (no variance parameter)
992        let ll = -100.0;
993        let n_coef = 3;
994        let n_obs = 100;
995
996        let bic_py = bic_python(ll, n_coef, n_obs);
997        // BIC = k*ln(n) - 2logL where k = n_coef (Python convention)
998        // BIC = 3*ln(100) - 2*(-100) = 3*4.605... + 200
999        let expected_bic = 3.0 * (100.0_f64).ln() + 200.0;
1000        assert!((bic_py - expected_bic).abs() < 1e-10);
1001    }
1002
1003    #[test]
1004    fn test_python_aic_smaller_than_r_aic() {
1005        // Python convention uses k = n_coef, R uses k = n_coef + 1
1006        // So Python AIC should be 2 smaller than R AIC
1007        let ll = -50.0;
1008        let n_coef = 2;
1009
1010        let aic_r = aic(ll, n_coef);
1011        let aic_py = aic_python(ll, n_coef);
1012
1013        assert_eq!(aic_r - aic_py, 2.0);
1014    }
1015
1016    #[test]
1017    fn test_log_likelihood_formula_matches_r() {
1018        // Test against R's logLik.lm() formula
1019        // For a model with n=100, SSR=450, logL = -n/2 * log(2*pi*SSR/n) - n/2
1020        let n = 100;
1021        let ssr = 450.0;
1022        let mse = ssr / (n as f64 - 2.0); // 2 parameters
1023
1024        let ll = log_likelihood(n, mse, ssr);
1025
1026        // Calculate expected value manually
1027        let two_pi = 2.0 * std::f64::consts::PI;
1028        let expected = -0.5 * n as f64 * (two_pi * ssr / n as f64).ln() - n as f64 / 2.0;
1029
1030        assert!((ll - expected).abs() < 1e-10);
1031    }
1032
1033    #[test]
1034    fn test_aic_bic_with_perfect_fit() {
1035        // Perfect fit (zero residuals) - edge case
1036        let n = 10;
1037        let ssr = 0.001; // Very small but non-zero to avoid log(0)
1038        let mse = ssr / (n as f64 - 2.0);
1039
1040        let ll = log_likelihood(n, mse, ssr);
1041        let aic_val = aic(ll, 2);
1042        let bic_val = bic(ll, 2, n);
1043
1044        // Perfect fit gives very high log-likelihood
1045        assert!(ll > 0.0);
1046        // AIC/BIC penalize complexity, so may be negative for very good fits
1047        assert!(aic_val.is_finite());
1048        assert!(bic_val.is_finite());
1049    }
1050
1051    #[test]
1052    fn test_aic_bic_model_selection() {
1053        // Simulate model comparison: simpler model vs complex model
1054        // Both models fit same data with similar R² but different complexity
1055        let n = 100;
1056
1057        // Simple model (2 params): better log-likelihood due to less penalty
1058        let ll_simple = -150.0;
1059        let aic_simple = aic(ll_simple, 2);
1060        let bic_simple = bic(ll_simple, 2, n);
1061
1062        // Complex model (5 params): slightly better fit but more parameters
1063        let ll_complex = -148.0; // Better fit (less negative)
1064        let aic_complex = aic(ll_complex, 5);
1065        let bic_complex = bic(ll_complex, 5, n);
1066
1067        // AIC might favor complex model (2*2 - 2*(-150) = 304 vs 2*6 - 2*(-148) = 308)
1068        // Actually: 4 + 300 = 304 vs 12 + 296 = 308, so simple wins
1069        assert!(aic_simple < aic_complex);
1070
1071        // BIC more heavily penalizes complexity, so simple should win
1072        assert!(bic_simple < bic_complex);
1073    }
1074
1075    #[test]
1076    fn test_log_likelihood_scale_invariance() {
1077        // Log-likelihood scales with sample size for same per-observation fit quality
1078        let ssr_per_obs = 1.0;
1079
1080        let n1 = 50;
1081        let ssr1 = ssr_per_obs * n1 as f64;
1082        let ll1 = log_likelihood(n1, ssr1 / (n1 as f64 - 2.0), ssr1);
1083
1084        let n2 = 100;
1085        let ssr2 = ssr_per_obs * n2 as f64;
1086        let ll2 = log_likelihood(n2, ssr2 / (n2 as f64 - 2.0), ssr2);
1087
1088        // The log-likelihood should become more negative with larger n for the same SSR/n ratio
1089        // because -n/2 * ln(2*pi*SSR/n) - n/2 becomes more negative as n increases
1090        assert!(ll2 < ll1);
1091
1092        // But when normalized by n, they should be similar
1093        let ll_per_obs1 = ll1 / n1 as f64;
1094        let ll_per_obs2 = ll2 / n2 as f64;
1095        assert!((ll_per_obs1 - ll_per_obs2).abs() < 0.1);
1096    }
1097
1098    #[test]
1099    fn test_regularized_regression_has_model_selection_criteria() {
1100        // Test that Ridge regression also calculates AIC/BIC/log_likelihood
1101        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
1102        let x_data = vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0];
1103        let x = crate::linalg::Matrix::new(5, 2, x_data);
1104
1105        let options = crate::regularized::ridge::RidgeFitOptions {
1106            lambda: 0.1,
1107            intercept: true,
1108            standardize: false,
1109            ..Default::default()
1110        };
1111
1112        let fit = crate::regularized::ridge::ridge_fit(&x, &y, &options).unwrap();
1113
1114        assert!(fit.log_likelihood.is_finite());
1115        assert!(fit.aic.is_finite());
1116        assert!(fit.bic.is_finite());
1117    }
1118
1119    #[test]
1120    fn test_elastic_net_regression_has_model_selection_criteria() {
1121        // Test that Elastic Net regression also calculates AIC/BIC/log_likelihood
1122        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
1123        let x_data = vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0];
1124        let x = crate::linalg::Matrix::new(5, 2, x_data);
1125
1126        let options = crate::regularized::elastic_net::ElasticNetOptions {
1127            lambda: 0.1,
1128            alpha: 0.5,
1129            intercept: true,
1130            standardize: false,
1131            ..Default::default()
1132        };
1133
1134        let fit = crate::regularized::elastic_net::elastic_net_fit(&x, &y, &options).unwrap();
1135
1136        assert!(fit.log_likelihood.is_finite());
1137        assert!(fit.aic.is_finite());
1138        assert!(fit.bic.is_finite());
1139    }
1140}