use serde::{Deserialize, Serialize};
use crate::error::RagError;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Distance {
#[default]
Cosine,
Euclidean,
DotProduct,
Angular,
Hamming,
}
impl Distance {
#[inline]
pub fn is_similarity(self) -> bool {
matches!(self, Self::Cosine | Self::DotProduct)
}
#[inline]
pub fn is_distance(self) -> bool {
!self.is_similarity()
}
#[inline]
pub fn validate(a: &[f32], b: &[f32]) -> Result<(), RagError> {
if a.iter().any(|x| !x.is_finite()) || b.iter().any(|x| !x.is_finite()) {
return Err(RagError::NonFinite);
}
Ok(())
}
pub fn compute(self, a: &[f32], b: &[f32]) -> Result<f32, RagError> {
if a.is_empty() || a.len() != b.len() {
return Err(RagError::DimensionMismatch {
expected: a.len().max(1),
got: b.len(),
});
}
Self::validate(a, b)?;
let value = match self {
Self::Cosine => cosine(a, b),
Self::Euclidean => euclidean(a, b),
Self::DotProduct => dot(a, b),
Self::Angular => angular(a, b),
Self::Hamming => hamming(a, b),
};
if !value.is_finite() {
return Err(RagError::NonFinite);
}
Ok(value)
}
#[inline]
pub fn to_score(self, value: f32) -> f32 {
if self.is_similarity() {
value
} else {
-value
}
}
}
#[inline]
fn dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
}
#[inline]
fn norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
#[inline]
fn cosine(a: &[f32], b: &[f32]) -> f32 {
let na = norm(a);
let nb = norm(b);
if na < 1e-12 || nb < 1e-12 {
return 0.0;
}
(dot(a, b) / (na * nb)).clamp(-1.0, 1.0)
}
#[inline]
fn euclidean(a: &[f32], b: &[f32]) -> f32 {
let sum: f32 = a
.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = x - y;
d * d
})
.sum();
sum.sqrt()
}
#[inline]
fn angular(a: &[f32], b: &[f32]) -> f32 {
let cos = cosine(a, b).clamp(-1.0, 1.0);
cos.acos() / std::f32::consts::PI
}
#[inline]
fn hamming(a: &[f32], b: &[f32]) -> f32 {
let total: u64 = a
.iter()
.zip(b.iter())
.map(|(x, y)| (x.to_bits() ^ y.to_bits()).count_ones() as u64)
.sum();
total as f32 / a.len() as f32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cosine_identical_unit_vectors_is_one() {
let a = vec![0.6, 0.8];
let b = vec![0.6, 0.8];
let d = Distance::Cosine
.compute(&a, &b)
.expect("finite inputs must succeed");
assert!((d - 1.0).abs() < 1e-6, "got {d}");
}
#[test]
fn euclidean_zero_for_identical() {
let a = vec![1.0, 2.0, 3.0];
let d = Distance::Euclidean.compute(&a, &a).expect("compute");
assert!(d.abs() < 1e-6);
}
#[test]
fn angular_range_zero_to_one() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
let d = Distance::Angular.compute(&a, &b).expect("compute");
assert!((d - 1.0).abs() < 1e-5, "got {d}");
}
#[test]
fn hamming_zero_for_identical() {
let a = vec![1.0, 2.0, 3.0];
let d = Distance::Hamming.compute(&a, &a).expect("compute");
assert_eq!(d, 0.0);
}
#[test]
fn nan_rejected() {
let a = vec![1.0, f32::NAN];
let b = vec![1.0, 2.0];
let e = Distance::Cosine.compute(&a, &b);
assert!(matches!(e, Err(RagError::NonFinite)));
}
#[test]
fn inf_rejected() {
let a = vec![f32::INFINITY, 2.0];
let b = vec![1.0, 2.0];
let e = Distance::Euclidean.compute(&a, &b);
assert!(matches!(e, Err(RagError::NonFinite)));
}
#[test]
fn dim_mismatch_rejected() {
let a = vec![1.0];
let b = vec![1.0, 2.0];
let e = Distance::DotProduct.compute(&a, &b);
assert!(matches!(e, Err(RagError::DimensionMismatch { .. })));
}
#[test]
fn is_similarity_classifies_correctly() {
assert!(Distance::Cosine.is_similarity());
assert!(Distance::DotProduct.is_similarity());
assert!(!Distance::Euclidean.is_similarity());
assert!(Distance::Euclidean.is_distance());
assert!(Distance::Angular.is_distance());
assert!(Distance::Hamming.is_distance());
}
}