crater/primitives/
mesh.rs

1//! `N = 3` [`TriangleMesh`] and [`MeshCollection`] types.
2
3use crate::primitives::nvector::{NVector, cross, normalize_mut, sub};
4use serde::{Deserialize, Serialize};
5
6/// Represents a 3D vertex using a 64-bit float vector.
7pub type Vertex = NVector<3>;
8
9/// Represents a normal vector using a 64-bit float vector.
10pub type Normal = NVector<3>;
11
12#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
13/// Represents a triangle in 3D space.
14pub struct Triangle {
15    /// The three vertices that define the triangle.
16    pub vertices: [Vertex; 3],
17}
18
19impl Triangle {
20    /// Creates a new triangle from three vertices.
21    pub fn new(vertices: [Vertex; 3]) -> Self {
22        Self { vertices }
23    }
24
25    /// Calculates the normal vector of the triangle.
26    pub fn normal(&self) -> Normal {
27        let v0 = sub(&self.vertices[1], &self.vertices[0]);
28        let v1 = sub(&self.vertices[2], &self.vertices[0]);
29        let mut c = cross(&v0, &v1);
30        normalize_mut(&mut c);
31        c
32    }
33    /// Returns the first vertex of the triangle.
34    pub fn p0(&self) -> Vertex {
35        self.vertices[0]
36    }
37    /// Returns the second vertex of the triangle.
38    pub fn p1(&self) -> Vertex {
39        self.vertices[1]
40    }
41    /// Returns the third vertex of the triangle.
42    pub fn p2(&self) -> Vertex {
43        self.vertices[2]
44    }
45}
46
47/// Represents a mesh of triangles in 3D space.
48#[derive(Debug, Serialize, Deserialize, PartialEq)]
49pub struct TriangleMesh {
50    /// Normal vectors for each triangle in the mesh.
51    pub normals: Vec<Normal>,
52    /// The triangles that make up the mesh.
53    pub triangles: Vec<Triangle>,
54}
55
56impl TriangleMesh {
57    /// Creates a new triangle mesh from a vector of triangles.
58    pub fn new(triangles: Vec<Triangle>) -> Self {
59        let normals = triangles.iter().map(|triangle| triangle.normal()).collect();
60        Self { normals, triangles }
61    }
62    /// Returns the number of triangles in the mesh.
63    pub fn num_triangles(&self) -> usize {
64        self.triangles.len()
65    }
66    /// Returns the number of vertices in the mesh.
67    pub fn num_vertices(&self) -> usize {
68        self.triangles.len() * 3
69    }
70    /// Returns the number of normal vectors in the mesh.
71    pub fn num_normals(&self) -> usize {
72        self.normals.len()
73    }
74    /// A slice of all the vertices in the mesh.
75    pub fn vertices(&self) -> &[Vertex] {
76        let ptr = self.triangles.as_ptr() as *const Vertex;
77        unsafe { std::slice::from_raw_parts(ptr, self.num_vertices()) }
78    }
79    /// A slice of all the normal vectors in the mesh.
80    pub fn normals(&self) -> &[Normal] {
81        self.normals.as_slice()
82    }
83}
84
85/// Represents a collection of [`TriangleMesh`] objects.
86#[derive(Debug, Serialize, Deserialize, PartialEq)]
87pub struct MeshCollection {
88    pub meshes: Vec<TriangleMesh>,
89}
90
91impl MeshCollection {
92    /// Creates a new [`MeshCollection`] from a vector of [`TriangleMesh`]es
93    pub fn new(meshes: Vec<TriangleMesh>) -> Self {
94        Self { meshes }
95    }
96    /// Returns the number of meshes in the collection.
97    pub fn num_meshes(&self) -> usize {
98        self.meshes.len()
99    }
100    /// Returns the number of triangles in the collection.
101    pub fn num_triangles(&self) -> usize {
102        self.meshes.iter().map(|m| m.num_triangles()).sum()
103    }
104    /// Returns the number of vertices in the collection.
105    pub fn num_vertices(&self) -> usize {
106        self.meshes.iter().map(|m| m.num_vertices()).sum()
107    }
108    /// Returns the number of normal vectors in the collection.
109    pub fn num_normals(&self) -> usize {
110        self.meshes.iter().map(|m| m.num_normals()).sum()
111    }
112    /// Returns an iterator over all the [`Triangle`]s in the collection.
113    pub fn triangles(&self) -> impl Iterator<Item = Triangle> + '_ {
114        self.meshes.iter().flat_map(|m| m.triangles.iter().copied())
115    }
116    /// Returns an iterator over all the [`Vertex`] objects in the collection.
117    pub fn vertices(&self) -> impl Iterator<Item = Vertex> + '_ {
118        self.meshes
119            .iter()
120            .flat_map(|m| m.triangles.iter().flat_map(|t| t.vertices.iter().copied()))
121    }
122    /// Returns an iterator over all the [`Normal`]s in the collection.
123    pub fn normals(&self) -> impl Iterator<Item = Normal> + '_ {
124        self.meshes.iter().flat_map(|m| m.normals.iter().copied())
125    }
126    /// Returns the index of the first [`Vertex`] of a given [`TriangleMesh`] in the collection.
127    pub fn index_of_first_vertex(&self, mesh_index: usize) -> usize {
128        self.meshes
129            .iter()
130            .take(mesh_index)
131            .map(|m| m.num_vertices())
132            .sum()
133    }
134}
135
136impl From<MeshCollection> for TriangleMesh {
137    fn from(mesh_collection: MeshCollection) -> Self {
138        let normals = mesh_collection
139            .meshes
140            .iter()
141            .flat_map(|mesh| mesh.normals.clone())
142            .collect();
143        let triangles = mesh_collection
144            .meshes
145            .iter()
146            .flat_map(|mesh| mesh.triangles.clone())
147            .collect();
148        Self { normals, triangles }
149    }
150}
151
152impl From<TriangleMesh> for MeshCollection {
153    fn from(triangle_mesh: TriangleMesh) -> Self {
154        let meshes = vec![triangle_mesh];
155        Self { meshes }
156    }
157}
158
159impl<I> From<I> for MeshCollection
160where
161    I: IntoIterator<Item = TriangleMesh>,
162{
163    fn from(meshes: I) -> Self {
164        let meshes = meshes.into_iter().collect();
165        Self { meshes }
166    }
167}
168
169#[cfg(test)]
170mod tests {
171
172    use super::*;
173
174    #[test]
175    fn test_triangle_creation() {
176        let vertices = [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
177        let triangle = Triangle::new(vertices);
178        assert_eq!(triangle.vertices, vertices);
179    }
180
181    #[test]
182    fn test_triangle_normal() {
183        let triangle = Triangle::new([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]);
184        let expected_normal = [0.0, 0.0, 1.0];
185        assert_eq!(triangle.normal(), expected_normal);
186    }
187
188    #[test]
189    fn test_triangle_mesh_creation() {
190        let triangle1 = Triangle::new([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]);
191        let triangle2 = Triangle::new([[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 1.0, 0.0]]);
192        let triangles = vec![triangle1, triangle2];
193        let mesh = TriangleMesh::new(triangles);
194
195        assert_eq!(mesh.normals.len(), 2);
196        assert_eq!(mesh.triangles.len(), 2);
197    }
198
199    #[test]
200    fn test_triangle_mesh_normals() {
201        let triangle1 = Triangle::new([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]);
202        let triangle2 = Triangle::new([[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 1.0, 0.0]]);
203        let triangles = vec![triangle1, triangle2];
204        let mesh = TriangleMesh::new(triangles);
205
206        assert_eq!(mesh.normals, vec![[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]]);
207    }
208
209    #[test]
210    fn test_mesh_collection() {
211        let triangle1 = Triangle::new([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]);
212        let triangle2 = Triangle::new([[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 1.0, 0.0]]);
213        let triangles1 = vec![triangle1];
214        let triangles2 = vec![triangle2];
215        let mesh1 = TriangleMesh::new(triangles1);
216        let mesh2 = TriangleMesh::new(triangles2);
217        let collection = MeshCollection::new(vec![mesh1, mesh2]);
218        let triangles: Vec<_> = collection.triangles().collect();
219        assert_eq!(triangles.len(), 2);
220        let vertices: Vec<_> = collection.vertices().collect();
221        assert_eq!(vertices.len(), 6);
222        let normals: Vec<_> = collection.normals().collect();
223        assert_eq!(normals.len(), 2);
224    }
225}