#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
const I8_LANES: usize = 64;
#[target_feature(enable = "avx512f")]
unsafe fn hsum_epi32_512(v: __m512i) -> i32 {
let lo256 = _mm512_castsi512_si256(v);
let hi256 = _mm512_extracti64x4_epi64(v, 1);
let sum256 = _mm256_add_epi32(lo256, hi256);
let hi128 = _mm256_extracti128_si256(sum256, 1);
let lo128 = _mm256_castsi256_si128(sum256);
let sum128 = _mm_add_epi32(lo128, hi128);
let hi64 = _mm_unpackhi_epi64(sum128, sum128);
let sum64 = _mm_add_epi32(sum128, hi64);
let hi32 = _mm_shuffle_epi32(sum64, 0b_00_00_00_01);
let sum32 = _mm_add_epi32(sum64, hi32);
_mm_cvtsi128_si32(sum32)
}
#[target_feature(enable = "avx512f", enable = "avx512bw")]
pub unsafe fn i8xi8_dot_i32(a: *const i8, b: *const i8, len: usize) -> i32 {
let chunks = len / I8_LANES;
let remainder = len % I8_LANES;
let mut acc = _mm512_setzero_si512();
for i in 0..chunks {
let offset = i * I8_LANES;
let va = _mm512_loadu_si512(a.add(offset) as *const __m512i);
let vb = _mm512_loadu_si512(b.add(offset) as *const __m512i);
let va_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(va));
let vb_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(vb));
let prod_lo = _mm512_madd_epi16(va_lo, vb_lo);
acc = _mm512_add_epi32(acc, prod_lo);
let va_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(va, 1));
let vb_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(vb, 1));
let prod_hi = _mm512_madd_epi16(va_hi, vb_hi);
acc = _mm512_add_epi32(acc, prod_hi);
}
let mut result = hsum_epi32_512(acc);
for i in 0..remainder {
let offset = chunks * I8_LANES + i;
result += (*a.add(offset) as i32) * (*b.add(offset) as i32);
}
result
}