particular/compute_method/
storage.rs

1use crate::compute_method::{
2    math::{AsPrimitive, BitAnd, CmpNe, Float, FloatVector, FromPrimitive, Sum, Zero, SIMD},
3    tree::{
4        partition::{BoundingBox, SubDivide},
5        Node, NodeID, Orthtree,
6    },
7    ComputeMethod,
8};
9
10/// Point-mass representation of an object in space.
11#[derive(Clone, Copy, Debug, Default)]
12#[repr(C)]
13pub struct PointMass<V, S> {
14    /// Position of the object.
15    pub position: V,
16    /// Mass of the object.
17    pub mass: S,
18}
19
20impl<V: Zero, S: Zero> PointMass<V, S> {
21    /// [`PointMass`] with position and mass set to [`Zero::ZERO`].
22    pub const ZERO: Self = PointMass::new(V::ZERO, S::ZERO);
23}
24
25impl<V, S> PointMass<V, S> {
26    /// Creates a new [`PointMass`] with the given position and mass.
27    #[inline]
28    pub const fn new(position: V, mass: S) -> Self {
29        Self { position, mass }
30    }
31
32    /// Creates a new [`PointMass`] with the given lanes of positions and masses.
33    #[inline]
34    pub fn new_lane(position: V::Lane, mass: S::Lane) -> Self
35    where
36        V: SIMD,
37        S: SIMD,
38    {
39        Self::new(V::new_lane(position), S::new_lane(mass))
40    }
41
42    /// Returns the [`PointMass`] corresponding to the center of mass and total mass of the given
43    /// slice of point-masses.
44    #[inline]
45    pub fn new_com(data: &[Self]) -> Self
46    where
47        V: FloatVector<Float = S> + Sum + Copy,
48        S: Float + FromPrimitive<usize> + Sum + Copy,
49    {
50        let tot = data.iter().map(|p| p.mass).sum();
51        let com = if tot == S::ZERO {
52            data.iter().map(|p| p.position).sum::<V>() / data.len().as_()
53        } else {
54            data.iter().map(|p| p.position * (p.mass / tot)).sum()
55        };
56
57        Self::new(com, tot)
58    }
59
60    /// Creates a new [`PointMass`] with all lanes set to the given position and mass.
61    #[inline]
62    pub fn splat_lane(position: V::Element, mass: S::Element) -> Self
63    where
64        V: SIMD,
65        S: SIMD,
66    {
67        Self::new(V::splat(position), S::splat(mass))
68    }
69
70    /// Returns a [`SIMD`] point-masses from a slice of [`SIMD::Element`] point-masses.
71    #[inline]
72    pub fn slice_to_lane<const L: usize, T, E>(slice: &[PointMass<T, E>]) -> Self
73    where
74        T: Clone + Zero,
75        E: Clone + Zero,
76        V: SIMD<Lane = [T; L], Element = T>,
77        S: SIMD<Lane = [E; L], Element = E>,
78    {
79        let mut lane = [PointMass::ZERO; L];
80        lane[..slice.len()].clone_from_slice(slice);
81        Self::new_lane(lane.clone().map(|p| p.position), lane.map(|p| p.mass))
82    }
83
84    /// Returns an iterator of [`SIMD`] point-masses from a slice of [`SIMD::Element`] point-masses.
85    #[inline]
86    pub fn slice_to_lanes<'a, const L: usize, T, E>(
87        slice: &'a [PointMass<T, E>],
88    ) -> impl Iterator<Item = Self> + 'a
89    where
90        T: Clone + Zero,
91        E: Clone + Zero,
92        V: SIMD<Lane = [T; L], Element = T> + 'a,
93        S: SIMD<Lane = [E; L], Element = E> + 'a,
94    {
95        slice.chunks(L).map(Self::slice_to_lane)
96    }
97
98    /// Returns true if the mass is zero.
99    #[inline]
100    pub fn is_massless(&self) -> bool
101    where
102        S: PartialEq + Zero,
103    {
104        self.mass == S::ZERO
105    }
106
107    /// Returns false if the mass is zero.
108    #[inline]
109    pub fn is_massive(&self) -> bool
110    where
111        S: PartialEq + Zero,
112    {
113        self.mass != S::ZERO
114    }
115
116    /// Computes the gravitational force exerted on the current point-mass using the given position
117    /// and mass. This method is optimised in the case where `V` and `S` are scalar types.
118    ///
119    /// If the position of the current point-mass is guaranteed to be different from the given
120    /// position, this computation can be more efficient with `CHECK_ZERO` set to false.
121    #[inline]
122    pub fn force_scalar<const CHECK_ZERO: bool>(&self, position: V, mass: S, softening: S) -> V
123    where
124        V: FloatVector<Float = S> + Copy,
125        S: Float + Copy,
126    {
127        let dir = position - self.position;
128        let norm = dir.norm_squared();
129        let norm_s = norm + (softening * softening);
130
131        // Branch removed by the compiler when `CHECK_ZERO` is false.
132        if CHECK_ZERO && norm == S::ZERO {
133            dir
134        } else {
135            dir * (mass / (norm_s * norm_s.sqrt()))
136        }
137    }
138
139    /// Computes the gravitational force exerted on the current point-mass using the given position
140    /// and mass. This method is optimised in the case where `V` and `S` are simd types.
141    ///
142    /// If the position of the current point-mass is guaranteed to be different from the given
143    /// position, this computation can be more efficient with `CHECK_ZERO` set to false.
144    #[inline]
145    pub fn force_simd<const CHECK_ZERO: bool>(&self, position: V, mass: S, softening: S) -> V
146    where
147        V: FloatVector<Float = S> + Copy,
148        S: Float + BitAnd<Output = S> + CmpNe<Output = S> + Copy,
149    {
150        let dir = position - self.position;
151        let norm = dir.norm_squared();
152        let norm_s = norm + (softening * softening);
153        let f = mass * (norm_s * norm_s * norm_s).rsqrt();
154
155        // Branch removed by the compiler when `CHECK_ZERO` is false.
156        if CHECK_ZERO {
157            dir * f.bitand(norm.cmp_ne(S::ZERO))
158        } else {
159            dir * f
160        }
161    }
162
163    /// Computes the gravitational acceleration exerted on the current point-mass by the specified
164    /// node of the given [`Orthtree`] following the [Barnes-Hut](https://en.wikipedia.org/wiki/Barnes%E2%80%93Hut_simulation)
165    /// approximation with the given `theta` parameter, provided `V` and `S` are scalar types.
166    #[inline]
167    pub fn acceleration_tree<const X: usize, const D: usize>(
168        &self,
169        tree: &Orthtree<X, D, S, PointMass<V, S>>,
170        node: Option<NodeID>,
171        theta: S,
172        softening: S,
173    ) -> V
174    where
175        V: FloatVector<Float = S> + Copy + Sum,
176        S: Float + PartialOrd + Copy,
177    {
178        let mut acceleration = V::ZERO;
179
180        let estimate = X * (tree.nodes.len() as f32).ln() as usize; // TODO: find a proper estimate
181        let mut stack = Vec::with_capacity(estimate);
182        stack.push(node);
183
184        while let Some(node) = stack.pop() {
185            let id = match node {
186                Some(id) => id as usize,
187                None => continue,
188            };
189
190            let p2 = tree.data[id];
191            let dir = p2.position - self.position;
192            let norm = dir.norm_squared();
193
194            if norm == S::ZERO {
195                continue;
196            }
197
198            match tree.nodes[id] {
199                Node::Internal(node) if theta < node.bbox.width() / norm.sqrt() => {
200                    stack.extend(node.orthant);
201                }
202                _ => {
203                    let norm_s = norm + (softening * softening);
204                    acceleration += dir * (p2.mass / (norm_s * norm_s.sqrt()));
205                }
206            }
207        }
208
209        acceleration
210    }
211}
212
213/// Flexible, copyable storage with references to affected particles and a generic massive storage.
214#[derive(Debug)]
215pub struct ParticleSystem<'p, V, S, T: ?Sized> {
216    /// Particles for which the acceleration is computed.
217    pub affected: &'p [PointMass<V, S>],
218    /// Particles responsible for the acceleration exerted on the `affected` particles, in a
219    /// storage `T`.
220    pub massive: &'p T,
221}
222
223impl<V, S, T: ?Sized> Clone for ParticleSystem<'_, V, S, T> {
224    fn clone(&self) -> Self {
225        *self
226    }
227}
228
229impl<V, S, T: ?Sized> Copy for ParticleSystem<'_, V, S, T> {}
230
231impl<'p, V, S, T: ?Sized> ParticleSystem<'p, V, S, T> {
232    /// Creates a new [`ParticleSystem`] with the given slice of particles and massive storage.
233    #[inline]
234    pub const fn with(affected: &'p [PointMass<V, S>], massive: &'p T) -> Self {
235        Self { affected, massive }
236    }
237}
238
239/// [`ParticleSystem`] with a slice of particles for the massive storage.
240pub type ParticleSliceSystem<'p, V, S> = ParticleSystem<'p, V, S, [PointMass<V, S>]>;
241
242/// Storage with particles in an [`Orthtree`] and its root.
243#[derive(Clone, Debug)]
244pub struct ParticleTree<const X: usize, const D: usize, V, S> {
245    root: Option<NodeID>,
246    tree: Orthtree<X, D, S, PointMass<V, S>>,
247}
248
249impl<const X: usize, const D: usize, V, S> ParticleTree<X, D, V, S> {
250    /// Returns the root of the [`Orthtree`].
251    #[inline]
252    pub const fn root(&self) -> Option<NodeID> {
253        self.root
254    }
255
256    /// Returns a reference to the [`Orthtree`].
257    #[inline]
258    pub const fn get(&self) -> &Orthtree<X, D, S, PointMass<V, S>> {
259        &self.tree
260    }
261}
262
263impl<const X: usize, const D: usize, V, S> From<&[PointMass<V, S>]> for ParticleTree<X, D, V, S>
264where
265    V: Copy + FloatVector<Float = S, Array = [S; D]>,
266    S: Copy + Float + Sum + PartialOrd + FromPrimitive<usize>,
267    BoundingBox<[S; D]>: SubDivide<Division = [BoundingBox<[S; D]>; X]>,
268{
269    #[inline]
270    fn from(slice: &[PointMass<V, S>]) -> Self {
271        let mut tree = Orthtree::with_capacity(slice.len());
272        let root = tree.build_node(slice, |p| p.position.into(), PointMass::new_com);
273
274        Self { root, tree }
275    }
276}
277
278/// [`ParticleSystem`] with a [`ParticleTree`] for the massive storage.
279pub type ParticleTreeSystem<'p, const X: usize, const D: usize, V, S> =
280    ParticleSystem<'p, V, S, ParticleTree<X, D, V, S>>;
281
282/// Storage inside of which the massive particles are placed before the massless ones.
283///
284/// Allows for easy optimisation of the computation of forces between massive and massless
285/// particles.
286#[derive(Clone, Debug)]
287pub struct ParticleOrdered<V, S> {
288    massive_len: usize,
289    particles: Vec<PointMass<V, S>>,
290}
291
292impl<V, S> ParticleOrdered<V, S> {
293    /// Creates a new [`ParticleOrdered`] with the given massive and massless particles.
294    #[inline]
295    pub fn with<I, U>(massive: I, massless: U) -> Self
296    where
297        S: PartialEq + Zero,
298        I: IntoIterator<Item = PointMass<V, S>>,
299        U: IntoIterator<Item = PointMass<V, S>>,
300    {
301        let particles = massive.into_iter().chain(massless).collect::<Vec<_>>();
302        let massive_len = particles
303            .iter()
304            .position(PointMass::is_massless)
305            .unwrap_or(particles.len());
306
307        Self {
308            massive_len,
309            particles,
310        }
311    }
312
313    /// Returns the number of stored massive particles.
314    #[inline]
315    pub const fn massive_len(&self) -> usize {
316        self.massive_len
317    }
318
319    /// Returns a reference to the massive particles.
320    #[inline]
321    pub fn massive(&self) -> &[PointMass<V, S>] {
322        &self.particles[..self.massive_len]
323    }
324
325    /// Returns a reference to the massless particles.
326    #[inline]
327    pub fn massless(&self) -> &[PointMass<V, S>] {
328        &self.particles[self.massive_len..]
329    }
330
331    /// Returns a reference to the particles.
332    #[inline]
333    pub fn particles(&self) -> &[PointMass<V, S>] {
334        &self.particles
335    }
336
337    /// Returns a mutable reference to the massive particles.
338    #[inline]
339    pub fn massive_mut(&mut self) -> &mut [PointMass<V, S>] {
340        &mut self.particles[..self.massive_len]
341    }
342
343    /// Returns a mutable reference to the massless particles.
344    #[inline]
345    pub fn massless_mut(&mut self) -> &mut [PointMass<V, S>] {
346        &mut self.particles[self.massive_len..]
347    }
348
349    /// Returns a mutable reference to the stored ordered particles.
350    #[inline]
351    pub fn particles_mut(&mut self) -> &mut [PointMass<V, S>] {
352        &mut self.particles
353    }
354}
355
356impl<V, S> From<&[PointMass<V, S>]> for ParticleOrdered<V, S>
357where
358    V: Clone,
359    S: Clone + PartialEq + Zero,
360{
361    #[inline]
362    fn from(particles: &[PointMass<V, S>]) -> Self {
363        Self::with(
364            particles.iter().filter(|p| p.is_massive()).cloned(),
365            particles.iter().filter(|p| p.is_massless()).cloned(),
366        )
367    }
368}
369
370/// Storage for particles which has a copy of the stored particles inside a [`ParticleOrdered`].
371#[derive(Clone, Debug)]
372pub struct ParticleReordered<'p, V, S> {
373    /// Original, unordered particles.
374    pub unordered: &'p [PointMass<V, S>],
375    ordered: ParticleOrdered<V, S>,
376}
377
378impl<V, S> ParticleReordered<'_, V, S> {
379    /// Returns a reference to the [`ParticleOrdered`].
380    #[inline]
381    pub const fn ordered(&self) -> &ParticleOrdered<V, S> {
382        &self.ordered
383    }
384
385    /// Returns the number of stored massive particles.
386    #[inline]
387    pub const fn massive_len(&self) -> usize {
388        self.ordered.massive_len()
389    }
390
391    /// Returns a reference to the massive particles.
392    #[inline]
393    pub fn massive(&self) -> &[PointMass<V, S>] {
394        self.ordered.massive()
395    }
396
397    /// Returns a reference to the massless particles.
398    #[inline]
399    pub fn massless(&self) -> &[PointMass<V, S>] {
400        self.ordered.massless()
401    }
402
403    /// Returns a reference to the stored ordered particles.
404    #[inline]
405    pub fn reordered(&self) -> &[PointMass<V, S>] {
406        self.ordered.particles()
407    }
408}
409
410impl<'p, V, S> From<&'p [PointMass<V, S>]> for ParticleReordered<'p, V, S>
411where
412    V: Clone,
413    S: Clone + Zero + PartialEq,
414{
415    #[inline]
416    fn from(affected: &'p [PointMass<V, S>]) -> Self {
417        Self {
418            unordered: affected,
419            ordered: ParticleOrdered::from(affected),
420        }
421    }
422}
423
424impl<V, S, C, O> ComputeMethod<&[PointMass<V, S>]> for C
425where
426    O: IntoIterator,
427    for<'a> C: ComputeMethod<ParticleSliceSystem<'a, V, S>, Output = O>,
428{
429    type Output = O;
430
431    #[inline]
432    fn compute(&mut self, slice: &[PointMass<V, S>]) -> Self::Output {
433        self.compute(ParticleSliceSystem {
434            affected: slice,
435            massive: slice,
436        })
437    }
438}
439
440impl<V, S, C, O> ComputeMethod<&ParticleOrdered<V, S>> for C
441where
442    O: IntoIterator,
443    for<'a> C: ComputeMethod<ParticleSliceSystem<'a, V, S>, Output = O>,
444{
445    type Output = O;
446
447    #[inline]
448    fn compute(&mut self, ordered: &ParticleOrdered<V, S>) -> Self::Output {
449        self.compute(ParticleSliceSystem {
450            affected: ordered.particles(),
451            massive: ordered.massive(),
452        })
453    }
454}
455
456impl<V, S, C, O> ComputeMethod<&ParticleReordered<'_, V, S>> for C
457where
458    O: IntoIterator,
459    for<'a> C: ComputeMethod<ParticleSliceSystem<'a, V, S>, Output = O>,
460{
461    type Output = O;
462
463    #[inline]
464    fn compute(&mut self, reordered: &ParticleReordered<V, S>) -> Self::Output {
465        self.compute(ParticleSliceSystem {
466            affected: reordered.unordered,
467            massive: reordered.massive(),
468        })
469    }
470}
471
472impl<const X: usize, const D: usize, V, S, C, O> ComputeMethod<ParticleSliceSystem<'_, V, S>> for C
473where
474    O: IntoIterator,
475    for<'a> C: ComputeMethod<ParticleTreeSystem<'a, X, D, V, S>, Output = O>,
476    V: Copy + FloatVector<Float = S, Array = [S; D]>,
477    S: Copy + Float + Sum + PartialOrd + FromPrimitive<usize>,
478    BoundingBox<[S; D]>: SubDivide<Division = [BoundingBox<[S; D]>; X]>,
479{
480    type Output = O;
481
482    #[inline]
483    fn compute(&mut self, system: ParticleSliceSystem<V, S>) -> Self::Output {
484        self.compute(ParticleTreeSystem {
485            affected: system.affected,
486            massive: &ParticleTree::from(system.massive),
487        })
488    }
489}