#[cfg(target_arch = "x86_64")]
mod avx2;
#[cfg(target_arch = "x86_64")]
mod avx512;
#[cfg(target_arch = "aarch64")]
mod aarch64;
use super::{SimdLevel, detect_simd};
const SIMD_THRESHOLD: usize = 32;
#[inline]
pub unsafe fn softmax_f32(a: *const f32, out: *mut f32, outer_size: usize, dim_size: usize) {
let level = detect_simd();
if dim_size < SIMD_THRESHOLD || level == SimdLevel::Scalar {
softmax_scalar_f32(a, out, outer_size, dim_size);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::softmax_f32(a, out, outer_size, dim_size),
SimdLevel::Avx2Fma => avx2::softmax_f32(a, out, outer_size, dim_size),
_ => softmax_scalar_f32(a, out, outer_size, dim_size),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::softmax_f32(a, out, outer_size, dim_size)
}
_ => softmax_scalar_f32(a, out, outer_size, dim_size),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
softmax_scalar_f32(a, out, outer_size, dim_size);
}
#[inline]
pub unsafe fn softmax_f64(a: *const f64, out: *mut f64, outer_size: usize, dim_size: usize) {
let level = detect_simd();
if dim_size < SIMD_THRESHOLD || level == SimdLevel::Scalar {
softmax_scalar_f64(a, out, outer_size, dim_size);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::softmax_f64(a, out, outer_size, dim_size),
SimdLevel::Avx2Fma => avx2::softmax_f64(a, out, outer_size, dim_size),
_ => softmax_scalar_f64(a, out, outer_size, dim_size),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::softmax_f64(a, out, outer_size, dim_size)
}
_ => softmax_scalar_f64(a, out, outer_size, dim_size),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
softmax_scalar_f64(a, out, outer_size, dim_size);
}
#[inline]
pub unsafe fn softmax_scalar_f32(a: *const f32, out: *mut f32, outer_size: usize, dim_size: usize) {
for o in 0..outer_size {
let base = o * dim_size;
let mut max_val = *a.add(base);
let mut sum = if max_val.is_finite() { 1.0f32 } else { 0.0f32 };
for d in 1..dim_size {
let val = *a.add(base + d);
if val > max_val {
let rescale = if max_val == f32::NEG_INFINITY {
0.0
} else {
(max_val - val).exp()
};
sum = sum * rescale + 1.0;
max_val = val;
} else if val == f32::NEG_INFINITY {
} else {
sum += (val - max_val).exp();
}
}
let inv_sum = 1.0 / sum;
for d in 0..dim_size {
let val = *a.add(base + d);
*out.add(base + d) = if val == f32::NEG_INFINITY {
0.0
} else {
(val - max_val).exp() * inv_sum
};
}
}
}
#[inline]
pub unsafe fn softmax_scalar_f64(a: *const f64, out: *mut f64, outer_size: usize, dim_size: usize) {
for o in 0..outer_size {
let base = o * dim_size;
let mut max_val = *a.add(base);
let mut sum = if max_val.is_finite() { 1.0f64 } else { 0.0f64 };
for d in 1..dim_size {
let val = *a.add(base + d);
if val > max_val {
let rescale = if max_val == f64::NEG_INFINITY {
0.0
} else {
(max_val - val).exp()
};
sum = sum * rescale + 1.0;
max_val = val;
} else if val == f64::NEG_INFINITY {
} else {
sum += (val - max_val).exp();
}
}
let inv_sum = 1.0 / sum;
for d in 0..dim_size {
let val = *a.add(base + d);
*out.add(base + d) = if val == f64::NEG_INFINITY {
0.0
} else {
(val - max_val).exp() * inv_sum
};
}
}
}
#[cfg(feature = "f16")]
pub unsafe fn softmax_f16(
a: *const half::f16,
out: *mut half::f16,
outer_size: usize,
dim_size: usize,
) {
use super::half_convert_utils::*;
let row_len = dim_size;
let mut a_buf = vec![0.0f32; row_len];
let mut out_buf = vec![0.0f32; row_len];
for i in 0..outer_size {
let offset = i * dim_size;
convert_f16_to_f32(a.add(offset) as *const u16, a_buf.as_mut_ptr(), row_len);
softmax_f32(a_buf.as_ptr(), out_buf.as_mut_ptr(), 1, dim_size);
convert_f32_to_f16(out_buf.as_ptr(), out.add(offset) as *mut u16, row_len);
}
}
#[cfg(feature = "f16")]
pub unsafe fn softmax_bf16(
a: *const half::bf16,
out: *mut half::bf16,
outer_size: usize,
dim_size: usize,
) {
use super::half_convert_utils::*;
let row_len = dim_size;
let mut a_buf = vec![0.0f32; row_len];
let mut out_buf = vec![0.0f32; row_len];
for i in 0..outer_size {
let offset = i * dim_size;
convert_bf16_to_f32(a.add(offset) as *const u16, a_buf.as_mut_ptr(), row_len);
softmax_f32(a_buf.as_ptr(), out_buf.as_mut_ptr(), 1, dim_size);
convert_f32_to_bf16(out_buf.as_ptr(), out.add(offset) as *mut u16, row_len);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_softmax_f32() {
let dim_size = 128;
let outer_size = 4;
let input: Vec<f32> = (0..(outer_size * dim_size))
.map(|x| (x as f32) / 100.0 - 2.5)
.collect();
let mut out = vec![0.0f32; outer_size * dim_size];
let mut out_ref = vec![0.0f32; outer_size * dim_size];
unsafe {
softmax_f32(input.as_ptr(), out.as_mut_ptr(), outer_size, dim_size);
softmax_scalar_f32(input.as_ptr(), out_ref.as_mut_ptr(), outer_size, dim_size);
}
for i in 0..(outer_size * dim_size) {
let rel_err = if out_ref[i].abs() > 1e-10 {
(out[i] - out_ref[i]).abs() / out_ref[i].abs()
} else {
(out[i] - out_ref[i]).abs()
};
assert!(
rel_err < 1e-3,
"mismatch at {}: {} vs {} (rel_err: {})",
i,
out[i],
out_ref[i],
rel_err
);
}
}
#[test]
fn test_softmax_sum_to_one() {
let dim_size = 64;
let outer_size = 2;
let input: Vec<f32> = (0..(outer_size * dim_size))
.map(|x| (x as f32) * 0.1 - 3.0)
.collect();
let mut out = vec![0.0f32; outer_size * dim_size];
unsafe {
softmax_f32(input.as_ptr(), out.as_mut_ptr(), outer_size, dim_size);
}
for o in 0..outer_size {
let row_sum: f32 = out[o * dim_size..(o + 1) * dim_size].iter().sum();
assert!(
(row_sum - 1.0).abs() < 1e-5,
"row {} sum = {}, expected 1.0",
o,
row_sum
);
}
}
#[test]
fn test_softmax_numerical_stability() {
let dim_size = 64;
let input: Vec<f32> = (0..dim_size).map(|x| 1000.0 + x as f32).collect();
let mut out = vec![0.0f32; dim_size];
unsafe {
softmax_f32(input.as_ptr(), out.as_mut_ptr(), 1, dim_size);
}
for (i, &val) in out.iter().enumerate() {
assert!(val.is_finite(), "non-finite value at {}: {}", i, val);
}
let sum: f32 = out.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "sum = {}, expected 1.0", sum);
}
}