pub trait DistanceMetric: Clone + Send + Sync {
fn distance(&self, a: &[f32], b: &[f32]) -> f32;
fn supports_expanded_form(&self) -> bool {
false
}
fn normalize_centroids(&self) -> bool {
false
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct SquaredEuclidean;
impl DistanceMetric for SquaredEuclidean {
#[inline]
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(feature = "simd")]
if a.len() >= 32 {
return innr::l2_distance_squared(a, b);
}
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = x - y;
d * d
})
.sum()
}
fn supports_expanded_form(&self) -> bool {
true
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct Euclidean;
impl DistanceMetric for Euclidean {
#[inline]
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
#[cfg(feature = "simd")]
if a.len() >= 32 {
return innr::l2_distance(a, b);
}
SquaredEuclidean.distance(a, b).sqrt()
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct CosineDistance;
impl DistanceMetric for CosineDistance {
fn normalize_centroids(&self) -> bool {
true
}
#[inline]
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(feature = "simd")]
if a.len() >= 16 {
return 1.0 - innr::cosine(a, b);
}
{
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = (norm_a * norm_b).sqrt();
if denom < 1e-9 {
return 0.0;
}
1.0 - (dot / denom)
}
}
}
#[deprecated(
since = "0.5.1",
note = "Violates DistanceMetric non-negativity contract. Will be removed in 0.6.0. Use CosineDistance instead."
)]
#[derive(Clone, Copy, Debug, Default)]
pub struct InnerProductDistance;
#[allow(deprecated)]
impl DistanceMetric for InnerProductDistance {
#[inline]
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(feature = "simd")]
if a.len() >= 16 {
return -innr::dot(a, b);
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
-dot
}
}
#[deprecated(
since = "0.5.1",
note = "No algorithm tests exist for this type. Will be removed in 0.6.0. If you use it, please open an issue."
)]
#[derive(Clone, Debug)]
pub struct CompositeDistance<A: DistanceMetric, B: DistanceMetric> {
a: A,
b: B,
weight_a: f32,
weight_b: f32,
}
#[allow(deprecated)]
impl<A: DistanceMetric, B: DistanceMetric> CompositeDistance<A, B> {
pub fn new(a: A, b: B, weight_a: f32, weight_b: f32) -> Self {
Self {
a,
b,
weight_a,
weight_b,
}
}
}
#[allow(deprecated)]
impl<A: DistanceMetric, B: DistanceMetric> DistanceMetric for CompositeDistance<A, B> {
#[inline]
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
self.weight_a * self.a.distance(a, b) + self.weight_b * self.b.distance(a, b)
}
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use super::*;
#[test]
fn squared_euclidean_basic() {
let a = [1.0, 0.0];
let b = [0.0, 1.0];
assert!((SquaredEuclidean.distance(&a, &b) - 2.0).abs() < 1e-6);
}
#[test]
fn euclidean_basic() {
let a = [3.0, 0.0];
let b = [0.0, 4.0];
assert!((Euclidean.distance(&a, &b) - 5.0).abs() < 1e-6);
}
#[test]
fn cosine_identical_vectors() {
let a = [1.0, 2.0, 3.0];
assert!(CosineDistance.distance(&a, &a).abs() < 1e-6);
}
#[test]
fn cosine_orthogonal_vectors() {
let a = [1.0, 0.0];
let b = [0.0, 1.0];
assert!((CosineDistance.distance(&a, &b) - 1.0).abs() < 1e-6);
}
#[test]
fn inner_product_basic() {
let a = [1.0, 2.0];
let b = [3.0, 4.0];
assert!((InnerProductDistance.distance(&a, &b) - (-11.0)).abs() < 1e-6);
}
#[test]
fn self_distance_is_zero() {
let v = [1.0, 2.0, 3.0];
assert!(SquaredEuclidean.distance(&v, &v).abs() < 1e-6);
assert!(Euclidean.distance(&v, &v).abs() < 1e-6);
assert!(CosineDistance.distance(&v, &v).abs() < 1e-6);
}
#[test]
fn composite_weighted_combination() {
let metric = CompositeDistance::new(SquaredEuclidean, Euclidean, 0.5, 0.5);
let d = metric.distance(&[0.0, 0.0], &[3.0, 4.0]);
assert!((d - 15.0).abs() < 1e-6);
}
#[test]
fn composite_weight_one_zero_degenerates_to_a() {
let metric = CompositeDistance::new(SquaredEuclidean, Euclidean, 1.0, 0.0);
let a = [1.0, 2.0];
let b = [4.0, 6.0];
let expected = SquaredEuclidean.distance(&a, &b);
assert!((metric.distance(&a, &b) - expected).abs() < 1e-6);
}
#[test]
fn composite_weight_zero_one_degenerates_to_b() {
let metric = CompositeDistance::new(SquaredEuclidean, Euclidean, 0.0, 1.0);
let a = [1.0, 2.0];
let b = [4.0, 6.0];
let expected = Euclidean.distance(&a, &b);
assert!((metric.distance(&a, &b) - expected).abs() < 1e-6);
}
#[test]
fn composite_self_distance_zero() {
let metric = CompositeDistance::new(SquaredEuclidean, Euclidean, 0.7, 0.3);
let v = [1.0, 2.0, 3.0];
assert!(metric.distance(&v, &v).abs() < 1e-6);
}
#[test]
fn composite_symmetry() {
let metric = CompositeDistance::new(SquaredEuclidean, CosineDistance, 0.6, 0.4);
let a = [1.0, 2.0, 3.0];
let b = [4.0, 5.0, 6.0];
let d_ab = metric.distance(&a, &b);
let d_ba = metric.distance(&b, &a);
assert!((d_ab - d_ba).abs() < 1e-6);
}
#[test]
fn composite_equal_weights() {
let metric = CompositeDistance::new(SquaredEuclidean, Euclidean, 1.0, 1.0);
let a = [0.0f32];
let b = [2.0f32];
let d = metric.distance(&a, &b);
assert!((d - 6.0).abs() < 1e-6);
}
}