use crate::la::{ColF, MatF, col_zeros};
use anyhow::Result;
use faer::linalg::solvers::{SolveLstsq, Svd};
#[derive(Debug)]
pub struct IrwlsResult {
#[allow(dead_code)]
pub est: ColF,
#[allow(dead_code)]
pub jknife_se: Option<ColF>,
#[allow(dead_code)]
pub jknife_var: Option<ColF>,
pub jknife_cov: Option<MatF>,
pub delete_values: Option<MatF>,
}
#[allow(dead_code)]
pub fn irwls(x: &MatF, y: &ColF, weights: &mut ColF, n_iter: usize) -> Result<IrwlsResult> {
let n = x.nrows();
let p = x.ncols();
let mut xw = MatF::zeros(n, p);
let mut yw = col_zeros(n);
let mut est = col_zeros(p);
for _ in 0..n_iter {
for i in 0..n {
let sw = weights[(i, 0)].sqrt();
for j in 0..p {
xw[(i, j)] = x[(i, j)] * sw;
}
yw[(i, 0)] = y[(i, 0)] * sw;
}
let svd =
Svd::new(xw.as_ref()).map_err(|err| anyhow::anyhow!("svd for irwls: {:?}", err))?;
let rhs = svd.solve_lstsq(yw.as_ref());
est = rhs;
for i in 0..n {
let mut fitted = 0.0f64;
for j in 0..p {
fitted += x[(i, j)] * est[(j, 0)];
}
let f = fitted.max(1e-9); weights[(i, 0)] = 1.0 / (f * f);
}
}
Ok(IrwlsResult {
est,
jknife_se: None,
jknife_var: None,
jknife_cov: None,
delete_values: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::la::{col_from_vec, col_ones};
use faer::mat;
#[test]
fn test_irwls_trivial() {
let x = mat![[1.0, 0.0], [1.0, 1.0], [1.0, 2.0], [1.0, 3.0]]; let y = col_from_vec(vec![1.0, 3.0, 5.0, 7.0]);
let mut w = col_ones(4);
let res = irwls(&x, &y, &mut w, 2).unwrap();
let intercept = res.est[(0, 0)];
let slope = res.est[(1, 0)];
assert!((intercept - 1.0).abs() < 1e-6, "intercept={}", intercept);
assert!((slope - 2.0).abs() < 1e-6, "slope={}", slope);
}
#[test]
fn test_irwls_1d_constant() {
let x = mat![[1.0], [1.0], [1.0], [1.0]];
let y = col_from_vec(vec![1.0, 1.0, 1.0, 1.0]);
let mut w = col_ones(4);
let res = irwls(&x, &y, &mut w, 2).unwrap();
assert!(
(res.est[(0, 0)] - 1.0).abs() < 1e-9,
"coef={}",
res.est[(0, 0)]
);
}
#[test]
fn test_irwls_2d_sum() {
let x = mat![[1.0, 1.0], [1.0, 4.0], [1.0, 3.0], [1.0, 2.0]];
let y = col_from_vec(vec![2.0, 5.0, 4.0, 3.0]);
let mut w = col_ones(4);
let res = irwls(&x, &y, &mut w, 2).unwrap();
assert!(
(res.est[(0, 0)] - 1.0).abs() < 1e-6,
"coef0={}",
res.est[(0, 0)]
);
assert!(
(res.est[(1, 0)] - 1.0).abs() < 1e-6,
"coef1={}",
res.est[(1, 0)]
);
}
#[test]
fn test_irwls_nonuniform_weights() {
let x = mat![[1.0], [1.0], [1.0], [1.0]];
let y = col_from_vec(vec![1.0, 1.0, 1.0, 1.0]);
let mut w = col_from_vec(vec![1.0, 0.0, 0.0, 1.0]);
let res = irwls(&x, &y, &mut w, 2).unwrap();
assert!(
(res.est[(0, 0)] - 1.0).abs() < 1e-6,
"coef={}",
res.est[(0, 0)]
);
}
}