use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DistanceMetric {
#[default]
Cosine,
Euclidean,
DotProduct,
Manhattan,
}
impl DistanceMetric {
pub fn as_str(&self) -> &str {
match self {
DistanceMetric::Cosine => "cosine",
DistanceMetric::Euclidean => "euclidean",
DistanceMetric::DotProduct => "dot_product",
DistanceMetric::Manhattan => "manhattan",
}
}
}
impl fmt::Display for DistanceMetric {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DistanceMetric::Cosine => write!(f, "cosine"),
DistanceMetric::Euclidean => write!(f, "euclidean"),
DistanceMetric::DotProduct => write!(f, "dot_product"),
DistanceMetric::Manhattan => write!(f, "manhattan"),
}
}
}
pub fn compute_distance(metric: DistanceMetric, a: &[f32], b: &[f32]) -> f32 {
match metric {
DistanceMetric::Cosine => {
let similarity = cosine_similarity(a, b);
(1.0 - similarity) / 2.0
}
DistanceMetric::Euclidean => euclidean_distance(a, b),
DistanceMetric::DotProduct => {
-dot_product(a, b)
}
DistanceMetric::Manhattan => manhattan_distance(a, b),
}
}
pub use crate::hnsw::distance_functions::{
cosine_similarity, dot_product, euclidean_distance, manhattan_distance,
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distance_metric_display() {
assert_eq!(DistanceMetric::Cosine.to_string(), "cosine");
assert_eq!(DistanceMetric::Euclidean.to_string(), "euclidean");
assert_eq!(DistanceMetric::DotProduct.to_string(), "dot_product");
assert_eq!(DistanceMetric::Manhattan.to_string(), "manhattan");
}
#[test]
fn test_distance_metric_default() {
let metric = DistanceMetric::default();
assert_eq!(metric, DistanceMetric::Cosine);
}
#[test]
fn test_distance_metric_equality() {
assert_eq!(DistanceMetric::Cosine, DistanceMetric::Cosine);
assert_ne!(DistanceMetric::Cosine, DistanceMetric::Euclidean);
assert_ne!(DistanceMetric::Euclidean, DistanceMetric::Manhattan);
}
#[test]
fn test_compute_distance_cosine() {
let a = [1.0, 0.0];
let b = [0.0, 1.0];
let distance = compute_distance(DistanceMetric::Cosine, &a, &b);
assert_eq!(distance, 0.5); }
#[test]
fn test_compute_distance_euclidean() {
let a = [0.0, 0.0];
let b = [3.0, 4.0];
let distance = compute_distance(DistanceMetric::Euclidean, &a, &b);
assert_eq!(distance, 5.0);
}
#[test]
fn test_compute_distance_dot_product() {
let a = [1.0, 0.0];
let b = [1.0, 0.0];
let distance = compute_distance(DistanceMetric::DotProduct, &a, &b);
assert_eq!(distance, -1.0); }
#[test]
fn test_compute_distance_manhattan() {
let a = [1.0, 2.0];
let b = [4.0, 0.0];
let distance = compute_distance(DistanceMetric::Manhattan, &a, &b);
assert_eq!(distance, 5.0); }
#[test]
fn test_all_metrics_identical_vectors() {
let a = [1.0, 0.0];
let b = [1.0, 0.0];
let cosine_dist = compute_distance(DistanceMetric::Cosine, &a, &b);
let euclidean_dist = compute_distance(DistanceMetric::Euclidean, &a, &b);
let dot_dist = compute_distance(DistanceMetric::DotProduct, &a, &b);
let manhattan_dist = compute_distance(DistanceMetric::Manhattan, &a, &b);
assert_eq!(cosine_dist, 0.0);
assert_eq!(euclidean_dist, 0.0);
assert_eq!(manhattan_dist, 0.0);
assert_eq!(dot_dist, -1.0);
}
}