1use anyhow::{bail, Result};
12use ndarray::{Array1, Array2};
13
14use crate::linalg::qr_econ;
15
16pub enum BeadDispersion<'a> {
20 Stdev(&'a Array2<f64>),
22 Stderr(&'a Array2<f64>),
24}
25
26pub struct BeadCountWeights {
28 pub weights: Array2<f64>,
31 pub var_biological: Array1<f64>,
33 pub var_technical: Array2<f64>,
35 pub cv_array: Array1<f64>,
37 pub cv_constant: f64,
39}
40
41#[allow(clippy::too_many_arguments)]
51pub fn bead_count_weights(
52 e_norm: &Array2<f64>,
53 e_raw: &Array2<f64>,
54 design: &Array2<f64>,
55 nbeads: &Array2<f64>,
56 dispersion: BeadDispersion,
57 array_cv: bool,
58 scale: bool,
59) -> Result<BeadCountWeights> {
60 let p = e_norm.nrows();
61 let a = e_norm.ncols();
62 if e_raw.dim() != (p, a) {
63 bail!("dimensions don't match");
64 }
65 if design.nrows() != a {
66 bail!("row dimension of design doesn't match column dimension of data");
67 }
68 if nbeads.dim() != (p, a) {
69 bail!("dimensions don't match");
70 }
71
72 let bead_stdev = match dispersion {
73 BeadDispersion::Stdev(s) => {
74 if s.dim() != (p, a) {
75 bail!("dimensions don't match");
76 }
77 s.to_owned()
78 }
79 BeadDispersion::Stderr(se) => {
80 if se.dim() != (p, a) {
81 bail!("dimensions don't match");
82 }
83 Array2::from_shape_fn((p, a), |(i, j)| se[[i, j]] * nbeads[[i, j]].sqrt())
84 }
85 };
86
87 let sqrt_cv =
89 Array2::from_shape_fn((p, a), |(i, j)| (bead_stdev[[i, j]] / e_raw[[i, j]]).sqrt());
90
91 let cv_array: Array1<f64> = (0..a)
93 .map(|j| {
94 let col: Vec<f64> = (0..p).map(|i| sqrt_cv[[i, j]]).collect();
95 let m = trimmed_mean_narm(&col, 0.125);
96 m * m
97 })
98 .collect();
99 let pooled: Vec<f64> = sqrt_cv.iter().copied().collect();
100 let cm = trimmed_mean_narm(&pooled, 0.125);
101 let cv_constant = cm * cm;
102
103 let cv = if array_cv {
105 Array2::from_shape_fn((p, a), |(_, j)| cv_array[j])
106 } else {
107 Array2::from_elem((p, a), cv_constant)
108 };
109
110 let ln2_sq = std::f64::consts::LN_2 * std::f64::consts::LN_2;
113 let var_technical = Array2::from_shape_fn((p, a), |(i, j)| {
114 (cv[[i, j]] * cv[[i, j]] / nbeads[[i, j]] + 1.0).ln() / ln2_sq
115 });
116
117 let (q, _r) = qr_econ(design); let k = q.ncols();
120 let h: Vec<f64> = (0..a)
121 .map(|i| (0..k).map(|c| q[[i, c]] * q[[i, c]]).sum())
122 .collect();
123 let mut r2 = Array2::<f64>::zeros((p, a));
124 for g in 0..p {
125 let y: Vec<f64> = (0..a).map(|j| e_norm[[g, j]]).collect();
126 let qty: Vec<f64> = (0..k)
127 .map(|c| (0..a).map(|j| q[[j, c]] * y[j]).sum())
128 .collect();
129 for j in 0..a {
130 let fitted: f64 = (0..k).map(|c| q[[j, c]] * qty[c]).sum();
131 let res = y[j] - fitted;
132 r2[[g, j]] = res * res;
133 }
134 }
135
136 let var_biological = ilmn_biological_variance(&var_technical, &r2, &h)?;
137
138 let mut weights = Array2::from_shape_fn((p, a), |(g, j)| {
140 1.0 / (var_technical[[g, j]] + var_biological[g])
141 });
142 for g in 0..p {
143 let rm: f64 = (0..a).map(|j| weights[[g, j]]).sum::<f64>() / a as f64;
144 for j in 0..a {
145 weights[[g, j]] /= rm;
146 }
147 }
148 if scale {
149 for g in 0..p {
150 let rs: f64 = (0..a).map(|j| 1.0 / var_technical[[g, j]]).sum::<f64>() / a as f64;
151 for j in 0..a {
152 weights[[g, j]] *= rs;
153 }
154 }
155 }
156
157 Ok(BeadCountWeights {
158 weights,
159 var_biological,
160 var_technical,
161 cv_array,
162 cv_constant,
163 })
164}
165
166fn ilmn_biological_variance(tv: &Array2<f64>, r2: &Array2<f64>, h: &[f64]) -> Result<Array1<f64>> {
170 let p = tv.nrows();
171 let a = tv.ncols();
172 if tv.iter().any(|&v| v < 0.0) || r2.iter().any(|&v| v < 0.0) {
173 bail!("negative variances not allowed");
174 }
175
176 let mut bv = vec![0.0f64; p];
177 let mut fvec = vec![0.0f64; p];
178 for g in 0..p {
179 let mut acc = 0.0;
180 for j in 0..a {
181 let t = tv[[g, j]];
182 acc += r2[[g, j]] / (2.0 * t * t) - (1.0 - h[j]) / (2.0 * t);
183 }
184 fvec[g] = acc / a as f64;
185 }
186 let mut active: Vec<bool> = (0..p).map(|g| fvec[g] > 0.0).collect();
187
188 let mut iter = 0u32;
189 while active.iter().any(|&x| x) {
190 iter += 1;
191 if iter > 200 {
192 break; }
194 for g in 0..p {
195 if !active[g] {
196 continue;
197 }
198 let mut fd = 0.0;
199 for j in 0..a {
200 let denom = bv[g] + tv[[g, j]];
201 fd += r2[[g, j]] / denom.powi(3) - (1.0 - h[j]) / (2.0 * denom * denom);
202 }
203 let fdash = -(fd / a as f64);
204 let step = -fvec[g] / fdash;
205 bv[g] += step;
206 let mut nf = 0.0;
207 for j in 0..a {
208 let denom = bv[g] + tv[[g, j]];
209 nf += r2[[g, j]] / (2.0 * denom * denom) - (1.0 - h[j]) / (2.0 * denom);
210 }
211 fvec[g] = nf / a as f64;
212 active[g] = step > 1e-5;
213 }
214 }
215
216 Ok(Array1::from(bv))
217}
218
219fn trimmed_mean_narm(xs: &[f64], trim: f64) -> f64 {
222 let mut v: Vec<f64> = xs.iter().copied().filter(|x| !x.is_nan()).collect();
223 let n = v.len();
224 v.sort_by(|a, b| a.partial_cmp(b).unwrap());
225 let lo = (trim * n as f64).floor() as usize;
226 let hi = n - lo;
227 let s = &v[lo..hi];
228 s.iter().sum::<f64>() / s.len() as f64
229}
230
231#[cfg(test)]
232#[allow(clippy::excessive_precision)]
233mod tests {
234 use super::*;
235
236 fn rclose(a: f64, b: f64) -> bool {
237 (a - b).abs() <= 1e-7 * (1.0 + b.abs())
238 }
239
240 fn fixture() -> (Array2<f64>, Array2<f64>, Array2<f64>, Array2<f64>) {
242 let (p, a) = (8usize, 4usize);
243 let mut enorm = Array2::zeros((p, a));
244 let mut eraw = Array2::zeros((p, a));
245 let mut nbeads = Array2::zeros((p, a));
246 let mut bstdev = Array2::zeros((p, a));
247 for g0 in 0..p {
248 for j0 in 0..a {
249 let (g, j) = (g0 as i64, j0 as i64);
250 enorm[[g0, j0]] =
251 5.0 + (g % 4) as f64 * 0.3 + ((g * 3 + j * 2) % 5 - 2) as f64 * 0.15;
252 eraw[[g0, j0]] =
253 50.0 + (g % 5) as f64 * 10.0 + j as f64 * 5.0 + ((g * 2 + j) % 7) as f64;
254 nbeads[[g0, j0]] = 20.0 + ((g + j) % 10) as f64;
255 bstdev[[g0, j0]] = 5.0 + (g % 3) as f64 * 1.5 + (j % 2) as f64 * 0.5;
256 }
257 }
258 (enorm, eraw, nbeads, bstdev)
259 }
260
261 fn design4() -> Array2<f64> {
262 Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]).unwrap()
263 }
264
265 fn assert_mat(out: &Array2<f64>, exp: &[[f64; 4]; 8], label: &str) {
266 for g in 0..8 {
267 for j in 0..4 {
268 assert!(
269 rclose(out[[g, j]], exp[g][j]),
270 "{label}[{g},{j}]: {} vs {}",
271 out[[g, j]],
272 exp[g][j]
273 );
274 }
275 }
276 }
277
278 #[test]
279 fn bead_count_weights_array_cv() {
280 let (e, x, nb, sd) = fixture();
281 let out = bead_count_weights(
282 &e,
283 &x,
284 &design4(),
285 &nb,
286 BeadDispersion::Stdev(&sd),
287 true,
288 false,
289 )
290 .unwrap();
291 let w = [
292 [
293 0.99782734558068886,
294 0.99858220483100923,
295 1.0017640081022272,
296 1.0018264414860747,
297 ],
298 [
299 0.99795142109299873,
300 0.99864815134429352,
301 1.001677550700262,
302 1.001722876862446,
303 ],
304 [
305 0.99686487564372694,
306 0.99790796071791255,
307 1.0025886936831236,
308 1.0026384699552373,
309 ],
310 [
311 0.99868108910280529,
312 0.99911248623705162,
313 1.0010967180710966,
314 1.0011097065890469,
315 ],
316 [
317 0.99718642392248746,
318 0.99808778549730581,
319 1.0023570293212056,
320 1.0023687612590013,
321 ],
322 [
323 0.99835495001703034,
324 0.99887346538654376,
325 1.0013868359620184,
326 1.0013847486344079,
327 ],
328 [
329 0.99842913135349864,
330 0.99891590226037741,
331 1.0013325480313711,
332 1.0013224183547531,
333 ],
334 [
335 0.99880592784455446,
336 0.99954887746182364,
337 1.0033290822752707,
338 0.99831611241835094,
339 ],
340 ];
341 assert_mat(&out.weights, &w, "bcw A");
342
343 let bio = [
344 0.072508375912570672,
345 0.072343975538660149,
346 0.04434429071263505,
347 0.10062096340379205,
348 0.044396127280804984,
349 0.072621855179473602,
350 0.072490525871691081,
351 0.044404095553014722,
352 ];
353 for (g, &b) in bio.iter().enumerate() {
354 assert!(rclose(out.var_biological[g], b), "bio[{g}]");
355 }
356 let cva = [
357 0.091824134833759868,
358 0.091067645980364265,
359 0.078902066838947571,
360 0.080362689765811415,
361 ];
362 for (j, &c) in cva.iter().enumerate() {
363 assert!(rclose(out.cv_array[j], c), "cv.array[{j}]");
364 }
365 assert!(rclose(out.cv_constant, 0.085513043675327097), "cv.constant");
366
367 assert!(rclose(out.var_technical[[0, 0]], 0.00087728608895408207));
369 assert!(rclose(out.var_technical[[7, 3]], 0.00067198240487293759));
370 }
371
372 #[test]
373 fn bead_count_weights_constant_cv() {
374 let (e, x, nb, sd) = fixture();
375 let out = bead_count_weights(
376 &e,
377 &x,
378 &design4(),
379 &nb,
380 BeadDispersion::Stdev(&sd),
381 false,
382 false,
383 )
384 .unwrap();
385 let w = [
386 [
387 0.99930095702677646,
388 0.99979572350253321,
389 1.0002459440283038,
390 1.0006573754423866,
391 ],
392 [
393 0.99936238568389368,
394 0.99981251521384185,
395 1.000223863553922,
396 1.0006012355483425,
397 ],
398 [
399 0.99905212502205476,
400 0.99971950893033579,
401 1.0003320697666715,
402 1.0008962962809378,
403 ],
404 [
405 0.99961286553655182,
406 0.99988496416876738,
407 1.0001354289590132,
408 1.0003667413356676,
409 ],
410 [
411 0.99919714223094125,
412 0.9997600888508188,
413 1.0002803012872656,
414 1.0007624676309743,
415 ],
416 [
417 0.99954310018584613,
418 0.9998629549019481,
419 1.0001593024646029,
420 1.0004346424476027,
421 ],
422 [
423 0.99957600343922848,
424 0.99987231147702482,
425 1.0001476147611972,
426 1.0004040703225496,
427 ],
428 [
429 1.0007641431549719,
430 1.0012122896222646,
431 1.0016298942856547,
432 0.9963936729371089,
433 ],
434 ];
435 assert_mat(&out.weights, &w, "bcw B");
436 }
437
438 #[test]
439 fn bead_count_weights_scaled() {
440 let (e, x, nb, sd) = fixture();
441 let out = bead_count_weights(
442 &e,
443 &x,
444 &design4(),
445 &nb,
446 BeadDispersion::Stdev(&sd),
447 true,
448 true,
449 )
450 .unwrap();
451 let w = [
452 [
453 1438.3935054717908,
454 1439.4816542863277,
455 1444.0683046535266,
456 1444.1583039647787,
457 ],
458 [
459 1505.0568251182424,
460 1506.1075963259523,
461 1510.6763740042441,
462 1510.7447324917509,
463 ],
464 [
465 1569.8302304364456,
466 1571.4728467251539,
467 1578.843912040946,
468 1578.9222980876727,
469 ],
470 [
471 1639.2234167489569,
472 1639.9315069412025,
473 1643.1884017819241,
474 1643.2097209828235,
475 ],
476 [
477 1703.2035875341246,
478 1704.7431213976697,
479 1712.0350291317538,
480 1712.0550673894586,
481 ],
482 [
483 1771.7107857931012,
484 1772.6309587964511,
485 1777.0912619753831,
486 1777.0875577408913,
487 ],
488 [
489 1838.3587171634942,
490 1839.2549846218658,
491 1843.7046362595136,
492 1843.685984981493,
493 ],
494 [
495 1719.8292154910105,
496 1721.108489493894,
497 1727.617568482578,
498 1718.9858095232457,
499 ],
500 ];
501 assert_mat(&out.weights, &w, "bcw C");
502 }
503
504 #[test]
505 fn bead_count_weights_stderr_path() {
506 let (e, x, nb, sd) = fixture();
507 let stderr = Array2::from_shape_fn((8, 4), |(i, j)| sd[[i, j]] / nb[[i, j]].sqrt());
509 let out = bead_count_weights(
510 &e,
511 &x,
512 &design4(),
513 &nb,
514 BeadDispersion::Stderr(&stderr),
515 true,
516 false,
517 )
518 .unwrap();
519 let ref_out = bead_count_weights(
521 &e,
522 &x,
523 &design4(),
524 &nb,
525 BeadDispersion::Stdev(&sd),
526 true,
527 false,
528 )
529 .unwrap();
530 for g in 0..8 {
531 for j in 0..4 {
532 assert!(
533 rclose(out.weights[[g, j]], ref_out.weights[[g, j]]),
534 "stderr[{g},{j}]"
535 );
536 }
537 }
538 }
539}