1use tracing::{debug, info, warn};
7use wgpu::{BindGroupLayout, ComputePipeline, ShaderModule};
8
9use mesh_repair::Mesh;
10
11use crate::buffers::{MeshBuffers, SdfGridBuffers, TileConfig};
12use crate::context::GpuContext;
13use crate::error::{GpuError, GpuResult};
14
15const SDF_SHADER: &str = include_str!("shaders/sdf_compute.wgsl");
17
18#[derive(Debug, Clone)]
20pub struct GpuSdfParams {
21 pub dims: [usize; 3],
23 pub origin: [f32; 3],
25 pub voxel_size: f32,
27}
28
29#[derive(Debug)]
31pub struct GpuSdfResult {
32 pub values: Vec<f32>,
34 pub dims: [usize; 3],
36 pub compute_time_ms: f64,
38}
39
40pub struct SdfPipeline {
45 #[allow(dead_code)] shader: ShaderModule,
47 pipeline: ComputePipeline,
48 bind_group_layout: BindGroupLayout,
49}
50
51impl SdfPipeline {
52 pub fn new(ctx: &GpuContext) -> GpuResult<Self> {
54 debug!("Creating SDF compute pipeline");
55
56 let shader = ctx
58 .device
59 .create_shader_module(wgpu::ShaderModuleDescriptor {
60 label: Some("sdf_compute"),
61 source: wgpu::ShaderSource::Wgsl(SDF_SHADER.into()),
62 });
63
64 let bind_group_layout =
66 ctx.device
67 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
68 label: Some("sdf_bind_group_layout"),
69 entries: &[
70 wgpu::BindGroupLayoutEntry {
72 binding: 0,
73 visibility: wgpu::ShaderStages::COMPUTE,
74 ty: wgpu::BindingType::Buffer {
75 ty: wgpu::BufferBindingType::Storage { read_only: true },
76 has_dynamic_offset: false,
77 min_binding_size: None,
78 },
79 count: None,
80 },
81 wgpu::BindGroupLayoutEntry {
83 binding: 1,
84 visibility: wgpu::ShaderStages::COMPUTE,
85 ty: wgpu::BindingType::Buffer {
86 ty: wgpu::BufferBindingType::Uniform,
87 has_dynamic_offset: false,
88 min_binding_size: None,
89 },
90 count: None,
91 },
92 wgpu::BindGroupLayoutEntry {
94 binding: 2,
95 visibility: wgpu::ShaderStages::COMPUTE,
96 ty: wgpu::BindingType::Buffer {
97 ty: wgpu::BufferBindingType::Storage { read_only: false },
98 has_dynamic_offset: false,
99 min_binding_size: None,
100 },
101 count: None,
102 },
103 ],
104 });
105
106 let pipeline_layout = ctx
108 .device
109 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
110 label: Some("sdf_pipeline_layout"),
111 bind_group_layouts: &[&bind_group_layout],
112 push_constant_ranges: &[],
113 });
114
115 let pipeline = ctx
117 .device
118 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
119 label: Some("sdf_compute_pipeline"),
120 layout: Some(&pipeline_layout),
121 module: &shader,
122 entry_point: Some("compute_sdf"),
123 compilation_options: Default::default(),
124 cache: None,
125 });
126
127 Ok(Self {
128 shader,
129 pipeline,
130 bind_group_layout,
131 })
132 }
133
134 pub fn compute(
144 &self,
145 ctx: &GpuContext,
146 mesh_buffers: &MeshBuffers,
147 params: &GpuSdfParams,
148 ) -> GpuResult<GpuSdfResult> {
149 let start = std::time::Instant::now();
150 let total_voxels = params.dims[0] * params.dims[1] * params.dims[2];
151
152 info!(
153 dims = ?params.dims,
154 total_voxels = total_voxels,
155 triangles = mesh_buffers.triangle_count,
156 "Computing SDF on GPU"
157 );
158
159 let grid_buffers = SdfGridBuffers::allocate(
161 ctx,
162 params.dims,
163 params.origin,
164 params.voxel_size,
165 mesh_buffers.triangle_count,
166 )?;
167
168 let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
170 label: Some("sdf_bind_group"),
171 layout: &self.bind_group_layout,
172 entries: &[
173 wgpu::BindGroupEntry {
174 binding: 0,
175 resource: mesh_buffers.triangles.as_entire_binding(),
176 },
177 wgpu::BindGroupEntry {
178 binding: 1,
179 resource: grid_buffers.params.as_entire_binding(),
180 },
181 wgpu::BindGroupEntry {
182 binding: 2,
183 resource: grid_buffers.values.as_entire_binding(),
184 },
185 ],
186 });
187
188 let mut encoder = ctx
190 .device
191 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
192 label: Some("sdf_compute_encoder"),
193 });
194
195 {
197 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
198 label: Some("sdf_compute_pass"),
199 timestamp_writes: None,
200 });
201
202 compute_pass.set_pipeline(&self.pipeline);
203 compute_pass.set_bind_group(0, &bind_group, &[]);
204
205 let workgroups = (total_voxels as u32).div_ceil(256);
207 compute_pass.dispatch_workgroups(workgroups, 1, 1);
208 }
209
210 ctx.queue.submit([encoder.finish()]);
212
213 let values = grid_buffers.download_values(ctx)?;
215
216 let compute_time_ms = start.elapsed().as_secs_f64() * 1000.0;
217 info!(
218 voxels = total_voxels,
219 time_ms = compute_time_ms,
220 "SDF computation complete"
221 );
222
223 Ok(GpuSdfResult {
224 values,
225 dims: params.dims,
226 compute_time_ms,
227 })
228 }
229}
230
231pub fn compute_sdf_gpu(mesh: &Mesh, params: &GpuSdfParams) -> GpuResult<GpuSdfResult> {
246 let ctx = GpuContext::try_get()?;
247
248 let mesh_buffers = MeshBuffers::from_mesh(ctx, mesh)?;
250
251 let total_voxels = params.dims[0] * params.dims[1] * params.dims[2];
253 let max_voxels = ctx.max_storage_buffer_size() as usize / std::mem::size_of::<f32>();
254
255 if total_voxels > max_voxels {
256 compute_sdf_tiled(ctx, &mesh_buffers, params)
258 } else {
259 let pipeline = SdfPipeline::new(ctx)?;
261 pipeline.compute(ctx, &mesh_buffers, params)
262 }
263}
264
265fn compute_sdf_tiled(
267 ctx: &GpuContext,
268 mesh_buffers: &MeshBuffers,
269 params: &GpuSdfParams,
270) -> GpuResult<GpuSdfResult> {
271 let start = std::time::Instant::now();
272 let total_voxels = params.dims[0] * params.dims[1] * params.dims[2];
273
274 let available_memory = ctx.estimate_available_memory();
276 let grid_memory =
278 available_memory.saturating_sub(mesh_buffers.triangles_size() + 256 * 1024 * 1024);
279 let tile_config = TileConfig::for_memory_budget(grid_memory);
280
281 let tile_counts = tile_config.tile_count(params.dims);
282 let total_tiles = tile_config.total_tiles(params.dims);
283
284 info!(
285 grid_dims = ?params.dims,
286 tile_size = ?tile_config.tile_size,
287 tiles = total_tiles,
288 "Using tiled SDF computation"
289 );
290
291 let mut result = vec![0.0f32; total_voxels];
293
294 let pipeline = SdfPipeline::new(ctx)?;
296
297 for tz in 0..tile_counts[2] {
299 for ty in 0..tile_counts[1] {
300 for tx in 0..tile_counts[0] {
301 let tile_origin_voxels = [
302 tx * tile_config.tile_size[0],
303 ty * tile_config.tile_size[1],
304 tz * tile_config.tile_size[2],
305 ];
306
307 let tile_dims = [
309 (params.dims[0] - tile_origin_voxels[0]).min(tile_config.tile_size[0]),
310 (params.dims[1] - tile_origin_voxels[1]).min(tile_config.tile_size[1]),
311 (params.dims[2] - tile_origin_voxels[2]).min(tile_config.tile_size[2]),
312 ];
313
314 let tile_origin_world = [
316 params.origin[0] + (tile_origin_voxels[0] as f32) * params.voxel_size,
317 params.origin[1] + (tile_origin_voxels[1] as f32) * params.voxel_size,
318 params.origin[2] + (tile_origin_voxels[2] as f32) * params.voxel_size,
319 ];
320
321 let tile_params = GpuSdfParams {
322 dims: tile_dims,
323 origin: tile_origin_world,
324 voxel_size: params.voxel_size,
325 };
326
327 let tile_result = pipeline.compute(ctx, mesh_buffers, &tile_params)?;
329
330 copy_tile_to_grid(
332 &tile_result.values,
333 &mut result,
334 params.dims,
335 tile_origin_voxels,
336 tile_dims,
337 );
338
339 debug!(tile_x = tx, tile_y = ty, tile_z = tz, "Tile processed");
340 }
341 }
342 }
343
344 let compute_time_ms = start.elapsed().as_secs_f64() * 1000.0;
345 info!(
346 tiles = total_tiles,
347 time_ms = compute_time_ms,
348 "Tiled SDF computation complete"
349 );
350
351 Ok(GpuSdfResult {
352 values: result,
353 dims: params.dims,
354 compute_time_ms,
355 })
356}
357
358fn copy_tile_to_grid(
360 tile_values: &[f32],
361 grid_values: &mut [f32],
362 grid_dims: [usize; 3],
363 tile_origin: [usize; 3],
364 tile_dims: [usize; 3],
365) {
366 for z in 0..tile_dims[2] {
368 for y in 0..tile_dims[1] {
369 for x in 0..tile_dims[0] {
370 let tile_idx = z + y * tile_dims[2] + x * tile_dims[1] * tile_dims[2];
371 let grid_x = tile_origin[0] + x;
372 let grid_y = tile_origin[1] + y;
373 let grid_z = tile_origin[2] + z;
374 let grid_idx =
375 grid_z + grid_y * grid_dims[2] + grid_x * grid_dims[1] * grid_dims[2];
376
377 if grid_idx < grid_values.len() && tile_idx < tile_values.len() {
378 grid_values[grid_idx] = tile_values[tile_idx];
379 }
380 }
381 }
382 }
383}
384
385pub fn try_compute_sdf_gpu(mesh: &Mesh, params: &GpuSdfParams) -> Option<GpuSdfResult> {
390 match compute_sdf_gpu(mesh, params) {
391 Ok(result) => Some(result),
392 Err(GpuError::NotAvailable) => {
393 debug!("GPU not available for SDF computation");
394 None
395 }
396 Err(e) => {
397 warn!("GPU SDF computation failed: {}", e);
398 None
399 }
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use mesh_repair::Vertex;
407
408 fn create_test_cube() -> Mesh {
409 let mut mesh = Mesh::new();
410
411 let coords = [
413 [-1.0, -1.0, -1.0],
414 [1.0, -1.0, -1.0],
415 [1.0, 1.0, -1.0],
416 [-1.0, 1.0, -1.0],
417 [-1.0, -1.0, 1.0],
418 [1.0, -1.0, 1.0],
419 [1.0, 1.0, 1.0],
420 [-1.0, 1.0, 1.0],
421 ];
422
423 for c in &coords {
424 mesh.vertices.push(Vertex::from_coords(c[0], c[1], c[2]));
425 }
426
427 let faces = [
429 [0, 1, 2],
430 [0, 2, 3], [4, 6, 5],
432 [4, 7, 6], [0, 5, 1],
434 [0, 4, 5], [2, 7, 3],
436 [2, 6, 7], [0, 3, 7],
438 [0, 7, 4], [1, 5, 6],
440 [1, 6, 2], ];
442
443 for f in &faces {
444 mesh.faces.push(*f);
445 }
446
447 mesh
448 }
449
450 #[test]
451 fn test_gpu_sdf_params() {
452 let params = GpuSdfParams {
453 dims: [10, 10, 10],
454 origin: [-2.0, -2.0, -2.0],
455 voxel_size: 0.4,
456 };
457
458 assert_eq!(params.dims[0] * params.dims[1] * params.dims[2], 1000);
459 }
460
461 #[test]
462 fn test_try_compute_sdf_gpu() {
463 let mesh = create_test_cube();
464 let params = GpuSdfParams {
465 dims: [5, 5, 5],
466 origin: [-2.0, -2.0, -2.0],
467 voxel_size: 0.8,
468 };
469
470 let _result = try_compute_sdf_gpu(&mesh, ¶ms);
472 }
473}