use crate::constants::physical::G;
#[derive(Clone, Debug, PartialEq)]
pub struct OctreeNode {
pub center: [f64; 3],
pub half_width: f64,
pub total_mass: f64,
pub center_of_mass: [f64; 3],
pub children: [Option<Box<OctreeNode>>; 8],
pub particle_index: Option<usize>,
}
impl OctreeNode {
pub fn new(center: [f64; 3], half_width: f64) -> Self {
Self {
center,
half_width,
total_mass: 0.0,
center_of_mass: [0.0; 3],
children: [None, None, None, None, None, None, None, None],
particle_index: None,
}
}
pub fn is_leaf(&self) -> bool {
self.children.iter().all(|c| c.is_none())
}
pub fn octant_index(&self, pos: &[f64; 3]) -> usize {
let mut idx = 0;
if pos[0] > self.center[0] {
idx += 1;
}
if pos[1] > self.center[1] {
idx += 2;
}
if pos[2] > self.center[2] {
idx += 4;
}
idx
}
pub fn child_center(&self, octant: usize) -> [f64; 3] {
let q = self.half_width / 2.0;
[
self.center[0] + if octant & 1 != 0 { q } else { -q },
self.center[1] + if octant & 2 != 0 { q } else { -q },
self.center[2] + if octant & 4 != 0 { q } else { -q },
]
}
}
pub fn build_octree(
positions: &[[f64; 3]],
masses: &[f64],
center: [f64; 3],
half_width: f64,
) -> OctreeNode {
let mut root = OctreeNode::new(center, half_width);
for i in 0..positions.len() {
insert_particle(&mut root, positions, masses, i);
}
root
}
fn insert_particle(node: &mut OctreeNode, positions: &[[f64; 3]], masses: &[f64], idx: usize) {
let m = masses[idx];
let pos = &positions[idx];
if node.total_mass < 1e-100 && node.particle_index.is_none() {
node.total_mass = m;
node.center_of_mass = *pos;
node.particle_index = Some(idx);
return;
}
if let Some(existing) = node.particle_index.take() {
let old_pos = positions[existing];
let old_m = masses[existing];
let oct = node.octant_index(&old_pos);
if node.children[oct].is_none() {
let cc = node.child_center(oct);
node.children[oct] = Some(Box::new(OctreeNode::new(cc, node.half_width / 2.0)));
}
insert_particle(
node.children[oct].as_mut().unwrap(),
positions,
masses,
existing,
);
node.total_mass = old_m;
node.center_of_mass = old_pos;
}
let total = node.total_mass + m;
node.center_of_mass[0] = (node.center_of_mass[0] * node.total_mass + pos[0] * m) / total;
node.center_of_mass[1] = (node.center_of_mass[1] * node.total_mass + pos[1] * m) / total;
node.center_of_mass[2] = (node.center_of_mass[2] * node.total_mass + pos[2] * m) / total;
node.total_mass = total;
let oct = node.octant_index(pos);
if node.children[oct].is_none() {
let cc = node.child_center(oct);
node.children[oct] = Some(Box::new(OctreeNode::new(cc, node.half_width / 2.0)));
}
insert_particle(node.children[oct].as_mut().unwrap(), positions, masses, idx);
}
pub fn barnes_hut_acceleration(
tree: &OctreeNode,
pos: &[f64; 3],
theta: f64,
softening: f64,
) -> [f64; 3] {
let mut acc = [0.0; 3];
bh_walk(tree, pos, theta, softening, &mut acc);
acc
}
fn bh_walk(node: &OctreeNode, pos: &[f64; 3], theta: f64, softening: f64, acc: &mut [f64; 3]) {
if node.total_mass < 1e-100 {
return;
}
let dx = node.center_of_mass[0] - pos[0];
let dy = node.center_of_mass[1] - pos[1];
let dz = node.center_of_mass[2] - pos[2];
let r2 = dx * dx + dy * dy + dz * dz + softening * softening;
let r = r2.sqrt();
if r < 1e-30 {
return;
}
let ratio = 2.0 * node.half_width / r;
if node.is_leaf() || ratio < theta {
let r3 = r * r2;
let f = G * node.total_mass / r3;
acc[0] += f * dx;
acc[1] += f * dy;
acc[2] += f * dz;
return;
}
for c in node.children.iter().flatten() {
bh_walk(c, pos, theta, softening, acc);
}
}
pub fn all_accelerations_barnes_hut(
positions: &[[f64; 3]],
masses: &[f64],
center: [f64; 3],
half_width: f64,
theta: f64,
softening: f64,
) -> Vec<[f64; 3]> {
let tree = build_octree(positions, masses, center, half_width);
let mut accs = Vec::with_capacity(positions.len());
for pos in positions {
accs.push(barnes_hut_acceleration(&tree, pos, theta, softening));
}
accs
}
pub struct LeapfrogParams {
pub dt: f64,
pub center: [f64; 3],
pub half_width: f64,
pub theta: f64,
pub softening: f64,
}
pub fn leapfrog_step_barnes_hut(
positions: &mut [[f64; 3]],
velocities: &mut [[f64; 3]],
masses: &[f64],
params: &LeapfrogParams,
) {
let accs = all_accelerations_barnes_hut(
positions,
masses,
params.center,
params.half_width,
params.theta,
params.softening,
);
for (v, a) in velocities.iter_mut().zip(accs.iter()) {
v[0] += 0.5 * params.dt * a[0];
v[1] += 0.5 * params.dt * a[1];
v[2] += 0.5 * params.dt * a[2];
}
for (p, v) in positions.iter_mut().zip(velocities.iter()) {
p[0] += params.dt * v[0];
p[1] += params.dt * v[1];
p[2] += params.dt * v[2];
}
let accs_new = all_accelerations_barnes_hut(
positions,
masses,
params.center,
params.half_width,
params.theta,
params.softening,
);
for (v, a) in velocities.iter_mut().zip(accs_new.iter()) {
v[0] += 0.5 * params.dt * a[0];
v[1] += 0.5 * params.dt * a[1];
v[2] += 0.5 * params.dt * a[2];
}
}
pub fn optimal_half_width(positions: &[[f64; 3]]) -> f64 {
let mut max_coord = 0.0_f64;
for p in positions {
max_coord = max_coord.max(p[0].abs()).max(p[1].abs()).max(p[2].abs());
}
max_coord * 1.1
}
pub fn tree_potential_energy(
positions: &[[f64; 3]],
masses: &[f64],
center: [f64; 3],
half_width: f64,
theta: f64,
softening: f64,
) -> f64 {
let mut pe = 0.0;
let tree = build_octree(positions, masses, center, half_width);
for (i, pos) in positions.iter().enumerate() {
let acc = barnes_hut_acceleration(&tree, pos, theta, softening);
let a_mag = (acc[0] * acc[0] + acc[1] * acc[1] + acc[2] * acc[2]).sqrt();
pe -= masses[i] * a_mag * softening;
}
pe * 0.5
}