ray_tracer/hitable/primitive/
sphere.rs1use crate::float::Float;
2use crate::vector::Vec3;
3use crate::ray::Ray;
4use crate::hit::Hit;
5use crate::hitable::Hitable;
6use crate::boundingbox::BoundingBox;
7
8pub struct Sphere<T>
9 where T: Float
10{
11 radius: T,
12 bounds: BoundingBox<T>
13}
14
15impl<T> Sphere<T>
16 where T: Float
17{
18 pub fn new(radius: T) -> Self {
19 let mut sphere = Sphere {
20 radius,
21 bounds: BoundingBox::<T>::new(Vec3::<T>::new(), Vec3::<T>::new())
22 };
23 sphere.update_bounds();
24 sphere
25 }
26
27 pub fn get_radius(&self) -> T {
28 self.radius
29 }
30
31 pub fn set_radius(&mut self, radius: T) {
32 self.radius = radius;
33 self.update_bounds();
34 }
35
36 fn update_bounds(&mut self) {
37 let one = Vec3::<T>::from_array([T::one(), T::one(), T::one()]);
38 let p0 = &one * self.get_radius() * (- T::one());
39 let p1 = &one * self.get_radius();
40 self.bounds = BoundingBox::<T>::new(p0, p1);
41 }
42}
43
44impl<T> Hitable<T> for Sphere<T>
45 where T: Float
46{
47 fn hit(&self, ray: &Ray<T>, t_min: T, t_max: T) -> Option<Hit<T>> {
48 let oc = ray.get_origin();
65 let a = ray.get_direction().dot(ray.get_direction());
66 let b = ray.get_direction().dot(&oc);
67 let c = oc.dot(&oc) - self.get_radius() * self.get_radius();
68 let discriminant = b * b - a * c;
69 if discriminant <= T::zero() {
70 return None;
71 }
72 let discriminant = discriminant.sqrt();
73 let t0 = (- b - discriminant) / a;
74 let t1 = (- b + discriminant) / a;
75 let t = if t0 >= t_min && t0 < t_max { t0 }
76 else if t1 >= t_min && t1 < t_max { t1 }
77 else { return None; };
78
79 let point = ray.get_point(t);
80 let normal = (&point) / self.get_radius();
81 let hit = Hit {
82 point,
83 normal,
84 t
85 };
86
87 Some(hit)
88 }
89
90 fn get_bounds(&self) -> &BoundingBox<T> {
91 &self.bounds
92 }
93
94 fn unwrap(self: Box<Self>) -> Box<dyn Hitable<T>> {
95 self
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102
103 #[test]
104 fn init() {
105 let sphere = Sphere::<f64>::new(1.0);
106 assert_eq!(sphere.get_radius(), 1.0);
107
108 let radius = 5.5;
109 let sphere = Sphere::<f64>::new(radius);
110 assert_eq!(sphere.get_radius(), 5.5);
111 }
112
113 #[test]
114 fn set() {
115 let mut sphere = Sphere::<f64>::new(1.0);
116 let radius = 5.5;
117 sphere.set_radius(radius);
118 assert_eq!(sphere.get_radius(), 5.5);
119 }
120
121 #[test]
122 fn hit() {
123 let radius = 2.0;
124 let sphere = Sphere::<f64>::new(radius);
125
126 let origin = [-8.0, 0.0, 0.0];
127 let direction = [2.0, 0.0, 0.0];
128 let ray = Ray::from_array(origin, direction);
129 let hit = sphere.hit(&ray, 0.0, 100.0);
130 match hit {
131 Some(hit) => {
132 assert_eq!(hit.point.get_data(), [-2.0, 0.0, 0.0]);
133 assert_eq!(hit.normal.get_data(), [-1.0, 0.0, 0.0]);
134 assert_eq!(hit.t, 3.0);
135 },
136 None => {
137 assert!(false);
138 }
139 }
140
141 let origin = [-8.0, 2.1, 0.0];
142 let direction = [2.0, 0.0, 0.0];
143 let ray = Ray::from_array(origin, direction);
144 let hit = sphere.hit(&ray, 0.0, 100.0);
145 match hit {
146 Some(hit) => {
147 assert!(false);
148 },
149 None => {}
150 }
151
152 let origin = [-8.0, 0.0, 0.0];
153 let direction = [0.0, 2.0, 0.0];
154 let ray = Ray::from_array(origin, direction);
155 let hit = sphere.hit(&ray, 0.0, 100.0);
156 match hit {
157 Some(hit) => {
158 assert!(false);
159 },
160 None => {}
161 }
162 }
163
164 #[test]
165 fn bounds() {
166 let radius = 2.5;
167 let sphere = Sphere::<f64>::new(radius);
168 let bounds = sphere.get_bounds();
169 assert_eq!(bounds.get_p0().get_data(), [-2.5, -2.5, -2.5]);
170 assert_eq!(bounds.get_p1().get_data(), [2.5, 2.5, 2.5]);
171 }
172}