#![allow(dead_code)]
use faer::RowRef;
use num_traits::Float;
use std::iter::Sum;
use std::sync::OnceLock;
use wide::{f32x4, f32x8, f64x2, f64x4};
#[cfg(feature = "quantised")]
use half::bf16;
#[cfg(feature = "quantised")]
use num_traits::{FromPrimitive, ToPrimitive};
#[cfg(all(feature = "quantised", target_arch = "aarch64"))]
use std::arch::aarch64::*;
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
use std::arch::x86_64::*;
#[derive(Clone, Debug, Copy, PartialEq, Default)]
pub enum Dist {
#[default]
Euclidean,
Cosine,
}
pub fn parse_ann_dist(s: &str) -> Option<Dist> {
match s.to_lowercase().as_str() {
"euclidean" => Some(Dist::Euclidean),
"cosine" => Some(Dist::Cosine),
_ => None,
}
}
#[derive(Clone, Copy, Debug)]
pub enum SimdLevel {
Scalar,
Sse,
Avx2,
Avx512,
}
static SIMD_LEVEL: OnceLock<SimdLevel> = OnceLock::new();
pub fn detect_simd_level() -> SimdLevel {
*SIMD_LEVEL.get_or_init(|| {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
return SimdLevel::Avx512;
}
if is_x86_feature_detected!("avx2") {
return SimdLevel::Avx2;
}
if is_x86_feature_detected!("sse4.1") {
return SimdLevel::Sse;
}
return SimdLevel::Scalar;
}
#[cfg(target_arch = "aarch64")]
{
SimdLevel::Sse
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
SimdLevel::Scalar
}
})
}
pub trait SimdDistance: Sized + Copy {
fn euclidean_simd(a: &[Self], b: &[Self]) -> Self;
fn dot_simd(a: &[Self], b: &[Self]) -> Self;
fn subtract_simd(a: &[Self], b: &[Self]) -> Vec<Self>;
fn add_simd(a: &[Self], b: &[Self]) -> Vec<Self>;
fn add_assign_simd(dst: &mut [Self], src: &[Self]);
fn calculate_l2_norm(vec: &[Self]) -> Self;
fn calculate_l1_norm(vec: &[Self]) -> Self;
}
#[inline(always)]
fn euclidean_f32_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| {
let d = x - y;
d * d
})
.sum()
}
#[inline(always)]
fn euclidean_f32_sse(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 4;
let mut acc = f32x4::ZERO;
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = f32x4::from(*(a_ptr.add(offset) as *const [f32; 4]));
let vb = f32x4::from(*(b_ptr.add(offset) as *const [f32; 4]));
let diff = va - vb;
acc += diff * diff;
}
}
let mut sum = acc.reduce_add();
for i in (chunks * 4)..len {
let diff = a[i] - b[i];
sum += diff * diff;
}
sum
}
#[inline(always)]
fn euclidean_f32_avx2(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 8;
let mut acc = f32x8::ZERO;
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let va = f32x8::from(*(a_ptr.add(offset) as *const [f32; 8]));
let vb = f32x8::from(*(b_ptr.add(offset) as *const [f32; 8]));
let diff = va - vb;
acc += diff * diff;
}
}
let mut sum = acc.reduce_add();
for i in (chunks * 8)..len {
let diff = a[i] - b[i];
sum += diff * diff;
}
sum
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[inline(always)]
fn euclidean_f32_avx512(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 16;
unsafe {
let mut acc = _mm512_setzero_ps();
for i in 0..chunks {
let va = _mm512_loadu_ps(a.as_ptr().add(i * 16));
let vb = _mm512_loadu_ps(b.as_ptr().add(i * 16));
let diff = _mm512_sub_ps(va, vb);
acc = _mm512_fmadd_ps(diff, diff, acc);
}
let mut sum = _mm512_reduce_add_ps(acc);
for i in (chunks * 16)..len {
let diff = a[i] - b[i];
sum += diff * diff;
}
sum
}
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
#[inline(always)]
fn euclidean_f32_avx512(a: &[f32], b: &[f32]) -> f32 {
euclidean_f32_avx2(a, b)
}
#[inline(always)]
fn euclidean_f64_scalar(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| {
let d = x - y;
d * d
})
.sum()
}
#[inline(always)]
fn euclidean_f64_sse(a: &[f64], b: &[f64]) -> f64 {
let len = a.len();
let chunks = len / 2;
let mut acc = f64x2::ZERO;
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 2;
let va = f64x2::from(*(a_ptr.add(offset) as *const [f64; 2]));
let vb = f64x2::from(*(b_ptr.add(offset) as *const [f64; 2]));
let diff = va - vb;
acc += diff * diff;
}
}
let mut sum = acc.reduce_add();
if len % 2 == 1 {
let diff = a[len - 1] - b[len - 1];
sum += diff * diff;
}
sum
}
#[inline(always)]
fn euclidean_f64_avx2(a: &[f64], b: &[f64]) -> f64 {
let len = a.len();
let chunks = len / 4;
let mut acc = f64x4::ZERO;
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = f64x4::from(*(a_ptr.add(offset) as *const [f64; 4]));
let vb = f64x4::from(*(b_ptr.add(offset) as *const [f64; 4]));
let diff = va - vb;
acc += diff * diff;
}
}
let mut sum = acc.reduce_add();
for i in (chunks * 4)..len {
let diff = a[i] - b[i];
sum += diff * diff;
}
sum
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[inline(always)]
fn euclidean_f64_avx512(a: &[f64], b: &[f64]) -> f64 {
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 8;
unsafe {
let mut acc = _mm512_setzero_pd();
for i in 0..chunks {
let va = _mm512_loadu_pd(a.as_ptr().add(i * 8));
let vb = _mm512_loadu_pd(b.as_ptr().add(i * 8));
let diff = _mm512_sub_pd(va, vb);
acc = _mm512_fmadd_pd(diff, diff, acc);
}
let mut sum = _mm512_reduce_add_pd(acc);
for i in (chunks * 8)..len {
let diff = a[i] - b[i];
sum += diff * diff;
}
sum
}
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
#[inline(always)]
fn euclidean_f64_avx512(a: &[f64], b: &[f64]) -> f64 {
euclidean_f64_avx2(a, b)
}
#[inline(always)]
fn dot_f32_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
#[inline(always)]
fn dot_f32_sse(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 4;
let mut acc = f32x4::ZERO;
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = f32x4::from(*(a_ptr.add(offset) as *const [f32; 4]));
let vb = f32x4::from(*(b_ptr.add(offset) as *const [f32; 4]));
acc += va * vb;
}
}
let mut sum = acc.reduce_add();
for i in (chunks * 4)..len {
sum += a[i] * b[i];
}
sum
}
#[inline(always)]
fn dot_f32_avx2(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 8;
let mut acc = f32x8::ZERO;
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let va = f32x8::from(*(a_ptr.add(offset) as *const [f32; 8]));
let vb = f32x8::from(*(b_ptr.add(offset) as *const [f32; 8]));
acc += va * vb;
}
}
let mut sum = acc.reduce_add();
for i in (chunks * 8)..len {
sum += a[i] * b[i];
}
sum
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[inline(always)]
fn dot_f32_avx512(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 16;
unsafe {
let mut acc = _mm512_setzero_ps();
for i in 0..chunks {
let va = _mm512_loadu_ps(a.as_ptr().add(i * 16));
let vb = _mm512_loadu_ps(b.as_ptr().add(i * 16));
acc = _mm512_fmadd_ps(va, vb, acc);
}
let mut sum = _mm512_reduce_add_ps(acc);
for i in (chunks * 16)..len {
sum += a[i] * b[i];
}
sum
}
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
#[inline(always)]
fn dot_f32_avx512(a: &[f32], b: &[f32]) -> f32 {
dot_f32_avx2(a, b)
}
#[inline(always)]
fn dot_f64_scalar(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
#[inline(always)]
fn dot_f64_sse(a: &[f64], b: &[f64]) -> f64 {
let len = a.len();
let chunks = len / 2;
let mut acc = f64x2::ZERO;
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 2;
let va = f64x2::from(*(a_ptr.add(offset) as *const [f64; 2]));
let vb = f64x2::from(*(b_ptr.add(offset) as *const [f64; 2]));
acc += va * vb;
}
}
let mut sum = acc.reduce_add();
if len % 2 == 1 {
sum += a[len - 1] * b[len - 1];
}
sum
}
#[inline(always)]
fn dot_f64_avx2(a: &[f64], b: &[f64]) -> f64 {
let len = a.len();
let chunks = len / 4;
let mut acc = f64x4::ZERO;
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = f64x4::from(*(a_ptr.add(offset) as *const [f64; 4]));
let vb = f64x4::from(*(b_ptr.add(offset) as *const [f64; 4]));
acc += va * vb;
}
}
let mut sum = acc.reduce_add();
for i in (chunks * 4)..len {
sum += a[i] * b[i];
}
sum
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[inline(always)]
fn dot_f64_avx512(a: &[f64], b: &[f64]) -> f64 {
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 8;
unsafe {
let mut acc = _mm512_setzero_pd();
for i in 0..chunks {
let va = _mm512_loadu_pd(a.as_ptr().add(i * 8));
let vb = _mm512_loadu_pd(b.as_ptr().add(i * 8));
acc = _mm512_fmadd_pd(va, vb, acc);
}
let mut sum = _mm512_reduce_add_pd(acc);
for i in (chunks * 8)..len {
sum += a[i] * b[i];
}
sum
}
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
#[inline(always)]
fn dot_f64_avx512(a: &[f64], b: &[f64]) -> f64 {
dot_f64_avx2(a, b)
}
#[inline(always)]
fn subtract_f32_scalar(a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect()
}
#[inline(always)]
fn subtract_f32_sse(a: &[f32], b: &[f32]) -> Vec<f32> {
let len = a.len();
let chunks = len / 4;
let mut result = Vec::with_capacity(len);
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let result_ptr: *mut f32 = result.as_mut_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = f32x4::from(*(a_ptr.add(offset) as *const [f32; 4]));
let vb = f32x4::from(*(b_ptr.add(offset) as *const [f32; 4]));
let diff = va - vb;
*(result_ptr.add(offset) as *mut [f32; 4]) = diff.into();
}
for i in (chunks * 4)..len {
*result_ptr.add(i) = a[i] - b[i];
}
result.set_len(len);
}
result
}
#[inline(always)]
fn subtract_f32_avx2(a: &[f32], b: &[f32]) -> Vec<f32> {
let len = a.len();
let chunks = len / 8;
let mut result = Vec::with_capacity(len);
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let result_ptr: *mut f32 = result.as_mut_ptr();
for i in 0..chunks {
let offset = i * 8;
let va = f32x8::from(*(a_ptr.add(offset) as *const [f32; 8]));
let vb = f32x8::from(*(b_ptr.add(offset) as *const [f32; 8]));
let diff = va - vb;
*(result_ptr.add(offset) as *mut [f32; 8]) = diff.into();
}
for i in (chunks * 8)..len {
*result_ptr.add(i) = a[i] - b[i];
}
result.set_len(len);
}
result
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[inline(always)]
fn subtract_f32_avx512(a: &[f32], b: &[f32]) -> Vec<f32> {
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 16;
let mut result = Vec::with_capacity(len);
unsafe {
let result_ptr: *mut f32 = result.as_mut_ptr();
for i in 0..chunks {
let va = _mm512_loadu_ps(a.as_ptr().add(i * 16));
let vb = _mm512_loadu_ps(b.as_ptr().add(i * 16));
let diff = _mm512_sub_ps(va, vb);
_mm512_storeu_ps(result_ptr.add(i * 16), diff);
}
for i in (chunks * 16)..len {
*result_ptr.add(i) = a[i] - b[i];
}
result.set_len(len);
}
result
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
#[inline(always)]
fn subtract_f32_avx512(a: &[f32], b: &[f32]) -> Vec<f32> {
subtract_f32_avx2(a, b)
}
#[inline(always)]
fn subtract_f64_scalar(a: &[f64], b: &[f64]) -> Vec<f64> {
a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect()
}
#[inline(always)]
fn subtract_f64_sse(a: &[f64], b: &[f64]) -> Vec<f64> {
let len = a.len();
let chunks = len / 2;
let mut result = Vec::with_capacity(len);
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let result_ptr: *mut f64 = result.as_mut_ptr();
for i in 0..chunks {
let offset = i * 2;
let va = f64x2::from(*(a_ptr.add(offset) as *const [f64; 2]));
let vb = f64x2::from(*(b_ptr.add(offset) as *const [f64; 2]));
let diff = va - vb;
*(result_ptr.add(offset) as *mut [f64; 2]) = diff.into();
}
if len % 2 == 1 {
*result_ptr.add(len - 1) = a[len - 1] - b[len - 1];
}
result.set_len(len);
}
result
}
#[inline(always)]
fn subtract_f64_avx2(a: &[f64], b: &[f64]) -> Vec<f64> {
let len = a.len();
let chunks = len / 4;
let mut result = Vec::with_capacity(len);
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let result_ptr: *mut f64 = result.as_mut_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = f64x4::from(*(a_ptr.add(offset) as *const [f64; 4]));
let vb = f64x4::from(*(b_ptr.add(offset) as *const [f64; 4]));
let diff = va - vb;
*(result_ptr.add(offset) as *mut [f64; 4]) = diff.into();
}
for i in (chunks * 4)..len {
*result_ptr.add(i) = a[i] - b[i];
}
result.set_len(len);
}
result
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[inline(always)]
fn subtract_f64_avx512(a: &[f64], b: &[f64]) -> Vec<f64> {
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 8;
let mut result = Vec::with_capacity(len);
unsafe {
let result_ptr: *mut f64 = result.as_mut_ptr();
for i in 0..chunks {
let va = _mm512_loadu_pd(a.as_ptr().add(i * 8));
let vb = _mm512_loadu_pd(b.as_ptr().add(i * 8));
let diff = _mm512_sub_pd(va, vb);
_mm512_storeu_pd(result_ptr.add(i * 8), diff);
}
for i in (chunks * 8)..len {
*result_ptr.add(i) = a[i] - b[i];
}
result.set_len(len);
}
result
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
#[inline(always)]
fn subtract_f64_avx512(a: &[f64], b: &[f64]) -> Vec<f64> {
subtract_f64_avx2(a, b)
}
#[inline(always)]
fn add_f32_scalar(a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect()
}
#[inline(always)]
fn add_f32_sse(a: &[f32], b: &[f32]) -> Vec<f32> {
let len = a.len();
let chunks = len / 4;
let mut result = Vec::with_capacity(len);
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let result_ptr: *mut f32 = result.as_mut_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = f32x4::from(*(a_ptr.add(offset) as *const [f32; 4]));
let vb = f32x4::from(*(b_ptr.add(offset) as *const [f32; 4]));
let diff = va + vb;
*(result_ptr.add(offset) as *mut [f32; 4]) = diff.into();
}
for i in (chunks * 4)..len {
*result_ptr.add(i) = a[i] + b[i];
}
result.set_len(len);
}
result
}
#[inline(always)]
fn add_f32_avx2(a: &[f32], b: &[f32]) -> Vec<f32> {
let len = a.len();
let chunks = len / 8;
let mut result = Vec::with_capacity(len);
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let result_ptr: *mut f32 = result.as_mut_ptr();
for i in 0..chunks {
let offset = i * 8;
let va = f32x8::from(*(a_ptr.add(offset) as *const [f32; 8]));
let vb = f32x8::from(*(b_ptr.add(offset) as *const [f32; 8]));
let diff = va + vb;
*(result_ptr.add(offset) as *mut [f32; 8]) = diff.into();
}
for i in (chunks * 8)..len {
*result_ptr.add(i) = a[i] + b[i];
}
result.set_len(len);
}
result
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[inline(always)]
fn add_f32_avx512(a: &[f32], b: &[f32]) -> Vec<f32> {
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 16;
let mut result = Vec::with_capacity(len);
unsafe {
let result_ptr: *mut f32 = result.as_mut_ptr();
for i in 0..chunks {
let va = _mm512_loadu_ps(a.as_ptr().add(i * 16));
let vb = _mm512_loadu_ps(b.as_ptr().add(i * 16));
let diff = _mm512_add_ps(va, vb);
_mm512_storeu_ps(result_ptr.add(i * 16), diff);
}
for i in (chunks * 16)..len {
*result_ptr.add(i) = a[i] + b[i];
}
result.set_len(len);
}
result
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
#[inline(always)]
fn add_f32_avx512(a: &[f32], b: &[f32]) -> Vec<f32> {
add_f32_avx2(a, b)
}
#[inline(always)]
fn add_f64_scalar(a: &[f64], b: &[f64]) -> Vec<f64> {
a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect()
}
#[inline(always)]
fn add_f64_sse(a: &[f64], b: &[f64]) -> Vec<f64> {
let len = a.len();
let chunks = len / 2;
let mut result = Vec::with_capacity(len);
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let result_ptr: *mut f64 = result.as_mut_ptr();
for i in 0..chunks {
let offset = i * 2;
let va = f64x2::from(*(a_ptr.add(offset) as *const [f64; 2]));
let vb = f64x2::from(*(b_ptr.add(offset) as *const [f64; 2]));
let diff = va + vb;
*(result_ptr.add(offset) as *mut [f64; 2]) = diff.into();
}
if len % 2 == 1 {
*result_ptr.add(len - 1) = a[len - 1] + b[len - 1];
}
result.set_len(len);
}
result
}
#[inline(always)]
fn add_f64_avx2(a: &[f64], b: &[f64]) -> Vec<f64> {
let len = a.len();
let chunks = len / 4;
let mut result = Vec::with_capacity(len);
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let result_ptr: *mut f64 = result.as_mut_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = f64x4::from(*(a_ptr.add(offset) as *const [f64; 4]));
let vb = f64x4::from(*(b_ptr.add(offset) as *const [f64; 4]));
let diff = va + vb;
*(result_ptr.add(offset) as *mut [f64; 4]) = diff.into();
}
for i in (chunks * 4)..len {
*result_ptr.add(i) = a[i] + b[i];
}
result.set_len(len);
}
result
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[inline(always)]
fn add_f64_avx512(a: &[f64], b: &[f64]) -> Vec<f64> {
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 8;
let mut result = Vec::with_capacity(len);
unsafe {
let result_ptr: *mut f64 = result.as_mut_ptr();
for i in 0..chunks {
let va = _mm512_loadu_pd(a.as_ptr().add(i * 8));
let vb = _mm512_loadu_pd(b.as_ptr().add(i * 8));
let diff = _mm512_add_pd(va, vb);
_mm512_storeu_pd(result_ptr.add(i * 8), diff);
}
for i in (chunks * 8)..len {
*result_ptr.add(i) = a[i] + b[i];
}
result.set_len(len);
}
result
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
#[inline(always)]
fn add_f64_avx512(a: &[f64], b: &[f64]) -> Vec<f64> {
add_f64_avx2(a, b)
}
#[inline(always)]
fn add_assign_f32_scalar(dst: &mut [f32], src: &[f32]) {
dst.iter_mut().zip(src.iter()).for_each(|(d, &s)| *d += s);
}
#[inline(always)]
fn add_assign_f32_sse(dst: &mut [f32], src: &[f32]) {
let len = dst.len();
let chunks = len / 4;
unsafe {
let dst_ptr = dst.as_mut_ptr();
let src_ptr = src.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let vd = f32x4::from(*(dst_ptr.add(offset) as *const [f32; 4]));
let vs = f32x4::from(*(src_ptr.add(offset) as *const [f32; 4]));
let sum = vd + vs;
*(dst_ptr.add(offset) as *mut [f32; 4]) = sum.into();
}
for i in (chunks * 4)..len {
*dst_ptr.add(i) += *src_ptr.add(i);
}
}
}
#[inline(always)]
fn add_assign_f32_avx2(dst: &mut [f32], src: &[f32]) {
let len = dst.len();
let chunks = len / 8;
unsafe {
let dst_ptr = dst.as_mut_ptr();
let src_ptr = src.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let vd = f32x8::from(*(dst_ptr.add(offset) as *const [f32; 8]));
let vs = f32x8::from(*(src_ptr.add(offset) as *const [f32; 8]));
let sum = vd + vs;
*(dst_ptr.add(offset) as *mut [f32; 8]) = sum.into();
}
for i in (chunks * 8)..len {
*dst_ptr.add(i) += *src_ptr.add(i);
}
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[inline(always)]
fn add_assign_f32_avx512(dst: &mut [f32], src: &[f32]) {
use std::arch::x86_64::*;
let len = dst.len();
let chunks = len / 16;
unsafe {
let dst_ptr = dst.as_mut_ptr();
let src_ptr = src.as_ptr();
for i in 0..chunks {
let offset = i * 16;
let vd = _mm512_loadu_ps(dst_ptr.add(offset));
let vs = _mm512_loadu_ps(src_ptr.add(offset));
let sum = _mm512_add_ps(vd, vs);
_mm512_storeu_ps(dst_ptr.add(offset), sum);
}
for i in (chunks * 16)..len {
*dst_ptr.add(i) += *src_ptr.add(i);
}
}
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
#[inline(always)]
fn add_assign_f32_avx512(dst: &mut [f32], src: &[f32]) {
add_assign_f32_avx2(dst, src);
}
#[inline(always)]
fn add_assign_f64_scalar(dst: &mut [f64], src: &[f64]) {
dst.iter_mut().zip(src.iter()).for_each(|(d, &s)| *d += s);
}
#[inline(always)]
fn add_assign_f64_sse(dst: &mut [f64], src: &[f64]) {
let len = dst.len();
let chunks = len / 2;
unsafe {
let dst_ptr = dst.as_mut_ptr();
let src_ptr = src.as_ptr();
for i in 0..chunks {
let offset = i * 2;
let vd = f64x2::from(*(dst_ptr.add(offset) as *const [f64; 2]));
let vs = f64x2::from(*(src_ptr.add(offset) as *const [f64; 2]));
let sum = vd + vs;
*(dst_ptr.add(offset) as *mut [f64; 2]) = sum.into();
}
if len % 2 == 1 {
*dst_ptr.add(len - 1) += *src_ptr.add(len - 1);
}
}
}
#[inline(always)]
fn add_assign_f64_avx2(dst: &mut [f64], src: &[f64]) {
let len = dst.len();
let chunks = len / 4;
unsafe {
let dst_ptr = dst.as_mut_ptr();
let src_ptr = src.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let vd = f64x4::from(*(dst_ptr.add(offset) as *const [f64; 4]));
let vs = f64x4::from(*(src_ptr.add(offset) as *const [f64; 4]));
let sum = vd + vs;
*(dst_ptr.add(offset) as *mut [f64; 4]) = sum.into();
}
for i in (chunks * 4)..len {
*dst_ptr.add(i) += *src_ptr.add(i);
}
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[inline(always)]
fn add_assign_f64_avx512(dst: &mut [f64], src: &[f64]) {
use std::arch::x86_64::*;
let len = dst.len();
let chunks = len / 8;
unsafe {
let dst_ptr = dst.as_mut_ptr();
let src_ptr = src.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let vd = _mm512_loadu_pd(dst_ptr.add(offset));
let vs = _mm512_loadu_pd(src_ptr.add(offset));
let sum = _mm512_add_pd(vd, vs);
_mm512_storeu_pd(dst_ptr.add(offset), sum);
}
for i in (chunks * 8)..len {
*dst_ptr.add(i) += *src_ptr.add(i);
}
}
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
#[inline(always)]
fn add_assign_f64_avx512(dst: &mut [f64], src: &[f64]) {
add_assign_f64_avx2(dst, src);
}
#[inline(always)]
fn compute_l2_norm_f32_scalar(vec: &[f32]) -> f32 {
let mut sum = 0.0_f32;
for &x in vec {
sum += x * x;
}
sum.sqrt()
}
#[inline(always)]
fn compute_l2_norm_f32_sse(vec: &[f32]) -> f32 {
let len = vec.len();
let chunks = len / 4;
let mut acc = f32x4::ZERO;
unsafe {
let vec_ptr = vec.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let v_chunk = f32x4::from(*(vec_ptr.add(offset) as *const [f32; 4]));
acc += v_chunk * v_chunk;
}
}
let mut sum = acc.reduce_add();
for i in (chunks * 4)..len {
sum += vec[i] * vec[i];
}
sum.sqrt()
}
#[inline(always)]
fn compute_l2_norm_f32_avx2(vec: &[f32]) -> f32 {
let len = vec.len();
let chunks = len / 8;
let mut acc = f32x8::ZERO;
unsafe {
let vec_ptr = vec.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let v_chunk = f32x8::from(*(vec_ptr.add(offset) as *const [f32; 8]));
acc += v_chunk * v_chunk;
}
}
let mut sum = acc.reduce_add();
for i in (chunks * 8)..len {
sum += vec[i] * vec[i];
}
sum.sqrt()
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[inline(always)]
fn compute_l2_norm_f32_avx512(vec: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = vec.len();
let chunks = len / 16;
unsafe {
let mut acc = _mm512_setzero_ps();
for i in 0..chunks {
let v_chunk = _mm512_loadu_ps(vec.as_ptr().add(i * 16));
acc = _mm512_fmadd_ps(v_chunk, v_chunk, acc);
}
let mut sum = _mm512_reduce_add_ps(acc);
for i in (chunks * 16)..len {
sum += vec[i] * vec[i];
}
sum.sqrt()
}
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
#[inline(always)]
fn compute_l2_norm_f32_avx512(vec: &[f32]) -> f32 {
compute_l2_norm_f32_avx2(vec)
}
#[inline(always)]
fn compute_l2_norm_f64_scalar(vec: &[f64]) -> f64 {
let mut sum = 0.0_f64;
for &x in vec {
sum += x * x;
}
sum.sqrt()
}
#[inline(always)]
fn compute_l2_norm_f64_sse(vec: &[f64]) -> f64 {
let len = vec.len();
let chunks = len / 2;
let mut acc = f64x2::ZERO;
unsafe {
let vec_ptr = vec.as_ptr();
for i in 0..chunks {
let offset = i * 2;
let v_chunk = f64x2::from(*(vec_ptr.add(offset) as *const [f64; 2]));
acc += v_chunk * v_chunk;
}
}
let mut sum = acc.reduce_add();
if len % 2 == 1 {
sum += vec[len - 1] * vec[len - 1];
}
sum.sqrt()
}
#[inline(always)]
fn compute_l2_norm_f64_avx2(vec: &[f64]) -> f64 {
let len = vec.len();
let chunks = len / 4;
let mut acc = f64x4::ZERO;
unsafe {
let vec_ptr = vec.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let v_chunk = f64x4::from(*(vec_ptr.add(offset) as *const [f64; 4]));
acc += v_chunk * v_chunk;
}
}
let mut sum = acc.reduce_add();
for i in (chunks * 4)..len {
sum += vec[i] * vec[i];
}
sum.sqrt()
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[inline(always)]
fn compute_l2_norm_f64_avx512(vec: &[f64]) -> f64 {
use std::arch::x86_64::*;
let len = vec.len();
let chunks = len / 8;
unsafe {
let mut acc = _mm512_setzero_pd();
for i in 0..chunks {
let v_chunk = _mm512_loadu_pd(vec.as_ptr().add(i * 8));
acc = _mm512_fmadd_pd(v_chunk, v_chunk, acc);
}
let mut sum = _mm512_reduce_add_pd(acc);
for i in (chunks * 8)..len {
sum += vec[i] * vec[i];
}
sum.sqrt()
}
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
#[inline(always)]
fn compute_l2_norm_f64_avx512(vec: &[f64]) -> f64 {
compute_l2_norm_f64_avx2(vec)
}
#[inline(always)]
fn compute_l1_norm_f32_scalar(a: &[f32]) -> f32 {
a.iter().map(|&x| x.abs()).sum()
}
#[inline(always)]
fn compute_l1_norm_f32_sse(a: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 4;
let mut acc = f32x4::ZERO;
unsafe {
let a_ptr = a.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = f32x4::from(*(a_ptr.add(offset) as *const [f32; 4]));
acc += va.abs();
}
}
let mut sum = acc.reduce_add();
for i in (chunks * 4)..len {
sum += a[i].abs();
}
sum
}
#[inline(always)]
fn compute_l1_norm_f32_avx2(a: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 8;
let mut acc = f32x8::ZERO;
unsafe {
let a_ptr = a.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let va = f32x8::from(*(a_ptr.add(offset) as *const [f32; 8]));
acc += va.abs();
}
}
let mut sum = acc.reduce_add();
for i in (chunks * 8)..len {
sum += a[i].abs();
}
sum
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[inline(always)]
fn compute_l1_norm_f32_avx512(a: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 16;
unsafe {
let mut acc = _mm512_setzero_ps();
for i in 0..chunks {
let va = _mm512_loadu_ps(a.as_ptr().add(i * 16));
let abs_va = _mm512_abs_ps(va);
acc = _mm512_add_ps(acc, abs_va);
}
let mut sum = _mm512_reduce_add_ps(acc);
for i in (chunks * 16)..len {
sum += a[i].abs();
}
sum
}
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
#[inline(always)]
fn compute_l1_norm_f32_avx512(a: &[f32]) -> f32 {
compute_l1_norm_f32_avx2(a)
}
#[inline(always)]
fn compute_l1_norm_f64_scalar(a: &[f64]) -> f64 {
a.iter().map(|&x| x.abs()).sum()
}
#[inline(always)]
fn compute_l1_norm_f64_sse(a: &[f64]) -> f64 {
let len = a.len();
let chunks = len / 2;
let mut acc = f64x2::ZERO;
unsafe {
let a_ptr = a.as_ptr();
for i in 0..chunks {
let offset = i * 2;
let va = f64x2::from(*(a_ptr.add(offset) as *const [f64; 2]));
acc += va.abs();
}
}
let mut sum = acc.reduce_add();
if len % 2 == 1 {
sum += a[len - 1].abs();
}
sum
}
#[inline(always)]
fn compute_l1_norm_f64_avx2(a: &[f64]) -> f64 {
let len = a.len();
let chunks = len / 4;
let mut acc = f64x4::ZERO;
unsafe {
let a_ptr = a.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = f64x4::from(*(a_ptr.add(offset) as *const [f64; 4]));
acc += va.abs();
}
}
let mut sum = acc.reduce_add();
for i in (chunks * 4)..len {
sum += a[i].abs();
}
sum
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[inline(always)]
fn compute_l1_norm_f64_avx512(a: &[f64]) -> f64 {
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 8;
unsafe {
let mut acc = _mm512_setzero_pd();
for i in 0..chunks {
let va = _mm512_loadu_pd(a.as_ptr().add(i * 8));
let abs_va = _mm512_abs_pd(va);
acc = _mm512_add_pd(acc, abs_va);
}
let mut sum = _mm512_reduce_add_pd(acc);
for i in (chunks * 8)..len {
sum += a[i].abs();
}
sum
}
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
#[inline(always)]
fn compute_l1_norm_f64_avx512(a: &[f64]) -> f64 {
compute_l1_norm_f64_avx2(a)
}
impl SimdDistance for f32 {
#[inline]
fn euclidean_simd(a: &[f32], b: &[f32]) -> f32 {
match detect_simd_level() {
SimdLevel::Avx512 => euclidean_f32_avx512(a, b),
SimdLevel::Avx2 => euclidean_f32_avx2(a, b),
SimdLevel::Sse => euclidean_f32_sse(a, b),
SimdLevel::Scalar => euclidean_f32_scalar(a, b),
}
}
#[inline]
fn dot_simd(a: &[f32], b: &[f32]) -> f32 {
match detect_simd_level() {
SimdLevel::Avx512 => dot_f32_avx512(a, b),
SimdLevel::Avx2 => dot_f32_avx2(a, b),
SimdLevel::Sse => dot_f32_sse(a, b),
SimdLevel::Scalar => dot_f32_scalar(a, b),
}
}
#[inline]
fn subtract_simd(a: &[f32], b: &[f32]) -> Vec<f32> {
match detect_simd_level() {
SimdLevel::Avx512 => subtract_f32_avx512(a, b),
SimdLevel::Avx2 => subtract_f32_avx2(a, b),
SimdLevel::Sse => subtract_f32_sse(a, b),
SimdLevel::Scalar => subtract_f32_scalar(a, b),
}
}
#[inline]
fn add_simd(a: &[Self], b: &[Self]) -> Vec<Self> {
match detect_simd_level() {
SimdLevel::Avx512 => add_f32_avx512(a, b),
SimdLevel::Avx2 => add_f32_avx2(a, b),
SimdLevel::Sse => add_f32_sse(a, b),
SimdLevel::Scalar => add_f32_scalar(a, b),
}
}
#[inline]
fn add_assign_simd(dst: &mut [Self], src: &[Self]) {
match detect_simd_level() {
SimdLevel::Avx512 => add_assign_f32_avx512(dst, src),
SimdLevel::Avx2 => add_assign_f32_avx2(dst, src),
SimdLevel::Sse => add_assign_f32_sse(dst, src),
SimdLevel::Scalar => add_assign_f32_scalar(dst, src),
}
}
#[inline]
fn calculate_l2_norm(vec: &[Self]) -> Self {
match detect_simd_level() {
SimdLevel::Avx512 => compute_l2_norm_f32_avx512(vec),
SimdLevel::Avx2 => compute_l2_norm_f32_avx2(vec),
SimdLevel::Sse => compute_l2_norm_f32_sse(vec),
SimdLevel::Scalar => compute_l2_norm_f32_scalar(vec),
}
}
#[inline]
fn calculate_l1_norm(vec: &[Self]) -> Self {
match detect_simd_level() {
SimdLevel::Avx512 => compute_l1_norm_f32_avx512(vec),
SimdLevel::Avx2 => compute_l1_norm_f32_avx2(vec),
SimdLevel::Sse => compute_l1_norm_f32_sse(vec),
SimdLevel::Scalar => compute_l1_norm_f32_scalar(vec),
}
}
}
impl SimdDistance for f64 {
#[inline]
fn euclidean_simd(a: &[f64], b: &[f64]) -> f64 {
match detect_simd_level() {
SimdLevel::Avx512 => euclidean_f64_avx512(a, b),
SimdLevel::Avx2 => euclidean_f64_avx2(a, b),
SimdLevel::Sse => euclidean_f64_sse(a, b),
SimdLevel::Scalar => euclidean_f64_scalar(a, b),
}
}
#[inline]
fn dot_simd(a: &[f64], b: &[f64]) -> f64 {
match detect_simd_level() {
SimdLevel::Avx512 => dot_f64_avx512(a, b),
SimdLevel::Avx2 => dot_f64_avx2(a, b),
SimdLevel::Sse => dot_f64_sse(a, b),
SimdLevel::Scalar => dot_f64_scalar(a, b),
}
}
#[inline]
fn subtract_simd(a: &[f64], b: &[f64]) -> Vec<f64> {
match detect_simd_level() {
SimdLevel::Avx512 => subtract_f64_avx512(a, b),
SimdLevel::Avx2 => subtract_f64_avx2(a, b),
SimdLevel::Sse => subtract_f64_sse(a, b),
SimdLevel::Scalar => subtract_f64_scalar(a, b),
}
}
#[inline]
fn add_simd(a: &[Self], b: &[Self]) -> Vec<Self> {
match detect_simd_level() {
SimdLevel::Avx512 => add_f64_avx512(a, b),
SimdLevel::Avx2 => add_f64_avx2(a, b),
SimdLevel::Sse => add_f64_sse(a, b),
SimdLevel::Scalar => add_f64_scalar(a, b),
}
}
#[inline]
fn add_assign_simd(dst: &mut [Self], src: &[Self]) {
match detect_simd_level() {
SimdLevel::Avx512 => add_assign_f64_avx512(dst, src),
SimdLevel::Avx2 => add_assign_f64_avx2(dst, src),
SimdLevel::Sse => add_assign_f64_sse(dst, src),
SimdLevel::Scalar => add_assign_f64_scalar(dst, src),
}
}
#[inline]
fn calculate_l2_norm(vec: &[Self]) -> Self {
match detect_simd_level() {
SimdLevel::Avx512 => compute_l2_norm_f64_avx512(vec),
SimdLevel::Avx2 => compute_l2_norm_f64_avx2(vec),
SimdLevel::Sse => compute_l2_norm_f64_sse(vec),
SimdLevel::Scalar => compute_l2_norm_f64_scalar(vec),
}
}
#[inline]
fn calculate_l1_norm(vec: &[Self]) -> Self {
match detect_simd_level() {
SimdLevel::Avx512 => compute_l1_norm_f64_avx512(vec),
SimdLevel::Avx2 => compute_l1_norm_f64_avx2(vec),
SimdLevel::Sse => compute_l1_norm_f64_sse(vec),
SimdLevel::Scalar => compute_l1_norm_f64_scalar(vec),
}
}
}
pub trait VectorDistance<T>
where
T: Float + Sum + SimdDistance,
{
fn vectors_flat(&self) -> &[T];
fn dim(&self) -> usize;
fn norms(&self) -> &[T];
#[inline(always)]
fn euclidean_distance(&self, i: usize, j: usize) -> T {
let start_i = i * self.dim();
let start_j = j * self.dim();
let vec_i = &self.vectors_flat()[start_i..start_i + self.dim()];
let vec_j = &self.vectors_flat()[start_j..start_j + self.dim()];
T::euclidean_simd(vec_i, vec_j)
}
#[inline(always)]
fn euclidean_distance_to_query(&self, internal_idx: usize, query: &[T]) -> T {
let start = internal_idx * self.dim();
let vec = &self.vectors_flat()[start..start + self.dim()];
T::euclidean_simd(vec, query)
}
#[inline(always)]
fn cosine_distance(&self, i: usize, j: usize) -> T {
let start_i = i * self.dim();
let start_j = j * self.dim();
let vec_i = &self.vectors_flat()[start_i..start_i + self.dim()];
let vec_j = &self.vectors_flat()[start_j..start_j + self.dim()];
let dot = T::dot_simd(vec_i, vec_j);
T::one() - (dot / (self.norms()[i] * self.norms()[j]))
}
#[inline(always)]
fn cosine_distance_to_query(&self, internal_idx: usize, query: &[T], query_norm: T) -> T {
let start = internal_idx * self.dim();
let vec = &self.vectors_flat()[start..start + self.dim()];
let dot = T::dot_simd(vec, query);
T::one() - (dot / (query_norm * self.norms()[internal_idx]))
}
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
unsafe fn bf16x4_to_f32x4_sse(ptr: *const bf16) -> __m128 {
let raw = _mm_loadl_epi64(ptr as *const __m128i);
let extended = _mm_cvtepu16_epi32(raw);
let shifted = _mm_slli_epi32(extended, 16);
_mm_castsi128_ps(shifted)
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
unsafe fn hsum_f32_sse(v: __m128) -> f32 {
let shuf = _mm_movehdup_ps(v);
let sums = _mm_add_ps(v, shuf);
let shuf2 = _mm_movehl_ps(sums, sums);
let sums2 = _mm_add_ss(sums, shuf2);
_mm_cvtss_f32(sums2)
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
unsafe fn bf16x8_to_f32x8_avx2(ptr: *const bf16) -> __m256 {
let raw = _mm_loadu_si128(ptr as *const __m128i);
let extended = _mm256_cvtepu16_epi32(raw);
let shifted = _mm256_slli_epi32(extended, 16);
_mm256_castsi256_ps(shifted)
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
unsafe fn hsum_f32_avx2(v: __m256) -> f32 {
let low = _mm256_castps256_ps128(v);
let high = _mm256_extractf128_ps(v, 1);
let sum128 = _mm_add_ps(low, high);
hsum_f32_sse(sum128)
}
#[cfg(all(
feature = "quantised",
target_arch = "x86_64",
target_feature = "avx512f"
))]
#[inline(always)]
unsafe fn bf16x16_to_f32x16_avx512(ptr: *const bf16) -> __m512 {
let raw = _mm256_loadu_si256(ptr as *const __m256i);
let extended = _mm512_cvtepu16_epi32(raw);
let shifted = _mm512_slli_epi32(extended, 16);
_mm512_castsi512_ps(shifted)
}
#[cfg(all(feature = "quantised", target_arch = "aarch64"))]
#[inline(always)]
unsafe fn bf16x4_to_f32x4_neon(ptr: *const bf16) -> float32x4_t {
let raw = vld1_u16(ptr as *const u16);
let extended = vmovl_u16(raw);
let shifted = vshlq_n_u32(extended, 16);
vreinterpretq_f32_u32(shifted)
}
#[cfg(all(feature = "quantised", target_arch = "aarch64"))]
#[inline(always)]
unsafe fn hsum_f32_neon(v: float32x4_t) -> f32 {
vaddvq_f32(v)
}
#[cfg(feature = "quantised")]
#[inline(always)]
fn euclidean_bf16_scalar(a: &[bf16], b: &[bf16]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = f32::from(*x) - f32::from(*y);
d * d
})
.sum()
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
fn euclidean_bf16_sse(a: &[bf16], b: &[bf16]) -> f32 {
let len = a.len();
let chunks = len / 4;
unsafe {
let mut acc = _mm_setzero_ps();
for i in 0..chunks {
let va = bf16x4_to_f32x4_sse(a.as_ptr().add(i * 4));
let vb = bf16x4_to_f32x4_sse(b.as_ptr().add(i * 4));
let diff = _mm_sub_ps(va, vb);
acc = _mm_add_ps(acc, _mm_mul_ps(diff, diff));
}
let mut sum = hsum_f32_sse(acc);
for i in (chunks * 4)..len {
let diff = a[i].to_f32() - b[i].to_f32();
sum += diff * diff;
}
sum
}
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
fn euclidean_bf16_avx2(a: &[bf16], b: &[bf16]) -> f32 {
let len = a.len();
let chunks = len / 8;
unsafe {
let mut acc = _mm256_setzero_ps();
for i in 0..chunks {
let va = bf16x8_to_f32x8_avx2(a.as_ptr().add(i * 8));
let vb = bf16x8_to_f32x8_avx2(b.as_ptr().add(i * 8));
let diff = _mm256_sub_ps(va, vb);
acc = _mm256_fmadd_ps(diff, diff, acc);
}
let mut sum = hsum_f32_avx2(acc);
for i in (chunks * 8)..len {
let diff = a[i].to_f32() - b[i].to_f32();
sum += diff * diff;
}
sum
}
}
#[cfg(all(
feature = "quantised",
target_arch = "x86_64",
target_feature = "avx512f"
))]
#[inline(always)]
fn euclidean_bf16_avx512(a: &[bf16], b: &[bf16]) -> f32 {
let len = a.len();
let chunks = len / 16;
unsafe {
let mut acc = _mm512_setzero_ps();
for i in 0..chunks {
let va = bf16x16_to_f32x16_avx512(a.as_ptr().add(i * 16));
let vb = bf16x16_to_f32x16_avx512(b.as_ptr().add(i * 16));
let diff = _mm512_sub_ps(va, vb);
acc = _mm512_fmadd_ps(diff, diff, acc);
}
let mut sum = _mm512_reduce_add_ps(acc);
for i in (chunks * 16)..len {
let diff = a[i].to_f32() - b[i].to_f32();
sum += diff * diff;
}
sum
}
}
#[cfg(all(
feature = "quantised",
target_arch = "x86_64",
not(target_feature = "avx512f")
))]
#[inline(always)]
fn euclidean_bf16_avx512(a: &[bf16], b: &[bf16]) -> f32 {
euclidean_bf16_avx2(a, b)
}
#[cfg(all(feature = "quantised", target_arch = "aarch64"))]
#[inline(always)]
fn euclidean_bf16_neon(a: &[bf16], b: &[bf16]) -> f32 {
let len = a.len();
let chunks = len / 4;
unsafe {
let mut acc = vdupq_n_f32(0.0);
for i in 0..chunks {
let va = bf16x4_to_f32x4_neon(a.as_ptr().add(i * 4));
let vb = bf16x4_to_f32x4_neon(b.as_ptr().add(i * 4));
let diff = vsubq_f32(va, vb);
acc = vfmaq_f32(acc, diff, diff);
}
let mut sum = hsum_f32_neon(acc);
for i in (chunks * 4)..len {
let diff = a[i].to_f32() - b[i].to_f32();
sum += diff * diff;
}
sum
}
}
#[cfg(feature = "quantised")]
#[inline]
pub fn euclidean_bf16_simd(a: &[bf16], b: &[bf16]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
match crate::detect_simd_level() {
crate::SimdLevel::Avx512 => euclidean_bf16_avx512(a, b),
crate::SimdLevel::Avx2 => euclidean_bf16_avx2(a, b),
crate::SimdLevel::Sse => euclidean_bf16_sse(a, b),
crate::SimdLevel::Scalar => euclidean_bf16_scalar(a, b),
}
}
#[cfg(target_arch = "aarch64")]
{
euclidean_bf16_neon(a, b)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
euclidean_bf16_scalar(a, b)
}
}
#[cfg(feature = "quantised")]
#[inline(always)]
fn euclidean_bf16_f32_scalar(a: &[bf16], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = f32::from(*x) - y;
d * d
})
.sum()
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
fn euclidean_bf16_f32_sse(a: &[bf16], b: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 4;
unsafe {
let mut acc = _mm_setzero_ps();
for i in 0..chunks {
let va = bf16x4_to_f32x4_sse(a.as_ptr().add(i * 4));
let vb = _mm_loadu_ps(b.as_ptr().add(i * 4));
let diff = _mm_sub_ps(va, vb);
acc = _mm_add_ps(acc, _mm_mul_ps(diff, diff));
}
let mut sum = hsum_f32_sse(acc);
for i in (chunks * 4)..len {
let diff = a[i].to_f32() - b[i];
sum += diff * diff;
}
sum
}
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
fn euclidean_bf16_f32_avx2(a: &[bf16], b: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 8;
unsafe {
let mut acc = _mm256_setzero_ps();
for i in 0..chunks {
let va = bf16x8_to_f32x8_avx2(a.as_ptr().add(i * 8));
let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
let diff = _mm256_sub_ps(va, vb);
acc = _mm256_fmadd_ps(diff, diff, acc);
}
let mut sum = hsum_f32_avx2(acc);
for i in (chunks * 8)..len {
let diff = a[i].to_f32() - b[i];
sum += diff * diff;
}
sum
}
}
#[cfg(all(
feature = "quantised",
target_arch = "x86_64",
target_feature = "avx512f"
))]
#[inline(always)]
fn euclidean_bf16_f32_avx512(a: &[bf16], b: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 16;
unsafe {
let mut acc = _mm512_setzero_ps();
for i in 0..chunks {
let va = bf16x16_to_f32x16_avx512(a.as_ptr().add(i * 16));
let vb = _mm512_loadu_ps(b.as_ptr().add(i * 16));
let diff = _mm512_sub_ps(va, vb);
acc = _mm512_fmadd_ps(diff, diff, acc);
}
let mut sum = _mm512_reduce_add_ps(acc);
for i in (chunks * 16)..len {
let diff = a[i].to_f32() - b[i];
sum += diff * diff;
}
sum
}
}
#[cfg(all(
feature = "quantised",
target_arch = "x86_64",
not(target_feature = "avx512f")
))]
#[inline(always)]
fn euclidean_bf16_f32_avx512(a: &[bf16], b: &[f32]) -> f32 {
euclidean_bf16_f32_avx2(a, b)
}
#[cfg(all(feature = "quantised", target_arch = "aarch64"))]
#[inline(always)]
fn euclidean_bf16_f32_neon(a: &[bf16], b: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 4;
unsafe {
let mut acc = vdupq_n_f32(0.0);
for i in 0..chunks {
let va = bf16x4_to_f32x4_neon(a.as_ptr().add(i * 4));
let vb = vld1q_f32(b.as_ptr().add(i * 4));
let diff = vsubq_f32(va, vb);
acc = vfmaq_f32(acc, diff, diff);
}
let mut sum = hsum_f32_neon(acc);
for i in (chunks * 4)..len {
let diff = a[i].to_f32() - b[i];
sum += diff * diff;
}
sum
}
}
#[cfg(feature = "quantised")]
#[inline(always)]
fn euclidean_bf16_f64_scalar(a: &[bf16], b: &[f64]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = f32::from(*x) - (*y as f32);
d * d
})
.sum()
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
fn euclidean_bf16_f64_sse(a: &[bf16], b: &[f64]) -> f32 {
let len = a.len();
let chunks = len / 4;
unsafe {
let mut acc = _mm_setzero_ps();
for i in 0..chunks {
let offset = i * 4;
let va = bf16x4_to_f32x4_sse(a.as_ptr().add(offset));
let b_lo = _mm_loadu_pd(b.as_ptr().add(offset)); let b_hi = _mm_loadu_pd(b.as_ptr().add(offset + 2)); let b_lo_f32 = _mm_cvtpd_ps(b_lo); let b_hi_f32 = _mm_cvtpd_ps(b_hi); let vb = _mm_movelh_ps(b_lo_f32, b_hi_f32);
let diff = _mm_sub_ps(va, vb);
acc = _mm_add_ps(acc, _mm_mul_ps(diff, diff));
}
let mut sum = hsum_f32_sse(acc);
for i in (chunks * 4)..len {
let diff = a[i].to_f32() - (b[i] as f32);
sum += diff * diff;
}
sum
}
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
fn euclidean_bf16_f64_avx2(a: &[bf16], b: &[f64]) -> f32 {
let len = a.len();
let chunks = len / 8;
unsafe {
let mut acc = _mm256_setzero_ps();
for i in 0..chunks {
let offset = i * 8;
let va = bf16x8_to_f32x8_avx2(a.as_ptr().add(offset));
let b_lo = _mm256_loadu_pd(b.as_ptr().add(offset));
let b_hi = _mm256_loadu_pd(b.as_ptr().add(offset + 4));
let b_lo_f32 = _mm256_cvtpd_ps(b_lo); let b_hi_f32 = _mm256_cvtpd_ps(b_hi);
let vb = _mm256_insertf128_ps(_mm256_castps128_ps256(b_lo_f32), b_hi_f32, 1);
let diff = _mm256_sub_ps(va, vb);
acc = _mm256_fmadd_ps(diff, diff, acc);
}
let mut sum = hsum_f32_avx2(acc);
for i in (chunks * 8)..len {
let diff = a[i].to_f32() - (b[i] as f32);
sum += diff * diff;
}
sum
}
}
#[cfg(all(
feature = "quantised",
target_arch = "x86_64",
target_feature = "avx512f"
))]
#[inline(always)]
fn euclidean_bf16_f64_avx512(a: &[bf16], b: &[f64]) -> f32 {
let len = a.len();
let chunks = len / 16;
unsafe {
let mut acc = _mm512_setzero_ps();
for i in 0..chunks {
let offset = i * 16;
let va = bf16x16_to_f32x16_avx512(a.as_ptr().add(offset));
let b_lo = _mm512_loadu_pd(b.as_ptr().add(offset));
let b_hi = _mm512_loadu_pd(b.as_ptr().add(offset + 8));
let b_lo_f32 = _mm512_cvtpd_ps(b_lo); let b_hi_f32 = _mm512_cvtpd_ps(b_hi);
let vb = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo_f32), b_hi_f32, 1);
let diff = _mm512_sub_ps(va, vb);
acc = _mm512_fmadd_ps(diff, diff, acc);
}
let mut sum = _mm512_reduce_add_ps(acc);
for i in (chunks * 16)..len {
let diff = a[i].to_f32() - (b[i] as f32);
sum += diff * diff;
}
sum
}
}
#[cfg(all(
feature = "quantised",
target_arch = "x86_64",
not(target_feature = "avx512f")
))]
#[inline(always)]
fn euclidean_bf16_f64_avx512(a: &[bf16], b: &[f64]) -> f32 {
euclidean_bf16_f64_avx2(a, b)
}
#[cfg(all(feature = "quantised", target_arch = "aarch64"))]
#[inline(always)]
fn euclidean_bf16_f64_neon(a: &[bf16], b: &[f64]) -> f32 {
let len = a.len();
let chunks = len / 4;
unsafe {
let mut acc = vdupq_n_f32(0.0);
for i in 0..chunks {
let offset = i * 4;
let va = bf16x4_to_f32x4_neon(a.as_ptr().add(offset));
let b_lo = vld1q_f64(b.as_ptr().add(offset)); let b_hi = vld1q_f64(b.as_ptr().add(offset + 2)); let b_lo_f32 = vcvt_f32_f64(b_lo); let b_hi_f32 = vcvt_f32_f64(b_hi);
let vb = vcombine_f32(b_lo_f32, b_hi_f32);
let diff = vsubq_f32(va, vb);
acc = vfmaq_f32(acc, diff, diff);
}
let mut sum = hsum_f32_neon(acc);
for i in (chunks * 4)..len {
let diff = a[i].to_f32() - (b[i] as f32);
sum += diff * diff;
}
sum
}
}
#[cfg(feature = "quantised")]
#[inline]
pub fn euclidean_bf16_f32_simd(a: &[bf16], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
match detect_simd_level() {
SimdLevel::Avx512 => euclidean_bf16_f32_avx512(a, b),
SimdLevel::Avx2 => euclidean_bf16_f32_avx2(a, b),
SimdLevel::Sse => euclidean_bf16_f32_sse(a, b),
SimdLevel::Scalar => euclidean_bf16_f32_scalar(a, b),
}
}
#[cfg(target_arch = "aarch64")]
{
euclidean_bf16_f32_neon(a, b)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
euclidean_bf16_f32_scalar(a, b)
}
}
#[cfg(feature = "quantised")]
#[inline]
pub fn euclidean_bf16_f64_simd(a: &[bf16], b: &[f64]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
match detect_simd_level() {
SimdLevel::Avx512 => euclidean_bf16_f64_avx512(a, b),
SimdLevel::Avx2 => euclidean_bf16_f64_avx2(a, b),
SimdLevel::Sse => euclidean_bf16_f64_sse(a, b),
SimdLevel::Scalar => euclidean_bf16_f64_scalar(a, b),
}
}
#[cfg(target_arch = "aarch64")]
{
euclidean_bf16_f64_neon(a, b)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
euclidean_bf16_f64_scalar(a, b)
}
}
#[cfg(feature = "quantised")]
#[inline(always)]
fn dot_bf16_scalar(a: &[bf16], b: &[bf16]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| f32::from(*x) * f32::from(*y))
.sum()
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
fn dot_bf16_sse(a: &[bf16], b: &[bf16]) -> f32 {
let len = a.len();
let chunks = len / 4;
unsafe {
let mut acc = _mm_setzero_ps();
for i in 0..chunks {
let va = bf16x4_to_f32x4_sse(a.as_ptr().add(i * 4));
let vb = bf16x4_to_f32x4_sse(b.as_ptr().add(i * 4));
acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
}
let mut sum = hsum_f32_sse(acc);
for i in (chunks * 4)..len {
sum += a[i].to_f32() * b[i].to_f32();
}
sum
}
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
fn dot_bf16_avx2(a: &[bf16], b: &[bf16]) -> f32 {
let len = a.len();
let chunks = len / 8;
unsafe {
let mut acc = _mm256_setzero_ps();
for i in 0..chunks {
let va = bf16x8_to_f32x8_avx2(a.as_ptr().add(i * 8));
let vb = bf16x8_to_f32x8_avx2(b.as_ptr().add(i * 8));
acc = _mm256_fmadd_ps(va, vb, acc);
}
let mut sum = hsum_f32_avx2(acc);
for i in (chunks * 8)..len {
sum += a[i].to_f32() * b[i].to_f32();
}
sum
}
}
#[cfg(all(
feature = "quantised",
target_arch = "x86_64",
target_feature = "avx512f"
))]
#[inline(always)]
fn dot_bf16_avx512(a: &[bf16], b: &[bf16]) -> f32 {
let len = a.len();
let chunks = len / 16;
unsafe {
let mut acc = _mm512_setzero_ps();
for i in 0..chunks {
let va = bf16x16_to_f32x16_avx512(a.as_ptr().add(i * 16));
let vb = bf16x16_to_f32x16_avx512(b.as_ptr().add(i * 16));
acc = _mm512_fmadd_ps(va, vb, acc);
}
let mut sum = _mm512_reduce_add_ps(acc);
for i in (chunks * 16)..len {
sum += a[i].to_f32() * b[i].to_f32();
}
sum
}
}
#[cfg(all(
feature = "quantised",
target_arch = "x86_64",
not(target_feature = "avx512f")
))]
#[inline(always)]
fn dot_bf16_avx512(a: &[bf16], b: &[bf16]) -> f32 {
dot_bf16_avx2(a, b)
}
#[cfg(all(feature = "quantised", target_arch = "aarch64"))]
#[inline(always)]
fn dot_bf16_neon(a: &[bf16], b: &[bf16]) -> f32 {
let len = a.len();
let chunks = len / 4;
unsafe {
let mut acc = vdupq_n_f32(0.0);
for i in 0..chunks {
let va = bf16x4_to_f32x4_neon(a.as_ptr().add(i * 4));
let vb = bf16x4_to_f32x4_neon(b.as_ptr().add(i * 4));
acc = vfmaq_f32(acc, va, vb);
}
let mut sum = hsum_f32_neon(acc);
for i in (chunks * 4)..len {
sum += a[i].to_f32() * b[i].to_f32();
}
sum
}
}
#[cfg(feature = "quantised")]
#[inline]
pub fn dot_bf16_simd(a: &[bf16], b: &[bf16]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
match crate::detect_simd_level() {
crate::SimdLevel::Avx512 => dot_bf16_avx512(a, b),
crate::SimdLevel::Avx2 => dot_bf16_avx2(a, b),
crate::SimdLevel::Sse => dot_bf16_sse(a, b),
crate::SimdLevel::Scalar => dot_bf16_scalar(a, b),
}
}
#[cfg(target_arch = "aarch64")]
{
dot_bf16_neon(a, b)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
dot_bf16_scalar(a, b)
}
}
#[cfg(feature = "quantised")]
#[inline(always)]
fn dot_bf16_f32_scalar(a: &[bf16], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| f32::from(*x) * y).sum()
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
fn dot_bf16_f32_sse(a: &[bf16], b: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 4;
unsafe {
let mut acc = _mm_setzero_ps();
for i in 0..chunks {
let va = bf16x4_to_f32x4_sse(a.as_ptr().add(i * 4));
let vb = _mm_loadu_ps(b.as_ptr().add(i * 4));
acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
}
let mut sum = hsum_f32_sse(acc);
for i in (chunks * 4)..len {
sum += a[i].to_f32() * b[i];
}
sum
}
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
fn dot_bf16_f32_avx2(a: &[bf16], b: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 8;
unsafe {
let mut acc = _mm256_setzero_ps();
for i in 0..chunks {
let va = bf16x8_to_f32x8_avx2(a.as_ptr().add(i * 8));
let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
acc = _mm256_fmadd_ps(va, vb, acc);
}
let mut sum = hsum_f32_avx2(acc);
for i in (chunks * 8)..len {
sum += a[i].to_f32() * b[i];
}
sum
}
}
#[cfg(all(
feature = "quantised",
target_arch = "x86_64",
target_feature = "avx512f"
))]
#[inline(always)]
fn dot_bf16_f32_avx512(a: &[bf16], b: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 16;
unsafe {
let mut acc = _mm512_setzero_ps();
for i in 0..chunks {
let va = bf16x16_to_f32x16_avx512(a.as_ptr().add(i * 16));
let vb = _mm512_loadu_ps(b.as_ptr().add(i * 16));
acc = _mm512_fmadd_ps(va, vb, acc);
}
let mut sum = _mm512_reduce_add_ps(acc);
for i in (chunks * 16)..len {
sum += a[i].to_f32() * b[i];
}
sum
}
}
#[cfg(all(
feature = "quantised",
target_arch = "x86_64",
not(target_feature = "avx512f")
))]
#[inline(always)]
fn dot_bf16_f32_avx512(a: &[bf16], b: &[f32]) -> f32 {
dot_bf16_f32_avx2(a, b)
}
#[cfg(all(feature = "quantised", target_arch = "aarch64"))]
#[inline(always)]
fn dot_bf16_f32_neon(a: &[bf16], b: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 4;
unsafe {
let mut acc = vdupq_n_f32(0.0);
for i in 0..chunks {
let va = bf16x4_to_f32x4_neon(a.as_ptr().add(i * 4));
let vb = vld1q_f32(b.as_ptr().add(i * 4));
acc = vfmaq_f32(acc, va, vb);
}
let mut sum = hsum_f32_neon(acc);
for i in (chunks * 4)..len {
sum += a[i].to_f32() * b[i];
}
sum
}
}
#[cfg(feature = "quantised")]
#[inline(always)]
fn dot_bf16_f64_scalar(a: &[bf16], b: &[f64]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| f32::from(*x) * (*y as f32))
.sum()
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
fn dot_bf16_f64_sse(a: &[bf16], b: &[f64]) -> f32 {
let len = a.len();
let chunks = len / 4;
unsafe {
let mut acc = _mm_setzero_ps();
for i in 0..chunks {
let offset = i * 4;
let va = bf16x4_to_f32x4_sse(a.as_ptr().add(offset));
let b_lo = _mm_loadu_pd(b.as_ptr().add(offset));
let b_hi = _mm_loadu_pd(b.as_ptr().add(offset + 2));
let b_lo_f32 = _mm_cvtpd_ps(b_lo);
let b_hi_f32 = _mm_cvtpd_ps(b_hi);
let vb = _mm_movelh_ps(b_lo_f32, b_hi_f32);
acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
}
let mut sum = hsum_f32_sse(acc);
for i in (chunks * 4)..len {
sum += a[i].to_f32() * (b[i] as f32);
}
sum
}
}
#[cfg(all(feature = "quantised", target_arch = "x86_64"))]
#[inline(always)]
fn dot_bf16_f64_avx2(a: &[bf16], b: &[f64]) -> f32 {
let len = a.len();
let chunks = len / 8;
unsafe {
let mut acc = _mm256_setzero_ps();
for i in 0..chunks {
let offset = i * 8;
let va = bf16x8_to_f32x8_avx2(a.as_ptr().add(offset));
let b_lo = _mm256_loadu_pd(b.as_ptr().add(offset));
let b_hi = _mm256_loadu_pd(b.as_ptr().add(offset + 4));
let b_lo_f32 = _mm256_cvtpd_ps(b_lo);
let b_hi_f32 = _mm256_cvtpd_ps(b_hi);
let vb = _mm256_insertf128_ps(_mm256_castps128_ps256(b_lo_f32), b_hi_f32, 1);
acc = _mm256_fmadd_ps(va, vb, acc);
}
let mut sum = hsum_f32_avx2(acc);
for i in (chunks * 8)..len {
sum += a[i].to_f32() * (b[i] as f32);
}
sum
}
}
#[cfg(all(
feature = "quantised",
target_arch = "x86_64",
target_feature = "avx512f"
))]
#[inline(always)]
fn dot_bf16_f64_avx512(a: &[bf16], b: &[f64]) -> f32 {
let len = a.len();
let chunks = len / 16;
unsafe {
let mut acc = _mm512_setzero_ps();
for i in 0..chunks {
let offset = i * 16;
let va = bf16x16_to_f32x16_avx512(a.as_ptr().add(offset));
let b_lo = _mm512_loadu_pd(b.as_ptr().add(offset));
let b_hi = _mm512_loadu_pd(b.as_ptr().add(offset + 8));
let b_lo_f32 = _mm512_cvtpd_ps(b_lo);
let b_hi_f32 = _mm512_cvtpd_ps(b_hi);
let vb = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo_f32), b_hi_f32, 1);
acc = _mm512_fmadd_ps(va, vb, acc);
}
let mut sum = _mm512_reduce_add_ps(acc);
for i in (chunks * 16)..len {
sum += a[i].to_f32() * (b[i] as f32);
}
sum
}
}
#[cfg(all(
feature = "quantised",
target_arch = "x86_64",
not(target_feature = "avx512f")
))]
#[inline(always)]
fn dot_bf16_f64_avx512(a: &[bf16], b: &[f64]) -> f32 {
dot_bf16_f64_avx2(a, b)
}
#[cfg(all(feature = "quantised", target_arch = "aarch64"))]
#[inline(always)]
fn dot_bf16_f64_neon(a: &[bf16], b: &[f64]) -> f32 {
let len = a.len();
let chunks = len / 4;
unsafe {
let mut acc = vdupq_n_f32(0.0);
for i in 0..chunks {
let offset = i * 4;
let va = bf16x4_to_f32x4_neon(a.as_ptr().add(offset));
let b_lo = vld1q_f64(b.as_ptr().add(offset));
let b_hi = vld1q_f64(b.as_ptr().add(offset + 2));
let b_lo_f32 = vcvt_f32_f64(b_lo);
let b_hi_f32 = vcvt_f32_f64(b_hi);
let vb = vcombine_f32(b_lo_f32, b_hi_f32);
acc = vfmaq_f32(acc, va, vb);
}
let mut sum = hsum_f32_neon(acc);
for i in (chunks * 4)..len {
sum += a[i].to_f32() * (b[i] as f32);
}
sum
}
}
#[cfg(feature = "quantised")]
#[inline]
pub fn dot_bf16_f32_simd(a: &[bf16], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
match detect_simd_level() {
SimdLevel::Avx512 => dot_bf16_f32_avx512(a, b),
SimdLevel::Avx2 => dot_bf16_f32_avx2(a, b),
SimdLevel::Sse => dot_bf16_f32_sse(a, b),
SimdLevel::Scalar => dot_bf16_f32_scalar(a, b),
}
}
#[cfg(target_arch = "aarch64")]
{
dot_bf16_f32_neon(a, b)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
dot_bf16_f32_scalar(a, b)
}
}
#[cfg(feature = "quantised")]
#[inline]
pub fn dot_bf16_f64_simd(a: &[bf16], b: &[f64]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
match detect_simd_level() {
SimdLevel::Avx512 => dot_bf16_f64_avx512(a, b),
SimdLevel::Avx2 => dot_bf16_f64_avx2(a, b),
SimdLevel::Sse => dot_bf16_f64_sse(a, b),
SimdLevel::Scalar => dot_bf16_f64_scalar(a, b),
}
}
#[cfg(target_arch = "aarch64")]
{
dot_bf16_f64_neon(a, b)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
dot_bf16_f64_scalar(a, b)
}
}
#[cfg(feature = "quantised")]
pub trait Bf16Compatible: Float + ToPrimitive {
fn euclidean_bf16_dispatch(a: &[bf16], b: &[Self]) -> f32;
fn dot_bf16_dispatch(a: &[bf16], b: &[Self]) -> f32;
}
#[cfg(feature = "quantised")]
impl Bf16Compatible for f32 {
#[inline(always)]
fn euclidean_bf16_dispatch(a: &[bf16], b: &[Self]) -> f32 {
euclidean_bf16_f32_simd(a, b)
}
#[inline(always)]
fn dot_bf16_dispatch(a: &[bf16], b: &[Self]) -> f32 {
dot_bf16_f32_simd(a, b)
}
}
#[cfg(feature = "quantised")]
impl Bf16Compatible for f64 {
#[inline(always)]
fn euclidean_bf16_dispatch(a: &[bf16], b: &[Self]) -> f32 {
euclidean_bf16_f64_simd(a, b)
}
#[inline(always)]
fn dot_bf16_dispatch(a: &[bf16], b: &[Self]) -> f32 {
dot_bf16_f64_simd(a, b)
}
}
#[cfg(feature = "quantised")]
pub trait VectorDistanceBf16<T>
where
T: Float + Sum + FromPrimitive + ToPrimitive + SimdDistance,
{
fn vectors_flat(&self) -> &[bf16];
fn dim(&self) -> usize;
fn norms(&self) -> &[T];
#[inline(always)]
fn euclidean_distance_bf16(&self, i: usize, j: usize) -> T {
let start_i = i * self.dim();
let start_j = j * self.dim();
let vec_i = &self.vectors_flat()[start_i..start_i + self.dim()];
let vec_j = &self.vectors_flat()[start_j..start_j + self.dim()];
let result = euclidean_bf16_simd(vec_i, vec_j);
T::from_f32(result).unwrap()
}
#[inline(always)]
fn euclidean_distance_to_query_bf16<Q>(&self, internal_idx: usize, query: &[Q]) -> T
where
Q: Bf16Compatible,
{
let start = internal_idx * self.dim();
let vec = &self.vectors_flat()[start..start + self.dim()];
T::from_f32(Q::euclidean_bf16_dispatch(vec, query)).unwrap()
}
#[inline(always)]
fn euclidean_distance_to_query_dual_bf16(&self, internal_idx: usize, query: &[bf16]) -> T {
let start = internal_idx * self.dim();
let vec = &self.vectors_flat()[start..start + self.dim()];
let result = euclidean_bf16_simd(vec, query);
T::from_f32(result).unwrap()
}
#[inline(always)]
fn cosine_distance_bf16(&self, i: usize, j: usize) -> T {
let start_i = i * self.dim();
let start_j = j * self.dim();
let vec_i = &self.vectors_flat()[start_i..start_i + self.dim()];
let vec_j = &self.vectors_flat()[start_j..start_j + self.dim()];
let dot = dot_bf16_simd(vec_i, vec_j);
let norm_i = self.norms()[i].to_f32().unwrap();
let norm_j = self.norms()[j].to_f32().unwrap();
let dist = 1.0 - (dot / (norm_i * norm_j));
T::from_f32(dist).unwrap()
}
#[inline(always)]
fn cosine_distance_to_query_bf16<Q>(&self, internal_idx: usize, query: &[Q], query_norm: T) -> T
where
Q: Bf16Compatible,
{
let start = internal_idx * self.dim();
let vec = &self.vectors_flat()[start..start + self.dim()];
let dot = Q::dot_bf16_dispatch(vec, query);
let dist = 1.0
- (dot / (query_norm.to_f32().unwrap() * self.norms()[internal_idx].to_f32().unwrap()));
T::from_f32(dist).unwrap()
}
#[inline(always)]
fn cosine_distance_to_query_dual_bf16(
&self,
internal_idx: usize,
query: &[bf16],
query_norm: bf16,
) -> T {
let start = internal_idx * self.dim();
let vec = &self.vectors_flat()[start..start + self.dim()];
let dot = dot_bf16_simd(vec, query);
let norm_internal = self.norms()[internal_idx].to_f32().unwrap();
let dist = 1.0 - (dot / (query_norm.to_f32() * norm_internal));
T::from_f32(dist).unwrap()
}
}
#[cfg(feature = "quantised")]
pub trait VectorDistanceSq8<T>
where
T: Float + FromPrimitive + ToPrimitive,
{
fn vectors_flat_quantised(&self) -> &[i8];
fn norms_quantised(&self) -> &[i32];
fn dim(&self) -> usize;
#[inline(always)]
fn euclidean_distance_i8(&self, internal_idx: usize, query_i8: &[i8]) -> T {
let start = internal_idx * self.dim();
unsafe {
let db_vec = &self
.vectors_flat_quantised()
.get_unchecked(start..start + self.dim());
let sum: i32 = query_i8
.iter()
.zip(db_vec.iter())
.map(|(&q, &d)| {
let diff = q as i32 - d as i32;
diff * diff
})
.sum();
T::from_i32(sum).unwrap()
}
}
#[inline(always)]
fn cosine_distance_i8(&self, vec_idx: usize, query_i8: &[i8], query_norm_sq: i32) -> T {
let start = vec_idx * self.dim();
unsafe {
let db_vec = &self
.vectors_flat_quantised()
.get_unchecked(start..start + self.dim());
let dot: i32 = query_i8
.iter()
.zip(db_vec.iter())
.map(|(&q, &d)| q as i32 * d as i32)
.sum();
let db_norm_sq: i32 = self.norms_quantised()[vec_idx];
let query_norm = T::from_i32(query_norm_sq).unwrap().sqrt();
let db_norm = T::from_i32(db_norm_sq).unwrap().sqrt();
if query_norm > T::zero() && db_norm > T::zero() {
T::one() - T::from_i32(dot).unwrap() / (query_norm * db_norm)
} else {
T::one()
}
}
}
}
#[cfg(feature = "quantised")]
pub trait VectorDistanceAdc<T>
where
T: Float + FromPrimitive + ToPrimitive + Sum + SimdDistance,
{
fn codebook_m(&self) -> usize;
fn codebook_n_centroids(&self) -> usize;
fn codebook_subvec_dim(&self) -> usize;
fn centroids(&self) -> &[T];
fn dim(&self) -> usize;
fn codebooks(&self) -> &[Vec<T>];
fn quantised_codes(&self) -> &[u8];
fn build_lookup_tables_residual(&self, query_vec: &[T], cluster_idx: usize) -> Vec<T> {
let m = self.codebook_m();
let subvec_dim = self.codebook_subvec_dim();
let n_cents = self.codebook_n_centroids();
let centroid = &self.centroids()[cluster_idx * self.dim()..(cluster_idx + 1) * self.dim()];
let query_residual = T::subtract_simd(query_vec, centroid);
self.build_lookup_tables_impl(&query_residual, m, subvec_dim, n_cents)
}
fn build_lookup_tables_direct(&self, query_vec: &[T]) -> Vec<T> {
let m = self.codebook_m();
let subvec_dim = self.codebook_subvec_dim();
let n_cents = self.codebook_n_centroids();
self.build_lookup_tables_impl(query_vec, m, subvec_dim, n_cents)
}
fn build_lookup_tables_impl(
&self,
query_vec: &[T],
m: usize,
subvec_dim: usize,
n_cents: usize,
) -> Vec<T> {
let mut table = vec![T::zero(); m * n_cents];
for subspace in 0..m {
let query_sub = &query_vec[subspace * subvec_dim..(subspace + 1) * subvec_dim];
let table_offset = subspace * n_cents;
for centroid_idx in 0..n_cents {
let centroid_start = centroid_idx * subvec_dim;
let pq_centroid =
&self.codebooks()[subspace][centroid_start..centroid_start + subvec_dim];
let dist = T::euclidean_simd(query_sub, pq_centroid);
table[table_offset + centroid_idx] = dist;
}
}
table
}
#[inline(always)]
fn compute_distance_adc(&self, vec_idx: usize, lookup_table: &[T]) -> T {
let m = self.codebook_m();
let n_cents = self.codebook_n_centroids();
let codes_start = vec_idx * m;
let codes = &self.quantised_codes()[codes_start..codes_start + m];
match m {
8 => {
let mut sum = T::zero();
for i in 0..8 {
let code = unsafe { *codes.get_unchecked(i) } as usize;
let offset = i * n_cents + code;
sum = sum + unsafe { *lookup_table.get_unchecked(offset) };
}
sum
}
16 => {
let mut sum = T::zero();
for i in 0..16 {
let code = unsafe { *codes.get_unchecked(i) } as usize;
let offset = i * n_cents + code;
sum = sum + unsafe { *lookup_table.get_unchecked(offset) };
}
sum
}
32 => {
let mut sum = T::zero();
for i in 0..32 {
let code = unsafe { *codes.get_unchecked(i) } as usize;
let offset = i * n_cents + code;
sum = sum + unsafe { *lookup_table.get_unchecked(offset) };
}
sum
}
_ => {
codes
.iter()
.enumerate()
.map(|(subspace, &code)| {
let offset = subspace * n_cents + (code as usize);
lookup_table[offset]
})
.fold(T::zero(), |acc, x| acc + x)
}
}
}
}
#[inline(always)]
pub fn euclidean_distance_static<T>(a: &[T], b: &[T]) -> T
where
T: Float + SimdDistance,
{
assert!(a.len() == b.len(), "Vectors a and b need to have same len!");
T::euclidean_simd(a, b)
}
#[inline(always)]
pub fn cosine_distance_static<T>(a: &[T], b: &[T]) -> T
where
T: Float + SimdDistance,
{
assert!(a.len() == b.len(), "Vectors a and b need to have same len!");
let dot: T = T::dot_simd(a, b);
let norm_a = T::calculate_l2_norm(a);
let norm_b = T::calculate_l2_norm(b);
T::one() - (dot / (norm_a * norm_b))
}
pub fn cosine_distance_static_norm<T>(a: &[T], b: &[T], norm_a: &T, norm_b: &T) -> T
where
T: Float + SimdDistance,
{
assert!(a.len() == b.len(), "Vectors a and b need to have same len!");
let dot: T = T::dot_simd(a, b);
T::one() - (dot / (*norm_a * *norm_b))
}
#[inline(always)]
pub fn normalise_vector<T>(vec: &mut [T])
where
T: Float + Sum + SimdDistance,
{
let norm = compute_l2_norm(vec);
if norm > T::zero() {
vec.iter_mut().for_each(|v| *v = *v / norm);
}
}
#[inline(always)]
pub fn compute_l2_norm<T>(vec: &[T]) -> T
where
T: Float + SimdDistance,
{
T::calculate_l2_norm(vec)
}
#[inline(always)]
pub fn compute_norm_row<T>(row: RowRef<T>) -> T
where
T: Float + SimdDistance,
{
if row.col_stride() == 1 {
let slice = unsafe { std::slice::from_raw_parts(row.as_ptr(), row.ncols()) };
return T::calculate_l2_norm(slice);
}
let vec: Vec<T> = row.iter().cloned().collect();
T::calculate_l2_norm(&vec)
}
#[inline(always)]
pub fn compute_l1_norm<T>(vec: &[T]) -> T
where
T: Float + SimdDistance,
{
T::calculate_l1_norm(vec)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
struct TestVectors {
data: Vec<f32>,
dim: usize,
norms: Vec<f32>,
}
impl VectorDistance<f32> for TestVectors {
fn vectors_flat(&self) -> &[f32] {
&self.data
}
fn dim(&self) -> usize {
self.dim
}
fn norms(&self) -> &[f32] {
&self.norms
}
}
#[test]
fn test_parse_ann_dist_euclidean() {
assert_eq!(parse_ann_dist("euclidean"), Some(Dist::Euclidean));
assert_eq!(parse_ann_dist("Euclidean"), Some(Dist::Euclidean));
assert_eq!(parse_ann_dist("EUCLIDEAN"), Some(Dist::Euclidean));
}
#[test]
fn test_parse_ann_dist_cosine() {
assert_eq!(parse_ann_dist("cosine"), Some(Dist::Cosine));
assert_eq!(parse_ann_dist("Cosine"), Some(Dist::Cosine));
assert_eq!(parse_ann_dist("COSINE"), Some(Dist::Cosine));
}
#[test]
fn test_parse_ann_dist_invalid() {
assert_eq!(parse_ann_dist("manhattan"), None);
assert_eq!(parse_ann_dist(""), None);
assert_eq!(parse_ann_dist("cosine "), None); assert_eq!(parse_ann_dist(" euclidean"), None); }
#[test]
fn test_euclidean_distance_basic() {
let data = vec![
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, ];
let vecs = TestVectors {
data,
dim: 3,
norms: vec![],
};
let dist_01 = vecs.euclidean_distance(0, 1);
assert_relative_eq!(dist_01, 2.0, epsilon = 1e-6);
let dist_02 = vecs.euclidean_distance(0, 2);
assert_relative_eq!(dist_02, 1.0, epsilon = 1e-6);
let dist_00 = vecs.euclidean_distance(0, 0);
assert_relative_eq!(dist_00, 0.0, epsilon = 1e-6);
}
#[test]
fn test_euclidean_distance_symmetry() {
let data = vec![2.0, 3.0, 5.0, 1.0, 4.0, 2.0];
let vecs = TestVectors {
data,
dim: 3,
norms: vec![],
};
let dist_01 = vecs.euclidean_distance(0, 1);
let dist_10 = vecs.euclidean_distance(1, 0);
assert_relative_eq!(dist_01, dist_10, epsilon = 1e-6);
}
#[test]
fn test_euclidean_distance_unrolled() {
let data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 4.0, 3.0, 2.0, 1.0,
];
let vecs = TestVectors {
data,
dim: 5,
norms: vec![],
};
let dist = vecs.euclidean_distance(0, 1);
assert_relative_eq!(dist, 40.0, epsilon = 1e-6);
}
#[test]
fn test_cosine_distance_basic() {
let data = vec![
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, ];
let norm0 = (1.0_f32 * 1.0 + 0.0 * 0.0 + 0.0 * 0.0).sqrt();
let norm1 = (0.0_f32 * 0.0 + 1.0 * 1.0 + 0.0 * 0.0).sqrt();
let norm2 = (1.0_f32 * 1.0 + 1.0 * 1.0 + 0.0 * 0.0).sqrt();
let vecs = TestVectors {
data,
dim: 3,
norms: vec![norm0, norm1, norm2],
};
let dist_01 = vecs.cosine_distance(0, 1);
assert_relative_eq!(dist_01, 1.0, epsilon = 1e-6);
let dist_02 = vecs.cosine_distance(0, 2);
assert_relative_eq!(dist_02, 1.0 - 1.0 / 2.0_f32.sqrt(), epsilon = 1e-5);
let dist_00 = vecs.cosine_distance(0, 0);
assert_relative_eq!(dist_00, 0.0, epsilon = 1e-6);
}
#[test]
fn test_cosine_norm() {
let data_1 = vec![1.5, 2.5, 2.0];
let data_2 = vec![2.5, 0.5, 1.0];
let data_3 = vec![1.0, 0.0, 1.0];
let norm_1 = &data_1.iter().map(|x| *x * *x).sum::<f64>().sqrt();
let norm_2 = &data_2.iter().map(|x| *x * *x).sum::<f64>().sqrt();
let norm_3 = &data_3.iter().map(|x| *x * *x).sum::<f64>().sqrt();
assert_relative_eq!(*norm_1, compute_l2_norm(&data_1), epsilon = 1e-5);
assert_relative_eq!(*norm_2, compute_l2_norm(&data_2), epsilon = 1e-5);
assert_relative_eq!(*norm_3, compute_l2_norm(&data_3), epsilon = 1e-5);
}
#[test]
fn test_cosine_distance_symmetry() {
let data = vec![2.0, 3.0, 5.0, 1.0, 4.0, 2.0];
let norm0 = (2.0_f32 * 2.0 + 3.0 * 3.0 + 5.0 * 5.0).sqrt();
let norm1 = (1.0_f32 * 1.0 + 4.0 * 4.0 + 2.0 * 2.0).sqrt();
let vecs = TestVectors {
data,
dim: 3,
norms: vec![norm0, norm1],
};
let dist_01 = vecs.cosine_distance(0, 1);
let dist_10 = vecs.cosine_distance(1, 0);
assert_relative_eq!(dist_01, dist_10, epsilon = 1e-6);
}
#[test]
fn test_cosine_distance_unrolled() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let norm0 = (1.0_f32 + 4.0 + 9.0 + 16.0 + 25.0).sqrt();
let norm1 = (25.0_f32 + 16.0 + 9.0 + 4.0 + 1.0).sqrt();
let vecs = TestVectors {
data,
dim: 5,
norms: vec![norm0, norm1],
};
let dist = vecs.cosine_distance(0, 1);
let expected = 1.0 - (35.0 / (norm0 * norm1));
assert_relative_eq!(dist, expected, epsilon = 1e-5);
}
#[test]
fn test_parallel_vectors() {
let data = vec![
1.0, 2.0, 3.0, 2.0, 4.0, 6.0, ];
let norm0 = (1.0_f32 + 4.0 + 9.0).sqrt();
let norm1 = (4.0_f32 + 16.0 + 36.0).sqrt();
let vecs = TestVectors {
data,
dim: 3,
norms: vec![norm0, norm1],
};
let dist = vecs.cosine_distance(0, 1);
assert_relative_eq!(dist, 0.0, epsilon = 1e-5);
}
#[test]
fn test_opposite_vectors() {
let data = vec![
1.0, 2.0, 3.0, -1.0, -2.0, -3.0, ];
let norm0 = (1.0_f32 + 4.0 + 9.0).sqrt();
let norm1 = (1.0_f32 + 4.0 + 9.0).sqrt();
let vecs = TestVectors {
data,
dim: 3,
norms: vec![norm0, norm1],
};
let dist = vecs.cosine_distance(0, 1);
assert_relative_eq!(dist, 2.0, epsilon = 1e-5);
}
#[test]
fn test_large_dimension() {
let dim = 100;
let mut data = Vec::with_capacity(dim * 2);
for i in 0..dim {
data.push(i as f32);
}
for i in 0..dim {
data.push((dim - i) as f32);
}
let norm0 = data[0..dim].iter().map(|x| x * x).sum::<f32>().sqrt();
let norm1 = data[dim..].iter().map(|x| x * x).sum::<f32>().sqrt();
let vecs = TestVectors {
data,
dim,
norms: vec![norm0, norm1],
};
let dist_01 = vecs.euclidean_distance(0, 1);
let dist_10 = vecs.euclidean_distance(1, 0);
assert_relative_eq!(dist_01, dist_10, epsilon = 1e-3);
let cos_01 = vecs.cosine_distance(0, 1);
let cos_10 = vecs.cosine_distance(1, 0);
assert_relative_eq!(cos_01, cos_10, epsilon = 1e-5);
}
#[test]
fn test_euclidean_distance_to_query() {
let data = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
let vecs = TestVectors {
data,
dim: 3,
norms: vec![],
};
let query = vec![1.0, 1.0, 0.0];
let dist_0 = vecs.euclidean_distance_to_query(0, &query);
assert_relative_eq!(dist_0, 1.0, epsilon = 1e-6);
let dist_1 = vecs.euclidean_distance_to_query(1, &query);
assert_relative_eq!(dist_1, 1.0, epsilon = 1e-6);
}
#[test]
fn test_cosine_distance_to_query() {
let data = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
let norm0 = 1.0;
let norm1 = 1.0;
let norm2 = 2.0_f32.sqrt();
let vecs = TestVectors {
data,
dim: 3,
norms: vec![norm0, norm1, norm2],
};
let query = vec![1.0, 1.0, 0.0];
let query_norm = 2.0_f32.sqrt();
let dist_0 = vecs.cosine_distance_to_query(0, &query, query_norm);
assert_relative_eq!(dist_0, 1.0 - 1.0 / 2.0_f32.sqrt(), epsilon = 1e-6);
let dist_2 = vecs.cosine_distance_to_query(2, &query, query_norm);
assert_relative_eq!(dist_2, 0.0, epsilon = 1e-6);
}
#[test]
fn test_euclidean_distance_static() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let dist = euclidean_distance_static(&a, &b);
assert_relative_eq!(dist, 27.0, epsilon = 1e-6);
let dist_self = euclidean_distance_static(&a, &a);
assert_relative_eq!(dist_self, 0.0, epsilon = 1e-6);
}
#[test]
fn test_cosine_distance_static() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let dist = cosine_distance_static(&a, &b);
assert_relative_eq!(dist, 1.0, epsilon = 1e-6);
let c = vec![2.0, 0.0, 0.0];
let dist_parallel = cosine_distance_static(&a, &c);
assert_relative_eq!(dist_parallel, 0.0, epsilon = 1e-6);
}
#[test]
fn test_normalise_vector() {
let mut vec = vec![3.0, 4.0, 0.0];
normalise_vector(&mut vec);
assert_relative_eq!(vec[0], 0.6, epsilon = 1e-6);
assert_relative_eq!(vec[1], 0.8, epsilon = 1e-6);
assert_relative_eq!(vec[2], 0.0, epsilon = 1e-6);
let norm = vec.iter().map(|&x| x * x).sum::<f32>().sqrt();
assert_relative_eq!(norm, 1.0, epsilon = 1e-6);
}
#[test]
fn test_normalise_vector_zero() {
let mut vec = vec![0.0, 0.0, 0.0];
normalise_vector(&mut vec);
assert_relative_eq!(vec[0], 0.0, epsilon = 1e-6);
assert_relative_eq!(vec[1], 0.0, epsilon = 1e-6);
assert_relative_eq!(vec[2], 0.0, epsilon = 1e-6);
}
#[test]
fn test_add_simd_f32_basic() {
let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b: Vec<f32> = vec![5.0, 4.0, 3.0, 2.0, 1.0];
let result = f32::add_simd(&a, &b);
assert_eq!(result.len(), 5);
for i in 0..5 {
assert_relative_eq!(result[i], 6.0, epsilon = 1e-6);
}
}
#[test]
fn test_add_simd_f32_zeros() {
let a: Vec<f32> = vec![1.0, -2.0, 3.0];
let b: Vec<f32> = vec![0.0, 0.0, 0.0];
let result = f32::add_simd(&a, &b);
assert_relative_eq!(result[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(result[1], -2.0, epsilon = 1e-6);
assert_relative_eq!(result[2], 3.0, epsilon = 1e-6);
}
#[test]
fn test_add_simd_f32_negatives() {
let a: Vec<f32> = vec![1.0, 2.0, 3.0];
let b: Vec<f32> = vec![-1.0, -2.0, -3.0];
let result = f32::add_simd(&a, &b);
for val in &result {
assert_relative_eq!(*val, 0.0, epsilon = 1e-6);
}
}
#[test]
fn test_add_simd_f32_large_dimension() {
let dim = 128;
let a: Vec<f32> = (0..dim).map(|i| i as f32).collect();
let b: Vec<f32> = (0..dim).map(|i| (dim - i) as f32).collect();
let result = f32::add_simd(&a, &b);
assert_eq!(result.len(), dim);
for val in &result {
assert_relative_eq!(*val, dim as f32, epsilon = 1e-6);
}
}
#[test]
fn test_add_simd_f64_basic() {
let a: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b: Vec<f64> = vec![5.0, 4.0, 3.0, 2.0, 1.0];
let result = f64::add_simd(&a, &b);
assert_eq!(result.len(), 5);
for i in 0..5 {
assert_relative_eq!(result[i], 6.0, epsilon = 1e-10);
}
}
#[test]
fn test_add_assign_simd_f32_basic() {
let mut dst: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let src: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0, 50.0];
f32::add_assign_simd(&mut dst, &src);
assert_relative_eq!(dst[0], 11.0, epsilon = 1e-6);
assert_relative_eq!(dst[1], 22.0, epsilon = 1e-6);
assert_relative_eq!(dst[2], 33.0, epsilon = 1e-6);
assert_relative_eq!(dst[3], 44.0, epsilon = 1e-6);
assert_relative_eq!(dst[4], 55.0, epsilon = 1e-6);
}
#[test]
fn test_add_assign_simd_f32_zeros() {
let mut dst: Vec<f32> = vec![1.0, 2.0, 3.0];
let src: Vec<f32> = vec![0.0, 0.0, 0.0];
f32::add_assign_simd(&mut dst, &src);
assert_relative_eq!(dst[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(dst[1], 2.0, epsilon = 1e-6);
assert_relative_eq!(dst[2], 3.0, epsilon = 1e-6);
}
#[test]
fn test_add_assign_simd_f32_accumulate() {
let mut dst: Vec<f32> = vec![0.0; 8];
let src: Vec<f32> = vec![1.0; 8];
for _ in 0..100 {
f32::add_assign_simd(&mut dst, &src);
}
for val in &dst {
assert_relative_eq!(*val, 100.0, epsilon = 1e-4);
}
}
#[test]
fn test_add_assign_simd_f32_large_dimension() {
let dim = 128;
let mut dst: Vec<f32> = vec![1.0; dim];
let src: Vec<f32> = vec![2.0; dim];
f32::add_assign_simd(&mut dst, &src);
for val in &dst {
assert_relative_eq!(*val, 3.0, epsilon = 1e-6);
}
}
#[test]
fn test_add_assign_simd_f64_basic() {
let mut dst: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let src: Vec<f64> = vec![10.0, 20.0, 30.0, 40.0, 50.0];
f64::add_assign_simd(&mut dst, &src);
assert_relative_eq!(dst[0], 11.0, epsilon = 1e-10);
assert_relative_eq!(dst[1], 22.0, epsilon = 1e-10);
assert_relative_eq!(dst[2], 33.0, epsilon = 1e-10);
assert_relative_eq!(dst[3], 44.0, epsilon = 1e-10);
assert_relative_eq!(dst[4], 55.0, epsilon = 1e-10);
}
#[test]
fn test_add_assign_simd_f64_accumulate() {
let mut dst: Vec<f64> = vec![0.0; 8];
let src: Vec<f64> = vec![1.0; 8];
for _ in 0..100 {
f64::add_assign_simd(&mut dst, &src);
}
for val in &dst {
assert_relative_eq!(*val, 100.0, epsilon = 1e-10);
}
}
#[cfg(feature = "quantised")]
mod quantised_tests {
use super::*;
struct TestVectorsSq8 {
data: Vec<i8>,
norms: Vec<i32>,
dim: usize,
}
impl VectorDistanceSq8<f32> for TestVectorsSq8 {
fn vectors_flat_quantised(&self) -> &[i8] {
&self.data
}
fn norms_quantised(&self) -> &[i32] {
&self.norms
}
fn dim(&self) -> usize {
self.dim
}
}
#[test]
fn test_euclidean_distance_i8() {
let data = vec![127, 0, 0, 0, 127, 0];
let vecs = TestVectorsSq8 {
data,
norms: vec![],
dim: 3,
};
let query = vec![127, 127, 0];
let dist = vecs.euclidean_distance_i8(0, &query);
assert_relative_eq!(dist, 16129.0, epsilon = 1e-3);
}
#[test]
fn test_cosine_distance_i8() {
let data = vec![127, 0, 0, 0, 127, 0, 127, 127, 0];
let norm0 = 127 * 127;
let norm1 = 127 * 127;
let norm2 = 127 * 127 + 127 * 127;
let vecs = TestVectorsSq8 {
data,
norms: vec![norm0, norm1, norm2],
dim: 3,
};
let query = vec![127, 127, 0];
let query_norm_sq = 127 * 127 + 127 * 127;
let dist_0 = vecs.cosine_distance_i8(0, &query, query_norm_sq);
assert_relative_eq!(dist_0, 1.0 - 1.0 / 2.0_f32.sqrt(), epsilon = 1e-5);
let dist_2 = vecs.cosine_distance_i8(2, &query, query_norm_sq);
assert_relative_eq!(dist_2, 0.0, epsilon = 1e-5);
}
}
}