crystal_ball 0.3.0

A path tracing library written in Rust.
Documentation
use std::sync::Arc;

use crate::math::{Bounds3, Hit, Point2, Point3, Ray, Vec3, Vec4};
use crate::shapes::triangle_mesh::TriangleMesh;
use crate::shapes::Shape;

/// A single triangle from a [`TriangleMesh`].
#[derive(Clone)]
pub struct Triangle {
    /// The triangle's corresponding [`TriangleMesh`].
    pub triangle_mesh: Arc<TriangleMesh>,
    /// The triangle's index in the [`TriangleMesh`].
    pub index: usize,
}

impl Triangle {
    /// Create a new [`Triangle`].
    pub fn new(triangle_mesh: Arc<TriangleMesh>, index: usize) -> Self {
        Triangle {
            triangle_mesh,
            index,
        }
    }

    /// Return the triangle's vertices.
    pub fn vertices(&self) -> [Point3; 3] {
        let mesh_index = self.index * 3;
        let indices = &self.triangle_mesh.indices[mesh_index..mesh_index + 3];

        [
            self.triangle_mesh.vertices[indices[0] as usize],
            self.triangle_mesh.vertices[indices[1] as usize],
            self.triangle_mesh.vertices[indices[2] as usize],
        ]
    }

    /// Return the triangle's normals.
    ///
    /// If the [`TriangleMesh`] doesn't provide any normals,
    /// they are automatically calculated assuming flat shading.
    pub fn normals(&self) -> [Vec3; 3] {
        let mesh_index = self.index * 3;
        let indices = &self.triangle_mesh.indices[mesh_index..mesh_index + 3];

        if self.triangle_mesh.normals.is_empty() {
            let normal = self.face_normal();

            [normal; 3]
        } else {
            [
                self.triangle_mesh.normals[indices[0] as usize],
                self.triangle_mesh.normals[indices[1] as usize],
                self.triangle_mesh.normals[indices[2] as usize],
            ]
        }
    }

    /// Return the triangle's tangents.
    ///
    /// Returns [`None`] if the [`TriangleMesh`] doesn't provide any tangents.
    pub fn tangents(&self) -> Option<[Vec4; 3]> {
        let mesh_index = self.index * 3;
        let indices = &self.triangle_mesh.indices[mesh_index..mesh_index + 3];

        if self.triangle_mesh.tangents.is_empty() {
            None
        } else {
            Some([
                self.triangle_mesh.tangents[indices[0] as usize],
                self.triangle_mesh.tangents[indices[1] as usize],
                self.triangle_mesh.tangents[indices[2] as usize],
            ])
        }
    }

    /// Return the triangle's UV coordinates.
    ///
    /// Returns [`None`] if the [`TriangleMesh`] doesn't provide UV coordinates.
    pub fn uvs(&self) -> Option<[Point2; 3]> {
        let mesh_index = self.index * 3;
        let indices = &self.triangle_mesh.indices[mesh_index..mesh_index + 3];

        if self.triangle_mesh.uvs.is_empty() {
            None
        } else {
            Some([
                self.triangle_mesh.uvs[indices[0] as usize],
                self.triangle_mesh.uvs[indices[1] as usize],
                self.triangle_mesh.uvs[indices[2] as usize],
            ])
        }
    }

    /// Calculate the normal of the triangle's face (flat normal).
    pub fn face_normal(&self) -> Vec3 {
        let vertices = self.vertices();

        let edge1 = vertices[1] - vertices[0];
        let edge2 = vertices[2] - vertices[0];

        Vec3::cross(edge1, edge2).normalize()
    }

    /// Calculate the triangle's circumcenter.
    pub fn circumcenter(&self) -> Point3 {
        let vertices = self.vertices();

        let u = vertices[1] - vertices[0];
        let v = vertices[2] - vertices[1];
        let w = vertices[0] - vertices[2];
        let n = Vec3::cross(u, v);

        (vertices[0] + vertices[1]) * 0.5
            - Vec3::dot(v, w) * Vec3::cross(n, u) / n.magnitude_squared() * 0.5
    }

    /// Calculate the triangle's vertex mean.
    pub fn vertex_mean(&self) -> Point3 {
        let vertices = self.vertices();

        (vertices[0] + vertices[1] + vertices[2]) / 3.0
    }
}

