use std::iter::zip;
use crate::math::{point::Point, FloatNumber};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum DistanceMetric {
#[default]
Euclidean,
SquaredEuclidean,
}
impl DistanceMetric {
#[inline]
#[must_use]
pub fn measure<T, const N: usize>(&self, point1: &Point<T, N>, point2: &Point<T, N>) -> T
where
T: FloatNumber,
{
match self {
DistanceMetric::Euclidean => squared_euclidean(point1, point2).sqrt(),
DistanceMetric::SquaredEuclidean => squared_euclidean(point1, point2),
}
}
}
#[inline]
#[must_use]
fn squared_euclidean<T, const N: usize>(point1: &Point<T, N>, point2: &Point<T, N>) -> T
where
T: FloatNumber,
{
zip(point1.iter(), point2.iter())
.map(|(value1, value2)| {
let diff = *value1 - *value2;
diff * diff
})
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_euclidean_distance() {
let point1 = [0.0, 1.0];
let point2 = [1.0, 0.0];
let distance = DistanceMetric::Euclidean.measure(&point1, &point2);
assert_eq!(distance, 2.0_f32.sqrt());
}
#[test]
fn test_square_euclidean_distance() {
let point1 = [0.0, 1.0];
let point2 = [1.0, 0.0];
let distance = DistanceMetric::SquaredEuclidean.measure(&point1, &point2);
assert_eq!(distance, 2.0);
}
#[test]
fn test_identical_points() {
let point = [1.0, 2.0];
assert_eq!(DistanceMetric::Euclidean.measure(&point, &point), 0.0);
assert_eq!(
DistanceMetric::SquaredEuclidean.measure(&point, &point),
0.0
);
}
#[test]
fn test_with_nan_values() {
let point1 = [0.0, f32::NAN];
let point2 = [0.0, 0.0];
let distance = DistanceMetric::Euclidean.measure(&point1, &point2);
assert!(distance.is_nan());
}
#[test]
fn test_three_dimensional_points() {
let point1 = [1.0, 2.0, 3.0];
let point2 = [4.0, 5.0, 6.0];
assert_eq!(
DistanceMetric::Euclidean.measure(&point1, &point2),
27.0_f32.sqrt()
);
assert_eq!(
DistanceMetric::SquaredEuclidean.measure(&point1, &point2),
27.0_f32
);
}
}