rsomics-limma-array-weights 0.1.0

Estimate per-sample (array) quality weights by REML for a log-expression matrix + design — a clean-room Rust reimplementation of limma's arrayWeights
Documentation
//! REML scoring for limma's array-weight model.
//!
//! For each gene g the log-expression row follows  y_g ~ N(X b_g, s2_g * diag(exp(gamma_j))),
//! where gamma_j is a per-array log-variance offset shared across genes (sum gamma = 0) and
//! s2_g is a per-gene scale. The array weights are w_j = exp(-gamma_j), renormalised to
//! geometric mean 1. gamma is fitted by Fisher scoring on the REML deviance pooled over genes;
//! the dispersion design is the array indicator, so the score and (full) information follow
//! Smyth (2002)'s heteroscedastic REML with the residual projection's Hadamard square.

use rsomics_common::{Result, RsomicsError};

pub struct Fit {
    pub weights: Vec<f64>,
    pub iters: usize,
    pub converged: bool,
}

/// `y` is [gene][array]; `x` is [array][coef].
pub fn array_weights(
    y: &[Vec<f64>],
    x: &[Vec<f64>],
    prior_n: f64,
    maxiter: usize,
    tol: f64,
) -> Result<Fit> {
    let ng = y.len();
    let n = x.len();
    let p = x[0].len();
    if n <= p {
        return Err(RsomicsError::InvalidInput(format!(
            "design is saturated: {n} arrays, {p} coefficients leave no residual df"
        )));
    }

    let mut gamma = vec![0.0f64; n];
    let mut converged = false;
    let mut iters = 0;

    for it in 1..=maxiter {
        iters = it;
        let w: Vec<f64> = gamma.iter().map(|g| (-g).exp()).collect();

        let hat = weighted_hat(x, &w)?;
        let h: Vec<f64> = (0..n).map(|i| hat[i][i]).collect();

        // a = W^{1/2} (I - H) W^{-1/2}: the symmetric residual projection in the weighted metric.
        let sw: Vec<f64> = w.iter().map(|wj| wj.sqrt()).collect();
        let mut a = vec![vec![0.0f64; n]; n];
        for i in 0..n {
            for k in 0..n {
                let proj = if i == k { 1.0 } else { 0.0 } - hat[i][k];
                a[i][k] = sw[i] * proj / sw[k];
            }
        }

        // Pool per-array deviance components over genes. For gene g the standardised squared
        // residual w_j r_gj^2, divided by the gene's REML scale s2_g, has expectation (1 - h_j).
        let mut std_sumsq = vec![0.0f64; n];
        for row in y {
            let r = residual(row, &hat);
            let mut rss = 0.0;
            let mut rw2 = vec![0.0f64; n];
            for j in 0..n {
                let v = w[j] * r[j] * r[j];
                rw2[j] = v;
                rss += v;
            }
            let s2 = rss / (n - p) as f64;
            for j in 0..n {
                std_sumsq[j] += rw2[j] / s2;
            }
        }

        let mut score = vec![0.0f64; n];
        let mut info = vec![vec![0.0f64; n]; n];
        for j in 0..n {
            for (info_jk, &a_jk) in info[j].iter_mut().zip(&a[j]) {
                *info_jk = 0.5 * ng as f64 * a_jk * a_jk;
            }
            // Prior: prior_n pseudo-genes asserting unit array weight (gamma_j = 0), entering
            // as a gamma-deviance term that pulls extreme weights back toward equality.
            let e = (-gamma[j]).exp();
            score[j] = 0.5 * (std_sumsq[j] - ng as f64 * (1.0 - h[j])) + 0.5 * prior_n * (e - 1.0);
            info[j][j] += 0.5 * prior_n * e;
        }

        // gamma lives on the sum-to-zero hyperplane; a small constant ridge lifts the
        // information off its null space so the dense solve is well posed.
        let ridge = mean_diag(&info) * 1e-4;
        for row in &mut info {
            for v in row.iter_mut() {
                *v += ridge;
            }
        }

        let mut step = solve(&info, &score)?;
        let mean_step = step.iter().sum::<f64>() / n as f64;
        for s in &mut step {
            *s = (*s - mean_step).clamp(-1.0, 1.0);
        }

        let mut max_abs = 0.0f64;
        for j in 0..n {
            gamma[j] += step[j];
            max_abs = max_abs.max(step[j].abs());
        }
        let mean_gamma = gamma.iter().sum::<f64>() / n as f64;
        for g in &mut gamma {
            *g -= mean_gamma;
        }

        if max_abs < tol {
            converged = true;
            break;
        }
    }

    let mut weights: Vec<f64> = gamma.iter().map(|g| (-g).exp()).collect();
    let log_mean = weights.iter().map(|w| w.ln()).sum::<f64>() / n as f64;
    let scale = log_mean.exp();
    for w in &mut weights {
        *w /= scale;
    }

    Ok(Fit {
        weights,
        iters,
        converged,
    })
}

