pub const MAX_DEPTH: u8 = 17;
pub struct OctreeNode {
pub center: [f64; 3],
pub half_len: f64,
pub depth: u8,
pub points: Vec<[f64; 3]>,
pub children: Option<Box<[OctreeNode; 8]>>,
}
impl OctreeNode {
pub fn new(center: [f64; 3], half_len: f64) -> Self {
Self { center, half_len, depth: 0, points: vec![], children: None }
}
fn with_depth(center: [f64; 3], half_len: f64, depth: u8) -> Self {
Self { center, half_len, depth, points: vec![], children: None }
}
pub fn insert(&mut self, p: [f64; 3], threshold: usize) {
if !self.contains(p) {
return;
}
if self.depth >= MAX_DEPTH || (self.children.is_none() && self.points.len() < threshold) {
self.points.push(p);
return;
}
if self.children.is_none() {
self.subdivide();
let old: Vec<_> = self.points.drain(..).collect();
for op in old {
self.insert(op, threshold);
}
}
if let Some(ch) = &mut self.children {
for c in ch.iter_mut() {
if c.contains(p) {
c.insert(p, threshold);
break;
}
}
}
}
pub fn insert_batch(&mut self, pts: &[[f64; 3]], threshold: usize) {
for &p in pts {
self.insert(p, threshold);
}
}
fn contains(&self, p: [f64; 3]) -> bool {
let h = self.half_len;
(0..3).all(|i| p[i] >= self.center[i] - h && p[i] < self.center[i] + h)
}
fn subdivide(&mut self) {
let h = self.half_len / 2.0;
let [cx, cy, cz] = self.center;
let d = self.depth + 1;
let offs = [
[-h, -h, -h],
[h, -h, -h],
[-h, h, -h],
[h, h, -h],
[-h, -h, h],
[h, -h, h],
[-h, h, h],
[h, h, h],
];
self.children = Some(Box::new(std::array::from_fn(|i| {
let [dx, dy, dz] = offs[i];
OctreeNode::with_depth([cx + dx, cy + dy, cz + dz], h, d)
})));
}
pub fn nodes_at_depth(&self, target: u8) -> Vec<&OctreeNode> {
if self.depth == target {
return vec![self];
}
if let Some(ch) = &self.children {
ch.iter().flat_map(|c| c.nodes_at_depth(target)).collect()
} else {
vec![]
}
}
pub fn range_query(&self, min: [f64; 3], max: [f64; 3]) -> Vec<&[f64; 3]> {
let h = self.half_len;
let overlaps = (0..3).all(|i| self.center[i] + h > min[i] && self.center[i] - h < max[i]);
if !overlaps {
return vec![];
}
if self.children.is_none() {
return self
.points
.iter()
.filter(|p| (0..3).all(|i| p[i] >= min[i] && p[i] <= max[i]))
.collect();
}
self.children
.as_ref()
.unwrap()
.iter()
.flat_map(|c| c.range_query(min, max))
.collect()
}
}
pub fn depth_for_accuracy(max_error_m: f64, world_size_m: f64) -> u8 {
let mut size = world_size_m;
let mut d = 0u8;
while size > max_error_m && d < MAX_DEPTH {
size /= 2.0;
d += 1;
}
d
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn insert_and_range_query() {
let mut t = OctreeNode::new([0.0; 3], 100.0);
t.insert([1.0, 1.0, 1.0], 4);
t.insert([2.0, 2.0, 2.0], 4);
t.insert([-50.0, -50.0, -50.0], 4);
let near = t.range_query([0.0, 0.0, 0.0], [5.0, 5.0, 5.0]);
assert_eq!(near.len(), 2);
let all = t.range_query([-100.0, -100.0, -100.0], [100.0, 100.0, 100.0]);
assert_eq!(all.len(), 3);
}
#[test]
fn points_outside_root_are_dropped() {
let mut t = OctreeNode::new([0.0; 3], 10.0);
t.insert([100.0, 0.0, 0.0], 4);
let all = t.range_query([-10.0, -10.0, -10.0], [10.0, 10.0, 10.0]);
assert!(all.is_empty());
}
#[test]
fn subdivision_preserves_all_points() {
let mut t = OctreeNode::new([0.0; 3], 100.0);
let pts = [
[10.0, 10.0, 10.0],
[-10.0, 10.0, 10.0],
[10.0, -10.0, 10.0],
[10.0, 10.0, -10.0],
[-10.0, -10.0, -10.0],
];
t.insert_batch(&pts, 2);
assert!(t.children.is_some(), "should have subdivided past threshold");
let found = t.range_query([-100.0, -100.0, -100.0], [100.0, 100.0, 100.0]);
assert_eq!(found.len(), pts.len(), "no points lost during subdivision");
}
#[test]
fn depth_for_accuracy_scales() {
assert_eq!(depth_for_accuracy(1.0, 1000.0), 10);
assert_eq!(depth_for_accuracy(2000.0, 1000.0), 0);
assert_eq!(depth_for_accuracy(1e-9, 1000.0), MAX_DEPTH);
}
}