#[cfg(table_format = "q16_16")]
use std::arch::x86_64::*;
use crate::fixed_point::universal::fasc::stack_evaluator::{BinaryStorage, ComputeStorage};
#[cfg(table_format = "q16_16")]
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn tq19_dot_avx2(
weights: &[i16],
activations: &[BinaryStorage], ) -> ComputeStorage { let n = weights.len();
let chunks = n / 8;
let mut acc_even = _mm256_setzero_si256();
let mut acc_odd = _mm256_setzero_si256();
let w_ptr = weights.as_ptr();
let a_ptr = activations.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let w_128 = _mm_loadu_si128(w_ptr.add(offset) as *const __m128i);
let w = _mm256_cvtepi16_epi32(w_128);
let a = _mm256_loadu_si256(a_ptr.add(offset) as *const __m256i);
let prod_even = _mm256_mul_epi32(w, a);
acc_even = _mm256_add_epi64(acc_even, prod_even);
let w_odd = _mm256_srli_epi64(w, 32);
let a_odd = _mm256_srli_epi64(a, 32);
let prod_odd = _mm256_mul_epi32(w_odd, a_odd);
acc_odd = _mm256_add_epi64(acc_odd, prod_odd);
}
let acc = _mm256_add_epi64(acc_even, acc_odd);
let mut result = hsum_epi64(acc);
for i in (chunks * 8)..n {
result += (weights[i] as i64) * (activations[i] as i64);
}
result
}
#[cfg(not(table_format = "q16_16"))]
#[allow(dead_code)]
pub(crate) unsafe fn tq19_dot_avx2(
_weights: &[i16],
_activations: &[BinaryStorage],
) -> ComputeStorage {
unreachable!("SIMD TQ1.9 dot only available on Q16.16 profile")
}
#[cfg(table_format = "q16_16")]
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn trit_dot_avx2(
trits: &[i8],
activations: &[BinaryStorage], ) -> ComputeStorage { let n = trits.len();
let chunks = n / 8;
let mut acc_lo = _mm256_setzero_si256(); let mut acc_hi = _mm256_setzero_si256();
let t_ptr = trits.as_ptr();
let a_ptr = activations.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let t_64 = _mm_loadl_epi64(t_ptr.add(offset) as *const __m128i);
let t_i32 = _mm256_cvtepi8_epi32(t_64);
let a = _mm256_loadu_si256(a_ptr.add(offset) as *const __m256i);
let signed = _mm256_sign_epi32(a, t_i32);
let lo_128 = _mm256_castsi256_si128(signed);
let hi_128 = _mm256_extracti128_si256(signed, 1);
let lo_64 = _mm256_cvtepi32_epi64(lo_128);
let hi_64 = _mm256_cvtepi32_epi64(hi_128);
acc_lo = _mm256_add_epi64(acc_lo, lo_64);
acc_hi = _mm256_add_epi64(acc_hi, hi_64);
}
let acc = _mm256_add_epi64(acc_lo, acc_hi);
let mut result = hsum_epi64(acc);
for i in (chunks * 8)..n {
match trits[i] {
1 => result += activations[i] as i64,
-1 => result -= activations[i] as i64,
_ => {}
}
}
result
}
#[cfg(not(table_format = "q16_16"))]
#[allow(dead_code)]
pub(crate) unsafe fn trit_dot_avx2(
_trits: &[i8],
_activations: &[BinaryStorage],
) -> ComputeStorage {
unreachable!("SIMD trit dot only available on Q16.16 profile")
}
#[cfg(table_format = "q16_16")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn hsum_epi64(v: __m256i) -> i64 {
let hi128 = _mm256_extracti128_si256(v, 1); let lo128 = _mm256_castsi256_si128(v); let sum128 = _mm_add_epi64(lo128, hi128); let hi64 = _mm_srli_si128(sum128, 8); let total = _mm_add_epi64(sum128, hi64); _mm_cvtsi128_si64(total)
}
#[cfg(test)]
#[cfg(table_format = "q16_16")]
mod tests {
use super::*;
#[test]
fn simd_tq19_dot_matches_scalar() {
if !std::is_x86_feature_detected!("avx2") {
return; }
let weights: Vec<i16> = (0..16).map(|i| ((i * 1000 - 8000) as i16)).collect();
let activations: Vec<i32> = (0..16).map(|i| (i * 5000 + 1000) as i32).collect();
let mut scalar_acc: i64 = 0;
for i in 0..16 {
scalar_acc += (weights[i] as i64) * (activations[i] as i64);
}
let simd_acc = unsafe { tq19_dot_avx2(&weights, &activations) };
assert_eq!(simd_acc, scalar_acc, "SIMD and scalar must produce identical results");
}
#[test]
fn simd_trit_dot_matches_scalar() {
if !std::is_x86_feature_detected!("avx2") {
return;
}
let trits: Vec<i8> = vec![1, -1, 0, 1, 1, -1, 0, 1, -1, 1, 0, 0, 1, -1, 1, -1];
let activations: Vec<i32> = (0..16).map(|i| (i * 3000 + 500) as i32).collect();
let mut scalar_acc: i64 = 0;
for i in 0..16 {
match trits[i] {
1 => scalar_acc += activations[i] as i64,
-1 => scalar_acc -= activations[i] as i64,
_ => {}
}
}
let simd_acc = unsafe { trit_dot_avx2(&trits, &activations) };
assert_eq!(simd_acc, scalar_acc, "SIMD trit dot must match scalar");
}
#[test]
fn simd_handles_remainder() {
if !std::is_x86_feature_detected!("avx2") {
return;
}
let weights: Vec<i16> = (0..13).map(|i| (i * 500) as i16).collect();
let activations: Vec<i32> = (0..13).map(|i| (i * 1000 + 100) as i32).collect();
let mut scalar_acc: i64 = 0;
for i in 0..13 {
scalar_acc += (weights[i] as i64) * (activations[i] as i64);
}
let simd_acc = unsafe { tq19_dot_avx2(&weights, &activations) };
assert_eq!(simd_acc, scalar_acc, "remainder handling must be correct");
}
}