hyper_tree/
lib.rs

1#![deny(unsafe_code)]
2
3mod bound;
4mod point;
5mod traits;
6mod util;
7
8use core::cmp::Ord;
9use core::marker::Copy;
10use core::ops::Sub;
11use std::collections::VecDeque;
12
13use bound::Bound;
14
15use crate::point::Point;
16use crate::traits::Epsilon;
17use crate::traits::Mean;
18
19pub type QuadTree<T> = Tree<T, 2>;
20pub type OcTree<T> = Tree<T, 3>;
21
22pub struct Tree<T, const N: usize> {
23    /// The points stored by our tree. These are sorted into 2^N regions by their respective
24    /// orthants, and each subregion is sorted in the same way, etc `depth` times.
25    points: Vec<Point<T, N>>,
26
27    /// The splits of our tree allow us to index into `points` to recover a particular orthant.
28    ///
29    /// They are stored in level-order in groups of size 2^N and represent the starting index in
30    /// `points` for that orthant.
31    splits: Vec<usize>,
32
33    /// The max depth of our tree.
34    depth: u32,
35}
36
37impl<T, const N: usize> Tree<T, N> {
38    fn uninit(points: Vec<Point<T, N>>, depth: u32) -> Self
39    where
40        T: Ord + Copy,
41    {
42        // * On depth `d`, we get `2^N^d` splits
43        // * 2^N^0 + 2^N^1 + ... + 2^N^d = 2^N^(d + 1) / (2^N - 1)
44        // * approximate with 2^N^(d + 1)
45        let num_splits = util::num_divs::<N>().pow(depth + 1);
46        let splits = Vec::with_capacity(num_splits);
47
48        Self {
49            points,
50            splits,
51            depth,
52        }
53    }
54
55    pub fn new(points: Vec<Point<T, N>>, depth: u32) -> Self
56    where
57        T: Mean + Epsilon + Sub<Output = T> + Ord,
58    {
59        let mut tree = Self::uninit(points, depth);
60
61        tree.build();
62
63        tree
64    }
65
66    fn build(&mut self)
67    where
68        T: Ord + Mean + Epsilon + Sub<Output = T>,
69    {
70        let n = self.points.len();
71
72        let Some(bound) = Bound::from_points(&self.points) else {
73            return;
74        };
75
76        let mut keys = vec![0; n];
77        let mut split_queue = VecDeque::with_capacity(n);
78        let mut bound_queue = VecDeque::with_capacity(n);
79
80        // Buf of indices used for sorting points
81        let mut swaps = Vec::with_capacity(n);
82
83        // Small buffer to hold the splits for the current layer
84        let mut splits = vec![0; util::num_divs::<N>()];
85
86        split_queue.push_back(0);
87        bound_queue.push_back(bound);
88
89        for d in 0..self.depth {
90            // A whole layer of the tree contains 2^N^d regions
91            let regions = util::num_divs::<N>().pow(d);
92
93            for _ in 0..regions {
94                let Some(lo) = split_queue.pop_front() else {
95                    unreachable!()
96                };
97
98                let Some(bound) = bound_queue.pop_front() else {
99                    unreachable!()
100                };
101
102                self.splits.push(lo);
103
104                // Splits are only ever not monotone when we are
105                // about to process a new layer of the tree
106                let hi = split_queue.front().copied().unwrap_or(n);
107                let hi = if hi < lo { n } else { hi };
108
109                let points = &mut self.points[lo..hi];
110                let keys = &mut keys[lo..hi];
111
112                let mid = bound.center();
113                Self::sort_layer(mid, points, keys, &mut swaps, &mut splits);
114
115                split_queue.extend(splits.iter().copied().map(|s| s + lo));
116                splits.fill(0);
117
118                // If the bound is minimal we can't do this
119                let Some(bounds) = bound.split() else {
120                    continue;
121                };
122
123                bound_queue.extend(bounds);
124            }
125        }
126
127        self.splits.extend(split_queue);
128    }
129
130    fn sort_layer(
131        mid: Point<T, N>,
132        points: &mut [Point<T, N>],
133        keys: &mut [usize],
134        swaps: &mut Vec<usize>,
135        splits: &mut [usize],
136    ) where
137        T: Ord,
138    {
139        debug_assert_eq!(points.len(), keys.len());
140        let n = points.len();
141
142        // For each dimensional axis
143        for i in 0..N {
144            // For each indexed point
145            for (j, p) in points.iter().enumerate() {
146                if p.0[i] >= mid.0[i] {
147                    keys[j] |= 1 << i;
148                }
149            }
150        }
151
152        // Sort by keys order
153        swaps.extend(0..n);
154        util::argsort(keys, swaps);
155        util::sort_by_argsort(points, swaps);
156
157        Self::compute_splits(keys, splits);
158
159        keys.fill(0);
160        swaps.clear();
161    }
162
163    fn compute_splits(keys: &[usize], splits: &mut [usize]) {
164        for &k in keys {
165            splits[k] += 1; // 0 <= k < 2^N
166        }
167
168        // Accumulate the list
169        for i in 1..util::num_divs::<N>() {
170            splits[i] += splits[i - 1];
171        }
172
173        splits.rotate_right(1);
174        splits[0] = 0;
175    }
176}
177
178#[cfg(test)]
179mod test_sort_layer {
180    use std::collections::VecDeque;
181    use std::fmt::Debug;
182
183    use crate::Tree;
184    use crate::bound::Bound;
185    use crate::point::Point;
186    use crate::traits::Mean;
187    use crate::util;
188
189    /// Returns `mid` and the `split_queue`
190    fn sort_layer_wrapper<T: Copy + Ord + Mean + Debug, const N: usize>(
191        points: &mut [Point<T, N>],
192        lo: usize,
193    ) -> (Point<T, N>, VecDeque<usize>) {
194        let n = points.len();
195
196        let Some(bound) = Bound::from_points(points) else {
197            panic!("Provide at least one point")
198        };
199
200        let mid = bound.center();
201        let mut keys = vec![0; n];
202        let mut swaps = Vec::with_capacity(n);
203        let mut splits = vec![0; util::num_divs::<N>()];
204        let mut split_queue = VecDeque::with_capacity(n);
205
206        Tree::sort_layer(mid.clone(), points, &mut keys, &mut swaps, &mut splits);
207
208        split_queue.extend(splits.iter().copied().map(|s| s + lo));
209
210        (mid, split_queue)
211    }
212
213    #[test]
214    fn no_offset() {
215        let mut points = [[0, 2], [2, 2], [2, 0], [0, 0]].map(Into::into).to_vec();
216        let exp_points = &[[0, 0], [2, 0], [0, 2], [2, 2]].map(Into::into);
217
218        let n = points.len();
219
220        let lo = 0;
221        let (mid, mut split_queue) = sort_layer_wrapper(&mut points, lo);
222
223        assert_eq!(points, exp_points);
224        assert_eq!(split_queue, [0, 1, 2, 3]);
225
226        let split_queue = split_queue.make_contiguous();
227        let Ok([a, b, c, d]) = TryInto::<[usize; 4]>::try_into(split_queue) else {
228            unreachable!()
229        };
230
231        assert_eq!(mid, (1, 1).into());
232
233        let nw = &points[a..b];
234        let ne = &points[b..c];
235        let sw = &points[c..d];
236        let se = &points[d..n];
237
238        for p in nw {
239            assert!(p.0[0] < mid.0[0], "{p:?} not in NW");
240            assert!(p.0[1] < mid.0[1], "{p:?} not in NW");
241        }
242
243        for p in ne {
244            assert!(p.0[0] >= mid.0[0], "{p:?} not in NE");
245            assert!(p.0[1] < mid.0[1], "{p:?} not in NE");
246        }
247
248        for p in sw {
249            assert!(p.0[0] < mid.0[0], "{p:?} not in SW");
250            assert!(p.0[1] >= mid.0[1], "{p:?} not in SW");
251        }
252
253        for p in se {
254            assert!(p.0[0] >= mid.0[0], "{p:?} not in SE");
255            assert!(p.0[1] >= mid.0[1], "{p:?} not in SE");
256        }
257    }
258
259    #[test]
260    fn with_offset() {
261        let mut points = [[0, 0], [0, 0], [0, 0], [0, 2], [2, 2], [2, 0], [0, 0]]
262            .map(Into::into)
263            .to_vec();
264        let exp_points = &[[0, 0], [0, 0], [0, 0], [0, 0], [2, 0], [0, 2], [2, 2]].map(Into::into);
265
266        let n = points.len();
267
268        let lo = 3;
269        let (mid, mut split_queue) = sort_layer_wrapper(&mut points[lo..], lo);
270
271        assert_eq!(points, exp_points);
272        assert_eq!(split_queue, [3, 4, 5, 6]);
273
274        let split_queue = split_queue.make_contiguous();
275        let Ok([a, b, c, d]) = TryInto::<[usize; 4]>::try_into(split_queue) else {
276            unreachable!()
277        };
278
279        assert_eq!(mid, (1, 1).into());
280
281        let nw = &points[a..b];
282        let ne = &points[b..c];
283        let sw = &points[c..d];
284        let se = &points[d..n];
285
286        for p in nw {
287            assert!(p.0[0] < mid.0[0], "{p:?} not in NW");
288            assert!(p.0[1] < mid.0[1], "{p:?} not in NW");
289        }
290
291        for p in ne {
292            assert!(p.0[0] >= mid.0[0], "{p:?} not in NE");
293            assert!(p.0[1] < mid.0[1], "{p:?} not in NE");
294        }
295
296        for p in sw {
297            assert!(p.0[0] < mid.0[0], "{p:?} not in SW");
298            assert!(p.0[1] >= mid.0[1], "{p:?} not in SW");
299        }
300
301        for p in se {
302            assert!(p.0[0] >= mid.0[0], "{p:?} not in SE");
303            assert!(p.0[1] >= mid.0[1], "{p:?} not in SE");
304        }
305    }
306}
307
308#[cfg(test)]
309mod test_tree_d1 {
310    use crate::Tree;
311    use crate::point::Point;
312
313    const DEPTH: u32 = 1;
314
315    #[test]
316    fn ordered_2d() {
317        let points = [[0, 0], [2, 0], [0, 2], [2, 2]].map(Into::into).to_vec();
318
319        let exp_points = &[[0, 0], [2, 0], [0, 2], [2, 2]].map(Into::into);
320
321        let tree = Tree::new(points, DEPTH);
322
323        assert_eq!(tree.points, exp_points);
324        assert_eq!(tree.splits, [0, 0, 1, 2, 3]);
325    }
326
327    #[test]
328    fn unordered_2d() {
329        let points = [[0, 2], [2, 2], [2, 0], [0, 0]].map(Into::into).to_vec();
330
331        let exp_points = &[[0, 0], [2, 0], [0, 2], [2, 2]].map(Into::into);
332
333        let tree = Tree::new(points, DEPTH);
334
335        assert_eq!(tree.points, exp_points);
336        assert_eq!(tree.splits, [0, 0, 1, 2, 3]);
337    }
338
339    #[test]
340    fn simple_3d() {
341        let points: Vec<Point<i32, 3>> = [(0, 0, 0), (2, 2, 2)].map(Into::into).to_vec();
342
343        let exp_points = [(0, 0, 0), (2, 2, 2)].map(Into::into);
344
345        let tree = Tree::new(points, DEPTH);
346
347        assert_eq!(tree.points, exp_points);
348        assert_eq!(tree.splits, [0, 0, 1, 1, 1, 1, 1, 1, 1]);
349    }
350}
351
352#[cfg(test)]
353mod test_tree_d2 {
354    use crate::Tree;
355
356    const DEPTH: u32 = 2;
357
358    #[test]
359    #[rustfmt::skip]
360    fn unordered_2d() {
361        let exp_points = [
362            // mid: (1, 1)
363
364            (0, 0),
365
366            (1, 0),    (3, 0), (2, 0),                                                      // mid: (2, 0)
367
368            (0, 1),                       (0, 3), (0, 2),                                   // mid: (0, 2)
369
370            (1, 1),    (2, 1), (3, 1),    (1, 2), (1, 3),    (3, 3), (2, 3), (3, 2), (2, 2) // mid: (2, 2)
371        ].map(Into::into);
372
373        let points = [
374            (3, 0), (3, 3), (2, 3), (0, 0),
375            (0, 3), (0, 2), (0, 1), (1, 2),
376            (1, 3), (2, 1), (3, 1), (2, 0),
377            (1, 0), (1, 1), (3, 2), (2, 2),
378        ].map(Into::into).to_vec();
379
380        let tree = Tree::new(points, DEPTH);
381
382        assert_eq!(tree.points, exp_points);
383        assert_eq!(
384            tree.splits, 
385            [
386                // whole tree split
387                0,
388
389                // depth 2 splits
390                0, 1, 4, 7,
391
392                // depth 2 splits
393                0, 0, 0, 0,    1, 1, 1, 2,    4, 4, 5, 5,    7, 8, 10, 12
394            ]
395        );
396    }
397
398    #[test]
399    #[rustfmt::skip]
400    fn sorted_larger_2d() {
401        let points = [
402            (0, 0), (2, 0), (0, 2), (2, 2),
403
404            (4, 0),
405
406            (0, 4), (2, 4), (0, 6), (2, 6),
407
408            (4, 4), (6, 4), (4, 6), (6, 6),
409        ].map(Into::into).to_vec();
410
411        let exp_points = &[
412            (0, 0), (2, 0), (0, 2), (2, 2),
413
414            (4, 0),
415
416            (0, 4), (0, 6), (2, 4), (2, 6),
417
418            (4, 4), (6, 4), (4, 6), (6, 6),
419        ].map(Into::into);
420
421        let tree: Tree<i32, 2> = Tree::new(points, DEPTH);
422
423        assert_eq!(tree.points, exp_points);
424        assert_eq!(
425            tree.splits,
426            [
427                // whole tree split
428                0,
429
430                // depth 1 splits
431                0, 4, 5, 9,
432
433                // depth 2 splits
434                0, 1, 2, 3,    4, 4, 5, 5,    5, 5, 5, 7,    9, 9, 9, 9
435            ]
436        );
437    }
438}
439
440#[cfg(test)]
441mod test_tree_d3 {
442    use std::ops::Range;
443
444    use crate::Tree;
445    use crate::point::Point;
446
447    const DEPTH: u32 = 3;
448
449    fn range2(xs: Range<i32>, ys: Range<i32>) -> Vec<Point<i32, 2>> {
450        xs.flat_map(|x| ys.clone().map(move |y| [x, y].into()))
451            .collect()
452    }
453
454    #[test]
455    #[rustfmt::skip]
456    fn unordered_large() {
457        let points = range2(0..5, 0..5);
458
459        let exp_points = &[
460            // mid: (2, 2)
461
462            (0, 0),            (1, 0),            (0, 1),            (1, 1),                        // mid: (1, 1)
463
464            (2, 0), (3, 0),    (4, 0),            (2, 1),            (3, 1), (4, 1),                // mid: (3, 1)
465
466            (0, 2),            (1, 2),            (0, 3), (0, 4),    (1, 3), (1, 4),                // mid: (1, 3)
467
468            (2, 2),            (3, 2), (4, 2),    (2, 3), (2, 4),    (3, 3), (3, 4), (4, 3), (4, 4) // mid: (3, 3)
469        ].map(Into::into);
470
471        let tree = Tree::new(points, DEPTH);
472
473        assert_eq!(tree.points, exp_points);
474        assert_eq!(
475            tree.splits,
476            [
477                // whole tree split
478                0, 
479
480                // depth = 1 splits
481                0, 4, 10, 16, 
482
483                // depth 2 splits
484                0, 1, 2, 3,    4, 5, 7, 8,    10, 11, 12, 14,    16, 17, 19, 21, 
485
486                // depth 3 splits
487                0, 0, 0, 0,        1, 1, 1, 1,        2, 2, 2, 2,        3, 3, 3, 3, 
488                4, 4, 4, 4,        5, 5, 5, 5,        7, 7, 7, 7,        8, 8, 8, 8,
489                10, 10, 10, 10,    11, 11, 11, 11,    12, 12, 12, 12,    14, 14, 14, 14,
490                16, 16, 16, 16,    17, 17, 17, 17,    19, 19, 19, 19,    21, 21, 21, 21
491            ]
492        );
493    }
494}
495
496#[cfg(test)]
497mod test_splits {
498    use crate::Tree;
499
500    #[test]
501    #[rustfmt::skip]
502    fn test_splits_2d() {
503        // key split pairs
504        let tests: &[(&[usize], &[usize])] = &[
505            (&[0],          &[0, 1, 1, 1]),
506            (&[0, 0],       &[0, 2, 2, 2]),
507            (&[1, 2],       &[0, 0, 1, 2]),
508            (&[0, 1, 2, 3], &[0, 1, 2, 3]),
509            (&[0, 1, 2, 2], &[0, 1, 2, 4]),
510            (&[0, 1, 1, 3], &[0, 1, 3, 3]),
511            (&[0, 0, 0, 3], &[0, 3, 3, 3]),
512        ];
513
514        let mut splits = [0; 4];
515
516        for (keys, exp) in tests {
517            Tree::<i8, 2>::compute_splits(keys, &mut splits);
518
519            assert_eq!(splits, *exp);
520            splits.fill(0);
521        }
522    }
523
524    #[test]
525    #[rustfmt::skip]
526    fn test_splits_3d() {
527        // key split pairs
528        let tests: &[(&[usize], &[usize])] = &[
529            (&[0, 7],       &[0, 1, 1, 1, 1, 1, 1, 1]),
530        ];
531
532        let mut splits = [0; 8];
533
534        for (keys, exp) in tests {
535            Tree::<i8, 3>::compute_splits(keys, &mut splits);
536
537            assert_eq!(splits, *exp);
538            splits.fill(0);
539        }
540    }
541}
542
543#[cfg(test)]
544mod proptests {
545    use std::collections::VecDeque;
546
547    use proptest::prelude::*;
548
549    use crate::Tree;
550    use crate::bound::Bound;
551    use crate::point::Point;
552    use crate::util;
553
554    type PointType = i8;
555    const N: usize = 2;
556
557    fn assert_point_in_orthant<T: Ord + std::fmt::Debug, const N: usize>(
558        p: &Point<T, N>,
559        mid: &Point<T, N>,
560        mut orth: usize,
561    ) {
562        for i in 0..N {
563            if orth & 1 == 1 {
564                assert!(
565                    p.0[i] >= mid.0[i],
566                    "point {p:?} is < {mid:?} midpoint (index {i})",
567                );
568            } else {
569                assert!(
570                    p.0[i] < mid.0[i],
571                    "point {p:?} is >= {mid:?} midpoint (index {i})",
572                );
573            }
574
575            orth >>= 1;
576        }
577    }
578
579    proptest! {
580        #[test]
581        fn test_sort_layer_num_splits(
582            lo in any::<usize>(),
583            points in prop::collection::vec(
584                prop::array::uniform(any::<PointType>()),
585                1..20
586            )
587        ) {
588            let mut points: Vec<Point<PointType, N>> = points.into_iter()
589                .map(Point::from)
590                .collect();
591
592            let n = points.len();
593            let mut keys = vec![0usize; n];
594            let mut split_queue = VecDeque::with_capacity(n);
595
596            let mut swaps = Vec::with_capacity(n);
597            let mut splits = vec![0; util::num_divs::<N>()];
598
599            let Some(bound) = Bound::from_points(&points) else {
600                unreachable!("We always have at least one point")
601            };
602            let mid = bound.center();
603
604            Tree::sort_layer(mid.clone(), &mut points, &mut keys, &mut swaps, &mut splits);
605            split_queue.extend(splits.iter().copied().map(|s| s + lo));
606
607            // We should always have 2^N splits
608            assert_eq!(split_queue.len(), util::num_divs::<N>(), "Expected {} splits, found {}", util::num_divs::<N>(), split_queue.len());
609        }
610
611        #[test]
612        fn test_sort_layer_sorted(
613            points in prop::collection::vec(
614                prop::array::uniform(any::<PointType>()),
615                1..20
616            )
617        ) {
618            let lo = 0;
619            let mut points: Vec<Point<PointType, N>> = points.into_iter()
620                .map(Point::from)
621                .collect();
622
623            let n = points.len();
624            let mut keys = vec![0usize; n];
625            let mut split_queue = VecDeque::with_capacity(n);
626
627            let mut swaps = Vec::with_capacity(n);
628            let mut splits = vec![0; util::num_divs::<N>()];
629
630            let Some(bound) = Bound::from_points(&points) else {
631                unreachable!("We always have at least one point")
632            };
633            let mid = bound.center();
634
635            Tree::sort_layer(mid.clone(), &mut points, &mut keys, &mut swaps, &mut splits);
636            split_queue.extend(splits.iter().copied().map(|s| s + lo));
637
638            // For every orthant
639            for (i, &lo) in split_queue.iter().enumerate() {
640                let hi = split_queue.get(i + 1).copied().unwrap_or(n);
641
642                // Get all the points in that orthant
643                let orth = &points[lo..hi];
644                for p in orth {
645                    assert_point_in_orthant(p, &mid, i);
646                }
647            }
648        }
649    }
650}