capt/lib.rs
1//! # Collision-Affording Point Trees: SIMD-Amenable Nearest Neighbors for Fast Collision Checking
2//!
3//! This is a Rust implementation of the _collision-affording point tree_ (CAPT), a data structure
4//! for SIMD-parallel collision-checking between spheres and point clouds.
5//!
6//! You may also want to look at the following other sources:
7//!
8//! - [The paper](https://arxiv.org/abs/2406.02807)
9//! - [C++ implementation](https://github.com/KavrakiLab/vamp)
10//! - [Blog post about it](https://www.claytonwramsey.com/blog/captree)
11//! - [Demo video](https://youtu.be/BzDKdrU1VpM)
12//!
13//! If you use this in an academic work, please cite it as follows:
14//!
15//! ```bibtex
16//! @InProceedings{capt,
17//! title = {Collision-Affording Point Trees: {SIMD}-Amenable Nearest Neighbors for Fast Collision Checking},
18//! author = {Ramsey, Clayton W. and Kingston, Zachary and Thomason, Wil and Kavraki, Lydia E.},
19//! booktitle = {Robotics: Science and Systems},
20//! date = {2024},
21//! url = {http://arxiv.org/abs/2406.02807},
22//! note = {To Appear.}
23//! }
24//! ```
25//!
26//! ## Usage
27//!
28//! The core data structure in this library is the [`Capt`], which is a search tree used for
29//! collision checking. [`Capt`]s are polymorphic over dimension and data type. On construction,
30//! they take in a list of points in a point cloud and a _radius range_: a tuple of the minimum and
31//! maximum radius used for querying.
32//!
33//! ```rust
34//! use capt::Capt;
35//!
36//! // list of points in cloud
37//! let points = [[0.0, 1.1], [0.2, 3.1]];
38//! let r_min = 0.05;
39//! let r_max = 2.0;
40//!
41//! let capt = Capt::<2>::new(&points, (r_min, r_max));
42//! ```
43//!
44//! Once you have a `Capt`, you can use it for collision-checking against spheres.
45//! Correct answers are only guaranteed if you collision-check against spheres with a radius inside
46//! the radius range.
47//!
48//! ```rust
49//! # use capt::Capt;
50//! # let points = [[0.0, 1.1], [0.2, 3.1]];
51//! # let capt = Capt::<2>::new(&points, (0.05, 2.0));
52//! let center = [0.0, 0.0]; // center of sphere
53//! let radius0 = 1.0; // radius of sphere
54//! assert!(!capt.collides(¢er, radius0));
55//!
56//! let radius1 = 1.5;
57//! assert!(capt.collides(¢er, radius1));
58//! ```
59//!
60//! ## Optional features
61//!
62//! This crate exposes one feature, `simd`, which enables a SIMD-parallel interface for querying
63//! `Capt`s. The `simd` feature requires nightly Rust and therefore should be considered unstable.
64//! This enables the function `Capt::collides_simd`, a parallel collision checker for batches of
65//! search queries.
66//!
67//! ## License
68//!
69//! This work is licensed to you under the Apache 2.0 license.
70#![cfg_attr(feature = "simd", feature(portable_simd))]
71#![cfg_attr(not(test), no_std)]
72#![warn(clippy::pedantic, clippy::cargo, clippy::nursery, missing_docs)]
73
74extern crate alloc;
75use alloc::{boxed::Box, vec, vec::Vec};
76
77use core::{
78 array,
79 fmt::Debug,
80 mem::size_of,
81 ops::{Add, Sub},
82};
83
84#[cfg(feature = "simd")]
85use core::{
86 ops::{AddAssign, Mul, SubAssign},
87 ptr,
88 simd::{
89 LaneCount, Mask, Simd, SimdElement, SupportedLaneCount,
90 cmp::{SimdPartialEq, SimdPartialOrd},
91 ptr::SimdConstPtr,
92 },
93};
94
95use elain::{Align, Alignment};
96
97/// A generic trait representing values which may be used as an "axis;" that is, elements of a
98/// vector representing a point.
99///
100/// An array of `Axis` values is a point which can be stored in a [`Capt`].
101/// Accordingly, this trait specifies nearly all the requirements for points that [`Capt`]s require.
102/// The only exception is that [`Axis`] values really ought to be [`Ord`] instead of [`PartialOrd`];
103/// however, due to the disaster that is IEE 754 floating point numbers, `f32` and `f64` are not
104/// totally ordered. As a compromise, we relax the `Ord` requirement so that you can use floats in a
105/// `Capt`.
106///
107/// # Examples
108///
109/// ```
110/// #[derive(Clone, Copy, PartialOrd, PartialEq)]
111/// enum HyperInt {
112/// MinusInf,
113/// Real(i32),
114/// PlusInf,
115/// }
116///
117/// impl std::ops::Add for HyperInt {
118/// // ...
119/// # type Output = Self;
120/// #
121/// # fn add(self, rhs: Self) -> Self {
122/// # match (self, rhs) {
123/// # (Self::MinusInf, Self::PlusInf) => Self::Real(0), // evil, but who cares?
124/// # (Self::MinusInf, _) | (_, Self::MinusInf) => Self::MinusInf,
125/// # (Self::PlusInf, _) | (_, Self::PlusInf) => Self::PlusInf,
126/// # (Self::Real(x), Self::Real(y)) => Self::Real(x + y),
127/// # }
128/// # }
129/// }
130///
131///
132/// impl std::ops::Sub for HyperInt {
133/// // ...
134/// # type Output = Self;
135/// #
136/// # fn sub(self, rhs: Self) -> Self {
137/// # match (self, rhs) {
138/// # (Self::MinusInf, Self::MinusInf) | (Self::PlusInf, Self::PlusInf) => Self::Real(0), // evil, but who cares?
139/// # (Self::MinusInf, _) | (_, Self::PlusInf) => Self::MinusInf,
140/// # (Self::PlusInf, _) | (_, Self::MinusInf) => Self::PlusInf,
141/// # (Self::Real(x), Self::Real(y)) => Self::Real(x - y),
142/// # }
143/// # }
144/// }
145///
146/// impl capt::Axis for HyperInt {
147/// const ZERO: Self = Self::Real(0);
148/// const INFINITY: Self = Self::PlusInf;
149/// const NEG_INFINITY: Self = Self::MinusInf;
150///
151/// fn is_finite(self) -> bool {
152/// matches!(self, Self::Real(_))
153/// }
154///
155/// fn in_between(self, rhs: Self) -> Self {
156/// match (self, rhs) {
157/// (Self::PlusInf, Self::MinusInf) | (Self::MinusInf, Self::PlusInf) => Self::Real(0),
158/// (Self::MinusInf, _) | (_, Self::MinusInf) => Self::MinusInf,
159/// (Self::PlusInf, _) | (_, Self::PlusInf) => Self::PlusInf,
160/// (Self::Real(a), Self::Real(b)) => Self::Real((a + b) / 2)
161/// }
162/// }
163///
164/// fn square(self) -> Self {
165/// match self {
166/// Self::PlusInf | Self::MinusInf => Self::PlusInf,
167/// Self::Real(a) => Self::Real(a * a),
168/// }
169/// }
170/// }
171/// ```
172pub trait Axis: PartialOrd + Copy + Sub<Output = Self> + Add<Output = Self> {
173 /// A zero value.
174 const ZERO: Self;
175 /// A value which is larger than any finite value.
176 const INFINITY: Self;
177 /// A value which is smaller than any finite value.
178 const NEG_INFINITY: Self;
179
180 #[must_use]
181 /// Determine whether this value is finite or infinite.
182 fn is_finite(self) -> bool;
183
184 #[must_use]
185 /// Compute a value of `Self` which is halfway between `self` and `rhs`.
186 /// If there are no legal values between `self` and `rhs`, it is acceptable to return `self`
187 /// instead.
188 fn in_between(self, rhs: Self) -> Self;
189
190 #[must_use]
191 /// Compute the square of this value.
192 fn square(self) -> Self;
193}
194
195#[cfg(feature = "simd")]
196/// A trait used for SIMD elements.
197pub trait AxisSimdElement: SimdElement + Default + Axis {}
198
199#[cfg(feature = "simd")]
200/// A trait used for masks over SIMD vectors, used for parallel querying on [`Capt`]s.
201///
202/// The interface for this trait should be considered unstable since the standard SIMD API may
203/// change with Rust versions.
204pub trait AxisSimd<const L: usize>:
205 Sized
206 + SimdPartialOrd
207 + Add<Output = Self>
208 + AddAssign
209 + Sub<Output = Self>
210 + SubAssign
211 + Mul<Output = Self>
212where
213 LaneCount<L>: SupportedLaneCount,
214{
215 /// Cast a mask for a SIMD vector into a mask of `isize`s.
216 fn cast_mask(mask: <Self as SimdPartialEq>::Mask) -> Mask<isize, L>;
217 /// Determine whether a mask contains any true elements.
218 fn mask_any(mask: <Self as SimdPartialEq>::Mask) -> bool;
219}
220
221/// An index type used for lookups into and out of arrays.
222///
223/// This is implemented so that [`Capt`]s can use smaller index sizes (such as [`u32`] or [`u16`])
224/// for improved memory performance.
225pub trait Index: TryFrom<usize> + TryInto<usize> + Copy {
226 /// The zero index. This must be equal to `(0usize).try_into().unwrap()`.
227 const ZERO: Self;
228}
229
230#[cfg(feature = "simd")]
231/// A SIMD parallel version of [`Index`].
232///
233/// This is used for implementing SIMD lookups in a [`Capt`].
234/// The interface for this trait should be considered unstable since the standard SIMD API may
235/// change with Rust versions.
236pub trait IndexSimd: SimdElement + Default {
237 #[must_use]
238 /// Convert a SIMD array of `Self` to a SIMD array of `usize`, without checking that each
239 /// element is valid.
240 ///
241 /// # Safety
242 ///
243 /// This function is only safe if all values of `x` are valid when converted to a `usize`.
244 unsafe fn to_simd_usize_unchecked<const L: usize>(x: Simd<Self, L>) -> Simd<usize, L>
245 where
246 LaneCount<L>: SupportedLaneCount;
247}
248
249macro_rules! impl_axis {
250 ($t: ty, $tm: ty) => {
251 impl Axis for $t {
252 const ZERO: Self = 0.0;
253 const INFINITY: Self = <$t>::INFINITY;
254 const NEG_INFINITY: Self = <$t>::NEG_INFINITY;
255 fn is_finite(self) -> bool {
256 <$t>::is_finite(self)
257 }
258
259 fn in_between(self, rhs: Self) -> Self {
260 (self + rhs) / 2.0
261 }
262
263 fn square(self) -> Self {
264 self * self
265 }
266 }
267
268 #[cfg(feature = "simd")]
269 impl AxisSimdElement for $t {}
270
271 #[cfg(feature = "simd")]
272 impl<const L: usize> AxisSimd<L> for Simd<$t, L>
273 where
274 LaneCount<L>: SupportedLaneCount,
275 {
276 fn cast_mask(mask: <Self as SimdPartialEq>::Mask) -> Mask<isize, L> {
277 mask.into()
278 }
279 fn mask_any(mask: <Self as SimdPartialEq>::Mask) -> bool {
280 mask.any()
281 }
282 }
283 };
284}
285
286macro_rules! impl_idx {
287 ($t: ty) => {
288 impl Index for $t {
289 const ZERO: Self = 0;
290 }
291
292 #[cfg(feature = "simd")]
293 impl IndexSimd for $t {
294 unsafe fn to_simd_usize_unchecked<const L: usize>(x: Simd<Self, L>) -> Simd<usize, L>
295 where
296 LaneCount<L>: SupportedLaneCount,
297 {
298 unsafe { x.to_array().map(|a| a.try_into().unwrap_unchecked()).into() }
299 }
300 }
301 };
302}
303
304impl_axis!(f32, i32);
305impl_axis!(f64, i64);
306
307impl_idx!(u8);
308impl_idx!(u16);
309impl_idx!(u32);
310impl_idx!(u64);
311impl_idx!(usize);
312
313/// Clamp a floating-point number.
314fn clamp<A: PartialOrd>(x: A, min: A, max: A) -> A {
315 if x < min {
316 min
317 } else if x > max {
318 max
319 } else {
320 x
321 }
322}
323
324#[inline]
325#[allow(clippy::cast_possible_wrap)]
326#[cfg(feature = "simd")]
327fn forward_pass_simd<A, const K: usize, const L: usize>(
328 tests: &[A],
329 centers: &[Simd<A, L>; K],
330) -> Simd<isize, L>
331where
332 Simd<A, L>: AxisSimd<L>,
333 A: AxisSimdElement,
334 LaneCount<L>: SupportedLaneCount,
335{
336 let mut test_idxs: Simd<isize, L> = Simd::splat(0);
337 let mut k = 0;
338 for _ in 0..tests.len().trailing_ones() {
339 let test_ptrs = Simd::splat(tests.as_ptr()).wrapping_offset(test_idxs);
340 let relevant_tests: Simd<A, L> = unsafe { Simd::gather_ptr(test_ptrs) };
341 let cmp_results: Mask<isize, L> =
342 Simd::<A, L>::cast_mask(centers[k % K].simd_ge(relevant_tests));
343
344 let one = Simd::splat(1);
345 test_idxs = (test_idxs << one) + one + (cmp_results.to_int() & Simd::splat(1));
346 k = (k + 1) % K;
347 }
348
349 test_idxs - Simd::splat(tests.len() as isize)
350}
351
352#[repr(C)]
353#[derive(Clone, Copy, Debug, PartialEq, Eq)]
354/// A stable-safe wrapper for `[A; L]` which is aligned to `L`.
355/// Equivalent to a `Simd`, but easier to work with.
356struct MySimd<A, const L: usize>
357where
358 Align<L>: Alignment,
359{
360 data: [A; L],
361 _align: Align<L>,
362}
363
364#[derive(Clone, Debug, PartialEq, Eq)]
365#[allow(clippy::module_name_repetitions)]
366/// A collision-affording point tree (CAPT), which allows for efficient collision-checking in a
367/// SIMD-parallel manner between spheres and point clouds.
368///
369/// # Generic parameters
370///
371/// - `K`: The dimension of the space.
372/// - `L`: The lane size of this tree. Internally, this is the upper bound on the width of a SIMD
373/// lane that can be used in this data structure. The alignment of this structure must be a power
374/// of two.
375/// - `A`: The value of the axes of each point. This should typically be `f32` or `f64`. This should
376/// implement [`Axis`].
377/// - `I`: The index integer. This should generally be an unsigned integer, such as `usize` or
378/// `u32`. This should implement [`Index`].
379///
380/// # Examples
381///
382/// ```
383/// // list of points in cloud
384/// let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]];
385///
386/// // query radii must be between 0.0 and 0.2
387/// let t = capt::Capt::<2>::new(&points, (0.0, 0.2));
388///
389/// assert!(!t.collides(&[0.0, 0.3], 0.1));
390/// assert!(t.collides(&[0.0, 0.2], 0.15));
391/// ```
392pub struct Capt<const K: usize, const L: usize = 8, A = f32, I = usize>
393where
394 Align<L>: Alignment,
395{
396 /// The test values for determining which part of the tree to enter.
397 ///
398 /// The first element of `tests` should be the first value to test against.
399 /// If we are less than `tests[0]`, we move on to `tests[1]`; if not, we move on to `tests[2]`.
400 /// At the `i`-th test performed in sequence of the traversal, if we are less than
401 /// `tests[idx]`, we advance to `2 * idx + 1`; otherwise, we go to `2 * idx + 2`.
402 ///
403 /// The length of `tests` must be `N`, rounded up to the next power of 2, minus one.
404 tests: Box<[A]>,
405 /// Axis-aligned bounding boxes containing the set of afforded points for each cell.
406 aabbs: Box<[Aabb<A, K>]>,
407 /// Indexes for the starts of the affordance buffer subsequence of `points` corresponding to
408 /// each leaf cell in the tree.
409 /// This buffer is padded with one extra `usize` at the end with the maximum length of `points`
410 /// for the sake of branchless computation.
411 starts: Box<[I]>,
412 /// The sets of afforded points for each cell.
413 afforded: [Box<[MySimd<A, L>]>; K],
414 r_point: A,
415}
416
417#[repr(C)]
418#[derive(Clone, Copy, Debug, PartialEq, Eq)]
419#[doc(hidden)]
420/// A prismatic bounding volume.
421pub struct Aabb<A, const K: usize> {
422 /// The lower bound on the volume.
423 pub lo: [A; K],
424 /// The upper bound on the volume.
425 pub hi: [A; K],
426}
427
428#[non_exhaustive]
429#[derive(Clone, Debug, PartialEq, Eq)]
430/// The errors which can occur when calling [`Capt::try_new`].
431pub enum NewCaptError {
432 /// There were too many points in the provided cloud to be represented without integer
433 /// overflow.
434 TooManyPoints,
435 /// At least one of the points had a non-finite value.
436 NonFinite,
437}
438
439impl<A, I, const K: usize, const L: usize> Capt<K, L, A, I>
440where
441 A: Axis,
442 I: Index,
443 Align<L>: Alignment,
444{
445 /// Construct a new CAPT containing all the points in `points`.
446 ///
447 /// `r_range` is a `(minimum, maximum)` pair containing the lower and upper bound on the
448 /// radius of the balls which will be queried against the tree.
449 ///
450 /// # Panics
451 ///
452 /// This function will panic if there are too many points in the tree to be addressed by `I`, or
453 /// if any points contain non-finite non-real value. This can even be the case if there are
454 /// fewer points in `points` than can be addressed by `I` as the CAPT may duplicate points
455 /// for efficiency.
456 ///
457 /// # Examples
458 ///
459 /// ```
460 /// let points = [[0.0]];
461 ///
462 /// let capt = capt::Capt::<1>::new(&points, (0.0, f32::INFINITY));
463 ///
464 /// assert!(capt.collides(&[1.0], 1.5));
465 /// assert!(!capt.collides(&[1.0], 0.5));
466 /// ```
467 ///
468 /// If there are too many points in `points`, this could cause a panic!
469 ///
470 /// ```rust,should_panic
471 /// let points = [[0.0]; 256];
472 ///
473 /// // note that we are using `u8` as our index type
474 /// let capt = capt::Capt::<1, 8, f32, u8>::new(&points, (0.0, f32::INFINITY));
475 /// ```
476 pub fn new(points: &[[A; K]], r_range: (A, A)) -> Self {
477 Self::try_new(points, r_range)
478 .expect("index type I must be able to support all points in CAPT during construction")
479 }
480
481 /// Construct a new CAPT containing all the points in `points` with a point radius `r_point`.
482 ///
483 /// `r_range` is a `(minimum, maximum)` pair containing the lower and upper bound on the
484 /// radius of the balls which will be queried against the tree.
485 ///
486 /// # Panics
487 ///
488 /// This function will panic if there are too many points in the tree to be addressed by `I`, or
489 /// if any points contain non-finite non-real value. This can even be the case if there are
490 /// fewer points in `points` than can be addressed by `I` as the CAPT may duplicate points
491 /// for efficiency.
492 ///
493 /// # Examples
494 ///
495 /// ```
496 /// let points = [[0.0]];
497 ///
498 /// let capt = capt::Capt::<1>::with_point_radius(&points, (0.0, f32::INFINITY), 0.2);
499 ///
500 /// assert!(capt.collides(&[1.0], 1.5));
501 /// assert!(!capt.collides(&[1.0], 0.5));
502 /// ```
503 ///
504 /// If there are too many points in `points`, this could cause a panic!
505 ///
506 /// ```rust,should_panic
507 /// let points = [[0.0]; 256];
508 ///
509 /// // note that we are using `u8` as our index type
510 /// let capt = capt::Capt::<1, 8, f32, u8>::with_point_radius(&points, (0.0, f32::INFINITY), 0.2);
511 /// ```
512 pub fn with_point_radius(points: &[[A; K]], r_range: (A, A), r_point: A) -> Self {
513 Self::try_with_point_radius(points, r_range, r_point)
514 .expect("index type I must be able to support all points in CAPT during construction")
515 }
516
517 /// Construct a new CAPT containing all the points in `points`, checking for index overflow.
518 ///
519 /// `r_range` is a `(minimum, maximum)` pair containing the lower and upper bound on the
520 /// radius of the balls which will be queried against the tree.
521 ///
522 /// # Errors
523 ///
524 /// This function will return `Err(NewCaptError::TooManyPoints)` if there are too many points to
525 /// be indexed by `I`. It will return `Err(NewCaptError::NonFinite)` if any element of
526 /// `points` is non-finite.
527 ///
528 /// # Examples
529 ///
530 /// Unwrapping the output from this function is equivalent to calling [`Capt::new`].
531 ///
532 /// ```
533 /// let points = [[0.0]];
534 ///
535 /// let capt = capt::Capt::<1>::try_new(&points, (0.0, f32::INFINITY)).unwrap();
536 /// ```
537 ///
538 /// In failure, we get an `Err`.
539 ///
540 /// ```
541 /// let points = [[0.0]; 256];
542 ///
543 /// // note that we are using `u8` as our index type
544 /// let opt = capt::Capt::<1, 8, f32, u8>::try_new(&points, (0.0, f32::INFINITY));
545 ///
546 /// assert!(opt.is_err());
547 /// ```
548 pub fn try_new(points: &[[A; K]], r_range: (A, A)) -> Result<Self, NewCaptError> {
549 Self::try_with_point_radius(points, r_range, A::ZERO)
550 }
551
552 /// Construct a new CAPT containing all the points in `points` with a point radius `r_point`,
553 /// checking for index overflow.
554 ///
555 /// `r_range` is a `(minimum, maximum)` pair containing the lower and upper bound on the
556 /// radius of the balls which will be queried against the tree.
557 ///
558 /// # Errors
559 ///
560 /// This function will return `Err(NewCaptError::TooManyPoints)` if there are too many points to
561 /// be indexed by `I`. It will return `Err(NewCaptError::NonFinite)` if any element of
562 /// `points` is non-finite.
563 ///
564 /// # Examples
565 ///
566 /// Unwrapping the output from this function is equivalent to calling
567 /// [`Capt::with_point_radius`].
568 ///
569 /// ```
570 /// let points = [[0.0]];
571 ///
572 /// let capt = capt::Capt::<1>::try_with_point_radius(&points, (0.0, f32::INFINITY), 0.01).unwrap();
573 /// ```
574 ///
575 /// In failure, we get an `Err`.
576 ///
577 /// ```
578 /// let points = [[0.0]; 256];
579 ///
580 /// // note that we are using `u8` as our index type
581 /// let opt =
582 /// capt::Capt::<1, 8, f32, u8>::try_with_point_radius(&points, (0.0, f32::INFINITY), 0.01);
583 ///
584 /// assert!(opt.is_err());
585 /// ```
586 pub fn try_with_point_radius(
587 points: &[[A; K]],
588 r_range: (A, A),
589 r_point: A,
590 ) -> Result<Self, NewCaptError> {
591 let n2 = points.len().next_power_of_two();
592
593 if points.iter().any(|p| p.iter().any(|x| !x.is_finite())) {
594 return Err(NewCaptError::NonFinite);
595 }
596
597 let mut tests = vec![A::INFINITY; n2 - 1].into_boxed_slice();
598
599 // hack: just pad with infinity to make it a power of 2
600 let mut points2 = vec![[A::INFINITY; K]; n2].into_boxed_slice();
601 points2[..points.len()].copy_from_slice(points);
602 // hack - reduce number of reallocations by allocating a lot of points from the start
603 let mut afforded = array::from_fn(|_| Vec::with_capacity(n2 * 100));
604 let mut starts = vec![I::ZERO; n2 + 1].into_boxed_slice();
605
606 let mut aabbs = vec![
607 Aabb {
608 lo: [A::NEG_INFINITY; K],
609 hi: [A::INFINITY; K],
610 };
611 n2
612 ]
613 .into_boxed_slice();
614
615 unsafe {
616 // SAFETY: We tested that `points` contains no `NaN` values.
617 Self::new_help(
618 &mut points2,
619 &mut tests,
620 &mut aabbs,
621 &mut afforded,
622 &mut starts,
623 0,
624 0,
625 r_range,
626 Vec::new(),
627 Aabb::ALL,
628 )?;
629 }
630
631 Ok(Self {
632 tests,
633 starts,
634 afforded: afforded.map(Vec::into_boxed_slice),
635 aabbs,
636 r_point,
637 })
638 }
639
640 #[allow(clippy::too_many_arguments, clippy::too_many_lines)]
641 /// # Safety
642 ///
643 /// This function will contain undefined behavior if `points` contains any `NaN` values.
644 unsafe fn new_help(
645 points: &mut [[A; K]],
646 tests: &mut [A],
647 aabbs: &mut [Aabb<A, K>],
648 afforded: &mut [Vec<MySimd<A, L>>; K],
649 starts: &mut [I],
650 k: usize,
651 i: usize,
652 r_range: (A, A),
653 in_range: Vec<[A; K]>,
654 cell: Aabb<A, K>,
655 ) -> Result<(), NewCaptError> {
656 unsafe {
657 let rsq_min = r_range.0.square();
658 if let [rep] = *points {
659 let z = i - tests.len();
660 let aabb = &mut aabbs[z];
661 *aabb = Aabb { lo: rep, hi: rep };
662 if rep[0].is_finite() {
663 // lanes for afforded points
664 let mut news = [[A::INFINITY; L]; K];
665 for k in 0..K {
666 news[k][0] = rep[k];
667 }
668
669 // index into the current lane
670 let mut j = 1;
671
672 // populate affordance buffer if the representative doesn't cover everything
673 if !cell.contained_by_ball(&rep, rsq_min) {
674 for ak in afforded.iter_mut() {
675 ak.reserve(ak.len() + in_range.len() / L);
676 }
677 for p in in_range {
678 aabb.insert(&p);
679
680 // start a new lane if it's full
681 if j == L {
682 for k in 0..K {
683 afforded[k].push(MySimd {
684 data: news[k],
685 _align: Align::NEW,
686 });
687 }
688 j = 0;
689 }
690
691 // add this point to the lane
692 for k in 0..K {
693 news[k][j] = p[k];
694 }
695
696 j += 1;
697 }
698 }
699
700 // fill out the last lane with infinities
701 for k in 0..K {
702 afforded[k].push(MySimd {
703 data: news[k],
704 _align: Align::NEW,
705 });
706 }
707 }
708
709 starts[z + 1] = afforded[0]
710 .len()
711 .try_into()
712 .map_err(|_| NewCaptError::TooManyPoints)?;
713 return Ok(());
714 }
715
716 let test = median_partition(points, k);
717 tests[i] = test;
718
719 let (lhs, rhs) = points.split_at_mut(points.len() / 2);
720 let (lo_vol, hi_vol) = cell.split(test, k);
721
722 let lo_too_small = distsq(lo_vol.lo, lo_vol.hi) <= rsq_min;
723 let hi_too_small = distsq(hi_vol.lo, hi_vol.hi) <= rsq_min;
724
725 // retain only points which might be in the affordance buffer for the split-out cells
726 let (lo_afford, hi_afford) = match (lo_too_small, hi_too_small) {
727 (false, false) => {
728 let mut lo_afford = in_range;
729 let mut hi_afford = lo_afford.clone();
730 lo_afford.retain(|pt| pt[k] <= test + r_range.1);
731 lo_afford.extend(rhs.iter().filter(|pt| pt[k] <= test + r_range.1));
732 hi_afford.retain(|pt| pt[k].is_finite() && test - r_range.1 <= pt[k]);
733 hi_afford.extend(
734 lhs.iter()
735 .filter(|pt| pt[k].is_finite() && test - r_range.1 <= pt[k]),
736 );
737
738 (lo_afford, hi_afford)
739 }
740 (false, true) => {
741 let mut lo_afford = in_range;
742 lo_afford.retain(|pt| pt[k] <= test + r_range.1);
743 lo_afford.extend(rhs.iter().filter(|pt| pt[k] <= test + r_range.1));
744
745 (lo_afford, Vec::new())
746 }
747 (true, false) => {
748 let mut hi_afford = in_range;
749 hi_afford.retain(|pt| test - r_range.1 <= pt[k]);
750 hi_afford.extend(
751 lhs.iter()
752 .filter(|pt| pt[k].is_finite() && test - r_range.1 <= pt[k]),
753 );
754
755 (Vec::new(), hi_afford)
756 }
757 (true, true) => (Vec::new(), Vec::new()),
758 };
759
760 let next_k = (k + 1) % K;
761 Self::new_help(
762 lhs,
763 tests,
764 aabbs,
765 afforded,
766 starts,
767 next_k,
768 2 * i + 1,
769 r_range,
770 lo_afford,
771 lo_vol,
772 )?;
773 Self::new_help(
774 rhs,
775 tests,
776 aabbs,
777 afforded,
778 starts,
779 next_k,
780 2 * i + 2,
781 r_range,
782 hi_afford,
783 hi_vol,
784 )?;
785
786 Ok(())
787 }
788 }
789
790 #[must_use]
791 /// Determine whether a point in this tree is within a distance of `radius` to `center`.
792 ///
793 /// Note that this function will accept query radii outside of the range `r_range` passed to the
794 /// construction for this CAPT in [`Capt::new`] or [`Capt::try_new`]. However, if the query
795 /// radius is outside this range, the tree may erroneously return `false` (that is, erroneously
796 /// report non-collision).
797 ///
798 /// # Examples
799 ///
800 /// ```
801 /// let points = [[0.0; 3], [1.0; 3], [0.1, 0.5, 1.0]];
802 /// let capt = capt::Capt::<3>::new(&points, (0.0, 1.0));
803 ///
804 /// assert!(capt.collides(&[1.1; 3], 0.2));
805 /// assert!(!capt.collides(&[2.0; 3], 1.0));
806 ///
807 /// // no guarantees about what this is, since the radius is greater than the construction range
808 /// println!(
809 /// "collision check result is {:?}",
810 /// capt.collides(&[100.0; 3], 100.0)
811 /// );
812 /// ```
813 pub fn collides(&self, center: &[A; K], mut radius: A) -> bool {
814 radius = radius + self.r_point;
815 // forward pass through the tree
816 let mut test_idx = 0;
817 let mut k = 0;
818 for _ in 0..self.tests.len().trailing_ones() {
819 test_idx = 2 * test_idx
820 + 1
821 + usize::from(unsafe { *self.tests.get_unchecked(test_idx) } <= center[k]);
822 k = (k + 1) % K;
823 }
824
825 // retrieve affordance buffer location
826 let rsq = radius.square();
827 let i = test_idx - self.tests.len();
828 let aabb = unsafe { self.aabbs.get_unchecked(i) };
829 if aabb.closest_distsq_to(center) > rsq {
830 return false;
831 }
832
833 let mut range = unsafe {
834 // SAFETY: The conversion worked the first way.
835 self.starts[i].try_into().unwrap_unchecked()
836 ..self.starts[i + 1].try_into().unwrap_unchecked()
837 };
838
839 // check affordance buffer
840 range.any(|i| {
841 (0..L).any(|j| {
842 let mut aff_pt = [A::INFINITY; K];
843 for (ak, sk) in aff_pt.iter_mut().zip(&self.afforded) {
844 *ak = sk[i].data[j];
845 }
846 distsq(aff_pt, *center) <= rsq
847 })
848 })
849 }
850
851 #[must_use]
852 #[doc(hidden)]
853 /// Get the total memory used (stack + heap) by this structure, measured in bytes.
854 /// This function should not be considered stable; it is only used internally for benchmarks.
855 pub const fn memory_used(&self) -> usize {
856 size_of::<Self>()
857 + K * self.afforded[0].len() * size_of::<A>()
858 + self.starts.len() * size_of::<I>()
859 + self.tests.len() * size_of::<I>()
860 + self.aabbs.len() * size_of::<Aabb<A, K>>()
861 }
862
863 #[must_use]
864 #[doc(hidden)]
865 #[allow(clippy::cast_precision_loss)]
866 /// Get the average number of affordances per point.
867 /// This function should not be considered stable; it is only used internally for benchmarks.
868 pub fn affordance_size(&self) -> f64 {
869 self.afforded.len() as f64 / (self.tests.len() + 1) as f64
870 }
871}
872
873#[allow(clippy::mismatching_type_param_order)]
874#[cfg(feature = "simd")]
875impl<A, I, const K: usize, const L: usize> Capt<K, L, A, I>
876where
877 I: IndexSimd,
878 A: Mul<Output = A>,
879 Align<L>: Alignment,
880{
881 #[must_use]
882 /// Determine whether any sphere in the list of provided spheres intersects a point in this
883 /// tree.
884 ///
885 /// # Examples
886 ///
887 /// ```
888 /// #![feature(portable_simd)]
889 /// use std::simd::Simd;
890 ///
891 /// let points = [[1.0, 2.0], [1.1, 1.1]];
892 ///
893 /// let centers = [
894 /// Simd::from_array([1.0, 1.1, 1.2, 1.3]), // x-positions
895 /// Simd::from_array([1.0, 1.1, 1.2, 1.3]), // y-positions
896 /// ];
897 /// let radii = Simd::splat(0.05);
898 ///
899 /// let tree = capt::Capt::<2, 4, f32, u32>::new(&points, (0.0, 0.1));
900 ///
901 /// println!("{tree:?}");
902 ///
903 /// assert!(tree.collides_simd(¢ers, radii));
904 /// ```
905 pub fn collides_simd(&self, centers: &[Simd<A, L>; K], mut radii: Simd<A, L>) -> bool
906 where
907 LaneCount<L>: SupportedLaneCount,
908 Simd<A, L>: AxisSimd<L>,
909 A: AxisSimdElement,
910 {
911 radii += Simd::splat(self.r_point);
912 let zs = forward_pass_simd(&self.tests, centers);
913
914 let mut inbounds = Mask::splat(true);
915
916 let mut aabb_ptrs = Simd::splat(self.aabbs.as_ptr()).wrapping_offset(zs).cast();
917
918 unsafe {
919 for center in centers {
920 inbounds &= Simd::<A, L>::cast_mask(
921 (Simd::gather_select_ptr(aabb_ptrs, inbounds, Simd::splat(A::NEG_INFINITY))
922 - radii)
923 .simd_le(*center),
924 );
925 aabb_ptrs = aabb_ptrs.wrapping_add(Simd::splat(1));
926 }
927 for center in centers {
928 inbounds &= Simd::<A, L>::cast_mask(
929 Simd::gather_select_ptr(aabb_ptrs, inbounds, Simd::splat(A::NEG_INFINITY))
930 .simd_ge(*center - radii),
931 );
932 aabb_ptrs = aabb_ptrs.wrapping_add(Simd::splat(1));
933 }
934 }
935 if !inbounds.any() {
936 return false;
937 }
938
939 // retrieve start/end pointers for the affordance buffer
940 let start_ptrs = Simd::splat(self.starts.as_ptr()).wrapping_offset(zs);
941 let starts = unsafe { I::to_simd_usize_unchecked(Simd::gather_ptr(start_ptrs)) }.to_array();
942 let ends = unsafe {
943 I::to_simd_usize_unchecked(Simd::gather_ptr(start_ptrs.wrapping_add(Simd::splat(1))))
944 }
945 .to_array();
946
947 starts
948 .into_iter()
949 .zip(ends)
950 .zip(inbounds.to_array())
951 .enumerate()
952 .filter_map(|(j, r)| r.1.then_some((j, r.0)))
953 .any(|(j, (start, end))| {
954 let mut n_center = [Simd::splat(A::ZERO); K];
955 for k in 0..K {
956 n_center[k] = Simd::splat(centers[k][j]);
957 }
958 let rs = Simd::splat(radii[j]);
959 let rs_sq = rs * rs;
960 (start..end).any(|i| {
961 let mut dists_sq = Simd::splat(A::ZERO);
962 #[allow(clippy::needless_range_loop)]
963 for k in 0..K {
964 let vals: Simd<A, L> = unsafe {
965 *ptr::from_ref(&self.afforded[k].get_unchecked(i).data).cast()
966 };
967 let diff = vals - n_center[k];
968 dists_sq += diff * diff;
969 }
970 Simd::<A, L>::mask_any(dists_sq.simd_le(rs_sq))
971 })
972 })
973 }
974}
975
976fn distsq<A: Axis, const K: usize>(a: [A; K], b: [A; K]) -> A {
977 let mut total = A::ZERO;
978 for i in 0..K {
979 total = total + (a[i] - b[i]).square();
980 }
981 total
982}
983
984impl<A, const K: usize> Aabb<A, K>
985where
986 A: Axis,
987{
988 const ALL: Self = Self {
989 lo: [A::NEG_INFINITY; K],
990 hi: [A::INFINITY; K],
991 };
992
993 /// Split this volume by a test plane with value `test` along `dim`.
994 const fn split(mut self, test: A, dim: usize) -> (Self, Self) {
995 let mut rhs = self;
996 self.hi[dim] = test;
997 rhs.lo[dim] = test;
998
999 (self, rhs)
1000 }
1001
1002 fn contained_by_ball(&self, center: &[A; K], rsq: A) -> bool {
1003 let mut dist = A::ZERO;
1004
1005 #[allow(clippy::needless_range_loop)]
1006 for k in 0..K {
1007 let lo_diff = (self.lo[k] - center[k]).square();
1008 let hi_diff = (self.hi[k] - center[k]).square();
1009
1010 dist = dist + if lo_diff < hi_diff { hi_diff } else { lo_diff };
1011 }
1012
1013 dist <= rsq
1014 }
1015
1016 #[doc(hidden)]
1017 pub fn closest_distsq_to(&self, pt: &[A; K]) -> A {
1018 let mut dist = A::ZERO;
1019
1020 #[allow(clippy::needless_range_loop)]
1021 for d in 0..K {
1022 let clamped = clamp(pt[d], self.lo[d], self.hi[d]);
1023 dist = dist + (pt[d] - clamped).square();
1024 }
1025
1026 dist
1027 }
1028
1029 fn insert(&mut self, point: &[A; K]) {
1030 self.lo
1031 .iter_mut()
1032 .zip(&mut self.hi)
1033 .zip(point)
1034 .for_each(|((l, h), &x)| {
1035 if *l > x {
1036 *l = x;
1037 }
1038 if x > *h {
1039 *h = x;
1040 }
1041 });
1042 }
1043}
1044
1045#[inline]
1046/// Calculate the "true" median (halfway between two midpoints) and partition `points` about said
1047/// median along axis `d`.
1048///
1049/// # Safety
1050///
1051/// This function will result in undefined behavior if `points` contains any `NaN` values.
1052unsafe fn median_partition<A: Axis, const K: usize>(points: &mut [[A; K]], k: usize) -> A {
1053 unsafe {
1054 let (lh, med_hi, _) = points.select_nth_unstable_by(points.len() / 2, |a, b| {
1055 a[k].partial_cmp(&b[k]).unwrap_unchecked()
1056 });
1057 let med_lo = lh
1058 .iter_mut()
1059 .map(|p| p[k])
1060 .max_by(|a, b| a.partial_cmp(b).unwrap_unchecked())
1061 .unwrap();
1062 A::in_between(med_lo, med_hi[k])
1063 }
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068 use rand::{Rng, SeedableRng, rngs::SmallRng};
1069
1070 use super::*;
1071
1072 #[test]
1073 fn build_simple() {
1074 let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]];
1075 let t = Capt::<2>::new(&points, (0.0, 0.2));
1076 println!("{t:?}");
1077 }
1078
1079 #[test]
1080 fn exact_query_single() {
1081 let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]];
1082 let t = Capt::<2>::new(&points, (0.0, 0.2));
1083
1084 println!("{t:?}");
1085
1086 let q0 = [0.0, -0.01];
1087 assert!(t.collides(&q0, 0.12));
1088 }
1089
1090 #[test]
1091 fn another_one() {
1092 let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]];
1093 let t = Capt::<2>::new(&points, (0.0, 0.2));
1094
1095 println!("{t:?}");
1096
1097 let q0 = [0.003_265_380_9, 0.106_527_805];
1098 assert!(t.collides(&q0, 0.02));
1099 }
1100
1101 #[test]
1102 fn three_d() {
1103 let points = [
1104 [0.0; 3],
1105 [0.1, -1.1, 0.5],
1106 [-0.2, -0.3, 0.25],
1107 [0.1, -1.1, 0.5],
1108 ];
1109
1110 let t = Capt::<3>::new(&points, (0.0, 0.2));
1111
1112 println!("{t:?}");
1113 assert!(t.collides(&[0.0, 0.1, 0.0], 0.11));
1114 assert!(!t.collides(&[0.0, 0.1, 0.0], 0.05));
1115 }
1116
1117 #[test]
1118 fn fuzz() {
1119 const R: f32 = 0.04;
1120 let points = [[0.0, 0.1], [0.4, -0.2], [-0.2, -0.1]];
1121 let mut rng = SmallRng::seed_from_u64(1234);
1122 let t = Capt::<2>::new(&points, (0.0, R));
1123
1124 for _ in 0..10_000 {
1125 let p = [rng.random_range(-1.0..1.0), rng.random_range(-1.0..1.0)];
1126 let collides = points.iter().any(|a| distsq(*a, p) <= R * R);
1127 println!("{p:?}; {collides}");
1128 assert_eq!(collides, t.collides(&p, R));
1129 }
1130 }
1131
1132 #[test]
1133 /// This test _should_ fail, but it doesn't somehow?
1134 fn weird_bounds() {
1135 const R_SQ: f32 = 1.0;
1136 let points = [
1137 [-1.0, 0.0],
1138 [0.001, 0.0],
1139 [0.0, 0.5],
1140 [-1.0, 10.0],
1141 [-2.0, 10.0],
1142 [-3.0, 10.0],
1143 [-0.5, 0.0],
1144 [-11.0, 1.0],
1145 [-1.0, -0.5],
1146 [1.0, 1.0],
1147 [2.0, 2.0],
1148 [3.0, 3.0],
1149 [4.0, 4.0],
1150 [5.0, 5.0],
1151 [6.0, 6.0],
1152 [7.0, 7.0],
1153 ];
1154 let rsq_range = (R_SQ - f32::EPSILON, R_SQ + f32::EPSILON);
1155 let t = Capt::<2>::new(&points, rsq_range);
1156 println!("{t:?}");
1157
1158 assert!(t.collides(&[-0.001, -0.2], 1.0));
1159 }
1160
1161 #[test]
1162 #[allow(clippy::float_cmp)]
1163 fn does_it_partition() {
1164 let mut points = vec![[1.0], [2.0], [1.5], [2.1], [-0.5]];
1165 let median = unsafe { median_partition(&mut points, 0) };
1166 assert_eq!(median, 1.25);
1167 for p0 in &points[..points.len() / 2] {
1168 assert!(p0[0] <= median);
1169 }
1170
1171 for p0 in &points[points.len() / 2..] {
1172 assert!(p0[0] >= median);
1173 }
1174 }
1175
1176 #[test]
1177 fn point_radius() {
1178 let points = [[0.0, 0.0], [0.0, 1.0]];
1179 let r_range = (0.0, 1.0);
1180
1181 let capt: Capt<_, 8, _, u32> = Capt::with_point_radius(&points, r_range, 0.5);
1182 assert!(capt.collides(&[0.6, 0.0], 0.2));
1183 assert!(!capt.collides(&[0.6, 0.0], 0.05));
1184 }
1185}