1#![deny(unsafe_code)]
2
3mod bound;
4mod point;
5mod traits;
6mod util;
7
8use core::cmp::Ord;
9use core::marker::Copy;
10use core::ops::Sub;
11use std::collections::VecDeque;
12
13use bound::Bound;
14
15use crate::point::Point;
16use crate::traits::Epsilon;
17use crate::traits::Mean;
18
19pub type QuadTree<T> = Tree<T, 2>;
20pub type OcTree<T> = Tree<T, 3>;
21
22pub struct Tree<T, const N: usize> {
23 points: Vec<Point<T, N>>,
26
27 splits: Vec<usize>,
32
33 depth: u32,
35}
36
37impl<T, const N: usize> Tree<T, N> {
38 fn uninit(points: Vec<Point<T, N>>, depth: u32) -> Self
39 where
40 T: Ord + Copy,
41 {
42 let num_splits = util::num_divs::<N>().pow(depth + 1);
46 let splits = Vec::with_capacity(num_splits);
47
48 Self {
49 points,
50 splits,
51 depth,
52 }
53 }
54
55 pub fn new(points: Vec<Point<T, N>>, depth: u32) -> Self
56 where
57 T: Mean + Epsilon + Sub<Output = T> + Ord,
58 {
59 let mut tree = Self::uninit(points, depth);
60
61 tree.build();
62
63 tree
64 }
65
66 fn build(&mut self)
67 where
68 T: Ord + Mean + Epsilon + Sub<Output = T>,
69 {
70 let n = self.points.len();
71
72 let Some(bound) = Bound::from_points(&self.points) else {
73 return;
74 };
75
76 let mut keys = vec![0; n];
77 let mut split_queue = VecDeque::with_capacity(n);
78 let mut bound_queue = VecDeque::with_capacity(n);
79
80 let mut swaps = Vec::with_capacity(n);
82
83 let mut splits = vec![0; util::num_divs::<N>()];
85
86 split_queue.push_back(0);
87 bound_queue.push_back(bound);
88
89 for d in 0..self.depth {
90 let regions = util::num_divs::<N>().pow(d);
92
93 for _ in 0..regions {
94 let Some(lo) = split_queue.pop_front() else {
95 unreachable!()
96 };
97
98 let Some(bound) = bound_queue.pop_front() else {
99 unreachable!()
100 };
101
102 self.splits.push(lo);
103
104 let hi = split_queue.front().copied().unwrap_or(n);
107 let hi = if hi < lo { n } else { hi };
108
109 let points = &mut self.points[lo..hi];
110 let keys = &mut keys[lo..hi];
111
112 let mid = bound.center();
113 Self::sort_layer(mid, points, keys, &mut swaps, &mut splits);
114
115 split_queue.extend(splits.iter().copied().map(|s| s + lo));
116 splits.fill(0);
117
118 let Some(bounds) = bound.split() else {
120 continue;
121 };
122
123 bound_queue.extend(bounds);
124 }
125 }
126
127 self.splits.extend(split_queue);
128 }
129
130 fn sort_layer(
131 mid: Point<T, N>,
132 points: &mut [Point<T, N>],
133 keys: &mut [usize],
134 swaps: &mut Vec<usize>,
135 splits: &mut [usize],
136 ) where
137 T: Ord,
138 {
139 debug_assert_eq!(points.len(), keys.len());
140 let n = points.len();
141
142 for i in 0..N {
144 for (j, p) in points.iter().enumerate() {
146 if p.0[i] >= mid.0[i] {
147 keys[j] |= 1 << i;
148 }
149 }
150 }
151
152 swaps.extend(0..n);
154 util::argsort(keys, swaps);
155 util::sort_by_argsort(points, swaps);
156
157 Self::compute_splits(keys, splits);
158
159 keys.fill(0);
160 swaps.clear();
161 }
162
163 fn compute_splits(keys: &[usize], splits: &mut [usize]) {
164 for &k in keys {
165 splits[k] += 1; }
167
168 for i in 1..util::num_divs::<N>() {
170 splits[i] += splits[i - 1];
171 }
172
173 splits.rotate_right(1);
174 splits[0] = 0;
175 }
176}
177
178#[cfg(test)]
179mod test_sort_layer {
180 use std::collections::VecDeque;
181 use std::fmt::Debug;
182
183 use crate::Tree;
184 use crate::bound::Bound;
185 use crate::point::Point;
186 use crate::traits::Mean;
187 use crate::util;
188
189 fn sort_layer_wrapper<T: Copy + Ord + Mean + Debug, const N: usize>(
191 points: &mut [Point<T, N>],
192 lo: usize,
193 ) -> (Point<T, N>, VecDeque<usize>) {
194 let n = points.len();
195
196 let Some(bound) = Bound::from_points(points) else {
197 panic!("Provide at least one point")
198 };
199
200 let mid = bound.center();
201 let mut keys = vec![0; n];
202 let mut swaps = Vec::with_capacity(n);
203 let mut splits = vec![0; util::num_divs::<N>()];
204 let mut split_queue = VecDeque::with_capacity(n);
205
206 Tree::sort_layer(mid.clone(), points, &mut keys, &mut swaps, &mut splits);
207
208 split_queue.extend(splits.iter().copied().map(|s| s + lo));
209
210 (mid, split_queue)
211 }
212
213 #[test]
214 fn no_offset() {
215 let mut points = [[0, 2], [2, 2], [2, 0], [0, 0]].map(Into::into).to_vec();
216 let exp_points = &[[0, 0], [2, 0], [0, 2], [2, 2]].map(Into::into);
217
218 let n = points.len();
219
220 let lo = 0;
221 let (mid, mut split_queue) = sort_layer_wrapper(&mut points, lo);
222
223 assert_eq!(points, exp_points);
224 assert_eq!(split_queue, [0, 1, 2, 3]);
225
226 let split_queue = split_queue.make_contiguous();
227 let Ok([a, b, c, d]) = TryInto::<[usize; 4]>::try_into(split_queue) else {
228 unreachable!()
229 };
230
231 assert_eq!(mid, (1, 1).into());
232
233 let nw = &points[a..b];
234 let ne = &points[b..c];
235 let sw = &points[c..d];
236 let se = &points[d..n];
237
238 for p in nw {
239 assert!(p.0[0] < mid.0[0], "{p:?} not in NW");
240 assert!(p.0[1] < mid.0[1], "{p:?} not in NW");
241 }
242
243 for p in ne {
244 assert!(p.0[0] >= mid.0[0], "{p:?} not in NE");
245 assert!(p.0[1] < mid.0[1], "{p:?} not in NE");
246 }
247
248 for p in sw {
249 assert!(p.0[0] < mid.0[0], "{p:?} not in SW");
250 assert!(p.0[1] >= mid.0[1], "{p:?} not in SW");
251 }
252
253 for p in se {
254 assert!(p.0[0] >= mid.0[0], "{p:?} not in SE");
255 assert!(p.0[1] >= mid.0[1], "{p:?} not in SE");
256 }
257 }
258
259 #[test]
260 fn with_offset() {
261 let mut points = [[0, 0], [0, 0], [0, 0], [0, 2], [2, 2], [2, 0], [0, 0]]
262 .map(Into::into)
263 .to_vec();
264 let exp_points = &[[0, 0], [0, 0], [0, 0], [0, 0], [2, 0], [0, 2], [2, 2]].map(Into::into);
265
266 let n = points.len();
267
268 let lo = 3;
269 let (mid, mut split_queue) = sort_layer_wrapper(&mut points[lo..], lo);
270
271 assert_eq!(points, exp_points);
272 assert_eq!(split_queue, [3, 4, 5, 6]);
273
274 let split_queue = split_queue.make_contiguous();
275 let Ok([a, b, c, d]) = TryInto::<[usize; 4]>::try_into(split_queue) else {
276 unreachable!()
277 };
278
279 assert_eq!(mid, (1, 1).into());
280
281 let nw = &points[a..b];
282 let ne = &points[b..c];
283 let sw = &points[c..d];
284 let se = &points[d..n];
285
286 for p in nw {
287 assert!(p.0[0] < mid.0[0], "{p:?} not in NW");
288 assert!(p.0[1] < mid.0[1], "{p:?} not in NW");
289 }
290
291 for p in ne {
292 assert!(p.0[0] >= mid.0[0], "{p:?} not in NE");
293 assert!(p.0[1] < mid.0[1], "{p:?} not in NE");
294 }
295
296 for p in sw {
297 assert!(p.0[0] < mid.0[0], "{p:?} not in SW");
298 assert!(p.0[1] >= mid.0[1], "{p:?} not in SW");
299 }
300
301 for p in se {
302 assert!(p.0[0] >= mid.0[0], "{p:?} not in SE");
303 assert!(p.0[1] >= mid.0[1], "{p:?} not in SE");
304 }
305 }
306}
307
308#[cfg(test)]
309mod test_tree_d1 {
310 use crate::Tree;
311 use crate::point::Point;
312
313 const DEPTH: u32 = 1;
314
315 #[test]
316 fn ordered_2d() {
317 let points = [[0, 0], [2, 0], [0, 2], [2, 2]].map(Into::into).to_vec();
318
319 let exp_points = &[[0, 0], [2, 0], [0, 2], [2, 2]].map(Into::into);
320
321 let tree = Tree::new(points, DEPTH);
322
323 assert_eq!(tree.points, exp_points);
324 assert_eq!(tree.splits, [0, 0, 1, 2, 3]);
325 }
326
327 #[test]
328 fn unordered_2d() {
329 let points = [[0, 2], [2, 2], [2, 0], [0, 0]].map(Into::into).to_vec();
330
331 let exp_points = &[[0, 0], [2, 0], [0, 2], [2, 2]].map(Into::into);
332
333 let tree = Tree::new(points, DEPTH);
334
335 assert_eq!(tree.points, exp_points);
336 assert_eq!(tree.splits, [0, 0, 1, 2, 3]);
337 }
338
339 #[test]
340 fn simple_3d() {
341 let points: Vec<Point<i32, 3>> = [(0, 0, 0), (2, 2, 2)].map(Into::into).to_vec();
342
343 let exp_points = [(0, 0, 0), (2, 2, 2)].map(Into::into);
344
345 let tree = Tree::new(points, DEPTH);
346
347 assert_eq!(tree.points, exp_points);
348 assert_eq!(tree.splits, [0, 0, 1, 1, 1, 1, 1, 1, 1]);
349 }
350}
351
352#[cfg(test)]
353mod test_tree_d2 {
354 use crate::Tree;
355
356 const DEPTH: u32 = 2;
357
358 #[test]
359 #[rustfmt::skip]
360 fn unordered_2d() {
361 let exp_points = [
362 (0, 0),
365
366 (1, 0), (3, 0), (2, 0), (0, 1), (0, 3), (0, 2), (1, 1), (2, 1), (3, 1), (1, 2), (1, 3), (3, 3), (2, 3), (3, 2), (2, 2) ].map(Into::into);
372
373 let points = [
374 (3, 0), (3, 3), (2, 3), (0, 0),
375 (0, 3), (0, 2), (0, 1), (1, 2),
376 (1, 3), (2, 1), (3, 1), (2, 0),
377 (1, 0), (1, 1), (3, 2), (2, 2),
378 ].map(Into::into).to_vec();
379
380 let tree = Tree::new(points, DEPTH);
381
382 assert_eq!(tree.points, exp_points);
383 assert_eq!(
384 tree.splits,
385 [
386 0,
388
389 0, 1, 4, 7,
391
392 0, 0, 0, 0, 1, 1, 1, 2, 4, 4, 5, 5, 7, 8, 10, 12
394 ]
395 );
396 }
397
398 #[test]
399 #[rustfmt::skip]
400 fn sorted_larger_2d() {
401 let points = [
402 (0, 0), (2, 0), (0, 2), (2, 2),
403
404 (4, 0),
405
406 (0, 4), (2, 4), (0, 6), (2, 6),
407
408 (4, 4), (6, 4), (4, 6), (6, 6),
409 ].map(Into::into).to_vec();
410
411 let exp_points = &[
412 (0, 0), (2, 0), (0, 2), (2, 2),
413
414 (4, 0),
415
416 (0, 4), (0, 6), (2, 4), (2, 6),
417
418 (4, 4), (6, 4), (4, 6), (6, 6),
419 ].map(Into::into);
420
421 let tree: Tree<i32, 2> = Tree::new(points, DEPTH);
422
423 assert_eq!(tree.points, exp_points);
424 assert_eq!(
425 tree.splits,
426 [
427 0,
429
430 0, 4, 5, 9,
432
433 0, 1, 2, 3, 4, 4, 5, 5, 5, 5, 5, 7, 9, 9, 9, 9
435 ]
436 );
437 }
438}
439
440#[cfg(test)]
441mod test_tree_d3 {
442 use std::ops::Range;
443
444 use crate::Tree;
445 use crate::point::Point;
446
447 const DEPTH: u32 = 3;
448
449 fn range2(xs: Range<i32>, ys: Range<i32>) -> Vec<Point<i32, 2>> {
450 xs.flat_map(|x| ys.clone().map(move |y| [x, y].into()))
451 .collect()
452 }
453
454 #[test]
455 #[rustfmt::skip]
456 fn unordered_large() {
457 let points = range2(0..5, 0..5);
458
459 let exp_points = &[
460 (0, 0), (1, 0), (0, 1), (1, 1), (2, 0), (3, 0), (4, 0), (2, 1), (3, 1), (4, 1), (0, 2), (1, 2), (0, 3), (0, 4), (1, 3), (1, 4), (2, 2), (3, 2), (4, 2), (2, 3), (2, 4), (3, 3), (3, 4), (4, 3), (4, 4) ].map(Into::into);
470
471 let tree = Tree::new(points, DEPTH);
472
473 assert_eq!(tree.points, exp_points);
474 assert_eq!(
475 tree.splits,
476 [
477 0,
479
480 0, 4, 10, 16,
482
483 0, 1, 2, 3, 4, 5, 7, 8, 10, 11, 12, 14, 16, 17, 19, 21,
485
486 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3,
488 4, 4, 4, 4, 5, 5, 5, 5, 7, 7, 7, 7, 8, 8, 8, 8,
489 10, 10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 14, 14, 14, 14,
490 16, 16, 16, 16, 17, 17, 17, 17, 19, 19, 19, 19, 21, 21, 21, 21
491 ]
492 );
493 }
494}
495
496#[cfg(test)]
497mod test_splits {
498 use crate::Tree;
499
500 #[test]
501 #[rustfmt::skip]
502 fn test_splits_2d() {
503 let tests: &[(&[usize], &[usize])] = &[
505 (&[0], &[0, 1, 1, 1]),
506 (&[0, 0], &[0, 2, 2, 2]),
507 (&[1, 2], &[0, 0, 1, 2]),
508 (&[0, 1, 2, 3], &[0, 1, 2, 3]),
509 (&[0, 1, 2, 2], &[0, 1, 2, 4]),
510 (&[0, 1, 1, 3], &[0, 1, 3, 3]),
511 (&[0, 0, 0, 3], &[0, 3, 3, 3]),
512 ];
513
514 let mut splits = [0; 4];
515
516 for (keys, exp) in tests {
517 Tree::<i8, 2>::compute_splits(keys, &mut splits);
518
519 assert_eq!(splits, *exp);
520 splits.fill(0);
521 }
522 }
523
524 #[test]
525 #[rustfmt::skip]
526 fn test_splits_3d() {
527 let tests: &[(&[usize], &[usize])] = &[
529 (&[0, 7], &[0, 1, 1, 1, 1, 1, 1, 1]),
530 ];
531
532 let mut splits = [0; 8];
533
534 for (keys, exp) in tests {
535 Tree::<i8, 3>::compute_splits(keys, &mut splits);
536
537 assert_eq!(splits, *exp);
538 splits.fill(0);
539 }
540 }
541}
542
543#[cfg(test)]
544mod proptests {
545 use std::collections::VecDeque;
546
547 use proptest::prelude::*;
548
549 use crate::Tree;
550 use crate::bound::Bound;
551 use crate::point::Point;
552 use crate::util;
553
554 type PointType = i8;
555 const N: usize = 2;
556
557 fn assert_point_in_orthant<T: Ord + std::fmt::Debug, const N: usize>(
558 p: &Point<T, N>,
559 mid: &Point<T, N>,
560 mut orth: usize,
561 ) {
562 for i in 0..N {
563 if orth & 1 == 1 {
564 assert!(
565 p.0[i] >= mid.0[i],
566 "point {p:?} is < {mid:?} midpoint (index {i})",
567 );
568 } else {
569 assert!(
570 p.0[i] < mid.0[i],
571 "point {p:?} is >= {mid:?} midpoint (index {i})",
572 );
573 }
574
575 orth >>= 1;
576 }
577 }
578
579 proptest! {
580 #[test]
581 fn test_sort_layer_num_splits(
582 lo in any::<usize>(),
583 points in prop::collection::vec(
584 prop::array::uniform(any::<PointType>()),
585 1..20
586 )
587 ) {
588 let mut points: Vec<Point<PointType, N>> = points.into_iter()
589 .map(Point::from)
590 .collect();
591
592 let n = points.len();
593 let mut keys = vec![0usize; n];
594 let mut split_queue = VecDeque::with_capacity(n);
595
596 let mut swaps = Vec::with_capacity(n);
597 let mut splits = vec![0; util::num_divs::<N>()];
598
599 let Some(bound) = Bound::from_points(&points) else {
600 unreachable!("We always have at least one point")
601 };
602 let mid = bound.center();
603
604 Tree::sort_layer(mid.clone(), &mut points, &mut keys, &mut swaps, &mut splits);
605 split_queue.extend(splits.iter().copied().map(|s| s + lo));
606
607 assert_eq!(split_queue.len(), util::num_divs::<N>(), "Expected {} splits, found {}", util::num_divs::<N>(), split_queue.len());
609 }
610
611 #[test]
612 fn test_sort_layer_sorted(
613 points in prop::collection::vec(
614 prop::array::uniform(any::<PointType>()),
615 1..20
616 )
617 ) {
618 let lo = 0;
619 let mut points: Vec<Point<PointType, N>> = points.into_iter()
620 .map(Point::from)
621 .collect();
622
623 let n = points.len();
624 let mut keys = vec![0usize; n];
625 let mut split_queue = VecDeque::with_capacity(n);
626
627 let mut swaps = Vec::with_capacity(n);
628 let mut splits = vec![0; util::num_divs::<N>()];
629
630 let Some(bound) = Bound::from_points(&points) else {
631 unreachable!("We always have at least one point")
632 };
633 let mid = bound.center();
634
635 Tree::sort_layer(mid.clone(), &mut points, &mut keys, &mut swaps, &mut splits);
636 split_queue.extend(splits.iter().copied().map(|s| s + lo));
637
638 for (i, &lo) in split_queue.iter().enumerate() {
640 let hi = split_queue.get(i + 1).copied().unwrap_or(n);
641
642 let orth = &points[lo..hi];
644 for p in orth {
645 assert_point_in_orthant(p, &mid, i);
646 }
647 }
648 }
649 }
650}