1#![allow(dead_code)]
4
5#[derive(Debug, Clone, Copy)]
9#[allow(dead_code)]
10pub struct KdPoint2 {
11 pub x: f32,
12 pub y: f32,
13 pub id: usize,
14}
15
16impl KdPoint2 {
17 #[allow(dead_code)]
18 pub fn new(x: f32, y: f32, id: usize) -> Self {
19 Self { x, y, id }
20 }
21
22 fn dist_sq(self, other: [f32; 2]) -> f32 {
23 let dx = self.x - other[0];
24 let dy = self.y - other[1];
25 dx * dx + dy * dy
26 }
27}
28
29#[derive(Debug, Clone)]
31#[allow(dead_code)]
32enum KdNode2 {
33 Leaf(KdPoint2),
34 Split {
35 axis: usize,
36 split_val: f32,
37 left: Box<KdNode2>,
38 right: Box<KdNode2>,
39 },
40}
41
42#[derive(Debug, Clone, Default)]
44#[allow(dead_code)]
45pub struct KdTree2D {
46 root: Option<Box<KdNode2>>,
47 count: usize,
48}
49
50fn build_node(points: &mut [KdPoint2], depth: usize) -> Option<Box<KdNode2>> {
51 if points.is_empty() {
52 return None;
53 }
54 if points.len() == 1 {
55 return Some(Box::new(KdNode2::Leaf(points[0])));
56 }
57 let axis = depth % 2;
58 if axis == 0 {
59 points.sort_by(|a, b| a.x.partial_cmp(&b.x).unwrap_or(std::cmp::Ordering::Equal));
60 } else {
61 points.sort_by(|a, b| a.y.partial_cmp(&b.y).unwrap_or(std::cmp::Ordering::Equal));
62 }
63 let mid = points.len() / 2;
64 let split_val = if axis == 0 {
65 points[mid].x
66 } else {
67 points[mid].y
68 };
69 let left = build_node(&mut points[..mid], depth + 1);
70 let right = build_node(&mut points[mid..], depth + 1);
71 let node = match (left, right) {
72 (Some(l), Some(r)) => KdNode2::Split {
73 axis,
74 split_val,
75 left: l,
76 right: r,
77 },
78 (Some(l), None) => *l,
79 (None, Some(r)) => *r,
80 (None, None) => unreachable!(),
81 };
82 Some(Box::new(node))
83}
84
85fn nn_search(node: &KdNode2, query: [f32; 2], best: &mut Option<(f32, KdPoint2)>) {
86 match node {
87 KdNode2::Leaf(p) => {
88 let d = p.dist_sq(query);
89 if best.is_none_or(|(bd, _)| d < bd) {
90 *best = Some((d, *p));
91 }
92 }
93 KdNode2::Split {
94 axis,
95 split_val,
96 left,
97 right,
98 } => {
99 let qval = if *axis == 0 { query[0] } else { query[1] };
100 let (near, far) = if qval < *split_val {
101 (left.as_ref(), right.as_ref())
102 } else {
103 (right.as_ref(), left.as_ref())
104 };
105 nn_search(near, query, best);
106 let plane_dist = (qval - split_val) * (qval - split_val);
107 if best.is_none_or(|(bd, _)| plane_dist < bd) {
108 nn_search(far, query, best);
109 }
110 }
111 }
112}
113
114impl KdTree2D {
115 #[allow(dead_code)]
117 pub fn build(points: &[KdPoint2]) -> Self {
118 let mut pts = points.to_vec();
119 let n = pts.len();
120 Self {
121 root: build_node(&mut pts, 0),
122 count: n,
123 }
124 }
125
126 #[allow(dead_code)]
128 pub fn nearest(&self, query: [f32; 2]) -> Option<(KdPoint2, f32)> {
129 let root = self.root.as_ref()?;
130 let mut best = None;
131 nn_search(root, query, &mut best);
132 best.map(|(d, p)| (p, d))
133 }
134
135 #[allow(dead_code)]
137 pub fn len(&self) -> usize {
138 self.count
139 }
140
141 #[allow(dead_code)]
143 pub fn is_empty(&self) -> bool {
144 self.count == 0
145 }
146}
147
148#[allow(dead_code)]
150pub fn kd2_build(xys: &[[f32; 2]]) -> KdTree2D {
151 let pts: Vec<KdPoint2> = xys
152 .iter()
153 .enumerate()
154 .map(|(i, &[x, y])| KdPoint2::new(x, y, i))
155 .collect();
156 KdTree2D::build(&pts)
157}
158
159#[allow(dead_code)]
161pub fn kd2_nn_dist_sq(tree: &KdTree2D, query: [f32; 2]) -> Option<f32> {
162 tree.nearest(query).map(|(_, d)| d)
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 fn sample_tree() -> KdTree2D {
170 kd2_build(&[[0.0, 0.0], [3.0, 0.0], [1.0, 2.0], [5.0, 5.0]])
171 }
172
173 #[test]
174 fn empty_tree_returns_none() {
175 let t = KdTree2D::build(&[]);
176 assert!(t.nearest([0.0, 0.0]).is_none());
177 }
178
179 #[test]
180 fn single_point_is_nearest() {
181 let t = kd2_build(&[[2.0, 3.0]]);
182 let (p, _) = t.nearest([0.0, 0.0]).expect("should succeed");
183 assert_eq!(p.id, 0);
184 }
185
186 #[test]
187 fn nearest_to_origin() {
188 let t = sample_tree();
189 let (p, d) = t.nearest([0.0, 0.0]).expect("should succeed");
190 assert_eq!(p.id, 0);
191 assert!(d < 1e-5);
192 }
193
194 #[test]
195 fn nearest_to_far_point() {
196 let t = sample_tree();
197 let (p, _) = t.nearest([5.0, 5.0]).expect("should succeed");
198 assert_eq!(p.id, 3);
199 }
200
201 #[test]
202 fn tree_len_matches_input() {
203 let t = sample_tree();
204 assert_eq!(t.len(), 4);
205 }
206
207 #[test]
208 fn is_empty_false_for_nonempty() {
209 let t = sample_tree();
210 assert!(!t.is_empty());
211 }
212
213 #[test]
214 fn is_empty_true_for_empty() {
215 let t = KdTree2D::build(&[]);
216 assert!(t.is_empty());
217 }
218
219 #[test]
220 fn nn_dist_sq_is_zero_for_exact_match() {
221 let t = kd2_build(&[[1.0, 1.0], [2.0, 2.0]]);
222 let d = kd2_nn_dist_sq(&t, [1.0, 1.0]).expect("should succeed");
223 assert!(d < 1e-6);
224 }
225
226 #[test]
227 fn nearest_among_two_picks_closer() {
228 let t = kd2_build(&[[0.0, 0.0], [10.0, 0.0]]);
229 let (p, _) = t.nearest([3.0, 0.0]).expect("should succeed");
230 assert_eq!(p.id, 0);
231 }
232
233 #[test]
234 fn build_from_raw_xy_assigns_ids() {
235 let t = kd2_build(&[[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]]);
236 assert_eq!(t.len(), 3);
237 }
238}