kd_tree_rs/lib.rs
1//! # KdTree
2//!
3//! A simple k-d tree implementation in Rust.
4//!
5//! Data structure for efficiently finding points in a k-dimensional space.
6//!
7//! This is an under development implementation of a KD Tree in Rust.
8//! Below is a list of features that are currently implemented and features that are planned to be implemented.
9//!
10//! * [x] Build Tree
11//! * [x] Find All Points Within A Radius
12//! * [x] Find Nearest Neighbor
13//! * [x] Insert New Point
14//! * [x] Find **N** Nearest Neighbors
15//! * [ ] Delete Point
16//! * [ ] Re-Balance Tree
17//! * [ ] Serialize Tree
18//! * [ ] Publish Crate
19//! * [ ] Add **K** dimensions **(Currently only 2D)**
20//! * [x] Add Examples
21//!
22//! This was developed initially as a way to learn Rust and to implement a KD Tree for a boids simulation although the
23//! simulation is in progress. I plan to continue to work on this project as I learn more about Rust and as I have time.
24//!
25//! ## Usage
26//!
27//! [`KdNode`](struct.KdNode.html) is the main data structure for the KD Tree. It is a generic struct that takes a type
28//!
29//! [`Point`](point/struct.Point.html) which is a struct that contains the x and y coordinates of a point in 2D space.
30//!
31//! The type of the x and y coordinates can be any type that can implement the [`KDT`](trait.KDT.html) trait.
32//! This trait is implemented for all types that implement the following traits:
33//! [`PartialEq`](https://doc.rust-lang.org/std/cmp/trait.PartialEq.html),
34//! [`PartialOrd`](https://doc.rust-lang.org/std/cmp/trait.PartialOrd.html),
35//! [`Into<f64>`](https://doc.rust-lang.org/std/convert/trait.Into.html),
36//! [`Copy`](https://doc.rust-lang.org/std/marker/trait.Copy.html),
37//! [`Add`](https://doc.rust-lang.org/std/ops/trait.Add.html),
38//! [`Sub`](https://doc.rust-lang.org/std/ops/trait.Sub.html),
39//! [`Mul`](https://doc.rust-lang.org/std/ops/trait.Mul.html).
40//!
41//!
42//! ```rust
43//! extern crate kd_tree_rs;
44//!
45//! use kd_tree_rs::KdNode;
46//! use kd_tree_rs::KdNode::Empty;
47//! use kd_tree_rs::point::Point;
48//!
49//! fn main() {
50//! let mut node: KdNode<i32> = KdNode::new();
51//!
52//! node.insert(1, 1);
53//! node.insert(2, 2);
54//!
55//! assert_eq!(node.nearest_neighbor(Point{x: 1, y: 1}, 1.0), vec![Point{x: 1, y: 1}]);
56//! }
57//! ```
58//!
59//! ## References
60//!
61//! * [KD Tree](https://en.wikipedia.org/wiki/K-d_tree)
62//! * [KD Tree Visualization](https://www.cs.usfca.edu/~galles/visualization/KDTree.html)
63//! * [KD Tree Nearest Neighbor](https://www.cs.cmu.edu/~ckingsf/bioinfo-lectures/kdtrees.pdf)
64//! * [Proof for neighborhood computation in expected logarithmic time - Martin Skrodzki](https://arxiv.org/pdf/1903.04936.pdf)
65//! * [Introduction to a KD Tree](https://yasenh.github.io/post/kd-tree/)
66
67
68extern crate core;
69
70pub mod dim;
71pub mod point;
72mod tests;
73
74pub use crate::dim::Dim;
75pub use crate::point::Point;
76pub use crate::KdNode::{Empty, Node};
77use crate::point::distance;
78use std::cmp::Ordering;
79use std::ops::{Add, Mul, Sub};
80
81pub trait KDT: PartialEq + PartialOrd + Copy + Mul + Sub + Add + Into<f64> {}
82impl<T> KDT for T where
83 T: PartialEq
84 + PartialOrd
85 + Copy
86 + Mul<Output = T>
87 + Sub<Output = T>
88 + Add<Output = T>
89 + Into<f64>
90{
91}
92
93#[derive(Debug, PartialEq)]
94pub enum KdNode<T: KDT> {
95 Empty,
96 Node {
97 point: Point<T>,
98 dim: Dim,
99 left: Box<KdNode<T>>,
100 right: Box<KdNode<T>>,
101 },
102}
103
104impl<T: KDT + Mul<Output = T> + Sub<Output = T> + Add<Output = T> + std::fmt::Debug> KdNode<T> {
105 /// Create a new empty tree
106 pub fn new() -> Self {
107 Empty
108 }
109
110 /// Insert a new item into the tree
111 ///
112 /// This should used sparingly as it can unbalance the tree
113 /// and reduce performance. If there is a large change to the dataset
114 /// it is better to create a new tree. This effect has not been tested
115 /// though and could be totally fine in terms of performance for a large
116 /// number of inserts. A good rule of thumb may be if the tree size is
117 /// going to increase by more than 10% it may be better to create a new
118 /// tree.
119 pub fn insert(&mut self, x: T, y: T) -> &Self {
120 self.insert_point(Point{ x, y })
121 }
122
123 /// Insert a new item into the tree
124 ///
125 /// This is the same as `insert` but takes a `Point` instead of `x` and `y`
126 pub fn insert_point(&mut self, item: Point<T>) -> &Self {
127 self._insert(item, 0)
128 }
129
130 fn _insert(&mut self, item: Point<T>, depth: usize) -> &Self {
131 *self = match self {
132 Empty => Node {
133 point: item,
134 dim: Dim::from_depth(depth),
135 left: Box::new(Empty),
136 right: Box::new(Empty),
137 },
138 Node {
139 point, left, right, ..
140 } => {
141 let next_depth: usize = depth + 1;
142 if point.gt(&item, &Dim::from_depth(next_depth)) {
143 right._insert(item, next_depth);
144 } else {
145 left._insert(item, next_depth);
146 }
147 return self;
148 }
149 };
150
151 self
152 }
153
154 /// Find the nearest neighbors to the origin point
155 ///
156 /// This will return a vector of points that are within the radius of the origin point.
157 /// The radius is inclusive so if a point is exactly on the radius it will be included.
158 ///
159 pub fn nearest_neighbor<'a>(&self, origin: Point<T>, radius: f64) -> Vec<Point<T>> {
160 assert!(radius >= 0.0, "Radius must be positive");
161
162 let mut best_queue: Vec<(&KdNode<T>, f64)> = Vec::new();
163 let mut parent_queue: Vec<&KdNode<T>> = self.drill_down(origin);
164 let deepest: &KdNode<T> = parent_queue.get(0).unwrap();
165
166 deepest._nearest_neighbor(origin, radius, &mut best_queue, &mut parent_queue, None);
167
168 best_queue.retain(|(_, dist)| *dist <= radius);
169 return best_queue
170 .iter()
171 .map(|(node, _)| match node {
172 Node { point, .. } => point.clone(),
173 _ => panic!("Empty node in best queue"),
174 })
175 .collect();
176 }
177
178 /// Insert a new item into the tree
179 ///
180 /// This is the same as `insert` but takes a `Point` instead of `x` and `y`
181 pub fn nearest_neighbor_x_y<'a>(&self, x: T, y: T, radius: f64) -> Vec<Point<T>> {
182 self.nearest_neighbor(Point { x, y }, radius)
183 }
184
185 /// Find the nearest neighbors to the origin point
186 ///
187 /// This will return a vector of points that are within the radius of the origin point.
188 /// This is the same as `nearest_neighbor` but will only return the `max` number of points.
189 pub fn n_nearest_neighbor<'a>(&self, origin: Point<T>, max: usize) -> Vec<Point<T>> {
190 let mut best_queue: Vec<(&KdNode<T>, f64)> = Vec::new();
191 let mut parent_queue: Vec<&KdNode<T>> = self.drill_down(origin);
192 let deepest: &KdNode<T> = parent_queue.get(0).unwrap();
193
194 // TODO Should use just an option instead of `f64::MAX`.
195 deepest._nearest_neighbor(origin, f64::MAX, &mut best_queue, &mut parent_queue, Some(max));
196
197 return best_queue
198 .iter()
199 .map(|(node, _)| match node {
200 Node { point, .. } => point.clone(),
201 _ => panic!("Empty node in best queue"),
202 })
203 .collect();
204 }
205
206 /// Find the nearest neighbors to the origin point
207 ///
208 /// This is a recursive function that will recursively work its way up the tree
209 /// collecting all neighbours within the radius provided.
210 fn _nearest_neighbor<'a>(
211 &'a self,
212 origin: Point<T>,
213 radius: f64,
214 best_queue: &mut Vec<(&'a KdNode<T>, f64)>,
215 parent_queue: &mut Vec<&'a KdNode<T>>,
216 max: Option<usize>,
217 ) -> Vec<(&KdNode<T>, f64)> {
218 let parent = parent_queue.pop();
219 if parent.is_none() {
220 return best_queue.clone();
221 }
222
223 match parent.unwrap() {
224 Empty => {}
225 Node {
226 left,
227 right,
228 point,
229 ..
230 } => {
231 if let Some(max) = max {
232 if best_queue.len() >= max {
233 return vec![];
234 }
235 }
236
237 // Add node point if in range.
238 let dis = distance(&origin, point);
239 if dis <= radius {
240 KdNode::insert_sorted(best_queue, (parent.unwrap(), distance(&origin, point)));
241 }
242
243 for side_node in [left, right] {
244 if !best_queue.iter()
245 .find(|(a, _)| *a == side_node.as_ref())
246 .is_some()
247 {
248 // Check if the radius actually overlaps the node children.
249 match side_node.as_ref() {
250 Node { point, dim, .. } => {
251 if !point.in_radius(&origin, dim, radius) {
252 continue;
253 }
254 },
255 _ => {}
256 }
257
258 parent_queue.push(side_node.as_ref());
259 let temp =
260 side_node._nearest_neighbor(origin, radius, best_queue, parent_queue, max);
261 for (node, dist) in temp {
262 if dist <= radius {
263 if let Some(max) = max {
264 if best_queue.len() >= max {
265 return vec![];
266 }
267 }
268 KdNode::insert_sorted(best_queue, (node, dist));
269 }
270 }
271 }
272 }
273
274 parent.unwrap()._nearest_neighbor(origin, radius, best_queue, parent_queue, max);
275 }
276 }
277
278 best_queue.clone()
279 }
280
281 /// Drill down the tree to find appropriate node and return the parents.
282 fn drill_down(&self, origin: Point<T>) -> Vec<&KdNode<T>> {
283 let mut parents: Vec<&KdNode<T>> = Vec::new();
284 let mut best_node: &KdNode<T> = self;
285 while let Node {
286 point,
287 left,
288 right,
289 dim,
290 } = best_node
291 {
292 parents.push(best_node);
293 if *left == Box::new(Empty) && *right == Box::new(Empty) {
294 break;
295 }
296
297 match point.cmp(&origin, dim) {
298 Ordering::Less => best_node = left,
299 _ => best_node = right,
300 }
301 }
302 return parents;
303 }
304
305 /// Insert a point into a sorted list if it is not already in the list.
306 fn insert_sorted<'a>(
307 points: &mut Vec<(&'a KdNode<T>, f64)>,
308 point: (&'a KdNode<T>, f64),
309 ) -> () {
310 let mut index: usize = 0;
311 for (i, (node, dist)) in points.iter().enumerate() {
312 if *dist < point.1 {
313 index = i + 1;
314 }
315 if *node == point.0 && *node != &Empty {
316 return;
317 }
318 }
319 points.insert(index, point);
320 }
321
322 pub fn build(points: Vec<Point<T>>) -> Self {
323 KdNode::_build(points, 1)
324 }
325
326 fn _build(points: Vec<Point<T>>, depth: usize) -> Self {
327 // Increment the dimension
328 let next_depth: usize = depth + 1;
329
330 // End recursion if there are one or no points
331 if points.is_empty() {
332 return Empty;
333 } else if points.len() == 1 {
334 return Node {
335 point: points[0].clone(),
336 dim: Dim::from_depth(next_depth),
337 left: Box::new(Empty),
338 right: Box::new(Empty),
339 };
340 }
341
342 // Choose axis
343 let axis = Dim::from_depth(next_depth);
344
345 // Get Median
346 let (median, left, right): (Point<T>, Vec<Point<T>>, Vec<Point<T>>) =
347 KdNode::split_on_median(points, &axis);
348
349 Node {
350 point: median,
351 dim: axis,
352 left: Box::from(Self::_build(left, next_depth)),
353 right: Box::from(Self::_build(right, next_depth)),
354 }
355 }
356
357 /// Split the points into two vectors based on the median
358 ///
359 /// The median is chosen based on the axis and returned along with
360 /// two separate vectors of points, the left and right of the median.
361 fn split_on_median(
362 mut points: Vec<Point<T>>,
363 axis: &Dim,
364 ) -> (Point<T>, Vec<Point<T>>, Vec<Point<T>>) {
365 points.sort_by(|a: &Point<T>, b: &Point<T>| a.cmp(&b, &axis));
366 let median_index: usize = if points.len() % 2 == 0 {
367 points.len() / 2 - 1
368 } else {
369 points.len() / 2
370 };
371 let median: Point<T> = points[median_index];
372 let right: Vec<Point<T>> = points.drain(..median_index).collect();
373 let left: Vec<Point<T>> = points.drain(1..).collect();
374 (median, left, right)
375 }
376}