1#[cfg(not(feature = "vector-traits"))]
28use approx::UlpsEq;
29use std::fmt::Display;
30use std::{
31 fmt,
32 fmt::Debug,
33 ops::{AddAssign, DivAssign, MulAssign, SubAssign},
34};
35#[cfg(feature = "vector-traits")]
36use vector_traits::approx::UlpsEq;
37
38#[cfg(not(feature = "vector-traits"))]
39use num_traits::{FromPrimitive, Zero, real::Real};
40#[cfg(feature = "vector-traits")]
41use vector_traits::num_traits::{FromPrimitive, Zero, real::Real};
42
43mod impls;
44
45#[cfg(test)]
46mod tests;
47
48#[derive(thiserror::Error, Debug)]
49pub enum KrakelError {
50 #[error("Unknown error: {0}")]
51 InternalError(String),
52}
53
54pub trait PointTrait: Clone + PartialEq
55where
56 Self::PScalar: Real
57 + FromPrimitive
58 + UlpsEq
59 + Debug
60 + Display
61 + PartialEq
62 + MulAssign
63 + SubAssign
64 + DivAssign
65 + AddAssign,
66{
67 type PScalar;
68 fn x(&self) -> Self::PScalar;
69 fn y(&self) -> Self::PScalar;
70 fn set_x(&mut self, x: Self::PScalar);
71 fn set_y(&mut self, y: Self::PScalar);
72
73 #[inline(always)]
75 fn dist_sq<Q: PointTrait<PScalar = Self::PScalar>>(a: &Self, b: &Q) -> Self::PScalar {
76 let dx: Self::PScalar = a.x() - b.x();
77 let dy: Self::PScalar = a.y() - b.y();
78 dx * dx + dy * dy
79 }
80
81 fn at(&self, index: u8) -> Self::PScalar;
82 fn at_mut(&mut self, index: u8) -> &mut Self::PScalar;
83 const DIMENSION: u8;
84}
85
86pub trait KDPoint<P: PointTrait> {
87 fn get_coordinate(&self, index: usize) -> P::PScalar;
88 fn set_coordinate(&mut self, index: usize, value: P::PScalar);
89}
90
91pub struct KDNode<P: PointTrait> {
92 pos: P,
93 dir: u8,
94 left: Option<Box<KDNode<P>>>,
95 right: Option<Box<KDNode<P>>>,
96}
97
98#[derive(Clone)]
99struct HyperRectangle<P: PointTrait> {
100 min: P,
101 max: P,
102}
103
104pub struct KDTree<P: PointTrait> {
105 root: Option<Box<KDNode<P>>>,
106 rect: Option<HyperRectangle<P>>,
107}
108
109impl<P: PointTrait> KDNode<P> {
110 fn recursive_insert(
111 node: &mut Option<Box<KDNode<P>>>,
112 pos: P,
113 dir: u8,
114 dim: u8,
115 ) -> Result<(), KrakelError> {
116 match node {
117 None => {
118 *node = Some(Box::new(KDNode {
119 pos,
120 dir,
121 left: None,
122 right: None,
123 }));
124 }
125 Some(current) => {
126 let new_dir = (current.dir + 1) % dim;
127 if pos.at(current.dir) < current.pos.at(current.dir) {
128 Self::recursive_insert(&mut current.left, pos, new_dir, dim)?;
129 } else {
130 Self::recursive_insert(&mut current.right, pos, new_dir, dim)?;
131 }
132 }
133 }
134 Ok(())
135 }
136
137 fn recursive_nearest<'a>(
138 &'a self,
139 pos: &P,
140 result: &mut Option<&'a P>,
141 result_dist_sq: &mut P::PScalar,
142 rect: &mut HyperRectangle<P>,
143 ) {
144 let dir = self.dir;
145
146 let (nearer_subtree, farther_subtree) = if pos.at(dir) <= self.pos.at(dir) {
147 (&self.left, &self.right)
148 } else {
149 (&self.right, &self.left)
150 };
151
152 let old_value = if pos.at(dir) <= self.pos.at(dir) {
153 std::mem::replace(&mut rect.max.at(dir), self.pos.at(dir))
154 } else {
155 std::mem::replace(&mut rect.min.at(dir), self.pos.at(dir))
156 };
157
158 if let Some(nearer_node) = nearer_subtree {
159 nearer_node.recursive_nearest(pos, result, result_dist_sq, rect);
160 }
161
162 if pos.at(dir) <= self.pos.at(dir) {
163 *rect.max.at_mut(dir) = old_value;
164 } else {
165 *rect.min.at_mut(dir) = old_value;
166 }
167
168 let dist_sq = PointTrait::dist_sq(&self.pos, pos);
169 if dist_sq < *result_dist_sq {
170 *result_dist_sq = dist_sq;
171 *result = Some(&self.pos);
172 }
173
174 if let Some(farther_node) = farther_subtree {
175 if KDTree::hyper_rect_dist_sq(rect, pos) < *result_dist_sq {
176 farther_node.recursive_nearest(pos, result, result_dist_sq, rect);
177 }
178 }
179 }
180
181 fn recursive_range_query<Q: PointTrait<PScalar = P::PScalar>>(
182 &self,
183 pos: &Q,
184 radius_sq: P::PScalar,
185 results: &mut Vec<P>,
186 rect: &mut HyperRectangle<P>,
187 ) {
188 let dir = self.dir;
189
190 let (nearer_subtree, farther_subtree) = if pos.at(dir) <= self.pos.at(dir) {
191 (&self.left, &self.right)
192 } else {
193 (&self.right, &self.left)
194 };
195
196 let old_value = if pos.at(dir) <= self.pos.at(dir) {
197 std::mem::replace(&mut rect.max.at(dir), self.pos.at(dir))
198 } else {
199 std::mem::replace(&mut rect.min.at(dir), self.pos.at(dir))
200 };
201
202 if let Some(nearer_node) = nearer_subtree {
203 nearer_node.recursive_range_query(pos, radius_sq, results, rect);
204 }
205
206 if pos.at(dir) <= self.pos.at(dir) {
207 *rect.max.at_mut(dir) = old_value;
208 } else {
209 *rect.min.at_mut(dir) = old_value;
210 }
211
212 let dist_sq = PointTrait::dist_sq(&self.pos, pos);
213 if dist_sq <= radius_sq {
214 results.push(self.pos.clone());
215 }
216
217 if let Some(farther_node) = farther_subtree {
218 if KDTree::hyper_rect_dist_sq(rect, pos) <= radius_sq {
219 farther_node.recursive_range_query(pos, radius_sq, results, rect);
220 }
221 }
222 }
223
224 fn recursive_closure_range_query<Q: PointTrait<PScalar = P::PScalar>, F>(
225 &self,
226 pos: &Q,
227 radius_sq: P::PScalar,
228 rect: &mut HyperRectangle<P>,
229 process: &mut F,
230 ) where
231 F: FnMut(&P),
232 {
233 let dir = self.dir;
234
235 let (nearer_subtree, farther_subtree) = if pos.at(dir) <= self.pos.at(dir) {
236 (&self.left, &self.right)
237 } else {
238 (&self.right, &self.left)
239 };
240
241 let old_value = if pos.at(dir) <= self.pos.at(dir) {
242 std::mem::replace(&mut rect.max.at(dir), self.pos.at(dir))
243 } else {
244 std::mem::replace(&mut rect.min.at(dir), self.pos.at(dir))
245 };
246
247 if let Some(nearer_node) = nearer_subtree {
248 nearer_node.recursive_closure_range_query(pos, radius_sq, rect, process);
249 }
250
251 if pos.at(dir) <= self.pos.at(dir) {
252 *rect.max.at_mut(dir) = old_value;
253 } else {
254 *rect.min.at_mut(dir) = old_value;
255 }
256
257 if PointTrait::dist_sq(&self.pos, pos) <= radius_sq {
258 process(&self.pos);
259 }
260
261 if let Some(farther_node) = farther_subtree {
262 if KDTree::hyper_rect_dist_sq(rect, pos) <= radius_sq {
263 farther_node.recursive_closure_range_query(pos, radius_sq, rect, process);
264 }
265 }
266 }
267
268 fn format_node(&self, f: &mut fmt::Formatter<'_>, depth: usize) -> fmt::Result {
269 for _ in 0..depth {
270 write!(f, " ")?;
271 }
272
273 write!(f, "d={} node at ", self.dir)?;
274 for i in 0..P::DIMENSION {
275 write!(f, "{} ", self.pos.at(i))?;
276 }
277 writeln!(f)?;
278
279 if let Some(ref left_node) = self.left {
280 left_node.format_node(f, depth + 1)?;
281 }
282
283 if let Some(ref right_node) = self.right {
284 right_node.format_node(f, depth + 1)?;
285 }
286
287 Ok(())
288 }
289}
290
291impl<P: PointTrait> KDTree<P> {
292 pub fn insert(&mut self, pos: P) -> Result<(), KrakelError> {
293 KDNode::recursive_insert(&mut self.root, pos.clone(), 0, P::DIMENSION)?;
294
295 if self.rect.is_none() {
296 self.rect = Some(HyperRectangle {
297 min: pos.clone(),
298 max: pos.clone(),
299 });
300 } else {
301 for i in 0..P::DIMENSION {
302 if pos.at(i) < self.rect.as_mut().unwrap().min.at(i) {
303 *self.rect.as_mut().unwrap().min.at_mut(i) = pos.at(i);
304 } else if pos.at(i) > self.rect.as_mut().unwrap().max.at(i) {
305 *self.rect.as_mut().unwrap().max.at_mut(i) = pos.at(i);
306 }
307 }
308 }
309 Ok(())
310 }
311
312 #[allow(dead_code)]
313 pub fn nearest(&self, pos: &P) -> Option<P> {
314 if let Some(root_node) = &self.root {
315 let mut rect = self.rect.clone().unwrap();
317 let mut result: Option<&P> = self.root.as_ref().map(|node| &node.pos);
318 let mut result_dist_sq = P::dist_sq(result.as_ref().unwrap(), pos);
319
320 root_node.recursive_nearest(pos, &mut result, &mut result_dist_sq, &mut rect);
321 result.cloned()
322 } else {
323 None
324 }
325 }
326
327 #[allow(dead_code)]
328 pub fn range_query<Q: PointTrait<PScalar = P::PScalar>>(
329 &self,
330 pos: &Q,
331 radius: P::PScalar,
332 ) -> Vec<P> {
333 if let Some(root_node) = &self.root {
334 let mut results: Vec<P> = Vec::new();
335 let mut cloned_rect = self.rect.clone().unwrap();
336
337 root_node.recursive_range_query(pos, radius * radius, &mut results, &mut cloned_rect);
338 results
339 } else {
340 Vec::new()
341 }
342 }
343
344 pub fn closure_range_query<Q: PointTrait<PScalar = P::PScalar>, F>(
345 &self,
346 pos: &Q,
347 radius: P::PScalar,
348 mut process: F,
349 ) where
350 F: FnMut(&P),
351 {
352 if let Some(root_node) = &self.root {
353 let mut cloned_rect = self.rect.clone().unwrap();
354
355 root_node.recursive_closure_range_query(
356 pos,
357 radius * radius,
358 &mut cloned_rect,
359 &mut process,
360 );
361 }
362 }
363
364 fn hyper_rect_dist_sq<Q: PointTrait<PScalar = P::PScalar>>(
365 rect: &HyperRectangle<P>,
366 pos: &Q,
367 ) -> P::PScalar {
368 let mut result = P::PScalar::zero();
369 for i in 0..P::DIMENSION {
370 let pos_val = pos.at(i);
371 if pos_val < rect.min.at(i) {
372 result += Self::sq(rect.min.at(i) - pos_val);
373 } else if pos_val > rect.max.at(i) {
374 result += Self::sq(rect.max.at(i) - pos_val);
375 }
376 }
377 result
378 }
379
380 #[inline(always)]
381 fn sq(i: P::PScalar) -> P::PScalar {
382 i * i
383 }
384}