use glam::Vec3;
use crate::util::dot_simd;
#[derive(Debug, Clone, Copy)]
pub(crate) struct OctNode {
min: Vec3,
max: Vec3,
mass: f32,
com: Vec3,
body: Option<u32>,
children: [u32; 8], has_children: bool,
}
impl OctNode {
fn empty(min: Vec3, max: Vec3) -> Self {
Self {
min,
max,
mass: 0.0,
com: Vec3::ZERO,
body: None,
children: [u32::MAX; 8],
has_children: false,
}
}
#[cfg(test)]
pub(crate) fn mass(&self) -> f32 {
self.mass
}
#[cfg(test)]
pub(crate) fn com(&self) -> Vec3 {
self.com
}
#[inline]
fn size(&self) -> f32 {
(self.max - self.min).max_element()
}
#[inline]
fn center(&self) -> Vec3 {
(self.min + self.max) * 0.5
}
#[inline]
fn octant_of(&self, p: Vec3) -> usize {
let c = self.center();
let east = (p.x >= c.x) as usize;
let north = (p.y >= c.y) as usize;
let up = (p.z >= c.z) as usize;
(up << 2) | (north << 1) | east
}
#[inline]
fn child_bounds(&self, q: usize) -> (Vec3, Vec3) {
let c = self.center();
let (x_min, x_max) = if q & 1 == 1 {
(c.x, self.max.x)
} else {
(self.min.x, c.x)
};
let (y_min, y_max) = if q & 2 == 2 {
(c.y, self.max.y)
} else {
(self.min.y, c.y)
};
let (z_min, z_max) = if q & 4 == 4 {
(c.z, self.max.z)
} else {
(self.min.z, c.z)
};
(
Vec3::new(x_min, y_min, z_min),
Vec3::new(x_max, y_max, z_max),
)
}
}
pub(crate) fn build_octree(arena: &mut Vec<OctNode>, positions: &[Vec3], min_cell_size: f32) {
arena.clear();
let n = positions.len();
if n == 0 {
return;
}
let mut min = positions[0];
let mut max = positions[0];
for p in &positions[1..] {
min = min.min(*p);
max = max.max(*p);
}
let extent = (max - min).max_element().max(min_cell_size);
let center = (min + max) * 0.5;
let half = Vec3::splat(extent * 0.5);
let min = center - half;
let max = center + half;
arena.push(OctNode::empty(min, max));
for (i, p) in positions.iter().enumerate() {
insert(arena, 0, i as u32, *p, min_cell_size);
}
finalise(arena, 0);
}
fn insert(tree: &mut Vec<OctNode>, idx: usize, body: u32, pos: Vec3, min_cell_size: f32) {
if !tree[idx].has_children && tree[idx].body.is_none() && tree[idx].mass == 0.0 {
tree[idx].body = Some(body);
tree[idx].com = pos;
tree[idx].mass = 1.0;
return;
}
if !tree[idx].has_children {
if tree[idx].size() <= min_cell_size {
let new_mass = tree[idx].mass + 1.0;
tree[idx].com = (tree[idx].com * tree[idx].mass + pos) / new_mass;
tree[idx].mass = new_mass;
return;
}
let existing = tree[idx].body.take().expect("leaf without a body");
let existing_pos = tree[idx].com;
tree[idx].mass = 0.0;
tree[idx].com = Vec3::ZERO;
tree[idx].has_children = true;
let q_existing = tree[idx].octant_of(existing_pos);
let c_existing = create_or_get_child(tree, idx, q_existing);
insert(tree, c_existing, existing, existing_pos, min_cell_size);
let q_new = tree[idx].octant_of(pos);
let c_new = create_or_get_child(tree, idx, q_new);
insert(tree, c_new, body, pos, min_cell_size);
return;
}
let q = tree[idx].octant_of(pos);
let c = create_or_get_child(tree, idx, q);
insert(tree, c, body, pos, min_cell_size);
}
fn create_or_get_child(tree: &mut Vec<OctNode>, parent: usize, octant: usize) -> usize {
let existing = tree[parent].children[octant];
if existing != u32::MAX {
return existing as usize;
}
let (cmin, cmax) = tree[parent].child_bounds(octant);
let idx = tree.len();
tree.push(OctNode::empty(cmin, cmax));
tree[parent].children[octant] = idx as u32;
idx
}
fn finalise(tree: &mut [OctNode], idx: usize) {
if !tree[idx].has_children {
return;
}
let children = tree[idx].children;
let mut mass = 0.0;
let mut com = Vec3::ZERO;
for &c in &children {
if c == u32::MAX {
continue;
}
finalise(tree, c as usize);
let child = tree[c as usize];
mass += child.mass;
com += child.com * child.mass;
}
if mass > 0.0 {
com /= mass;
}
tree[idx].mass = mass;
tree[idx].com = com;
}
pub(crate) fn accumulate_repulsion_3d(
tree: &[OctNode],
idx: usize,
target_pos: Vec3,
target_idx: u32,
theta2: f32,
repulsion: f32,
) -> Vec3 {
let node = &tree[idx];
if node.mass <= 0.0 {
return Vec3::ZERO;
}
if !node.has_children {
if let Some(body) = node.body {
if body == target_idx {
let residual_mass = node.mass - 1.0;
if residual_mass <= 0.0 {
return Vec3::ZERO;
}
return pair_force_3d(target_pos, node.com, residual_mass, repulsion);
}
}
return pair_force_3d(target_pos, node.com, node.mass, repulsion);
}
let delta = node.com - target_pos;
let dist2 = delta.length_squared();
let size = node.size();
if size * size < theta2 * dist2 {
return pair_force_3d(target_pos, node.com, node.mass, repulsion);
}
let mut total = Vec3::ZERO;
for &c in &node.children {
if c == u32::MAX {
continue;
}
total +=
accumulate_repulsion_3d(tree, c as usize, target_pos, target_idx, theta2, repulsion);
}
total
}
#[inline]
fn pair_force_3d(from: Vec3, to: Vec3, mass: f32, repulsion: f32) -> Vec3 {
let delta = to - from;
let dist2 = delta.length_squared() + 0.01;
let dist = dist2.sqrt();
let dir = delta / dist;
let magnitude = repulsion * mass / dist2;
dir * magnitude
}
#[allow(dead_code)]
pub(crate) fn sum_positions_sqnorm(positions: &[Vec3]) -> f32 {
if positions.is_empty() {
return 0.0;
}
let flat: &[f32] = unsafe {
std::slice::from_raw_parts(positions.as_ptr() as *const f32, positions.len() * 3)
};
dot_simd(flat, flat)
}