limma-rust 0.1.0

Pure-Rust port of the Bioconductor limma differential-expression package
Documentation
//! Weighted median.
//!
//! Pure-Rust port of limma's `weightedmedian.R` ([`weighted_median`]): the
//! median of the discrete distribution that places probability proportional to
//! `w` on each `x`.

/// `weighted.median(x, w, na.rm)`. `w = None` weights every point equally.
/// With `na_rm`, NaN `x` (and their weights) are dropped first. Panics on
/// negative or NaN weights; returns NaN if all weights are zero.
pub fn weighted_median(x: &[f64], w: Option<&[f64]>, na_rm: bool) -> f64 {
    let ones;
    let w = match w {
        Some(w) => {
            assert_eq!(w.len(), x.len(), "'x' and 'w' must have the same length");
            w
        }
        None => {
            ones = vec![1.0; x.len()];
            &ones
        }
    };

    // Pair up, optionally dropping NaN observations.
    let mut pairs: Vec<(f64, f64)> = x
        .iter()
        .zip(w.iter())
        .filter(|&(&xi, _)| !(na_rm && xi.is_nan()))
        .map(|(&xi, &wi)| (xi, wi))
        .collect();

    for &(_, wi) in &pairs {
        assert!(!wi.is_nan(), "NA weights not allowed");
        assert!(wi >= 0.0, "Negative weights not allowed");
    }
    let total: f64 = pairs.iter().map(|&(_, wi)| wi).sum();
    if total == 0.0 {
        return f64::NAN; // "All weights are zero"
    }

    // Drop zero weights and sort by x ascending (stable, matching R's order()).
    pairs.retain(|&(_, wi)| wi != 0.0);
    pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());

    // Median of the discrete distribution with weights as probabilities.
    let mut cum = 0.0;
    let p: Vec<f64> = pairs
        .iter()
        .map(|&(_, wi)| {
            cum += wi;
            cum / total
        })
        .collect();
    let n = p.iter().filter(|&&pi| pi < 0.5).count();
    if p[n] > 0.5 {
        pairs[n].0
    } else {
        (pairs[n].0 + pairs[n + 1].0) / 2.0
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn close(a: f64, b: f64) -> bool {
        (a - b).abs() <= 1e-12 + 1e-12 * b.abs()
    }

    #[test]
    fn weighted_median_matches_r() {
        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
        let w = [1.0, 1.0, 1.0, 1.0, 1.0];
        // Reference: weighted.median(x, w) in limma 3.68.3.
        assert!(close(weighted_median(&x, Some(&w), false), 3.0));

        let w2 = [0.1, 0.1, 0.1, 0.1, 5.0];
        assert!(close(weighted_median(&x, Some(&w2), false), 5.0));

        let x3 = [1.0, 2.0, 3.0, 4.0];
        let w3 = [1.0, 1.0, 1.0, 1.0];
        assert!(close(weighted_median(&x3, Some(&w3), false), 2.5));

        let x4 = [10.0, 20.0, 30.0];
        let w4 = [1.0, 2.0, 1.0];
        assert!(close(weighted_median(&x4, Some(&w4), false), 20.0));
    }

    #[test]
    fn weighted_median_unweighted_default() {
        let x = [5.0, 1.0, 3.0, 2.0, 4.0];
        assert!(close(weighted_median(&x, None, false), 3.0));
    }

    #[test]
    fn weighted_median_na_rm() {
        let x = [1.0, 2.0, f64::NAN, 4.0, 5.0];
        let w = [1.0, 1.0, 1.0, 1.0, 1.0];
        assert!(close(weighted_median(&x, Some(&w), true), 3.0));
    }

    #[test]
    fn weighted_median_all_zero_weights_is_nan() {
        let x = [1.0, 2.0, 3.0];
        let w = [0.0, 0.0, 0.0];
        assert!(weighted_median(&x, Some(&w), false).is_nan());
    }
}