fn weighted_hat(x: &[Vec<f64>], w: &[f64]) -> Result<Vec<Vec<f64>>> {
    let n = x.len();
    let p = x[0].len();

    let mut xtwx = vec![vec![0.0f64; p]; p];
    for i in 0..n {
        for a in 0..p {
            let xa = x[i][a] * w[i];
            for b in 0..p {
                xtwx[a][b] += xa * x[i][b];
            }
        }
    }
    let inv = invert(&xtwx)?;

    // H = X (X'WX)^{-1} X' W
    let mut hat = vec![vec![0.0f64; n]; n];
    for i in 0..n {
        let mut xi_inv = vec![0.0f64; p];
        for a in 0..p {
            let mut acc = 0.0;
            for b in 0..p {
                acc += x[i][b] * inv[b][a];
            }
            xi_inv[a] = acc;
        }
        for k in 0..n {
            let mut acc = 0.0;
            for a in 0..p {
                acc += xi_inv[a] * x[k][a];
            }
            hat[i][k] = acc * w[k];
        }
    }
    Ok(hat)
}

fn residual(y: &[f64], hat: &[Vec<f64>]) -> Vec<f64> {
    let n = y.len();
    let mut r = vec![0.0f64; n];
    for i in 0..n {
        let mut fitted = 0.0;
        for k in 0..n {
            fitted += hat[i][k] * y[k];
        }
        r[i] = y[i] - fitted;
    }
    r
}

fn mean_diag(m: &[Vec<f64>]) -> f64 {
    let n = m.len();
    (0..n).map(|i| m[i][i]).sum::<f64>() / n as f64
}

fn invert(m: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
    let n = m.len();
    let mut a: Vec<Vec<f64>> = m.to_vec();
    let mut inv = vec![vec![0.0f64; n]; n];
    for (i, row) in inv.iter_mut().enumerate() {
        row[i] = 1.0;
    }
    for col in 0..n {
        let mut pivot = col;
        for r in (col + 1)..n {
            if a[r][col].abs() > a[pivot][col].abs() {
                pivot = r;
            }
        }
        if a[pivot][col].abs() < 1e-12 {
            return Err(RsomicsError::InvalidInput(
                "design matrix is rank deficient".into(),
            ));
        }
        a.swap(col, pivot);
        inv.swap(col, pivot);
        let d = a[col][col];
        for j in 0..n {
            a[col][j] /= d;
            inv[col][j] /= d;
        }
        for r in 0..n {
            if r == col {
                continue;
            }
            let f = a[r][col];
            if f == 0.0 {
                continue;
            }
            for j in 0..n {
                a[r][j] -= f * a[col][j];
                inv[r][j] -= f * inv[col][j];
            }
        }
    }
    Ok(inv)
}

fn solve(m: &[Vec<f64>], b: &[f64]) -> Result<Vec<f64>> {
    let inv = invert(m)?;
    let n = b.len();
    let mut x = vec![0.0f64; n];
    for i in 0..n {
        let mut acc = 0.0;
        for j in 0..n {
            acc += inv[i][j] * b[j];
        }
        x[i] = acc;
    }
    Ok(x)
}