linreg_core/
core.rs

1//! Core OLS regression implementation.
2//!
3//! This module provides the main Ordinary Least Squares regression functionality
4//! that can be used directly in Rust code. Functions accept native Rust slices
5//! and return Result types for proper error handling.
6//!
7//! # Example
8//!
9//! ```
10//! # use linreg_core::core::ols_regression;
11//! # use linreg_core::Error;
12//! let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
13//! let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
14//! let x2 = vec![2.0, 3.0, 3.5, 4.0, 4.5, 5.0];
15//! let names = vec![
16//!     "Intercept".to_string(),
17//!     "X1".to_string(),
18//!     "X2".to_string(),
19//! ];
20//!
21//! let result = ols_regression(&y, &[x1, x2], &names)?;
22//! # Ok::<(), Error>(())
23//! ```
24
25use crate::error::{Error, Result};
26use crate::linalg::{Matrix, vec_mean, vec_sub, vec_dot};
27use serde::Serialize;
28use crate::distributions::{student_t_cdf, student_t_inverse_cdf, fisher_snedecor_cdf};
29
30// ============================================================================
31// Numerical Constants
32// ============================================================================
33
34/// Minimum threshold for standardized residual denominator to avoid division by zero.
35/// When (1 - leverage) is very small, the observation has extremely high leverage
36/// and standardized residuals may be unreliable.
37const MIN_LEVERAGE_DENOM: f64 = 1e-10;
38
39// ============================================================================
40// Result Types
41// ============================================================================
42//
43// Structs containing the output of regression computations.
44
45/// Result of VIF (Variance Inflation Factor) calculation.
46///
47/// VIF measures how much the variance of an estimated regression coefficient
48/// increases due to multicollinearity among the predictors.
49///
50/// # Fields
51///
52/// * `variable` - Name of the predictor variable
53/// * `vif` - Variance Inflation Factor (VIF > 10 indicates high multicollinearity)
54/// * `rsquared` - R-squared from regressing this predictor on all others
55/// * `interpretation` - Human-readable interpretation of the VIF value
56///
57/// # Example
58///
59/// ```
60/// # use linreg_core::core::VifResult;
61/// let vif = VifResult {
62///     variable: "X1".to_string(),
63///     vif: 2.5,
64///     rsquared: 0.6,
65///     interpretation: "Low multicollinearity".to_string(),
66/// };
67/// assert_eq!(vif.variable, "X1");
68/// ```
69#[derive(Debug, Clone, Serialize)]
70pub struct VifResult {
71    /// Name of the predictor variable
72    pub variable: String,
73    /// Variance Inflation Factor (VIF > 10 indicates high multicollinearity)
74    pub vif: f64,
75    /// R-squared from regressing this predictor on all others
76    pub rsquared: f64,
77    /// Human-readable interpretation of the VIF value
78    pub interpretation: String,
79}
80
81/// Complete output from OLS regression.
82///
83/// Contains all coefficients, statistics, diagnostics, and residuals from
84/// an Ordinary Least Squares regression.
85///
86/// # Example
87///
88/// ```
89/// # use linreg_core::core::ols_regression;
90/// let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
91/// let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
92/// let names = vec!["Intercept".to_string(), "X1".to_string()];
93///
94/// let result = ols_regression(&y, &[x1], &names).unwrap();
95/// assert!(result.r_squared > 0.0);
96/// assert_eq!(result.coefficients.len(), 2); // intercept + 1 predictor
97/// ```
98#[derive(Debug, Clone, Serialize)]
99pub struct RegressionOutput {
100    /// Regression coefficients (including intercept)
101    pub coefficients: Vec<f64>,
102    /// Standard errors of coefficients
103    pub std_errors: Vec<f64>,
104    /// t-statistics for coefficient significance tests
105    pub t_stats: Vec<f64>,
106    /// Two-tailed p-values for coefficients
107    pub p_values: Vec<f64>,
108    /// Lower bounds of 95% confidence intervals
109    pub conf_int_lower: Vec<f64>,
110    /// Upper bounds of 95% confidence intervals
111    pub conf_int_upper: Vec<f64>,
112    /// R-squared (coefficient of determination)
113    pub r_squared: f64,
114    /// Adjusted R-squared (accounts for number of predictors)
115    pub adj_r_squared: f64,
116    /// F-statistic for overall model significance
117    pub f_statistic: f64,
118    /// P-value for F-statistic
119    pub f_p_value: f64,
120    /// Mean squared error of residuals
121    pub mse: f64,
122    /// Root mean squared error (prediction error in original units)
123    pub rmse: f64,
124    /// Mean absolute error of residuals
125    pub mae: f64,
126    /// Standard error of the regression (residual standard deviation)
127    pub std_error: f64,
128    /// Raw residuals (observed - predicted)
129    pub residuals: Vec<f64>,
130    /// Standardized residuals (accounting for leverage)
131    pub standardized_residuals: Vec<f64>,
132    /// Fitted/predicted values
133    pub predictions: Vec<f64>,
134    /// Leverage values for each observation (diagonal of hat matrix)
135    pub leverage: Vec<f64>,
136    /// Variance Inflation Factors for detecting multicollinearity
137    pub vif: Vec<VifResult>,
138    /// Number of observations
139    pub n: usize,
140    /// Number of predictor variables (excluding intercept)
141    pub k: usize,
142    /// Degrees of freedom for residuals (n - k - 1)
143    pub df: usize,
144    /// Names of variables (including intercept)
145    pub variable_names: Vec<String>,
146}
147
148// ============================================================================
149// Statistical Helper Functions
150// ============================================================================
151//
152// Utility functions for computing p-values, critical values, and leverage.
153
154/// Computes a two-tailed p-value from a t-statistic.
155///
156/// Uses the Student's t-distribution CDF to calculate the probability
157/// of observing a t-statistic as extreme as the one provided.
158///
159/// # Arguments
160///
161/// * `t` - The t-statistic value
162/// * `df` - Degrees of freedom
163///
164/// # Example
165///
166/// ```
167/// # use linreg_core::core::two_tailed_p_value;
168/// let p = two_tailed_p_value(2.0, 20.0);
169/// assert!(p > 0.0 && p < 0.1);
170/// ```
171pub fn two_tailed_p_value(t: f64, df: f64) -> f64 {
172    if t.abs() > 100.0 {
173        return 0.0;
174    }
175
176    let cdf = student_t_cdf(t, df);
177    if t >= 0.0 { 2.0 * (1.0 - cdf) } else { 2.0 * cdf }
178}
179
180/// Computes the critical t-value for a given significance level and degrees of freedom.
181///
182/// Returns the t-value such that the area under the t-distribution curve
183/// to the right of it equals alpha/2 (two-tailed test).
184///
185/// # Arguments
186///
187/// * `df` - Degrees of freedom
188/// * `alpha` - Significance level (typically 0.05 for 95% confidence)
189///
190/// # Example
191///
192/// ```
193/// # use linreg_core::core::t_critical_quantile;
194/// let t_crit = t_critical_quantile(20.0, 0.05);
195/// assert!(t_crit > 2.0); // approximately 2.086 for df=20, alpha=0.05
196/// ```
197pub fn t_critical_quantile(df: f64, alpha: f64) -> f64 {
198    let p = 1.0 - alpha / 2.0;
199    student_t_inverse_cdf(p, df)
200}
201
202/// Computes a p-value from an F-statistic.
203///
204/// Uses the F-distribution CDF to calculate the probability of observing
205/// an F-statistic as extreme as the one provided.
206///
207/// # Arguments
208///
209/// * `f_stat` - The F-statistic value
210/// * `df1` - Numerator degrees of freedom
211/// * `df2` - Denominator degrees of freedom
212///
213/// # Example
214///
215/// ```
216/// # use linreg_core::core::f_p_value;
217/// let p = f_p_value(5.0, 2.0, 20.0);
218/// assert!(p > 0.0 && p < 0.05);
219/// ```
220pub fn f_p_value(f_stat: f64, df1: f64, df2: f64) -> f64 {
221    if f_stat <= 0.0 {
222        return 1.0;
223    }
224    1.0 - fisher_snedecor_cdf(f_stat, df1, df2)
225}
226
227/// Computes leverage values from the design matrix and its inverse.
228///
229/// Leverage measures how far an observation's predictor values are from
230/// the center of the predictor space. High leverage points can have
231/// disproportionate influence on the regression results.
232///
233/// # Arguments
234///
235/// * `x` - Design matrix (including intercept column)
236/// * `xtx_inv` - Inverse of X'X matrix
237#[allow(clippy::needless_range_loop)]
238pub fn compute_leverage(x: &Matrix, xtx_inv: &Matrix) -> Vec<f64> {
239    let n = x.rows;
240    let mut leverage = vec![0.0; n];
241    for i in 0..n {
242        // x_row is (1, cols)
243        // temp = x_row * xtx_inv (1, cols)
244        // lev = temp * x_row^T (1, 1)
245        
246        // Manual row extraction and multiplication
247        let mut row_vec = vec![0.0; x.cols];
248        for j in 0..x.cols {
249            row_vec[j] = x.get(i, j);
250        }
251        
252        let mut temp_vec = vec![0.0; x.cols];
253        for c in 0..x.cols {
254            let mut sum = 0.0;
255            for k in 0..x.cols {
256                sum += row_vec[k] * xtx_inv.get(k, c);
257            }
258            temp_vec[c] = sum;
259        }
260        
261        leverage[i] = vec_dot(&temp_vec, &row_vec);
262    }
263    leverage
264}
265
266// ============================================================================
267// VIF Calculation
268// ============================================================================
269//
270// Variance Inflation Factor analysis for detecting multicollinearity.
271
272/// Calculates Variance Inflation Factors for all predictors.
273///
274/// VIF quantifies the severity of multicollinearity in a regression analysis.
275/// A VIF > 10 indicates high multicollinearity that may need to be addressed.
276///
277/// # Arguments
278///
279/// * `x_vars` - Predictor variables (each of length n)
280/// * `names` - Variable names (including intercept as first element)
281/// * `n` - Number of observations
282///
283/// # Returns
284///
285/// Vector of VIF results for each predictor (excluding intercept).
286///
287/// # Example
288///
289/// ```
290/// # use linreg_core::core::calculate_vif;
291/// let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
292/// let x2 = vec![2.0, 4.0, 6.0, 8.0, 10.0];
293/// let names = vec!["Intercept".to_string(), "X1".to_string(), "X2".to_string()];
294///
295/// let vif_results = calculate_vif(&[x1, x2], &names, 5);
296/// assert_eq!(vif_results.len(), 2);
297/// // Perfectly collinear variables will have VIF = infinity
298/// ```
299pub fn calculate_vif(x_vars: &[Vec<f64>], names: &[String], n: usize) -> Vec<VifResult> {
300    let k = x_vars.len();
301    if k <= 1 {
302        return vec![];
303    }
304
305    // Standardize predictors (Z-score)
306    let mut z_data = vec![0.0; n * k];
307
308    for (j, var) in x_vars.iter().enumerate() {
309        let mean = vec_mean(var);
310        let variance = var.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / ((n - 1) as f64);
311        let std_dev = variance.sqrt();
312
313        // Handle constant variables
314        if std_dev.abs() < 1e-10 {
315            return names.iter().skip(1).map(|name| VifResult {
316                variable: name.clone(),
317                vif: f64::INFINITY,
318                rsquared: 1.0,
319                interpretation: "Constant variable (undefined correlation)".to_string()
320            }).collect();
321        }
322
323        for i in 0..n {
324            z_data[i * k + j] = (var[i] - mean) / std_dev;
325        }
326    }
327
328    let z = Matrix::new(n, k, z_data);
329
330    // Correlation Matrix R = (1/(n-1)) * Z^T * Z
331    let z_t = z.transpose();
332    let zt_z = z_t.matmul(&z);
333    
334    // Scale by 1/(n-1)
335    let mut r_corr = zt_z; // Copy
336    let factor = 1.0 / ((n - 1) as f64);
337    for val in &mut r_corr.data {
338        *val *= factor;
339    }
340
341    // Invert R using QR on R_corr (since it's symmetric positive definite, Cholesky is better but QR works)
342    // Or just generic inversion. We implemented generic inversion for Upper Triangular.
343    // Let's use QR: A = QR => A^-1 = R^-1 Q^T
344    let (q_corr, r_corr_tri) = r_corr.qr();
345    
346    let r_inv_opt = r_corr_tri.invert_upper_triangular();
347    
348    let r_inv = match r_inv_opt {
349        Some(inv) => inv.matmul(&q_corr.transpose()),
350        None => {
351             return names.iter().skip(1).map(|name| VifResult {
352                variable: name.clone(),
353                vif: f64::INFINITY,
354                rsquared: 1.0,
355                interpretation: "Perfect multicollinearity (singular matrix)".to_string()
356            }).collect();
357        }
358    };
359
360    // Extract diagonals
361    let mut results = vec![];
362    for j in 0..k {
363        let vif = r_inv.get(j, j);
364        let vif = if vif < 1.0 { 1.0 } else { vif };
365        let rsquared = 1.0 - 1.0 / vif;
366
367        let interpretation = if vif.is_infinite() {
368            "Perfect multicollinearity".to_string()
369        } else if vif > 10.0 {
370            "High multicollinearity - consider removing".to_string()
371        } else if vif > 5.0 {
372            "Moderate multicollinearity".to_string()
373        } else {
374            "Low multicollinearity".to_string()
375        };
376
377        results.push(VifResult {
378            variable: names[j + 1].clone(),
379            vif,
380            rsquared,
381            interpretation,
382        });
383    }
384
385    results
386}
387
388// ============================================================================
389// OLS Regression
390// ============================================================================
391//
392// Ordinary Least Squares regression implementation using QR decomposition.
393
394/// Performs Ordinary Least Squares regression using QR decomposition.
395///
396/// Uses a numerically stable QR decomposition approach to solve the normal
397/// equations. Returns comprehensive statistics including coefficients,
398/// standard errors, t-statistics, p-values, and diagnostic measures.
399///
400/// # Arguments
401///
402/// * `y` - Response variable (n observations)
403/// * `x_vars` - Predictor variables (each of length n)
404/// * `variable_names` - Names for variables (including intercept)
405///
406/// # Returns
407///
408/// A [`RegressionOutput`] containing all regression statistics and diagnostics.
409///
410/// # Errors
411///
412/// Returns [`Error::InsufficientData`] if n ≤ k + 1.
413/// Returns [`Error::SingularMatrix`] if perfect multicollinearity exists.
414/// Returns [`Error::InvalidInput`] if coefficients are NaN.
415///
416/// # Example
417///
418/// ```
419/// # use linreg_core::core::ols_regression;
420/// # use linreg_core::Error;
421/// let y = vec![2.5, 3.7, 4.2, 5.1, 6.3, 7.0];
422/// let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
423/// let x2 = vec![2.0, 4.0, 5.0, 4.0, 3.0, 2.0];
424/// let names = vec![
425///     "Intercept".to_string(),
426///     "Temperature".to_string(),
427///     "Pressure".to_string(),
428/// ];
429///
430/// let result = ols_regression(&y, &[x1, x2], &names)?;
431/// println!("R-squared: {}", result.r_squared);
432/// # Ok::<(), Error>(())
433/// ```
434#[allow(clippy::needless_range_loop)]
435pub fn ols_regression(
436    y: &[f64],
437    x_vars: &[Vec<f64>],
438    variable_names: &[String],
439) -> Result<RegressionOutput> {
440    let n = y.len();
441    let k = x_vars.len();
442    let p = k + 1;
443
444    // Validate inputs
445    if n <= k + 1 {
446        return Err(Error::InsufficientData { required: k + 2, available: n });
447    }
448
449    // Validate dimensions and finite values using shared helper
450    crate::diagnostics::validate_regression_data(y, x_vars)?;
451
452    // Prepare variable names
453    let mut names = variable_names.to_vec();
454    while names.len() <= k {
455        names.push(format!("X{}", names.len()));
456    }
457
458    // Create design matrix
459    let mut x_data = vec![0.0; n * p];
460    for (row, _yi) in y.iter().enumerate() {
461        x_data[row * p] = 1.0;  // intercept
462        for (col, x_var) in x_vars.iter().enumerate() {
463            x_data[row * p + col + 1] = x_var[row];
464        }
465    }
466
467    let x_matrix = Matrix::new(n, p, x_data);
468
469    // QR Decomposition
470    let (q, r) = x_matrix.qr();
471
472    // Solve R * beta = Q^T * y
473    // extract upper p x p part of R
474    let mut r_upper = Matrix::zeros(p, p);
475    for i in 0..p {
476        for j in 0..p {
477            r_upper.set(i, j, r.get(i, j));
478        }
479    }
480
481    // Q^T * y
482    let q_t = q.transpose();
483    let qty = q_t.mul_vec(y);
484    
485    // Take first p elements of qty
486    let rhs_vec = qty[0..p].to_vec();
487    let rhs_mat = Matrix::new(p, 1, rhs_vec); // column vector
488
489    // Invert R_upper
490    let r_inv = match r_upper.invert_upper_triangular() {
491        Some(inv) => inv,
492        None => return Err(Error::SingularMatrix),
493    };
494
495    let beta_mat = r_inv.matmul(&rhs_mat);
496    let beta = beta_mat.data;
497
498    // Validate coefficients
499    if beta.iter().any(|&b| b.is_nan()) {
500        return Err(Error::InvalidInput("Coefficients contain NaN".to_string()));
501    }
502
503    // Compute predictions and residuals
504    let predictions = x_matrix.mul_vec(&beta);
505    let residuals = vec_sub(y, &predictions);
506
507    // Compute sums of squares
508    let y_mean = vec_mean(y);
509    let ss_total: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
510    let ss_residual: f64 = residuals.iter().map(|&r| r.powi(2)).sum();
511    let ss_regression = ss_total - ss_residual;
512
513    // R-squared and adjusted R-squared
514    let r_squared = if ss_total == 0.0 {
515        f64::NAN
516    } else {
517        1.0 - ss_residual / ss_total
518    };
519
520    let adj_r_squared = if ss_total == 0.0 {
521        f64::NAN
522    } else {
523        1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n - k - 1) as f64)
524    };
525
526    // Mean squared error and standard error
527    let df = n - k - 1;
528    let mse = ss_residual / df as f64;
529    let std_error = mse.sqrt();
530
531    // Standard errors using (X'X)^-1 = R^-1 (R')^-1
532    // xtx_inv = r_inv * r_inv^T
533    let xtx_inv = r_inv.matmul(&r_inv.transpose());
534
535    let mut std_errors = vec![0.0; k + 1];
536    for i in 0..=k {
537        std_errors[i] = (xtx_inv.get(i, i) * mse).sqrt();
538        if std_errors[i].is_nan() {
539            return Err(Error::InvalidInput(format!("Standard error for coefficient {} is NaN", i)));
540        }
541    }
542
543    // T-statistics and p-values
544    let t_stats: Vec<f64> = beta.iter().zip(&std_errors).map(|(&b, &se)| b / se).collect();
545    let p_values: Vec<f64> = t_stats.iter().map(|&t| two_tailed_p_value(t, df as f64)).collect();
546
547    // Confidence intervals
548    let alpha = 0.05;
549    let t_critical = t_critical_quantile(df as f64, alpha);
550
551    let conf_int_lower: Vec<f64> = beta.iter().zip(&std_errors).map(|(&b, &se)| b - t_critical * se).collect();
552    let conf_int_upper: Vec<f64> = beta.iter().zip(&std_errors).map(|(&b, &se)| b + t_critical * se).collect();
553
554    // Leverage
555    let leverage = compute_leverage(&x_matrix, &xtx_inv);
556
557    // Standardized residuals
558    let residuals_vec = residuals.clone();
559    let standardized_residuals: Vec<f64> = residuals_vec.iter().zip(&leverage)
560        .map(|(&r, &h)| {
561            let factor = (1.0 - h).max(MIN_LEVERAGE_DENOM).sqrt();
562            let denom = std_error * factor;
563            if denom > MIN_LEVERAGE_DENOM { r / denom } else { 0.0 }
564        })
565        .collect();
566
567    // F-statistic
568    let f_statistic = (ss_regression / k as f64) / mse;
569    let f_p_value = f_p_value(f_statistic, k as f64, df as f64);
570
571    // RMSE and MAE
572    let rmse = std_error;
573    let mae: f64 = residuals_vec.iter().map(|&r| r.abs()).sum::<f64>() / n as f64;
574
575    // VIF
576    let vif = calculate_vif(x_vars, &names, n);
577
578    Ok(RegressionOutput {
579        coefficients: beta,
580        std_errors,
581        t_stats,
582        p_values,
583        conf_int_lower,
584        conf_int_upper,
585        r_squared,
586        adj_r_squared,
587        f_statistic,
588        f_p_value,
589        mse,
590        rmse,
591        mae,
592        std_error,
593        residuals: residuals_vec,
594        standardized_residuals,
595        predictions,
596        leverage,
597        vif,
598        n,
599        k,
600        df,
601        variable_names: names,
602    })
603}