#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
const I16_PER_VEC: usize = 16;
#[target_feature(enable = "avx2")]
pub unsafe fn vec_add_i16(acc: &mut [i16], w: &[i16]) {
let len = acc.len();
let chunks = len / I16_PER_VEC;
for c in 0..chunks {
let off = c * I16_PER_VEC;
let a = _mm256_loadu_si256(acc.as_ptr().add(off) as *const __m256i);
let b = _mm256_loadu_si256(w.as_ptr().add(off) as *const __m256i);
let sum = _mm256_adds_epi16(a, b);
_mm256_storeu_si256(acc.as_mut_ptr().add(off) as *mut __m256i, sum);
}
let tail = chunks * I16_PER_VEC;
for i in tail..len {
acc[i] = acc[i].saturating_add(w[i]);
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn vec_sub_i16(acc: &mut [i16], w: &[i16]) {
let len = acc.len();
let chunks = len / I16_PER_VEC;
for c in 0..chunks {
let off = c * I16_PER_VEC;
let a = _mm256_loadu_si256(acc.as_ptr().add(off) as *const __m256i);
let b = _mm256_loadu_si256(w.as_ptr().add(off) as *const __m256i);
let diff = _mm256_subs_epi16(a, b);
_mm256_storeu_si256(acc.as_mut_ptr().add(off) as *mut __m256i, diff);
}
let tail = chunks * I16_PER_VEC;
for i in tail..len {
acc[i] = acc[i].saturating_sub(w[i]);
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn vec_clipped_relu(out: &mut [i16], inp: &[i16]) {
let len = inp.len();
let chunks = len / I16_PER_VEC;
let zero = _mm256_setzero_si256();
let max127 = _mm256_set1_epi16(127);
for c in 0..chunks {
let off = c * I16_PER_VEC;
let v = _mm256_loadu_si256(inp.as_ptr().add(off) as *const __m256i);
let clamped = _mm256_min_epi16(_mm256_max_epi16(v, zero), max127);
_mm256_storeu_si256(out.as_mut_ptr().add(off) as *mut __m256i, clamped);
}
let tail = chunks * I16_PER_VEC;
for i in tail..len {
out[i] = inp[i].max(0).min(127);
}
}
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn hsum_i32(v: __m256i) -> i32 {
let hi128 = _mm256_extracti128_si256::<1>(v);
let lo128 = _mm256_castsi256_si128(v);
let sum128 = _mm_add_epi32(lo128, hi128);
let shuf1 = _mm_shuffle_epi32::<0b_00_01_10_11>(sum128);
let sum64 = _mm_add_epi32(sum128, shuf1);
let shuf2 = _mm_shufflelo_epi16::<0b_01_00_11_10>(sum64);
let sum32 = _mm_add_epi32(sum64, shuf2);
_mm_cvtsi128_si32(sum32)
}
#[target_feature(enable = "avx2")]
pub unsafe fn dot_i16_i32(a: &[i16], b: &[i16]) -> i32 {
let len = a.len();
let chunks = len / I16_PER_VEC;
let mut acc = _mm256_setzero_si256();
for c in 0..chunks {
let off = c * I16_PER_VEC;
let va = _mm256_loadu_si256(a.as_ptr().add(off) as *const __m256i);
let vb = _mm256_loadu_si256(b.as_ptr().add(off) as *const __m256i);
let prod = _mm256_madd_epi16(va, vb);
acc = _mm256_add_epi32(acc, prod);
}
let mut result = hsum_i32(acc);
let tail = chunks * I16_PER_VEC;
for i in tail..len {
result += a[i] as i32 * b[i] as i32;
}
result
}
#[target_feature(enable = "avx2")]
pub unsafe fn dot_screlu_i64(a: &[i16], b: &[i16]) -> i64 {
let len = a.len();
let chunks = len / I16_PER_VEC;
let mut acc_lo = _mm256_setzero_si256(); let mut acc_hi = _mm256_setzero_si256();
for c in 0..chunks {
let off = c * I16_PER_VEC;
let va = _mm256_loadu_si256(a.as_ptr().add(off) as *const __m256i);
let vb = _mm256_loadu_si256(b.as_ptr().add(off) as *const __m256i);
let sq = _mm256_mullo_epi16(va, va);
let prod32 = _mm256_madd_epi16(sq, vb);
let lo128 = _mm256_castsi256_si128(prod32);
let hi128 = _mm256_extracti128_si256::<1>(prod32);
let lo_64 = _mm256_cvtepi32_epi64(lo128);
let hi_64 = _mm256_cvtepi32_epi64(hi128);
acc_lo = _mm256_add_epi64(acc_lo, lo_64);
acc_hi = _mm256_add_epi64(acc_hi, hi_64);
}
let combined = _mm256_add_epi64(acc_lo, acc_hi); let mut buf = [0i64; 4];
_mm256_storeu_si256(buf.as_mut_ptr() as *mut __m256i, combined);
let mut result = buf[0] + buf[1] + buf[2] + buf[3];
let tail = chunks * I16_PER_VEC;
for i in tail..len {
result += a[i] as i64 * a[i] as i64 * b[i] as i64;
}
result
}