pub struct ComputePipeline {
pipeline: wgpu::ComputePipeline,
bind_group_layouts: Vec<wgpu::BindGroupLayout>,
}
impl ComputePipeline {
pub fn new(
device: &wgpu::Device,
wgsl_source: &str,
entry_point: &str,
buffer_count: u32,
) -> Self {
let entries: Vec<wgpu::BindGroupLayoutEntry> = (0..buffer_count)
.map(|i| wgpu::BindGroupLayoutEntry {
binding: i,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: i > 0 },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
})
.collect();
Self::with_layout(device, wgsl_source, entry_point, &entries)
}
pub fn with_layout(
device: &wgpu::Device,
wgsl_source: &str,
entry_point: &str,
entries: &[wgpu::BindGroupLayoutEntry],
) -> Self {
Self::with_layouts(device, wgsl_source, entry_point, &[entries])
}
pub fn with_layouts(
device: &wgpu::Device,
wgsl_source: &str,
entry_point: &str,
groups: &[&[wgpu::BindGroupLayoutEntry]],
) -> Self {
tracing::debug!(
entry_point,
groups = groups.len(),
"creating compute pipeline"
);
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("compute_shader"),
source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
});
let bind_group_layouts: Vec<wgpu::BindGroupLayout> = groups
.iter()
.enumerate()
.map(|(i, entries)| {
use std::fmt::Write;
let mut label = String::with_capacity(20);
let _ = write!(label, "compute_layout_{i}");
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(&label),
entries,
})
})
.collect();
let layout_refs: Vec<Option<&wgpu::BindGroupLayout>> =
bind_group_layouts.iter().map(Some).collect();
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("compute_pipeline_layout"),
bind_group_layouts: &layout_refs,
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("compute_pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some(entry_point),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
Self {
pipeline,
bind_group_layouts,
}
}
#[must_use]
#[inline]
pub fn bind_group_layout(&self, index: usize) -> Option<&wgpu::BindGroupLayout> {
self.bind_group_layouts.get(index)
}
#[must_use]
#[inline]
pub fn bind_group_layout_count(&self) -> usize {
self.bind_group_layouts.len()
}
#[must_use]
#[inline]
pub fn raw(&self) -> &wgpu::ComputePipeline {
&self.pipeline
}
pub fn dispatch(
&self,
device: &wgpu::Device,
queue: &wgpu::Queue,
bind_group: &wgpu::BindGroup,
workgroups_x: u32,
workgroups_y: u32,
workgroups_z: u32,
) {
self.dispatch_multi(
device,
queue,
&[bind_group],
workgroups_x,
workgroups_y,
workgroups_z,
);
}
pub fn dispatch_multi(
&self,
device: &wgpu::Device,
queue: &wgpu::Queue,
bind_groups: &[&wgpu::BindGroup],
workgroups_x: u32,
workgroups_y: u32,
workgroups_z: u32,
) {
tracing::debug!(
workgroups_x,
workgroups_y,
workgroups_z,
groups = bind_groups.len(),
"compute dispatch"
);
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("compute_encoder"),
});
self.encode_dispatch_multi(
&mut encoder,
bind_groups,
workgroups_x,
workgroups_y,
workgroups_z,
);
queue.submit(std::iter::once(encoder.finish()));
}
pub fn encode_dispatch(
&self,
encoder: &mut wgpu::CommandEncoder,
bind_group: &wgpu::BindGroup,
workgroups_x: u32,
workgroups_y: u32,
workgroups_z: u32,
) {
self.encode_dispatch_multi(
encoder,
&[bind_group],
workgroups_x,
workgroups_y,
workgroups_z,
);
}
pub fn encode_dispatch_multi(
&self,
encoder: &mut wgpu::CommandEncoder,
bind_groups: &[&wgpu::BindGroup],
workgroups_x: u32,
workgroups_y: u32,
workgroups_z: u32,
) {
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("compute_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.pipeline);
for (i, bg) in bind_groups.iter().enumerate() {
pass.set_bind_group(i as u32, *bg, &[]);
}
pass.dispatch_workgroups(workgroups_x, workgroups_y, workgroups_z);
}
pub fn encode_dispatch_indirect(
&self,
encoder: &mut wgpu::CommandEncoder,
bind_groups: &[&wgpu::BindGroup],
indirect_buffer: &wgpu::Buffer,
indirect_offset: u64,
) {
tracing::debug!(indirect_offset, "compute indirect dispatch");
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("compute_pass_indirect"),
timestamp_writes: None,
});
pass.set_pipeline(&self.pipeline);
for (i, bg) in bind_groups.iter().enumerate() {
pass.set_bind_group(i as u32, *bg, &[]);
}
pass.dispatch_workgroups_indirect(indirect_buffer, indirect_offset);
}
}
pub struct PingPongBuffer {
buffers: [wgpu::Buffer; 2],
current: usize,
}
impl PingPongBuffer {
pub fn new(device: &wgpu::Device, size: u64, label: &str) -> Self {
tracing::debug!(size, label, "creating ping-pong buffer pair");
let buffers = [
device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("{label}_a")),
size,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
}),
device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("{label}_b")),
size,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
}),
];
Self {
buffers,
current: 0,
}
}
#[must_use]
#[inline]
pub fn source(&self) -> &wgpu::Buffer {
&self.buffers[self.current]
}
#[must_use]
#[inline]
pub fn dest(&self) -> &wgpu::Buffer {
&self.buffers[1 - self.current]
}
#[inline]
pub fn swap(&mut self) {
self.current = 1 - self.current;
}
#[must_use]
#[inline]
pub fn index(&self) -> usize {
self.current
}
}
pub fn validate_dispatch(
limits: &wgpu::Limits,
workgroups_x: u32,
workgroups_y: u32,
workgroups_z: u32,
) -> crate::error::Result<()> {
use crate::error::GpuError;
let max = limits.max_compute_workgroups_per_dimension;
if workgroups_x > max {
return Err(GpuError::WorkgroupLimitExceeded {
axis: "x",
actual: workgroups_x,
limit: max,
});
}
if workgroups_y > max {
return Err(GpuError::WorkgroupLimitExceeded {
axis: "y",
actual: workgroups_y,
limit: max,
});
}
if workgroups_z > max {
return Err(GpuError::WorkgroupLimitExceeded {
axis: "z",
actual: workgroups_z,
limit: max,
});
}
Ok(())
}
#[must_use]
#[inline]
pub fn workgroups_1d(total: u32, workgroup_size: u32) -> u32 {
total.div_ceil(workgroup_size)
}
#[must_use]
#[inline]
pub fn workgroups_2d(width: u32, height: u32, wg_x: u32, wg_y: u32) -> (u32, u32) {
(width.div_ceil(wg_x), height.div_ceil(wg_y))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compute_pipeline_types() {
let _size = std::mem::size_of::<ComputePipeline>();
}
#[test]
fn workgroups_1d_exact() {
assert_eq!(workgroups_1d(256, 256), 1);
assert_eq!(workgroups_1d(512, 256), 2);
}
#[test]
fn workgroups_1d_remainder() {
assert_eq!(workgroups_1d(257, 256), 2);
assert_eq!(workgroups_1d(1, 256), 1);
}
#[test]
fn workgroups_2d_exact() {
assert_eq!(workgroups_2d(32, 32, 16, 16), (2, 2));
}
#[test]
fn workgroups_2d_remainder() {
assert_eq!(workgroups_2d(33, 17, 16, 16), (3, 2));
}
#[test]
fn workgroups_1d_single() {
assert_eq!(workgroups_1d(1, 64), 1);
assert_eq!(workgroups_1d(0, 64), 0);
}
#[test]
fn workgroups_2d_single() {
assert_eq!(workgroups_2d(1, 1, 8, 8), (1, 1));
assert_eq!(workgroups_2d(0, 0, 8, 8), (0, 0));
}
#[test]
fn validate_dispatch_within_limits() {
let limits = wgpu::Limits {
max_compute_workgroups_per_dimension: 65535,
..Default::default()
};
assert!(validate_dispatch(&limits, 100, 100, 1).is_ok());
assert!(validate_dispatch(&limits, 65535, 65535, 65535).is_ok());
}
#[test]
fn validate_dispatch_exceeds_limits() {
let limits = wgpu::Limits {
max_compute_workgroups_per_dimension: 65535,
..Default::default()
};
assert!(validate_dispatch(&limits, 65536, 1, 1).is_err());
assert!(validate_dispatch(&limits, 1, 65536, 1).is_err());
assert!(validate_dispatch(&limits, 1, 1, 65536).is_err());
}
#[test]
fn validate_dispatch_error_contains_axis() {
let limits = wgpu::Limits {
max_compute_workgroups_per_dimension: 100,
..Default::default()
};
let err = validate_dispatch(&limits, 200, 1, 1).unwrap_err();
assert!(err.to_string().contains("x"));
let err = validate_dispatch(&limits, 1, 200, 1).unwrap_err();
assert!(err.to_string().contains("y"));
}
#[test]
fn workgroups_1d_large() {
assert_eq!(workgroups_1d(1_000_000, 256), 3907);
assert_eq!(workgroups_1d(u32::MAX, 256), 16_777_216);
}
#[test]
fn ping_pong_swap() {
let mut current = 0usize;
assert_eq!(current, 0);
assert_eq!(1 - current, 1);
current = 1 - current;
assert_eq!(current, 1);
assert_eq!(1 - current, 0);
current = 1 - current;
assert_eq!(current, 0);
}
#[test]
fn ping_pong_types() {
let _size = std::mem::size_of::<PingPongBuffer>();
}
fn try_gpu() -> Option<(wgpu::Device, wgpu::Queue)> {
let ctx = pollster::block_on(crate::context::GpuContext::new()).ok()?;
Some((ctx.device, ctx.queue))
}
const DOUBLE_SHADER: &str = r#"
@group(0) @binding(0) var<storage, read_write> output: array<f32>;
@group(0) @binding(1) var<storage, read> input: array<f32>;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id: vec3u) {
if id.x < arrayLength(&input) {
output[id.x] = input[id.x] * 2.0;
}
}
"#;
#[test]
fn gpu_compute_pipeline_create() {
let Some((device, _queue)) = try_gpu() else {
return;
};
let pipeline = ComputePipeline::new(&device, DOUBLE_SHADER, "main", 2);
assert_eq!(pipeline.bind_group_layout_count(), 1);
assert!(pipeline.bind_group_layout(0).is_some());
assert!(pipeline.bind_group_layout(1).is_none());
}
#[test]
fn gpu_compute_dispatch_roundtrip() {
let Some((device, queue)) = try_gpu() else {
return;
};
let pipeline = ComputePipeline::new(&device, DOUBLE_SHADER, "main", 2);
let input: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
let input_buf = crate::buffer::create_storage_buffer(
&device,
bytemuck::cast_slice(&input),
"input",
true,
);
let output_buf = crate::buffer::create_storage_buffer_empty(&device, 16, "output", false);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("test_bg"),
layout: pipeline.bind_group_layout(0).unwrap(),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: output_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: input_buf.as_entire_binding(),
},
],
});
pipeline.dispatch(&device, &queue, &bind_group, 1, 1, 1);
let result: Vec<f32> =
crate::buffer::read_buffer_typed(&device, &queue, &output_buf, 4).unwrap();
assert_eq!(result, vec![2.0, 4.0, 6.0, 8.0]);
}
#[test]
fn gpu_ping_pong_buffer() {
let Some((device, _queue)) = try_gpu() else {
return;
};
let mut pp = PingPongBuffer::new(&device, 64, "pp_test");
assert_eq!(pp.index(), 0);
let src0 = pp.source() as *const _;
let dst0 = pp.dest() as *const _;
pp.swap();
assert_eq!(pp.index(), 1);
assert_eq!(src0, pp.dest() as *const _);
assert_eq!(dst0, pp.source() as *const _);
}
}