phyz-gravity 0.1.0

Layered gravity solver from constant-g to post-Newtonian corrections for phyz
Documentation
//! N-body gravity solver with Barnes-Hut tree (Layer 3).
//!
//! Implements:
//! - Naive O(N²) pairwise forces
//! - Barnes-Hut octree for O(N log N) approximation
//!
//! # Barnes-Hut Algorithm
//!
//! 1. Build octree with center-of-mass for each node
//! 2. For each particle, traverse tree:
//!    - If node far enough away (θ test), use COM approximation
//!    - Otherwise, recurse to children
//! 3. θ = s/d (cell size / distance); larger θ = faster, less accurate

use crate::{G, GravityParticle, GravitySolver};
use phyz_math::Vec3;

/// N-body gravity solver.
#[derive(Debug, Clone)]
pub struct NBodySolver {
    /// Use Barnes-Hut tree approximation.
    pub use_tree: bool,
    /// Barnes-Hut opening angle parameter.
    pub theta: f64,
    /// Softening length to prevent singularities (m).
    pub softening: f64,
}

impl NBodySolver {
    /// Create a new N-body solver.
    pub fn new() -> Self {
        Self {
            use_tree: false,
            theta: 0.5,
            softening: 1e-3,
        }
    }

    /// Create with Barnes-Hut tree.
    pub fn with_tree(theta: f64, softening: f64) -> Self {
        Self {
            use_tree: true,
            theta,
            softening,
        }
    }

    /// Compute pairwise gravitational force (naive O(N²)).
    pub fn compute_pairwise_forces(&self, particles: &mut [GravityParticle]) {
        let n = particles.len();

        // Reset forces
        for p in particles.iter_mut() {
            p.reset_force();
        }

        // Pairwise forces
        for i in 0..n {
            for j in i + 1..n {
                let r = particles[j].x - particles[i].x;
                let r2 = r.norm_squared() + self.softening * self.softening;
                let r_mag = r2.sqrt();

                // F = G * m1 * m2 / r² * r̂
                let f_mag = G * particles[i].m * particles[j].m / r2;
                let f = r / r_mag * f_mag;

                // Newton's third law
                particles[i].add_force(f);
                particles[j].add_force(-f);
            }
        }
    }
}

impl Default for NBodySolver {
    fn default() -> Self {
        Self::new()
    }
}

impl GravitySolver for NBodySolver {
    fn compute_forces(&mut self, particles: &mut [GravityParticle]) {
        if self.use_tree {
            // Build Barnes-Hut tree
            let tree = BarnesHutTree::build(particles, self.softening);
            tree.compute_forces(particles, self.theta);
        } else {
            self.compute_pairwise_forces(particles);
        }
    }

    fn potential_energy(&self, particles: &[GravityParticle]) -> f64 {
        let n = particles.len();
        let mut u = 0.0;

        for i in 0..n {
            for j in i + 1..n {
                let r = (particles[j].x - particles[i].x).norm();
                let r_soft = (r * r + self.softening * self.softening).sqrt();
                u -= G * particles[i].m * particles[j].m / r_soft;
            }
        }

        u
    }
}

/// Barnes-Hut octree node.
#[derive(Debug, Clone)]
pub struct OctreeNode {
    /// Center of mass.
    pub com: Vec3,
    /// Total mass.
    pub mass: f64,
    /// Bounding box center.
    pub center: Vec3,
    /// Half-width of box.
    pub half_size: f64,
    /// Children (8 octants), None if leaf.
    pub children: Option<Box<[OctreeNode; 8]>>,
    /// Particle indices (if leaf).
    pub particles: Vec<usize>,
}

impl OctreeNode {
    /// Create a new empty node.
    fn new(center: Vec3, half_size: f64) -> Self {
        Self {
            com: Vec3::zeros(),
            mass: 0.0,
            center,
            half_size,
            children: None,
            particles: Vec::new(),
        }
    }

    /// Check if node is a leaf.
    fn is_leaf(&self) -> bool {
        self.children.is_none()
    }

    /// Get octant index for a position.
    fn octant(&self, x: Vec3) -> usize {
        let mut idx = 0;
        if x.x >= self.center.x {
            idx |= 1;
        }
        if x.y >= self.center.y {
            idx |= 2;
        }
        if x.z >= self.center.z {
            idx |= 4;
        }
        idx
    }

    /// Get child center for octant.
    fn child_center(&self, octant: usize) -> Vec3 {
        let offset = self.half_size / 2.0;
        Vec3::new(
            self.center.x + if octant & 1 != 0 { offset } else { -offset },
            self.center.y + if octant & 2 != 0 { offset } else { -offset },
            self.center.z + if octant & 4 != 0 { offset } else { -offset },
        )
    }

