limma-rust 0.1.0

Pure-Rust port of the Bioconductor limma differential-expression package
Documentation
//! Weighted surrogate variable analysis. Port of limma's `wsva` (`wsva.R`).
//!
//! `wsva` estimates surrogate variables that capture unwanted variation
//! orthogonal to a known design. It projects the expression matrix onto the
//! residual space of `design`, takes the leading left singular vectors of the
//! resulting effect matrix, and turns them back into per-array surrogate
//! variables (`SVᵀy`), each rescaled to unit mean-square.
//!
//! Two weighting modes are supported, matching R:
//!   * default — one SVD of the residual-effect matrix yields `n_sv` surrogate
//!     variables at once;
//!   * `weight_by_sd = true` — surrogate variables are extracted one at a time,
//!     each time row-weighting the effects by their residual standard deviation
//!     and appending the new surrogate variable to the working design so the
//!     next one is estimated orthogonal to it.
//!
//! Surrogate variables, like any singular vectors, are defined only up to sign;
//! the returned columns match R's `wsva` up to a per-column sign flip. Each
//! output column has unit mean-square (its squared entries sum to `narrays`).

use anyhow::{bail, Result};
use ndarray::{s, Array1, Array2};

use crate::linalg::{qr_full_q, svd_left};

/// Compute surrogate variables for `y` (genes x arrays) given `design`
/// (arrays x coefficients).
///
/// `n_sv` is the number of surrogate variables to return; it is clamped to
/// `[1, narrays - ncol(design)]` exactly as R does. Returns an
/// `narrays x n_sv` matrix whose columns are the surrogate variables `SV1..`.
///
/// Errors if `design` has the wrong number of rows or if there are no residual
/// degrees of freedom (`narrays <= ncol(design)`).
pub fn wsva(
    y: &Array2<f64>,
    design: &Array2<f64>,
    n_sv: usize,
    weight_by_sd: bool,
) -> Result<Array2<f64>> {
    let narrays = y.ncols();
    let p = design.ncols();
    if design.nrows() != narrays {
        bail!(
            "row dimension of design ({}) must match number of arrays ({})",
            design.nrows(),
            narrays
        );
    }
    if narrays <= p {
        bail!("No residual df");
    }
    let d = narrays - p;
    let n_sv = n_sv.max(1).min(d);

    // M is narrays x n_sv: column j is the (un-normalised) surrogate variable j.
    let mut m = if weight_by_sd {
        wsva_weight_by_sd(y, design, p, n_sv)
    } else {
        let q = qr_full_q(design);
        let resid = q.slice(s![.., p..]); // narrays x d, orthonormal residual basis
        let effects = y.dot(&resid); // genes x d
        let (_svals, u) = svd_left(&effects, n_sv); // genes x n_sv
        y.t().dot(&u) // narrays x n_sv  == (uᵀy)ᵀ
    };

    // Rescale each surrogate variable to unit mean-square.
    for j in 0..n_sv {
        let col = m.column(j);
        let denom = (col.dot(&col) / narrays as f64).sqrt();
        if denom > 0.0 {
            m.column_mut(j).mapv_inplace(|v| v / denom);
        }
    }
    Ok(m)
}

