1mod kd_features;
2mod kd_leaf;
3pub use kd_features::KdFeature;
4pub use kd_leaf::KdLeaf;
5
6#[cfg(all(feature = "pure", not(feature = "core")))]
7use crate::{
8 serde::{self, Deserialize, Serialize},
9 BasicFloat,
10};
11use crate::{SearchBy, TreeHeapElement, TreeKnnResult, TreeRadiusResult, TreeResult, TreeSearch};
12#[cfg(all(feature = "core", not(feature = "pure")))]
13use f3l_core::{
14 rayon,
15 serde::{self, Deserialize, Serialize},
16 BasicFloat,
17};
18use std::{cmp::Reverse, collections::BinaryHeap, ops::Index};
19
20#[derive(Debug, Clone, Default, Serialize, Deserialize)]
48#[serde(crate = "self::serde")]
49pub struct KdTree<'a, T: BasicFloat, P>
50where
51 P: Index<usize, Output = T> + Clone + Copy,
52{
53 pub dim: usize,
54 pub ignores: Vec<usize>,
55 pub enable_ignore: bool,
56 #[serde(skip_serializing)]
57 #[serde(skip_deserializing)]
58 pub root: Option<Box<KdLeaf>>,
59 #[serde(skip_serializing)]
60 #[serde(skip_deserializing)]
61 pub data: Option<&'a [P]>,
62}
63
64impl<'a, T: BasicFloat, P> KdTree<'a, T, P>
65where
66 P: Index<usize, Output = T> + Clone + Copy + Send + Sync,
67{
68 pub fn new(dim: usize) -> Self {
69 Self {
70 root: None,
71 dim,
72 data: None,
73 ignores: vec![],
74 enable_ignore: false,
75 }
76 }
77
78 pub fn with_data(dim: usize, data: &'a [P]) -> Self {
79 Self {
80 root: None,
81 dim,
82 data: Some(data),
83 ignores: vec![],
84 enable_ignore: false,
85 }
86 }
87
88 pub fn clear(&mut self) {
89 self.root = None;
91 }
92
93 pub fn set_data(&mut self, data: &'a [P]) {
94 self.clear();
95 self.data = Some(data);
96 }
97
98 pub fn build(&mut self) {
99 if let Some(d) = self.data {
100 let n = d.len();
101 self.root = Some(self.build_recursive(&mut (0..n).collect::<Vec<usize>>(), d));
102 }
103 }
104
105 fn build_recursive(&self, indices: &mut [usize], data: &[P]) -> Box<KdLeaf> {
106 let mut node = Box::<KdLeaf>::default();
107 if indices.len() == 1 {
108 node.feature = KdFeature::Leaf(indices[0]);
109 return node;
110 }
111 let (split, index) = self.mean_split(indices, data);
112
113 let mut data_l = indices[..index].to_owned();
114 let mut data_r = indices[index..].to_owned();
115
116 (node.left, node.right) = rayon::join(
117 || Some(self.build_recursive(&mut data_l, data)),
118 || Some(self.build_recursive(&mut data_r, data)),
119 );
120 node.feature = split;
121
122 node
123 }
124
125 fn mean_split(&self, indices: &mut [usize], data: &[P]) -> (KdFeature, usize) {
126 let factor = T::from(1.0f32 / indices.len() as f32).unwrap();
128 let mut mean = vec![T::zero(); self.dim];
129 indices.iter().for_each(|&i| {
130 (0..self.dim).for_each(|j| {
131 mean[j] += data[i][j] * factor;
132 })
133 });
134
135 let mut var = vec![T::zero(); self.dim];
137 indices.iter().for_each(|&i| {
138 (0..self.dim).for_each(|j| {
139 let dist = data[i][j] - mean[j];
140 var[j] += dist * dist;
141 })
142 });
143 let mut split_dim = 0;
145 (1..self.dim).for_each(|i| {
146 if var[i] > var[split_dim] {
147 split_dim = i;
148 }
149 });
150
151 let split_val = mean[split_dim];
152 let (lim1, lim2) = self.plane_split(indices, split_dim, split_val, data);
153
154 let mut index: usize;
155 let mid = indices.len() / 2;
156 if lim1 > mid {
157 index = lim1;
158 } else if lim2 < mid {
159 index = lim2;
160 } else {
161 index = mid;
162 }
163 if lim1 == indices.len() || lim2 == 0 {
164 index = mid;
165 }
166
167 (
168 KdFeature::Split((split_dim, split_val.to_f32().unwrap())),
169 index,
170 )
171 }
172
173 fn plane_split(&self, indices: &mut [usize], split_dim: usize, split_val: T, data: &[P]) -> (usize, usize) {
174 let mut left = 0;
175 let mut right = indices.len() - 1;
176
177 loop {
178 while left <= right && data[indices[left]][split_dim] < split_val {
179 left += 1;
180 }
181 while left < right && data[indices[right]][split_dim] >= split_val {
182 right -= 1;
183 }
184 if left >= right {
185 break;
186 }
187 indices.swap(left, right);
188 left += 1;
189 right -= 1;
190 }
191 let lim1 = left;
192 right = indices.len() - 1;
193 loop {
194 while left <= right && data[indices[left]][split_dim] <= split_val {
195 left += 1;
196 }
197 while left < right && data[indices[right]][split_dim] > split_val {
198 right -= 1;
199 }
200 if left >= right {
201 break;
202 }
203 indices.swap(left, right);
204 left += 1;
205 right -= 1;
206 }
207 (lim1, left)
208 }
209
210 pub fn search<R: TreeResult>(&self, data: P, by: SearchBy, result: &mut R) {
211 let mut search_queue =
212 BinaryHeap::with_capacity(30);
213
214 if self.root.is_none() {
215 return;
216 }
217 if let Some(root) = &self.root {
218 self.search_(
219 result,
220 root,
221 &data,
222 by,
223 if result.is_farthest() { f32::MAX } else { 0.0 },
224 &mut search_queue,
225 );
226
227 while let Some(Reverse(node)) = search_queue.pop() {
228 self.search_(result, node.raw, &data, by, node.order, &mut search_queue)
229 }
230 };
231 }
232
233 fn search_<R: TreeResult>(
234 &self,
235 result: &mut R,
236 node: &'a KdLeaf,
237 data: &P,
238 by: SearchBy,
239 min_dist: f32,
240 queue: &mut BinaryHeap<Reverse<TreeHeapElement<&'a Box<KdLeaf>, f32>>>,
242 ) {
243 let is_farthest = result.is_farthest();
244 if match is_farthest {
245 true => result.worst() > min_dist,
246 false => result.worst() < min_dist,
247 } {
248 return;
249 }
250 let p = data;
252
253 let near;
254 let far;
255
256 let d: T;
257 match node.feature {
258 KdFeature::Leaf(leaf) => {
259 if self.enable_ignore && self.ignores.contains(&leaf) {
260 return;
261 }
262 let dist = distance(self.data.unwrap()[leaf], *p, self.dim);
263 result.add(leaf, dist.to_f32().unwrap());
264 return;
265 }
266 KdFeature::Split((sp_dim, sp_val)) => {
267 d = p[sp_dim] - T::from(sp_val).unwrap();
268 if d < T::zero() {
269 near = &node.left;
270 far = &node.right;
271 } else {
272 near = &node.right;
273 far = &node.left;
274 }
275 }
276 };
277 let (near, far) = if is_farthest {
278 (far, near)
279 } else {
280 (near, far)
281 };
282
283 if let Some(far) = far {
284 let add_far = match by {
285 SearchBy::Count(_) => {
286 if !result.is_full() {
287 true
288 } else {
289 match is_farthest {
290 true => d * d > T::from(result.worst() + f32::EPSILON).unwrap(),
291 false => d * d < T::from(result.worst() + f32::EPSILON).unwrap(),
292 }
293 }
294 }
295 SearchBy::Radius(r) => d * d <= T::from(r).unwrap(),
296 };
297 if add_far {
298 let node = TreeHeapElement {
299 raw: far,
300 order: min_dist + (d * d).to_f32().unwrap(),
301 };
302 queue.push(Reverse(node));
303 }
304 }
305
306 if let Some(near) = near {
307 self.search_(result, near, data, by, min_dist, queue);
308 }
309 }
310}
311
312#[inline]
313fn distance<T: BasicFloat, P>(a: P, b: P, dim: usize) -> T
314where
315 P: Index<usize, Output = T> + Copy,
316{
317 (0..dim).fold(T::zero(), |acc, i| acc + (a[i] - b[i]).powi(2))
318}
319
320impl<'a, T: BasicFloat, P> TreeSearch<P> for KdTree<'a, T, P>
321where
322 P: Send + Sync + Clone + Copy + Index<usize, Output = T>,
323{
324 fn search_knn(&self, point: &P, k: usize) -> Vec<(P, f32)> {
325 if self.data.is_none() {
326 return vec![];
327 }
328 let by = if k == 0 {
329 SearchBy::Count(1)
330 } else {
331 SearchBy::Count(k)
332 };
333 let mut result = TreeKnnResult::new(k);
334 self.search(*point, by, &mut result);
335 result
336 .result()
337 .iter()
338 .map(|&(i, d)| (self.data.unwrap()[i], d.sqrt()))
339 .collect::<Vec<(P, f32)>>()
340 }
341
342 fn search_radius(&self, point: &P, radius: f32) -> Vec<P> {
343 if self.data.is_none() {
344 return vec![];
345 }
346 let by = if radius == 0.0 {
347 SearchBy::Count(1)
348 } else {
349 SearchBy::Radius(radius * radius)
350 };
351 let mut result = TreeRadiusResult::new(radius * radius);
352 self.search(*point, by, &mut result);
353 result.data.iter().map(|&i| self.data.unwrap()[i]).collect()
354 }
355
356 fn search_knn_ids(&self, point: &P, k: usize) -> Vec<usize> {
357 let by = if k == 0 {
358 SearchBy::Count(1)
359 } else {
360 SearchBy::Count(k)
361 };
362 let mut result = TreeKnnResult::new(k);
363 self.search(*point, by, &mut result);
364 result.data.iter().map(|&(i, _)| i).collect()
365 }
366
367 fn search_radius_ids(&self, point: &P, radius: f32) -> Vec<usize> {
368 let by = if radius == 0.0 {
369 SearchBy::Count(1)
370 } else {
371 SearchBy::Radius(radius * radius)
372 };
373 let mut result = TreeRadiusResult::new(radius * radius);
374 self.search(*point, by, &mut result);
375 result.data
376 }
377
378 fn add_ignore(&mut self, idx: usize) {
379 self.ignores.push(idx);
380 }
381
382 fn add_ignores(&mut self, idx: &[usize]) {
383 idx.iter().for_each(|&i| self.ignores.push(i));
384 }
385
386 fn set_ignore(&mut self, enable: bool) {
387 self.enable_ignore = enable;
388 }
389}