1use crate::coords::Coordinates;
4use crate::distance::Proximity;
5use crate::lp::Minkowski;
6use crate::knn::{ExactNeighbors, NearestNeighbors, Neighborhood};
7use crate::util::Ordered;
8
9use num_traits::Signed;
10
11use alloc::boxed::Box;
12use alloc::vec::Vec;
13
14#[derive(Debug)]
16struct KdNode<T> {
17 item: T,
19 left: Option<Box<Self>>,
21 right: Option<Box<Self>>,
23}
24
25impl<T: Coordinates> KdNode<T> {
26 fn new(item: T) -> Self {
28 Self {
29 item,
30 left: None,
31 right: None,
32 }
33 }
34
35 fn balanced<I: IntoIterator<Item = T>>(items: I) -> Option<Self> {
37 let mut nodes: Vec<_> = items
38 .into_iter()
39 .map(Self::new)
40 .map(Box::new)
41 .map(Some)
42 .collect();
43
44 Self::balanced_recursive(&mut nodes, 0)
45 .map(|node| *node)
46 }
47
48 fn balanced_recursive(nodes: &mut [Option<Box<Self>>], level: usize) -> Option<Box<Self>> {
50 if nodes.is_empty() {
51 return None;
52 }
53
54 nodes.sort_unstable_by_key(|x| Ordered::new(x.as_ref().unwrap().item.coord(level)));
55
56 let (left, right) = nodes.split_at_mut(nodes.len() / 2);
57 let (node, right) = right.split_first_mut().unwrap();
58 let mut node = node.take().unwrap();
59
60 let next = (level + 1) % node.item.dims();
61 node.left = Self::balanced_recursive(left, next);
62 node.right = Self::balanced_recursive(right, next);
63
64 Some(node)
65 }
66
67 fn push(&mut self, item: T, level: usize) {
69 let next = (level + 1) % item.dims();
70
71 if item.coord(level) <= self.item.coord(level) {
72 if let Some(left) = &mut self.left {
73 left.push(item, next);
74 } else {
75 self.left = Some(Box::new(Self::new(item)));
76 }
77 } else {
78 if let Some(right) = &mut self.right {
79 right.push(item, next);
80 } else {
81 self.right = Some(Box::new(Self::new(item)));
82 }
83 }
84 }
85}
86
87pub trait KdProximity<V: ?Sized = Self>
89where
90 Self: Coordinates<Value = V::Value>,
91 Self: Proximity<V>,
92 Self::Value: PartialOrd<Self::Distance>,
93 V: Coordinates,
94{}
95
96impl<K, V> KdProximity<V> for K
98where
99 K: Coordinates<Value = V::Value>,
100 K: Proximity<V>,
101 K::Value: PartialOrd<K::Distance>,
102 V: Coordinates,
103{}
104
105trait KdSearch<K, V, N>: Copy
106where
107 K: KdProximity<V>,
108 K::Value: PartialOrd<K::Distance>,
109 V: Coordinates + Copy,
110 N: Neighborhood<K, V>,
111{
112 fn item(self) -> V;
114
115 fn left(self) -> Option<Self>;
117
118 fn right(self) -> Option<Self>;
120
121 fn search(self, level: usize, neighborhood: &mut N) {
123 let item = self.item();
124 neighborhood.consider(item);
125
126 let target = neighborhood.target();
127
128 let bound = target.coord(level) - item.coord(level);
129 let (near, far) = if bound.is_negative() {
130 (self.left(), self.right())
131 } else {
132 (self.right(), self.left())
133 };
134
135 let next = (level + 1) % self.item().dims();
136
137 if let Some(near) = near {
138 near.search(next, neighborhood);
139 }
140
141 if let Some(far) = far {
142 if neighborhood.contains(bound.abs()) {
143 far.search(next, neighborhood);
144 }
145 }
146 }
147}
148
149impl<'a, K, V, N> KdSearch<K, &'a V, N> for &'a KdNode<V>
150where
151 K: KdProximity<&'a V>,
152 K::Value: PartialOrd<K::Distance>,
153 V: Coordinates,
154 N: Neighborhood<K, &'a V>,
155{
156 fn item(self) -> &'a V {
157 &self.item
158 }
159
160 fn left(self) -> Option<Self> {
161 self.left.as_deref()
162 }
163
164 fn right(self) -> Option<Self> {
165 self.right.as_deref()
166 }
167}
168
169#[derive(Debug)]
171pub struct KdTree<T> {
172 root: Option<KdNode<T>>,
173}
174
175impl<T: Coordinates> KdTree<T> {
176 pub fn new() -> Self {
178 Self { root: None }
179 }
180
181 pub fn balanced<I: IntoIterator<Item = T>>(items: I) -> Self {
183 Self {
184 root: KdNode::balanced(items),
185 }
186 }
187
188 pub fn iter(&self) -> Iter<'_, T> {
190 self.into_iter()
191 }
192
193 pub fn balance(&mut self) {
195 let mut nodes = Vec::new();
196 if let Some(root) = self.root.take() {
197 nodes.push(Some(Box::new(root)));
198 }
199
200 let mut i = 0;
201 while i < nodes.len() {
202 let node = nodes[i].as_mut().unwrap();
203 let inside = node.left.take();
204 let outside = node.right.take();
205 if inside.is_some() {
206 nodes.push(inside);
207 }
208 if outside.is_some() {
209 nodes.push(outside);
210 }
211
212 i += 1;
213 }
214
215 self.root = KdNode::balanced_recursive(&mut nodes, 0)
216 .map(|node| *node);
217 }
218
219 pub fn push(&mut self, item: T) {
224 if let Some(root) = &mut self.root {
225 root.push(item, 0);
226 } else {
227 self.root = Some(KdNode::new(item));
228 }
229 }
230}
231
232impl<T: Coordinates> Default for KdTree<T> {
233 fn default() -> Self {
234 Self::new()
235 }
236}
237
238impl<T: Coordinates> Extend<T> for KdTree<T> {
239 fn extend<I: IntoIterator<Item = T>>(&mut self, items: I) {
240 if self.root.is_some() {
241 for item in items {
242 self.push(item);
243 }
244 } else {
245 self.root = KdNode::balanced(items);
246 }
247 }
248}
249
250impl<T: Coordinates> FromIterator<T> for KdTree<T> {
251 fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
252 Self::balanced(items)
253 }
254}
255
256#[derive(Debug)]
258pub struct IntoIter<T> {
259 stack: Vec<KdNode<T>>,
260}
261
262impl<T> IntoIter<T> {
263 fn new(node: Option<KdNode<T>>) -> Self {
264 Self {
265 stack: node.into_iter().collect(),
266 }
267 }
268}
269
270impl<T> Iterator for IntoIter<T> {
271 type Item = T;
272
273 fn next(&mut self) -> Option<Self::Item> {
274 self.stack.pop().map(|node| {
275 if let Some(left) = node.left {
276 self.stack.push(*left);
277 }
278 if let Some(right) = node.right {
279 self.stack.push(*right);
280 }
281 node.item
282 })
283 }
284}
285
286impl<T> IntoIterator for KdTree<T> {
287 type Item = T;
288 type IntoIter = IntoIter<T>;
289
290 fn into_iter(self) -> Self::IntoIter {
291 IntoIter::new(self.root)
292 }
293}
294
295#[derive(Debug)]
297pub struct Iter<'a, T> {
298 stack: Vec<&'a KdNode<T>>,
299}
300
301impl<'a, T> Iter<'a, T> {
302 fn new(node: &'a Option<KdNode<T>>) -> Self {
303 Self {
304 stack: node.as_ref().into_iter().collect(),
305 }
306 }
307}
308
309impl<'a, T> Iterator for Iter<'a, T> {
310 type Item = &'a T;
311
312 fn next(&mut self) -> Option<Self::Item> {
313 self.stack.pop().map(|node| {
314 if let Some(left) = &node.left {
315 self.stack.push(left);
316 }
317 if let Some(right) = &node.right {
318 self.stack.push(right);
319 }
320 &node.item
321 })
322 }
323}
324
325impl<'a, T> IntoIterator for &'a KdTree<T> {
326 type Item = &'a T;
327 type IntoIter = Iter<'a, T>;
328
329 fn into_iter(self) -> Self::IntoIter {
330 Iter::new(&self.root)
331 }
332}
333
334impl<K, V> NearestNeighbors<K, V> for KdTree<V>
335where
336 K: KdProximity<V>,
337 K::Value: PartialOrd<K::Distance>,
338 V: Coordinates,
339{
340 fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N
341 where
342 K: 'k,
343 V: 'v,
344 N: Neighborhood<&'k K, &'v V>,
345 {
346 if let Some(root) = &self.root {
347 root.search(0, &mut neighborhood);
348 }
349 neighborhood
350 }
351}
352
353impl<K, V> ExactNeighbors<K, V> for KdTree<V>
355where
356 K: KdProximity<V> + Minkowski<V>,
357 K::Value: PartialOrd<K::Distance>,
358 V: Coordinates,
359{}
360
361#[derive(Debug)]
363struct FlatKdNode<T> {
364 item: T,
366 left_len: usize,
368}
369
370impl<T: Coordinates> FlatKdNode<T> {
371 fn new(item: T) -> Self {
373 Self {
374 item,
375 left_len: 0,
376 }
377 }
378
379 fn balanced<I: IntoIterator<Item = T>>(items: I) -> Vec<Self> {
381 let mut nodes: Vec<_> = items
382 .into_iter()
383 .map(Self::new)
384 .collect();
385
386 Self::balance_recursive(&mut nodes, 0);
387
388 nodes
389 }
390
391 fn balance_recursive(nodes: &mut [Self], level: usize) {
393 if !nodes.is_empty() {
394 nodes.sort_unstable_by_key(|x| Ordered::new(x.item.coord(level)));
395
396 let mid = nodes.len() / 2;
397 nodes.swap(0, mid);
398
399 let (node, children) = nodes.split_first_mut().unwrap();
400 let (left, right) = children.split_at_mut(mid);
401 node.left_len = left.len();
402
403 let next = (level + 1) % node.item.dims();
404 Self::balance_recursive(left, next);
405 Self::balance_recursive(right, next);
406 }
407 }
408}
409
410impl<'a, K, V, N> KdSearch<K, &'a V, N> for &'a [FlatKdNode<V>]
411where
412 K: KdProximity<&'a V>,
413 K::Value: PartialOrd<K::Distance>,
414 V: Coordinates,
415 N: Neighborhood<K, &'a V>,
416{
417 fn item(self) -> &'a V {
418 &self[0].item
419 }
420
421 fn left(self) -> Option<Self> {
422 let end = self[0].left_len + 1;
423 if end > 1 {
424 Some(&self[1..end])
425 } else {
426 None
427 }
428 }
429
430 fn right(self) -> Option<Self> {
431 let start = self[0].left_len + 1;
432 if start < self.len() {
433 Some(&self[start..])
434 } else {
435 None
436 }
437 }
438}
439
440#[derive(Debug)]
447pub struct FlatKdTree<T> {
448 nodes: Vec<FlatKdNode<T>>,
449}
450
451impl<T: Coordinates> FlatKdTree<T> {
452 pub fn balanced<I: IntoIterator<Item = T>>(items: I) -> Self {
454 Self {
455 nodes: FlatKdNode::balanced(items),
456 }
457 }
458
459 pub fn iter(&self) -> FlatIter<'_, T> {
461 self.into_iter()
462 }
463}
464
465impl<T: Coordinates> FromIterator<T> for FlatKdTree<T> {
466 fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
467 Self::balanced(items)
468 }
469}
470
471#[derive(Debug)]
473pub struct FlatIntoIter<T>(alloc::vec::IntoIter<FlatKdNode<T>>);
474
475impl<T> Iterator for FlatIntoIter<T> {
476 type Item = T;
477
478 fn next(&mut self) -> Option<Self::Item> {
479 self.0.next().map(|n| n.item)
480 }
481}
482
483impl<T> IntoIterator for FlatKdTree<T> {
484 type Item = T;
485 type IntoIter = FlatIntoIter<T>;
486
487 fn into_iter(self) -> Self::IntoIter {
488 FlatIntoIter(self.nodes.into_iter())
489 }
490}
491
492#[derive(Debug)]
494pub struct FlatIter<'a, T>(core::slice::Iter<'a, FlatKdNode<T>>);
495
496impl<'a, T> Iterator for FlatIter<'a, T> {
497 type Item = &'a T;
498
499 fn next(&mut self) -> Option<Self::Item> {
500 self.0.next().map(|n| &n.item)
501 }
502}
503
504impl<'a, T> IntoIterator for &'a FlatKdTree<T> {
505 type Item = &'a T;
506 type IntoIter = FlatIter<'a, T>;
507
508 fn into_iter(self) -> Self::IntoIter {
509 FlatIter(self.nodes.iter())
510 }
511}
512
513impl<K, V> NearestNeighbors<K, V> for FlatKdTree<V>
514where
515 K: KdProximity<V>,
516 K::Value: PartialOrd<K::Distance>,
517 V: Coordinates,
518{
519 fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N
520 where
521 K: 'k,
522 V: 'v,
523 N: Neighborhood<&'k K, &'v V>,
524 {
525 if !self.nodes.is_empty() {
526 self.nodes.as_slice().search(0, &mut neighborhood);
527 }
528 neighborhood
529 }
530}
531
532impl<K, V> ExactNeighbors<K, V> for FlatKdTree<V>
534where
535 K: KdProximity<V> + Minkowski<V>,
536 K::Value: PartialOrd<K::Distance>,
537 V: Coordinates,
538{}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543
544 use crate::knn::tests::test_exact_neighbors;
545
546 #[test]
547 fn test_kd_tree() {
548 test_exact_neighbors(KdTree::from_iter);
549 }
550
551 #[test]
552 fn test_unbalanced_kd_tree() {
553 test_exact_neighbors(|points| {
554 let mut tree = KdTree::new();
555 for point in points {
556 tree.push(point);
557 }
558 tree
559 });
560 }
561
562 #[test]
563 fn test_flat_kd_tree() {
564 test_exact_neighbors(FlatKdTree::from_iter);
565 }
566}