use crate::token::TokenId;
#[derive(Debug, Clone, Copy, Default)]
pub struct SimdCapabilities {
pub avx2: bool,
pub avx512: bool,
pub neon: bool,
}
impl SimdCapabilities {
pub fn detect() -> Self {
Self {
avx2: has_avx2(),
avx512: has_avx512(),
neon: has_neon(),
}
}
pub fn best_available(&self) -> &'static str {
if self.avx512 {
"AVX-512"
} else if self.avx2 {
"AVX2"
} else if self.neon {
"NEON"
} else {
"Scalar"
}
}
}
#[cfg(target_arch = "x86_64")]
pub fn has_avx2() -> bool {
is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
pub fn has_avx2() -> bool {
false
}
#[cfg(target_arch = "x86_64")]
pub fn has_avx512() -> bool {
is_x86_feature_detected!("avx512f")
}
#[cfg(not(target_arch = "x86_64"))]
pub fn has_avx512() -> bool {
false
}
#[cfg(target_arch = "aarch64")]
pub fn has_neon() -> bool {
true
}
#[cfg(not(target_arch = "aarch64"))]
pub fn has_neon() -> bool {
false
}
pub fn simd_max_f32(data: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if has_avx2() {
unsafe { simd_max_f32_avx2(data) }
} else {
scalar_max_f32(data)
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { simd_max_f32_neon(data) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
scalar_max_f32(data)
}
}
#[cfg(not(target_arch = "aarch64"))]
fn scalar_max_f32(data: &[f32]) -> f32 {
data.iter().copied().fold(f32::NEG_INFINITY, f32::max)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn simd_max_f32_avx2(data: &[f32]) -> f32 {
use std::arch::x86_64::*;
let mut max_vec = _mm256_set1_ps(f32::NEG_INFINITY);
let chunks = data.chunks_exact(8);
let remainder = chunks.remainder();
for chunk in chunks {
let vals = _mm256_loadu_ps(chunk.as_ptr());
max_vec = _mm256_max_ps(max_vec, vals);
}
let low = _mm256_castps256_ps128(max_vec);
let high = _mm256_extractf128_ps(max_vec, 1);
let max128 = _mm_max_ps(low, high);
let max64 = _mm_max_ps(max128, _mm_movehl_ps(max128, max128));
let max32 = _mm_max_ss(max64, _mm_shuffle_ps(max64, max64, 1));
let mut result = _mm_cvtss_f32(max32);
for &val in remainder {
result = result.max(val);
}
result
}
#[cfg(target_arch = "aarch64")]
unsafe fn simd_max_f32_neon(data: &[f32]) -> f32 {
use std::arch::aarch64::*;
let mut max_vec = vdupq_n_f32(f32::NEG_INFINITY);
let chunks = data.chunks_exact(4);
let remainder = chunks.remainder();
for chunk in chunks {
let vals = vld1q_f32(chunk.as_ptr());
max_vec = vmaxq_f32(max_vec, vals);
}
let max2 = vpmax_f32(vget_low_f32(max_vec), vget_high_f32(max_vec));
let max1 = vpmax_f32(max2, max2);
let mut result = vget_lane_f32(max1, 0);
for &val in remainder {
result = result.max(val);
}
result
}
pub fn simd_sum_f32(data: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if has_avx2() {
unsafe { simd_sum_f32_avx2(data) }
} else {
scalar_sum_f32(data)
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { simd_sum_f32_neon(data) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
scalar_sum_f32(data)
}
}
#[cfg(not(target_arch = "aarch64"))]
fn scalar_sum_f32(data: &[f32]) -> f32 {
data.iter().sum()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn simd_sum_f32_avx2(data: &[f32]) -> f32 {
use std::arch::x86_64::*;
let mut sum_vec = _mm256_setzero_ps();
let chunks = data.chunks_exact(8);
let remainder = chunks.remainder();
for chunk in chunks {
let vals = _mm256_loadu_ps(chunk.as_ptr());
sum_vec = _mm256_add_ps(sum_vec, vals);
}
let low = _mm256_castps256_ps128(sum_vec);
let high = _mm256_extractf128_ps(sum_vec, 1);
let sum128 = _mm_add_ps(low, high);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
for &val in remainder {
result += val;
}
result
}
#[cfg(target_arch = "aarch64")]
unsafe fn simd_sum_f32_neon(data: &[f32]) -> f32 {
use std::arch::aarch64::*;
let mut sum_vec = vdupq_n_f32(0.0);
let chunks = data.chunks_exact(4);
let remainder = chunks.remainder();
for chunk in chunks {
let vals = vld1q_f32(chunk.as_ptr());
sum_vec = vaddq_f32(sum_vec, vals);
}
let sum2 = vpadd_f32(vget_low_f32(sum_vec), vget_high_f32(sum_vec));
let sum1 = vpadd_f32(sum2, sum2);
let mut result = vget_lane_f32(sum1, 0);
for &val in remainder {
result += val;
}
result
}
pub fn simd_softmax(data: &mut [f32]) {
if data.is_empty() {
return;
}
let max_val = simd_max_f32(data);
#[cfg(target_arch = "x86_64")]
{
if has_avx2() {
unsafe { simd_softmax_avx2(data, max_val) };
return;
}
scalar_softmax(data, max_val);
return;
}
#[cfg(target_arch = "aarch64")]
{
unsafe { simd_softmax_neon(data, max_val) };
return;
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
scalar_softmax(data, max_val);
}
}
#[cfg(not(target_arch = "aarch64"))]
fn scalar_softmax(data: &mut [f32], max_val: f32) {
let mut sum = 0.0f32;
for val in data.iter_mut() {
*val = (*val - max_val).exp();
sum += *val;
}
if sum > 0.0 {
for val in data.iter_mut() {
*val /= sum;
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn simd_softmax_avx2(data: &mut [f32], max_val: f32) {
use std::arch::x86_64::*;
let max_vec = _mm256_set1_ps(max_val);
let mut sum_vec = _mm256_setzero_ps();
let len = data.len();
let chunks = len / 8;
for i in 0..chunks {
let ptr = data.as_mut_ptr().add(i * 8);
let vals = _mm256_loadu_ps(ptr);
let shifted = _mm256_sub_ps(vals, max_vec);
let exp_vals = fast_exp_avx2(shifted);
_mm256_storeu_ps(ptr, exp_vals);
sum_vec = _mm256_add_ps(sum_vec, exp_vals);
}
let mut scalar_sum = 0.0f32;
for item in data.iter_mut().take(len).skip(chunks * 8) {
let val = (*item - max_val).exp();
*item = val;
scalar_sum += val;
}
let low = _mm256_castps256_ps128(sum_vec);
let high = _mm256_extractf128_ps(sum_vec, 1);
let sum128 = _mm_add_ps(low, high);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let total_sum = _mm_cvtss_f32(sum32) + scalar_sum;
if total_sum > 0.0 {
let inv_sum = _mm256_set1_ps(1.0 / total_sum);
for i in 0..chunks {
let ptr = data.as_mut_ptr().add(i * 8);
let vals = _mm256_loadu_ps(ptr);
let normalized = _mm256_mul_ps(vals, inv_sum);
_mm256_storeu_ps(ptr, normalized);
}
for item in data.iter_mut().take(len).skip(chunks * 8) {
*item /= total_sum;
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn fast_exp_avx2(x: std::arch::x86_64::__m256) -> std::arch::x86_64::__m256 {
use std::arch::x86_64::*;
let min_val = _mm256_set1_ps(-88.0);
let max_val = _mm256_set1_ps(88.0);
let x = _mm256_max_ps(_mm256_min_ps(x, max_val), min_val);
let log2e = _mm256_set1_ps(std::f32::consts::LOG2_E);
let y = _mm256_mul_ps(x, log2e);
let yi = _mm256_round_ps(y, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
let yf = _mm256_sub_ps(y, yi);
let c0 = _mm256_set1_ps(1.0);
let c1 = _mm256_set1_ps(std::f32::consts::LN_2); let c2 = _mm256_set1_ps(0.240_226_5); let c3 = _mm256_set1_ps(0.055_504_11);
let yf2 = _mm256_mul_ps(yf, yf);
let yf3 = _mm256_mul_ps(yf2, yf);
let poly = _mm256_add_ps(
c0,
_mm256_add_ps(
_mm256_mul_ps(c1, yf),
_mm256_add_ps(_mm256_mul_ps(c2, yf2), _mm256_mul_ps(c3, yf3)),
),
);
let yi_i32 = _mm256_cvtps_epi32(yi);
let shift = _mm256_add_epi32(yi_i32, _mm256_set1_epi32(127));
let scale = _mm256_castsi256_ps(_mm256_slli_epi32(shift, 23));
_mm256_mul_ps(poly, scale)
}
#[cfg(target_arch = "aarch64")]
unsafe fn simd_softmax_neon(data: &mut [f32], max_val: f32) {
use std::arch::aarch64::*;
let max_vec = vdupq_n_f32(max_val);
let mut sum_vec = vdupq_n_f32(0.0);
let len = data.len();
let chunks = len / 4;
for i in 0..chunks {
let ptr = data.as_mut_ptr().add(i * 4);
let vals = vld1q_f32(ptr);
let shifted = vsubq_f32(vals, max_vec);
let mut exp_vals = [0.0f32; 4];
vst1q_f32(exp_vals.as_mut_ptr(), shifted);
for j in 0..4 {
exp_vals[j] = exp_vals[j].exp();
}
let exp_vec = vld1q_f32(exp_vals.as_ptr());
vst1q_f32(ptr, exp_vec);
sum_vec = vaddq_f32(sum_vec, exp_vec);
}
let mut scalar_sum = 0.0f32;
for i in (chunks * 4)..len {
let val = (data[i] - max_val).exp();
data[i] = val;
scalar_sum += val;
}
let sum2 = vpadd_f32(vget_low_f32(sum_vec), vget_high_f32(sum_vec));
let sum1 = vpadd_f32(sum2, sum2);
let total_sum = vget_lane_f32(sum1, 0) + scalar_sum;
if total_sum > 0.0 {
let inv_sum = vdupq_n_f32(1.0 / total_sum);
for i in 0..chunks {
let ptr = data.as_mut_ptr().add(i * 4);
let vals = vld1q_f32(ptr);
let normalized = vmulq_f32(vals, inv_sum);
vst1q_f32(ptr, normalized);
}
for i in (chunks * 4)..len {
data[i] /= total_sum;
}
}
}
pub fn simd_top_k(data: &[f32], k: usize) -> Vec<(usize, f32)> {
if k == 0 || data.is_empty() {
return Vec::new();
}
let k = k.min(data.len());
if k <= 32 {
return partial_top_k(data, k);
}
full_top_k(data, k)
}
fn partial_top_k(data: &[f32], k: usize) -> Vec<(usize, f32)> {
use std::collections::BinaryHeap;
#[derive(PartialEq)]
struct MinHeapItem(f32, usize);
impl Eq for MinHeapItem {}
impl PartialOrd for MinHeapItem {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MinHeapItem {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other
.0
.total_cmp(&self.0)
.then_with(|| self.1.cmp(&other.1))
}
}
let mut heap = BinaryHeap::with_capacity(k + 1);
for (i, &val) in data.iter().enumerate() {
if heap.len() < k {
heap.push(MinHeapItem(val, i));
} else if let Some(min) = heap.peek() {
if val > min.0 {
heap.pop();
heap.push(MinHeapItem(val, i));
}
}
}
let mut result: Vec<_> = heap.into_iter().map(|item| (item.1, item.0)).collect();
result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
result
}
fn full_top_k(data: &[f32], k: usize) -> Vec<(usize, f32)> {
let mut indexed: Vec<_> = data.iter().enumerate().map(|(i, &v)| (i, v)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(k);
indexed
}
pub fn simd_argmax(data: &[f32]) -> Option<usize> {
if data.is_empty() {
return None;
}
let max_val = simd_max_f32(data);
data.iter().position(|&v| v == max_val)
}
pub fn simd_select_top_k_tokens(logits: &[f32], k: usize) -> Vec<(TokenId, f32)> {
let top_k = simd_top_k(logits, k);
top_k
.into_iter()
.map(|(idx, val)| (idx as TokenId, val))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_max() {
let data = vec![1.0, 5.0, 3.0, 9.0, 2.0, 7.0, 4.0, 8.0, 6.0];
assert_eq!(simd_max_f32(&data), 9.0);
}
#[test]
fn test_simd_sum() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
assert!((simd_sum_f32(&data) - 36.0).abs() < 0.001);
}
#[test]
fn test_simd_softmax() {
let mut data = vec![1.0, 2.0, 3.0, 4.0];
simd_softmax(&mut data);
let sum: f32 = data.iter().sum();
assert!((sum - 1.0).abs() < 0.001);
assert!(data[3] > data[2]);
assert!(data[2] > data[1]);
assert!(data[1] > data[0]);
}
#[test]
fn test_simd_top_k() {
let data = vec![1.0, 5.0, 3.0, 9.0, 2.0, 7.0, 4.0, 8.0];
let top3 = simd_top_k(&data, 3);
assert_eq!(top3.len(), 3);
assert_eq!(top3[0], (3, 9.0)); assert_eq!(top3[1], (7, 8.0)); assert_eq!(top3[2], (5, 7.0)); }
#[test]
fn test_platform_detection() {
let _ = has_avx2();
let _ = has_avx512();
let _ = has_neon();
}
}