1use std::f64::consts::LN_2;
9
10pub fn logcosh(x: f64) -> f64 {
13 let ax = x.abs();
14 if ax < 1e-4 {
15 0.5 * x * x
16 } else if ax < 17.0 {
17 x.cosh().ln()
18 } else {
19 ax - LN_2
20 }
21}
22
23pub fn logsumexp(x: f64, y: f64) -> f64 {
27 if x.is_nan() || y.is_nan() {
28 return f64::NAN;
29 }
30 let ma = x.max(y);
31 let mi = x.min(y);
32 if ma == f64::INFINITY {
33 return f64::INFINITY;
34 }
35 if mi == f64::NEG_INFINITY {
36 return ma;
37 }
38 let m = (x + y) / 2.0;
39 m + logcosh(ma - m) + LN_2
40}
41
42#[cfg(test)]
43mod tests {
44 use super::*;
45
46 fn close(a: f64, b: f64, tol: f64) -> bool {
47 (a - b).abs() <= tol + tol * b.abs()
48 }
49
50 #[test]
51 fn logcosh_matches_r() {
52 let x = [0.0, 1e-5, 0.5, 1.0, 5.0, 16.0, 20.0, -3.0, -25.0];
54 let want = [
55 0.0,
56 5e-11,
57 0.120114506958277,
58 0.433780830483027,
59 4.30689821833927,
60 15.3068528194401,
61 19.3068528194401,
62 2.30932850457779,
63 24.3068528194401,
64 ];
65 for i in 0..x.len() {
66 let got = logcosh(x[i]);
67 assert!(close(got, want[i], 1e-12), "logcosh({}): {got}", x[i]);
68 }
69 }
70
71 #[test]
72 fn logsumexp_matches_r() {
73 let cases = [
75 (0.0, 0.0, std::f64::consts::LN_2),
76 (1.0, 2.0, 2.31326168751822),
77 (-5.0, 3.0, 3.0003354063729),
78 (1000.0, 1001.0, 1001.31326168752),
79 (-1000.0, -1001.0, -999.686738312482),
80 ];
81 for (x, y, want) in cases {
82 assert!(close(logsumexp(x, y), want, 1e-12), "logsumexp({x},{y})");
83 }
84 }
85
86 #[test]
87 fn logsumexp_special_values() {
88 assert_eq!(logsumexp(f64::INFINITY, 3.0), f64::INFINITY);
89 assert_eq!(logsumexp(5.0, f64::NEG_INFINITY), 5.0);
90 assert_eq!(
91 logsumexp(f64::NEG_INFINITY, f64::NEG_INFINITY),
92 f64::NEG_INFINITY
93 );
94 assert!(logsumexp(f64::NAN, 1.0).is_nan());
95 }
96}