use ndarray::{Array1, Array2, ArrayView2};
use crate::faer_ndarray::{
FaerArrayView, factorize_symmetricwith_fallback, fast_ab, fast_xt_diag_x, fast_xt_diag_y,
};
use crate::matrix::FactorizedSystem;
use faer::Side;
pub const ORTHOGONAL_PROJECTION_RELATIVE_RIDGE: f64 = 1.0e-10;
pub const ORTHOGONAL_PROJECTION_RIDGE_FLOOR: f64 = 1.0e-12;
#[derive(Debug, Clone)]
pub struct OrthogonalReparam {
shear: Array2<f64>,
confound_orthogonal: Array2<f64>,
}
impl OrthogonalReparam {
pub fn build_unconditional(
primary: ArrayView2<f64>,
confound: ArrayView2<f64>,
w_metric: &Array1<f64>,
) -> Result<Self, String> {
let n = primary.nrows();
if confound.nrows() != n {
return Err(format!(
"orthogonal_reparam: primary rows ({n}) != confound rows ({})",
confound.nrows()
));
}
if w_metric.len() != n {
return Err(format!(
"orthogonal_reparam: row metric length ({}) != design rows ({n})",
w_metric.len()
));
}
if w_metric.iter().any(|v| !v.is_finite() || *v < 0.0) {
return Err(
"orthogonal_reparam: row metric must be finite and non-negative".to_string(),
);
}
let p_m = primary.ncols();
let p_c = confound.ncols();
if p_m == 0 || p_c == 0 {
return Ok(Self {
shear: Array2::<f64>::zeros((p_m, p_c)),
confound_orthogonal: confound.to_owned(),
});
}
let mut gram = fast_xt_diag_x(&primary, w_metric);
let gram_scale = (0..p_m).map(|i| gram[[i, i]]).fold(0.0_f64, f64::max);
let eps = (gram_scale * ORTHOGONAL_PROJECTION_RELATIVE_RIDGE)
.max(ORTHOGONAL_PROJECTION_RIDGE_FLOOR);
for i in 0..p_m {
gram[[i, i]] += eps;
}
let cross = fast_xt_diag_y(&primary, w_metric, &confound.to_owned());
let gram_view = FaerArrayView::new(&gram);
let factor =
factorize_symmetricwith_fallback(gram_view.as_ref(), Side::Lower).map_err(|e| {
format!("orthogonal_reparam: weighted primary Gram factorization failed: {e:?}")
})?;
let shear = factor
.solvemulti(&cross)
.map_err(|e| format!("orthogonal_reparam: projection solve failed: {e}"))?;
let projection = fast_ab(&primary, &shear);
let confound_orthogonal = &confound - &projection;
if shear.iter().any(|v| !v.is_finite())
|| confound_orthogonal.iter().any(|v| !v.is_finite())
{
return Err(
"orthogonal_reparam: reparameterization produced non-finite entries".to_string(),
);
}
Ok(Self {
shear,
confound_orthogonal,
})
}
#[inline]
pub fn shear(&self) -> ArrayView2<'_, f64> {
self.shear.view()
}
#[inline]
pub fn reparameterized_confound(&self) -> ArrayView2<'_, f64> {
self.confound_orthogonal.view()
}
#[inline]
pub fn primary_cols(&self) -> usize {
self.shear.nrows()
}
#[inline]
pub fn confound_cols(&self) -> usize {
self.shear.ncols()
}
pub fn recover_original(
&self,
beta_m_reparam: &Array1<f64>,
beta_c: &Array1<f64>,
) -> Result<(Array1<f64>, Array1<f64>), String> {
let p_m = self.primary_cols();
let p_c = self.confound_cols();
if beta_m_reparam.len() != p_m {
return Err(format!(
"orthogonal_reparam: reparameterized primary coeffs length ({}) != p_m ({p_m})",
beta_m_reparam.len()
));
}
if beta_c.len() != p_c {
return Err(format!(
"orthogonal_reparam: confound coeffs length ({}) != p_c ({p_c})",
beta_c.len()
));
}
let shear_beta_c = self.shear.dot(beta_c);
let beta_m = beta_m_reparam - &shear_beta_c;
Ok((beta_m, beta_c.clone()))
}
pub fn to_reparameterized(
&self,
beta_m: &Array1<f64>,
beta_c: &Array1<f64>,
) -> Result<Array1<f64>, String> {
let p_m = self.primary_cols();
let p_c = self.confound_cols();
if beta_m.len() != p_m {
return Err(format!(
"orthogonal_reparam: primary coeffs length ({}) != p_m ({p_m})",
beta_m.len()
));
}
if beta_c.len() != p_c {
return Err(format!(
"orthogonal_reparam: confound coeffs length ({}) != p_c ({p_c})",
beta_c.len()
));
}
let shear_beta_c = self.shear.dot(beta_c);
Ok(beta_m + &shear_beta_c)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2};
#[test]
fn orthogonalized_confound_is_w_orthogonal_to_primary() {
let n = 50;
let mut m = Array2::<f64>::zeros((n, 3));
let mut c = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let t = i as f64 / n as f64;
m[[i, 0]] = 1.0;
m[[i, 1]] = t;
m[[i, 2]] = (t * 6.0).sin();
c[[i, 0]] = t + 0.01 * (t * 13.0).cos();
c[[i, 1]] = (t * 3.0).cos();
}
let w = Array1::<f64>::from_elem(n, 1.0);
let reparam = OrthogonalReparam::build_unconditional(m.view(), c.view(), &w)
.expect("build should succeed");
let c_tilde = reparam.reparameterized_confound().to_owned();
let cross = fast_xt_diag_y(&m, &w, &c_tilde);
let max_abs = cross.iter().fold(0.0_f64, |a, v| a.max(v.abs()));
assert!(
max_abs < 1e-8,
"MᵀW C̃ not orthogonal: max |entry| = {max_abs:e}"
);
}
#[test]
fn coefficient_round_trip_is_exact() {
let n = 40;
let mut m = Array2::<f64>::zeros((n, 2));
let mut c = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let t = i as f64 / n as f64;
m[[i, 0]] = 1.0;
m[[i, 1]] = (t * 4.0).sin();
c[[i, 0]] = t; c[[i, 1]] = (t * 2.0).cos();
}
let w = Array1::<f64>::from_elem(n, 1.0);
let reparam = OrthogonalReparam::build_unconditional(m.view(), c.view(), &w)
.expect("build should succeed");
let beta_m_reparam = Array1::from_vec(vec![0.7, -1.3]);
let beta_c = Array1::from_vec(vec![2.1, 0.4]);
let c_tilde = reparam.reparameterized_confound().to_owned();
let eta_reparam = m.dot(&beta_m_reparam) + c_tilde.dot(&beta_c);
let (beta_m, beta_c_out) = reparam
.recover_original(&beta_m_reparam, &beta_c)
.expect("recover should succeed");
let eta_original = m.dot(&beta_m) + c.dot(&beta_c_out);
let max_diff = (&eta_reparam - &eta_original)
.iter()
.fold(0.0_f64, |a, v| a.max(v.abs()));
assert!(
max_diff < 1e-10,
"predictor changed under round-trip: max |Δη| = {max_diff:e}"
);
let cdiff = (&beta_c_out - &beta_c)
.iter()
.fold(0.0_f64, |a, v| a.max(v.abs()));
assert!(cdiff == 0.0, "confound coeffs changed: {cdiff:e}");
let back = reparam
.to_reparameterized(&beta_m, &beta_c)
.expect("forward should succeed");
let fdiff = (&back - &beta_m_reparam)
.iter()
.fold(0.0_f64, |a, v| a.max(v.abs()));
assert!(fdiff < 1e-10, "forward/inverse mismatch: {fdiff:e}");
}
#[test]
fn absent_confound_leaves_design_and_predictions_unchanged() {
let n = 30;
let mut m = Array2::<f64>::zeros((n, 2));
let mut raw_quad = Vec::with_capacity(n);
for i in 0..n {
let t = i as f64 / (n as f64 - 1.0);
m[[i, 0]] = 1.0;
m[[i, 1]] = t;
raw_quad.push(t * t);
}
let w = Array1::<f64>::from_elem(n, 1.0);
let gram = fast_xt_diag_x(&m, &w);
let quad = Array1::from_vec(raw_quad);
let cross = m.t().dot(&quad);
let gview = FaerArrayView::new(&gram);
let factor = factorize_symmetricwith_fallback(gview.as_ref(), Side::Lower).expect("factor");
let b = FactorizedSystem::solve(&factor, &cross).expect("solve");
let resid = &quad - &m.dot(&b);
let mut c = Array2::<f64>::zeros((n, 1));
c.column_mut(0).assign(&resid);
let reparam = OrthogonalReparam::build_unconditional(m.view(), c.view(), &w)
.expect("build should succeed");
let shear_max = reparam.shear().iter().fold(0.0_f64, |a, v| a.max(v.abs()));
assert!(shear_max < 1e-8, "expected ~zero shear, got {shear_max:e}");
let c_tilde = reparam.reparameterized_confound().to_owned();
let design_diff = (&c_tilde - &c).iter().fold(0.0_f64, |a, v| a.max(v.abs()));
assert!(
design_diff < 1e-8,
"orthogonalized design drifted from raw when confound absent: {design_diff:e}"
);
}
#[test]
fn empty_primary_returns_raw_confound() {
let n = 8;
let m = Array2::<f64>::zeros((n, 0));
let mut c = Array2::<f64>::zeros((n, 2));
for i in 0..n {
c[[i, 0]] = i as f64;
c[[i, 1]] = 1.0;
}
let w = Array1::<f64>::from_elem(n, 1.0);
let reparam = OrthogonalReparam::build_unconditional(m.view(), c.view(), &w)
.expect("build should succeed");
let c_tilde = reparam.reparameterized_confound().to_owned();
let diff = (&c_tilde - &c).iter().fold(0.0_f64, |a, v| a.max(v.abs()));
assert!(diff == 0.0, "empty primary must return raw confound");
}
}