#[cfg(target_arch = "x86_64")]
mod avx2;
#[cfg(target_arch = "x86_64")]
mod avx512;
#[cfg(target_arch = "aarch64")]
mod aarch64;
use super::{SimdLevel, detect_simd};
const SIMD_THRESHOLD: usize = 32;
#[inline]
pub unsafe fn softmax_bwd_f32(
grad: *const f32,
output: *const f32,
d_input: *mut f32,
outer_size: usize,
dim_size: usize,
) {
let level = detect_simd();
if dim_size < SIMD_THRESHOLD || level == SimdLevel::Scalar {
softmax_bwd_scalar_f32(grad, output, d_input, outer_size, dim_size);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::softmax_bwd_f32(grad, output, d_input, outer_size, dim_size),
SimdLevel::Avx2Fma => avx2::softmax_bwd_f32(grad, output, d_input, outer_size, dim_size),
_ => softmax_bwd_scalar_f32(grad, output, d_input, outer_size, dim_size),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::softmax_bwd_f32(grad, output, d_input, outer_size, dim_size)
}
_ => softmax_bwd_scalar_f32(grad, output, d_input, outer_size, dim_size),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
softmax_bwd_scalar_f32(grad, output, d_input, outer_size, dim_size);
}
#[inline]
pub unsafe fn softmax_bwd_f64(
grad: *const f64,
output: *const f64,
d_input: *mut f64,
outer_size: usize,
dim_size: usize,
) {
let level = detect_simd();
if dim_size < SIMD_THRESHOLD || level == SimdLevel::Scalar {
softmax_bwd_scalar_f64(grad, output, d_input, outer_size, dim_size);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::softmax_bwd_f64(grad, output, d_input, outer_size, dim_size),
SimdLevel::Avx2Fma => avx2::softmax_bwd_f64(grad, output, d_input, outer_size, dim_size),
_ => softmax_bwd_scalar_f64(grad, output, d_input, outer_size, dim_size),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::softmax_bwd_f64(grad, output, d_input, outer_size, dim_size)
}
_ => softmax_bwd_scalar_f64(grad, output, d_input, outer_size, dim_size),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
softmax_bwd_scalar_f64(grad, output, d_input, outer_size, dim_size);
}
#[inline]
pub unsafe fn softmax_bwd_scalar_f32(
grad: *const f32,
output: *const f32,
d_input: *mut f32,
outer_size: usize,
dim_size: usize,
) {
for o in 0..outer_size {
let base = o * dim_size;
let mut dot = 0.0f32;
for d in 0..dim_size {
dot += *grad.add(base + d) * *output.add(base + d);
}
for d in 0..dim_size {
let idx = base + d;
*d_input.add(idx) = *output.add(idx) * (*grad.add(idx) - dot);
}
}
}
#[inline]
pub unsafe fn softmax_bwd_scalar_f64(
grad: *const f64,
output: *const f64,
d_input: *mut f64,
outer_size: usize,
dim_size: usize,
) {
for o in 0..outer_size {
let base = o * dim_size;
let mut dot = 0.0f64;
for d in 0..dim_size {
dot += *grad.add(base + d) * *output.add(base + d);
}
for d in 0..dim_size {
let idx = base + d;
*d_input.add(idx) = *output.add(idx) * (*grad.add(idx) - dot);
}
}
}
#[cfg(feature = "f16")]
pub unsafe fn softmax_bwd_f16(
grad: *const half::f16,
output: *const half::f16,
d_input: *mut half::f16,
outer_size: usize,
dim_size: usize,
) {
use super::half_convert_utils::*;
let row_len = dim_size;
let mut grad_buf = vec![0.0f32; row_len];
let mut out_buf = vec![0.0f32; row_len];
let mut result_buf = vec![0.0f32; row_len];
for i in 0..outer_size {
let offset = i * dim_size;
convert_f16_to_f32(
grad.add(offset) as *const u16,
grad_buf.as_mut_ptr(),
row_len,
);
convert_f16_to_f32(
output.add(offset) as *const u16,
out_buf.as_mut_ptr(),
row_len,
);
softmax_bwd_f32(
grad_buf.as_ptr(),
out_buf.as_ptr(),
result_buf.as_mut_ptr(),
1,
dim_size,
);
convert_f32_to_f16(
result_buf.as_ptr(),
d_input.add(offset) as *mut u16,
row_len,
);
}
}
#[cfg(feature = "f16")]
pub unsafe fn softmax_bwd_bf16(
grad: *const half::bf16,
output: *const half::bf16,
d_input: *mut half::bf16,
outer_size: usize,
dim_size: usize,
) {
use super::half_convert_utils::*;
let row_len = dim_size;
let mut grad_buf = vec![0.0f32; row_len];
let mut out_buf = vec![0.0f32; row_len];
let mut result_buf = vec![0.0f32; row_len];
for i in 0..outer_size {
let offset = i * dim_size;
convert_bf16_to_f32(
grad.add(offset) as *const u16,
grad_buf.as_mut_ptr(),
row_len,
);
convert_bf16_to_f32(
output.add(offset) as *const u16,
out_buf.as_mut_ptr(),
row_len,
);
softmax_bwd_f32(
grad_buf.as_ptr(),
out_buf.as_ptr(),
result_buf.as_mut_ptr(),
1,
dim_size,
);
convert_f32_to_bf16(
result_buf.as_ptr(),
d_input.add(offset) as *mut u16,
row_len,
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_softmax_bwd_f32() {
let output = [0.09003057f32, 0.24472847, 0.66524096]; let grad = [1.0f32, 0.0, 0.0]; let mut d_input = [0.0f32; 3];
unsafe {
softmax_bwd_f32(grad.as_ptr(), output.as_ptr(), d_input.as_mut_ptr(), 1, 3);
}
assert!((d_input[0] - 0.08192507).abs() < 1e-5);
assert!((d_input[1] - (-0.02203645)).abs() < 1e-5);
assert!((d_input[2] - (-0.05988862)).abs() < 1e-5);
let sum: f32 = d_input.iter().sum();
assert!(sum.abs() < 1e-6, "gradients should sum to 0, got {sum}");
}
#[test]
fn test_softmax_bwd_simd() {
let dim_size = 128;
let outer_size = 4;
let mut output = vec![0.0f32; outer_size * dim_size];
for o in 0..outer_size {
let base = o * dim_size;
let sum: f32 = (0..dim_size).map(|d| ((d as f32) * 0.1 - 5.0).exp()).sum();
for d in 0..dim_size {
output[base + d] = ((d as f32) * 0.1 - 5.0).exp() / sum;
}
}
let grad: Vec<f32> = (0..(outer_size * dim_size))
.map(|x| (x as f32) / 100.0 - 2.5)
.collect();
let mut d_input_simd = vec![0.0f32; outer_size * dim_size];
let mut d_input_ref = vec![0.0f32; outer_size * dim_size];
unsafe {
softmax_bwd_f32(
grad.as_ptr(),
output.as_ptr(),
d_input_simd.as_mut_ptr(),
outer_size,
dim_size,
);
softmax_bwd_scalar_f32(
grad.as_ptr(),
output.as_ptr(),
d_input_ref.as_mut_ptr(),
outer_size,
dim_size,
);
}
for i in 0..(outer_size * dim_size) {
let rel_err = if d_input_ref[i].abs() > 1e-10 {
(d_input_simd[i] - d_input_ref[i]).abs() / d_input_ref[i].abs()
} else {
(d_input_simd[i] - d_input_ref[i]).abs()
};
assert!(
rel_err < 1e-3,
"mismatch at {}: {} vs {} (rel_err: {})",
i,
d_input_simd[i],
d_input_ref[i],
rel_err
);
}
}
}