limma-rust 0.1.0

Pure-Rust port of the Bioconductor limma differential-expression package
Documentation
//! Overflow-safe `log(cosh(x))` and `log(exp(x) + exp(y))`.
//!
//! Pure-Rust port of limma's `logsumexp.R` ([`logcosh`] and [`logsumexp`]).
//! Both avoid floating overflow/underflow: `logcosh` uses a small-x Taylor
//! term and a large-x linear tail, and `logsumexp` reduces to `logcosh` of the
//! half-difference.

use std::f64::consts::LN_2;

/// `logcosh(x)`: `log(cosh(x))` without overflow. Uses `0.5*x^2` for very small
/// `|x|` and the asymptote `|x| - log 2` for `|x| >= 17`.
pub fn logcosh(x: f64) -> f64 {
    let ax = x.abs();
    if ax < 1e-4 {
        0.5 * x * x
    } else if ax < 17.0 {
        x.cosh().ln()
    } else {
        ax - LN_2
    }
}

/// `logsumexp(x, y)`: `log(exp(x) + exp(y))` without overflow or underflow.
/// Propagates NaN; `+Inf` in either argument gives `+Inf`; `-Inf` is the
/// identity element.
pub fn logsumexp(x: f64, y: f64) -> f64 {
    if x.is_nan() || y.is_nan() {
        return f64::NAN;
    }
    let ma = x.max(y);
    let mi = x.min(y);
    if ma == f64::INFINITY {
        return f64::INFINITY;
    }
    if mi == f64::NEG_INFINITY {
        return ma;
    }
    let m = (x + y) / 2.0;
    m + logcosh(ma - m) + LN_2
}

#[cfg(test)]
mod tests {
    use super::*;

    fn close(a: f64, b: f64, tol: f64) -> bool {
        (a - b).abs() <= tol + tol * b.abs()
    }

    #[test]
    fn logcosh_matches_r() {
        // Reference: log(cosh(x)) in R across the three regimes.
        let x = [0.0, 1e-5, 0.5, 1.0, 5.0, 16.0, 20.0, -3.0, -25.0];
        let want = [
            0.0,
            5e-11,
            0.120114506958277,
            0.433780830483027,
            4.30689821833927,
            15.3068528194401,
            19.3068528194401,
            2.30932850457779,
            24.3068528194401,
        ];
        for i in 0..x.len() {
            let got = logcosh(x[i]);
            assert!(close(got, want[i], 1e-12), "logcosh({}): {got}", x[i]);
        }
    }

    #[test]
    fn logsumexp_matches_r() {
        // Reference: log(exp(x)+exp(y)) computed in high precision.
        let cases = [
            (0.0, 0.0, std::f64::consts::LN_2),
            (1.0, 2.0, 2.31326168751822),
            (-5.0, 3.0, 3.0003354063729),
            (1000.0, 1001.0, 1001.31326168752),
            (-1000.0, -1001.0, -999.686738312482),
        ];
        for (x, y, want) in cases {
            assert!(close(logsumexp(x, y), want, 1e-12), "logsumexp({x},{y})");
        }
    }

    #[test]
    fn logsumexp_special_values() {
        assert_eq!(logsumexp(f64::INFINITY, 3.0), f64::INFINITY);
        assert_eq!(logsumexp(5.0, f64::NEG_INFINITY), 5.0);
        assert_eq!(
            logsumexp(f64::NEG_INFINITY, f64::NEG_INFINITY),
            f64::NEG_INFINITY
        );
        assert!(logsumexp(f64::NAN, 1.0).is_nan());
    }
}