mesh_gpu/
buffers.rs

1//! GPU buffer types for mesh and SDF grid data.
2//!
3//! This module provides structures for efficiently transferring mesh and grid
4//! data to and from the GPU.
5
6use bytemuck::{Pod, Zeroable};
7use wgpu::util::DeviceExt;
8use wgpu::{Buffer, BufferUsages};
9
10use mesh_repair::Mesh;
11
12use crate::context::GpuContext;
13use crate::error::{GpuError, GpuResult};
14
15/// GPU-friendly triangle representation with aligned fields.
16///
17/// Uses 4-component vectors for proper GPU alignment (vec4 = 16 bytes).
18#[repr(C)]
19#[derive(Clone, Copy, Debug, Pod, Zeroable)]
20pub struct GpuTriangle {
21    /// First vertex position (xyz) + padding.
22    pub v0: [f32; 4],
23    /// Second vertex position (xyz) + padding.
24    pub v1: [f32; 4],
25    /// Third vertex position (xyz) + padding.
26    pub v2: [f32; 4],
27}
28
29impl GpuTriangle {
30    /// Create a GPU triangle from vertex positions.
31    pub fn new(v0: [f32; 3], v1: [f32; 3], v2: [f32; 3]) -> Self {
32        Self {
33            v0: [v0[0], v0[1], v0[2], 0.0],
34            v1: [v1[0], v1[1], v1[2], 0.0],
35            v2: [v2[0], v2[1], v2[2], 0.0],
36        }
37    }
38}
39
40/// GPU-friendly vertex representation with offset data.
41#[repr(C)]
42#[derive(Clone, Copy, Debug, Pod, Zeroable)]
43pub struct GpuVertex {
44    /// Vertex position (xyz) + padding.
45    pub position: [f32; 4],
46    /// Offset value for variable thickness shells.
47    pub offset: f32,
48    /// Vertex tag/region ID.
49    pub tag: u32,
50    /// Padding for alignment.
51    pub _padding: [f32; 2],
52}
53
54/// GPU buffers containing mesh geometry data.
55pub struct MeshBuffers {
56    /// Buffer containing triangle data.
57    pub triangles: Buffer,
58    /// Buffer containing vertex data (for offset queries).
59    pub vertices: Buffer,
60    /// Number of triangles.
61    pub triangle_count: u32,
62    /// Number of vertices.
63    pub vertex_count: u32,
64}
65
66impl MeshBuffers {
67    /// Create GPU buffers from a mesh.
68    ///
69    /// # Arguments
70    /// * `ctx` - GPU context
71    /// * `mesh` - Source mesh to upload
72    ///
73    /// # Returns
74    /// GPU buffers containing the mesh data.
75    pub fn from_mesh(ctx: &GpuContext, mesh: &Mesh) -> GpuResult<Self> {
76        let triangle_count = mesh.faces.len();
77        let vertex_count = mesh.vertices.len();
78
79        // Check mesh size limits
80        let max_triangles =
81            ctx.max_storage_buffer_size() as usize / std::mem::size_of::<GpuTriangle>();
82        if triangle_count > max_triangles {
83            return Err(GpuError::MeshTooLarge {
84                triangles: triangle_count,
85                max: max_triangles,
86            });
87        }
88
89        // Convert triangles to GPU format
90        let gpu_triangles: Vec<GpuTriangle> = mesh
91            .faces
92            .iter()
93            .map(|face| {
94                let v0 = &mesh.vertices[face[0] as usize].position;
95                let v1 = &mesh.vertices[face[1] as usize].position;
96                let v2 = &mesh.vertices[face[2] as usize].position;
97                GpuTriangle::new(
98                    [v0.x as f32, v0.y as f32, v0.z as f32],
99                    [v1.x as f32, v1.y as f32, v1.z as f32],
100                    [v2.x as f32, v2.y as f32, v2.z as f32],
101                )
102            })
103            .collect();
104
105        // Convert vertices to GPU format
106        let gpu_vertices: Vec<GpuVertex> = mesh
107            .vertices
108            .iter()
109            .map(|v| GpuVertex {
110                position: [
111                    v.position.x as f32,
112                    v.position.y as f32,
113                    v.position.z as f32,
114                    0.0,
115                ],
116                offset: v.offset.unwrap_or(0.0),
117                tag: v.tag.unwrap_or(0),
118                _padding: [0.0, 0.0],
119            })
120            .collect();
121
122        // Create GPU buffers
123        let triangles = ctx
124            .device
125            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
126                label: Some("mesh_triangles"),
127                contents: bytemuck::cast_slice(&gpu_triangles),
128                usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
129            });
130
131        let vertices = ctx
132            .device
133            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
134                label: Some("mesh_vertices"),
135                contents: bytemuck::cast_slice(&gpu_vertices),
136                usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
137            });
138
139        Ok(Self {
140            triangles,
141            vertices,
142            triangle_count: triangle_count as u32,
143            vertex_count: vertex_count as u32,
144        })
145    }
146
147    /// Get the size of triangle buffer in bytes.
148    pub fn triangles_size(&self) -> u64 {
149        self.triangles.size()
150    }
151
152    /// Get the size of vertex buffer in bytes.
153    pub fn vertices_size(&self) -> u64 {
154        self.vertices.size()
155    }
156}
157
158/// Uniform parameters for SDF grid computation.
159#[repr(C)]
160#[derive(Clone, Copy, Debug, Pod, Zeroable)]
161pub struct GpuGridParams {
162    /// Grid origin (min corner) in world coordinates.
163    pub origin: [f32; 4],
164    /// Grid dimensions (x, y, z, padding).
165    pub dims: [u32; 4],
166    /// Voxel size in world units.
167    pub voxel_size: f32,
168    /// Number of triangles in mesh.
169    pub triangle_count: u32,
170    /// Padding for alignment.
171    pub _padding: [f32; 2],
172}
173
174/// GPU buffers for SDF grid computation and storage.
175pub struct SdfGridBuffers {
176    /// Uniform buffer containing grid parameters.
177    pub params: Buffer,
178    /// Storage buffer for SDF values (read/write).
179    pub values: Buffer,
180    /// Storage buffer for offset values (read/write).
181    pub offsets: Buffer,
182    /// Grid dimensions.
183    pub dims: [usize; 3],
184    /// Total number of voxels.
185    pub total_voxels: usize,
186}
187
188impl SdfGridBuffers {
189    /// Allocate GPU buffers for an SDF grid.
190    ///
191    /// # Arguments
192    /// * `ctx` - GPU context
193    /// * `dims` - Grid dimensions [x, y, z]
194    /// * `origin` - Grid origin in world coordinates
195    /// * `voxel_size` - Size of each voxel
196    /// * `triangle_count` - Number of triangles in mesh
197    ///
198    /// # Returns
199    /// Allocated GPU buffers for the grid.
200    pub fn allocate(
201        ctx: &GpuContext,
202        dims: [usize; 3],
203        origin: [f32; 3],
204        voxel_size: f32,
205        triangle_count: u32,
206    ) -> GpuResult<Self> {
207        let total_voxels = dims[0] * dims[1] * dims[2];
208
209        // Check grid size limits
210        let max_voxels = ctx.max_storage_buffer_size() as usize / std::mem::size_of::<f32>();
211        if total_voxels > max_voxels {
212            return Err(GpuError::GridTooLarge {
213                dims,
214                total: total_voxels,
215                max: max_voxels,
216            });
217        }
218
219        let grid_params = GpuGridParams {
220            origin: [origin[0], origin[1], origin[2], 0.0],
221            dims: [dims[0] as u32, dims[1] as u32, dims[2] as u32, 0],
222            voxel_size,
223            triangle_count,
224            _padding: [0.0, 0.0],
225        };
226
227        // Create uniform buffer for grid parameters
228        let params = ctx
229            .device
230            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
231                label: Some("sdf_grid_params"),
232                contents: bytemuck::bytes_of(&grid_params),
233                usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
234            });
235
236        // Create storage buffer for SDF values
237        let values_size = (total_voxels * std::mem::size_of::<f32>()) as u64;
238        let values = ctx.device.create_buffer(&wgpu::BufferDescriptor {
239            label: Some("sdf_values"),
240            size: values_size,
241            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
242            mapped_at_creation: false,
243        });
244
245        // Create storage buffer for offset values
246        let offsets = ctx.device.create_buffer(&wgpu::BufferDescriptor {
247            label: Some("sdf_offsets"),
248            size: values_size,
249            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
250            mapped_at_creation: false,
251        });
252
253        Ok(Self {
254            params,
255            values,
256            offsets,
257            dims,
258            total_voxels,
259        })
260    }
261
262    /// Download SDF values from GPU to CPU.
263    ///
264    /// # Arguments
265    /// * `ctx` - GPU context
266    ///
267    /// # Returns
268    /// Vector of SDF values.
269    pub fn download_values(&self, ctx: &GpuContext) -> GpuResult<Vec<f32>> {
270        self.download_buffer(ctx, &self.values)
271    }
272
273    /// Download offset values from GPU to CPU.
274    ///
275    /// # Arguments
276    /// * `ctx` - GPU context
277    ///
278    /// # Returns
279    /// Vector of offset values.
280    pub fn download_offsets(&self, ctx: &GpuContext) -> GpuResult<Vec<f32>> {
281        self.download_buffer(ctx, &self.offsets)
282    }
283
284    /// Download a buffer's contents to CPU memory.
285    fn download_buffer(&self, ctx: &GpuContext, buffer: &Buffer) -> GpuResult<Vec<f32>> {
286        let buffer_size = buffer.size();
287
288        // Create staging buffer for readback
289        let staging = ctx.device.create_buffer(&wgpu::BufferDescriptor {
290            label: Some("sdf_staging"),
291            size: buffer_size,
292            usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
293            mapped_at_creation: false,
294        });
295
296        // Copy GPU buffer to staging buffer
297        let mut encoder = ctx
298            .device
299            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
300                label: Some("sdf_download"),
301            });
302        encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, buffer_size);
303        ctx.queue.submit([encoder.finish()]);
304
305        // Map staging buffer and read data
306        let slice = staging.slice(..);
307        let (tx, rx) = std::sync::mpsc::channel();
308        slice.map_async(wgpu::MapMode::Read, move |result| {
309            tx.send(result).unwrap();
310        });
311
312        // Wait for GPU to finish
313        ctx.device.poll(wgpu::Maintain::Wait);
314
315        // Check mapping result
316        rx.recv()
317            .map_err(|_| GpuError::BufferMapping("channel closed".into()))?
318            .map_err(|e| GpuError::BufferMapping(format!("{:?}", e)))?;
319
320        // Read data from mapped buffer
321        let data = slice.get_mapped_range();
322        let values: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
323
324        // Unmap buffer (drop guard first)
325        drop(data);
326        staging.unmap();
327
328        Ok(values)
329    }
330
331    /// Upload SDF values from CPU to GPU.
332    ///
333    /// # Arguments
334    /// * `ctx` - GPU context
335    /// * `values` - SDF values to upload
336    pub fn upload_values(&self, ctx: &GpuContext, values: &[f32]) -> GpuResult<()> {
337        if values.len() != self.total_voxels {
338            return Err(GpuError::Execution(format!(
339                "value count mismatch: expected {}, got {}",
340                self.total_voxels,
341                values.len()
342            )));
343        }
344        ctx.queue
345            .write_buffer(&self.values, 0, bytemuck::cast_slice(values));
346        Ok(())
347    }
348
349    /// Upload offset values from CPU to GPU.
350    ///
351    /// # Arguments
352    /// * `ctx` - GPU context
353    /// * `offsets` - Offset values to upload
354    pub fn upload_offsets(&self, ctx: &GpuContext, offsets: &[f32]) -> GpuResult<()> {
355        if offsets.len() != self.total_voxels {
356            return Err(GpuError::Execution(format!(
357                "offset count mismatch: expected {}, got {}",
358                self.total_voxels,
359                offsets.len()
360            )));
361        }
362        ctx.queue
363            .write_buffer(&self.offsets, 0, bytemuck::cast_slice(offsets));
364        Ok(())
365    }
366}
367
368/// Configuration for tiled processing of large grids.
369#[derive(Debug, Clone)]
370pub struct TileConfig {
371    /// Size of each tile in voxels [x, y, z].
372    pub tile_size: [usize; 3],
373    /// Overlap between tiles (for algorithms that need neighbor data).
374    pub overlap: usize,
375}
376
377impl Default for TileConfig {
378    fn default() -> Self {
379        Self {
380            tile_size: [128, 128, 128],
381            overlap: 2,
382        }
383    }
384}
385
386impl TileConfig {
387    /// Create a tile configuration optimized for available GPU memory.
388    pub fn for_memory_budget(budget_bytes: u64) -> Self {
389        // Each voxel needs: f32 for SDF + f32 for offset = 8 bytes
390        let bytes_per_voxel = 8u64;
391        let max_voxels = budget_bytes / bytes_per_voxel;
392
393        // Find cube root to get equal-sided tiles
394        let side = (max_voxels as f64).cbrt() as usize;
395        let side = side.clamp(32, 256); // Clamp to reasonable range
396
397        Self {
398            tile_size: [side, side, side],
399            overlap: 2,
400        }
401    }
402
403    /// Calculate the number of tiles needed for given grid dimensions.
404    pub fn tile_count(&self, grid_dims: [usize; 3]) -> [usize; 3] {
405        [
406            grid_dims[0].div_ceil(self.tile_size[0]),
407            grid_dims[1].div_ceil(self.tile_size[1]),
408            grid_dims[2].div_ceil(self.tile_size[2]),
409        ]
410    }
411
412    /// Calculate the total number of tiles.
413    pub fn total_tiles(&self, grid_dims: [usize; 3]) -> usize {
414        let count = self.tile_count(grid_dims);
415        count[0] * count[1] * count[2]
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    #[test]
424    fn test_gpu_triangle_alignment() {
425        // GpuTriangle should be 48 bytes (3 x vec4)
426        assert_eq!(std::mem::size_of::<GpuTriangle>(), 48);
427    }
428
429    #[test]
430    fn test_gpu_vertex_alignment() {
431        // GpuVertex should be 32 bytes (vec4 + f32 + u32 + 2xf32)
432        assert_eq!(std::mem::size_of::<GpuVertex>(), 32);
433    }
434
435    #[test]
436    fn test_gpu_grid_params_alignment() {
437        // GpuGridParams should be 48 bytes
438        assert_eq!(std::mem::size_of::<GpuGridParams>(), 48);
439    }
440
441    #[test]
442    fn test_tile_config_default() {
443        let config = TileConfig::default();
444        assert_eq!(config.tile_size, [128, 128, 128]);
445        assert_eq!(config.overlap, 2);
446    }
447
448    #[test]
449    fn test_tile_count_calculation() {
450        let config = TileConfig {
451            tile_size: [100, 100, 100],
452            overlap: 0,
453        };
454
455        // 250x250x250 grid should need 3x3x3 = 27 tiles
456        let count = config.tile_count([250, 250, 250]);
457        assert_eq!(count, [3, 3, 3]);
458        assert_eq!(config.total_tiles([250, 250, 250]), 27);
459    }
460}