ghostflow_nn/
mesh.rs

1//! Mesh Processing
2//!
3//! Implements mesh neural networks:
4//! - Mesh representation (vertices, faces, edges)
5//! - Mesh convolution operations
6//! - Graph-based mesh processing
7//! - Mesh pooling and unpooling
8//! - Mesh feature extraction
9
10use ghostflow_core::Tensor;
11use crate::linear::Linear;
12use crate::Module;
13
14/// Mesh representation
15#[derive(Debug, Clone)]
16pub struct Mesh {
17    /// Vertices: [num_vertices, 3] (x, y, z coordinates)
18    pub vertices: Tensor,
19    /// Faces: [num_faces, 3] (vertex indices for triangles)
20    pub faces: Vec<[usize; 3]>,
21    /// Vertex features: [num_vertices, feature_dim]
22    pub features: Option<Tensor>,
23}
24
25impl Mesh {
26    /// Create new mesh
27    pub fn new(vertices: Tensor, faces: Vec<[usize; 3]>) -> Self {
28        Mesh {
29            vertices,
30            faces,
31            features: None,
32        }
33    }
34    
35    /// Create mesh with features
36    pub fn with_features(vertices: Tensor, faces: Vec<[usize; 3]>, features: Tensor) -> Self {
37        Mesh {
38            vertices,
39            faces,
40            features: Some(features),
41        }
42    }
43    
44    /// Get number of vertices
45    pub fn num_vertices(&self) -> usize {
46        self.vertices.dims()[0]
47    }
48    
49    /// Get number of faces
50    pub fn num_faces(&self) -> usize {
51        self.faces.len()
52    }
53    
54    /// Compute adjacency list
55    pub fn compute_adjacency(&self) -> Vec<Vec<usize>> {
56        let num_verts = self.num_vertices();
57        let mut adjacency = vec![Vec::new(); num_verts];
58        
59        for face in &self.faces {
60            // Add edges for each triangle
61            adjacency[face[0]].push(face[1]);
62            adjacency[face[0]].push(face[2]);
63            adjacency[face[1]].push(face[0]);
64            adjacency[face[1]].push(face[2]);
65            adjacency[face[2]].push(face[0]);
66            adjacency[face[2]].push(face[1]);
67        }
68        
69        // Remove duplicates and sort
70        for neighbors in &mut adjacency {
71            neighbors.sort_unstable();
72            neighbors.dedup();
73        }
74        
75        adjacency
76    }
77    
78    /// Compute face normals
79    pub fn compute_face_normals(&self) -> Result<Tensor, String> {
80        let verts_data = self.vertices.data_f32();
81        let mut normals = Vec::with_capacity(self.num_faces() * 3);
82        
83        for face in &self.faces {
84            let v0 = &verts_data[face[0] * 3..face[0] * 3 + 3];
85            let v1 = &verts_data[face[1] * 3..face[1] * 3 + 3];
86            let v2 = &verts_data[face[2] * 3..face[2] * 3 + 3];
87            
88            // Compute edges
89            let e1 = [v1[0] - v0[0], v1[1] - v0[1], v1[2] - v0[2]];
90            let e2 = [v2[0] - v0[0], v2[1] - v0[1], v2[2] - v0[2]];
91            
92            // Cross product
93            let normal = [
94                e1[1] * e2[2] - e1[2] * e2[1],
95                e1[2] * e2[0] - e1[0] * e2[2],
96                e1[0] * e2[1] - e1[1] * e2[0],
97            ];
98            
99            // Normalize
100            let length = (normal[0] * normal[0] + normal[1] * normal[1] + normal[2] * normal[2]).sqrt();
101            if length > 1e-8 {
102                normals.push(normal[0] / length);
103                normals.push(normal[1] / length);
104                normals.push(normal[2] / length);
105            } else {
106                normals.extend_from_slice(&[0.0, 0.0, 1.0]);
107            }
108        }
109        
110        Tensor::from_slice(&normals, &[self.num_faces(), 3])
111            .map_err(|e| format!("Failed to create normals: {:?}", e))
112    }
113}
114
115/// Mesh convolution layer
116pub struct MeshConv {
117    in_features: usize,
118    out_features: usize,
119    weight: Linear,
120}
121
122impl MeshConv {
123    /// Create new mesh convolution
124    pub fn new(in_features: usize, out_features: usize) -> Self {
125        MeshConv {
126            in_features,
127            out_features,
128            weight: Linear::new(in_features, out_features),
129        }
130    }
131    
132    /// Forward pass
133    pub fn forward(&self, features: &Tensor, adjacency: &[Vec<usize>]) -> Result<Tensor, String> {
134        let feat_data = features.data_f32();
135        let dims = features.dims();
136        let num_vertices = dims[0];
137        
138        // Aggregate neighbor features
139        let mut aggregated = Vec::with_capacity(num_vertices * self.in_features);
140        
141        for v in 0..num_vertices {
142            let neighbors = &adjacency[v];
143            
144            if neighbors.is_empty() {
145                // No neighbors, use self features
146                let start = v * self.in_features;
147                aggregated.extend_from_slice(&feat_data[start..start + self.in_features]);
148            } else {
149                // Average neighbor features
150                let mut avg_features = vec![0.0; self.in_features];
151                
152                for &neighbor in neighbors {
153                    let start = neighbor * self.in_features;
154                    for i in 0..self.in_features {
155                        avg_features[i] += feat_data[start + i];
156                    }
157                }
158                
159                let num_neighbors = neighbors.len() as f32;
160                for feat in &mut avg_features {
161                    *feat /= num_neighbors;
162                }
163                
164                aggregated.extend_from_slice(&avg_features);
165            }
166        }
167        
168        let aggregated_tensor = Tensor::from_slice(&aggregated, &[num_vertices, self.in_features])
169            .map_err(|e| format!("Failed to create aggregated tensor: {:?}", e))?;
170        
171        // Apply linear transformation
172        Ok(self.weight.forward(&aggregated_tensor))
173    }
174}
175
176/// Mesh pooling (vertex decimation)
177pub struct MeshPool;
178
179impl MeshPool {
180    /// Pool mesh by selecting every nth vertex
181    pub fn pool(mesh: &Mesh, stride: usize) -> Result<Mesh, String> {
182        let verts_data = mesh.vertices.data_f32();
183        let num_verts = mesh.num_vertices();
184        
185        // Select vertices
186        let mut new_verts = Vec::new();
187        let mut vertex_map = vec![None; num_verts];
188        let mut new_idx = 0;
189        
190        for i in (0..num_verts).step_by(stride) {
191            vertex_map[i] = Some(new_idx);
192            new_verts.extend_from_slice(&verts_data[i * 3..i * 3 + 3]);
193            new_idx += 1;
194        }
195        
196        // Update faces
197        let mut new_faces = Vec::new();
198        for face in &mesh.faces {
199            if let (Some(v0), Some(v1), Some(v2)) = (vertex_map[face[0]], vertex_map[face[1]], vertex_map[face[2]]) {
200                new_faces.push([v0, v1, v2]);
201            }
202        }
203        
204        let new_vertices = Tensor::from_slice(&new_verts, &[new_idx, 3])
205            .map_err(|e| format!("Failed to create pooled vertices: {:?}", e))?;
206        
207        Ok(Mesh::new(new_vertices, new_faces))
208    }
209}
210
211/// Mesh encoder
212pub struct MeshEncoder {
213    conv1: MeshConv,
214    conv2: MeshConv,
215    conv3: MeshConv,
216    fc: Linear,
217}
218
219impl MeshEncoder {
220    /// Create new mesh encoder
221    pub fn new(input_dim: usize, hidden_dim: usize, output_dim: usize) -> Self {
222        MeshEncoder {
223            conv1: MeshConv::new(input_dim, hidden_dim),
224            conv2: MeshConv::new(hidden_dim, hidden_dim * 2),
225            conv3: MeshConv::new(hidden_dim * 2, hidden_dim * 4),
226            fc: Linear::new(hidden_dim * 4, output_dim),
227        }
228    }
229    
230    /// Forward pass
231    pub fn forward(&self, mesh: &Mesh) -> Result<Tensor, String> {
232        let adjacency = mesh.compute_adjacency();
233        
234        // Use vertex positions as initial features if no features provided
235        let features = if let Some(ref feat) = mesh.features {
236            feat.clone()
237        } else {
238            mesh.vertices.clone()
239        };
240        
241        // Apply mesh convolutions
242        let mut x = self.conv1.forward(&features, &adjacency)?;
243        x = x.relu();
244        
245        x = self.conv2.forward(&x, &adjacency)?;
246        x = x.relu();
247        
248        x = self.conv3.forward(&x, &adjacency)?;
249        x = x.relu();
250        
251        // Global pooling (max over vertices)
252        let pooled = self.global_max_pool(&x)?;
253        
254        // Final linear layer
255        Ok(self.fc.forward(&pooled))
256    }
257    
258    fn global_max_pool(&self, x: &Tensor) -> Result<Tensor, String> {
259        let data = x.data_f32();
260        let dims = x.dims();
261        let num_vertices = dims[0];
262        let feature_dim = dims[1];
263        
264        let mut result = vec![f32::NEG_INFINITY; feature_dim];
265        
266        for v in 0..num_vertices {
267            for f in 0..feature_dim {
268                let val = data[v * feature_dim + f];
269                result[f] = result[f].max(val);
270            }
271        }
272        
273        Tensor::from_slice(&result, &[1, feature_dim])
274            .map_err(|e| format!("Failed to pool: {:?}", e))
275    }
276}
277
278/// Mesh utilities
279pub struct MeshUtils;
280
281impl MeshUtils {
282    /// Create a simple cube mesh
283    pub fn create_cube() -> Mesh {
284        let vertices = vec![
285            -1.0f32, -1.0, -1.0,  // 0
286             1.0, -1.0, -1.0,  // 1
287             1.0,  1.0, -1.0,  // 2
288            -1.0,  1.0, -1.0,  // 3
289            -1.0, -1.0,  1.0,  // 4
290             1.0, -1.0,  1.0,  // 5
291             1.0,  1.0,  1.0,  // 6
292            -1.0,  1.0,  1.0,  // 7
293        ];
294        
295        let faces = vec![
296            // Front
297            [0, 1, 2], [0, 2, 3],
298            // Back
299            [4, 6, 5], [4, 7, 6],
300            // Left
301            [0, 3, 7], [0, 7, 4],
302            // Right
303            [1, 5, 6], [1, 6, 2],
304            // Top
305            [3, 2, 6], [3, 6, 7],
306            // Bottom
307            [0, 4, 5], [0, 5, 1],
308        ];
309        
310        let verts_tensor = Tensor::from_slice(&vertices, &[8, 3]).unwrap();
311        Mesh::new(verts_tensor, faces)
312    }
313    
314    /// Create a simple tetrahedron mesh
315    pub fn create_tetrahedron() -> Mesh {
316        let vertices = vec![
317            0.0f32, 0.0, 0.0,
318            1.0, 0.0, 0.0,
319            0.5, 1.0, 0.0,
320            0.5, 0.5, 1.0,
321        ];
322        
323        let faces = vec![
324            [0, 1, 2],
325            [0, 1, 3],
326            [0, 2, 3],
327            [1, 2, 3],
328        ];
329        
330        let verts_tensor = Tensor::from_slice(&vertices, &[4, 3]).unwrap();
331        Mesh::new(verts_tensor, faces)
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    
339    #[test]
340    fn test_mesh_creation() {
341        let cube = MeshUtils::create_cube();
342        assert_eq!(cube.num_vertices(), 8);
343        assert_eq!(cube.num_faces(), 12);
344    }
345    
346    #[test]
347    fn test_mesh_adjacency() {
348        let cube = MeshUtils::create_cube();
349        let adjacency = cube.compute_adjacency();
350        
351        assert_eq!(adjacency.len(), 8);
352        // Each vertex in a cube connects to multiple neighbors through triangulated faces
353        for neighbors in &adjacency {
354            assert!(neighbors.len() >= 3, "Each vertex should have at least 3 neighbors");
355        }
356    }
357    
358    #[test]
359    fn test_face_normals() {
360        let cube = MeshUtils::create_cube();
361        let normals = cube.compute_face_normals().unwrap();
362        assert_eq!(normals.dims(), &[12, 3]); // 12 faces, 3D normals
363    }
364    
365    #[test]
366    fn test_mesh_conv() {
367        let conv = MeshConv::new(3, 16);
368        let cube = MeshUtils::create_cube();
369        let adjacency = cube.compute_adjacency();
370        
371        let output = conv.forward(&cube.vertices, &adjacency).unwrap();
372        assert_eq!(output.dims(), &[8, 16]); // 8 vertices, 16 features
373    }
374    
375    #[test]
376    fn test_mesh_pool() {
377        let cube = MeshUtils::create_cube();
378        let pooled = MeshPool::pool(&cube, 2).unwrap();
379        
380        assert!(pooled.num_vertices() <= cube.num_vertices());
381        assert!(pooled.num_faces() <= cube.num_faces());
382    }
383    
384    #[test]
385    fn test_mesh_encoder() {
386        let encoder = MeshEncoder::new(3, 16, 128);
387        let cube = MeshUtils::create_cube();
388        
389        let features = encoder.forward(&cube).unwrap();
390        assert_eq!(features.dims(), &[1, 128]); // Global features
391    }
392    
393    #[test]
394    fn test_tetrahedron() {
395        let tetra = MeshUtils::create_tetrahedron();
396        assert_eq!(tetra.num_vertices(), 4);
397        assert_eq!(tetra.num_faces(), 4);
398    }
399}