1use anyhow::{bail, Result};
22use ndarray::{s, Array1, Array2};
23
24use crate::linalg::{qr_full_q, svd_left};
25
26pub fn wsva(
36 y: &Array2<f64>,
37 design: &Array2<f64>,
38 n_sv: usize,
39 weight_by_sd: bool,
40) -> Result<Array2<f64>> {
41 let narrays = y.ncols();
42 let p = design.ncols();
43 if design.nrows() != narrays {
44 bail!(
45 "row dimension of design ({}) must match number of arrays ({})",
46 design.nrows(),
47 narrays
48 );
49 }
50 if narrays <= p {
51 bail!("No residual df");
52 }
53 let d = narrays - p;
54 let n_sv = n_sv.max(1).min(d);
55
56 let mut m = if weight_by_sd {
58 wsva_weight_by_sd(y, design, p, n_sv)
59 } else {
60 let q = qr_full_q(design);
61 let resid = q.slice(s![.., p..]); let effects = y.dot(&resid); let (_svals, u) = svd_left(&effects, n_sv); y.t().dot(&u) };
66
67 for j in 0..n_sv {
69 let col = m.column(j);
70 let denom = (col.dot(&col) / narrays as f64).sqrt();
71 if denom > 0.0 {
72 m.column_mut(j).mapv_inplace(|v| v / denom);
73 }
74 }
75 Ok(m)
76}
77
78fn wsva_weight_by_sd(y: &Array2<f64>, design: &Array2<f64>, p: usize, n_sv: usize) -> Array2<f64> {
81 let ngenes = y.nrows();
82 let mut design_cur = design.to_owned();
83 for _ in 0..n_sv {
84 let p_cur = design_cur.ncols();
85 let q = qr_full_q(&design_cur);
86 let resid = q.slice(s![.., p_cur..]);
87 let eff = y.dot(&resid); let dcur = eff.ncols();
89
90 let mut s = Array1::<f64>::zeros(ngenes);
92 for g in 0..ngenes {
93 let row = eff.row(g);
94 s[g] = (row.dot(&row) / dcur as f64).sqrt();
95 }
96 let mut scaled = eff;
97 for g in 0..ngenes {
98 let sg = s[g];
99 scaled.row_mut(g).mapv_inplace(|v| v * sg);
100 }
101
102 let (_sv1, u1) = svd_left(&scaled, 1); let mut uvec = Array1::<f64>::zeros(ngenes);
104 for g in 0..ngenes {
105 uvec[g] = u1[[g, 0]] * s[g];
106 }
107 let svcol = uvec.dot(y); design_cur = append_col(&design_cur, &svcol);
109 }
110 design_cur.slice(s![.., p..]).to_owned()
111}
112
113fn append_col(base: &Array2<f64>, col: &Array1<f64>) -> Array2<f64> {
115 let n = base.nrows();
116 let p = base.ncols();
117 let mut out = Array2::<f64>::zeros((n, p + 1));
118 out.slice_mut(s![.., ..p]).assign(base);
119 out.slice_mut(s![.., p]).assign(col);
120 out
121}
122
123#[cfg(test)]
124#[allow(clippy::excessive_precision, clippy::approx_constant)]
125mod tests {
126 use super::*;
127 use crate::linalg::qr_full_q;
128
129 fn rclose(a: f64, b: f64) -> bool {
130 (a - b).abs() <= 1e-6 * (1.0 + b.abs())
131 }
132
133 fn col_matches_up_to_sign(got: &[f64], want: &[f64]) -> bool {
136 let pos = got.iter().zip(want).all(|(g, w)| rclose(*g, *w));
137 let neg = got.iter().zip(want).all(|(g, w)| rclose(*g, -*w));
138 pos || neg
139 }
140
141 fn fixture() -> (Array2<f64>, Array2<f64>) {
144 let intercept = [1.0; 6];
145 let grp_b = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
146 let l = [1.0, 0.0, -1.0, 1.0, 0.0, -1.0];
147 let l2 = [1.0, -2.0, 1.0, 1.0, -2.0, 1.0];
148 let m = [5.0, 6.0, 4.0, 7.0, 5.0, 6.0, 8.0, 4.0, 5.0, 6.0, 7.0, 5.0];
149 let b = [
150 1.0, -1.0, 2.0, 0.0, 1.0, -2.0, 1.0, 0.0, -1.0, 2.0, 1.0, 0.0,
151 ];
152 let a = [
153 3.0, 1.0, -2.0, 2.0, -3.0, 1.0, 2.0, -1.0, 3.0, -2.0, 1.0, -1.0,
154 ];
155 let cc = [
156 0.5, -0.3, 0.2, -0.4, 0.1, 0.3, -0.2, 0.4, -0.1, 0.2, -0.3, 0.1,
157 ];
158 let ngenes = 12;
159 let narrays = 6;
160 let mut y = Array2::<f64>::zeros((ngenes, narrays));
161 for g in 0..ngenes {
162 for k in 0..narrays {
163 y[[g, k]] = m[g] + b[g] * grp_b[k] + a[g] * l[k] + cc[g] * l2[k];
164 }
165 }
166 let mut design = Array2::<f64>::zeros((narrays, 2));
167 for k in 0..narrays {
168 design[[k, 0]] = intercept[k];
169 design[[k, 1]] = grp_b[k];
170 }
171 (y, design)
172 }
173
174 #[test]
175 fn svd_left_singular_values_match_r() {
176 let (y, design) = fixture();
177 let q = qr_full_q(&design);
178 let resid = q.slice(s![.., 2..]).to_owned();
179 let effects = y.dot(&resid);
180 let (svals, _u) = svd_left(&effects, 2);
181 assert!(rclose(svals[0], 13.890894185767216), "got {}", svals[0]);
183 assert!(rclose(svals[1], 3.3050051013301838), "got {}", svals[1]);
184 }
185
186 #[test]
187 fn wsva_n_sv_1_matches_r() {
188 let (y, design) = fixture();
189 let sv = wsva(&y, &design, 1, false).unwrap();
190 assert_eq!(sv.dim(), (6, 1));
191 let want = [
192 -1.6381244655650049,
193 -0.78750534836239416,
194 0.3081793572367994,
195 -1.4144384160296206,
196 -0.56381929882700987,
197 0.53186540677218364,
198 ];
199 let got: Vec<f64> = sv.column(0).to_vec();
200 assert!(col_matches_up_to_sign(&got, &want), "got {:?}", got);
201 let ss: f64 = got.iter().map(|v| v * v).sum();
203 assert!(rclose(ss, 6.0), "sum-sq {}", ss);
204 }
205
206 #[test]
207 fn wsva_n_sv_2_matches_r() {
208 let (y, design) = fixture();
209 let sv = wsva(&y, &design, 2, false).unwrap();
210 assert_eq!(sv.dim(), (6, 2));
211 let want0 = [
212 -1.6381244655650049,
213 -0.78750534836239416,
214 0.3081793572367994,
215 -1.4144384160296206,
216 -0.56381929882700987,
217 0.53186540677218364,
218 ];
219 let want1 = [
220 -1.2187101217580685,
221 0.1181758179300028,
222 -1.1110092752986038,
223 -1.3306824964699207,
224 0.0062034432181509447,
225 -1.2229816500104556,
226 ];
227 let got0: Vec<f64> = sv.column(0).to_vec();
228 let got1: Vec<f64> = sv.column(1).to_vec();
229 assert!(col_matches_up_to_sign(&got0, &want0), "col0 {:?}", got0);
230 assert!(col_matches_up_to_sign(&got1, &want1), "col1 {:?}", got1);
231 for j in 0..2 {
232 let ss: f64 = sv.column(j).iter().map(|v| v * v).sum();
233 assert!(rclose(ss, 6.0), "col{} sum-sq {}", j, ss);
234 }
235 }
236
237 #[test]
238 fn wsva_weight_by_sd_matches_r() {
239 let (y, design) = fixture();
240 let sv = wsva(&y, &design, 2, true).unwrap();
241 assert_eq!(sv.dim(), (6, 2));
242 let want0 = [
243 -1.6498101462161809,
244 -0.64340693693813988,
245 0.3707412714969342,
246 -1.4869600811704591,
247 -0.48055687189241814,
248 0.533591336542656,
249 ];
250 let want1 = [
251 -1.3427857158177972,
252 -0.010493921001334434,
253 -0.77610404808136568,
254 -1.5817416957409842,
255 -0.24944990092452146,
256 -1.0150600280045523,
257 ];
258 let got0: Vec<f64> = sv.column(0).to_vec();
259 let got1: Vec<f64> = sv.column(1).to_vec();
260 assert!(col_matches_up_to_sign(&got0, &want0), "col0 {:?}", got0);
261 assert!(col_matches_up_to_sign(&got1, &want1), "col1 {:?}", got1);
262 for j in 0..2 {
263 let ss: f64 = sv.column(j).iter().map(|v| v * v).sum();
264 assert!(rclose(ss, 6.0), "col{} sum-sq {}", j, ss);
265 }
266 }
267
268 #[test]
269 fn wsva_no_residual_df_errors() {
270 let y = Array2::<f64>::ones((4, 6));
272 let design = Array2::<f64>::eye(6);
273 assert!(wsva(&y, &design, 1, false).is_err());
274 }
275
276 #[test]
277 fn wsva_clamps_n_sv_to_residual_df() {
278 let (y, design) = fixture();
279 let sv = wsva(&y, &design, 99, false).unwrap();
281 assert_eq!(sv.ncols(), 4);
282 }
283}