1use crate::context::GpuContext;
7use crate::error::{GpuError, GpuResult};
8use bytemuck::{Pod, Zeroable};
9use std::marker::PhantomData;
10use std::sync::Arc;
11use tracing::{debug, trace};
12use wgpu::{
13 Buffer, BufferAsyncError, BufferDescriptor, BufferUsages, COPY_BUFFER_ALIGNMENT, MapMode,
14};
15
16pub struct GpuBuffer<T: Pod> {
21 buffer: Arc<Buffer>,
23 context: GpuContext,
25 len: usize,
27 usage: BufferUsages,
29 _phantom: PhantomData<T>,
31}
32
33impl<T: Pod> GpuBuffer<T> {
34 pub fn new(context: &GpuContext, len: usize, usage: BufferUsages) -> GpuResult<Self> {
40 let size = Self::calculate_size(len)?;
41
42 trace!("Creating GPU buffer: {} elements, {} bytes", len, size);
43
44 let buffer = context.device().create_buffer(&BufferDescriptor {
45 label: Some("GpuBuffer"),
46 size,
47 usage,
48 mapped_at_creation: false,
49 });
50
51 Ok(Self {
52 buffer: Arc::new(buffer),
53 context: context.clone(),
54 len,
55 usage,
56 _phantom: PhantomData,
57 })
58 }
59
60 pub fn from_data(context: &GpuContext, data: &[T], usage: BufferUsages) -> GpuResult<Self> {
66 let mut buffer = Self::new(context, data.len(), usage | BufferUsages::COPY_DST)?;
67 buffer.write(data)?;
68 Ok(buffer)
69 }
70
71 pub fn staging(context: &GpuContext, len: usize) -> GpuResult<Self> {
77 Self::new(
78 context,
79 len,
80 BufferUsages::MAP_READ | BufferUsages::COPY_DST,
81 )
82 }
83
84 fn calculate_size(len: usize) -> GpuResult<u64> {
86 let element_size = std::mem::size_of::<T>();
87 let size = len
88 .checked_mul(element_size)
89 .ok_or_else(|| GpuError::invalid_buffer("Buffer size overflow"))?;
90
91 let aligned_size = ((size as u64 + COPY_BUFFER_ALIGNMENT - 1) / COPY_BUFFER_ALIGNMENT)
93 * COPY_BUFFER_ALIGNMENT;
94
95 Ok(aligned_size)
96 }
97
98 pub fn write(&mut self, data: &[T]) -> GpuResult<()> {
105 if data.len() != self.len {
106 return Err(GpuError::invalid_buffer(format!(
107 "Data size mismatch: expected {}, got {}",
108 self.len,
109 data.len()
110 )));
111 }
112
113 if !self.usage.contains(BufferUsages::COPY_DST) {
114 return Err(GpuError::invalid_buffer(
115 "Buffer not writable (missing COPY_DST usage)",
116 ));
117 }
118
119 let bytes = bytemuck::cast_slice(data);
120 self.context.queue().write_buffer(&self.buffer, 0, bytes);
121
122 debug!("Wrote {} bytes to GPU buffer", bytes.len());
123 Ok(())
124 }
125
126 pub async fn read(&self) -> GpuResult<Vec<T>> {
132 if !self.usage.contains(BufferUsages::MAP_READ) {
133 return Err(GpuError::invalid_buffer(
134 "Buffer not readable (missing MAP_READ usage)",
135 ));
136 }
137
138 let buffer_slice = self.buffer.slice(..);
139
140 let (tx, rx) = futures::channel::oneshot::channel();
142 buffer_slice.map_async(MapMode::Read, move |result| {
143 let _ = tx.send(result);
144 });
145
146 self.context.poll(true);
148
149 rx.await
151 .map_err(|_| GpuError::buffer_mapping("Channel closed"))?
152 .map_err(|e| GpuError::buffer_mapping(Self::map_error_to_string(e)))?;
153
154 let data = buffer_slice.get_mapped_range();
156 let result: Vec<T> = bytemuck::cast_slice(&data).to_vec();
157
158 drop(data);
160 self.buffer.unmap();
161
162 debug!("Read {} elements from GPU buffer", result.len());
163 Ok(result)
164 }
165
166 pub fn read_blocking(&self) -> GpuResult<Vec<T>> {
172 pollster::block_on(self.read())
173 }
174
175 pub fn copy_from(&mut self, source: &GpuBuffer<T>) -> GpuResult<()> {
181 if self.len != source.len {
182 return Err(GpuError::invalid_buffer(format!(
183 "Buffer size mismatch: {} != {}",
184 self.len, source.len
185 )));
186 }
187
188 if !source.usage.contains(BufferUsages::COPY_SRC) {
189 return Err(GpuError::invalid_buffer(
190 "Source buffer not copyable (missing COPY_SRC usage)",
191 ));
192 }
193
194 if !self.usage.contains(BufferUsages::COPY_DST) {
195 return Err(GpuError::invalid_buffer(
196 "Destination buffer not copyable (missing COPY_DST usage)",
197 ));
198 }
199
200 let mut encoder =
201 self.context
202 .device()
203 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
204 label: Some("Buffer Copy"),
205 });
206
207 let size = Self::calculate_size(self.len)?;
208 encoder.copy_buffer_to_buffer(&source.buffer, 0, &self.buffer, 0, size);
209
210 self.context.queue().submit(Some(encoder.finish()));
211
212 debug!("Copied {} elements between GPU buffers", self.len);
213 Ok(())
214 }
215
216 pub fn len(&self) -> usize {
218 self.len
219 }
220
221 pub fn is_empty(&self) -> bool {
223 self.len == 0
224 }
225
226 pub fn size_bytes(&self) -> u64 {
228 Self::calculate_size(self.len).unwrap_or(0)
229 }
230
231 pub fn buffer(&self) -> &Buffer {
233 &self.buffer
234 }
235
236 pub fn usage(&self) -> BufferUsages {
238 self.usage
239 }
240
241 fn map_error_to_string(error: BufferAsyncError) -> String {
243 error.to_string()
244 }
245}
246
247impl<T: Pod> Clone for GpuBuffer<T> {
248 fn clone(&self) -> Self {
249 Self {
250 buffer: Arc::clone(&self.buffer),
251 context: self.context.clone(),
252 len: self.len,
253 usage: self.usage,
254 _phantom: PhantomData,
255 }
256 }
257}
258
259impl<T: Pod> std::fmt::Debug for GpuBuffer<T> {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 f.debug_struct("GpuBuffer")
262 .field("len", &self.len)
263 .field("size_bytes", &self.size_bytes())
264 .field("usage", &self.usage)
265 .field("type", &std::any::type_name::<T>())
266 .finish()
267 }
268}
269
270pub struct GpuRasterBuffer<T: Pod> {
275 bands: Vec<GpuBuffer<T>>,
277 width: u32,
279 height: u32,
281}
282
283impl<T: Pod + Zeroable> GpuRasterBuffer<T> {
284 pub fn new(
290 context: &GpuContext,
291 width: u32,
292 height: u32,
293 num_bands: usize,
294 usage: BufferUsages,
295 ) -> GpuResult<Self> {
296 let pixels_per_band = (width as usize)
297 .checked_mul(height as usize)
298 .ok_or_else(|| GpuError::invalid_buffer("Raster size overflow"))?;
299
300 let bands = (0..num_bands)
301 .map(|_| GpuBuffer::new(context, pixels_per_band, usage))
302 .collect::<GpuResult<Vec<_>>>()?;
303
304 debug!(
305 "Created GPU raster buffer: {}x{} with {} bands",
306 width, height, num_bands
307 );
308
309 Ok(Self {
310 bands,
311 width,
312 height,
313 })
314 }
315
316 pub fn from_bands(
322 context: &GpuContext,
323 width: u32,
324 height: u32,
325 bands_data: &[Vec<T>],
326 usage: BufferUsages,
327 ) -> GpuResult<Self> {
328 let expected_size = (width as usize) * (height as usize);
329
330 for (i, band) in bands_data.iter().enumerate() {
331 if band.len() != expected_size {
332 return Err(GpuError::invalid_buffer(format!(
333 "Band {} size mismatch: expected {}, got {}",
334 i,
335 expected_size,
336 band.len()
337 )));
338 }
339 }
340
341 let bands = bands_data
342 .iter()
343 .map(|data| GpuBuffer::from_data(context, data, usage))
344 .collect::<GpuResult<Vec<_>>>()?;
345
346 Ok(Self {
347 bands,
348 width,
349 height,
350 })
351 }
352
353 pub fn band(&self, index: usize) -> Option<&GpuBuffer<T>> {
355 self.bands.get(index)
356 }
357
358 pub fn band_mut(&mut self, index: usize) -> Option<&mut GpuBuffer<T>> {
360 self.bands.get_mut(index)
361 }
362
363 pub fn bands(&self) -> &[GpuBuffer<T>] {
365 &self.bands
366 }
367
368 pub fn num_bands(&self) -> usize {
370 self.bands.len()
371 }
372
373 pub fn dimensions(&self) -> (u32, u32) {
375 (self.width, self.height)
376 }
377
378 pub fn width(&self) -> u32 {
380 self.width
381 }
382
383 pub fn height(&self) -> u32 {
385 self.height
386 }
387
388 pub async fn read_all_bands(&self) -> GpuResult<Vec<Vec<T>>> {
394 let mut results = Vec::with_capacity(self.bands.len());
395
396 for band in &self.bands {
397 results.push(band.read().await?);
398 }
399
400 Ok(results)
401 }
402
403 pub fn read_all_bands_blocking(&self) -> GpuResult<Vec<Vec<T>>> {
409 pollster::block_on(self.read_all_bands())
410 }
411}
412
413impl<T: Pod> std::fmt::Debug for GpuRasterBuffer<T> {
414 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415 f.debug_struct("GpuRasterBuffer")
416 .field("width", &self.width)
417 .field("height", &self.height)
418 .field("num_bands", &self.num_bands())
419 .field("type", &std::any::type_name::<T>())
420 .finish()
421 }
422}
423
424#[cfg(test)]
425#[allow(clippy::panic)]
426mod tests {
427 use super::*;
428
429 #[tokio::test]
430 async fn test_gpu_buffer_creation() {
431 if let Ok(context) = GpuContext::new().await {
432 let buffer: GpuBuffer<f32> = GpuBuffer::new(&context, 1024, BufferUsages::STORAGE)
433 .unwrap_or_else(|e| {
434 panic!("Failed to create buffer: {}", e);
435 });
436
437 assert_eq!(buffer.len(), 1024);
438 assert!(!buffer.is_empty());
439 }
440 }
441
442 #[tokio::test]
443 #[ignore]
444 async fn test_gpu_buffer_write_read() {
445 if let Ok(context) = GpuContext::new().await {
446 let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
447
448 let buffer = GpuBuffer::from_data(
449 &context,
450 &data,
451 BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
452 )
453 .unwrap_or_else(|e| {
454 panic!("Failed to create buffer: {}", e);
455 });
456
457 let mut staging = GpuBuffer::staging(&context, 100).unwrap_or_else(|e| {
459 panic!("Failed to create staging buffer: {}", e);
460 });
461
462 staging.copy_from(&buffer).unwrap_or_else(|e| {
463 panic!("Failed to copy buffer: {}", e);
464 });
465
466 let result = staging.read().await.unwrap_or_else(|e| {
467 panic!("Failed to read buffer: {}", e);
468 });
469
470 assert_eq!(result.len(), data.len());
471 for (a, b) in result.iter().zip(data.iter()) {
472 assert!((a - b).abs() < 1e-6);
473 }
474 }
475 }
476
477 #[tokio::test]
478 async fn test_gpu_raster_buffer() {
479 if let Ok(context) = GpuContext::new().await {
480 let width = 64;
481 let height = 64;
482 let num_bands = 3;
483
484 let raster: GpuRasterBuffer<f32> =
485 GpuRasterBuffer::new(&context, width, height, num_bands, BufferUsages::STORAGE)
486 .unwrap_or_else(|e| {
487 panic!("Failed to create raster buffer: {}", e);
488 });
489
490 assert_eq!(raster.width(), width);
491 assert_eq!(raster.height(), height);
492 assert_eq!(raster.num_bands(), num_bands);
493 }
494 }
495}