use parking_lot::Mutex;
use std::collections::HashMap;
use std::sync::Arc;
use wgpu::{
BindGroup, BindGroupDescriptor, BindGroupEntry, BindGroupLayout, BindGroupLayoutDescriptor,
BindGroupLayoutEntry, BindingType, Buffer, BufferBindingType, ComputePipeline,
ComputePipelineDescriptor, Device, PipelineLayoutDescriptor, Queue, ShaderModule,
ShaderModuleDescriptor, ShaderSource, ShaderStages,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
pub const WORKGROUP_SIZE: u32 = 256;
pub struct PipelineCache {
device: Arc<Device>,
#[allow(dead_code)]
queue: Arc<Queue>,
modules: Mutex<HashMap<&'static str, Arc<ShaderModule>>>,
dynamic_modules: Mutex<HashMap<String, Arc<ShaderModule>>>,
pipelines: Mutex<HashMap<(&'static str, &'static str), Arc<ComputePipeline>>>,
dynamic_pipelines: Mutex<HashMap<(String, String), Arc<ComputePipeline>>>,
layouts: Mutex<HashMap<LayoutKey, Arc<BindGroupLayout>>>,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct LayoutKey {
pub num_storage_buffers: u32,
pub num_uniform_buffers: u32,
pub num_readonly_storage: u32,
}
impl PipelineCache {
pub fn new(device: Arc<Device>, queue: Arc<Queue>) -> Self {
Self {
device,
queue,
modules: Mutex::new(HashMap::new()),
dynamic_modules: Mutex::new(HashMap::new()),
pipelines: Mutex::new(HashMap::new()),
dynamic_pipelines: Mutex::new(HashMap::new()),
layouts: Mutex::new(HashMap::new()),
}
}
pub fn get_or_create_module(&self, name: &'static str, source: &str) -> Arc<ShaderModule> {
let mut modules = self.modules.lock();
if let Some(module) = modules.get(name) {
return module.clone();
}
let module = self.device.create_shader_module(ShaderModuleDescriptor {
label: Some(name),
source: ShaderSource::Wgsl(source.into()),
});
let module = Arc::new(module);
modules.insert(name, module.clone());
module
}
pub fn get_or_create_module_from_source(&self, name: &str, source: &str) -> Arc<ShaderModule> {
let mut modules = self.dynamic_modules.lock();
if let Some(module) = modules.get(name) {
return module.clone();
}
let module = self.device.create_shader_module(ShaderModuleDescriptor {
label: Some(name),
source: ShaderSource::Wgsl(source.into()),
});
let module = Arc::new(module);
modules.insert(name.to_string(), module.clone());
module
}
pub fn get_or_create_pipeline(
&self,
shader_name: &'static str,
entry_point: &'static str,
module: &ShaderModule,
layout: &BindGroupLayout,
) -> Arc<ComputePipeline> {
let key = (shader_name, entry_point);
let mut pipelines = self.pipelines.lock();
if let Some(pipeline) = pipelines.get(&key) {
return pipeline.clone();
}
let pipeline_layout = self
.device
.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some(&format!("{}_layout", shader_name)),
bind_group_layouts: &[layout],
immediate_size: 0, });
let pipeline = self
.device
.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some(&format!("{}_{}", shader_name, entry_point)),
layout: Some(&pipeline_layout),
module,
entry_point: Some(entry_point),
compilation_options: Default::default(),
cache: None,
});
let pipeline = Arc::new(pipeline);
pipelines.insert(key, pipeline.clone());
pipeline
}
pub fn get_or_create_dynamic_pipeline(
&self,
shader_name: &str,
entry_point: &str,
module: &ShaderModule,
layout: &BindGroupLayout,
) -> Arc<ComputePipeline> {
let key = (shader_name.to_string(), entry_point.to_string());
let mut pipelines = self.dynamic_pipelines.lock();
if let Some(pipeline) = pipelines.get(&key) {
return pipeline.clone();
}
let pipeline_layout = self
.device
.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some(&format!("{}_layout", shader_name)),
bind_group_layouts: &[layout],
immediate_size: 0,
});
let pipeline = self
.device
.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some(&format!("{}_{}", shader_name, entry_point)),
layout: Some(&pipeline_layout),
module,
entry_point: Some(entry_point),
compilation_options: Default::default(),
cache: None,
});
let pipeline = Arc::new(pipeline);
pipelines.insert(key, pipeline.clone());
pipeline
}
pub fn get_or_create_layout(&self, key: LayoutKey) -> Arc<BindGroupLayout> {
let mut layouts = self.layouts.lock();
if let Some(layout) = layouts.get(&key) {
return layout.clone();
}
let mut entries = Vec::new();
for i in 0..key.num_storage_buffers {
let read_only = i < key.num_readonly_storage;
entries.push(BindGroupLayoutEntry {
binding: i,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
});
}
for i in 0..key.num_uniform_buffers {
entries.push(BindGroupLayoutEntry {
binding: key.num_storage_buffers + i,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
});
}
let layout = self
.device
.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("linalg_layout"),
entries: &entries,
});
let layout = Arc::new(layout);
layouts.insert(key, layout.clone());
layout
}
pub fn create_bind_group(&self, layout: &BindGroupLayout, buffers: &[&Buffer]) -> BindGroup {
let entries: Vec<BindGroupEntry> = buffers
.iter()
.enumerate()
.map(|(i, buffer)| BindGroupEntry {
binding: i as u32,
resource: buffer.as_entire_binding(),
})
.collect();
self.device.create_bind_group(&BindGroupDescriptor {
label: Some("compute_bind_group"),
layout,
entries: &entries,
})
}
pub fn device(&self) -> &Device {
&self.device
}
}
#[inline]
pub fn workgroup_count(n: usize) -> u32 {
((n as u32) + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE
}
#[allow(dead_code)]
pub fn dtype_suffix(dtype: DType) -> Result<&'static str> {
match dtype {
DType::F32 => Ok("f32"),
DType::F64 => Err(Error::UnsupportedDType {
dtype,
op: "WGSL (f64 not supported in WebGPU)",
}),
_ => Err(Error::UnsupportedDType {
dtype,
op: "linalg",
}),
}
}
#[allow(dead_code)]
pub fn entry_point(op: &str, dtype: DType) -> Result<String> {
let suffix = dtype_suffix(dtype)?;
Ok(format!("{}_{}", op, suffix))
}