use ndarray::{Array1, Array2, Axis};
use crate::fit::MArrayLM;
use crate::linalg::{qr_econ, solve_upper};
pub(crate) fn contr_sum(n: usize) -> Array2<f64> {
let mut z = Array2::<f64>::zeros((n, n - 1));
for j in 0..(n - 1) {
z[[j, j]] = 1.0;
z[[n - 1, j]] = -1.0;
}
z
}
pub(crate) fn solve_linear(a: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
let k = a.nrows();
let mut m = a.clone();
let mut rhs = b.clone();
for col in 0..k {
let mut piv = col;
let mut best = m[[col, col]].abs();
for r in (col + 1)..k {
let v = m[[r, col]].abs();
if v > best {
best = v;
piv = r;
}
}
if piv != col {
for c in 0..k {
let tmp = m[[col, c]];
m[[col, c]] = m[[piv, c]];
m[[piv, c]] = tmp;
}
rhs.swap(col, piv);
}
let d = m[[col, col]];
for r in (col + 1)..k {
let f = m[[r, col]] / d;
if f != 0.0 {
for c in col..k {
let v = m[[col, c]];
m[[r, c]] -= f * v;
}
rhs[r] -= f * rhs[col];
}
}
}
let mut x = Array1::<f64>::zeros(k);
for i in (0..k).rev() {
let mut sum = rhs[i];
for j in (i + 1)..k {
sum -= m[[i, j]] * x[j];
}
x[i] = sum / m[[i, i]];
}
x
}
pub fn array_weights(
exprs: &Array2<f64>,
design: &Array2<f64>,
var_design: Option<&Array2<f64>>,
prior_n: f64,
maxiter: usize,
tol: f64,
) -> Array1<f64> {
let narrays = exprs.ncols();
let ngenes_all = exprs.nrows();
let p = design.ncols();
let mut w = Array1::<f64>::ones(narrays);
if ngenes_all < 2 || narrays < p + 2 {
return w;
}
let z2 = match var_design {
Some(v) => v.to_owned(),
None => contr_sum(narrays),
};
let ngam = z2.ncols();
let nz = ngam + 1;
let mut zmat = Array2::<f64>::ones((narrays, nz));
for j in 0..ngam {
for i in 0..narrays {
zmat[[i, j + 1]] = z2[[i, j]];
}
}
let z2tz2 = z2.t().dot(&z2);
let dfres = (narrays - p) as f64;
let (q0, _r0) = qr_econ(design);
let mut kept: Vec<usize> = Vec::with_capacity(ngenes_all);
for g in 0..ngenes_all {
let yg = exprs.row(g).to_owned();
let cg = q0.t().dot(&yg);
let s2 = (yg.dot(&yg) - cg.dot(&cg)) / dfres;
if s2 >= 1e-15 {
kept.push(g);
}
}
if kept.len() < 2 {
return w;
}
let y = exprs.select(Axis(0), &kept);
let ngenes = y.nrows();
let ngenes_f = ngenes as f64;
let mut gam = Array1::<f64>::zeros(ngam);
let mut convcrit_last = f64::INFINITY;
let p2 = p * (p + 1) / 2;
for _iter in 1..=maxiter {
let sw: Array1<f64> = w.mapv(f64::sqrt);
let mut xw = design.clone();
for i in 0..narrays {
for j in 0..p {
xw[[i, j]] *= sw[i];
}
}
let (qe, r) = qr_econ(&xw);
let mut resid = Array2::<f64>::zeros((narrays, ngenes));
let mut s2 = Array1::<f64>::zeros(ngenes);
for gi in 0..ngenes {
let yg = y.row(gi).to_owned();
let ywg = &yg * &sw;
let cg = qe.t().dot(&ywg);
let beta = solve_upper(&r, &cg);
let fitted = design.dot(&beta);
let ss = ywg.dot(&ywg) - cg.dot(&cg);
s2[gi] = ss / dfres;
for i in 0..narrays {
resid[[i, gi]] = yg[i] - fitted[i];
}
}
let mut q2 = Array2::<f64>::zeros((narrays, p2));
let mut h = Array1::<f64>::zeros(narrays);
for i in 0..narrays {
let mut col = 0usize;
for k in 0..p {
for a in 0..(p - k) {
q2[[i, col]] = qe[[i, a]] * qe[[i, a + k]];
col += 1;
}
}
for c in p..p2 {
q2[[i, c]] *= std::f64::consts::SQRT_2;
}
let mut lev = 0.0;
for c in 0..p {
lev += q2[[i, c]];
}
h[i] = lev;
}
let mut info = Array2::<f64>::zeros((nz, nz));
for a in 0..nz {
for b in 0..nz {
let mut acc = 0.0;
for i in 0..narrays {
acc += zmat[[i, a]] * (1.0 - 2.0 * h[i]) * zmat[[i, b]];
}
info[[a, b]] = acc;
}
}
let q2tz = q2.t().dot(&zmat);
let gram = q2tz.t().dot(&q2tz);
info = &info + &gram;
let i00 = info[[0, 0]];
let mut info2 = Array2::<f64>::zeros((ngam, ngam));
for a in 0..ngam {
for b in 0..ngam {
info2[[a, b]] = info[[a + 1, b + 1]] - info[[a + 1, 0]] * info[[0, b + 1]] / i00;
}
}
let mut zvec = Array1::<f64>::zeros(narrays);
for i in 0..narrays {
let mut acc = 0.0;
for gi in 0..ngenes {
acc += w[i] * resid[[i, gi]] * resid[[i, gi]] / s2[gi];
}
zvec[i] = acc / ngenes_f - (1.0 - h[i]);
}
for a in 0..ngam {
for b in 0..ngam {
info2[[a, b]] = ngenes_f * info2[[a, b]] + prior_n * z2tz2[[a, b]];
}
}
for i in 0..narrays {
zvec[i] = ngenes_f * zvec[i] + prior_n * (w[i] - 1.0);
}
let dl = z2.t().dot(&zvec);
let gamstep = solve_linear(&info2, &dl);
let convcrit = dl.dot(&gamstep) / (ngam as f64) / (ngenes_f + prior_n);
if convcrit.is_nan() || convcrit >= convcrit_last {
break;
}
convcrit_last = convcrit;
gam = &gam + &gamstep;
w = z2.dot(&gam).mapv(|x| (-x).exp());
if convcrit < tol {
break;
}
}
w
}
pub fn array_weights_prwts_reml(
exprs: &Array2<f64>,
design: &Array2<f64>,
weights: &Array2<f64>,
var_design: Option<&Array2<f64>>,
prior_n: f64,
maxiter: usize,
tol: f64,
) -> Array1<f64> {
let narrays = exprs.ncols();
let ngenes = exprs.nrows();
let p = design.ncols();
let z2 = match var_design {
Some(v) => v.to_owned(),
None => contr_sum(narrays),
};
let ngam = z2.ncols();
let nz = ngam + 1;
let mut zmat = Array2::<f64>::ones((narrays, nz));
for j in 0..ngam {
for i in 0..narrays {
zmat[[i, j + 1]] = z2[[i, j]];
}
}
let z2tz2 = z2.t().dot(&z2);
let dfres = (narrays - p) as f64;
let denom = ngenes as f64 + prior_n;
let p2 = p * (p + 1) / 2;
let mut gam = Array1::<f64>::zeros(ngam);
let mut w = Array1::<f64>::ones(narrays);
for _iter in 1..=maxiter {
let mut info2 = z2tz2.mapv(|x| x * prior_n);
let mut zvec = w.mapv(|wi| prior_n * (wi - 1.0));
for g in 0..ngenes {
let cw: Vec<f64> = (0..narrays).map(|i| w[i] * weights[[g, i]]).collect();
let sw: Vec<f64> = cw.iter().map(|&v| v.sqrt()).collect();
let mut xw = design.clone();
for i in 0..narrays {
for j in 0..p {
xw[[i, j]] *= sw[i];
}
}
let yg = exprs.row(g);
let yw: Array1<f64> = (0..narrays).map(|i| yg[i] * sw[i]).collect();
let (qe, r) = qr_econ(&xw);
let cg = qe.t().dot(&yw);
let beta = solve_upper(&r, &cg);
let fitted = design.dot(&beta);
let resid: Vec<f64> = (0..narrays).map(|i| yg[i] - fitted[i]).collect();
let s2 = (yw.dot(&yw) - cg.dot(&cg)) / dfres;
let mut q2 = Array2::<f64>::zeros((narrays, p2));
let mut h = vec![0.0f64; narrays];
for i in 0..narrays {
let mut col = 0usize;
for k in 0..p {
for a in 0..(p - k) {
q2[[i, col]] = qe[[i, a]] * qe[[i, a + k]];
col += 1;
}
}
for c in p..p2 {
q2[[i, c]] *= std::f64::consts::SQRT_2;
}
let mut lev = 0.0;
for c in 0..p {
lev += q2[[i, c]];
}
h[i] = lev;
}
let mut info = Array2::<f64>::zeros((nz, nz));
for a in 0..nz {
for b in 0..nz {
let mut acc = 0.0;
for i in 0..narrays {
acc += zmat[[i, a]] * (1.0 - 2.0 * h[i]) * zmat[[i, b]];
}
info[[a, b]] = acc;
}
}
let q2tz = q2.t().dot(&zmat);
let gram = q2tz.t().dot(&q2tz);
info = &info + &gram;
let i00 = info[[0, 0]];
for a in 0..ngam {
for b in 0..ngam {
info2[[a, b]] +=
info[[a + 1, b + 1]] - info[[a + 1, 0]] * info[[0, b + 1]] / i00;
}
}
if s2 > 1e-15 {
for i in 0..narrays {
zvec[i] += cw[i] * resid[i] * resid[i] / s2 - (1.0 - h[i]);
}
}
}
info2.mapv_inplace(|x| x / denom);
zvec.mapv_inplace(|x| x / denom);
let dl = z2.t().dot(&zvec);
let gamstep = solve_linear(&info2, &dl);
gam = &gam + &gamstep;
w = z2.dot(&gam).mapv(|x| (-x).exp());
let convcrit = dl.dot(&gamstep) / denom / (ngam as f64);
if convcrit.is_nan() || convcrit < tol {
break;
}
}
w
}
pub(crate) fn wfit_resid_lev_s2(
x: &Array2<f64>,
y: &[f64],
w: &[f64],
) -> (Vec<f64>, Vec<f64>, f64) {
let n = x.nrows();
let p = x.ncols();
let sw: Vec<f64> = w.iter().map(|&v| v.sqrt()).collect();
let mut xw = x.clone();
for i in 0..n {
for j in 0..p {
xw[[i, j]] *= sw[i];
}
}
let yw: Array1<f64> = (0..n).map(|i| y[i] * sw[i]).collect();
let (qe, r) = qr_econ(&xw);
let cg = qe.t().dot(&yw);
let beta = solve_upper(&r, &cg);
let fitted = x.dot(&beta);
let resid: Vec<f64> = (0..n).map(|i| y[i] - fitted[i]).collect();
let lev: Vec<f64> = (0..n)
.map(|i| (0..p).map(|k| qe[[i, k]] * qe[[i, k]]).sum())
.collect();
let rss = yw.dot(&yw) - cg.dot(&cg);
let s2 = rss / (n - p) as f64;
(resid, lev, s2)
}
pub fn array_weights_gene_by_gene(
exprs: &Array2<f64>,
design: &Array2<f64>,
weights: Option<&Array2<f64>>,
var_design: Option<&Array2<f64>>,
prior_n: f64,
) -> Array1<f64> {
let ngenes = exprs.nrows();
let narrays = exprs.ncols();
let nparams = design.ncols();
let z2 = match var_design {
Some(v) => v.to_owned(),
None => contr_sum(narrays),
};
let ngam = z2.ncols();
let nz = ngam + 1;
let mut zmat = Array2::<f64>::ones((narrays, nz));
for j in 0..ngam {
for i in 0..narrays {
zmat[[i, j + 1]] = z2[[i, j]];
}
}
let mut gam = Array1::<f64>::zeros(ngam);
let mut aw = Array1::<f64>::ones(narrays);
let mut info2 = z2.t().dot(&z2);
info2.mapv_inplace(|x| x * prior_n);
for i in 0..ngenes {
let mut w: Vec<f64> = aw.to_vec();
if let Some(wt) = weights {
for (j, wj) in w.iter_mut().enumerate() {
*wj *= wt[[i, j]];
}
}
let yrow: Vec<f64> = exprs.row(i).to_vec();
let mut d = vec![0.0f64; narrays];
let mut h1 = vec![0.0f64; narrays];
let s2;
if yrow.iter().any(|v| v.is_nan()) {
let obs: Vec<usize> = (0..narrays).filter(|&j| yrow[j].is_finite()).collect();
let nobs = obs.len();
if nobs <= 2 || nobs < nparams + 2 {
continue;
}
let mut xsub = Array2::<f64>::zeros((nobs, nparams));
let mut ysub = vec![0.0f64; nobs];
let mut wsub = vec![0.0f64; nobs];
for (r, &j) in obs.iter().enumerate() {
for c in 0..nparams {
xsub[[r, c]] = design[[j, c]];
}
ysub[r] = yrow[j];
wsub[r] = w[j];
}
let (resid, lev, s2v) = wfit_resid_lev_s2(&xsub, &ysub, &wsub);
s2 = s2v;
for (r, &j) in obs.iter().enumerate() {
d[j] = wsub[r] * resid[r] * resid[r];
h1[j] = 1.0 - lev[r];
}
} else {
let (resid, lev, s2v) = wfit_resid_lev_s2(design, &yrow, &w);
s2 = s2v;
for j in 0..narrays {
d[j] = w[j] * resid[j] * resid[j];
h1[j] = 1.0 - lev[j];
}
}
if s2 < 1e-15 {
continue;
}
let mut info = Array2::<f64>::zeros((nz, nz));
for a in 0..nz {
for b in 0..nz {
let mut acc = 0.0;
for j in 0..narrays {
acc += zmat[[j, a]] * h1[j] * zmat[[j, b]];
}
info[[a, b]] = acc;
}
}
let i00 = info[[0, 0]];
for a in 0..ngam {
for b in 0..ngam {
info2[[a, b]] += info[[a + 1, b + 1]] - info[[a + 1, 0]] * info[[0, b + 1]] / i00;
}
}
let z: Array1<f64> = (0..narrays).map(|j| d[j] / s2 - h1[j]).collect();
let dl = z2.t().dot(&z);
let step = solve_linear(&info2, &dl);
gam = &gam + &step;
aw = z2.dot(&gam).mapv(|x| (-x).exp());
}
aw
}
pub fn array_weights_quick(y: &Array2<f64>, fit: &MArrayLM) -> Array1<f64> {
let design = fit
.design
.as_ref()
.expect("arrayWeightsQuick requires a design in the fit");
let narrays = design.nrows();
let ngenes = y.nrows();
let fitted = fit.coefficients.dot(&design.t());
let (q, _r) = qr_econ(design);
let h: Vec<f64> = (0..narrays)
.map(|j| q.row(j).iter().map(|&v| v * v).sum::<f64>())
.collect();
let mut w = Array1::<f64>::zeros(narrays);
for j in 0..narrays {
let denom_j = 1.0 - h[j];
let mut sum = 0.0;
let mut cnt = 0usize;
for i in 0..ngenes {
let e = y[[i, j]] - fitted[[i, j]];
let s2 = fit.sigma[i] * fit.sigma[i];
let ratio = e * e / (s2 * denom_j);
if !ratio.is_nan() {
sum += ratio;
cnt += 1;
}
}
w[j] = cnt as f64 / sum;
}
w
}
#[cfg(test)]
#[allow(clippy::excessive_precision)]
mod tests {
use super::*;
use ndarray::array;
fn gbg_fixture() -> (Array2<f64>, Array2<f64>, Array2<f64>) {
let ngenes = 12usize;
let narrays = 6usize;
let grp = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let scale = [0.5, 0.7, 1.0, 1.3, 1.6, 2.0];
let mut e = Array2::<f64>::zeros((ngenes, narrays));
let mut wt = Array2::<f64>::zeros((ngenes, narrays));
for g in 0..ngenes {
let gi = g as i64;
for j in 0..narrays {
let ji = j as i64;
let noise = (((gi * 7 + ji * 5) % 11) - 5) as f64 * 0.1 * scale[j];
e[[g, j]] = 5.0 + (gi % 5) as f64 * 0.5 + grp[j] * 0.3 + noise;
wt[[g, j]] = 0.5 + ((gi * 3 + ji * 2) % 7) as f64 * 0.15;
}
}
let mut design = Array2::<f64>::zeros((narrays, 2));
for j in 0..narrays {
design[[j, 0]] = 1.0;
design[[j, 1]] = grp[j];
}
(e, design, wt)
}
#[test]
fn array_weights_gene_by_gene_matches_r() {
let (e, design, wt) = gbg_fixture();
let aw = array_weights_gene_by_gene(&e, &design, None, None, 10.0);
let want = [
1.5261488797077414,
1.200903015330324,
1.2231111922725117,
1.2216065418856867,
0.42436321655995812,
0.86051835803628207,
];
for (g, x) in aw.iter().zip(want.iter()) {
assert!((g - x).abs() < 1e-7, "no-weights: got {g}, want {x}");
}
let aww = array_weights_gene_by_gene(&e, &design, Some(&wt), None, 10.0);
let want_w = [
1.5273497496573842,
1.1661798617674437,
1.2204452025989252,
1.2146290021966843,
0.42565954641887849,
0.8897574953263433,
];
for (g, x) in aww.iter().zip(want_w.iter()) {
assert!((g - x).abs() < 1e-7, "with-weights: got {g}, want {x}");
}
let mut ena = e.clone();
ena[[2, 4]] = f64::NAN;
ena[[6, 0]] = f64::NAN;
ena[[9, 5]] = f64::NAN;
let awn = array_weights_gene_by_gene(&ena, &design, None, None, 10.0);
let want_na = [
1.4797107434846923,
1.1254112099887159,
1.1542363030943652,
1.2058806596955525,
0.47074236183939411,
0.91649393378073563,
];
for (g, x) in awn.iter().zip(want_na.iter()) {
assert!((g - x).abs() < 1e-7, "NA: got {g}, want {x}");
}
}
#[test]
fn array_weights_reml_matches_r() {
let exprs = array![
[4.871, 4.629, 4.697, 5.807, 4.798, 5.195],
[6.356, 6.349, 6.764, 4.125, 3.125, 4.752],
[4.298, 4.659, 4.508, 5.936, 4.075, 7.367],
[8.896, 9.420, 8.915, 9.165, 9.466, 8.598],
[6.563, 6.610, 6.813, 6.123, 6.155, 7.309],
[4.443, 4.283, 3.851, 5.435, 5.304, 5.784],
[7.247, 7.184, 7.620, 6.533, 7.878, 6.820],
[7.456, 7.644, 8.368, 9.096, 7.422, 10.245],
[7.229, 6.945, 6.986, 8.178, 7.445, 10.159],
[5.378, 5.177, 4.919, 7.692, 6.023, 7.432],
[8.748, 9.133, 9.280, 9.431, 10.394, 11.954],
[6.697, 7.010, 6.719, 4.293, 3.114, 5.796],
];
let design = array![
[1.0, 0.0],
[1.0, 0.0],
[1.0, 0.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
];
let want = [
1.611164881845,
1.659781018122,
1.455189349487,
1.160250145839,
0.462418787784,
0.478963710208,
];
let w = array_weights(&exprs, &design, None, 10.0, 50, 1e-5);
assert_eq!(w.len(), want.len());
for (got, exp) in w.iter().zip(want.iter()) {
assert!(
(got - exp).abs() < 1e-6,
"array weight mismatch: got {got}, want {exp}"
);
}
}
#[test]
fn array_weights_prwts_reml_matches_r() {
let (e, design, wt) = gbg_fixture();
let w = array_weights_prwts_reml(&e, &design, &wt, None, 10.0, 50, 1e-5);
let want = [
1.5867284550700409,
1.1019127664080206,
1.2065279057943283,
1.4157105291721905,
0.29591187727091722,
1.1315557212822311,
];
assert_eq!(w.len(), want.len());
for (got, exp) in w.iter().zip(want.iter()) {
assert!((got - exp).abs() < 1e-6, "p=2: got {got}, want {exp}");
}
let narrays = e.ncols();
let design1 = Array2::<f64>::ones((narrays, 1));
let w1 = array_weights_prwts_reml(&e, &design1, &wt, None, 10.0, 50, 1e-5);
let want1 = [
1.3749226879021179,
1.4906361441733864,
1.0601401145509048,
1.0761314591200781,
0.62056380371792619,
0.68918376453280328,
];
for (got, exp) in w1.iter().zip(want1.iter()) {
assert!((got - exp).abs() < 1e-6, "p=1: got {got}, want {exp}");
}
}
#[test]
fn array_weights_quick_matches_r() {
let y = array![
[-0.59, 0.01, 0.79, 0.36],
[0.03, -0.19, -0.23, 0.65],
[-1.52, -0.77, 1.67, 1.81],
[-1.36, -0.22, 0.50, 1.22],
[1.18, -0.98, 0.16, 0.91],
];
let design = array![[1.0, 0.0], [1.0, 0.0], [1.0, 1.0], [1.0, 1.0]];
let fit = crate::fit::lmfit(
&y,
&design,
(0..5).map(|i| i.to_string()).collect(),
vec!["Int".into(), "grp".into()],
)
.unwrap();
let w = array_weights_quick(&y, &fit);
let want = [
0.759166824278653,
0.759166824278652,
1.46462963844169,
1.46462963844169,
];
assert_eq!(w.len(), want.len());
for (got, exp) in w.iter().zip(want.iter()) {
assert!((got - exp).abs() < 1e-9, "got {got}, want {exp}");
}
}
#[test]
fn contr_sum_shape() {
let c = contr_sum(4);
assert_eq!(c.dim(), (4, 3));
assert_eq!(c[[0, 0]], 1.0);
assert_eq!(c[[1, 1]], 1.0);
assert_eq!(c[[3, 0]], -1.0);
assert_eq!(c[[3, 2]], -1.0);
assert_eq!(c[[0, 1]], 0.0);
for j in 0..3 {
let s: f64 = (0..4).map(|i| c[[i, j]]).sum();
assert!(s.abs() < 1e-15);
}
}
}