mesh_gpu/
surface_nets.rs

1//! GPU-accelerated Surface Nets isosurface extraction.
2//!
3//! Surface Nets is an algorithm for extracting isosurfaces from volumetric
4//! data (like signed distance fields). It produces higher quality meshes
5//! than Marching Cubes with simpler implementation.
6
7use bytemuck::{Pod, Zeroable};
8use tracing::{debug, info, warn};
9use wgpu::util::DeviceExt;
10use wgpu::{BindGroupLayout, ComputePipeline};
11
12use mesh_repair::Mesh;
13
14use crate::context::GpuContext;
15use crate::error::{GpuError, GpuResult};
16
17/// Shader source for Surface Nets.
18const SURFACE_NETS_SHADER: &str = include_str!("shaders/surface_nets.wgsl");
19
20/// Parameters for GPU Surface Nets extraction.
21#[derive(Debug, Clone)]
22pub struct GpuSurfaceNetsParams {
23    /// Grid dimensions [x, y, z].
24    pub dims: [usize; 3],
25    /// Grid origin in world coordinates.
26    pub origin: [f32; 3],
27    /// Voxel size in world units.
28    pub voxel_size: f32,
29    /// Iso-value for surface extraction (typically 0.0 for SDF).
30    pub iso_value: f32,
31}
32
33impl Default for GpuSurfaceNetsParams {
34    fn default() -> Self {
35        Self {
36            dims: [0, 0, 0],
37            origin: [0.0, 0.0, 0.0],
38            voxel_size: 1.0,
39            iso_value: 0.0,
40        }
41    }
42}
43
44/// Result of GPU Surface Nets extraction.
45#[derive(Debug)]
46pub struct GpuSurfaceNetsResult {
47    /// Extracted mesh.
48    pub mesh: Mesh,
49    /// Number of active cells found.
50    pub active_cells: usize,
51    /// Number of vertices generated.
52    pub vertex_count: usize,
53    /// Computation time in milliseconds.
54    pub compute_time_ms: f64,
55}
56
57/// GPU vertex output structure (matches shader).
58#[repr(C)]
59#[derive(Clone, Copy, Debug, Pod, Zeroable)]
60struct GpuOutputVertex {
61    position: [f32; 4], // xyz + vertex_idx
62    normal: [f32; 4],   // xyz + padding
63}
64
65/// Uniform parameters for the shader.
66#[repr(C)]
67#[derive(Clone, Copy, Debug, Pod, Zeroable)]
68struct ShaderGridParams {
69    origin: [f32; 4],
70    dims: [u32; 4],
71    voxel_size: f32,
72    iso_value: f32,
73    _padding: [f32; 2],
74}
75
76/// Pipeline for GPU Surface Nets extraction.
77pub struct SurfaceNetsPipeline {
78    identify_pipeline: ComputePipeline,
79    generate_pipeline: ComputePipeline,
80    bind_group_layout: BindGroupLayout,
81}
82
83impl SurfaceNetsPipeline {
84    /// Create a new Surface Nets pipeline.
85    pub fn new(ctx: &GpuContext) -> GpuResult<Self> {
86        debug!("Creating Surface Nets compute pipeline");
87
88        // Compile shader
89        let shader = ctx
90            .device
91            .create_shader_module(wgpu::ShaderModuleDescriptor {
92                label: Some("surface_nets"),
93                source: wgpu::ShaderSource::Wgsl(SURFACE_NETS_SHADER.into()),
94            });
95
96        // Create bind group layout
97        let bind_group_layout =
98            ctx.device
99                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
100                    label: Some("surface_nets_bind_group_layout"),
101                    entries: &[
102                        // SDF values (read-only)
103                        wgpu::BindGroupLayoutEntry {
104                            binding: 0,
105                            visibility: wgpu::ShaderStages::COMPUTE,
106                            ty: wgpu::BindingType::Buffer {
107                                ty: wgpu::BufferBindingType::Storage { read_only: true },
108                                has_dynamic_offset: false,
109                                min_binding_size: None,
110                            },
111                            count: None,
112                        },
113                        // Grid params (uniform)
114                        wgpu::BindGroupLayoutEntry {
115                            binding: 1,
116                            visibility: wgpu::ShaderStages::COMPUTE,
117                            ty: wgpu::BindingType::Buffer {
118                                ty: wgpu::BufferBindingType::Uniform,
119                                has_dynamic_offset: false,
120                                min_binding_size: None,
121                            },
122                            count: None,
123                        },
124                        // Active cells bitmap (read-write)
125                        wgpu::BindGroupLayoutEntry {
126                            binding: 2,
127                            visibility: wgpu::ShaderStages::COMPUTE,
128                            ty: wgpu::BindingType::Buffer {
129                                ty: wgpu::BufferBindingType::Storage { read_only: false },
130                                has_dynamic_offset: false,
131                                min_binding_size: None,
132                            },
133                            count: None,
134                        },
135                        // Cell vertices output (read-write)
136                        wgpu::BindGroupLayoutEntry {
137                            binding: 3,
138                            visibility: wgpu::ShaderStages::COMPUTE,
139                            ty: wgpu::BindingType::Buffer {
140                                ty: wgpu::BufferBindingType::Storage { read_only: false },
141                                has_dynamic_offset: false,
142                                min_binding_size: None,
143                            },
144                            count: None,
145                        },
146                        // Vertex count (atomic)
147                        wgpu::BindGroupLayoutEntry {
148                            binding: 4,
149                            visibility: wgpu::ShaderStages::COMPUTE,
150                            ty: wgpu::BindingType::Buffer {
151                                ty: wgpu::BufferBindingType::Storage { read_only: false },
152                                has_dynamic_offset: false,
153                                min_binding_size: None,
154                            },
155                            count: None,
156                        },
157                    ],
158                });
159
160        // Create pipeline layout
161        let pipeline_layout = ctx
162            .device
163            .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
164                label: Some("surface_nets_pipeline_layout"),
165                bind_group_layouts: &[&bind_group_layout],
166                push_constant_ranges: &[],
167            });
168
169        // Create identify_active_cells pipeline
170        let identify_pipeline =
171            ctx.device
172                .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
173                    label: Some("surface_nets_identify_pipeline"),
174                    layout: Some(&pipeline_layout),
175                    module: &shader,
176                    entry_point: Some("identify_active_cells"),
177                    compilation_options: Default::default(),
178                    cache: None,
179                });
180
181        // Create generate_vertices pipeline
182        let generate_pipeline =
183            ctx.device
184                .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
185                    label: Some("surface_nets_generate_pipeline"),
186                    layout: Some(&pipeline_layout),
187                    module: &shader,
188                    entry_point: Some("generate_vertices"),
189                    compilation_options: Default::default(),
190                    cache: None,
191                });
192
193        Ok(Self {
194            identify_pipeline,
195            generate_pipeline,
196            bind_group_layout,
197        })
198    }
199
200    /// Extract isosurface from SDF values.
201    ///
202    /// # Arguments
203    /// * `ctx` - GPU context
204    /// * `sdf_values` - SDF values already uploaded to GPU
205    /// * `params` - Extraction parameters
206    ///
207    /// # Returns
208    /// The extracted mesh.
209    pub fn extract(
210        &self,
211        ctx: &GpuContext,
212        sdf_buffer: &wgpu::Buffer,
213        params: &GpuSurfaceNetsParams,
214    ) -> GpuResult<GpuSurfaceNetsResult> {
215        let start = std::time::Instant::now();
216
217        // Calculate number of cells (one less than voxels in each dimension)
218        let cells = [
219            params.dims[0].saturating_sub(1),
220            params.dims[1].saturating_sub(1),
221            params.dims[2].saturating_sub(1),
222        ];
223        let total_cells = cells[0] * cells[1] * cells[2];
224
225        if total_cells == 0 {
226            return Ok(GpuSurfaceNetsResult {
227                mesh: Mesh::new(),
228                active_cells: 0,
229                vertex_count: 0,
230                compute_time_ms: 0.0,
231            });
232        }
233
234        info!(
235            dims = ?params.dims,
236            cells = total_cells,
237            "Extracting isosurface on GPU"
238        );
239
240        // Create uniform buffer for grid params
241        let grid_params = ShaderGridParams {
242            origin: [params.origin[0], params.origin[1], params.origin[2], 0.0],
243            dims: [
244                params.dims[0] as u32,
245                params.dims[1] as u32,
246                params.dims[2] as u32,
247                0,
248            ],
249            voxel_size: params.voxel_size,
250            iso_value: params.iso_value,
251            _padding: [0.0, 0.0],
252        };
253
254        let params_buffer = ctx
255            .device
256            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
257                label: Some("surface_nets_params"),
258                contents: bytemuck::bytes_of(&grid_params),
259                usage: wgpu::BufferUsages::UNIFORM,
260            });
261
262        // Create active cells buffer (1 u32 flag per cell: 0 = inactive, 1 = active)
263        let active_cells_size = total_cells * std::mem::size_of::<u32>();
264        let active_cells_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
265            label: Some("surface_nets_active_cells"),
266            size: active_cells_size as u64,
267            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
268            mapped_at_creation: false,
269        });
270
271        // Create cell vertices buffer (one potential vertex per cell)
272        let vertices_size = total_cells * std::mem::size_of::<GpuOutputVertex>();
273        let vertices_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
274            label: Some("surface_nets_vertices"),
275            size: vertices_size as u64,
276            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
277            mapped_at_creation: false,
278        });
279
280        // Create vertex count buffer (atomic counter)
281        let count_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
282            label: Some("surface_nets_count"),
283            size: std::mem::size_of::<u32>() as u64,
284            usage: wgpu::BufferUsages::STORAGE
285                | wgpu::BufferUsages::COPY_SRC
286                | wgpu::BufferUsages::COPY_DST,
287            mapped_at_creation: false,
288        });
289
290        // Initialize count to 0
291        ctx.queue
292            .write_buffer(&count_buffer, 0, bytemuck::bytes_of(&0u32));
293
294        // Create bind group
295        let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
296            label: Some("surface_nets_bind_group"),
297            layout: &self.bind_group_layout,
298            entries: &[
299                wgpu::BindGroupEntry {
300                    binding: 0,
301                    resource: sdf_buffer.as_entire_binding(),
302                },
303                wgpu::BindGroupEntry {
304                    binding: 1,
305                    resource: params_buffer.as_entire_binding(),
306                },
307                wgpu::BindGroupEntry {
308                    binding: 2,
309                    resource: active_cells_buffer.as_entire_binding(),
310                },
311                wgpu::BindGroupEntry {
312                    binding: 3,
313                    resource: vertices_buffer.as_entire_binding(),
314                },
315                wgpu::BindGroupEntry {
316                    binding: 4,
317                    resource: count_buffer.as_entire_binding(),
318                },
319            ],
320        });
321
322        // Create command encoder
323        let mut encoder = ctx
324            .device
325            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
326                label: Some("surface_nets_encoder"),
327            });
328
329        let workgroups = (total_cells as u32).div_ceil(256);
330
331        // Pass 1: Identify active cells
332        {
333            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
334                label: Some("surface_nets_identify_pass"),
335                timestamp_writes: None,
336            });
337            compute_pass.set_pipeline(&self.identify_pipeline);
338            compute_pass.set_bind_group(0, &bind_group, &[]);
339            compute_pass.dispatch_workgroups(workgroups, 1, 1);
340        }
341
342        // Pass 2: Generate vertices
343        {
344            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
345                label: Some("surface_nets_generate_pass"),
346                timestamp_writes: None,
347            });
348            compute_pass.set_pipeline(&self.generate_pipeline);
349            compute_pass.set_bind_group(0, &bind_group, &[]);
350            compute_pass.dispatch_workgroups(workgroups, 1, 1);
351        }
352
353        // Submit commands
354        ctx.queue.submit([encoder.finish()]);
355
356        // Download results
357        let vertex_count = self.download_count(ctx, &count_buffer)?;
358        let vertices = self.download_vertices(ctx, &vertices_buffer, total_cells)?;
359
360        // Build mesh from vertices (faces generated on CPU for now)
361        let mesh = self.build_mesh(&vertices, vertex_count as usize, &cells, params);
362
363        let compute_time_ms = start.elapsed().as_secs_f64() * 1000.0;
364        info!(
365            vertices = vertex_count,
366            faces = mesh.faces.len(),
367            time_ms = compute_time_ms,
368            "Isosurface extraction complete"
369        );
370
371        Ok(GpuSurfaceNetsResult {
372            mesh,
373            active_cells: vertex_count as usize, // Approximate
374            vertex_count: vertex_count as usize,
375            compute_time_ms,
376        })
377    }
378
379    fn download_count(&self, ctx: &GpuContext, buffer: &wgpu::Buffer) -> GpuResult<u32> {
380        let staging = ctx.device.create_buffer(&wgpu::BufferDescriptor {
381            label: Some("count_staging"),
382            size: std::mem::size_of::<u32>() as u64,
383            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
384            mapped_at_creation: false,
385        });
386
387        let mut encoder = ctx
388            .device
389            .create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
390        encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, std::mem::size_of::<u32>() as u64);
391        ctx.queue.submit([encoder.finish()]);
392
393        let slice = staging.slice(..);
394        let (tx, rx) = std::sync::mpsc::channel();
395        slice.map_async(wgpu::MapMode::Read, move |result| {
396            tx.send(result).unwrap();
397        });
398        ctx.device.poll(wgpu::Maintain::Wait);
399
400        rx.recv()
401            .map_err(|_| GpuError::BufferMapping("channel closed".into()))?
402            .map_err(|e| GpuError::BufferMapping(format!("{:?}", e)))?;
403
404        let data = slice.get_mapped_range();
405        let count = *bytemuck::from_bytes::<u32>(&data);
406        drop(data);
407        staging.unmap();
408
409        Ok(count)
410    }
411
412    fn download_vertices(
413        &self,
414        ctx: &GpuContext,
415        buffer: &wgpu::Buffer,
416        count: usize,
417    ) -> GpuResult<Vec<GpuOutputVertex>> {
418        let size = count * std::mem::size_of::<GpuOutputVertex>();
419        let staging = ctx.device.create_buffer(&wgpu::BufferDescriptor {
420            label: Some("vertices_staging"),
421            size: size as u64,
422            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
423            mapped_at_creation: false,
424        });
425
426        let mut encoder = ctx
427            .device
428            .create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
429        encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, size as u64);
430        ctx.queue.submit([encoder.finish()]);
431
432        let slice = staging.slice(..);
433        let (tx, rx) = std::sync::mpsc::channel();
434        slice.map_async(wgpu::MapMode::Read, move |result| {
435            tx.send(result).unwrap();
436        });
437        ctx.device.poll(wgpu::Maintain::Wait);
438
439        rx.recv()
440            .map_err(|_| GpuError::BufferMapping("channel closed".into()))?
441            .map_err(|e| GpuError::BufferMapping(format!("{:?}", e)))?;
442
443        let data = slice.get_mapped_range();
444        let vertices: Vec<GpuOutputVertex> = bytemuck::cast_slice(&data).to_vec();
445        drop(data);
446        staging.unmap();
447
448        Ok(vertices)
449    }
450
451    fn build_mesh(
452        &self,
453        gpu_vertices: &[GpuOutputVertex],
454        vertex_count: usize,
455        _cells: &[usize; 3],
456        _params: &GpuSurfaceNetsParams,
457    ) -> Mesh {
458        use mesh_repair::Vertex;
459
460        let mut mesh = Mesh::new();
461
462        // Add vertices
463        for v in gpu_vertices.iter().take(vertex_count) {
464            let mut vertex = Vertex::from_coords(
465                v.position[0] as f64,
466                v.position[1] as f64,
467                v.position[2] as f64,
468            );
469            vertex.normal = Some(nalgebra::Vector3::new(
470                v.normal[0] as f64,
471                v.normal[1] as f64,
472                v.normal[2] as f64,
473            ));
474            mesh.vertices.push(vertex);
475        }
476
477        // Note: Face generation is more complex in Surface Nets and requires
478        // connectivity information. For now, we return just the vertices.
479        // The face generation can be done on CPU using the cell adjacency info,
480        // or a third GPU pass could be added.
481        //
482        // For a complete implementation, faces are generated by connecting
483        // vertices from adjacent active cells that share an edge crossing.
484
485        mesh
486    }
487}
488
489/// Extract isosurface from SDF values on GPU.
490pub fn extract_isosurface_gpu(
491    sdf_values: &[f32],
492    params: &GpuSurfaceNetsParams,
493) -> GpuResult<GpuSurfaceNetsResult> {
494    let ctx = GpuContext::try_get()?;
495
496    // Upload SDF values to GPU
497    let sdf_buffer = ctx
498        .device
499        .create_buffer_init(&wgpu::util::BufferInitDescriptor {
500            label: Some("surface_nets_sdf"),
501            contents: bytemuck::cast_slice(sdf_values),
502            usage: wgpu::BufferUsages::STORAGE,
503        });
504
505    let pipeline = SurfaceNetsPipeline::new(ctx)?;
506    pipeline.extract(ctx, &sdf_buffer, params)
507}
508
509/// Try to extract isosurface on GPU, returning None if unavailable.
510pub fn try_extract_isosurface_gpu(
511    sdf_values: &[f32],
512    params: &GpuSurfaceNetsParams,
513) -> Option<GpuSurfaceNetsResult> {
514    match extract_isosurface_gpu(sdf_values, params) {
515        Ok(result) => Some(result),
516        Err(GpuError::NotAvailable) => {
517            debug!("GPU not available for isosurface extraction");
518            None
519        }
520        Err(e) => {
521            warn!("GPU isosurface extraction failed: {}", e);
522            None
523        }
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    #[test]
532    fn test_gpu_surface_nets_params_default() {
533        let params = GpuSurfaceNetsParams::default();
534        assert_eq!(params.iso_value, 0.0);
535        assert_eq!(params.voxel_size, 1.0);
536    }
537
538    #[test]
539    fn test_try_extract_isosurface_gpu() {
540        // Create a simple 3x3x3 SDF with a sphere
541        let mut sdf = vec![1.0f32; 27];
542        // Center voxel (1,1,1) is inside
543        sdf[1 + 3 + 9] = -1.0;
544
545        let params = GpuSurfaceNetsParams {
546            dims: [3, 3, 3],
547            origin: [0.0, 0.0, 0.0],
548            voxel_size: 1.0,
549            iso_value: 0.0,
550        };
551
552        // This test will pass whether or not GPU is available
553        let _result = try_extract_isosurface_gpu(&sdf, &params);
554    }
555}