/// `weight.by.sd = TRUE` path: extract surrogate variables one at a time,
/// growing the design as we go. Returns the appended columns (narrays x n_sv).
fn wsva_weight_by_sd(y: &Array2<f64>, design: &Array2<f64>, p: usize, n_sv: usize) -> Array2<f64> {
    let ngenes = y.nrows();
    let mut design_cur = design.to_owned();
    for _ in 0..n_sv {
        let p_cur = design_cur.ncols();
        let q = qr_full_q(&design_cur);
        let resid = q.slice(s![.., p_cur..]);
        let eff = y.dot(&resid); // genes x (narrays - p_cur)
        let dcur = eff.ncols();

        // s[g] = sqrt(mean_c eff[g,c]^2); row-weight the effects by s.
        let mut s = Array1::<f64>::zeros(ngenes);
        for g in 0..ngenes {
            let row = eff.row(g);
            s[g] = (row.dot(&row) / dcur as f64).sqrt();
        }
        let mut scaled = eff;
        for g in 0..ngenes {
            let sg = s[g];
            scaled.row_mut(g).mapv_inplace(|v| v * sg);
        }

        let (_sv1, u1) = svd_left(&scaled, 1); // genes x 1
        let mut uvec = Array1::<f64>::zeros(ngenes);
        for g in 0..ngenes {
            uvec[g] = u1[[g, 0]] * s[g];
        }
        let svcol = uvec.dot(y); // length narrays  == uᵀy
        design_cur = append_col(&design_cur, &svcol);
    }
    design_cur.slice(s![.., p..]).to_owned()
}

/// Return `base` with `col` appended as a new last column.
fn append_col(base: &Array2<f64>, col: &Array1<f64>) -> Array2<f64> {
    let n = base.nrows();
    let p = base.ncols();
    let mut out = Array2::<f64>::zeros((n, p + 1));
    out.slice_mut(s![.., ..p]).assign(base);
    out.slice_mut(s![.., p]).assign(col);
    out
}

#[cfg(test)]
#[allow(clippy::excessive_precision, clippy::approx_constant)]
mod tests {
    use super::*;
    use crate::linalg::qr_full_q;

    fn rclose(a: f64, b: f64) -> bool {
        (a - b).abs() <= 1e-6 * (1.0 + b.abs())
    }

    /// True if `got` matches `want` or `-want` elementwise (singular vectors are
    /// only defined up to sign).
    fn col_matches_up_to_sign(got: &[f64], want: &[f64]) -> bool {
        let pos = got.iter().zip(want).all(|(g, w)| rclose(*g, *w));
        let neg = got.iter().zip(want).all(|(g, w)| rclose(*g, -*w));
        pos || neg
    }

    /// Build the rank-2 residual fixture from scratch/wsva_ref.R:
    /// Y[g,k] = m[g] + b[g]*grpB[k] + a[g]*L[k] + cc[g]*L2[k].
    fn fixture() -> (Array2<f64>, Array2<f64>) {
        let intercept = [1.0; 6];
        let grp_b = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
        let l = [1.0, 0.0, -1.0, 1.0, 0.0, -1.0];
        let l2 = [1.0, -2.0, 1.0, 1.0, -2.0, 1.0];
        let m = [5.0, 6.0, 4.0, 7.0, 5.0, 6.0, 8.0, 4.0, 5.0, 6.0, 7.0, 5.0];
        let b = [
            1.0, -1.0, 2.0, 0.0, 1.0, -2.0, 1.0, 0.0, -1.0, 2.0, 1.0, 0.0,
        ];
        let a = [
            3.0, 1.0, -2.0, 2.0, -3.0, 1.0, 2.0, -1.0, 3.0, -2.0, 1.0, -1.0,
        ];
        let cc = [
            0.5, -0.3, 0.2, -0.4, 0.1, 0.3, -0.2, 0.4, -0.1, 0.2, -0.3, 0.1,
        ];
        let ngenes = 12;
        let narrays = 6;
        let mut y = Array2::<f64>::zeros((ngenes, narrays));
        for g in 0..ngenes {
            for k in 0..narrays {
                y[[g, k]] = m[g] + b[g] * grp_b[k] + a[g] * l[k] + cc[g] * l2[k];
            }
        }
        let mut design = Array2::<f64>::zeros((narrays, 2));
        for k in 0..narrays {
            design[[k, 0]] = intercept[k];
            design[[k, 1]] = grp_b[k];
        }
        (y, design)
    }

