Skip to main content

linreg_core/
core.rs

1//! Core OLS regression implementation.
2//!
3//! This module provides the main Ordinary Least Squares regression functionality
4//! that can be used directly in Rust code. Functions accept native Rust slices
5//! and return Result types for proper error handling.
6//!
7//! # Example
8//!
9//! ```
10//! # use linreg_core::core::ols_regression;
11//! # use linreg_core::Error;
12//! let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
13//! let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
14//! let x2 = vec![2.0, 3.0, 3.5, 4.0, 4.5, 5.0];
15//! let names = vec![
16//!     "Intercept".to_string(),
17//!     "X1".to_string(),
18//!     "X2".to_string(),
19//! ];
20//!
21//! let result = ols_regression(&y, &[x1, x2], &names)?;
22//! # Ok::<(), Error>(())
23//! ```
24
25use crate::distributions::{fisher_snedecor_cdf, student_t_cdf, student_t_inverse_cdf};
26use crate::error::{Error, Result};
27use crate::linalg::{vec_dot, vec_mean, vec_sub, Matrix};
28use serde::Serialize;
29
30// ============================================================================
31// Numerical Constants
32// ============================================================================
33
34/// Minimum threshold for standardized residual denominator to avoid division by zero.
35/// When (1 - leverage) is very small, the observation has extremely high leverage
36/// and standardized residuals may be unreliable.
37const MIN_LEVERAGE_DENOM: f64 = 1e-10;
38
39// ============================================================================
40// Result Types
41// ============================================================================
42//
43// Structs containing the output of regression computations.
44
45/// Result of VIF (Variance Inflation Factor) calculation.
46///
47/// VIF measures how much the variance of an estimated regression coefficient
48/// increases due to multicollinearity among the predictors.
49///
50/// # Fields
51///
52/// * `variable` - Name of the predictor variable
53/// * `vif` - Variance Inflation Factor (VIF > 10 indicates high multicollinearity)
54/// * `rsquared` - R-squared from regressing this predictor on all others
55/// * `interpretation` - Human-readable interpretation of the VIF value
56///
57/// # Example
58///
59/// ```
60/// # use linreg_core::core::VifResult;
61/// let vif = VifResult {
62///     variable: "X1".to_string(),
63///     vif: 2.5,
64///     rsquared: 0.6,
65///     interpretation: "Low multicollinearity".to_string(),
66/// };
67/// assert_eq!(vif.variable, "X1");
68/// ```
69#[derive(Debug, Clone, Serialize)]
70pub struct VifResult {
71    /// Name of the predictor variable
72    pub variable: String,
73    /// Variance Inflation Factor (VIF > 10 indicates high multicollinearity)
74    pub vif: f64,
75    /// R-squared from regressing this predictor on all others
76    pub rsquared: f64,
77    /// Human-readable interpretation of the VIF value
78    pub interpretation: String,
79}
80
81/// Complete output from OLS regression.
82///
83/// Contains all coefficients, statistics, diagnostics, and residuals from
84/// an Ordinary Least Squares regression.
85///
86/// # Example
87///
88/// ```
89/// # use linreg_core::core::ols_regression;
90/// let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
91/// let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
92/// let names = vec!["Intercept".to_string(), "X1".to_string()];
93///
94/// let result = ols_regression(&y, &[x1], &names).unwrap();
95/// assert!(result.r_squared > 0.0);
96/// assert_eq!(result.coefficients.len(), 2); // intercept + 1 predictor
97/// ```
98#[derive(Debug, Clone, Serialize)]
99pub struct RegressionOutput {
100    /// Regression coefficients (including intercept)
101    pub coefficients: Vec<f64>,
102    /// Standard errors of coefficients
103    pub std_errors: Vec<f64>,
104    /// t-statistics for coefficient significance tests
105    pub t_stats: Vec<f64>,
106    /// Two-tailed p-values for coefficients
107    pub p_values: Vec<f64>,
108    /// Lower bounds of 95% confidence intervals
109    pub conf_int_lower: Vec<f64>,
110    /// Upper bounds of 95% confidence intervals
111    pub conf_int_upper: Vec<f64>,
112    /// R-squared (coefficient of determination)
113    pub r_squared: f64,
114    /// Adjusted R-squared (accounts for number of predictors)
115    pub adj_r_squared: f64,
116    /// F-statistic for overall model significance
117    pub f_statistic: f64,
118    /// P-value for F-statistic
119    pub f_p_value: f64,
120    /// Mean squared error of residuals
121    pub mse: f64,
122    /// Root mean squared error (prediction error in original units)
123    pub rmse: f64,
124    /// Mean absolute error of residuals
125    pub mae: f64,
126    /// Standard error of the regression (residual standard deviation)
127    pub std_error: f64,
128    /// Raw residuals (observed - predicted)
129    pub residuals: Vec<f64>,
130    /// Standardized residuals (accounting for leverage)
131    pub standardized_residuals: Vec<f64>,
132    /// Fitted/predicted values
133    pub predictions: Vec<f64>,
134    /// Leverage values for each observation (diagonal of hat matrix)
135    pub leverage: Vec<f64>,
136    /// Variance Inflation Factors for detecting multicollinearity
137    pub vif: Vec<VifResult>,
138    /// Number of observations
139    pub n: usize,
140    /// Number of predictor variables (excluding intercept)
141    pub k: usize,
142    /// Degrees of freedom for residuals (n - k - 1)
143    pub df: usize,
144    /// Names of variables (including intercept)
145    pub variable_names: Vec<String>,
146    /// Log-likelihood of the model (useful for AIC/BIC calculation and model comparison)
147    pub log_likelihood: f64,
148    /// Akaike Information Criterion (lower = better model, penalizes complexity)
149    pub aic: f64,
150    /// Bayesian Information Criterion (lower = better model, penalizes complexity more heavily than AIC)
151    pub bic: f64,
152}
153
154// ============================================================================
155// Statistical Helper Functions
156// ============================================================================
157//
158// Utility functions for computing p-values, critical values, and leverage.
159
160/// Computes a two-tailed p-value from a t-statistic.
161///
162/// Uses the Student's t-distribution CDF to calculate the probability
163/// of observing a t-statistic as extreme as the one provided.
164///
165/// # Arguments
166///
167/// * `t` - The t-statistic value
168/// * `df` - Degrees of freedom
169///
170/// # Example
171///
172/// ```
173/// # use linreg_core::core::two_tailed_p_value;
174/// let p = two_tailed_p_value(2.0, 20.0);
175/// assert!(p > 0.0 && p < 0.1);
176/// ```
177pub fn two_tailed_p_value(t: f64, df: f64) -> f64 {
178    if t.abs() > 100.0 {
179        return 0.0;
180    }
181
182    let cdf = student_t_cdf(t, df);
183    if t >= 0.0 {
184        2.0 * (1.0 - cdf)
185    } else {
186        2.0 * cdf
187    }
188}
189
190/// Computes the critical t-value for a given significance level and degrees of freedom.
191///
192/// Returns the t-value such that the area under the t-distribution curve
193/// to the right of it equals alpha/2 (two-tailed test).
194///
195/// # Arguments
196///
197/// * `df` - Degrees of freedom
198/// * `alpha` - Significance level (typically 0.05 for 95% confidence)
199///
200/// # Example
201///
202/// ```
203/// # use linreg_core::core::t_critical_quantile;
204/// let t_crit = t_critical_quantile(20.0, 0.05);
205/// assert!(t_crit > 2.0); // approximately 2.086 for df=20, alpha=0.05
206/// ```
207pub fn t_critical_quantile(df: f64, alpha: f64) -> f64 {
208    let p = 1.0 - alpha / 2.0;
209    student_t_inverse_cdf(p, df)
210}
211
212/// Computes a p-value from an F-statistic.
213///
214/// Uses the F-distribution CDF to calculate the probability of observing
215/// an F-statistic as extreme as the one provided.
216///
217/// # Arguments
218///
219/// * `f_stat` - The F-statistic value
220/// * `df1` - Numerator degrees of freedom
221/// * `df2` - Denominator degrees of freedom
222///
223/// # Example
224///
225/// ```
226/// # use linreg_core::core::f_p_value;
227/// let p = f_p_value(5.0, 2.0, 20.0);
228/// assert!(p > 0.0 && p < 0.05);
229/// ```
230pub fn f_p_value(f_stat: f64, df1: f64, df2: f64) -> f64 {
231    if f_stat <= 0.0 {
232        return 1.0;
233    }
234    1.0 - fisher_snedecor_cdf(f_stat, df1, df2)
235}
236
237/// Computes leverage values from the design matrix and its inverse.
238///
239/// Leverage measures how far an observation's predictor values are from
240/// the center of the predictor space. High leverage points can have
241/// disproportionate influence on the regression results.
242///
243/// # Arguments
244///
245/// * `x` - Design matrix (including intercept column)
246/// * `xtx_inv` - Inverse of X'X matrix
247///
248/// # Example
249///
250/// ```
251/// # use linreg_core::core::compute_leverage;
252/// # use linreg_core::linalg::Matrix;
253/// // Design matrix with intercept: [[1, 1], [1, 2], [1, 3]]
254/// let x = Matrix::new(3, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0]);
255/// let xtx = x.transpose().matmul(&x);
256/// let xtx_inv = xtx.invert().unwrap();
257///
258/// let leverage = compute_leverage(&x, &xtx_inv);
259/// assert_eq!(leverage.len(), 3);
260/// // Leverage values should sum to the number of parameters (2)
261/// assert!((leverage.iter().sum::<f64>() - 2.0).abs() < 0.01);
262/// ```
263#[allow(clippy::needless_range_loop)]
264pub fn compute_leverage(x: &Matrix, xtx_inv: &Matrix) -> Vec<f64> {
265    let n = x.rows;
266    let mut leverage = vec![0.0; n];
267    for i in 0..n {
268        // x_row is (1, cols)
269        // temp = x_row * xtx_inv (1, cols)
270        // lev = temp * x_row^T (1, 1)
271
272        // Manual row extraction and multiplication
273        let mut row_vec = vec![0.0; x.cols];
274        for j in 0..x.cols {
275            row_vec[j] = x.get(i, j);
276        }
277
278        let mut temp_vec = vec![0.0; x.cols];
279        for c in 0..x.cols {
280            let mut sum = 0.0;
281            for k in 0..x.cols {
282                sum += row_vec[k] * xtx_inv.get(k, c);
283            }
284            temp_vec[c] = sum;
285        }
286
287        leverage[i] = vec_dot(&temp_vec, &row_vec);
288    }
289    leverage
290}
291
292// ============================================================================
293// Model Selection Criteria
294// ============================================================================
295//
296// Log-likelihood, AIC, and BIC for model comparison.
297
298/// Computes the log-likelihood of the OLS model.
299///
300/// For a linear regression with normally distributed errors, the log-likelihood is:
301///
302/// ```text
303/// log L = -n/2 * ln(2π) - n/2 * ln(SSR/n) - n/2
304///       = -n/2 * ln(2π*SSR/n) - n/2
305/// ```
306///
307/// where SSR is the sum of squared residuals and n is the number of observations.
308/// This matches R's `logLik.lm()` implementation.
309///
310/// # Arguments
311///
312/// * `n` - Number of observations
313/// * `mse` - Mean squared error (estimate of σ², but NOT directly used in formula)
314/// * `ssr` - Sum of squared residuals
315///
316/// # Example
317///
318/// ```
319/// # use linreg_core::core::log_likelihood;
320/// let ll = log_likelihood(100, 4.5, 450.0);
321/// assert!(ll < 0.0);  // Log-likelihood is negative for typical data
322/// ```
323pub fn log_likelihood(n: usize, _mse: f64, ssr: f64) -> f64 {
324    let n_f64 = n as f64;
325    let two_pi = 2.0 * std::f64::consts::PI;
326
327    // R's logLik.lm formula: -n/2 * log(2*pi*SSR/n) - n/2
328    // This is equivalent to: -n/2 * (log(2*pi) + log(SSR/n) + 1)
329    -0.5 * n_f64 * (two_pi * ssr / n_f64).ln() - n_f64 / 2.0
330}
331
332/// Computes the Akaike Information Criterion (AIC).
333///
334/// AIC = 2k - 2logL
335///
336/// where k is the number of estimated parameters and logL is the log-likelihood.
337/// Lower AIC indicates a better model, with a penalty for additional parameters.
338///
339/// Note: R's AIC for lm models counts k as (n_coef + 1) where n_coef is the
340/// number of coefficients (including intercept) and +1 is for the estimated
341/// variance parameter. This implementation follows that convention.
342///
343/// # Arguments
344///
345/// * `log_likelihood` - Log-likelihood of the model
346/// * `n_coef` - Number of coefficients (including intercept)
347///
348/// # Example
349///
350/// ```
351/// # use linreg_core::core::aic;
352/// let aic_value = aic(-150.5, 3);  // 3 coefficients
353/// ```
354pub fn aic(log_likelihood: f64, n_coef: usize) -> f64 {
355    // R's AIC for lm: 2k - 2*logL
356    // where k = n_coef + 1 (coefficients + variance parameter)
357    let k = n_coef + 1;
358    2.0 * k as f64 - 2.0 * log_likelihood
359}
360
361/// Computes the Bayesian Information Criterion (BIC).
362///
363/// BIC = k*ln(n) - 2logL
364///
365/// where k is the number of estimated parameters, n is the sample size, and
366/// logL is the log-likelihood. BIC penalizes model complexity more heavily
367/// than AIC for larger sample sizes.
368///
369/// Note: R's BIC for lm models counts k as (n_coef + 1) where n_coef is the
370/// number of coefficients (including intercept) and +1 is for the estimated
371/// variance parameter. This implementation follows that convention.
372///
373/// # Arguments
374///
375/// * `log_likelihood` - Log-likelihood of the model
376/// * `n_coef` - Number of coefficients (including intercept)
377/// * `n_obs` - Number of observations
378///
379/// # Example
380///
381/// ```
382/// # use linreg_core::core::bic;
383/// let bic_value = bic(-150.5, 3, 100);  // 3 coefficients, 100 obs
384/// ```
385pub fn bic(log_likelihood: f64, n_coef: usize, n_obs: usize) -> f64 {
386    // R's BIC for lm: k * ln(n) - 2*logL
387    // where k = n_coef + 1 (coefficients + variance parameter)
388    let k = n_coef + 1;
389    k as f64 * (n_obs as f64).ln() - 2.0 * log_likelihood
390}
391
392/// Computes AIC using Python/statsmodels convention.
393///
394/// AIC = 2k - 2logL
395///
396/// where k is the number of coefficients (NOT including variance parameter).
397/// This matches Python's statsmodels OLS.aic behavior.
398///
399/// # Arguments
400///
401/// * `log_likelihood` - Log-likelihood of the model
402/// * `n_coef` - Number of coefficients (including intercept)
403///
404/// # Example
405///
406/// ```
407/// # use linreg_core::core::aic_python;
408/// let aic_value = aic_python(-150.5, 3);  // 3 coefficients
409/// ```
410pub fn aic_python(log_likelihood: f64, n_coef: usize) -> f64 {
411    // Python's statsmodels AIC: 2k - 2*logL
412    // where k = n_coef (does NOT include variance parameter)
413    2.0 * n_coef as f64 - 2.0 * log_likelihood
414}
415
416/// Computes BIC using Python/statsmodels convention.
417///
418/// BIC = k*ln(n) - 2logL
419///
420/// where k is the number of coefficients (NOT including variance parameter).
421/// This matches Python's statsmodels OLS.bic behavior.
422///
423/// # Arguments
424///
425/// * `log_likelihood` - Log-likelihood of the model
426/// * `n_coef` - Number of coefficients (including intercept)
427/// * `n_obs` - Number of observations
428///
429/// # Example
430///
431/// ```
432/// # use linreg_core::core::bic_python;
433/// let bic_value = bic_python(-150.5, 3, 100);  // 3 coefficients, 100 obs
434/// ```
435pub fn bic_python(log_likelihood: f64, n_coef: usize, n_obs: usize) -> f64 {
436    // Python's statsmodels BIC: k * ln(n) - 2*logL
437    // where k = n_coef (does NOT include variance parameter)
438    n_coef as f64 * (n_obs as f64).ln() - 2.0 * log_likelihood
439}
440
441// ============================================================================
442// VIF Calculation
443// ============================================================================
444//
445// Variance Inflation Factor analysis for detecting multicollinearity.
446
447/// Calculates Variance Inflation Factors for all predictors.
448///
449/// VIF quantifies the severity of multicollinearity in a regression analysis.
450/// A VIF > 10 indicates high multicollinearity that may need to be addressed.
451///
452/// # Arguments
453///
454/// * `x_vars` - Predictor variables (each of length n)
455/// * `names` - Variable names (including intercept as first element)
456/// * `n` - Number of observations
457///
458/// # Returns
459///
460/// Vector of VIF results for each predictor (excluding intercept).
461///
462/// # Example
463///
464/// ```
465/// # use linreg_core::core::calculate_vif;
466/// let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
467/// let x2 = vec![2.0, 4.0, 6.0, 8.0, 10.0];
468/// let names = vec!["Intercept".to_string(), "X1".to_string(), "X2".to_string()];
469///
470/// let vif_results = calculate_vif(&[x1, x2], &names, 5);
471/// assert_eq!(vif_results.len(), 2);
472/// // Perfectly collinear variables will have VIF = infinity
473/// ```
474pub fn calculate_vif(x_vars: &[Vec<f64>], names: &[String], n: usize) -> Vec<VifResult> {
475    let k = x_vars.len();
476    if k <= 1 {
477        return vec![];
478    }
479
480    // Standardize predictors (Z-score)
481    let mut z_data = vec![0.0; n * k];
482
483    for (j, var) in x_vars.iter().enumerate() {
484        let mean = vec_mean(var);
485        let variance = var.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / ((n - 1) as f64);
486        let std_dev = variance.sqrt();
487
488        // Handle constant variables
489        if std_dev.abs() < 1e-10 {
490            return names
491                .iter()
492                .skip(1)
493                .map(|name| VifResult {
494                    variable: name.clone(),
495                    vif: f64::INFINITY,
496                    rsquared: 1.0,
497                    interpretation: "Constant variable (undefined correlation)".to_string(),
498                })
499                .collect();
500        }
501
502        for i in 0..n {
503            z_data[i * k + j] = (var[i] - mean) / std_dev;
504        }
505    }
506
507    let z = Matrix::new(n, k, z_data);
508
509    // Correlation Matrix R = (1/(n-1)) * Z^T * Z
510    let z_t = z.transpose();
511    let zt_z = z_t.matmul(&z);
512
513    // Scale by 1/(n-1)
514    let mut r_corr = zt_z; // Copy
515    let factor = 1.0 / ((n - 1) as f64);
516    for val in &mut r_corr.data {
517        *val *= factor;
518    }
519
520    // Invert R using QR on R_corr (since it's symmetric positive definite, Cholesky is better but QR works)
521    // Or just generic inversion. We implemented generic inversion for Upper Triangular.
522    // Let's use QR: A = QR => A^-1 = R^-1 Q^T
523    let (q_corr, r_corr_tri) = r_corr.qr();
524
525    let r_inv_opt = r_corr_tri.invert_upper_triangular();
526
527    let r_inv = match r_inv_opt {
528        Some(inv) => inv.matmul(&q_corr.transpose()),
529        None => {
530            return names
531                .iter()
532                .skip(1)
533                .map(|name| VifResult {
534                    variable: name.clone(),
535                    vif: f64::INFINITY,
536                    rsquared: 1.0,
537                    interpretation: "Perfect multicollinearity (singular matrix)".to_string(),
538                })
539                .collect();
540        },
541    };
542
543    // Extract diagonals
544    let mut results = vec![];
545    for j in 0..k {
546        let vif = r_inv.get(j, j);
547        let vif = if vif < 1.0 { 1.0 } else { vif };
548        let rsquared = 1.0 - 1.0 / vif;
549
550        let interpretation = if vif.is_infinite() {
551            "Perfect multicollinearity".to_string()
552        } else if vif > 10.0 {
553            "High multicollinearity - consider removing".to_string()
554        } else if vif > 5.0 {
555            "Moderate multicollinearity".to_string()
556        } else {
557            "Low multicollinearity".to_string()
558        };
559
560        results.push(VifResult {
561            variable: names[j + 1].clone(),
562            vif,
563            rsquared,
564            interpretation,
565        });
566    }
567
568    results
569}
570
571// ============================================================================
572// OLS Regression
573// ============================================================================
574//
575// Ordinary Least Squares regression implementation using QR decomposition.
576
577/// Performs Ordinary Least Squares regression using QR decomposition.
578///
579/// Uses a numerically stable QR decomposition approach to solve the normal
580/// equations. Returns comprehensive statistics including coefficients,
581/// standard errors, t-statistics, p-values, and diagnostic measures.
582///
583/// # Arguments
584///
585/// * `y` - Response variable (n observations)
586/// * `x_vars` - Predictor variables (each of length n)
587/// * `variable_names` - Names for variables (including intercept)
588///
589/// # Returns
590///
591/// A [`RegressionOutput`] containing all regression statistics and diagnostics.
592///
593/// # Errors
594///
595/// Returns [`Error::InsufficientData`] if n ≤ k + 1.
596/// Returns [`Error::SingularMatrix`] if perfect multicollinearity exists.
597/// Returns [`Error::InvalidInput`] if coefficients are NaN.
598///
599/// # Example
600///
601/// ```
602/// # use linreg_core::core::ols_regression;
603/// # use linreg_core::Error;
604/// let y = vec![2.5, 3.7, 4.2, 5.1, 6.3, 7.0];
605/// let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
606/// let x2 = vec![2.0, 4.0, 5.0, 4.0, 3.0, 2.0];
607/// let names = vec![
608///     "Intercept".to_string(),
609///     "Temperature".to_string(),
610///     "Pressure".to_string(),
611/// ];
612///
613/// let result = ols_regression(&y, &[x1, x2], &names)?;
614/// println!("R-squared: {}", result.r_squared);
615/// # Ok::<(), Error>(())
616/// ```
617#[allow(clippy::needless_range_loop)]
618pub fn ols_regression(
619    y: &[f64],
620    x_vars: &[Vec<f64>],
621    variable_names: &[String],
622) -> Result<RegressionOutput> {
623    let n = y.len();
624    let k = x_vars.len();
625    let p = k + 1;
626
627    // Validate inputs
628    if n <= k + 1 {
629        return Err(Error::InsufficientData {
630            required: k + 2,
631            available: n,
632        });
633    }
634
635    // Validate dimensions and finite values using shared helper
636    crate::diagnostics::validate_regression_data(y, x_vars)?;
637
638    // Prepare variable names
639    let mut names = variable_names.to_vec();
640    while names.len() <= k {
641        names.push(format!("X{}", names.len()));
642    }
643
644    // Create design matrix
645    let mut x_data = vec![0.0; n * p];
646    for (row, _yi) in y.iter().enumerate() {
647        x_data[row * p] = 1.0; // intercept
648        for (col, x_var) in x_vars.iter().enumerate() {
649            x_data[row * p + col + 1] = x_var[row];
650        }
651    }
652
653    let x_matrix = Matrix::new(n, p, x_data);
654
655    // QR Decomposition
656    let (q, r) = x_matrix.qr();
657
658    // Solve R * beta = Q^T * y
659    // extract upper p x p part of R
660    let mut r_upper = Matrix::zeros(p, p);
661    for i in 0..p {
662        for j in 0..p {
663            r_upper.set(i, j, r.get(i, j));
664        }
665    }
666
667    // Q^T * y
668    let q_t = q.transpose();
669    let qty = q_t.mul_vec(y);
670
671    // Take first p elements of qty
672    let rhs_vec = qty[0..p].to_vec();
673    let rhs_mat = Matrix::new(p, 1, rhs_vec); // column vector
674
675    // Invert R_upper
676    let r_inv = match r_upper.invert_upper_triangular() {
677        Some(inv) => inv,
678        None => return Err(Error::SingularMatrix),
679    };
680
681    let beta_mat = r_inv.matmul(&rhs_mat);
682    let beta = beta_mat.data;
683
684    // Validate coefficients
685    if beta.iter().any(|&b| b.is_nan()) {
686        return Err(Error::InvalidInput("Coefficients contain NaN".to_string()));
687    }
688
689    // Compute predictions and residuals
690    let predictions = x_matrix.mul_vec(&beta);
691    let residuals = vec_sub(y, &predictions);
692
693    // Compute sums of squares
694    let y_mean = vec_mean(y);
695    let ss_total: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
696    let ss_residual: f64 = residuals.iter().map(|&r| r.powi(2)).sum();
697    let ss_regression = ss_total - ss_residual;
698
699    // R-squared and adjusted R-squared
700    let r_squared = if ss_total == 0.0 {
701        f64::NAN
702    } else {
703        1.0 - ss_residual / ss_total
704    };
705
706    let adj_r_squared = if ss_total == 0.0 {
707        f64::NAN
708    } else {
709        1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n - k - 1) as f64)
710    };
711
712    // Mean squared error and standard error
713    let df = n - k - 1;
714    let mse = ss_residual / df as f64;
715    let std_error = mse.sqrt();
716
717    // Standard errors using (X'X)^-1 = R^-1 (R')^-1
718    // xtx_inv = r_inv * r_inv^T
719    let xtx_inv = r_inv.matmul(&r_inv.transpose());
720
721    let mut std_errors = vec![0.0; k + 1];
722    for i in 0..=k {
723        std_errors[i] = (xtx_inv.get(i, i) * mse).sqrt();
724        if std_errors[i].is_nan() {
725            return Err(Error::InvalidInput(format!(
726                "Standard error for coefficient {} is NaN",
727                i
728            )));
729        }
730    }
731
732    // T-statistics and p-values
733    let t_stats: Vec<f64> = beta
734        .iter()
735        .zip(&std_errors)
736        .map(|(&b, &se)| b / se)
737        .collect();
738    let p_values: Vec<f64> = t_stats
739        .iter()
740        .map(|&t| two_tailed_p_value(t, df as f64))
741        .collect();
742
743    // Confidence intervals
744    let alpha = 0.05;
745    let t_critical = t_critical_quantile(df as f64, alpha);
746
747    let conf_int_lower: Vec<f64> = beta
748        .iter()
749        .zip(&std_errors)
750        .map(|(&b, &se)| b - t_critical * se)
751        .collect();
752    let conf_int_upper: Vec<f64> = beta
753        .iter()
754        .zip(&std_errors)
755        .map(|(&b, &se)| b + t_critical * se)
756        .collect();
757
758    // Leverage
759    let leverage = compute_leverage(&x_matrix, &xtx_inv);
760
761    // Standardized residuals
762    let residuals_vec = residuals.clone();
763    let standardized_residuals: Vec<f64> = residuals_vec
764        .iter()
765        .zip(&leverage)
766        .map(|(&r, &h)| {
767            let factor = (1.0 - h).max(MIN_LEVERAGE_DENOM).sqrt();
768            let denom = std_error * factor;
769            if denom > MIN_LEVERAGE_DENOM {
770                r / denom
771            } else {
772                0.0
773            }
774        })
775        .collect();
776
777    // F-statistic
778    let f_statistic = (ss_regression / k as f64) / mse;
779    let f_p_value = f_p_value(f_statistic, k as f64, df as f64);
780
781    // RMSE and MAE
782    let rmse = std_error;
783    let mae: f64 = residuals_vec.iter().map(|&r| r.abs()).sum::<f64>() / n as f64;
784
785    // VIF
786    let vif = calculate_vif(x_vars, &names, n);
787
788    // Model selection criteria (for model comparison)
789    let ll = log_likelihood(n, mse, ss_residual);
790    let n_coef = k + 1;  // predictors + intercept
791    let aic_val = aic(ll, n_coef);
792    let bic_val = bic(ll, n_coef, n);
793
794    Ok(RegressionOutput {
795        coefficients: beta,
796        std_errors,
797        t_stats,
798        p_values,
799        conf_int_lower,
800        conf_int_upper,
801        r_squared,
802        adj_r_squared,
803        f_statistic,
804        f_p_value,
805        mse,
806        rmse,
807        mae,
808        std_error,
809        residuals: residuals_vec,
810        standardized_residuals,
811        predictions,
812        leverage,
813        vif,
814        n,
815        k,
816        df,
817        variable_names: names,
818        log_likelihood: ll,
819        aic: aic_val,
820        bic: bic_val,
821    })
822}
823
824#[cfg(test)]
825mod tests {
826    use super::*;
827
828    #[test]
829    fn test_aic_bic_formulas_known_values() {
830        // Test formulas with simple known inputs
831        let ll = -100.0;
832        let n_coef = 3; // 3 coefficients (e.g., intercept + 2 predictors)
833        let n_obs = 100;
834
835        let aic_val = aic(ll, n_coef);
836        let bic_val = bic(ll, n_coef, n_obs);
837
838        // AIC = 2k - 2logL where k = n_coef + 1 (variance parameter)
839        // AIC = 2*4 - 2*(-100) = 8 + 200 = 208
840        assert!((aic_val - 208.0).abs() < 1e-10);
841
842        // BIC = k*ln(n) - 2logL where k = n_coef + 1
843        // BIC = 4*ln(100) - 2*(-100) = 4*4.605... + 200
844        let expected_bic = 4.0 * (100.0_f64).ln() + 200.0;
845        assert!((bic_val - expected_bic).abs() < 1e-10);
846    }
847
848    #[test]
849    fn test_bic_greater_than_aic_for_reasonable_n() {
850        // For n >= 8, ln(n) > 2, so BIC > AIC (both have -2logL term)
851        // BIC uses k*ln(n) while AIC uses 2k, so when ln(n) > 2, BIC > AIC
852        let ll = -50.0;
853        let n_coef = 2;
854
855        let aic_val = aic(ll, n_coef);
856        let bic_val = bic(ll, n_coef, 100); // n=100, ln(100) ≈ 4.6 > 2
857
858        assert!(bic_val > aic_val);
859    }
860
861    #[test]
862    fn test_log_likelihood_returns_finite() {
863        // Ensure log_likelihood returns finite values for valid inputs
864        let n = 10;
865        let mse = 4.0;
866        let ssr = mse * (n - 2) as f64;
867
868        let ll = log_likelihood(n, mse, ssr);
869        assert!(ll.is_finite());
870    }
871
872    #[test]
873    fn test_log_likelihood_increases_with_better_fit() {
874        // Lower SSR (better fit) should give higher log-likelihood
875        let n = 10;
876
877        // Worse fit (higher residuals)
878        let ll_worse = log_likelihood(n, 10.0, 80.0);
879
880        // Better fit (lower residuals)
881        let ll_better = log_likelihood(n, 2.0, 16.0);
882
883        assert!(ll_better > ll_worse);
884    }
885
886    #[test]
887    fn test_model_selection_criteria_present_in_output() {
888        // Basic sanity check that the fields are populated
889        let y = vec![2.0, 4.0, 5.0, 4.0, 5.0];
890        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
891        let names = vec!["Intercept".to_string(), "X1".to_string()];
892
893        let result = ols_regression(&y, &[x1], &names).unwrap();
894
895        // All three should be finite
896        assert!(result.log_likelihood.is_finite());
897        assert!(result.aic.is_finite());
898        assert!(result.bic.is_finite());
899
900        // AIC and BIC should be positive for typical cases
901        // (since log_likelihood is usually negative and bounded)
902        assert!(result.aic > 0.0);
903        assert!(result.bic > 0.0);
904    }
905
906    #[test]
907    fn test_regression_output_has_correct_dimensions() {
908        // Verify AIC/BIC use k = n_coef + 1 (coefficients + variance parameter)
909        let y = vec![3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0];
910        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
911        let x2 = vec![3.0, 2.0, 4.0, 3.0, 5.0, 4.0, 6.0, 5.0];
912        let names = vec!["Intercept".into(), "X1".into(), "X2".into()];
913
914        let result = ols_regression(&y, &[x1, x2], &names).unwrap();
915
916        // n_coef = 3 (intercept + 2 predictors)
917        // k = n_coef + 1 = 4 (including variance parameter, following R convention)
918        let n_coef = 3;
919        let k = n_coef + 1; // R's convention includes variance parameter
920
921        // Verify by recalculating AIC from log_likelihood
922        let expected_aic = 2.0 * k as f64 - 2.0 * result.log_likelihood;
923        assert!((result.aic - expected_aic).abs() < 1e-10);
924
925        // Verify by recalculating BIC from log_likelihood
926        let expected_bic = k as f64 * (result.n as f64).ln() - 2.0 * result.log_likelihood;
927        assert!((result.bic - expected_bic).abs() < 1e-10);
928    }
929
930    #[test]
931    fn test_aic_python_convention() {
932        // Python's statsmodels uses k = n_coef (no variance parameter)
933        let ll = -100.0;
934        let n_coef = 3;
935
936        let aic_py = aic_python(ll, n_coef);
937        // AIC = 2k - 2logL where k = n_coef (Python convention)
938        // AIC = 2*3 - 2*(-100) = 6 + 200 = 206
939        assert!((aic_py - 206.0).abs() < 1e-10);
940    }
941
942    #[test]
943    fn test_bic_python_convention() {
944        // Python's statsmodels uses k = n_coef (no variance parameter)
945        let ll = -100.0;
946        let n_coef = 3;
947        let n_obs = 100;
948
949        let bic_py = bic_python(ll, n_coef, n_obs);
950        // BIC = k*ln(n) - 2logL where k = n_coef (Python convention)
951        // BIC = 3*ln(100) - 2*(-100) = 3*4.605... + 200
952        let expected_bic = 3.0 * (100.0_f64).ln() + 200.0;
953        assert!((bic_py - expected_bic).abs() < 1e-10);
954    }
955
956    #[test]
957    fn test_python_aic_smaller_than_r_aic() {
958        // Python convention uses k = n_coef, R uses k = n_coef + 1
959        // So Python AIC should be 2 smaller than R AIC
960        let ll = -50.0;
961        let n_coef = 2;
962
963        let aic_r = aic(ll, n_coef);
964        let aic_py = aic_python(ll, n_coef);
965
966        assert_eq!(aic_r - aic_py, 2.0);
967    }
968
969    #[test]
970    fn test_log_likelihood_formula_matches_r() {
971        // Test against R's logLik.lm() formula
972        // For a model with n=100, SSR=450, logL = -n/2 * log(2*pi*SSR/n) - n/2
973        let n = 100;
974        let ssr = 450.0;
975        let mse = ssr / (n as f64 - 2.0); // 2 parameters
976
977        let ll = log_likelihood(n, mse, ssr);
978
979        // Calculate expected value manually
980        let two_pi = 2.0 * std::f64::consts::PI;
981        let expected = -0.5 * n as f64 * (two_pi * ssr / n as f64).ln() - n as f64 / 2.0;
982
983        assert!((ll - expected).abs() < 1e-10);
984    }
985
986    #[test]
987    fn test_aic_bic_with_perfect_fit() {
988        // Perfect fit (zero residuals) - edge case
989        let n = 10;
990        let ssr = 0.001; // Very small but non-zero to avoid log(0)
991        let mse = ssr / (n as f64 - 2.0);
992
993        let ll = log_likelihood(n, mse, ssr);
994        let aic_val = aic(ll, 2);
995        let bic_val = bic(ll, 2, n);
996
997        // Perfect fit gives very high log-likelihood
998        assert!(ll > 0.0);
999        // AIC/BIC penalize complexity, so may be negative for very good fits
1000        assert!(aic_val.is_finite());
1001        assert!(bic_val.is_finite());
1002    }
1003
1004    #[test]
1005    fn test_aic_bic_model_selection() {
1006        // Simulate model comparison: simpler model vs complex model
1007        // Both models fit same data with similar R² but different complexity
1008        let n = 100;
1009
1010        // Simple model (2 params): better log-likelihood due to less penalty
1011        let ll_simple = -150.0;
1012        let aic_simple = aic(ll_simple, 2);
1013        let bic_simple = bic(ll_simple, 2, n);
1014
1015        // Complex model (5 params): slightly better fit but more parameters
1016        let ll_complex = -148.0; // Better fit (less negative)
1017        let aic_complex = aic(ll_complex, 5);
1018        let bic_complex = bic(ll_complex, 5, n);
1019
1020        // AIC might favor complex model (2*2 - 2*(-150) = 304 vs 2*6 - 2*(-148) = 308)
1021        // Actually: 4 + 300 = 304 vs 12 + 296 = 308, so simple wins
1022        assert!(aic_simple < aic_complex);
1023
1024        // BIC more heavily penalizes complexity, so simple should win
1025        assert!(bic_simple < bic_complex);
1026    }
1027
1028    #[test]
1029    fn test_log_likelihood_scale_invariance() {
1030        // Log-likelihood scales with sample size for same per-observation fit quality
1031        let ssr_per_obs = 1.0;
1032
1033        let n1 = 50;
1034        let ssr1 = ssr_per_obs * n1 as f64;
1035        let ll1 = log_likelihood(n1, ssr1 / (n1 as f64 - 2.0), ssr1);
1036
1037        let n2 = 100;
1038        let ssr2 = ssr_per_obs * n2 as f64;
1039        let ll2 = log_likelihood(n2, ssr2 / (n2 as f64 - 2.0), ssr2);
1040
1041        // The log-likelihood should become more negative with larger n for the same SSR/n ratio
1042        // because -n/2 * ln(2*pi*SSR/n) - n/2 becomes more negative as n increases
1043        assert!(ll2 < ll1);
1044
1045        // But when normalized by n, they should be similar
1046        let ll_per_obs1 = ll1 / n1 as f64;
1047        let ll_per_obs2 = ll2 / n2 as f64;
1048        assert!((ll_per_obs1 - ll_per_obs2).abs() < 0.1);
1049    }
1050
1051    #[test]
1052    fn test_regularized_regression_has_model_selection_criteria() {
1053        // Test that Ridge regression also calculates AIC/BIC/log_likelihood
1054        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
1055        let x_data = vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0];
1056        let x = crate::linalg::Matrix::new(5, 2, x_data);
1057
1058        let options = crate::regularized::ridge::RidgeFitOptions {
1059            lambda: 0.1,
1060            intercept: true,
1061            standardize: false,
1062            ..Default::default()
1063        };
1064
1065        let fit = crate::regularized::ridge::ridge_fit(&x, &y, &options).unwrap();
1066
1067        assert!(fit.log_likelihood.is_finite());
1068        assert!(fit.aic.is_finite());
1069        assert!(fit.bic.is_finite());
1070    }
1071
1072    #[test]
1073    fn test_elastic_net_regression_has_model_selection_criteria() {
1074        // Test that Elastic Net regression also calculates AIC/BIC/log_likelihood
1075        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
1076        let x_data = vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0];
1077        let x = crate::linalg::Matrix::new(5, 2, x_data);
1078
1079        let options = crate::regularized::elastic_net::ElasticNetOptions {
1080            lambda: 0.1,
1081            alpha: 0.5,
1082            intercept: true,
1083            standardize: false,
1084            ..Default::default()
1085        };
1086
1087        let fit = crate::regularized::elastic_net::elastic_net_fit(&x, &y, &options).unwrap();
1088
1089        assert!(fit.log_likelihood.is_finite());
1090        assert!(fit.aic.is_finite());
1091        assert!(fit.bic.is_finite());
1092    }
1093}