1use bytemuck::{Pod, Zeroable};
7use wgpu::util::DeviceExt;
8use wgpu::{Buffer, BufferUsages};
9
10use mesh_repair::Mesh;
11
12use crate::context::GpuContext;
13use crate::error::{GpuError, GpuResult};
14
15#[repr(C)]
19#[derive(Clone, Copy, Debug, Pod, Zeroable)]
20pub struct GpuTriangle {
21 pub v0: [f32; 4],
23 pub v1: [f32; 4],
25 pub v2: [f32; 4],
27}
28
29impl GpuTriangle {
30 pub fn new(v0: [f32; 3], v1: [f32; 3], v2: [f32; 3]) -> Self {
32 Self {
33 v0: [v0[0], v0[1], v0[2], 0.0],
34 v1: [v1[0], v1[1], v1[2], 0.0],
35 v2: [v2[0], v2[1], v2[2], 0.0],
36 }
37 }
38}
39
40#[repr(C)]
42#[derive(Clone, Copy, Debug, Pod, Zeroable)]
43pub struct GpuVertex {
44 pub position: [f32; 4],
46 pub offset: f32,
48 pub tag: u32,
50 pub _padding: [f32; 2],
52}
53
54pub struct MeshBuffers {
56 pub triangles: Buffer,
58 pub vertices: Buffer,
60 pub triangle_count: u32,
62 pub vertex_count: u32,
64}
65
66impl MeshBuffers {
67 pub fn from_mesh(ctx: &GpuContext, mesh: &Mesh) -> GpuResult<Self> {
76 let triangle_count = mesh.faces.len();
77 let vertex_count = mesh.vertices.len();
78
79 let max_triangles =
81 ctx.max_storage_buffer_size() as usize / std::mem::size_of::<GpuTriangle>();
82 if triangle_count > max_triangles {
83 return Err(GpuError::MeshTooLarge {
84 triangles: triangle_count,
85 max: max_triangles,
86 });
87 }
88
89 let gpu_triangles: Vec<GpuTriangle> = mesh
91 .faces
92 .iter()
93 .map(|face| {
94 let v0 = &mesh.vertices[face[0] as usize].position;
95 let v1 = &mesh.vertices[face[1] as usize].position;
96 let v2 = &mesh.vertices[face[2] as usize].position;
97 GpuTriangle::new(
98 [v0.x as f32, v0.y as f32, v0.z as f32],
99 [v1.x as f32, v1.y as f32, v1.z as f32],
100 [v2.x as f32, v2.y as f32, v2.z as f32],
101 )
102 })
103 .collect();
104
105 let gpu_vertices: Vec<GpuVertex> = mesh
107 .vertices
108 .iter()
109 .map(|v| GpuVertex {
110 position: [
111 v.position.x as f32,
112 v.position.y as f32,
113 v.position.z as f32,
114 0.0,
115 ],
116 offset: v.offset.unwrap_or(0.0),
117 tag: v.tag.unwrap_or(0),
118 _padding: [0.0, 0.0],
119 })
120 .collect();
121
122 let triangles = ctx
124 .device
125 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
126 label: Some("mesh_triangles"),
127 contents: bytemuck::cast_slice(&gpu_triangles),
128 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
129 });
130
131 let vertices = ctx
132 .device
133 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
134 label: Some("mesh_vertices"),
135 contents: bytemuck::cast_slice(&gpu_vertices),
136 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
137 });
138
139 Ok(Self {
140 triangles,
141 vertices,
142 triangle_count: triangle_count as u32,
143 vertex_count: vertex_count as u32,
144 })
145 }
146
147 pub fn triangles_size(&self) -> u64 {
149 self.triangles.size()
150 }
151
152 pub fn vertices_size(&self) -> u64 {
154 self.vertices.size()
155 }
156}
157
158#[repr(C)]
160#[derive(Clone, Copy, Debug, Pod, Zeroable)]
161pub struct GpuGridParams {
162 pub origin: [f32; 4],
164 pub dims: [u32; 4],
166 pub voxel_size: f32,
168 pub triangle_count: u32,
170 pub _padding: [f32; 2],
172}
173
174pub struct SdfGridBuffers {
176 pub params: Buffer,
178 pub values: Buffer,
180 pub offsets: Buffer,
182 pub dims: [usize; 3],
184 pub total_voxels: usize,
186}
187
188impl SdfGridBuffers {
189 pub fn allocate(
201 ctx: &GpuContext,
202 dims: [usize; 3],
203 origin: [f32; 3],
204 voxel_size: f32,
205 triangle_count: u32,
206 ) -> GpuResult<Self> {
207 let total_voxels = dims[0] * dims[1] * dims[2];
208
209 let max_voxels = ctx.max_storage_buffer_size() as usize / std::mem::size_of::<f32>();
211 if total_voxels > max_voxels {
212 return Err(GpuError::GridTooLarge {
213 dims,
214 total: total_voxels,
215 max: max_voxels,
216 });
217 }
218
219 let grid_params = GpuGridParams {
220 origin: [origin[0], origin[1], origin[2], 0.0],
221 dims: [dims[0] as u32, dims[1] as u32, dims[2] as u32, 0],
222 voxel_size,
223 triangle_count,
224 _padding: [0.0, 0.0],
225 };
226
227 let params = ctx
229 .device
230 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
231 label: Some("sdf_grid_params"),
232 contents: bytemuck::bytes_of(&grid_params),
233 usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
234 });
235
236 let values_size = (total_voxels * std::mem::size_of::<f32>()) as u64;
238 let values = ctx.device.create_buffer(&wgpu::BufferDescriptor {
239 label: Some("sdf_values"),
240 size: values_size,
241 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
242 mapped_at_creation: false,
243 });
244
245 let offsets = ctx.device.create_buffer(&wgpu::BufferDescriptor {
247 label: Some("sdf_offsets"),
248 size: values_size,
249 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
250 mapped_at_creation: false,
251 });
252
253 Ok(Self {
254 params,
255 values,
256 offsets,
257 dims,
258 total_voxels,
259 })
260 }
261
262 pub fn download_values(&self, ctx: &GpuContext) -> GpuResult<Vec<f32>> {
270 self.download_buffer(ctx, &self.values)
271 }
272
273 pub fn download_offsets(&self, ctx: &GpuContext) -> GpuResult<Vec<f32>> {
281 self.download_buffer(ctx, &self.offsets)
282 }
283
284 fn download_buffer(&self, ctx: &GpuContext, buffer: &Buffer) -> GpuResult<Vec<f32>> {
286 let buffer_size = buffer.size();
287
288 let staging = ctx.device.create_buffer(&wgpu::BufferDescriptor {
290 label: Some("sdf_staging"),
291 size: buffer_size,
292 usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
293 mapped_at_creation: false,
294 });
295
296 let mut encoder = ctx
298 .device
299 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
300 label: Some("sdf_download"),
301 });
302 encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, buffer_size);
303 ctx.queue.submit([encoder.finish()]);
304
305 let slice = staging.slice(..);
307 let (tx, rx) = std::sync::mpsc::channel();
308 slice.map_async(wgpu::MapMode::Read, move |result| {
309 tx.send(result).unwrap();
310 });
311
312 ctx.device.poll(wgpu::Maintain::Wait);
314
315 rx.recv()
317 .map_err(|_| GpuError::BufferMapping("channel closed".into()))?
318 .map_err(|e| GpuError::BufferMapping(format!("{:?}", e)))?;
319
320 let data = slice.get_mapped_range();
322 let values: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
323
324 drop(data);
326 staging.unmap();
327
328 Ok(values)
329 }
330
331 pub fn upload_values(&self, ctx: &GpuContext, values: &[f32]) -> GpuResult<()> {
337 if values.len() != self.total_voxels {
338 return Err(GpuError::Execution(format!(
339 "value count mismatch: expected {}, got {}",
340 self.total_voxels,
341 values.len()
342 )));
343 }
344 ctx.queue
345 .write_buffer(&self.values, 0, bytemuck::cast_slice(values));
346 Ok(())
347 }
348
349 pub fn upload_offsets(&self, ctx: &GpuContext, offsets: &[f32]) -> GpuResult<()> {
355 if offsets.len() != self.total_voxels {
356 return Err(GpuError::Execution(format!(
357 "offset count mismatch: expected {}, got {}",
358 self.total_voxels,
359 offsets.len()
360 )));
361 }
362 ctx.queue
363 .write_buffer(&self.offsets, 0, bytemuck::cast_slice(offsets));
364 Ok(())
365 }
366}
367
368#[derive(Debug, Clone)]
370pub struct TileConfig {
371 pub tile_size: [usize; 3],
373 pub overlap: usize,
375}
376
377impl Default for TileConfig {
378 fn default() -> Self {
379 Self {
380 tile_size: [128, 128, 128],
381 overlap: 2,
382 }
383 }
384}
385
386impl TileConfig {
387 pub fn for_memory_budget(budget_bytes: u64) -> Self {
389 let bytes_per_voxel = 8u64;
391 let max_voxels = budget_bytes / bytes_per_voxel;
392
393 let side = (max_voxels as f64).cbrt() as usize;
395 let side = side.clamp(32, 256); Self {
398 tile_size: [side, side, side],
399 overlap: 2,
400 }
401 }
402
403 pub fn tile_count(&self, grid_dims: [usize; 3]) -> [usize; 3] {
405 [
406 grid_dims[0].div_ceil(self.tile_size[0]),
407 grid_dims[1].div_ceil(self.tile_size[1]),
408 grid_dims[2].div_ceil(self.tile_size[2]),
409 ]
410 }
411
412 pub fn total_tiles(&self, grid_dims: [usize; 3]) -> usize {
414 let count = self.tile_count(grid_dims);
415 count[0] * count[1] * count[2]
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn test_gpu_triangle_alignment() {
425 assert_eq!(std::mem::size_of::<GpuTriangle>(), 48);
427 }
428
429 #[test]
430 fn test_gpu_vertex_alignment() {
431 assert_eq!(std::mem::size_of::<GpuVertex>(), 32);
433 }
434
435 #[test]
436 fn test_gpu_grid_params_alignment() {
437 assert_eq!(std::mem::size_of::<GpuGridParams>(), 48);
439 }
440
441 #[test]
442 fn test_tile_config_default() {
443 let config = TileConfig::default();
444 assert_eq!(config.tile_size, [128, 128, 128]);
445 assert_eq!(config.overlap, 2);
446 }
447
448 #[test]
449 fn test_tile_count_calculation() {
450 let config = TileConfig {
451 tile_size: [100, 100, 100],
452 overlap: 0,
453 };
454
455 let count = config.tile_count([250, 250, 250]);
457 assert_eq!(count, [3, 3, 3]);
458 assert_eq!(config.total_tiles([250, 250, 250]), 27);
459 }
460}