    #[test]
    fn svd_left_singular_values_match_r() {
        let (y, design) = fixture();
        let q = qr_full_q(&design);
        let resid = q.slice(s![.., 2..]).to_owned();
        let effects = y.dot(&resid);
        let (svals, _u) = svd_left(&effects, 2);
        // svd(Effects)$d[1:2] from R.
        assert!(rclose(svals[0], 13.890894185767216), "got {}", svals[0]);
        assert!(rclose(svals[1], 3.3050051013301838), "got {}", svals[1]);
    }

    #[test]
    fn wsva_n_sv_1_matches_r() {
        let (y, design) = fixture();
        let sv = wsva(&y, &design, 1, false).unwrap();
        assert_eq!(sv.dim(), (6, 1));
        let want = [
            -1.6381244655650049,
            -0.78750534836239416,
            0.3081793572367994,
            -1.4144384160296206,
            -0.56381929882700987,
            0.53186540677218364,
        ];
        let got: Vec<f64> = sv.column(0).to_vec();
        assert!(col_matches_up_to_sign(&got, &want), "got {:?}", got);
        // Unit mean-square.
        let ss: f64 = got.iter().map(|v| v * v).sum();
        assert!(rclose(ss, 6.0), "sum-sq {}", ss);
    }

    #[test]
    fn wsva_n_sv_2_matches_r() {
        let (y, design) = fixture();
        let sv = wsva(&y, &design, 2, false).unwrap();
        assert_eq!(sv.dim(), (6, 2));
        let want0 = [
            -1.6381244655650049,
            -0.78750534836239416,
            0.3081793572367994,
            -1.4144384160296206,
            -0.56381929882700987,
            0.53186540677218364,
        ];
        let want1 = [
            -1.2187101217580685,
            0.1181758179300028,
            -1.1110092752986038,
            -1.3306824964699207,
            0.0062034432181509447,
            -1.2229816500104556,
        ];
        let got0: Vec<f64> = sv.column(0).to_vec();
        let got1: Vec<f64> = sv.column(1).to_vec();
        assert!(col_matches_up_to_sign(&got0, &want0), "col0 {:?}", got0);
        assert!(col_matches_up_to_sign(&got1, &want1), "col1 {:?}", got1);
        for j in 0..2 {
            let ss: f64 = sv.column(j).iter().map(|v| v * v).sum();
            assert!(rclose(ss, 6.0), "col{} sum-sq {}", j, ss);
        }
    }

    #[test]
    fn wsva_weight_by_sd_matches_r() {
        let (y, design) = fixture();
        let sv = wsva(&y, &design, 2, true).unwrap();
        assert_eq!(sv.dim(), (6, 2));
        let want0 = [
            -1.6498101462161809,
            -0.64340693693813988,
            0.3707412714969342,
            -1.4869600811704591,
            -0.48055687189241814,
            0.533591336542656,
        ];
        let want1 = [
            -1.3427857158177972,
            -0.010493921001334434,
            -0.77610404808136568,
            -1.5817416957409842,
            -0.24944990092452146,
            -1.0150600280045523,
        ];
        let got0: Vec<f64> = sv.column(0).to_vec();
        let got1: Vec<f64> = sv.column(1).to_vec();
        assert!(col_matches_up_to_sign(&got0, &want0), "col0 {:?}", got0);
        assert!(col_matches_up_to_sign(&got1, &want1), "col1 {:?}", got1);
        for j in 0..2 {
            let ss: f64 = sv.column(j).iter().map(|v| v * v).sum();
            assert!(rclose(ss, 6.0), "col{} sum-sq {}", j, ss);
        }
    }

    #[test]
    fn wsva_no_residual_df_errors() {
        // 6 arrays, 6-column design -> no residual df.
        let y = Array2::<f64>::ones((4, 6));
        let design = Array2::<f64>::eye(6);
        assert!(wsva(&y, &design, 1, false).is_err());
    }

    #[test]
    fn wsva_clamps_n_sv_to_residual_df() {
        let (y, design) = fixture();
        // d = 4; asking for 99 surrogate variables clamps to 4.
        let sv = wsva(&y, &design, 99, false).unwrap();
        assert_eq!(sv.ncols(), 4);
    }
}