#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use std::sync::OnceLock;
use super::simd_config;
#[derive(Debug, Clone, Copy)]
pub struct QuantizationParams {
pub scale: f32,
pub zero_point: i8,
pub min_val: f32,
pub max_val: f32,
}
impl QuantizationParams {
pub fn from_vector(vector: &[f32]) -> Self {
let mut min_val = f32::INFINITY;
let mut max_val = f32::NEG_INFINITY;
for &v in vector {
if v.is_finite() {
min_val = min_val.min(v);
max_val = max_val.max(v);
}
}
if !min_val.is_finite() || !max_val.is_finite() {
min_val = 0.0;
max_val = 0.0;
}
let max_abs = min_val.abs().max(max_val.abs());
let scale = if max_abs > 1e-10 {
127.0 / max_abs
} else {
1.0 };
Self {
scale,
zero_point: 0,
min_val,
max_val,
}
}
}
#[derive(Debug, Clone)]
pub struct QuantizedVector {
pub data: Vec<i8>,
pub params: QuantizationParams,
pub norm: f32,
}
impl QuantizedVector {
pub fn from_f32(vector: &[f32]) -> Self {
let mut params = QuantizationParams::from_vector(vector);
if !params.scale.is_finite() || params.scale == 0.0 {
params.scale = 1.0;
}
let mut norm_sq = 0.0f32;
for &v in vector {
if v.is_finite() {
norm_sq += v * v;
}
}
let norm = norm_sq.sqrt();
let data: Vec<i8> = vector
.iter()
.map(|&v| {
if !v.is_finite() {
0
} else {
(v * params.scale).round().clamp(-127.0, 127.0) as i8
}
})
.collect();
Self { data, params, norm }
}
pub fn to_f32(&self) -> Vec<f32> {
let scale = if self.params.scale.is_finite() && self.params.scale != 0.0 {
self.params.scale
} else {
1.0
};
self.data.iter().map(|&v| v as f32 / scale).collect()
}
#[inline]
pub fn dot_product(&self, other: &QuantizedVector) -> f32 {
dot_product_i8(self, other)
}
#[inline]
pub fn cosine_similarity(&self, other: &QuantizedVector) -> f32 {
cosine_similarity_i8(self, other)
}
}
#[inline]
pub fn dot_product_i8(a: &QuantizedVector, b: &QuantizedVector) -> f32 {
assert!(
a.data.iter().all(|&v| v != -128i8),
"QuantizedVector a contains -128, which violates the [-127, 127] VNNI invariant"
);
assert!(
b.data.iter().all(|&v| v != -128i8),
"QuantizedVector b contains -128, which violates the [-127, 127] VNNI invariant"
);
if a.data.len() != b.data.len() {
return 0.0;
}
debug_assert_eq!(a.data.len(), b.data.len());
let denom = a.params.scale * b.params.scale;
if denom == 0.0 || !denom.is_finite() {
return 0.0;
}
dot_product_i8_raw(&a.data, &b.data) / denom
}
#[inline]
pub(crate) fn dot_product_i8_trusted(a: &QuantizedVector, b: &QuantizedVector) -> f32 {
if a.data.len() != b.data.len() {
return 0.0;
}
let denom = a.params.scale * b.params.scale;
if denom == 0.0 || !denom.is_finite() {
return 0.0;
}
debug_assert!(a.data.iter().all(|&v| v != i8::MIN));
debug_assert!(b.data.iter().all(|&v| v != i8::MIN));
dot_product_i8_raw(&a.data, &b.data) / denom
}
#[inline]
pub fn cosine_similarity_i8(a: &QuantizedVector, b: &QuantizedVector) -> f32 {
let denom = a.norm * b.norm;
if denom == 0.0 || !denom.is_finite() {
return 0.0;
}
dot_product_i8(a, b) / denom
}
#[inline]
pub(crate) fn cosine_similarity_i8_trusted(a: &QuantizedVector, b: &QuantizedVector) -> f32 {
let denom = a.norm * b.norm;
if denom == 0.0 || !denom.is_finite() {
return 0.0;
}
dot_product_i8_trusted(a, b) / denom
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn dot_product_i8_neon_unrolled(a: &[i8], b: &[i8]) -> f32 {
const SIMD_WIDTH: usize = 16;
const UNROLL: usize = 4;
const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
let n = a.len();
debug_assert_eq!(n, b.len());
let chunks = n / CHUNK_SIZE;
let mut sum0 = vdupq_n_s32(0);
let mut sum1 = vdupq_n_s32(0);
let mut sum2 = vdupq_n_s32(0);
let mut sum3 = vdupq_n_s32(0);
for i in 0..chunks {
let base = i * CHUNK_SIZE;
let a0 = vld1q_s8(a.as_ptr().add(base));
let b0 = vld1q_s8(b.as_ptr().add(base));
let a0_lo = vget_low_s8(a0);
let a0_hi = vget_high_s8(a0);
let b0_lo = vget_low_s8(b0);
let b0_hi = vget_high_s8(b0);
let prod0_lo = vmull_s8(a0_lo, b0_lo);
let prod0_hi = vmull_s8(a0_hi, b0_hi);
sum0 = vpadalq_s16(sum0, prod0_lo);
sum0 = vpadalq_s16(sum0, prod0_hi);
let a1 = vld1q_s8(a.as_ptr().add(base + SIMD_WIDTH));
let b1 = vld1q_s8(b.as_ptr().add(base + SIMD_WIDTH));
let a1_lo = vget_low_s8(a1);
let a1_hi = vget_high_s8(a1);
let b1_lo = vget_low_s8(b1);
let b1_hi = vget_high_s8(b1);
let prod1_lo = vmull_s8(a1_lo, b1_lo);
let prod1_hi = vmull_s8(a1_hi, b1_hi);
sum1 = vpadalq_s16(sum1, prod1_lo);
sum1 = vpadalq_s16(sum1, prod1_hi);
let a2 = vld1q_s8(a.as_ptr().add(base + SIMD_WIDTH * 2));
let b2 = vld1q_s8(b.as_ptr().add(base + SIMD_WIDTH * 2));
let a2_lo = vget_low_s8(a2);
let a2_hi = vget_high_s8(a2);
let b2_lo = vget_low_s8(b2);
let b2_hi = vget_high_s8(b2);
let prod2_lo = vmull_s8(a2_lo, b2_lo);
let prod2_hi = vmull_s8(a2_hi, b2_hi);
sum2 = vpadalq_s16(sum2, prod2_lo);
sum2 = vpadalq_s16(sum2, prod2_hi);
let a3 = vld1q_s8(a.as_ptr().add(base + SIMD_WIDTH * 3));
let b3 = vld1q_s8(b.as_ptr().add(base + SIMD_WIDTH * 3));
let a3_lo = vget_low_s8(a3);
let a3_hi = vget_high_s8(a3);
let b3_lo = vget_low_s8(b3);
let b3_hi = vget_high_s8(b3);
let prod3_lo = vmull_s8(a3_lo, b3_lo);
let prod3_hi = vmull_s8(a3_hi, b3_hi);
sum3 = vpadalq_s16(sum3, prod3_lo);
sum3 = vpadalq_s16(sum3, prod3_hi);
}
let sum01 = vaddq_s32(sum0, sum1);
let sum23 = vaddq_s32(sum2, sum3);
let mut sum_vec = vaddq_s32(sum01, sum23);
let tail_start = chunks * CHUNK_SIZE;
let tail_chunks = (n - tail_start) / SIMD_WIDTH;
for j in 0..tail_chunks {
let base = tail_start + j * SIMD_WIDTH;
let at = vld1q_s8(a.as_ptr().add(base));
let bt = vld1q_s8(b.as_ptr().add(base));
let at_lo = vget_low_s8(at);
let at_hi = vget_high_s8(at);
let bt_lo = vget_low_s8(bt);
let bt_hi = vget_high_s8(bt);
let pt_lo = vmull_s8(at_lo, bt_lo);
let pt_hi = vmull_s8(at_hi, bt_hi);
sum_vec = vpadalq_s16(sum_vec, pt_lo);
sum_vec = vpadalq_s16(sum_vec, pt_hi);
}
let sum = vgetq_lane_s32(sum_vec, 0)
+ vgetq_lane_s32(sum_vec, 1)
+ vgetq_lane_s32(sum_vec, 2)
+ vgetq_lane_s32(sum_vec, 3);
let remainder_start = tail_start + tail_chunks * SIMD_WIDTH;
let remainder: i32 = a[remainder_start..]
.iter()
.zip(b[remainder_start..].iter())
.map(|(&x, &y)| x as i32 * y as i32)
.sum();
(sum + remainder) as f32
}
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
#[target_feature(enable = "avx512f", enable = "avx512bw")]
#[inline]
unsafe fn mm512_sign_epi8(b: __m512i, a: __m512i) -> __m512i {
let zero = _mm512_setzero_si512();
let neg_b = _mm512_sub_epi8(zero, b);
let mask_neg = _mm512_cmplt_epi8_mask(a, zero);
let mask_zero = _mm512_cmpeq_epi8_mask(a, zero);
let result = _mm512_mask_blend_epi8(mask_neg, b, neg_b);
_mm512_mask_blend_epi8(mask_zero, result, zero)
}
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
#[target_feature(enable = "avx512f", enable = "avx512vnni", enable = "avx512bw")]
unsafe fn dot_product_i8_avx512vnni(a: &[i8], b: &[i8]) -> f32 {
const SIMD_WIDTH: usize = 64; const UNROLL: usize = 4;
const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
let n = a.len();
debug_assert_eq!(n, b.len());
debug_assert!(a.iter().all(|&v| v != i8::MIN));
debug_assert!(b.iter().all(|&v| v != i8::MIN));
let chunks = n / CHUNK_SIZE;
let mut sum0 = _mm512_setzero_si512();
let mut sum1 = _mm512_setzero_si512();
let mut sum2 = _mm512_setzero_si512();
let mut sum3 = _mm512_setzero_si512();
for i in 0..chunks {
let base = i * CHUNK_SIZE;
let a0 = _mm512_loadu_si512(a.as_ptr().add(base) as *const __m512i);
let b0 = _mm512_loadu_si512(b.as_ptr().add(base) as *const __m512i);
let a0_abs = _mm512_abs_epi8(a0);
let b0_signed = mm512_sign_epi8(b0, a0);
sum0 = _mm512_dpbusd_epi32(sum0, a0_abs, b0_signed);
let a1 = _mm512_loadu_si512(a.as_ptr().add(base + SIMD_WIDTH) as *const __m512i);
let b1 = _mm512_loadu_si512(b.as_ptr().add(base + SIMD_WIDTH) as *const __m512i);
let a1_abs = _mm512_abs_epi8(a1);
let b1_signed = mm512_sign_epi8(b1, a1);
sum1 = _mm512_dpbusd_epi32(sum1, a1_abs, b1_signed);
let a2 = _mm512_loadu_si512(a.as_ptr().add(base + SIMD_WIDTH * 2) as *const __m512i);
let b2 = _mm512_loadu_si512(b.as_ptr().add(base + SIMD_WIDTH * 2) as *const __m512i);
let a2_abs = _mm512_abs_epi8(a2);
let b2_signed = mm512_sign_epi8(b2, a2);
sum2 = _mm512_dpbusd_epi32(sum2, a2_abs, b2_signed);
let a3 = _mm512_loadu_si512(a.as_ptr().add(base + SIMD_WIDTH * 3) as *const __m512i);
let b3 = _mm512_loadu_si512(b.as_ptr().add(base + SIMD_WIDTH * 3) as *const __m512i);
let a3_abs = _mm512_abs_epi8(a3);
let b3_signed = mm512_sign_epi8(b3, a3);
sum3 = _mm512_dpbusd_epi32(sum3, a3_abs, b3_signed);
}
let sum01 = _mm512_add_epi32(sum0, sum1);
let sum23 = _mm512_add_epi32(sum2, sum3);
let sum_vec = _mm512_add_epi32(sum01, sum23);
let sum = _mm512_reduce_add_epi32(sum_vec);
let remainder_start = chunks * CHUNK_SIZE;
let remainder: i32 = a[remainder_start..]
.iter()
.zip(b[remainder_start..].iter())
.map(|(&x, &y)| x as i32 * y as i32)
.sum();
(sum + remainder) as f32
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn dot_product_i8_avx2_unrolled(a: &[i8], b: &[i8]) -> f32 {
const SIMD_WIDTH: usize = 32;
const UNROLL: usize = 4;
const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
let n = a.len();
debug_assert_eq!(n, b.len());
debug_assert!(a.iter().all(|&v| v != i8::MIN));
debug_assert!(b.iter().all(|&v| v != i8::MIN));
let chunks = n / CHUNK_SIZE;
let mut sum0 = _mm256_setzero_si256();
let mut sum1 = _mm256_setzero_si256();
let mut sum2 = _mm256_setzero_si256();
let mut sum3 = _mm256_setzero_si256();
let ones = _mm256_set1_epi16(1);
for i in 0..chunks {
let base = i * CHUNK_SIZE;
let a0 = _mm256_loadu_si256(a.as_ptr().add(base) as *const __m256i);
let b0 = _mm256_loadu_si256(b.as_ptr().add(base) as *const __m256i);
let prod0 = _mm256_maddubs_epi16(_mm256_abs_epi8(a0), _mm256_sign_epi8(b0, a0));
let prod0_32 = _mm256_madd_epi16(prod0, ones);
sum0 = _mm256_add_epi32(sum0, prod0_32);
let a1 = _mm256_loadu_si256(a.as_ptr().add(base + SIMD_WIDTH) as *const __m256i);
let b1 = _mm256_loadu_si256(b.as_ptr().add(base + SIMD_WIDTH) as *const __m256i);
let prod1 = _mm256_maddubs_epi16(_mm256_abs_epi8(a1), _mm256_sign_epi8(b1, a1));
let prod1_32 = _mm256_madd_epi16(prod1, ones);
sum1 = _mm256_add_epi32(sum1, prod1_32);
let a2 = _mm256_loadu_si256(a.as_ptr().add(base + SIMD_WIDTH * 2) as *const __m256i);
let b2 = _mm256_loadu_si256(b.as_ptr().add(base + SIMD_WIDTH * 2) as *const __m256i);
let prod2 = _mm256_maddubs_epi16(_mm256_abs_epi8(a2), _mm256_sign_epi8(b2, a2));
let prod2_32 = _mm256_madd_epi16(prod2, ones);
sum2 = _mm256_add_epi32(sum2, prod2_32);
let a3 = _mm256_loadu_si256(a.as_ptr().add(base + SIMD_WIDTH * 3) as *const __m256i);
let b3 = _mm256_loadu_si256(b.as_ptr().add(base + SIMD_WIDTH * 3) as *const __m256i);
let prod3 = _mm256_maddubs_epi16(_mm256_abs_epi8(a3), _mm256_sign_epi8(b3, a3));
let prod3_32 = _mm256_madd_epi16(prod3, ones);
sum3 = _mm256_add_epi32(sum3, prod3_32);
}
let sum01 = _mm256_add_epi32(sum0, sum1);
let sum23 = _mm256_add_epi32(sum2, sum3);
let sum_vec = _mm256_add_epi32(sum01, sum23);
let sum128_lo = _mm256_castsi256_si128(sum_vec);
let sum128_hi = _mm256_extracti128_si256(sum_vec, 1);
let sum128 = _mm_add_epi32(sum128_lo, sum128_hi);
let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
let sum = _mm_cvtsi128_si32(sum32);
let remainder_start = chunks * CHUNK_SIZE;
let remainder: i32 = a[remainder_start..]
.iter()
.zip(b[remainder_start..].iter())
.map(|(&x, &y)| x as i32 * y as i32)
.sum();
(sum + remainder) as f32
}
pub type I8DotKernel = fn(&[i8], &[i8]) -> f32;
static I8_DOT_KERNEL: OnceLock<I8DotKernel> = OnceLock::new();
#[inline]
pub fn resolved_i8_dot_kernel() -> I8DotKernel {
*I8_DOT_KERNEL.get_or_init(resolve_i8_dot_kernel)
}
fn resolve_i8_dot_kernel() -> I8DotKernel {
let config = simd_config();
#[cfg(target_arch = "aarch64")]
{
if config.neon_enabled {
return dot_product_i8_neon_kernel;
}
}
#[cfg(target_arch = "x86_64")]
{
#[cfg(feature = "avx512")]
{
if config.avx512vnni_enabled {
return dot_product_i8_avx512vnni_kernel;
}
}
if config.avx2_enabled {
return dot_product_i8_avx2_kernel;
}
}
dot_product_i8_scalar_kernel
}
#[cfg(target_arch = "aarch64")]
fn dot_product_i8_neon_kernel(a: &[i8], b: &[i8]) -> f32 {
unsafe { dot_product_i8_neon_unrolled(a, b) }
}
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
fn dot_product_i8_avx512vnni_kernel(a: &[i8], b: &[i8]) -> f32 {
debug_assert!(a.iter().all(|&v| v != i8::MIN));
debug_assert!(b.iter().all(|&v| v != i8::MIN));
unsafe { dot_product_i8_avx512vnni(a, b) }
}
#[cfg(target_arch = "x86_64")]
fn dot_product_i8_avx2_kernel(a: &[i8], b: &[i8]) -> f32 {
debug_assert!(a.iter().all(|&v| v != i8::MIN));
debug_assert!(b.iter().all(|&v| v != i8::MIN));
unsafe { dot_product_i8_avx2_unrolled(a, b) }
}
fn dot_product_i8_scalar_kernel(a: &[i8], b: &[i8]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| x as i32 * y as i32)
.sum::<i32>() as f32
}
#[inline]
pub fn dot_product_i8_raw(a: &[i8], b: &[i8]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
debug_assert_eq!(a.len(), b.len());
resolved_i8_dot_kernel()(a, b)
}
#[cfg(test)]
mod simd_parity_tests {
use super::*;
fn gen_vec(dim: usize, seed: u64) -> Vec<f32> {
let mut state = seed ^ ((dim as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15));
(0..dim)
.map(|i| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407)
.wrapping_add(i as u64);
let unit = ((state >> 32) as u32) as f32 / u32::MAX as f32;
unit * 2.0 - 1.0
})
.collect()
}
#[test]
fn test_i8_neon_scalar_parity() {
#[cfg(target_arch = "aarch64")]
for dim in [7usize, 16, 64, 128, 384, 768] {
let a_q = QuantizedVector::from_f32(&gen_vec(dim, 200 + dim as u64));
let b_q = QuantizedVector::from_f32(&gen_vec(dim, 300 + dim as u64));
let neon = unsafe { dot_product_i8_neon_unrolled(&a_q.data, &b_q.data) };
let scalar: f32 = a_q
.data
.iter()
.zip(b_q.data.iter())
.map(|(&x, &y)| x as i32 * y as i32)
.sum::<i32>() as f32;
let diff = (neon - scalar).abs();
assert!(
diff <= 1.0,
"NEON vs scalar i8 dot product dim={dim}: neon={neon} scalar={scalar} diff={diff}"
);
}
}
#[test]
fn test_i8_avx2_scalar_parity() {
#[cfg(target_arch = "x86_64")]
if std::arch::is_x86_feature_detected!("avx2") {
for dim in [7usize, 16, 64, 128, 384, 768] {
let a_q = QuantizedVector::from_f32(&gen_vec(dim, 400 + dim as u64));
let b_q = QuantizedVector::from_f32(&gen_vec(dim, 500 + dim as u64));
let avx2 = unsafe { dot_product_i8_avx2_unrolled(&a_q.data, &b_q.data) };
let scalar: f32 = a_q
.data
.iter()
.zip(b_q.data.iter())
.map(|(&x, &y)| x as i32 * y as i32)
.sum::<i32>() as f32;
let diff = (avx2 - scalar).abs();
assert!(
diff <= 1.0,
"AVX2 vs scalar i8 dot product dim={dim}: avx2={avx2} scalar={scalar} diff={diff}"
);
}
}
}
}