#[cfg(target_arch = "x86_64")]
mod avx2;
#[cfg(target_arch = "x86_64")]
mod avx512;
#[cfg(target_arch = "aarch64")]
mod aarch64;
use super::{SimdLevel, detect_simd};
const SIMD_THRESHOLD: usize = 32;
#[inline]
pub unsafe fn logsumexp_f32(a: *const f32, out: *mut f32, reduce_size: usize, outer_size: usize) {
let level = detect_simd();
if reduce_size < SIMD_THRESHOLD || level == SimdLevel::Scalar {
logsumexp_scalar_f32(a, out, reduce_size, outer_size);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::logsumexp_f32(a, out, reduce_size, outer_size),
SimdLevel::Avx2Fma => avx2::logsumexp_f32(a, out, reduce_size, outer_size),
_ => logsumexp_scalar_f32(a, out, reduce_size, outer_size),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::logsumexp_f32(a, out, reduce_size, outer_size)
}
_ => logsumexp_scalar_f32(a, out, reduce_size, outer_size),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
logsumexp_scalar_f32(a, out, reduce_size, outer_size);
}
#[inline]
pub unsafe fn logsumexp_f64(a: *const f64, out: *mut f64, reduce_size: usize, outer_size: usize) {
let level = detect_simd();
if reduce_size < SIMD_THRESHOLD || level == SimdLevel::Scalar {
logsumexp_scalar_f64(a, out, reduce_size, outer_size);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::logsumexp_f64(a, out, reduce_size, outer_size),
SimdLevel::Avx2Fma => avx2::logsumexp_f64(a, out, reduce_size, outer_size),
_ => logsumexp_scalar_f64(a, out, reduce_size, outer_size),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::logsumexp_f64(a, out, reduce_size, outer_size)
}
_ => logsumexp_scalar_f64(a, out, reduce_size, outer_size),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
logsumexp_scalar_f64(a, out, reduce_size, outer_size);
}
#[inline]
pub unsafe fn logsumexp_scalar_f32(
a: *const f32,
out: *mut f32,
reduce_size: usize,
outer_size: usize,
) {
for o in 0..outer_size {
let base = o * reduce_size;
let mut max_val = *a.add(base);
for i in 1..reduce_size {
let val = *a.add(base + i);
if val > max_val {
max_val = val;
}
}
let mut sum = 0.0f32;
for i in 0..reduce_size {
let val = *a.add(base + i);
sum += (val - max_val).exp();
}
*out.add(o) = max_val + sum.ln();
}
}
#[inline]
pub unsafe fn logsumexp_scalar_f64(
a: *const f64,
out: *mut f64,
reduce_size: usize,
outer_size: usize,
) {
for o in 0..outer_size {
let base = o * reduce_size;
let mut max_val = *a.add(base);
for i in 1..reduce_size {
let val = *a.add(base + i);
if val > max_val {
max_val = val;
}
}
let mut sum = 0.0f64;
for i in 0..reduce_size {
let val = *a.add(base + i);
sum += (val - max_val).exp();
}
*out.add(o) = max_val + sum.ln();
}
}
#[cfg(feature = "f16")]
pub unsafe fn logsumexp_f16(
a: *const half::f16,
out: *mut half::f16,
reduce_size: usize,
outer_size: usize,
) {
use super::half_convert_utils::*;
let input_len = outer_size * reduce_size;
let mut a_f32 = vec![0.0f32; input_len];
let mut out_f32 = vec![0.0f32; outer_size];
convert_f16_to_f32(a as *const u16, a_f32.as_mut_ptr(), input_len);
logsumexp_f32(
a_f32.as_ptr(),
out_f32.as_mut_ptr(),
reduce_size,
outer_size,
);
convert_f32_to_f16(out_f32.as_ptr(), out as *mut u16, outer_size);
}
#[cfg(feature = "f16")]
pub unsafe fn logsumexp_bf16(
a: *const half::bf16,
out: *mut half::bf16,
reduce_size: usize,
outer_size: usize,
) {
use super::half_convert_utils::*;
let input_len = outer_size * reduce_size;
let mut a_f32 = vec![0.0f32; input_len];
let mut out_f32 = vec![0.0f32; outer_size];
convert_bf16_to_f32(a as *const u16, a_f32.as_mut_ptr(), input_len);
logsumexp_f32(
a_f32.as_ptr(),
out_f32.as_mut_ptr(),
reduce_size,
outer_size,
);
convert_f32_to_bf16(out_f32.as_ptr(), out as *mut u16, outer_size);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_logsumexp_f32() {
let reduce_size = 128;
let outer_size = 4;
let input: Vec<f32> = (0..(outer_size * reduce_size))
.map(|x| (x as f32) / 100.0 - 2.5)
.collect();
let mut out = vec![0.0f32; outer_size];
let mut out_ref = vec![0.0f32; outer_size];
unsafe {
logsumexp_f32(input.as_ptr(), out.as_mut_ptr(), reduce_size, outer_size);
logsumexp_scalar_f32(
input.as_ptr(),
out_ref.as_mut_ptr(),
reduce_size,
outer_size,
);
}
for i in 0..outer_size {
let rel_err = if out_ref[i].abs() > 1e-10 {
(out[i] - out_ref[i]).abs() / out_ref[i].abs()
} else {
(out[i] - out_ref[i]).abs()
};
assert!(
rel_err < 1e-4,
"mismatch at {}: {} vs {} (rel_err: {})",
i,
out[i],
out_ref[i],
rel_err
);
}
}
#[test]
fn test_logsumexp_numerical_stability() {
let reduce_size = 64;
let input: Vec<f32> = (0..reduce_size).map(|x| 1000.0 + x as f32).collect();
let mut out = vec![0.0f32; 1];
unsafe {
logsumexp_f32(input.as_ptr(), out.as_mut_ptr(), reduce_size, 1);
}
assert!(out[0].is_finite(), "non-finite value: {}", out[0]);
let max_val = 1063.0f32;
let sum: f32 = (0..reduce_size)
.map(|x| ((1000.0 + x as f32) - max_val).exp())
.sum();
let expected = max_val + sum.ln();
assert!(
(out[0] - expected).abs() < 0.5,
"result {} vs expected {}",
out[0],
expected
);
}
#[test]
fn test_logsumexp_single_element() {
let input = [5.0f32];
let mut out = [0.0f32];
unsafe {
logsumexp_scalar_f32(input.as_ptr(), out.as_mut_ptr(), 1, 1);
}
assert!((out[0] - 5.0).abs() < 1e-6);
}
}