impl Shape for Triangle {
    fn intersects(&self, ray: Ray) -> Option<Hit> {
        let vertices = self.vertices();
        let normals = self.normals();
        let tangents = self.tangents();
        let uvs = self.uvs();

        let (edge1, edge2) = (vertices[1] - vertices[0], vertices[2] - vertices[0]);

        let p_vec = Vec3::cross(ray.direction, edge2);
        let determinant = Vec3::dot(edge1, p_vec);

        if determinant.abs() < 1e-9 {
            return None;
        }

        let inverse_determinant = 1.0 / determinant;

        let t_vec = ray.origin - vertices[0];
        let u = inverse_determinant * Vec3::dot(t_vec, p_vec);

        if !(0.0..=1.0).contains(&u) {
            return None;
        }

        let q_vec = Vec3::cross(t_vec, edge1);
        let v = inverse_determinant * Vec3::dot(ray.direction, q_vec);

        if v < 0.0 || u + v > 1.0 {
            return None;
        }

        let intersection_distance = inverse_determinant * Vec3::dot(edge2, q_vec);

        if intersection_distance < 0.0 {
            return None;
        }

        let intersection_point = ray.get(intersection_distance);

        let w = 1.0 - u - v;
        // Note that the Möller–Trumbore algorithm (https://cadxfem.org/inf/Fast%20MinimumStorage%20RayTriangle%20Intersection.pdf)
        // describes a point T(u,v) by
        // T(u, v) = (1 - u - v) V0 + u V1 + v V2
        // Because u + v + w = 1 and therefore w = 1 - u - v we interpolate by
        // N(u, v, w) = w N0 + u N1 + v N2
        let mut normal = (w * normals[0] + u * normals[1] + v * normals[2]).normalize();
        let mut tangent = tangents.map(|t| (w * t[0] + u * t[1] + v * t[2]).normalize());

        let uv = uvs
            .map(|uvs| {
                // Same explanation for interpolation as above
                w * uvs[0] + u * uvs[1] + v * uvs[2]
            })
            // TODO: Generate in glTF loader when missing?
            // TODO: use standard triangle UVs
            .unwrap_or_default();

        if Vec3::dot(ray.direction, normal) > 0.0 {
            normal = -normal;
            tangent = tangent.map(|t| -t);
        }

        Some(Hit::new(
            intersection_point,
            normal,
            tangent,
            intersection_distance,
            uv,
        ))
    }

    fn bounds(&self) -> Bounds3 {
        let vertices = self.vertices();

        Bounds3::new(vertices[0], vertices[1]).include_point(vertices[2])
    }
}

#[cfg(test)]
mod tests {
    use std::f64::consts::{PI, TAU};

    use assert_approx_eq::assert_approx_eq;

    use crate::math::{Point3, Transformable, Vec3};
    use crate::util::EPSILON_F64;

    use super::TriangleMesh;

    #[test]
    fn rotate_triangle() {
        let indices = vec![0, 1, 2];
        let vertices = vec![
            Point3::new(-1.0, -0.5, -1.0),
            Point3::new(0.0, 1.0, -1.0),
            Point3::new(1.0, -0.5, 2.0),
        ];
        let normals = vec![Vec3::cross(vertices[1] - vertices[0], vertices[2] - vertices[0]); 3];
        let tangents = vec![];
        let uvs = vec![];

        let triangle_mesh = TriangleMesh::new(indices, vertices.clone(), normals, tangents, uvs);

        let rotated_x_triangle = triangle_mesh.clone().rotate_x(TAU);
        let rotated_y_triangle = triangle_mesh.clone().rotate_y(TAU);
        let rotated_z_triangle = triangle_mesh.clone().rotate_z(TAU);

        let rotated_triangle = triangle_mesh.clone().rotate_x(PI).rotate_y(PI).rotate_z(PI);

        assert_approx_eq!(vertices[0], rotated_x_triangle.vertices[0], EPSILON_F64);
        assert_approx_eq!(vertices[1], rotated_x_triangle.vertices[1], EPSILON_F64);
        assert_approx_eq!(vertices[2], rotated_x_triangle.vertices[2], EPSILON_F64);
        assert_approx_eq!(vertices[0], rotated_y_triangle.vertices[0], EPSILON_F64);
        assert_approx_eq!(vertices[1], rotated_y_triangle.vertices[1], EPSILON_F64);
        assert_approx_eq!(vertices[2], rotated_y_triangle.vertices[2], EPSILON_F64);
        assert_approx_eq!(vertices[0], rotated_z_triangle.vertices[0], EPSILON_F64);
        assert_approx_eq!(vertices[1], rotated_z_triangle.vertices[1], EPSILON_F64);
        assert_approx_eq!(vertices[2], rotated_z_triangle.vertices[2], EPSILON_F64);
        assert_approx_eq!(vertices[0], rotated_triangle.vertices[0], EPSILON_F64);
        assert_approx_eq!(vertices[1], rotated_triangle.vertices[1], EPSILON_F64);
        assert_approx_eq!(vertices[2], rotated_triangle.vertices[2], EPSILON_F64);
    }
}