use bytemuck::Pod;
use std::collections::HashMap;
use std::marker::PhantomData;
pub fn storage_buffer_init(device: &wgpu::Device, label: &str, data: &[u8]) -> wgpu::Buffer {
assert!(
!data.is_empty(),
"storage_buffer_init: data must be non-empty"
);
use wgpu::util::DeviceExt as _;
device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: non_empty_label(label),
contents: data,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
})
}
pub fn uniform_buffer(device: &wgpu::Device, label: &str, data: &[u8]) -> wgpu::Buffer {
assert!(!data.is_empty(), "uniform_buffer: data must be non-empty");
use wgpu::util::DeviceExt as _;
device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: non_empty_label(label),
contents: data,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
})
}
pub fn staging_buffer(device: &wgpu::Device, label: &str, size: u64) -> wgpu::Buffer {
assert!(size > 0, "staging_buffer: size must be > 0");
device.create_buffer(&wgpu::BufferDescriptor {
label: non_empty_label(label),
size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
})
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "debug", skip(device, queue, buf))
)]
pub fn read_back<T: Pod>(
device: &wgpu::Device,
queue: &wgpu::Queue,
buf: &wgpu::Buffer,
len: usize,
) -> Vec<T> {
let byte_size = (std::mem::size_of::<T>() * len) as u64;
assert!(byte_size > 0, "read_back: requested size must be > 0");
let staging = staging_buffer(device, "oxiui-compute-wgpu readback staging", byte_size);
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("oxiui-compute-wgpu readback encoder"),
});
encoder.copy_buffer_to_buffer(buf, 0, &staging, 0, byte_size);
queue.submit(std::iter::once(encoder.finish()));
let slice = staging.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
device
.poll(wgpu::PollType::wait_indefinitely())
.expect("read_back: device poll failed");
rx.recv()
.expect("read_back: channel closed before map callback")
.expect("read_back: GPU mapping failed");
let mapped = slice.get_mapped_range();
let result: Vec<T> = bytemuck::cast_slice::<u8, T>(&mapped).to_vec();
drop(mapped);
staging.unmap();
result
}
pub fn mapped_storage_init(device: &wgpu::Device, label: &str, data: &[u8]) -> wgpu::Buffer {
assert!(
!data.is_empty(),
"mapped_storage_init: data must be non-empty"
);
let buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: non_empty_label(label),
size: data.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: true,
});
buffer
.slice(..)
.get_mapped_range_mut()
.copy_from_slice(data);
buffer.unmap();
buffer
}
pub fn read_back_range<T: bytemuck::Pod>(
device: &wgpu::Device,
queue: &wgpu::Queue,
src: &wgpu::Buffer,
byte_offset: u64,
len: usize,
) -> Vec<T> {
let byte_size = (len * std::mem::size_of::<T>()) as u64;
assert!(byte_size > 0, "read_back_range: requested size must be > 0");
let staging = staging_buffer(device, "", byte_size);
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
encoder.copy_buffer_to_buffer(src, byte_offset, &staging, 0, byte_size);
queue.submit(std::iter::once(encoder.finish()));
let slice = staging.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
device
.poll(wgpu::PollType::wait_indefinitely())
.expect("read_back_range: device poll failed");
rx.recv()
.expect("read_back_range: channel closed before map callback")
.expect("read_back_range: GPU mapping failed");
let mapped = slice.get_mapped_range();
let result = bytemuck::cast_slice::<u8, T>(&mapped).to_vec();
drop(mapped);
staging.unmap();
result
}
pub async fn read_back_async<T: bytemuck::Pod>(
device: &wgpu::Device,
queue: &wgpu::Queue,
src: &wgpu::Buffer,
len: usize,
) -> Result<Vec<T>, crate::ComputeError> {
let byte_size = (len * std::mem::size_of::<T>()) as u64;
assert!(byte_size > 0, "read_back_async: requested size must be > 0");
let staging = staging_buffer(device, "read-back-async", byte_size);
let mut encoder = device.create_command_encoder(&Default::default());
encoder.copy_buffer_to_buffer(src, 0, &staging, 0, byte_size);
queue.submit(std::iter::once(encoder.finish()));
let (tx, rx) = std::sync::mpsc::channel::<Result<(), wgpu::BufferAsyncError>>();
staging.slice(..).map_async(wgpu::MapMode::Read, move |r| {
let _ = tx.send(r);
});
std::future::ready(()).await;
device
.poll(wgpu::PollType::wait_indefinitely())
.map_err(|e| crate::ComputeError::Operation {
op: "read_back_async",
detail: e.to_string(),
})?;
rx.recv()
.map_err(|_| crate::ComputeError::Operation {
op: "read_back_async",
detail: "channel closed before map callback fired".into(),
})?
.map_err(|e| crate::ComputeError::Operation {
op: "read_back_async",
detail: e.to_string(),
})?;
let mapped = staging.slice(..).get_mapped_range();
let data = bytemuck::cast_slice::<u8, T>(&mapped).to_vec();
drop(mapped);
staging.unmap();
Ok(data)
}
pub struct TypedBuffer<T: bytemuck::Pod> {
buffer: wgpu::Buffer,
len: usize,
_phantom: PhantomData<T>,
}
impl<T: bytemuck::Pod> TypedBuffer<T> {
pub fn new(device: &wgpu::Device, label: &str, usage: wgpu::BufferUsages, len: usize) -> Self {
let size = (len * std::mem::size_of::<T>()) as u64;
let buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: non_empty_label(label),
size,
usage,
mapped_at_creation: false,
});
TypedBuffer {
buffer,
len,
_phantom: PhantomData,
}
}
pub fn from_data(device: &wgpu::Device, label: &str, data: &[T]) -> Self {
let bytes = bytemuck::cast_slice(data);
let buffer = storage_buffer_init(device, label, bytes);
TypedBuffer {
buffer,
len: data.len(),
_phantom: PhantomData,
}
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn byte_len(&self) -> u64 {
(self.len * std::mem::size_of::<T>()) as u64
}
pub fn as_entire_binding(&self) -> wgpu::BindingResource<'_> {
self.buffer.as_entire_binding()
}
pub fn inner(&self) -> &wgpu::Buffer {
&self.buffer
}
pub fn upload(&self, queue: &wgpu::Queue, data: &[T]) {
assert_eq!(data.len(), self.len, "upload length mismatch");
queue.write_buffer(&self.buffer, 0, bytemuck::cast_slice(data));
}
pub fn download(&self, device: &wgpu::Device, queue: &wgpu::Queue) -> Vec<T> {
read_back(device, queue, &self.buffer, self.len)
}
}
pub struct BufferPool {
buckets: HashMap<(u64, wgpu::BufferUsages), Vec<wgpu::Buffer>>,
}
impl BufferPool {
pub fn new() -> Self {
BufferPool {
buckets: HashMap::new(),
}
}
pub fn acquire(
&mut self,
device: &wgpu::Device,
size: u64,
usage: wgpu::BufferUsages,
) -> wgpu::Buffer {
let rounded = size.next_power_of_two().max(256);
let bucket = self.buckets.entry((rounded, usage)).or_default();
if let Some(buf) = bucket.pop() {
return buf;
}
device.create_buffer(&wgpu::BufferDescriptor {
label: Some("pool-buffer"),
size: rounded,
usage,
mapped_at_creation: false,
})
}
pub fn release(&mut self, size: u64, usage: wgpu::BufferUsages, buffer: wgpu::Buffer) {
let rounded = size.next_power_of_two().max(256);
self.buckets
.entry((rounded, usage))
.or_default()
.push(buffer);
}
pub fn available_count(&self, size: u64, usage: wgpu::BufferUsages) -> usize {
let rounded = size.next_power_of_two().max(256);
self.buckets.get(&(rounded, usage)).map_or(0, |v| v.len())
}
}
impl Default for BufferPool {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SubRegion {
pub offset: u64,
pub size: u64,
}
pub struct SubAllocator {
buffer: wgpu::Buffer,
capacity: u64,
cursor: u64,
alignment: u64,
}
impl SubAllocator {
pub fn new(buffer: wgpu::Buffer, capacity: u64, alignment: u64) -> Self {
SubAllocator {
buffer,
capacity,
cursor: 0,
alignment: alignment.max(1),
}
}
pub fn alloc(&mut self, size: u64) -> Option<SubRegion> {
let aligned_cursor = align_up(self.cursor, self.alignment);
let end = aligned_cursor.checked_add(size)?;
if end > self.capacity {
return None;
}
self.cursor = end;
Some(SubRegion {
offset: aligned_cursor,
size,
})
}
pub fn reset(&mut self) {
self.cursor = 0;
}
pub fn inner(&self) -> &wgpu::Buffer {
&self.buffer
}
pub fn used(&self) -> u64 {
self.cursor
}
pub fn remaining(&self) -> u64 {
self.capacity.saturating_sub(self.cursor)
}
}
#[inline]
fn non_empty_label(label: &str) -> Option<&str> {
if label.is_empty() {
None
} else {
Some(label)
}
}
#[inline]
fn align_up(value: u64, alignment: u64) -> u64 {
if alignment == 0 {
return value;
}
value.div_ceil(alignment) * alignment
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::ComputeContext;
macro_rules! require_gpu {
($ctx:ident) => {
let Some($ctx) = ComputeContext::try_new() else {
return; };
};
}
#[test]
fn storage_buffer_init_roundtrip() {
require_gpu!(ctx);
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes = bytemuck::cast_slice::<f32, u8>(&data);
let buf = storage_buffer_init(&ctx.device, "test-storage", bytes);
let back: Vec<f32> = read_back(&ctx.device, &ctx.queue, &buf, data.len());
assert_eq!(back, data);
}
#[test]
fn uniform_buffer_created() {
require_gpu!(ctx);
let data: [f32; 4] = [0.1, 0.2, 0.3, 0.4];
let bytes = bytemuck::cast_slice::<f32, u8>(&data);
let _buf = uniform_buffer(&ctx.device, "test-uniform", bytes);
}
#[test]
fn staging_buffer_created() {
require_gpu!(ctx);
let _buf = staging_buffer(&ctx.device, "test-staging", 256);
}
#[test]
fn non_empty_label_behaviour() {
assert_eq!(non_empty_label("foo"), Some("foo"));
assert_eq!(non_empty_label(""), None);
}
#[test]
fn typed_buffer_len_math() {
assert_eq!(std::mem::size_of::<f32>(), 4);
let len: usize = 8;
assert_eq!(len * std::mem::size_of::<f32>(), 32);
assert_eq!((len * std::mem::size_of::<f32>()) as u64, 32u64);
}
#[test]
fn suballocator_offsets_aligned() {
let first_aligned = align_up(0, 256);
assert_eq!(first_aligned, 0);
let after_first = first_aligned + 100; let second_aligned = align_up(after_first, 256);
assert!(
second_aligned >= 256,
"second offset {second_aligned} should be >= 256"
);
assert_eq!(second_aligned % 256, 0);
}
#[test]
fn suballocator_reset_rewinds() {
require_gpu!(ctx);
let backing = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("sub-alloc-test"),
size: 1024,
usage: wgpu::BufferUsages::STORAGE,
mapped_at_creation: false,
});
let mut sa = SubAllocator::new(backing, 1024, 256);
let r1 = sa.alloc(100).expect("first alloc should succeed");
assert_eq!(r1.offset, 0);
sa.reset();
let r2 = sa.alloc(100).expect("post-reset alloc should succeed");
assert_eq!(r2.offset, 0, "after reset, offset must restart at 0");
}
#[test]
fn buffer_pool_size_rounds_up() {
assert_eq!(256u64.next_power_of_two(), 256);
assert_eq!(300u64.next_power_of_two(), 512);
assert_eq!(1u64.next_power_of_two().max(256), 256);
assert_eq!(255u64.next_power_of_two().max(256), 256);
}
#[test]
fn bytemuck_pod_roundtrip() {
let original: [f32; 3] = [1.0, 2.0, 3.0];
let bytes: &[u8] = bytemuck::cast_slice(&original);
assert_eq!(bytes.len(), 12);
let back: &[f32] = bytemuck::cast_slice(bytes);
assert_eq!(back, &original);
}
#[test]
fn pool_acquire_reuses_buffer() {
require_gpu!(ctx);
let mut pool = BufferPool::new();
let usage = wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC;
let size: u64 = 256;
assert_eq!(pool.available_count(size, usage), 0);
let buf = pool.acquire(&ctx.device, size, usage);
pool.release(size, usage, buf);
assert_eq!(pool.available_count(size, usage), 1);
let _buf2 = pool.acquire(&ctx.device, size, usage);
assert_eq!(pool.available_count(size, usage), 0);
}
#[test]
fn mapped_init_roundtrip() {
require_gpu!(ctx);
let data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
let bytes = bytemuck::cast_slice::<f32, u8>(&data);
let src = mapped_storage_init(&ctx.device, "mapped-init-test", bytes);
let staging = staging_buffer(&ctx.device, "mapped-init-staging", bytes.len() as u64);
let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("mapped-init-readback"),
});
encoder.copy_buffer_to_buffer(&src, 0, &staging, 0, bytes.len() as u64);
ctx.queue.submit(std::iter::once(encoder.finish()));
let slice = staging.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = tx.send(r);
});
ctx.device
.poll(wgpu::PollType::wait_indefinitely())
.expect("poll failed");
rx.recv()
.expect("channel closed")
.expect("map_async failed");
let mapped = slice.get_mapped_range();
let back: Vec<f32> = bytemuck::cast_slice::<u8, f32>(&mapped).to_vec();
drop(mapped);
staging.unmap();
assert_eq!(back, data);
}
#[test]
fn read_back_range_returns_subslice() {
require_gpu!(ctx);
let data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
let bytes = bytemuck::cast_slice::<f32, u8>(&data);
let buf = storage_buffer_init(&ctx.device, "range-test", bytes);
let sub: Vec<f32> = read_back_range(&ctx.device, &ctx.queue, &buf, 4, 2);
assert_eq!(sub, vec![20.0f32, 30.0]);
}
#[test]
fn async_readback_matches_sync() {
require_gpu!(ctx);
let data: Vec<f32> = vec![5.0, 6.0, 7.0, 8.0];
let bytes = bytemuck::cast_slice::<f32, u8>(&data);
let buf = storage_buffer_init(&ctx.device, "async-readback-test", bytes);
let sync_result: Vec<f32> = read_back(&ctx.device, &ctx.queue, &buf, data.len());
let async_result: Vec<f32> =
pollster::block_on(read_back_async(&ctx.device, &ctx.queue, &buf, data.len()))
.expect("async readback failed");
assert_eq!(sync_result, async_result);
assert_eq!(async_result, data);
}
}