use alloc::collections::BinaryHeap;
use alloc::vec;
use alloc::vec::Vec;
use core::fmt::{self, Debug};
use crate::vector::SQVec;
#[cfg(all(target_arch = "x86_64", feature = "std"))]
mod simd_x86;
#[inline]
fn sqrt_f32(x: f32) -> f32 {
#[cfg(feature = "std")]
{
x.sqrt()
}
#[cfg(not(feature = "std"))]
{
if x < 0.0 || x.is_nan() {
return f32::NAN;
}
if x == 0.0 || x.is_infinite() {
return x;
}
if x < f32::MIN_POSITIVE {
return sqrt_f32(x * 16_777_216.0) / 4096.0;
}
let bits = x.to_bits();
let guess_bits = (bits >> 1) + 0x1FC0_0000;
let mut guess = f32::from_bits(guess_bits);
guess = 0.5 * (guess + x / guess);
guess = 0.5 * (guess + x / guess);
guess = 0.5 * (guess + x / guess);
guess = 0.5 * (guess + x / guess);
guess = 0.5 * (guess + x / guess);
guess
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DistanceMetric {
Cosine,
EuclideanSq,
DotProduct,
Manhattan,
}
impl DistanceMetric {
#[inline]
pub fn compute(&self, a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::MAX;
}
let d = match self {
Self::Cosine => cosine_distance(a, b),
Self::EuclideanSq => euclidean_distance_sq(a, b),
Self::DotProduct => -dot_product(a, b),
Self::Manhattan => manhattan_distance(a, b),
};
if d.is_nan() { f32::MAX } else { d }
}
}
impl fmt::Display for DistanceMetric {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Cosine => f.write_str("cosine"),
Self::EuclideanSq => f.write_str("euclidean_sq"),
Self::DotProduct => f.write_str("dot_product"),
Self::Manhattan => f.write_str("manhattan"),
}
}
}
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "dot_product: dimension mismatch");
#[cfg(all(target_arch = "x86_64", feature = "std"))]
{
if is_x86_feature_detected!("avx2") {
return unsafe { simd_x86::dot_product_avx2(a, b) };
}
}
dot_product_scalar(a, b)
}
#[inline]
fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
pub fn euclidean_distance_sq(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(
a.len(),
b.len(),
"euclidean_distance_sq: dimension mismatch"
);
#[cfg(all(target_arch = "x86_64", feature = "std"))]
{
if is_x86_feature_detected!("avx2") {
return unsafe { simd_x86::euclidean_distance_sq_avx2(a, b) };
}
}
euclidean_distance_sq_scalar(a, b)
}
#[inline]
fn euclidean_distance_sq_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
}
#[inline]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "cosine_similarity: dimension mismatch");
#[cfg(all(target_arch = "x86_64", feature = "std"))]
{
if is_x86_feature_detected!("avx2") {
return unsafe { simd_x86::cosine_similarity_avx2(a, b) };
}
}
cosine_similarity_scalar(a, b)
}
#[inline]
fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..a.len() {
let x = a[i];
let y = b[i];
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = sqrt_f32(norm_a) * sqrt_f32(norm_b);
if denom == 0.0 {
0.0
} else {
(dot / denom).clamp(-1.0, 1.0)
}
}
#[inline]
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
1.0 - cosine_similarity(a, b)
}
#[inline]
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "manhattan_distance: dimension mismatch");
#[cfg(all(target_arch = "x86_64", feature = "std"))]
{
if is_x86_feature_detected!("avx2") {
return unsafe { simd_x86::manhattan_distance_avx2(a, b) };
}
}
manhattan_distance_scalar(a, b)
}
#[inline]
fn manhattan_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
}
#[inline]
pub fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
let len = a.len().min(b.len());
let a = &a[..len];
let b = &b[..len];
#[cfg(all(target_arch = "x86_64", feature = "std"))]
{
if is_x86_feature_detected!("avx2") {
return unsafe { simd_x86::hamming_distance_avx2(a, b) };
}
}
hamming_distance_scalar(a, b)
}
#[inline]
fn hamming_distance_scalar(a: &[u8], b: &[u8]) -> u32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x ^ y).count_ones())
.sum()
}
#[inline]
pub fn l2_norm(v: &[f32]) -> f32 {
sqrt_f32(v.iter().map(|x| x * x).sum::<f32>())
}
#[inline]
pub fn l2_normalize(v: &mut [f32]) {
let norm = l2_norm(v);
if norm.is_finite() && norm > 0.0 {
let inv = 1.0 / norm;
for x in v.iter_mut() {
*x *= inv;
}
} else if !norm.is_finite() {
let max_abs = v.iter().fold(0.0f32, |acc, &x| {
let a = x.abs();
if a > acc { a } else { acc }
});
if max_abs == 0.0 || !max_abs.is_finite() {
return;
}
let inv_max = 1.0 / max_abs;
for x in v.iter_mut() {
*x *= inv_max;
}
let scaled_norm = l2_norm(v);
if scaled_norm.is_finite() && scaled_norm > 0.0 {
let inv = 1.0 / scaled_norm;
for x in v.iter_mut() {
*x *= inv;
}
}
}
}
#[inline]
pub fn l2_normalized(v: &[f32]) -> Vec<f32> {
let mut out = v.to_vec();
l2_normalize(&mut out);
out
}
pub fn quantize_binary(v: &[f32]) -> Vec<u8> {
let byte_count = v.len().div_ceil(8);
let mut result = vec![0u8; byte_count];
for (i, &val) in v.iter().enumerate() {
if val > 0.0 {
let byte_idx = i / 8;
let bit_idx = 7 - (i % 8); result[byte_idx] |= 1 << bit_idx;
}
}
result
}
pub fn quantize_scalar<const N: usize>(v: &[f32; N]) -> SQVec<N> {
let mut min_val = f32::INFINITY;
let mut max_val = f32::NEG_INFINITY;
for &x in v {
if x < min_val {
min_val = x;
}
if x > max_val {
max_val = x;
}
}
let mut codes = [0u8; N];
let range = max_val - min_val;
if !range.is_finite() {
return SQVec {
min_val: 0.0,
max_val: 0.0,
codes,
};
}
if range >= f32::MIN_POSITIVE {
let inv_range = 255.0 / range;
if inv_range.is_finite() {
for (i, &x) in v.iter().enumerate() {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let q = ((x - min_val) * inv_range + 0.5) as u8;
codes[i] = q;
}
}
}
SQVec {
min_val,
max_val,
codes,
}
}
#[inline]
pub fn dequantize_scalar<const N: usize>(sq: &SQVec<N>) -> [f32; N] {
sq.dequantize()
}
#[inline]
pub fn sq_euclidean_distance_sq<const N: usize>(query: &[f32; N], sq: &SQVec<N>) -> f32 {
let range = sq.max_val - sq.min_val;
if !range.is_finite() {
return f32::MAX;
}
if range == 0.0 {
let d: f32 = query
.iter()
.map(|&q| {
let diff = q - sq.min_val;
diff * diff
})
.sum();
return if d.is_nan() { f32::MAX } else { d };
}
let scale = range / 255.0;
let mut sum = 0.0f32;
for (i, &q) in query.iter().enumerate() {
let dequant = sq.min_val + f32::from(sq.codes[i]) * scale;
let diff = q - dequant;
sum += diff * diff;
}
if sum.is_nan() { f32::MAX } else { sum }
}
#[inline]
pub fn sq_dot_product<const N: usize>(query: &[f32; N], sq: &SQVec<N>) -> f32 {
let range = sq.max_val - sq.min_val;
if !range.is_finite() {
return 0.0;
}
if range == 0.0 {
let d = query.iter().sum::<f32>() * sq.min_val;
return if d.is_nan() { 0.0 } else { d };
}
let scale = range / 255.0;
let mut sum = 0.0f32;
for (i, &q) in query.iter().enumerate() {
let dequant = sq.min_val + f32::from(sq.codes[i]) * scale;
sum += q * dequant;
}
if sum.is_nan() { 0.0 } else { sum }
}
#[derive(Debug, Clone)]
pub struct Neighbor<K> {
pub key: K,
pub distance: f32,
}
impl<K> PartialEq for Neighbor<K> {
fn eq(&self, other: &Self) -> bool {
self.distance.to_bits() == other.distance.to_bits()
}
}
impl<K> Eq for Neighbor<K> {}
impl<K> PartialOrd for Neighbor<K> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<K> Ord for Neighbor<K> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.distance.total_cmp(&other.distance)
}
}
pub fn nearest_k<K, I, F>(iter: I, query: &[f32], k: usize, distance_fn: F) -> Vec<Neighbor<K>>
where
I: Iterator<Item = (K, Vec<f32>)>,
F: Fn(&[f32], &[f32]) -> f32,
{
if k == 0 {
return Vec::new();
}
let mut heap: BinaryHeap<Neighbor<K>> = BinaryHeap::with_capacity(k + 1);
for (key, vec) in iter {
let dist = distance_fn(query, &vec);
if heap.len() < k {
heap.push(Neighbor {
key,
distance: dist,
});
} else if heap
.peek()
.is_some_and(|worst| dist.total_cmp(&worst.distance).is_lt())
{
heap.pop();
heap.push(Neighbor {
key,
distance: dist,
});
}
}
let mut results: Vec<Neighbor<K>> = heap.into_vec();
results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
results
}
pub fn nearest_k_fixed<K, I, F, const N: usize>(
iter: I,
query: &[f32; N],
k: usize,
distance_fn: F,
) -> Vec<Neighbor<K>>
where
I: Iterator<Item = (K, [f32; N])>,
F: Fn(&[f32], &[f32]) -> f32,
{
if k == 0 {
return Vec::new();
}
let mut heap: BinaryHeap<Neighbor<K>> = BinaryHeap::with_capacity(k + 1);
for (key, vec) in iter {
let dist = distance_fn(query.as_slice(), vec.as_slice());
if heap.len() < k {
heap.push(Neighbor {
key,
distance: dist,
});
} else if heap
.peek()
.is_some_and(|worst| dist.total_cmp(&worst.distance).is_lt())
{
heap.pop();
heap.push(Neighbor {
key,
distance: dist,
});
}
}
let mut results: Vec<Neighbor<K>> = heap.into_vec();
results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
results
}
#[inline]
pub fn write_f32_le(dest: &mut [u8], values: &[f32]) {
let count = (dest.len() / 4).min(values.len());
#[cfg(target_endian = "little")]
{
let byte_len = count * 4;
unsafe {
core::ptr::copy_nonoverlapping(
values.as_ptr().cast::<u8>(),
dest.as_mut_ptr(),
byte_len,
);
}
}
#[cfg(not(target_endian = "little"))]
{
for (i, val) in values.iter().enumerate().take(count) {
let start = i * 4;
dest[start..start + 4].copy_from_slice(&val.to_le_bytes());
}
}
}
#[inline]
pub fn read_f32_le(src: &[u8]) -> Vec<f32> {
let usable = src.len() - (src.len() % 4);
let count = usable / 4;
#[cfg(target_endian = "little")]
{
let mut result = vec![0.0f32; count];
unsafe {
core::ptr::copy_nonoverlapping(src.as_ptr(), result.as_mut_ptr().cast::<u8>(), usable);
}
result
}
#[cfg(not(target_endian = "little"))]
{
let mut result = Vec::with_capacity(count);
for i in 0..count {
let start = i * 4;
let bytes: [u8; 4] = src[start..start + 4].try_into().unwrap_or([0u8; 4]);
result.push(f32::from_le_bytes(bytes));
}
result
}
}
#[cfg(test)]
#[allow(
clippy::float_cmp,
clippy::cast_precision_loss,
clippy::cast_possible_truncation
)]
mod tests {
use super::*;
const DIMS: &[usize] = &[1, 7, 8, 15, 16, 31, 32, 128, 384, 768];
fn make_vecs(dim: usize) -> (Vec<f32>, Vec<f32>) {
let a: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.1 - 5.0).collect();
let b: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.2 + 1.0).collect();
(a, b)
}
fn assert_close(actual: f32, expected: f32, tol: f32, label: &str, dim: usize) {
let diff = (actual - expected).abs();
let scale = expected.abs().max(1.0);
assert!(
diff < tol * scale,
"{label} dim={dim}: expected={expected}, actual={actual}, diff={diff}"
);
}
#[test]
fn dot_product_matches_scalar() {
for &dim in DIMS {
let (a, b) = make_vecs(dim);
let scalar = dot_product_scalar(&a, &b);
let result = dot_product(&a, &b);
assert_close(result, scalar, 1e-5, "dot_product", dim);
}
}
#[test]
fn euclidean_distance_sq_matches_scalar() {
for &dim in DIMS {
let (a, b) = make_vecs(dim);
let scalar = euclidean_distance_sq_scalar(&a, &b);
let result = euclidean_distance_sq(&a, &b);
assert_close(result, scalar, 1e-5, "euclidean_distance_sq", dim);
}
}
#[test]
fn cosine_similarity_matches_scalar() {
for &dim in DIMS {
let (a, b) = make_vecs(dim);
let scalar = cosine_similarity_scalar(&a, &b);
let result = cosine_similarity(&a, &b);
assert_close(result, scalar, 1e-5, "cosine_similarity", dim);
}
}
#[test]
fn manhattan_distance_matches_scalar() {
for &dim in DIMS {
let (a, b) = make_vecs(dim);
let scalar = manhattan_distance_scalar(&a, &b);
let result = manhattan_distance(&a, &b);
assert_close(result, scalar, 1e-5, "manhattan_distance", dim);
}
}
#[test]
fn hamming_distance_matches_scalar() {
for dim in [1usize, 7, 8, 15, 16, 31, 32, 64, 128, 256] {
let a: Vec<u8> = (0..dim).map(|i| (i * 37 + 13) as u8).collect();
let b: Vec<u8> = (0..dim).map(|i| (i * 53 + 7) as u8).collect();
let scalar = hamming_distance_scalar(&a, &b);
let result = hamming_distance(&a, &b);
assert_eq!(
result, scalar,
"hamming_distance dim={dim}: scalar={scalar}, simd={result}"
);
}
}
#[test]
fn dot_product_zero_vectors() {
let a = vec![0.0f32; 128];
let b = vec![0.0f32; 128];
assert_eq!(dot_product(&a, &b), 0.0);
}
#[test]
fn cosine_similarity_zero_vector() {
let a = vec![0.0f32; 32];
let b = vec![1.0f32; 32];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn cosine_similarity_identical() {
let a: Vec<f32> = (0..64).map(|i| (i as f32) * 0.3 + 0.1).collect();
let result = cosine_similarity(&a, &a);
assert!(
(result - 1.0).abs() < 1e-6,
"identical vectors: sim={result}"
);
}
#[test]
fn cosine_similarity_opposite() {
let a: Vec<f32> = (0..64).map(|i| (i as f32) * 0.3 + 0.1).collect();
let b: Vec<f32> = a.iter().map(|x| -x).collect();
let result = cosine_similarity(&a, &b);
assert!(
(result - (-1.0)).abs() < 1e-6,
"opposite vectors: sim={result}"
);
}
#[test]
fn hamming_distance_known_pattern() {
let a = vec![0xFF_u8; 32];
let b = vec![0x00_u8; 32];
assert_eq!(hamming_distance(&a, &b), 32 * 8);
}
#[test]
fn hamming_distance_identical() {
let a = vec![0xAB_u8; 64];
assert_eq!(hamming_distance(&a, &a), 0);
}
#[test]
fn euclidean_distance_sq_identical() {
let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
assert_eq!(euclidean_distance_sq(&a, &a), 0.0);
}
#[test]
fn manhattan_distance_identical() {
let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
assert_eq!(manhattan_distance(&a, &a), 0.0);
}
#[test]
#[should_panic(expected = "dimension mismatch")]
fn dot_product_dimension_mismatch_panics() {
let a = vec![1.0f32; 10];
let b = vec![1.0f32; 11];
dot_product(&a, &b);
}
#[test]
#[should_panic(expected = "dimension mismatch")]
fn euclidean_dimension_mismatch_panics() {
let a = vec![1.0f32; 10];
let b = vec![1.0f32; 11];
euclidean_distance_sq(&a, &b);
}
#[test]
fn distance_metric_nan_returns_max() {
let a = [1.0f32, f32::NAN, 3.0];
let b = [4.0f32, 5.0, 6.0];
let d = DistanceMetric::EuclideanSq.compute(&a, &b);
assert_eq!(d, f32::MAX);
}
#[test]
fn distance_metric_mismatch_returns_max() {
let a = [1.0f32, 2.0];
let b = [1.0f32, 2.0, 3.0];
let d = DistanceMetric::Cosine.compute(&a, &b);
assert_eq!(d, f32::MAX);
}
#[test]
fn write_read_f32_le_roundtrip() {
let values: Vec<f32> = (0..100).map(|i| (i as f32) * 0.123 - 6.0).collect();
let mut buf = vec![0u8; values.len() * 4];
write_f32_le(&mut buf, &values);
let decoded = read_f32_le(&buf);
assert_eq!(decoded, values);
}
}