#[inline(always)]
pub fn fast_log2sumexp2(a: f64, b: f64) -> f64 {
let (min, max) = if a < b { (a, b) } else { (b, a) };
if min == f64::NEG_INFINITY {
return max;
}
let diff = min - max;
if diff < -50.0 {
max
} else {
let ln2 = std::f64::consts::LN_2;
let log2e = std::f64::consts::LOG2_E;
max + log2e * (diff * ln2).exp().ln_1p()
}
}
#[inline(always)]
pub fn fast_log2sumexp2_3(a: f64, b: f64, c: f64) -> f64 {
let max = a.max(b).max(c);
if max == f64::NEG_INFINITY {
return f64::NEG_INFINITY;
}
let da = a - max;
let db = b - max;
let dc = c - max;
let mut sum = 0.0_f64;
if da > -50.0 {
sum += fast_exp2(da);
}
if db > -50.0 {
sum += fast_exp2(db);
}
if dc > -50.0 {
sum += fast_exp2(dc);
}
if sum == 0.0 {
max
} else {
sum.log2() + max
}
}
#[inline(always)]
fn fast_exp2(x: f64) -> f64 {
if x < -50.0 {
0.0
} else {
2_f64.powf(x)
}
}
pub fn log2sumexp2(values: &[f64]) -> f64 {
if values.is_empty() {
return f64::NEG_INFINITY;
}
let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if max == f64::NEG_INFINITY {
return f64::NEG_INFINITY;
}
let sum: f64 = values.iter().map(|&x| 2_f64.powf(x - max)).sum();
sum.log2() + max
}
#[inline]
pub fn log2_prod<L, I>(labels: I, log2_sizes: &std::collections::HashMap<L, f64>) -> f64
where
L: std::hash::Hash + Eq,
I: IntoIterator<Item = L>,
{
labels
.into_iter()
.map(|l| log2_sizes.get(&l).copied().unwrap_or(0.0))
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_fast_log2sumexp2_equal() {
let result = fast_log2sumexp2(10.0, 10.0);
assert!((result - 11.0).abs() < 1e-10);
}
#[test]
fn test_fast_log2sumexp2_different() {
let result = fast_log2sumexp2(3.0, 5.0);
let expected = 40_f64.log2();
assert!((result - expected).abs() < 1e-10);
}
#[test]
fn test_fast_log2sumexp2_neg_inf() {
let result = fast_log2sumexp2(f64::NEG_INFINITY, 5.0);
assert!((result - 5.0).abs() < 1e-10);
}
#[test]
fn test_fast_log2sumexp2_large_difference() {
let result = fast_log2sumexp2(-100.0, 10.0);
assert!((result - 10.0).abs() < 1e-10);
}
#[test]
fn test_fast_log2sumexp2_reversed_order() {
let result1 = fast_log2sumexp2(3.0, 5.0);
let result2 = fast_log2sumexp2(5.0, 3.0);
assert!((result1 - result2).abs() < 1e-10);
}
#[test]
fn test_fast_log2sumexp2_3() {
let result = fast_log2sumexp2_3(10.0, 10.0, 10.0);
let expected = (3.0 * 2_f64.powi(10)).log2();
assert!((result - expected).abs() < 1e-10);
}
#[test]
fn test_fast_log2sumexp2_3_all_neg_inf() {
let result = fast_log2sumexp2_3(f64::NEG_INFINITY, f64::NEG_INFINITY, f64::NEG_INFINITY);
assert!(result == f64::NEG_INFINITY);
}
#[test]
fn test_fast_log2sumexp2_3_one_neg_inf() {
let result = fast_log2sumexp2_3(f64::NEG_INFINITY, 10.0, 10.0);
let expected = (2.0 * 2_f64.powi(10)).log2();
assert!((result - expected).abs() < 1e-10);
}
#[test]
fn test_fast_log2sumexp2_3_large_difference() {
let result = fast_log2sumexp2_3(-100.0, -100.0, 10.0);
assert!((result - 10.0).abs() < 1e-10);
}
#[test]
fn test_log2sumexp2_vec() {
let result = log2sumexp2(&[10.0, 10.0, 10.0, 10.0]);
assert!((result - 12.0).abs() < 1e-10);
}
#[test]
fn test_log2sumexp2_empty() {
let result = log2sumexp2(&[]);
assert!(result == f64::NEG_INFINITY);
}
#[test]
fn test_log2sumexp2_all_neg_inf() {
let result = log2sumexp2(&[f64::NEG_INFINITY, f64::NEG_INFINITY]);
assert!(result == f64::NEG_INFINITY);
}
#[test]
fn test_log2sumexp2_single() {
let result = log2sumexp2(&[10.0]);
assert!((result - 10.0).abs() < 1e-10);
}
#[test]
fn test_log2_prod() {
let mut log2_sizes: HashMap<char, f64> = HashMap::new();
log2_sizes.insert('i', 2.0);
log2_sizes.insert('j', 3.0);
log2_sizes.insert('k', 4.0);
let result = log2_prod(['i', 'j', 'k'].iter().cloned(), &log2_sizes);
assert!((result - 9.0).abs() < 1e-10); }
#[test]
fn test_log2_prod_missing_label() {
let mut log2_sizes: HashMap<char, f64> = HashMap::new();
log2_sizes.insert('i', 2.0);
let result = log2_prod(['i', 'j'].iter().cloned(), &log2_sizes);
assert!((result - 2.0).abs() < 1e-10);
}
#[test]
fn test_log2_prod_empty() {
let log2_sizes: HashMap<char, f64> = HashMap::new();
let result = log2_prod(std::iter::empty::<char>(), &log2_sizes);
assert!((result - 0.0).abs() < 1e-10);
}
#[test]
fn test_fast_exp2_very_negative() {
let result = fast_log2sumexp2_3(-100.0, -100.0, 10.0);
assert!((result - 10.0).abs() < 1e-10);
}
#[test]
fn test_fast_log2sumexp2_3_all_same() {
let result = fast_log2sumexp2_3(-100.0, -100.0, -100.0);
let expected = -100.0 + 3_f64.log2();
assert!((result - expected).abs() < 1e-10);
}
#[test]
fn test_log2sumexp2_large_values() {
let result = log2sumexp2(&[500.0, 500.0]);
assert!((result - 501.0).abs() < 1e-10);
}
#[test]
fn test_log2sumexp2_mixed_magnitudes() {
let result = log2sumexp2(&[-10.0, 10.0]);
assert!((result - 10.0).abs() < 0.001);
}
#[test]
fn test_log2sumexp2_consistency_with_fast() {
let values = [5.0, 10.0];
let result_vec = log2sumexp2(&values);
let result_fast = fast_log2sumexp2(5.0, 10.0);
assert!((result_vec - result_fast).abs() < 1e-10);
}
#[test]
fn test_log2sumexp2_many_small_values() {
let result = log2sumexp2(&[10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]);
assert!((result - 13.0).abs() < 1e-10);
}
}