use common::DistanceMetric;
pub fn calculate_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
match metric {
DistanceMetric::Cosine => cosine_similarity(a, b),
DistanceMetric::Euclidean => negative_euclidean(a, b),
DistanceMetric::DotProduct => dot_product(a, b),
}
}
#[inline]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let norm_a = norm_a.sqrt();
let norm_b = norm_b.sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[inline]
pub fn negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
-sum.sqrt()
}
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 1e-6;
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < EPSILON);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(cosine_similarity(&a, &b).abs() < EPSILON);
}
#[test]
fn test_cosine_similarity_opposite() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
assert!((cosine_similarity(&a, &b) + 1.0).abs() < EPSILON);
}
#[test]
fn test_cosine_similarity_normalized() {
let a = vec![1.0, 0.0];
let b = vec![0.707107, 0.707107]; let result = cosine_similarity(&a, &b);
assert!((result - 0.707107).abs() < 0.001);
}
#[test]
fn test_euclidean_zero_distance() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
assert!(negative_euclidean(&a, &b).abs() < EPSILON);
}
#[test]
fn test_euclidean_known_distance() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
assert!((negative_euclidean(&a, &b) + 5.0).abs() < EPSILON);
}
#[test]
fn test_dot_product() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert!((dot_product(&a, &b) - 32.0).abs() < EPSILON);
}
#[test]
fn test_dot_product_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(dot_product(&a, &b).abs() < EPSILON);
}
#[test]
fn test_calculate_distance_dispatch() {
let a = vec![1.0, 0.0];
let b = vec![1.0, 0.0];
assert!((calculate_distance(&a, &b, DistanceMetric::Cosine) - 1.0).abs() < EPSILON);
assert!(calculate_distance(&a, &b, DistanceMetric::Euclidean).abs() < EPSILON);
assert!((calculate_distance(&a, &b, DistanceMetric::DotProduct) - 1.0).abs() < EPSILON);
}
}