Skip to main content

gizmo_physics_core/
raycast.rs

1use crate::components::{ColliderShape, Transform};
2use gizmo_core::entity::Entity;
3use gizmo_math::Aabb;
4use gizmo_math::Vec3;
5
6/// Ray for raycasting
7#[derive(Debug, Clone, Copy)]
8pub struct Ray {
9    pub origin: Vec3,
10    pub direction: Vec3, // Should be normalized
11}
12
13impl Ray {
14    pub fn new(origin: Vec3, direction: Vec3) -> Self {
15        Self {
16            origin,
17            direction: direction.normalize(),
18        }
19    }
20
21    pub fn point_at(&self, t: f32) -> Vec3 {
22        self.origin + self.direction * t
23    }
24}
25
26/// Result of a raycast hit
27#[derive(Debug, Clone, Copy)]
28pub struct RaycastHit {
29    pub entity: Entity,
30    pub point: Vec3,
31    pub normal: Vec3,
32    pub distance: f32,
33}
34
35/// Raycast query system
36pub struct Raycast;
37
38impl Raycast {
39    /// Test ray against AABB
40    pub fn ray_aabb(ray: &Ray, aabb: &Aabb) -> Option<f32> {
41        let mut tmin: f32 = 0.0;
42        let mut tmax = f32::INFINITY;
43
44        for i in 0..3 {
45            let origin = match i {
46                0 => ray.origin.x,
47                1 => ray.origin.y,
48                _ => ray.origin.z,
49            };
50            let dir = match i {
51                0 => ray.direction.x,
52                1 => ray.direction.y,
53                _ => ray.direction.z,
54            };
55            let min = match i {
56                0 => aabb.min.x,
57                1 => aabb.min.y,
58                _ => aabb.min.z,
59            };
60            let max = match i {
61                0 => aabb.max.x,
62                1 => aabb.max.y,
63                _ => aabb.max.z,
64            };
65
66            if dir.abs() < 1e-8 {
67                // Ray is parallel to slab
68                if origin < min || origin > max {
69                    return None;
70                }
71            } else {
72                let inv_d = 1.0 / dir;
73                let mut t1 = (min - origin) * inv_d;
74                let mut t2 = (max - origin) * inv_d;
75
76                if t1 > t2 {
77                    std::mem::swap(&mut t1, &mut t2);
78                }
79
80                tmin = tmin.max(t1);
81                tmax = tmax.min(t2);
82
83                if tmin > tmax {
84                    return None;
85                }
86            }
87        }
88
89        Some(tmin)
90    }
91
92    /// Test ray against sphere
93    pub fn ray_sphere(ray: &Ray, center: Vec3, radius: f32) -> Option<(f32, Vec3)> {
94        let oc = ray.origin - center;
95        let b = oc.dot(ray.direction);
96        let c = oc.dot(oc) - radius * radius;
97        let discriminant = b * b - c;
98
99        if discriminant < 0.0 {
100            return None;
101        }
102
103        let sqrt_d = discriminant.sqrt();
104        let t1 = -b - sqrt_d;
105        let t2 = -b + sqrt_d;
106
107        let t = if t1 > 0.0 {
108            t1
109        } else if t2 > 0.0 {
110            t2
111        } else {
112            return None;
113        };
114
115        let hit_point = ray.point_at(t);
116        let normal = (hit_point - center).try_normalize().unwrap_or(Vec3::Y);
117
118        Some((t, normal))
119    }
120
121    /// Test ray against box (OBB)
122    pub fn ray_box(
123        ray: &Ray,
124        center: Vec3,
125        rotation: gizmo_math::Quat,
126        half_extents: Vec3,
127    ) -> Option<(f32, Vec3)> {
128        // Transform ray to box's local space
129        let inv_rot = rotation.inverse();
130        let local_origin = inv_rot * (ray.origin - center);
131        let local_dir = inv_rot * ray.direction;
132
133        let local_ray = Ray::new(local_origin, local_dir);
134
135        // Create AABB in local space
136        let local_aabb = Aabb::from_center_half_extents(Vec3::ZERO, half_extents);
137
138        if let Some(t) = Self::ray_aabb(&local_ray, &local_aabb) {
139            let local_hit = local_ray.point_at(t);
140
141            // Calculate normal in local space
142            let mut normal = Vec3::ZERO;
143
144            let epsilon = 1e-4;
145            for i in 0..3 {
146                if (local_hit[i] - half_extents[i]).abs() < epsilon {
147                    normal[i] = 1.0;
148                }
149                if (local_hit[i] + half_extents[i]).abs() < epsilon {
150                    normal[i] = -1.0;
151                }
152            }
153            normal = normal.try_normalize().unwrap_or(Vec3::Y);
154
155            // Transform normal back to world space
156            let world_normal = rotation * normal;
157
158            Some((t, world_normal))
159        } else {
160            None
161        }
162    }
163
164    /// Test ray against capsule
165    pub fn ray_capsule(
166        ray: &Ray,
167        center: Vec3,
168        rotation: gizmo_math::Quat,
169        radius: f32,
170        half_height: f32,
171    ) -> Option<(f32, Vec3)> {
172        // Transform to local space
173        let inv_rot = rotation.inverse();
174        let local_origin = inv_rot * (ray.origin - center);
175        let local_dir = inv_rot * ray.direction;
176
177        // Capsule is aligned along Y axis in local space
178        let p1 = Vec3::new(0.0, half_height, 0.0);
179        let p2 = Vec3::new(0.0, -half_height, 0.0);
180
181        // Ray-cylinder intersection
182        let ba = p2 - p1;
183        let oc = local_origin - p1;
184
185        let baba = ba.dot(ba);
186        let bard = ba.dot(local_dir);
187        let baoc = ba.dot(oc);
188
189        let k2 = baba - bard * bard;
190        let k1 = baba * oc.dot(local_dir) - baoc * bard;
191        let k0 = baba * oc.dot(oc) - baoc * baoc - radius * radius * baba;
192
193        if k2.abs() >= 1e-8 {
194            let h = k1 * k1 - k2 * k0;
195            if h >= 0.0 {
196                let t = (-k1 - h.sqrt()) / k2;
197                // Check if hit is within cylinder height
198                let y = baoc + t * bard;
199                if y > 0.0 && y < baba {
200                    let hit_point = local_origin + local_dir * t;
201                    let normal = (hit_point - (p1 + ba * (y / baba)))
202                        .try_normalize()
203                        .unwrap_or(Vec3::Y);
204                    let world_normal = rotation * normal;
205                    return Some((t, world_normal));
206                }
207            }
208        }
209
210        // Check sphere caps
211        let mut best_t = f32::INFINITY;
212        let mut best_normal = Vec3::ZERO;
213
214        for &cap_center in &[p1, p2] {
215            let oc = local_origin - cap_center;
216            let a = local_dir.dot(local_dir);
217            let b = 2.0 * oc.dot(local_dir);
218            let c = oc.dot(oc) - radius * radius;
219            let discriminant = b * b - 4.0 * a * c;
220
221            if discriminant >= 0.0 {
222                let t = (-b - discriminant.sqrt()) / (2.0 * a);
223                if t > 0.0 && t < best_t {
224                    best_t = t;
225                    let hit = local_origin + local_dir * t;
226                    best_normal = (hit - cap_center).try_normalize().unwrap_or(Vec3::Y);
227                }
228            }
229        }
230
231        if best_t < f32::INFINITY {
232            let world_normal = rotation * best_normal;
233            Some((best_t, world_normal))
234        } else {
235            None
236        }
237    }
238
239    /// Test ray against collider shape
240    pub fn ray_shape(
241        ray: &Ray,
242        shape: &ColliderShape,
243        transform: &Transform,
244    ) -> Option<(f32, Vec3)> {
245        match shape {
246            ColliderShape::Sphere(s) => Self::ray_sphere(ray, transform.position, s.radius),
247            ColliderShape::Box(b) => {
248                Self::ray_box(ray, transform.position, transform.rotation, b.half_extents)
249            }
250            ColliderShape::Capsule(c) => Self::ray_capsule(
251                ray,
252                transform.position,
253                transform.rotation,
254                c.radius,
255                c.half_height,
256            ),
257            ColliderShape::Plane(p) => {
258                // Ray-plane intersection
259                let denom = ray.direction.dot(p.normal);
260                if denom.abs() > 1e-6 {
261                    let t = (p.distance - ray.origin.dot(p.normal)) / denom;
262                    if t >= 0.0 {
263                        let normal = if denom < 0.0 { p.normal } else { -p.normal };
264                        Some((t, normal))
265                    } else {
266                        None
267                    }
268                } else {
269                    None
270                }
271            }
272            ColliderShape::TriMesh(tm) => {
273                let mut best_t = f32::INFINITY;
274                let mut best_normal = Vec3::ZERO;
275                let inv_rot = transform.rotation.inverse();
276                let local_origin = inv_rot * (ray.origin - transform.position);
277                let local_dir = inv_rot * ray.direction;
278                let local_ray = Ray::new(local_origin, local_dir);
279
280                if !tm.bvh.nodes.is_empty() {
281                    let mut stack = Vec::with_capacity(64);
282                    stack.push(0); // root node
283
284                    while let Some(node_idx) = stack.pop() {
285                        let node = &tm.bvh.nodes[node_idx];
286
287                        // Check AABB
288                        if Self::ray_aabb(&local_ray, &node.aabb).is_none() {
289                            continue;
290                        }
291
292                        if node.is_leaf() {
293                            let start = (node.first_tri_index * 3) as usize;
294                            let end = start + (node.tri_count * 3) as usize;
295                            for i in (start..end).step_by(3) {
296                                let v0 = tm.vertices[tm.indices[i] as usize];
297                                let v1 = tm.vertices[tm.indices[i + 1] as usize];
298                                let v2 = tm.vertices[tm.indices[i + 2] as usize];
299
300                                let e1 = v1 - v0;
301                                let e2 = v2 - v0;
302                                let h = local_dir.cross(e2);
303                                let a = e1.dot(h);
304                                if a.abs() < 1e-6 {
305                                    continue;
306                                }
307                                let f = 1.0 / a;
308                                let s = local_origin - v0;
309                                let u = f * s.dot(h);
310                                if !(0.0..=1.0).contains(&u) {
311                                    continue;
312                                }
313                                let q = s.cross(e1);
314                                let v = f * local_dir.dot(q);
315                                if v < 0.0 || u + v > 1.0 {
316                                    continue;
317                                }
318                                let t = f * e2.dot(q);
319                                if t > 0.0 && t < best_t {
320                                    best_t = t;
321                                    best_normal = e1.cross(e2).try_normalize().unwrap_or(Vec3::Y);
322                                    if best_normal.dot(local_dir) > 0.0 {
323                                        best_normal = -best_normal;
324                                    }
325                                }
326                            }
327                        } else {
328                            if node.left_child >= 0 {
329                                stack.push(node.left_child as usize);
330                            }
331                            if node.right_child >= 0 {
332                                stack.push(node.right_child as usize);
333                            }
334                        }
335                    }
336                } else {
337                    // Fallback to naive loop if BVH is missing
338                    for chunk in tm.indices.chunks_exact(3) {
339                        let v0 = tm.vertices[chunk[0] as usize];
340                        let v1 = tm.vertices[chunk[1] as usize];
341                        let v2 = tm.vertices[chunk[2] as usize];
342                        let e1 = v1 - v0;
343                        let e2 = v2 - v0;
344                        let h = local_dir.cross(e2);
345                        let a = e1.dot(h);
346                        if a.abs() < 1e-6 {
347                            continue;
348                        }
349                        let f = 1.0 / a;
350                        let s = local_origin - v0;
351                        let u = f * s.dot(h);
352                        if !(0.0..=1.0).contains(&u) {
353                            continue;
354                        }
355                        let q = s.cross(e1);
356                        let v = f * local_dir.dot(q);
357                        if v < 0.0 || u + v > 1.0 {
358                            continue;
359                        }
360                        let t = f * e2.dot(q);
361                        if t > 0.0 && t < best_t {
362                            best_t = t;
363                            best_normal = e1.cross(e2).try_normalize().unwrap_or(Vec3::Y);
364                            if best_normal.dot(local_dir) > 0.0 {
365                                best_normal = -best_normal;
366                            }
367                        }
368                    }
369                }
370
371                if best_t < f32::INFINITY {
372                    Some((best_t, transform.rotation * best_normal))
373                } else {
374                    None
375                }
376            }
377            ColliderShape::ConvexHull(ch) => {
378                let mut min = Vec3::splat(f32::MAX);
379                let mut max = Vec3::splat(f32::MIN);
380                for v in ch.vertices.iter() {
381                    min.x = min.x.min(v.x);
382                    min.y = min.y.min(v.y);
383                    min.z = min.z.min(v.z);
384                    max.x = max.x.max(v.x);
385                    max.y = max.y.max(v.y);
386                    max.z = max.z.max(v.z);
387                }
388                let center = (min + max) * 0.5;
389                let half_extents = (max - min) * 0.5;
390
391                // Adjust transform to local space of the original transform
392                let world_center = transform.position + transform.rotation * center;
393                Self::ray_box(ray, world_center, transform.rotation, half_extents)
394            }
395            ColliderShape::Compound(shapes) => {
396                let mut closest_dist = f32::MAX;
397                let mut closest_normal = Vec3::ZERO;
398                for (local_t, sub_shape) in shapes {
399                    let world_pos =
400                        transform.position + transform.rotation.mul_vec3(local_t.position);
401                    let world_rot = transform.rotation * local_t.rotation;
402                    let world_t =
403                        crate::components::Transform::new(world_pos).with_rotation(world_rot);
404                    if let Some((d, n)) = Self::ray_shape(ray, sub_shape, &world_t) {
405                        if d < closest_dist {
406                            closest_dist = d;
407                            closest_normal = n;
408                        }
409                    }
410                }
411                if closest_dist < f32::MAX {
412                    Some((closest_dist, closest_normal))
413                } else {
414                    None
415                }
416            }
417        }
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    #[test]
426    fn test_ray_sphere() {
427        let ray = Ray::new(Vec3::new(0.0, 0.0, -5.0), Vec3::new(0.0, 0.0, 1.0));
428        let center = Vec3::ZERO;
429        let radius = 1.0;
430
431        let result = Raycast::ray_sphere(&ray, center, radius);
432        assert!(result.is_some());
433
434        let (t, _normal) = result.unwrap();
435        assert!((t - 4.0).abs() < 0.01);
436    }
437
438    #[test]
439    fn test_ray_aabb() {
440        let ray = Ray::new(Vec3::new(0.0, 0.0, -5.0), Vec3::new(0.0, 0.0, 1.0));
441        let aabb = Aabb::from_center_half_extents(Vec3::ZERO, Vec3::splat(1.0));
442
443        let result = Raycast::ray_aabb(&ray, &aabb);
444        assert!(result.is_some());
445
446        let t = result.unwrap();
447        assert!((t - 4.0).abs() < 0.01);
448    }
449
450    #[test]
451    fn test_ray_miss() {
452        let ray = Ray::new(Vec3::new(5.0, 0.0, 0.0), Vec3::new(0.0, 0.0, 1.0));
453        let center = Vec3::ZERO;
454        let radius = 1.0;
455
456        let result = Raycast::ray_sphere(&ray, center, radius);
457        assert!(result.is_none());
458    }
459
460    #[test]
461    fn test_ray_box() {
462        let ray = Ray::new(Vec3::new(0.0, 0.0, -5.0), Vec3::new(0.0, 0.0, 1.0));
463        let center = Vec3::ZERO;
464        let result = Raycast::ray_box(&ray, center, gizmo_math::Quat::IDENTITY, Vec3::splat(1.0));
465        assert!(result.is_some());
466        let (t, normal) = result.unwrap();
467        assert!((t - 4.0).abs() < 0.01);
468        assert!((normal.z - -1.0).abs() < 0.01);
469    }
470
471    #[test]
472    fn test_ray_capsule() {
473        let ray = Ray::new(Vec3::new(0.0, 0.0, -5.0), Vec3::new(0.0, 0.0, 1.0));
474        let center = Vec3::ZERO;
475        let result = Raycast::ray_capsule(&ray, center, gizmo_math::Quat::IDENTITY, 1.0, 1.0);
476        assert!(result.is_some());
477        let (t, normal) = result.unwrap();
478        assert!((t - 4.0).abs() < 0.01);
479        assert!((normal.z - -1.0).abs() < 0.01);
480    }
481
482    #[test]
483    fn test_ray_capsule_parallel() {
484        let ray = Ray::new(Vec3::new(0.0, 10.0, 0.0), Vec3::new(0.0, -1.0, 0.0));
485        let center = Vec3::ZERO;
486        // The ray is parallel to the Y axis (the capsule's internal axis).
487        // It hits the top sphere cap. The height is half_height = 1.0.
488        // The top sphere cap is centered at Y=1.0 with radius 1.0. Hit should be at Y=2.0.
489        let result = Raycast::ray_capsule(&ray, center, gizmo_math::Quat::IDENTITY, 1.0, 1.0);
490        assert!(result.is_some());
491        let (t, normal) = result.unwrap();
492        assert!((t - 8.0).abs() < 0.01); // 10.0 - 2.0 = 8.0
493        assert!((normal.y - 1.0).abs() < 0.01);
494    }
495
496    #[test]
497    fn test_ray_plane_backface() {
498        // Plane is at Z=0, pointing towards +Z.
499        let plane = crate::components::PlaneShape {
500            normal: Vec3::Z,
501            distance: 0.0,
502        };
503        let shape = ColliderShape::Plane(plane);
504
505        // Ray from -5 looking towards +Z
506        let ray = Ray::new(Vec3::new(0.0, 0.0, -5.0), Vec3::new(0.0, 0.0, 1.0));
507        let result = Raycast::ray_shape(&ray, &shape, &Transform::new(Vec3::ZERO));
508        assert!(result.is_some());
509        assert_eq!(result.unwrap().1, -Vec3::Z); // Should be flipped since ray hits the backface
510    }
511}