capt/
lib.rs

1//! # Collision-Affording Point Trees: SIMD-Amenable Nearest Neighbors for Fast Collision Checking
2//!
3//! This is a Rust implementation of the _collision-affording point tree_ (CAPT), a data structure
4//! for SIMD-parallel collision-checking between spheres and point clouds.
5//!
6//! You may also want to look at the following other sources:
7//!
8//! - [The paper](https://arxiv.org/abs/2406.02807)
9//! - [C++ implementation](https://github.com/KavrakiLab/vamp)
10//! - [Blog post about it](https://www.claytonwramsey.com/blog/captree)
11//! - [Demo video](https://youtu.be/BzDKdrU1VpM)
12//!
13//! If you use this in an academic work, please cite it as follows:
14//!
15//! ```bibtex
16//! @InProceedings{capt,
17//!   title = {Collision-Affording Point Trees: {SIMD}-Amenable Nearest Neighbors for Fast Collision Checking},
18//!   author = {Ramsey, Clayton W. and Kingston, Zachary and Thomason, Wil and Kavraki, Lydia E.},
19//!   booktitle = {Robotics: Science and Systems},
20//!   date = {2024},
21//!   url = {http://arxiv.org/abs/2406.02807},
22//!   note = {To Appear.}
23//! }
24//! ```
25//!
26//! ## Usage
27//!
28//! The core data structure in this library is the [`Capt`], which is a search tree used for
29//! collision checking. [`Capt`]s are polymorphic over dimension and data type. On construction,
30//! they take in a list of points in a point cloud and a _radius range_: a tuple of the minimum and
31//! maximum radius used for querying.
32//!
33//! ```rust
34//! use capt::Capt;
35//!
36//! // list of points in cloud
37//! let points = [[0.0, 1.1], [0.2, 3.1]];
38//! let r_min = 0.05;
39//! let r_max = 2.0;
40//!
41//! let capt = Capt::<2>::new(&points, (r_min, r_max));
42//! ```
43//!
44//! Once you have a `Capt`, you can use it for collision-checking against spheres.
45//! Correct answers are only guaranteed if you collision-check against spheres with a radius inside
46//! the radius range.
47//!
48//! ```rust
49//! # use capt::Capt;
50//! # let points = [[0.0, 1.1], [0.2, 3.1]];
51//! # let capt = Capt::<2>::new(&points, (0.05, 2.0));
52//! let center = [0.0, 0.0]; // center of sphere
53//! let radius0 = 1.0; // radius of sphere
54//! assert!(!capt.collides(&center, radius0));
55//!
56//! let radius1 = 1.5;
57//! assert!(capt.collides(&center, radius1));
58//! ```
59//!
60//! ## Optional features
61//!
62//! This crate exposes one feature, `simd`, which enables a SIMD-parallel interface for querying
63//! `Capt`s. The `simd` feature requires nightly Rust and therefore should be considered unstable.
64//! This enables the function `Capt::collides_simd`, a parallel collision checker for batches of
65//! search queries.
66//!
67//! ## License
68//!
69//! This work is licensed to you under the Apache 2.0 license.
70#![cfg_attr(feature = "simd", feature(portable_simd))]
71#![cfg_attr(not(test), no_std)]
72#![warn(clippy::pedantic, clippy::cargo, clippy::nursery, missing_docs)]
73
74extern crate alloc;
75use alloc::{boxed::Box, vec, vec::Vec};
76
77use core::{
78    array,
79    fmt::Debug,
80    mem::size_of,
81    ops::{Add, Sub},
82};
83
84#[cfg(feature = "simd")]
85use core::{
86    ops::{AddAssign, Mul, SubAssign},
87    ptr,
88    simd::{
89        LaneCount, Mask, Simd, SimdElement, SupportedLaneCount,
90        cmp::{SimdPartialEq, SimdPartialOrd},
91        ptr::SimdConstPtr,
92    },
93};
94
95use elain::{Align, Alignment};
96
97/// A generic trait representing values which may be used as an "axis;" that is, elements of a
98/// vector representing a point.
99///
100/// An array of `Axis` values is a point which can be stored in a [`Capt`].
101/// Accordingly, this trait specifies nearly all the requirements for points that [`Capt`]s require.
102/// The only exception is that [`Axis`] values really ought to be [`Ord`] instead of [`PartialOrd`];
103/// however, due to the disaster that is IEE 754 floating point numbers, `f32` and `f64` are not
104/// totally ordered. As a compromise, we relax the `Ord` requirement so that you can use floats in a
105/// `Capt`.
106///
107/// # Examples
108///
109/// ```
110/// #[derive(Clone, Copy, PartialOrd, PartialEq)]
111/// enum HyperInt {
112///     MinusInf,
113///     Real(i32),
114///     PlusInf,
115/// }
116///
117/// impl std::ops::Add for HyperInt {
118/// // ...
119/// #    type Output = Self;
120/// #
121/// #    fn add(self, rhs: Self) -> Self {
122/// #        match (self, rhs) {
123/// #            (Self::MinusInf, Self::PlusInf) => Self::Real(0), // evil, but who cares?
124/// #            (Self::MinusInf, _) | (_, Self::MinusInf) => Self::MinusInf,
125/// #            (Self::PlusInf, _) | (_, Self::PlusInf) => Self::PlusInf,
126/// #            (Self::Real(x), Self::Real(y)) => Self::Real(x + y),
127/// #        }
128/// #    }
129/// }
130///
131///
132/// impl std::ops::Sub for HyperInt {
133/// // ...
134/// #    type Output = Self;
135/// #
136/// #    fn sub(self, rhs: Self) -> Self {
137/// #        match (self, rhs) {
138/// #            (Self::MinusInf, Self::MinusInf) | (Self::PlusInf, Self::PlusInf) => Self::Real(0), // evil, but who cares?
139/// #            (Self::MinusInf, _) | (_, Self::PlusInf) => Self::MinusInf,
140/// #            (Self::PlusInf, _) | (_, Self::MinusInf) => Self::PlusInf,
141/// #            (Self::Real(x), Self::Real(y)) => Self::Real(x - y),
142/// #        }
143/// #    }
144/// }
145///
146/// impl capt::Axis for HyperInt {
147///     const ZERO: Self = Self::Real(0);
148///     const INFINITY: Self = Self::PlusInf;
149///     const NEG_INFINITY: Self = Self::MinusInf;
150///
151///     fn is_finite(self) -> bool {
152///         matches!(self, Self::Real(_))
153///     }
154///
155///     fn in_between(self, rhs: Self) -> Self {
156///         match (self, rhs) {
157///             (Self::PlusInf, Self::MinusInf) | (Self::MinusInf, Self::PlusInf) => Self::Real(0),
158///             (Self::MinusInf, _) | (_, Self::MinusInf) => Self::MinusInf,
159///             (Self::PlusInf, _) | (_, Self::PlusInf) => Self::PlusInf,
160///             (Self::Real(a), Self::Real(b)) => Self::Real((a + b) / 2)
161///         }
162///     }
163///
164///     fn square(self) -> Self {
165///         match self {
166///             Self::PlusInf | Self::MinusInf => Self::PlusInf,
167///             Self::Real(a) => Self::Real(a * a),
168///         }
169///     }
170/// }
171/// ```
172pub trait Axis: PartialOrd + Copy + Sub<Output = Self> + Add<Output = Self> {
173    /// A zero value.
174    const ZERO: Self;
175    /// A value which is larger than any finite value.
176    const INFINITY: Self;
177    /// A value which is smaller than any finite value.
178    const NEG_INFINITY: Self;
179
180    #[must_use]
181    /// Determine whether this value is finite or infinite.
182    fn is_finite(self) -> bool;
183
184    #[must_use]
185    /// Compute a value of `Self` which is halfway between `self` and `rhs`.
186    /// If there are no legal values between `self` and `rhs`, it is acceptable to return `self`
187    /// instead.
188    fn in_between(self, rhs: Self) -> Self;
189
190    #[must_use]
191    /// Compute the square of this value.
192    fn square(self) -> Self;
193}
194
195#[cfg(feature = "simd")]
196/// A trait used for SIMD elements.
197pub trait AxisSimdElement: SimdElement + Default + Axis {}
198
199#[cfg(feature = "simd")]
200/// A trait used for masks over SIMD vectors, used for parallel querying on [`Capt`]s.
201///
202/// The interface for this trait should be considered unstable since the standard SIMD API may
203/// change with Rust versions.
204pub trait AxisSimd<const L: usize>:
205    Sized
206    + SimdPartialOrd
207    + Add<Output = Self>
208    + AddAssign
209    + Sub<Output = Self>
210    + SubAssign
211    + Mul<Output = Self>
212where
213    LaneCount<L>: SupportedLaneCount,
214{
215    /// Cast a mask for a SIMD vector into a mask of `isize`s.
216    fn cast_mask(mask: <Self as SimdPartialEq>::Mask) -> Mask<isize, L>;
217    /// Determine whether a mask contains any true elements.
218    fn mask_any(mask: <Self as SimdPartialEq>::Mask) -> bool;
219}
220
221/// An index type used for lookups into and out of arrays.
222///
223/// This is implemented so that [`Capt`]s can use smaller index sizes (such as [`u32`] or [`u16`])
224/// for improved memory performance.
225pub trait Index: TryFrom<usize> + TryInto<usize> + Copy {
226    /// The zero index. This must be equal to `(0usize).try_into().unwrap()`.
227    const ZERO: Self;
228}
229
230#[cfg(feature = "simd")]
231/// A SIMD parallel version of [`Index`].
232///
233/// This is used for implementing SIMD lookups in a [`Capt`].
234/// The interface for this trait should be considered unstable since the standard SIMD API may
235/// change with Rust versions.
236pub trait IndexSimd: SimdElement + Default {
237    #[must_use]
238    /// Convert a SIMD array of `Self` to a SIMD array of `usize`, without checking that each
239    /// element is valid.
240    ///
241    /// # Safety
242    ///
243    /// This function is only safe if all values of `x` are valid when converted to a `usize`.
244    unsafe fn to_simd_usize_unchecked<const L: usize>(x: Simd<Self, L>) -> Simd<usize, L>
245    where
246        LaneCount<L>: SupportedLaneCount;
247}
248
249macro_rules! impl_axis {
250    ($t: ty, $tm: ty) => {
251        impl Axis for $t {
252            const ZERO: Self = 0.0;
253            const INFINITY: Self = <$t>::INFINITY;
254            const NEG_INFINITY: Self = <$t>::NEG_INFINITY;
255            fn is_finite(self) -> bool {
256                <$t>::is_finite(self)
257            }
258
259            fn in_between(self, rhs: Self) -> Self {
260                (self + rhs) / 2.0
261            }
262
263            fn square(self) -> Self {
264                self * self
265            }
266        }
267
268        #[cfg(feature = "simd")]
269        impl AxisSimdElement for $t {}
270
271        #[cfg(feature = "simd")]
272        impl<const L: usize> AxisSimd<L> for Simd<$t, L>
273        where
274            LaneCount<L>: SupportedLaneCount,
275        {
276            fn cast_mask(mask: <Self as SimdPartialEq>::Mask) -> Mask<isize, L> {
277                mask.into()
278            }
279            fn mask_any(mask: <Self as SimdPartialEq>::Mask) -> bool {
280                mask.any()
281            }
282        }
283    };
284}
285
286macro_rules! impl_idx {
287    ($t: ty) => {
288        impl Index for $t {
289            const ZERO: Self = 0;
290        }
291
292        #[cfg(feature = "simd")]
293        impl IndexSimd for $t {
294            unsafe fn to_simd_usize_unchecked<const L: usize>(x: Simd<Self, L>) -> Simd<usize, L>
295            where
296                LaneCount<L>: SupportedLaneCount,
297            {
298                unsafe { x.to_array().map(|a| a.try_into().unwrap_unchecked()).into() }
299            }
300        }
301    };
302}
303
304impl_axis!(f32, i32);
305impl_axis!(f64, i64);
306
307impl_idx!(u8);
308impl_idx!(u16);
309impl_idx!(u32);
310impl_idx!(u64);
311impl_idx!(usize);
312
313/// Clamp a floating-point number.
314fn clamp<A: PartialOrd>(x: A, min: A, max: A) -> A {
315    if x < min {
316        min
317    } else if x > max {
318        max
319    } else {
320        x
321    }
322}
323
324#[inline]
325#[allow(clippy::cast_possible_wrap)]
326#[cfg(feature = "simd")]
327fn forward_pass_simd<A, const K: usize, const L: usize>(
328    tests: &[A],
329    centers: &[Simd<A, L>; K],
330) -> Simd<isize, L>
331where
332    Simd<A, L>: AxisSimd<L>,
333    A: AxisSimdElement,
334    LaneCount<L>: SupportedLaneCount,
335{
336    let mut test_idxs: Simd<isize, L> = Simd::splat(0);
337    let mut k = 0;
338    for _ in 0..tests.len().trailing_ones() {
339        let test_ptrs = Simd::splat(tests.as_ptr()).wrapping_offset(test_idxs);
340        let relevant_tests: Simd<A, L> = unsafe { Simd::gather_ptr(test_ptrs) };
341        let cmp_results: Mask<isize, L> =
342            Simd::<A, L>::cast_mask(centers[k % K].simd_ge(relevant_tests));
343
344        let one = Simd::splat(1);
345        test_idxs = (test_idxs << one) + one + (cmp_results.to_int() & Simd::splat(1));
346        k = (k + 1) % K;
347    }
348
349    test_idxs - Simd::splat(tests.len() as isize)
350}
351
352#[repr(C)]
353#[derive(Clone, Copy, Debug, PartialEq, Eq)]
354/// A stable-safe wrapper for `[A; L]` which is aligned to `L`.
355/// Equivalent to a `Simd`, but easier to work with.
356struct MySimd<A, const L: usize>
357where
358    Align<L>: Alignment,
359{
360    data: [A; L],
361    _align: Align<L>,
362}
363
364#[derive(Clone, Debug, PartialEq, Eq)]
365#[allow(clippy::module_name_repetitions)]
366/// A collision-affording point tree (CAPT), which allows for efficient collision-checking in a
367/// SIMD-parallel manner between spheres and point clouds.
368///
369/// # Generic parameters
370///
371/// - `K`: The dimension of the space.
372/// - `L`: The lane size of this tree. Internally, this is the upper bound on the width of a SIMD
373///   lane that can be used in this data structure. The alignment of this structure must be a power
374///   of two.
375/// - `A`: The value of the axes of each point. This should typically be `f32` or `f64`. This should
376///   implement [`Axis`].
377/// - `I`: The index integer. This should generally be an unsigned integer, such as `usize` or
378///   `u32`. This should implement [`Index`].
379///
380/// # Examples
381///
382/// ```
383/// // list of points in cloud
384/// let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]];
385///
386/// // query radii must be between 0.0 and 0.2
387/// let t = capt::Capt::<2>::new(&points, (0.0, 0.2));
388///
389/// assert!(!t.collides(&[0.0, 0.3], 0.1));
390/// assert!(t.collides(&[0.0, 0.2], 0.15));
391/// ```
392pub struct Capt<const K: usize, const L: usize = 8, A = f32, I = usize>
393where
394    Align<L>: Alignment,
395{
396    /// The test values for determining which part of the tree to enter.
397    ///
398    /// The first element of `tests` should be the first value to test against.
399    /// If we are less than `tests[0]`, we move on to `tests[1]`; if not, we move on to `tests[2]`.
400    /// At the `i`-th test performed in sequence of the traversal, if we are less than
401    /// `tests[idx]`, we advance to `2 * idx + 1`; otherwise, we go to `2 * idx + 2`.
402    ///
403    /// The length of `tests` must be `N`, rounded up to the next power of 2, minus one.
404    tests: Box<[A]>,
405    /// Axis-aligned bounding boxes containing the set of afforded points for each cell.
406    aabbs: Box<[Aabb<A, K>]>,
407    /// Indexes for the starts of the affordance buffer subsequence of `points` corresponding to
408    /// each leaf cell in the tree.
409    /// This buffer is padded with one extra `usize` at the end with the maximum length of `points`
410    /// for the sake of branchless computation.
411    starts: Box<[I]>,
412    /// The sets of afforded points for each cell.
413    afforded: [Box<[MySimd<A, L>]>; K],
414    r_point: A,
415}
416
417#[repr(C)]
418#[derive(Clone, Copy, Debug, PartialEq, Eq)]
419#[doc(hidden)]
420/// A prismatic bounding volume.
421pub struct Aabb<A, const K: usize> {
422    /// The lower bound on the volume.
423    pub lo: [A; K],
424    /// The upper bound on the volume.
425    pub hi: [A; K],
426}
427
428#[non_exhaustive]
429#[derive(Clone, Debug, PartialEq, Eq)]
430/// The errors which can occur when calling [`Capt::try_new`].
431pub enum NewCaptError {
432    /// There were too many points in the provided cloud to be represented without integer
433    /// overflow.
434    TooManyPoints,
435    /// At least one of the points had a non-finite value.
436    NonFinite,
437}
438
439impl<A, I, const K: usize, const L: usize> Capt<K, L, A, I>
440where
441    A: Axis,
442    I: Index,
443    Align<L>: Alignment,
444{
445    /// Construct a new CAPT containing all the points in `points`.
446    ///
447    /// `r_range` is a `(minimum, maximum)` pair containing the lower and upper bound on the
448    /// radius of the balls which will be queried against the tree.
449    ///
450    /// # Panics
451    ///
452    /// This function will panic if there are too many points in the tree to be addressed by `I`, or
453    /// if any points contain non-finite non-real value. This can even be the case if there are
454    /// fewer points in `points` than can be addressed by `I` as the CAPT may duplicate points
455    /// for efficiency.
456    ///
457    /// # Examples
458    ///
459    /// ```
460    /// let points = [[0.0]];
461    ///
462    /// let capt = capt::Capt::<1>::new(&points, (0.0, f32::INFINITY));
463    ///
464    /// assert!(capt.collides(&[1.0], 1.5));
465    /// assert!(!capt.collides(&[1.0], 0.5));
466    /// ```
467    ///
468    /// If there are too many points in `points`, this could cause a panic!
469    ///
470    /// ```rust,should_panic
471    /// let points = [[0.0]; 256];
472    ///
473    /// // note that we are using `u8` as our index type
474    /// let capt = capt::Capt::<1, 8, f32, u8>::new(&points, (0.0, f32::INFINITY));
475    /// ```
476    pub fn new(points: &[[A; K]], r_range: (A, A)) -> Self {
477        Self::try_new(points, r_range)
478            .expect("index type I must be able to support all points in CAPT during construction")
479    }
480
481    /// Construct a new CAPT containing all the points in `points` with a point radius `r_point`.
482    ///
483    /// `r_range` is a `(minimum, maximum)` pair containing the lower and upper bound on the
484    /// radius of the balls which will be queried against the tree.
485    ///
486    /// # Panics
487    ///
488    /// This function will panic if there are too many points in the tree to be addressed by `I`, or
489    /// if any points contain non-finite non-real value. This can even be the case if there are
490    /// fewer points in `points` than can be addressed by `I` as the CAPT may duplicate points
491    /// for efficiency.
492    ///
493    /// # Examples
494    ///
495    /// ```
496    /// let points = [[0.0]];
497    ///
498    /// let capt = capt::Capt::<1>::with_point_radius(&points, (0.0, f32::INFINITY), 0.2);
499    ///
500    /// assert!(capt.collides(&[1.0], 1.5));
501    /// assert!(!capt.collides(&[1.0], 0.5));
502    /// ```
503    ///
504    /// If there are too many points in `points`, this could cause a panic!
505    ///
506    /// ```rust,should_panic
507    /// let points = [[0.0]; 256];
508    ///
509    /// // note that we are using `u8` as our index type
510    /// let capt = capt::Capt::<1, 8, f32, u8>::with_point_radius(&points, (0.0, f32::INFINITY), 0.2);
511    /// ```
512    pub fn with_point_radius(points: &[[A; K]], r_range: (A, A), r_point: A) -> Self {
513        Self::try_with_point_radius(points, r_range, r_point)
514            .expect("index type I must be able to support all points in CAPT during construction")
515    }
516
517    /// Construct a new CAPT containing all the points in `points`, checking for index overflow.
518    ///
519    /// `r_range` is a `(minimum, maximum)` pair containing the lower and upper bound on the
520    /// radius of the balls which will be queried against the tree.
521    ///
522    /// # Errors
523    ///
524    /// This function will return `Err(NewCaptError::TooManyPoints)` if there are too many points to
525    /// be indexed by `I`. It will return `Err(NewCaptError::NonFinite)` if any element of
526    /// `points` is non-finite.
527    ///
528    /// # Examples
529    ///
530    /// Unwrapping the output from this function is equivalent to calling [`Capt::new`].
531    ///
532    /// ```
533    /// let points = [[0.0]];
534    ///
535    /// let capt = capt::Capt::<1>::try_new(&points, (0.0, f32::INFINITY)).unwrap();
536    /// ```
537    ///
538    /// In failure, we get an `Err`.
539    ///
540    /// ```
541    /// let points = [[0.0]; 256];
542    ///
543    /// // note that we are using `u8` as our index type
544    /// let opt = capt::Capt::<1, 8, f32, u8>::try_new(&points, (0.0, f32::INFINITY));
545    ///
546    /// assert!(opt.is_err());
547    /// ```
548    pub fn try_new(points: &[[A; K]], r_range: (A, A)) -> Result<Self, NewCaptError> {
549        Self::try_with_point_radius(points, r_range, A::ZERO)
550    }
551
552    /// Construct a new CAPT containing all the points in `points` with a point radius `r_point`,
553    /// checking for index overflow.
554    ///
555    /// `r_range` is a `(minimum, maximum)` pair containing the lower and upper bound on the
556    /// radius of the balls which will be queried against the tree.
557    ///
558    /// # Errors
559    ///
560    /// This function will return `Err(NewCaptError::TooManyPoints)` if there are too many points to
561    /// be indexed by `I`. It will return `Err(NewCaptError::NonFinite)` if any element of
562    /// `points` is non-finite.
563    ///
564    /// # Examples
565    ///
566    /// Unwrapping the output from this function is equivalent to calling
567    /// [`Capt::with_point_radius`].
568    ///
569    /// ```
570    /// let points = [[0.0]];
571    ///
572    /// let capt = capt::Capt::<1>::try_with_point_radius(&points, (0.0, f32::INFINITY), 0.01).unwrap();
573    /// ```
574    ///
575    /// In failure, we get an `Err`.
576    ///
577    /// ```
578    /// let points = [[0.0]; 256];
579    ///
580    /// // note that we are using `u8` as our index type
581    /// let opt =
582    ///     capt::Capt::<1, 8, f32, u8>::try_with_point_radius(&points, (0.0, f32::INFINITY), 0.01);
583    ///
584    /// assert!(opt.is_err());
585    /// ```
586    pub fn try_with_point_radius(
587        points: &[[A; K]],
588        r_range: (A, A),
589        r_point: A,
590    ) -> Result<Self, NewCaptError> {
591        let n2 = points.len().next_power_of_two();
592
593        if points.iter().any(|p| p.iter().any(|x| !x.is_finite())) {
594            return Err(NewCaptError::NonFinite);
595        }
596
597        let mut tests = vec![A::INFINITY; n2 - 1].into_boxed_slice();
598
599        // hack: just pad with infinity to make it a power of 2
600        let mut points2 = vec![[A::INFINITY; K]; n2].into_boxed_slice();
601        points2[..points.len()].copy_from_slice(points);
602        // hack - reduce number of reallocations by allocating a lot of points from the start
603        let mut afforded = array::from_fn(|_| Vec::with_capacity(n2 * 100));
604        let mut starts = vec![I::ZERO; n2 + 1].into_boxed_slice();
605
606        let mut aabbs = vec![
607            Aabb {
608                lo: [A::NEG_INFINITY; K],
609                hi: [A::INFINITY; K],
610            };
611            n2
612        ]
613        .into_boxed_slice();
614
615        unsafe {
616            // SAFETY: We tested that `points` contains no `NaN` values.
617            Self::new_help(
618                &mut points2,
619                &mut tests,
620                &mut aabbs,
621                &mut afforded,
622                &mut starts,
623                0,
624                0,
625                r_range,
626                Vec::new(),
627                Aabb::ALL,
628            )?;
629        }
630
631        Ok(Self {
632            tests,
633            starts,
634            afforded: afforded.map(Vec::into_boxed_slice),
635            aabbs,
636            r_point,
637        })
638    }
639
640    #[allow(clippy::too_many_arguments, clippy::too_many_lines)]
641    /// # Safety
642    ///
643    /// This function will contain undefined behavior if `points` contains any `NaN` values.
644    unsafe fn new_help(
645        points: &mut [[A; K]],
646        tests: &mut [A],
647        aabbs: &mut [Aabb<A, K>],
648        afforded: &mut [Vec<MySimd<A, L>>; K],
649        starts: &mut [I],
650        k: usize,
651        i: usize,
652        r_range: (A, A),
653        in_range: Vec<[A; K]>,
654        cell: Aabb<A, K>,
655    ) -> Result<(), NewCaptError> {
656        unsafe {
657            let rsq_min = r_range.0.square();
658            if let [rep] = *points {
659                let z = i - tests.len();
660                let aabb = &mut aabbs[z];
661                *aabb = Aabb { lo: rep, hi: rep };
662                if rep[0].is_finite() {
663                    // lanes for afforded points
664                    let mut news = [[A::INFINITY; L]; K];
665                    for k in 0..K {
666                        news[k][0] = rep[k];
667                    }
668
669                    // index into the current lane
670                    let mut j = 1;
671
672                    // populate affordance buffer if the representative doesn't cover everything
673                    if !cell.contained_by_ball(&rep, rsq_min) {
674                        for ak in afforded.iter_mut() {
675                            ak.reserve(ak.len() + in_range.len() / L);
676                        }
677                        for p in in_range {
678                            aabb.insert(&p);
679
680                            // start a new lane if it's full
681                            if j == L {
682                                for k in 0..K {
683                                    afforded[k].push(MySimd {
684                                        data: news[k],
685                                        _align: Align::NEW,
686                                    });
687                                }
688                                j = 0;
689                            }
690
691                            // add this point to the lane
692                            for k in 0..K {
693                                news[k][j] = p[k];
694                            }
695
696                            j += 1;
697                        }
698                    }
699
700                    // fill out the last lane with infinities
701                    for k in 0..K {
702                        afforded[k].push(MySimd {
703                            data: news[k],
704                            _align: Align::NEW,
705                        });
706                    }
707                }
708
709                starts[z + 1] = afforded[0]
710                    .len()
711                    .try_into()
712                    .map_err(|_| NewCaptError::TooManyPoints)?;
713                return Ok(());
714            }
715
716            let test = median_partition(points, k);
717            tests[i] = test;
718
719            let (lhs, rhs) = points.split_at_mut(points.len() / 2);
720            let (lo_vol, hi_vol) = cell.split(test, k);
721
722            let lo_too_small = distsq(lo_vol.lo, lo_vol.hi) <= rsq_min;
723            let hi_too_small = distsq(hi_vol.lo, hi_vol.hi) <= rsq_min;
724
725            // retain only points which might be in the affordance buffer for the split-out cells
726            let (lo_afford, hi_afford) = match (lo_too_small, hi_too_small) {
727                (false, false) => {
728                    let mut lo_afford = in_range;
729                    let mut hi_afford = lo_afford.clone();
730                    lo_afford.retain(|pt| pt[k] <= test + r_range.1);
731                    lo_afford.extend(rhs.iter().filter(|pt| pt[k] <= test + r_range.1));
732                    hi_afford.retain(|pt| pt[k].is_finite() && test - r_range.1 <= pt[k]);
733                    hi_afford.extend(
734                        lhs.iter()
735                            .filter(|pt| pt[k].is_finite() && test - r_range.1 <= pt[k]),
736                    );
737
738                    (lo_afford, hi_afford)
739                }
740                (false, true) => {
741                    let mut lo_afford = in_range;
742                    lo_afford.retain(|pt| pt[k] <= test + r_range.1);
743                    lo_afford.extend(rhs.iter().filter(|pt| pt[k] <= test + r_range.1));
744
745                    (lo_afford, Vec::new())
746                }
747                (true, false) => {
748                    let mut hi_afford = in_range;
749                    hi_afford.retain(|pt| test - r_range.1 <= pt[k]);
750                    hi_afford.extend(
751                        lhs.iter()
752                            .filter(|pt| pt[k].is_finite() && test - r_range.1 <= pt[k]),
753                    );
754
755                    (Vec::new(), hi_afford)
756                }
757                (true, true) => (Vec::new(), Vec::new()),
758            };
759
760            let next_k = (k + 1) % K;
761            Self::new_help(
762                lhs,
763                tests,
764                aabbs,
765                afforded,
766                starts,
767                next_k,
768                2 * i + 1,
769                r_range,
770                lo_afford,
771                lo_vol,
772            )?;
773            Self::new_help(
774                rhs,
775                tests,
776                aabbs,
777                afforded,
778                starts,
779                next_k,
780                2 * i + 2,
781                r_range,
782                hi_afford,
783                hi_vol,
784            )?;
785
786            Ok(())
787        }
788    }
789
790    #[must_use]
791    /// Determine whether a point in this tree is within a distance of `radius` to `center`.
792    ///
793    /// Note that this function will accept query radii outside of the range `r_range` passed to the
794    /// construction for this CAPT in [`Capt::new`] or [`Capt::try_new`]. However, if the query
795    /// radius is outside this range, the tree may erroneously return `false` (that is, erroneously
796    /// report non-collision).
797    ///
798    /// # Examples
799    ///
800    /// ```
801    /// let points = [[0.0; 3], [1.0; 3], [0.1, 0.5, 1.0]];
802    /// let capt = capt::Capt::<3>::new(&points, (0.0, 1.0));
803    ///
804    /// assert!(capt.collides(&[1.1; 3], 0.2));
805    /// assert!(!capt.collides(&[2.0; 3], 1.0));
806    ///
807    /// // no guarantees about what this is, since the radius is greater than the construction range
808    /// println!(
809    ///     "collision check result is {:?}",
810    ///     capt.collides(&[100.0; 3], 100.0)
811    /// );
812    /// ```
813    pub fn collides(&self, center: &[A; K], mut radius: A) -> bool {
814        radius = radius + self.r_point;
815        // forward pass through the tree
816        let mut test_idx = 0;
817        let mut k = 0;
818        for _ in 0..self.tests.len().trailing_ones() {
819            test_idx = 2 * test_idx
820                + 1
821                + usize::from(unsafe { *self.tests.get_unchecked(test_idx) } <= center[k]);
822            k = (k + 1) % K;
823        }
824
825        // retrieve affordance buffer location
826        let rsq = radius.square();
827        let i = test_idx - self.tests.len();
828        let aabb = unsafe { self.aabbs.get_unchecked(i) };
829        if aabb.closest_distsq_to(center) > rsq {
830            return false;
831        }
832
833        let mut range = unsafe {
834            // SAFETY: The conversion worked the first way.
835            self.starts[i].try_into().unwrap_unchecked()
836                ..self.starts[i + 1].try_into().unwrap_unchecked()
837        };
838
839        // check affordance buffer
840        range.any(|i| {
841            (0..L).any(|j| {
842                let mut aff_pt = [A::INFINITY; K];
843                for (ak, sk) in aff_pt.iter_mut().zip(&self.afforded) {
844                    *ak = sk[i].data[j];
845                }
846                distsq(aff_pt, *center) <= rsq
847            })
848        })
849    }
850
851    #[must_use]
852    #[doc(hidden)]
853    /// Get the total memory used (stack + heap) by this structure, measured in bytes.
854    /// This function should not be considered stable; it is only used internally for benchmarks.
855    pub const fn memory_used(&self) -> usize {
856        size_of::<Self>()
857            + K * self.afforded[0].len() * size_of::<A>()
858            + self.starts.len() * size_of::<I>()
859            + self.tests.len() * size_of::<I>()
860            + self.aabbs.len() * size_of::<Aabb<A, K>>()
861    }
862
863    #[must_use]
864    #[doc(hidden)]
865    #[allow(clippy::cast_precision_loss)]
866    /// Get the average number of affordances per point.
867    /// This function should not be considered stable; it is only used internally for benchmarks.
868    pub fn affordance_size(&self) -> f64 {
869        self.afforded.len() as f64 / (self.tests.len() + 1) as f64
870    }
871}
872
873#[allow(clippy::mismatching_type_param_order)]
874#[cfg(feature = "simd")]
875impl<A, I, const K: usize, const L: usize> Capt<K, L, A, I>
876where
877    I: IndexSimd,
878    A: Mul<Output = A>,
879    Align<L>: Alignment,
880{
881    #[must_use]
882    /// Determine whether any sphere in the list of provided spheres intersects a point in this
883    /// tree.
884    ///
885    /// # Examples
886    ///
887    /// ```
888    /// #![feature(portable_simd)]
889    /// use std::simd::Simd;
890    ///
891    /// let points = [[1.0, 2.0], [1.1, 1.1]];
892    ///
893    /// let centers = [
894    ///     Simd::from_array([1.0, 1.1, 1.2, 1.3]), // x-positions
895    ///     Simd::from_array([1.0, 1.1, 1.2, 1.3]), // y-positions
896    /// ];
897    /// let radii = Simd::splat(0.05);
898    ///
899    /// let tree = capt::Capt::<2, 4, f32, u32>::new(&points, (0.0, 0.1));
900    ///
901    /// println!("{tree:?}");
902    ///
903    /// assert!(tree.collides_simd(&centers, radii));
904    /// ```
905    pub fn collides_simd(&self, centers: &[Simd<A, L>; K], mut radii: Simd<A, L>) -> bool
906    where
907        LaneCount<L>: SupportedLaneCount,
908        Simd<A, L>: AxisSimd<L>,
909        A: AxisSimdElement,
910    {
911        radii += Simd::splat(self.r_point);
912        let zs = forward_pass_simd(&self.tests, centers);
913
914        let mut inbounds = Mask::splat(true);
915
916        let mut aabb_ptrs = Simd::splat(self.aabbs.as_ptr()).wrapping_offset(zs).cast();
917
918        unsafe {
919            for center in centers {
920                inbounds &= Simd::<A, L>::cast_mask(
921                    (Simd::gather_select_ptr(aabb_ptrs, inbounds, Simd::splat(A::NEG_INFINITY))
922                        - radii)
923                        .simd_le(*center),
924                );
925                aabb_ptrs = aabb_ptrs.wrapping_add(Simd::splat(1));
926            }
927            for center in centers {
928                inbounds &= Simd::<A, L>::cast_mask(
929                    Simd::gather_select_ptr(aabb_ptrs, inbounds, Simd::splat(A::NEG_INFINITY))
930                        .simd_ge(*center - radii),
931                );
932                aabb_ptrs = aabb_ptrs.wrapping_add(Simd::splat(1));
933            }
934        }
935        if !inbounds.any() {
936            return false;
937        }
938
939        // retrieve start/end pointers for the affordance buffer
940        let start_ptrs = Simd::splat(self.starts.as_ptr()).wrapping_offset(zs);
941        let starts = unsafe { I::to_simd_usize_unchecked(Simd::gather_ptr(start_ptrs)) }.to_array();
942        let ends = unsafe {
943            I::to_simd_usize_unchecked(Simd::gather_ptr(start_ptrs.wrapping_add(Simd::splat(1))))
944        }
945        .to_array();
946
947        starts
948            .into_iter()
949            .zip(ends)
950            .zip(inbounds.to_array())
951            .enumerate()
952            .filter_map(|(j, r)| r.1.then_some((j, r.0)))
953            .any(|(j, (start, end))| {
954                let mut n_center = [Simd::splat(A::ZERO); K];
955                for k in 0..K {
956                    n_center[k] = Simd::splat(centers[k][j]);
957                }
958                let rs = Simd::splat(radii[j]);
959                let rs_sq = rs * rs;
960                (start..end).any(|i| {
961                    let mut dists_sq = Simd::splat(A::ZERO);
962                    #[allow(clippy::needless_range_loop)]
963                    for k in 0..K {
964                        let vals: Simd<A, L> = unsafe {
965                            *ptr::from_ref(&self.afforded[k].get_unchecked(i).data).cast()
966                        };
967                        let diff = vals - n_center[k];
968                        dists_sq += diff * diff;
969                    }
970                    Simd::<A, L>::mask_any(dists_sq.simd_le(rs_sq))
971                })
972            })
973    }
974}
975
976fn distsq<A: Axis, const K: usize>(a: [A; K], b: [A; K]) -> A {
977    let mut total = A::ZERO;
978    for i in 0..K {
979        total = total + (a[i] - b[i]).square();
980    }
981    total
982}
983
984impl<A, const K: usize> Aabb<A, K>
985where
986    A: Axis,
987{
988    const ALL: Self = Self {
989        lo: [A::NEG_INFINITY; K],
990        hi: [A::INFINITY; K],
991    };
992
993    /// Split this volume by a test plane with value `test` along `dim`.
994    const fn split(mut self, test: A, dim: usize) -> (Self, Self) {
995        let mut rhs = self;
996        self.hi[dim] = test;
997        rhs.lo[dim] = test;
998
999        (self, rhs)
1000    }
1001
1002    fn contained_by_ball(&self, center: &[A; K], rsq: A) -> bool {
1003        let mut dist = A::ZERO;
1004
1005        #[allow(clippy::needless_range_loop)]
1006        for k in 0..K {
1007            let lo_diff = (self.lo[k] - center[k]).square();
1008            let hi_diff = (self.hi[k] - center[k]).square();
1009
1010            dist = dist + if lo_diff < hi_diff { hi_diff } else { lo_diff };
1011        }
1012
1013        dist <= rsq
1014    }
1015
1016    #[doc(hidden)]
1017    pub fn closest_distsq_to(&self, pt: &[A; K]) -> A {
1018        let mut dist = A::ZERO;
1019
1020        #[allow(clippy::needless_range_loop)]
1021        for d in 0..K {
1022            let clamped = clamp(pt[d], self.lo[d], self.hi[d]);
1023            dist = dist + (pt[d] - clamped).square();
1024        }
1025
1026        dist
1027    }
1028
1029    fn insert(&mut self, point: &[A; K]) {
1030        self.lo
1031            .iter_mut()
1032            .zip(&mut self.hi)
1033            .zip(point)
1034            .for_each(|((l, h), &x)| {
1035                if *l > x {
1036                    *l = x;
1037                }
1038                if x > *h {
1039                    *h = x;
1040                }
1041            });
1042    }
1043}
1044
1045#[inline]
1046/// Calculate the "true" median (halfway between two midpoints) and partition `points` about said
1047/// median along axis `d`.
1048///
1049/// # Safety
1050///
1051/// This function will result in undefined behavior if `points` contains any `NaN` values.
1052unsafe fn median_partition<A: Axis, const K: usize>(points: &mut [[A; K]], k: usize) -> A {
1053    unsafe {
1054        let (lh, med_hi, _) = points.select_nth_unstable_by(points.len() / 2, |a, b| {
1055            a[k].partial_cmp(&b[k]).unwrap_unchecked()
1056        });
1057        let med_lo = lh
1058            .iter_mut()
1059            .map(|p| p[k])
1060            .max_by(|a, b| a.partial_cmp(b).unwrap_unchecked())
1061            .unwrap();
1062        A::in_between(med_lo, med_hi[k])
1063    }
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068    use rand::{Rng, SeedableRng, rngs::SmallRng};
1069
1070    use super::*;
1071
1072    #[test]
1073    fn build_simple() {
1074        let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]];
1075        let t = Capt::<2>::new(&points, (0.0, 0.2));
1076        println!("{t:?}");
1077    }
1078
1079    #[test]
1080    fn exact_query_single() {
1081        let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]];
1082        let t = Capt::<2>::new(&points, (0.0, 0.2));
1083
1084        println!("{t:?}");
1085
1086        let q0 = [0.0, -0.01];
1087        assert!(t.collides(&q0, 0.12));
1088    }
1089
1090    #[test]
1091    fn another_one() {
1092        let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]];
1093        let t = Capt::<2>::new(&points, (0.0, 0.2));
1094
1095        println!("{t:?}");
1096
1097        let q0 = [0.003_265_380_9, 0.106_527_805];
1098        assert!(t.collides(&q0, 0.02));
1099    }
1100
1101    #[test]
1102    fn three_d() {
1103        let points = [
1104            [0.0; 3],
1105            [0.1, -1.1, 0.5],
1106            [-0.2, -0.3, 0.25],
1107            [0.1, -1.1, 0.5],
1108        ];
1109
1110        let t = Capt::<3>::new(&points, (0.0, 0.2));
1111
1112        println!("{t:?}");
1113        assert!(t.collides(&[0.0, 0.1, 0.0], 0.11));
1114        assert!(!t.collides(&[0.0, 0.1, 0.0], 0.05));
1115    }
1116
1117    #[test]
1118    fn fuzz() {
1119        const R: f32 = 0.04;
1120        let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]];
1121        let mut rng = SmallRng::seed_from_u64(1234);
1122        let t = Capt::<2>::new(&points, (0.0, R));
1123
1124        for _ in 0..10_000 {
1125            let p = [rng.random_range(-1.0..1.0), rng.random_range(-1.0..1.0)];
1126            let collides = points.iter().any(|a| distsq(*a, p) <= R * R);
1127            println!("{p:?}; {collides}");
1128            assert_eq!(collides, t.collides(&p, R));
1129        }
1130    }
1131
1132    #[test]
1133    /// This test _should_ fail, but it doesn't somehow?
1134    fn weird_bounds() {
1135        const R_SQ: f32 = 1.0;
1136        let points = [
1137            [-1.0, 0.0],
1138            [0.001, 0.0],
1139            [0.0, 0.5],
1140            [-1.0, 10.0],
1141            [-2.0, 10.0],
1142            [-3.0, 10.0],
1143            [-0.5, 0.0],
1144            [-11.0, 1.0],
1145            [-1.0, -0.5],
1146            [1.0, 1.0],
1147            [2.0, 2.0],
1148            [3.0, 3.0],
1149            [4.0, 4.0],
1150            [5.0, 5.0],
1151            [6.0, 6.0],
1152            [7.0, 7.0],
1153        ];
1154        let rsq_range = (R_SQ - f32::EPSILON, R_SQ + f32::EPSILON);
1155        let t = Capt::<2>::new(&points, rsq_range);
1156        println!("{t:?}");
1157
1158        assert!(t.collides(&[-0.001, -0.2], 1.0));
1159    }
1160
1161    #[test]
1162    #[allow(clippy::float_cmp)]
1163    fn does_it_partition() {
1164        let mut points = vec![[1.0], [2.0], [1.5], [2.1], [-0.5]];
1165        let median = unsafe { median_partition(&mut points, 0) };
1166        assert_eq!(median, 1.25);
1167        for p0 in &points[..points.len() / 2] {
1168            assert!(p0[0] <= median);
1169        }
1170
1171        for p0 in &points[points.len() / 2..] {
1172            assert!(p0[0] >= median);
1173        }
1174    }
1175
1176    #[test]
1177    fn point_radius() {
1178        let points = [[0.0, 0.0], [0.0, 1.0]];
1179        let r_range = (0.0, 1.0);
1180
1181        let capt: Capt<_, 8, _, u32> = Capt::with_point_radius(&points, r_range, 0.5);
1182        assert!(capt.collides(&[0.6, 0.0], 0.2));
1183        assert!(!capt.collides(&[0.6, 0.0], 0.05));
1184    }
1185}