use crate::error::ShardexError;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum DistanceMetric {
#[default]
Cosine,
Euclidean,
DotProduct,
}
impl DistanceMetric {
pub fn similarity(&self, a: &[f32], b: &[f32]) -> Result<f32, ShardexError> {
if a.len() != b.len() {
return Err(ShardexError::InvalidDimension {
expected: a.len(),
actual: b.len(),
});
}
if a.is_empty() {
return Ok(0.5); }
let score = match self {
DistanceMetric::Cosine => cosine_similarity(a, b),
DistanceMetric::Euclidean => euclidean_similarity(a, b),
DistanceMetric::DotProduct => dot_product_similarity(a, b),
};
let score = score.clamp(0.0, 1.0);
if score.is_nan() {
return Ok(0.5); }
Ok(score)
}
pub fn name(&self) -> &'static str {
match self {
DistanceMetric::Cosine => "cosine",
DistanceMetric::Euclidean => "euclidean",
DistanceMetric::DotProduct => "dot_product",
}
}
pub fn description(&self) -> &'static str {
match self {
DistanceMetric::Cosine => "Measures angle between vectors, ideal for high-dimensional text embeddings",
DistanceMetric::Euclidean => "Measures straight-line distance, ideal for geometric and spatial data",
DistanceMetric::DotProduct => "Measures projection magnitude, ideal for normalized vectors",
}
}
pub fn prefers_normalized(&self) -> bool {
matches!(self, DistanceMetric::Cosine | DistanceMetric::DotProduct)
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.5; }
let cosine = dot_product / (norm_a * norm_b);
let cosine = cosine.clamp(-1.0, 1.0);
(cosine + 1.0) / 2.0
}
fn euclidean_similarity(a: &[f32], b: &[f32]) -> f32 {
let distance_squared: f32 = a
.iter()
.zip(b.iter())
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum();
let distance = distance_squared.sqrt();
1.0 / (1.0 + distance)
}
fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
1.0 / (1.0 + (-dot_product).exp())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distance_metric_default() {
assert_eq!(DistanceMetric::default(), DistanceMetric::Cosine);
}
#[test]
fn test_distance_metric_names() {
assert_eq!(DistanceMetric::Cosine.name(), "cosine");
assert_eq!(DistanceMetric::Euclidean.name(), "euclidean");
assert_eq!(DistanceMetric::DotProduct.name(), "dot_product");
}
#[test]
fn test_distance_metric_descriptions() {
for metric in [
DistanceMetric::Cosine,
DistanceMetric::Euclidean,
DistanceMetric::DotProduct,
] {
let desc = metric.description();
assert!(!desc.is_empty());
assert!(desc.len() > 10); }
}
#[test]
fn test_prefers_normalized() {
assert!(DistanceMetric::Cosine.prefers_normalized());
assert!(DistanceMetric::DotProduct.prefers_normalized());
assert!(!DistanceMetric::Euclidean.prefers_normalized());
}
#[test]
fn test_cosine_similarity_identical_vectors() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
let similarity = DistanceMetric::Cosine.similarity(&a, &b).unwrap();
assert!((similarity - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_opposite_vectors() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
let similarity = DistanceMetric::Cosine.similarity(&a, &b).unwrap();
assert!((similarity - 0.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal_vectors() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let similarity = DistanceMetric::Cosine.similarity(&a, &b).unwrap();
assert!((similarity - 0.5).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_zero_vectors() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![1.0, 2.0, 3.0];
let similarity = DistanceMetric::Cosine.similarity(&a, &b).unwrap();
assert!((similarity - 0.5).abs() < 1e-6); }
#[test]
fn test_euclidean_similarity_identical_vectors() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
let similarity = DistanceMetric::Euclidean.similarity(&a, &b).unwrap();
assert!((similarity - 1.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_similarity_different_vectors() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
let similarity = DistanceMetric::Euclidean.similarity(&a, &b).unwrap();
let expected = 1.0 / (1.0 + 5.0); assert!((similarity - expected).abs() < 1e-6);
}
#[test]
fn test_dot_product_similarity_positive_correlation() {
let a = vec![1.0, 1.0, 1.0];
let b = vec![2.0, 2.0, 2.0];
let similarity = DistanceMetric::DotProduct.similarity(&a, &b).unwrap();
assert!(similarity > 0.9);
}
#[test]
fn test_dot_product_similarity_negative_correlation() {
let a = vec![1.0, 1.0, 1.0];
let b = vec![-1.0, -1.0, -1.0];
let similarity = DistanceMetric::DotProduct.similarity(&a, &b).unwrap();
assert!(similarity < 0.1);
}
#[test]
fn test_dimension_validation() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0];
for metric in [
DistanceMetric::Cosine,
DistanceMetric::Euclidean,
DistanceMetric::DotProduct,
] {
let result = metric.similarity(&a, &b);
assert!(matches!(result.unwrap_err(), ShardexError::InvalidDimension { .. }));
}
}
#[test]
fn test_empty_vectors() {
let a: Vec<f32> = vec![];
let b: Vec<f32> = vec![];
for metric in [
DistanceMetric::Cosine,
DistanceMetric::Euclidean,
DistanceMetric::DotProduct,
] {
let similarity = metric.similarity(&a, &b).unwrap();
assert!((similarity - 0.5).abs() < 1e-6); }
}
#[test]
fn test_similarity_range_bounds() {
let test_vectors = vec![
(vec![1.0, 0.0, 0.0], vec![1.0, 0.0, 0.0]), (vec![1.0, 0.0, 0.0], vec![-1.0, 0.0, 0.0]), (vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]), (vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]), (vec![0.0, 0.0, 0.0], vec![1.0, 1.0, 1.0]), ];
for metric in [
DistanceMetric::Cosine,
DistanceMetric::Euclidean,
DistanceMetric::DotProduct,
] {
for (a, b) in &test_vectors {
let similarity = metric.similarity(a, b).unwrap();
assert!(
similarity >= 0.0,
"Similarity {} should be >= 0.0 for metric {:?}",
similarity,
metric
);
assert!(
similarity <= 1.0,
"Similarity {} should be <= 1.0 for metric {:?}",
similarity,
metric
);
assert!(
!similarity.is_nan(),
"Similarity should not be NaN for metric {:?}",
metric
);
}
}
}
#[test]
fn test_cosine_similarity_normalization() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
let similarity = cosine_similarity(&a, &b);
assert!((similarity - 0.0).abs() < 1e-6);
let c = vec![1.0, 0.0];
let d = vec![0.0, 1.0];
let similarity2 = cosine_similarity(&c, &d);
assert!((similarity2 - 0.5).abs() < 1e-6);
}
#[test]
fn test_euclidean_similarity_behavior() {
let origin = vec![0.0, 0.0];
let close = vec![1.0, 0.0]; let far = vec![10.0, 0.0];
let close_similarity = euclidean_similarity(&origin, &close);
let far_similarity = euclidean_similarity(&origin, &far);
assert!(close_similarity > far_similarity);
assert!(close_similarity > 0.4); assert!(far_similarity < 0.1); }
#[test]
fn test_dot_product_similarity_sigmoid() {
let a = vec![1.0, 0.0];
let b_positive = vec![2.0, 0.0]; let b_zero = vec![0.0, 1.0]; let b_negative = vec![-1.0, 0.0];
let pos_sim = dot_product_similarity(&a, &b_positive);
let zero_sim = dot_product_similarity(&a, &b_zero);
let neg_sim = dot_product_similarity(&a, &b_negative);
assert!(pos_sim > zero_sim);
assert!(zero_sim > neg_sim);
assert!(pos_sim > 0.8); assert!((zero_sim - 0.5).abs() < 0.1); assert!(neg_sim < 0.4); }
}