#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use super::super::super::math::aarch64::neon::{hsum_f32, hsum_f64};
const F32_LANES: usize = 4;
const F64_LANES: usize = 2;
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn softmax_bwd_f32(
grad: *const f32,
output: *const f32,
d_input: *mut f32,
outer_size: usize,
dim_size: usize,
) {
let chunks = dim_size / F32_LANES;
let remainder = dim_size % F32_LANES;
for o in 0..outer_size {
let g_base = grad.add(o * dim_size);
let o_base = output.add(o * dim_size);
let d_base = d_input.add(o * dim_size);
let mut dot_acc = vdupq_n_f32(0.0);
for i in 0..chunks {
let offset = i * F32_LANES;
let g = vld1q_f32(g_base.add(offset));
let out = vld1q_f32(o_base.add(offset));
dot_acc = vfmaq_f32(dot_acc, g, out);
}
let mut dot = hsum_f32(dot_acc);
for i in 0..remainder {
let offset = chunks * F32_LANES + i;
dot += *g_base.add(offset) * *o_base.add(offset);
}
let v_dot = vdupq_n_f32(dot);
for i in 0..chunks {
let offset = i * F32_LANES;
let g = vld1q_f32(g_base.add(offset));
let out = vld1q_f32(o_base.add(offset));
let shifted = vsubq_f32(g, v_dot);
let result = vmulq_f32(out, shifted);
vst1q_f32(d_base.add(offset), result);
}
for i in 0..remainder {
let offset = chunks * F32_LANES + i;
*d_base.add(offset) = *o_base.add(offset) * (*g_base.add(offset) - dot);
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn softmax_bwd_f64(
grad: *const f64,
output: *const f64,
d_input: *mut f64,
outer_size: usize,
dim_size: usize,
) {
let chunks = dim_size / F64_LANES;
let remainder = dim_size % F64_LANES;
for o in 0..outer_size {
let g_base = grad.add(o * dim_size);
let o_base = output.add(o * dim_size);
let d_base = d_input.add(o * dim_size);
let mut dot_acc = vdupq_n_f64(0.0);
for i in 0..chunks {
let offset = i * F64_LANES;
let g = vld1q_f64(g_base.add(offset));
let out = vld1q_f64(o_base.add(offset));
dot_acc = vfmaq_f64(dot_acc, g, out);
}
let mut dot = hsum_f64(dot_acc);
for i in 0..remainder {
let offset = chunks * F64_LANES + i;
dot += *g_base.add(offset) * *o_base.add(offset);
}
let v_dot = vdupq_n_f64(dot);
for i in 0..chunks {
let offset = i * F64_LANES;
let g = vld1q_f64(g_base.add(offset));
let out = vld1q_f64(o_base.add(offset));
let shifted = vsubq_f64(g, v_dot);
let result = vmulq_f64(out, shifted);
vst1q_f64(d_base.add(offset), result);
}
for i in 0..remainder {
let offset = chunks * F64_LANES + i;
*d_base.add(offset) = *o_base.add(offset) * (*g_base.add(offset) - dot);
}
}
}