linreg_core/weighted_regression/wls.rs
1//! Weighted Least Squares (WLS) regression
2//!
3//! This module provides WLS regression using the weighted least squares solver
4//! from the LOESS module. WLS is useful when:
5//! - Observations have different precision/variances (heteroscedasticity)
6//! - You want to incorporate robustness weights from a previous fit
7//! - Certain observations should be given more influence
8//!
9//! The output format matches R's `lm()` function with weights, providing:
10//! - Coefficient estimates with standard errors, t-values, and p-values
11//! - F-statistic and p-value for overall model significance
12//! - Residual standard error, R², adjusted R²
13
14use crate::{
15 core::{f_p_value, t_critical_quantile},
16 distributions::student_t_cdf,
17 error::{Error, Result},
18 linalg::Matrix,
19 serialization::types::ModelType,
20 impl_serialization,
21};
22use serde::{Deserialize, Serialize};
23
24/// WLS regression result
25///
26/// Contains the fitted coefficients and comprehensive model fit statistics
27/// matching R's `summary(lm(y ~ x, weights=w))` output.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct WlsFit {
30 // ============================================================
31 // Coefficient Statistics (matching R's coefficients table)
32 // ============================================================
33 /// Coefficient values (including intercept as first element)
34 pub coefficients: Vec<f64>,
35
36 /// Standard errors of the coefficients
37 pub standard_errors: Vec<f64>,
38
39 /// t-statistics for coefficient significance tests
40 pub t_statistics: Vec<f64>,
41
42 /// Two-tailed p-values for coefficients
43 pub p_values: Vec<f64>,
44
45 /// Lower bounds of 95% confidence intervals for coefficients
46 pub conf_int_lower: Vec<f64>,
47
48 /// Upper bounds of 95% confidence intervals for coefficients
49 pub conf_int_upper: Vec<f64>,
50
51 // ============================================================
52 // Model Fit Statistics
53 // ============================================================
54 /// R-squared (coefficient of determination)
55 pub r_squared: f64,
56
57 /// Adjusted R-squared
58 pub adj_r_squared: f64,
59
60 /// F-statistic for overall model significance
61 pub f_statistic: f64,
62
63 /// p-value for F-statistic
64 pub f_p_value: f64,
65
66 /// Residual standard error (sigma-hat estimate)
67 pub residual_std_error: f64,
68
69 /// Degrees of freedom for residuals
70 pub df_residuals: isize,
71
72 /// Degrees of freedom for the model
73 pub df_model: isize,
74
75 // ============================================================
76 // Predictions and Diagnostics
77 // ============================================================
78 /// Fitted values (predicted values)
79 pub fitted_values: Vec<f64>,
80
81 /// Residuals (y - ŷ)
82 pub residuals: Vec<f64>,
83
84 /// Mean squared error
85 pub mse: f64,
86
87 /// Root mean squared error
88 pub rmse: f64,
89
90 /// Mean absolute error
91 pub mae: f64,
92
93 // ============================================================
94 // Sample Information
95 // ============================================================
96 /// Number of observations
97 pub n: usize,
98
99 /// Number of predictors (excluding intercept)
100 pub k: usize,
101}
102
103/// Perform weighted least squares regression
104///
105/// Fits a linear model using weighted least squares, where each observation
106/// can have a different weight. The output format matches R's `lm()` function
107/// with the `weights` parameter, providing comprehensive statistics including
108/// coefficient standard errors, t-statistics, p-values, and F-test.
109///
110/// # Arguments
111///
112/// * `y` - Response variable (n observations)
113/// * `x_vars` - Predictor variables (p vectors, each of length n)
114/// * `weights` - Observation weights (n weights, must be non-negative)
115///
116/// # Returns
117///
118/// `WlsFit` containing coefficients, fitted values, and comprehensive fit statistics
119///
120/// # Errors
121///
122/// - `Error::InsufficientData` if n <= k + 1
123/// - `Error::InvalidInput` if weights contain negative values or dimensions don't match
124/// - `Error::SingularMatrix` if the design matrix is singular even with weighting
125///
126/// # Example
127///
128/// ```
129/// use linreg_core::weighted_regression::{wls_regression, WlsFit};
130///
131/// let y = vec![2.0, 3.0, 4.0, 5.0, 6.0];
132/// let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
133/// let weights = vec![1.0, 1.0, 1.0, 1.0, 1.0]; // Equal weights = OLS
134///
135/// let fit: WlsFit = wls_regression(&y, &[x1], &weights)?;
136///
137/// // Access coefficients and statistics
138/// println!("Intercept: {} (SE: {}, t: {}, p: {})",
139/// fit.coefficients[0],
140/// fit.standard_errors[0],
141/// fit.t_statistics[0],
142/// fit.p_values[0]
143/// );
144/// println!("F-statistic: {} (p: {})", fit.f_statistic, fit.f_p_value);
145/// # Ok::<(), linreg_core::Error>(())
146/// ```
147pub fn wls_regression(
148 y: &[f64],
149 x_vars: &[Vec<f64>],
150 weights: &[f64],
151) -> Result<WlsFit> {
152 let n = y.len();
153 let k = x_vars.len();
154
155 // Validate minimum sample size
156 if n <= k + 1 {
157 return Err(Error::InsufficientData {
158 required: k + 2,
159 available: n,
160 });
161 }
162
163 // Validate dimensions
164 for (i, x_var) in x_vars.iter().enumerate() {
165 if x_var.len() != n {
166 return Err(Error::InvalidInput(format!(
167 "x[{}] has {} elements, expected {}",
168 i,
169 x_var.len(),
170 n
171 )));
172 }
173 }
174
175 if weights.len() != n {
176 return Err(Error::InvalidInput(format!(
177 "weights has {} elements, expected {}",
178 weights.len(),
179 n
180 )));
181 }
182
183 // Check for negative weights
184 for (i, &w) in weights.iter().enumerate() {
185 if w < 0.0 {
186 return Err(Error::InvalidInput(format!(
187 "weights[{}] is negative ({}), weights must be non-negative",
188 i, w
189 )));
190 }
191 }
192
193 // Check for zero total weight
194 let weight_sum: f64 = weights.iter().sum();
195 if weight_sum <= 0.0 {
196 return Err(Error::InvalidInput(
197 "Sum of weights is zero or negative".to_string()
198 ));
199 }
200
201 // Build design matrix: include intercept column
202 let mut x_data = Vec::with_capacity(n * (k + 1));
203 for i in 0..n {
204 x_data.push(1.0); // Intercept
205 for j in 0..k {
206 x_data.push(x_vars[j][i]);
207 }
208 }
209 let x = Matrix::new(n, k + 1, x_data);
210
211 // Call the WLS solver with decomposition info (single decomposition, no duplicate QR)
212 let decomp = crate::loess::wls::weighted_least_squares_with_decomposition(&x, y, weights)?;
213 let coefficients = decomp.coefficients;
214
215 // Compute fitted values
216 let fitted_values: Vec<f64> = (0..n)
217 .map(|i| {
218 let mut y_hat = coefficients[0]; // Intercept
219 for j in 0..k {
220 y_hat += coefficients[j + 1] * x_vars[j][i];
221 }
222 y_hat
223 })
224 .collect();
225
226 // Compute residuals
227 let residuals: Vec<f64> = y.iter().zip(fitted_values.iter())
228 .map(|(yi, y_hat)| yi - y_hat)
229 .collect();
230
231 // ============================================================
232 // Compute Degrees of Freedom
233 // ============================================================
234 let p = k + 1; // Number of coefficients (including intercept)
235 let df_residuals = n as isize - p as isize;
236 let df_model = k as isize;
237
238 if df_residuals <= 0 {
239 return Err(Error::InsufficientData {
240 required: p + 1,
241 available: n,
242 });
243 }
244
245 // ============================================================
246 // Compute MSE and Residual Standard Error
247 // ============================================================
248 // RSS = sum of squared residuals
249 let rss: f64 = residuals.iter().map(|r| r * r).sum();
250
251 // MSE (using n - p for unbiased estimate, like R)
252 let mse = rss / df_residuals as f64;
253
254 // Residual standard error (R's sigma-hat)
255 let residual_std_error = mse.sqrt();
256
257 // ============================================================
258 // Compute R-squared and Adjusted R-squared
259 // ============================================================
260 let ss_tot: f64 = {
261 let y_mean = y.iter().sum::<f64>() / n as f64;
262 y.iter().map(|yi| (yi - y_mean).powi(2)).sum()
263 };
264 let r_squared = if ss_tot > 0.0 {
265 1.0 - (rss / ss_tot)
266 } else {
267 0.0
268 };
269
270 let adj_r_squared = if df_residuals > 1 {
271 1.0 - ((1.0 - r_squared) * (n - 1) as f64 / df_residuals as f64)
272 } else {
273 r_squared
274 };
275
276 // ============================================================
277 // Compute Covariance Matrix of Coefficients
278 // ============================================================
279 // Uses decomposition info from the solver (no duplicate QR!)
280 let cov = if let Some(ref r_inv) = decomp.r_inv {
281 // QR path: Cov = MSE * S^-1 * R^-1 * (R^-1)' * S^-1
282 compute_covariance_from_qr(r_inv, &decomp.column_scales, mse, p)
283 } else if let Some((ref v, ref singular_values)) = decomp.svd_components {
284 // SVD path: Cov = MSE * V * diag(1/σᵢ²) * V'
285 compute_covariance_from_svd(v, singular_values, &decomp.column_scales, mse, p)
286 } else {
287 return Err(Error::SingularMatrix);
288 };
289
290 // ============================================================
291 // Extract Standard Errors (diagonal of covariance matrix)
292 // ============================================================
293 let mut standard_errors = Vec::with_capacity(p);
294 for i in 0..p {
295 let se = cov.get(i, i).sqrt();
296 standard_errors.push(se);
297 }
298
299 // ============================================================
300 // Compute t-statistics and p-values for coefficients
301 // ============================================================
302 let mut t_statistics = Vec::with_capacity(p);
303 let mut p_values = Vec::with_capacity(p);
304
305 for i in 0..p {
306 let t = coefficients[i] / standard_errors[i];
307 t_statistics.push(t);
308
309 // Two-tailed p-value using Student's t-distribution
310 let p = 2.0 * (1.0 - student_t_cdf(t.abs(), df_residuals as f64));
311 p_values.push(p);
312 }
313
314 // ============================================================
315 // Compute 95% Confidence Intervals
316 // ============================================================
317 let alpha = 0.05;
318 let t_critical = t_critical_quantile(df_residuals as f64, alpha);
319
320 let conf_int_lower: Vec<f64> = coefficients
321 .iter()
322 .zip(&standard_errors)
323 .map(|(&b, &se)| b - t_critical * se)
324 .collect();
325 let conf_int_upper: Vec<f64> = coefficients
326 .iter()
327 .zip(&standard_errors)
328 .map(|(&b, &se)| b + t_critical * se)
329 .collect();
330
331 // ============================================================
332 // Compute F-statistic and p-value for overall model
333 // ============================================================
334 // F = ((TSS - RSS) / k) / (RSS / (n - k - 1))
335 // where TSS is total sum of squares, RSS is residual sum of squares,
336 // and k is the number of predictors (excluding intercept)
337 let f_statistic = if ss_tot > rss && k > 0 {
338 ((ss_tot - rss) / k as f64) / (rss / df_residuals as f64)
339 } else {
340 0.0
341 };
342
343 let f_p_value = if f_statistic > 0.0 {
344 f_p_value(f_statistic, k as f64, df_residuals as f64)
345 } else {
346 1.0
347 };
348
349 // ============================================================
350 // Additional Error Metrics
351 // ============================================================
352 let rmse = mse.sqrt();
353 let mae = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
354
355 Ok(WlsFit {
356 coefficients,
357 standard_errors,
358 t_statistics,
359 p_values,
360 conf_int_lower,
361 conf_int_upper,
362 r_squared,
363 adj_r_squared,
364 f_statistic,
365 f_p_value,
366 residual_std_error,
367 df_residuals,
368 df_model,
369 fitted_values,
370 residuals,
371 mse,
372 rmse,
373 mae,
374 n,
375 k,
376 })
377}
378
379/// Compute covariance matrix from QR decomposition
380///
381/// Formula: Cov(β_orig)_ij = MSE * Σ_l(R^-1_il * R^-1_jl) / (scales\[i\] * scales\[j\])
382fn compute_covariance_from_qr(
383 r_inv: &Matrix,
384 column_scales: &[f64],
385 mse: f64,
386 p: usize,
387) -> Matrix {
388 let mut cov = Matrix::zeros(p, p);
389 for i in 0..p {
390 for j in 0..p {
391 let mut sum = 0.0;
392 for l in 0..p {
393 sum += r_inv.get(i, l) * r_inv.get(j, l);
394 }
395 cov.set(i, j, mse * sum / (column_scales[i] * column_scales[j]));
396 }
397 }
398 cov
399}
400
401/// Compute covariance matrix from SVD decomposition
402///
403/// Formula: Cov(β) = MSE * V * diag(1/σᵢ²) * V'
404/// Then compensate for equilibration: divide by scales\[i\] * scales\[j\]
405fn compute_covariance_from_svd(
406 v: &Matrix,
407 singular_values: &[f64],
408 column_scales: &[f64],
409 mse: f64,
410 p: usize,
411) -> Matrix {
412 // Use same tolerance as svd_solve in linalg.rs: sigma[0] * 100 * epsilon
413 let max_sigma = singular_values.first().copied().unwrap_or(0.0);
414 let tol = if max_sigma > 0.0 {
415 max_sigma * 100.0 * f64::EPSILON
416 } else {
417 f64::EPSILON
418 };
419
420 let mut cov = Matrix::zeros(p, p);
421 for i in 0..p {
422 for j in 0..p {
423 let mut sum = 0.0;
424 for l in 0..singular_values.len().min(p) {
425 if singular_values[l] > tol {
426 let inv_sigma_sq = 1.0 / (singular_values[l] * singular_values[l]);
427 sum += v.get(i, l) * v.get(j, l) * inv_sigma_sq;
428 }
429 }
430 cov.set(i, j, mse * sum / (column_scales[i] * column_scales[j]));
431 }
432 }
433 cov
434}
435
436// ============================================================================
437// Model Serialization Traits
438// ============================================================================
439
440// Generate ModelSave and ModelLoad implementations using macro
441impl_serialization!(WlsFit, ModelType::WLS, "WLS");
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[test]
448 fn test_wls_equal_weights_matches_ols() {
449 // WLS with equal weights should match OLS
450 let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
451 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
452 let weights = vec![1.0; 5]; // Equal weights
453
454 let fit = wls_regression(&y, &[x], &weights).unwrap();
455
456 // For perfect linear y = x, intercept should be ~0, slope ~1
457 assert!((fit.coefficients[0] - 0.0).abs() < 1e-10);
458 assert!((fit.coefficients[1] - 1.0).abs() < 1e-10);
459 assert_eq!(fit.k, 1);
460 assert_eq!(fit.n, 5);
461
462 // Check that statistics are computed
463 assert!(fit.standard_errors.len() == 2);
464 assert!(fit.t_statistics.len() == 2);
465 assert!(fit.p_values.len() == 2);
466 assert!(fit.f_statistic > 0.0);
467 assert!(fit.f_p_value < 0.05); // Should be significant for perfect fit
468 }
469
470 #[test]
471 fn test_wls_with_weighted_data() {
472 // Create data where one point is an outlier
473 let y = vec![2.0, 4.0, 6.0, 8.0, 100.0]; // Last point is outlier
474 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
475
476 // With low weight on the outlier, the fit should ignore it
477 let weights_low = vec![1.0, 1.0, 1.0, 1.0, 0.01];
478 let fit_low = wls_regression(&y, &[x.clone()], &weights_low).unwrap();
479
480 // With high weight on the outlier, the fit should be pulled toward it
481 let weights_high = vec![1.0, 1.0, 1.0, 1.0, 10.0];
482 let fit_high = wls_regression(&y, &[x], &weights_high).unwrap();
483
484 // The low-weight fit should have slope close to 2 (from first 4 points)
485 // The high-weight fit should have a much larger slope
486 assert!(fit_low.coefficients[1] < fit_high.coefficients[1]);
487 }
488
489 #[test]
490 fn test_wls_negative_weight_error() {
491 let y = vec![1.0, 2.0, 3.0];
492 let x = vec![1.0, 2.0, 3.0];
493 let weights = vec![1.0, -1.0, 1.0]; // Negative weight
494
495 let result = wls_regression(&y, &[x], &weights);
496 assert!(result.is_err());
497 }
498
499 #[test]
500 fn test_wls_multiple_predictors() {
501 // Use non-collinear predictors (x2 is not a linear function of x1)
502 let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
503 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
504 let x2 = vec![1.0, 4.0, 2.0, 5.0, 3.0]; // Independent of x1
505 let weights = vec![1.0; 5];
506
507 let fit = wls_regression(&y, &[x1, x2], &weights).unwrap();
508
509 assert_eq!(fit.k, 2); // Two predictors
510 assert_eq!(fit.coefficients.len(), 3); // Intercept + 2 slopes
511 assert_eq!(fit.fitted_values.len(), 5);
512 assert_eq!(fit.standard_errors.len(), 3);
513 assert_eq!(fit.t_statistics.len(), 3);
514 assert_eq!(fit.p_values.len(), 3);
515 }
516
517 #[test]
518 fn test_wls_insufficient_data() {
519 let y = vec![1.0, 2.0];
520 let x1 = vec![1.0, 2.0];
521 let x2 = vec![0.5, 1.0]; // Second predictor
522 let weights = vec![1.0, 1.0];
523
524 // n=2, k=2, need k+2=4 observations
525 let result = wls_regression(&y, &[x1, x2], &weights);
526 assert!(result.is_err());
527 }
528
529 #[test]
530 fn test_wls_statistics_completeness() {
531 // Verify all statistics are computed
532 let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
533 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
534 let weights = vec![1.0; 5];
535
536 let fit = wls_regression(&y, &[x], &weights).unwrap();
537
538 // Check all fields are populated
539 assert_eq!(fit.coefficients.len(), 2);
540 assert_eq!(fit.standard_errors.len(), 2);
541 assert_eq!(fit.t_statistics.len(), 2);
542 assert_eq!(fit.p_values.len(), 2);
543 assert!(fit.r_squared >= 0.0 && fit.r_squared <= 1.0);
544 assert!(fit.adj_r_squared >= 0.0 && fit.adj_r_squared <= 1.0);
545 assert!(fit.f_statistic >= 0.0);
546 assert!(fit.f_p_value >= 0.0 && fit.f_p_value <= 1.0);
547 assert!(fit.residual_std_error >= 0.0);
548 assert_eq!(fit.df_residuals, 3); // n=5, p=2, df=5-2=3
549 assert_eq!(fit.df_model, 1);
550 assert_eq!(fit.fitted_values.len(), 5);
551 assert_eq!(fit.residuals.len(), 5);
552 assert!(fit.mse >= 0.0);
553 assert!(fit.rmse >= 0.0);
554 assert!(fit.mae >= 0.0);
555 assert_eq!(fit.n, 5);
556 assert_eq!(fit.k, 1);
557 }
558
559 #[test]
560 fn test_wls_zero_sum_weights_error() {
561 let y = vec![1.0, 2.0, 3.0];
562 let x = vec![1.0, 2.0, 3.0];
563 let weights = vec![0.0, 0.0, 0.0]; // All zero
564
565 let result = wls_regression(&y, &[x], &weights);
566 assert!(result.is_err());
567 }
568
569 #[test]
570 fn test_wls_svd_fallback_computes_standard_errors() {
571 // Near-collinear predictors that trigger SVD fallback
572 let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
573 let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
574 let x2 = vec![2.0, 4.0, 6.0, 8.0, 10.0]; // x2 = 2*x1 (perfectly collinear)
575 let weights = vec![1.0; 5];
576
577 let result = wls_regression(&y, &[x1, x2], &weights);
578 // Should either succeed with finite SEs or fail gracefully
579 // Previously this would succeed for coefficients but fail for SEs
580 match result {
581 Ok(fit) => {
582 // If it succeeds, SEs should be finite (from SVD covariance path)
583 for se in &fit.standard_errors {
584 assert!(se.is_finite(), "Standard error should be finite, got {}", se);
585 }
586 }
587 Err(_) => {
588 // Graceful failure is also acceptable for perfectly collinear data
589 }
590 }
591 }
592}