ray_tracer/hitable/primitive/
sphere.rs

1use crate::float::Float;
2use crate::vector::Vec3;
3use crate::ray::Ray;
4use crate::hit::Hit;
5use crate::hitable::Hitable;
6use crate::boundingbox::BoundingBox;
7
8pub struct Sphere<T>
9    where T: Float
10{
11    radius: T,
12    bounds: BoundingBox<T>
13}
14
15impl<T> Sphere<T>
16    where T: Float
17{
18    pub fn new(radius: T) -> Self {
19        let mut sphere = Sphere {
20            radius,
21            bounds: BoundingBox::<T>::new(Vec3::<T>::new(), Vec3::<T>::new())
22        };
23        sphere.update_bounds();
24        sphere
25    }
26
27    pub fn get_radius(&self) -> T {
28        self.radius
29    }
30
31    pub fn set_radius(&mut self, radius: T) {
32        self.radius = radius;
33        self.update_bounds();
34    }
35
36    fn update_bounds(&mut self) {
37        let one = Vec3::<T>::from_array([T::one(), T::one(), T::one()]);
38        let p0 = &one * self.get_radius() * (- T::one());
39        let p1 = &one * self.get_radius();
40        self.bounds = BoundingBox::<T>::new(p0, p1);
41    }
42}
43
44impl<T> Hitable<T> for Sphere<T>
45    where T: Float
46{
47    fn hit(&self, ray: &Ray<T>, t_min: T, t_max: T) -> Option<Hit<T>> {
48        // Intersection of a line and a sphere:
49        //
50        // p(t) = origin + t * direction
51        //
52        //   dot( (p(t) - center), (p(t) - center) )
53        // = radius * radius
54        //
55        //   t * t * dot(direction, direction)
56        // + 2 * t * dot(direction, origin - center)
57        // + dot(origin - center, origin - center)
58        // - radius * radius
59        // = 0
60        //
61        // t = (- b +/- sqrt(b * b - 4 * a * c)) / (2 * a)
62        // 
63        // drop 2s coming from b
64        let oc = ray.get_origin();
65        let a = ray.get_direction().dot(ray.get_direction());
66        let b = ray.get_direction().dot(&oc);
67        let c = oc.dot(&oc) - self.get_radius() * self.get_radius();
68        let discriminant = b * b - a * c;
69        if discriminant <= T::zero() {
70            return None;
71        }
72        let discriminant = discriminant.sqrt();
73        let t0 = (- b - discriminant) / a;
74        let t1 = (- b + discriminant) / a;
75        let t = if t0 >= t_min && t0 < t_max { t0 }
76                else if t1 >= t_min && t1 < t_max { t1 }
77                else { return None; };
78
79        let point = ray.get_point(t);
80        let normal = (&point) / self.get_radius();
81        let hit = Hit {
82            point,
83            normal,
84            t
85        };
86
87        Some(hit)
88    }
89
90    fn get_bounds(&self) -> &BoundingBox<T> {
91        &self.bounds
92    }
93
94    fn unwrap(self: Box<Self>) -> Box<dyn Hitable<T>> {
95        self
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102
103    #[test]
104    fn init() {
105        let sphere = Sphere::<f64>::new(1.0);
106        assert_eq!(sphere.get_radius(), 1.0);
107
108        let radius = 5.5;
109        let sphere = Sphere::<f64>::new(radius);
110        assert_eq!(sphere.get_radius(), 5.5);
111    }
112
113    #[test]
114    fn set() {
115        let mut sphere = Sphere::<f64>::new(1.0);
116        let radius = 5.5;
117        sphere.set_radius(radius);
118        assert_eq!(sphere.get_radius(), 5.5);
119    }
120
121    #[test]
122    fn hit() {
123        let radius = 2.0;
124        let sphere = Sphere::<f64>::new(radius);
125
126        let origin = [-8.0, 0.0, 0.0];
127        let direction = [2.0, 0.0, 0.0];
128        let ray = Ray::from_array(origin, direction);
129        let hit = sphere.hit(&ray, 0.0, 100.0);
130        match hit {
131            Some(hit) => {
132                assert_eq!(hit.point.get_data(), [-2.0, 0.0, 0.0]);
133                assert_eq!(hit.normal.get_data(), [-1.0, 0.0, 0.0]);
134                assert_eq!(hit.t, 3.0);
135            },
136            None => {
137                assert!(false);
138            }
139        }
140
141        let origin = [-8.0, 2.1, 0.0];
142        let direction = [2.0, 0.0, 0.0];
143        let ray = Ray::from_array(origin, direction);
144        let hit = sphere.hit(&ray, 0.0, 100.0);
145        match hit {
146            Some(hit) => {
147                assert!(false);
148            },
149            None => {}
150        }
151
152        let origin = [-8.0, 0.0, 0.0];
153        let direction = [0.0, 2.0, 0.0];
154        let ray = Ray::from_array(origin, direction);
155        let hit = sphere.hit(&ray, 0.0, 100.0);
156        match hit {
157            Some(hit) => {
158                assert!(false);
159            },
160            None => {}
161        }
162    }
163
164    #[test]
165    fn bounds() {
166        let radius = 2.5;
167        let sphere = Sphere::<f64>::new(radius);
168        let bounds = sphere.get_bounds();
169        assert_eq!(bounds.get_p0().get_data(), [-2.5, -2.5, -2.5]);
170        assert_eq!(bounds.get_p1().get_data(), [2.5, 2.5, 2.5]);
171    }
172}