#![allow(dead_code)]
use crate::simd;
#[inline]
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
let similarity = simd::dot(a, b);
1.0 - similarity
}
#[inline]
pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::INFINITY;
}
l2_distance_squared(a, b).sqrt()
}
#[inline]
pub fn l2_distance_squared(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::INFINITY;
}
let a_squared = simd::dot(a, a);
let b_squared = simd::dot(b, b);
let ab_dot = simd::dot(a, b);
a_squared + b_squared - 2.0 * ab_dot
}
#[inline]
pub fn inner_product_distance(a: &[f32], b: &[f32]) -> f32 {
-simd::dot(a, b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_distance() {
let a = [1.0, 0.0];
let b = [1.0, 0.0];
assert!((cosine_distance(&a, &b) - 0.0).abs() < 1e-5);
let a = [1.0, 0.0];
let b = [0.0, 1.0];
assert!((cosine_distance(&a, &b) - 1.0).abs() < 1e-5);
}
#[test]
fn test_l2_distance() {
let a = [0.0, 0.0];
let b = [3.0, 4.0];
assert!((l2_distance(&a, &b) - 5.0).abs() < 1e-5);
}
#[test]
fn test_l2_distance_squared() {
let a = [0.0, 0.0];
let b = [3.0, 4.0];
assert!((l2_distance_squared(&a, &b) - 25.0).abs() < 1e-5);
let a2 = [1.0, 2.0];
let b2 = [4.0, 6.0];
let dist = l2_distance(&a2, &b2);
let dist_sq = l2_distance_squared(&a2, &b2);
assert!((dist * dist - dist_sq).abs() < 1e-5);
}
}