use crate::error::{Error, Result};
use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::Mutex;
use wgpu;
pub struct GpuBuffer {
buffer: wgpu::Buffer,
size: usize,
device: Arc<wgpu::Device>,
}
impl GpuBuffer {
pub fn new(device: Arc<wgpu::Device>, size: usize) -> Self {
let buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("veda-gpu-buffer"),
size: size as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
Self {
buffer,
size,
device,
}
}
pub fn write_data(&self, queue: &wgpu::Queue, data: &[u8]) -> Result<()> {
if data.len() > self.size {
return Err(Error::gpu("Data too large for buffer"));
}
queue.write_buffer(&self.buffer, 0, data);
Ok(())
}
pub async fn read_data(&self, queue: &wgpu::Queue) -> Result<Vec<u8>> {
let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("veda-staging-buffer"),
size: self.size as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("veda-copy-encoder"),
});
encoder.copy_buffer_to_buffer(&self.buffer, 0, &staging_buffer, 0, self.size as u64);
queue.submit(Some(encoder.finish()));
let buffer_slice = staging_buffer.slice(..);
let (sender, receiver) = futures::channel::oneshot::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
sender.send(result).unwrap();
});
self.device.poll(wgpu::Maintain::Wait);
receiver.await
.map_err(|e| Error::gpu(format!("Failed to receive map result: {}", e)))?
.map_err(|e| Error::gpu(format!("Failed to map buffer: {:?}", e)))?;
let data = buffer_slice.get_mapped_range();
let result = data.to_vec();
drop(data);
staging_buffer.unmap();
Ok(result)
}
pub fn buffer(&self) -> &wgpu::Buffer {
&self.buffer
}
pub fn size(&self) -> usize {
self.size
}
pub fn device(&self) -> &wgpu::Device {
&self.device
}
}
pub struct BufferPool {
device: Arc<wgpu::Device>,
free_buffers: Mutex<HashMap<usize, Vec<GpuBuffer>>>,
}
impl BufferPool {
pub fn new(device: Arc<wgpu::Device>) -> Self {
Self {
device,
free_buffers: Mutex::new(HashMap::new()),
}
}
pub fn acquire(&self, size: usize) -> GpuBuffer {
let mut buffers = self.free_buffers.lock();
if let Some(pool) = buffers.get_mut(&size) {
if let Some(buffer) = pool.pop() {
return buffer;
}
}
GpuBuffer::new(Arc::clone(&self.device), size)
}
pub fn release(&self, buffer: GpuBuffer) {
let mut buffers = self.free_buffers.lock();
buffers.entry(buffer.size).or_insert_with(Vec::new).push(buffer);
}
pub fn clear(&self) {
self.free_buffers.lock().clear();
}
}