use std::sync::OnceLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CosineAccumulators {
pub dot_ab: u32,
pub norm_a_sq: u32,
pub norm_b_sq: u32,
}
#[inline]
pub fn cosine_u8_accum_scalar(a: &[u8], b: &[u8]) -> CosineAccumulators {
debug_assert_eq!(a.len(), b.len());
let (mut dot_ab, mut norm_a_sq, mut norm_b_sq) = (0u32, 0u32, 0u32);
for (&x, &y) in a.iter().zip(b.iter()) {
let (xu, yu) = (x as u32, y as u32);
dot_ab += xu * yu;
norm_a_sq += xu * xu;
norm_b_sq += yu * yu;
}
CosineAccumulators {
dot_ab,
norm_a_sq,
norm_b_sq,
}
}
#[inline]
fn normalize(acc: CosineAccumulators) -> f32 {
let na = (acc.norm_a_sq as f32).sqrt();
let nb = (acc.norm_b_sq as f32).sqrt();
let denom = na * nb;
if denom == 0.0 {
return 0.0;
}
1.0 - acc.dot_ab as f32 / denom
}
#[inline]
pub fn cosine_u8_scalar(a: &[u8], b: &[u8]) -> f32 {
normalize(cosine_u8_accum_scalar(a, b))
}
#[cfg(target_arch = "x86_64")]
mod x86 {
use super::CosineAccumulators;
use std::arch::x86_64::*;
#[inline(always)]
unsafe fn hsum_epi32_avx2(v: __m256i) -> u32 {
let lo128 = _mm256_castsi256_si128(v);
let hi128 = _mm256_extracti128_si256(v, 1);
let mut sum128 = _mm_add_epi32(lo128, hi128);
sum128 = _mm_hadd_epi32(sum128, sum128);
sum128 = _mm_hadd_epi32(sum128, sum128);
_mm_cvtsi128_si32(sum128) as u32
}
#[target_feature(enable = "avx2")]
pub unsafe fn cosine_u8_accum_avx2(a: &[u8], b: &[u8]) -> CosineAccumulators {
debug_assert_eq!(a.len(), b.len());
let n = a.len();
let mut acc_dot = _mm256_setzero_si256();
let mut acc_na = _mm256_setzero_si256();
let mut acc_nb = _mm256_setzero_si256();
let mut i = 0usize;
while i + 32 <= n {
let av = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i);
let bv = _mm256_loadu_si256(b.as_ptr().add(i) as *const __m256i);
let a_lo = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(av));
let a_hi = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(av, 1));
let b_lo = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(bv));
let b_hi = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(bv, 1));
acc_dot = _mm256_add_epi32(acc_dot, _mm256_madd_epi16(a_lo, b_lo));
acc_dot = _mm256_add_epi32(acc_dot, _mm256_madd_epi16(a_hi, b_hi));
acc_na = _mm256_add_epi32(acc_na, _mm256_madd_epi16(a_lo, a_lo));
acc_na = _mm256_add_epi32(acc_na, _mm256_madd_epi16(a_hi, a_hi));
acc_nb = _mm256_add_epi32(acc_nb, _mm256_madd_epi16(b_lo, b_lo));
acc_nb = _mm256_add_epi32(acc_nb, _mm256_madd_epi16(b_hi, b_hi));
i += 32;
}
let mut dot_ab = hsum_epi32_avx2(acc_dot);
let mut norm_a_sq = hsum_epi32_avx2(acc_na);
let mut norm_b_sq = hsum_epi32_avx2(acc_nb);
while i < n {
let (xu, yu) = (a[i] as u32, b[i] as u32);
dot_ab += xu * yu;
norm_a_sq += xu * xu;
norm_b_sq += yu * yu;
i += 1;
}
CosineAccumulators {
dot_ab,
norm_a_sq,
norm_b_sq,
}
}
#[target_feature(enable = "avx512f,avx512bw,avx512vnni")]
pub unsafe fn cosine_u8_accum_avx512_vnni(a: &[u8], b: &[u8]) -> CosineAccumulators {
debug_assert_eq!(a.len(), b.len());
let n = a.len();
let zeros = _mm512_setzero_si512();
let mut acc_dot = _mm512_setzero_si512();
let mut acc_na = _mm512_setzero_si512();
let mut acc_nb = _mm512_setzero_si512();
let mut i = 0usize;
while i + 64 <= n {
let av = _mm512_loadu_si512(a.as_ptr().add(i) as *const __m512i);
let bv = _mm512_loadu_si512(b.as_ptr().add(i) as *const __m512i);
let a_lo = _mm512_unpacklo_epi8(av, zeros);
let a_hi = _mm512_unpackhi_epi8(av, zeros);
let b_lo = _mm512_unpacklo_epi8(bv, zeros);
let b_hi = _mm512_unpackhi_epi8(bv, zeros);
acc_dot = _mm512_dpwssd_epi32(acc_dot, a_lo, b_lo);
acc_dot = _mm512_dpwssd_epi32(acc_dot, a_hi, b_hi);
acc_na = _mm512_dpwssd_epi32(acc_na, a_lo, a_lo);
acc_na = _mm512_dpwssd_epi32(acc_na, a_hi, a_hi);
acc_nb = _mm512_dpwssd_epi32(acc_nb, b_lo, b_lo);
acc_nb = _mm512_dpwssd_epi32(acc_nb, b_hi, b_hi);
i += 64;
}
let mut dot_ab = _mm512_reduce_add_epi32(acc_dot) as u32;
let mut norm_a_sq = _mm512_reduce_add_epi32(acc_na) as u32;
let mut norm_b_sq = _mm512_reduce_add_epi32(acc_nb) as u32;
while i < n {
let (xu, yu) = (a[i] as u32, b[i] as u32);
dot_ab += xu * yu;
norm_a_sq += xu * xu;
norm_b_sq += yu * yu;
i += 1;
}
CosineAccumulators {
dot_ab,
norm_a_sq,
norm_b_sq,
}
}
}
type CosineU8AccumFn = fn(&[u8], &[u8]) -> CosineAccumulators;
static DISPATCH: OnceLock<CosineU8AccumFn> = OnceLock::new();
fn select_backend() -> CosineU8AccumFn {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f")
&& is_x86_feature_detected!("avx512bw")
&& is_x86_feature_detected!("avx512vnni")
{
return |a, b| unsafe { x86::cosine_u8_accum_avx512_vnni(a, b) };
}
if is_x86_feature_detected!("avx2") {
return |a, b| unsafe { x86::cosine_u8_accum_avx2(a, b) };
}
}
cosine_u8_accum_scalar
}
#[inline]
fn cosine_u8_accum(a: &[u8], b: &[u8]) -> CosineAccumulators {
(DISPATCH.get_or_init(select_backend))(a, b)
}
#[inline]
pub fn cosine_u8(a: &[u8], b: &[u8]) -> f32 {
normalize(cosine_u8_accum(a, b))
}
#[cfg(test)]
mod tests {
use super::*;
fn fill_random(buf: &mut [u8], seed: &mut u32) {
for slot in buf.iter_mut() {
*seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
*slot = (*seed >> 16) as u8;
}
}
const SIZES: &[usize] = &[
0, 1, 7, 15, 16, 31, 32, 33, 63, 64, 65, 127, 128, 255, 256, 1024, 4096, 4097,
];
fn check_all_backends_accum(a: &[u8], b: &[u8], case: &str) {
let reference = cosine_u8_accum_scalar(a, b);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
let got = unsafe { x86::cosine_u8_accum_avx2(a, b) };
assert_eq!(got, reference, "avx2 [{case}] n={}", a.len());
}
if is_x86_feature_detected!("avx512f")
&& is_x86_feature_detected!("avx512bw")
&& is_x86_feature_detected!("avx512vnni")
{
let got = unsafe { x86::cosine_u8_accum_avx512_vnni(a, b) };
assert_eq!(got, reference, "avx512_vnni [{case}] n={}", a.len());
}
}
let dispatched = cosine_u8_accum(a, b);
assert_eq!(dispatched, reference, "dispatch [{case}] n={}", a.len());
}
#[test]
fn random_inputs_across_sizes_and_seeds() {
let mut a = vec![0u8; 4097];
let mut b = vec![0u8; 4097];
for seed_idx in 0..4u32 {
let mut seed = 0xC0FFEE_u32.wrapping_add(seed_idx.wrapping_mul(7919));
for &n in SIZES {
fill_random(&mut a[..n], &mut seed);
fill_random(&mut b[..n], &mut seed);
check_all_backends_accum(&a[..n], &b[..n], "random");
}
}
}
#[test]
fn boundary_values() {
let mut a = vec![0u8; 4097];
let mut b = vec![0u8; 4097];
for &n in SIZES {
a[..n].fill(u8::MAX);
b[..n].fill(u8::MAX);
check_all_backends_accum(&a[..n], &b[..n], "max-max");
a[..n].fill(u8::MAX);
b[..n].fill(0);
check_all_backends_accum(&a[..n], &b[..n], "max-0");
a[..n].fill(0);
b[..n].fill(u8::MAX);
check_all_backends_accum(&a[..n], &b[..n], "0-max");
a[..n].fill(0);
b[..n].fill(0);
check_all_backends_accum(&a[..n], &b[..n], "0-0");
for i in 0..n {
a[i] = if i & 1 == 0 { 0 } else { u8::MAX };
b[i] = if i & 1 == 0 { u8::MAX } else { 0 };
}
check_all_backends_accum(&a[..n], &b[..n], "alt 0/max");
}
}
#[test]
fn one_sided_zeros() {
let mut a = vec![0u8; 4097];
let mut b = vec![0u8; 4097];
for &n in SIZES {
let mut seed = 0xDEAD_BEEF_u32;
fill_random(&mut a[..n], &mut seed);
b[..n].fill(0);
check_all_backends_accum(&a[..n], &b[..n], "b=0");
a[..n].fill(0);
fill_random(&mut b[..n], &mut seed);
check_all_backends_accum(&a[..n], &b[..n], "a=0");
}
}
#[test]
fn cosine_known_values() {
let v = [10u8, 20, 30, 40];
assert_eq!(cosine_u8(&v, &v), 0.0);
let a = [255u8, 255, 0, 0];
let b = [0u8, 0, 255, 255];
assert_eq!(cosine_u8(&a, &b), 1.0);
assert_eq!(cosine_u8(&[0, 0], &[0, 0]), 0.0);
assert_eq!(cosine_u8(&[0, 0], &[1, 2]), 0.0);
}
#[test]
fn cosine_distance_symmetry() {
let mut a = vec![0u8; 256];
let mut b = vec![0u8; 256];
let mut seed = 0xBEEF_u32;
fill_random(&mut a, &mut seed);
fill_random(&mut b, &mut seed);
let d1 = cosine_u8(&a, &b);
let d2 = cosine_u8(&b, &a);
assert_eq!(d1, d2);
}
#[test]
fn cosine_distance_range() {
let mut a = vec![0u8; 1024];
let mut b = vec![0u8; 1024];
for seed_idx in 0..8u32 {
let mut seed = 0xABCD_u32.wrapping_add(seed_idx.wrapping_mul(31));
fill_random(&mut a, &mut seed);
fill_random(&mut b, &mut seed);
let d = cosine_u8(&a, &b);
assert!(
(0.0..=1.0).contains(&d),
"cosine distance {d} out of [0,1] range"
);
}
}
}