mesh_gpu/
sdf.rs

1//! GPU-accelerated SDF (Signed Distance Field) computation.
2//!
3//! This module provides GPU-accelerated computation of signed distance fields
4//! from triangle meshes. It uses WGPU compute shaders for parallel processing.
5
6use tracing::{debug, info, warn};
7use wgpu::{BindGroupLayout, ComputePipeline, ShaderModule};
8
9use mesh_repair::Mesh;
10
11use crate::buffers::{MeshBuffers, SdfGridBuffers, TileConfig};
12use crate::context::GpuContext;
13use crate::error::{GpuError, GpuResult};
14
15/// Shader source for SDF computation.
16const SDF_SHADER: &str = include_str!("shaders/sdf_compute.wgsl");
17
18/// Parameters for GPU SDF computation.
19#[derive(Debug, Clone)]
20pub struct GpuSdfParams {
21    /// Grid dimensions [x, y, z].
22    pub dims: [usize; 3],
23    /// Grid origin in world coordinates.
24    pub origin: [f32; 3],
25    /// Voxel size in world units.
26    pub voxel_size: f32,
27}
28
29/// Result of GPU SDF computation.
30#[derive(Debug)]
31pub struct GpuSdfResult {
32    /// Computed SDF values.
33    pub values: Vec<f32>,
34    /// Grid dimensions.
35    pub dims: [usize; 3],
36    /// Computation time in milliseconds.
37    pub compute_time_ms: f64,
38}
39
40/// Pipeline for GPU SDF computation.
41///
42/// This struct caches the compiled shader and pipeline, allowing efficient
43/// reuse across multiple SDF computations.
44pub struct SdfPipeline {
45    #[allow(dead_code)] // Kept for potential future use (shader introspection)
46    shader: ShaderModule,
47    pipeline: ComputePipeline,
48    bind_group_layout: BindGroupLayout,
49}
50
51impl SdfPipeline {
52    /// Create a new SDF computation pipeline.
53    pub fn new(ctx: &GpuContext) -> GpuResult<Self> {
54        debug!("Creating SDF compute pipeline");
55
56        // Compile shader
57        let shader = ctx
58            .device
59            .create_shader_module(wgpu::ShaderModuleDescriptor {
60                label: Some("sdf_compute"),
61                source: wgpu::ShaderSource::Wgsl(SDF_SHADER.into()),
62            });
63
64        // Create bind group layout
65        let bind_group_layout =
66            ctx.device
67                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
68                    label: Some("sdf_bind_group_layout"),
69                    entries: &[
70                        // Triangles storage buffer (read-only)
71                        wgpu::BindGroupLayoutEntry {
72                            binding: 0,
73                            visibility: wgpu::ShaderStages::COMPUTE,
74                            ty: wgpu::BindingType::Buffer {
75                                ty: wgpu::BufferBindingType::Storage { read_only: true },
76                                has_dynamic_offset: false,
77                                min_binding_size: None,
78                            },
79                            count: None,
80                        },
81                        // Grid params uniform buffer
82                        wgpu::BindGroupLayoutEntry {
83                            binding: 1,
84                            visibility: wgpu::ShaderStages::COMPUTE,
85                            ty: wgpu::BindingType::Buffer {
86                                ty: wgpu::BufferBindingType::Uniform,
87                                has_dynamic_offset: false,
88                                min_binding_size: None,
89                            },
90                            count: None,
91                        },
92                        // SDF values storage buffer (read-write)
93                        wgpu::BindGroupLayoutEntry {
94                            binding: 2,
95                            visibility: wgpu::ShaderStages::COMPUTE,
96                            ty: wgpu::BindingType::Buffer {
97                                ty: wgpu::BufferBindingType::Storage { read_only: false },
98                                has_dynamic_offset: false,
99                                min_binding_size: None,
100                            },
101                            count: None,
102                        },
103                    ],
104                });
105
106        // Create pipeline layout
107        let pipeline_layout = ctx
108            .device
109            .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
110                label: Some("sdf_pipeline_layout"),
111                bind_group_layouts: &[&bind_group_layout],
112                push_constant_ranges: &[],
113            });
114
115        // Create compute pipeline
116        let pipeline = ctx
117            .device
118            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
119                label: Some("sdf_compute_pipeline"),
120                layout: Some(&pipeline_layout),
121                module: &shader,
122                entry_point: Some("compute_sdf"),
123                compilation_options: Default::default(),
124                cache: None,
125            });
126
127        Ok(Self {
128            shader,
129            pipeline,
130            bind_group_layout,
131        })
132    }
133
134    /// Compute SDF for a mesh.
135    ///
136    /// # Arguments
137    /// * `ctx` - GPU context
138    /// * `mesh_buffers` - Mesh data already uploaded to GPU
139    /// * `params` - SDF computation parameters
140    ///
141    /// # Returns
142    /// The computed SDF values.
143    pub fn compute(
144        &self,
145        ctx: &GpuContext,
146        mesh_buffers: &MeshBuffers,
147        params: &GpuSdfParams,
148    ) -> GpuResult<GpuSdfResult> {
149        let start = std::time::Instant::now();
150        let total_voxels = params.dims[0] * params.dims[1] * params.dims[2];
151
152        info!(
153            dims = ?params.dims,
154            total_voxels = total_voxels,
155            triangles = mesh_buffers.triangle_count,
156            "Computing SDF on GPU"
157        );
158
159        // Allocate grid buffers
160        let grid_buffers = SdfGridBuffers::allocate(
161            ctx,
162            params.dims,
163            params.origin,
164            params.voxel_size,
165            mesh_buffers.triangle_count,
166        )?;
167
168        // Create bind group
169        let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
170            label: Some("sdf_bind_group"),
171            layout: &self.bind_group_layout,
172            entries: &[
173                wgpu::BindGroupEntry {
174                    binding: 0,
175                    resource: mesh_buffers.triangles.as_entire_binding(),
176                },
177                wgpu::BindGroupEntry {
178                    binding: 1,
179                    resource: grid_buffers.params.as_entire_binding(),
180                },
181                wgpu::BindGroupEntry {
182                    binding: 2,
183                    resource: grid_buffers.values.as_entire_binding(),
184                },
185            ],
186        });
187
188        // Create command encoder
189        let mut encoder = ctx
190            .device
191            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
192                label: Some("sdf_compute_encoder"),
193            });
194
195        // Dispatch compute shader
196        {
197            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
198                label: Some("sdf_compute_pass"),
199                timestamp_writes: None,
200            });
201
202            compute_pass.set_pipeline(&self.pipeline);
203            compute_pass.set_bind_group(0, &bind_group, &[]);
204
205            // Workgroup size is 256, so dispatch enough workgroups
206            let workgroups = (total_voxels as u32).div_ceil(256);
207            compute_pass.dispatch_workgroups(workgroups, 1, 1);
208        }
209
210        // Submit commands
211        ctx.queue.submit([encoder.finish()]);
212
213        // Download results
214        let values = grid_buffers.download_values(ctx)?;
215
216        let compute_time_ms = start.elapsed().as_secs_f64() * 1000.0;
217        info!(
218            voxels = total_voxels,
219            time_ms = compute_time_ms,
220            "SDF computation complete"
221        );
222
223        Ok(GpuSdfResult {
224            values,
225            dims: params.dims,
226            compute_time_ms,
227        })
228    }
229}
230
231/// Compute SDF on GPU with automatic fallback.
232///
233/// This is the main entry point for GPU SDF computation. It handles:
234/// - GPU availability detection
235/// - Pipeline creation and caching
236/// - Automatic tiling for large grids
237/// - Error handling with graceful fallback
238///
239/// # Arguments
240/// * `mesh` - Source mesh
241/// * `params` - SDF computation parameters
242///
243/// # Returns
244/// The computed SDF values, or an error if GPU computation fails.
245pub fn compute_sdf_gpu(mesh: &Mesh, params: &GpuSdfParams) -> GpuResult<GpuSdfResult> {
246    let ctx = GpuContext::try_get()?;
247
248    // Upload mesh to GPU
249    let mesh_buffers = MeshBuffers::from_mesh(ctx, mesh)?;
250
251    // Check if we need tiling
252    let total_voxels = params.dims[0] * params.dims[1] * params.dims[2];
253    let max_voxels = ctx.max_storage_buffer_size() as usize / std::mem::size_of::<f32>();
254
255    if total_voxels > max_voxels {
256        // Use tiled computation
257        compute_sdf_tiled(ctx, &mesh_buffers, params)
258    } else {
259        // Direct computation
260        let pipeline = SdfPipeline::new(ctx)?;
261        pipeline.compute(ctx, &mesh_buffers, params)
262    }
263}
264
265/// Compute SDF using tiled processing for large grids.
266fn compute_sdf_tiled(
267    ctx: &GpuContext,
268    mesh_buffers: &MeshBuffers,
269    params: &GpuSdfParams,
270) -> GpuResult<GpuSdfResult> {
271    let start = std::time::Instant::now();
272    let total_voxels = params.dims[0] * params.dims[1] * params.dims[2];
273
274    // Determine tile configuration based on available memory
275    let available_memory = ctx.estimate_available_memory();
276    // Reserve memory for mesh and overhead
277    let grid_memory =
278        available_memory.saturating_sub(mesh_buffers.triangles_size() + 256 * 1024 * 1024);
279    let tile_config = TileConfig::for_memory_budget(grid_memory);
280
281    let tile_counts = tile_config.tile_count(params.dims);
282    let total_tiles = tile_config.total_tiles(params.dims);
283
284    info!(
285        grid_dims = ?params.dims,
286        tile_size = ?tile_config.tile_size,
287        tiles = total_tiles,
288        "Using tiled SDF computation"
289    );
290
291    // Allocate result buffer
292    let mut result = vec![0.0f32; total_voxels];
293
294    // Create pipeline once
295    let pipeline = SdfPipeline::new(ctx)?;
296
297    // Process each tile
298    for tz in 0..tile_counts[2] {
299        for ty in 0..tile_counts[1] {
300            for tx in 0..tile_counts[0] {
301                let tile_origin_voxels = [
302                    tx * tile_config.tile_size[0],
303                    ty * tile_config.tile_size[1],
304                    tz * tile_config.tile_size[2],
305                ];
306
307                // Calculate tile dimensions (may be smaller at edges)
308                let tile_dims = [
309                    (params.dims[0] - tile_origin_voxels[0]).min(tile_config.tile_size[0]),
310                    (params.dims[1] - tile_origin_voxels[1]).min(tile_config.tile_size[1]),
311                    (params.dims[2] - tile_origin_voxels[2]).min(tile_config.tile_size[2]),
312                ];
313
314                // Calculate tile origin in world coordinates
315                let tile_origin_world = [
316                    params.origin[0] + (tile_origin_voxels[0] as f32) * params.voxel_size,
317                    params.origin[1] + (tile_origin_voxels[1] as f32) * params.voxel_size,
318                    params.origin[2] + (tile_origin_voxels[2] as f32) * params.voxel_size,
319                ];
320
321                let tile_params = GpuSdfParams {
322                    dims: tile_dims,
323                    origin: tile_origin_world,
324                    voxel_size: params.voxel_size,
325                };
326
327                // Compute tile
328                let tile_result = pipeline.compute(ctx, mesh_buffers, &tile_params)?;
329
330                // Copy tile results to main grid
331                copy_tile_to_grid(
332                    &tile_result.values,
333                    &mut result,
334                    params.dims,
335                    tile_origin_voxels,
336                    tile_dims,
337                );
338
339                debug!(tile_x = tx, tile_y = ty, tile_z = tz, "Tile processed");
340            }
341        }
342    }
343
344    let compute_time_ms = start.elapsed().as_secs_f64() * 1000.0;
345    info!(
346        tiles = total_tiles,
347        time_ms = compute_time_ms,
348        "Tiled SDF computation complete"
349    );
350
351    Ok(GpuSdfResult {
352        values: result,
353        dims: params.dims,
354        compute_time_ms,
355    })
356}
357
358/// Copy tile results to the main grid.
359fn copy_tile_to_grid(
360    tile_values: &[f32],
361    grid_values: &mut [f32],
362    grid_dims: [usize; 3],
363    tile_origin: [usize; 3],
364    tile_dims: [usize; 3],
365) {
366    // Use ZYX ordering to match mesh_to_sdf's layout
367    for z in 0..tile_dims[2] {
368        for y in 0..tile_dims[1] {
369            for x in 0..tile_dims[0] {
370                let tile_idx = z + y * tile_dims[2] + x * tile_dims[1] * tile_dims[2];
371                let grid_x = tile_origin[0] + x;
372                let grid_y = tile_origin[1] + y;
373                let grid_z = tile_origin[2] + z;
374                let grid_idx =
375                    grid_z + grid_y * grid_dims[2] + grid_x * grid_dims[1] * grid_dims[2];
376
377                if grid_idx < grid_values.len() && tile_idx < tile_values.len() {
378                    grid_values[grid_idx] = tile_values[tile_idx];
379                }
380            }
381        }
382    }
383}
384
385/// Try to compute SDF on GPU, returning None if GPU is unavailable.
386///
387/// This is a convenience function that doesn't return an error for GPU
388/// unavailability, making it easy to implement fallback logic.
389pub fn try_compute_sdf_gpu(mesh: &Mesh, params: &GpuSdfParams) -> Option<GpuSdfResult> {
390    match compute_sdf_gpu(mesh, params) {
391        Ok(result) => Some(result),
392        Err(GpuError::NotAvailable) => {
393            debug!("GPU not available for SDF computation");
394            None
395        }
396        Err(e) => {
397            warn!("GPU SDF computation failed: {}", e);
398            None
399        }
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use mesh_repair::Vertex;
407
408    fn create_test_cube() -> Mesh {
409        let mut mesh = Mesh::new();
410
411        // Unit cube centered at origin
412        let coords = [
413            [-1.0, -1.0, -1.0],
414            [1.0, -1.0, -1.0],
415            [1.0, 1.0, -1.0],
416            [-1.0, 1.0, -1.0],
417            [-1.0, -1.0, 1.0],
418            [1.0, -1.0, 1.0],
419            [1.0, 1.0, 1.0],
420            [-1.0, 1.0, 1.0],
421        ];
422
423        for c in &coords {
424            mesh.vertices.push(Vertex::from_coords(c[0], c[1], c[2]));
425        }
426
427        // Cube faces (2 triangles per face)
428        let faces = [
429            [0, 1, 2],
430            [0, 2, 3], // Front
431            [4, 6, 5],
432            [4, 7, 6], // Back
433            [0, 5, 1],
434            [0, 4, 5], // Bottom
435            [2, 7, 3],
436            [2, 6, 7], // Top
437            [0, 3, 7],
438            [0, 7, 4], // Left
439            [1, 5, 6],
440            [1, 6, 2], // Right
441        ];
442
443        for f in &faces {
444            mesh.faces.push(*f);
445        }
446
447        mesh
448    }
449
450    #[test]
451    fn test_gpu_sdf_params() {
452        let params = GpuSdfParams {
453            dims: [10, 10, 10],
454            origin: [-2.0, -2.0, -2.0],
455            voxel_size: 0.4,
456        };
457
458        assert_eq!(params.dims[0] * params.dims[1] * params.dims[2], 1000);
459    }
460
461    #[test]
462    fn test_try_compute_sdf_gpu() {
463        let mesh = create_test_cube();
464        let params = GpuSdfParams {
465            dims: [5, 5, 5],
466            origin: [-2.0, -2.0, -2.0],
467            voxel_size: 0.8,
468        };
469
470        // This test will pass whether or not GPU is available
471        let _result = try_compute_sdf_gpu(&mesh, &params);
472    }
473}