Skip to main content

linreg_core/weighted_regression/
wls.rs

1//! Weighted Least Squares (WLS) regression
2//!
3//! This module provides WLS regression using the weighted least squares solver
4//! from the LOESS module. WLS is useful when:
5//! - Observations have different precision/variances (heteroscedasticity)
6//! - You want to incorporate robustness weights from a previous fit
7//! - Certain observations should be given more influence
8//!
9//! The output format matches R's `lm()` function with weights, providing:
10//! - Coefficient estimates with standard errors, t-values, and p-values
11//! - F-statistic and p-value for overall model significance
12//! - Residual standard error, R², adjusted R²
13
14use crate::{
15    core::{f_p_value, t_critical_quantile},
16    distributions::student_t_cdf,
17    error::{Error, Result},
18    linalg::Matrix,
19    serialization::types::ModelType,
20    impl_serialization,
21};
22use serde::{Deserialize, Serialize};
23
24/// WLS regression result
25///
26/// Contains the fitted coefficients and comprehensive model fit statistics
27/// matching R's `summary(lm(y ~ x, weights=w))` output.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct WlsFit {
30    // ============================================================
31    // Coefficient Statistics (matching R's coefficients table)
32    // ============================================================
33    /// Coefficient values (including intercept as first element)
34    pub coefficients: Vec<f64>,
35
36    /// Standard errors of the coefficients
37    pub standard_errors: Vec<f64>,
38
39    /// t-statistics for coefficient significance tests
40    pub t_statistics: Vec<f64>,
41
42    /// Two-tailed p-values for coefficients
43    pub p_values: Vec<f64>,
44
45    /// Lower bounds of 95% confidence intervals for coefficients
46    pub conf_int_lower: Vec<f64>,
47
48    /// Upper bounds of 95% confidence intervals for coefficients
49    pub conf_int_upper: Vec<f64>,
50
51    // ============================================================
52    // Model Fit Statistics
53    // ============================================================
54    /// R-squared (coefficient of determination)
55    pub r_squared: f64,
56
57    /// Adjusted R-squared
58    pub adj_r_squared: f64,
59
60    /// F-statistic for overall model significance
61    pub f_statistic: f64,
62
63    /// p-value for F-statistic
64    pub f_p_value: f64,
65
66    /// Residual standard error (sigma-hat estimate)
67    pub residual_std_error: f64,
68
69    /// Degrees of freedom for residuals
70    pub df_residuals: isize,
71
72    /// Degrees of freedom for the model
73    pub df_model: isize,
74
75    // ============================================================
76    // Predictions and Diagnostics
77    // ============================================================
78    /// Fitted values (predicted values)
79    pub fitted_values: Vec<f64>,
80
81    /// Residuals (y - ŷ)
82    pub residuals: Vec<f64>,
83
84    /// Mean squared error
85    pub mse: f64,
86
87    /// Root mean squared error
88    pub rmse: f64,
89
90    /// Mean absolute error
91    pub mae: f64,
92
93    // ============================================================
94    // Sample Information
95    // ============================================================
96    /// Number of observations
97    pub n: usize,
98
99    /// Number of predictors (excluding intercept)
100    pub k: usize,
101}
102
103/// Perform weighted least squares regression
104///
105/// Fits a linear model using weighted least squares, where each observation
106/// can have a different weight. The output format matches R's `lm()` function
107/// with the `weights` parameter, providing comprehensive statistics including
108/// coefficient standard errors, t-statistics, p-values, and F-test.
109///
110/// # Arguments
111///
112/// * `y` - Response variable (n observations)
113/// * `x_vars` - Predictor variables (p vectors, each of length n)
114/// * `weights` - Observation weights (n weights, must be non-negative)
115///
116/// # Returns
117///
118/// `WlsFit` containing coefficients, fitted values, and comprehensive fit statistics
119///
120/// # Errors
121///
122/// - `Error::InsufficientData` if n <= k + 1
123/// - `Error::InvalidInput` if weights contain negative values or dimensions don't match
124/// - `Error::SingularMatrix` if the design matrix is singular even with weighting
125///
126/// # Example
127///
128/// ```
129/// use linreg_core::weighted_regression::{wls_regression, WlsFit};
130///
131/// let y = vec![2.0, 3.0, 4.0, 5.0, 6.0];
132/// let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
133/// let weights = vec![1.0, 1.0, 1.0, 1.0, 1.0]; // Equal weights = OLS
134///
135/// let fit: WlsFit = wls_regression(&y, &[x1], &weights)?;
136///
137/// // Access coefficients and statistics
138/// println!("Intercept: {} (SE: {}, t: {}, p: {})",
139///     fit.coefficients[0],
140///     fit.standard_errors[0],
141///     fit.t_statistics[0],
142///     fit.p_values[0]
143/// );
144/// println!("F-statistic: {} (p: {})", fit.f_statistic, fit.f_p_value);
145/// # Ok::<(), linreg_core::Error>(())
146/// ```
147pub fn wls_regression(
148    y: &[f64],
149    x_vars: &[Vec<f64>],
150    weights: &[f64],
151) -> Result<WlsFit> {
152    let n = y.len();
153    let k = x_vars.len();
154
155    // Validate minimum sample size
156    if n <= k + 1 {
157        return Err(Error::InsufficientData {
158            required: k + 2,
159            available: n,
160        });
161    }
162
163    // Validate dimensions
164    for (i, x_var) in x_vars.iter().enumerate() {
165        if x_var.len() != n {
166            return Err(Error::InvalidInput(format!(
167                "x[{}] has {} elements, expected {}",
168                i,
169                x_var.len(),
170                n
171            )));
172        }
173    }
174
175    if weights.len() != n {
176        return Err(Error::InvalidInput(format!(
177            "weights has {} elements, expected {}",
178            weights.len(),
179            n
180        )));
181    }
182
183    // Check for negative weights
184    for (i, &w) in weights.iter().enumerate() {
185        if w < 0.0 {
186            return Err(Error::InvalidInput(format!(
187                "weights[{}] is negative ({}), weights must be non-negative",
188                i, w
189            )));
190        }
191    }
192
193    // Check for zero total weight
194    let weight_sum: f64 = weights.iter().sum();
195    if weight_sum <= 0.0 {
196        return Err(Error::InvalidInput(
197            "Sum of weights is zero or negative".to_string()
198        ));
199    }
200
201    // Build design matrix: include intercept column
202    let mut x_data = Vec::with_capacity(n * (k + 1));
203    for i in 0..n {
204        x_data.push(1.0); // Intercept
205        for j in 0..k {
206            x_data.push(x_vars[j][i]);
207        }
208    }
209    let x = Matrix::new(n, k + 1, x_data);
210
211    // Call the WLS solver with decomposition info (single decomposition, no duplicate QR)
212    let decomp = crate::loess::wls::weighted_least_squares_with_decomposition(&x, y, weights)?;
213    let coefficients = decomp.coefficients;
214
215    // Compute fitted values
216    let fitted_values: Vec<f64> = (0..n)
217        .map(|i| {
218            let mut y_hat = coefficients[0]; // Intercept
219            for j in 0..k {
220                y_hat += coefficients[j + 1] * x_vars[j][i];
221            }
222            y_hat
223        })
224        .collect();
225
226    // Compute residuals
227    let residuals: Vec<f64> = y.iter().zip(fitted_values.iter())
228        .map(|(yi, y_hat)| yi - y_hat)
229        .collect();
230
231    // ============================================================
232    // Compute Degrees of Freedom
233    // ============================================================
234    let p = k + 1; // Number of coefficients (including intercept)
235    let df_residuals = n as isize - p as isize;
236    let df_model = k as isize;
237
238    if df_residuals <= 0 {
239        return Err(Error::InsufficientData {
240            required: p + 1,
241            available: n,
242        });
243    }
244
245    // ============================================================
246    // Compute MSE and Residual Standard Error
247    // ============================================================
248    // RSS = sum of squared residuals
249    let rss: f64 = residuals.iter().map(|r| r * r).sum();
250
251    // MSE (using n - p for unbiased estimate, like R)
252    let mse = rss / df_residuals as f64;
253
254    // Residual standard error (R's sigma-hat)
255    let residual_std_error = mse.sqrt();
256
257    // ============================================================
258    // Compute R-squared and Adjusted R-squared
259    // ============================================================
260    let ss_tot: f64 = {
261        let y_mean = y.iter().sum::<f64>() / n as f64;
262        y.iter().map(|yi| (yi - y_mean).powi(2)).sum()
263    };
264    let r_squared = if ss_tot > 0.0 {
265        1.0 - (rss / ss_tot)
266    } else {
267        0.0
268    };
269
270    let adj_r_squared = if df_residuals > 1 {
271        1.0 - ((1.0 - r_squared) * (n - 1) as f64 / df_residuals as f64)
272    } else {
273        r_squared
274    };
275
276    // ============================================================
277    // Compute Covariance Matrix of Coefficients
278    // ============================================================
279    // Uses decomposition info from the solver (no duplicate QR!)
280    let cov = if let Some(ref r_inv) = decomp.r_inv {
281        // QR path: Cov = MSE * S^-1 * R^-1 * (R^-1)' * S^-1
282        compute_covariance_from_qr(r_inv, &decomp.column_scales, mse, p)
283    } else if let Some((ref v, ref singular_values)) = decomp.svd_components {
284        // SVD path: Cov = MSE * V * diag(1/σᵢ²) * V'
285        compute_covariance_from_svd(v, singular_values, &decomp.column_scales, mse, p)
286    } else {
287        return Err(Error::SingularMatrix);
288    };
289
290    // ============================================================
291    // Extract Standard Errors (diagonal of covariance matrix)
292    // ============================================================
293    let mut standard_errors = Vec::with_capacity(p);
294    for i in 0..p {
295        let se = cov.get(i, i).sqrt();
296        standard_errors.push(se);
297    }
298
299    // ============================================================
300    // Compute t-statistics and p-values for coefficients
301    // ============================================================
302    let mut t_statistics = Vec::with_capacity(p);
303    let mut p_values = Vec::with_capacity(p);
304
305    for i in 0..p {
306        let t = coefficients[i] / standard_errors[i];
307        t_statistics.push(t);
308
309        // Two-tailed p-value using Student's t-distribution
310        let p = 2.0 * (1.0 - student_t_cdf(t.abs(), df_residuals as f64));
311        p_values.push(p);
312    }
313
314    // ============================================================
315    // Compute 95% Confidence Intervals
316    // ============================================================
317    let alpha = 0.05;
318    let t_critical = t_critical_quantile(df_residuals as f64, alpha);
319
320    let conf_int_lower: Vec<f64> = coefficients
321        .iter()
322        .zip(&standard_errors)
323        .map(|(&b, &se)| b - t_critical * se)
324        .collect();
325    let conf_int_upper: Vec<f64> = coefficients
326        .iter()
327        .zip(&standard_errors)
328        .map(|(&b, &se)| b + t_critical * se)
329        .collect();
330
331    // ============================================================
332    // Compute F-statistic and p-value for overall model
333    // ============================================================
334    // F = ((TSS - RSS) / k) / (RSS / (n - k - 1))
335    // where TSS is total sum of squares, RSS is residual sum of squares,
336    // and k is the number of predictors (excluding intercept)
337    let f_statistic = if ss_tot > rss && k > 0 {
338        ((ss_tot - rss) / k as f64) / (rss / df_residuals as f64)
339    } else {
340        0.0
341    };
342
343    let f_p_value = if f_statistic > 0.0 {
344        f_p_value(f_statistic, k as f64, df_residuals as f64)
345    } else {
346        1.0
347    };
348
349    // ============================================================
350    // Additional Error Metrics
351    // ============================================================
352    let rmse = mse.sqrt();
353    let mae = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
354
355    Ok(WlsFit {
356        coefficients,
357        standard_errors,
358        t_statistics,
359        p_values,
360        conf_int_lower,
361        conf_int_upper,
362        r_squared,
363        adj_r_squared,
364        f_statistic,
365        f_p_value,
366        residual_std_error,
367        df_residuals,
368        df_model,
369        fitted_values,
370        residuals,
371        mse,
372        rmse,
373        mae,
374        n,
375        k,
376    })
377}
378
379/// Compute covariance matrix from QR decomposition
380///
381/// Formula: Cov(β_orig)_ij = MSE * Σ_l(R^-1_il * R^-1_jl) / (scales\[i\] * scales\[j\])
382fn compute_covariance_from_qr(
383    r_inv: &Matrix,
384    column_scales: &[f64],
385    mse: f64,
386    p: usize,
387) -> Matrix {
388    let mut cov = Matrix::zeros(p, p);
389    for i in 0..p {
390        for j in 0..p {
391            let mut sum = 0.0;
392            for l in 0..p {
393                sum += r_inv.get(i, l) * r_inv.get(j, l);
394            }
395            cov.set(i, j, mse * sum / (column_scales[i] * column_scales[j]));
396        }
397    }
398    cov
399}
400
401/// Compute covariance matrix from SVD decomposition
402///
403/// Formula: Cov(β) = MSE * V * diag(1/σᵢ²) * V'
404/// Then compensate for equilibration: divide by scales\[i\] * scales\[j\]
405fn compute_covariance_from_svd(
406    v: &Matrix,
407    singular_values: &[f64],
408    column_scales: &[f64],
409    mse: f64,
410    p: usize,
411) -> Matrix {
412    // Use same tolerance as svd_solve in linalg.rs: sigma[0] * 100 * epsilon
413    let max_sigma = singular_values.first().copied().unwrap_or(0.0);
414    let tol = if max_sigma > 0.0 {
415        max_sigma * 100.0 * f64::EPSILON
416    } else {
417        f64::EPSILON
418    };
419
420    let mut cov = Matrix::zeros(p, p);
421    for i in 0..p {
422        for j in 0..p {
423            let mut sum = 0.0;
424            for l in 0..singular_values.len().min(p) {
425                if singular_values[l] > tol {
426                    let inv_sigma_sq = 1.0 / (singular_values[l] * singular_values[l]);
427                    sum += v.get(i, l) * v.get(j, l) * inv_sigma_sq;
428                }
429            }
430            cov.set(i, j, mse * sum / (column_scales[i] * column_scales[j]));
431        }
432    }
433    cov
434}
435
436// ============================================================================
437// Model Serialization Traits
438// ============================================================================
439
440// Generate ModelSave and ModelLoad implementations using macro
441impl_serialization!(WlsFit, ModelType::WLS, "WLS");
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    #[test]
448    fn test_wls_equal_weights_matches_ols() {
449        // WLS with equal weights should match OLS
450        let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
451        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
452        let weights = vec![1.0; 5]; // Equal weights
453
454        let fit = wls_regression(&y, &[x], &weights).unwrap();
455
456        // For perfect linear y = x, intercept should be ~0, slope ~1
457        assert!((fit.coefficients[0] - 0.0).abs() < 1e-10);
458        assert!((fit.coefficients[1] - 1.0).abs() < 1e-10);
459        assert_eq!(fit.k, 1);
460        assert_eq!(fit.n, 5);
461
462        // Check that statistics are computed
463        assert!(fit.standard_errors.len() == 2);
464        assert!(fit.t_statistics.len() == 2);
465        assert!(fit.p_values.len() == 2);
466        assert!(fit.f_statistic > 0.0);
467        assert!(fit.f_p_value < 0.05); // Should be significant for perfect fit
468    }
469
470    #[test]
471    fn test_wls_with_weighted_data() {
472        // Create data where one point is an outlier
473        let y = vec![2.0, 4.0, 6.0, 8.0, 100.0]; // Last point is outlier
474        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
475
476        // With low weight on the outlier, the fit should ignore it
477        let weights_low = vec![1.0, 1.0, 1.0, 1.0, 0.01];
478        let fit_low = wls_regression(&y, &[x.clone()], &weights_low).unwrap();
479
480        // With high weight on the outlier, the fit should be pulled toward it
481        let weights_high = vec![1.0, 1.0, 1.0, 1.0, 10.0];
482        let fit_high = wls_regression(&y, &[x], &weights_high).unwrap();
483
484        // The low-weight fit should have slope close to 2 (from first 4 points)
485        // The high-weight fit should have a much larger slope
486        assert!(fit_low.coefficients[1] < fit_high.coefficients[1]);
487    }
488
489    #[test]
490    fn test_wls_negative_weight_error() {
491        let y = vec![1.0, 2.0, 3.0];
492        let x = vec![1.0, 2.0, 3.0];
493        let weights = vec![1.0, -1.0, 1.0]; // Negative weight
494
495        let result = wls_regression(&y, &[x], &weights);
496        assert!(result.is_err());
497    }
498
499    #[test]
500    fn test_wls_multiple_predictors() {
501        // Use non-collinear predictors (x2 is not a linear function of x1)
502        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
503        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
504        let x2 = vec![1.0, 4.0, 2.0, 5.0, 3.0];  // Independent of x1
505        let weights = vec![1.0; 5];
506
507        let fit = wls_regression(&y, &[x1, x2], &weights).unwrap();
508
509        assert_eq!(fit.k, 2); // Two predictors
510        assert_eq!(fit.coefficients.len(), 3); // Intercept + 2 slopes
511        assert_eq!(fit.fitted_values.len(), 5);
512        assert_eq!(fit.standard_errors.len(), 3);
513        assert_eq!(fit.t_statistics.len(), 3);
514        assert_eq!(fit.p_values.len(), 3);
515    }
516
517    #[test]
518    fn test_wls_insufficient_data() {
519        let y = vec![1.0, 2.0];
520        let x1 = vec![1.0, 2.0];
521        let x2 = vec![0.5, 1.0]; // Second predictor
522        let weights = vec![1.0, 1.0];
523
524        // n=2, k=2, need k+2=4 observations
525        let result = wls_regression(&y, &[x1, x2], &weights);
526        assert!(result.is_err());
527    }
528
529    #[test]
530    fn test_wls_statistics_completeness() {
531        // Verify all statistics are computed
532        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
533        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
534        let weights = vec![1.0; 5];
535
536        let fit = wls_regression(&y, &[x], &weights).unwrap();
537
538        // Check all fields are populated
539        assert_eq!(fit.coefficients.len(), 2);
540        assert_eq!(fit.standard_errors.len(), 2);
541        assert_eq!(fit.t_statistics.len(), 2);
542        assert_eq!(fit.p_values.len(), 2);
543        assert!(fit.r_squared >= 0.0 && fit.r_squared <= 1.0);
544        assert!(fit.adj_r_squared >= 0.0 && fit.adj_r_squared <= 1.0);
545        assert!(fit.f_statistic >= 0.0);
546        assert!(fit.f_p_value >= 0.0 && fit.f_p_value <= 1.0);
547        assert!(fit.residual_std_error >= 0.0);
548        assert_eq!(fit.df_residuals, 3); // n=5, p=2, df=5-2=3
549        assert_eq!(fit.df_model, 1);
550        assert_eq!(fit.fitted_values.len(), 5);
551        assert_eq!(fit.residuals.len(), 5);
552        assert!(fit.mse >= 0.0);
553        assert!(fit.rmse >= 0.0);
554        assert!(fit.mae >= 0.0);
555        assert_eq!(fit.n, 5);
556        assert_eq!(fit.k, 1);
557    }
558
559    #[test]
560    fn test_wls_zero_sum_weights_error() {
561        let y = vec![1.0, 2.0, 3.0];
562        let x = vec![1.0, 2.0, 3.0];
563        let weights = vec![0.0, 0.0, 0.0]; // All zero
564
565        let result = wls_regression(&y, &[x], &weights);
566        assert!(result.is_err());
567    }
568
569    #[test]
570    fn test_wls_svd_fallback_computes_standard_errors() {
571        // Near-collinear predictors that trigger SVD fallback
572        let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
573        let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
574        let x2 = vec![2.0, 4.0, 6.0, 8.0, 10.0]; // x2 = 2*x1 (perfectly collinear)
575        let weights = vec![1.0; 5];
576
577        let result = wls_regression(&y, &[x1, x2], &weights);
578        // Should either succeed with finite SEs or fail gracefully
579        // Previously this would succeed for coefficients but fail for SEs
580        match result {
581            Ok(fit) => {
582                // If it succeeds, SEs should be finite (from SVD covariance path)
583                for se in &fit.standard_errors {
584                    assert!(se.is_finite(), "Standard error should be finite, got {}", se);
585                }
586            }
587            Err(_) => {
588                // Graceful failure is also acceptable for perfectly collinear data
589            }
590        }
591    }
592}