Skip to main content

oxigdal_gpu/
buffer.rs

1//! GPU buffer management for OxiGDAL.
2//!
3//! This module provides efficient GPU buffer management for raster data,
4//! including upload, download, and memory mapping operations.
5
6use crate::context::GpuContext;
7use crate::error::{GpuError, GpuResult};
8use bytemuck::{Pod, Zeroable};
9use std::marker::PhantomData;
10use std::sync::Arc;
11use tracing::{debug, trace};
12use wgpu::{
13    Buffer, BufferAsyncError, BufferDescriptor, BufferUsages, COPY_BUFFER_ALIGNMENT, MapMode,
14};
15
16/// GPU buffer wrapper with type safety.
17///
18/// This struct wraps a WGPU buffer and provides type-safe operations
19/// for uploading and downloading data to/from the GPU.
20pub struct GpuBuffer<T: Pod> {
21    /// The underlying WGPU buffer.
22    buffer: Arc<Buffer>,
23    /// GPU context.
24    context: GpuContext,
25    /// Number of elements in the buffer.
26    len: usize,
27    /// Buffer usage flags.
28    usage: BufferUsages,
29    /// Phantom data for type parameter.
30    _phantom: PhantomData<T>,
31}
32
33impl<T: Pod> GpuBuffer<T> {
34    /// Create a new GPU buffer with the specified size and usage.
35    ///
36    /// # Errors
37    ///
38    /// Returns an error if buffer creation fails or size is invalid.
39    pub fn new(context: &GpuContext, len: usize, usage: BufferUsages) -> GpuResult<Self> {
40        let size = Self::calculate_size(len)?;
41
42        trace!("Creating GPU buffer: {} elements, {} bytes", len, size);
43
44        let buffer = context.device().create_buffer(&BufferDescriptor {
45            label: Some("GpuBuffer"),
46            size,
47            usage,
48            mapped_at_creation: false,
49        });
50
51        Ok(Self {
52            buffer: Arc::new(buffer),
53            context: context.clone(),
54            len,
55            usage,
56            _phantom: PhantomData,
57        })
58    }
59
60    /// Create a GPU buffer from existing data.
61    ///
62    /// # Errors
63    ///
64    /// Returns an error if buffer creation or upload fails.
65    pub fn from_data(context: &GpuContext, data: &[T], usage: BufferUsages) -> GpuResult<Self> {
66        let mut buffer = Self::new(context, data.len(), usage | BufferUsages::COPY_DST)?;
67        buffer.write(data)?;
68        Ok(buffer)
69    }
70
71    /// Create a staging buffer for CPU-GPU transfers.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if buffer creation fails.
76    pub fn staging(context: &GpuContext, len: usize) -> GpuResult<Self> {
77        Self::new(
78            context,
79            len,
80            BufferUsages::MAP_READ | BufferUsages::COPY_DST,
81        )
82    }
83
84    /// Calculate the aligned buffer size in bytes.
85    fn calculate_size(len: usize) -> GpuResult<u64> {
86        let element_size = std::mem::size_of::<T>();
87        let size = len
88            .checked_mul(element_size)
89            .ok_or_else(|| GpuError::invalid_buffer("Buffer size overflow"))?;
90
91        // Align to COPY_BUFFER_ALIGNMENT for efficient transfers
92        let aligned_size = ((size as u64 + COPY_BUFFER_ALIGNMENT - 1) / COPY_BUFFER_ALIGNMENT)
93            * COPY_BUFFER_ALIGNMENT;
94
95        Ok(aligned_size)
96    }
97
98    /// Write data to the GPU buffer.
99    ///
100    /// # Errors
101    ///
102    /// Returns an error if the buffer doesn't support writes or data size
103    /// doesn't match buffer size.
104    pub fn write(&mut self, data: &[T]) -> GpuResult<()> {
105        if data.len() != self.len {
106            return Err(GpuError::invalid_buffer(format!(
107                "Data size mismatch: expected {}, got {}",
108                self.len,
109                data.len()
110            )));
111        }
112
113        if !self.usage.contains(BufferUsages::COPY_DST) {
114            return Err(GpuError::invalid_buffer(
115                "Buffer not writable (missing COPY_DST usage)",
116            ));
117        }
118
119        let bytes = bytemuck::cast_slice(data);
120        self.context.queue().write_buffer(&self.buffer, 0, bytes);
121
122        debug!("Wrote {} bytes to GPU buffer", bytes.len());
123        Ok(())
124    }
125
126    /// Read data from the GPU buffer asynchronously.
127    ///
128    /// # Errors
129    ///
130    /// Returns an error if the buffer doesn't support reads or mapping fails.
131    pub async fn read(&self) -> GpuResult<Vec<T>> {
132        if !self.usage.contains(BufferUsages::MAP_READ) {
133            return Err(GpuError::invalid_buffer(
134                "Buffer not readable (missing MAP_READ usage)",
135            ));
136        }
137
138        let buffer_slice = self.buffer.slice(..);
139
140        // Map the buffer for reading
141        let (tx, rx) = futures::channel::oneshot::channel();
142        buffer_slice.map_async(MapMode::Read, move |result| {
143            let _ = tx.send(result);
144        });
145
146        // Poll the device until the buffer is mapped
147        self.context.poll(true);
148
149        // Wait for mapping to complete
150        rx.await
151            .map_err(|_| GpuError::buffer_mapping("Channel closed"))?
152            .map_err(|e| GpuError::buffer_mapping(Self::map_error_to_string(e)))?;
153
154        // Read the data
155        let data = buffer_slice.get_mapped_range();
156        let result: Vec<T> = bytemuck::cast_slice(&data).to_vec();
157
158        // Unmap the buffer
159        drop(data);
160        self.buffer.unmap();
161
162        debug!("Read {} elements from GPU buffer", result.len());
163        Ok(result)
164    }
165
166    /// Read data from the GPU buffer synchronously (blocking).
167    ///
168    /// # Errors
169    ///
170    /// Returns an error if the buffer doesn't support reads or mapping fails.
171    pub fn read_blocking(&self) -> GpuResult<Vec<T>> {
172        pollster::block_on(self.read())
173    }
174
175    /// Copy data from another GPU buffer.
176    ///
177    /// # Errors
178    ///
179    /// Returns an error if buffer sizes don't match or copy is not supported.
180    pub fn copy_from(&mut self, source: &GpuBuffer<T>) -> GpuResult<()> {
181        if self.len != source.len {
182            return Err(GpuError::invalid_buffer(format!(
183                "Buffer size mismatch: {} != {}",
184                self.len, source.len
185            )));
186        }
187
188        if !source.usage.contains(BufferUsages::COPY_SRC) {
189            return Err(GpuError::invalid_buffer(
190                "Source buffer not copyable (missing COPY_SRC usage)",
191            ));
192        }
193
194        if !self.usage.contains(BufferUsages::COPY_DST) {
195            return Err(GpuError::invalid_buffer(
196                "Destination buffer not copyable (missing COPY_DST usage)",
197            ));
198        }
199
200        let mut encoder =
201            self.context
202                .device()
203                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
204                    label: Some("Buffer Copy"),
205                });
206
207        let size = Self::calculate_size(self.len)?;
208        encoder.copy_buffer_to_buffer(&source.buffer, 0, &self.buffer, 0, size);
209
210        self.context.queue().submit(Some(encoder.finish()));
211
212        debug!("Copied {} elements between GPU buffers", self.len);
213        Ok(())
214    }
215
216    /// Get the number of elements in the buffer.
217    pub fn len(&self) -> usize {
218        self.len
219    }
220
221    /// Check if the buffer is empty.
222    pub fn is_empty(&self) -> bool {
223        self.len == 0
224    }
225
226    /// Get the buffer size in bytes.
227    pub fn size_bytes(&self) -> u64 {
228        Self::calculate_size(self.len).unwrap_or(0)
229    }
230
231    /// Get the underlying WGPU buffer.
232    pub fn buffer(&self) -> &Buffer {
233        &self.buffer
234    }
235
236    /// Get buffer usage flags.
237    pub fn usage(&self) -> BufferUsages {
238        self.usage
239    }
240
241    /// Convert buffer mapping error to string.
242    fn map_error_to_string(error: BufferAsyncError) -> String {
243        error.to_string()
244    }
245}
246
247impl<T: Pod> Clone for GpuBuffer<T> {
248    fn clone(&self) -> Self {
249        Self {
250            buffer: Arc::clone(&self.buffer),
251            context: self.context.clone(),
252            len: self.len,
253            usage: self.usage,
254            _phantom: PhantomData,
255        }
256    }
257}
258
259impl<T: Pod> std::fmt::Debug for GpuBuffer<T> {
260    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261        f.debug_struct("GpuBuffer")
262            .field("len", &self.len)
263            .field("size_bytes", &self.size_bytes())
264            .field("usage", &self.usage)
265            .field("type", &std::any::type_name::<T>())
266            .finish()
267    }
268}
269
270/// GPU raster buffer for multi-band raster data.
271///
272/// This struct manages GPU buffers for multi-band raster data with
273/// efficient interleaved or planar storage.
274pub struct GpuRasterBuffer<T: Pod> {
275    /// GPU buffers for each band.
276    bands: Vec<GpuBuffer<T>>,
277    /// Width of the raster.
278    width: u32,
279    /// Height of the raster.
280    height: u32,
281}
282
283impl<T: Pod + Zeroable> GpuRasterBuffer<T> {
284    /// Create a new GPU raster buffer.
285    ///
286    /// # Errors
287    ///
288    /// Returns an error if buffer creation fails.
289    pub fn new(
290        context: &GpuContext,
291        width: u32,
292        height: u32,
293        num_bands: usize,
294        usage: BufferUsages,
295    ) -> GpuResult<Self> {
296        let pixels_per_band = (width as usize)
297            .checked_mul(height as usize)
298            .ok_or_else(|| GpuError::invalid_buffer("Raster size overflow"))?;
299
300        let bands = (0..num_bands)
301            .map(|_| GpuBuffer::new(context, pixels_per_band, usage))
302            .collect::<GpuResult<Vec<_>>>()?;
303
304        debug!(
305            "Created GPU raster buffer: {}x{} with {} bands",
306            width, height, num_bands
307        );
308
309        Ok(Self {
310            bands,
311            width,
312            height,
313        })
314    }
315
316    /// Create a GPU raster buffer from data.
317    ///
318    /// # Errors
319    ///
320    /// Returns an error if buffer creation or upload fails.
321    pub fn from_bands(
322        context: &GpuContext,
323        width: u32,
324        height: u32,
325        bands_data: &[Vec<T>],
326        usage: BufferUsages,
327    ) -> GpuResult<Self> {
328        let expected_size = (width as usize) * (height as usize);
329
330        for (i, band) in bands_data.iter().enumerate() {
331            if band.len() != expected_size {
332                return Err(GpuError::invalid_buffer(format!(
333                    "Band {} size mismatch: expected {}, got {}",
334                    i,
335                    expected_size,
336                    band.len()
337                )));
338            }
339        }
340
341        let bands = bands_data
342            .iter()
343            .map(|data| GpuBuffer::from_data(context, data, usage))
344            .collect::<GpuResult<Vec<_>>>()?;
345
346        Ok(Self {
347            bands,
348            width,
349            height,
350        })
351    }
352
353    /// Get a specific band buffer.
354    pub fn band(&self, index: usize) -> Option<&GpuBuffer<T>> {
355        self.bands.get(index)
356    }
357
358    /// Get mutable reference to a specific band buffer.
359    pub fn band_mut(&mut self, index: usize) -> Option<&mut GpuBuffer<T>> {
360        self.bands.get_mut(index)
361    }
362
363    /// Get all band buffers.
364    pub fn bands(&self) -> &[GpuBuffer<T>] {
365        &self.bands
366    }
367
368    /// Get the number of bands.
369    pub fn num_bands(&self) -> usize {
370        self.bands.len()
371    }
372
373    /// Get raster dimensions.
374    pub fn dimensions(&self) -> (u32, u32) {
375        (self.width, self.height)
376    }
377
378    /// Get raster width.
379    pub fn width(&self) -> u32 {
380        self.width
381    }
382
383    /// Get raster height.
384    pub fn height(&self) -> u32 {
385        self.height
386    }
387
388    /// Read all bands from GPU asynchronously.
389    ///
390    /// # Errors
391    ///
392    /// Returns an error if reading fails.
393    pub async fn read_all_bands(&self) -> GpuResult<Vec<Vec<T>>> {
394        let mut results = Vec::with_capacity(self.bands.len());
395
396        for band in &self.bands {
397            results.push(band.read().await?);
398        }
399
400        Ok(results)
401    }
402
403    /// Read all bands from GPU synchronously.
404    ///
405    /// # Errors
406    ///
407    /// Returns an error if reading fails.
408    pub fn read_all_bands_blocking(&self) -> GpuResult<Vec<Vec<T>>> {
409        pollster::block_on(self.read_all_bands())
410    }
411}
412
413impl<T: Pod> std::fmt::Debug for GpuRasterBuffer<T> {
414    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415        f.debug_struct("GpuRasterBuffer")
416            .field("width", &self.width)
417            .field("height", &self.height)
418            .field("num_bands", &self.num_bands())
419            .field("type", &std::any::type_name::<T>())
420            .finish()
421    }
422}
423
424#[cfg(test)]
425#[allow(clippy::panic)]
426mod tests {
427    use super::*;
428
429    #[tokio::test]
430    async fn test_gpu_buffer_creation() {
431        if let Ok(context) = GpuContext::new().await {
432            let buffer: GpuBuffer<f32> = GpuBuffer::new(&context, 1024, BufferUsages::STORAGE)
433                .unwrap_or_else(|e| {
434                    panic!("Failed to create buffer: {}", e);
435                });
436
437            assert_eq!(buffer.len(), 1024);
438            assert!(!buffer.is_empty());
439        }
440    }
441
442    #[tokio::test]
443    #[ignore]
444    async fn test_gpu_buffer_write_read() {
445        if let Ok(context) = GpuContext::new().await {
446            let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
447
448            let buffer = GpuBuffer::from_data(
449                &context,
450                &data,
451                BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
452            )
453            .unwrap_or_else(|e| {
454                panic!("Failed to create buffer: {}", e);
455            });
456
457            // Create staging buffer for reading
458            let mut staging = GpuBuffer::staging(&context, 100).unwrap_or_else(|e| {
459                panic!("Failed to create staging buffer: {}", e);
460            });
461
462            staging.copy_from(&buffer).unwrap_or_else(|e| {
463                panic!("Failed to copy buffer: {}", e);
464            });
465
466            let result = staging.read().await.unwrap_or_else(|e| {
467                panic!("Failed to read buffer: {}", e);
468            });
469
470            assert_eq!(result.len(), data.len());
471            for (a, b) in result.iter().zip(data.iter()) {
472                assert!((a - b).abs() < 1e-6);
473            }
474        }
475    }
476
477    #[tokio::test]
478    async fn test_gpu_raster_buffer() {
479        if let Ok(context) = GpuContext::new().await {
480            let width = 64;
481            let height = 64;
482            let num_bands = 3;
483
484            let raster: GpuRasterBuffer<f32> =
485                GpuRasterBuffer::new(&context, width, height, num_bands, BufferUsages::STORAGE)
486                    .unwrap_or_else(|e| {
487                        panic!("Failed to create raster buffer: {}", e);
488                    });
489
490            assert_eq!(raster.width(), width);
491            assert_eq!(raster.height(), height);
492            assert_eq!(raster.num_bands(), num_bands);
493        }
494    }
495}