1use crate::ray::{HitRecord, Ray};
4use crate::MaterialId;
5use nalgebra::{Point3, Unit, Vector3};
6
7#[derive(Debug, Clone)]
9pub enum Primitive {
10 Plane {
12 point: Point3<f64>,
13 normal: Unit<Vector3<f64>>,
14 },
15
16 AaBox { min: Point3<f64>, max: Point3<f64> },
18
19 Cylinder {
21 center: Point3<f64>,
22 axis: Unit<Vector3<f64>>,
23 radius: f64,
24 half_height: f64,
25 capped: bool,
26 },
27
28 Sheet {
30 center: Point3<f64>,
31 normal: Unit<Vector3<f64>>,
32 u_axis: Unit<Vector3<f64>>,
33 half_width: f64,
34 half_height: f64,
35 thickness: f64,
36 },
37}
38
39#[derive(Debug, Clone)]
41pub struct SceneObject {
42 pub primitive: Primitive,
43 pub material: MaterialId,
44 pub label: String,
45}
46
47impl Primitive {
48 pub fn intersect(
51 &self,
52 ray: &Ray,
53 t_min: f64,
54 t_max: f64,
55 material_id: MaterialId,
56 ) -> Option<HitRecord> {
57 match self {
58 Primitive::Plane { point, normal } => {
59 intersect_plane(ray, point, normal, t_min, t_max, material_id)
60 }
61 Primitive::AaBox { min, max } => {
62 intersect_aa_box(ray, min, max, t_min, t_max, material_id)
63 }
64 Primitive::Cylinder {
65 center,
66 axis,
67 radius,
68 half_height,
69 capped,
70 } => intersect_cylinder(
71 ray,
72 center,
73 axis,
74 *radius,
75 *half_height,
76 *capped,
77 t_min,
78 t_max,
79 material_id,
80 ),
81 Primitive::Sheet {
82 center,
83 normal,
84 u_axis,
85 half_width,
86 half_height,
87 ..
88 } => intersect_sheet(
89 ray,
90 center,
91 normal,
92 u_axis,
93 *half_width,
94 *half_height,
95 t_min,
96 t_max,
97 material_id,
98 ),
99 }
100 }
101}
102
103fn intersect_plane(
108 ray: &Ray,
109 point: &Point3<f64>,
110 normal: &Unit<Vector3<f64>>,
111 t_min: f64,
112 t_max: f64,
113 material_id: MaterialId,
114) -> Option<HitRecord> {
115 let denom = ray.direction.dot(normal.as_ref());
116 if denom.abs() < 1e-10 {
117 return None; }
119 let t = (point - ray.origin).dot(normal.as_ref()) / denom;
120 if t < t_min || t > t_max {
121 return None;
122 }
123 let mut hit = HitRecord {
124 t,
125 point: ray.at(t),
126 normal: *normal,
127 front_face: true,
128 material: material_id,
129 };
130 hit.set_face_normal(ray, *normal);
131 Some(hit)
132}
133
134fn intersect_aa_box(
139 ray: &Ray,
140 min: &Point3<f64>,
141 max: &Point3<f64>,
142 t_min: f64,
143 t_max: f64,
144 material_id: MaterialId,
145) -> Option<HitRecord> {
146 let mut tmin = t_min;
147 let mut tmax = t_max;
148 let mut hit_axis = 0usize;
149
150 for i in 0..3 {
151 let inv_d = 1.0 / ray.direction[i];
152 let mut t0 = (min[i] - ray.origin[i]) * inv_d;
153 let mut t1 = (max[i] - ray.origin[i]) * inv_d;
154 if inv_d < 0.0 {
155 std::mem::swap(&mut t0, &mut t1);
156 }
157 if t0 > tmin {
158 tmin = t0;
159 hit_axis = i;
160 }
161 if t1 < tmax {
162 tmax = t1;
163 }
164 if tmax < tmin {
165 return None;
166 }
167 }
168
169 let t = tmin;
170 let point = ray.at(t);
171 let mut normal = Vector3::zeros();
172 normal[hit_axis] = if ray.direction[hit_axis] < 0.0 {
173 1.0
174 } else {
175 -1.0
176 };
177 let outward_normal = Unit::new_normalize(normal);
178
179 let mut hit = HitRecord {
180 t,
181 point,
182 normal: outward_normal,
183 front_face: true,
184 material: material_id,
185 };
186 hit.set_face_normal(ray, outward_normal);
187 Some(hit)
188}
189
190#[allow(clippy::too_many_arguments)]
195fn intersect_cylinder(
196 ray: &Ray,
197 center: &Point3<f64>,
198 axis: &Unit<Vector3<f64>>,
199 radius: f64,
200 half_height: f64,
201 capped: bool,
202 t_min: f64,
203 t_max: f64,
204 material_id: MaterialId,
205) -> Option<HitRecord> {
206 let oc = ray.origin - center;
207 let d = ray.direction.as_ref();
208 let a_vec = axis.as_ref();
209
210 let d_perp = d - d.dot(a_vec) * a_vec;
212 let oc_perp = oc - oc.dot(a_vec) * a_vec;
213
214 let a = d_perp.dot(&d_perp);
215 let b = 2.0 * d_perp.dot(&oc_perp);
216 let c = oc_perp.dot(&oc_perp) - radius * radius;
217
218 let discriminant = b * b - 4.0 * a * c;
219 if discriminant < 0.0 {
220 return None;
221 }
222
223 let sqrt_d = discriminant.sqrt();
224 let mut best_t = None;
225 let mut best_normal = Vector3::zeros();
226
227 for sign in [-1.0, 1.0] {
229 let t = (-b + sign * sqrt_d) / (2.0 * a);
230 if t < t_min || t > t_max {
231 continue;
232 }
233 let p = ray.at(t);
234 let h = (p - center).dot(a_vec);
235 if h.abs() <= half_height && (best_t.is_none() || t < best_t.unwrap()) {
236 let n = (p - center) - h * a_vec;
237 best_t = Some(t);
238 best_normal = n / radius;
239 }
240 }
241
242 if capped {
244 for sign in [-1.0f64, 1.0] {
245 let cap_center = center + sign * half_height * a_vec;
246 let cap_normal = sign * a_vec;
247 let denom = d.dot(&cap_normal);
248 if denom.abs() < 1e-10 {
249 continue;
250 }
251 let t = (cap_center - ray.origin).dot(&cap_normal) / denom;
252 if t < t_min || t > t_max {
253 continue;
254 }
255 let p = ray.at(t);
256 let diff = p - cap_center;
257 let dist2 = diff.dot(&diff) - diff.dot(&cap_normal).powi(2);
258 if dist2 <= radius * radius && (best_t.is_none() || t < best_t.unwrap()) {
259 best_t = Some(t);
260 best_normal = cap_normal;
261 }
262 }
263 }
264
265 let t = best_t?;
266 let outward_normal = Unit::new_normalize(best_normal);
267 let mut hit = HitRecord {
268 t,
269 point: ray.at(t),
270 normal: outward_normal,
271 front_face: true,
272 material: material_id,
273 };
274 hit.set_face_normal(ray, outward_normal);
275 Some(hit)
276}
277
278#[allow(clippy::too_many_arguments)]
283fn intersect_sheet(
284 ray: &Ray,
285 center: &Point3<f64>,
286 normal: &Unit<Vector3<f64>>,
287 u_axis: &Unit<Vector3<f64>>,
288 half_width: f64,
289 half_height: f64,
290 t_min: f64,
291 t_max: f64,
292 material_id: MaterialId,
293) -> Option<HitRecord> {
294 let denom = ray.direction.dot(normal.as_ref());
295 if denom.abs() < 1e-10 {
296 return None;
297 }
298 let t = (center - ray.origin).dot(normal.as_ref()) / denom;
299 if t < t_min || t > t_max {
300 return None;
301 }
302
303 let point = ray.at(t);
304 let local = point - center;
305 let v_axis = normal.cross(u_axis.as_ref());
306
307 let u = local.dot(u_axis.as_ref());
308 let v = local.dot(&v_axis);
309
310 if u.abs() > half_width || v.abs() > half_height {
311 return None;
312 }
313
314 let mut hit = HitRecord {
315 t,
316 point,
317 normal: *normal,
318 front_face: true,
319 material: material_id,
320 };
321 hit.set_face_normal(ray, *normal);
322 Some(hit)
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn ray_hits_plane() {
331 let ray = Ray::new(
332 Point3::new(0.0, 0.0, 1.0),
333 Unit::new_normalize(Vector3::new(0.0, 0.0, -1.0)),
334 );
335 let hit = intersect_plane(
336 &ray,
337 &Point3::origin(),
338 &Vector3::z_axis(),
339 0.001,
340 f64::INFINITY,
341 0,
342 );
343 assert!(hit.is_some());
344 let hit = hit.unwrap();
345 assert!((hit.t - 1.0).abs() < 1e-6);
346 }
347
348 #[test]
349 fn ray_misses_plane_parallel() {
350 let ray = Ray::new(
351 Point3::new(0.0, 0.0, 1.0),
352 Unit::new_normalize(Vector3::new(1.0, 0.0, 0.0)),
353 );
354 let hit = intersect_plane(
355 &ray,
356 &Point3::origin(),
357 &Vector3::z_axis(),
358 0.001,
359 f64::INFINITY,
360 0,
361 );
362 assert!(hit.is_none());
363 }
364
365 #[test]
366 fn ray_hits_sheet() {
367 let ray = Ray::new(
368 Point3::new(0.0, 0.0, 1.0),
369 Unit::new_normalize(Vector3::new(0.0, 0.0, -1.0)),
370 );
371 let hit = intersect_sheet(
372 &ray,
373 &Point3::origin(),
374 &Vector3::z_axis(),
375 &Vector3::x_axis(),
376 0.5,
377 0.5,
378 0.001,
379 f64::INFINITY,
380 0,
381 );
382 assert!(hit.is_some());
383 }
384
385 #[test]
386 fn ray_misses_sheet_outside_bounds() {
387 let ray = Ray::new(
388 Point3::new(2.0, 0.0, 1.0),
389 Unit::new_normalize(Vector3::new(0.0, 0.0, -1.0)),
390 );
391 let hit = intersect_sheet(
392 &ray,
393 &Point3::origin(),
394 &Vector3::z_axis(),
395 &Vector3::x_axis(),
396 0.5,
397 0.5,
398 0.001,
399 f64::INFINITY,
400 0,
401 );
402 assert!(hit.is_none());
403 }
404
405 #[test]
406 fn ray_hits_aa_box() {
407 let ray = Ray::new(
408 Point3::new(0.0, 0.0, 2.0),
409 Unit::new_normalize(Vector3::new(0.0, 0.0, -1.0)),
410 );
411 let hit = intersect_aa_box(
412 &ray,
413 &Point3::new(-1.0, -1.0, -1.0),
414 &Point3::new(1.0, 1.0, 1.0),
415 0.001,
416 f64::INFINITY,
417 0,
418 );
419 assert!(hit.is_some());
420 let hit = hit.unwrap();
421 assert!((hit.t - 1.0).abs() < 1e-6);
422 }
423}