ray_tracer/hitable/primitive/
cube.rs

1use crate::float::Float;
2use crate::vector::Vec3;
3use crate::ray::Ray;
4use crate::hit::Hit;
5use crate::hitable::Hitable;
6use crate::hitable::primitive::Rectangle;
7use crate::hitable::transform::Translation;
8use crate::hitable::primitive::Group;
9use crate::boundingbox::BoundingBox;
10use crate::constants::Axis;
11use crate::utils::axis_to_index;
12
13pub struct Cube<T>
14    where T: Float
15{
16    length: T,
17    width: T,
18    height: T,
19    faces: Group<T>
20}
21
22impl<T> Cube<T>
23    where T: Float
24{
25    pub fn new(length: T, width: T, height: T) -> Self {
26        let axes = [
27            (Axis::X, Axis::Y),
28            (Axis::Y, Axis::X),
29            (Axis::Y, Axis::Z),
30            (Axis::Z, Axis::Y),
31            (Axis::Z, Axis::X),
32            (Axis::X, Axis::Z)
33        ];
34
35        let lengths = [
36            (length, width, height),
37            (width, length, height),
38            (width, height, length),
39            (height, width, length),
40            (height, length, width),
41            (length, height, width)
42        ];
43
44        let half = T::from(0.5).unwrap();
45        let mut faces = Group::<T>::new();
46
47        for i in 0..6 {
48            let (width, height, depth) = lengths[i];
49            let (width_axis, height_axis) = axes[i];
50            let face = Box::new(Rectangle::<T>::new(width, width_axis, height, height_axis));
51            let translation = face.get_normal() * depth * half;
52            let face : Box<Hitable<T>> = Box::new(Translation::new(face, translation));
53            faces.add_hitable(face);
54        };
55
56        Cube {
57            length,
58            width,
59            height,
60            faces
61        }
62    }
63}
64
65impl<T> Hitable<T> for Cube<T>
66    where T: Float
67{
68    fn hit(&self, ray: &Ray<T>, t_min: T, t_max: T) -> Option<Hit<T>> {
69        self.faces.hit(ray, t_min, t_max)
70    }
71
72    fn get_bounds(&self) -> &BoundingBox<T> {
73        self.faces.get_bounds()
74    }
75
76    fn unwrap(self: Box<Self>) -> Box<dyn Hitable<T>> {
77        self
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84
85    #[test]
86    fn init() {
87        let length = 2.0;
88        let width = 2.0;
89        let height = 4.0;
90        let cube = Cube::<f64>::new(length, width, height);
91    }
92
93    #[test]
94    fn hit() {
95        let length = 2.0;
96        let width = 4.0;
97        let height = 6.0;
98        let cube = Cube::<f64>::new(length, width, height);
99
100        // Hit YZ faces
101        let origin = [-8.0, 0.0, 0.0];
102        let direction = [2.0, 0.0, 0.0];
103        let ray = Ray::from_array(origin, direction);
104        let hit = cube.hit(&ray, 0.0, 100.0);
105
106        match hit {
107            Some(hit) => {
108                assert_eq!(hit.point.get_data(), [-1.0, 0.0, 0.0]);
109                assert_eq!(hit.normal.get_data(), [-1.0, 0.0, 0.0]);
110                assert_eq!(hit.t, 3.5);
111            },
112            None => {
113                assert!(false);
114            }
115        }
116
117        let origin = [8.0, 0.0, 0.0];
118        let direction = [-2.0, 0.0, 0.0];
119        let ray = Ray::from_array(origin, direction);
120        let hit = cube.hit(&ray, 0.0, 100.0);
121
122        match hit {
123            Some(hit) => {
124                assert_eq!(hit.point.get_data(), [1.0, 0.0, 0.0]);
125                assert_eq!(hit.normal.get_data(), [1.0, 0.0, 0.0]);
126                assert_eq!(hit.t, 3.5);
127            },
128            None => {
129                assert!(false);
130            }
131        }
132
133        // Hit XZ faces
134        let origin = [0.0, -8.0, 0.0];
135        let direction = [0.0, 2.0, 0.0];
136        let ray = Ray::from_array(origin, direction);
137        let hit = cube.hit(&ray, 0.0, 100.0);
138
139        match hit {
140            Some(hit) => {
141                assert_eq!(hit.point.get_data(), [0.0, -2.0, 0.0]);
142                assert_eq!(hit.normal.get_data(), [0.0, -1.0, 0.0]);
143                assert_eq!(hit.t, 3.0);
144            },
145            None => {
146                assert!(false);
147            }
148        }
149
150        let origin = [0.0, 8.0, 0.0];
151        let direction = [0.0, -2.0, 0.0];
152        let ray = Ray::from_array(origin, direction);
153        let hit = cube.hit(&ray, 0.0, 100.0);
154
155        match hit {
156            Some(hit) => {
157                assert_eq!(hit.point.get_data(), [0.0, 2.0, 0.0]);
158                assert_eq!(hit.normal.get_data(), [0.0, 1.0, 0.0]);
159                assert_eq!(hit.t, 3.0);
160            },
161            None => {
162                assert!(false);
163            }
164        }
165
166        // Hit XY faces
167        let origin = [0.0, 0.0, 8.0];
168        let direction = [0.0, 0.0, -2.0];
169        let ray = Ray::from_array(origin, direction);
170        let hit = cube.hit(&ray, 0.0, 100.0);
171
172        match hit {
173            Some(hit) => {
174                assert_eq!(hit.point.get_data(), [0.0, 0.0, 3.0]);
175                assert_eq!(hit.normal.get_data(), [0.0, 0.0, 1.0]);
176                assert_eq!(hit.t, 2.5);
177            },
178            None => {
179                assert!(false);
180            }
181        }
182
183        let origin = [0.0, 0.0, -8.0];
184        let direction = [0.0, 0.0, 2.0];
185        let ray = Ray::from_array(origin, direction);
186        let hit = cube.hit(&ray, 0.0, 100.0);
187
188        match hit {
189            Some(hit) => {
190                assert_eq!(hit.point.get_data(), [0.0, 0.0, -3.0]);
191                assert_eq!(hit.normal.get_data(), [0.0, 0.0, -1.0]);
192                assert_eq!(hit.t, 2.5);
193            },
194            None => {
195                assert!(false);
196            }
197        }
198
199        // Hit nothing
200        let origin = [-8.0, 0.0, 0.0];
201        let direction = [-2.0, 0.0, 0.0];
202        let ray = Ray::from_array(origin, direction);
203        let hit = cube.hit(&ray, 0.0, 100.0);
204
205        if let Some(_hit) = hit {
206            assert!(false);
207        }
208
209        let origin = [-8.0, 2.001, 3.001];
210        let direction = [2.0, 0.0, 0.0];
211        let ray = Ray::from_array(origin, direction);
212        let hit = cube.hit(&ray, 0.0, 100.0);
213
214        if let Some(_hit) = hit {
215            assert!(false);
216        }
217    }
218
219    #[test]
220    fn bounds() {
221        let length = 2.0;
222        let width = 4.0;
223        let height = 6.0;
224        let cube = Cube::<f64>::new(length, width, height);
225        let bounds = cube.get_bounds();
226        assert_eq!(bounds.get_p0().get_data(), [-1.0, -2.0, -3.0]);
227        assert_eq!(bounds.get_p1().get_data(), [1.0, 2.0, 3.0]);
228    }
229}