use nalgebra::{Matrix3, Matrix4, RealField, Unit};
use num_traits::ToPrimitive;
use crate::{
error::{Result, TransformationError},
geometry::{Aabb, Mesh},
rt::{Hit, Ray},
traits::{Bounded, Traceable},
};
#[derive(Debug)]
pub struct Instance<'a, T: RealField + Copy> {
mesh: &'a Mesh<T>,
world_to_object: Matrix4<T>,
object_to_world: Matrix4<T>,
world_aabb: Aabb<T>,
normal_transform: Matrix3<T>,
}
impl<'a, T: RealField + Copy + ToPrimitive> Instance<'a, T> {
pub fn new(mesh: &'a Mesh<T>, transform: Matrix4<T>) -> Result<Self> {
let world_to_object = transform.try_inverse().ok_or(TransformationError::NonInvertibleMatrix)?;
let object_to_world = transform;
let world_aabb = mesh.aabb()?.transform(&object_to_world)?;
let upper_3x3 = world_to_object.fixed_view::<3, 3>(0, 0);
let normal_transform = upper_3x3.transpose();
Ok(Self {
mesh,
world_to_object,
object_to_world,
world_aabb,
normal_transform,
})
}
pub const fn mesh(&self) -> &Mesh<T> {
self.mesh
}
pub const fn world_aabb(&self) -> &Aabb<T> {
&self.world_aabb
}
fn transform_ray_to_object_space(&self, ray: &Ray<T>) -> Ray<T> {
let object_origin = self.world_to_object.transform_point(&ray.origin);
let object_direction_vector = self.world_to_object.transform_vector(&ray.direction);
let object_direction = Unit::new_normalize(object_direction_vector);
Ray::new(object_origin, object_direction)
}
fn transform_hit_to_world_space(&self, hit: &mut Hit<T>, world_ray: &Ray<T>, object_ray: &Ray<T>) {
let world_geometric_normal_vector = self.normal_transform * hit.geometric_normal.as_ref();
hit.geometric_normal = Unit::new_normalize(world_geometric_normal_vector);
let world_interpolated_normal_vector = self.normal_transform * hit.interpolated_normal.as_ref();
hit.interpolated_normal = Unit::new_normalize(world_interpolated_normal_vector);
let object_hit_point = object_ray.origin + object_ray.direction.scale(hit.distance);
let world_hit_point = self.object_to_world.transform_point(&object_hit_point);
let to_hit = world_hit_point - world_ray.origin;
hit.distance = to_hit.dot(&world_ray.direction);
}
}
impl<T: RealField + Copy + ToPrimitive> Traceable<T> for Instance<'_, T> {
fn intersect(&self, ray: &Ray<T>) -> Result<Option<Hit<T>>> {
let object_ray = self.transform_ray_to_object_space(ray);
(self.mesh.intersect(&object_ray)?).map_or(Ok(None), |mut hit| {
self.transform_hit_to_world_space(&mut hit, ray, &object_ray);
Ok(Some(hit))
})
}
fn intersect_any(&self, ray: &Ray<T>, max_distance: T) -> Result<bool> {
let object_ray = self.transform_ray_to_object_space(ray);
let world_endpoint = ray.origin + ray.direction.scale(max_distance);
let object_endpoint = self.world_to_object.transform_point(&world_endpoint);
let object_max_distance = (object_endpoint - object_ray.origin).norm();
self.mesh.intersect_any(&object_ray, object_max_distance)
}
}