mesh_gpu/
collision.rs

1//! GPU-accelerated collision detection for self-intersection testing.
2//!
3//! This module provides GPU-accelerated self-intersection detection using
4//! WGPU compute shaders. It uses AABB-based broad phase culling followed
5//! by exact triangle-triangle intersection tests using the Separating Axis
6//! Theorem (SAT).
7
8use bytemuck::{Pod, Zeroable};
9use tracing::{debug, info, warn};
10use wgpu::util::DeviceExt;
11use wgpu::{BindGroupLayout, ComputePipeline};
12
13use mesh_repair::Mesh;
14
15use crate::buffers::MeshBuffers;
16use crate::context::GpuContext;
17use crate::error::{GpuError, GpuResult};
18
19/// Shader source for collision detection.
20const COLLISION_SHADER: &str = include_str!("shaders/collision.wgsl");
21
22/// Parameters for GPU collision detection.
23#[derive(Debug, Clone)]
24pub struct GpuCollisionParams {
25    /// Maximum number of intersection pairs to report.
26    /// Set to 0 for unlimited (up to buffer size).
27    pub max_pairs: usize,
28    /// Epsilon for geometric comparisons.
29    pub epsilon: f32,
30    /// Whether to skip adjacent triangles (sharing vertices).
31    pub skip_adjacent: bool,
32}
33
34impl Default for GpuCollisionParams {
35    fn default() -> Self {
36        Self {
37            max_pairs: 1000,
38            epsilon: 1e-7,
39            skip_adjacent: true,
40        }
41    }
42}
43
44/// Result of GPU collision detection.
45#[derive(Debug)]
46pub struct GpuCollisionResult {
47    /// Whether any self-intersections were found.
48    pub has_intersections: bool,
49    /// Number of intersecting triangle pairs found.
50    pub intersection_count: usize,
51    /// List of intersecting triangle pairs (face_idx_a, face_idx_b).
52    pub intersecting_pairs: Vec<(u32, u32)>,
53    /// Whether the search was truncated due to max_pairs limit.
54    pub truncated: bool,
55    /// Computation time in milliseconds.
56    pub compute_time_ms: f64,
57}
58
59/// Uniform parameters for the shader.
60#[repr(C)]
61#[derive(Clone, Copy, Debug, Pod, Zeroable)]
62struct ShaderCollisionParams {
63    triangle_count: u32,
64    max_pairs: u32,
65    epsilon: f32,
66    skip_adjacent: u32,
67}
68
69/// AABB structure (matches shader).
70#[repr(C)]
71#[derive(Clone, Copy, Debug, Pod, Zeroable)]
72struct GpuAABB {
73    min: [f32; 3],
74    _padding1: f32,
75    max: [f32; 3],
76    _padding2: f32,
77}
78
79/// Intersection pair (matches shader).
80#[repr(C)]
81#[derive(Clone, Copy, Debug, Pod, Zeroable)]
82struct GpuIntersectionPair {
83    tri_a: u32,
84    tri_b: u32,
85}
86
87/// Pipeline for GPU collision detection.
88pub struct CollisionPipeline {
89    aabb_pipeline: ComputePipeline,
90    test_pipeline: ComputePipeline,
91    bind_group_layout: BindGroupLayout,
92}
93
94impl CollisionPipeline {
95    /// Create a new collision detection pipeline.
96    pub fn new(ctx: &GpuContext) -> GpuResult<Self> {
97        debug!("Creating collision detection compute pipeline");
98
99        // Compile shader
100        let shader = ctx
101            .device
102            .create_shader_module(wgpu::ShaderModuleDescriptor {
103                label: Some("collision"),
104                source: wgpu::ShaderSource::Wgsl(COLLISION_SHADER.into()),
105            });
106
107        // Create bind group layout
108        let bind_group_layout =
109            ctx.device
110                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
111                    label: Some("collision_bind_group_layout"),
112                    entries: &[
113                        // Triangles (read-only)
114                        wgpu::BindGroupLayoutEntry {
115                            binding: 0,
116                            visibility: wgpu::ShaderStages::COMPUTE,
117                            ty: wgpu::BindingType::Buffer {
118                                ty: wgpu::BufferBindingType::Storage { read_only: true },
119                                has_dynamic_offset: false,
120                                min_binding_size: None,
121                            },
122                            count: None,
123                        },
124                        // Params (uniform)
125                        wgpu::BindGroupLayoutEntry {
126                            binding: 1,
127                            visibility: wgpu::ShaderStages::COMPUTE,
128                            ty: wgpu::BindingType::Buffer {
129                                ty: wgpu::BufferBindingType::Uniform,
130                                has_dynamic_offset: false,
131                                min_binding_size: None,
132                            },
133                            count: None,
134                        },
135                        // AABBs (read-write)
136                        wgpu::BindGroupLayoutEntry {
137                            binding: 2,
138                            visibility: wgpu::ShaderStages::COMPUTE,
139                            ty: wgpu::BindingType::Buffer {
140                                ty: wgpu::BufferBindingType::Storage { read_only: false },
141                                has_dynamic_offset: false,
142                                min_binding_size: None,
143                            },
144                            count: None,
145                        },
146                        // Intersection pairs (read-write)
147                        wgpu::BindGroupLayoutEntry {
148                            binding: 3,
149                            visibility: wgpu::ShaderStages::COMPUTE,
150                            ty: wgpu::BindingType::Buffer {
151                                ty: wgpu::BufferBindingType::Storage { read_only: false },
152                                has_dynamic_offset: false,
153                                min_binding_size: None,
154                            },
155                            count: None,
156                        },
157                        // Pair count (atomic)
158                        wgpu::BindGroupLayoutEntry {
159                            binding: 4,
160                            visibility: wgpu::ShaderStages::COMPUTE,
161                            ty: wgpu::BindingType::Buffer {
162                                ty: wgpu::BufferBindingType::Storage { read_only: false },
163                                has_dynamic_offset: false,
164                                min_binding_size: None,
165                            },
166                            count: None,
167                        },
168                    ],
169                });
170
171        // Create pipeline layout
172        let pipeline_layout = ctx
173            .device
174            .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
175                label: Some("collision_pipeline_layout"),
176                bind_group_layouts: &[&bind_group_layout],
177                push_constant_ranges: &[],
178            });
179
180        // Create AABB computation pipeline
181        let aabb_pipeline = ctx
182            .device
183            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
184                label: Some("collision_aabb_pipeline"),
185                layout: Some(&pipeline_layout),
186                module: &shader,
187                entry_point: Some("compute_aabbs"),
188                compilation_options: Default::default(),
189                cache: None,
190            });
191
192        // Create intersection test pipeline
193        let test_pipeline = ctx
194            .device
195            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
196                label: Some("collision_test_pipeline"),
197                layout: Some(&pipeline_layout),
198                module: &shader,
199                entry_point: Some("test_intersections"),
200                compilation_options: Default::default(),
201                cache: None,
202            });
203
204        Ok(Self {
205            aabb_pipeline,
206            test_pipeline,
207            bind_group_layout,
208        })
209    }
210
211    /// Detect self-intersections in a mesh.
212    pub fn detect(
213        &self,
214        ctx: &GpuContext,
215        mesh_buffers: &MeshBuffers,
216        params: &GpuCollisionParams,
217    ) -> GpuResult<GpuCollisionResult> {
218        let start = std::time::Instant::now();
219        let triangle_count = mesh_buffers.triangle_count as usize;
220
221        if triangle_count < 2 {
222            return Ok(GpuCollisionResult {
223                has_intersections: false,
224                intersection_count: 0,
225                intersecting_pairs: Vec::new(),
226                truncated: false,
227                compute_time_ms: 0.0,
228            });
229        }
230
231        let max_pairs = if params.max_pairs == 0 {
232            triangle_count * triangle_count / 2 // Upper bound
233        } else {
234            params.max_pairs
235        };
236
237        info!(
238            triangles = triangle_count,
239            max_pairs = max_pairs,
240            "Detecting self-intersections on GPU"
241        );
242
243        // Create uniform buffer
244        let shader_params = ShaderCollisionParams {
245            triangle_count: triangle_count as u32,
246            max_pairs: max_pairs as u32,
247            epsilon: params.epsilon,
248            skip_adjacent: if params.skip_adjacent { 1 } else { 0 },
249        };
250
251        let params_buffer = ctx
252            .device
253            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
254                label: Some("collision_params"),
255                contents: bytemuck::bytes_of(&shader_params),
256                usage: wgpu::BufferUsages::UNIFORM,
257            });
258
259        // Create AABB buffer
260        let aabb_size = triangle_count * std::mem::size_of::<GpuAABB>();
261        let aabb_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
262            label: Some("collision_aabbs"),
263            size: aabb_size as u64,
264            usage: wgpu::BufferUsages::STORAGE,
265            mapped_at_creation: false,
266        });
267
268        // Create intersection pairs buffer
269        let pairs_size = max_pairs * std::mem::size_of::<GpuIntersectionPair>();
270        let pairs_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
271            label: Some("collision_pairs"),
272            size: pairs_size as u64,
273            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
274            mapped_at_creation: false,
275        });
276
277        // Create pair count buffer
278        let count_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
279            label: Some("collision_count"),
280            size: std::mem::size_of::<u32>() as u64,
281            usage: wgpu::BufferUsages::STORAGE
282                | wgpu::BufferUsages::COPY_SRC
283                | wgpu::BufferUsages::COPY_DST,
284            mapped_at_creation: false,
285        });
286
287        // Initialize count to 0
288        ctx.queue
289            .write_buffer(&count_buffer, 0, bytemuck::bytes_of(&0u32));
290
291        // Create bind group
292        let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
293            label: Some("collision_bind_group"),
294            layout: &self.bind_group_layout,
295            entries: &[
296                wgpu::BindGroupEntry {
297                    binding: 0,
298                    resource: mesh_buffers.triangles.as_entire_binding(),
299                },
300                wgpu::BindGroupEntry {
301                    binding: 1,
302                    resource: params_buffer.as_entire_binding(),
303                },
304                wgpu::BindGroupEntry {
305                    binding: 2,
306                    resource: aabb_buffer.as_entire_binding(),
307                },
308                wgpu::BindGroupEntry {
309                    binding: 3,
310                    resource: pairs_buffer.as_entire_binding(),
311                },
312                wgpu::BindGroupEntry {
313                    binding: 4,
314                    resource: count_buffer.as_entire_binding(),
315                },
316            ],
317        });
318
319        // Create command encoder
320        let mut encoder = ctx
321            .device
322            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
323                label: Some("collision_encoder"),
324            });
325
326        let workgroups = (triangle_count as u32).div_ceil(256);
327
328        // Pass 1: Compute AABBs
329        {
330            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
331                label: Some("collision_aabb_pass"),
332                timestamp_writes: None,
333            });
334            compute_pass.set_pipeline(&self.aabb_pipeline);
335            compute_pass.set_bind_group(0, &bind_group, &[]);
336            compute_pass.dispatch_workgroups(workgroups, 1, 1);
337        }
338
339        // Pass 2: Test intersections
340        {
341            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
342                label: Some("collision_test_pass"),
343                timestamp_writes: None,
344            });
345            compute_pass.set_pipeline(&self.test_pipeline);
346            compute_pass.set_bind_group(0, &bind_group, &[]);
347            compute_pass.dispatch_workgroups(workgroups, 1, 1);
348        }
349
350        // Submit commands
351        ctx.queue.submit([encoder.finish()]);
352
353        // Download results
354        let pair_count = self.download_count(ctx, &count_buffer)?;
355        let pairs = self.download_pairs(ctx, &pairs_buffer, pair_count.min(max_pairs as u32))?;
356
357        let compute_time_ms = start.elapsed().as_secs_f64() * 1000.0;
358
359        let intersecting_pairs: Vec<(u32, u32)> =
360            pairs.iter().map(|p| (p.tri_a, p.tri_b)).collect();
361
362        info!(
363            pairs_found = pair_count,
364            time_ms = compute_time_ms,
365            "Collision detection complete"
366        );
367
368        Ok(GpuCollisionResult {
369            has_intersections: pair_count > 0,
370            intersection_count: pair_count as usize,
371            intersecting_pairs,
372            truncated: pair_count as usize >= max_pairs,
373            compute_time_ms,
374        })
375    }
376
377    fn download_count(&self, ctx: &GpuContext, buffer: &wgpu::Buffer) -> GpuResult<u32> {
378        let staging = ctx.device.create_buffer(&wgpu::BufferDescriptor {
379            label: Some("count_staging"),
380            size: std::mem::size_of::<u32>() as u64,
381            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
382            mapped_at_creation: false,
383        });
384
385        let mut encoder = ctx
386            .device
387            .create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
388        encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, std::mem::size_of::<u32>() as u64);
389        ctx.queue.submit([encoder.finish()]);
390
391        let slice = staging.slice(..);
392        let (tx, rx) = std::sync::mpsc::channel();
393        slice.map_async(wgpu::MapMode::Read, move |result| {
394            tx.send(result).unwrap();
395        });
396        ctx.device.poll(wgpu::Maintain::Wait);
397
398        rx.recv()
399            .map_err(|_| GpuError::BufferMapping("channel closed".into()))?
400            .map_err(|e| GpuError::BufferMapping(format!("{:?}", e)))?;
401
402        let data = slice.get_mapped_range();
403        let count = *bytemuck::from_bytes::<u32>(&data);
404        drop(data);
405        staging.unmap();
406
407        Ok(count)
408    }
409
410    fn download_pairs(
411        &self,
412        ctx: &GpuContext,
413        buffer: &wgpu::Buffer,
414        count: u32,
415    ) -> GpuResult<Vec<GpuIntersectionPair>> {
416        if count == 0 {
417            return Ok(Vec::new());
418        }
419
420        let size = (count as usize) * std::mem::size_of::<GpuIntersectionPair>();
421        let staging = ctx.device.create_buffer(&wgpu::BufferDescriptor {
422            label: Some("pairs_staging"),
423            size: size as u64,
424            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
425            mapped_at_creation: false,
426        });
427
428        let mut encoder = ctx
429            .device
430            .create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
431        encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, size as u64);
432        ctx.queue.submit([encoder.finish()]);
433
434        let slice = staging.slice(..);
435        let (tx, rx) = std::sync::mpsc::channel();
436        slice.map_async(wgpu::MapMode::Read, move |result| {
437            tx.send(result).unwrap();
438        });
439        ctx.device.poll(wgpu::Maintain::Wait);
440
441        rx.recv()
442            .map_err(|_| GpuError::BufferMapping("channel closed".into()))?
443            .map_err(|e| GpuError::BufferMapping(format!("{:?}", e)))?;
444
445        let data = slice.get_mapped_range();
446        let pairs: Vec<GpuIntersectionPair> = bytemuck::cast_slice(&data).to_vec();
447        drop(data);
448        staging.unmap();
449
450        Ok(pairs)
451    }
452}
453
454/// Detect self-intersections in a mesh on GPU.
455pub fn detect_self_intersections_gpu(
456    mesh: &Mesh,
457    params: &GpuCollisionParams,
458) -> GpuResult<GpuCollisionResult> {
459    let ctx = GpuContext::try_get()?;
460
461    // Upload mesh to GPU
462    let mesh_buffers = MeshBuffers::from_mesh(ctx, mesh)?;
463
464    let pipeline = CollisionPipeline::new(ctx)?;
465    pipeline.detect(ctx, &mesh_buffers, params)
466}
467
468/// Try to detect self-intersections on GPU, returning None if unavailable.
469pub fn try_detect_self_intersections_gpu(
470    mesh: &Mesh,
471    params: &GpuCollisionParams,
472) -> Option<GpuCollisionResult> {
473    match detect_self_intersections_gpu(mesh, params) {
474        Ok(result) => Some(result),
475        Err(GpuError::NotAvailable) => {
476            debug!("GPU not available for collision detection");
477            None
478        }
479        Err(e) => {
480            warn!("GPU collision detection failed: {}", e);
481            None
482        }
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489    use mesh_repair::Vertex;
490
491    fn create_simple_mesh() -> Mesh {
492        let mut mesh = Mesh::new();
493
494        // Single triangle
495        mesh.vertices.push(Vertex::from_coords(0.0, 0.0, 0.0));
496        mesh.vertices.push(Vertex::from_coords(1.0, 0.0, 0.0));
497        mesh.vertices.push(Vertex::from_coords(0.0, 1.0, 0.0));
498        mesh.faces.push([0, 1, 2]);
499
500        mesh
501    }
502
503    #[test]
504    fn test_gpu_collision_params_default() {
505        let params = GpuCollisionParams::default();
506        assert!(params.skip_adjacent);
507        assert_eq!(params.max_pairs, 1000);
508    }
509
510    #[test]
511    fn test_try_detect_self_intersections_gpu() {
512        let mesh = create_simple_mesh();
513        let params = GpuCollisionParams::default();
514
515        // This test will pass whether or not GPU is available
516        let _result = try_detect_self_intersections_gpu(&mesh, &params);
517    }
518}