1use crate::{ItemAndDistance, KdPoint};
2
3pub fn kd_nearest<'a, T: KdPoint>(
4 kdtree: &'a [T],
5 query: &impl KdPoint<Scalar = T::Scalar, Dim = T::Dim>,
6) -> ItemAndDistance<'a, T, T::Scalar> {
7 kd_nearest_by(kdtree, query, |item, k| item.at(k))
8}
9
10pub fn kd_nearest_by<'a, T, P: KdPoint>(
11 kdtree: &'a [T],
12 query: &P,
13 get: impl Fn(&T, usize) -> P::Scalar + Copy,
14) -> ItemAndDistance<'a, T, P::Scalar> {
15 fn distance_squared<P: KdPoint, T>(
16 p1: &P,
17 p2: &T,
18 get: impl Fn(&T, usize) -> P::Scalar,
19 ) -> P::Scalar {
20 let mut squared_distance = <P::Scalar as num_traits::Zero>::zero();
21 for i in 0..P::dim() {
22 let diff = p1.at(i) - get(p2, i);
23 squared_distance += diff * diff;
24 }
25 squared_distance
26 }
27 fn recurse<'a, T, Q: KdPoint>(
28 nearest: &mut ItemAndDistance<'a, T, Q::Scalar>,
29 kdtree: &'a [T],
30 get: impl Fn(&T, usize) -> Q::Scalar + Copy,
31 query: &Q,
32 axis: usize,
33 ) {
34 let mid_idx = kdtree.len() / 2;
35 let item = &kdtree[mid_idx];
36 let squared_distance = distance_squared(query, item, get);
37 if squared_distance < nearest.squared_distance {
38 nearest.item = item;
39 nearest.squared_distance = squared_distance;
40 use num_traits::Zero;
41 if nearest.squared_distance.is_zero() {
42 return;
43 }
44 }
45 let mid_pos = get(item, axis);
46 let [branch1, branch2] = if query.at(axis) < mid_pos {
47 [&kdtree[..mid_idx], &kdtree[mid_idx + 1..]]
48 } else {
49 [&kdtree[mid_idx + 1..], &kdtree[..mid_idx]]
50 };
51 if !branch1.is_empty() {
52 recurse(nearest, branch1, get, query, (axis + 1) % Q::dim());
53 }
54 if !branch2.is_empty() {
55 let diff = query.at(axis) - mid_pos;
56 if diff * diff < nearest.squared_distance {
57 recurse(nearest, branch2, get, query, (axis + 1) % Q::dim());
58 }
59 }
60 }
61 assert!(!kdtree.is_empty());
62 let mut nearest = ItemAndDistance {
63 item: &kdtree[0],
64 squared_distance: distance_squared(query, &kdtree[0], get),
65 };
66 recurse(&mut nearest, kdtree, get, query, 0);
67 nearest
68}
69
70#[allow(dead_code)]
71pub fn kd_nearest_with<T, Scalar>(
72 kdtree: &[T],
73 dim: usize,
74 kd_difference: impl Fn(&T, usize) -> Scalar + Copy,
75) -> ItemAndDistance<'_, T, Scalar>
76where
77 Scalar: num_traits::NumAssign + Copy + PartialOrd,
78{
79 fn squared_distance<T, Scalar: num_traits::NumAssign + Copy>(
80 item: &T,
81 dim: usize,
82 kd_difference: impl Fn(&T, usize) -> Scalar + Copy,
83 ) -> Scalar {
84 let mut squared_distance = Scalar::zero();
85 for k in 0..dim {
86 let diff = kd_difference(item, k);
87 squared_distance += diff * diff;
88 }
89 squared_distance
90 }
91 fn recurse<'a, T, Scalar>(
92 nearest: &mut ItemAndDistance<'a, T, Scalar>,
93 kdtree: &'a [T],
94 axis: usize,
95 dim: usize,
96 kd_difference: impl Fn(&T, usize) -> Scalar + Copy,
97 ) where
98 Scalar: num_traits::NumAssign + Copy + PartialOrd,
99 {
100 let mid_idx = kdtree.len() / 2;
101 let mid = &kdtree[mid_idx];
102 let squared_distance = squared_distance(mid, dim, kd_difference);
103 if squared_distance < nearest.squared_distance {
104 *nearest = ItemAndDistance {
105 item: mid,
106 squared_distance,
107 };
108 if nearest.squared_distance.is_zero() {
109 return;
110 }
111 }
112 let [branch1, branch2] = if kd_difference(mid, axis) < Scalar::zero() {
113 [&kdtree[..mid_idx], &kdtree[mid_idx + 1..]]
114 } else {
115 [&kdtree[mid_idx + 1..], &kdtree[..mid_idx]]
116 };
117 if !branch1.is_empty() {
118 recurse(nearest, branch1, (axis + 1) % dim, dim, kd_difference);
119 }
120 if !branch2.is_empty() {
121 let diff = kd_difference(mid, axis);
122 if diff * diff < nearest.squared_distance {
123 recurse(nearest, branch2, (axis + 1) % dim, dim, kd_difference);
124 }
125 }
126 }
127 assert!(!kdtree.is_empty());
128 let mut nearest = ItemAndDistance {
129 item: &kdtree[0],
130 squared_distance: squared_distance(&kdtree[0], dim, kd_difference),
131 };
132 recurse(&mut nearest, kdtree, 0, dim, kd_difference);
133 nearest
134}