f3l_search_tree/
kdtree.rs

1mod kd_features;
2mod kd_leaf;
3pub use kd_features::KdFeature;
4pub use kd_leaf::KdLeaf;
5
6#[cfg(all(feature = "pure", not(feature = "core")))]
7use crate::{
8    serde::{self, Deserialize, Serialize},
9    BasicFloat,
10};
11use crate::{SearchBy, TreeHeapElement, TreeKnnResult, TreeRadiusResult, TreeResult, TreeSearch};
12#[cfg(all(feature = "core", not(feature = "pure")))]
13use f3l_core::{
14    rayon,
15    serde::{self, Deserialize, Serialize},
16    BasicFloat,
17};
18use std::{cmp::Reverse, collections::BinaryHeap, ops::Index};
19
20/// KD-Tree Implement
21///
22/// Use for any dimension of data.
23/// Allow type which implement `Into<[T; D]>`
24/// See more in `tests`.
25///
26/// `let mut tree = KdTree::<f32, 1>::new();`
27/// Input:
28/// * element type (like f32 or f64.. )
29/// * Dimension: usize
30///
31/// # Examples
32/// ```
33/// use approx::assert_relative_eq;
34/// use f3l_core::glam::{Vec2, Vec3};
35/// use f3l_search_tree::*;
36///
37/// let data = (0..10).map(|i| [i as f32]).collect::<Vec<_>>();
38/// let mut tree = KdTree::with_data(&data);
39/// tree.build();
40/// let result = tree.search_knn(&[5.1f32], 1);
41/// let nearest_data = result[0].0[0];
42/// let nearest_distance = result[0].1;
43///
44/// assert_relative_eq!(nearest_data, 5f32);
45/// assert_relative_eq!(nearest_distance, 0.1f32);
46/// ```
47#[derive(Debug, Clone, Default, Serialize, Deserialize)]
48#[serde(crate = "self::serde")]
49pub struct KdTree<'a, T: BasicFloat, P>
50where
51    P: Index<usize, Output = T> + Clone + Copy,
52{
53    pub dim: usize,
54    pub ignores: Vec<usize>,
55    pub enable_ignore: bool,
56    #[serde(skip_serializing)]
57    #[serde(skip_deserializing)]
58    pub root: Option<Box<KdLeaf>>,
59    #[serde(skip_serializing)]
60    #[serde(skip_deserializing)]
61    pub data: Option<&'a [P]>,
62}
63
64impl<'a, T: BasicFloat, P> KdTree<'a, T, P>
65where
66    P: Index<usize, Output = T> + Clone + Copy + Send + Sync,
67{
68    pub fn new(dim: usize) -> Self {
69        Self {
70            root: None,
71            dim,
72            data: None,
73            ignores: vec![],
74            enable_ignore: false,
75        }
76    }
77
78    pub fn with_data(dim: usize, data: &'a [P]) -> Self {
79        Self {
80            root: None,
81            dim,
82            data: Some(data),
83            ignores: vec![],
84            enable_ignore: false,
85        }
86    }
87
88    pub fn clear(&mut self) {
89        // self.data.clear();
90        self.root = None;
91    }
92
93    pub fn set_data(&mut self, data: &'a [P]) {
94        self.clear();
95        self.data = Some(data);
96    }
97
98    pub fn build(&mut self) {
99        if let Some(d) = self.data {
100            let n = d.len();
101            self.root = Some(self.build_recursive(&mut (0..n).collect::<Vec<usize>>(), d));
102        }
103    }
104
105    fn build_recursive(&self, indices: &mut [usize], data: &[P]) -> Box<KdLeaf> {
106        let mut node = Box::<KdLeaf>::default();
107        if indices.len() == 1 {
108            node.feature = KdFeature::Leaf(indices[0]);
109            return node;
110        }
111        let (split, index) = self.mean_split(indices, data);
112
113        let mut data_l = indices[..index].to_owned();
114        let mut data_r = indices[index..].to_owned();
115
116        (node.left, node.right) = rayon::join(
117            || Some(self.build_recursive(&mut data_l, data)),
118            || Some(self.build_recursive(&mut data_r, data)),
119        );
120        node.feature = split;
121
122        node
123    }
124
125    fn mean_split(&self, indices: &mut [usize], data: &[P]) -> (KdFeature, usize) {
126        // Compute mean value per dimension
127        let factor = T::from(1.0f32 / indices.len() as f32).unwrap();
128        let mut mean = vec![T::zero(); self.dim];
129        indices.iter().for_each(|&i| {
130            (0..self.dim).for_each(|j| {
131                mean[j] += data[i][j] * factor;
132            })
133        });
134
135        // Compute variance per dimension
136        let mut var = vec![T::zero(); self.dim];
137        indices.iter().for_each(|&i| {
138            (0..self.dim).for_each(|j| {
139                let dist = data[i][j] - mean[j];
140                var[j] += dist * dist;
141            })
142        });
143        // Choose the max variance dimension
144        let mut split_dim = 0;
145        (1..self.dim).for_each(|i| {
146            if var[i] > var[split_dim] {
147                split_dim = i;
148            }
149        });
150
151        let split_val = mean[split_dim];
152        let (lim1, lim2) = self.plane_split(indices, split_dim, split_val, data);
153
154        let mut index: usize;
155        let mid = indices.len() / 2;
156        if lim1 > mid {
157            index = lim1;
158        } else if lim2 < mid {
159            index = lim2;
160        } else {
161            index = mid;
162        }
163        if lim1 == indices.len() || lim2 == 0 {
164            index = mid;
165        }
166
167        (
168            KdFeature::Split((split_dim, split_val.to_f32().unwrap())),
169            index,
170        )
171    }
172
173    fn plane_split(&self, indices: &mut [usize], split_dim: usize, split_val: T, data: &[P]) -> (usize, usize) {
174        let mut left = 0;
175        let mut right = indices.len() - 1;
176
177        loop {
178            while left <= right && data[indices[left]][split_dim] < split_val {
179                left += 1;
180            }
181            while left < right && data[indices[right]][split_dim] >= split_val {
182                right -= 1;
183            }
184            if left >= right {
185                break;
186            }
187            indices.swap(left, right);
188            left += 1;
189            right -= 1;
190        }
191        let lim1 = left;
192        right = indices.len() - 1;
193        loop {
194            while left <= right && data[indices[left]][split_dim] <= split_val {
195                left += 1;
196            }
197            while left < right && data[indices[right]][split_dim] > split_val {
198                right -= 1;
199            }
200            if left >= right {
201                break;
202            }
203            indices.swap(left, right);
204            left += 1;
205            right -= 1;
206        }
207        (lim1, left)
208    }
209
210    pub fn search<R: TreeResult>(&self, data: P, by: SearchBy, result: &mut R) {
211        let mut search_queue =
212            BinaryHeap::with_capacity(30);
213
214        if self.root.is_none() {
215            return;
216        }
217        if let Some(root) = &self.root {
218            self.search_(
219                result,
220                root,
221                &data,
222                by,
223                if result.is_farthest() { f32::MAX } else { 0.0 },
224                &mut search_queue,
225            );
226
227            while let Some(Reverse(node)) = search_queue.pop() {
228                self.search_(result, node.raw, &data, by, node.order, &mut search_queue)
229            }
230        };
231    }
232
233    fn search_<R: TreeResult>(
234        &self,
235        result: &mut R,
236        node: &'a KdLeaf,
237        data: &P,
238        by: SearchBy,
239        min_dist: f32,
240        // queue: &mut BinaryHeap<SearchQueue<TreeHeapElement<&'a Box<KdLeaf>, f32>>>,
241        queue: &mut BinaryHeap<Reverse<TreeHeapElement<&'a Box<KdLeaf>, f32>>>,
242    ) {
243        let is_farthest = result.is_farthest();
244        if match is_farthest {
245            true => result.worst() > min_dist,
246            false => result.worst() < min_dist,
247        } {
248            return;
249        }
250        // let p: [T; D] = (*data).into();
251        let p = data;
252
253        let near;
254        let far;
255
256        let d: T;
257        match node.feature {
258            KdFeature::Leaf(leaf) => {
259                if self.enable_ignore && self.ignores.contains(&leaf) {
260                    return;
261                }
262                let dist = distance(self.data.unwrap()[leaf], *p, self.dim);
263                result.add(leaf, dist.to_f32().unwrap());
264                return;
265            }
266            KdFeature::Split((sp_dim, sp_val)) => {
267                d = p[sp_dim] - T::from(sp_val).unwrap();
268                if d < T::zero() {
269                    near = &node.left;
270                    far = &node.right;
271                } else {
272                    near = &node.right;
273                    far = &node.left;
274                }
275            }
276        };
277        let (near, far) = if is_farthest {
278            (far, near)
279        } else {
280            (near, far)
281        };
282
283        if let Some(far) = far {
284            let add_far = match by {
285                SearchBy::Count(_) => {
286                    if !result.is_full() {
287                        true
288                    } else {
289                        match is_farthest {
290                            true => d * d > T::from(result.worst() + f32::EPSILON).unwrap(),
291                            false => d * d < T::from(result.worst() + f32::EPSILON).unwrap(),
292                        }
293                    }
294                }
295                SearchBy::Radius(r) => d * d <= T::from(r).unwrap(),
296            };
297            if add_far {
298                let node = TreeHeapElement {
299                    raw: far,
300                    order: min_dist + (d * d).to_f32().unwrap(),
301                };
302                queue.push(Reverse(node));
303            }
304        }
305
306        if let Some(near) = near {
307            self.search_(result, near, data, by, min_dist, queue);
308        }
309    }
310}
311
312#[inline]
313fn distance<T: BasicFloat, P>(a: P, b: P, dim: usize) -> T
314where
315    P: Index<usize, Output = T> + Copy,
316{
317    (0..dim).fold(T::zero(), |acc, i| acc + (a[i] - b[i]).powi(2))
318}
319
320impl<'a, T: BasicFloat, P> TreeSearch<P> for KdTree<'a, T, P>
321where
322    P: Send + Sync + Clone + Copy + Index<usize, Output = T>,
323{
324    fn search_knn(&self, point: &P, k: usize) -> Vec<(P, f32)> {
325        if self.data.is_none() {
326            return vec![];
327        }
328        let by = if k == 0 {
329            SearchBy::Count(1)
330        } else {
331            SearchBy::Count(k)
332        };
333        let mut result = TreeKnnResult::new(k);
334        self.search(*point, by, &mut result);
335        result
336            .result()
337            .iter()
338            .map(|&(i, d)| (self.data.unwrap()[i], d.sqrt()))
339            .collect::<Vec<(P, f32)>>()
340    }
341
342    fn search_radius(&self, point: &P, radius: f32) -> Vec<P> {
343        if self.data.is_none() {
344            return vec![];
345        }
346        let by = if radius == 0.0 {
347            SearchBy::Count(1)
348        } else {
349            SearchBy::Radius(radius * radius)
350        };
351        let mut result = TreeRadiusResult::new(radius * radius);
352        self.search(*point, by, &mut result);
353        result.data.iter().map(|&i| self.data.unwrap()[i]).collect()
354    }
355
356    fn search_knn_ids(&self, point: &P, k: usize) -> Vec<usize> {
357        let by = if k == 0 {
358            SearchBy::Count(1)
359        } else {
360            SearchBy::Count(k)
361        };
362        let mut result = TreeKnnResult::new(k);
363        self.search(*point, by, &mut result);
364        result.data.iter().map(|&(i, _)| i).collect()
365    }
366
367    fn search_radius_ids(&self, point: &P, radius: f32) -> Vec<usize> {
368        let by = if radius == 0.0 {
369            SearchBy::Count(1)
370        } else {
371            SearchBy::Radius(radius * radius)
372        };
373        let mut result = TreeRadiusResult::new(radius * radius);
374        self.search(*point, by, &mut result);
375        result.data
376    }
377
378    fn add_ignore(&mut self, idx: usize) {
379        self.ignores.push(idx);
380    }
381
382    fn add_ignores(&mut self, idx: &[usize]) {
383        idx.iter().for_each(|&i| self.ignores.push(i));
384    }
385
386    fn set_ignore(&mut self, enable: bool) {
387        self.enable_ignore = enable;
388    }
389}