acap/
kd.rs

1//! [k-d trees](https://en.wikipedia.org/wiki/K-d_tree).
2
3use crate::coords::Coordinates;
4use crate::distance::Proximity;
5use crate::lp::Minkowski;
6use crate::knn::{ExactNeighbors, NearestNeighbors, Neighborhood};
7use crate::util::Ordered;
8
9use num_traits::Signed;
10
11use alloc::boxed::Box;
12use alloc::vec::Vec;
13
14/// A node in a k-d tree.
15#[derive(Debug)]
16struct KdNode<T> {
17    /// The item stored in this node.
18    item: T,
19    /// The left subtree, if any.
20    left: Option<Box<Self>>,
21    /// The right subtree, if any.
22    right: Option<Box<Self>>,
23}
24
25impl<T: Coordinates> KdNode<T> {
26    /// Create a new KdNode.
27    fn new(item: T) -> Self {
28        Self {
29            item,
30            left: None,
31            right: None,
32        }
33    }
34
35    /// Create a balanced tree.
36    fn balanced<I: IntoIterator<Item = T>>(items: I) -> Option<Self> {
37        let mut nodes: Vec<_> = items
38            .into_iter()
39            .map(Self::new)
40            .map(Box::new)
41            .map(Some)
42            .collect();
43
44        Self::balanced_recursive(&mut nodes, 0)
45            .map(|node| *node)
46    }
47
48    /// Create a balanced subtree.
49    fn balanced_recursive(nodes: &mut [Option<Box<Self>>], level: usize) -> Option<Box<Self>> {
50        if nodes.is_empty() {
51            return None;
52        }
53
54        nodes.sort_unstable_by_key(|x| Ordered::new(x.as_ref().unwrap().item.coord(level)));
55
56        let (left, right) = nodes.split_at_mut(nodes.len() / 2);
57        let (node, right) = right.split_first_mut().unwrap();
58        let mut node = node.take().unwrap();
59
60        let next = (level + 1) % node.item.dims();
61        node.left = Self::balanced_recursive(left, next);
62        node.right = Self::balanced_recursive(right, next);
63
64        Some(node)
65    }
66
67    /// Push a new item into this subtree.
68    fn push(&mut self, item: T, level: usize) {
69        let next = (level + 1) % item.dims();
70
71        if item.coord(level) <= self.item.coord(level) {
72            if let Some(left) = &mut self.left {
73                left.push(item, next);
74            } else {
75                self.left = Some(Box::new(Self::new(item)));
76            }
77        } else {
78            if let Some(right) = &mut self.right {
79                right.push(item, next);
80            } else {
81                self.right = Some(Box::new(Self::new(item)));
82            }
83        }
84    }
85}
86
87/// Marker trait for [`Proximity`] implementations that are compatible with k-d trees.
88pub trait KdProximity<V: ?Sized = Self>
89where
90    Self: Coordinates<Value = V::Value>,
91    Self: Proximity<V>,
92    Self::Value: PartialOrd<Self::Distance>,
93    V: Coordinates,
94{}
95
96/// Blanket [`KdProximity`] implementation.
97impl<K, V> KdProximity<V> for K
98where
99    K: Coordinates<Value = V::Value>,
100    K: Proximity<V>,
101    K::Value: PartialOrd<K::Distance>,
102    V: Coordinates,
103{}
104
105trait KdSearch<K, V, N>: Copy
106where
107    K: KdProximity<V>,
108    K::Value: PartialOrd<K::Distance>,
109    V: Coordinates + Copy,
110    N: Neighborhood<K, V>,
111{
112    /// Get this node's item.
113    fn item(self) -> V;
114
115    /// Get the left subtree.
116    fn left(self) -> Option<Self>;
117
118    /// Get the right subtree.
119    fn right(self) -> Option<Self>;
120
121    /// Recursively search for nearest neighbors.
122    fn search(self, level: usize, neighborhood: &mut N) {
123        let item = self.item();
124        neighborhood.consider(item);
125
126        let target = neighborhood.target();
127
128        let bound = target.coord(level) - item.coord(level);
129        let (near, far) = if bound.is_negative() {
130            (self.left(), self.right())
131        } else {
132            (self.right(), self.left())
133        };
134
135        let next = (level + 1) % self.item().dims();
136
137        if let Some(near) = near {
138            near.search(next, neighborhood);
139        }
140
141        if let Some(far) = far {
142            if neighborhood.contains(bound.abs()) {
143                far.search(next, neighborhood);
144            }
145        }
146    }
147}
148
149impl<'a, K, V, N> KdSearch<K, &'a V, N> for &'a KdNode<V>
150where
151    K: KdProximity<&'a V>,
152    K::Value: PartialOrd<K::Distance>,
153    V: Coordinates,
154    N: Neighborhood<K, &'a V>,
155{
156    fn item(self) -> &'a V {
157        &self.item
158    }
159
160    fn left(self) -> Option<Self> {
161        self.left.as_deref()
162    }
163
164    fn right(self) -> Option<Self> {
165        self.right.as_deref()
166    }
167}
168
169/// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree).
170#[derive(Debug)]
171pub struct KdTree<T> {
172    root: Option<KdNode<T>>,
173}
174
175impl<T: Coordinates> KdTree<T> {
176    /// Create an empty tree.
177    pub fn new() -> Self {
178        Self { root: None }
179    }
180
181    /// Create a balanced tree out of a sequence of items.
182    pub fn balanced<I: IntoIterator<Item = T>>(items: I) -> Self {
183        Self {
184            root: KdNode::balanced(items),
185        }
186    }
187
188    /// Iterate over the items stored in this tree.
189    pub fn iter(&self) -> Iter<'_, T> {
190        self.into_iter()
191    }
192
193    /// Rebalance this k-d tree.
194    pub fn balance(&mut self) {
195        let mut nodes = Vec::new();
196        if let Some(root) = self.root.take() {
197            nodes.push(Some(Box::new(root)));
198        }
199
200        let mut i = 0;
201        while i < nodes.len() {
202            let node = nodes[i].as_mut().unwrap();
203            let inside = node.left.take();
204            let outside = node.right.take();
205            if inside.is_some() {
206                nodes.push(inside);
207            }
208            if outside.is_some() {
209                nodes.push(outside);
210            }
211
212            i += 1;
213        }
214
215        self.root = KdNode::balanced_recursive(&mut nodes, 0)
216            .map(|node| *node);
217    }
218
219    /// Push a new item into the tree.
220    ///
221    /// Inserting elements individually tends to unbalance the tree.  Use [`KdTree::balanced()`] if
222    /// possible to create a balanced tree from a batch of items.
223    pub fn push(&mut self, item: T) {
224        if let Some(root) = &mut self.root {
225            root.push(item, 0);
226        } else {
227            self.root = Some(KdNode::new(item));
228        }
229    }
230}
231
232impl<T: Coordinates> Default for KdTree<T> {
233    fn default() -> Self {
234        Self::new()
235    }
236}
237
238impl<T: Coordinates> Extend<T> for KdTree<T> {
239    fn extend<I: IntoIterator<Item = T>>(&mut self, items: I) {
240        if self.root.is_some() {
241            for item in items {
242                self.push(item);
243            }
244        } else {
245            self.root = KdNode::balanced(items);
246        }
247    }
248}
249
250impl<T: Coordinates> FromIterator<T> for KdTree<T> {
251    fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
252        Self::balanced(items)
253    }
254}
255
256/// An iterator that moves values out of a k-d tree.
257#[derive(Debug)]
258pub struct IntoIter<T> {
259    stack: Vec<KdNode<T>>,
260}
261
262impl<T> IntoIter<T> {
263    fn new(node: Option<KdNode<T>>) -> Self {
264        Self {
265            stack: node.into_iter().collect(),
266        }
267    }
268}
269
270impl<T> Iterator for IntoIter<T> {
271    type Item = T;
272
273    fn next(&mut self) -> Option<Self::Item> {
274        self.stack.pop().map(|node| {
275            if let Some(left) = node.left {
276                self.stack.push(*left);
277            }
278            if let Some(right) = node.right {
279                self.stack.push(*right);
280            }
281            node.item
282        })
283    }
284}
285
286impl<T> IntoIterator for KdTree<T> {
287    type Item = T;
288    type IntoIter = IntoIter<T>;
289
290    fn into_iter(self) -> Self::IntoIter {
291        IntoIter::new(self.root)
292    }
293}
294
295/// An iterator over the values in a k-d tree.
296#[derive(Debug)]
297pub struct Iter<'a, T> {
298    stack: Vec<&'a KdNode<T>>,
299}
300
301impl<'a, T> Iter<'a, T> {
302    fn new(node: &'a Option<KdNode<T>>) -> Self {
303        Self {
304            stack: node.as_ref().into_iter().collect(),
305        }
306    }
307}
308
309impl<'a, T> Iterator for Iter<'a, T> {
310    type Item = &'a T;
311
312    fn next(&mut self) -> Option<Self::Item> {
313        self.stack.pop().map(|node| {
314            if let Some(left) = &node.left {
315                self.stack.push(left);
316            }
317            if let Some(right) = &node.right {
318                self.stack.push(right);
319            }
320            &node.item
321        })
322    }
323}
324
325impl<'a, T> IntoIterator for &'a KdTree<T> {
326    type Item = &'a T;
327    type IntoIter = Iter<'a, T>;
328
329    fn into_iter(self) -> Self::IntoIter {
330        Iter::new(&self.root)
331    }
332}
333
334impl<K, V> NearestNeighbors<K, V> for KdTree<V>
335where
336    K: KdProximity<V>,
337    K::Value: PartialOrd<K::Distance>,
338    V: Coordinates,
339{
340    fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N
341    where
342        K: 'k,
343        V: 'v,
344        N: Neighborhood<&'k K, &'v V>,
345    {
346        if let Some(root) = &self.root {
347            root.search(0, &mut neighborhood);
348        }
349        neighborhood
350    }
351}
352
353/// k-d trees are exact for [Minkowski] distances.
354impl<K, V> ExactNeighbors<K, V> for KdTree<V>
355where
356    K: KdProximity<V> + Minkowski<V>,
357    K::Value: PartialOrd<K::Distance>,
358    V: Coordinates,
359{}
360
361/// A node in a flat k-d tree.
362#[derive(Debug)]
363struct FlatKdNode<T> {
364    /// The item stored in this node.
365    item: T,
366    /// The size of the left subtree.
367    left_len: usize,
368}
369
370impl<T: Coordinates> FlatKdNode<T> {
371    /// Create a new FlatKdNode.
372    fn new(item: T) -> Self {
373        Self {
374            item,
375            left_len: 0,
376        }
377    }
378
379    /// Create a balanced tree.
380    fn balanced<I: IntoIterator<Item = T>>(items: I) -> Vec<Self> {
381        let mut nodes: Vec<_> = items
382            .into_iter()
383            .map(Self::new)
384            .collect();
385
386        Self::balance_recursive(&mut nodes, 0);
387
388        nodes
389    }
390
391    /// Create a balanced subtree.
392    fn balance_recursive(nodes: &mut [Self], level: usize) {
393        if !nodes.is_empty() {
394            nodes.sort_unstable_by_key(|x| Ordered::new(x.item.coord(level)));
395
396            let mid = nodes.len() / 2;
397            nodes.swap(0, mid);
398
399            let (node, children) = nodes.split_first_mut().unwrap();
400            let (left, right) = children.split_at_mut(mid);
401            node.left_len = left.len();
402
403            let next = (level + 1) % node.item.dims();
404            Self::balance_recursive(left, next);
405            Self::balance_recursive(right, next);
406        }
407    }
408}
409
410impl<'a, K, V, N> KdSearch<K, &'a V, N> for &'a [FlatKdNode<V>]
411where
412    K: KdProximity<&'a V>,
413    K::Value: PartialOrd<K::Distance>,
414    V: Coordinates,
415    N: Neighborhood<K, &'a V>,
416{
417    fn item(self) -> &'a V {
418        &self[0].item
419    }
420
421    fn left(self) -> Option<Self> {
422        let end = self[0].left_len + 1;
423        if end > 1 {
424            Some(&self[1..end])
425        } else {
426            None
427        }
428    }
429
430    fn right(self) -> Option<Self> {
431        let start = self[0].left_len + 1;
432        if start < self.len() {
433            Some(&self[start..])
434        } else {
435            None
436        }
437    }
438}
439
440/// A [k-d tree] stored as a flat array.
441///
442/// A FlatKdTree is always balanced and usually more efficient than a [`KdTree`], but doesn't
443/// support dynamic updates.
444///
445/// [k-d tree]: https://en.wikipedia.org/wiki/K-d_tree
446#[derive(Debug)]
447pub struct FlatKdTree<T> {
448    nodes: Vec<FlatKdNode<T>>,
449}
450
451impl<T: Coordinates> FlatKdTree<T> {
452    /// Create a balanced tree out of a sequence of items.
453    pub fn balanced<I: IntoIterator<Item = T>>(items: I) -> Self {
454        Self {
455            nodes: FlatKdNode::balanced(items),
456        }
457    }
458
459    /// Iterate over the items stored in this tree.
460    pub fn iter(&self) -> FlatIter<'_, T> {
461        self.into_iter()
462    }
463}
464
465impl<T: Coordinates> FromIterator<T> for FlatKdTree<T> {
466    fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
467        Self::balanced(items)
468    }
469}
470
471/// An iterator that moves values out of a flat k-d tree.
472#[derive(Debug)]
473pub struct FlatIntoIter<T>(alloc::vec::IntoIter<FlatKdNode<T>>);
474
475impl<T> Iterator for FlatIntoIter<T> {
476    type Item = T;
477
478    fn next(&mut self) -> Option<Self::Item> {
479        self.0.next().map(|n| n.item)
480    }
481}
482
483impl<T> IntoIterator for FlatKdTree<T> {
484    type Item = T;
485    type IntoIter = FlatIntoIter<T>;
486
487    fn into_iter(self) -> Self::IntoIter {
488        FlatIntoIter(self.nodes.into_iter())
489    }
490}
491
492/// An iterator over the values in a flat k-d tree.
493#[derive(Debug)]
494pub struct FlatIter<'a, T>(core::slice::Iter<'a, FlatKdNode<T>>);
495
496impl<'a, T> Iterator for FlatIter<'a, T> {
497    type Item = &'a T;
498
499    fn next(&mut self) -> Option<Self::Item> {
500        self.0.next().map(|n| &n.item)
501    }
502}
503
504impl<'a, T> IntoIterator for &'a FlatKdTree<T> {
505    type Item = &'a T;
506    type IntoIter = FlatIter<'a, T>;
507
508    fn into_iter(self) -> Self::IntoIter {
509        FlatIter(self.nodes.iter())
510    }
511}
512
513impl<K, V> NearestNeighbors<K, V> for FlatKdTree<V>
514where
515    K: KdProximity<V>,
516    K::Value: PartialOrd<K::Distance>,
517    V: Coordinates,
518{
519    fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N
520    where
521        K: 'k,
522        V: 'v,
523        N: Neighborhood<&'k K, &'v V>,
524    {
525        if !self.nodes.is_empty() {
526            self.nodes.as_slice().search(0, &mut neighborhood);
527        }
528        neighborhood
529    }
530}
531
532/// k-d trees are exact for [Minkowski] distances.
533impl<K, V> ExactNeighbors<K, V> for FlatKdTree<V>
534where
535    K: KdProximity<V> + Minkowski<V>,
536    K::Value: PartialOrd<K::Distance>,
537    V: Coordinates,
538{}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543
544    use crate::knn::tests::test_exact_neighbors;
545
546    #[test]
547    fn test_kd_tree() {
548        test_exact_neighbors(KdTree::from_iter);
549    }
550
551    #[test]
552    fn test_unbalanced_kd_tree() {
553        test_exact_neighbors(|points| {
554            let mut tree = KdTree::new();
555            for point in points {
556                tree.push(point);
557            }
558            tree
559        });
560    }
561
562    #[test]
563    fn test_flat_kd_tree() {
564        test_exact_neighbors(FlatKdTree::from_iter);
565    }
566}