1#![allow(dead_code)]
5
6#[derive(Debug)]
8pub struct OctreeNode {
9 pub bounds_min: [f32; 3],
10 pub bounds_max: [f32; 3],
11 pub point_indices: Vec<usize>,
13 pub children: Option<Box<[OctreeNode; 8]>>,
15}
16
17#[derive(Debug)]
19pub struct Octree {
20 pub root: OctreeNode,
21 pub points: Vec<[f32; 3]>,
22 pub max_depth: u32,
23 pub max_points_per_leaf: usize,
24}
25
26fn aabb_contains(min: &[f32; 3], max: &[f32; 3], p: &[f32; 3]) -> bool {
29 p[0] >= min[0]
30 && p[0] <= max[0]
31 && p[1] >= min[1]
32 && p[1] <= max[1]
33 && p[2] >= min[2]
34 && p[2] <= max[2]
35}
36
37fn aabb_overlaps_aabb(amin: &[f32; 3], amax: &[f32; 3], bmin: &[f32; 3], bmax: &[f32; 3]) -> bool {
38 amin[0] <= bmax[0]
39 && amax[0] >= bmin[0]
40 && amin[1] <= bmax[1]
41 && amax[1] >= bmin[1]
42 && amin[2] <= bmax[2]
43 && amax[2] >= bmin[2]
44}
45
46fn aabb_overlaps_sphere(min: &[f32; 3], max: &[f32; 3], center: &[f32; 3], radius: f32) -> bool {
47 let mut dist_sq = 0.0_f32;
48 for i in 0..3 {
49 let v = center[i].clamp(min[i], max[i]);
50 let d = center[i] - v;
51 dist_sq += d * d;
52 }
53 dist_sq <= radius * radius
54}
55
56fn aabb_overlaps_ray(
57 min: &[f32; 3],
58 max: &[f32; 3],
59 origin: &[f32; 3],
60 inv_dir: &[f32; 3],
61 max_dist: f32,
62) -> bool {
63 let mut tmin = 0.0_f32;
64 let mut tmax = max_dist;
65 for i in 0..3 {
66 if inv_dir[i].is_finite() {
67 let t1 = (min[i] - origin[i]) * inv_dir[i];
68 let t2 = (max[i] - origin[i]) * inv_dir[i];
69 let ta = t1.min(t2);
70 let tb = t1.max(t2);
71 tmin = tmin.max(ta);
72 tmax = tmax.min(tb);
73 } else {
74 if origin[i] < min[i] || origin[i] > max[i] {
76 return false;
77 }
78 }
79 }
80 tmin <= tmax
81}
82
83fn dist_sq(a: &[f32; 3], b: &[f32; 3]) -> f32 {
84 let dx = a[0] - b[0];
85 let dy = a[1] - b[1];
86 let dz = a[2] - b[2];
87 dx * dx + dy * dy + dz * dz
88}
89
90fn octant_bounds(
91 parent_min: &[f32; 3],
92 parent_max: &[f32; 3],
93 octant: usize,
94) -> ([f32; 3], [f32; 3]) {
95 let mid = [
96 (parent_min[0] + parent_max[0]) * 0.5,
97 (parent_min[1] + parent_max[1]) * 0.5,
98 (parent_min[2] + parent_max[2]) * 0.5,
99 ];
100 let min_x = if octant & 1 != 0 {
101 mid[0]
102 } else {
103 parent_min[0]
104 };
105 let min_y = if octant & 2 != 0 {
106 mid[1]
107 } else {
108 parent_min[1]
109 };
110 let min_z = if octant & 4 != 0 {
111 mid[2]
112 } else {
113 parent_min[2]
114 };
115 let max_x = if octant & 1 != 0 {
116 parent_max[0]
117 } else {
118 mid[0]
119 };
120 let max_y = if octant & 2 != 0 {
121 parent_max[1]
122 } else {
123 mid[1]
124 };
125 let max_z = if octant & 4 != 0 {
126 parent_max[2]
127 } else {
128 mid[2]
129 };
130 ([min_x, min_y, min_z], [max_x, max_y, max_z])
131}
132
133fn build_node(
136 indices: Vec<usize>,
137 points: &[[f32; 3]],
138 min: [f32; 3],
139 max: [f32; 3],
140 depth: u32,
141 max_depth: u32,
142 max_per_leaf: usize,
143) -> OctreeNode {
144 if depth >= max_depth || indices.len() <= max_per_leaf {
145 return OctreeNode {
146 bounds_min: min,
147 bounds_max: max,
148 point_indices: indices,
149 children: None,
150 };
151 }
152
153 let mut child_indices: [Vec<usize>; 8] = Default::default();
154 let mid = [
155 (min[0] + max[0]) * 0.5,
156 (min[1] + max[1]) * 0.5,
157 (min[2] + max[2]) * 0.5,
158 ];
159
160 for idx in &indices {
161 let p = &points[*idx];
162 let ox = if p[0] >= mid[0] { 1 } else { 0 };
163 let oy = if p[1] >= mid[1] { 2 } else { 0 };
164 let oz = if p[2] >= mid[2] { 4 } else { 0 };
165 child_indices[ox | oy | oz].push(*idx);
166 }
167
168 let children: [OctreeNode; 8] = std::array::from_fn(|i| {
169 let (cmin, cmax) = octant_bounds(&min, &max, i);
170 build_node(
171 std::mem::take(&mut child_indices[i]),
172 points,
173 cmin,
174 cmax,
175 depth + 1,
176 max_depth,
177 max_per_leaf,
178 )
179 });
180
181 OctreeNode {
182 bounds_min: min,
183 bounds_max: max,
184 point_indices: Vec::new(), children: Some(Box::new(children)),
186 }
187}
188
189pub fn build_octree(points: &[[f32; 3]], max_depth: u32, max_per_leaf: usize) -> Octree {
193 let mut min = [f32::MAX; 3];
194 let mut max = [f32::MIN; 3];
195 for p in points {
196 for i in 0..3 {
197 if p[i] < min[i] {
198 min[i] = p[i];
199 }
200 if p[i] > max[i] {
201 max[i] = p[i];
202 }
203 }
204 }
205 for i in 0..3 {
207 max[i] += 1e-4;
208 min[i] -= 1e-4;
209 }
210
211 let indices: Vec<usize> = (0..points.len()).collect();
212 let root = build_node(indices, points, min, max, 0, max_depth, max_per_leaf);
213 Octree {
214 root,
215 points: points.to_vec(),
216 max_depth,
217 max_points_per_leaf: max_per_leaf,
218 }
219}
220
221fn collect_sphere(
224 node: &OctreeNode,
225 center: &[f32; 3],
226 radius: f32,
227 result: &mut Vec<usize>,
228 points: &[[f32; 3]],
229) {
230 if !aabb_overlaps_sphere(&node.bounds_min, &node.bounds_max, center, radius) {
231 return;
232 }
233 if let Some(children) = &node.children {
234 for child in children.iter() {
235 collect_sphere(child, center, radius, result, points);
236 }
237 } else {
238 let r2 = radius * radius;
239 for &idx in &node.point_indices {
240 if dist_sq(&points[idx], center) <= r2 {
241 result.push(idx);
242 }
243 }
244 }
245}
246
247pub fn query_sphere(octree: &Octree, center: [f32; 3], radius: f32) -> Vec<usize> {
249 let mut result = Vec::new();
250 collect_sphere(&octree.root, ¢er, radius, &mut result, &octree.points);
251 result
252}
253
254fn collect_aabb(
255 node: &OctreeNode,
256 min: &[f32; 3],
257 max: &[f32; 3],
258 result: &mut Vec<usize>,
259 points: &[[f32; 3]],
260) {
261 if !aabb_overlaps_aabb(&node.bounds_min, &node.bounds_max, min, max) {
262 return;
263 }
264 if let Some(children) = &node.children {
265 for child in children.iter() {
266 collect_aabb(child, min, max, result, points);
267 }
268 } else {
269 for &idx in &node.point_indices {
270 if aabb_contains(min, max, &points[idx]) {
271 result.push(idx);
272 }
273 }
274 }
275}
276
277pub fn query_aabb(octree: &Octree, min: [f32; 3], max: [f32; 3]) -> Vec<usize> {
279 let mut result = Vec::new();
280 collect_aabb(&octree.root, &min, &max, &mut result, &octree.points);
281 result
282}
283
284fn collect_nn(
285 node: &OctreeNode,
286 query: &[f32; 3],
287 best_dist_sq: &mut f32,
288 best_idx: &mut Option<usize>,
289 points: &[[f32; 3]],
290) {
291 let mut node_dist_sq = 0.0_f32;
293 for ((&q, &bmin), &bmax) in query
294 .iter()
295 .zip(node.bounds_min.iter())
296 .zip(node.bounds_max.iter())
297 {
298 if bmin > bmax {
299 return;
300 } let v = q.clamp(bmin, bmax);
302 let d = q - v;
303 node_dist_sq += d * d;
304 }
305 if node_dist_sq >= *best_dist_sq {
306 return;
307 }
308
309 if let Some(children) = &node.children {
310 for child in children.iter() {
311 collect_nn(child, query, best_dist_sq, best_idx, points);
312 }
313 } else {
314 for &idx in &node.point_indices {
315 let d2 = dist_sq(&points[idx], query);
316 if d2 < *best_dist_sq {
317 *best_dist_sq = d2;
318 *best_idx = Some(idx);
319 }
320 }
321 }
322}
323
324pub fn nearest_neighbor(octree: &Octree, query: [f32; 3]) -> Option<(usize, f32)> {
326 let mut best_dist_sq = f32::MAX;
327 let mut best_idx = None;
328 collect_nn(
329 &octree.root,
330 &query,
331 &mut best_dist_sq,
332 &mut best_idx,
333 &octree.points,
334 );
335 best_idx.map(|idx| (idx, best_dist_sq.sqrt()))
336}
337
338fn collect_knn(
339 node: &OctreeNode,
340 query: &[f32; 3],
341 heap: &mut Vec<(f32, usize)>, k: usize,
343 points: &[[f32; 3]],
344) {
345 let worst = if heap.len() == k { heap[0].0 } else { f32::MAX };
346 let mut node_dist_sq = 0.0_f32;
347 for ((&q, &bmin), &bmax) in query
348 .iter()
349 .zip(node.bounds_min.iter())
350 .zip(node.bounds_max.iter())
351 {
352 if bmin > bmax {
353 return;
354 } let v = q.clamp(bmin, bmax);
356 let d = q - v;
357 node_dist_sq += d * d;
358 }
359 if node_dist_sq >= worst {
360 return;
361 }
362
363 if let Some(children) = &node.children {
364 for child in children.iter() {
365 collect_knn(child, query, heap, k, points);
366 }
367 } else {
368 for &idx in &node.point_indices {
369 let d2 = dist_sq(&points[idx], query);
370 let cur_worst = if heap.len() == k { heap[0].0 } else { f32::MAX };
371 if d2 < cur_worst {
372 if heap.len() == k {
373 heap.remove(0);
375 }
376 let pos = heap.partition_point(|&(d, _)| d < d2);
378 heap.insert(pos, (d2, idx));
379 }
380 }
381 }
382}
383
384pub fn k_nearest_neighbors(octree: &Octree, query: [f32; 3], k: usize) -> Vec<(usize, f32)> {
386 if k == 0 {
387 return Vec::new();
388 }
389 let mut heap: Vec<(f32, usize)> = Vec::with_capacity(k + 1);
390 collect_knn(&octree.root, &query, &mut heap, k, &octree.points);
391 heap.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
392 heap.into_iter().map(|(d2, idx)| (idx, d2.sqrt())).collect()
393}
394
395fn node_depth(node: &OctreeNode) -> u32 {
396 match &node.children {
397 None => 0,
398 Some(children) => 1 + children.iter().map(node_depth).max().unwrap_or(0),
399 }
400}
401
402pub fn octree_depth(octree: &Octree) -> u32 {
404 node_depth(&octree.root)
405}
406
407fn count_leaves(node: &OctreeNode) -> usize {
408 match &node.children {
409 None => 1,
410 Some(children) => children.iter().map(count_leaves).sum(),
411 }
412}
413
414pub fn octree_leaf_count(octree: &Octree) -> usize {
415 count_leaves(&octree.root)
416}
417
418fn count_points(node: &OctreeNode) -> usize {
419 match &node.children {
420 None => node.point_indices.len(),
421 Some(children) => children.iter().map(count_points).sum(),
422 }
423}
424
425pub fn octree_point_count(octree: &Octree) -> usize {
426 count_points(&octree.root)
427}
428
429pub fn octree_stats(octree: &Octree) -> (u32, usize, usize) {
431 (
432 octree_depth(octree),
433 octree_leaf_count(octree),
434 octree_point_count(octree),
435 )
436}
437
438pub fn insert_point(octree: &mut Octree, point: [f32; 3]) -> usize {
441 let idx = octree.points.len();
442 octree.points.push(point);
443 insert_into_node(
444 &mut octree.root,
445 point,
446 idx,
447 0,
448 octree.max_depth,
449 octree.max_points_per_leaf,
450 &octree.points.clone(),
451 );
452 idx
453}
454
455fn insert_into_node(
456 node: &mut OctreeNode,
457 point: [f32; 3],
458 idx: usize,
459 depth: u32,
460 max_depth: u32,
461 max_per_leaf: usize,
462 points: &[[f32; 3]],
463) {
464 if !aabb_contains(&node.bounds_min, &node.bounds_max, &point) {
465 for ((&pv, bmin), bmax) in point
467 .iter()
468 .zip(node.bounds_min.iter_mut())
469 .zip(node.bounds_max.iter_mut())
470 {
471 if pv < *bmin {
472 *bmin = pv - 1e-4;
473 }
474 if pv > *bmax {
475 *bmax = pv + 1e-4;
476 }
477 }
478 }
479
480 if let Some(children) = &mut node.children {
481 let mid = [
483 (node.bounds_min[0] + node.bounds_max[0]) * 0.5,
484 (node.bounds_min[1] + node.bounds_max[1]) * 0.5,
485 (node.bounds_min[2] + node.bounds_max[2]) * 0.5,
486 ];
487 let ox = if point[0] >= mid[0] { 1 } else { 0 };
488 let oy = if point[1] >= mid[1] { 2 } else { 0 };
489 let oz = if point[2] >= mid[2] { 4 } else { 0 };
490 insert_into_node(
491 &mut children[ox | oy | oz],
492 point,
493 idx,
494 depth + 1,
495 max_depth,
496 max_per_leaf,
497 points,
498 );
499 } else {
500 node.point_indices.push(idx);
501 if node.point_indices.len() > max_per_leaf && depth < max_depth {
503 let all_indices = std::mem::take(&mut node.point_indices);
504 let min = node.bounds_min;
505 let max = node.bounds_max;
506 let mut child_indices: [Vec<usize>; 8] = Default::default();
507 let mid = [
508 (min[0] + max[0]) * 0.5,
509 (min[1] + max[1]) * 0.5,
510 (min[2] + max[2]) * 0.5,
511 ];
512 for i in all_indices {
513 let p = &points[i];
514 let ox = if p[0] >= mid[0] { 1 } else { 0 };
515 let oy = if p[1] >= mid[1] { 2 } else { 0 };
516 let oz = if p[2] >= mid[2] { 4 } else { 0 };
517 child_indices[ox | oy | oz].push(i);
518 }
519 let children: [OctreeNode; 8] = std::array::from_fn(|i| {
520 let (cmin, cmax) = octant_bounds(&min, &max, i);
521 OctreeNode {
522 bounds_min: cmin,
523 bounds_max: cmax,
524 point_indices: std::mem::take(&mut child_indices[i]),
525 children: None,
526 }
527 });
528 node.children = Some(Box::new(children));
529 }
530 }
531}
532
533fn collect_ray(
534 node: &OctreeNode,
535 origin: &[f32; 3],
536 inv_dir: &[f32; 3],
537 max_dist: f32,
538 result: &mut Vec<usize>,
539) {
540 if !aabb_overlaps_ray(
541 &node.bounds_min,
542 &node.bounds_max,
543 origin,
544 inv_dir,
545 max_dist,
546 ) {
547 return;
548 }
549 if let Some(children) = &node.children {
550 for child in children.iter() {
551 collect_ray(child, origin, inv_dir, max_dist, result);
552 }
553 } else {
554 for &idx in &node.point_indices {
556 result.push(idx);
557 }
558 }
559}
560
561pub fn ray_query(
563 octree: &Octree,
564 origin: [f32; 3],
565 direction: [f32; 3],
566 max_dist: f32,
567) -> Vec<usize> {
568 let inv_dir = [
569 if direction[0].abs() > 1e-10 {
570 1.0 / direction[0]
571 } else {
572 f32::INFINITY
573 },
574 if direction[1].abs() > 1e-10 {
575 1.0 / direction[1]
576 } else {
577 f32::INFINITY
578 },
579 if direction[2].abs() > 1e-10 {
580 1.0 / direction[2]
581 } else {
582 f32::INFINITY
583 },
584 ];
585 let mut result = Vec::new();
586 collect_ray(&octree.root, &origin, &inv_dir, max_dist, &mut result);
587 result.sort_unstable();
588 result.dedup();
589 result
590}
591
592#[cfg(test)]
595mod tests {
596 use super::*;
597
598 fn grid_points(n: usize) -> Vec<[f32; 3]> {
599 let mut pts = Vec::new();
600 for x in 0..n {
601 for y in 0..n {
602 for z in 0..n {
603 pts.push([x as f32, y as f32, z as f32]);
604 }
605 }
606 }
607 pts
608 }
609
610 #[test]
611 fn test_build_empty() {
612 let oct = build_octree(&[], 4, 8);
613 assert_eq!(octree_point_count(&oct), 0);
614 }
615
616 #[test]
617 fn test_build_single_point() {
618 let pts = vec![[1.0, 2.0, 3.0]];
619 let oct = build_octree(&pts, 4, 8);
620 assert_eq!(octree_point_count(&oct), 1);
621 }
622
623 #[test]
624 fn test_query_sphere_finds_nearby() {
625 let pts = grid_points(5);
626 let oct = build_octree(&pts, 4, 8);
627 let result = query_sphere(&oct, [0.0, 0.0, 0.0], 1.5);
628 assert!(!result.is_empty());
629 for idx in &result {
631 let p = oct.points[*idx];
632 let d = (p[0] * p[0] + p[1] * p[1] + p[2] * p[2]).sqrt();
633 assert!(d <= 1.5 + 1e-4);
634 }
635 }
636
637 #[test]
638 fn test_query_sphere_excludes_far() {
639 let pts = vec![[0.0, 0.0, 0.0], [100.0, 100.0, 100.0]];
640 let oct = build_octree(&pts, 4, 4);
641 let result = query_sphere(&oct, [0.0, 0.0, 0.0], 1.0);
642 assert_eq!(result.len(), 1);
643 assert_eq!(result[0], 0);
644 }
645
646 #[test]
647 fn test_query_aabb() {
648 let pts = grid_points(4);
649 let oct = build_octree(&pts, 4, 4);
650 let result = query_aabb(&oct, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
651 assert!(!result.is_empty());
652 for idx in &result {
653 let p = oct.points[*idx];
654 assert!(p[0] >= 0.0 && p[0] <= 1.0);
655 assert!(p[1] >= 0.0 && p[1] <= 1.0);
656 assert!(p[2] >= 0.0 && p[2] <= 1.0);
657 }
658 }
659
660 #[test]
661 fn test_nearest_neighbor_exact() {
662 let pts = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]];
663 let oct = build_octree(&pts, 4, 4);
664 let (idx, dist) = nearest_neighbor(&oct, [1.0, 0.0, 0.0]).expect("should succeed");
665 assert_eq!(idx, 1);
666 assert!(dist < 1e-5);
667 }
668
669 #[test]
670 fn test_nearest_neighbor_empty() {
671 let oct = build_octree(&[], 4, 4);
672 assert!(nearest_neighbor(&oct, [0.0, 0.0, 0.0]).is_none());
673 }
674
675 #[test]
676 fn test_k_nearest_neighbors() {
677 let pts = grid_points(4);
678 let oct = build_octree(&pts, 4, 4);
679 let knn = k_nearest_neighbors(&oct, [1.0, 1.0, 1.0], 3);
680 assert_eq!(knn.len(), 3);
681 for i in 0..knn.len() - 1 {
683 assert!(knn[i].1 <= knn[i + 1].1 + 1e-5);
684 }
685 }
686
687 #[test]
688 fn test_octree_depth() {
689 let pts = grid_points(3);
690 let oct = build_octree(&pts, 4, 2);
691 let depth = octree_depth(&oct);
692 assert!(depth > 0);
693 assert!(depth <= 4);
694 }
695
696 #[test]
697 fn test_octree_leaf_count_positive() {
698 let pts = grid_points(3);
699 let oct = build_octree(&pts, 3, 4);
700 assert!(octree_leaf_count(&oct) >= 1);
701 }
702
703 #[test]
704 fn test_octree_point_count_matches() {
705 let pts = grid_points(3);
706 let oct = build_octree(&pts, 4, 4);
707 assert_eq!(octree_point_count(&oct), pts.len());
708 }
709
710 #[test]
711 fn test_octree_stats() {
712 let pts = grid_points(3);
713 let oct = build_octree(&pts, 4, 4);
714 let (depth, leaves, total) = octree_stats(&oct);
715 assert_eq!(total, pts.len());
716 assert!(leaves >= 1);
717 assert!(depth <= 4);
718 }
719
720 #[test]
721 fn test_insert_point() {
722 let pts = vec![[0.0, 0.0, 0.0]];
723 let mut oct = build_octree(&pts, 4, 4);
724 let new_idx = insert_point(&mut oct, [5.0, 5.0, 5.0]);
725 assert_eq!(new_idx, 1);
726 assert_eq!(octree_point_count(&oct), 2);
727 }
728
729 #[test]
730 fn test_ray_query() {
731 let pts = vec![
732 [0.0, 0.0, 0.0],
733 [1.0, 0.0, 0.0],
734 [0.0, 5.0, 0.0], ];
736 let oct = build_octree(&pts, 4, 4);
737 let result = ray_query(&oct, [0.0, 0.0, -1.0], [0.0, 0.0, 1.0], 100.0);
738 assert!(!result.is_empty());
740 }
741
742 #[test]
743 fn test_k_nearest_zero_k() {
744 let pts = grid_points(3);
745 let oct = build_octree(&pts, 4, 4);
746 let knn = k_nearest_neighbors(&oct, [0.0, 0.0, 0.0], 0);
747 assert!(knn.is_empty());
748 }
749}