linreg_core/
core.rs

1//! Core OLS regression implementation.
2//!
3//! This module provides the main Ordinary Least Squares regression functionality
4//! that can be used directly in Rust code. Functions accept native Rust slices
5//! and return Result types for proper error handling.
6//!
7//! # Example
8//!
9//! ```
10//! # use linreg_core::core::ols_regression;
11//! # use linreg_core::Error;
12//! let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
13//! let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
14//! let x2 = vec![2.0, 3.0, 3.5, 4.0, 4.5, 5.0];
15//! let names = vec![
16//!     "Intercept".to_string(),
17//!     "X1".to_string(),
18//!     "X2".to_string(),
19//! ];
20//!
21//! let result = ols_regression(&y, &[x1, x2], &names)?;
22//! # Ok::<(), Error>(())
23//! ```
24
25use crate::error::{Error, Result};
26use crate::linalg::{Matrix, vec_mean, vec_sub, vec_dot};
27use serde::Serialize;
28use crate::distributions::{student_t_cdf, student_t_inverse_cdf, fisher_snedecor_cdf};
29
30// ============================================================================
31// Numerical Constants
32// ============================================================================
33
34/// Minimum threshold for standardized residual denominator to avoid division by zero.
35/// When (1 - leverage) is very small, the observation has extremely high leverage
36/// and standardized residuals may be unreliable.
37const MIN_LEVERAGE_DENOM: f64 = 1e-10;
38
39// ============================================================================
40// Result Types
41// ============================================================================
42//
43// Structs containing the output of regression computations.
44
45/// Result of VIF (Variance Inflation Factor) calculation.
46///
47/// VIF measures how much the variance of an estimated regression coefficient
48/// increases due to multicollinearity among the predictors.
49#[derive(Debug, Clone, Serialize)]
50pub struct VifResult {
51    /// Name of the predictor variable
52    pub variable: String,
53    /// Variance Inflation Factor (VIF > 10 indicates high multicollinearity)
54    pub vif: f64,
55    /// R-squared from regressing this predictor on all others
56    pub rsquared: f64,
57    /// Human-readable interpretation of the VIF value
58    pub interpretation: String,
59}
60
61/// Complete output from OLS regression.
62///
63/// Contains all coefficients, statistics, diagnostics, and residuals from
64/// an Ordinary Least Squares regression.
65#[derive(Debug, Clone, Serialize)]
66pub struct RegressionOutput {
67    /// Regression coefficients (including intercept)
68    pub coefficients: Vec<f64>,
69    /// Standard errors of coefficients
70    pub std_errors: Vec<f64>,
71    /// t-statistics for coefficient significance tests
72    pub t_stats: Vec<f64>,
73    /// Two-tailed p-values for coefficients
74    pub p_values: Vec<f64>,
75    /// Lower bounds of 95% confidence intervals
76    pub conf_int_lower: Vec<f64>,
77    /// Upper bounds of 95% confidence intervals
78    pub conf_int_upper: Vec<f64>,
79    /// R-squared (coefficient of determination)
80    pub r_squared: f64,
81    /// Adjusted R-squared (accounts for number of predictors)
82    pub adj_r_squared: f64,
83    /// F-statistic for overall model significance
84    pub f_statistic: f64,
85    /// P-value for F-statistic
86    pub f_p_value: f64,
87    /// Mean squared error of residuals
88    pub mse: f64,
89    /// Standard error of the regression (residual standard deviation)
90    pub std_error: f64,
91    /// Raw residuals (observed - predicted)
92    pub residuals: Vec<f64>,
93    /// Standardized residuals (accounting for leverage)
94    pub standardized_residuals: Vec<f64>,
95    /// Fitted/predicted values
96    pub predictions: Vec<f64>,
97    /// Leverage values for each observation (diagonal of hat matrix)
98    pub leverage: Vec<f64>,
99    /// Variance Inflation Factors for detecting multicollinearity
100    pub vif: Vec<VifResult>,
101    /// Number of observations
102    pub n: usize,
103    /// Number of predictor variables (excluding intercept)
104    pub k: usize,
105    /// Degrees of freedom for residuals (n - k - 1)
106    pub df: usize,
107    /// Names of variables (including intercept)
108    pub variable_names: Vec<String>,
109}
110
111// ============================================================================
112// Statistical Helper Functions
113// ============================================================================
114//
115// Utility functions for computing p-values, critical values, and leverage.
116
117/// Computes a two-tailed p-value from a t-statistic.
118///
119/// Uses the Student's t-distribution CDF to calculate the probability
120/// of observing a t-statistic as extreme as the one provided.
121///
122/// # Arguments
123///
124/// * `t` - The t-statistic value
125/// * `df` - Degrees of freedom
126pub fn two_tailed_p_value(t: f64, df: f64) -> f64 {
127    if t.abs() > 100.0 {
128        return 0.0;
129    }
130
131    let cdf = student_t_cdf(t, df);
132    if t >= 0.0 { 2.0 * (1.0 - cdf) } else { 2.0 * cdf }
133}
134
135/// Computes the critical t-value for a given significance level and degrees of freedom.
136///
137/// Returns the t-value such that the area under the t-distribution curve
138/// to the right of it equals alpha/2 (two-tailed test).
139///
140/// # Arguments
141///
142/// * `df` - Degrees of freedom
143/// * `alpha` - Significance level (typically 0.05 for 95% confidence)
144pub fn t_critical_quantile(df: f64, alpha: f64) -> f64 {
145    let p = 1.0 - alpha / 2.0;
146    student_t_inverse_cdf(p, df)
147}
148
149/// Computes a p-value from an F-statistic.
150///
151/// Uses the F-distribution CDF to calculate the probability of observing
152/// an F-statistic as extreme as the one provided.
153///
154/// # Arguments
155///
156/// * `f_stat` - The F-statistic value
157/// * `df1` - Numerator degrees of freedom
158/// * `df2` - Denominator degrees of freedom
159pub fn f_p_value(f_stat: f64, df1: f64, df2: f64) -> f64 {
160    if f_stat <= 0.0 {
161        return 1.0;
162    }
163    1.0 - fisher_snedecor_cdf(f_stat, df1, df2)
164}
165
166/// Computes leverage values from the design matrix and its inverse.
167///
168/// Leverage measures how far an observation's predictor values are from
169/// the center of the predictor space. High leverage points can have
170/// disproportionate influence on the regression results.
171///
172/// # Arguments
173///
174/// * `x` - Design matrix (including intercept column)
175/// * `xtx_inv` - Inverse of X'X matrix
176pub fn compute_leverage(x: &Matrix, xtx_inv: &Matrix) -> Vec<f64> {
177    let n = x.rows;
178    let mut leverage = vec![0.0; n];
179    for i in 0..n {
180        // x_row is (1, cols)
181        // temp = x_row * xtx_inv (1, cols)
182        // lev = temp * x_row^T (1, 1)
183        
184        // Manual row extraction and multiplication
185        let mut row_vec = vec![0.0; x.cols];
186        for j in 0..x.cols {
187            row_vec[j] = x.get(i, j);
188        }
189        
190        let mut temp_vec = vec![0.0; x.cols];
191        for c in 0..x.cols {
192            let mut sum = 0.0;
193            for k in 0..x.cols {
194                sum += row_vec[k] * xtx_inv.get(k, c);
195            }
196            temp_vec[c] = sum;
197        }
198        
199        leverage[i] = vec_dot(&temp_vec, &row_vec);
200    }
201    leverage
202}
203
204// ============================================================================
205// VIF Calculation
206// ============================================================================
207//
208// Variance Inflation Factor analysis for detecting multicollinearity.
209
210/// Calculates Variance Inflation Factors for all predictors.
211///
212/// VIF quantifies the severity of multicollinearity in a regression analysis.
213/// A VIF > 10 indicates high multicollinearity that may need to be addressed.
214///
215/// # Arguments
216///
217/// * `x_vars` - Predictor variables (each of length n)
218/// * `names` - Variable names (including intercept as first element)
219/// * `n` - Number of observations
220///
221/// # Returns
222///
223/// Vector of VIF results for each predictor (excluding intercept).
224pub fn calculate_vif(x_vars: &[Vec<f64>], names: &[String], n: usize) -> Vec<VifResult> {
225    let k = x_vars.len();
226    if k <= 1 {
227        return vec![];
228    }
229
230    // Standardize predictors (Z-score)
231    let mut z_data = vec![0.0; n * k];
232
233    for (j, var) in x_vars.iter().enumerate() {
234        let mean = vec_mean(var);
235        let variance = var.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / ((n - 1) as f64);
236        let std_dev = variance.sqrt();
237
238        // Handle constant variables
239        if std_dev.abs() < 1e-10 {
240            return names.iter().skip(1).map(|name| VifResult {
241                variable: name.clone(),
242                vif: f64::INFINITY,
243                rsquared: 1.0,
244                interpretation: "Constant variable (undefined correlation)".to_string()
245            }).collect();
246        }
247
248        for i in 0..n {
249            z_data[i * k + j] = (var[i] - mean) / std_dev;
250        }
251    }
252
253    let z = Matrix::new(n, k, z_data);
254
255    // Correlation Matrix R = (1/(n-1)) * Z^T * Z
256    let z_t = z.transpose();
257    let zt_z = z_t.matmul(&z);
258    
259    // Scale by 1/(n-1)
260    let mut r_corr = zt_z; // Copy
261    let factor = 1.0 / ((n - 1) as f64);
262    for val in &mut r_corr.data {
263        *val *= factor;
264    }
265
266    // Invert R using QR on R_corr (since it's symmetric positive definite, Cholesky is better but QR works)
267    // Or just generic inversion. We implemented generic inversion for Upper Triangular.
268    // Let's use QR: A = QR => A^-1 = R^-1 Q^T
269    let (q_corr, r_corr_tri) = r_corr.qr();
270    
271    let r_inv_opt = r_corr_tri.invert_upper_triangular();
272    
273    let r_inv = match r_inv_opt {
274        Some(inv) => inv.matmul(&q_corr.transpose()),
275        None => {
276             return names.iter().skip(1).map(|name| VifResult {
277                variable: name.clone(),
278                vif: f64::INFINITY,
279                rsquared: 1.0,
280                interpretation: "Perfect multicollinearity (singular matrix)".to_string()
281            }).collect();
282        }
283    };
284
285    // Extract diagonals
286    let mut results = vec![];
287    for j in 0..k {
288        let vif = r_inv.get(j, j);
289        let vif = if vif < 1.0 { 1.0 } else { vif };
290        let rsquared = 1.0 - 1.0 / vif;
291
292        let interpretation = if vif.is_infinite() {
293            "Perfect multicollinearity".to_string()
294        } else if vif > 10.0 {
295            "High multicollinearity - consider removing".to_string()
296        } else if vif > 5.0 {
297            "Moderate multicollinearity".to_string()
298        } else {
299            "Low multicollinearity".to_string()
300        };
301
302        results.push(VifResult {
303            variable: names[j + 1].clone(),
304            vif,
305            rsquared,
306            interpretation,
307        });
308    }
309
310    results
311}
312
313// ============================================================================
314// OLS Regression
315// ============================================================================
316//
317// Ordinary Least Squares regression implementation using QR decomposition.
318
319/// Performs Ordinary Least Squares regression using QR decomposition.
320///
321/// Uses a numerically stable QR decomposition approach to solve the normal
322/// equations. Returns comprehensive statistics including coefficients,
323/// standard errors, t-statistics, p-values, and diagnostic measures.
324///
325/// # Arguments
326///
327/// * `y` - Response variable (n observations)
328/// * `x_vars` - Predictor variables (each of length n)
329/// * `variable_names` - Names for variables (including intercept)
330///
331/// # Returns
332///
333/// A [`RegressionOutput`] containing all regression statistics and diagnostics.
334///
335/// # Errors
336///
337/// Returns [`Error::InsufficientData`] if n ≤ k + 1.
338/// Returns [`Error::SingularMatrix`] if perfect multicollinearity exists.
339/// Returns [`Error::InvalidInput`] if coefficients are NaN.
340///
341/// # Example
342///
343/// ```
344/// # use linreg_core::core::ols_regression;
345/// # use linreg_core::Error;
346/// let y = vec![2.5, 3.7, 4.2, 5.1, 6.3, 7.0];
347/// let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
348/// let x2 = vec![2.0, 4.0, 5.0, 4.0, 3.0, 2.0];
349/// let names = vec![
350///     "Intercept".to_string(),
351///     "Temperature".to_string(),
352///     "Pressure".to_string(),
353/// ];
354///
355/// let result = ols_regression(&y, &[x1, x2], &names)?;
356/// println!("R-squared: {}", result.r_squared);
357/// # Ok::<(), Error>(())
358/// ```
359pub fn ols_regression(
360    y: &[f64],
361    x_vars: &[Vec<f64>],
362    variable_names: &[String],
363) -> Result<RegressionOutput> {
364    let n = y.len();
365    let k = x_vars.len();
366    let p = k + 1;
367
368    // Validate inputs
369    if n <= k + 1 {
370        return Err(Error::InsufficientData { required: k + 2, available: n });
371    }
372
373    // Prepare variable names
374    let mut names = variable_names.to_vec();
375    while names.len() <= k {
376        names.push(format!("X{}", names.len()));
377    }
378
379    // Create design matrix
380    let mut x_data = vec![0.0; n * p];
381    for (row, _yi) in y.iter().enumerate() {
382        x_data[row * p] = 1.0;  // intercept
383        for (col, x_var) in x_vars.iter().enumerate() {
384            x_data[row * p + col + 1] = x_var[row];
385        }
386    }
387
388    let x_matrix = Matrix::new(n, p, x_data);
389
390    // QR Decomposition
391    let (q, r) = x_matrix.qr();
392
393    // Solve R * beta = Q^T * y
394    // extract upper p x p part of R
395    let mut r_upper = Matrix::zeros(p, p);
396    for i in 0..p {
397        for j in 0..p {
398            r_upper.set(i, j, r.get(i, j));
399        }
400    }
401
402    // Q^T * y
403    let q_t = q.transpose();
404    let qty = q_t.mul_vec(y);
405    
406    // Take first p elements of qty
407    let rhs_vec = qty[0..p].to_vec();
408    let rhs_mat = Matrix::new(p, 1, rhs_vec); // column vector
409
410    // Invert R_upper
411    let r_inv = match r_upper.invert_upper_triangular() {
412        Some(inv) => inv,
413        None => return Err(Error::SingularMatrix),
414    };
415
416    let beta_mat = r_inv.matmul(&rhs_mat);
417    let beta = beta_mat.data;
418
419    // Validate coefficients
420    if beta.iter().any(|&b| b.is_nan()) {
421        return Err(Error::InvalidInput("Coefficients contain NaN".to_string()));
422    }
423
424    // Compute predictions and residuals
425    let predictions = x_matrix.mul_vec(&beta);
426    let residuals = vec_sub(y, &predictions);
427
428    // Compute sums of squares
429    let y_mean = vec_mean(y);
430    let ss_total: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
431    let ss_residual: f64 = residuals.iter().map(|&r| r.powi(2)).sum();
432    let ss_regression = ss_total - ss_residual;
433
434    // R-squared and adjusted R-squared
435    let r_squared = if ss_total == 0.0 {
436        f64::NAN
437    } else {
438        1.0 - ss_residual / ss_total
439    };
440
441    let adj_r_squared = if ss_total == 0.0 {
442        f64::NAN
443    } else {
444        1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n - k - 1) as f64)
445    };
446
447    // Mean squared error and standard error
448    let df = n - k - 1;
449    let mse = ss_residual / df as f64;
450    let std_error = mse.sqrt();
451
452    // Standard errors using (X'X)^-1 = R^-1 (R')^-1
453    // xtx_inv = r_inv * r_inv^T
454    let xtx_inv = r_inv.matmul(&r_inv.transpose());
455
456    let mut std_errors = vec![0.0; k + 1];
457    for i in 0..=k {
458        std_errors[i] = (xtx_inv.get(i, i) * mse).sqrt();
459        if std_errors[i].is_nan() {
460            return Err(Error::InvalidInput(format!("Standard error for coefficient {} is NaN", i)));
461        }
462    }
463
464    // T-statistics and p-values
465    let t_stats: Vec<f64> = beta.iter().zip(&std_errors).map(|(&b, &se)| b / se).collect();
466    let p_values: Vec<f64> = t_stats.iter().map(|&t| two_tailed_p_value(t, df as f64)).collect();
467
468    // Confidence intervals
469    let alpha = 0.05;
470    let t_critical = t_critical_quantile(df as f64, alpha);
471
472    let conf_int_lower: Vec<f64> = beta.iter().zip(&std_errors).map(|(&b, &se)| b - t_critical * se).collect();
473    let conf_int_upper: Vec<f64> = beta.iter().zip(&std_errors).map(|(&b, &se)| b + t_critical * se).collect();
474
475    // Leverage
476    let leverage = compute_leverage(&x_matrix, &xtx_inv);
477
478    // Standardized residuals
479    let residuals_vec = residuals.clone();
480    let standardized_residuals: Vec<f64> = residuals_vec.iter().zip(&leverage)
481        .map(|(&r, &h)| {
482            let factor = (1.0 - h).max(MIN_LEVERAGE_DENOM).sqrt();
483            let denom = std_error * factor;
484            if denom > MIN_LEVERAGE_DENOM { r / denom } else { 0.0 }
485        })
486        .collect();
487
488    // F-statistic
489    let f_statistic = (ss_regression / k as f64) / mse;
490    let f_p_value = f_p_value(f_statistic, k as f64, df as f64);
491
492    // VIF
493    let vif = calculate_vif(x_vars, &names, n);
494
495    Ok(RegressionOutput {
496        coefficients: beta,
497        std_errors,
498        t_stats,
499        p_values,
500        conf_int_lower,
501        conf_int_upper,
502        r_squared,
503        adj_r_squared,
504        f_statistic,
505        f_p_value,
506        mse,
507        std_error,
508        residuals: residuals_vec,
509        standardized_residuals,
510        predictions,
511        leverage,
512        vif,
513        n,
514        k,
515        df,
516        variable_names: names,
517    })
518}