Skip to main content

proof_engine/spatial/
mod.rs

1//! Spatial acceleration structures for fast proximity queries.
2//!
3//! # Structures
4//!
5//! - `SpatialGrid`   — uniform 3D grid (O(1) insert, O(k) range query)
6//! - `SpatialGrid2D` — 2D version for screen-space queries
7//! - `BvhNode`       — bounding volume hierarchy for static geometry
8//! - `KdTree`        — k-d tree for nearest-neighbor queries
9//! - `SpatialIndex`  — unified trait for all spatial structures
10//!
11//! Used for:
12//! - Fast glyph proximity (cohesion, repulsion forces)
13//! - Collision detection between entities
14//! - Field influence queries
15//! - Particle flocking neighbor search
16
17use glam::{Vec2, Vec3};
18use std::collections::HashMap;
19
20// ── SpatialIndex trait ────────────────────────────────────────────────────────
21
22/// Common interface for spatial acceleration structures.
23pub trait SpatialIndex<T: Clone> {
24    /// Insert an item at the given position.
25    fn insert(&mut self, pos: Vec3, item: T);
26
27    /// Query all items within `radius` of `center`.
28    fn query_radius(&self, center: Vec3, radius: f32) -> Vec<(T, Vec3, f32)>;
29
30    /// Query the `k` nearest items to `center`.
31    fn k_nearest(&self, center: Vec3, k: usize) -> Vec<(T, Vec3, f32)>;
32
33    /// Remove all items.
34    fn clear(&mut self);
35
36    /// Total number of stored items.
37    fn len(&self) -> usize;
38
39    fn is_empty(&self) -> bool { self.len() == 0 }
40}
41
42// ── SpatialGrid ───────────────────────────────────────────────────────────────
43
44/// A uniform 3D spatial hash grid.
45///
46/// Space is divided into cubic cells of `cell_size`. Items are stored in
47/// buckets by their cell coordinate. Queries scan all cells overlapping
48/// the query sphere.
49///
50/// # Complexity
51///
52/// - Insert: O(1) average
53/// - Range query: O(k + m) where k = items in range, m = cells overlapping sphere
54/// - Rebuild: O(n)
55pub struct SpatialGrid<T: Clone> {
56    /// Cell size (world units per cell edge).
57    pub cell_size: f32,
58    /// Grid cells: (cx, cy, cz) → [(position, item)].
59    cells: HashMap<(i32, i32, i32), Vec<(Vec3, T)>>,
60    item_count: usize,
61}
62
63impl<T: Clone> SpatialGrid<T> {
64    pub fn new(cell_size: f32) -> Self {
65        Self {
66            cell_size: cell_size.max(0.001),
67            cells: HashMap::new(),
68            item_count: 0,
69        }
70    }
71
72    fn cell_key(&self, pos: Vec3) -> (i32, i32, i32) {
73        (
74            (pos.x / self.cell_size).floor() as i32,
75            (pos.y / self.cell_size).floor() as i32,
76            (pos.z / self.cell_size).floor() as i32,
77        )
78    }
79
80    pub fn insert(&mut self, pos: Vec3, item: T) {
81        let key = self.cell_key(pos);
82        self.cells.entry(key).or_default().push((pos, item));
83        self.item_count += 1;
84    }
85
86    /// Query all items within `radius` of `center`, returning (item, position, distance).
87    pub fn query_radius(&self, center: Vec3, radius: f32) -> Vec<(T, Vec3, f32)> {
88        let r2 = radius * radius;
89        let cell_r = (radius / self.cell_size).ceil() as i32;
90        let (cx, cy, cz) = self.cell_key(center);
91
92        let mut results = Vec::new();
93        for dx in -cell_r..=cell_r {
94            for dy in -cell_r..=cell_r {
95                for dz in -cell_r..=cell_r {
96                    let key = (cx + dx, cy + dy, cz + dz);
97                    if let Some(bucket) = self.cells.get(&key) {
98                        for (pos, item) in bucket {
99                            let d2 = (*pos - center).length_squared();
100                            if d2 <= r2 {
101                                results.push((item.clone(), *pos, d2.sqrt()));
102                            }
103                        }
104                    }
105                }
106            }
107        }
108        results
109    }
110
111    /// Find the k nearest items to `center`.
112    pub fn k_nearest(&self, center: Vec3, k: usize) -> Vec<(T, Vec3, f32)> {
113        // Start with small radius and expand until we have k results
114        let mut radius = self.cell_size;
115        let mut results;
116        loop {
117            results = self.query_radius(center, radius);
118            if results.len() >= k || radius > 1000.0 { break; }
119            radius *= 2.0;
120        }
121        results.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
122        results.truncate(k);
123        results
124    }
125
126    /// Find the single nearest item.
127    pub fn nearest(&self, center: Vec3) -> Option<(T, Vec3, f32)> {
128        self.k_nearest(center, 1).into_iter().next()
129    }
130
131    pub fn clear(&mut self) {
132        self.cells.clear();
133        self.item_count = 0;
134    }
135
136    pub fn len(&self) -> usize { self.item_count }
137    pub fn is_empty(&self) -> bool { self.item_count == 0 }
138    pub fn bucket_count(&self) -> usize { self.cells.len() }
139
140    /// Average items per occupied bucket (load factor).
141    pub fn avg_bucket_load(&self) -> f32 {
142        if self.cells.is_empty() { return 0.0; }
143        self.item_count as f32 / self.cells.len() as f32
144    }
145
146    /// Rebuild from an iterator of (position, item) pairs.
147    pub fn rebuild(&mut self, items: impl Iterator<Item = (Vec3, T)>) {
148        self.clear();
149        for (pos, item) in items {
150            self.insert(pos, item);
151        }
152    }
153
154    /// Iterate over all items.
155    pub fn iter(&self) -> impl Iterator<Item = (&Vec3, &T)> {
156        self.cells.values().flat_map(|bucket| bucket.iter().map(|(p, t)| (p, t)))
157    }
158}
159
160// ── SpatialGrid2D ─────────────────────────────────────────────────────────────
161
162/// A uniform 2D spatial hash grid for screen-space queries.
163pub struct SpatialGrid2D<T: Clone> {
164    pub cell_size:  f32,
165    cells: HashMap<(i32, i32), Vec<(Vec2, T)>>,
166    item_count: usize,
167}
168
169impl<T: Clone> SpatialGrid2D<T> {
170    pub fn new(cell_size: f32) -> Self {
171        Self { cell_size: cell_size.max(0.001), cells: HashMap::new(), item_count: 0 }
172    }
173
174    fn cell_key(&self, pos: Vec2) -> (i32, i32) {
175        (
176            (pos.x / self.cell_size).floor() as i32,
177            (pos.y / self.cell_size).floor() as i32,
178        )
179    }
180
181    pub fn insert(&mut self, pos: Vec2, item: T) {
182        let key = self.cell_key(pos);
183        self.cells.entry(key).or_default().push((pos, item));
184        self.item_count += 1;
185    }
186
187    pub fn query_radius(&self, center: Vec2, radius: f32) -> Vec<(T, Vec2, f32)> {
188        let r2 = radius * radius;
189        let cell_r = (radius / self.cell_size).ceil() as i32;
190        let (cx, cy) = self.cell_key(center);
191        let mut results = Vec::new();
192        for dx in -cell_r..=cell_r {
193            for dy in -cell_r..=cell_r {
194                if let Some(bucket) = self.cells.get(&(cx + dx, cy + dy)) {
195                    for (pos, item) in bucket {
196                        let d2 = (*pos - center).length_squared();
197                        if d2 <= r2 {
198                            results.push((item.clone(), *pos, d2.sqrt()));
199                        }
200                    }
201                }
202            }
203        }
204        results
205    }
206
207    pub fn k_nearest(&self, center: Vec2, k: usize) -> Vec<(T, Vec2, f32)> {
208        let mut radius = self.cell_size;
209        let mut results;
210        loop {
211            results = self.query_radius(center, radius);
212            if results.len() >= k || radius > 10000.0 { break; }
213            radius *= 2.0;
214        }
215        results.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
216        results.truncate(k);
217        results
218    }
219
220    pub fn clear(&mut self) { self.cells.clear(); self.item_count = 0; }
221    pub fn len(&self) -> usize { self.item_count }
222    pub fn is_empty(&self) -> bool { self.item_count == 0 }
223}
224
225// ── AABB ──────────────────────────────────────────────────────────────────────
226
227/// Axis-aligned bounding box in 3D.
228#[derive(Debug, Clone, Copy, PartialEq)]
229pub struct Aabb {
230    pub min: Vec3,
231    pub max: Vec3,
232}
233
234impl Aabb {
235    pub fn new(min: Vec3, max: Vec3) -> Self { Self { min, max } }
236
237    pub fn from_center_half_extents(center: Vec3, half: Vec3) -> Self {
238        Self { min: center - half, max: center + half }
239    }
240
241    pub fn from_sphere(center: Vec3, radius: f32) -> Self {
242        let r = Vec3::splat(radius);
243        Self { min: center - r, max: center + r }
244    }
245
246    pub fn center(&self) -> Vec3 { (self.min + self.max) * 0.5 }
247    pub fn half_extents(&self) -> Vec3 { (self.max - self.min) * 0.5 }
248    pub fn size(&self) -> Vec3 { self.max - self.min }
249
250    pub fn contains_point(&self, p: Vec3) -> bool {
251        p.x >= self.min.x && p.x <= self.max.x
252            && p.y >= self.min.y && p.y <= self.max.y
253            && p.z >= self.min.z && p.z <= self.max.z
254    }
255
256    pub fn intersects(&self, other: &Aabb) -> bool {
257        self.min.x <= other.max.x && self.max.x >= other.min.x
258            && self.min.y <= other.max.y && self.max.y >= other.min.y
259            && self.min.z <= other.max.z && self.max.z >= other.min.z
260    }
261
262    pub fn intersects_sphere(&self, center: Vec3, radius: f32) -> bool {
263        let closest = center.clamp(self.min, self.max);
264        (closest - center).length_squared() <= radius * radius
265    }
266
267    /// Expand to include another AABB.
268    pub fn union(&self, other: &Aabb) -> Aabb {
269        Aabb {
270            min: self.min.min(other.min),
271            max: self.max.max(other.max),
272        }
273    }
274
275    /// Expand by `amount` in all directions.
276    pub fn expand(&self, amount: f32) -> Aabb {
277        let e = Vec3::splat(amount);
278        Aabb { min: self.min - e, max: self.max + e }
279    }
280
281    pub fn surface_area(&self) -> f32 {
282        let s = self.size();
283        2.0 * (s.x * s.y + s.y * s.z + s.z * s.x)
284    }
285
286    /// Ray-AABB intersection. Returns (t_min, t_max) if hit, None otherwise.
287    pub fn ray_intersect(&self, origin: Vec3, dir: Vec3) -> Option<(f32, f32)> {
288        let inv_dir = Vec3::new(
289            if dir.x != 0.0 { 1.0 / dir.x } else { f32::INFINITY },
290            if dir.y != 0.0 { 1.0 / dir.y } else { f32::INFINITY },
291            if dir.z != 0.0 { 1.0 / dir.z } else { f32::INFINITY },
292        );
293        let t1 = (self.min - origin) * inv_dir;
294        let t2 = (self.max - origin) * inv_dir;
295        let t_min = t1.min(t2).max_element();
296        let t_max = t1.max(t2).min_element();
297        if t_max >= t_min && t_max >= 0.0 { Some((t_min, t_max)) } else { None }
298    }
299}
300
301// ── BvhNode ───────────────────────────────────────────────────────────────────
302
303/// A node in a Bounding Volume Hierarchy.
304///
305/// BVH provides O(log n) ray queries and O(log n) sphere overlap queries.
306/// Build using `Bvh::build(items)` where each item is an `(Aabb, T)` pair.
307#[derive(Debug, Clone)]
308pub enum BvhNode<T: Clone> {
309    Leaf {
310        bounds: Aabb,
311        item:   T,
312    },
313    Branch {
314        bounds: Aabb,
315        left:   Box<BvhNode<T>>,
316        right:  Box<BvhNode<T>>,
317    },
318}
319
320impl<T: Clone> BvhNode<T> {
321    pub fn bounds(&self) -> Aabb {
322        match self {
323            BvhNode::Leaf { bounds, .. } => *bounds,
324            BvhNode::Branch { bounds, .. } => *bounds,
325        }
326    }
327
328    /// Query all items whose AABB overlaps `sphere(center, radius)`.
329    pub fn query_sphere(&self, center: Vec3, radius: f32, out: &mut Vec<T>) {
330        if !self.bounds().intersects_sphere(center, radius) { return; }
331        match self {
332            BvhNode::Leaf { item, .. } => out.push(item.clone()),
333            BvhNode::Branch { left, right, .. } => {
334                left.query_sphere(center, radius, out);
335                right.query_sphere(center, radius, out);
336            }
337        }
338    }
339
340    /// Ray query — returns all items whose AABB the ray intersects.
341    pub fn query_ray(&self, origin: Vec3, dir: Vec3, max_t: f32, out: &mut Vec<(T, f32)>) {
342        match self.bounds().ray_intersect(origin, dir) {
343            None => return,
344            Some((t_min, _)) if t_min > max_t => return,
345            _ => {}
346        }
347        match self {
348            BvhNode::Leaf { item, bounds } => {
349                if let Some((t, _)) = bounds.ray_intersect(origin, dir) {
350                    out.push((item.clone(), t));
351                }
352            }
353            BvhNode::Branch { left, right, .. } => {
354                left.query_ray(origin, dir, max_t, out);
355                right.query_ray(origin, dir, max_t, out);
356            }
357        }
358    }
359
360    /// AABB overlap query.
361    pub fn query_aabb(&self, query: &Aabb, out: &mut Vec<T>) {
362        if !self.bounds().intersects(query) { return; }
363        match self {
364            BvhNode::Leaf { item, bounds } => {
365                if bounds.intersects(query) { out.push(item.clone()); }
366            }
367            BvhNode::Branch { left, right, .. } => {
368                left.query_aabb(query, out);
369                right.query_aabb(query, out);
370            }
371        }
372    }
373}
374
375/// A complete BVH tree.
376pub struct Bvh<T: Clone> {
377    root: Option<BvhNode<T>>,
378    pub item_count: usize,
379}
380
381impl<T: Clone> Bvh<T> {
382    /// Build a BVH from a list of (bounds, item) pairs using SAH heuristic.
383    pub fn build(items: Vec<(Aabb, T)>) -> Self {
384        let count = items.len();
385        let root = if items.is_empty() { None } else { Some(Self::build_recursive(items)) };
386        Self { root, item_count: count }
387    }
388
389    fn build_recursive(mut items: Vec<(Aabb, T)>) -> BvhNode<T> {
390        if items.len() == 1 {
391            let (bounds, item) = items.remove(0);
392            return BvhNode::Leaf { bounds, item };
393        }
394
395        // Compute combined bounds
396        let bounds = items.iter()
397            .map(|(b, _)| *b)
398            .reduce(|a, b| a.union(&b))
399            .unwrap();
400
401        // Split along the longest axis at the centroid median
402        let size = bounds.size();
403        let axis = if size.x >= size.y && size.x >= size.z { 0 }
404                   else if size.y >= size.z { 1 }
405                   else { 2 };
406
407        let centroid = |b: &Aabb| -> f32 {
408            match axis { 0 => b.center().x, 1 => b.center().y, _ => b.center().z }
409        };
410
411        items.sort_by(|(a, _), (b, _)| centroid(a).partial_cmp(&centroid(b))
412            .unwrap_or(std::cmp::Ordering::Equal));
413
414        let mid = items.len() / 2;
415        let right_items = items.split_off(mid);
416        let left_items = items;
417
418        let left = Box::new(Self::build_recursive(left_items));
419        let right = Box::new(Self::build_recursive(right_items));
420
421        BvhNode::Branch { bounds, left, right }
422    }
423
424    pub fn query_sphere(&self, center: Vec3, radius: f32) -> Vec<T> {
425        let mut out = Vec::new();
426        if let Some(root) = &self.root {
427            root.query_sphere(center, radius, &mut out);
428        }
429        out
430    }
431
432    pub fn query_ray(&self, origin: Vec3, dir: Vec3, max_t: f32) -> Vec<(T, f32)> {
433        let mut out = Vec::new();
434        if let Some(root) = &self.root {
435            root.query_ray(origin, dir.normalize_or_zero(), max_t, &mut out);
436        }
437        out.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
438        out
439    }
440
441    pub fn query_aabb(&self, bounds: &Aabb) -> Vec<T> {
442        let mut out = Vec::new();
443        if let Some(root) = &self.root {
444            root.query_aabb(bounds, &mut out);
445        }
446        out
447    }
448
449    pub fn is_empty(&self) -> bool { self.root.is_none() }
450    pub fn item_count(&self) -> usize { self.item_count }
451}
452
453// ── KdTree ────────────────────────────────────────────────────────────────────
454
455/// A 3D k-d tree for efficient nearest-neighbor queries.
456///
457/// Best for static point clouds queried repeatedly.
458/// Build with `KdTree::build(points)`.
459#[derive(Debug, Clone)]
460pub struct KdTree<T: Clone> {
461    nodes: Vec<KdNode<T>>,
462    pub item_count: usize,
463}
464
465#[derive(Debug, Clone)]
466struct KdNode<T: Clone> {
467    pos:   Vec3,
468    item:  T,
469    left:  Option<usize>,
470    right: Option<usize>,
471    axis:  u8,  // 0=x, 1=y, 2=z
472}
473
474impl<T: Clone> KdTree<T> {
475    /// Build a k-d tree from a list of (position, item) pairs.
476    pub fn build(points: Vec<(Vec3, T)>) -> Self {
477        let count = points.len();
478        let mut tree = Self { nodes: Vec::with_capacity(count), item_count: count };
479        if !points.is_empty() {
480            tree.build_recursive(points, 0);
481        }
482        tree
483    }
484
485    fn build_recursive(&mut self, mut points: Vec<(Vec3, T)>, depth: usize) -> usize {
486        let axis = (depth % 3) as u8;
487
488        // Sort by the current axis
489        points.sort_by(|a, b| {
490            let va = match axis { 0 => a.0.x, 1 => a.0.y, _ => a.0.z };
491            let vb = match axis { 0 => b.0.x, 1 => b.0.y, _ => b.0.z };
492            va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
493        });
494
495        let mid = points.len() / 2;
496        let mut right_points = points.split_off(mid + 1);
497        let left_points = points.split_off(mid);
498        let (pos, item) = points.remove(0);
499
500        let idx = self.nodes.len();
501        self.nodes.push(KdNode { pos, item, left: None, right: None, axis });
502
503        if !left_points.is_empty() {
504            let left_idx = self.build_recursive(left_points, depth + 1);
505            self.nodes[idx].left = Some(left_idx);
506        }
507        if !right_points.is_empty() {
508            let right_idx = self.build_recursive(right_points, depth + 1);
509            self.nodes[idx].right = Some(right_idx);
510        }
511
512        idx
513    }
514
515    /// Find the k nearest neighbors to `query`.
516    /// Returns (item, position, distance) tuples sorted by distance.
517    pub fn k_nearest(&self, query: Vec3, k: usize) -> Vec<(T, Vec3, f32)> {
518        if self.nodes.is_empty() { return Vec::new(); }
519        let mut heap: Vec<(f32, usize)> = Vec::new(); // (dist_sq, node_idx)
520        self.nn_search(0, query, k, &mut heap);
521        heap.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
522        heap.into_iter().map(|(d2, idx)| {
523            let node = &self.nodes[idx];
524            (node.item.clone(), node.pos, d2.sqrt())
525        }).collect()
526    }
527
528    fn nn_search(&self, node_idx: usize, query: Vec3, k: usize, heap: &mut Vec<(f32, usize)>) {
529        let node = &self.nodes[node_idx];
530        let d2 = (node.pos - query).length_squared();
531
532        // Check if this node belongs in the heap
533        let worst = heap.iter().map(|&(d, _)| d).fold(f32::NEG_INFINITY, f32::max);
534        if heap.len() < k || d2 < worst {
535            heap.push((d2, node_idx));
536            if heap.len() > k {
537                // Remove worst
538                let worst_idx = heap.iter()
539                    .enumerate()
540                    .max_by(|a, b| a.1.0.partial_cmp(&b.1.0).unwrap_or(std::cmp::Ordering::Equal))
541                    .map(|(i, _)| i)
542                    .unwrap();
543                heap.swap_remove(worst_idx);
544            }
545        }
546
547        // Determine which subtree to explore first
548        let axis_val = match node.axis { 0 => query.x, 1 => query.y, _ => query.z };
549        let node_val = match node.axis { 0 => node.pos.x, 1 => node.pos.y, _ => node.pos.z };
550        let (near, far) = if axis_val <= node_val {
551            (node.left, node.right)
552        } else {
553            (node.right, node.left)
554        };
555
556        if let Some(near_idx) = near {
557            self.nn_search(near_idx, query, k, heap);
558        }
559
560        // Check if far side could have closer points
561        let plane_dist_sq = (axis_val - node_val) * (axis_val - node_val);
562        let current_worst = heap.iter().map(|&(d, _)| d).fold(f32::NEG_INFINITY, f32::max);
563        if let Some(far_idx) = far {
564            if heap.len() < k || plane_dist_sq < current_worst {
565                self.nn_search(far_idx, query, k, heap);
566            }
567        }
568    }
569
570    /// Find all points within `radius` of `query`.
571    pub fn radius_search(&self, query: Vec3, radius: f32) -> Vec<(T, Vec3, f32)> {
572        if self.nodes.is_empty() { return Vec::new(); }
573        let mut results = Vec::new();
574        self.range_search(0, query, radius * radius, &mut results);
575        results.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
576        results
577    }
578
579    fn range_search(&self, node_idx: usize, query: Vec3, r2: f32, out: &mut Vec<(T, Vec3, f32)>) {
580        let node = &self.nodes[node_idx];
581        let d2 = (node.pos - query).length_squared();
582        if d2 <= r2 {
583            out.push((node.item.clone(), node.pos, d2.sqrt()));
584        }
585
586        let axis_val = match node.axis { 0 => query.x, 1 => query.y, _ => query.z };
587        let node_val = match node.axis { 0 => node.pos.x, 1 => node.pos.y, _ => node.pos.z };
588        let plane_d2 = (axis_val - node_val) * (axis_val - node_val);
589
590        let (near, far) = if axis_val <= node_val {
591            (node.left, node.right)
592        } else {
593            (node.right, node.left)
594        };
595
596        if let Some(near_idx) = near {
597            self.range_search(near_idx, query, r2, out);
598        }
599        if plane_d2 <= r2 {
600            if let Some(far_idx) = far {
601                self.range_search(far_idx, query, r2, out);
602            }
603        }
604    }
605
606    pub fn is_empty(&self) -> bool { self.nodes.is_empty() }
607    pub fn len(&self) -> usize { self.item_count }
608}
609
610// ── Frustum culling ───────────────────────────────────────────────────────────
611
612/// A camera frustum for view culling.
613///
614/// Used to discard objects outside the camera view before rendering.
615/// Defined by six half-space planes: left, right, top, bottom, near, far.
616#[derive(Debug, Clone)]
617pub struct Frustum {
618    /// Six plane normals (pointing inward).
619    planes: [(Vec3, f32); 6],  // (normal, d) where n·p + d >= 0 means inside
620}
621
622impl Frustum {
623    /// Build a frustum from a view-projection matrix.
624    ///
625    /// Works with row-major matrices (as provided by `glam::Mat4`).
626    pub fn from_matrix(vp: glam::Mat4) -> Self {
627        let m = vp.to_cols_array_2d();
628        // Gribb-Hartmann method
629        let planes_raw = [
630            // Left:   row3 + row0
631            [m[0][3]+m[0][0], m[1][3]+m[1][0], m[2][3]+m[2][0], m[3][3]+m[3][0]],
632            // Right:  row3 - row0
633            [m[0][3]-m[0][0], m[1][3]-m[1][0], m[2][3]-m[2][0], m[3][3]-m[3][0]],
634            // Bottom: row3 + row1
635            [m[0][3]+m[0][1], m[1][3]+m[1][1], m[2][3]+m[2][1], m[3][3]+m[3][1]],
636            // Top:    row3 - row1
637            [m[0][3]-m[0][1], m[1][3]-m[1][1], m[2][3]-m[2][1], m[3][3]-m[3][1]],
638            // Near:   row3 + row2
639            [m[0][3]+m[0][2], m[1][3]+m[1][2], m[2][3]+m[2][2], m[3][3]+m[3][2]],
640            // Far:    row3 - row2
641            [m[0][3]-m[0][2], m[1][3]-m[1][2], m[2][3]-m[2][2], m[3][3]-m[3][2]],
642        ];
643
644        let mut planes = [(Vec3::ZERO, 0.0_f32); 6];
645        for (i, raw) in planes_raw.iter().enumerate() {
646            let n = Vec3::new(raw[0], raw[1], raw[2]);
647            let len = n.length().max(1e-6);
648            planes[i] = (n / len, raw[3] / len);
649        }
650        Self { planes }
651    }
652
653    /// Test if a sphere overlaps the frustum.
654    pub fn sphere_inside(&self, center: Vec3, radius: f32) -> bool {
655        for &(n, d) in &self.planes {
656            if n.dot(center) + d < -radius {
657                return false;
658            }
659        }
660        true
661    }
662
663    /// Test if an AABB overlaps the frustum (conservative test).
664    pub fn aabb_inside(&self, bounds: &Aabb) -> bool {
665        let center = bounds.center();
666        let half = bounds.half_extents();
667        for &(n, d) in &self.planes {
668            // Compute positive vertex (farthest in plane normal direction)
669            let r = half.x * n.x.abs() + half.y * n.y.abs() + half.z * n.z.abs();
670            if n.dot(center) + d < -r {
671                return false;
672            }
673        }
674        true
675    }
676
677    /// Test if a point is inside the frustum.
678    pub fn point_inside(&self, p: Vec3) -> bool {
679        self.planes.iter().all(|&(n, d)| n.dot(p) + d >= 0.0)
680    }
681}
682
683// ── Proximity pairs ───────────────────────────────────────────────────────────
684
685/// Find all pairs of points closer than `max_dist`.
686///
687/// Returns `(i, j, distance)` for each pair where i < j.
688/// Uses a spatial grid for O(n log n) performance.
689pub fn find_close_pairs(positions: &[Vec3], max_dist: f32) -> Vec<(usize, usize, f32)> {
690    let mut grid: SpatialGrid<usize> = SpatialGrid::new(max_dist);
691    for (i, &pos) in positions.iter().enumerate() {
692        grid.insert(pos, i);
693    }
694
695    let mut pairs = Vec::new();
696    let r2 = max_dist * max_dist;
697    for (i, &pos) in positions.iter().enumerate() {
698        let nearby = grid.query_radius(pos, max_dist);
699        for (j, npos, _) in nearby {
700            if j > i {
701                let d2 = (pos - npos).length_squared();
702                if d2 <= r2 {
703                    pairs.push((i, j, d2.sqrt()));
704                }
705            }
706        }
707    }
708    pairs
709}
710
711/// Find all positions within `radius` of any of the given `query_points`.
712///
713/// Returns pairs of (query_index, position_index, distance).
714pub fn batch_radius_query(
715    query_points: &[Vec3],
716    positions: &[Vec3],
717    radius: f32,
718) -> Vec<(usize, usize, f32)> {
719    let mut grid: SpatialGrid<usize> = SpatialGrid::new(radius);
720    for (i, &pos) in positions.iter().enumerate() {
721        grid.insert(pos, i);
722    }
723    let mut results = Vec::new();
724    for (qi, &qpos) in query_points.iter().enumerate() {
725        for (idx, _, dist) in grid.query_radius(qpos, radius) {
726            results.push((qi, idx, dist));
727        }
728    }
729    results
730}
731
732// ── Tests ──────────────────────────────────────────────────────────────────────
733
734#[cfg(test)]
735mod tests {
736    use super::*;
737
738    #[test]
739    fn spatial_grid_insert_query() {
740        let mut grid: SpatialGrid<u32> = SpatialGrid::new(1.0);
741        grid.insert(Vec3::new(0.5, 0.5, 0.5), 1);
742        grid.insert(Vec3::new(10.0, 10.0, 10.0), 2);
743
744        let near = grid.query_radius(Vec3::ZERO, 2.0);
745        assert_eq!(near.len(), 1, "should find 1 item near origin");
746        assert_eq!(*near[0].0, 1u32);
747    }
748
749    #[test]
750    fn spatial_grid_k_nearest() {
751        let mut grid: SpatialGrid<usize> = SpatialGrid::new(0.5);
752        for i in 0..10 {
753            grid.insert(Vec3::new(i as f32, 0.0, 0.0), i);
754        }
755        let nn = grid.k_nearest(Vec3::new(4.5, 0.0, 0.0), 2);
756        assert_eq!(nn.len(), 2);
757        // Nearest two should be 4 and 5
758        let mut ids: Vec<usize> = nn.iter().map(|(id, _, _)| **id).collect();
759        ids.sort();
760        assert_eq!(ids, vec![4, 5]);
761    }
762
763    #[test]
764    fn aabb_intersects() {
765        let a = Aabb::new(Vec3::ZERO, Vec3::ONE);
766        let b = Aabb::new(Vec3::new(0.5, 0.5, 0.5), Vec3::new(1.5, 1.5, 1.5));
767        let c = Aabb::new(Vec3::new(2.0, 0.0, 0.0), Vec3::new(3.0, 1.0, 1.0));
768        assert!(a.intersects(&b));
769        assert!(!a.intersects(&c));
770    }
771
772    #[test]
773    fn aabb_ray_hit() {
774        let aabb = Aabb::new(Vec3::splat(-1.0), Vec3::splat(1.0));
775        let hit = aabb.ray_intersect(Vec3::new(-5.0, 0.0, 0.0), Vec3::X);
776        assert!(hit.is_some(), "ray should hit the AABB");
777        let miss = aabb.ray_intersect(Vec3::new(-5.0, 5.0, 0.0), Vec3::X);
778        assert!(miss.is_none(), "ray should miss");
779    }
780
781    #[test]
782    fn bvh_sphere_query() {
783        let items: Vec<(Aabb, usize)> = (0..10).map(|i| {
784            let c = Vec3::new(i as f32, 0.0, 0.0);
785            (Aabb::from_sphere(c, 0.4), i)
786        }).collect();
787        let bvh = Bvh::build(items);
788        let hits = bvh.query_sphere(Vec3::new(4.5, 0.0, 0.0), 1.0);
789        // Items at x=4 and x=5 should be hit
790        assert!(hits.len() >= 2, "should hit at least 2 items, got {}", hits.len());
791    }
792
793    #[test]
794    fn kd_tree_nearest() {
795        let points: Vec<(Vec3, usize)> = (0..10).map(|i| {
796            (Vec3::new(i as f32, 0.0, 0.0), i)
797        }).collect();
798        let tree = KdTree::build(points);
799        let nn = tree.k_nearest(Vec3::new(3.1, 0.0, 0.0), 1);
800        assert_eq!(nn.len(), 1);
801        assert_eq!(*nn[0].0, 3usize);
802    }
803
804    #[test]
805    fn kd_tree_radius_search() {
806        let points: Vec<(Vec3, usize)> = (0..20).map(|i| {
807            (Vec3::new(i as f32, 0.0, 0.0), i)
808        }).collect();
809        let tree = KdTree::build(points);
810        let results = tree.radius_search(Vec3::new(10.0, 0.0, 0.0), 2.5);
811        // Should find 8,9,10,11,12
812        assert!(results.len() >= 4, "expected at least 4 in radius, got {}", results.len());
813    }
814
815    #[test]
816    fn find_close_pairs_correct() {
817        let positions = vec![
818            Vec3::new(0.0, 0.0, 0.0),
819            Vec3::new(0.5, 0.0, 0.0),
820            Vec3::new(10.0, 0.0, 0.0),
821        ];
822        let pairs = find_close_pairs(&positions, 1.0);
823        assert_eq!(pairs.len(), 1, "should find exactly one close pair");
824        assert_eq!(pairs[0].0, 0);
825        assert_eq!(pairs[0].1, 1);
826    }
827
828    #[test]
829    fn spatial_grid_2d_query() {
830        let mut grid: SpatialGrid2D<u32> = SpatialGrid2D::new(1.0);
831        grid.insert(Vec2::new(0.5, 0.5), 10);
832        grid.insert(Vec2::new(5.0, 5.0), 20);
833        let near = grid.query_radius(Vec2::ZERO, 1.5);
834        assert_eq!(near.len(), 1);
835        assert_eq!(*near[0].0, 10u32);
836    }
837}