use crate::error::{GpuError, Result};
pub struct UniformBuffer<T: bytemuck::Pod> {
buffer: wgpu::Buffer,
_marker: std::marker::PhantomData<T>,
}
impl<T: bytemuck::Pod> UniformBuffer<T> {
#[must_use = "GPU buffer allocated but not used"]
pub fn new(device: &wgpu::Device, data: &T, label: &str) -> Result<Self> {
let size = std::mem::size_of::<T>();
if size == 0 {
return Err(GpuError::Buffer("uniform buffer type has zero size".into()));
}
if !size.is_multiple_of(16) {
return Err(GpuError::Buffer(format!(
"uniform buffer type size ({size} bytes) must be a multiple of 16"
)));
}
tracing::debug!(label, size, "creating typed uniform buffer");
let buffer = crate::buffer::create_uniform_buffer(device, bytemuck::bytes_of(data), label);
Ok(Self {
buffer,
_marker: std::marker::PhantomData,
})
}
#[inline]
pub fn write(&self, queue: &wgpu::Queue, data: &T) {
queue.write_buffer(&self.buffer, 0, bytemuck::bytes_of(data));
}
#[must_use]
#[inline]
pub fn buffer(&self) -> &wgpu::Buffer {
&self.buffer
}
}
pub struct StorageBuffer<T: bytemuck::Pod> {
buffer: wgpu::Buffer,
count: usize,
_marker: std::marker::PhantomData<T>,
}
impl<T: bytemuck::Pod> StorageBuffer<T> {
#[must_use = "GPU buffer allocated but not used"]
pub fn new(device: &wgpu::Device, data: &[T], label: &str, read_only: bool) -> Self {
tracing::debug!(
label,
count = data.len(),
element_size = std::mem::size_of::<T>(),
read_only,
"creating typed storage buffer"
);
let buffer = crate::buffer::create_storage_buffer(
device,
bytemuck::cast_slice(data),
label,
read_only,
);
Self {
buffer,
count: data.len(),
_marker: std::marker::PhantomData,
}
}
#[must_use = "GPU buffer allocated but not used"]
pub fn empty(device: &wgpu::Device, count: usize, label: &str, read_only: bool) -> Self {
let size = count.saturating_mul(std::mem::size_of::<T>()) as u64;
tracing::debug!(
label,
count,
element_size = std::mem::size_of::<T>(),
read_only,
"creating empty typed storage buffer"
);
let buffer = crate::buffer::create_storage_buffer_empty(device, size, label, read_only);
Self {
buffer,
count,
_marker: std::marker::PhantomData,
}
}
pub fn write(&self, queue: &wgpu::Queue, data: &[T]) -> Result<()> {
if data.len() > self.count {
return Err(GpuError::Buffer(format!(
"storage buffer write exceeds capacity: {} > {}",
data.len(),
self.count
)));
}
queue.write_buffer(&self.buffer, 0, bytemuck::cast_slice(data));
Ok(())
}
#[must_use]
#[inline]
pub fn buffer(&self) -> &wgpu::Buffer {
&self.buffer
}
#[must_use]
#[inline]
pub fn count(&self) -> usize {
self.count
}
}
#[cfg(test)]
mod tests {
use super::*;
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct Aligned16 {
data: [f32; 4], }
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct Aligned32 {
a: [f32; 4],
b: [f32; 4],
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct Unaligned12 {
data: [f32; 3], }
#[test]
fn uniform_buffer_alignment_check() {
assert_eq!(std::mem::size_of::<Aligned16>(), 16);
assert_eq!(std::mem::size_of::<Aligned32>(), 32);
assert_eq!(std::mem::size_of::<Unaligned12>(), 12);
assert!(std::mem::size_of::<Aligned16>().is_multiple_of(16));
assert!(std::mem::size_of::<Aligned32>().is_multiple_of(16));
assert!(!std::mem::size_of::<Unaligned12>().is_multiple_of(16));
}
#[test]
fn storage_buffer_types() {
let _size = std::mem::size_of::<StorageBuffer<f32>>();
}
#[test]
fn uniform_buffer_types() {
let _size = std::mem::size_of::<UniformBuffer<Aligned16>>();
}
#[test]
fn storage_buffer_phantom_data() {
assert_eq!(
std::mem::size_of::<StorageBuffer<f32>>(),
std::mem::size_of::<StorageBuffer<[f32; 4]>>()
);
}
fn try_gpu() -> Option<(wgpu::Device, wgpu::Queue)> {
let ctx = pollster::block_on(crate::context::GpuContext::new()).ok()?;
Some((ctx.device, ctx.queue))
}
#[test]
fn gpu_uniform_buffer_create_aligned() {
let Some((device, _queue)) = try_gpu() else {
return;
};
let data = Aligned16 { data: [1.0; 4] };
let buf = UniformBuffer::new(&device, &data, "test_uniform");
assert!(buf.is_ok());
}
#[test]
fn gpu_uniform_buffer_reject_unaligned() {
let Some((device, _queue)) = try_gpu() else {
return;
};
let data = Unaligned12 { data: [1.0; 3] };
let result = UniformBuffer::new(&device, &data, "bad_uniform");
assert!(result.is_err());
}
#[test]
fn gpu_uniform_buffer_write() {
let Some((device, queue)) = try_gpu() else {
return;
};
let data = Aligned32 {
a: [1.0; 4],
b: [2.0; 4],
};
let buf = UniformBuffer::new(&device, &data, "write_test").unwrap();
let updated = Aligned32 {
a: [3.0; 4],
b: [4.0; 4],
};
buf.write(&queue, &updated);
}
#[test]
fn gpu_storage_buffer_create() {
let Some((device, _queue)) = try_gpu() else {
return;
};
let data: [f32; 8] = [0.0; 8];
let buf = StorageBuffer::new(&device, &data, "test_storage", false);
assert_eq!(buf.count(), 8);
}
#[test]
fn gpu_storage_buffer_empty() {
let Some((device, _queue)) = try_gpu() else {
return;
};
let buf = StorageBuffer::<f32>::empty(&device, 256, "empty_storage", true);
assert_eq!(buf.count(), 256);
}
#[test]
fn gpu_storage_buffer_write_ok() {
let Some((device, queue)) = try_gpu() else {
return;
};
let data: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
let buf = StorageBuffer::new(&device, &data, "write_test", false);
assert!(buf.write(&queue, &[5.0, 6.0, 7.0, 8.0]).is_ok());
}
#[test]
fn gpu_storage_buffer_write_exceeds_capacity() {
let Some((device, queue)) = try_gpu() else {
return;
};
let data: [f32; 2] = [1.0, 2.0];
let buf = StorageBuffer::new(&device, &data, "small", false);
let too_big: [f32; 4] = [0.0; 4];
assert!(buf.write(&queue, &too_big).is_err());
}
#[test]
fn gpu_storage_buffer_write_partial() {
let Some((device, queue)) = try_gpu() else {
return;
};
let data: [f32; 8] = [0.0; 8];
let buf = StorageBuffer::new(&device, &data, "partial_write", false);
let partial: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
assert!(buf.write(&queue, &partial).is_ok());
}
#[test]
fn gpu_storage_buffer_count_and_buffer() {
let Some((device, _queue)) = try_gpu() else {
return;
};
let data: [f32; 16] = [0.0; 16];
let buf = StorageBuffer::new(&device, &data, "accessors", false);
assert_eq!(buf.count(), 16);
let _raw: &wgpu::Buffer = buf.buffer();
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct Aligned48 {
data: [f32; 12], }
#[test]
fn gpu_uniform_buffer_48_aligned() {
let Some((device, queue)) = try_gpu() else {
return;
};
assert!(std::mem::size_of::<Aligned48>().is_multiple_of(16));
let data = Aligned48 { data: [0.5; 12] };
let buf = UniformBuffer::new(&device, &data, "aligned48");
assert!(buf.is_ok());
let buf = buf.unwrap();
let updated = Aligned48 { data: [1.0; 12] };
buf.write(&queue, &updated);
}
}