use core::f64::consts::{LN_2, LOG2_E};
mod sealed {
pub trait Sealed {}
impl Sealed for f32 {}
impl Sealed for f64 {}
}
pub trait LogAddExpArg: sealed::Sealed {
const ZERO: Self;
fn npy_logaddexp(self, other: Self) -> Self;
fn npy_logaddexp2(self, other: Self) -> Self;
}
impl LogAddExpArg for f32 {
const ZERO: Self = 0.0;
#[inline]
#[allow(clippy::cast_possible_truncation)]
fn npy_logaddexp(self, other: Self) -> Self {
f64::from(self).npy_logaddexp(f64::from(other)) as f32
}
#[inline]
#[allow(clippy::cast_possible_truncation)]
fn npy_logaddexp2(self, other: Self) -> Self {
f64::from(self).npy_logaddexp2(f64::from(other)) as f32
}
}
impl LogAddExpArg for f64 {
const ZERO: Self = 0.0;
#[inline]
fn npy_logaddexp(self, other: Self) -> Self {
#[allow(clippy::float_cmp)]
if self == other {
self + LN_2
} else {
let tmp = self - other;
if tmp > 0.0 {
self + (-tmp).exp().ln_1p()
} else if tmp <= 0.0 {
other + tmp.exp().ln_1p()
} else {
tmp
}
}
}
#[inline]
fn npy_logaddexp2(self, other: Self) -> Self {
#[allow(clippy::float_cmp)]
if self == other {
self + 1.0
} else {
let tmp = self - other;
if tmp > 0.0 {
self + (-tmp).exp2().ln_1p() * LOG2_E
} else if tmp <= 0.0 {
other + tmp.exp2().ln_1p() * LOG2_E
} else {
tmp
}
}
}
}
#[inline]
pub fn logaddexp<T: LogAddExpArg>(x: T, y: T) -> T {
x.npy_logaddexp(y)
}
#[inline]
pub fn logaddexp2<T: LogAddExpArg>(x: T, y: T) -> T {
x.npy_logaddexp2(y)
}
#[cfg(test)]
mod tests {
use crate::np_assert_allclose;
#[test]
fn test_logaddexp_f32() {
let x = [1.0, 2.0, 3.0, 4.0, 5.0];
let y = [5.0, 4.0, 3.0, 2.0, 1.0];
let z = [6.0, 6.0, 6.0, 6.0, 6.0];
let xf = x.map(f64::log2);
let yf = y.map(f64::log2);
let zf = z.map(f64::log2);
#[allow(clippy::cast_possible_truncation)]
let zr: [f32; 5] = std::array::from_fn(|i| crate::logaddexp2(xf[i] as f32, yf[i] as f32));
np_assert_allclose!(zr.map(f64::from), zf, atol = 1.5e-7);
}
#[test]
fn test_logaddexp_f64() {
let x = [1.0, 2.0, 3.0, 4.0, 5.0];
let y = [5.0, 4.0, 3.0, 2.0, 1.0];
let z = [6.0, 6.0, 6.0, 6.0, 6.0];
let xf = x.map(f64::log2);
let yf = y.map(f64::log2);
let zf = z.map(f64::log2);
let zr: [f64; 5] = std::array::from_fn(|i| crate::logaddexp2(xf[i], yf[i]));
np_assert_allclose!(zr, zf, atol = 1e-15);
}
}