use crate::distance::DistanceMetric;
use smallvec::SmallVec;
pub trait DistanceEngine: Send + Sync {
fn distance(&self, a: &[f32], b: &[f32]) -> f32;
fn batch_distance(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
candidates.iter().map(|c| self.distance(query, c)).collect()
}
fn metric(&self) -> DistanceMetric;
#[must_use]
fn is_pre_normalized(&self) -> bool {
false
}
}
#[allow(dead_code)] #[inline]
pub(crate) fn simd_distance_for_metric(metric: DistanceMetric, a: &[f32], b: &[f32]) -> f32 {
match metric {
DistanceMetric::Cosine => 1.0 - crate::simd_native::cosine_similarity_native(a, b),
DistanceMetric::Euclidean => crate::simd_native::euclidean_native(a, b),
DistanceMetric::DotProduct => -crate::simd_native::dot_product_native(a, b),
DistanceMetric::Hamming => crate::simd_native::hamming_distance_native(a, b),
DistanceMetric::Jaccard => 1.0 - crate::simd_native::jaccard_similarity_native(a, b),
}
}
#[inline]
pub(crate) fn batch_distance_with_prefetch(
engine: &impl DistanceEngine,
query: &[f32],
candidates: &[&[f32]],
) -> SmallVec<[f32; 32]> {
let prefetch_distance = crate::simd_native::calculate_prefetch_distance(query.len());
let mut results = SmallVec::with_capacity(candidates.len());
for (i, candidate) in candidates.iter().enumerate() {
if i + prefetch_distance < candidates.len() {
crate::simd_native::prefetch_vector(candidates[i + prefetch_distance]);
}
results.push(engine.distance(query, candidate));
}
results
}
pub struct CpuDistance {
metric: DistanceMetric,
}
impl CpuDistance {
#[must_use]
pub fn new(metric: DistanceMetric) -> Self {
Self { metric }
}
}
impl DistanceEngine for CpuDistance {
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
match self.metric {
DistanceMetric::Cosine => cosine_distance_scalar(a, b),
DistanceMetric::Euclidean => euclidean_distance_scalar(a, b),
DistanceMetric::DotProduct => dot_product_scalar(a, b),
DistanceMetric::Hamming => hamming_distance_scalar(a, b),
DistanceMetric::Jaccard => jaccard_distance_scalar(a, b),
}
}
fn metric(&self) -> DistanceMetric {
self.metric
}
}
pub struct CachedSimdDistance {
metric: DistanceMetric,
engine: crate::simd_native::DistanceEngine,
pre_normalized: bool,
}
impl CachedSimdDistance {
#[must_use]
pub fn new(metric: DistanceMetric, dimension: usize) -> Self {
Self {
metric,
engine: crate::simd_native::DistanceEngine::new(dimension),
pre_normalized: false,
}
}
#[must_use]
pub fn new_prenormalized(metric: DistanceMetric, dimension: usize) -> Self {
Self {
metric,
engine: crate::simd_native::DistanceEngine::new(dimension),
pre_normalized: metric == DistanceMetric::Cosine,
}
}
}
impl DistanceEngine for CachedSimdDistance {
#[allow(clippy::inline_always)] #[inline(always)]
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
match self.metric {
DistanceMetric::Cosine if self.pre_normalized => {
1.0 - self.engine.cosine_similarity(a, b).clamp(-1.0, 1.0)
}
DistanceMetric::Cosine => 1.0 - self.engine.cosine_similarity(a, b),
DistanceMetric::Euclidean => self.engine.euclidean_squared(a, b),
DistanceMetric::DotProduct => -self.engine.dot_product(a, b),
DistanceMetric::Hamming => self.engine.hamming(a, b),
DistanceMetric::Jaccard => 1.0 - self.engine.jaccard(a, b),
}
}
fn batch_distance(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
batch_distance_with_prefetch(self, query, candidates).into_vec()
}
fn metric(&self) -> DistanceMetric {
self.metric
}
fn is_pre_normalized(&self) -> bool {
self.pre_normalized
}
}
#[inline]
fn cosine_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0_f32;
let mut norm_a = 0.0_f32;
let mut norm_b = 0.0_f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = (norm_a * norm_b).sqrt();
if denom == 0.0 {
1.0
} else {
1.0 - (dot / denom)
}
}
#[inline]
fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
#[inline]
fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
-a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
}
#[inline]
fn hamming_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
#[allow(clippy::cast_precision_loss)]
let count = a
.iter()
.zip(b.iter())
.filter(|(x, y)| (x.to_bits() ^ y.to_bits()) != 0)
.count() as f32;
count
}
#[inline]
fn jaccard_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
let mut intersection = 0.0_f32;
let mut union = 0.0_f32;
for (x, y) in a.iter().zip(b.iter()) {
intersection += x.min(*y);
union += x.max(*y);
}
if union == 0.0 {
1.0
} else {
1.0 - (intersection / union)
}
}