use crate::mesh3d::Vec3;
use crate::frustum_culling::BoundingBox;
#[derive(Debug, Clone)]
pub struct Octree<T> {
root: OctreeNode<T>,
max_depth: usize,
max_objects: usize,
}
#[derive(Debug, Clone)]
struct OctreeNode<T> {
bounds: BoundingBox,
objects: Vec<OctreeObject<T>>,
children: Option<Box<[OctreeNode<T>; 8]>>,
depth: usize,
}
#[derive(Debug, Clone)]
struct OctreeObject<T> {
position: Vec3,
data: T,
}
impl<T: Clone> Octree<T> {
pub fn new(bounds: BoundingBox, max_depth: usize, max_objects: usize) -> Self {
Self {
root: OctreeNode::new(bounds, 0),
max_depth,
max_objects,
}
}
pub fn insert(&mut self, position: Vec3, data: T) {
self.root.insert(
position,
data,
self.max_depth,
self.max_objects,
);
}
pub fn query_range(&self, bounds: &BoundingBox) -> Vec<&T> {
let mut results = Vec::new();
self.root.query_range(bounds, &mut results);
results
}
pub fn query_sphere(&self, center: Vec3, radius: f32) -> Vec<&T> {
let mut results = Vec::new();
self.root.query_sphere(center, radius, &mut results);
results
}
pub fn clear(&mut self) {
self.root.clear();
}
pub fn count(&self) -> usize {
self.root.count()
}
}
impl<T: Clone> OctreeNode<T> {
fn new(bounds: BoundingBox, depth: usize) -> Self {
Self {
bounds,
objects: Vec::new(),
children: None,
depth,
}
}
fn insert(&mut self, position: Vec3, data: T, max_depth: usize, max_objects: usize) {
if !self.contains_point(position) {
return;
}
if let Some(children) = &mut self.children {
for child in children.iter_mut() {
if child.contains_point(position) {
child.insert(position, data, max_depth, max_objects);
return;
}
}
}
self.objects.push(OctreeObject { position, data });
if self.objects.len() > max_objects && self.depth < max_depth {
self.subdivide();
self.redistribute(max_depth, max_objects);
}
}
fn subdivide(&mut self) {
let center = self.bounds.center();
let size = self.bounds.size();
let half_size = size / 2.0;
let mut children = Vec::new();
for x in 0..2 {
for y in 0..2 {
for z in 0..2 {
let offset = Vec3::new(
if x == 0 { -half_size.x / 2.0 } else { half_size.x / 2.0 },
if y == 0 { -half_size.y / 2.0 } else { half_size.y / 2.0 },
if z == 0 { -half_size.z / 2.0 } else { half_size.z / 2.0 },
);
let child_center = center + offset;
let child_min = child_center - half_size / 2.0;
let child_max = child_center + half_size / 2.0;
children.push(OctreeNode::new(
BoundingBox::new(child_min, child_max),
self.depth + 1,
));
}
}
}
self.children = Some(Box::new([
children[0].clone(),
children[1].clone(),
children[2].clone(),
children[3].clone(),
children[4].clone(),
children[5].clone(),
children[6].clone(),
children[7].clone(),
]));
}
fn redistribute(&mut self, max_depth: usize, max_objects: usize) {
if let Some(children) = &mut self.children {
let objects = std::mem::take(&mut self.objects);
for obj in objects {
let mut inserted = false;
for child in children.iter_mut() {
if child.contains_point(obj.position) {
child.insert(obj.position, obj.data.clone(), max_depth, max_objects);
inserted = true;
break;
}
}
if !inserted {
self.objects.push(obj);
}
}
}
}
fn contains_point(&self, point: Vec3) -> bool {
point.x >= self.bounds.min.x && point.x <= self.bounds.max.x
&& point.y >= self.bounds.min.y && point.y <= self.bounds.max.y
&& point.z >= self.bounds.min.z && point.z <= self.bounds.max.z
}
fn intersects_bounds(&self, bounds: &BoundingBox) -> bool {
self.bounds.max.x >= bounds.min.x && self.bounds.min.x <= bounds.max.x
&& self.bounds.max.y >= bounds.min.y && self.bounds.min.y <= bounds.max.y
&& self.bounds.max.z >= bounds.min.z && self.bounds.min.z <= bounds.max.z
}
fn query_range<'a>(&'a self, bounds: &BoundingBox, results: &mut Vec<&'a T>) {
if !self.intersects_bounds(bounds) {
return;
}
for obj in &self.objects {
if obj.position.x >= bounds.min.x && obj.position.x <= bounds.max.x
&& obj.position.y >= bounds.min.y && obj.position.y <= bounds.max.y
&& obj.position.z >= bounds.min.z && obj.position.z <= bounds.max.z
{
results.push(&obj.data);
}
}
if let Some(children) = &self.children {
for child in children.iter() {
child.query_range(bounds, results);
}
}
}
fn query_sphere<'a>(&'a self, center: Vec3, radius: f32, results: &mut Vec<&'a T>) {
let closest = Vec3::new(
center.x.max(self.bounds.min.x).min(self.bounds.max.x),
center.y.max(self.bounds.min.y).min(self.bounds.max.y),
center.z.max(self.bounds.min.z).min(self.bounds.max.z),
);
if center.distance(&closest) > radius {
return;
}
for obj in &self.objects {
if center.distance(&obj.position) <= radius {
results.push(&obj.data);
}
}
if let Some(children) = &self.children {
for child in children.iter() {
child.query_sphere(center, radius, results);
}
}
}
fn clear(&mut self) {
self.objects.clear();
if let Some(children) = &mut self.children {
for child in children.iter_mut() {
child.clear();
}
}
self.children = None;
}
fn count(&self) -> usize {
let mut total = self.objects.len();
if let Some(children) = &self.children {
for child in children.iter() {
total += child.count();
}
}
total
}
}
pub struct OctreeStats {
pub total_nodes: usize,
pub total_objects: usize,
pub max_depth_reached: usize,
pub average_objects_per_node: f32,
}
impl<T: Clone> Octree<T> {
pub fn get_stats(&self) -> OctreeStats {
let (nodes, max_depth) = self.root.count_nodes();
let objects = self.count();
OctreeStats {
total_nodes: nodes,
total_objects: objects,
max_depth_reached: max_depth,
average_objects_per_node: if nodes > 0 {
objects as f32 / nodes as f32
} else {
0.0
},
}
}
}
impl<T: Clone> OctreeNode<T> {
fn count_nodes(&self) -> (usize, usize) {
let mut total = 1;
let mut max_depth = self.depth;
if let Some(children) = &self.children {
for child in children.iter() {
let (child_nodes, child_depth) = child.count_nodes();
total += child_nodes;
max_depth = max_depth.max(child_depth);
}
}
(total, max_depth)
}
}