1use anyhow::{bail, Result};
8use ndarray::{Array1, Array2, Axis};
9
10use crate::linalg::{eigh, matrix_rank, qr_full_q};
11
12pub struct DupCorOutput {
14 pub consensus_correlation: f64,
17 pub atanh_correlations: Vec<f64>,
20}
21
22pub fn unwrapdups(m: &Array2<f64>, ndups: usize, spacing: usize) -> Array2<f64> {
26 if ndups == 1 {
27 return m.clone();
28 }
29 let nspots = m.nrows();
30 let nslides = m.ncols();
31 let ngroups = nspots / ndups / spacing;
32 let mut out = Array2::<f64>::zeros((spacing * ngroups, ndups * nslides));
33 for s in 0..spacing {
34 for d in 0..ndups {
35 for g in 0..ngroups {
36 let src_row = s + spacing * d + spacing * ndups * g;
37 for sl in 0..nslides {
38 out[[s + spacing * g, d + ndups * sl]] = m[[src_row, sl]];
39 }
40 }
41 }
42 }
43 out
44}
45
46fn block_indicator(groups: &[i64]) -> (Array2<f64>, usize) {
49 let mut levels: Vec<i64> = groups.to_vec();
50 levels.sort_unstable();
51 levels.dedup();
52 let n = groups.len();
53 let k = levels.len();
54 let mut z = Array2::<f64>::zeros((n, k));
55 for (r, &g) in groups.iter().enumerate() {
56 let c = levels.binary_search(&g).unwrap();
57 z[[r, c]] = 1.0;
58 }
59 (z, k)
60}
61
62pub(crate) fn kron_rows(design: &Array2<f64>, ndups: usize) -> Array2<f64> {
65 let (nr, nc) = design.dim();
66 let mut out = Array2::<f64>::zeros((nr * ndups, nc));
67 for sl in 0..nr {
68 for d in 0..ndups {
69 out.row_mut(sl * ndups + d).assign(&design.row(sl));
70 }
71 }
72 out
73}
74
75pub fn avedups(
81 x: &Array2<f64>,
82 ndups: usize,
83 spacing: usize,
84 weights: Option<&Array2<f64>>,
85) -> Array2<f64> {
86 if ndups == 1 {
87 return x.clone();
88 }
89 let nspots = x.nrows();
90 let nslides = x.ncols();
91 let ngroups = nspots / ndups / spacing;
92 let mut out = Array2::<f64>::zeros((spacing * ngroups, nslides));
93 for s in 0..spacing {
94 for g in 0..ngroups {
95 let rr = s + spacing * g;
96 for sl in 0..nslides {
97 let mut num = 0.0;
98 let mut den = 0.0;
99 let mut cnt = 0usize;
100 for d in 0..ndups {
101 let v = x[[s + spacing * d + spacing * ndups * g, sl]];
102 match weights {
103 None => {
104 if v.is_finite() {
105 num += v;
106 cnt += 1;
107 }
108 }
109 Some(w) => {
110 let mut wt = w[[s + spacing * d + spacing * ndups * g, sl]];
111 if wt.is_nan() || v.is_nan() || wt < 0.0 {
112 wt = 0.0;
113 }
114 if v.is_finite() {
115 num += wt * v;
116 }
117 den += wt;
118 }
119 }
120 }
121 out[[rr, sl]] = if weights.is_some() {
122 num / den } else if cnt > 0 {
124 num / cnt as f64
125 } else {
126 f64::NAN
127 };
128 }
129 }
130 }
131 out
132}
133
134pub fn uniquegenelist<T: Clone>(genelist: &[T], ndups: usize, spacing: usize) -> Vec<T> {
139 if ndups <= 1 {
140 return genelist.to_vec();
141 }
142 let ngroups = genelist.len() / ndups / spacing;
143 let m = spacing * ngroups;
144 (0..m)
145 .map(|rr| genelist[(rr % spacing) + spacing * ndups * (rr / spacing)].clone())
146 .collect()
147}
148
149struct Mm2Prep {
153 m: Array2<f64>,
154 d: Vec<f64>,
155 refine: bool,
156}
157
158fn mm2_prep(x: &Array2<f64>, z: &Array2<f64>) -> Option<Mm2Prep> {
163 let n = x.nrows();
164 let p = x.ncols();
165 if matrix_rank(x) < p {
166 return None;
167 }
168 let mq = n - p;
169 if mq == 0 {
170 return None;
171 }
172 let q = qr_full_q(x);
173 let q2 = q.slice(ndarray::s![.., p..n]).to_owned();
174 let qtz = q2.t().dot(z);
175 let s_mat = qtz.dot(&qtz.t());
176 let (evals, evecs) = eigh(&s_mat);
177 let d: Vec<f64> = evals.iter().map(|&e| e.max(0.0)).collect();
178 let w = q2.dot(&evecs);
179 let m = w.t().to_owned();
180 let nnz = d.iter().filter(|&&v| v.abs() > 1e-15).count();
181 let refine = mq > 2 && nnz > 1 && sample_var(&d) > 1e-15;
182 Some(Mm2Prep { m, d, refine })
183}
184
185fn mm2_varcomp(prep: &Mm2Prep, y: &[f64]) -> Option<(f64, f64)> {
188 let yv = Array1::from(y.to_vec());
189 let uqy = prep.m.dot(&yv);
190 let dy: Vec<f64> = uqy.iter().map(|&u| u * u).collect();
191 let (c0, c1, fitted) = ols2(&prep.d, &dy)?;
192 if !prep.refine {
193 return Some((c0, c1));
194 }
195 let start = if fitted.iter().all(|&f| f >= 0.0) {
196 (c0, c1)
197 } else {
198 (mean(&dy), 0.0)
199 };
200 glmgam_fit2(&prep.d, &dy, start, 1e-6, 20)
201}
202
203fn ols2(d: &[f64], dy: &[f64]) -> Option<(f64, f64, Vec<f64>)> {
206 let n = d.len() as f64;
207 let s1: f64 = d.iter().sum();
208 let s2: f64 = d.iter().map(|&v| v * v).sum();
209 let t0: f64 = dy.iter().sum();
210 let t1: f64 = d.iter().zip(dy).map(|(&v, &w)| v * w).sum();
211 let det = n * s2 - s1 * s1;
212 if det.abs() < 1e-300 {
213 return None;
214 }
215 let c0 = (s2 * t0 - s1 * t1) / det;
216 let c1 = (n * t1 - s1 * t0) / det;
217 let fitted: Vec<f64> = d.iter().map(|&v| c0 + c1 * v).collect();
218 Some((c0, c1, fitted))
219}
220
221fn glmgam_fit2(
226 d: &[f64],
227 dy: &[f64],
228 start: (f64, f64),
229 tol: f64,
230 maxit: usize,
231) -> Option<(f64, f64)> {
232 let (mut b0, mut b1) = start;
233 let mu_of = |b0: f64, b1: f64| -> Vec<f64> { d.iter().map(|&v| b0 + b1 * v).collect() };
234 let mut mu = mu_of(b0, b1);
235 if mu.iter().any(|&m| m < 0.0) {
236 return None;
237 }
238 let mut dev = deviance_gamma(dy, &mu);
239 let mut lambda = 0.0_f64;
240 let mut iter = 0usize;
241 loop {
242 iter += 1;
243 let mut v: Vec<f64> = mu.iter().map(|&m| m * m).collect();
245 let vmax = v.iter().cloned().fold(0.0_f64, f64::max);
246 let vfloor = vmax / 1e3;
247 for vi in v.iter_mut() {
248 *vi = vi.max(vfloor);
249 }
250 let a00: f64 = v.iter().map(|&vi| 1.0 / vi).sum();
252 let a01: f64 = d.iter().zip(&v).map(|(&dk, &vi)| dk / vi).sum();
253 let a11: f64 = d.iter().zip(&v).map(|(&dk, &vi)| dk * dk / vi).sum();
254 let maxinfo = a00.max(a11);
255 if iter == 1 {
256 lambda = ((a00 + a11) / 2.0).abs() / 2.0;
257 }
258 let dl0: f64 = dy
260 .iter()
261 .zip(&mu)
262 .zip(&v)
263 .map(|((&yk, &mk), &vi)| (yk - mk) / vi)
264 .sum();
265 let dl1: f64 = d
266 .iter()
267 .zip(dy.iter().zip(&mu).zip(&v))
268 .map(|(&dk, ((&yk, &mk), &vi))| dk * (yk - mk) / vi)
269 .sum();
270 let (b0_old, b1_old, dev_old) = (b0, b1, dev);
271 let mut lev = 0usize;
272 let mut dbeta;
273 loop {
274 lev += 1;
275 let det = (a00 + lambda) * (a11 + lambda) - a01 * a01;
276 let db0 = ((a11 + lambda) * dl0 - a01 * dl1) / det;
277 let db1 = ((a00 + lambda) * dl1 - a01 * dl0) / det;
278 dbeta = (db0, db1);
279 b0 = b0_old + db0;
280 b1 = b1_old + db1;
281 mu = mu_of(b0, b1);
282 dev = deviance_gamma(dy, &mu);
283 let max_mu = mu.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
284 if dev <= dev_old || dev / max_mu < 1e-15 {
285 break;
286 }
287 if lambda / maxinfo > 1e15 {
288 b0 = b0_old;
289 b1 = b1_old;
290 break;
291 }
292 lambda *= 2.0;
293 }
294 if lambda / maxinfo > 1e15 {
295 break;
296 }
297 if lev == 1 {
298 lambda /= 10.0;
299 }
300 let max_mu = mu.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
301 if dl0 * dbeta.0 + dl1 * dbeta.1 < tol || dev / max_mu < 1e-15 {
302 break;
303 }
304 if iter > maxit {
305 break;
306 }
307 }
308 Some((b0, b1))
309}
310
311fn deviance_gamma(y: &[f64], mu: &[f64]) -> f64 {
314 if mu.iter().any(|&m| m < 0.0) {
315 return f64::INFINITY;
316 }
317 let mut dev = 0.0;
318 for (&yk, &mk) in y.iter().zip(mu) {
319 if yk < 1e-15 && mk < 1e-15 {
320 continue;
321 }
322 dev += (yk - mk) / mk - (yk / mk).ln();
323 }
324 2.0 * dev
325}
326
327fn mean(v: &[f64]) -> f64 {
328 v.iter().sum::<f64>() / v.len() as f64
329}
330
331fn sample_var(v: &[f64]) -> f64 {
332 let n = v.len();
333 if n < 2 {
334 return 0.0;
335 }
336 let m = mean(v);
337 v.iter().map(|&x| (x - m) * (x - m)).sum::<f64>() / (n - 1) as f64
338}
339
340fn trimmed_atanh_mean(arho: &[f64], trim: f64) -> f64 {
342 let mut x: Vec<f64> = arho.iter().copied().filter(|v| v.is_finite()).collect();
343 let n = x.len();
344 if n == 0 {
345 return f64::NAN;
346 }
347 x.sort_by(|a, b| a.partial_cmp(b).unwrap());
348 let g = (n as f64 * trim).floor() as usize;
349 let kept = &x[g..n - g];
350 (kept.iter().sum::<f64>() / kept.len() as f64).tanh()
351}
352
353fn unique_count(groups: &[i64]) -> usize {
354 let mut v = groups.to_vec();
355 v.sort_unstable();
356 v.dedup();
357 v.len()
358}
359
360pub fn duplicate_correlation(
364 exprs: &Array2<f64>,
365 design: &Array2<f64>,
366 ndups: usize,
367 spacing: usize,
368 block: Option<&[i64]>,
369 trim: f64,
370) -> Result<DupCorOutput> {
371 let narrays = exprs.ncols();
372 if design.nrows() != narrays {
373 bail!("number of rows of design does not match number of arrays");
374 }
375 let nbeta = design.ncols();
376
377 let zero_result = |ngenes: usize| DupCorOutput {
378 consensus_correlation: 0.0,
379 atanh_correlations: vec![0.0; ngenes],
380 };
381
382 let (m_mat, design2, groups, rhomin) = if let Some(block) = block {
383 if block.len() != narrays {
384 bail!("length of block does not match number of arrays");
385 }
386 let max_block = {
387 let mut counts = std::collections::HashMap::<i64, usize>::new();
388 for &b in block {
389 *counts.entry(b).or_insert(0) += 1;
390 }
391 counts.values().copied().max().unwrap_or(0)
392 };
393 if max_block == 1 {
394 return Ok(zero_result(exprs.nrows()));
395 }
396 let rhomin = 1.0 / (1.0 - max_block as f64) + 0.01;
397 (exprs.clone(), design.clone(), block.to_vec(), rhomin)
398 } else {
399 if ndups < 2 {
400 return Ok(zero_result(exprs.nrows()));
401 }
402 let m_mat = unwrapdups(exprs, ndups, spacing);
403 let design2 = kron_rows(design, ndups);
404 let groups: Vec<i64> = (0..narrays)
405 .flat_map(|a| std::iter::repeat_n(a as i64, ndups))
406 .collect();
407 let rhomin = 1.0 / (1.0 - ndups as f64) + 0.01;
408 (m_mat, design2, groups, rhomin)
409 };
410
411 let ngenes = m_mat.nrows();
412 let ncols = m_mat.ncols();
413 let (full_z, _) = block_indicator(&groups);
414 let full_prep = mm2_prep(&design2, &full_z);
415
416 let mut rho = vec![f64::NAN; ngenes];
417 for i in 0..ngenes {
418 let yrow: Vec<f64> = m_mat.row(i).to_vec();
419 let obs: Vec<usize> = (0..ncols).filter(|&k| yrow[k].is_finite()).collect();
420 let nobs = obs.len();
421 let groups_o: Vec<i64> = obs.iter().map(|&k| groups[k]).collect();
422 let nblocks = unique_count(&groups_o);
423 if !(nobs > nbeta + 2 && nblocks > 1 && nblocks < nobs - 1) {
424 continue;
425 }
426 let varcomp = if nobs == ncols {
427 full_prep.as_ref().and_then(|p| mm2_varcomp(p, &yrow))
428 } else {
429 let ysub: Vec<f64> = obs.iter().map(|&k| m_mat[[i, k]]).collect();
430 let xsub = design2.select(Axis(0), &obs);
431 let (zsub, _) = block_indicator(&groups_o);
432 mm2_prep(&xsub, &zsub)
433 .as_ref()
434 .and_then(|p| mm2_varcomp(p, &ysub))
435 };
436 if let Some((res, blk)) = varcomp {
437 rho[i] = blk / (res + blk);
438 }
439 }
440
441 let rhomax = 0.99;
444 let min_incl0 = rho
445 .iter()
446 .copied()
447 .filter(|v| v.is_finite())
448 .fold(0.0_f64, f64::min);
449 if min_incl0 < rhomin {
450 for r in rho.iter_mut() {
451 if r.is_finite() && *r < rhomin {
452 *r = rhomin;
453 }
454 }
455 }
456 let max_incl0 = rho
457 .iter()
458 .copied()
459 .filter(|v| v.is_finite())
460 .fold(0.0_f64, f64::max);
461 if max_incl0 > rhomax {
462 for r in rho.iter_mut() {
463 if r.is_finite() && *r > rhomax {
464 *r = rhomax;
465 }
466 }
467 }
468
469 let arho: Vec<f64> = rho.iter().map(|&r| r.atanh()).collect();
470 let consensus = trimmed_atanh_mean(&arho, trim);
471 Ok(DupCorOutput {
472 consensus_correlation: consensus,
473 atanh_correlations: arho,
474 })
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use ndarray::array;
481
482 fn assert_vec_close(got: &[f64], want: &[f64], tol: f64) {
483 assert_eq!(got.len(), want.len());
484 for (a, b) in got.iter().zip(want) {
485 assert!((a - b).abs() < tol, "got {a} want {b}");
486 }
487 }
488
489 #[test]
492 fn block_correlation_matches_r() {
493 let exprs = array![
494 [5.10, 5.30, 6.20, 6.00, 7.10, 7.40, 4.10, 4.30],
495 [2.30, 2.10, 3.80, 3.50, 2.90, 3.10, 5.50, 5.20],
496 [7.70, 7.90, 8.10, 8.40, 6.90, 6.70, 7.20, 7.50],
497 [1.10, 1.40, 0.90, 1.20, 2.10, 1.90, 3.10, 2.80],
498 [9.30, 9.10, 8.80, 9.00, 9.50, 9.70, 8.20, 8.40],
499 [4.40, 4.10, 5.20, 5.50, 4.90, 4.60, 6.10, 6.40],
500 [6.60, 6.40, 6.90, 7.10, 5.80, 5.50, 6.20, 6.50],
501 [3.30, 3.60, 3.10, 2.80, 4.40, 4.70, 2.90, 2.60],
502 [8.10, 8.40, 7.60, 7.90, 8.80, 8.50, 7.10, 7.40],
503 [0.50, 0.80, 1.20, 0.90, 0.30, 0.60, 1.50, 1.20],
504 ];
505 let design = array![
506 [1.0, 0.0],
507 [1.0, 0.0],
508 [1.0, 0.0],
509 [1.0, 0.0],
510 [1.0, 1.0],
511 [1.0, 1.0],
512 [1.0, 1.0],
513 [1.0, 1.0],
514 ];
515 let block = [1, 1, 2, 2, 3, 3, 4, 4];
516 let out = duplicate_correlation(&exprs, &design, 2, 1, Some(&block), 0.15).unwrap();
517 let want_atanh = [
518 2.630357195885,
519 2.382400165790,
520 1.025085579690,
521 1.249128991481,
522 1.897744594586,
523 1.824607098861,
524 1.216131458151,
525 1.828923672433,
526 1.600469062091,
527 1.188743200584,
528 ];
529 assert_vec_close(&out.atanh_correlations, &want_atanh, 1e-6);
530 assert!((out.consensus_correlation - 0.928654049014294).abs() < 1e-6);
531 }
532
533 #[test]
534 fn ndups_correlation_matches_r() {
535 let exprs = array![
536 [5.1, 4.8, 6.2, 5.5],
537 [5.3, 5.0, 6.0, 5.7],
538 [2.3, 3.1, 2.8, 3.5],
539 [2.1, 2.9, 3.0, 3.3],
540 [7.7, 7.2, 8.1, 6.9],
541 [7.9, 7.4, 7.8, 7.1],
542 [1.1, 0.9, 1.4, 1.2],
543 [1.3, 1.1, 1.2, 1.0],
544 [9.3, 9.1, 8.8, 9.5],
545 [9.1, 8.9, 9.0, 9.3],
546 [4.4, 4.9, 5.2, 4.1],
547 [4.6, 4.7, 5.0, 4.3],
548 ];
549 let design = array![[1.0, 0.0], [1.0, 0.0], [1.0, 1.0], [1.0, 1.0]];
550 let out = duplicate_correlation(&exprs, &design, 2, 1, None, 0.15).unwrap();
551 let want_atanh = [
552 1.070033081748,
553 1.551171004306,
554 1.544437802637,
555 0.346573590280,
556 0.990500734433,
557 1.556757654605,
558 ];
559 assert_vec_close(&out.atanh_correlations, &want_atanh, 1e-6);
560 assert!((out.consensus_correlation - 0.826369823215032).abs() < 1e-6);
561 }
562
563 #[test]
564 fn unwrapdups_pairs_consecutive_rows() {
565 let m = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
566 let u = unwrapdups(&m, 2, 1);
567 assert_eq!(u, array![[1.0, 3.0, 2.0, 4.0], [5.0, 7.0, 6.0, 8.0]]);
568 }
569
570 fn nan() -> f64 {
571 f64::NAN
572 }
573
574 fn assert_mat_close(got: &Array2<f64>, want: &Array2<f64>) {
576 assert_eq!(got.dim(), want.dim());
577 for (a, b) in got.iter().zip(want.iter()) {
578 let ok = (a.is_nan() && b.is_nan()) || (a - b).abs() < 1e-12;
579 assert!(ok, "got {a} want {b}");
580 }
581 }
582
583 fn avedups_data() -> (Array2<f64>, Array2<f64>) {
584 let x = array![
585 [1.0, 2.0, 3.0],
586 [4.0, 5.0, 6.0],
587 [7.0, 8.0, 9.0],
588 [10.0, 11.0, 12.0],
589 [13.0, nan(), 15.0],
590 [16.0, 17.0, 18.0],
591 [nan(), nan(), 21.0],
592 [22.0, 23.0, 24.0],
593 ];
594 let w = array![
595 [1.0, 2.0, 0.5],
596 [3.0, 1.0, 2.0],
597 [0.0, 1.5, 1.0],
598 [2.0, 2.0, 2.0],
599 [1.0, 1.0, -1.0],
600 [4.0, 1.0, 1.0],
601 [1.0, 1.0, 1.0],
602 [2.0, 0.0, 3.0],
603 ];
604 (x, w)
605 }
606
607 #[test]
608 fn avedups_unweighted_matches_r() {
609 let (x, _) = avedups_data();
610 assert_mat_close(
611 &avedups(&x, 2, 1, None),
612 &array![
613 [2.5, 3.5, 4.5],
614 [8.5, 9.5, 10.5],
615 [14.5, 17.0, 16.5],
616 [22.0, 23.0, 22.5],
617 ],
618 );
619 assert_mat_close(
620 &avedups(&x, 2, 2, None),
621 &array![
622 [4.0, 5.0, 6.0],
623 [7.0, 8.0, 9.0],
624 [13.0, nan(), 18.0],
625 [19.0, 20.0, 21.0],
626 ],
627 );
628 assert_mat_close(
629 &avedups(&x, 4, 1, None),
630 &array![[5.5, 6.5, 7.5], [17.0, 20.0, 19.5]],
631 );
632 }
633
634 #[test]
635 fn avedups_weighted_matches_r() {
636 let (x, w) = avedups_data();
637 assert_mat_close(
638 &avedups(&x, 2, 1, Some(&w)),
639 &array![
640 [3.25, 3.0, 5.4],
641 [10.0, 9.71428571428571, 11.0],
642 [15.4, 17.0, 18.0],
643 [22.0, nan(), 23.25],
644 ],
645 );
646 assert_mat_close(
647 &avedups(&x, 2, 2, Some(&w)),
648 &array![
649 [1.0, 4.57142857142857, 7.0],
650 [6.4, 9.0, 9.0],
651 [13.0, nan(), 21.0],
652 [18.0, 17.0, 22.5],
653 ],
654 );
655 }
656
657 #[test]
658 fn uniquegenelist_matches_r() {
659 let g: Vec<i32> = (1..=8).collect();
660 assert_eq!(uniquegenelist(&g, 2, 1), vec![1, 3, 5, 7]);
661 assert_eq!(uniquegenelist(&g, 2, 2), vec![1, 2, 5, 6]);
662 assert_eq!(uniquegenelist(&g, 4, 1), vec![1, 5]);
663 assert_eq!(uniquegenelist(&g, 1, 1), g);
664 }
665}