limma-rust 0.1.0

Pure-Rust port of the Bioconductor limma differential-expression package
Documentation
//! `printtipWeights` (printtipWeights.R): print-tip array quality weights for
//! two-colour arrays. Each print-tip block (a contiguous run of `nspots` rows)
//! gets its own array-weight estimate via the gene-by-gene update algorithm —
//! the same machinery as `arrayWeights(method = "genebygene")` but with a
//! `contr.sum` variance design and a per-block prior — and the resulting
//! `narrays` weights are broadcast across the block's rows.
//!
//! Only the default `method = "genebygene"` is ported; the `reml` branch is out
//! of scope. The design is assumed to have full column rank (limma reports
//! non-estimable coefficients but otherwise proceeds). The `Agam.del` rank-1
//! correction is reproduced exactly as shipped, including limma's
//! `h[1:(length(narrays)-1)]` indexing quirk (which reduces it to `h[1]`).

use anyhow::{bail, Result};
use ndarray::{Array1, Array2};

use crate::arrayweights::{contr_sum, solve_linear, wfit_resid_lev_s2};
use crate::normwithin::PrinterLayout;

/// `printtipWeights(M, design, weights, method = "genebygene", layout)`.
///
/// * `m` — `n_probes x n_arrays` log-ratio matrix (`NA`/infinite entries are
///   dropped per spot).
/// * `design` — `n_arrays x p` design matrix (full column rank).
/// * `weights` — optional `n_probes x n_arrays` spot weights; a spot's weights
///   are rescaled to a maximum of 1 (when that maximum exceeds 1) before being
///   combined with the running variance estimate, exactly as limma does.
/// * `layout` — print-tip layout; `ngrid_r*ngrid_c*nspot_r*nspot_c` must equal
///   `n_probes`.
///
/// Returns the `n_probes x n_arrays` weight matrix (each block's weights
/// broadcast across its rows), ready to pass to `lmFit`.
pub fn printtip_weights(
    m: &Array2<f64>,
    design: &Array2<f64>,
    weights: Option<&Array2<f64>>,
    layout: PrinterLayout,
) -> Result<Array2<f64>> {
    let nprobes = m.nrows();
    let narrays = m.ncols();
    let nparams = design.ncols();
    if design.nrows() != narrays {
        bail!("design row dimension must equal number of arrays");
    }
    if narrays < 3 {
        bail!("too few arrays");
    }
    if nprobes < narrays {
        bail!("too few probes");
    }
    let nspots = layout.nspot_r * layout.nspot_c;
    let ngrids = layout.ngrid_r * layout.ngrid_c;
    if ngrids * nspots != nprobes {
        bail!("printer layout information does not match M row dimension");
    }
    if let Some(w) = weights {
        if w.dim() != (nprobes, narrays) {
            bail!("dimensions of weights do not match M");
        }
    }

    let z = contr_sum(narrays); // narrays x (narrays-1)
    let ngam = narrays - 1;
    let ztz = z.t().dot(&z);
    let prior = 10.0 * (narrays - nparams) as f64 / narrays as f64;

    let mut blockw = Array2::<f64>::zeros((ngrids, narrays));
    for blk in 0..ngrids {
        let start = blk * nspots;
        let mut gammas = Array1::<f64>::zeros(ngam);
        let mut zinfo = ztz.mapv(|v| v * prior);

        for s in 0..nspots {
            let i = start + s;
            if gammas.iter().any(|v| !v.is_finite()) {
                bail!("convergence problem at block {blk} spot {s}: array weights not estimable");
            }

            // w = 1/vary = exp(-Z gammas), optionally times the (re-scaled) spot
            // weights for this probe.
            let zg = z.dot(&gammas);
            let mut wfull: Vec<f64> = (0..narrays).map(|a| (-zg[a]).exp()).collect();
            if let Some(wmat) = weights {
                let mut wrow: Vec<f64> = (0..narrays).map(|a| wmat[[i, a]]).collect();
                let mx = wrow
                    .iter()
                    .copied()
                    .filter(|v| v.is_finite())
                    .fold(f64::NEG_INFINITY, f64::max);
                if mx > 1.0 {
                    for v in wrow.iter_mut() {
                        *v /= mx;
                    }
                }
                for a in 0..narrays {
                    wfull[a] *= wrow[a];
                }
            }

            let yfull: Vec<f64> = (0..narrays).map(|a| m[[i, a]]).collect();
            let obs: Vec<usize> = (0..narrays)
                .filter(|&a| yfull[a].is_finite() && wfull[a] != 0.0)
                .collect();
            let nobs = obs.len();
            if nobs <= 1 {
                continue;
            }

            let xsub = Array2::from_shape_fn((nobs, nparams), |(r, c)| design[[obs[r], c]]);
            let ysub: Vec<f64> = obs.iter().map(|&a| yfull[a]).collect();
            let wsub: Vec<f64> = obs.iter().map(|&a| wfull[a]).collect();
            let (resid, lev, s2) = wfit_resid_lev_s2(&xsub, &ysub, &wsub);
            let df_resid = (nobs - nparams) as f64;

            // d (= w*resid^2) and h (leverage) spread to full length; unobserved
            // arrays carry h = 1, d = 0.
            let mut d = vec![0.0f64; narrays];
            let mut h = vec![1.0f64; narrays];
            for (k, &a) in obs.iter().enumerate() {
                d[a] = wsub[k] * resid[k] * resid[k];
                h[a] = lev[k];
            }

            // Agene.gam = Z' diag(1-h) Z - (1/df) * Agam.del, with limma's
            // Agam.del reducing to (h[last]-h[first])^2 * ones.
            let mut agene = Array2::<f64>::zeros((ngam, ngam));
            for p in 0..ngam {
                for q in 0..ngam {
                    let mut acc = 0.0;
                    for a in 0..narrays {
                        acc += z[[a, p]] * (1.0 - h[a]) * z[[a, q]];
                    }
                    agene[[p, q]] = acc;
                }
            }
            let cdel = h[narrays - 1] - h[0];
            let del = cdel * cdel / df_resid;
            agene.mapv_inplace(|v| v - del);
            if !agene.iter().all(|v| v.is_finite()) {
                continue;
            }

            let zd: Array1<f64> = (0..narrays).map(|a| d[a] / s2 - 1.0 + h[a]).collect();

            if nobs == narrays {
                zinfo = &zinfo + &agene;
                let zzd = z.t().dot(&zd);
                let step = solve_linear(&zinfo, &zzd);
                gammas = &gammas + &step;
            } else if nobs > 2 {
                zinfo = &zinfo + &agene;
                // A1 = (I - J/nobs) Z[obs,] with its last row dropped.
                let z2 = Array2::from_shape_fn((nobs, ngam), |(r, c)| z[[obs[r], c]]);
                let mut a1 = Array2::<f64>::zeros((nobs - 1, ngam));
                for r in 0..(nobs - 1) {
                    for c in 0..ngam {
                        let mut acc = 0.0;
                        for k in 0..nobs {
                            let centering = (if k == r { 1.0 } else { 0.0 }) - 1.0 / nobs as f64;
                            acc += centering * z2[[k, c]];
                        }
                        a1[[r, c]] = acc;
                    }
                }
                let ztzd = z.t().dot(&zd);
                let zzd = a1.dot(&ztzd); // length nobs-1
                                         // A1 Zinfo^-1 A1' via per-column solves of Zinfo x = A1'[,r].
                let mut zinv_a1t = Array2::<f64>::zeros((ngam, nobs - 1));
                for r in 0..(nobs - 1) {
                    let rhs: Array1<f64> = (0..ngam).map(|c| a1[[r, c]]).collect();
                    let sol = solve_linear(&zinfo, &rhs);
                    for c in 0..ngam {
                        zinv_a1t[[c, r]] = sol[c];
                    }
                }
                let alphas_iter = a1.dot(&zinv_a1t).dot(&zzd); // length nobs-1
                                                               // Us (alphas_new - alphas_old) with Us = [I_{nobs-1}; -1].
                let mut usalphas = vec![0.0f64; nobs];
                for k in 0..(nobs - 1) {
                    usalphas[k] = alphas_iter[k];
                    usalphas[nobs - 1] -= alphas_iter[k];
                }
                let mut usg = z.dot(&gammas);
                for (k, &a) in obs.iter().enumerate() {
                    usg[a] += usalphas[k];
                }
                gammas = (0..ngam).map(|a| usg[a]).collect();
            }
        }

        let zg = z.dot(&gammas);
        for a in 0..narrays {
            blockw[[blk, a]] = (-zg[a]).exp();
        }
    }

    let mut wts = Array2::<f64>::zeros((nprobes, narrays));
    for blk in 0..ngrids {
        for s in 0..nspots {
            let i = blk * nspots + s;
            for a in 0..narrays {
                wts[[i, a]] = blockw[[blk, a]];
            }
        }
    }
    Ok(wts)
}

