#[derive(Debug, Clone, Copy)]
pub struct CpuFeatures {
pub avx512f: bool,
pub avx512bw: bool,
pub avx512vl: bool,
pub avx512vbmi: bool,
pub avx2: bool,
pub sse41: bool,
pub neon: bool,
pub sve: bool,
}
impl CpuFeatures {
pub fn detect() -> Self {
#[cfg(target_arch = "x86_64")]
{
Self {
avx512f: is_x86_feature_detected!("avx512f"),
avx512bw: is_x86_feature_detected!("avx512bw"),
avx512vl: is_x86_feature_detected!("avx512vl"),
avx512vbmi: is_x86_feature_detected!("avx512vbmi"),
avx2: is_x86_feature_detected!("avx2"),
sse41: is_x86_feature_detected!("sse4.1"),
neon: false,
sve: false,
}
}
#[cfg(target_arch = "aarch64")]
{
Self {
avx512f: false,
avx512bw: false,
avx512vl: false,
avx512vbmi: false,
avx2: false,
sse41: false,
neon: true, sve: false, }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
Self {
avx512f: false,
avx512bw: false,
avx512vl: false,
avx512vbmi: false,
avx2: false,
sse41: false,
neon: false,
sve: false,
}
}
}
pub fn best_simd_level(&self) -> SimdLevel {
if self.avx512f && self.avx512bw {
SimdLevel::Avx512
} else if self.avx2 {
SimdLevel::Avx2
} else if self.sse41 {
SimdLevel::Sse41
} else if self.neon {
SimdLevel::Neon
} else {
SimdLevel::Scalar
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdLevel {
Avx512,
Avx2,
Sse41,
Neon,
Scalar,
}
impl SimdLevel {
pub fn width_bytes(&self) -> usize {
match self {
SimdLevel::Avx512 => 64,
SimdLevel::Avx2 => 32,
SimdLevel::Sse41 => 16,
SimdLevel::Neon => 16,
SimdLevel::Scalar => 1,
}
}
pub fn f32_elements(&self) -> usize {
self.width_bytes() / 4
}
pub fn i8_elements(&self) -> usize {
self.width_bytes()
}
}
pub trait DistanceKernel: Send + Sync {
fn l2_squared_f32(&self, a: &[f32], b: &[f32]) -> f32;
fn dot_f32(&self, a: &[f32], b: &[f32]) -> f32;
fn dot_i8(&self, a: &[i8], b: &[i8]) -> i32;
fn l2_squared_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]);
fn dot_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]);
fn simd_level(&self) -> SimdLevel;
}
pub struct ScalarKernel;
impl DistanceKernel for ScalarKernel {
fn l2_squared_f32(&self, a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum()
}
fn dot_f32(&self, a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
fn dot_i8(&self, a: &[i8], b: &[i8]) -> i32 {
debug_assert_eq!(a.len(), b.len());
a.iter()
.zip(b.iter())
.map(|(&x, &y)| x as i32 * y as i32)
.sum()
}
fn l2_squared_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
let n = vectors.len() / dim;
debug_assert!(out.len() >= n);
for i in 0..n {
let vec = &vectors[i * dim..(i + 1) * dim];
out[i] = self.l2_squared_f32(query, vec);
}
}
fn dot_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
let n = vectors.len() / dim;
debug_assert!(out.len() >= n);
for i in 0..n {
let vec = &vectors[i * dim..(i + 1) * dim];
out[i] = self.dot_f32(query, vec);
}
}
fn simd_level(&self) -> SimdLevel {
SimdLevel::Scalar
}
}
#[cfg(target_arch = "x86_64")]
pub struct Avx2Kernel;
#[cfg(target_arch = "x86_64")]
impl DistanceKernel for Avx2Kernel {
fn l2_squared_f32(&self, a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[target_feature(enable = "avx2")]
unsafe fn inner(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
unsafe {
let n = a.len();
let chunks = n / 8;
let mut sum = _mm256_setzero_ps();
for i in 0..chunks {
let va = _mm256_loadu_ps(a.as_ptr().add(i * 8));
let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
let diff = _mm256_sub_ps(va, vb);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
let sum128 =
_mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1));
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
for i in (chunks * 8)..n {
let diff = a[i] - b[i];
result += diff * diff;
}
result
}
}
if is_x86_feature_detected!("avx2") {
unsafe { inner(a, b) }
} else {
ScalarKernel.l2_squared_f32(a, b)
}
}
fn dot_f32(&self, a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[target_feature(enable = "avx2")]
unsafe fn inner(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
unsafe {
let n = a.len();
let chunks = n / 8;
let mut sum = _mm256_setzero_ps();
for i in 0..chunks {
let va = _mm256_loadu_ps(a.as_ptr().add(i * 8));
let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
sum = _mm256_fmadd_ps(va, vb, sum);
}
let sum128 =
_mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1));
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
for i in (chunks * 8)..n {
result += a[i] * b[i];
}
result
}
}
if is_x86_feature_detected!("avx2") {
unsafe { inner(a, b) }
} else {
ScalarKernel.dot_f32(a, b)
}
}
fn dot_i8(&self, a: &[i8], b: &[i8]) -> i32 {
debug_assert_eq!(a.len(), b.len());
#[target_feature(enable = "avx2")]
unsafe fn inner(a: &[i8], b: &[i8]) -> i32 {
use std::arch::x86_64::*;
unsafe {
let n = a.len();
let chunks = n / 32;
let mut sum = _mm256_setzero_si256();
for i in 0..chunks {
let va = _mm256_loadu_si256(a.as_ptr().add(i * 32) as *const __m256i);
let vb = _mm256_loadu_si256(b.as_ptr().add(i * 32) as *const __m256i);
let a_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 0));
let b_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 0));
let a_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
let b_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
let prod_lo = _mm256_madd_epi16(a_lo, b_lo);
let prod_hi = _mm256_madd_epi16(a_hi, b_hi);
sum = _mm256_add_epi32(sum, prod_lo);
sum = _mm256_add_epi32(sum, prod_hi);
}
let sum128 = _mm_add_epi32(
_mm256_extracti128_si256(sum, 0),
_mm256_extracti128_si256(sum, 1),
);
let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
let mut result = _mm_cvtsi128_si32(sum32);
for i in (chunks * 32)..n {
result += a[i] as i32 * b[i] as i32;
}
result
}
}
if is_x86_feature_detected!("avx2") {
unsafe { inner(a, b) }
} else {
ScalarKernel.dot_i8(a, b)
}
}
fn l2_squared_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
let n = vectors.len() / dim;
for i in 0..n {
let vec = &vectors[i * dim..(i + 1) * dim];
out[i] = self.l2_squared_f32(query, vec);
}
}
fn dot_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
let n = vectors.len() / dim;
for i in 0..n {
let vec = &vectors[i * dim..(i + 1) * dim];
out[i] = self.dot_f32(query, vec);
}
}
fn simd_level(&self) -> SimdLevel {
SimdLevel::Avx2
}
}
#[cfg(target_arch = "aarch64")]
pub struct NeonKernel;
#[cfg(target_arch = "aarch64")]
impl DistanceKernel for NeonKernel {
fn l2_squared_f32(&self, a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
unsafe fn inner(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
unsafe {
let n = a.len();
let chunks = n / 4;
let mut sum = vdupq_n_f32(0.0);
for i in 0..chunks {
let va = vld1q_f32(a.as_ptr().add(i * 4));
let vb = vld1q_f32(b.as_ptr().add(i * 4));
let diff = vsubq_f32(va, vb);
sum = vfmaq_f32(sum, diff, diff);
}
let mut result = vaddvq_f32(sum);
for i in (chunks * 4)..n {
let diff = a[i] - b[i];
result += diff * diff;
}
result
}
}
unsafe { inner(a, b) }
}
fn dot_f32(&self, a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
unsafe fn inner(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
unsafe {
let n = a.len();
let chunks = n / 4;
let mut sum = vdupq_n_f32(0.0);
for i in 0..chunks {
let va = vld1q_f32(a.as_ptr().add(i * 4));
let vb = vld1q_f32(b.as_ptr().add(i * 4));
sum = vfmaq_f32(sum, va, vb);
}
let mut result = vaddvq_f32(sum);
for i in (chunks * 4)..n {
result += a[i] * b[i];
}
result
}
}
unsafe { inner(a, b) }
}
fn dot_i8(&self, a: &[i8], b: &[i8]) -> i32 {
debug_assert_eq!(a.len(), b.len());
unsafe fn inner(a: &[i8], b: &[i8]) -> i32 {
use std::arch::aarch64::*;
unsafe {
let n = a.len();
let chunks = n / 16;
let mut sum = vdupq_n_s32(0);
for i in 0..chunks {
let va = vld1q_s8(a.as_ptr().add(i * 16));
let vb = vld1q_s8(b.as_ptr().add(i * 16));
let a_lo = vmovl_s8(vget_low_s8(va));
let b_lo = vmovl_s8(vget_low_s8(vb));
let a_hi = vmovl_s8(vget_high_s8(va));
let b_hi = vmovl_s8(vget_high_s8(vb));
let prod_lo = vmull_s16(vget_low_s16(a_lo), vget_low_s16(b_lo));
let prod_hi = vmull_s16(vget_high_s16(a_lo), vget_high_s16(b_lo));
sum = vaddq_s32(sum, prod_lo);
sum = vaddq_s32(sum, prod_hi);
let prod_lo2 = vmull_s16(vget_low_s16(a_hi), vget_low_s16(b_hi));
let prod_hi2 = vmull_s16(vget_high_s16(a_hi), vget_high_s16(b_hi));
sum = vaddq_s32(sum, prod_lo2);
sum = vaddq_s32(sum, prod_hi2);
}
let mut result = vaddvq_s32(sum);
for i in (chunks * 16)..n {
result += a[i] as i32 * b[i] as i32;
}
result
}
}
unsafe { inner(a, b) }
}
fn l2_squared_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
let n = vectors.len() / dim;
for i in 0..n {
let vec = &vectors[i * dim..(i + 1) * dim];
out[i] = self.l2_squared_f32(query, vec);
}
}
fn dot_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
let n = vectors.len() / dim;
for i in 0..n {
let vec = &vectors[i * dim..(i + 1) * dim];
out[i] = self.dot_f32(query, vec);
}
}
fn simd_level(&self) -> SimdLevel {
SimdLevel::Neon
}
}
pub struct KernelDispatcher {
features: CpuFeatures,
}
impl KernelDispatcher {
pub fn new() -> Self {
Self {
features: CpuFeatures::detect(),
}
}
pub fn best_kernel(&self) -> Box<dyn DistanceKernel> {
#[cfg(target_arch = "x86_64")]
{
if self.features.avx2 {
return Box::new(Avx2Kernel);
}
}
#[cfg(target_arch = "aarch64")]
{
if self.features.neon {
return Box::new(NeonKernel);
}
}
Box::new(ScalarKernel)
}
pub fn kernel_for_level(&self, level: SimdLevel) -> Box<dyn DistanceKernel> {
match level {
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx2 if self.features.avx2 => Box::new(Avx2Kernel),
#[cfg(target_arch = "aarch64")]
SimdLevel::Neon if self.features.neon => Box::new(NeonKernel),
_ => Box::new(ScalarKernel),
}
}
pub fn features(&self) -> CpuFeatures {
self.features
}
pub fn description(&self) -> String {
format!(
"SIMD: {:?}, Features: avx2={}, neon={}",
self.features.best_simd_level(),
self.features.avx2,
self.features.neon,
)
}
}
impl Default for KernelDispatcher {
fn default() -> Self {
Self::new()
}
}
pub struct ScanOps {
kernel: Box<dyn DistanceKernel>,
}
impl ScanOps {
pub fn new() -> Self {
Self {
kernel: KernelDispatcher::new().best_kernel(),
}
}
pub fn with_kernel(kernel: Box<dyn DistanceKernel>) -> Self {
Self { kernel }
}
pub fn top_k_l2(
&self,
query: &[f32],
vectors: &[f32],
dim: usize,
k: usize,
) -> Vec<(u32, f32)> {
let n = vectors.len() / dim;
let mut distances = vec![0.0f32; n];
self.kernel
.l2_squared_batch_f32(query, vectors, dim, &mut distances);
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| distances[a].total_cmp(&distances[b]));
indices
.into_iter()
.take(k)
.map(|i| (i as u32, distances[i].sqrt()))
.collect()
}
pub fn top_k_dot(
&self,
query: &[f32],
vectors: &[f32],
dim: usize,
k: usize,
) -> Vec<(u32, f32)> {
let n = vectors.len() / dim;
let mut scores = vec![0.0f32; n];
self.kernel.dot_batch_f32(query, vectors, dim, &mut scores);
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| scores[b].total_cmp(&scores[a]));
indices
.into_iter()
.take(k)
.map(|i| (i as u32, scores[i]))
.collect()
}
pub fn simd_level(&self) -> SimdLevel {
self.kernel.simd_level()
}
}
impl Default for ScanOps {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scalar_l2() {
let kernel = ScalarKernel;
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![1.0, 2.0, 3.0, 5.0];
let dist = kernel.l2_squared_f32(&a, &b);
assert!((dist - 1.0).abs() < 1e-6);
}
#[test]
fn test_scalar_dot() {
let kernel = ScalarKernel;
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![1.0, 2.0, 3.0, 4.0];
let dot = kernel.dot_f32(&a, &b);
assert!((dot - 30.0).abs() < 1e-6);
}
#[test]
fn test_scalar_dot_i8() {
let kernel = ScalarKernel;
let a: Vec<i8> = vec![1, 2, 3, 4];
let b: Vec<i8> = vec![1, 2, 3, 4];
let dot = kernel.dot_i8(&a, &b);
assert_eq!(dot, 30);
}
#[test]
fn test_dispatcher() {
let dispatcher = KernelDispatcher::new();
let kernel = dispatcher.best_kernel();
let a = vec![1.0f32; 128];
let b = vec![2.0f32; 128];
let l2 = kernel.l2_squared_f32(&a, &b);
assert!((l2 - 128.0).abs() < 1e-4);
let dot = kernel.dot_f32(&a, &b);
assert!((dot - 256.0).abs() < 1e-4);
}
#[test]
fn test_scan_ops() {
let ops = ScanOps::new();
let query = vec![1.0, 0.0, 0.0, 0.0];
let vectors = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, ];
let top2 = ops.top_k_l2(&query, &vectors, 4, 2);
assert_eq!(top2.len(), 2);
assert_eq!(top2[0].0, 0); }
#[test]
fn test_cpu_features() {
let features = CpuFeatures::detect();
let level = features.best_simd_level();
println!("Detected SIMD level: {:?}", level);
assert!(level.width_bytes() > 0);
}
}