kd_tree_rs/
lib.rs

1//! # KdTree
2//!
3//! A simple k-d tree implementation in Rust.
4//!
5//! Data structure for efficiently finding points in a k-dimensional space.
6//!
7//! This is an under development implementation of a KD Tree in Rust.
8//! Below is a list of features that are currently implemented and features that are planned to be implemented.
9//!
10//! * [x] Build Tree
11//! * [x] Find All Points Within A Radius
12//! * [x] Find Nearest Neighbor
13//! * [x] Insert New Point
14//! * [x] Find **N** Nearest Neighbors
15//! * [ ] Delete Point
16//! * [ ] Re-Balance Tree
17//! * [ ] Serialize Tree
18//! * [ ] Publish Crate
19//! * [ ] Add **K** dimensions **(Currently only 2D)**
20//! * [x] Add Examples
21//!
22//! This was developed initially as a way to learn Rust and to implement a KD Tree for a boids simulation although the
23//! simulation is in progress. I plan to continue to work on this project as I learn more about Rust and as I have time.
24//!
25//! ## Usage
26//!
27//! [`KdNode`](struct.KdNode.html) is the main data structure for the KD Tree. It is a generic struct that takes a type
28//!
29//! [`Point`](point/struct.Point.html) which is a struct that contains the x and y coordinates of a point in 2D space.
30//!
31//! The type of the x and y coordinates can be any type that can implement the [`KDT`](trait.KDT.html) trait.
32//! This trait is implemented for all types that implement the following traits:
33//! [`PartialEq`](https://doc.rust-lang.org/std/cmp/trait.PartialEq.html),
34//! [`PartialOrd`](https://doc.rust-lang.org/std/cmp/trait.PartialOrd.html),
35//! [`Into<f64>`](https://doc.rust-lang.org/std/convert/trait.Into.html),
36//! [`Copy`](https://doc.rust-lang.org/std/marker/trait.Copy.html),
37//! [`Add`](https://doc.rust-lang.org/std/ops/trait.Add.html),
38//! [`Sub`](https://doc.rust-lang.org/std/ops/trait.Sub.html),
39//! [`Mul`](https://doc.rust-lang.org/std/ops/trait.Mul.html).
40//!
41//!
42//! ```rust
43//! extern crate kd_tree_rs;
44//!
45//! use kd_tree_rs::KdNode;
46//! use kd_tree_rs::KdNode::Empty;
47//! use kd_tree_rs::point::Point;
48//!
49//! fn main() {
50//!    let mut node: KdNode<i32> = KdNode::new();
51//!
52//!    node.insert(1, 1);
53//!    node.insert(2, 2);
54//!
55//!    assert_eq!(node.nearest_neighbor(Point{x: 1, y: 1}, 1.0), vec![Point{x: 1, y: 1}]);
56//! }
57//! ```
58//!
59//! ## References
60//!
61//! * [KD Tree](https://en.wikipedia.org/wiki/K-d_tree)
62//! * [KD Tree Visualization](https://www.cs.usfca.edu/~galles/visualization/KDTree.html)
63//! * [KD Tree Nearest Neighbor](https://www.cs.cmu.edu/~ckingsf/bioinfo-lectures/kdtrees.pdf)
64//! * [Proof for neighborhood computation in expected logarithmic time - Martin Skrodzki](https://arxiv.org/pdf/1903.04936.pdf)
65//! * [Introduction to a KD Tree](https://yasenh.github.io/post/kd-tree/)
66
67
68extern crate core;
69
70pub mod dim;
71pub mod point;
72mod tests;
73
74pub use crate::dim::Dim;
75pub use crate::point::Point;
76pub use crate::KdNode::{Empty, Node};
77use crate::point::distance;
78use std::cmp::Ordering;
79use std::ops::{Add, Mul, Sub};
80
81pub trait KDT: PartialEq + PartialOrd + Copy + Mul + Sub + Add + Into<f64> {}
82impl<T> KDT for T where
83    T: PartialEq
84        + PartialOrd
85        + Copy
86        + Mul<Output = T>
87        + Sub<Output = T>
88        + Add<Output = T>
89        + Into<f64>
90{
91}
92
93#[derive(Debug, PartialEq)]
94pub enum KdNode<T: KDT> {
95    Empty,
96    Node {
97        point: Point<T>,
98        dim: Dim,
99        left: Box<KdNode<T>>,
100        right: Box<KdNode<T>>,
101    },
102}
103
104impl<T: KDT + Mul<Output = T> + Sub<Output = T> + Add<Output = T> + std::fmt::Debug> KdNode<T> {
105    /// Create a new empty tree
106    pub fn new() -> Self {
107        Empty
108    }
109
110    /// Insert a new item into the tree
111    ///
112    /// This should used sparingly as it can unbalance the tree
113    /// and reduce performance. If there is a large change to the dataset
114    /// it is better to create a new tree. This effect has not been tested
115    /// though and could be totally fine in terms of performance for a large
116    /// number of inserts. A good rule of thumb may be if the tree size is
117    /// going to increase by more than 10% it may be better to create a new
118    /// tree.
119    pub fn insert(&mut self, x: T, y: T) -> &Self {
120        self.insert_point(Point{ x, y })
121    }
122
123    /// Insert a new item into the tree
124    ///
125    /// This is the same as `insert` but takes a `Point` instead of `x` and `y`
126    pub fn insert_point(&mut self, item: Point<T>) -> &Self {
127        self._insert(item, 0)
128    }
129
130    fn _insert(&mut self, item: Point<T>, depth: usize) -> &Self {
131        *self = match self {
132            Empty => Node {
133                point: item,
134                dim: Dim::from_depth(depth),
135                left: Box::new(Empty),
136                right: Box::new(Empty),
137            },
138            Node {
139                point, left, right, ..
140            } => {
141                let next_depth: usize = depth + 1;
142                if point.gt(&item, &Dim::from_depth(next_depth)) {
143                    right._insert(item, next_depth);
144                } else {
145                    left._insert(item, next_depth);
146                }
147                return self;
148            }
149        };
150
151        self
152    }
153
154    /// Find the nearest neighbors to the origin point
155    ///
156    /// This will return a vector of points that are within the radius of the origin point.
157    /// The radius is inclusive so if a point is exactly on the radius it will be included.
158    ///
159    pub fn nearest_neighbor<'a>(&self, origin: Point<T>, radius: f64) -> Vec<Point<T>> {
160        assert!(radius >= 0.0, "Radius must be positive");
161
162        let mut best_queue: Vec<(&KdNode<T>, f64)> = Vec::new();
163        let mut parent_queue: Vec<&KdNode<T>> = self.drill_down(origin);
164        let deepest: &KdNode<T> = parent_queue.get(0).unwrap();
165
166        deepest._nearest_neighbor(origin, radius, &mut best_queue, &mut parent_queue, None);
167
168        best_queue.retain(|(_, dist)| *dist <= radius);
169        return best_queue
170            .iter()
171            .map(|(node, _)| match node {
172                Node { point, .. } => point.clone(),
173                _ => panic!("Empty node in best queue"),
174            })
175            .collect();
176    }
177
178    /// Insert a new item into the tree
179    ///
180    /// This is the same as `insert` but takes a `Point` instead of `x` and `y`
181    pub fn nearest_neighbor_x_y<'a>(&self, x: T, y: T, radius: f64) -> Vec<Point<T>> {
182        self.nearest_neighbor(Point { x, y }, radius)
183    }
184
185    /// Find the nearest neighbors to the origin point
186    ///
187    /// This will return a vector of points that are within the radius of the origin point.
188    /// This is the same as `nearest_neighbor` but will only return the `max` number of points.
189    pub fn n_nearest_neighbor<'a>(&self, origin: Point<T>, max: usize) -> Vec<Point<T>> {
190        let mut best_queue: Vec<(&KdNode<T>, f64)> = Vec::new();
191        let mut parent_queue: Vec<&KdNode<T>> = self.drill_down(origin);
192        let deepest: &KdNode<T> = parent_queue.get(0).unwrap();
193
194        // TODO Should use just an option instead of `f64::MAX`.
195        deepest._nearest_neighbor(origin, f64::MAX, &mut best_queue, &mut parent_queue, Some(max));
196
197        return best_queue
198            .iter()
199            .map(|(node, _)| match node {
200                Node { point, .. } => point.clone(),
201                _ => panic!("Empty node in best queue"),
202            })
203            .collect();
204    }
205
206    /// Find the nearest neighbors to the origin point
207    ///
208    /// This is a recursive function that will recursively work its way up the tree
209    /// collecting all neighbours within the radius provided.
210    fn _nearest_neighbor<'a>(
211        &'a self,
212        origin: Point<T>,
213        radius: f64,
214        best_queue: &mut Vec<(&'a KdNode<T>, f64)>,
215        parent_queue: &mut Vec<&'a KdNode<T>>,
216        max: Option<usize>,
217    ) -> Vec<(&KdNode<T>, f64)> {
218        let parent = parent_queue.pop();
219        if parent.is_none() {
220            return best_queue.clone();
221        }
222
223        match parent.unwrap() {
224            Empty => {}
225            Node {
226                left,
227                right,
228                point,
229                ..
230            } => {
231                if let Some(max) = max {
232                    if best_queue.len() >= max {
233                        return vec![];
234                    }
235                }
236
237                // Add node point if in range.
238                let dis = distance(&origin, point);
239                if dis <= radius {
240                    KdNode::insert_sorted(best_queue, (parent.unwrap(), distance(&origin, point)));
241                }
242
243                for side_node in [left, right] {
244                    if !best_queue.iter()
245                        .find(|(a, _)| *a == side_node.as_ref())
246                        .is_some()
247                    {
248                        // Check if the radius actually overlaps the node children.
249                        match side_node.as_ref() {
250                            Node { point, dim, .. } => {
251                                if !point.in_radius(&origin, dim, radius) {
252                                    continue;
253                                }
254                            },
255                            _ => {}
256                        }
257
258                        parent_queue.push(side_node.as_ref());
259                        let temp =
260                            side_node._nearest_neighbor(origin, radius, best_queue, parent_queue, max);
261                        for (node, dist) in temp {
262                            if dist <= radius {
263                                if let Some(max) = max {
264                                    if best_queue.len() >= max {
265                                        return vec![];
266                                    }
267                                }
268                                KdNode::insert_sorted(best_queue, (node, dist));
269                            }
270                        }
271                    }
272                }
273
274                parent.unwrap()._nearest_neighbor(origin, radius, best_queue, parent_queue, max);
275            }
276        }
277
278        best_queue.clone()
279    }
280
281    /// Drill down the tree to find appropriate node and return the parents.
282    fn drill_down(&self, origin: Point<T>) -> Vec<&KdNode<T>> {
283        let mut parents: Vec<&KdNode<T>> = Vec::new();
284        let mut best_node: &KdNode<T> = self;
285        while let Node {
286            point,
287            left,
288            right,
289            dim,
290        } = best_node
291        {
292            parents.push(best_node);
293            if *left == Box::new(Empty) && *right == Box::new(Empty) {
294                break;
295            }
296
297            match point.cmp(&origin, dim) {
298                Ordering::Less => best_node = left,
299                _ => best_node = right,
300            }
301        }
302        return parents;
303    }
304
305    /// Insert a point into a sorted list if it is not already in the list.
306    fn insert_sorted<'a>(
307        points: &mut Vec<(&'a KdNode<T>, f64)>,
308        point: (&'a KdNode<T>, f64),
309    ) -> () {
310        let mut index: usize = 0;
311        for (i, (node, dist)) in points.iter().enumerate() {
312            if *dist < point.1 {
313                index = i + 1;
314            }
315            if *node == point.0 && *node != &Empty {
316                return;
317            }
318        }
319        points.insert(index, point);
320    }
321
322    pub fn build(points: Vec<Point<T>>) -> Self {
323        KdNode::_build(points, 1)
324    }
325
326    fn _build(points: Vec<Point<T>>, depth: usize) -> Self {
327        // Increment the dimension
328        let next_depth: usize = depth + 1;
329
330        // End recursion if there are one or no points
331        if points.is_empty() {
332            return Empty;
333        } else if points.len() == 1 {
334            return Node {
335                point: points[0].clone(),
336                dim: Dim::from_depth(next_depth),
337                left: Box::new(Empty),
338                right: Box::new(Empty),
339            };
340        }
341
342        // Choose axis
343        let axis = Dim::from_depth(next_depth);
344
345        // Get Median
346        let (median, left, right): (Point<T>, Vec<Point<T>>, Vec<Point<T>>) =
347            KdNode::split_on_median(points, &axis);
348
349        Node {
350            point: median,
351            dim: axis,
352            left: Box::from(Self::_build(left, next_depth)),
353            right: Box::from(Self::_build(right, next_depth)),
354        }
355    }
356
357    /// Split the points into two vectors based on the median
358    ///
359    /// The median is chosen based on the axis and returned along with
360    /// two separate vectors of points, the left and right of the median.
361    fn split_on_median(
362        mut points: Vec<Point<T>>,
363        axis: &Dim,
364    ) -> (Point<T>, Vec<Point<T>>, Vec<Point<T>>) {
365        points.sort_by(|a: &Point<T>, b: &Point<T>| a.cmp(&b, &axis));
366        let median_index: usize = if points.len() % 2 == 0 {
367            points.len() / 2 - 1
368        } else {
369            points.len() / 2
370        };
371        let median: Point<T> = points[median_index];
372        let right: Vec<Point<T>> = points.drain(..median_index).collect();
373        let left: Vec<Point<T>> = points.drain(1..).collect();
374        (median, left, right)
375    }
376}