use std::sync::OnceLock;
#[inline]
pub fn dot_u8_scalar(a: &[u8], b: &[u8]) -> u32 {
debug_assert_eq!(a.len(), b.len());
a.iter()
.zip(b.iter())
.map(|(&x, &y)| x as u32 * y as u32)
.sum()
}
#[cfg(target_arch = "x86_64")]
mod x86 {
use std::arch::x86_64::*;
#[target_feature(enable = "avx2")]
pub unsafe fn dot_u8_avx2(a: &[u8], b: &[u8]) -> u32 {
debug_assert_eq!(a.len(), b.len());
let n = a.len();
let mut acc = _mm256_setzero_si256();
let mut i = 0usize;
while i + 32 <= n {
let av = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i);
let bv = _mm256_loadu_si256(b.as_ptr().add(i) as *const __m256i);
let a_lo = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(av));
let a_hi = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(av, 1));
let b_lo = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(bv));
let b_hi = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(bv, 1));
acc = _mm256_add_epi32(acc, _mm256_madd_epi16(a_lo, b_lo));
acc = _mm256_add_epi32(acc, _mm256_madd_epi16(a_hi, b_hi));
i += 32;
}
let lo128 = _mm256_castsi256_si128(acc);
let hi128 = _mm256_extracti128_si256(acc, 1);
let mut sum128 = _mm_add_epi32(lo128, hi128);
sum128 = _mm_hadd_epi32(sum128, sum128);
sum128 = _mm_hadd_epi32(sum128, sum128);
let mut result = _mm_cvtsi128_si32(sum128) as u32;
while i < n {
result += a[i] as u32 * b[i] as u32;
i += 1;
}
result
}
#[target_feature(enable = "avx512f,avx512bw,avx512vnni")]
pub unsafe fn dot_u8_avx512_vnni(a: &[u8], b: &[u8]) -> u32 {
debug_assert_eq!(a.len(), b.len());
let n = a.len();
let mut acc_dot = _mm512_setzero_si512();
let mut acc_suma = _mm512_setzero_si512();
let sign_flip = _mm512_set1_epi8(0x80u8 as i8);
let zeros = _mm512_setzero_si512();
let mut i = 0usize;
while i + 64 <= n {
let av = _mm512_loadu_si512(a.as_ptr().add(i) as *const __m512i);
let bv = _mm512_loadu_si512(b.as_ptr().add(i) as *const __m512i);
let b_biased = _mm512_xor_si512(bv, sign_flip);
acc_dot = _mm512_dpbusd_epi32(acc_dot, av, b_biased);
acc_suma = _mm512_add_epi64(acc_suma, _mm512_sad_epu8(av, zeros));
i += 64;
}
let biased_dot = _mm512_reduce_add_epi32(acc_dot);
let sum_a = _mm512_reduce_add_epi64(acc_suma);
let mut result = (biased_dot as i64 + 128 * sum_a) as u32;
while i < n {
result += a[i] as u32 * b[i] as u32;
i += 1;
}
result
}
}
type DotU8Fn = fn(&[u8], &[u8]) -> u32;
static DISPATCH: OnceLock<DotU8Fn> = OnceLock::new();
fn select_backend() -> DotU8Fn {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f")
&& is_x86_feature_detected!("avx512bw")
&& is_x86_feature_detected!("avx512vnni")
{
return |a, b| unsafe { x86::dot_u8_avx512_vnni(a, b) };
}
if is_x86_feature_detected!("avx2") {
return |a, b| unsafe { x86::dot_u8_avx2(a, b) };
}
}
dot_u8_scalar
}
#[inline]
pub fn dot_u8(a: &[u8], b: &[u8]) -> u32 {
(DISPATCH.get_or_init(select_backend))(a, b)
}
#[cfg(test)]
mod tests {
use super::*;
fn fill_random(buf: &mut [u8], seed: &mut u32) {
for slot in buf.iter_mut() {
*seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
*slot = (*seed >> 16) as u8;
}
}
const SIZES: &[usize] = &[
0, 1, 7, 15, 16, 31, 32, 33, 63, 64, 65, 127, 128, 255, 256, 1024, 4096, 4097,
];
fn check_all_backends(a: &[u8], b: &[u8], case: &str) {
let reference = dot_u8_scalar(a, b);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
let got = unsafe { x86::dot_u8_avx2(a, b) };
assert_eq!(got, reference, "avx2 [{case}] n={}", a.len());
}
if is_x86_feature_detected!("avx512f")
&& is_x86_feature_detected!("avx512bw")
&& is_x86_feature_detected!("avx512vnni")
{
let got = unsafe { x86::dot_u8_avx512_vnni(a, b) };
assert_eq!(got, reference, "avx512_vnni [{case}] n={}", a.len());
}
}
assert_eq!(dot_u8(a, b), reference, "dispatch [{case}] n={}", a.len());
}
#[test]
fn random_inputs_across_sizes_and_seeds() {
let mut a = vec![0u8; 4097];
let mut b = vec![0u8; 4097];
for seed_idx in 0..4u32 {
let mut seed = 0xC0FFEE_u32.wrapping_add(seed_idx.wrapping_mul(7919));
for &n in SIZES {
fill_random(&mut a[..n], &mut seed);
fill_random(&mut b[..n], &mut seed);
check_all_backends(&a[..n], &b[..n], "random");
}
}
}
#[test]
fn boundary_values() {
let mut a = vec![0u8; 4097];
let mut b = vec![0u8; 4097];
for &n in SIZES {
a[..n].fill(u8::MAX);
b[..n].fill(u8::MAX);
check_all_backends(&a[..n], &b[..n], "max*max");
a[..n].fill(u8::MAX);
b[..n].fill(0);
check_all_backends(&a[..n], &b[..n], "max*0");
a[..n].fill(0);
b[..n].fill(u8::MAX);
check_all_backends(&a[..n], &b[..n], "0*max");
for i in 0..n {
a[i] = if i & 1 == 0 { 0 } else { u8::MAX };
b[i] = if i & 1 == 0 { u8::MAX } else { 0 };
}
check_all_backends(&a[..n], &b[..n], "alt 0/max");
}
}
#[test]
fn one_sided_zeros() {
let mut a = vec![0u8; 4097];
let mut b = vec![0u8; 4097];
for &n in SIZES {
let mut seed = 0xDEAD_BEEF_u32;
fill_random(&mut a[..n], &mut seed);
b[..n].fill(0);
check_all_backends(&a[..n], &b[..n], "b=0");
a[..n].fill(0);
fill_random(&mut b[..n], &mut seed);
check_all_backends(&a[..n], &b[..n], "a=0");
}
}
#[test]
fn all_ones_pattern() {
let mut a = vec![0u8; 4097];
let mut b = vec![0u8; 4097];
for &n in SIZES {
a[..n].fill(1);
b[..n].fill(1);
check_all_backends(&a[..n], &b[..n], "1*1");
assert_eq!(dot_u8_scalar(&a[..n], &b[..n]), n as u32);
}
}
}