1use bytemuck::{Pod, Zeroable};
9use tracing::{debug, info, warn};
10use wgpu::util::DeviceExt;
11use wgpu::{BindGroupLayout, ComputePipeline};
12
13use mesh_repair::Mesh;
14
15use crate::buffers::MeshBuffers;
16use crate::context::GpuContext;
17use crate::error::{GpuError, GpuResult};
18
19const COLLISION_SHADER: &str = include_str!("shaders/collision.wgsl");
21
22#[derive(Debug, Clone)]
24pub struct GpuCollisionParams {
25 pub max_pairs: usize,
28 pub epsilon: f32,
30 pub skip_adjacent: bool,
32}
33
34impl Default for GpuCollisionParams {
35 fn default() -> Self {
36 Self {
37 max_pairs: 1000,
38 epsilon: 1e-7,
39 skip_adjacent: true,
40 }
41 }
42}
43
44#[derive(Debug)]
46pub struct GpuCollisionResult {
47 pub has_intersections: bool,
49 pub intersection_count: usize,
51 pub intersecting_pairs: Vec<(u32, u32)>,
53 pub truncated: bool,
55 pub compute_time_ms: f64,
57}
58
59#[repr(C)]
61#[derive(Clone, Copy, Debug, Pod, Zeroable)]
62struct ShaderCollisionParams {
63 triangle_count: u32,
64 max_pairs: u32,
65 epsilon: f32,
66 skip_adjacent: u32,
67}
68
69#[repr(C)]
71#[derive(Clone, Copy, Debug, Pod, Zeroable)]
72struct GpuAABB {
73 min: [f32; 3],
74 _padding1: f32,
75 max: [f32; 3],
76 _padding2: f32,
77}
78
79#[repr(C)]
81#[derive(Clone, Copy, Debug, Pod, Zeroable)]
82struct GpuIntersectionPair {
83 tri_a: u32,
84 tri_b: u32,
85}
86
87pub struct CollisionPipeline {
89 aabb_pipeline: ComputePipeline,
90 test_pipeline: ComputePipeline,
91 bind_group_layout: BindGroupLayout,
92}
93
94impl CollisionPipeline {
95 pub fn new(ctx: &GpuContext) -> GpuResult<Self> {
97 debug!("Creating collision detection compute pipeline");
98
99 let shader = ctx
101 .device
102 .create_shader_module(wgpu::ShaderModuleDescriptor {
103 label: Some("collision"),
104 source: wgpu::ShaderSource::Wgsl(COLLISION_SHADER.into()),
105 });
106
107 let bind_group_layout =
109 ctx.device
110 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
111 label: Some("collision_bind_group_layout"),
112 entries: &[
113 wgpu::BindGroupLayoutEntry {
115 binding: 0,
116 visibility: wgpu::ShaderStages::COMPUTE,
117 ty: wgpu::BindingType::Buffer {
118 ty: wgpu::BufferBindingType::Storage { read_only: true },
119 has_dynamic_offset: false,
120 min_binding_size: None,
121 },
122 count: None,
123 },
124 wgpu::BindGroupLayoutEntry {
126 binding: 1,
127 visibility: wgpu::ShaderStages::COMPUTE,
128 ty: wgpu::BindingType::Buffer {
129 ty: wgpu::BufferBindingType::Uniform,
130 has_dynamic_offset: false,
131 min_binding_size: None,
132 },
133 count: None,
134 },
135 wgpu::BindGroupLayoutEntry {
137 binding: 2,
138 visibility: wgpu::ShaderStages::COMPUTE,
139 ty: wgpu::BindingType::Buffer {
140 ty: wgpu::BufferBindingType::Storage { read_only: false },
141 has_dynamic_offset: false,
142 min_binding_size: None,
143 },
144 count: None,
145 },
146 wgpu::BindGroupLayoutEntry {
148 binding: 3,
149 visibility: wgpu::ShaderStages::COMPUTE,
150 ty: wgpu::BindingType::Buffer {
151 ty: wgpu::BufferBindingType::Storage { read_only: false },
152 has_dynamic_offset: false,
153 min_binding_size: None,
154 },
155 count: None,
156 },
157 wgpu::BindGroupLayoutEntry {
159 binding: 4,
160 visibility: wgpu::ShaderStages::COMPUTE,
161 ty: wgpu::BindingType::Buffer {
162 ty: wgpu::BufferBindingType::Storage { read_only: false },
163 has_dynamic_offset: false,
164 min_binding_size: None,
165 },
166 count: None,
167 },
168 ],
169 });
170
171 let pipeline_layout = ctx
173 .device
174 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
175 label: Some("collision_pipeline_layout"),
176 bind_group_layouts: &[&bind_group_layout],
177 push_constant_ranges: &[],
178 });
179
180 let aabb_pipeline = ctx
182 .device
183 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
184 label: Some("collision_aabb_pipeline"),
185 layout: Some(&pipeline_layout),
186 module: &shader,
187 entry_point: Some("compute_aabbs"),
188 compilation_options: Default::default(),
189 cache: None,
190 });
191
192 let test_pipeline = ctx
194 .device
195 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
196 label: Some("collision_test_pipeline"),
197 layout: Some(&pipeline_layout),
198 module: &shader,
199 entry_point: Some("test_intersections"),
200 compilation_options: Default::default(),
201 cache: None,
202 });
203
204 Ok(Self {
205 aabb_pipeline,
206 test_pipeline,
207 bind_group_layout,
208 })
209 }
210
211 pub fn detect(
213 &self,
214 ctx: &GpuContext,
215 mesh_buffers: &MeshBuffers,
216 params: &GpuCollisionParams,
217 ) -> GpuResult<GpuCollisionResult> {
218 let start = std::time::Instant::now();
219 let triangle_count = mesh_buffers.triangle_count as usize;
220
221 if triangle_count < 2 {
222 return Ok(GpuCollisionResult {
223 has_intersections: false,
224 intersection_count: 0,
225 intersecting_pairs: Vec::new(),
226 truncated: false,
227 compute_time_ms: 0.0,
228 });
229 }
230
231 let max_pairs = if params.max_pairs == 0 {
232 triangle_count * triangle_count / 2 } else {
234 params.max_pairs
235 };
236
237 info!(
238 triangles = triangle_count,
239 max_pairs = max_pairs,
240 "Detecting self-intersections on GPU"
241 );
242
243 let shader_params = ShaderCollisionParams {
245 triangle_count: triangle_count as u32,
246 max_pairs: max_pairs as u32,
247 epsilon: params.epsilon,
248 skip_adjacent: if params.skip_adjacent { 1 } else { 0 },
249 };
250
251 let params_buffer = ctx
252 .device
253 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
254 label: Some("collision_params"),
255 contents: bytemuck::bytes_of(&shader_params),
256 usage: wgpu::BufferUsages::UNIFORM,
257 });
258
259 let aabb_size = triangle_count * std::mem::size_of::<GpuAABB>();
261 let aabb_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
262 label: Some("collision_aabbs"),
263 size: aabb_size as u64,
264 usage: wgpu::BufferUsages::STORAGE,
265 mapped_at_creation: false,
266 });
267
268 let pairs_size = max_pairs * std::mem::size_of::<GpuIntersectionPair>();
270 let pairs_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
271 label: Some("collision_pairs"),
272 size: pairs_size as u64,
273 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
274 mapped_at_creation: false,
275 });
276
277 let count_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
279 label: Some("collision_count"),
280 size: std::mem::size_of::<u32>() as u64,
281 usage: wgpu::BufferUsages::STORAGE
282 | wgpu::BufferUsages::COPY_SRC
283 | wgpu::BufferUsages::COPY_DST,
284 mapped_at_creation: false,
285 });
286
287 ctx.queue
289 .write_buffer(&count_buffer, 0, bytemuck::bytes_of(&0u32));
290
291 let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
293 label: Some("collision_bind_group"),
294 layout: &self.bind_group_layout,
295 entries: &[
296 wgpu::BindGroupEntry {
297 binding: 0,
298 resource: mesh_buffers.triangles.as_entire_binding(),
299 },
300 wgpu::BindGroupEntry {
301 binding: 1,
302 resource: params_buffer.as_entire_binding(),
303 },
304 wgpu::BindGroupEntry {
305 binding: 2,
306 resource: aabb_buffer.as_entire_binding(),
307 },
308 wgpu::BindGroupEntry {
309 binding: 3,
310 resource: pairs_buffer.as_entire_binding(),
311 },
312 wgpu::BindGroupEntry {
313 binding: 4,
314 resource: count_buffer.as_entire_binding(),
315 },
316 ],
317 });
318
319 let mut encoder = ctx
321 .device
322 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
323 label: Some("collision_encoder"),
324 });
325
326 let workgroups = (triangle_count as u32).div_ceil(256);
327
328 {
330 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
331 label: Some("collision_aabb_pass"),
332 timestamp_writes: None,
333 });
334 compute_pass.set_pipeline(&self.aabb_pipeline);
335 compute_pass.set_bind_group(0, &bind_group, &[]);
336 compute_pass.dispatch_workgroups(workgroups, 1, 1);
337 }
338
339 {
341 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
342 label: Some("collision_test_pass"),
343 timestamp_writes: None,
344 });
345 compute_pass.set_pipeline(&self.test_pipeline);
346 compute_pass.set_bind_group(0, &bind_group, &[]);
347 compute_pass.dispatch_workgroups(workgroups, 1, 1);
348 }
349
350 ctx.queue.submit([encoder.finish()]);
352
353 let pair_count = self.download_count(ctx, &count_buffer)?;
355 let pairs = self.download_pairs(ctx, &pairs_buffer, pair_count.min(max_pairs as u32))?;
356
357 let compute_time_ms = start.elapsed().as_secs_f64() * 1000.0;
358
359 let intersecting_pairs: Vec<(u32, u32)> =
360 pairs.iter().map(|p| (p.tri_a, p.tri_b)).collect();
361
362 info!(
363 pairs_found = pair_count,
364 time_ms = compute_time_ms,
365 "Collision detection complete"
366 );
367
368 Ok(GpuCollisionResult {
369 has_intersections: pair_count > 0,
370 intersection_count: pair_count as usize,
371 intersecting_pairs,
372 truncated: pair_count as usize >= max_pairs,
373 compute_time_ms,
374 })
375 }
376
377 fn download_count(&self, ctx: &GpuContext, buffer: &wgpu::Buffer) -> GpuResult<u32> {
378 let staging = ctx.device.create_buffer(&wgpu::BufferDescriptor {
379 label: Some("count_staging"),
380 size: std::mem::size_of::<u32>() as u64,
381 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
382 mapped_at_creation: false,
383 });
384
385 let mut encoder = ctx
386 .device
387 .create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
388 encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, std::mem::size_of::<u32>() as u64);
389 ctx.queue.submit([encoder.finish()]);
390
391 let slice = staging.slice(..);
392 let (tx, rx) = std::sync::mpsc::channel();
393 slice.map_async(wgpu::MapMode::Read, move |result| {
394 tx.send(result).unwrap();
395 });
396 ctx.device.poll(wgpu::Maintain::Wait);
397
398 rx.recv()
399 .map_err(|_| GpuError::BufferMapping("channel closed".into()))?
400 .map_err(|e| GpuError::BufferMapping(format!("{:?}", e)))?;
401
402 let data = slice.get_mapped_range();
403 let count = *bytemuck::from_bytes::<u32>(&data);
404 drop(data);
405 staging.unmap();
406
407 Ok(count)
408 }
409
410 fn download_pairs(
411 &self,
412 ctx: &GpuContext,
413 buffer: &wgpu::Buffer,
414 count: u32,
415 ) -> GpuResult<Vec<GpuIntersectionPair>> {
416 if count == 0 {
417 return Ok(Vec::new());
418 }
419
420 let size = (count as usize) * std::mem::size_of::<GpuIntersectionPair>();
421 let staging = ctx.device.create_buffer(&wgpu::BufferDescriptor {
422 label: Some("pairs_staging"),
423 size: size as u64,
424 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
425 mapped_at_creation: false,
426 });
427
428 let mut encoder = ctx
429 .device
430 .create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
431 encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, size as u64);
432 ctx.queue.submit([encoder.finish()]);
433
434 let slice = staging.slice(..);
435 let (tx, rx) = std::sync::mpsc::channel();
436 slice.map_async(wgpu::MapMode::Read, move |result| {
437 tx.send(result).unwrap();
438 });
439 ctx.device.poll(wgpu::Maintain::Wait);
440
441 rx.recv()
442 .map_err(|_| GpuError::BufferMapping("channel closed".into()))?
443 .map_err(|e| GpuError::BufferMapping(format!("{:?}", e)))?;
444
445 let data = slice.get_mapped_range();
446 let pairs: Vec<GpuIntersectionPair> = bytemuck::cast_slice(&data).to_vec();
447 drop(data);
448 staging.unmap();
449
450 Ok(pairs)
451 }
452}
453
454pub fn detect_self_intersections_gpu(
456 mesh: &Mesh,
457 params: &GpuCollisionParams,
458) -> GpuResult<GpuCollisionResult> {
459 let ctx = GpuContext::try_get()?;
460
461 let mesh_buffers = MeshBuffers::from_mesh(ctx, mesh)?;
463
464 let pipeline = CollisionPipeline::new(ctx)?;
465 pipeline.detect(ctx, &mesh_buffers, params)
466}
467
468pub fn try_detect_self_intersections_gpu(
470 mesh: &Mesh,
471 params: &GpuCollisionParams,
472) -> Option<GpuCollisionResult> {
473 match detect_self_intersections_gpu(mesh, params) {
474 Ok(result) => Some(result),
475 Err(GpuError::NotAvailable) => {
476 debug!("GPU not available for collision detection");
477 None
478 }
479 Err(e) => {
480 warn!("GPU collision detection failed: {}", e);
481 None
482 }
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use mesh_repair::Vertex;
490
491 fn create_simple_mesh() -> Mesh {
492 let mut mesh = Mesh::new();
493
494 mesh.vertices.push(Vertex::from_coords(0.0, 0.0, 0.0));
496 mesh.vertices.push(Vertex::from_coords(1.0, 0.0, 0.0));
497 mesh.vertices.push(Vertex::from_coords(0.0, 1.0, 0.0));
498 mesh.faces.push([0, 1, 2]);
499
500 mesh
501 }
502
503 #[test]
504 fn test_gpu_collision_params_default() {
505 let params = GpuCollisionParams::default();
506 assert!(params.skip_adjacent);
507 assert_eq!(params.max_pairs, 1000);
508 }
509
510 #[test]
511 fn test_try_detect_self_intersections_gpu() {
512 let mesh = create_simple_mesh();
513 let params = GpuCollisionParams::default();
514
515 let _result = try_detect_self_intersections_gpu(&mesh, ¶ms);
517 }
518}