use std::collections::HashMap;
use crate::point::{BoundingBox3D, Point3D};
pub struct OctreeNode {
pub bounds: BoundingBox3D,
pub points: Vec<Point3D>,
pub children: Option<Box<[OctreeNode; 8]>>,
pub depth: u8,
}
impl OctreeNode {
pub fn new(bounds: BoundingBox3D, depth: u8) -> Self {
Self {
bounds,
points: Vec::new(),
children: None,
depth,
}
}
pub fn point_count(&self) -> usize {
match &self.children {
None => self.points.len(),
Some(children) => {
self.points.len() + children.iter().map(|c| c.point_count()).sum::<usize>()
}
}
}
#[inline]
pub fn is_leaf(&self) -> bool {
self.children.is_none()
}
pub fn depth_max(&self) -> u8 {
match &self.children {
None => self.depth,
Some(children) => children
.iter()
.map(|c| c.depth_max())
.max()
.unwrap_or(self.depth),
}
}
}
pub struct Octree {
root: OctreeNode,
max_points_per_node: usize,
max_depth: u8,
total_points: usize,
}
impl Octree {
pub fn new(bounds: BoundingBox3D) -> Self {
Self {
root: OctreeNode::new(bounds, 0),
max_points_per_node: 64,
max_depth: 16,
total_points: 0,
}
}
pub fn with_max_points(mut self, n: usize) -> Self {
self.max_points_per_node = n;
self
}
pub fn with_max_depth(mut self, depth: u8) -> Self {
self.max_depth = depth;
self
}
#[inline]
pub fn len(&self) -> usize {
self.total_points
}
#[inline]
pub fn is_empty(&self) -> bool {
self.total_points == 0
}
pub fn insert(&mut self, point: Point3D) {
if !self.root.bounds.contains(&point) {
return;
}
let max_pts = self.max_points_per_node;
let max_d = self.max_depth;
Self::node_insert(&mut self.root, point, max_pts, max_d);
self.total_points += 1;
}
pub fn insert_batch(&mut self, points: Vec<Point3D>) {
for p in points {
self.insert(p);
}
}
pub fn query_bbox<'a>(&'a self, bbox: &BoundingBox3D) -> Vec<&'a Point3D> {
let mut result = Vec::new();
Self::node_query_bbox(&self.root, bbox, &mut result);
result
}
pub fn query_sphere(&self, cx: f64, cy: f64, cz: f64, radius: f64) -> Vec<&Point3D> {
let r2 = radius * radius;
let sphere_bbox = BoundingBox3D {
min_x: cx - radius,
min_y: cy - radius,
min_z: cz - radius,
max_x: cx + radius,
max_y: cy + radius,
max_z: cz + radius,
};
let mut result = Vec::new();
Self::node_query_sphere(&self.root, cx, cy, cz, r2, &sphere_bbox, &mut result);
result
}
pub fn k_nearest(&self, cx: f64, cy: f64, cz: f64, k: usize) -> Vec<(&Point3D, f64)> {
if k == 0 || self.is_empty() {
return Vec::new();
}
let mut candidates: Vec<(&Point3D, f64)> = Vec::new();
Self::node_collect_knn(&self.root, cx, cy, cz, k, f64::INFINITY, &mut candidates);
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
candidates.truncate(k);
candidates
}
pub fn by_classification(&self, class: u8) -> Vec<&Point3D> {
let mut result = Vec::new();
Self::node_collect_by_class(&self.root, class, &mut result);
result
}
pub fn stats(&self) -> PointCloudStats {
let mut all_points: Vec<&Point3D> = Vec::with_capacity(self.total_points);
Self::node_collect_all(&self.root, &mut all_points);
if all_points.is_empty() {
return PointCloudStats {
count: 0,
bounds: None,
mean_z: 0.0,
std_z: 0.0,
density: 0.0,
classification_counts: HashMap::new(),
};
}
let count = all_points.len();
let sum_z: f64 = all_points.iter().map(|p| p.z).sum();
let mean_z = sum_z / count as f64;
let variance_z: f64 = all_points
.iter()
.map(|p| {
let d = p.z - mean_z;
d * d
})
.sum::<f64>()
/ count as f64;
let std_z = variance_z.sqrt();
let mut classification_counts: HashMap<u8, usize> = HashMap::new();
for p in &all_points {
*classification_counts.entry(p.classification).or_insert(0) += 1;
}
let bounds = {
let first = all_points[0];
let (mut min_x, mut min_y, mut min_z) = (first.x, first.y, first.z);
let (mut max_x, mut max_y, mut max_z) = (first.x, first.y, first.z);
for p in all_points.iter().skip(1) {
if p.x < min_x {
min_x = p.x;
}
if p.y < min_y {
min_y = p.y;
}
if p.z < min_z {
min_z = p.z;
}
if p.x > max_x {
max_x = p.x;
}
if p.y > max_y {
max_y = p.y;
}
if p.z > max_z {
max_z = p.z;
}
}
BoundingBox3D::new(min_x, min_y, min_z, max_x, max_y, max_z)
};
let density = if let Some(ref bb) = bounds {
let dx = bb.max_x - bb.min_x;
let dy = bb.max_y - bb.min_y;
let xy_area = dx * dy;
if xy_area > 0.0 {
count as f64 / xy_area
} else {
0.0
}
} else {
0.0
};
PointCloudStats {
count,
bounds,
mean_z,
std_z,
density,
classification_counts,
}
}
pub fn voxel_downsample(&self, voxel_size: f64) -> Vec<Point3D> {
if voxel_size <= 0.0 {
return Vec::new();
}
let mut all_points: Vec<&Point3D> = Vec::with_capacity(self.total_points);
Self::node_collect_all(&self.root, &mut all_points);
let origin_x = self.root.bounds.min_x;
let origin_y = self.root.bounds.min_y;
let origin_z = self.root.bounds.min_z;
let mut occupied: HashMap<(i64, i64, i64), ()> = HashMap::new();
let mut result: Vec<Point3D> = Vec::new();
for &p in &all_points {
let ix = ((p.x - origin_x) / voxel_size).floor() as i64;
let iy = ((p.y - origin_y) / voxel_size).floor() as i64;
let iz = ((p.z - origin_z) / voxel_size).floor() as i64;
let key = (ix, iy, iz);
if occupied.insert(key, ()).is_none() {
result.push((*p).clone());
}
}
result
}
fn node_insert(node: &mut OctreeNode, point: Point3D, max_pts: usize, max_d: u8) {
if node.is_leaf() {
node.points.push(point);
if node.points.len() > max_pts && node.depth < max_d {
Self::split_node(node, max_pts, max_d);
}
} else if let Some(children) = node.children.as_mut() {
let child_idx = children.iter().enumerate().find_map(|(i, c)| {
if c.bounds.contains(&point) {
Some(i)
} else {
None
}
});
if let Some(idx) = child_idx {
Self::node_insert(&mut children[idx], point, max_pts, max_d);
} else {
node.points.push(point);
}
} else {
node.points.push(point);
}
}
fn split_node(node: &mut OctreeNode, max_pts: usize, max_d: u8) {
let octants = node.bounds.split_octants();
let next_depth = node.depth.saturating_add(1);
let children: Box<[OctreeNode; 8]> = Box::new([
OctreeNode::new(octants[0].clone(), next_depth),
OctreeNode::new(octants[1].clone(), next_depth),
OctreeNode::new(octants[2].clone(), next_depth),
OctreeNode::new(octants[3].clone(), next_depth),
OctreeNode::new(octants[4].clone(), next_depth),
OctreeNode::new(octants[5].clone(), next_depth),
OctreeNode::new(octants[6].clone(), next_depth),
OctreeNode::new(octants[7].clone(), next_depth),
]);
node.children = Some(children);
let old_points = std::mem::take(&mut node.points);
let mut overflow: Vec<Point3D> = Vec::new();
if let Some(children_ref) = node.children.as_mut() {
for p in old_points {
let child_idx = children_ref
.iter()
.enumerate()
.find_map(|(i, c)| if c.bounds.contains(&p) { Some(i) } else { None });
if let Some(idx) = child_idx {
Self::node_insert(&mut children_ref[idx], p, max_pts, max_d);
} else {
overflow.push(p);
}
}
}
node.points.extend(overflow);
}
fn node_query_bbox<'a>(
node: &'a OctreeNode,
bbox: &BoundingBox3D,
result: &mut Vec<&'a Point3D>,
) {
if !node.bounds.intersects_3d(bbox) {
return;
}
for p in &node.points {
if bbox.contains(p) {
result.push(p);
}
}
if let Some(children) = &node.children {
for child in children.iter() {
Self::node_query_bbox(child, bbox, result);
}
}
}
fn node_query_sphere<'a>(
node: &'a OctreeNode,
cx: f64,
cy: f64,
cz: f64,
r2: f64,
sphere_bbox: &BoundingBox3D,
result: &mut Vec<&'a Point3D>,
) {
if !node.bounds.intersects_3d(sphere_bbox) {
return;
}
for p in &node.points {
let dx = p.x - cx;
let dy = p.y - cy;
let dz = p.z - cz;
if dx * dx + dy * dy + dz * dz <= r2 {
result.push(p);
}
}
if let Some(children) = &node.children {
for child in children.iter() {
Self::node_query_sphere(child, cx, cy, cz, r2, sphere_bbox, result);
}
}
}
fn node_collect_knn<'a>(
node: &'a OctreeNode,
cx: f64,
cy: f64,
cz: f64,
k: usize,
mut best_dist: f64,
result: &mut Vec<(&'a Point3D, f64)>,
) {
if Self::node_min_dist_sq(node, cx, cy, cz) > best_dist * best_dist {
return;
}
for p in &node.points {
let dx = p.x - cx;
let dy = p.y - cy;
let dz = p.z - cz;
let dist = (dx * dx + dy * dy + dz * dz).sqrt();
result.push((p, dist));
if result.len() >= k {
if let Some(&(_, d)) = result
.iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
{
if d < best_dist {
best_dist = d;
}
}
}
}
if let Some(children) = &node.children {
for child in children.iter() {
Self::node_collect_knn(child, cx, cy, cz, k, best_dist, result);
}
}
}
fn node_min_dist_sq(node: &OctreeNode, cx: f64, cy: f64, cz: f64) -> f64 {
let b = &node.bounds;
let dx = if cx < b.min_x {
b.min_x - cx
} else if cx > b.max_x {
cx - b.max_x
} else {
0.0
};
let dy = if cy < b.min_y {
b.min_y - cy
} else if cy > b.max_y {
cy - b.max_y
} else {
0.0
};
let dz = if cz < b.min_z {
b.min_z - cz
} else if cz > b.max_z {
cz - b.max_z
} else {
0.0
};
dx * dx + dy * dy + dz * dz
}
fn node_collect_by_class<'a>(node: &'a OctreeNode, class: u8, result: &mut Vec<&'a Point3D>) {
for p in &node.points {
if p.classification == class {
result.push(p);
}
}
if let Some(children) = &node.children {
for child in children.iter() {
Self::node_collect_by_class(child, class, result);
}
}
}
fn node_collect_all<'a>(node: &'a OctreeNode, result: &mut Vec<&'a Point3D>) {
result.extend(node.points.iter());
if let Some(children) = &node.children {
for child in children.iter() {
Self::node_collect_all(child, result);
}
}
}
}
pub struct PointCloudStats {
pub count: usize,
pub bounds: Option<BoundingBox3D>,
pub mean_z: f64,
pub std_z: f64,
pub density: f64,
pub classification_counts: HashMap<u8, usize>,
}