Skip to main content

limma/
logsumexp.rs

1//! Overflow-safe `log(cosh(x))` and `log(exp(x) + exp(y))`.
2//!
3//! Pure-Rust port of limma's `logsumexp.R` ([`logcosh`] and [`logsumexp`]).
4//! Both avoid floating overflow/underflow: `logcosh` uses a small-x Taylor
5//! term and a large-x linear tail, and `logsumexp` reduces to `logcosh` of the
6//! half-difference.
7
8use std::f64::consts::LN_2;
9
10/// `logcosh(x)`: `log(cosh(x))` without overflow. Uses `0.5*x^2` for very small
11/// `|x|` and the asymptote `|x| - log 2` for `|x| >= 17`.
12pub fn logcosh(x: f64) -> f64 {
13    let ax = x.abs();
14    if ax < 1e-4 {
15        0.5 * x * x
16    } else if ax < 17.0 {
17        x.cosh().ln()
18    } else {
19        ax - LN_2
20    }
21}
22
23/// `logsumexp(x, y)`: `log(exp(x) + exp(y))` without overflow or underflow.
24/// Propagates NaN; `+Inf` in either argument gives `+Inf`; `-Inf` is the
25/// identity element.
26pub fn logsumexp(x: f64, y: f64) -> f64 {
27    if x.is_nan() || y.is_nan() {
28        return f64::NAN;
29    }
30    let ma = x.max(y);
31    let mi = x.min(y);
32    if ma == f64::INFINITY {
33        return f64::INFINITY;
34    }
35    if mi == f64::NEG_INFINITY {
36        return ma;
37    }
38    let m = (x + y) / 2.0;
39    m + logcosh(ma - m) + LN_2
40}
41
42#[cfg(test)]
43mod tests {
44    use super::*;
45
46    fn close(a: f64, b: f64, tol: f64) -> bool {
47        (a - b).abs() <= tol + tol * b.abs()
48    }
49
50    #[test]
51    fn logcosh_matches_r() {
52        // Reference: log(cosh(x)) in R across the three regimes.
53        let x = [0.0, 1e-5, 0.5, 1.0, 5.0, 16.0, 20.0, -3.0, -25.0];
54        let want = [
55            0.0,
56            5e-11,
57            0.120114506958277,
58            0.433780830483027,
59            4.30689821833927,
60            15.3068528194401,
61            19.3068528194401,
62            2.30932850457779,
63            24.3068528194401,
64        ];
65        for i in 0..x.len() {
66            let got = logcosh(x[i]);
67            assert!(close(got, want[i], 1e-12), "logcosh({}): {got}", x[i]);
68        }
69    }
70
71    #[test]
72    fn logsumexp_matches_r() {
73        // Reference: log(exp(x)+exp(y)) computed in high precision.
74        let cases = [
75            (0.0, 0.0, std::f64::consts::LN_2),
76            (1.0, 2.0, 2.31326168751822),
77            (-5.0, 3.0, 3.0003354063729),
78            (1000.0, 1001.0, 1001.31326168752),
79            (-1000.0, -1001.0, -999.686738312482),
80        ];
81        for (x, y, want) in cases {
82            assert!(close(logsumexp(x, y), want, 1e-12), "logsumexp({x},{y})");
83        }
84    }
85
86    #[test]
87    fn logsumexp_special_values() {
88        assert_eq!(logsumexp(f64::INFINITY, 3.0), f64::INFINITY);
89        assert_eq!(logsumexp(5.0, f64::NEG_INFINITY), 5.0);
90        assert_eq!(
91            logsumexp(f64::NEG_INFINITY, f64::NEG_INFINITY),
92            f64::NEG_INFINITY
93        );
94        assert!(logsumexp(f64::NAN, 1.0).is_nan());
95    }
96}