krakel/
lib.rs

1/*  SPDX-License-Identifier:LGPL-2.1-only
2 *  Rust code Copyright (c) 2023 lacklustr@protonmail.com https://github.com/eadf
3 *
4 *  This file is ported from code inside of OpenCAMlib:
5 *  Copyright (c) 2010-2011 Anders Wallin (anders.e.e.wallin "at" gmail.com).
6 *  (see https://github.com/aewallin/opencamlib).
7 *
8 *  This program is free software: you can redistribute it and/or modify
9 *  it under the terms of the GNU Lesser General Public License as published by
10 *  the Free Software Foundation, either version 2.1 of the License, or
11 *  (at your option) any later version.
12 *
13 *  This program is distributed in the hope that it will be useful,
14 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 *  GNU Lesser General Public License for more details.
17 *
18 *  You should have received a copy of the GNU Lesser General Public License
19 *  along with this program. If not, see <http://www.gnu.org/licenses/>.
20*/
21
22//! # Krakel Crate
23//!
24//! `krakel` A tiny 2d kd-tree ported from [OpenCamLib](https://github.com/aewallin/opencamlib)
25//!
26
27#[cfg(not(feature = "vector-traits"))]
28use approx::UlpsEq;
29use std::fmt::Display;
30use std::{
31    fmt,
32    fmt::Debug,
33    ops::{AddAssign, DivAssign, MulAssign, SubAssign},
34};
35#[cfg(feature = "vector-traits")]
36use vector_traits::approx::UlpsEq;
37
38#[cfg(not(feature = "vector-traits"))]
39use num_traits::{FromPrimitive, Zero, real::Real};
40#[cfg(feature = "vector-traits")]
41use vector_traits::num_traits::{FromPrimitive, Zero, real::Real};
42
43mod impls;
44
45#[cfg(test)]
46mod tests;
47
48#[derive(thiserror::Error, Debug)]
49pub enum KrakelError {
50    #[error("Unknown error: {0}")]
51    InternalError(String),
52}
53
54pub trait PointTrait: Clone + PartialEq
55where
56    Self::PScalar: Real
57        + FromPrimitive
58        + UlpsEq
59        + Debug
60        + Display
61        + PartialEq
62        + MulAssign
63        + SubAssign
64        + DivAssign
65        + AddAssign,
66{
67    type PScalar;
68    fn x(&self) -> Self::PScalar;
69    fn y(&self) -> Self::PScalar;
70    fn set_x(&mut self, x: Self::PScalar);
71    fn set_y(&mut self, y: Self::PScalar);
72
73    /// Returns the squared distance between this point and another point that is using the same scalar type.
74    #[inline(always)]
75    fn dist_sq<Q: PointTrait<PScalar = Self::PScalar>>(a: &Self, b: &Q) -> Self::PScalar {
76        let dx: Self::PScalar = a.x() - b.x();
77        let dy: Self::PScalar = a.y() - b.y();
78        dx * dx + dy * dy
79    }
80
81    fn at(&self, index: u8) -> Self::PScalar;
82    fn at_mut(&mut self, index: u8) -> &mut Self::PScalar;
83    const DIMENSION: u8;
84}
85
86pub trait KDPoint<P: PointTrait> {
87    fn get_coordinate(&self, index: usize) -> P::PScalar;
88    fn set_coordinate(&mut self, index: usize, value: P::PScalar);
89}
90
91pub struct KDNode<P: PointTrait> {
92    pos: P,
93    dir: u8,
94    left: Option<Box<KDNode<P>>>,
95    right: Option<Box<KDNode<P>>>,
96}
97
98#[derive(Clone)]
99struct HyperRectangle<P: PointTrait> {
100    min: P,
101    max: P,
102}
103
104pub struct KDTree<P: PointTrait> {
105    root: Option<Box<KDNode<P>>>,
106    rect: Option<HyperRectangle<P>>,
107}
108
109impl<P: PointTrait> KDNode<P> {
110    fn recursive_insert(
111        node: &mut Option<Box<KDNode<P>>>,
112        pos: P,
113        dir: u8,
114        dim: u8,
115    ) -> Result<(), KrakelError> {
116        match node {
117            None => {
118                *node = Some(Box::new(KDNode {
119                    pos,
120                    dir,
121                    left: None,
122                    right: None,
123                }));
124            }
125            Some(current) => {
126                let new_dir = (current.dir + 1) % dim;
127                if pos.at(current.dir) < current.pos.at(current.dir) {
128                    Self::recursive_insert(&mut current.left, pos, new_dir, dim)?;
129                } else {
130                    Self::recursive_insert(&mut current.right, pos, new_dir, dim)?;
131                }
132            }
133        }
134        Ok(())
135    }
136
137    fn recursive_nearest<'a>(
138        &'a self,
139        pos: &P,
140        result: &mut Option<&'a P>,
141        result_dist_sq: &mut P::PScalar,
142        rect: &mut HyperRectangle<P>,
143    ) {
144        let dir = self.dir;
145
146        let (nearer_subtree, farther_subtree) = if pos.at(dir) <= self.pos.at(dir) {
147            (&self.left, &self.right)
148        } else {
149            (&self.right, &self.left)
150        };
151
152        let old_value = if pos.at(dir) <= self.pos.at(dir) {
153            std::mem::replace(&mut rect.max.at(dir), self.pos.at(dir))
154        } else {
155            std::mem::replace(&mut rect.min.at(dir), self.pos.at(dir))
156        };
157
158        if let Some(nearer_node) = nearer_subtree {
159            nearer_node.recursive_nearest(pos, result, result_dist_sq, rect);
160        }
161
162        if pos.at(dir) <= self.pos.at(dir) {
163            *rect.max.at_mut(dir) = old_value;
164        } else {
165            *rect.min.at_mut(dir) = old_value;
166        }
167
168        let dist_sq = PointTrait::dist_sq(&self.pos, pos);
169        if dist_sq < *result_dist_sq {
170            *result_dist_sq = dist_sq;
171            *result = Some(&self.pos);
172        }
173
174        if let Some(farther_node) = farther_subtree {
175            if KDTree::hyper_rect_dist_sq(rect, pos) < *result_dist_sq {
176                farther_node.recursive_nearest(pos, result, result_dist_sq, rect);
177            }
178        }
179    }
180
181    fn recursive_range_query<Q: PointTrait<PScalar = P::PScalar>>(
182        &self,
183        pos: &Q,
184        radius_sq: P::PScalar,
185        results: &mut Vec<P>,
186        rect: &mut HyperRectangle<P>,
187    ) {
188        let dir = self.dir;
189
190        let (nearer_subtree, farther_subtree) = if pos.at(dir) <= self.pos.at(dir) {
191            (&self.left, &self.right)
192        } else {
193            (&self.right, &self.left)
194        };
195
196        let old_value = if pos.at(dir) <= self.pos.at(dir) {
197            std::mem::replace(&mut rect.max.at(dir), self.pos.at(dir))
198        } else {
199            std::mem::replace(&mut rect.min.at(dir), self.pos.at(dir))
200        };
201
202        if let Some(nearer_node) = nearer_subtree {
203            nearer_node.recursive_range_query(pos, radius_sq, results, rect);
204        }
205
206        if pos.at(dir) <= self.pos.at(dir) {
207            *rect.max.at_mut(dir) = old_value;
208        } else {
209            *rect.min.at_mut(dir) = old_value;
210        }
211
212        let dist_sq = PointTrait::dist_sq(&self.pos, pos);
213        if dist_sq <= radius_sq {
214            results.push(self.pos.clone());
215        }
216
217        if let Some(farther_node) = farther_subtree {
218            if KDTree::hyper_rect_dist_sq(rect, pos) <= radius_sq {
219                farther_node.recursive_range_query(pos, radius_sq, results, rect);
220            }
221        }
222    }
223
224    fn recursive_closure_range_query<Q: PointTrait<PScalar = P::PScalar>, F>(
225        &self,
226        pos: &Q,
227        radius_sq: P::PScalar,
228        rect: &mut HyperRectangle<P>,
229        process: &mut F,
230    ) where
231        F: FnMut(&P),
232    {
233        let dir = self.dir;
234
235        let (nearer_subtree, farther_subtree) = if pos.at(dir) <= self.pos.at(dir) {
236            (&self.left, &self.right)
237        } else {
238            (&self.right, &self.left)
239        };
240
241        let old_value = if pos.at(dir) <= self.pos.at(dir) {
242            std::mem::replace(&mut rect.max.at(dir), self.pos.at(dir))
243        } else {
244            std::mem::replace(&mut rect.min.at(dir), self.pos.at(dir))
245        };
246
247        if let Some(nearer_node) = nearer_subtree {
248            nearer_node.recursive_closure_range_query(pos, radius_sq, rect, process);
249        }
250
251        if pos.at(dir) <= self.pos.at(dir) {
252            *rect.max.at_mut(dir) = old_value;
253        } else {
254            *rect.min.at_mut(dir) = old_value;
255        }
256
257        if PointTrait::dist_sq(&self.pos, pos) <= radius_sq {
258            process(&self.pos);
259        }
260
261        if let Some(farther_node) = farther_subtree {
262            if KDTree::hyper_rect_dist_sq(rect, pos) <= radius_sq {
263                farther_node.recursive_closure_range_query(pos, radius_sq, rect, process);
264            }
265        }
266    }
267
268    fn format_node(&self, f: &mut fmt::Formatter<'_>, depth: usize) -> fmt::Result {
269        for _ in 0..depth {
270            write!(f, " ")?;
271        }
272
273        write!(f, "d={} node at ", self.dir)?;
274        for i in 0..P::DIMENSION {
275            write!(f, "{} ", self.pos.at(i))?;
276        }
277        writeln!(f)?;
278
279        if let Some(ref left_node) = self.left {
280            left_node.format_node(f, depth + 1)?;
281        }
282
283        if let Some(ref right_node) = self.right {
284            right_node.format_node(f, depth + 1)?;
285        }
286
287        Ok(())
288    }
289}
290
291impl<P: PointTrait> KDTree<P> {
292    pub fn insert(&mut self, pos: P) -> Result<(), KrakelError> {
293        KDNode::recursive_insert(&mut self.root, pos.clone(), 0, P::DIMENSION)?;
294
295        if self.rect.is_none() {
296            self.rect = Some(HyperRectangle {
297                min: pos.clone(),
298                max: pos.clone(),
299            });
300        } else {
301            for i in 0..P::DIMENSION {
302                if pos.at(i) < self.rect.as_mut().unwrap().min.at(i) {
303                    *self.rect.as_mut().unwrap().min.at_mut(i) = pos.at(i);
304                } else if pos.at(i) > self.rect.as_mut().unwrap().max.at(i) {
305                    *self.rect.as_mut().unwrap().max.at_mut(i) = pos.at(i);
306                }
307            }
308        }
309        Ok(())
310    }
311
312    #[allow(dead_code)]
313    pub fn nearest(&self, pos: &P) -> Option<P> {
314        if let Some(root_node) = &self.root {
315            // Now that we know self.root is Some(_), it's safe to assume self.rect is Some(_) as well
316            let mut rect = self.rect.clone().unwrap();
317            let mut result: Option<&P> = self.root.as_ref().map(|node| &node.pos);
318            let mut result_dist_sq = P::dist_sq(result.as_ref().unwrap(), pos);
319
320            root_node.recursive_nearest(pos, &mut result, &mut result_dist_sq, &mut rect);
321            result.cloned()
322        } else {
323            None
324        }
325    }
326
327    #[allow(dead_code)]
328    pub fn range_query<Q: PointTrait<PScalar = P::PScalar>>(
329        &self,
330        pos: &Q,
331        radius: P::PScalar,
332    ) -> Vec<P> {
333        if let Some(root_node) = &self.root {
334            let mut results: Vec<P> = Vec::new();
335            let mut cloned_rect = self.rect.clone().unwrap();
336
337            root_node.recursive_range_query(pos, radius * radius, &mut results, &mut cloned_rect);
338            results
339        } else {
340            Vec::new()
341        }
342    }
343
344    pub fn closure_range_query<Q: PointTrait<PScalar = P::PScalar>, F>(
345        &self,
346        pos: &Q,
347        radius: P::PScalar,
348        mut process: F,
349    ) where
350        F: FnMut(&P),
351    {
352        if let Some(root_node) = &self.root {
353            let mut cloned_rect = self.rect.clone().unwrap();
354
355            root_node.recursive_closure_range_query(
356                pos,
357                radius * radius,
358                &mut cloned_rect,
359                &mut process,
360            );
361        }
362    }
363
364    fn hyper_rect_dist_sq<Q: PointTrait<PScalar = P::PScalar>>(
365        rect: &HyperRectangle<P>,
366        pos: &Q,
367    ) -> P::PScalar {
368        let mut result = P::PScalar::zero();
369        for i in 0..P::DIMENSION {
370            let pos_val = pos.at(i);
371            if pos_val < rect.min.at(i) {
372                result += Self::sq(rect.min.at(i) - pos_val);
373            } else if pos_val > rect.max.at(i) {
374                result += Self::sq(rect.max.at(i) - pos_val);
375            }
376        }
377        result
378    }
379
380    #[inline(always)]
381    fn sq(i: P::PScalar) -> P::PScalar {
382        i * i
383    }
384}