1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
//! Overflow-safe `log(cosh(x))` and `log(exp(x) + exp(y))`.
//!
//! Pure-Rust port of limma's `logsumexp.R` ([`logcosh`] and [`logsumexp`]).
//! Both avoid floating overflow/underflow: `logcosh` uses a small-x Taylor
//! term and a large-x linear tail, and `logsumexp` reduces to `logcosh` of the
//! half-difference.
use std::f64::consts::LN_2;
/// `logcosh(x)`: `log(cosh(x))` without overflow. Uses `0.5*x^2` for very small
/// `|x|` and the asymptote `|x| - log 2` for `|x| >= 17`.
pub fn logcosh(x: f64) -> f64 {
let ax = x.abs();
if ax < 1e-4 {
0.5 * x * x
} else if ax < 17.0 {
x.cosh().ln()
} else {
ax - LN_2
}
}
/// `logsumexp(x, y)`: `log(exp(x) + exp(y))` without overflow or underflow.
/// Propagates NaN; `+Inf` in either argument gives `+Inf`; `-Inf` is the
/// identity element.
pub fn logsumexp(x: f64, y: f64) -> f64 {
if x.is_nan() || y.is_nan() {
return f64::NAN;
}
let ma = x.max(y);
let mi = x.min(y);
if ma == f64::INFINITY {
return f64::INFINITY;
}
if mi == f64::NEG_INFINITY {
return ma;
}
let m = (x + y) / 2.0;
m + logcosh(ma - m) + LN_2
}
#[cfg(test)]
mod tests {
use super::*;
fn close(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() <= tol + tol * b.abs()
}
#[test]
fn logcosh_matches_r() {
// Reference: log(cosh(x)) in R across the three regimes.
let x = [0.0, 1e-5, 0.5, 1.0, 5.0, 16.0, 20.0, -3.0, -25.0];
let want = [
0.0,
5e-11,
0.120114506958277,
0.433780830483027,
4.30689821833927,
15.3068528194401,
19.3068528194401,
2.30932850457779,
24.3068528194401,
];
for i in 0..x.len() {
let got = logcosh(x[i]);
assert!(close(got, want[i], 1e-12), "logcosh({}): {got}", x[i]);
}
}
#[test]
fn logsumexp_matches_r() {
// Reference: log(exp(x)+exp(y)) computed in high precision.
let cases = [
(0.0, 0.0, std::f64::consts::LN_2),
(1.0, 2.0, 2.31326168751822),
(-5.0, 3.0, 3.0003354063729),
(1000.0, 1001.0, 1001.31326168752),
(-1000.0, -1001.0, -999.686738312482),
];
for (x, y, want) in cases {
assert!(close(logsumexp(x, y), want, 1e-12), "logsumexp({x},{y})");
}
}
#[test]
fn logsumexp_special_values() {
assert_eq!(logsumexp(f64::INFINITY, 3.0), f64::INFINITY);
assert_eq!(logsumexp(5.0, f64::NEG_INFINITY), 5.0);
assert_eq!(
logsumexp(f64::NEG_INFINITY, f64::NEG_INFINITY),
f64::NEG_INFINITY
);
assert!(logsumexp(f64::NAN, 1.0).is_nan());
}
}