use anyhow::{bail, Result};
use ndarray::{s, Array1, Array2};
use crate::linalg::{qr_full_q, svd_left};
pub fn wsva(
y: &Array2<f64>,
design: &Array2<f64>,
n_sv: usize,
weight_by_sd: bool,
) -> Result<Array2<f64>> {
let narrays = y.ncols();
let p = design.ncols();
if design.nrows() != narrays {
bail!(
"row dimension of design ({}) must match number of arrays ({})",
design.nrows(),
narrays
);
}
if narrays <= p {
bail!("No residual df");
}
let d = narrays - p;
let n_sv = n_sv.max(1).min(d);
let mut m = if weight_by_sd {
wsva_weight_by_sd(y, design, p, n_sv)
} else {
let q = qr_full_q(design);
let resid = q.slice(s![.., p..]); let effects = y.dot(&resid); let (_svals, u) = svd_left(&effects, n_sv); y.t().dot(&u) };
for j in 0..n_sv {
let col = m.column(j);
let denom = (col.dot(&col) / narrays as f64).sqrt();
if denom > 0.0 {
m.column_mut(j).mapv_inplace(|v| v / denom);
}
}
Ok(m)
}
fn wsva_weight_by_sd(y: &Array2<f64>, design: &Array2<f64>, p: usize, n_sv: usize) -> Array2<f64> {
let ngenes = y.nrows();
let mut design_cur = design.to_owned();
for _ in 0..n_sv {
let p_cur = design_cur.ncols();
let q = qr_full_q(&design_cur);
let resid = q.slice(s![.., p_cur..]);
let eff = y.dot(&resid); let dcur = eff.ncols();
let mut s = Array1::<f64>::zeros(ngenes);
for g in 0..ngenes {
let row = eff.row(g);
s[g] = (row.dot(&row) / dcur as f64).sqrt();
}
let mut scaled = eff;
for g in 0..ngenes {
let sg = s[g];
scaled.row_mut(g).mapv_inplace(|v| v * sg);
}
let (_sv1, u1) = svd_left(&scaled, 1); let mut uvec = Array1::<f64>::zeros(ngenes);
for g in 0..ngenes {
uvec[g] = u1[[g, 0]] * s[g];
}
let svcol = uvec.dot(y); design_cur = append_col(&design_cur, &svcol);
}
design_cur.slice(s![.., p..]).to_owned()
}
fn append_col(base: &Array2<f64>, col: &Array1<f64>) -> Array2<f64> {
let n = base.nrows();
let p = base.ncols();
let mut out = Array2::<f64>::zeros((n, p + 1));
out.slice_mut(s![.., ..p]).assign(base);
out.slice_mut(s![.., p]).assign(col);
out
}
#[cfg(test)]
#[allow(clippy::excessive_precision, clippy::approx_constant)]
mod tests {
use super::*;
use crate::linalg::qr_full_q;
fn rclose(a: f64, b: f64) -> bool {
(a - b).abs() <= 1e-6 * (1.0 + b.abs())
}
fn col_matches_up_to_sign(got: &[f64], want: &[f64]) -> bool {
let pos = got.iter().zip(want).all(|(g, w)| rclose(*g, *w));
let neg = got.iter().zip(want).all(|(g, w)| rclose(*g, -*w));
pos || neg
}
fn fixture() -> (Array2<f64>, Array2<f64>) {
let intercept = [1.0; 6];
let grp_b = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let l = [1.0, 0.0, -1.0, 1.0, 0.0, -1.0];
let l2 = [1.0, -2.0, 1.0, 1.0, -2.0, 1.0];
let m = [5.0, 6.0, 4.0, 7.0, 5.0, 6.0, 8.0, 4.0, 5.0, 6.0, 7.0, 5.0];
let b = [
1.0, -1.0, 2.0, 0.0, 1.0, -2.0, 1.0, 0.0, -1.0, 2.0, 1.0, 0.0,
];
let a = [
3.0, 1.0, -2.0, 2.0, -3.0, 1.0, 2.0, -1.0, 3.0, -2.0, 1.0, -1.0,
];
let cc = [
0.5, -0.3, 0.2, -0.4, 0.1, 0.3, -0.2, 0.4, -0.1, 0.2, -0.3, 0.1,
];
let ngenes = 12;
let narrays = 6;
let mut y = Array2::<f64>::zeros((ngenes, narrays));
for g in 0..ngenes {
for k in 0..narrays {
y[[g, k]] = m[g] + b[g] * grp_b[k] + a[g] * l[k] + cc[g] * l2[k];
}
}
let mut design = Array2::<f64>::zeros((narrays, 2));
for k in 0..narrays {
design[[k, 0]] = intercept[k];
design[[k, 1]] = grp_b[k];
}
(y, design)
}
#[test]
fn svd_left_singular_values_match_r() {
let (y, design) = fixture();
let q = qr_full_q(&design);
let resid = q.slice(s![.., 2..]).to_owned();
let effects = y.dot(&resid);
let (svals, _u) = svd_left(&effects, 2);
assert!(rclose(svals[0], 13.890894185767216), "got {}", svals[0]);
assert!(rclose(svals[1], 3.3050051013301838), "got {}", svals[1]);
}
#[test]
fn wsva_n_sv_1_matches_r() {
let (y, design) = fixture();
let sv = wsva(&y, &design, 1, false).unwrap();
assert_eq!(sv.dim(), (6, 1));
let want = [
-1.6381244655650049,
-0.78750534836239416,
0.3081793572367994,
-1.4144384160296206,
-0.56381929882700987,
0.53186540677218364,
];
let got: Vec<f64> = sv.column(0).to_vec();
assert!(col_matches_up_to_sign(&got, &want), "got {:?}", got);
let ss: f64 = got.iter().map(|v| v * v).sum();
assert!(rclose(ss, 6.0), "sum-sq {}", ss);
}
#[test]
fn wsva_n_sv_2_matches_r() {
let (y, design) = fixture();
let sv = wsva(&y, &design, 2, false).unwrap();
assert_eq!(sv.dim(), (6, 2));
let want0 = [
-1.6381244655650049,
-0.78750534836239416,
0.3081793572367994,
-1.4144384160296206,
-0.56381929882700987,
0.53186540677218364,
];
let want1 = [
-1.2187101217580685,
0.1181758179300028,
-1.1110092752986038,
-1.3306824964699207,
0.0062034432181509447,
-1.2229816500104556,
];
let got0: Vec<f64> = sv.column(0).to_vec();
let got1: Vec<f64> = sv.column(1).to_vec();
assert!(col_matches_up_to_sign(&got0, &want0), "col0 {:?}", got0);
assert!(col_matches_up_to_sign(&got1, &want1), "col1 {:?}", got1);
for j in 0..2 {
let ss: f64 = sv.column(j).iter().map(|v| v * v).sum();
assert!(rclose(ss, 6.0), "col{} sum-sq {}", j, ss);
}
}
#[test]
fn wsva_weight_by_sd_matches_r() {
let (y, design) = fixture();
let sv = wsva(&y, &design, 2, true).unwrap();
assert_eq!(sv.dim(), (6, 2));
let want0 = [
-1.6498101462161809,
-0.64340693693813988,
0.3707412714969342,
-1.4869600811704591,
-0.48055687189241814,
0.533591336542656,
];
let want1 = [
-1.3427857158177972,
-0.010493921001334434,
-0.77610404808136568,
-1.5817416957409842,
-0.24944990092452146,
-1.0150600280045523,
];
let got0: Vec<f64> = sv.column(0).to_vec();
let got1: Vec<f64> = sv.column(1).to_vec();
assert!(col_matches_up_to_sign(&got0, &want0), "col0 {:?}", got0);
assert!(col_matches_up_to_sign(&got1, &want1), "col1 {:?}", got1);
for j in 0..2 {
let ss: f64 = sv.column(j).iter().map(|v| v * v).sum();
assert!(rclose(ss, 6.0), "col{} sum-sq {}", j, ss);
}
}
#[test]
fn wsva_no_residual_df_errors() {
let y = Array2::<f64>::ones((4, 6));
let design = Array2::<f64>::eye(6);
assert!(wsva(&y, &design, 1, false).is_err());
}
#[test]
fn wsva_clamps_n_sv_to_residual_df() {
let (y, design) = fixture();
let sv = wsva(&y, &design, 99, false).unwrap();
assert_eq!(sv.ncols(), 4);
}
}