pub struct GpuContext {
pub(crate) device: wgpu::Device,
pub(crate) queue: wgpu::Queue,
}
impl GpuContext {
pub fn new() -> Result<Self, String> {
pollster::block_on(Self::new_async())
}
async fn new_async() -> Result<Self, String> {
let instance = wgpu::Instance::default();
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.ok_or_else(|| "no GPU adapter found".to_string())?;
let (device, queue) = adapter
.request_device(&wgpu::DeviceDescriptor::default(), None)
.await
.map_err(|e| format!("device request failed: {e}"))?;
Ok(Self { device, queue })
}
pub(crate) fn create_storage_buffer(&self, data: &[f32]) -> wgpu::Buffer {
use wgpu::util::DeviceExt;
self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("storage buffer"),
contents: bytemuck::cast_slice(data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
})
}
pub(crate) fn create_output_buffer(&self, size: u64) -> wgpu::Buffer {
self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("output buffer"),
size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
}
pub(crate) fn create_staging_buffer(&self, size: u64) -> wgpu::Buffer {
self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("staging buffer"),
size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
})
}
pub(crate) fn read_buffer(&self, buffer: &wgpu::Buffer, size: u64) -> Vec<f32> {
let staging = self.create_staging_buffer(size);
let mut encoder = self.device.create_command_encoder(&Default::default());
encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, size);
self.queue.submit(std::iter::once(encoder.finish()));
let slice = staging.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
self.device.poll(wgpu::Maintain::Wait);
rx.recv().unwrap().unwrap();
let data = slice.get_mapped_range();
bytemuck::cast_slice(&data).to_vec()
}
pub(crate) fn dispatch(
&self,
pipeline: &wgpu::ComputePipeline,
bind_group: &wgpu::BindGroup,
workgroups_x: u32,
) {
let mut encoder = self.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(pipeline);
pass.set_bind_group(0, bind_group, &[]);
pass.dispatch_workgroups(workgroups_x, 1, 1);
}
self.queue.submit(std::iter::once(encoder.finish()));
}
}