Skip to main content

hisab/geo/
primitives.rs

1use super::*;
2
3/// A ray defined by an origin and a direction.
4///
5/// # Examples
6///
7/// ```
8/// use hisab::geo::{Ray, Sphere, ray_sphere};
9/// use glam::Vec3;
10///
11/// let ray = Ray::new(Vec3::new(0.0, 0.0, -5.0), Vec3::Z).unwrap();
12/// let sphere = Sphere::new(Vec3::ZERO, 1.0).unwrap();
13/// let t = ray_sphere(&ray, &sphere).unwrap();
14/// assert!((t - 4.0).abs() < 1e-5);
15/// ```
16#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
17pub struct Ray {
18    pub origin: Vec3,
19    /// Should be normalized for correct distance results.
20    pub direction: Vec3,
21}
22
23impl Ray {
24    /// Create a new ray. Direction is normalized automatically.
25    ///
26    /// # Errors
27    /// Returns [`crate::HisabError::InvalidInput`] if `direction` is zero-length.
28    #[inline]
29    pub fn new(origin: Vec3, direction: Vec3) -> Result<Self, crate::HisabError> {
30        let len_sq = direction.length_squared();
31        if len_sq < crate::EPSILON_F32 {
32            return Err(crate::HisabError::InvalidInput(
33                "ray direction must be non-zero".into(),
34            ));
35        }
36        Ok(Self {
37            origin,
38            direction: direction.normalize(),
39        })
40    }
41
42    /// Point along the ray at parameter `t`.
43    #[must_use]
44    #[inline]
45    pub fn at(&self, t: f32) -> Vec3 {
46        self.origin + self.direction * t
47    }
48}
49
50impl fmt::Display for Ray {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        let p = f.precision();
53        let o = self.origin;
54        let d = self.direction;
55        match p {
56            Some(p) => write!(
57                f,
58                "Ray({:.p$}, {:.p$}, {:.p$} -> {:.p$}, {:.p$}, {:.p$})",
59                o.x, o.y, o.z, d.x, d.y, d.z
60            ),
61            None => write!(
62                f,
63                "Ray({}, {}, {} -> {}, {}, {})",
64                o.x, o.y, o.z, d.x, d.y, d.z
65            ),
66        }
67    }
68}
69
70/// An infinite plane defined by a normal and a signed distance from the origin.
71///
72/// Points **on** the plane satisfy `dot(normal, point) - distance == 0`.
73#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
74pub struct Plane {
75    pub normal: Vec3,
76    pub distance: f32,
77}
78
79impl Plane {
80    /// Create a plane from a point on the plane and a normal.
81    ///
82    /// # Errors
83    /// Returns [`crate::HisabError::InvalidInput`] if `normal` is zero-length.
84    #[inline]
85    pub fn from_point_normal(point: Vec3, normal: Vec3) -> Result<Self, crate::HisabError> {
86        let len_sq = normal.length_squared();
87        if len_sq < crate::EPSILON_F32 {
88            return Err(crate::HisabError::InvalidInput(
89                "plane normal must be non-zero".into(),
90            ));
91        }
92        let n = normal * len_sq.sqrt().recip();
93        Ok(Self {
94            normal: n,
95            distance: n.dot(point),
96        })
97    }
98
99    /// Signed distance from a point to the plane.
100    /// Positive = same side as normal, negative = opposite side.
101    #[must_use]
102    #[inline]
103    pub fn signed_distance(&self, point: Vec3) -> f32 {
104        self.normal.dot(point) - self.distance
105    }
106}
107
108impl fmt::Display for Plane {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        let p = f.precision();
111        let n = self.normal;
112        match p {
113            Some(p) => write!(
114                f,
115                "Plane(n=({:.p$}, {:.p$}, {:.p$}), d={:.p$})",
116                n.x, n.y, n.z, self.distance
117            ),
118            None => write!(
119                f,
120                "Plane(n=({}, {}, {}), d={})",
121                n.x, n.y, n.z, self.distance
122            ),
123        }
124    }
125}
126
127/// An axis-aligned bounding box.
128#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
129pub struct Aabb {
130    pub min: Vec3,
131    pub max: Vec3,
132}
133
134impl Aabb {
135    /// Create a new AABB. Min/max are corrected if swapped.
136    #[must_use]
137    #[inline]
138    pub fn new(a: Vec3, b: Vec3) -> Self {
139        Self {
140            min: a.min(b),
141            max: a.max(b),
142        }
143    }
144
145    /// Check whether a point is inside (or on the boundary of) this AABB.
146    #[must_use]
147    #[inline]
148    pub fn contains(&self, point: Vec3) -> bool {
149        point.cmpge(self.min).all() && point.cmple(self.max).all()
150    }
151
152    /// Center point of the AABB.
153    #[must_use]
154    #[inline]
155    pub fn center(&self) -> Vec3 {
156        (self.min + self.max) * 0.5
157    }
158
159    /// Size (extents) of the AABB.
160    #[must_use]
161    #[inline]
162    pub fn size(&self) -> Vec3 {
163        self.max - self.min
164    }
165
166    /// Merge two AABBs into one that encloses both.
167    #[must_use]
168    #[inline]
169    pub fn merge(&self, other: &Aabb) -> Aabb {
170        Aabb {
171            min: self.min.min(other.min),
172            max: self.max.max(other.max),
173        }
174    }
175
176    /// Compute the AABB of this AABB after applying an affine transform.
177    ///
178    /// Uses the Arvo/Koppelman method, which avoids transforming all 8 corners.
179    /// For each output axis `i`, the contribution of input column `j` is:
180    /// `min[i] += min(m[i][j]*old_min[j], m[i][j]*old_max[j])`.
181    ///
182    /// Only the upper-left 3×3 rotation/scale portion of `transform` and the
183    /// translation column are used (the homogeneous row is ignored, so this is
184    /// correct for affine — but not projective — transforms).
185    #[must_use]
186    #[inline]
187    pub fn transformed(&self, transform: glam::Mat4) -> Aabb {
188        // Extract the 3×3 linear part and the translation.
189        let col = [
190            transform.x_axis.truncate(), // column 0
191            transform.y_axis.truncate(), // column 1
192            transform.z_axis.truncate(), // column 2
193        ];
194        let translation = transform.w_axis.truncate();
195
196        let old_min = self.min.to_array();
197        let old_max = self.max.to_array();
198
199        let mut new_min = translation;
200        let mut new_max = translation;
201
202        // For each output axis i, accumulate contributions from each input axis j.
203        let new_min_arr = new_min.as_mut();
204        let new_max_arr = new_max.as_mut();
205        for j in 0..3 {
206            let col_arr = col[j].to_array();
207            for i in 0..3 {
208                let lo = col_arr[i] * old_min[j];
209                let hi = col_arr[i] * old_max[j];
210                if lo < hi {
211                    new_min_arr[i] += lo;
212                    new_max_arr[i] += hi;
213                } else {
214                    new_min_arr[i] += hi;
215                    new_max_arr[i] += lo;
216                }
217            }
218        }
219
220        Aabb {
221            min: new_min,
222            max: new_max,
223        }
224    }
225}
226
227impl fmt::Display for Aabb {
228    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229        let p = f.precision();
230        match p {
231            Some(p) => write!(
232                f,
233                "Aabb(({:.p$}, {:.p$}, {:.p$})..({:.p$}, {:.p$}, {:.p$}))",
234                self.min.x, self.min.y, self.min.z, self.max.x, self.max.y, self.max.z
235            ),
236            None => write!(
237                f,
238                "Aabb(({}, {}, {})..({}, {}, {}))",
239                self.min.x, self.min.y, self.min.z, self.max.x, self.max.y, self.max.z
240            ),
241        }
242    }
243}
244
245/// A sphere defined by a center and radius.
246#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
247pub struct Sphere {
248    pub center: Vec3,
249    pub radius: f32,
250}
251
252impl Sphere {
253    /// Create a new sphere.
254    ///
255    /// # Errors
256    /// Returns [`crate::HisabError::InvalidInput`] if `radius` is negative.
257    #[inline]
258    pub fn new(center: Vec3, radius: f32) -> Result<Self, crate::HisabError> {
259        if radius < 0.0 {
260            return Err(crate::HisabError::InvalidInput(
261                "sphere radius must be non-negative".into(),
262            ));
263        }
264        Ok(Self { center, radius })
265    }
266
267    /// Check whether a point is inside (or on the surface of) this sphere.
268    #[must_use]
269    #[inline]
270    pub fn contains_point(&self, point: Vec3) -> bool {
271        (point - self.center).length_squared() <= self.radius * self.radius
272    }
273}
274
275impl fmt::Display for Sphere {
276    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
277        let p = f.precision();
278        let c = self.center;
279        match p {
280            Some(p) => write!(
281                f,
282                "Sphere(({:.p$}, {:.p$}, {:.p$}), r={:.p$})",
283                c.x, c.y, c.z, self.radius
284            ),
285            None => write!(f, "Sphere(({}, {}, {}), r={})", c.x, c.y, c.z, self.radius),
286        }
287    }
288}
289
290/// Ray-plane intersection. Returns the `t` parameter if the ray hits the plane
291/// (only `t >= 0`, i.e. forward hits).
292#[must_use]
293#[inline]
294pub fn ray_plane(ray: &Ray, plane: &Plane) -> Option<f32> {
295    let denom = plane.normal.dot(ray.direction);
296    if denom.abs() < crate::EPSILON_F32 {
297        return None; // Ray parallel to plane
298    }
299    let t = (plane.distance - plane.normal.dot(ray.origin)) / denom;
300    if t >= 0.0 { Some(t) } else { None }
301}
302
303/// Ray-sphere intersection using the quadratic formula.
304/// Returns the nearest `t >= 0` if the ray hits the sphere.
305///
306/// Assumes `ray.direction` is normalized (guaranteed by `Ray::new`),
307/// so the quadratic coefficient `a = 1` and is eliminated.
308#[must_use]
309#[inline]
310pub fn ray_sphere(ray: &Ray, sphere: &Sphere) -> Option<f32> {
311    let oc = ray.origin - sphere.center;
312    // With normalized direction: a=1, so b=2*dot(oc,d), c=dot(oc,oc)-r²
313    // Use half-b form: half_b = dot(oc, d), discriminant = half_b² - c
314    let half_b = oc.dot(ray.direction);
315    let c = oc.dot(oc) - sphere.radius * sphere.radius;
316    let discriminant = half_b * half_b - c;
317
318    if discriminant < 0.0 {
319        return None;
320    }
321
322    let sqrt_d = discriminant.sqrt();
323    let t1 = -half_b - sqrt_d;
324    let t2 = -half_b + sqrt_d;
325
326    if t1 >= 0.0 {
327        Some(t1)
328    } else if t2 >= 0.0 {
329        Some(t2)
330    } else {
331        None
332    }
333}
334
335/// Ray-AABB intersection using the slab method.
336/// Returns the nearest `t >= 0` if the ray hits the AABB.
337#[must_use]
338#[inline]
339pub fn ray_aabb(ray: &Ray, aabb: &Aabb) -> Option<f32> {
340    let origin = ray.origin.to_array();
341    let dir = ray.direction.to_array();
342    let bb_min = aabb.min.to_array();
343    let bb_max = aabb.max.to_array();
344
345    let mut t_min = f32::NEG_INFINITY;
346    let mut t_max = f32::INFINITY;
347
348    for i in 0..3 {
349        if dir[i].abs() < crate::EPSILON_F32 {
350            if origin[i] < bb_min[i] || origin[i] > bb_max[i] {
351                return None;
352            }
353        } else {
354            let inv_d = 1.0 / dir[i];
355            let mut t1 = (bb_min[i] - origin[i]) * inv_d;
356            let mut t2 = (bb_max[i] - origin[i]) * inv_d;
357            if t1 > t2 {
358                std::mem::swap(&mut t1, &mut t2);
359            }
360            t_min = t_min.max(t1);
361            t_max = t_max.min(t2);
362            if t_min > t_max {
363                return None;
364            }
365        }
366    }
367
368    if t_min >= 0.0 {
369        Some(t_min)
370    } else if t_max >= 0.0 {
371        Some(t_max)
372    } else {
373        None
374    }
375}
376
377// ---------------------------------------------------------------------------
378
379// OBB (Oriented Bounding Box)
380// ---------------------------------------------------------------------------
381
382/// An oriented bounding box defined by a center, half-extents, and rotation.
383#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
384pub struct Obb {
385    /// Center of the OBB.
386    pub center: Vec3,
387    /// Half-extents along each local axis.
388    pub half_extents: Vec3,
389    /// Rotation quaternion (local → world).
390    pub rotation: glam::Quat,
391}
392
393impl Obb {
394    /// Create a new OBB.
395    #[must_use]
396    #[inline]
397    pub fn new(center: Vec3, half_extents: Vec3, rotation: glam::Quat) -> Self {
398        Self {
399            center,
400            half_extents,
401            rotation,
402        }
403    }
404
405    /// The three local axes (columns of the rotation matrix) in world space.
406    #[must_use]
407    #[inline]
408    pub fn axes(&self) -> [Vec3; 3] {
409        let m = glam::Mat3::from_quat(self.rotation);
410        [m.x_axis, m.y_axis, m.z_axis]
411    }
412
413    /// Check whether a point is inside (or on the surface of) this OBB.
414    #[must_use]
415    #[inline]
416    pub fn contains_point(&self, point: Vec3) -> bool {
417        let d = point - self.center;
418        let axes = self.axes();
419        let he = self.half_extents.to_array();
420        for (i, axis) in axes.iter().enumerate() {
421            if d.dot(*axis).abs() > he[i] + crate::EPSILON_F32 {
422                return false;
423            }
424        }
425        true
426    }
427
428    /// Closest point on this OBB to a given point.
429    #[must_use]
430    #[inline]
431    pub fn closest_point(&self, point: Vec3) -> Vec3 {
432        let d = point - self.center;
433        let axes = self.axes();
434        let he = self.half_extents.to_array();
435        let mut result = self.center;
436        for (i, axis) in axes.iter().enumerate() {
437            let dist = d.dot(*axis).clamp(-he[i], he[i]);
438            result += *axis * dist;
439        }
440        result
441    }
442}
443
444/// Ray-OBB intersection. Returns the `t` parameter if the ray hits the OBB.
445#[must_use]
446#[inline]
447pub fn ray_obb(ray: &Ray, obb: &Obb) -> Option<f32> {
448    let d = obb.center - ray.origin;
449    let axes = obb.axes();
450    let he = obb.half_extents.to_array();
451
452    let mut t_min = f32::NEG_INFINITY;
453    let mut t_max = f32::INFINITY;
454
455    for i in 0..3 {
456        let e = axes[i].dot(d);
457        let f = axes[i].dot(ray.direction);
458
459        if f.abs() > crate::EPSILON_F32 {
460            let inv_f = 1.0 / f;
461            let mut t1 = (e - he[i]) * inv_f;
462            let mut t2 = (e + he[i]) * inv_f;
463            if t1 > t2 {
464                std::mem::swap(&mut t1, &mut t2);
465            }
466            t_min = t_min.max(t1);
467            t_max = t_max.min(t2);
468            if t_min > t_max {
469                return None;
470            }
471        } else if (-e - he[i]) > 0.0 || (-e + he[i]) < 0.0 {
472            return None;
473        }
474    }
475
476    if t_min >= 0.0 {
477        Some(t_min)
478    } else if t_max >= 0.0 {
479        Some(t_max)
480    } else {
481        None
482    }
483}
484
485// ---------------------------------------------------------------------------
486// Capsule
487// ---------------------------------------------------------------------------
488
489/// A capsule defined by a line segment and a radius.
490///
491/// The capsule is the Minkowski sum of the segment and a sphere.
492#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
493pub struct Capsule {
494    /// Start point of the capsule's axis.
495    pub start: Vec3,
496    /// End point of the capsule's axis.
497    pub end: Vec3,
498    /// Radius of the capsule.
499    pub radius: f32,
500}
501
502impl Capsule {
503    /// Create a new capsule.
504    ///
505    /// # Errors
506    ///
507    /// Returns [`crate::HisabError::InvalidInput`] if `radius` is negative.
508    #[inline]
509    pub fn new(start: Vec3, end: Vec3, radius: f32) -> Result<Self, crate::HisabError> {
510        if radius < 0.0 {
511            return Err(crate::HisabError::InvalidInput(
512                "capsule radius must be non-negative".into(),
513            ));
514        }
515        Ok(Self { start, end, radius })
516    }
517
518    /// Check whether a point is inside the capsule.
519    #[must_use]
520    #[inline]
521    pub fn contains_point(&self, point: Vec3) -> bool {
522        let seg = Segment::new(self.start, self.end);
523        seg.distance_to_point(point) <= self.radius + crate::EPSILON_F32
524    }
525
526    /// Length of the capsule's axis (not including the hemispherical caps).
527    #[must_use]
528    #[inline]
529    pub fn axis_length(&self) -> f32 {
530        (self.end - self.start).length()
531    }
532}
533
534/// Ray-capsule intersection. Returns the nearest `t >= 0` if the ray hits.
535#[must_use]
536pub fn ray_capsule(ray: &Ray, capsule: &Capsule) -> Option<f32> {
537    // Test against the infinite cylinder, then clamp to segment + check hemispheres
538    let ab = capsule.end - capsule.start;
539    let ab_len_sq = ab.dot(ab);
540
541    if ab_len_sq < crate::EPSILON_F32 {
542        // Degenerate capsule: just a sphere
543        let sphere = Sphere {
544            center: capsule.start,
545            radius: capsule.radius,
546        };
547        return ray_sphere(ray, &sphere);
548    }
549
550    // Closest approach of ray to segment axis
551    let ao = ray.origin - capsule.start;
552    let d_par = ray.direction.dot(ab) / ab_len_sq;
553    let o_par = ao.dot(ab) / ab_len_sq;
554
555    let d_perp = ray.direction - ab * d_par;
556    let o_perp = ao - ab * o_par;
557
558    let a = d_perp.dot(d_perp);
559    let b = 2.0 * d_perp.dot(o_perp);
560    let c = o_perp.dot(o_perp) - capsule.radius * capsule.radius;
561
562    let disc = b * b - 4.0 * a * c;
563    if disc < 0.0 {
564        // Try sphere caps
565        let s1 = Sphere {
566            center: capsule.start,
567            radius: capsule.radius,
568        };
569        let s2 = Sphere {
570            center: capsule.end,
571            radius: capsule.radius,
572        };
573        let t1 = ray_sphere(ray, &s1);
574        let t2 = ray_sphere(ray, &s2);
575        return match (t1, t2) {
576            (Some(a), Some(b)) => Some(a.min(b)),
577            (Some(a), None) | (None, Some(a)) => Some(a),
578            _ => None,
579        };
580    }
581
582    let inv_2a = 0.5 / a;
583    let sqrt_disc = disc.sqrt();
584    let t1 = (-b - sqrt_disc) * inv_2a;
585    let t2 = (-b + sqrt_disc) * inv_2a;
586
587    let mut best: Option<f32> = None;
588    let mut check = |t: f32| {
589        if t >= 0.0 {
590            let p = ray.at(t);
591            let proj = (p - capsule.start).dot(ab) / ab_len_sq;
592            if (0.0..=1.0).contains(&proj) {
593                best = Some(best.map_or(t, |b: f32| b.min(t)));
594            }
595        }
596    };
597    check(t1);
598    check(t2);
599
600    // Also check hemisphere caps
601    let s1 = Sphere {
602        center: capsule.start,
603        radius: capsule.radius,
604    };
605    let s2 = Sphere {
606        center: capsule.end,
607        radius: capsule.radius,
608    };
609    if let Some(t) = ray_sphere(ray, &s1) {
610        let p = ray.at(t);
611        if (p - capsule.start).dot(ab) <= 0.0 {
612            best = Some(best.map_or(t, |b: f32| b.min(t)));
613        }
614    }
615    if let Some(t) = ray_sphere(ray, &s2) {
616        let p = ray.at(t);
617        if (p - capsule.end).dot(ab) >= 0.0 {
618            best = Some(best.map_or(t, |b: f32| b.min(t)));
619        }
620    }
621
622    best
623}
624
625// ---------------------------------------------------------------------------
626
627#[cfg(test)]
628mod tests {
629    use super::*;
630    use glam::{Mat4, Vec3};
631
632    const EPS: f32 = 1e-5;
633
634    fn approx_vec3(a: Vec3, b: Vec3) -> bool {
635        (a - b).length() < EPS
636    }
637
638    // --- Aabb::transformed tests --------------------------------------------
639
640    #[test]
641    fn transformed_identity_unchanged() {
642        let aabb = Aabb::new(Vec3::new(-1.0, -2.0, -3.0), Vec3::new(1.0, 2.0, 3.0));
643        let result = aabb.transformed(Mat4::IDENTITY);
644        assert!(approx_vec3(result.min, aabb.min));
645        assert!(approx_vec3(result.max, aabb.max));
646    }
647
648    #[test]
649    fn transformed_translation_only() {
650        let aabb = Aabb::new(Vec3::new(-1.0, -1.0, -1.0), Vec3::new(1.0, 1.0, 1.0));
651        let t = Mat4::from_translation(Vec3::new(3.0, 5.0, -2.0));
652        let result = aabb.transformed(t);
653        assert!(approx_vec3(result.min, Vec3::new(2.0, 4.0, -3.0)));
654        assert!(approx_vec3(result.max, Vec3::new(4.0, 6.0, -1.0)));
655    }
656
657    #[test]
658    fn transformed_uniform_scale() {
659        let aabb = Aabb::new(Vec3::new(-1.0, -1.0, -1.0), Vec3::new(1.0, 1.0, 1.0));
660        let s = Mat4::from_scale(Vec3::splat(2.0));
661        let result = aabb.transformed(s);
662        assert!(approx_vec3(result.min, Vec3::splat(-2.0)));
663        assert!(approx_vec3(result.max, Vec3::splat(2.0)));
664    }
665
666    #[test]
667    fn transformed_90_deg_rotation() {
668        // Rotate 90° around Z: (x,y,z) → (-y, x, z)
669        // An AABB [0,1]×[0,1]×[0,1] should become [-1,0]×[0,1]×[0,1]
670        let aabb = Aabb::new(Vec3::new(0.0, 0.0, 0.0), Vec3::new(1.0, 1.0, 1.0));
671        let r = Mat4::from_rotation_z(std::f32::consts::FRAC_PI_2);
672        let result = aabb.transformed(r);
673        // new_min.x = -1, new_max.x = 0 (within rounding)
674        assert!(
675            (result.min.x - (-1.0)).abs() < 1e-5,
676            "min.x = {}",
677            result.min.x
678        );
679        assert!(
680            (result.max.x - 0.0).abs() < 1e-5,
681            "max.x = {}",
682            result.max.x
683        );
684        assert!(
685            (result.min.y - 0.0).abs() < 1e-5,
686            "min.y = {}",
687            result.min.y
688        );
689        assert!(
690            (result.max.y - 1.0).abs() < 1e-5,
691            "max.y = {}",
692            result.max.y
693        );
694    }
695
696    #[test]
697    fn transformed_result_min_le_max() {
698        // Regardless of the transform, min should never exceed max.
699        let aabb = Aabb::new(Vec3::new(-2.0, -3.0, -1.0), Vec3::new(2.0, 3.0, 1.0));
700        let m = Mat4::from_cols_array(&[
701            -1.0, 0.5, 0.0, 0.0, 0.3, 2.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 2.0, 3.0, 1.0,
702        ]);
703        let result = aabb.transformed(m);
704        assert!(result.min.x <= result.max.x);
705        assert!(result.min.y <= result.max.y);
706        assert!(result.min.z <= result.max.z);
707    }
708}