use anyhow::Result;
use ndarray::{Array1, Array2};
use ndarray_linalg::LeastSquaresSvd;
#[derive(Debug)]
pub struct IrwlsResult {
pub est: Array1<f64>,
pub jknife_se: Option<Array1<f64>>,
#[allow(dead_code)]
pub jknife_var: Option<Array1<f64>>,
pub jknife_cov: Option<Array2<f64>>,
pub delete_values: Option<Array2<f64>>,
}
pub fn irwls(
x: &Array2<f64>,
y: &Array1<f64>,
weights: &mut Array1<f64>,
n_iter: usize,
) -> Result<IrwlsResult> {
let n = x.nrows();
let p = x.ncols();
let mut xw = Array2::<f64>::zeros((n, p));
let mut yw = Array1::<f64>::zeros(n);
let mut est = Array1::<f64>::zeros(p);
for _ in 0..n_iter {
for i in 0..n {
let sw = weights[i].sqrt();
xw.row_mut(i).assign(&(&x.row(i) * sw));
yw[i] = y[i] * sw;
}
let result = xw.least_squares(&yw)?;
est.assign(&result.solution);
let fitted = x.dot(&est);
for i in 0..n {
let f = fitted[i].max(1e-9); weights[i] = 1.0 / (f * f);
}
}
Ok(IrwlsResult {
est,
jknife_se: None,
jknife_var: None,
jknife_cov: None,
delete_values: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_irwls_trivial() {
let x = array![[1.0, 0.0], [1.0, 1.0], [1.0, 2.0], [1.0, 3.0]]; let y = array![1.0, 3.0, 5.0, 7.0];
let mut w = Array1::ones(4);
let res = irwls(&x, &y, &mut w, 2).unwrap();
let intercept = res.est[0];
let slope = res.est[1];
assert!((intercept - 1.0).abs() < 1e-6, "intercept={}", intercept);
assert!((slope - 2.0).abs() < 1e-6, "slope={}", slope);
}
#[test]
fn test_irwls_1d_constant() {
let x = array![[1.0], [1.0], [1.0], [1.0]];
let y = array![1.0, 1.0, 1.0, 1.0];
let mut w = Array1::ones(4);
let res = irwls(&x, &y, &mut w, 2).unwrap();
assert!((res.est[0] - 1.0).abs() < 1e-9, "coef={}", res.est[0]);
}
#[test]
fn test_irwls_2d_sum() {
let x = array![[1.0, 1.0], [1.0, 4.0], [1.0, 3.0], [1.0, 2.0]];
let y = array![2.0, 5.0, 4.0, 3.0];
let mut w = Array1::ones(4);
let res = irwls(&x, &y, &mut w, 2).unwrap();
assert!((res.est[0] - 1.0).abs() < 1e-6, "coef0={}", res.est[0]);
assert!((res.est[1] - 1.0).abs() < 1e-6, "coef1={}", res.est[1]);
}
#[test]
fn test_irwls_nonuniform_weights() {
let x = array![[1.0], [1.0], [1.0], [1.0]];
let y = array![1.0, 1.0, 1.0, 1.0];
let mut w = array![1.0, 0.0, 0.0, 1.0];
let res = irwls(&x, &y, &mut w, 2).unwrap();
assert!((res.est[0] - 1.0).abs() < 1e-6, "coef={}", res.est[0]);
}
}