use crate::float::kdtree::Axis;
use crate::traits::DistanceMetric;
#[doc(alias = "minkowski")]
#[doc(alias = "l1")]
pub struct Manhattan {}
impl<A: Axis, const K: usize> DistanceMetric<A, K> for Manhattan {
#[inline]
fn dist(a: &[A; K], b: &[A; K]) -> A {
a.iter()
.zip(b.iter())
.map(|(&a_val, &b_val)| (a_val - b_val).abs())
.fold(A::zero(), std::ops::Add::add)
}
#[inline]
fn dist1(a: A, b: A) -> A {
(a - b).abs()
}
#[inline]
fn accumulate(rd: A, delta: A) -> A {
rd + delta
}
}
pub struct Chebyshev {}
impl<A: Axis, const K: usize> DistanceMetric<A, K> for Chebyshev {
#[inline]
fn dist(a: &[A; K], b: &[A; K]) -> A {
a.iter()
.zip(b.iter())
.map(|(&a_val, &b_val)| (a_val - b_val).abs())
.fold(A::zero(), |acc, val| acc.max(val))
}
#[inline]
fn dist1(a: A, b: A) -> A {
(a - b).abs()
}
#[inline]
fn accumulate(rd: A, delta: A) -> A {
rd.max(delta)
}
}
#[doc(alias = "minkowski")]
#[doc(alias = "l2")]
pub struct SquaredEuclidean {}
impl<A: Axis, const K: usize> DistanceMetric<A, K> for SquaredEuclidean {
#[inline]
fn dist(a: &[A; K], b: &[A; K]) -> A {
a.iter()
.zip(b.iter())
.map(|(&a_val, &b_val)| (a_val - b_val) * (a_val - b_val))
.fold(A::zero(), std::ops::Add::add)
}
#[inline]
fn dist1(a: A, b: A) -> A {
(a - b) * (a - b)
}
#[inline]
fn accumulate(rd: A, delta: A) -> A {
rd + delta
}
}
#[doc(alias = "taxicab")]
#[doc(alias = "l1")]
#[doc(alias = "l2")]
#[doc(alias = "euclidean")]
pub struct Minkowski<const P: u32> {}
impl<const P: u32> Minkowski<P> {
const CHECK_P: () = {
if P == 1 {
panic!("Minkowski<1> is not recommended. Use `kiddo::Manhattan` metric instead.");
}
if P == 2 {
panic!(
"Minkowski<2> is not recommended. Use `kiddo::SquaredEuclidean` metric instead."
);
}
};
}
impl<A: Axis, const K: usize, const P: u32> DistanceMetric<A, K> for Minkowski<P> {
#[inline]
#[allow(clippy::let_unit_value)]
fn dist(a: &[A; K], b: &[A; K]) -> A {
let _ = Self::CHECK_P;
a.iter()
.zip(b.iter())
.map(|(&av, &bv)| (av - bv).abs().powi(P as i32))
.fold(A::zero(), std::ops::Add::add)
}
#[inline]
#[allow(clippy::let_unit_value)]
fn dist1(a: A, b: A) -> A {
let _ = Self::CHECK_P;
(a - b).abs().powi(P as i32)
}
#[inline]
fn accumulate(rd: A, delta: A) -> A {
rd + delta
}
}
#[doc(alias = "taxicab")]
#[doc(alias = "l1")]
#[doc(alias = "l2")]
#[doc(alias = "euclidean")]
pub struct MinkowskiF64<const P_BITS: u64> {}
impl<const P_BITS: u64> MinkowskiF64<P_BITS> {
const CHECK_P: () = {
let p = f64::from_bits(P_BITS);
if (p - 1.0).abs() < f64::EPSILON {
panic!(
"MinkowskiF64<P=1.0> is not recommended. Use `kiddo::Manhattan` metric instead."
);
}
if (p - 2.0).abs() < f64::EPSILON {
panic!(
"MinkowskiF64<P=2.0> is not recommended. Use `kiddo::SquaredEuclidean` metric instead."
);
}
if p.fract() < f64::EPSILON {
panic!(
"MinkowskiF64<P as F64> with power that is basically integer. Consider using Minkowski<P as u32> instead.",
);
}
};
}
impl<A: Axis, const K: usize, const P_BITS: u64> DistanceMetric<A, K> for MinkowskiF64<P_BITS> {
#[inline]
#[allow(clippy::let_unit_value)]
fn dist(a: &[A; K], b: &[A; K]) -> A {
let _ = Self::CHECK_P;
let p = f64::from_bits(P_BITS);
a.iter()
.zip(b.iter())
.map(|(&av, &bv)| {
let diff = (av - bv).abs().to_f64().unwrap();
A::from(diff.powf(p)).unwrap()
})
.fold(A::zero(), std::ops::Add::add)
}
#[inline]
#[allow(clippy::let_unit_value)]
fn dist1(a: A, b: A) -> A {
let _ = Self::CHECK_P;
let p = f64::from_bits(P_BITS);
let diff = (a - b).abs().to_f64().unwrap();
A::from(diff.powf(p)).unwrap()
}
#[inline]
fn accumulate(rd: A, delta: A) -> A {
rd + delta
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
mod common_metric_tests {
use super::*;
#[rstest]
#[case::zeros_1d([0.0f32], [0.0f32])]
#[case::normal_1d([1.0f32], [2.0f32])]
#[case::neg_1d([-1.0f32], [1.0f32])]
#[case::zeros_2d([0.0f32, 0.0f32], [0.0f32, 0.0f32])]
#[case::normal_2d([1.0f32, 2.0f32], [3.0f32, 4.0f32])]
#[case::large_2d([1e30f32, 1e30f32], [-1e30f32, -1e30f32])]
#[case::zeros_3d([0.0f32, 0.0f32, 0.0f32], [0.0f32, 0.0f32, 0.0f32])]
#[case::normal_3d([1.0f32, 2.0f32, 3.0f32], [4.0f32, 5.0f32, 6.0f32])]
#[case::zeros_4d([0.0f32; 4], [0.0f32; 4])]
#[case::normal_4d([1.0f32; 4], [2.0f32; 4])]
#[case::zeros_5d([0.0f32; 5], [0.0f32; 5])]
#[case::normal_5d([1.0f32; 5], [2.0f32; 5])]
fn test_metric_non_negativity<A: Axis, const K: usize, D: DistanceMetric<A, K>>(
#[values(
Manhattan {},
SquaredEuclidean {},
Chebyshev {},
Minkowski::<3> {},
MinkowskiF64::<{ 0.5f64.to_bits() }> {}
)]
_metric: D,
#[case] a: [A; K],
#[case] b: [A; K],
) {
let distance = D::dist(&a, &b);
assert!(distance >= A::zero());
}
#[rstest]
#[case::zeros_1d([0.0f32])]
#[case::normal_1d([1.0f32])]
#[case::zeros_2d([0.0f32, 0.0f32])]
#[case::normal_2d([1.0f32, 2.0f32])]
#[case::zeros_3d([0.0f32, 0.0f32, 0.0f32])]
#[case::normal_3d([1.0f32, 2.0f32, 3.0f32])]
#[case::zeros_4d([0.0f32; 4])]
#[case::zeros_5d([0.0f32; 5])]
fn test_metric_identity<A: Axis, const K: usize, D: DistanceMetric<A, K>>(
#[values(
Manhattan {},
SquaredEuclidean {},
Chebyshev {},
Minkowski::<3> {},
MinkowskiF64::<{ 0.5f64.to_bits() }> {}
)]
_metric: D,
#[case] a: [A; K],
) {
assert_eq!(D::dist(&a, &a), A::zero());
}
#[rstest]
#[case::normal_1d([1.0f64], [2.0f64])]
#[case::neg_1d([-1.0f64], [1.0f64])]
#[case::normal_2d([1.0f64, 2.0f64], [3.0f64, 4.0f64])]
#[case::normal_3d([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64])]
#[case::normal_4d([1.0f64; 4], [2.0f64; 4])]
#[case::normal_5d([1.0f64; 5], [2.0f64; 5])]
fn test_metric_symmetry<A: Axis, const K: usize, D: DistanceMetric<A, K>>(
#[values(
Manhattan {},
SquaredEuclidean {},
Chebyshev {},
Minkowski::<3> {},
MinkowskiF64::<{ 0.5f64.to_bits() }> {}
)]
_metric: D,
#[case] a: [A; K],
#[case] b: [A; K],
) {
assert_eq!(D::dist(&a, &b), D::dist(&b, &a));
}
}
mod manhattan_tests {
use super::*;
#[rstest]
#[case([0.0f32, 0.0f32], [0.0f32, 0.0f32], 0.0f32)] #[case([0.0f32, 0.0f32], [1.0f32, 0.0f32], 1.0f32)] #[case([0.0f32, 0.0f32], [0.0f32, 1.0f32], 1.0f32)] #[case([0.0f32, 0.0f32], [1.0f32, 1.0f32], 2.0f32)] #[case([-1.0f32, -1.0f32], [1.0f32, 1.0f32], 4.0f32)] #[case([1.5f32, 2.5f32], [3.5f32, 4.5f32], 4.0f32)] fn test_manhattan_distance_2d(
#[case] a: [f32; 2],
#[case] b: [f32; 2],
#[case] expected: f32,
) {
assert_eq!(Manhattan::dist(&a, &b), expected);
}
#[rstest]
#[case([0.0f64, 0.0f64, 0.0f64], [0.0f64, 0.0f64, 0.0f64], 0.0f64)] #[case([0.0f64, 0.0f64, 0.0f64], [1.0f64, 2.0f64, 3.0f64], 6.0f64)] #[case([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64], 9.0f64)] fn test_manhattan_distance_3d(
#[case] a: [f64; 3],
#[case] b: [f64; 3],
#[case] expected: f64,
) {
assert_eq!(Manhattan::dist(&a, &b), expected);
}
#[rstest]
#[case([0.0f32], [0.0f32], 0.0f32)] #[case([0.0f32], [5.0f32], 5.0f32)] #[case([5.0f32], [0.0f32], 5.0f32)] #[case([-3.0f32], [7.0f32], 10.0f32)] fn test_manhattan_distance_1d(
#[case] a: [f32; 1],
#[case] b: [f32; 1],
#[case] expected: f32,
) {
assert_eq!(Manhattan::dist(&a, &b), expected);
}
#[test]
fn test_manhattan_distance_4d() {
let a = [1.0f32, 2.0f32, 3.0f32, 4.0f32];
let b = [5.0f32, 6.0f32, 7.0f32, 8.0f32];
let expected = 16.0f32; assert_eq!(Manhattan::dist(&a, &b), expected);
}
#[test]
fn test_manhattan_distance_5d() {
let a = [0.0f64, 1.0f64, 2.0f64, 3.0f64, 4.0f64];
let b = [5.0f64, 6.0f64, 7.0f64, 8.0f64, 9.0f64];
let expected = 25.0f64; assert_eq!(Manhattan::dist(&a, &b), expected);
}
#[test]
fn test_manhattan_dist1() {
assert_eq!(
<Manhattan as DistanceMetric<f32, 1>>::dist1(0.0f32, 0.0f32),
0.0f32
); assert_eq!(
<Manhattan as DistanceMetric<f32, 1>>::dist1(1.0f32, 0.0f32),
1.0f32
); assert_eq!(
<Manhattan as DistanceMetric<f32, 1>>::dist1(0.0f32, 1.0f32),
1.0f32
); assert_eq!(
<Manhattan as DistanceMetric<f32, 1>>::dist1(-2.5f32, 3.5f32),
6.0f32
); assert_eq!(
<Manhattan as DistanceMetric<f32, 1>>::dist1(1000.0f32, -1000.0f32),
2000.0f32
); }
}
mod squared_euclidean_tests {
use super::*;
#[rstest]
#[case([0.0f32, 0.0f32], [0.0f32, 0.0f32], 0.0f32)] #[case([0.0f32, 0.0f32], [1.0f32, 0.0f32], 1.0f32)] #[case([0.0f32, 0.0f32], [0.0f32, 1.0f32], 1.0f32)] #[case([0.0f32, 0.0f32], [1.0f32, 1.0f32], 2.0f32)] #[case([-1.0f32, -1.0f32], [1.0f32, 1.0f32], 8.0f32)] #[case([1.5f32, 2.5f32], [3.5f32, 4.5f32], 8.0f32)] #[case([0.0f32, 0.0f32], [3.0f32, 4.0f32], 25.0f32)] fn test_squared_euclidean_distance_2d(
#[case] a: [f32; 2],
#[case] b: [f32; 2],
#[case] expected: f32,
) {
assert_eq!(SquaredEuclidean::dist(&a, &b), expected);
}
#[rstest]
#[case([0.0f64, 0.0f64, 0.0f64], [0.0f64, 0.0f64, 0.0f64], 0.0f64)] #[case([0.0f64, 0.0f64, 0.0f64], [1.0f64, 2.0f64, 2.0f64], 9.0f64)] #[case([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64], 27.0f64)] fn test_squared_euclidean_distance_3d(
#[case] a: [f64; 3],
#[case] b: [f64; 3],
#[case] expected: f64,
) {
assert_eq!(SquaredEuclidean::dist(&a, &b), expected);
}
#[rstest]
#[case([0.0f32], [0.0f32], 0.0f32)] #[case([0.0f32], [5.0f32], 25.0f32)] #[case([5.0f32], [0.0f32], 25.0f32)] #[case([-3.0f32], [7.0f32], 100.0f32)] fn test_squared_euclidean_distance_1d(
#[case] a: [f32; 1],
#[case] b: [f32; 1],
#[case] expected: f32,
) {
assert_eq!(SquaredEuclidean::dist(&a, &b), expected);
}
#[test]
fn test_squared_euclidean_dist1() {
assert_eq!(
<SquaredEuclidean as DistanceMetric<f32, 1>>::dist1(0.0f32, 0.0f32),
0.0f32
); assert_eq!(
<SquaredEuclidean as DistanceMetric<f32, 1>>::dist1(1.0f32, 0.0f32),
1.0f32
); assert_eq!(
<SquaredEuclidean as DistanceMetric<f32, 1>>::dist1(0.0f32, 1.0f32),
1.0f32
); assert_eq!(
<SquaredEuclidean as DistanceMetric<f32, 1>>::dist1(-2.5f32, 3.5f32),
36.0f32
); assert_eq!(
<SquaredEuclidean as DistanceMetric<f32, 1>>::dist1(10.0f32, -10.0f32),
400.0f32
); }
#[test]
fn test_squared_euclidean_triangle_inequality_property() {
let a = [0.0f32, 0.0f32];
let b = [1.0f32, 0.0f32];
let c = [1.0f32, 1.0f32];
let dist_ab = SquaredEuclidean::dist(&a, &b);
let dist_ac = SquaredEuclidean::dist(&a, &c);
let dist_bc = SquaredEuclidean::dist(&b, &c);
assert_eq!(dist_ab, 1.0f32);
assert_eq!(dist_bc, 1.0f32);
assert_eq!(dist_ac, 2.0f32);
}
}
mod chebyshev_tests {
use super::*;
#[rstest]
#[case([0.0f32, 0.0f32], [0.0f32, 0.0f32], 0.0f32)] #[case([0.0f32, 0.0f32], [1.0f32, 0.0f32], 1.0f32)] #[case([0.0f32, 0.0f32], [0.0f32, 1.0f32], 1.0f32)] #[case([0.0f32, 0.0f32], [1.0f32, 1.0f32], 1.0f32)] #[case([-1.0f32, -1.0f32], [1.0f32, 1.0f32], 2.0f32)] #[case([1.5f32, 2.5f32], [3.5f32, 4.5f32], 2.0f32)] #[case([0.0f32, 0.0f32], [2.0f32, 1.0f32], 2.0f32)] #[case([0.0f32, 0.0f32], [1.0f32, 2.0f32], 2.0f32)] fn test_chebyshev_distance_2d(
#[case] a: [f32; 2],
#[case] b: [f32; 2],
#[case] expected: f32,
) {
assert_eq!(Chebyshev::dist(&a, &b), expected);
}
#[rstest]
#[case([0.0f64, 0.0f64, 0.0f64], [0.0f64, 0.0f64, 0.0f64], 0.0f64)] #[case([0.0f64, 0.0f64, 0.0f64], [1.0f64, 2.0f64, 3.0f64], 3.0f64)] #[case([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64], 3.0f64)] fn test_chebyshev_distance_3d(
#[case] a: [f64; 3],
#[case] b: [f64; 3],
#[case] expected: f64,
) {
assert_eq!(Chebyshev::dist(&a, &b), expected);
}
#[rstest]
#[case([0.0f32], [0.0f32], 0.0f32)] #[case([0.0f32], [5.0f32], 5.0f32)] #[case([5.0f32], [0.0f32], 5.0f32)] #[case([-3.0f32], [7.0f32], 10.0f32)] fn test_chebyshev_distance_1d(
#[case] a: [f32; 1],
#[case] b: [f32; 1],
#[case] expected: f32,
) {
assert_eq!(Chebyshev::dist(&a, &b), expected);
}
#[test]
fn test_chebyshev_distance_4d() {
let a = [1.0f32, 2.0f32, 3.0f32, 4.0f32];
let b = [5.0f32, 6.0f32, 7.0f32, 8.0f32];
let expected = 4.0f32; assert_eq!(Chebyshev::dist(&a, &b), expected);
}
#[test]
fn test_chebyshev_distance_5d() {
let a = [0.0f64, 1.0f64, 2.0f64, 3.0f64, 4.0f64];
let b = [5.0f64, 6.0f64, 7.0f64, 8.0f64, 9.0f64];
let expected = 5.0f64; assert_eq!(Chebyshev::dist(&a, &b), expected);
}
#[rstest]
#[case(0.0f32, 0.0f32, 0.0f32)] #[case(1.0f32, 0.0f32, 1.0f32)] #[case(0.0f32, 1.0f32, 1.0f32)] #[case(-2.5f32, 3.5f32, 6.0f32)] #[case(1000.0f32, -1000.0f32, 2000.0f32)] fn test_chebyshev_dist1(#[case] a: f32, #[case] b: f32, #[case] expected: f32) {
assert_eq!(<Chebyshev as DistanceMetric<f32, 1>>::dist1(a, b), expected);
}
#[test]
fn test_chebyshev_symmetry() {
let a = [1.0f64, 2.0f64, 3.0f64];
let b = [4.0f64, 5.0f64, 6.0f64];
assert_eq!(Chebyshev::dist(&a, &b), Chebyshev::dist(&b, &a));
}
#[test]
fn test_chebyshev_identity() {
let a = [1.0f32, 2.0f32, 3.0f32];
assert_eq!(Chebyshev::dist(&a, &a), 0.0f32);
}
#[test]
fn test_chebyshev_non_negativity() {
let a = [1.0f32, 2.0f32];
let b = [3.0f32, 4.0f32];
let distance = Chebyshev::dist(&a, &b);
assert!(distance >= 0.0f32);
}
#[test]
fn test_chebyshev_max_property() {
let a = [0.0, 0.0];
let b = [3.0, 1.0];
let result = Chebyshev::dist(&a, &b);
assert_eq!(result, 3.0);
assert_ne!(result, 4.0);
assert_ne!(result, (10.0_f64).sqrt());
}
}
mod minkowski_tests {
use super::*;
#[rstest]
#[case([0.0f32, 0.0f32], [0.0f32, 0.0f32], 0.0f32)] #[case([0.0f32, 0.0f32], [1.0f32, 1.0f32], 2.0f32)] #[case([0.0f32, 0.0f32], [2.0f32, 0.0f32], 8.0f32)] #[case([-1.0f32, -1.0f32], [1.0f32, 1.0f32], 16.0f32)] fn test_minkowski_3_distance_2d(
#[case] a: [f32; 2],
#[case] b: [f32; 2],
#[case] expected: f32,
) {
assert_eq!(Minkowski::<3>::dist(&a, &b), expected);
}
#[rstest]
#[case([0.0f64, 0.0f64, 0.0f64], [1.0f64, 1.0f64, 1.0f64], 3.0f64)]
#[case([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64], 81.0f64)] fn test_minkowski_3_distance_3d(
#[case] a: [f64; 3],
#[case] b: [f64; 3],
#[case] expected: f64,
) {
assert_eq!(Minkowski::<3>::dist(&a, &b), expected);
}
#[rstest]
#[case([0.0f32, 0.0f32], [1.0f32, 1.0f32], 2.0f32)] #[case([0.0f32, 0.0f32], [4.0f32, 9.0f32], 5.0f32)] #[case([1.0f32, 1.0f32], [5.0f32, 10.0f32], 5.0f32)] fn test_minkowski_05_distance_2d(
#[case] a: [f32; 2],
#[case] b: [f32; 2],
#[case] expected: f32,
) {
assert_eq!(MinkowskiF64::<{ 0.5f64.to_bits() }>::dist(&a, &b), expected);
}
#[rstest]
#[case([0.0f32], [0.0f32], 0.0f32)]
#[case([0.0f32], [2.0f32], 8.0f32)]
#[case([-2.0f32], [2.0f32], 64.0f32)] fn test_minkowski_3_distance_1d(
#[case] a: [f32; 1],
#[case] b: [f32; 1],
#[case] expected: f32,
) {
assert_eq!(Minkowski::<3>::dist(&a, &b), expected);
}
#[rstest]
#[case(0.0f32, 0.0f32, 0.0f32)]
#[case(1.0f32, 0.0f32, 1.0f32)]
#[case(0.0f32, 2.0f32, 8.0f32)]
#[case(-2.5f32, 3.5f32, 216.0f32)] fn test_minkowski_3_dist1(#[case] a: f32, #[case] b: f32, #[case] expected: f32) {
assert_eq!(
<Minkowski<3> as DistanceMetric<f32, 1>>::dist1(a, b),
expected
);
}
#[rstest]
#[case(0.0f32, 0.0f32, 0.0f32)]
#[case(1.0f32, 0.0f32, 1.0f32)]
#[case(0.0f32, 4.0f32, 2.0f32)]
#[case(10.0f32, 35.0f32, 5.0f32)] fn test_minkowski_05_dist1(#[case] a: f32, #[case] b: f32, #[case] expected: f32) {
assert_eq!(
<MinkowskiF64<{ 0.5f64.to_bits() }> as DistanceMetric<f32, 1>>::dist1(a, b),
expected
);
}
}
#[cfg(feature = "f16")]
mod f16_tests {
use super::*;
use half::f16;
#[test]
fn test_manhattan_f16() {
let a = [f16::from_f32(0.0), f16::from_f32(0.0)];
let b = [f16::from_f32(1.0), f16::from_f32(1.0)];
let result = Manhattan::dist(&a, &b);
let expected = f16::from_f32(2.0);
assert_eq!(result, expected);
}
#[test]
fn test_squared_euclidean_f16() {
let a = [f16::from_f32(0.0), f16::from_f32(0.0)];
let b = [f16::from_f32(1.0), f16::from_f32(1.0)];
let result = SquaredEuclidean::dist(&a, &b);
let expected = f16::from_f32(2.0);
assert_eq!(result, expected);
}
}
mod integration_tests {
use super::*;
use crate::KdTree;
use rand::prelude::*;
use rand_distr::Normal;
use rstest::rstest;
#[derive(Debug, Clone, Copy)]
enum DataScenario {
NoTies,
Ties,
Gaussian,
}
#[derive(Debug, Clone, Copy)]
enum TreeType {
Mutable,
Immutable,
}
impl DataScenario {
fn get(&self, dim: usize) -> Vec<Vec<f64>> {
match (self, dim) {
(DataScenario::NoTies, 1) => vec![
vec![1.0],
vec![2.0],
vec![4.0],
vec![7.0],
vec![-9.0],
vec![16.0],
],
(DataScenario::NoTies, 2) => vec![
vec![0.0, 0.0],
vec![1.1, 0.1],
vec![2.3, 0.4],
vec![3.6, 0.9],
vec![5.0, 1.6],
vec![6.5, 2.5],
],
(DataScenario::NoTies, 3) => vec![
vec![0.0, 0.0, 0.0],
vec![1.1, 0.1, 0.01],
vec![2.3, 0.4, 0.08],
vec![-3.6, -0.9, -0.27],
vec![5.0, 1.6, 0.64],
vec![6.5, 2.5, 1.25],
],
(DataScenario::NoTies, 4) => vec![
vec![0.0, 0.0, 0.0, 1000.0],
vec![1.1, 0.1, 0.01, 1000.001],
vec![2.3, 0.4, 0.08, 1000.008],
vec![3.6, 0.9, 0.27, 1000.027],
vec![5.0, 1.6, 0.64, 1000.256],
vec![6.5, 2.5, 1.25, 1000.625],
],
(DataScenario::Ties, 1) => vec![
vec![0.0],
vec![1.0],
vec![1.0],
vec![2.0],
vec![2.0],
vec![3.0],
],
(DataScenario::Ties, 2) => vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![-1.0, 0.0],
vec![0.0, -1.0],
vec![1.0, 1.0],
],
(DataScenario::Ties, 3) => vec![
vec![0.0, 0.0, 0.0],
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
vec![-1.0, 0.0, 0.0],
vec![0.0, -1.0, 0.0],
],
(DataScenario::Ties, 4) => vec![
vec![0.0, 0.0, 0.0, 0.0],
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0],
vec![-1.0, 0.0, 0.0, 0.0],
],
(DataScenario::Gaussian, d) => {
let mut rng = StdRng::seed_from_u64(8757);
let normal = Normal::new(1.0, 10.0).unwrap();
let n_samples = 2000;
let mut data = vec![vec![0.0; d]; n_samples];
for i in 0..n_samples {
for j in 0..d {
data[i][j] = normal.sample(&mut rng);
}
}
data
}
_ => panic!("Unsupported dimension {} for scenario {:?}", dim, self),
}
}
}
fn run_nearest_n_test_helper<D: DistanceMetric<f64, 6>>(
dim: usize,
tree_type: TreeType,
scenario: DataScenario,
n: usize,
) {
let data = scenario.get(dim);
let query_point = &data[0];
let mut points: Vec<[f64; 6]> = Vec::with_capacity(data.len());
for row in &data {
let mut p = [0.0; 6];
for (i, &val) in row.iter().enumerate() {
p[i] = val;
}
points.push(p);
}
let mut query_arr = [0.0; 6];
for (i, &val) in query_point.iter().enumerate() {
if i < 6 {
query_arr[i] = val;
}
}
let mut expected: Vec<(usize, f64)> = points
.iter()
.enumerate()
.map(|(i, &point)| {
let dist = D::dist(&query_arr, &point);
(i, dist)
})
.collect();
expected.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let expected_distances: Vec<f64> = expected.iter().map(|(_, d)| *d).collect();
println!(
"Query: {:?}, TreeType: {:?}, Scenario: {:?}, dim={}, n={}",
query_point, tree_type, scenario, dim, n
);
let results = match tree_type {
TreeType::Mutable => {
let mut tree: crate::float::kdtree::KdTree<f64, u64, 6, 2048, u32> =
crate::float::kdtree::KdTree::new();
for (i, point) in points.iter().enumerate() {
tree.add(point, i as u64);
}
tree.nearest_n::<D>(&query_arr, n)
}
TreeType::Immutable => {
let tree: crate::immutable::float::kdtree::ImmutableKdTree<f64, u64, 6, 2048> =
crate::immutable::float::kdtree::ImmutableKdTree::new_from_slice(&points);
tree.nearest_n::<D>(&query_arr, std::num::NonZero::new(n).unwrap())
}
};
println!("Results (len: {}):", results.len());
assert_eq!(results[0].item, 0, "First result should be the query point");
assert_eq!(
results[0].distance, 0.0,
"First result distance should be 0.0"
);
for (i, result) in results.iter().enumerate() {
assert_eq!(
result.distance, expected_distances[i],
"Distance at index {} should be {}, but was {}",
i, expected_distances[i], result.distance
);
}
if matches!(scenario, DataScenario::NoTies) {
for (i, result) in results.iter().enumerate() {
let expected_id = expected[i].0;
assert_eq!(
result.item, expected_id as u64,
"Result {}: item ID mismatch. Expected {}, got {}",
i, expected_id, result.item
);
}
}
}
#[rstest]
fn test_nearest_n_chebyshev(
#[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType,
#[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)]
scenario: DataScenario,
#[values(1, 2, 3, 4, 5, 6)] n: usize,
#[values(1, 2, 3, 4)] dim: usize,
) {
run_nearest_n_test_helper::<Chebyshev>(dim, tree_type, scenario, n);
}
#[rstest]
fn test_nearest_n_squared_euclidean(
#[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType,
#[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)]
scenario: DataScenario,
#[values(1, 2, 3, 4, 5, 6)] n: usize,
#[values(1, 2, 3, 4)] dim: usize,
) {
run_nearest_n_test_helper::<SquaredEuclidean>(dim, tree_type, scenario, n);
}
#[rstest]
fn test_nearest_n_manhattan(
#[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType,
#[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)]
scenario: DataScenario,
#[values(1, 2, 3, 4, 5, 6)] n: usize,
#[values(1, 2, 3, 4)] dim: usize,
) {
run_nearest_n_test_helper::<Manhattan>(dim, tree_type, scenario, n);
}
#[rstest]
fn test_nearest_n_minkowski(
#[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType,
#[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)]
scenario: DataScenario,
#[values(1, 2, 3, 4, 5, 6)] n: usize,
#[values(1, 2, 3, 4)] dim: usize,
#[values(3, 4)] p: u32,
) {
match p {
3 => run_nearest_n_test_helper::<Minkowski<3>>(dim, tree_type, scenario, n),
4 => run_nearest_n_test_helper::<Minkowski<4>>(dim, tree_type, scenario, n),
_ => unreachable!(),
}
}
#[rstest]
fn test_nearest_n_minkowski_f64(
#[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType,
#[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)]
scenario: DataScenario,
#[values(1, 2, 3, 4, 5, 6)] n: usize,
#[values(1, 2, 3, 4)] dim: usize,
#[values(0.5, 1.5)] p: f64,
) {
if (p - 0.5).abs() < f64::EPSILON {
run_nearest_n_test_helper::<MinkowskiF64<{ 0.5f64.to_bits() }>>(
dim, tree_type, scenario, n,
);
} else if (p - 1.5).abs() < f64::EPSILON {
run_nearest_n_test_helper::<MinkowskiF64<{ 1.5f64.to_bits() }>>(
dim, tree_type, scenario, n,
);
} else {
unreachable!()
}
}
#[test]
fn test_nearest_n_manhattan_distance() {
let mut kdtree: KdTree<f32, 2> = KdTree::new();
let points = [
([0.0f32, 0.0f32], 0), ([1.0f32, 0.0f32], 1), ([0.0f32, 1.0f32], 2), ([2.0f32, 0.0f32], 3), ([0.0f32, 2.0f32], 4), ([3.0f32, 3.0f32], 5), ];
for (point, index) in points {
kdtree.add(&point, index);
}
let query_point = [0.0f32, 0.0f32];
let results = kdtree.nearest_n::<Manhattan>(&query_point, 4);
assert_eq!(results.len(), 4);
assert_eq!(results[0].item, 0);
assert_eq!(results[0].distance, 0.0);
assert_eq!(results[1].item, 1);
assert_eq!(results[1].distance, 1.0);
assert_eq!(results[2].item, 2);
assert_eq!(results[2].distance, 1.0);
assert!(results[3].item == 3 || results[3].item == 4);
assert_eq!(results[3].distance, 2.0);
}
#[test]
fn test_nearest_n_squared_euclidean_distance() {
let mut kdtree: KdTree<f64, 2> = KdTree::new();
let points = [
([0.0, 0.0], 0), ([1.0, 0.0], 1), ([0.0, 1.0], 2), ([1.0, 1.0], 3), ([2.0, 0.0], 4), ([0.0, 2.0], 5), ([3.0, 4.0], 6), ];
for (point, index) in points {
kdtree.add(&point, index);
}
let query_point = [0.0, 0.0];
let results = kdtree.nearest_n::<SquaredEuclidean>(&query_point, 5);
assert_eq!(results.len(), 5);
assert_eq!(results[0].item, 0);
assert_eq!(results[0].distance, 0.0);
assert_eq!(results[1].item, 1);
assert_eq!(results[1].distance, 1.0);
assert_eq!(results[2].item, 2);
assert_eq!(results[2].distance, 1.0);
assert_eq!(results[3].item, 3);
assert_eq!(results[3].distance, 2.0);
assert_eq!(results[4].item, 4);
assert_eq!(results[4].distance, 4.0);
assert!(results[4].distance > results[3].distance);
}
#[test]
fn test_nearest_n_different_metrics_produce_different_orderings() {
let mut kdtree: KdTree<f32, 2> = KdTree::new();
let points = [
([0.0, 0.0], 0), ([2.0, 1.0], 1), ([1.0, 2.0], 2), ([3.0, 0.0], 3), ([0.0, 3.0], 4), ];
for (point, index) in points {
kdtree.add(&point, index);
}
let query_point = [0.0, 0.0];
let manhattan_results = kdtree.nearest_n::<Manhattan>(&query_point, 3);
let euclidean_results = kdtree.nearest_n::<SquaredEuclidean>(&query_point, 3);
assert_eq!(manhattan_results[0].item, 0);
assert_eq!(euclidean_results[0].item, 0);
assert_eq!(manhattan_results[1].distance, 3.0);
assert_eq!(manhattan_results[2].distance, 3.0);
assert_eq!(euclidean_results[1].distance, 5.0);
assert_eq!(euclidean_results[2].distance, 5.0);
let euclidean_items: Vec<u64> = euclidean_results
.iter()
.skip(1) .take(2) .map(|nn| nn.item)
.collect();
assert!(euclidean_items.contains(&1) || euclidean_items.contains(&2));
let p1 = [2.0, 1.0];
let p2 = [1.0, 2.0];
let p3 = [3.0, 0.0];
let manhattan_p1 = Manhattan::dist(&query_point, &p1);
let manhattan_p2 = Manhattan::dist(&query_point, &p2);
let manhattan_p3 = Manhattan::dist(&query_point, &p3);
let euclidean_p1 = SquaredEuclidean::dist(&query_point, &p1);
let euclidean_p2 = SquaredEuclidean::dist(&query_point, &p2);
let euclidean_p3 = SquaredEuclidean::dist(&query_point, &p3);
assert_eq!(manhattan_p1, 3.0);
assert_eq!(manhattan_p2, 3.0);
assert_eq!(manhattan_p3, 3.0);
assert_eq!(euclidean_p1, 5.0);
assert_eq!(euclidean_p2, 5.0);
assert_eq!(euclidean_p3, 9.0);
}
#[test]
fn test_nearest_n_3d_different_metrics() {
let mut kdtree: KdTree<f64, 3> = KdTree::new();
let points = [
([1.0, 1.0, 1.0], 0), ([2.0, 1.0, 1.0], 1), ([1.0, 2.0, 1.0], 2), ([1.0, 1.0, 2.0], 3), ([3.0, 1.0, 1.0], 4), ([0.0, 0.0, 0.0], 5), ];
for (point, index) in points {
kdtree.add(&point, index);
}
let query_point = [1.0, 1.0, 1.0];
let results = kdtree.nearest_n::<Manhattan>(&query_point, 4);
assert_eq!(results.len(), 4);
assert_eq!(results[0].item, 0);
assert_eq!(results[0].distance, 0.0);
let nearby_items: Vec<u64> = results
.iter()
.skip(1) .take(3) .map(|nn| nn.item)
.collect();
assert!(nearby_items.contains(&1));
assert!(nearby_items.contains(&2));
assert!(nearby_items.contains(&3));
for result in results.iter().skip(1).take(3) {
assert_eq!(result.distance, 1.0);
}
let all_items: Vec<u64> = results.iter().map(|nn| nn.item).collect();
assert!(!all_items.contains(&4));
assert!(!all_items.contains(&5));
}
#[test]
fn test_nearest_n_large_scale() {
let mut kdtree: KdTree<f32, 2> = KdTree::new();
let mut index = 0;
for x in 0i32..10 {
for y in 0i32..10 {
let point = [x as f32, y as f32];
kdtree.add(&point, index);
index += 1;
}
}
let query_point = [5.0f32, 5.0f32];
let results = kdtree.nearest_n::<SquaredEuclidean>(&query_point, 10);
assert_eq!(results.len(), 10);
assert_eq!(results[0].item, 55);
assert_eq!(results[0].distance, 0.0);
for i in 1..10 {
assert!(results[i].distance >= results[i - 1].distance);
}
let expected_distances = [0.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32];
for (i, &expected_dist) in expected_distances.iter().enumerate() {
if i < results.len() {
assert_eq!(results[i].distance, expected_dist);
}
}
}
#[test]
fn test_nearest_n_chebyshev_distance() {
let mut kdtree: KdTree<f32, 2> = KdTree::new();
let points = [
([0.0f32, 0.0f32], 0), ([1.0f32, 0.0f32], 1), ([0.0f32, 1.0f32], 2), ([2.0f32, 0.0f32], 3), ([0.0f32, 2.0f32], 4), ([1.0f32, 1.0f32], 5), ];
for (point, index) in points {
kdtree.add(&point, index);
}
let query_point = [0.0f32, 0.0f32];
let results = kdtree.nearest_n::<Chebyshev>(&query_point, 5);
assert_eq!(results.len(), 5);
assert_eq!(results[0].item, 0);
assert_eq!(results[0].distance, 0.0);
let nearby_items: Vec<u64> = results
.iter()
.skip(1) .take(4) .filter(|r| (r.distance - 1.0).abs() < 0.001) .map(|nn| nn.item)
.collect();
assert!(nearby_items.contains(&1));
assert!(nearby_items.contains(&2));
assert!(nearby_items.contains(&5));
}
fn run_within_test_helper<D: DistanceMetric<f64, 6>>(
dim: usize,
tree_type: TreeType,
scenario: DataScenario,
radius: f64,
inclusive: bool,
) {
let data = scenario.get(dim);
let query_point = &data[0];
let mut points: Vec<[f64; 6]> = Vec::with_capacity(data.len());
for row in &data {
let mut p = [0.0; 6];
for (i, &val) in row.iter().enumerate() {
p[i] = val;
}
points.push(p);
}
let mut query_arr = [0.0; 6];
for (i, &val) in query_point.iter().enumerate() {
if i < 6 {
query_arr[i] = val;
}
}
let mut expected: Vec<(usize, f64)> = points
.iter()
.enumerate()
.filter_map(|(i, &point)| {
let dist = D::dist(&query_arr, &point);
if if inclusive {
dist <= radius
} else {
dist < radius
} {
Some((i, dist))
} else {
None
}
})
.collect();
expected.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
println!(
"Within Query: TreeType: {:?}, Scenario: {:?}, dim={}, radius={}, inclusive={}",
tree_type, scenario, dim, radius, inclusive
);
let mut results = match tree_type {
TreeType::Mutable => {
let mut tree: crate::float::kdtree::KdTree<f64, u64, 6, 2048, u32> =
crate::float::kdtree::KdTree::new();
for (i, point) in points.iter().enumerate() {
tree.add(point, i as u64);
}
tree.within_exclusive::<D>(&query_arr, radius, inclusive)
}
TreeType::Immutable => {
let tree: crate::immutable::float::kdtree::ImmutableKdTree<f64, u64, 6, 2048> =
crate::immutable::float::kdtree::ImmutableKdTree::new_from_slice(&points);
tree.within_exclusive::<D>(&query_arr, radius, inclusive)
}
};
results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
println!(
"Results (len: {}), Expected (len: {})",
results.len(),
expected.len()
);
assert_eq!(
results.len(),
expected.len(),
"Result count mismatch. Expected {}, got {}",
expected.len(),
results.len()
);
for (i, result) in results.iter().enumerate() {
assert!(
(result.distance - expected[i].1).abs() < 1e-10,
"Distance at index {} should be {}, but was {}",
i,
expected[i].1,
result.distance
);
}
if matches!(scenario, DataScenario::NoTies) {
for (i, result) in results.iter().enumerate() {
let expected_id = expected[i].0;
assert_eq!(
result.item, expected_id as u64,
"Result {}: item ID mismatch. Expected {}, got {}",
i, expected_id, result.item
);
}
}
}
#[rstest]
fn test_within_chebyshev(
#[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType,
#[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)]
scenario: DataScenario,
#[values(0.1, 0.5, 1.0, 2.0)] radius: f64,
#[values(1, 2, 3, 4)] dim: usize,
#[values(true, false)] inclusive: bool,
) {
run_within_test_helper::<Chebyshev>(dim, tree_type, scenario, radius, inclusive);
}
#[rstest]
fn test_within_squared_euclidean(
#[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType,
#[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)]
scenario: DataScenario,
#[values(0.1, 0.5, 1.0, 2.0)] radius: f64,
#[values(1, 2, 3, 4)] dim: usize,
#[values(true, false)] inclusive: bool,
) {
run_within_test_helper::<SquaredEuclidean>(dim, tree_type, scenario, radius, inclusive);
}
#[rstest]
fn test_within_manhattan(
#[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType,
#[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)]
scenario: DataScenario,
#[values(0.1, 0.5, 1.0, 2.0)] radius: f64,
#[values(1, 2, 3, 4)] dim: usize,
#[values(true, false)] inclusive: bool,
) {
run_within_test_helper::<Manhattan>(dim, tree_type, scenario, radius, inclusive);
}
#[rstest]
#[case(true, 1)]
#[case(false, 0)]
fn test_within_boundary_inclusiveness(
#[case] inclusive: bool,
#[case] expected_len: usize,
) {
let mut kdtree: KdTree<f64, 2> = KdTree::new();
kdtree.add(&[1.0, 0.0], 1);
kdtree.add(&[2.0, 0.0], 2);
let query = [0.0, 0.0];
let radius = 1.0;
let results = kdtree.within_exclusive::<SquaredEuclidean>(&query, radius, inclusive);
assert_eq!(results.len(), expected_len);
if expected_len > 0 {
assert_eq!(results[0].item, 1);
assert_eq!(results[0].distance, 1.0);
}
let max_qty = std::num::NonZero::new(10).unwrap();
let results = kdtree.nearest_n_within_exclusive::<SquaredEuclidean>(
&query, radius, max_qty, true, inclusive,
);
assert_eq!(results.len(), expected_len);
}
#[test]
fn test_chebyshev_vs_manhattan_ordering() {
let mut kdtree: KdTree<f32, 2> = KdTree::new();
let points = [
([0.0f32, 0.0f32], 0), ([3.0f32, 1.0f32], 1), ([1.0f32, 3.0f32], 2), ([2.0f32, 2.0f32], 3), ([4.0f32, 0.5f32], 4), ];
for (point, index) in points {
kdtree.add(&point, index);
}
let query_point = [0.0f32, 0.0f32];
let chebyshev_results = kdtree.nearest_n::<Chebyshev>(&query_point, 4);
let manhattan_results = kdtree.nearest_n::<Manhattan>(&query_point, 4);
assert_eq!(chebyshev_results[0].item, 0);
assert_eq!(manhattan_results[0].item, 0);
assert_eq!(chebyshev_results[1].item, 3);
assert_eq!(chebyshev_results[1].distance, 2.0);
let manhattan_items: Vec<u64> = manhattan_results
.iter()
.skip(1)
.take(3)
.map(|r| r.item)
.collect();
assert!(manhattan_items.contains(&1) || manhattan_items.contains(&2));
assert_eq!(chebyshev_results[1].distance, 2.0); assert_eq!(manhattan_results[1].distance, 4.0); }
}
}