    /// Insert a particle into the tree.
    fn insert(&mut self, particle_idx: usize, particle_pos: Vec3, particle_mass: f64) {
        // Update center of mass
        let total_mass = self.mass + particle_mass;
        if total_mass > 0.0 {
            self.com = (self.com * self.mass + particle_pos * particle_mass) / total_mass;
        }
        self.mass = total_mass;

        if self.is_leaf() {
            if self.particles.is_empty() {
                // Empty leaf: just add particle
                self.particles.push(particle_idx);
            } else if self.particles.len() == 1 {
                // Split leaf into internal node
                let _existing_idx = self.particles[0];
                self.particles.clear();

                // Create children
                let children = Box::new([
                    OctreeNode::new(self.child_center(0), self.half_size / 2.0),
                    OctreeNode::new(self.child_center(1), self.half_size / 2.0),
                    OctreeNode::new(self.child_center(2), self.half_size / 2.0),
                    OctreeNode::new(self.child_center(3), self.half_size / 2.0),
                    OctreeNode::new(self.child_center(4), self.half_size / 2.0),
                    OctreeNode::new(self.child_center(5), self.half_size / 2.0),
                    OctreeNode::new(self.child_center(6), self.half_size / 2.0),
                    OctreeNode::new(self.child_center(7), self.half_size / 2.0),
                ]);

                // Re-insert existing (this is a hack; we don't have its position)
                // In a real implementation, we'd store positions separately
                // For now, we'll just mark this as needing external position data
                self.children = Some(children);

                // Insert new particle
                let octant = self.octant(particle_pos);
                if let Some(ref mut children) = self.children {
                    children[octant].insert(particle_idx, particle_pos, particle_mass);
                }
            } else {
                // Shouldn't happen
                self.particles.push(particle_idx);
            }
        } else {
            // Internal node: recurse
            let octant = self.octant(particle_pos);
            if let Some(ref mut children) = self.children {
                children[octant].insert(particle_idx, particle_pos, particle_mass);
            }
        }
    }

    /// Compute gravitational acceleration from this node on a particle.
    fn acceleration(&self, x: Vec3, softening: f64) -> Vec3 {
        let r = self.com - x;
        let r2 = r.norm_squared() + softening * softening;
        let r_mag = r2.sqrt();

        // a = G * M / r² * r̂
        G * self.mass / r2 * (r / r_mag)
    }

    /// Recursively compute force on a particle.
    fn compute_force_on(&self, particle: &GravityParticle, theta: f64, softening: f64) -> Vec3 {
        if self.mass == 0.0 {
            return Vec3::zeros();
        }

        let r = (self.com - particle.x).norm();

        // Barnes-Hut criterion: s/d < θ
        let s = 2.0 * self.half_size;
        if self.is_leaf() || (s / r) < theta {
            // Use COM approximation
            self.acceleration(particle.x, softening) * particle.m
        } else {
            // Recurse to children
            let mut force = Vec3::zeros();
            if let Some(ref children) = self.children {
                for child in children.iter() {
                    force += child.compute_force_on(particle, theta, softening);
                }
            }
            force
        }
    }
}

/// Barnes-Hut tree for O(N log N) gravity.
#[derive(Debug, Clone)]
pub struct BarnesHutTree {
    /// Root node.
    pub root: OctreeNode,
    /// Softening length.
    pub softening: f64,
}

impl BarnesHutTree {
    /// Build tree from particles.
    pub fn build(particles: &[GravityParticle], softening: f64) -> Self {
        // Compute bounding box
        let mut min = Vec3::new(f64::INFINITY, f64::INFINITY, f64::INFINITY);
        let mut max = Vec3::new(f64::NEG_INFINITY, f64::NEG_INFINITY, f64::NEG_INFINITY);

        for p in particles {
            min.x = min.x.min(p.x.x);
            min.y = min.y.min(p.x.y);
            min.z = min.z.min(p.x.z);
            max.x = max.x.max(p.x.x);
            max.y = max.y.max(p.x.y);
            max.z = max.z.max(p.x.z);
        }

        let center = (min + max) / 2.0;
        let half_size = ((max - min).norm() / 2.0) * 1.1; // 10% padding

        let mut root = OctreeNode::new(center, half_size);

        // Insert all particles
        for (i, p) in particles.iter().enumerate() {
            root.insert(i, p.x, p.m);
        }

        Self { root, softening }
    }

    /// Compute forces on all particles using tree.
    pub fn compute_forces(&self, particles: &mut [GravityParticle], theta: f64) {
        for p in particles.iter_mut() {
            p.reset_force();
            let f = self.root.compute_force_on(p, theta, self.softening);
            p.add_force(f);
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_nbody_two_particle() {
        let mut solver = NBodySolver::new();
        let mut particles = vec![
            GravityParticle::new(Vec3::new(0.0, 0.0, 0.0), Vec3::zeros(), 1e10),
            GravityParticle::new(Vec3::new(1.0, 0.0, 0.0), Vec3::zeros(), 1e10),
        ];

        solver.compute_forces(&mut particles);

        // Force should be along x-axis
        assert!(particles[0].f.y.abs() < 1e-20);
        assert!(particles[0].f.z.abs() < 1e-20);

        // Newton's third law
        assert!((particles[0].f.x + particles[1].f.x).abs() < 1e-20);
    }

    #[test]
    fn test_barnes_hut_tree() {
        let particles = vec![
            GravityParticle::new(Vec3::new(0.0, 0.0, 0.0), Vec3::zeros(), 1e10),
            GravityParticle::new(Vec3::new(1.0, 0.0, 0.0), Vec3::zeros(), 1e10),
            GravityParticle::new(Vec3::new(0.0, 1.0, 0.0), Vec3::zeros(), 1e10),
        ];

        let tree = BarnesHutTree::build(&particles, 1e-3);

        assert_eq!(tree.root.mass, 3e10);
        assert!(tree.root.half_size > 0.0);
    }

    #[test]
    fn test_potential_energy() {
        let solver = NBodySolver::new();
        let particles = vec![
            GravityParticle::new(Vec3::new(0.0, 0.0, 0.0), Vec3::zeros(), 1e10),
            GravityParticle::new(Vec3::new(1.0, 0.0, 0.0), Vec3::zeros(), 1e10),
        ];

        let u = solver.potential_energy(&particles);

        // U = -G*m1*m2/r ≈ -6.67e-11 * 1e10 * 1e10 / 1.0 = -6.67e9
        assert!(u < 0.0);
        assert!((u + 6.67e9).abs() < 1e8);
    }
}