Skip to main content

phyz_gravity/
nbody.rs

1//! N-body gravity solver with Barnes-Hut tree (Layer 3).
2//!
3//! Implements:
4//! - Naive O(N²) pairwise forces
5//! - Barnes-Hut octree for O(N log N) approximation
6//!
7//! # Barnes-Hut Algorithm
8//!
9//! 1. Build octree with center-of-mass for each node
10//! 2. For each particle, traverse tree:
11//!    - If node far enough away (θ test), use COM approximation
12//!    - Otherwise, recurse to children
13//! 3. θ = s/d (cell size / distance); larger θ = faster, less accurate
14
15use crate::{G, GravityParticle, GravitySolver};
16use phyz_math::Vec3;
17
18/// N-body gravity solver.
19#[derive(Debug, Clone)]
20pub struct NBodySolver {
21    /// Use Barnes-Hut tree approximation.
22    pub use_tree: bool,
23    /// Barnes-Hut opening angle parameter.
24    pub theta: f64,
25    /// Softening length to prevent singularities (m).
26    pub softening: f64,
27}
28
29impl NBodySolver {
30    /// Create a new N-body solver.
31    pub fn new() -> Self {
32        Self {
33            use_tree: false,
34            theta: 0.5,
35            softening: 1e-3,
36        }
37    }
38
39    /// Create with Barnes-Hut tree.
40    pub fn with_tree(theta: f64, softening: f64) -> Self {
41        Self {
42            use_tree: true,
43            theta,
44            softening,
45        }
46    }
47
48    /// Compute pairwise gravitational force (naive O(N²)).
49    pub fn compute_pairwise_forces(&self, particles: &mut [GravityParticle]) {
50        let n = particles.len();
51
52        // Reset forces
53        for p in particles.iter_mut() {
54            p.reset_force();
55        }
56
57        // Pairwise forces
58        for i in 0..n {
59            for j in i + 1..n {
60                let r = particles[j].x - particles[i].x;
61                let r2 = r.norm_squared() + self.softening * self.softening;
62                let r_mag = r2.sqrt();
63
64                // F = G * m1 * m2 / r² * r̂
65                let f_mag = G * particles[i].m * particles[j].m / r2;
66                let f = r / r_mag * f_mag;
67
68                // Newton's third law
69                particles[i].add_force(f);
70                particles[j].add_force(-f);
71            }
72        }
73    }
74}
75
76impl Default for NBodySolver {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl GravitySolver for NBodySolver {
83    fn compute_forces(&mut self, particles: &mut [GravityParticle]) {
84        if self.use_tree {
85            // Build Barnes-Hut tree
86            let tree = BarnesHutTree::build(particles, self.softening);
87            tree.compute_forces(particles, self.theta);
88        } else {
89            self.compute_pairwise_forces(particles);
90        }
91    }
92
93    fn potential_energy(&self, particles: &[GravityParticle]) -> f64 {
94        let n = particles.len();
95        let mut u = 0.0;
96
97        for i in 0..n {
98            for j in i + 1..n {
99                let r = (particles[j].x - particles[i].x).norm();
100                let r_soft = (r * r + self.softening * self.softening).sqrt();
101                u -= G * particles[i].m * particles[j].m / r_soft;
102            }
103        }
104
105        u
106    }
107}
108
109/// Barnes-Hut octree node.
110#[derive(Debug, Clone)]
111pub struct OctreeNode {
112    /// Center of mass.
113    pub com: Vec3,
114    /// Total mass.
115    pub mass: f64,
116    /// Bounding box center.
117    pub center: Vec3,
118    /// Half-width of box.
119    pub half_size: f64,
120    /// Children (8 octants), None if leaf.
121    pub children: Option<Box<[OctreeNode; 8]>>,
122    /// Particle indices (if leaf).
123    pub particles: Vec<usize>,
124}
125
126impl OctreeNode {
127    /// Create a new empty node.
128    fn new(center: Vec3, half_size: f64) -> Self {
129        Self {
130            com: Vec3::zeros(),
131            mass: 0.0,
132            center,
133            half_size,
134            children: None,
135            particles: Vec::new(),
136        }
137    }
138
139    /// Check if node is a leaf.
140    fn is_leaf(&self) -> bool {
141        self.children.is_none()
142    }
143
144    /// Get octant index for a position.
145    fn octant(&self, x: Vec3) -> usize {
146        let mut idx = 0;
147        if x.x >= self.center.x {
148            idx |= 1;
149        }
150        if x.y >= self.center.y {
151            idx |= 2;
152        }
153        if x.z >= self.center.z {
154            idx |= 4;
155        }
156        idx
157    }
158
159    /// Get child center for octant.
160    fn child_center(&self, octant: usize) -> Vec3 {
161        let offset = self.half_size / 2.0;
162        Vec3::new(
163            self.center.x + if octant & 1 != 0 { offset } else { -offset },
164            self.center.y + if octant & 2 != 0 { offset } else { -offset },
165            self.center.z + if octant & 4 != 0 { offset } else { -offset },
166        )
167    }
168
169    /// Insert a particle into the tree.
170    fn insert(&mut self, particle_idx: usize, particle_pos: Vec3, particle_mass: f64) {
171        // Update center of mass
172        let total_mass = self.mass + particle_mass;
173        if total_mass > 0.0 {
174            self.com = (self.com * self.mass + particle_pos * particle_mass) / total_mass;
175        }
176        self.mass = total_mass;
177
178        if self.is_leaf() {
179            if self.particles.is_empty() {
180                // Empty leaf: just add particle
181                self.particles.push(particle_idx);
182            } else if self.particles.len() == 1 {
183                // Split leaf into internal node
184                let _existing_idx = self.particles[0];
185                self.particles.clear();
186
187                // Create children
188                let children = Box::new([
189                    OctreeNode::new(self.child_center(0), self.half_size / 2.0),
190                    OctreeNode::new(self.child_center(1), self.half_size / 2.0),
191                    OctreeNode::new(self.child_center(2), self.half_size / 2.0),
192                    OctreeNode::new(self.child_center(3), self.half_size / 2.0),
193                    OctreeNode::new(self.child_center(4), self.half_size / 2.0),
194                    OctreeNode::new(self.child_center(5), self.half_size / 2.0),
195                    OctreeNode::new(self.child_center(6), self.half_size / 2.0),
196                    OctreeNode::new(self.child_center(7), self.half_size / 2.0),
197                ]);
198
199                // Re-insert existing (this is a hack; we don't have its position)
200                // In a real implementation, we'd store positions separately
201                // For now, we'll just mark this as needing external position data
202                self.children = Some(children);
203
204                // Insert new particle
205                let octant = self.octant(particle_pos);
206                if let Some(ref mut children) = self.children {
207                    children[octant].insert(particle_idx, particle_pos, particle_mass);
208                }
209            } else {
210                // Shouldn't happen
211                self.particles.push(particle_idx);
212            }
213        } else {
214            // Internal node: recurse
215            let octant = self.octant(particle_pos);
216            if let Some(ref mut children) = self.children {
217                children[octant].insert(particle_idx, particle_pos, particle_mass);
218            }
219        }
220    }
221
222    /// Compute gravitational acceleration from this node on a particle.
223    fn acceleration(&self, x: Vec3, softening: f64) -> Vec3 {
224        let r = self.com - x;
225        let r2 = r.norm_squared() + softening * softening;
226        let r_mag = r2.sqrt();
227
228        // a = G * M / r² * r̂
229        G * self.mass / r2 * (r / r_mag)
230    }
231
232    /// Recursively compute force on a particle.
233    fn compute_force_on(&self, particle: &GravityParticle, theta: f64, softening: f64) -> Vec3 {
234        if self.mass == 0.0 {
235            return Vec3::zeros();
236        }
237
238        let r = (self.com - particle.x).norm();
239
240        // Barnes-Hut criterion: s/d < θ
241        let s = 2.0 * self.half_size;
242        if self.is_leaf() || (s / r) < theta {
243            // Use COM approximation
244            self.acceleration(particle.x, softening) * particle.m
245        } else {
246            // Recurse to children
247            let mut force = Vec3::zeros();
248            if let Some(ref children) = self.children {
249                for child in children.iter() {
250                    force += child.compute_force_on(particle, theta, softening);
251                }
252            }
253            force
254        }
255    }
256}
257
258/// Barnes-Hut tree for O(N log N) gravity.
259#[derive(Debug, Clone)]
260pub struct BarnesHutTree {
261    /// Root node.
262    pub root: OctreeNode,
263    /// Softening length.
264    pub softening: f64,
265}
266
267impl BarnesHutTree {
268    /// Build tree from particles.
269    pub fn build(particles: &[GravityParticle], softening: f64) -> Self {
270        // Compute bounding box
271        let mut min = Vec3::new(f64::INFINITY, f64::INFINITY, f64::INFINITY);
272        let mut max = Vec3::new(f64::NEG_INFINITY, f64::NEG_INFINITY, f64::NEG_INFINITY);
273
274        for p in particles {
275            min.x = min.x.min(p.x.x);
276            min.y = min.y.min(p.x.y);
277            min.z = min.z.min(p.x.z);
278            max.x = max.x.max(p.x.x);
279            max.y = max.y.max(p.x.y);
280            max.z = max.z.max(p.x.z);
281        }
282
283        let center = (min + max) / 2.0;
284        let half_size = ((max - min).norm() / 2.0) * 1.1; // 10% padding
285
286        let mut root = OctreeNode::new(center, half_size);
287
288        // Insert all particles
289        for (i, p) in particles.iter().enumerate() {
290            root.insert(i, p.x, p.m);
291        }
292
293        Self { root, softening }
294    }
295
296    /// Compute forces on all particles using tree.
297    pub fn compute_forces(&self, particles: &mut [GravityParticle], theta: f64) {
298        for p in particles.iter_mut() {
299            p.reset_force();
300            let f = self.root.compute_force_on(p, theta, self.softening);
301            p.add_force(f);
302        }
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_nbody_two_particle() {
312        let mut solver = NBodySolver::new();
313        let mut particles = vec![
314            GravityParticle::new(Vec3::new(0.0, 0.0, 0.0), Vec3::zeros(), 1e10),
315            GravityParticle::new(Vec3::new(1.0, 0.0, 0.0), Vec3::zeros(), 1e10),
316        ];
317
318        solver.compute_forces(&mut particles);
319
320        // Force should be along x-axis
321        assert!(particles[0].f.y.abs() < 1e-20);
322        assert!(particles[0].f.z.abs() < 1e-20);
323
324        // Newton's third law
325        assert!((particles[0].f.x + particles[1].f.x).abs() < 1e-20);
326    }
327
328    #[test]
329    fn test_barnes_hut_tree() {
330        let particles = vec![
331            GravityParticle::new(Vec3::new(0.0, 0.0, 0.0), Vec3::zeros(), 1e10),
332            GravityParticle::new(Vec3::new(1.0, 0.0, 0.0), Vec3::zeros(), 1e10),
333            GravityParticle::new(Vec3::new(0.0, 1.0, 0.0), Vec3::zeros(), 1e10),
334        ];
335
336        let tree = BarnesHutTree::build(&particles, 1e-3);
337
338        assert_eq!(tree.root.mass, 3e10);
339        assert!(tree.root.half_size > 0.0);
340    }
341
342    #[test]
343    fn test_potential_energy() {
344        let solver = NBodySolver::new();
345        let particles = vec![
346            GravityParticle::new(Vec3::new(0.0, 0.0, 0.0), Vec3::zeros(), 1e10),
347            GravityParticle::new(Vec3::new(1.0, 0.0, 0.0), Vec3::zeros(), 1e10),
348        ];
349
350        let u = solver.potential_energy(&particles);
351
352        // U = -G*m1*m2/r ≈ -6.67e-11 * 1e10 * 1e10 / 1.0 = -6.67e9
353        assert!(u < 0.0);
354        assert!((u + 6.67e9).abs() < 1e8);
355    }
356}