Skip to main content

u_analytics/
regression.rs

1//! Regression analysis.
2//!
3//! Simple and multiple linear regression with OLS, R², coefficient testing,
4//! ANOVA, VIF, and residual diagnostics.
5//!
6//! # Examples
7//!
8//! ```
9//! use u_analytics::regression::simple_linear_regression;
10//!
11//! let x = [1.0, 2.0, 3.0, 4.0, 5.0];
12//! let y = [2.1, 3.9, 6.1, 7.9, 10.1];
13//! let result = simple_linear_regression(&x, &y).unwrap();
14//! assert!((result.slope - 2.0).abs() < 0.1);
15//! assert!((result.intercept - 0.1).abs() < 0.2);
16//! assert!(result.r_squared > 0.99);
17//! ```
18
19use u_numflow::matrix::Matrix;
20use u_numflow::special;
21use u_numflow::stats;
22
23/// Result of a simple linear regression: y = intercept + slope · x.
24#[derive(Debug, Clone)]
25pub struct SimpleRegressionResult {
26    /// Slope coefficient (β₁).
27    pub slope: f64,
28    /// Intercept (β₀).
29    pub intercept: f64,
30    /// Coefficient of determination (R²).
31    pub r_squared: f64,
32    /// Adjusted R² = 1 - (1-R²)(n-1)/(n-2).
33    pub adjusted_r_squared: f64,
34    /// Standard error of the slope.
35    pub slope_se: f64,
36    /// Standard error of the intercept.
37    pub intercept_se: f64,
38    /// t-statistic for slope (H₀: β₁ = 0).
39    pub slope_t: f64,
40    /// t-statistic for intercept (H₀: β₀ = 0).
41    pub intercept_t: f64,
42    /// p-value for slope.
43    pub slope_p: f64,
44    /// p-value for intercept.
45    pub intercept_p: f64,
46    /// Residual standard error (√(SSE/(n-2))).
47    pub residual_se: f64,
48    /// F-statistic (= t² for simple regression).
49    pub f_statistic: f64,
50    /// p-value for F-statistic.
51    pub f_p_value: f64,
52    /// Residuals (yᵢ - ŷᵢ).
53    pub residuals: Vec<f64>,
54    /// Fitted values (ŷᵢ).
55    pub fitted: Vec<f64>,
56    /// Sample size.
57    pub n: usize,
58}
59
60/// Result of a multiple linear regression: y = Xβ + ε.
61#[derive(Debug, Clone)]
62pub struct MultipleRegressionResult {
63    /// Coefficient vector [β₀, β₁, ..., βₚ] (intercept first).
64    pub coefficients: Vec<f64>,
65    /// Standard errors of coefficients.
66    pub std_errors: Vec<f64>,
67    /// t-statistics for each coefficient.
68    pub t_statistics: Vec<f64>,
69    /// p-values for each coefficient.
70    pub p_values: Vec<f64>,
71    /// Coefficient of determination (R²).
72    pub r_squared: f64,
73    /// Adjusted R² = 1 - (1-R²)(n-1)/(n-p-1).
74    pub adjusted_r_squared: f64,
75    /// F-statistic for overall significance.
76    pub f_statistic: f64,
77    /// p-value for F-statistic.
78    pub f_p_value: f64,
79    /// Residual standard error.
80    pub residual_se: f64,
81    /// Residuals.
82    pub residuals: Vec<f64>,
83    /// Fitted values.
84    pub fitted: Vec<f64>,
85    /// VIF (Variance Inflation Factor) for each predictor (excludes intercept).
86    pub vif: Vec<f64>,
87    /// Sample size.
88    pub n: usize,
89    /// Number of predictors (excluding intercept).
90    pub p: usize,
91}
92
93// ---------------------------------------------------------------------------
94// Simple Linear Regression
95// ---------------------------------------------------------------------------
96
97/// Computes simple linear regression (OLS closed-form).
98///
99/// # Algorithm
100///
101/// β₁ = cov(x,y) / var(x)
102/// β₀ = ȳ - β₁·x̄
103///
104/// # Returns
105///
106/// `None` if fewer than 3 observations, slices differ in length, x has zero
107/// variance, or inputs contain non-finite values.
108///
109/// # References
110///
111/// Draper & Smith (1998). "Applied Regression Analysis", 3rd edition.
112///
113/// # Examples
114///
115/// ```
116/// use u_analytics::regression::simple_linear_regression;
117///
118/// let x = [1.0, 2.0, 3.0, 4.0, 5.0];
119/// let y = [2.0, 4.0, 6.0, 8.0, 10.0];
120/// let r = simple_linear_regression(&x, &y).unwrap();
121/// assert!((r.slope - 2.0).abs() < 1e-10);
122/// assert!((r.intercept).abs() < 1e-10);
123/// assert!((r.r_squared - 1.0).abs() < 1e-10);
124/// ```
125pub fn simple_linear_regression(x: &[f64], y: &[f64]) -> Option<SimpleRegressionResult> {
126    let n = x.len();
127    if n < 3 || n != y.len() {
128        return None;
129    }
130    if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) {
131        return None;
132    }
133
134    let x_mean = stats::mean(x)?;
135    let y_mean = stats::mean(y)?;
136    let x_var = stats::variance(x)?;
137    let cov = stats::covariance(x, y)?;
138
139    if x_var < 1e-300 {
140        return None; // zero variance in x
141    }
142
143    let slope = cov / x_var;
144    let intercept = y_mean - slope * x_mean;
145
146    // Fitted values and residuals
147    let fitted: Vec<f64> = x.iter().map(|&xi| intercept + slope * xi).collect();
148    let residuals: Vec<f64> = y
149        .iter()
150        .zip(fitted.iter())
151        .map(|(&yi, &fi)| yi - fi)
152        .collect();
153
154    // Sum of squares
155    let ss_res: f64 = residuals.iter().map(|r| r * r).sum();
156    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
157
158    let nf = n as f64;
159    let df_res = nf - 2.0;
160
161    let r_squared = if ss_tot > 1e-300 {
162        1.0 - ss_res / ss_tot
163    } else {
164        1.0
165    };
166    let adjusted_r_squared = 1.0 - (1.0 - r_squared) * (nf - 1.0) / df_res;
167
168    // Residual standard error
169    let mse = ss_res / df_res;
170    let residual_se = mse.sqrt();
171
172    // Standard errors of coefficients
173    let ss_x: f64 = x.iter().map(|&xi| (xi - x_mean).powi(2)).sum();
174    let slope_se = (mse / ss_x).sqrt();
175    let intercept_se = (mse * (1.0 / nf + x_mean * x_mean / ss_x)).sqrt();
176
177    // t-statistics
178    let slope_t = if slope_se > 1e-300 {
179        slope / slope_se
180    } else {
181        f64::INFINITY
182    };
183    let intercept_t = if intercept_se > 1e-300 {
184        intercept / intercept_se
185    } else {
186        f64::INFINITY
187    };
188
189    // p-values via t-distribution
190    let slope_p = 2.0 * (1.0 - special::t_distribution_cdf(slope_t.abs(), df_res));
191    let intercept_p = 2.0 * (1.0 - special::t_distribution_cdf(intercept_t.abs(), df_res));
192
193    // F-statistic (= t² for simple regression)
194    let f_statistic = slope_t * slope_t;
195    let f_p_value = if f_statistic.is_infinite() {
196        0.0
197    } else {
198        1.0 - special::f_distribution_cdf(f_statistic, 1.0, df_res)
199    };
200
201    Some(SimpleRegressionResult {
202        slope,
203        intercept,
204        r_squared,
205        adjusted_r_squared,
206        slope_se,
207        intercept_se,
208        slope_t,
209        intercept_t,
210        slope_p,
211        intercept_p,
212        residual_se,
213        f_statistic,
214        f_p_value,
215        residuals,
216        fitted,
217        n,
218    })
219}
220
221// ---------------------------------------------------------------------------
222// Multiple Linear Regression
223// ---------------------------------------------------------------------------
224
225/// Computes multiple linear regression via OLS (Cholesky solve).
226///
227/// # Arguments
228///
229/// * `predictors` — Slice of predictor variable slices. Each inner slice is
230///   one predictor's observations. All must have the same length.
231/// * `y` — Response variable observations.
232///
233/// # Algorithm
234///
235/// Constructs the design matrix X = [1 | x₁ | x₂ | ... | xₚ] (intercept column prepended),
236/// then solves the normal equations X'Xβ = X'y via Cholesky decomposition.
237///
238/// # Returns
239///
240/// `None` if n < p+2, predictor lengths differ, inputs contain non-finite
241/// values, or the system is singular.
242///
243/// # References
244///
245/// Draper & Smith (1998). "Applied Regression Analysis", 3rd edition.
246/// Montgomery, Peck & Vining (2012). "Introduction to Linear Regression Analysis", 5th edition.
247///
248/// # Examples
249///
250/// ```
251/// use u_analytics::regression::multiple_linear_regression;
252///
253/// let x1 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
254/// let x2 = [2.0, 1.0, 3.0, 2.0, 4.0, 3.0, 5.0, 4.0];
255/// let y  = [5.1, 5.0, 9.2, 8.9, 13.1, 12.0, 17.2, 15.9];
256/// let result = multiple_linear_regression(&[&x1, &x2], &y).unwrap();
257/// assert!(result.r_squared > 0.95);
258/// assert_eq!(result.coefficients.len(), 3); // intercept + 2 predictors
259/// ```
260pub fn multiple_linear_regression(
261    predictors: &[&[f64]],
262    y: &[f64],
263) -> Option<MultipleRegressionResult> {
264    let p = predictors.len(); // number of predictors
265    let n = y.len();
266
267    if p == 0 || n < p + 2 {
268        return None;
269    }
270
271    // Validate lengths and finite values
272    for pred in predictors {
273        if pred.len() != n {
274            return None;
275        }
276        if pred.iter().any(|v| !v.is_finite()) {
277            return None;
278        }
279    }
280    if y.iter().any(|v| !v.is_finite()) {
281        return None;
282    }
283
284    let ncols = p + 1; // intercept + predictors
285
286    // Build design matrix X (n × ncols, row-major)
287    let mut x_data = Vec::with_capacity(n * ncols);
288    for i in 0..n {
289        x_data.push(1.0); // intercept
290        for pred in predictors {
291            x_data.push(pred[i]);
292        }
293    }
294    let x_mat = Matrix::new(n, ncols, x_data).ok()?;
295
296    // X'X (ncols × ncols)
297    let xt = x_mat.transpose();
298    let xtx = xt.mul_mat(&x_mat).ok()?;
299
300    // X'y (ncols × 1)
301    let xty = xt.mul_vec(y).ok()?;
302
303    // Solve via Cholesky: β = (X'X)⁻¹ X'y
304    let coefficients = xtx.cholesky_solve(&xty).ok()?;
305
306    // Fitted values and residuals
307    let fitted = x_mat.mul_vec(&coefficients).ok()?;
308    let residuals: Vec<f64> = y
309        .iter()
310        .zip(fitted.iter())
311        .map(|(&yi, &fi)| yi - fi)
312        .collect();
313
314    // R² and adjusted R²
315    let y_mean = stats::mean(y)?;
316    let ss_res: f64 = residuals.iter().map(|r| r * r).sum();
317    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
318
319    let nf = n as f64;
320    let pf = p as f64;
321    let df_res = nf - pf - 1.0;
322
323    let r_squared = if ss_tot > 1e-300 {
324        1.0 - ss_res / ss_tot
325    } else {
326        1.0
327    };
328    let adjusted_r_squared = 1.0 - (1.0 - r_squared) * (nf - 1.0) / df_res;
329
330    // Residual standard error
331    let mse = ss_res / df_res;
332    let residual_se = mse.sqrt();
333
334    // Standard errors: SE(β) = sqrt(diag((X'X)⁻¹) · MSE)
335    let xtx_inv = xtx.inverse().ok()?;
336    let mut std_errors = Vec::with_capacity(ncols);
337    let mut t_statistics = Vec::with_capacity(ncols);
338    let mut p_values = Vec::with_capacity(ncols);
339
340    for (j, &coeff_j) in coefficients.iter().enumerate() {
341        let se = (xtx_inv.get(j, j) * mse).sqrt();
342        std_errors.push(se);
343        let t = if se > 1e-300 {
344            coeff_j / se
345        } else {
346            f64::INFINITY
347        };
348        t_statistics.push(t);
349        let pv = 2.0 * (1.0 - special::t_distribution_cdf(t.abs(), df_res));
350        p_values.push(pv);
351    }
352
353    // F-statistic: (SS_reg / p) / (SS_res / (n-p-1))
354    let ss_reg = ss_tot - ss_res;
355    let f_statistic = if pf > 0.0 && mse > 1e-300 {
356        (ss_reg / pf) / mse
357    } else {
358        0.0
359    };
360    let f_p_value = if f_statistic.is_infinite() || f_statistic.is_nan() {
361        0.0
362    } else {
363        1.0 - special::f_distribution_cdf(f_statistic, pf, df_res)
364    };
365
366    // VIF for each predictor: VIF_j = 1/(1 - R²_j) where R²_j is from
367    // regressing x_j on all other predictors
368    let vif = compute_vif(predictors);
369
370    Some(MultipleRegressionResult {
371        coefficients,
372        std_errors,
373        t_statistics,
374        p_values,
375        r_squared,
376        adjusted_r_squared,
377        f_statistic,
378        f_p_value,
379        residual_se,
380        residuals,
381        fitted,
382        vif,
383        n,
384        p,
385    })
386}
387
388/// Computes VIF for each predictor by regressing each on all others.
389fn compute_vif(predictors: &[&[f64]]) -> Vec<f64> {
390    let p = predictors.len();
391    if p < 2 {
392        return vec![1.0; p];
393    }
394
395    let mut vif = Vec::with_capacity(p);
396    for j in 0..p {
397        // Regress x_j on all other predictors
398        let y_j = predictors[j];
399        let others: Vec<&[f64]> = predictors
400            .iter()
401            .enumerate()
402            .filter(|&(i, _)| i != j)
403            .map(|(_, v)| *v)
404            .collect();
405
406        if let Some(result) = multiple_linear_regression(&others, y_j) {
407            let r2 = result.r_squared;
408            if r2 < 1.0 - 1e-15 {
409                vif.push(1.0 / (1.0 - r2));
410            } else {
411                vif.push(f64::INFINITY); // perfect multicollinearity
412            }
413        } else {
414            vif.push(f64::NAN);
415        }
416    }
417    vif
418}
419
420/// Predicts y values given new x data and a simple regression result.
421///
422/// # Examples
423///
424/// ```
425/// use u_analytics::regression::{simple_linear_regression, predict_simple};
426///
427/// let x = [1.0, 2.0, 3.0, 4.0, 5.0];
428/// let y = [2.0, 4.0, 6.0, 8.0, 10.0];
429/// let model = simple_linear_regression(&x, &y).unwrap();
430/// let pred = predict_simple(&model, &[6.0, 7.0]);
431/// assert!((pred[0] - 12.0).abs() < 1e-10);
432/// assert!((pred[1] - 14.0).abs() < 1e-10);
433/// ```
434pub fn predict_simple(model: &SimpleRegressionResult, x_new: &[f64]) -> Vec<f64> {
435    x_new
436        .iter()
437        .map(|&xi| model.intercept + model.slope * xi)
438        .collect()
439}
440
441/// Predicts y values given new predictor data and a multiple regression result.
442///
443/// # Arguments
444///
445/// * `model` — Multiple regression result.
446/// * `predictors_new` — Slice of predictor slices (same order as training).
447///
448/// Returns `None` if predictor count doesn't match or lengths differ.
449pub fn predict_multiple(
450    model: &MultipleRegressionResult,
451    predictors_new: &[&[f64]],
452) -> Option<Vec<f64>> {
453    if predictors_new.len() != model.p {
454        return None;
455    }
456    let n = predictors_new.first()?.len();
457    for pred in predictors_new {
458        if pred.len() != n {
459            return None;
460        }
461    }
462
463    let mut result = Vec::with_capacity(n);
464    for i in 0..n {
465        let mut yi = model.coefficients[0]; // intercept
466        for (j, pred) in predictors_new.iter().enumerate() {
467            yi += model.coefficients[j + 1] * pred[i];
468        }
469        result.push(yi);
470    }
471    Some(result)
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477
478    // -----------------------------------------------------------------------
479    // Simple regression tests
480    // -----------------------------------------------------------------------
481
482    #[test]
483    fn simple_perfect_fit() {
484        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
485        let y = [3.0, 5.0, 7.0, 9.0, 11.0]; // y = 1 + 2x
486        let r = simple_linear_regression(&x, &y).expect("should compute");
487        assert!((r.slope - 2.0).abs() < 1e-10);
488        assert!((r.intercept - 1.0).abs() < 1e-10);
489        assert!((r.r_squared - 1.0).abs() < 1e-10);
490    }
491
492    #[test]
493    fn simple_with_noise() {
494        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
495        let y = [2.1, 3.9, 6.1, 7.9, 10.1]; // y ≈ 2x + 0.1
496        let r = simple_linear_regression(&x, &y).expect("should compute");
497        assert!((r.slope - 2.0).abs() < 0.1);
498        assert!(r.r_squared > 0.99);
499        assert_eq!(r.residuals.len(), 5);
500        assert_eq!(r.fitted.len(), 5);
501    }
502
503    #[test]
504    fn simple_residuals_sum_to_zero() {
505        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
506        let y = [2.1, 3.9, 6.1, 7.9, 10.1];
507        let r = simple_linear_regression(&x, &y).expect("should compute");
508        let sum: f64 = r.residuals.iter().sum();
509        assert!(sum.abs() < 1e-10, "residuals sum = {sum}");
510    }
511
512    #[test]
513    fn simple_f_equals_t_squared() {
514        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
515        let y = [2.1, 3.9, 6.1, 7.9, 10.1];
516        let r = simple_linear_regression(&x, &y).expect("should compute");
517        assert!(
518            (r.f_statistic - r.slope_t * r.slope_t).abs() < 1e-8,
519            "F = {}, t² = {}",
520            r.f_statistic,
521            r.slope_t * r.slope_t
522        );
523    }
524
525    #[test]
526    fn simple_significant_slope() {
527        let x: Vec<f64> = (0..20).map(|i| i as f64).collect();
528        let y: Vec<f64> = x.iter().map(|&xi| 3.0 + 2.0 * xi).collect();
529        let r = simple_linear_regression(&x, &y).expect("should compute");
530        assert!(r.slope_p < 1e-10, "slope p = {}", r.slope_p);
531        assert!(r.f_p_value < 1e-10, "F p = {}", r.f_p_value);
532    }
533
534    #[test]
535    fn simple_negative_slope() {
536        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
537        let y = [10.0, 8.0, 6.0, 4.0, 2.0]; // y = 12 - 2x
538        let r = simple_linear_regression(&x, &y).expect("should compute");
539        assert!((r.slope + 2.0).abs() < 1e-10);
540        assert!((r.intercept - 12.0).abs() < 1e-10);
541    }
542
543    #[test]
544    fn simple_edge_cases() {
545        assert!(simple_linear_regression(&[1.0, 2.0], &[3.0, 4.0]).is_none()); // n < 3
546        assert!(simple_linear_regression(&[1.0, 2.0, 3.0], &[4.0, 5.0]).is_none()); // mismatch
547        assert!(simple_linear_regression(&[5.0, 5.0, 5.0], &[1.0, 2.0, 3.0]).is_none()); // zero var
548        assert!(simple_linear_regression(&[1.0, f64::NAN, 3.0], &[4.0, 5.0, 6.0]).is_none());
549    }
550
551    /// Verifies exact OLS numeric reference from the spec.
552    ///
553    /// x = [1,2,3,4,5], y = [2,4,5,4,5]
554    /// x̄=3, ȳ=4
555    /// Σ(xᵢ-x̄)(yᵢ-ȳ) = 6, Σ(xᵢ-x̄)² = 10
556    /// β̂₁ = 6/10 = 0.6, β̂₀ = 4 - 0.6·3 = 2.2
557    /// SS_tot = 6, SS_res = 2.4, R² = 1 - 2.4/6 = 0.6
558    ///
559    /// Reference: Draper & Smith (1998), Applied Regression Analysis, 3rd ed.
560    #[test]
561    fn simple_numeric_reference_ols() {
562        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
563        let y = [2.0, 4.0, 5.0, 4.0, 5.0];
564        let r = simple_linear_regression(&x, &y).expect("should compute");
565
566        assert!(
567            (r.slope - 0.6).abs() < 1e-10,
568            "β̂₁ expected 0.6, got {}",
569            r.slope
570        );
571        assert!(
572            (r.intercept - 2.2).abs() < 1e-10,
573            "β̂₀ expected 2.2, got {}",
574            r.intercept
575        );
576        assert!(
577            (r.r_squared - 0.6).abs() < 1e-3,
578            "R² expected 0.6, got {}",
579            r.r_squared
580        );
581
582        // Verify fitted values
583        // ŷ₁=2.8, ŷ₂=3.4, ŷ₃=4.0, ŷ₄=4.6, ŷ₅=5.2
584        let expected_fitted = [2.8, 3.4, 4.0, 4.6, 5.2];
585        for (i, (&fi, &ef)) in r.fitted.iter().zip(expected_fitted.iter()).enumerate() {
586            assert!(
587                (fi - ef).abs() < 1e-10,
588                "ŷ_{} expected {}, got {}",
589                i + 1,
590                ef,
591                fi
592            );
593        }
594    }
595
596    #[test]
597    fn simple_adjusted_r_squared() {
598        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
599        let y = [2.1, 3.9, 6.1, 7.9, 10.1];
600        let r = simple_linear_regression(&x, &y).expect("should compute");
601        // For perfect fit, adjusted R² ≈ R²
602        assert!(r.adjusted_r_squared <= r.r_squared);
603        assert!(r.adjusted_r_squared > 0.98);
604    }
605
606    // -----------------------------------------------------------------------
607    // Multiple regression tests
608    // -----------------------------------------------------------------------
609
610    #[test]
611    fn multiple_perfect_fit() {
612        // y = 1 + 2*x1 + 3*x2
613        let x1 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
614        let x2 = [2.0, 1.0, 3.0, 2.0, 4.0, 3.0, 5.0, 4.0, 6.0, 5.0];
615        let y: Vec<f64> = x1
616            .iter()
617            .zip(x2.iter())
618            .map(|(&a, &b)| 1.0 + 2.0 * a + 3.0 * b)
619            .collect();
620        let r = multiple_linear_regression(&[&x1, &x2], &y).expect("should compute");
621
622        assert!(
623            (r.coefficients[0] - 1.0).abs() < 1e-8,
624            "β₀ = {}",
625            r.coefficients[0]
626        );
627        assert!(
628            (r.coefficients[1] - 2.0).abs() < 1e-8,
629            "β₁ = {}",
630            r.coefficients[1]
631        );
632        assert!(
633            (r.coefficients[2] - 3.0).abs() < 1e-8,
634            "β₂ = {}",
635            r.coefficients[2]
636        );
637        assert!((r.r_squared - 1.0).abs() < 1e-8);
638    }
639
640    #[test]
641    fn multiple_with_noise() {
642        let x1 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
643        let x2 = [2.0, 1.0, 3.0, 2.0, 4.0, 3.0, 5.0, 4.0];
644        let y = [5.1, 5.0, 9.2, 8.9, 13.1, 12.0, 17.2, 15.9];
645        let r = multiple_linear_regression(&[&x1, &x2], &y).expect("should compute");
646        assert!(r.r_squared > 0.95);
647        assert_eq!(r.coefficients.len(), 3);
648        assert_eq!(r.std_errors.len(), 3);
649        assert_eq!(r.residuals.len(), 8);
650        assert_eq!(r.vif.len(), 2);
651    }
652
653    #[test]
654    fn multiple_residuals_sum_to_zero() {
655        let x1 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
656        let x2 = [2.0, 1.0, 3.0, 2.0, 4.0, 3.0, 5.0, 4.0];
657        let y = [5.1, 5.0, 9.2, 8.9, 13.1, 12.0, 17.2, 15.9];
658        let r = multiple_linear_regression(&[&x1, &x2], &y).expect("should compute");
659        let sum: f64 = r.residuals.iter().sum();
660        assert!(sum.abs() < 1e-8, "residuals sum = {sum}");
661    }
662
663    #[test]
664    fn multiple_single_predictor_matches_simple() {
665        let x = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
666        let y: Vec<f64> = x
667            .iter()
668            .map(|&xi| 3.0 + 2.5 * xi + 0.1 * (xi - 5.0))
669            .collect();
670
671        let simple = simple_linear_regression(&x, &y).expect("simple");
672        let multi = multiple_linear_regression(&[&x], &y).expect("multiple");
673
674        assert!(
675            (simple.slope - multi.coefficients[1]).abs() < 1e-8,
676            "slope: {} vs {}",
677            simple.slope,
678            multi.coefficients[1]
679        );
680        assert!(
681            (simple.intercept - multi.coefficients[0]).abs() < 1e-8,
682            "intercept: {} vs {}",
683            simple.intercept,
684            multi.coefficients[0]
685        );
686        assert!((simple.r_squared - multi.r_squared).abs() < 1e-8);
687    }
688
689    #[test]
690    fn multiple_edge_cases() {
691        let x1 = [1.0, 2.0];
692        let y = [3.0, 4.0];
693        assert!(multiple_linear_regression(&[&x1], &y).is_none()); // n < p+2
694
695        let x2 = [1.0, 2.0, 3.0];
696        let y2 = [4.0, 5.0];
697        assert!(multiple_linear_regression(&[&x2], &y2).is_none()); // length mismatch
698    }
699
700    #[test]
701    fn multiple_vif_independent_predictors() {
702        // Independent predictors → VIF ≈ 1
703        let x1 = [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0];
704        let x2 = [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0];
705        let y: Vec<f64> = x1
706            .iter()
707            .zip(x2.iter())
708            .map(|(&a, &b)| 1.0 + 2.0 * a + 3.0 * b)
709            .collect();
710        let r = multiple_linear_regression(&[&x1, &x2], &y).expect("should compute");
711        for (i, &v) in r.vif.iter().enumerate() {
712            assert!(
713                v < 2.0,
714                "VIF[{i}] = {v}, expected near 1.0 for independent predictors"
715            );
716        }
717    }
718
719    #[test]
720    fn multiple_vif_correlated_predictors() {
721        // Highly but not perfectly correlated predictors → higher VIF
722        let x1: Vec<f64> = (0..20).map(|i| i as f64).collect();
723        // Add small perturbation to break perfect collinearity
724        let noise = [
725            0.1, -0.2, 0.3, -0.1, 0.2, -0.3, 0.1, -0.1, 0.2, -0.2, 0.3, -0.1, 0.1, -0.2, 0.3, -0.3,
726            0.1, -0.1, 0.2, -0.2,
727        ];
728        let x2: Vec<f64> = x1
729            .iter()
730            .zip(noise.iter())
731            .map(|(&v, &n)| v * 0.9 + 1.0 + n)
732            .collect();
733        let y: Vec<f64> = x1
734            .iter()
735            .zip(x2.iter())
736            .map(|(&a, &b)| 1.0 + a + b)
737            .collect();
738        let r = multiple_linear_regression(&[&x1, &x2], &y).expect("should compute");
739        // VIF > 5 for highly correlated predictors
740        assert!(
741            r.vif[0] > 5.0,
742            "VIF[0] = {}, expected > 5.0 for correlated predictors",
743            r.vif[0]
744        );
745    }
746
747    // -----------------------------------------------------------------------
748    // Prediction tests
749    // -----------------------------------------------------------------------
750
751    #[test]
752    fn predict_simple_basic() {
753        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
754        let y = [3.0, 5.0, 7.0, 9.0, 11.0]; // y = 1 + 2x
755        let model = simple_linear_regression(&x, &y).expect("should compute");
756        let pred = predict_simple(&model, &[6.0, 7.0]);
757        assert!((pred[0] - 13.0).abs() < 1e-10);
758        assert!((pred[1] - 15.0).abs() < 1e-10);
759    }
760
761    #[test]
762    fn predict_multiple_basic() {
763        let x1 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
764        let x2 = [2.0, 1.0, 3.0, 2.0, 4.0, 3.0, 5.0, 4.0, 6.0, 5.0];
765        let y: Vec<f64> = x1
766            .iter()
767            .zip(x2.iter())
768            .map(|(&a, &b)| 1.0 + 2.0 * a + 3.0 * b)
769            .collect();
770        let model = multiple_linear_regression(&[&x1, &x2], &y).expect("should compute");
771
772        let new_x1 = [11.0];
773        let new_x2 = [6.0];
774        let pred = predict_multiple(&model, &[&new_x1, &new_x2]).expect("should predict");
775        let expected = 1.0 + 2.0 * 11.0 + 3.0 * 6.0;
776        assert!(
777            (pred[0] - expected).abs() < 1e-6,
778            "pred = {}, expected = {}",
779            pred[0],
780            expected
781        );
782    }
783
784    #[test]
785    fn predict_multiple_wrong_predictors() {
786        let x1 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
787        let x2 = [2.0, 1.0, 3.0, 2.0, 4.0, 3.0, 5.0, 4.0, 6.0, 5.0];
788        let y: Vec<f64> = x1.iter().zip(x2.iter()).map(|(&a, &b)| a + b).collect();
789        let model = multiple_linear_regression(&[&x1, &x2], &y).expect("should compute");
790
791        // Wrong number of predictors
792        assert!(predict_multiple(&model, &[&[1.0]]).is_none());
793    }
794}
795
796#[cfg(test)]
797mod proptests {
798    use super::*;
799    use proptest::prelude::*;
800
801    proptest! {
802        #[test]
803        fn simple_r_squared_bounded(
804            data in proptest::collection::vec(-1e3_f64..1e3, 5..=30)
805                .prop_flat_map(|x| {
806                    let n = x.len();
807                    (Just(x), proptest::collection::vec(-1e3_f64..1e3, n..=n))
808                })
809        ) {
810            let (x, y) = data;
811            if let Some(r) = simple_linear_regression(&x, &y) {
812                prop_assert!(r.r_squared >= -0.01 && r.r_squared <= 1.01,
813                    "R² = {}", r.r_squared);
814            }
815        }
816
817        #[test]
818        fn simple_residuals_orthogonal_to_x(
819            data in proptest::collection::vec(-1e3_f64..1e3, 5..=30)
820                .prop_flat_map(|x| {
821                    let n = x.len();
822                    (Just(x), proptest::collection::vec(-1e3_f64..1e3, n..=n))
823                })
824        ) {
825            let (x, y) = data;
826            if let Some(r) = simple_linear_regression(&x, &y) {
827                // Σ(xᵢ · eᵢ) should be near zero (OLS normal equation)
828                let dot: f64 = x.iter().zip(r.residuals.iter()).map(|(&xi, &ei)| xi * ei).sum();
829                let norm = r.residuals.iter().map(|e| e * e).sum::<f64>().sqrt();
830                let x_norm = x.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
831                if norm > 1e-10 && x_norm > 1e-10 {
832                    prop_assert!((dot / (norm * x_norm)).abs() < 1e-6,
833                        "residuals not orthogonal to x: dot={dot}");
834                }
835            }
836        }
837
838        #[test]
839        fn multiple_r_squared_bounded(
840            x1 in proptest::collection::vec(-1e3_f64..1e3, 8..=20),
841            x2_seed in proptest::collection::vec(-1e3_f64..1e3, 8..=20),
842            y_seed in proptest::collection::vec(-1e3_f64..1e3, 8..=20),
843        ) {
844            let n = x1.len().min(x2_seed.len()).min(y_seed.len());
845            let x2 = &x2_seed[..n];
846            let y = &y_seed[..n];
847            let x1 = &x1[..n];
848            if let Some(r) = multiple_linear_regression(&[x1, x2], y) {
849                prop_assert!(r.r_squared >= -0.01 && r.r_squared <= 1.01,
850                    "R² = {}", r.r_squared);
851            }
852        }
853    }
854}