use anyhow::{bail, Result};
use ndarray::{Array1, Array2, Axis};
use crate::linalg::{eigh, matrix_rank, qr_full_q};
pub struct DupCorOutput {
pub consensus_correlation: f64,
pub atanh_correlations: Vec<f64>,
}
pub fn unwrapdups(m: &Array2<f64>, ndups: usize, spacing: usize) -> Array2<f64> {
if ndups == 1 {
return m.clone();
}
let nspots = m.nrows();
let nslides = m.ncols();
let ngroups = nspots / ndups / spacing;
let mut out = Array2::<f64>::zeros((spacing * ngroups, ndups * nslides));
for s in 0..spacing {
for d in 0..ndups {
for g in 0..ngroups {
let src_row = s + spacing * d + spacing * ndups * g;
for sl in 0..nslides {
out[[s + spacing * g, d + ndups * sl]] = m[[src_row, sl]];
}
}
}
}
out
}
fn block_indicator(groups: &[i64]) -> (Array2<f64>, usize) {
let mut levels: Vec<i64> = groups.to_vec();
levels.sort_unstable();
levels.dedup();
let n = groups.len();
let k = levels.len();
let mut z = Array2::<f64>::zeros((n, k));
for (r, &g) in groups.iter().enumerate() {
let c = levels.binary_search(&g).unwrap();
z[[r, c]] = 1.0;
}
(z, k)
}
pub(crate) fn kron_rows(design: &Array2<f64>, ndups: usize) -> Array2<f64> {
let (nr, nc) = design.dim();
let mut out = Array2::<f64>::zeros((nr * ndups, nc));
for sl in 0..nr {
for d in 0..ndups {
out.row_mut(sl * ndups + d).assign(&design.row(sl));
}
}
out
}
pub fn avedups(
x: &Array2<f64>,
ndups: usize,
spacing: usize,
weights: Option<&Array2<f64>>,
) -> Array2<f64> {
if ndups == 1 {
return x.clone();
}
let nspots = x.nrows();
let nslides = x.ncols();
let ngroups = nspots / ndups / spacing;
let mut out = Array2::<f64>::zeros((spacing * ngroups, nslides));
for s in 0..spacing {
for g in 0..ngroups {
let rr = s + spacing * g;
for sl in 0..nslides {
let mut num = 0.0;
let mut den = 0.0;
let mut cnt = 0usize;
for d in 0..ndups {
let v = x[[s + spacing * d + spacing * ndups * g, sl]];
match weights {
None => {
if v.is_finite() {
num += v;
cnt += 1;
}
}
Some(w) => {
let mut wt = w[[s + spacing * d + spacing * ndups * g, sl]];
if wt.is_nan() || v.is_nan() || wt < 0.0 {
wt = 0.0;
}
if v.is_finite() {
num += wt * v;
}
den += wt;
}
}
}
out[[rr, sl]] = if weights.is_some() {
num / den } else if cnt > 0 {
num / cnt as f64
} else {
f64::NAN
};
}
}
}
out
}
pub fn uniquegenelist<T: Clone>(genelist: &[T], ndups: usize, spacing: usize) -> Vec<T> {
if ndups <= 1 {
return genelist.to_vec();
}
let ngroups = genelist.len() / ndups / spacing;
let m = spacing * ngroups;
(0..m)
.map(|rr| genelist[(rr % spacing) + spacing * ndups * (rr / spacing)].clone())
.collect()
}
struct Mm2Prep {
m: Array2<f64>,
d: Vec<f64>,
refine: bool,
}
fn mm2_prep(x: &Array2<f64>, z: &Array2<f64>) -> Option<Mm2Prep> {
let n = x.nrows();
let p = x.ncols();
if matrix_rank(x) < p {
return None;
}
let mq = n - p;
if mq == 0 {
return None;
}
let q = qr_full_q(x);
let q2 = q.slice(ndarray::s![.., p..n]).to_owned();
let qtz = q2.t().dot(z);
let s_mat = qtz.dot(&qtz.t());
let (evals, evecs) = eigh(&s_mat);
let d: Vec<f64> = evals.iter().map(|&e| e.max(0.0)).collect();
let w = q2.dot(&evecs);
let m = w.t().to_owned();
let nnz = d.iter().filter(|&&v| v.abs() > 1e-15).count();
let refine = mq > 2 && nnz > 1 && sample_var(&d) > 1e-15;
Some(Mm2Prep { m, d, refine })
}
fn mm2_varcomp(prep: &Mm2Prep, y: &[f64]) -> Option<(f64, f64)> {
let yv = Array1::from(y.to_vec());
let uqy = prep.m.dot(&yv);
let dy: Vec<f64> = uqy.iter().map(|&u| u * u).collect();
let (c0, c1, fitted) = ols2(&prep.d, &dy)?;
if !prep.refine {
return Some((c0, c1));
}
let start = if fitted.iter().all(|&f| f >= 0.0) {
(c0, c1)
} else {
(mean(&dy), 0.0)
};
glmgam_fit2(&prep.d, &dy, start, 1e-6, 20)
}
fn ols2(d: &[f64], dy: &[f64]) -> Option<(f64, f64, Vec<f64>)> {
let n = d.len() as f64;
let s1: f64 = d.iter().sum();
let s2: f64 = d.iter().map(|&v| v * v).sum();
let t0: f64 = dy.iter().sum();
let t1: f64 = d.iter().zip(dy).map(|(&v, &w)| v * w).sum();
let det = n * s2 - s1 * s1;
if det.abs() < 1e-300 {
return None;
}
let c0 = (s2 * t0 - s1 * t1) / det;
let c1 = (n * t1 - s1 * t0) / det;
let fitted: Vec<f64> = d.iter().map(|&v| c0 + c1 * v).collect();
Some((c0, c1, fitted))
}
fn glmgam_fit2(
d: &[f64],
dy: &[f64],
start: (f64, f64),
tol: f64,
maxit: usize,
) -> Option<(f64, f64)> {
let (mut b0, mut b1) = start;
let mu_of = |b0: f64, b1: f64| -> Vec<f64> { d.iter().map(|&v| b0 + b1 * v).collect() };
let mut mu = mu_of(b0, b1);
if mu.iter().any(|&m| m < 0.0) {
return None;
}
let mut dev = deviance_gamma(dy, &mu);
let mut lambda = 0.0_f64;
let mut iter = 0usize;
loop {
iter += 1;
let mut v: Vec<f64> = mu.iter().map(|&m| m * m).collect();
let vmax = v.iter().cloned().fold(0.0_f64, f64::max);
let vfloor = vmax / 1e3;
for vi in v.iter_mut() {
*vi = vi.max(vfloor);
}
let a00: f64 = v.iter().map(|&vi| 1.0 / vi).sum();
let a01: f64 = d.iter().zip(&v).map(|(&dk, &vi)| dk / vi).sum();
let a11: f64 = d.iter().zip(&v).map(|(&dk, &vi)| dk * dk / vi).sum();
let maxinfo = a00.max(a11);
if iter == 1 {
lambda = ((a00 + a11) / 2.0).abs() / 2.0;
}
let dl0: f64 = dy
.iter()
.zip(&mu)
.zip(&v)
.map(|((&yk, &mk), &vi)| (yk - mk) / vi)
.sum();
let dl1: f64 = d
.iter()
.zip(dy.iter().zip(&mu).zip(&v))
.map(|(&dk, ((&yk, &mk), &vi))| dk * (yk - mk) / vi)
.sum();
let (b0_old, b1_old, dev_old) = (b0, b1, dev);
let mut lev = 0usize;
let mut dbeta;
loop {
lev += 1;
let det = (a00 + lambda) * (a11 + lambda) - a01 * a01;
let db0 = ((a11 + lambda) * dl0 - a01 * dl1) / det;
let db1 = ((a00 + lambda) * dl1 - a01 * dl0) / det;
dbeta = (db0, db1);
b0 = b0_old + db0;
b1 = b1_old + db1;
mu = mu_of(b0, b1);
dev = deviance_gamma(dy, &mu);
let max_mu = mu.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if dev <= dev_old || dev / max_mu < 1e-15 {
break;
}
if lambda / maxinfo > 1e15 {
b0 = b0_old;
b1 = b1_old;
break;
}
lambda *= 2.0;
}
if lambda / maxinfo > 1e15 {
break;
}
if lev == 1 {
lambda /= 10.0;
}
let max_mu = mu.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if dl0 * dbeta.0 + dl1 * dbeta.1 < tol || dev / max_mu < 1e-15 {
break;
}
if iter > maxit {
break;
}
}
Some((b0, b1))
}
fn deviance_gamma(y: &[f64], mu: &[f64]) -> f64 {
if mu.iter().any(|&m| m < 0.0) {
return f64::INFINITY;
}
let mut dev = 0.0;
for (&yk, &mk) in y.iter().zip(mu) {
if yk < 1e-15 && mk < 1e-15 {
continue;
}
dev += (yk - mk) / mk - (yk / mk).ln();
}
2.0 * dev
}
fn mean(v: &[f64]) -> f64 {
v.iter().sum::<f64>() / v.len() as f64
}
fn sample_var(v: &[f64]) -> f64 {
let n = v.len();
if n < 2 {
return 0.0;
}
let m = mean(v);
v.iter().map(|&x| (x - m) * (x - m)).sum::<f64>() / (n - 1) as f64
}
fn trimmed_atanh_mean(arho: &[f64], trim: f64) -> f64 {
let mut x: Vec<f64> = arho.iter().copied().filter(|v| v.is_finite()).collect();
let n = x.len();
if n == 0 {
return f64::NAN;
}
x.sort_by(|a, b| a.partial_cmp(b).unwrap());
let g = (n as f64 * trim).floor() as usize;
let kept = &x[g..n - g];
(kept.iter().sum::<f64>() / kept.len() as f64).tanh()
}
fn unique_count(groups: &[i64]) -> usize {
let mut v = groups.to_vec();
v.sort_unstable();
v.dedup();
v.len()
}
pub fn duplicate_correlation(
exprs: &Array2<f64>,
design: &Array2<f64>,
ndups: usize,
spacing: usize,
block: Option<&[i64]>,
trim: f64,
) -> Result<DupCorOutput> {
let narrays = exprs.ncols();
if design.nrows() != narrays {
bail!("number of rows of design does not match number of arrays");
}
let nbeta = design.ncols();
let zero_result = |ngenes: usize| DupCorOutput {
consensus_correlation: 0.0,
atanh_correlations: vec![0.0; ngenes],
};
let (m_mat, design2, groups, rhomin) = if let Some(block) = block {
if block.len() != narrays {
bail!("length of block does not match number of arrays");
}
let max_block = {
let mut counts = std::collections::HashMap::<i64, usize>::new();
for &b in block {
*counts.entry(b).or_insert(0) += 1;
}
counts.values().copied().max().unwrap_or(0)
};
if max_block == 1 {
return Ok(zero_result(exprs.nrows()));
}
let rhomin = 1.0 / (1.0 - max_block as f64) + 0.01;
(exprs.clone(), design.clone(), block.to_vec(), rhomin)
} else {
if ndups < 2 {
return Ok(zero_result(exprs.nrows()));
}
let m_mat = unwrapdups(exprs, ndups, spacing);
let design2 = kron_rows(design, ndups);
let groups: Vec<i64> = (0..narrays)
.flat_map(|a| std::iter::repeat_n(a as i64, ndups))
.collect();
let rhomin = 1.0 / (1.0 - ndups as f64) + 0.01;
(m_mat, design2, groups, rhomin)
};
let ngenes = m_mat.nrows();
let ncols = m_mat.ncols();
let (full_z, _) = block_indicator(&groups);
let full_prep = mm2_prep(&design2, &full_z);
let mut rho = vec![f64::NAN; ngenes];
for i in 0..ngenes {
let yrow: Vec<f64> = m_mat.row(i).to_vec();
let obs: Vec<usize> = (0..ncols).filter(|&k| yrow[k].is_finite()).collect();
let nobs = obs.len();
let groups_o: Vec<i64> = obs.iter().map(|&k| groups[k]).collect();
let nblocks = unique_count(&groups_o);
if !(nobs > nbeta + 2 && nblocks > 1 && nblocks < nobs - 1) {
continue;
}
let varcomp = if nobs == ncols {
full_prep.as_ref().and_then(|p| mm2_varcomp(p, &yrow))
} else {
let ysub: Vec<f64> = obs.iter().map(|&k| m_mat[[i, k]]).collect();
let xsub = design2.select(Axis(0), &obs);
let (zsub, _) = block_indicator(&groups_o);
mm2_prep(&xsub, &zsub)
.as_ref()
.and_then(|p| mm2_varcomp(p, &ysub))
};
if let Some((res, blk)) = varcomp {
rho[i] = blk / (res + blk);
}
}
let rhomax = 0.99;
let min_incl0 = rho
.iter()
.copied()
.filter(|v| v.is_finite())
.fold(0.0_f64, f64::min);
if min_incl0 < rhomin {
for r in rho.iter_mut() {
if r.is_finite() && *r < rhomin {
*r = rhomin;
}
}
}
let max_incl0 = rho
.iter()
.copied()
.filter(|v| v.is_finite())
.fold(0.0_f64, f64::max);
if max_incl0 > rhomax {
for r in rho.iter_mut() {
if r.is_finite() && *r > rhomax {
*r = rhomax;
}
}
}
let arho: Vec<f64> = rho.iter().map(|&r| r.atanh()).collect();
let consensus = trimmed_atanh_mean(&arho, trim);
Ok(DupCorOutput {
consensus_correlation: consensus,
atanh_correlations: arho,
})
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn assert_vec_close(got: &[f64], want: &[f64], tol: f64) {
assert_eq!(got.len(), want.len());
for (a, b) in got.iter().zip(want) {
assert!((a - b).abs() < tol, "got {a} want {b}");
}
}
#[test]
fn block_correlation_matches_r() {
let exprs = array![
[5.10, 5.30, 6.20, 6.00, 7.10, 7.40, 4.10, 4.30],
[2.30, 2.10, 3.80, 3.50, 2.90, 3.10, 5.50, 5.20],
[7.70, 7.90, 8.10, 8.40, 6.90, 6.70, 7.20, 7.50],
[1.10, 1.40, 0.90, 1.20, 2.10, 1.90, 3.10, 2.80],
[9.30, 9.10, 8.80, 9.00, 9.50, 9.70, 8.20, 8.40],
[4.40, 4.10, 5.20, 5.50, 4.90, 4.60, 6.10, 6.40],
[6.60, 6.40, 6.90, 7.10, 5.80, 5.50, 6.20, 6.50],
[3.30, 3.60, 3.10, 2.80, 4.40, 4.70, 2.90, 2.60],
[8.10, 8.40, 7.60, 7.90, 8.80, 8.50, 7.10, 7.40],
[0.50, 0.80, 1.20, 0.90, 0.30, 0.60, 1.50, 1.20],
];
let design = array![
[1.0, 0.0],
[1.0, 0.0],
[1.0, 0.0],
[1.0, 0.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
];
let block = [1, 1, 2, 2, 3, 3, 4, 4];
let out = duplicate_correlation(&exprs, &design, 2, 1, Some(&block), 0.15).unwrap();
let want_atanh = [
2.630357195885,
2.382400165790,
1.025085579690,
1.249128991481,
1.897744594586,
1.824607098861,
1.216131458151,
1.828923672433,
1.600469062091,
1.188743200584,
];
assert_vec_close(&out.atanh_correlations, &want_atanh, 1e-6);
assert!((out.consensus_correlation - 0.928654049014294).abs() < 1e-6);
}
#[test]
fn ndups_correlation_matches_r() {
let exprs = array![
[5.1, 4.8, 6.2, 5.5],
[5.3, 5.0, 6.0, 5.7],
[2.3, 3.1, 2.8, 3.5],
[2.1, 2.9, 3.0, 3.3],
[7.7, 7.2, 8.1, 6.9],
[7.9, 7.4, 7.8, 7.1],
[1.1, 0.9, 1.4, 1.2],
[1.3, 1.1, 1.2, 1.0],
[9.3, 9.1, 8.8, 9.5],
[9.1, 8.9, 9.0, 9.3],
[4.4, 4.9, 5.2, 4.1],
[4.6, 4.7, 5.0, 4.3],
];
let design = array![[1.0, 0.0], [1.0, 0.0], [1.0, 1.0], [1.0, 1.0]];
let out = duplicate_correlation(&exprs, &design, 2, 1, None, 0.15).unwrap();
let want_atanh = [
1.070033081748,
1.551171004306,
1.544437802637,
0.346573590280,
0.990500734433,
1.556757654605,
];
assert_vec_close(&out.atanh_correlations, &want_atanh, 1e-6);
assert!((out.consensus_correlation - 0.826369823215032).abs() < 1e-6);
}
#[test]
fn unwrapdups_pairs_consecutive_rows() {
let m = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
let u = unwrapdups(&m, 2, 1);
assert_eq!(u, array![[1.0, 3.0, 2.0, 4.0], [5.0, 7.0, 6.0, 8.0]]);
}
fn nan() -> f64 {
f64::NAN
}
fn assert_mat_close(got: &Array2<f64>, want: &Array2<f64>) {
assert_eq!(got.dim(), want.dim());
for (a, b) in got.iter().zip(want.iter()) {
let ok = (a.is_nan() && b.is_nan()) || (a - b).abs() < 1e-12;
assert!(ok, "got {a} want {b}");
}
}
fn avedups_data() -> (Array2<f64>, Array2<f64>) {
let x = array![
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
[10.0, 11.0, 12.0],
[13.0, nan(), 15.0],
[16.0, 17.0, 18.0],
[nan(), nan(), 21.0],
[22.0, 23.0, 24.0],
];
let w = array![
[1.0, 2.0, 0.5],
[3.0, 1.0, 2.0],
[0.0, 1.5, 1.0],
[2.0, 2.0, 2.0],
[1.0, 1.0, -1.0],
[4.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[2.0, 0.0, 3.0],
];
(x, w)
}
#[test]
fn avedups_unweighted_matches_r() {
let (x, _) = avedups_data();
assert_mat_close(
&avedups(&x, 2, 1, None),
&array![
[2.5, 3.5, 4.5],
[8.5, 9.5, 10.5],
[14.5, 17.0, 16.5],
[22.0, 23.0, 22.5],
],
);
assert_mat_close(
&avedups(&x, 2, 2, None),
&array![
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
[13.0, nan(), 18.0],
[19.0, 20.0, 21.0],
],
);
assert_mat_close(
&avedups(&x, 4, 1, None),
&array![[5.5, 6.5, 7.5], [17.0, 20.0, 19.5]],
);
}
#[test]
fn avedups_weighted_matches_r() {
let (x, w) = avedups_data();
assert_mat_close(
&avedups(&x, 2, 1, Some(&w)),
&array![
[3.25, 3.0, 5.4],
[10.0, 9.71428571428571, 11.0],
[15.4, 17.0, 18.0],
[22.0, nan(), 23.25],
],
);
assert_mat_close(
&avedups(&x, 2, 2, Some(&w)),
&array![
[1.0, 4.57142857142857, 7.0],
[6.4, 9.0, 9.0],
[13.0, nan(), 21.0],
[18.0, 17.0, 22.5],
],
);
}
#[test]
fn uniquegenelist_matches_r() {
let g: Vec<i32> = (1..=8).collect();
assert_eq!(uniquegenelist(&g, 2, 1), vec![1, 3, 5, 7]);
assert_eq!(uniquegenelist(&g, 2, 2), vec![1, 2, 5, 6]);
assert_eq!(uniquegenelist(&g, 4, 1), vec![1, 5]);
assert_eq!(uniquegenelist(&g, 1, 1), g);
}
}