#![forbid(unsafe_code)]
#![warn(clippy::pedantic, missing_docs)]
use num_traits::{Float, FloatConst, One, Zero};
pub trait LogAddExp<Rhs = Self> {
type Output;
#[must_use]
fn ln_add_exp(self, other: Rhs) -> Self::Output;
}
impl<T> LogAddExp for T
where
T: Float + FloatConst,
{
type Output = T;
fn ln_add_exp(self, other: Self) -> Self {
if self == other {
self + Self::LN_2()
} else {
let diff = self - other;
if diff.is_nan() {
diff
} else if diff > Self::zero() {
self + (-diff).exp().ln_1p()
} else {
other + diff.exp().ln_1p()
}
}
}
}
impl<T> LogAddExp<&T> for T
where
T: Float + FloatConst,
{
type Output = T;
fn ln_add_exp(self, other: &Self) -> T {
self.ln_add_exp(*other)
}
}
pub trait LogSumExp {
type Output;
#[must_use]
fn ln_sum_exp(self) -> Self::Output;
}
impl<T> LogSumExp for T
where
T: Iterator,
T::Item: Float,
{
type Output = T::Item;
fn ln_sum_exp(self) -> Self::Output {
let mut max = Self::Output::neg_infinity();
let mut sum = Self::Output::zero();
for val in self {
if val > max {
sum = sum * (max - val).exp() + Self::Output::one();
max = val;
} else if val == max {
sum = sum + Self::Output::one();
} else {
sum = sum + (val - max).exp();
}
}
sum.ln() + max
}
}
#[cfg(test)]
mod tests {
use super::{LogAddExp, LogSumExp};
macro_rules! assert_close {
($a:expr, $b:expr, rtol = $rtol:expr, atol = $atol:expr) => {{
let a = $a;
let b = $b;
assert!(
(a - b).abs() <= $atol + $rtol * b.abs(),
"assertion failed: `(left !== right)`\n left: `{:?}`,\n right: `{:?}`",
a,
b,
);
}};
($a:expr, $b:expr, atol = $atol:expr, rtol = $rtol:expr) => {
assert_close!($a, $b, rol = $rtol, atol = $atol);
};
($a:expr, $b:expr, rtol = $rtol:expr) => {
assert_close!($a, $b, rtol = $rtol, atol = 1e-8);
};
($a:expr, $b:expr, atol = $atol:expr) => {
assert_close!($a, $b, atol = $atol, rtol = 1e-5);
};
($a:expr, $b:expr) => {
assert_close!($a, $b, rtol = 1e-5);
};
}
#[allow(clippy::float_cmp)]
#[test]
fn test_ln_add_exp() {
assert_close!(f64::ln_add_exp(1.0, 1.0), 1.0 + 2_f64.ln());
assert_close!(1.0.ln_add_exp(2.0), (1_f64.exp() + 2_f64.exp()).ln());
assert_close!(f64::ln_add_exp(0.0, &0.0), 2_f64.ln());
assert_close!(2_f64.ln().ln_add_exp(&0.0), 3_f64.ln());
assert!(f64::NAN.ln_add_exp(&1.0).is_nan());
assert!(1.0.ln_add_exp(f64::NAN).is_nan());
assert_eq!(f64::INFINITY.ln_add_exp(&0.0), f64::INFINITY);
assert_eq!(1.0.ln_add_exp(f64::INFINITY), f64::INFINITY);
assert_eq!(f64::INFINITY.ln_add_exp(f64::INFINITY), f64::INFINITY);
assert_eq!(f64::NEG_INFINITY.ln_add_exp(f64::INFINITY), f64::INFINITY);
assert_eq!(f64::INFINITY.ln_add_exp(f64::NEG_INFINITY), f64::INFINITY);
assert_eq!(
f64::NEG_INFINITY.ln_add_exp(f64::NEG_INFINITY),
f64::NEG_INFINITY
);
}
#[allow(clippy::float_cmp)]
#[test]
fn test_ln_sum_exp() {
let raw = (1..10).map(|n| f64::from(n).ln());
let binary = raw.clone().reduce(f64::ln_add_exp).unwrap();
let expected: u32 = (1..10).sum();
assert_close!(binary, f64::from(expected).ln());
let actual = raw.ln_sum_exp();
assert_close!(actual, binary);
assert_eq!(<[f64; 0]>::into_iter([]).ln_sum_exp(), f64::NEG_INFINITY);
assert_eq!(
[f64::NEG_INFINITY; 2].into_iter().ln_sum_exp(),
f64::NEG_INFINITY
);
assert_eq!([f64::INFINITY; 2].into_iter().ln_sum_exp(), f64::INFINITY);
assert_eq!(
[f64::NEG_INFINITY, f64::INFINITY].into_iter().ln_sum_exp(),
f64::INFINITY
);
assert!([f64::NAN, 1.0].into_iter().ln_sum_exp().is_nan());
assert!([1.0, f64::NAN].into_iter().ln_sum_exp().is_nan());
}
#[test]
fn test_ln_sum_exp_single() {
assert_close!([3.5_f64].into_iter().ln_sum_exp(), 3.5);
}
#[test]
fn test_ln_sum_exp_stable_for_large_values() {
let actual = [1000.0_f64, 1000.0, 1000.0].into_iter().ln_sum_exp();
assert_close!(actual, 1000.0 + 3_f64.ln());
assert!(actual.is_finite());
}
#[test]
fn test_ln_sum_exp_order_independent() {
let ascending = [-5.0_f64, 0.0, 2.0, 7.5].into_iter().ln_sum_exp();
let descending = [7.5_f64, 2.0, 0.0, -5.0].into_iter().ln_sum_exp();
assert_close!(ascending, descending);
}
#[test]
fn test_ln_sum_exp_no_clone() {
let mut values = vec![0.0_f64, 1.0, 2.0];
let actual = values.drain(..).ln_sum_exp();
assert_close!(actual, [0.0_f64, 1.0, 2.0].into_iter().ln_sum_exp());
}
}