linreg_core/diagnostics/
helpers.rs

1// ============================================================================
2// Diagnostic Test Helper Functions
3// ============================================================================
4
5//! Shared helper functions for diagnostic tests.
6//!
7//! This module provides utility functions used across multiple diagnostic tests,
8//! including:
9//!
10//! - Data validation (dimension and finite value checks)
11//! - P-value computation for common statistical distributions
12//! - OLS fitting with numerical stability safeguards
13//! - Residual sum of squares calculation
14//!
15//! # OLS Fitting Strategy
16//!
17//! The [`fit_ols`] function uses a robust two-stage approach:
18//! 1. First attempts standard QR decomposition OLS
19//! 2. Falls back to ridge regression (λ = 0.0001) if numerical issues occur
20//!
21//! This ensures diagnostic tests work correctly even with multicollinear data.
22
23use crate::distributions::{chi_squared_survival, fisher_snedecor_cdf, student_t_cdf};
24use crate::error::{Error, Result};
25use crate::linalg::{vec_sub, Matrix};
26
27/// Validates regression input data for dimensions and finite values.
28///
29/// This is a high-performance validation function that checks:
30/// 1. All predictor variables have the same length as the response
31/// 2. Response variable contains no NaN or infinite values
32/// 3. All predictor variables contain no NaN or infinite values
33///
34/// Uses explicit loops for maximum performance (no closure overhead).
35///
36/// # Arguments
37///
38/// * `y` - Response variable (n observations)
39/// * `x_vars` - Predictor variables (each expected to have length n)
40///
41/// # Returns
42///
43/// `Ok(())` if all validations pass, otherwise an error indicating the specific issue.
44///
45/// # Errors
46///
47/// * [`Error::DimensionMismatch`] - if any x_var has different length than y
48/// * [`Error::InvalidInput`] - if y or x_vars contain NaN or infinite values
49///
50/// # Examples
51///
52/// ```ignore
53/// use linreg_core::diagnostics::helpers::validate_regression_data;
54///
55/// let y = vec![1.0, 2.0, 3.0];
56/// let x1 = vec![1.0, 2.0, 3.0];
57/// let x2 = vec![2.0, 4.0, 6.0];
58///
59/// validate_regression_data(&y, &[x1, x2])?;
60/// # Ok::<(), linreg_core::Error>(())
61/// ```
62pub fn validate_regression_data(y: &[f64], x_vars: &[Vec<f64>]) -> Result<()> {
63    let n = y.len();
64
65    // Validate all x_vars have the same length as y
66    for (i, x_var) in x_vars.iter().enumerate() {
67        if x_var.len() != n {
68            return Err(Error::DimensionMismatch(format!(
69                "X{} has {} observations but y has {}",
70                i + 1,
71                x_var.len(),
72                n
73            )));
74        }
75    }
76
77    // Validate y contains no NaN or infinite values
78    for (i, &yi) in y.iter().enumerate() {
79        if !yi.is_finite() {
80            return Err(Error::InvalidInput(format!(
81                "y contains non-finite value at index {}: {}",
82                i, yi
83            )));
84        }
85    }
86
87    // Validate x_vars contain no NaN or infinite values
88    for (var_idx, x_var) in x_vars.iter().enumerate() {
89        for (i, &xi) in x_var.iter().enumerate() {
90            if !xi.is_finite() {
91                return Err(Error::InvalidInput(format!(
92                    "X{} contains non-finite value at index {}: {}",
93                    var_idx + 1,
94                    i,
95                    xi
96                )));
97            }
98        }
99    }
100
101    Ok(())
102}
103
104/// Computes a two-tailed p-value from a t-statistic.
105///
106/// This function calculates the probability of observing a t-statistic as extreme
107/// as the one provided, assuming a two-tailed test. It uses the Student's t
108/// cumulative distribution function.
109///
110/// # Arguments
111///
112/// * `t` - The t-statistic value
113/// * `df` - Degrees of freedom (must be positive)
114///
115/// # Returns
116///
117/// The two-tailed p-value in the range `[0, 2]`. For extreme values (`|t| > 100`),
118/// returns `0.0` to avoid numerical underflow.
119///
120/// # Examples
121///
122/// ```
123/// use linreg_core::diagnostics::two_tailed_p_value;
124///
125/// // t = 2.0 with 10 degrees of freedom
126/// let p = two_tailed_p_value(2.0, 10.0);
127/// assert!(p > 0.05 && p < 0.10);
128/// ```
129pub fn two_tailed_p_value(t: f64, df: f64) -> f64 {
130    if t.abs() > 100.0 {
131        return 0.0;
132    }
133
134    let cdf = student_t_cdf(t, df);
135    if t >= 0.0 {
136        2.0 * (1.0 - cdf)
137    } else {
138        2.0 * cdf
139    }
140}
141
142/// Computes a p-value from an F-statistic.
143///
144/// Calculates the upper-tail probability of observing an F-statistic as extreme
145/// as the one provided, using the Fisher-Snedecor (F) distribution.
146///
147/// # Arguments
148///
149/// * `f_stat` - The F-statistic value (must be non-negative)
150/// * `df1` - Numerator degrees of freedom
151/// * `df2` - Denominator degrees of freedom
152///
153/// # Returns
154///
155/// The p-value (upper tail probability) in the range `[0, 1]`. Returns `1.0` for
156/// non-positive F-statistics.
157///
158/// # Examples
159///
160/// ```
161/// use linreg_core::diagnostics::f_p_value;
162///
163/// // F = 5.0 with df1 = 2, df2 = 10
164/// let p = f_p_value(5.0, 2.0, 10.0);
165/// assert!(p > 0.0 && p < 0.05);
166/// ```
167pub fn f_p_value(f_stat: f64, df1: f64, df2: f64) -> f64 {
168    if f_stat <= 0.0 {
169        return 1.0;
170    }
171    1.0 - fisher_snedecor_cdf(f_stat, df1, df2)
172}
173
174/// Computes a p-value from a chi-squared statistic (upper tail).
175///
176/// Calculates the probability of observing a chi-squared statistic as extreme
177/// as the one provided, using the chi-squared distribution.
178///
179/// # Arguments
180///
181/// * `stat` - The chi-squared statistic value (must be non-negative)
182/// * `df` - Degrees of freedom
183///
184/// # Returns
185///
186/// The upper-tail p-value in the range `[0, 1]`.
187///
188/// # Examples
189///
190/// ```ignore
191/// use linreg_core::diagnostics::helpers::chi_squared_p_value;
192///
193/// // chi-squared = 10.0 with df = 5
194/// let p = chi_squared_p_value(10.0, 5.0);
195/// assert!(p > 0.0 && p < 1.0);
196/// ```
197pub fn chi_squared_p_value(stat: f64, df: f64) -> f64 {
198    chi_squared_survival(stat, df)
199}
200
201/// Computes the residual sum of squares (RSS) from a fitted model.
202///
203/// The RSS is the sum of squared differences between observed and predicted
204/// values: `RSS = Σ(yᵢ - ŷᵢ)²`, where `ŷᵢ = Xᵢβ`.
205///
206/// This is a measure of model fit - lower values indicate better fit. The RSS
207/// is used in many diagnostic tests including the Rainbow test and likelihood
208/// ratio tests.
209///
210/// # Arguments
211///
212/// * `y` - Observed response values (n observations)
213/// * `x` - Design matrix (n × p)
214/// * `beta` - Coefficient vector (p elements)
215///
216/// # Returns
217///
218/// The residual sum of squares as a non-negative value.
219///
220/// # Examples
221///
222/// ```ignore
223/// use linreg_core::diagnostics::helpers::compute_rss;
224/// use linreg_core::linalg::Matrix;
225///
226/// let y = vec![2.0, 4.0, 6.0];
227/// let x = Matrix::new(3, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0]);
228/// let beta = vec![0.0, 2.0];  // y = 2*x
229/// let rss = compute_rss(&y, &x, &beta).unwrap();
230/// assert_eq!(rss, 0.0);  // Perfect fit
231/// ```
232pub fn compute_rss(y: &[f64], x: &Matrix, beta: &[f64]) -> Result<f64> {
233    // predictions = x * beta
234    let predictions = x.mul_vec(beta);
235    let residuals = vec_sub(y, &predictions);
236    Ok(residuals.iter().map(|&r| r * r).sum())
237}
238
239/// Fits an OLS regression model and returns the coefficient estimates.
240///
241/// This function provides a robust OLS fitting procedure that first attempts
242/// standard QR decomposition, then falls back to ridge regression if numerical
243/// instability is detected (e.g., due to multicollinearity).
244///
245/// The ridge fallback uses a very small regularization parameter (λ = 0.0001)
246/// to maintain numerical stability while minimizing distortion of the estimates.
247///
248/// # Arguments
249///
250/// * `y` - Response variable (n observations)
251/// * `x` - Design matrix (n × p, should include intercept column if needed)
252///
253/// # Returns
254///
255/// A vector of coefficient estimates (p elements).
256///
257/// # Errors
258///
259/// * [`Error::SingularMatrix`] - if the design matrix is singular and ridge
260///   regression also fails
261///
262/// # Examples
263///
264/// ```ignore
265/// use linreg_core::diagnostics::helpers::fit_ols;
266/// use linreg_core::linalg::Matrix;
267///
268/// let y = vec![2.1, 4.0, 5.9];
269/// let x = Matrix::new(3, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0]);
270/// let beta = fit_ols(&y, &x).unwrap();
271/// assert_eq!(beta.len(), 2);  // Intercept and slope
272/// ```
273pub fn fit_ols(y: &[f64], x: &Matrix) -> Result<Vec<f64>> {
274    // First try standard QR decomposition OLS
275    let result = try_fit_ols_qr(y, x);
276    if result.is_ok() {
277        return result;
278    }
279
280    // If QR fails due to multicollinearity, use ridge regression
281    // Use a very small lambda to minimize distortion while maintaining stability
282    fit_ols_ridge(y, x, 0.0001)
283}
284
285/// Standard QR decomposition OLS solver.
286///
287/// Solves the normal equations using QR decomposition: `Xβ = y`. This is the
288/// preferred method for OLS estimation due to its numerical stability.
289///
290/// The algorithm computes `X = QR` where Q is orthogonal and R is upper
291/// triangular, then solves `Rβ = Qᵀy` via back-substitution.
292///
293/// # Arguments
294///
295/// * `y` - Response variable (n observations)
296/// * `x` - Design matrix (n × p)
297///
298/// # Returns
299///
300/// A vector of coefficient estimates (p elements).
301///
302/// # Errors
303///
304/// * [`Error::SingularMatrix`] - if the design matrix is singular (p > n or
305///   R is not invertible)
306fn try_fit_ols_qr(y: &[f64], x: &Matrix) -> Result<Vec<f64>> {
307    let p = x.cols;
308    let n = x.rows;
309
310    // When p > n, we have an underdetermined system (more predictors than observations)
311    // Fall back to ridge regression for numerical stability
312    if p > n {
313        return Err(Error::SingularMatrix);
314    }
315
316    let (q, r) = x.qr();
317
318    // Q^T * y
319    let qty = q.transpose().mul_vec(y);
320
321    // Take first p elements
322    let rhs_vec = qty[0..p].to_vec();
323    let rhs_mat = Matrix::new(p, 1, rhs_vec);
324
325    // Extract upper triangle of R
326    let mut r_upper = Matrix::zeros(p, p);
327    for i in 0..p {
328        for j in 0..p {
329            r_upper.set(i, j, r.get(i, j));
330        }
331    }
332
333    match r_upper.invert_upper_triangular() {
334        Some(r_inv) => Ok(r_inv.matmul(&rhs_mat).data),
335        None => Err(Error::SingularMatrix),
336    }
337}
338
339/// Ridge regression fallback for multicollinear data.
340///
341/// Solves the ridge regression problem: `(X'X + λI)β = X'y`. This adds a small
342/// positive constant to the diagonal of `X'X`, ensuring invertibility even when
343/// the design matrix is rank-deficient.
344///
345/// Ridge regression is used as a fallback when standard QR decomposition fails
346/// due to multicollinearity or numerical singularity.
347///
348/// # Arguments
349///
350/// * `y` - Response variable (n observations)
351/// * `x` - Design matrix (n × p)
352/// * `lambda` - Regularization parameter (small positive value, e.g., 0.0001)
353///
354/// # Returns
355///
356/// A vector of ridge-regularized coefficient estimates (p elements).
357///
358/// # Errors
359///
360/// * [`Error::SingularMatrix`] - if the ridge-adjusted matrix is still singular
361fn fit_ols_ridge(y: &[f64], x: &Matrix, lambda: f64) -> Result<Vec<f64>> {
362    let p = x.cols;
363
364    // Solve: (X'X + lambda*I) * beta = X'y
365    let xt = x.transpose();
366    let xtx = xt.matmul(x);
367
368    // Add ridge to diagonal
369    let mut xtx_ridge_data = xtx.data.clone();
370    for i in 0..p {
371        xtx_ridge_data[i * p + i] += lambda;
372    }
373    let xtx_ridge = Matrix::new(p, p, xtx_ridge_data);
374
375    // X'y
376    let xty = xt.mul_vec(y);
377
378    // Invert and solve
379    let xtx_inv = xtx_ridge.invert().ok_or(Error::SingularMatrix)?;
380    let beta_mat = xtx_inv.mul_vec(&xty);
381    Ok(beta_mat)
382}