Skip to main content

proof_engine/scene/
bvh.rs

1//! Bounding Volume Hierarchy for fast spatial queries.
2
3use crate::glyph::GlyphId;
4use glam::Vec3;
5
6/// Axis-aligned bounding box.
7#[derive(Clone, Debug, PartialEq)]
8pub struct Aabb {
9    pub min: Vec3,
10    pub max: Vec3,
11}
12
13impl Aabb {
14    pub fn new(min: Vec3, max: Vec3) -> Self { Self { min, max } }
15
16    pub fn empty() -> Self {
17        Self { min: Vec3::splat(f32::MAX), max: Vec3::splat(f32::MIN) }
18    }
19
20    pub fn from_point(p: Vec3, half: f32) -> Self {
21        Self { min: p - Vec3::splat(half), max: p + Vec3::splat(half) }
22    }
23
24    pub fn union(&self, other: &Aabb) -> Self {
25        Self { min: self.min.min(other.min), max: self.max.max(other.max) }
26    }
27
28    pub fn center(&self) -> Vec3 { (self.min + self.max) * 0.5 }
29    pub fn extents(&self) -> Vec3 { (self.max - self.min) * 0.5 }
30    pub fn surface_area(&self) -> f32 {
31        let d = self.max - self.min;
32        2.0 * (d.x * d.y + d.y * d.z + d.z * d.x)
33    }
34
35    pub fn contains_point(&self, p: Vec3) -> bool {
36        p.x >= self.min.x && p.x <= self.max.x &&
37        p.y >= self.min.y && p.y <= self.max.y &&
38        p.z >= self.min.z && p.z <= self.max.z
39    }
40
41    pub fn intersects_sphere(&self, center: Vec3, radius: f32) -> bool {
42        let closest = center.clamp(self.min, self.max);
43        (closest - center).length_squared() <= radius * radius
44    }
45
46    pub fn intersects_aabb(&self, other: &Aabb) -> bool {
47        self.min.x <= other.max.x && self.max.x >= other.min.x &&
48        self.min.y <= other.max.y && self.max.y >= other.min.y &&
49        self.min.z <= other.max.z && self.max.z >= other.min.z
50    }
51
52    /// Ray-AABB intersection (slab method). Returns distance or None.
53    pub fn ray_intersect(&self, origin: Vec3, dir_inv: Vec3) -> Option<f32> {
54        let t1 = (self.min - origin) * dir_inv;
55        let t2 = (self.max - origin) * dir_inv;
56        let tmin = t1.min(t2);
57        let tmax = t1.max(t2);
58        let enter = tmin.x.max(tmin.y).max(tmin.z);
59        let exit  = tmax.x.min(tmax.y).min(tmax.z);
60        if exit >= enter && exit >= 0.0 { Some(enter.max(0.0)) } else { None }
61    }
62
63    pub fn longest_axis(&self) -> usize {
64        let d = self.max - self.min;
65        if d.x >= d.y && d.x >= d.z { 0 } else if d.y >= d.z { 1 } else { 2 }
66    }
67}
68
69// ─── BVH node ─────────────────────────────────────────────────────────────────
70
71#[derive(Clone, Debug)]
72pub enum BvhNode {
73    Leaf {
74        aabb:  Aabb,
75        items: Vec<GlyphId>,
76    },
77    Internal {
78        aabb:  Aabb,
79        left:  Box<BvhNode>,
80        right: Box<BvhNode>,
81    },
82}
83
84impl BvhNode {
85    pub fn aabb(&self) -> &Aabb {
86        match self {
87            Self::Leaf   { aabb, .. }     => aabb,
88            Self::Internal { aabb, .. }   => aabb,
89        }
90    }
91
92    pub fn count(&self) -> usize {
93        match self {
94            Self::Leaf   { items, .. }   => items.len(),
95            Self::Internal { left, right, .. } => left.count() + right.count(),
96        }
97    }
98
99    pub fn depth(&self) -> usize {
100        match self {
101            Self::Leaf { .. } => 1,
102            Self::Internal { left, right, .. } => 1 + left.depth().max(right.depth()),
103        }
104    }
105}
106
107// ─── BVH ──────────────────────────────────────────────────────────────────────
108
109/// A flat BVH built from (GlyphId, Aabb) pairs.
110pub struct Bvh {
111    pub root: BvhNode,
112}
113
114impl Bvh {
115    const LEAF_MAX: usize = 4;
116
117    pub fn build(items: &[(GlyphId, Aabb)]) -> Self {
118        if items.is_empty() {
119            return Self {
120                root: BvhNode::Leaf { aabb: Aabb::empty(), items: Vec::new() },
121            };
122        }
123        let root = Self::build_node(items);
124        Self { root }
125    }
126
127    fn build_node(items: &[(GlyphId, Aabb)]) -> BvhNode {
128        // Compute enclosing AABB
129        let mut aabb = items[0].1.clone();
130        for (_, b) in &items[1..] { aabb = aabb.union(b); }
131
132        if items.len() <= Self::LEAF_MAX {
133            return BvhNode::Leaf {
134                aabb,
135                items: items.iter().map(|(id, _)| *id).collect(),
136            };
137        }
138
139        // Split along longest axis (median)
140        let axis = aabb.longest_axis();
141        let mut sorted = items.to_vec();
142        sorted.sort_by(|a, b| {
143            let ca = a.1.center();
144            let cb = b.1.center();
145            let av = [ca.x, ca.y, ca.z][axis];
146            let bv = [cb.x, cb.y, cb.z][axis];
147            av.partial_cmp(&bv).unwrap()
148        });
149        let mid = sorted.len() / 2;
150        let left  = Box::new(Self::build_node(&sorted[..mid]));
151        let right = Box::new(Self::build_node(&sorted[mid..]));
152        BvhNode::Internal { aabb, left, right }
153    }
154
155    /// Find all glyph IDs whose bounding box intersects the sphere.
156    pub fn sphere_query(&self, center: Vec3, radius: f32) -> Vec<GlyphId> {
157        let mut results = Vec::new();
158        Self::sphere_query_node(&self.root, center, radius, &mut results);
159        results
160    }
161
162    fn sphere_query_node(node: &BvhNode, center: Vec3, radius: f32, out: &mut Vec<GlyphId>) {
163        if !node.aabb().intersects_sphere(center, radius) { return; }
164        match node {
165            BvhNode::Leaf { items, .. } => { out.extend(items); }
166            BvhNode::Internal { left, right, .. } => {
167                Self::sphere_query_node(left,  center, radius, out);
168                Self::sphere_query_node(right, center, radius, out);
169            }
170        }
171    }
172
173    /// Find all glyph IDs intersected by a ray, sorted by distance.
174    pub fn ray_query(&self, origin: Vec3, direction: Vec3) -> Vec<(GlyphId, f32)> {
175        let dir = direction.normalize_or_zero();
176        let dir_inv = Vec3::new(
177            if dir.x.abs() > 1e-7 { 1.0 / dir.x } else { f32::MAX },
178            if dir.y.abs() > 1e-7 { 1.0 / dir.y } else { f32::MAX },
179            if dir.z.abs() > 1e-7 { 1.0 / dir.z } else { f32::MAX },
180        );
181        let mut results = Vec::new();
182        Self::ray_query_node(&self.root, origin, dir_inv, &mut results);
183        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
184        results
185    }
186
187    fn ray_query_node(
188        node: &BvhNode, origin: Vec3, dir_inv: Vec3, out: &mut Vec<(GlyphId, f32)>,
189    ) {
190        if node.aabb().ray_intersect(origin, dir_inv).is_none() { return; }
191        match node {
192            BvhNode::Leaf { items, aabb } => {
193                if let Some(t) = aabb.ray_intersect(origin, dir_inv) {
194                    for &id in items { out.push((id, t)); }
195                }
196            }
197            BvhNode::Internal { left, right, .. } => {
198                Self::ray_query_node(left,  origin, dir_inv, out);
199                Self::ray_query_node(right, origin, dir_inv, out);
200            }
201        }
202    }
203
204    /// AABB overlap query.
205    pub fn aabb_query(&self, query: &Aabb) -> Vec<GlyphId> {
206        let mut results = Vec::new();
207        Self::aabb_query_node(&self.root, query, &mut results);
208        results
209    }
210
211    fn aabb_query_node(node: &BvhNode, query: &Aabb, out: &mut Vec<GlyphId>) {
212        if !node.aabb().intersects_aabb(query) { return; }
213        match node {
214            BvhNode::Leaf { items, .. } => { out.extend(items); }
215            BvhNode::Internal { left, right, .. } => {
216                Self::aabb_query_node(left,  query, out);
217                Self::aabb_query_node(right, query, out);
218            }
219        }
220    }
221
222    pub fn depth(&self) -> usize { self.root.depth() }
223    pub fn count(&self) -> usize { self.root.count() }
224}
225
226// ─── Tests ────────────────────────────────────────────────────────────────────
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use glam::Vec3;
232
233    fn gid(n: u32) -> GlyphId { GlyphId(n) }
234
235    #[test]
236    fn aabb_contains() {
237        let a = Aabb::from_point(Vec3::ZERO, 1.0);
238        assert!( a.contains_point(Vec3::ZERO));
239        assert!(!a.contains_point(Vec3::new(2.0, 0.0, 0.0)));
240    }
241
242    #[test]
243    fn aabb_sphere() {
244        let a = Aabb::from_point(Vec3::ZERO, 1.0);
245        assert!( a.intersects_sphere(Vec3::ZERO, 0.5));
246        assert!( a.intersects_sphere(Vec3::new(2.0, 0.0, 0.0), 1.5));
247        assert!(!a.intersects_sphere(Vec3::new(5.0, 0.0, 0.0), 1.0));
248    }
249
250    #[test]
251    fn bvh_sphere_query() {
252        let items: Vec<(GlyphId, Aabb)> = (0u32..10)
253            .map(|i| (gid(i), Aabb::from_point(Vec3::new(i as f32 * 3.0, 0.0, 0.0), 0.5)))
254            .collect();
255        let bvh = Bvh::build(&items);
256        let hits = bvh.sphere_query(Vec3::ZERO, 2.0);
257        assert!(hits.contains(&gid(0)));
258        assert!(!hits.contains(&gid(5)));
259    }
260
261    #[test]
262    fn bvh_ray_query() {
263        let items = vec![
264            (gid(0), Aabb::from_point(Vec3::new(0.0, 0.0, 5.0),  0.5)),
265            (gid(1), Aabb::from_point(Vec3::new(0.0, 0.0, 20.0), 0.5)),
266            (gid(2), Aabb::from_point(Vec3::new(10.0, 0.0, 5.0), 0.5)),
267        ];
268        let bvh = Bvh::build(&items);
269        let hits = bvh.ray_query(Vec3::ZERO, Vec3::Z);
270        // Should hit items 0 and 1 (along Z), not item 2 (off to X side)
271        let ids: Vec<GlyphId> = hits.iter().map(|(id, _)| *id).collect();
272        assert!(ids.contains(&gid(0)));
273        assert!(ids.contains(&gid(1)));
274    }
275
276    #[test]
277    fn bvh_empty() {
278        let bvh = Bvh::build(&[]);
279        let hits = bvh.sphere_query(Vec3::ZERO, 100.0);
280        assert!(hits.is_empty());
281    }
282
283    #[test]
284    fn aabb_longest_axis() {
285        let a = Aabb::new(Vec3::ZERO, Vec3::new(10.0, 3.0, 1.0));
286        assert_eq!(a.longest_axis(), 0);
287        let b = Aabb::new(Vec3::ZERO, Vec3::new(1.0, 10.0, 3.0));
288        assert_eq!(b.longest_axis(), 1);
289    }
290}