linreg_core/diagnostics/
helpers.rs

1// ============================================================================
2// Diagnostic Test Helper Functions
3// ============================================================================
4
5//! Shared helper functions for diagnostic tests.
6//!
7//! This module provides utility functions used across multiple diagnostic tests,
8//! including:
9//!
10//! - Data validation (dimension and finite value checks)
11//! - P-value computation for common statistical distributions
12//! - OLS fitting with numerical stability safeguards
13//! - Residual sum of squares calculation
14//!
15//! # OLS Fitting Strategy
16//!
17//! The [`fit_ols`] function uses a robust two-stage approach:
18//! 1. First attempts standard QR decomposition OLS
19//! 2. Falls back to ridge regression (λ = 0.0001) if numerical issues occur
20//!
21//! This ensures diagnostic tests work correctly even with multicollinear data.
22
23use crate::error::{Error, Result};
24use crate::linalg::{Matrix, vec_sub};
25use crate::distributions::{student_t_cdf, fisher_snedecor_cdf, chi_squared_survival};
26
27/// Validates regression input data for dimensions and finite values.
28///
29/// This is a high-performance validation function that checks:
30/// 1. All predictor variables have the same length as the response
31/// 2. Response variable contains no NaN or infinite values
32/// 3. All predictor variables contain no NaN or infinite values
33///
34/// Uses explicit loops for maximum performance (no closure overhead).
35///
36/// # Arguments
37///
38/// * `y` - Response variable (n observations)
39/// * `x_vars` - Predictor variables (each expected to have length n)
40///
41/// # Returns
42///
43/// `Ok(())` if all validations pass, otherwise an error indicating the specific issue.
44///
45/// # Errors
46///
47/// * [`Error::DimensionMismatch`] - if any x_var has different length than y
48/// * [`Error::InvalidInput`] - if y or x_vars contain NaN or infinite values
49///
50/// # Examples
51///
52/// ```ignore
53/// use linreg_core::diagnostics::helpers::validate_regression_data;
54///
55/// let y = vec![1.0, 2.0, 3.0];
56/// let x1 = vec![1.0, 2.0, 3.0];
57/// let x2 = vec![2.0, 4.0, 6.0];
58///
59/// validate_regression_data(&y, &[x1, x2])?;
60/// # Ok::<(), linreg_core::Error>(())
61/// ```
62pub fn validate_regression_data(y: &[f64], x_vars: &[Vec<f64>]) -> Result<()> {
63    let n = y.len();
64
65    // Validate all x_vars have the same length as y
66    for (i, x_var) in x_vars.iter().enumerate() {
67        if x_var.len() != n {
68            return Err(Error::DimensionMismatch(
69                format!("X{} has {} observations but y has {}", i + 1, x_var.len(), n)
70            ));
71        }
72    }
73
74    // Validate y contains no NaN or infinite values
75    for (i, &yi) in y.iter().enumerate() {
76        if !yi.is_finite() {
77            return Err(Error::InvalidInput(format!(
78                "y contains non-finite value at index {}: {}", i, yi
79            )));
80        }
81    }
82
83    // Validate x_vars contain no NaN or infinite values
84    for (var_idx, x_var) in x_vars.iter().enumerate() {
85        for (i, &xi) in x_var.iter().enumerate() {
86            if !xi.is_finite() {
87                return Err(Error::InvalidInput(format!(
88                    "X{} contains non-finite value at index {}: {}", var_idx + 1, i, xi
89                )));
90            }
91        }
92    }
93
94    Ok(())
95}
96
97/// Computes a two-tailed p-value from a t-statistic.
98///
99/// This function calculates the probability of observing a t-statistic as extreme
100/// as the one provided, assuming a two-tailed test. It uses the Student's t
101/// cumulative distribution function.
102///
103/// # Arguments
104///
105/// * `t` - The t-statistic value
106/// * `df` - Degrees of freedom (must be positive)
107///
108/// # Returns
109///
110/// The two-tailed p-value in the range `[0, 2]`. For extreme values (`|t| > 100`),
111/// returns `0.0` to avoid numerical underflow.
112///
113/// # Examples
114///
115/// ```
116/// use linreg_core::diagnostics::two_tailed_p_value;
117///
118/// // t = 2.0 with 10 degrees of freedom
119/// let p = two_tailed_p_value(2.0, 10.0);
120/// assert!(p > 0.05 && p < 0.10);
121/// ```
122pub fn two_tailed_p_value(t: f64, df: f64) -> f64 {
123    if t.abs() > 100.0 {
124        return 0.0;
125    }
126
127    let cdf = student_t_cdf(t, df);
128    if t >= 0.0 { 2.0 * (1.0 - cdf) } else { 2.0 * cdf }
129}
130
131/// Computes a p-value from an F-statistic.
132///
133/// Calculates the upper-tail probability of observing an F-statistic as extreme
134/// as the one provided, using the Fisher-Snedecor (F) distribution.
135///
136/// # Arguments
137///
138/// * `f_stat` - The F-statistic value (must be non-negative)
139/// * `df1` - Numerator degrees of freedom
140/// * `df2` - Denominator degrees of freedom
141///
142/// # Returns
143///
144/// The p-value (upper tail probability) in the range `[0, 1]`. Returns `1.0` for
145/// non-positive F-statistics.
146///
147/// # Examples
148///
149/// ```
150/// use linreg_core::diagnostics::f_p_value;
151///
152/// // F = 5.0 with df1 = 2, df2 = 10
153/// let p = f_p_value(5.0, 2.0, 10.0);
154/// assert!(p > 0.0 && p < 0.05);
155/// ```
156pub fn f_p_value(f_stat: f64, df1: f64, df2: f64) -> f64 {
157    if f_stat <= 0.0 {
158        return 1.0;
159    }
160    1.0 - fisher_snedecor_cdf(f_stat, df1, df2)
161}
162
163/// Computes a p-value from a chi-squared statistic (upper tail).
164///
165/// Calculates the probability of observing a chi-squared statistic as extreme
166/// as the one provided, using the chi-squared distribution.
167///
168/// # Arguments
169///
170/// * `stat` - The chi-squared statistic value (must be non-negative)
171/// * `df` - Degrees of freedom
172///
173/// # Returns
174///
175/// The upper-tail p-value in the range `[0, 1]`.
176///
177/// # Examples
178///
179/// ```ignore
180/// use linreg_core::diagnostics::helpers::chi_squared_p_value;
181///
182/// // chi-squared = 10.0 with df = 5
183/// let p = chi_squared_p_value(10.0, 5.0);
184/// assert!(p > 0.0 && p < 1.0);
185/// ```
186pub fn chi_squared_p_value(stat: f64, df: f64) -> f64 {
187    chi_squared_survival(stat, df)
188}
189
190/// Computes the residual sum of squares (RSS) from a fitted model.
191///
192/// The RSS is the sum of squared differences between observed and predicted
193/// values: `RSS = Σ(yᵢ - ŷᵢ)²`, where `ŷᵢ = Xᵢβ`.
194///
195/// This is a measure of model fit - lower values indicate better fit. The RSS
196/// is used in many diagnostic tests including the Rainbow test and likelihood
197/// ratio tests.
198///
199/// # Arguments
200///
201/// * `y` - Observed response values (n observations)
202/// * `x` - Design matrix (n × p)
203/// * `beta` - Coefficient vector (p elements)
204///
205/// # Returns
206///
207/// The residual sum of squares as a non-negative value.
208///
209/// # Examples
210///
211/// ```ignore
212/// use linreg_core::diagnostics::helpers::compute_rss;
213/// use linreg_core::linalg::Matrix;
214///
215/// let y = vec![2.0, 4.0, 6.0];
216/// let x = Matrix::new(3, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0]);
217/// let beta = vec![0.0, 2.0];  // y = 2*x
218/// let rss = compute_rss(&y, &x, &beta).unwrap();
219/// assert_eq!(rss, 0.0);  // Perfect fit
220/// ```
221pub fn compute_rss(y: &[f64], x: &Matrix, beta: &[f64]) -> Result<f64> {
222    // predictions = x * beta
223    let predictions = x.mul_vec(beta);
224    let residuals = vec_sub(y, &predictions);
225    Ok(residuals.iter().map(|&r| r * r).sum())
226}
227
228/// Fits an OLS regression model and returns the coefficient estimates.
229///
230/// This function provides a robust OLS fitting procedure that first attempts
231/// standard QR decomposition, then falls back to ridge regression if numerical
232/// instability is detected (e.g., due to multicollinearity).
233///
234/// The ridge fallback uses a very small regularization parameter (λ = 0.0001)
235/// to maintain numerical stability while minimizing distortion of the estimates.
236///
237/// # Arguments
238///
239/// * `y` - Response variable (n observations)
240/// * `x` - Design matrix (n × p, should include intercept column if needed)
241///
242/// # Returns
243///
244/// A vector of coefficient estimates (p elements).
245///
246/// # Errors
247///
248/// * [`Error::SingularMatrix`] - if the design matrix is singular and ridge
249///   regression also fails
250///
251/// # Examples
252///
253/// ```ignore
254/// use linreg_core::diagnostics::helpers::fit_ols;
255/// use linreg_core::linalg::Matrix;
256///
257/// let y = vec![2.1, 4.0, 5.9];
258/// let x = Matrix::new(3, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0]);
259/// let beta = fit_ols(&y, &x).unwrap();
260/// assert_eq!(beta.len(), 2);  // Intercept and slope
261/// ```
262pub fn fit_ols(y: &[f64], x: &Matrix) -> Result<Vec<f64>> {
263    // First try standard QR decomposition OLS
264    let result = try_fit_ols_qr(y, x);
265    if result.is_ok() {
266        return result;
267    }
268
269    // If QR fails due to multicollinearity, use ridge regression
270    // Use a very small lambda to minimize distortion while maintaining stability
271    fit_ols_ridge(y, x, 0.0001)
272}
273
274/// Standard QR decomposition OLS solver.
275///
276/// Solves the normal equations using QR decomposition: `Xβ = y`. This is the
277/// preferred method for OLS estimation due to its numerical stability.
278///
279/// The algorithm computes `X = QR` where Q is orthogonal and R is upper
280/// triangular, then solves `Rβ = Qᵀy` via back-substitution.
281///
282/// # Arguments
283///
284/// * `y` - Response variable (n observations)
285/// * `x` - Design matrix (n × p)
286///
287/// # Returns
288///
289/// A vector of coefficient estimates (p elements).
290///
291/// # Errors
292///
293/// * [`Error::SingularMatrix`] - if the design matrix is singular (p > n or
294///   R is not invertible)
295fn try_fit_ols_qr(y: &[f64], x: &Matrix) -> Result<Vec<f64>> {
296    let p = x.cols;
297    let n = x.rows;
298
299    // When p > n, we have an underdetermined system (more predictors than observations)
300    // Fall back to ridge regression for numerical stability
301    if p > n {
302        return Err(Error::SingularMatrix);
303    }
304
305    let (q, r) = x.qr();
306
307    // Q^T * y
308    let qty = q.transpose().mul_vec(y);
309
310    // Take first p elements
311    let rhs_vec = qty[0..p].to_vec();
312    let rhs_mat = Matrix::new(p, 1, rhs_vec);
313
314    // Extract upper triangle of R
315    let mut r_upper = Matrix::zeros(p, p);
316    for i in 0..p {
317        for j in 0..p {
318            r_upper.set(i, j, r.get(i, j));
319        }
320    }
321
322    match r_upper.invert_upper_triangular() {
323        Some(r_inv) => Ok(r_inv.matmul(&rhs_mat).data),
324        None => Err(Error::SingularMatrix),
325    }
326}
327
328/// Ridge regression fallback for multicollinear data.
329///
330/// Solves the ridge regression problem: `(X'X + λI)β = X'y`. This adds a small
331/// positive constant to the diagonal of `X'X`, ensuring invertibility even when
332/// the design matrix is rank-deficient.
333///
334/// Ridge regression is used as a fallback when standard QR decomposition fails
335/// due to multicollinearity or numerical singularity.
336///
337/// # Arguments
338///
339/// * `y` - Response variable (n observations)
340/// * `x` - Design matrix (n × p)
341/// * `lambda` - Regularization parameter (small positive value, e.g., 0.0001)
342///
343/// # Returns
344///
345/// A vector of ridge-regularized coefficient estimates (p elements).
346///
347/// # Errors
348///
349/// * [`Error::SingularMatrix`] - if the ridge-adjusted matrix is still singular
350fn fit_ols_ridge(y: &[f64], x: &Matrix, lambda: f64) -> Result<Vec<f64>> {
351    let p = x.cols;
352
353    // Solve: (X'X + lambda*I) * beta = X'y
354    let xt = x.transpose();
355    let xtx = xt.matmul(x);
356
357    // Add ridge to diagonal
358    let mut xtx_ridge_data = xtx.data.clone();
359    for i in 0..p {
360        xtx_ridge_data[i * p + i] += lambda;
361    }
362    let xtx_ridge = Matrix::new(p, p, xtx_ridge_data);
363
364    // X'y
365    let xty = xt.mul_vec(y);
366
367    // Invert and solve
368    let xtx_inv = xtx_ridge.invert().ok_or(Error::SingularMatrix)?;
369    let beta_mat = xtx_inv.mul_vec(&xty);
370    Ok(beta_mat)
371}