pub const LOG_0: f64 = f64::NEG_INFINITY;
pub const LOG_1: f64 = 0.0;
pub const LOG_EPSILON: f64 = -1.0e10;
#[inline]
pub fn log_add(x: f64, y: f64) -> f64 {
if x == LOG_0 {
return y;
}
if y == LOG_0 {
return x;
}
let (hi, lo) = if x > y { (x, y) } else { (y, x) };
hi + (lo - hi).exp().ln_1p()
}
#[inline]
pub fn log_sub(x: f64, y: f64) -> f64 {
debug_assert!(x >= y, "log_sub requires x >= y (x={x}, y={y})");
if y == LOG_0 {
return x;
}
if x == y {
return LOG_0;
}
x + (-((y - x).exp())).ln_1p()
}
pub fn log_sum<I: IntoIterator<Item = f64>>(values: I) -> f64 {
values.into_iter().fold(LOG_0, log_add)
}
#[cfg(test)]
mod tests {
use super::*;
fn approx(a: f64, b: f64) {
if a == b {
return; }
assert!((a - b).abs() < 1e-9, "expected {b}, got {a}");
}
#[test]
fn add_identities() {
approx(log_add(LOG_0, 0.5), 0.5);
approx(log_add(0.5, LOG_0), 0.5);
approx(log_add(LOG_1, LOG_1), 2.0_f64.ln());
}
#[test]
fn add_matches_naive() {
let (x, y) = (-3.2_f64, -1.1_f64);
let naive = (x.exp() + y.exp()).ln();
approx(log_add(x, y), naive);
}
#[test]
fn sub_works() {
let (x, y) = (2.0_f64, 1.0_f64);
let naive = (x.exp() - y.exp()).ln();
approx(log_sub(x, y), naive);
approx(log_sub(x, LOG_0), x);
approx(log_sub(x, x), LOG_0);
}
#[test]
fn sum_works() {
let vals = [LOG_1, LOG_1, LOG_1];
approx(log_sum(vals), 3.0_f64.ln());
approx(log_sum(std::iter::empty::<f64>()), LOG_0);
}
}