Skip to main content

oxihuman_core/
spatial_index.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(dead_code)]
5
6/// A node in the octree (either internal or leaf).
7#[derive(Debug)]
8pub struct OctreeNode {
9    pub bounds_min: [f32; 3],
10    pub bounds_max: [f32; 3],
11    /// Indices into `Octree::points` for points in this leaf node.
12    pub point_indices: Vec<usize>,
13    /// Eight children if this is an internal node.
14    pub children: Option<Box<[OctreeNode; 8]>>,
15}
16
17/// A spatial octree over a set of 3-D points.
18#[derive(Debug)]
19pub struct Octree {
20    pub root: OctreeNode,
21    pub points: Vec<[f32; 3]>,
22    pub max_depth: u32,
23    pub max_points_per_leaf: usize,
24}
25
26// ─── Helpers ──────────────────────────────────────────────────────────────────
27
28fn aabb_contains(min: &[f32; 3], max: &[f32; 3], p: &[f32; 3]) -> bool {
29    p[0] >= min[0]
30        && p[0] <= max[0]
31        && p[1] >= min[1]
32        && p[1] <= max[1]
33        && p[2] >= min[2]
34        && p[2] <= max[2]
35}
36
37fn aabb_overlaps_aabb(amin: &[f32; 3], amax: &[f32; 3], bmin: &[f32; 3], bmax: &[f32; 3]) -> bool {
38    amin[0] <= bmax[0]
39        && amax[0] >= bmin[0]
40        && amin[1] <= bmax[1]
41        && amax[1] >= bmin[1]
42        && amin[2] <= bmax[2]
43        && amax[2] >= bmin[2]
44}
45
46fn aabb_overlaps_sphere(min: &[f32; 3], max: &[f32; 3], center: &[f32; 3], radius: f32) -> bool {
47    let mut dist_sq = 0.0_f32;
48    for i in 0..3 {
49        let v = center[i].clamp(min[i], max[i]);
50        let d = center[i] - v;
51        dist_sq += d * d;
52    }
53    dist_sq <= radius * radius
54}
55
56fn aabb_overlaps_ray(
57    min: &[f32; 3],
58    max: &[f32; 3],
59    origin: &[f32; 3],
60    inv_dir: &[f32; 3],
61    max_dist: f32,
62) -> bool {
63    let mut tmin = 0.0_f32;
64    let mut tmax = max_dist;
65    for i in 0..3 {
66        if inv_dir[i].is_finite() {
67            let t1 = (min[i] - origin[i]) * inv_dir[i];
68            let t2 = (max[i] - origin[i]) * inv_dir[i];
69            let ta = t1.min(t2);
70            let tb = t1.max(t2);
71            tmin = tmin.max(ta);
72            tmax = tmax.min(tb);
73        } else {
74            // ray is parallel to this axis
75            if origin[i] < min[i] || origin[i] > max[i] {
76                return false;
77            }
78        }
79    }
80    tmin <= tmax
81}
82
83fn dist_sq(a: &[f32; 3], b: &[f32; 3]) -> f32 {
84    let dx = a[0] - b[0];
85    let dy = a[1] - b[1];
86    let dz = a[2] - b[2];
87    dx * dx + dy * dy + dz * dz
88}
89
90fn octant_bounds(
91    parent_min: &[f32; 3],
92    parent_max: &[f32; 3],
93    octant: usize,
94) -> ([f32; 3], [f32; 3]) {
95    let mid = [
96        (parent_min[0] + parent_max[0]) * 0.5,
97        (parent_min[1] + parent_max[1]) * 0.5,
98        (parent_min[2] + parent_max[2]) * 0.5,
99    ];
100    let min_x = if octant & 1 != 0 {
101        mid[0]
102    } else {
103        parent_min[0]
104    };
105    let min_y = if octant & 2 != 0 {
106        mid[1]
107    } else {
108        parent_min[1]
109    };
110    let min_z = if octant & 4 != 0 {
111        mid[2]
112    } else {
113        parent_min[2]
114    };
115    let max_x = if octant & 1 != 0 {
116        parent_max[0]
117    } else {
118        mid[0]
119    };
120    let max_y = if octant & 2 != 0 {
121        parent_max[1]
122    } else {
123        mid[1]
124    };
125    let max_z = if octant & 4 != 0 {
126        parent_max[2]
127    } else {
128        mid[2]
129    };
130    ([min_x, min_y, min_z], [max_x, max_y, max_z])
131}
132
133// ─── Recursive build ─────────────────────────────────────────────────────────
134
135fn build_node(
136    indices: Vec<usize>,
137    points: &[[f32; 3]],
138    min: [f32; 3],
139    max: [f32; 3],
140    depth: u32,
141    max_depth: u32,
142    max_per_leaf: usize,
143) -> OctreeNode {
144    if depth >= max_depth || indices.len() <= max_per_leaf {
145        return OctreeNode {
146            bounds_min: min,
147            bounds_max: max,
148            point_indices: indices,
149            children: None,
150        };
151    }
152
153    let mut child_indices: [Vec<usize>; 8] = Default::default();
154    let mid = [
155        (min[0] + max[0]) * 0.5,
156        (min[1] + max[1]) * 0.5,
157        (min[2] + max[2]) * 0.5,
158    ];
159
160    for idx in &indices {
161        let p = &points[*idx];
162        let ox = if p[0] >= mid[0] { 1 } else { 0 };
163        let oy = if p[1] >= mid[1] { 2 } else { 0 };
164        let oz = if p[2] >= mid[2] { 4 } else { 0 };
165        child_indices[ox | oy | oz].push(*idx);
166    }
167
168    let children: [OctreeNode; 8] = std::array::from_fn(|i| {
169        let (cmin, cmax) = octant_bounds(&min, &max, i);
170        build_node(
171            std::mem::take(&mut child_indices[i]),
172            points,
173            cmin,
174            cmax,
175            depth + 1,
176            max_depth,
177            max_per_leaf,
178        )
179    });
180
181    OctreeNode {
182        bounds_min: min,
183        bounds_max: max,
184        point_indices: Vec::new(), // internal nodes hold no points directly
185        children: Some(Box::new(children)),
186    }
187}
188
189// ─── Public build ─────────────────────────────────────────────────────────────
190
191/// Build an octree from a set of 3-D points.
192pub fn build_octree(points: &[[f32; 3]], max_depth: u32, max_per_leaf: usize) -> Octree {
193    let mut min = [f32::MAX; 3];
194    let mut max = [f32::MIN; 3];
195    for p in points {
196        for i in 0..3 {
197            if p[i] < min[i] {
198                min[i] = p[i];
199            }
200            if p[i] > max[i] {
201                max[i] = p[i];
202            }
203        }
204    }
205    // Add small epsilon to avoid degenerate bounds
206    for i in 0..3 {
207        max[i] += 1e-4;
208        min[i] -= 1e-4;
209    }
210
211    let indices: Vec<usize> = (0..points.len()).collect();
212    let root = build_node(indices, points, min, max, 0, max_depth, max_per_leaf);
213    Octree {
214        root,
215        points: points.to_vec(),
216        max_depth,
217        max_points_per_leaf: max_per_leaf,
218    }
219}
220
221// ─── Queries ─────────────────────────────────────────────────────────────────
222
223fn collect_sphere(
224    node: &OctreeNode,
225    center: &[f32; 3],
226    radius: f32,
227    result: &mut Vec<usize>,
228    points: &[[f32; 3]],
229) {
230    if !aabb_overlaps_sphere(&node.bounds_min, &node.bounds_max, center, radius) {
231        return;
232    }
233    if let Some(children) = &node.children {
234        for child in children.iter() {
235            collect_sphere(child, center, radius, result, points);
236        }
237    } else {
238        let r2 = radius * radius;
239        for &idx in &node.point_indices {
240            if dist_sq(&points[idx], center) <= r2 {
241                result.push(idx);
242            }
243        }
244    }
245}
246
247/// Return indices of all points within `radius` of `center`.
248pub fn query_sphere(octree: &Octree, center: [f32; 3], radius: f32) -> Vec<usize> {
249    let mut result = Vec::new();
250    collect_sphere(&octree.root, &center, radius, &mut result, &octree.points);
251    result
252}
253
254fn collect_aabb(
255    node: &OctreeNode,
256    min: &[f32; 3],
257    max: &[f32; 3],
258    result: &mut Vec<usize>,
259    points: &[[f32; 3]],
260) {
261    if !aabb_overlaps_aabb(&node.bounds_min, &node.bounds_max, min, max) {
262        return;
263    }
264    if let Some(children) = &node.children {
265        for child in children.iter() {
266            collect_aabb(child, min, max, result, points);
267        }
268    } else {
269        for &idx in &node.point_indices {
270            if aabb_contains(min, max, &points[idx]) {
271                result.push(idx);
272            }
273        }
274    }
275}
276
277/// Return indices of all points within the axis-aligned bounding box [min, max].
278pub fn query_aabb(octree: &Octree, min: [f32; 3], max: [f32; 3]) -> Vec<usize> {
279    let mut result = Vec::new();
280    collect_aabb(&octree.root, &min, &max, &mut result, &octree.points);
281    result
282}
283
284fn collect_nn(
285    node: &OctreeNode,
286    query: &[f32; 3],
287    best_dist_sq: &mut f32,
288    best_idx: &mut Option<usize>,
289    points: &[[f32; 3]],
290) {
291    // Prune if closest possible point in this node is farther than current best
292    let mut node_dist_sq = 0.0_f32;
293    for ((&q, &bmin), &bmax) in query
294        .iter()
295        .zip(node.bounds_min.iter())
296        .zip(node.bounds_max.iter())
297    {
298        if bmin > bmax {
299            return;
300        } // inverted/empty bounds
301        let v = q.clamp(bmin, bmax);
302        let d = q - v;
303        node_dist_sq += d * d;
304    }
305    if node_dist_sq >= *best_dist_sq {
306        return;
307    }
308
309    if let Some(children) = &node.children {
310        for child in children.iter() {
311            collect_nn(child, query, best_dist_sq, best_idx, points);
312        }
313    } else {
314        for &idx in &node.point_indices {
315            let d2 = dist_sq(&points[idx], query);
316            if d2 < *best_dist_sq {
317                *best_dist_sq = d2;
318                *best_idx = Some(idx);
319            }
320        }
321    }
322}
323
324/// Find the nearest neighbour to `query`. Returns (index, distance).
325pub fn nearest_neighbor(octree: &Octree, query: [f32; 3]) -> Option<(usize, f32)> {
326    let mut best_dist_sq = f32::MAX;
327    let mut best_idx = None;
328    collect_nn(
329        &octree.root,
330        &query,
331        &mut best_dist_sq,
332        &mut best_idx,
333        &octree.points,
334    );
335    best_idx.map(|idx| (idx, best_dist_sq.sqrt()))
336}
337
338fn collect_knn(
339    node: &OctreeNode,
340    query: &[f32; 3],
341    heap: &mut Vec<(f32, usize)>, // (dist_sq, idx), max-heap by dist_sq
342    k: usize,
343    points: &[[f32; 3]],
344) {
345    let worst = if heap.len() == k { heap[0].0 } else { f32::MAX };
346    let mut node_dist_sq = 0.0_f32;
347    for ((&q, &bmin), &bmax) in query
348        .iter()
349        .zip(node.bounds_min.iter())
350        .zip(node.bounds_max.iter())
351    {
352        if bmin > bmax {
353            return;
354        } // inverted/empty bounds
355        let v = q.clamp(bmin, bmax);
356        let d = q - v;
357        node_dist_sq += d * d;
358    }
359    if node_dist_sq >= worst {
360        return;
361    }
362
363    if let Some(children) = &node.children {
364        for child in children.iter() {
365            collect_knn(child, query, heap, k, points);
366        }
367    } else {
368        for &idx in &node.point_indices {
369            let d2 = dist_sq(&points[idx], query);
370            let cur_worst = if heap.len() == k { heap[0].0 } else { f32::MAX };
371            if d2 < cur_worst {
372                if heap.len() == k {
373                    // Remove worst
374                    heap.remove(0);
375                }
376                // Insert maintaining max-heap property (simple insertion sort for small k)
377                let pos = heap.partition_point(|&(d, _)| d < d2);
378                heap.insert(pos, (d2, idx));
379            }
380        }
381    }
382}
383
384/// Return the k nearest neighbours sorted by ascending distance.
385pub fn k_nearest_neighbors(octree: &Octree, query: [f32; 3], k: usize) -> Vec<(usize, f32)> {
386    if k == 0 {
387        return Vec::new();
388    }
389    let mut heap: Vec<(f32, usize)> = Vec::with_capacity(k + 1);
390    collect_knn(&octree.root, &query, &mut heap, k, &octree.points);
391    heap.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
392    heap.into_iter().map(|(d2, idx)| (idx, d2.sqrt())).collect()
393}
394
395fn node_depth(node: &OctreeNode) -> u32 {
396    match &node.children {
397        None => 0,
398        Some(children) => 1 + children.iter().map(node_depth).max().unwrap_or(0),
399    }
400}
401
402/// Actual depth of the built octree (0 = single leaf).
403pub fn octree_depth(octree: &Octree) -> u32 {
404    node_depth(&octree.root)
405}
406
407fn count_leaves(node: &OctreeNode) -> usize {
408    match &node.children {
409        None => 1,
410        Some(children) => children.iter().map(count_leaves).sum(),
411    }
412}
413
414pub fn octree_leaf_count(octree: &Octree) -> usize {
415    count_leaves(&octree.root)
416}
417
418fn count_points(node: &OctreeNode) -> usize {
419    match &node.children {
420        None => node.point_indices.len(),
421        Some(children) => children.iter().map(count_points).sum(),
422    }
423}
424
425pub fn octree_point_count(octree: &Octree) -> usize {
426    count_points(&octree.root)
427}
428
429/// Return (depth, leaf_count, total_points).
430pub fn octree_stats(octree: &Octree) -> (u32, usize, usize) {
431    (
432        octree_depth(octree),
433        octree_leaf_count(octree),
434        octree_point_count(octree),
435    )
436}
437
438/// Insert a new point into the octree. Returns the new point's index.
439/// Note: rebuilds the affected leaf; simple but correct.
440pub fn insert_point(octree: &mut Octree, point: [f32; 3]) -> usize {
441    let idx = octree.points.len();
442    octree.points.push(point);
443    insert_into_node(
444        &mut octree.root,
445        point,
446        idx,
447        0,
448        octree.max_depth,
449        octree.max_points_per_leaf,
450        &octree.points.clone(),
451    );
452    idx
453}
454
455fn insert_into_node(
456    node: &mut OctreeNode,
457    point: [f32; 3],
458    idx: usize,
459    depth: u32,
460    max_depth: u32,
461    max_per_leaf: usize,
462    points: &[[f32; 3]],
463) {
464    if !aabb_contains(&node.bounds_min, &node.bounds_max, &point) {
465        // Expand bounds to include point
466        for ((&pv, bmin), bmax) in point
467            .iter()
468            .zip(node.bounds_min.iter_mut())
469            .zip(node.bounds_max.iter_mut())
470        {
471            if pv < *bmin {
472                *bmin = pv - 1e-4;
473            }
474            if pv > *bmax {
475                *bmax = pv + 1e-4;
476            }
477        }
478    }
479
480    if let Some(children) = &mut node.children {
481        // Find the right child
482        let mid = [
483            (node.bounds_min[0] + node.bounds_max[0]) * 0.5,
484            (node.bounds_min[1] + node.bounds_max[1]) * 0.5,
485            (node.bounds_min[2] + node.bounds_max[2]) * 0.5,
486        ];
487        let ox = if point[0] >= mid[0] { 1 } else { 0 };
488        let oy = if point[1] >= mid[1] { 2 } else { 0 };
489        let oz = if point[2] >= mid[2] { 4 } else { 0 };
490        insert_into_node(
491            &mut children[ox | oy | oz],
492            point,
493            idx,
494            depth + 1,
495            max_depth,
496            max_per_leaf,
497            points,
498        );
499    } else {
500        node.point_indices.push(idx);
501        // Split if over capacity and not at max depth
502        if node.point_indices.len() > max_per_leaf && depth < max_depth {
503            let all_indices = std::mem::take(&mut node.point_indices);
504            let min = node.bounds_min;
505            let max = node.bounds_max;
506            let mut child_indices: [Vec<usize>; 8] = Default::default();
507            let mid = [
508                (min[0] + max[0]) * 0.5,
509                (min[1] + max[1]) * 0.5,
510                (min[2] + max[2]) * 0.5,
511            ];
512            for i in all_indices {
513                let p = &points[i];
514                let ox = if p[0] >= mid[0] { 1 } else { 0 };
515                let oy = if p[1] >= mid[1] { 2 } else { 0 };
516                let oz = if p[2] >= mid[2] { 4 } else { 0 };
517                child_indices[ox | oy | oz].push(i);
518            }
519            let children: [OctreeNode; 8] = std::array::from_fn(|i| {
520                let (cmin, cmax) = octant_bounds(&min, &max, i);
521                OctreeNode {
522                    bounds_min: cmin,
523                    bounds_max: cmax,
524                    point_indices: std::mem::take(&mut child_indices[i]),
525                    children: None,
526                }
527            });
528            node.children = Some(Box::new(children));
529        }
530    }
531}
532
533fn collect_ray(
534    node: &OctreeNode,
535    origin: &[f32; 3],
536    inv_dir: &[f32; 3],
537    max_dist: f32,
538    result: &mut Vec<usize>,
539) {
540    if !aabb_overlaps_ray(
541        &node.bounds_min,
542        &node.bounds_max,
543        origin,
544        inv_dir,
545        max_dist,
546    ) {
547        return;
548    }
549    if let Some(children) = &node.children {
550        for child in children.iter() {
551            collect_ray(child, origin, inv_dir, max_dist, result);
552        }
553    } else {
554        // Return all points in leaf intersected by ray's AABB region
555        for &idx in &node.point_indices {
556            result.push(idx);
557        }
558    }
559}
560
561/// Return indices of points in leaf nodes intersected by the ray.
562pub fn ray_query(
563    octree: &Octree,
564    origin: [f32; 3],
565    direction: [f32; 3],
566    max_dist: f32,
567) -> Vec<usize> {
568    let inv_dir = [
569        if direction[0].abs() > 1e-10 {
570            1.0 / direction[0]
571        } else {
572            f32::INFINITY
573        },
574        if direction[1].abs() > 1e-10 {
575            1.0 / direction[1]
576        } else {
577            f32::INFINITY
578        },
579        if direction[2].abs() > 1e-10 {
580            1.0 / direction[2]
581        } else {
582            f32::INFINITY
583        },
584    ];
585    let mut result = Vec::new();
586    collect_ray(&octree.root, &origin, &inv_dir, max_dist, &mut result);
587    result.sort_unstable();
588    result.dedup();
589    result
590}
591
592// ─── Tests ────────────────────────────────────────────────────────────────────
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597
598    fn grid_points(n: usize) -> Vec<[f32; 3]> {
599        let mut pts = Vec::new();
600        for x in 0..n {
601            for y in 0..n {
602                for z in 0..n {
603                    pts.push([x as f32, y as f32, z as f32]);
604                }
605            }
606        }
607        pts
608    }
609
610    #[test]
611    fn test_build_empty() {
612        let oct = build_octree(&[], 4, 8);
613        assert_eq!(octree_point_count(&oct), 0);
614    }
615
616    #[test]
617    fn test_build_single_point() {
618        let pts = vec![[1.0, 2.0, 3.0]];
619        let oct = build_octree(&pts, 4, 8);
620        assert_eq!(octree_point_count(&oct), 1);
621    }
622
623    #[test]
624    fn test_query_sphere_finds_nearby() {
625        let pts = grid_points(5);
626        let oct = build_octree(&pts, 4, 8);
627        let result = query_sphere(&oct, [0.0, 0.0, 0.0], 1.5);
628        assert!(!result.is_empty());
629        // All returned points should be within radius
630        for idx in &result {
631            let p = oct.points[*idx];
632            let d = (p[0] * p[0] + p[1] * p[1] + p[2] * p[2]).sqrt();
633            assert!(d <= 1.5 + 1e-4);
634        }
635    }
636
637    #[test]
638    fn test_query_sphere_excludes_far() {
639        let pts = vec![[0.0, 0.0, 0.0], [100.0, 100.0, 100.0]];
640        let oct = build_octree(&pts, 4, 4);
641        let result = query_sphere(&oct, [0.0, 0.0, 0.0], 1.0);
642        assert_eq!(result.len(), 1);
643        assert_eq!(result[0], 0);
644    }
645
646    #[test]
647    fn test_query_aabb() {
648        let pts = grid_points(4);
649        let oct = build_octree(&pts, 4, 4);
650        let result = query_aabb(&oct, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
651        assert!(!result.is_empty());
652        for idx in &result {
653            let p = oct.points[*idx];
654            assert!(p[0] >= 0.0 && p[0] <= 1.0);
655            assert!(p[1] >= 0.0 && p[1] <= 1.0);
656            assert!(p[2] >= 0.0 && p[2] <= 1.0);
657        }
658    }
659
660    #[test]
661    fn test_nearest_neighbor_exact() {
662        let pts = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]];
663        let oct = build_octree(&pts, 4, 4);
664        let (idx, dist) = nearest_neighbor(&oct, [1.0, 0.0, 0.0]).expect("should succeed");
665        assert_eq!(idx, 1);
666        assert!(dist < 1e-5);
667    }
668
669    #[test]
670    fn test_nearest_neighbor_empty() {
671        let oct = build_octree(&[], 4, 4);
672        assert!(nearest_neighbor(&oct, [0.0, 0.0, 0.0]).is_none());
673    }
674
675    #[test]
676    fn test_k_nearest_neighbors() {
677        let pts = grid_points(4);
678        let oct = build_octree(&pts, 4, 4);
679        let knn = k_nearest_neighbors(&oct, [1.0, 1.0, 1.0], 3);
680        assert_eq!(knn.len(), 3);
681        // Should be sorted ascending by distance
682        for i in 0..knn.len() - 1 {
683            assert!(knn[i].1 <= knn[i + 1].1 + 1e-5);
684        }
685    }
686
687    #[test]
688    fn test_octree_depth() {
689        let pts = grid_points(3);
690        let oct = build_octree(&pts, 4, 2);
691        let depth = octree_depth(&oct);
692        assert!(depth > 0);
693        assert!(depth <= 4);
694    }
695
696    #[test]
697    fn test_octree_leaf_count_positive() {
698        let pts = grid_points(3);
699        let oct = build_octree(&pts, 3, 4);
700        assert!(octree_leaf_count(&oct) >= 1);
701    }
702
703    #[test]
704    fn test_octree_point_count_matches() {
705        let pts = grid_points(3);
706        let oct = build_octree(&pts, 4, 4);
707        assert_eq!(octree_point_count(&oct), pts.len());
708    }
709
710    #[test]
711    fn test_octree_stats() {
712        let pts = grid_points(3);
713        let oct = build_octree(&pts, 4, 4);
714        let (depth, leaves, total) = octree_stats(&oct);
715        assert_eq!(total, pts.len());
716        assert!(leaves >= 1);
717        assert!(depth <= 4);
718    }
719
720    #[test]
721    fn test_insert_point() {
722        let pts = vec![[0.0, 0.0, 0.0]];
723        let mut oct = build_octree(&pts, 4, 4);
724        let new_idx = insert_point(&mut oct, [5.0, 5.0, 5.0]);
725        assert_eq!(new_idx, 1);
726        assert_eq!(octree_point_count(&oct), 2);
727    }
728
729    #[test]
730    fn test_ray_query() {
731        let pts = vec![
732            [0.0, 0.0, 0.0],
733            [1.0, 0.0, 0.0],
734            [0.0, 5.0, 0.0], // far off axis
735        ];
736        let oct = build_octree(&pts, 4, 4);
737        let result = ray_query(&oct, [0.0, 0.0, -1.0], [0.0, 0.0, 1.0], 100.0);
738        // At least the point at z=0 should be in some intersected leaf
739        assert!(!result.is_empty());
740    }
741
742    #[test]
743    fn test_k_nearest_zero_k() {
744        let pts = grid_points(3);
745        let oct = build_octree(&pts, 4, 4);
746        let knn = k_nearest_neighbors(&oct, [0.0, 0.0, 0.0], 0);
747        assert!(knn.is_empty());
748    }
749}