1use bytemuck::{Pod, Zeroable};
8use tracing::{debug, info, warn};
9use wgpu::util::DeviceExt;
10use wgpu::{BindGroupLayout, ComputePipeline};
11
12use mesh_repair::Mesh;
13
14use crate::context::GpuContext;
15use crate::error::{GpuError, GpuResult};
16
17const SURFACE_NETS_SHADER: &str = include_str!("shaders/surface_nets.wgsl");
19
20#[derive(Debug, Clone)]
22pub struct GpuSurfaceNetsParams {
23 pub dims: [usize; 3],
25 pub origin: [f32; 3],
27 pub voxel_size: f32,
29 pub iso_value: f32,
31}
32
33impl Default for GpuSurfaceNetsParams {
34 fn default() -> Self {
35 Self {
36 dims: [0, 0, 0],
37 origin: [0.0, 0.0, 0.0],
38 voxel_size: 1.0,
39 iso_value: 0.0,
40 }
41 }
42}
43
44#[derive(Debug)]
46pub struct GpuSurfaceNetsResult {
47 pub mesh: Mesh,
49 pub active_cells: usize,
51 pub vertex_count: usize,
53 pub compute_time_ms: f64,
55}
56
57#[repr(C)]
59#[derive(Clone, Copy, Debug, Pod, Zeroable)]
60struct GpuOutputVertex {
61 position: [f32; 4], normal: [f32; 4], }
64
65#[repr(C)]
67#[derive(Clone, Copy, Debug, Pod, Zeroable)]
68struct ShaderGridParams {
69 origin: [f32; 4],
70 dims: [u32; 4],
71 voxel_size: f32,
72 iso_value: f32,
73 _padding: [f32; 2],
74}
75
76pub struct SurfaceNetsPipeline {
78 identify_pipeline: ComputePipeline,
79 generate_pipeline: ComputePipeline,
80 bind_group_layout: BindGroupLayout,
81}
82
83impl SurfaceNetsPipeline {
84 pub fn new(ctx: &GpuContext) -> GpuResult<Self> {
86 debug!("Creating Surface Nets compute pipeline");
87
88 let shader = ctx
90 .device
91 .create_shader_module(wgpu::ShaderModuleDescriptor {
92 label: Some("surface_nets"),
93 source: wgpu::ShaderSource::Wgsl(SURFACE_NETS_SHADER.into()),
94 });
95
96 let bind_group_layout =
98 ctx.device
99 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
100 label: Some("surface_nets_bind_group_layout"),
101 entries: &[
102 wgpu::BindGroupLayoutEntry {
104 binding: 0,
105 visibility: wgpu::ShaderStages::COMPUTE,
106 ty: wgpu::BindingType::Buffer {
107 ty: wgpu::BufferBindingType::Storage { read_only: true },
108 has_dynamic_offset: false,
109 min_binding_size: None,
110 },
111 count: None,
112 },
113 wgpu::BindGroupLayoutEntry {
115 binding: 1,
116 visibility: wgpu::ShaderStages::COMPUTE,
117 ty: wgpu::BindingType::Buffer {
118 ty: wgpu::BufferBindingType::Uniform,
119 has_dynamic_offset: false,
120 min_binding_size: None,
121 },
122 count: None,
123 },
124 wgpu::BindGroupLayoutEntry {
126 binding: 2,
127 visibility: wgpu::ShaderStages::COMPUTE,
128 ty: wgpu::BindingType::Buffer {
129 ty: wgpu::BufferBindingType::Storage { read_only: false },
130 has_dynamic_offset: false,
131 min_binding_size: None,
132 },
133 count: None,
134 },
135 wgpu::BindGroupLayoutEntry {
137 binding: 3,
138 visibility: wgpu::ShaderStages::COMPUTE,
139 ty: wgpu::BindingType::Buffer {
140 ty: wgpu::BufferBindingType::Storage { read_only: false },
141 has_dynamic_offset: false,
142 min_binding_size: None,
143 },
144 count: None,
145 },
146 wgpu::BindGroupLayoutEntry {
148 binding: 4,
149 visibility: wgpu::ShaderStages::COMPUTE,
150 ty: wgpu::BindingType::Buffer {
151 ty: wgpu::BufferBindingType::Storage { read_only: false },
152 has_dynamic_offset: false,
153 min_binding_size: None,
154 },
155 count: None,
156 },
157 ],
158 });
159
160 let pipeline_layout = ctx
162 .device
163 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
164 label: Some("surface_nets_pipeline_layout"),
165 bind_group_layouts: &[&bind_group_layout],
166 push_constant_ranges: &[],
167 });
168
169 let identify_pipeline =
171 ctx.device
172 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
173 label: Some("surface_nets_identify_pipeline"),
174 layout: Some(&pipeline_layout),
175 module: &shader,
176 entry_point: Some("identify_active_cells"),
177 compilation_options: Default::default(),
178 cache: None,
179 });
180
181 let generate_pipeline =
183 ctx.device
184 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
185 label: Some("surface_nets_generate_pipeline"),
186 layout: Some(&pipeline_layout),
187 module: &shader,
188 entry_point: Some("generate_vertices"),
189 compilation_options: Default::default(),
190 cache: None,
191 });
192
193 Ok(Self {
194 identify_pipeline,
195 generate_pipeline,
196 bind_group_layout,
197 })
198 }
199
200 pub fn extract(
210 &self,
211 ctx: &GpuContext,
212 sdf_buffer: &wgpu::Buffer,
213 params: &GpuSurfaceNetsParams,
214 ) -> GpuResult<GpuSurfaceNetsResult> {
215 let start = std::time::Instant::now();
216
217 let cells = [
219 params.dims[0].saturating_sub(1),
220 params.dims[1].saturating_sub(1),
221 params.dims[2].saturating_sub(1),
222 ];
223 let total_cells = cells[0] * cells[1] * cells[2];
224
225 if total_cells == 0 {
226 return Ok(GpuSurfaceNetsResult {
227 mesh: Mesh::new(),
228 active_cells: 0,
229 vertex_count: 0,
230 compute_time_ms: 0.0,
231 });
232 }
233
234 info!(
235 dims = ?params.dims,
236 cells = total_cells,
237 "Extracting isosurface on GPU"
238 );
239
240 let grid_params = ShaderGridParams {
242 origin: [params.origin[0], params.origin[1], params.origin[2], 0.0],
243 dims: [
244 params.dims[0] as u32,
245 params.dims[1] as u32,
246 params.dims[2] as u32,
247 0,
248 ],
249 voxel_size: params.voxel_size,
250 iso_value: params.iso_value,
251 _padding: [0.0, 0.0],
252 };
253
254 let params_buffer = ctx
255 .device
256 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
257 label: Some("surface_nets_params"),
258 contents: bytemuck::bytes_of(&grid_params),
259 usage: wgpu::BufferUsages::UNIFORM,
260 });
261
262 let active_cells_size = total_cells * std::mem::size_of::<u32>();
264 let active_cells_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
265 label: Some("surface_nets_active_cells"),
266 size: active_cells_size as u64,
267 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
268 mapped_at_creation: false,
269 });
270
271 let vertices_size = total_cells * std::mem::size_of::<GpuOutputVertex>();
273 let vertices_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
274 label: Some("surface_nets_vertices"),
275 size: vertices_size as u64,
276 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
277 mapped_at_creation: false,
278 });
279
280 let count_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
282 label: Some("surface_nets_count"),
283 size: std::mem::size_of::<u32>() as u64,
284 usage: wgpu::BufferUsages::STORAGE
285 | wgpu::BufferUsages::COPY_SRC
286 | wgpu::BufferUsages::COPY_DST,
287 mapped_at_creation: false,
288 });
289
290 ctx.queue
292 .write_buffer(&count_buffer, 0, bytemuck::bytes_of(&0u32));
293
294 let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
296 label: Some("surface_nets_bind_group"),
297 layout: &self.bind_group_layout,
298 entries: &[
299 wgpu::BindGroupEntry {
300 binding: 0,
301 resource: sdf_buffer.as_entire_binding(),
302 },
303 wgpu::BindGroupEntry {
304 binding: 1,
305 resource: params_buffer.as_entire_binding(),
306 },
307 wgpu::BindGroupEntry {
308 binding: 2,
309 resource: active_cells_buffer.as_entire_binding(),
310 },
311 wgpu::BindGroupEntry {
312 binding: 3,
313 resource: vertices_buffer.as_entire_binding(),
314 },
315 wgpu::BindGroupEntry {
316 binding: 4,
317 resource: count_buffer.as_entire_binding(),
318 },
319 ],
320 });
321
322 let mut encoder = ctx
324 .device
325 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
326 label: Some("surface_nets_encoder"),
327 });
328
329 let workgroups = (total_cells as u32).div_ceil(256);
330
331 {
333 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
334 label: Some("surface_nets_identify_pass"),
335 timestamp_writes: None,
336 });
337 compute_pass.set_pipeline(&self.identify_pipeline);
338 compute_pass.set_bind_group(0, &bind_group, &[]);
339 compute_pass.dispatch_workgroups(workgroups, 1, 1);
340 }
341
342 {
344 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
345 label: Some("surface_nets_generate_pass"),
346 timestamp_writes: None,
347 });
348 compute_pass.set_pipeline(&self.generate_pipeline);
349 compute_pass.set_bind_group(0, &bind_group, &[]);
350 compute_pass.dispatch_workgroups(workgroups, 1, 1);
351 }
352
353 ctx.queue.submit([encoder.finish()]);
355
356 let vertex_count = self.download_count(ctx, &count_buffer)?;
358 let vertices = self.download_vertices(ctx, &vertices_buffer, total_cells)?;
359
360 let mesh = self.build_mesh(&vertices, vertex_count as usize, &cells, params);
362
363 let compute_time_ms = start.elapsed().as_secs_f64() * 1000.0;
364 info!(
365 vertices = vertex_count,
366 faces = mesh.faces.len(),
367 time_ms = compute_time_ms,
368 "Isosurface extraction complete"
369 );
370
371 Ok(GpuSurfaceNetsResult {
372 mesh,
373 active_cells: vertex_count as usize, vertex_count: vertex_count as usize,
375 compute_time_ms,
376 })
377 }
378
379 fn download_count(&self, ctx: &GpuContext, buffer: &wgpu::Buffer) -> GpuResult<u32> {
380 let staging = ctx.device.create_buffer(&wgpu::BufferDescriptor {
381 label: Some("count_staging"),
382 size: std::mem::size_of::<u32>() as u64,
383 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
384 mapped_at_creation: false,
385 });
386
387 let mut encoder = ctx
388 .device
389 .create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
390 encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, std::mem::size_of::<u32>() as u64);
391 ctx.queue.submit([encoder.finish()]);
392
393 let slice = staging.slice(..);
394 let (tx, rx) = std::sync::mpsc::channel();
395 slice.map_async(wgpu::MapMode::Read, move |result| {
396 tx.send(result).unwrap();
397 });
398 ctx.device.poll(wgpu::Maintain::Wait);
399
400 rx.recv()
401 .map_err(|_| GpuError::BufferMapping("channel closed".into()))?
402 .map_err(|e| GpuError::BufferMapping(format!("{:?}", e)))?;
403
404 let data = slice.get_mapped_range();
405 let count = *bytemuck::from_bytes::<u32>(&data);
406 drop(data);
407 staging.unmap();
408
409 Ok(count)
410 }
411
412 fn download_vertices(
413 &self,
414 ctx: &GpuContext,
415 buffer: &wgpu::Buffer,
416 count: usize,
417 ) -> GpuResult<Vec<GpuOutputVertex>> {
418 let size = count * std::mem::size_of::<GpuOutputVertex>();
419 let staging = ctx.device.create_buffer(&wgpu::BufferDescriptor {
420 label: Some("vertices_staging"),
421 size: size as u64,
422 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
423 mapped_at_creation: false,
424 });
425
426 let mut encoder = ctx
427 .device
428 .create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
429 encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, size as u64);
430 ctx.queue.submit([encoder.finish()]);
431
432 let slice = staging.slice(..);
433 let (tx, rx) = std::sync::mpsc::channel();
434 slice.map_async(wgpu::MapMode::Read, move |result| {
435 tx.send(result).unwrap();
436 });
437 ctx.device.poll(wgpu::Maintain::Wait);
438
439 rx.recv()
440 .map_err(|_| GpuError::BufferMapping("channel closed".into()))?
441 .map_err(|e| GpuError::BufferMapping(format!("{:?}", e)))?;
442
443 let data = slice.get_mapped_range();
444 let vertices: Vec<GpuOutputVertex> = bytemuck::cast_slice(&data).to_vec();
445 drop(data);
446 staging.unmap();
447
448 Ok(vertices)
449 }
450
451 fn build_mesh(
452 &self,
453 gpu_vertices: &[GpuOutputVertex],
454 vertex_count: usize,
455 _cells: &[usize; 3],
456 _params: &GpuSurfaceNetsParams,
457 ) -> Mesh {
458 use mesh_repair::Vertex;
459
460 let mut mesh = Mesh::new();
461
462 for v in gpu_vertices.iter().take(vertex_count) {
464 let mut vertex = Vertex::from_coords(
465 v.position[0] as f64,
466 v.position[1] as f64,
467 v.position[2] as f64,
468 );
469 vertex.normal = Some(nalgebra::Vector3::new(
470 v.normal[0] as f64,
471 v.normal[1] as f64,
472 v.normal[2] as f64,
473 ));
474 mesh.vertices.push(vertex);
475 }
476
477 mesh
486 }
487}
488
489pub fn extract_isosurface_gpu(
491 sdf_values: &[f32],
492 params: &GpuSurfaceNetsParams,
493) -> GpuResult<GpuSurfaceNetsResult> {
494 let ctx = GpuContext::try_get()?;
495
496 let sdf_buffer = ctx
498 .device
499 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
500 label: Some("surface_nets_sdf"),
501 contents: bytemuck::cast_slice(sdf_values),
502 usage: wgpu::BufferUsages::STORAGE,
503 });
504
505 let pipeline = SurfaceNetsPipeline::new(ctx)?;
506 pipeline.extract(ctx, &sdf_buffer, params)
507}
508
509pub fn try_extract_isosurface_gpu(
511 sdf_values: &[f32],
512 params: &GpuSurfaceNetsParams,
513) -> Option<GpuSurfaceNetsResult> {
514 match extract_isosurface_gpu(sdf_values, params) {
515 Ok(result) => Some(result),
516 Err(GpuError::NotAvailable) => {
517 debug!("GPU not available for isosurface extraction");
518 None
519 }
520 Err(e) => {
521 warn!("GPU isosurface extraction failed: {}", e);
522 None
523 }
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530
531 #[test]
532 fn test_gpu_surface_nets_params_default() {
533 let params = GpuSurfaceNetsParams::default();
534 assert_eq!(params.iso_value, 0.0);
535 assert_eq!(params.voxel_size, 1.0);
536 }
537
538 #[test]
539 fn test_try_extract_isosurface_gpu() {
540 let mut sdf = vec![1.0f32; 27];
542 sdf[1 + 3 + 9] = -1.0;
544
545 let params = GpuSurfaceNetsParams {
546 dims: [3, 3, 3],
547 origin: [0.0, 0.0, 0.0],
548 voxel_size: 1.0,
549 iso_value: 0.0,
550 };
551
552 let _result = try_extract_isosurface_gpu(&sdf, ¶ms);
554 }
555}