linreg_core/core.rs
1//! Core OLS regression implementation.
2//!
3//! This module provides the main Ordinary Least Squares regression functionality
4//! that can be used directly in Rust code. Functions accept native Rust slices
5//! and return Result types for proper error handling.
6//!
7//! # Example
8//!
9//! ```
10//! # use linreg_core::core::ols_regression;
11//! # use linreg_core::Error;
12//! let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
13//! let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
14//! let x2 = vec![2.0, 3.0, 3.5, 4.0, 4.5, 5.0];
15//! let names = vec![
16//! "Intercept".to_string(),
17//! "X1".to_string(),
18//! "X2".to_string(),
19//! ];
20//!
21//! let result = ols_regression(&y, &[x1, x2], &names)?;
22//! # Ok::<(), Error>(())
23//! ```
24
25use crate::error::{Error, Result};
26use crate::linalg::{Matrix, vec_mean, vec_sub, vec_dot};
27use serde::Serialize;
28use crate::distributions::{student_t_cdf, student_t_inverse_cdf, fisher_snedecor_cdf};
29
30// ============================================================================
31// Numerical Constants
32// ============================================================================
33
34/// Minimum threshold for standardized residual denominator to avoid division by zero.
35/// When (1 - leverage) is very small, the observation has extremely high leverage
36/// and standardized residuals may be unreliable.
37const MIN_LEVERAGE_DENOM: f64 = 1e-10;
38
39// ============================================================================
40// Result Types
41// ============================================================================
42//
43// Structs containing the output of regression computations.
44
45/// Result of VIF (Variance Inflation Factor) calculation.
46///
47/// VIF measures how much the variance of an estimated regression coefficient
48/// increases due to multicollinearity among the predictors.
49#[derive(Debug, Clone, Serialize)]
50pub struct VifResult {
51 /// Name of the predictor variable
52 pub variable: String,
53 /// Variance Inflation Factor (VIF > 10 indicates high multicollinearity)
54 pub vif: f64,
55 /// R-squared from regressing this predictor on all others
56 pub rsquared: f64,
57 /// Human-readable interpretation of the VIF value
58 pub interpretation: String,
59}
60
61/// Complete output from OLS regression.
62///
63/// Contains all coefficients, statistics, diagnostics, and residuals from
64/// an Ordinary Least Squares regression.
65#[derive(Debug, Clone, Serialize)]
66pub struct RegressionOutput {
67 /// Regression coefficients (including intercept)
68 pub coefficients: Vec<f64>,
69 /// Standard errors of coefficients
70 pub std_errors: Vec<f64>,
71 /// t-statistics for coefficient significance tests
72 pub t_stats: Vec<f64>,
73 /// Two-tailed p-values for coefficients
74 pub p_values: Vec<f64>,
75 /// Lower bounds of 95% confidence intervals
76 pub conf_int_lower: Vec<f64>,
77 /// Upper bounds of 95% confidence intervals
78 pub conf_int_upper: Vec<f64>,
79 /// R-squared (coefficient of determination)
80 pub r_squared: f64,
81 /// Adjusted R-squared (accounts for number of predictors)
82 pub adj_r_squared: f64,
83 /// F-statistic for overall model significance
84 pub f_statistic: f64,
85 /// P-value for F-statistic
86 pub f_p_value: f64,
87 /// Mean squared error of residuals
88 pub mse: f64,
89 /// Standard error of the regression (residual standard deviation)
90 pub std_error: f64,
91 /// Raw residuals (observed - predicted)
92 pub residuals: Vec<f64>,
93 /// Standardized residuals (accounting for leverage)
94 pub standardized_residuals: Vec<f64>,
95 /// Fitted/predicted values
96 pub predictions: Vec<f64>,
97 /// Leverage values for each observation (diagonal of hat matrix)
98 pub leverage: Vec<f64>,
99 /// Variance Inflation Factors for detecting multicollinearity
100 pub vif: Vec<VifResult>,
101 /// Number of observations
102 pub n: usize,
103 /// Number of predictor variables (excluding intercept)
104 pub k: usize,
105 /// Degrees of freedom for residuals (n - k - 1)
106 pub df: usize,
107 /// Names of variables (including intercept)
108 pub variable_names: Vec<String>,
109}
110
111// ============================================================================
112// Statistical Helper Functions
113// ============================================================================
114//
115// Utility functions for computing p-values, critical values, and leverage.
116
117/// Computes a two-tailed p-value from a t-statistic.
118///
119/// Uses the Student's t-distribution CDF to calculate the probability
120/// of observing a t-statistic as extreme as the one provided.
121///
122/// # Arguments
123///
124/// * `t` - The t-statistic value
125/// * `df` - Degrees of freedom
126pub fn two_tailed_p_value(t: f64, df: f64) -> f64 {
127 if t.abs() > 100.0 {
128 return 0.0;
129 }
130
131 let cdf = student_t_cdf(t, df);
132 if t >= 0.0 { 2.0 * (1.0 - cdf) } else { 2.0 * cdf }
133}
134
135/// Computes the critical t-value for a given significance level and degrees of freedom.
136///
137/// Returns the t-value such that the area under the t-distribution curve
138/// to the right of it equals alpha/2 (two-tailed test).
139///
140/// # Arguments
141///
142/// * `df` - Degrees of freedom
143/// * `alpha` - Significance level (typically 0.05 for 95% confidence)
144pub fn t_critical_quantile(df: f64, alpha: f64) -> f64 {
145 let p = 1.0 - alpha / 2.0;
146 student_t_inverse_cdf(p, df)
147}
148
149/// Computes a p-value from an F-statistic.
150///
151/// Uses the F-distribution CDF to calculate the probability of observing
152/// an F-statistic as extreme as the one provided.
153///
154/// # Arguments
155///
156/// * `f_stat` - The F-statistic value
157/// * `df1` - Numerator degrees of freedom
158/// * `df2` - Denominator degrees of freedom
159pub fn f_p_value(f_stat: f64, df1: f64, df2: f64) -> f64 {
160 if f_stat <= 0.0 {
161 return 1.0;
162 }
163 1.0 - fisher_snedecor_cdf(f_stat, df1, df2)
164}
165
166/// Computes leverage values from the design matrix and its inverse.
167///
168/// Leverage measures how far an observation's predictor values are from
169/// the center of the predictor space. High leverage points can have
170/// disproportionate influence on the regression results.
171///
172/// # Arguments
173///
174/// * `x` - Design matrix (including intercept column)
175/// * `xtx_inv` - Inverse of X'X matrix
176pub fn compute_leverage(x: &Matrix, xtx_inv: &Matrix) -> Vec<f64> {
177 let n = x.rows;
178 let mut leverage = vec![0.0; n];
179 for i in 0..n {
180 // x_row is (1, cols)
181 // temp = x_row * xtx_inv (1, cols)
182 // lev = temp * x_row^T (1, 1)
183
184 // Manual row extraction and multiplication
185 let mut row_vec = vec![0.0; x.cols];
186 for j in 0..x.cols {
187 row_vec[j] = x.get(i, j);
188 }
189
190 let mut temp_vec = vec![0.0; x.cols];
191 for c in 0..x.cols {
192 let mut sum = 0.0;
193 for k in 0..x.cols {
194 sum += row_vec[k] * xtx_inv.get(k, c);
195 }
196 temp_vec[c] = sum;
197 }
198
199 leverage[i] = vec_dot(&temp_vec, &row_vec);
200 }
201 leverage
202}
203
204// ============================================================================
205// VIF Calculation
206// ============================================================================
207//
208// Variance Inflation Factor analysis for detecting multicollinearity.
209
210/// Calculates Variance Inflation Factors for all predictors.
211///
212/// VIF quantifies the severity of multicollinearity in a regression analysis.
213/// A VIF > 10 indicates high multicollinearity that may need to be addressed.
214///
215/// # Arguments
216///
217/// * `x_vars` - Predictor variables (each of length n)
218/// * `names` - Variable names (including intercept as first element)
219/// * `n` - Number of observations
220///
221/// # Returns
222///
223/// Vector of VIF results for each predictor (excluding intercept).
224pub fn calculate_vif(x_vars: &[Vec<f64>], names: &[String], n: usize) -> Vec<VifResult> {
225 let k = x_vars.len();
226 if k <= 1 {
227 return vec![];
228 }
229
230 // Standardize predictors (Z-score)
231 let mut z_data = vec![0.0; n * k];
232
233 for (j, var) in x_vars.iter().enumerate() {
234 let mean = vec_mean(var);
235 let variance = var.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / ((n - 1) as f64);
236 let std_dev = variance.sqrt();
237
238 // Handle constant variables
239 if std_dev.abs() < 1e-10 {
240 return names.iter().skip(1).map(|name| VifResult {
241 variable: name.clone(),
242 vif: f64::INFINITY,
243 rsquared: 1.0,
244 interpretation: "Constant variable (undefined correlation)".to_string()
245 }).collect();
246 }
247
248 for i in 0..n {
249 z_data[i * k + j] = (var[i] - mean) / std_dev;
250 }
251 }
252
253 let z = Matrix::new(n, k, z_data);
254
255 // Correlation Matrix R = (1/(n-1)) * Z^T * Z
256 let z_t = z.transpose();
257 let zt_z = z_t.matmul(&z);
258
259 // Scale by 1/(n-1)
260 let mut r_corr = zt_z; // Copy
261 let factor = 1.0 / ((n - 1) as f64);
262 for val in &mut r_corr.data {
263 *val *= factor;
264 }
265
266 // Invert R using QR on R_corr (since it's symmetric positive definite, Cholesky is better but QR works)
267 // Or just generic inversion. We implemented generic inversion for Upper Triangular.
268 // Let's use QR: A = QR => A^-1 = R^-1 Q^T
269 let (q_corr, r_corr_tri) = r_corr.qr();
270
271 let r_inv_opt = r_corr_tri.invert_upper_triangular();
272
273 let r_inv = match r_inv_opt {
274 Some(inv) => inv.matmul(&q_corr.transpose()),
275 None => {
276 return names.iter().skip(1).map(|name| VifResult {
277 variable: name.clone(),
278 vif: f64::INFINITY,
279 rsquared: 1.0,
280 interpretation: "Perfect multicollinearity (singular matrix)".to_string()
281 }).collect();
282 }
283 };
284
285 // Extract diagonals
286 let mut results = vec![];
287 for j in 0..k {
288 let vif = r_inv.get(j, j);
289 let vif = if vif < 1.0 { 1.0 } else { vif };
290 let rsquared = 1.0 - 1.0 / vif;
291
292 let interpretation = if vif.is_infinite() {
293 "Perfect multicollinearity".to_string()
294 } else if vif > 10.0 {
295 "High multicollinearity - consider removing".to_string()
296 } else if vif > 5.0 {
297 "Moderate multicollinearity".to_string()
298 } else {
299 "Low multicollinearity".to_string()
300 };
301
302 results.push(VifResult {
303 variable: names[j + 1].clone(),
304 vif,
305 rsquared,
306 interpretation,
307 });
308 }
309
310 results
311}
312
313// ============================================================================
314// OLS Regression
315// ============================================================================
316//
317// Ordinary Least Squares regression implementation using QR decomposition.
318
319/// Performs Ordinary Least Squares regression using QR decomposition.
320///
321/// Uses a numerically stable QR decomposition approach to solve the normal
322/// equations. Returns comprehensive statistics including coefficients,
323/// standard errors, t-statistics, p-values, and diagnostic measures.
324///
325/// # Arguments
326///
327/// * `y` - Response variable (n observations)
328/// * `x_vars` - Predictor variables (each of length n)
329/// * `variable_names` - Names for variables (including intercept)
330///
331/// # Returns
332///
333/// A [`RegressionOutput`] containing all regression statistics and diagnostics.
334///
335/// # Errors
336///
337/// Returns [`Error::InsufficientData`] if n ≤ k + 1.
338/// Returns [`Error::SingularMatrix`] if perfect multicollinearity exists.
339/// Returns [`Error::InvalidInput`] if coefficients are NaN.
340///
341/// # Example
342///
343/// ```
344/// # use linreg_core::core::ols_regression;
345/// # use linreg_core::Error;
346/// let y = vec![2.5, 3.7, 4.2, 5.1, 6.3, 7.0];
347/// let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
348/// let x2 = vec![2.0, 4.0, 5.0, 4.0, 3.0, 2.0];
349/// let names = vec![
350/// "Intercept".to_string(),
351/// "Temperature".to_string(),
352/// "Pressure".to_string(),
353/// ];
354///
355/// let result = ols_regression(&y, &[x1, x2], &names)?;
356/// println!("R-squared: {}", result.r_squared);
357/// # Ok::<(), Error>(())
358/// ```
359pub fn ols_regression(
360 y: &[f64],
361 x_vars: &[Vec<f64>],
362 variable_names: &[String],
363) -> Result<RegressionOutput> {
364 let n = y.len();
365 let k = x_vars.len();
366 let p = k + 1;
367
368 // Validate inputs
369 if n <= k + 1 {
370 return Err(Error::InsufficientData { required: k + 2, available: n });
371 }
372
373 // Prepare variable names
374 let mut names = variable_names.to_vec();
375 while names.len() <= k {
376 names.push(format!("X{}", names.len()));
377 }
378
379 // Create design matrix
380 let mut x_data = vec![0.0; n * p];
381 for (row, _yi) in y.iter().enumerate() {
382 x_data[row * p] = 1.0; // intercept
383 for (col, x_var) in x_vars.iter().enumerate() {
384 x_data[row * p + col + 1] = x_var[row];
385 }
386 }
387
388 let x_matrix = Matrix::new(n, p, x_data);
389
390 // QR Decomposition
391 let (q, r) = x_matrix.qr();
392
393 // Solve R * beta = Q^T * y
394 // extract upper p x p part of R
395 let mut r_upper = Matrix::zeros(p, p);
396 for i in 0..p {
397 for j in 0..p {
398 r_upper.set(i, j, r.get(i, j));
399 }
400 }
401
402 // Q^T * y
403 let q_t = q.transpose();
404 let qty = q_t.mul_vec(y);
405
406 // Take first p elements of qty
407 let rhs_vec = qty[0..p].to_vec();
408 let rhs_mat = Matrix::new(p, 1, rhs_vec); // column vector
409
410 // Invert R_upper
411 let r_inv = match r_upper.invert_upper_triangular() {
412 Some(inv) => inv,
413 None => return Err(Error::SingularMatrix),
414 };
415
416 let beta_mat = r_inv.matmul(&rhs_mat);
417 let beta = beta_mat.data;
418
419 // Validate coefficients
420 if beta.iter().any(|&b| b.is_nan()) {
421 return Err(Error::InvalidInput("Coefficients contain NaN".to_string()));
422 }
423
424 // Compute predictions and residuals
425 let predictions = x_matrix.mul_vec(&beta);
426 let residuals = vec_sub(y, &predictions);
427
428 // Compute sums of squares
429 let y_mean = vec_mean(y);
430 let ss_total: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
431 let ss_residual: f64 = residuals.iter().map(|&r| r.powi(2)).sum();
432 let ss_regression = ss_total - ss_residual;
433
434 // R-squared and adjusted R-squared
435 let r_squared = if ss_total == 0.0 {
436 f64::NAN
437 } else {
438 1.0 - ss_residual / ss_total
439 };
440
441 let adj_r_squared = if ss_total == 0.0 {
442 f64::NAN
443 } else {
444 1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n - k - 1) as f64)
445 };
446
447 // Mean squared error and standard error
448 let df = n - k - 1;
449 let mse = ss_residual / df as f64;
450 let std_error = mse.sqrt();
451
452 // Standard errors using (X'X)^-1 = R^-1 (R')^-1
453 // xtx_inv = r_inv * r_inv^T
454 let xtx_inv = r_inv.matmul(&r_inv.transpose());
455
456 let mut std_errors = vec![0.0; k + 1];
457 for i in 0..=k {
458 std_errors[i] = (xtx_inv.get(i, i) * mse).sqrt();
459 if std_errors[i].is_nan() {
460 return Err(Error::InvalidInput(format!("Standard error for coefficient {} is NaN", i)));
461 }
462 }
463
464 // T-statistics and p-values
465 let t_stats: Vec<f64> = beta.iter().zip(&std_errors).map(|(&b, &se)| b / se).collect();
466 let p_values: Vec<f64> = t_stats.iter().map(|&t| two_tailed_p_value(t, df as f64)).collect();
467
468 // Confidence intervals
469 let alpha = 0.05;
470 let t_critical = t_critical_quantile(df as f64, alpha);
471
472 let conf_int_lower: Vec<f64> = beta.iter().zip(&std_errors).map(|(&b, &se)| b - t_critical * se).collect();
473 let conf_int_upper: Vec<f64> = beta.iter().zip(&std_errors).map(|(&b, &se)| b + t_critical * se).collect();
474
475 // Leverage
476 let leverage = compute_leverage(&x_matrix, &xtx_inv);
477
478 // Standardized residuals
479 let residuals_vec = residuals.clone();
480 let standardized_residuals: Vec<f64> = residuals_vec.iter().zip(&leverage)
481 .map(|(&r, &h)| {
482 let factor = (1.0 - h).max(MIN_LEVERAGE_DENOM).sqrt();
483 let denom = std_error * factor;
484 if denom > MIN_LEVERAGE_DENOM { r / denom } else { 0.0 }
485 })
486 .collect();
487
488 // F-statistic
489 let f_statistic = (ss_regression / k as f64) / mse;
490 let f_p_value = f_p_value(f_statistic, k as f64, df as f64);
491
492 // VIF
493 let vif = calculate_vif(x_vars, &names, n);
494
495 Ok(RegressionOutput {
496 coefficients: beta,
497 std_errors,
498 t_stats,
499 p_values,
500 conf_int_lower,
501 conf_int_upper,
502 r_squared,
503 adj_r_squared,
504 f_statistic,
505 f_p_value,
506 mse,
507 std_error,
508 residuals: residuals_vec,
509 standardized_residuals,
510 predictions,
511 leverage,
512 vif,
513 n,
514 k,
515 df,
516 variable_names: names,
517 })
518}