use alloc::collections::BinaryHeap;
use alloc::vec;
use alloc::vec::Vec;
use core::fmt::{self, Debug};
use crate::vector::SQVec;
#[inline]
fn sqrt_f32(x: f32) -> f32 {
#[cfg(feature = "std")]
{
x.sqrt()
}
#[cfg(not(feature = "std"))]
{
if x < 0.0 || x.is_nan() {
return f32::NAN;
}
if x == 0.0 || x.is_infinite() {
return x;
}
let bits = x.to_bits();
let guess_bits = (bits >> 1) + 0x1FC0_0000;
let mut guess = f32::from_bits(guess_bits);
guess = 0.5 * (guess + x / guess);
guess = 0.5 * (guess + x / guess);
guess = 0.5 * (guess + x / guess);
guess = 0.5 * (guess + x / guess);
guess = 0.5 * (guess + x / guess);
guess
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DistanceMetric {
Cosine,
EuclideanSq,
DotProduct,
Manhattan,
}
impl DistanceMetric {
#[inline]
pub fn compute(&self, a: &[f32], b: &[f32]) -> f32 {
match self {
Self::Cosine => cosine_distance(a, b),
Self::EuclideanSq => euclidean_distance_sq(a, b),
Self::DotProduct => -dot_product(a, b),
Self::Manhattan => manhattan_distance(a, b),
}
}
}
impl fmt::Display for DistanceMetric {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Cosine => f.write_str("cosine"),
Self::EuclideanSq => f.write_str("euclidean_sq"),
Self::DotProduct => f.write_str("dot_product"),
Self::Manhattan => f.write_str("manhattan"),
}
}
}
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(
a.len(),
b.len(),
"dot_product: dimension mismatch ({} vs {})",
a.len(),
b.len()
);
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
pub fn euclidean_distance_sq(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(
a.len(),
b.len(),
"euclidean_distance_sq: dimension mismatch ({} vs {})",
a.len(),
b.len()
);
a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
}
#[inline]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(
a.len(),
b.len(),
"cosine_similarity: dimension mismatch ({} vs {})",
a.len(),
b.len()
);
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = sqrt_f32(norm_a) * sqrt_f32(norm_b);
if denom == 0.0 { 0.0 } else { dot / denom }
}
#[inline]
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
1.0 - cosine_similarity(a, b)
}
#[inline]
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(
a.len(),
b.len(),
"manhattan_distance: dimension mismatch ({} vs {})",
a.len(),
b.len()
);
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
}
#[inline]
pub fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
assert_eq!(
a.len(),
b.len(),
"hamming_distance: length mismatch ({} vs {})",
a.len(),
b.len()
);
a.iter()
.zip(b.iter())
.map(|(x, y)| (x ^ y).count_ones())
.sum()
}
#[inline]
pub fn l2_norm(v: &[f32]) -> f32 {
sqrt_f32(v.iter().map(|x| x * x).sum::<f32>())
}
#[inline]
pub fn l2_normalize(v: &mut [f32]) {
let norm = l2_norm(v);
if norm > 0.0 {
let inv = 1.0 / norm;
for x in v.iter_mut() {
*x *= inv;
}
}
}
#[inline]
pub fn l2_normalized(v: &[f32]) -> Vec<f32> {
let mut out = v.to_vec();
l2_normalize(&mut out);
out
}
pub fn quantize_binary(v: &[f32]) -> Vec<u8> {
let byte_count = v.len().div_ceil(8);
let mut result = vec![0u8; byte_count];
for (i, &val) in v.iter().enumerate() {
if val > 0.0 {
let byte_idx = i / 8;
let bit_idx = 7 - (i % 8); result[byte_idx] |= 1 << bit_idx;
}
}
result
}
pub fn quantize_scalar<const N: usize>(v: &[f32; N]) -> SQVec<N> {
let mut min_val = f32::INFINITY;
let mut max_val = f32::NEG_INFINITY;
for &x in v {
if x < min_val {
min_val = x;
}
if x > max_val {
max_val = x;
}
}
let mut codes = [0u8; N];
let range = max_val - min_val;
if range > 0.0 {
let inv_range = 255.0 / range;
for (i, &x) in v.iter().enumerate() {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let q = ((x - min_val) * inv_range + 0.5) as u8;
codes[i] = q;
}
}
SQVec {
min_val,
max_val,
codes,
}
}
#[inline]
pub fn dequantize_scalar<const N: usize>(sq: &SQVec<N>) -> [f32; N] {
sq.dequantize()
}
#[inline]
pub fn sq_euclidean_distance_sq<const N: usize>(query: &[f32; N], sq: &SQVec<N>) -> f32 {
let range = sq.max_val - sq.min_val;
if range == 0.0 {
return query
.iter()
.map(|&q| {
let d = q - sq.min_val;
d * d
})
.sum();
}
let scale = range / 255.0;
let mut sum = 0.0f32;
for (i, &q) in query.iter().enumerate() {
let dequant = sq.min_val + f32::from(sq.codes[i]) * scale;
let diff = q - dequant;
sum += diff * diff;
}
sum
}
#[inline]
pub fn sq_dot_product<const N: usize>(query: &[f32; N], sq: &SQVec<N>) -> f32 {
let range = sq.max_val - sq.min_val;
if range == 0.0 {
return query.iter().sum::<f32>() * sq.min_val;
}
let scale = range / 255.0;
let mut sum = 0.0f32;
for (i, &q) in query.iter().enumerate() {
let dequant = sq.min_val + f32::from(sq.codes[i]) * scale;
sum += q * dequant;
}
sum
}
#[derive(Debug, Clone)]
pub struct Neighbor<K> {
pub key: K,
pub distance: f32,
}
impl<K> PartialEq for Neighbor<K> {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl<K> Eq for Neighbor<K> {}
impl<K> PartialOrd for Neighbor<K> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<K> Ord for Neighbor<K> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.distance
.partial_cmp(&other.distance)
.unwrap_or(core::cmp::Ordering::Equal)
}
}
pub fn nearest_k<K, I, F>(iter: I, query: &[f32], k: usize, distance_fn: F) -> Vec<Neighbor<K>>
where
I: Iterator<Item = (K, Vec<f32>)>,
F: Fn(&[f32], &[f32]) -> f32,
{
if k == 0 {
return Vec::new();
}
let mut heap: BinaryHeap<Neighbor<K>> = BinaryHeap::with_capacity(k + 1);
for (key, vec) in iter {
let dist = distance_fn(query, &vec);
if heap.len() < k {
heap.push(Neighbor {
key,
distance: dist,
});
} else if heap.peek().is_some_and(|worst| dist < worst.distance) {
heap.pop();
heap.push(Neighbor {
key,
distance: dist,
});
}
}
let mut results: Vec<Neighbor<K>> = heap.into_vec();
results.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(core::cmp::Ordering::Equal)
});
results
}
pub fn nearest_k_fixed<K, I, F, const N: usize>(
iter: I,
query: &[f32; N],
k: usize,
distance_fn: F,
) -> Vec<Neighbor<K>>
where
I: Iterator<Item = (K, [f32; N])>,
F: Fn(&[f32], &[f32]) -> f32,
{
if k == 0 {
return Vec::new();
}
let mut heap: BinaryHeap<Neighbor<K>> = BinaryHeap::with_capacity(k + 1);
for (key, vec) in iter {
let dist = distance_fn(query.as_slice(), vec.as_slice());
if heap.len() < k {
heap.push(Neighbor {
key,
distance: dist,
});
} else if heap.peek().is_some_and(|worst| dist < worst.distance) {
heap.pop();
heap.push(Neighbor {
key,
distance: dist,
});
}
}
let mut results: Vec<Neighbor<K>> = heap.into_vec();
results.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(core::cmp::Ordering::Equal)
});
results
}
#[inline]
pub fn write_f32_le(dest: &mut [u8], values: &[f32]) {
assert_eq!(
dest.len(),
values.len() * 4,
"write_f32_le: buffer size {} != expected {}",
dest.len(),
values.len() * 4
);
for (i, val) in values.iter().enumerate() {
let start = i * 4;
dest[start..start + 4].copy_from_slice(&val.to_le_bytes());
}
}
#[inline]
pub fn read_f32_le(src: &[u8]) -> Vec<f32> {
assert_eq!(
src.len() % 4,
0,
"read_f32_le: byte length {} is not a multiple of 4",
src.len()
);
let count = src.len() / 4;
let mut result = Vec::with_capacity(count);
for i in 0..count {
let start = i * 4;
let bytes: [u8; 4] = src[start..start + 4].try_into().unwrap();
result.push(f32::from_le_bytes(bytes));
}
result
}