#[cfg(test)]
#[allow(clippy::excessive_precision)]
mod tests {
    use super::*;

    fn rclose(a: f64, b: f64) -> bool {
        (a - b).abs() <= 1e-7 * (1.0 + b.abs())
    }

    /// 12x4 M matrix (with per-array scale heterogeneity) and a 12x4 spot-weight
    /// matrix, matching `scratch/printtipweights_ref.R`.
    fn fixture() -> (Array2<f64>, Array2<f64>) {
        let scale = [1.0, 1.5, 0.7, 2.0];
        let (nprobe, narray) = (12usize, 4usize);
        let mut m = Array2::zeros((nprobe, narray));
        let mut w = Array2::zeros((nprobe, narray));
        for g0 in 0..nprobe {
            for j0 in 0..narray {
                let (gi, ji) = (g0 as i64, j0 as i64);
                m[[g0, j0]] = 3.0
                    + (gi % 4) as f64 * 0.5
                    + ((gi * 5 + ji * 3) % 7 - 3) as f64 * 0.2 * scale[j0];
                w[[g0, j0]] = 0.5 + ((gi * 2 + ji * 5) % 6) as f64 * 0.2;
            }
        }
        (m, w)
    }

    fn design4() -> Array2<f64> {
        Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]).unwrap()
    }

    const LAYOUT: PrinterLayout = PrinterLayout {
        ngrid_r: 1,
        ngrid_c: 2,
        nspot_r: 2,
        nspot_c: 3,
    };

    /// Assert that each of the 12 rows equals its block's expected weights:
    /// rows 0..6 -> `blocks[0]`, rows 6..12 -> `blocks[1]`.
    fn assert_blocks(out: &Array2<f64>, blocks: &[[f64; 4]; 2], label: &str) {
        for i in 0..12 {
            let exp = blocks[i / 6];
            for a in 0..4 {
                assert!(
                    rclose(out[[i, a]], exp[a]),
                    "{label}[{i},{a}]: {} vs {}",
                    out[[i, a]],
                    exp[a]
                );
            }
        }
    }

    #[test]
    fn printtip_weights_no_weights_clean() {
        let (m, _w) = fixture();
        let out = printtip_weights(&m, &design4(), None, LAYOUT).unwrap();
        let blocks = [
            [
                1.0279718633477748,
                1.0279718633477748,
                0.97278927143328675,
                0.97278927143328664,
            ],
            [
                1.0117348683445897,
                1.0117348683445895,
                0.98840124155867992,
                0.98840124155868003,
            ],
        ];
        assert_blocks(&out, &blocks, "ptw A");
    }

    #[test]
    fn printtip_weights_no_weights_na_branch() {
        let (m, _w) = fixture();
        let mut mna = m.clone();
        mna[[7, 1]] = f64::NAN; // M[8,2] in 1-based -> sum(obs)=3 branch in block 2.
        let out = printtip_weights(&mna, &design4(), None, LAYOUT).unwrap();
        let blocks = [
            [
                1.0279718633477748,
                1.0279718633477748,
                0.97278927143328675,
                0.97278927143328664,
            ],
            [
                0.99170352734697842,
                0.99170352734697809,
                1.0085121081848278,
                1.0082196729118387,
            ],
        ];
        assert_blocks(&out, &blocks, "ptw B");
    }

    #[test]
    fn printtip_weights_with_spot_weights() {
        let (m, w) = fixture();
        let out = printtip_weights(&m, &design4(), Some(&w), LAYOUT).unwrap();
        let blocks = [
            [
                1.032303811727709,
                1.013658601320103,
                0.96283233132495361,
                0.99254474410404492,
            ],
            [
                1.0258412521778075,
                0.99993246962846849,
                0.99494360599235865,
                0.97982993672697472,
            ],
        ];
        assert_blocks(&out, &blocks, "ptw C");
    }
}