use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
pub const MAX_WORKGROUPS: u32 = 65535;
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "debug", skip(device, wgsl_source))
)]
pub fn compute_pipeline(
device: &wgpu::Device,
wgsl_source: &str,
entry_point: &str,
) -> wgpu::ComputePipeline {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("oxiui-compute-wgpu shader"),
source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
});
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("oxiui-compute-wgpu pipeline"),
layout: None,
module: &shader,
entry_point: Some(entry_point),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
})
}
pub fn dispatch_1d(n: u32, workgroup_size: u32) -> u32 {
n.div_ceil(workgroup_size).min(MAX_WORKGROUPS)
}
pub fn dispatch_2d(width: u32, height: u32, wg_x: u32, wg_y: u32) -> (u32, u32) {
(
width.div_ceil(wg_x).min(MAX_WORKGROUPS),
height.div_ceil(wg_y).min(MAX_WORKGROUPS),
)
}
pub fn dispatch_3d(x: u32, y: u32, z: u32, wg_x: u32, wg_y: u32, wg_z: u32) -> (u32, u32, u32) {
(
x.div_ceil(wg_x).min(MAX_WORKGROUPS),
y.div_ceil(wg_y).min(MAX_WORKGROUPS),
z.div_ceil(wg_z).min(MAX_WORKGROUPS),
)
}
fn wgsl_hash(source: &str, entry_point: &str) -> u64 {
let mut h = DefaultHasher::new();
source.hash(&mut h);
entry_point.hash(&mut h);
h.finish()
}
pub struct PipelineCache {
cache: HashMap<u64, Arc<wgpu::ComputePipeline>>,
compile_count: usize,
}
impl PipelineCache {
pub fn new() -> Self {
PipelineCache {
cache: HashMap::new(),
compile_count: 0,
}
}
pub fn get_or_compile(
&mut self,
device: &wgpu::Device,
wgsl_source: &str,
entry_point: &str,
) -> Arc<wgpu::ComputePipeline> {
let key = wgsl_hash(wgsl_source, entry_point);
if let Some(p) = self.cache.get(&key) {
return Arc::clone(p);
}
let pipeline = Arc::new(compute_pipeline(device, wgsl_source, entry_point));
self.cache.insert(key, Arc::clone(&pipeline));
self.compile_count += 1;
pipeline
}
pub fn compile_count(&self) -> usize {
self.compile_count
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
}
impl Default for PipelineCache {
fn default() -> Self {
Self::new()
}
}
pub struct DispatchResult {
pub submission_index: wgpu::SubmissionIndex,
}
pub struct DispatchBuilder<'a> {
pipeline: &'a wgpu::ComputePipeline,
bind_groups: Vec<(u32, &'a wgpu::BindGroup)>,
dispatch: (u32, u32, u32),
label: Option<&'a str>,
}
impl<'a> DispatchBuilder<'a> {
pub fn new(pipeline: &'a wgpu::ComputePipeline) -> Self {
DispatchBuilder {
pipeline,
bind_groups: Vec::new(),
dispatch: (1, 1, 1),
label: None,
}
}
pub fn bind(mut self, index: u32, bind_group: &'a wgpu::BindGroup) -> Self {
self.bind_groups.push((index, bind_group));
self
}
pub fn dispatch_1d(mut self, n: u32, workgroup_size: u32) -> Self {
self.dispatch = (crate::pipeline::dispatch_1d(n, workgroup_size), 1, 1);
self
}
pub fn dispatch_xyz(mut self, x: u32, y: u32, z: u32) -> Self {
self.dispatch = (x, y, z);
self
}
pub fn label(mut self, l: &'a str) -> Self {
self.label = Some(l);
self
}
pub fn submit(self, device: &wgpu::Device, queue: &wgpu::Queue) -> DispatchResult {
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: self.label });
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: self.label,
timestamp_writes: None,
});
pass.set_pipeline(self.pipeline);
for (idx, bg) in &self.bind_groups {
pass.set_bind_group(*idx, *bg, &[]);
}
let (x, y, z) = self.dispatch;
pass.dispatch_workgroups(x, y, z);
}
let submission_index = queue.submit(std::iter::once(encoder.finish()));
DispatchResult { submission_index }
}
}
pub fn validate_immediates(offset: u32, data_len: usize) -> Result<(), String> {
let align = wgpu::IMMEDIATE_DATA_ALIGNMENT;
if !offset.is_multiple_of(align) {
return Err(format!(
"immediates offset {offset} must be aligned to {align}"
));
}
if !(data_len as u32).is_multiple_of(align) {
return Err(format!(
"immediates data length {data_len} must be aligned to {align}"
));
}
Ok(())
}
pub fn supports_immediates(features: wgpu::Features) -> bool {
features.contains(wgpu::Features::IMMEDIATES)
}
pub fn encode_indirect_dispatch<'enc>(
pass: &mut wgpu::ComputePass<'enc>,
indirect_buffer: &'enc wgpu::Buffer,
offset: u64,
) {
pass.dispatch_workgroups_indirect(indirect_buffer, offset);
}
pub fn checked_compute_pipeline(
device: &wgpu::Device,
wgsl_source: &str,
entry_point: &str,
) -> Result<wgpu::ComputePipeline, crate::ComputeError> {
use wgpu::CompilationMessageType;
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
});
let info = pollster::block_on(shader.get_compilation_info());
let errors: Vec<String> = info
.messages
.iter()
.filter(|m| m.message_type == CompilationMessageType::Error)
.map(|m| {
if let Some(loc) = &m.location {
format!("{}:{}: {}", loc.line_number, loc.line_position, m.message)
} else {
m.message.clone()
}
})
.collect();
if !errors.is_empty() {
return Err(crate::ComputeError::ShaderCompilation(errors.join("; ")));
}
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &shader,
entry_point: Some(entry_point),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
Ok(pipeline)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::ComputeContext;
const PASSTHROUGH_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read_write> buf: array<f32>;
@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
buf[gid.x] = buf[gid.x];
}
"#;
#[test]
fn compute_pipeline_compiles() {
let Some(ctx) = ComputeContext::try_new() else {
return; };
let _pipeline = compute_pipeline(&ctx.device, PASSTHROUGH_WGSL, "main");
}
#[test]
fn compute_pipeline_double_shader() {
let Some(ctx) = ComputeContext::try_new() else {
return; };
const DOUBLE_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read_write> data: array<f32>;
@compute @workgroup_size(64)
fn double_all(@builtin(global_invocation_id) gid: vec3<u32>) {
data[gid.x] *= 2.0;
}
"#;
let _pipeline = compute_pipeline(&ctx.device, DOUBLE_WGSL, "double_all");
}
#[test]
fn dispatch_1d_100_wg64() {
assert_eq!(dispatch_1d(100, 64), 2);
}
#[test]
fn dispatch_1d_exact_multiple() {
assert_eq!(dispatch_1d(128, 64), 2);
}
#[test]
fn dispatch_1d_one_element() {
assert_eq!(dispatch_1d(1, 64), 1);
}
#[test]
fn dispatch_1d_clamp_at_max() {
assert_eq!(dispatch_1d(u32::MAX, 1), MAX_WORKGROUPS);
}
#[test]
fn dispatch_2d_smoke() {
assert_eq!(dispatch_2d(100, 200, 16, 16), (7, 13));
}
#[test]
fn dispatch_3d_smoke() {
assert_eq!(dispatch_3d(10, 10, 10, 4, 4, 4), (3, 3, 3));
}
#[test]
fn pipeline_cache_default_empty() {
let cache = PipelineCache::new();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn validate_immediates_aligned_ok() {
assert!(validate_immediates(0, 16).is_ok());
assert!(validate_immediates(4, 8).is_ok());
}
#[test]
fn validate_immediates_offset_unaligned() {
assert!(validate_immediates(1, 4).is_err());
assert!(validate_immediates(3, 4).is_err());
}
#[test]
fn validate_immediates_data_unaligned() {
assert!(validate_immediates(0, 3).is_err());
assert!(validate_immediates(0, 5).is_err());
}
#[test]
fn wgsl_hash_distinguishes_sources() {
assert_ne!(wgsl_hash("a", "main"), wgsl_hash("b", "main"));
assert_ne!(wgsl_hash("src", "entry_a"), wgsl_hash("src", "entry_b"));
assert_eq!(wgsl_hash("same", "ep"), wgsl_hash("same", "ep"));
}
#[test]
fn pipeline_cache_hit_reuses_arc() {
oxiui_core::require_gpu!(ctx, ComputeContext::try_new());
let mut cache = PipelineCache::new();
let p1 = cache.get_or_compile(&ctx.device, PASSTHROUGH_WGSL, "main");
let p2 = cache.get_or_compile(&ctx.device, PASSTHROUGH_WGSL, "main");
assert!(Arc::ptr_eq(&p1, &p2), "cache hit must return the same Arc");
assert_eq!(cache.compile_count(), 1, "compiled only once");
assert_eq!(cache.len(), 1);
}
#[test]
fn pipeline_cache_miss_compiles_twice() {
oxiui_core::require_gpu!(ctx, ComputeContext::try_new());
const WGSL_B: &str = r#"
@group(0) @binding(0) var<storage, read_write> buf: array<u32>;
@compute @workgroup_size(1)
fn alt(@builtin(global_invocation_id) gid: vec3<u32>) {
buf[gid.x] = gid.x;
}
"#;
let mut cache = PipelineCache::new();
let _p1 = cache.get_or_compile(&ctx.device, PASSTHROUGH_WGSL, "main");
let _p2 = cache.get_or_compile(&ctx.device, WGSL_B, "alt");
assert_eq!(
cache.compile_count(),
2,
"two different shaders → two compilations"
);
assert_eq!(cache.len(), 2);
}
#[test]
fn dispatch_builder_doubling() {
use bytemuck::{Pod, Zeroable};
oxiui_core::require_gpu!(ctx, ComputeContext::try_new());
const DOUBLE_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read_write> data: array<f32>;
@compute @workgroup_size(64)
fn double_all(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
if i < arrayLength(&data) {
data[i] = data[i] * 2.0;
}
}
"#;
const N: usize = 128;
let input: Vec<f32> = (0..N as u32).map(|i| i as f32).collect();
let buf_size = (N * std::mem::size_of::<f32>()) as u64;
let storage_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("storage"),
size: buf_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: true,
});
{
let mut view = storage_buf.slice(..).get_mapped_range_mut();
view.copy_from_slice(bytemuck::cast_slice(&input));
}
storage_buf.unmap();
let layout = ctx
.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}],
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: storage_buf.as_entire_binding(),
}],
});
let pipeline_layout = ctx
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[Some(&layout)],
immediate_size: 0,
});
let shader = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(DOUBLE_WGSL.into()),
});
let pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("double_all"),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
let result = DispatchBuilder::new(&pipeline)
.bind(0, &bind_group)
.dispatch_1d(N as u32, 64)
.label("doubling-pass")
.submit(&ctx.device, &ctx.queue);
let staging = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("staging"),
size: buf_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("copy-back"),
});
encoder.copy_buffer_to_buffer(&storage_buf, 0, &staging, 0, buf_size);
ctx.queue.submit(std::iter::once(encoder.finish()));
let _ = ctx.device.poll(wgpu::PollType::Wait {
submission_index: Some(result.submission_index),
timeout: None,
});
let (tx, rx) = std::sync::mpsc::channel();
staging
.slice(..)
.map_async(wgpu::MapMode::Read, move |r| tx.send(r).unwrap());
let _ = ctx.device.poll(wgpu::PollType::wait_indefinitely());
rx.recv().unwrap().unwrap();
let output: Vec<f32> = {
let view = staging.slice(..).get_mapped_range();
bytemuck::cast_slice::<u8, f32>(&view).to_vec()
};
for (i, (&got, expected)) in output.iter().zip(input.iter().map(|v| v * 2.0)).enumerate() {
assert!(
(got - expected).abs() < 1e-6,
"index {i}: got {got}, expected {expected}"
);
}
let _: () = {
#[allow(dead_code)]
fn _use_pod<T: Pod + Zeroable>() {}
};
}
#[test]
fn checked_pipeline_valid_shader_ok() {
oxiui_core::require_gpu!(ctx, ComputeContext::try_new());
let result = checked_compute_pipeline(&ctx.device, PASSTHROUGH_WGSL, "main");
assert!(result.is_ok(), "valid shader must compile without error");
}
}