use crate::{error::Error, linalg::Matrix};
pub struct WlsDecomposition {
pub coefficients: Vec<f64>,
pub column_scales: Vec<f64>,
pub r_inv: Option<Matrix>,
pub svd_components: Option<(Matrix, Vec<f64>)>,
}
pub fn weighted_least_squares(x: &Matrix, y: &[f64], weights: &[f64]) -> Result<Vec<f64>, Error> {
Ok(weighted_least_squares_with_decomposition(x, y, weights)?.coefficients)
}
pub fn weighted_least_squares_with_decomposition(
x: &Matrix,
y: &[f64],
weights: &[f64],
) -> Result<WlsDecomposition, Error> {
let n = x.rows;
let p = x.cols;
assert_eq!(
y.len(),
n,
"Response vector length must match number of rows"
);
assert_eq!(
weights.len(),
n,
"Weight vector length must match number of rows"
);
let weight_sum: f64 = weights.iter().sum();
if weight_sum <= 0.0 {
return Err(Error::InvalidInput("All weights are zero".to_string()));
}
let mut x_weighted_data = Vec::with_capacity(n * p);
let mut y_weighted = Vec::with_capacity(n);
for i in 0..n {
let sqrt_weight = weights[i].sqrt();
y_weighted.push(y[i] * sqrt_weight);
for j in 0..p {
x_weighted_data.push(x.get(i, j) * sqrt_weight);
}
}
let mut x_weighted = Matrix::new(n, p, x_weighted_data);
let column_scales = equilibrate_columns(&mut x_weighted);
let qr_result = try_qr_solve_with_r_inv(&x_weighted, &y_weighted, p);
match qr_result {
Ok((coeffs, r_inv)) => {
let mut final_coeffs = Vec::with_capacity(p);
for j in 0..p {
final_coeffs.push(coeffs[j] / column_scales[j]);
}
Ok(WlsDecomposition {
coefficients: final_coeffs,
column_scales,
r_inv: Some(r_inv),
svd_components: None,
})
}
Err(Error::SingularMatrix) => {
let svd_result = try_svd_solve_with_components(&x_weighted, &y_weighted, p);
match svd_result {
Ok((coeffs, v, singular_values)) => {
let mut final_coeffs = Vec::with_capacity(p);
for j in 0..p {
final_coeffs.push(coeffs[j] / column_scales[j]);
}
Ok(WlsDecomposition {
coefficients: final_coeffs,
column_scales,
r_inv: None,
svd_components: Some((v, singular_values)),
})
}
Err(e) => Err(e),
}
}
Err(e) => Err(e),
}
}
fn try_qr_solve_with_r_inv(
x_weighted: &Matrix,
y_weighted: &[f64],
p: usize,
) -> Result<(Vec<f64>, Matrix), Error> {
let (q, r) = x_weighted.qr();
let mut r_upper = Matrix::zeros(p, p);
for i in 0..p {
for j in 0..p {
r_upper.set(i, j, r.get(i, j));
}
}
let q_t = q.transpose();
let qty = q_t.mul_vec(y_weighted);
let rhs_vec = qty[0..p].to_vec();
let rhs_mat = Matrix::new(p, 1, rhs_vec);
let r_inv = match r_upper.invert_upper_triangular() {
Some(inv) => inv,
None => return Err(Error::SingularMatrix),
};
let result = r_inv.matmul(&rhs_mat);
let mut coeffs = Vec::with_capacity(p);
for j in 0..p {
coeffs.push(result.get(j, 0));
}
Ok((coeffs, r_inv))
}
fn try_svd_solve_with_components(
x_weighted: &Matrix,
y_weighted: &[f64],
_p: usize,
) -> Result<(Vec<f64>, Matrix, Vec<f64>), Error> {
let svd_result = x_weighted.svd();
let coeffs = x_weighted.svd_solve(&svd_result, y_weighted);
let v = svd_result.v_t.transpose();
let singular_values = svd_result.sigma;
Ok((coeffs, v, singular_values))
}
pub fn equilibrate_columns(x: &mut Matrix) -> Vec<f64> {
let n = x.rows;
let p = x.cols;
let mut column_scales = Vec::with_capacity(p);
for j in 0..p {
let mut norm = 0.0;
for i in 0..n {
let val = x.get(i, j);
norm += val * val;
}
norm = norm.sqrt();
if norm > 0.0 {
for i in 0..n {
let val = x.get(i, j);
x.set(i, j, val / norm);
}
column_scales.push(norm);
} else {
column_scales.push(1.0);
}
}
column_scales
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_weighted_least_squares_simple() {
let x_data = vec![
1.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, ];
let x = Matrix::new(5, 2, x_data);
let y = vec![1.0, 3.0, 5.0, 7.0, 9.0]; let weights = vec![1.0, 1.0, 1.0, 1.0, 1.0];
let coeffs = weighted_least_squares(&x, &y, &weights).unwrap();
assert!((coeffs[0] - 1.0).abs() < 1e-10);
assert!((coeffs[1] - 2.0).abs() < 1e-10);
}
#[test]
fn test_weighted_least_squares_with_weights() {
let x_data = vec![
1.0, 0.0,
1.0, 1.0,
1.0, 2.0,
];
let x = Matrix::new(3, 2, x_data);
let y = vec![0.0, 1.0, 3.0];
let weights = vec![0.01, 100.0, 0.01];
let coeffs = weighted_least_squares(&x, &y, &weights).unwrap();
let predicted_at_1 = coeffs[0] + coeffs[1] * 1.0;
assert!((predicted_at_1 - 1.0).abs() < 0.1);
}
#[test]
fn test_weighted_least_squares_matches_ols() {
let x_data = vec![
1.0, 1.0,
1.0, 2.0,
1.0, 3.0,
1.0, 4.0,
];
let x = Matrix::new(4, 2, x_data);
let y = vec![3.1, 4.9, 7.2, 9.8];
let weights = vec![1.0, 1.0, 1.0, 1.0];
let coeffs = weighted_least_squares(&x, &y, &weights).unwrap();
assert!((coeffs[0] - 1.0).abs() < 1.0); assert!((coeffs[1] - 2.0).abs() < 0.5); }
#[test]
fn test_weighted_least_squares_zero_weight() {
let x_data = vec![
1.0, 0.0,
1.0, 1.0,
1.0, 2.0,
];
let x = Matrix::new(3, 2, x_data);
let y = vec![0.0, 1.0, 100.0];
let weights = vec![1.0, 1.0, 0.0];
let coeffs = weighted_least_squares(&x, &y, &weights).unwrap();
assert!((coeffs[0] - 0.0).abs() < 1e-10);
assert!((coeffs[1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_svd_fallback_rank_deficient() {
let x_data = vec![
1.0, 1.0, 1.0, 1.0, 2.0, 4.001, 1.0, 3.0, 6.001, 1.0, 4.0, 8.001, 1.0, 5.0, 10.001, ];
let x = Matrix::new(5, 3, x_data);
let y = vec![5.0, 7.0, 9.0, 11.0, 13.0];
let weights = vec![1.0, 1.0, 1.0, 1.0, 1.0];
let coeffs = weighted_least_squares(&x, &y, &weights).unwrap();
assert!((coeffs[0] - 3.0).abs() < 1.0);
assert!((coeffs[1] - 2.0).abs() < 1.0);
}
#[test]
fn test_svd_fallback_perfect_collinearity() {
let x_data = vec![
1.0, 1.0, 2.0, 1.0, 2.0, 4.0, 1.0, 3.0, 6.0, 1.0, 4.0, 8.0, 1.0, 5.0, 10.0, ];
let x = Matrix::new(5, 3, x_data);
let y = vec![5.0, 7.0, 9.0, 11.0, 13.0];
let weights = vec![1.0, 1.0, 1.0, 1.0, 1.0];
let coeffs = weighted_least_squares(&x, &y, &weights).unwrap();
assert!(coeffs.iter().all(|c| c.is_finite()));
for i in 0..5 {
let x1 = (i + 1) as f64;
let x2 = 2.0 * x1;
let pred = coeffs[0] + coeffs[1] * x1 + coeffs[2] * x2;
assert!((pred - y[i]).abs() < 1e-6, "Prediction mismatch at i={}: pred={}, y={}", i, pred, y[i]);
}
}
#[test]
fn test_svd_tolerance_matches() {
let x_data = vec![
1.0, 1.0,
1.0, 2.0,
1.0, 3.0,
];
let x = Matrix::new(3, 2, x_data);
let y = vec![2.0, 4.0, 6.0];
let weights = vec![1.0, 1.0, 1.0];
let coeffs = weighted_least_squares(&x, &y, &weights).unwrap();
assert!((coeffs[0] - 0.0).abs() < 1e-10);
assert!((coeffs[1] - 2.0).abs() < 1e-10);
}
#[test]
fn test_column_equilibration_with_svd() {
let x_data = vec![
1.0, 0.0001, 1.0, 0.0002,
1.0, 0.0003,
1.0, 0.0004,
1.0, 0.0005,
];
let x = Matrix::new(5, 2, x_data);
let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let weights = vec![1.0, 1.0, 1.0, 1.0, 1.0];
let coeffs = weighted_least_squares(&x, &y, &weights).unwrap();
for i in 0..5 {
let x_val = (i + 1) as f64 * 0.0001;
let pred = coeffs[0] + coeffs[1] * x_val;
assert!((pred - y[i]).abs() < 0.1);
}
}
#[test]
fn test_weighted_least_squares_quadratic_rank_deficient() {
let x_data = vec![
1.0, 1.0, 1.0, 1.0, 1.0, 1.01, 1.0201, 1.030301,
1.0, 1.02, 1.0404, 1.061208,
1.0, 1.03, 1.0609, 1.092727,
];
let x = Matrix::new(4, 4, x_data);
let y = vec![2.0 + 3.0 * 1.0 + 0.5 * 1.0,
2.0 + 3.0 * 1.01 + 0.5 * 1.0201,
2.0 + 3.0 * 1.02 + 0.5 * 1.0404,
2.0 + 3.0 * 1.03 + 0.5 * 1.0609];
let weights = vec![1.0, 1.0, 1.0, 1.0];
let coeffs = weighted_least_squares(&x, &y, &weights).unwrap();
assert!(coeffs[0].is_finite());
assert!(coeffs[1].is_finite());
assert!(coeffs[2].is_finite());
assert!(coeffs[3].is_finite());
for i in 0..4 {
let xi = 1.0 + i as f64 * 0.01;
let pred = coeffs[0] + coeffs[1] * xi + coeffs[2] * xi * xi;
assert!((pred - y[i]).abs() < 0.5);
}
}
#[test]
fn test_pseudoinverse_minimum_neighbors() {
let x_data = vec![
1.0, 1.0,
1.0, 1.000001, ];
let x = Matrix::new(2, 2, x_data);
let y = vec![2.0, 2.000001];
let weights = vec![1.0, 1.0];
let coeffs = weighted_least_squares(&x, &y, &weights).unwrap();
assert!(coeffs[0].is_finite());
assert!(coeffs[1].is_finite());
}
#[test]
fn test_svd_handles_zero_variance_column() {
let x_data = vec![
1.0, 1.0, 5.0, 1.0, 2.0, 5.0, 1.0, 3.0, 5.0, 1.0, 4.0, 5.0, 1.0, 5.0, 5.0, ];
let x = Matrix::new(5, 3, x_data);
let y = vec![5.0, 7.0, 9.0, 11.0, 13.0];
let weights = vec![1.0, 1.0, 1.0, 1.0, 1.0];
let coeffs = weighted_least_squares(&x, &y, &weights).unwrap();
assert!(coeffs[0].is_finite());
assert!(coeffs[1].is_finite());
assert!(coeffs[2].is_finite());
for i in 0..5 {
let pred = coeffs[0] + coeffs[1] * (i + 1) as f64 + coeffs[2] * 5.0;
assert!((pred - y[i]).abs() < 1.0);
}
}
#[test]
fn test_weighted_least_squares_unchanged_behavior() {
let x = Matrix::new(5, 2, vec![1.0,1.0, 1.0,2.0, 1.0,3.0, 1.0,4.0, 1.0,5.0]);
let y = vec![2.1, 3.9, 6.1, 7.9, 10.0];
let w = vec![1.0, 2.0, 1.0, 2.0, 1.0];
let coeffs = weighted_least_squares(&x, &y, &w).unwrap();
let decomp = weighted_least_squares_with_decomposition(&x, &y, &w).unwrap();
assert_eq!(coeffs.len(), decomp.coefficients.len());
for i in 0..coeffs.len() {
assert!(
(coeffs[i] - decomp.coefficients[i]).abs() < 1e-15,
"Coefficient {} differs: {} vs {}",
i, coeffs[i], decomp.coefficients[i]
);
}
assert!(decomp.r_inv.is_some());
assert!(decomp.svd_components.is_none());
}
}