use wgpu::BufferUsages;
use pollster::FutureExt;
use tiny_wgpu::{
BindGroupItem, Compute, ComputeKernel, ComputeProgram, Storage
};
struct ComputeExample {
storage: Storage,
compute: Compute
}
impl ComputeProgram for ComputeExample {
fn storage(&self) -> &tiny_wgpu::Storage {
&self.storage
}
fn storage_mut(&mut self) -> &mut tiny_wgpu::Storage {
&mut self.storage
}
fn compute(&self) -> &tiny_wgpu::Compute {
&self.compute
}
}
fn main() {
let compute = Compute::new(
wgpu::Features::empty(),
wgpu::Limits::default()
).block_on();
let storage = Default::default();
let mut program = ComputeExample { compute, storage };
program.add_module("compute", wgpu::include_wgsl!("compute.wgsl"));
program.add_buffer(
"example_buffer",
BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
128 * 4
);
program.add_staging_buffer("example_buffer");
program.add_bind_group("example_bind_group", &[
BindGroupItem::StorageBuffer { label: "example_buffer", min_binding_size: 4, read_only: false }
]);
{
let bind_groups = &["example_bind_group"];
let push_constant_ranges = &[];
program.add_compute_pipelines("compute", bind_groups, &[ComputeKernel { label: "compute", entry_point: "compute" }], push_constant_ranges, None);
}
{
let data: Vec<u32> = (0u32..128).collect();
program.compute.queue.write_buffer(
&program.storage().buffers["example_buffer"],
0,
bytemuck::cast_slice(&data)
);
}
let mut encoder = program.compute.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: None
});
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None
});
cpass.set_pipeline(&program.storage().compute_pipelines["compute"]);
cpass.set_bind_group(0, &program.storage().bind_groups["example_bind_group"], &[]);
cpass.dispatch_workgroups(8, 1, 1);
}
program.copy_buffer_to_staging(&mut encoder, "example_buffer");
program.compute.queue.submit(Some(encoder.finish()));
program.prepare_staging_buffer("example_buffer");
program.compute.device.poll(wgpu::Maintain::Wait);
let mut output = vec![0u32; 128];
program.read_staging_buffer(
"example_buffer",
&mut output[..]
);
for i in 0..128 {
print!("{} ", output[i]);
assert_eq!(output[i], (i as u32) * 2);
}
}