#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use super::super::super::math::aarch64::neon::{
exp_f32, exp_f64, hmax_f32, hmax_f64, hsum_f32, hsum_f64,
};
const F32_LANES: usize = 4;
const F64_LANES: usize = 2;
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn logsumexp_f32(a: *const f32, out: *mut f32, reduce_size: usize, outer_size: usize) {
let chunks = reduce_size / F32_LANES;
let remainder = reduce_size % F32_LANES;
for o in 0..outer_size {
let base = a.add(o * reduce_size);
let mut max_acc = vdupq_n_f32(f32::NEG_INFINITY);
for i in 0..chunks {
let v = vld1q_f32(base.add(i * F32_LANES));
max_acc = vmaxq_f32(max_acc, v);
}
let mut max_val = hmax_f32(max_acc);
for i in 0..remainder {
let val = *base.add(chunks * F32_LANES + i);
if val > max_val {
max_val = val;
}
}
let v_max = vdupq_n_f32(max_val);
let mut sum_acc = vdupq_n_f32(0.0);
for i in 0..chunks {
let v = vld1q_f32(base.add(i * F32_LANES));
let shifted = vsubq_f32(v, v_max);
let exp_v = exp_f32(shifted);
sum_acc = vaddq_f32(sum_acc, exp_v);
}
let mut sum = hsum_f32(sum_acc);
for i in 0..remainder {
let val = *base.add(chunks * F32_LANES + i);
sum += (val - max_val).exp();
}
*out.add(o) = max_val + sum.ln();
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn logsumexp_f64(a: *const f64, out: *mut f64, reduce_size: usize, outer_size: usize) {
let chunks = reduce_size / F64_LANES;
let remainder = reduce_size % F64_LANES;
for o in 0..outer_size {
let base = a.add(o * reduce_size);
let mut max_acc = vdupq_n_f64(f64::NEG_INFINITY);
for i in 0..chunks {
let v = vld1q_f64(base.add(i * F64_LANES));
max_acc = vmaxq_f64(max_acc, v);
}
let mut max_val = hmax_f64(max_acc);
for i in 0..remainder {
let val = *base.add(chunks * F64_LANES + i);
if val > max_val {
max_val = val;
}
}
let v_max = vdupq_n_f64(max_val);
let mut sum_acc = vdupq_n_f64(0.0);
for i in 0..chunks {
let v = vld1q_f64(base.add(i * F64_LANES));
let shifted = vsubq_f64(v, v_max);
let exp_v = exp_f64(shifted);
sum_acc = vaddq_f64(sum_acc, exp_v);
}
let mut sum = hsum_f64(sum_acc);
for i in 0..remainder {
let val = *base.add(chunks * F64_LANES + i);
sum += (val - max_val).exp();
}
*out.add(o) = max_val + sum.ln();
}
}