1use ghostflow_core::Tensor;
11use crate::linear::Linear;
12use crate::Module;
13
14#[derive(Debug, Clone)]
16pub struct Mesh {
17 pub vertices: Tensor,
19 pub faces: Vec<[usize; 3]>,
21 pub features: Option<Tensor>,
23}
24
25impl Mesh {
26 pub fn new(vertices: Tensor, faces: Vec<[usize; 3]>) -> Self {
28 Mesh {
29 vertices,
30 faces,
31 features: None,
32 }
33 }
34
35 pub fn with_features(vertices: Tensor, faces: Vec<[usize; 3]>, features: Tensor) -> Self {
37 Mesh {
38 vertices,
39 faces,
40 features: Some(features),
41 }
42 }
43
44 pub fn num_vertices(&self) -> usize {
46 self.vertices.dims()[0]
47 }
48
49 pub fn num_faces(&self) -> usize {
51 self.faces.len()
52 }
53
54 pub fn compute_adjacency(&self) -> Vec<Vec<usize>> {
56 let num_verts = self.num_vertices();
57 let mut adjacency = vec![Vec::new(); num_verts];
58
59 for face in &self.faces {
60 adjacency[face[0]].push(face[1]);
62 adjacency[face[0]].push(face[2]);
63 adjacency[face[1]].push(face[0]);
64 adjacency[face[1]].push(face[2]);
65 adjacency[face[2]].push(face[0]);
66 adjacency[face[2]].push(face[1]);
67 }
68
69 for neighbors in &mut adjacency {
71 neighbors.sort_unstable();
72 neighbors.dedup();
73 }
74
75 adjacency
76 }
77
78 pub fn compute_face_normals(&self) -> Result<Tensor, String> {
80 let verts_data = self.vertices.data_f32();
81 let mut normals = Vec::with_capacity(self.num_faces() * 3);
82
83 for face in &self.faces {
84 let v0 = &verts_data[face[0] * 3..face[0] * 3 + 3];
85 let v1 = &verts_data[face[1] * 3..face[1] * 3 + 3];
86 let v2 = &verts_data[face[2] * 3..face[2] * 3 + 3];
87
88 let e1 = [v1[0] - v0[0], v1[1] - v0[1], v1[2] - v0[2]];
90 let e2 = [v2[0] - v0[0], v2[1] - v0[1], v2[2] - v0[2]];
91
92 let normal = [
94 e1[1] * e2[2] - e1[2] * e2[1],
95 e1[2] * e2[0] - e1[0] * e2[2],
96 e1[0] * e2[1] - e1[1] * e2[0],
97 ];
98
99 let length = (normal[0] * normal[0] + normal[1] * normal[1] + normal[2] * normal[2]).sqrt();
101 if length > 1e-8 {
102 normals.push(normal[0] / length);
103 normals.push(normal[1] / length);
104 normals.push(normal[2] / length);
105 } else {
106 normals.extend_from_slice(&[0.0, 0.0, 1.0]);
107 }
108 }
109
110 Tensor::from_slice(&normals, &[self.num_faces(), 3])
111 .map_err(|e| format!("Failed to create normals: {:?}", e))
112 }
113}
114
115pub struct MeshConv {
117 in_features: usize,
118 out_features: usize,
119 weight: Linear,
120}
121
122impl MeshConv {
123 pub fn new(in_features: usize, out_features: usize) -> Self {
125 MeshConv {
126 in_features,
127 out_features,
128 weight: Linear::new(in_features, out_features),
129 }
130 }
131
132 pub fn forward(&self, features: &Tensor, adjacency: &[Vec<usize>]) -> Result<Tensor, String> {
134 let feat_data = features.data_f32();
135 let dims = features.dims();
136 let num_vertices = dims[0];
137
138 let mut aggregated = Vec::with_capacity(num_vertices * self.in_features);
140
141 for v in 0..num_vertices {
142 let neighbors = &adjacency[v];
143
144 if neighbors.is_empty() {
145 let start = v * self.in_features;
147 aggregated.extend_from_slice(&feat_data[start..start + self.in_features]);
148 } else {
149 let mut avg_features = vec![0.0; self.in_features];
151
152 for &neighbor in neighbors {
153 let start = neighbor * self.in_features;
154 for i in 0..self.in_features {
155 avg_features[i] += feat_data[start + i];
156 }
157 }
158
159 let num_neighbors = neighbors.len() as f32;
160 for feat in &mut avg_features {
161 *feat /= num_neighbors;
162 }
163
164 aggregated.extend_from_slice(&avg_features);
165 }
166 }
167
168 let aggregated_tensor = Tensor::from_slice(&aggregated, &[num_vertices, self.in_features])
169 .map_err(|e| format!("Failed to create aggregated tensor: {:?}", e))?;
170
171 Ok(self.weight.forward(&aggregated_tensor))
173 }
174}
175
176pub struct MeshPool;
178
179impl MeshPool {
180 pub fn pool(mesh: &Mesh, stride: usize) -> Result<Mesh, String> {
182 let verts_data = mesh.vertices.data_f32();
183 let num_verts = mesh.num_vertices();
184
185 let mut new_verts = Vec::new();
187 let mut vertex_map = vec![None; num_verts];
188 let mut new_idx = 0;
189
190 for i in (0..num_verts).step_by(stride) {
191 vertex_map[i] = Some(new_idx);
192 new_verts.extend_from_slice(&verts_data[i * 3..i * 3 + 3]);
193 new_idx += 1;
194 }
195
196 let mut new_faces = Vec::new();
198 for face in &mesh.faces {
199 if let (Some(v0), Some(v1), Some(v2)) = (vertex_map[face[0]], vertex_map[face[1]], vertex_map[face[2]]) {
200 new_faces.push([v0, v1, v2]);
201 }
202 }
203
204 let new_vertices = Tensor::from_slice(&new_verts, &[new_idx, 3])
205 .map_err(|e| format!("Failed to create pooled vertices: {:?}", e))?;
206
207 Ok(Mesh::new(new_vertices, new_faces))
208 }
209}
210
211pub struct MeshEncoder {
213 conv1: MeshConv,
214 conv2: MeshConv,
215 conv3: MeshConv,
216 fc: Linear,
217}
218
219impl MeshEncoder {
220 pub fn new(input_dim: usize, hidden_dim: usize, output_dim: usize) -> Self {
222 MeshEncoder {
223 conv1: MeshConv::new(input_dim, hidden_dim),
224 conv2: MeshConv::new(hidden_dim, hidden_dim * 2),
225 conv3: MeshConv::new(hidden_dim * 2, hidden_dim * 4),
226 fc: Linear::new(hidden_dim * 4, output_dim),
227 }
228 }
229
230 pub fn forward(&self, mesh: &Mesh) -> Result<Tensor, String> {
232 let adjacency = mesh.compute_adjacency();
233
234 let features = if let Some(ref feat) = mesh.features {
236 feat.clone()
237 } else {
238 mesh.vertices.clone()
239 };
240
241 let mut x = self.conv1.forward(&features, &adjacency)?;
243 x = x.relu();
244
245 x = self.conv2.forward(&x, &adjacency)?;
246 x = x.relu();
247
248 x = self.conv3.forward(&x, &adjacency)?;
249 x = x.relu();
250
251 let pooled = self.global_max_pool(&x)?;
253
254 Ok(self.fc.forward(&pooled))
256 }
257
258 fn global_max_pool(&self, x: &Tensor) -> Result<Tensor, String> {
259 let data = x.data_f32();
260 let dims = x.dims();
261 let num_vertices = dims[0];
262 let feature_dim = dims[1];
263
264 let mut result = vec![f32::NEG_INFINITY; feature_dim];
265
266 for v in 0..num_vertices {
267 for f in 0..feature_dim {
268 let val = data[v * feature_dim + f];
269 result[f] = result[f].max(val);
270 }
271 }
272
273 Tensor::from_slice(&result, &[1, feature_dim])
274 .map_err(|e| format!("Failed to pool: {:?}", e))
275 }
276}
277
278pub struct MeshUtils;
280
281impl MeshUtils {
282 pub fn create_cube() -> Mesh {
284 let vertices = vec![
285 -1.0f32, -1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, ];
294
295 let faces = vec![
296 [0, 1, 2], [0, 2, 3],
298 [4, 6, 5], [4, 7, 6],
300 [0, 3, 7], [0, 7, 4],
302 [1, 5, 6], [1, 6, 2],
304 [3, 2, 6], [3, 6, 7],
306 [0, 4, 5], [0, 5, 1],
308 ];
309
310 let verts_tensor = Tensor::from_slice(&vertices, &[8, 3]).unwrap();
311 Mesh::new(verts_tensor, faces)
312 }
313
314 pub fn create_tetrahedron() -> Mesh {
316 let vertices = vec![
317 0.0f32, 0.0, 0.0,
318 1.0, 0.0, 0.0,
319 0.5, 1.0, 0.0,
320 0.5, 0.5, 1.0,
321 ];
322
323 let faces = vec![
324 [0, 1, 2],
325 [0, 1, 3],
326 [0, 2, 3],
327 [1, 2, 3],
328 ];
329
330 let verts_tensor = Tensor::from_slice(&vertices, &[4, 3]).unwrap();
331 Mesh::new(verts_tensor, faces)
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_mesh_creation() {
341 let cube = MeshUtils::create_cube();
342 assert_eq!(cube.num_vertices(), 8);
343 assert_eq!(cube.num_faces(), 12);
344 }
345
346 #[test]
347 fn test_mesh_adjacency() {
348 let cube = MeshUtils::create_cube();
349 let adjacency = cube.compute_adjacency();
350
351 assert_eq!(adjacency.len(), 8);
352 for neighbors in &adjacency {
354 assert!(neighbors.len() >= 3, "Each vertex should have at least 3 neighbors");
355 }
356 }
357
358 #[test]
359 fn test_face_normals() {
360 let cube = MeshUtils::create_cube();
361 let normals = cube.compute_face_normals().unwrap();
362 assert_eq!(normals.dims(), &[12, 3]); }
364
365 #[test]
366 fn test_mesh_conv() {
367 let conv = MeshConv::new(3, 16);
368 let cube = MeshUtils::create_cube();
369 let adjacency = cube.compute_adjacency();
370
371 let output = conv.forward(&cube.vertices, &adjacency).unwrap();
372 assert_eq!(output.dims(), &[8, 16]); }
374
375 #[test]
376 fn test_mesh_pool() {
377 let cube = MeshUtils::create_cube();
378 let pooled = MeshPool::pool(&cube, 2).unwrap();
379
380 assert!(pooled.num_vertices() <= cube.num_vertices());
381 assert!(pooled.num_faces() <= cube.num_faces());
382 }
383
384 #[test]
385 fn test_mesh_encoder() {
386 let encoder = MeshEncoder::new(3, 16, 128);
387 let cube = MeshUtils::create_cube();
388
389 let features = encoder.forward(&cube).unwrap();
390 assert_eq!(features.dims(), &[1, 128]); }
392
393 #[test]
394 fn test_tetrahedron() {
395 let tetra = MeshUtils::create_tetrahedron();
396 assert_eq!(tetra.num_vertices(), 4);
397 assert_eq!(tetra.num_faces(), 4);
398 }
399}