Skip to main content

limma/
wsva.rs

1//! Weighted surrogate variable analysis. Port of limma's `wsva` (`wsva.R`).
2//!
3//! `wsva` estimates surrogate variables that capture unwanted variation
4//! orthogonal to a known design. It projects the expression matrix onto the
5//! residual space of `design`, takes the leading left singular vectors of the
6//! resulting effect matrix, and turns them back into per-array surrogate
7//! variables (`SVᵀy`), each rescaled to unit mean-square.
8//!
9//! Two weighting modes are supported, matching R:
10//!   * default — one SVD of the residual-effect matrix yields `n_sv` surrogate
11//!     variables at once;
12//!   * `weight_by_sd = true` — surrogate variables are extracted one at a time,
13//!     each time row-weighting the effects by their residual standard deviation
14//!     and appending the new surrogate variable to the working design so the
15//!     next one is estimated orthogonal to it.
16//!
17//! Surrogate variables, like any singular vectors, are defined only up to sign;
18//! the returned columns match R's `wsva` up to a per-column sign flip. Each
19//! output column has unit mean-square (its squared entries sum to `narrays`).
20
21use anyhow::{bail, Result};
22use ndarray::{s, Array1, Array2};
23
24use crate::linalg::{qr_full_q, svd_left};
25
26/// Compute surrogate variables for `y` (genes x arrays) given `design`
27/// (arrays x coefficients).
28///
29/// `n_sv` is the number of surrogate variables to return; it is clamped to
30/// `[1, narrays - ncol(design)]` exactly as R does. Returns an
31/// `narrays x n_sv` matrix whose columns are the surrogate variables `SV1..`.
32///
33/// Errors if `design` has the wrong number of rows or if there are no residual
34/// degrees of freedom (`narrays <= ncol(design)`).
35pub fn wsva(
36    y: &Array2<f64>,
37    design: &Array2<f64>,
38    n_sv: usize,
39    weight_by_sd: bool,
40) -> Result<Array2<f64>> {
41    let narrays = y.ncols();
42    let p = design.ncols();
43    if design.nrows() != narrays {
44        bail!(
45            "row dimension of design ({}) must match number of arrays ({})",
46            design.nrows(),
47            narrays
48        );
49    }
50    if narrays <= p {
51        bail!("No residual df");
52    }
53    let d = narrays - p;
54    let n_sv = n_sv.max(1).min(d);
55
56    // M is narrays x n_sv: column j is the (un-normalised) surrogate variable j.
57    let mut m = if weight_by_sd {
58        wsva_weight_by_sd(y, design, p, n_sv)
59    } else {
60        let q = qr_full_q(design);
61        let resid = q.slice(s![.., p..]); // narrays x d, orthonormal residual basis
62        let effects = y.dot(&resid); // genes x d
63        let (_svals, u) = svd_left(&effects, n_sv); // genes x n_sv
64        y.t().dot(&u) // narrays x n_sv  == (uᵀy)ᵀ
65    };
66
67    // Rescale each surrogate variable to unit mean-square.
68    for j in 0..n_sv {
69        let col = m.column(j);
70        let denom = (col.dot(&col) / narrays as f64).sqrt();
71        if denom > 0.0 {
72            m.column_mut(j).mapv_inplace(|v| v / denom);
73        }
74    }
75    Ok(m)
76}
77
78/// `weight.by.sd = TRUE` path: extract surrogate variables one at a time,
79/// growing the design as we go. Returns the appended columns (narrays x n_sv).
80fn wsva_weight_by_sd(y: &Array2<f64>, design: &Array2<f64>, p: usize, n_sv: usize) -> Array2<f64> {
81    let ngenes = y.nrows();
82    let mut design_cur = design.to_owned();
83    for _ in 0..n_sv {
84        let p_cur = design_cur.ncols();
85        let q = qr_full_q(&design_cur);
86        let resid = q.slice(s![.., p_cur..]);
87        let eff = y.dot(&resid); // genes x (narrays - p_cur)
88        let dcur = eff.ncols();
89
90        // s[g] = sqrt(mean_c eff[g,c]^2); row-weight the effects by s.
91        let mut s = Array1::<f64>::zeros(ngenes);
92        for g in 0..ngenes {
93            let row = eff.row(g);
94            s[g] = (row.dot(&row) / dcur as f64).sqrt();
95        }
96        let mut scaled = eff;
97        for g in 0..ngenes {
98            let sg = s[g];
99            scaled.row_mut(g).mapv_inplace(|v| v * sg);
100        }
101
102        let (_sv1, u1) = svd_left(&scaled, 1); // genes x 1
103        let mut uvec = Array1::<f64>::zeros(ngenes);
104        for g in 0..ngenes {
105            uvec[g] = u1[[g, 0]] * s[g];
106        }
107        let svcol = uvec.dot(y); // length narrays  == uᵀy
108        design_cur = append_col(&design_cur, &svcol);
109    }
110    design_cur.slice(s![.., p..]).to_owned()
111}
112
113/// Return `base` with `col` appended as a new last column.
114fn append_col(base: &Array2<f64>, col: &Array1<f64>) -> Array2<f64> {
115    let n = base.nrows();
116    let p = base.ncols();
117    let mut out = Array2::<f64>::zeros((n, p + 1));
118    out.slice_mut(s![.., ..p]).assign(base);
119    out.slice_mut(s![.., p]).assign(col);
120    out
121}
122
123#[cfg(test)]
124#[allow(clippy::excessive_precision, clippy::approx_constant)]
125mod tests {
126    use super::*;
127    use crate::linalg::qr_full_q;
128
129    fn rclose(a: f64, b: f64) -> bool {
130        (a - b).abs() <= 1e-6 * (1.0 + b.abs())
131    }
132
133    /// True if `got` matches `want` or `-want` elementwise (singular vectors are
134    /// only defined up to sign).
135    fn col_matches_up_to_sign(got: &[f64], want: &[f64]) -> bool {
136        let pos = got.iter().zip(want).all(|(g, w)| rclose(*g, *w));
137        let neg = got.iter().zip(want).all(|(g, w)| rclose(*g, -*w));
138        pos || neg
139    }
140
141    /// Build the rank-2 residual fixture from scratch/wsva_ref.R:
142    /// Y[g,k] = m[g] + b[g]*grpB[k] + a[g]*L[k] + cc[g]*L2[k].
143    fn fixture() -> (Array2<f64>, Array2<f64>) {
144        let intercept = [1.0; 6];
145        let grp_b = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
146        let l = [1.0, 0.0, -1.0, 1.0, 0.0, -1.0];
147        let l2 = [1.0, -2.0, 1.0, 1.0, -2.0, 1.0];
148        let m = [5.0, 6.0, 4.0, 7.0, 5.0, 6.0, 8.0, 4.0, 5.0, 6.0, 7.0, 5.0];
149        let b = [
150            1.0, -1.0, 2.0, 0.0, 1.0, -2.0, 1.0, 0.0, -1.0, 2.0, 1.0, 0.0,
151        ];
152        let a = [
153            3.0, 1.0, -2.0, 2.0, -3.0, 1.0, 2.0, -1.0, 3.0, -2.0, 1.0, -1.0,
154        ];
155        let cc = [
156            0.5, -0.3, 0.2, -0.4, 0.1, 0.3, -0.2, 0.4, -0.1, 0.2, -0.3, 0.1,
157        ];
158        let ngenes = 12;
159        let narrays = 6;
160        let mut y = Array2::<f64>::zeros((ngenes, narrays));
161        for g in 0..ngenes {
162            for k in 0..narrays {
163                y[[g, k]] = m[g] + b[g] * grp_b[k] + a[g] * l[k] + cc[g] * l2[k];
164            }
165        }
166        let mut design = Array2::<f64>::zeros((narrays, 2));
167        for k in 0..narrays {
168            design[[k, 0]] = intercept[k];
169            design[[k, 1]] = grp_b[k];
170        }
171        (y, design)
172    }
173
174    #[test]
175    fn svd_left_singular_values_match_r() {
176        let (y, design) = fixture();
177        let q = qr_full_q(&design);
178        let resid = q.slice(s![.., 2..]).to_owned();
179        let effects = y.dot(&resid);
180        let (svals, _u) = svd_left(&effects, 2);
181        // svd(Effects)$d[1:2] from R.
182        assert!(rclose(svals[0], 13.890894185767216), "got {}", svals[0]);
183        assert!(rclose(svals[1], 3.3050051013301838), "got {}", svals[1]);
184    }
185
186    #[test]
187    fn wsva_n_sv_1_matches_r() {
188        let (y, design) = fixture();
189        let sv = wsva(&y, &design, 1, false).unwrap();
190        assert_eq!(sv.dim(), (6, 1));
191        let want = [
192            -1.6381244655650049,
193            -0.78750534836239416,
194            0.3081793572367994,
195            -1.4144384160296206,
196            -0.56381929882700987,
197            0.53186540677218364,
198        ];
199        let got: Vec<f64> = sv.column(0).to_vec();
200        assert!(col_matches_up_to_sign(&got, &want), "got {:?}", got);
201        // Unit mean-square.
202        let ss: f64 = got.iter().map(|v| v * v).sum();
203        assert!(rclose(ss, 6.0), "sum-sq {}", ss);
204    }
205
206    #[test]
207    fn wsva_n_sv_2_matches_r() {
208        let (y, design) = fixture();
209        let sv = wsva(&y, &design, 2, false).unwrap();
210        assert_eq!(sv.dim(), (6, 2));
211        let want0 = [
212            -1.6381244655650049,
213            -0.78750534836239416,
214            0.3081793572367994,
215            -1.4144384160296206,
216            -0.56381929882700987,
217            0.53186540677218364,
218        ];
219        let want1 = [
220            -1.2187101217580685,
221            0.1181758179300028,
222            -1.1110092752986038,
223            -1.3306824964699207,
224            0.0062034432181509447,
225            -1.2229816500104556,
226        ];
227        let got0: Vec<f64> = sv.column(0).to_vec();
228        let got1: Vec<f64> = sv.column(1).to_vec();
229        assert!(col_matches_up_to_sign(&got0, &want0), "col0 {:?}", got0);
230        assert!(col_matches_up_to_sign(&got1, &want1), "col1 {:?}", got1);
231        for j in 0..2 {
232            let ss: f64 = sv.column(j).iter().map(|v| v * v).sum();
233            assert!(rclose(ss, 6.0), "col{} sum-sq {}", j, ss);
234        }
235    }
236
237    #[test]
238    fn wsva_weight_by_sd_matches_r() {
239        let (y, design) = fixture();
240        let sv = wsva(&y, &design, 2, true).unwrap();
241        assert_eq!(sv.dim(), (6, 2));
242        let want0 = [
243            -1.6498101462161809,
244            -0.64340693693813988,
245            0.3707412714969342,
246            -1.4869600811704591,
247            -0.48055687189241814,
248            0.533591336542656,
249        ];
250        let want1 = [
251            -1.3427857158177972,
252            -0.010493921001334434,
253            -0.77610404808136568,
254            -1.5817416957409842,
255            -0.24944990092452146,
256            -1.0150600280045523,
257        ];
258        let got0: Vec<f64> = sv.column(0).to_vec();
259        let got1: Vec<f64> = sv.column(1).to_vec();
260        assert!(col_matches_up_to_sign(&got0, &want0), "col0 {:?}", got0);
261        assert!(col_matches_up_to_sign(&got1, &want1), "col1 {:?}", got1);
262        for j in 0..2 {
263            let ss: f64 = sv.column(j).iter().map(|v| v * v).sum();
264            assert!(rclose(ss, 6.0), "col{} sum-sq {}", j, ss);
265        }
266    }
267
268    #[test]
269    fn wsva_no_residual_df_errors() {
270        // 6 arrays, 6-column design -> no residual df.
271        let y = Array2::<f64>::ones((4, 6));
272        let design = Array2::<f64>::eye(6);
273        assert!(wsva(&y, &design, 1, false).is_err());
274    }
275
276    #[test]
277    fn wsva_clamps_n_sv_to_residual_df() {
278        let (y, design) = fixture();
279        // d = 4; asking for 99 surrogate variables clamps to 4.
280        let sv = wsva(&y, &design, 99, false).unwrap();
281        assert_eq!(sv.ncols(), 4);
282    }
283}