use super::{cache_key, PipelineCache as ComputePipelineCache};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{LazyLock, Mutex, RwLock};
use vyre::error::{Error, Result};
const PIPELINE_CACHE_SHARDS: usize = 32;
#[inline]
pub fn compile_compute_pipeline(
device: &wgpu::Device,
label: &str,
wgsl_source: &str,
entry_point: &str,
) -> Result<wgpu::ComputePipeline> {
compile_compute_pipeline_with_layout(device, label, wgsl_source, entry_point, None)
}
#[inline]
pub fn compile_compute_pipeline_with_layout(
device: &wgpu::Device,
label: &str,
wgsl_source: &str,
entry_point: &str,
layout: Option<&wgpu::PipelineLayout>,
) -> Result<wgpu::ComputePipeline> {
static PIPELINES: LazyLock<[RwLock<ComputePipelineCache>; PIPELINE_CACHE_SHARDS]> =
LazyLock::new(|| std::array::from_fn(|_| RwLock::new(ComputePipelineCache::new())));
let cache_key = cache_key(wgsl_source, entry_point);
let shard_index = pipeline_cache_shard(&cache_key);
{
let mut pipelines = PIPELINES[shard_index].write().map_err(|source| Error::Gpu {
message: format!(
"shader pipeline cache shard {shard_index} poisoned while reading '{label}' entry '{entry_point}': {source}. Fix: restart the process and inspect panics from other shader-compilation threads."
),
})?;
if let Some(pipeline) = pipelines.get(&cache_key) {
return Ok(pipeline);
}
}
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
});
let driver_cache = driver_pipeline_cache(device, label)?;
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(label),
layout,
module: &module,
entry_point: Some(entry_point),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: Some(&driver_cache),
});
let mut pipelines = PIPELINES[shard_index].write().map_err(|source| Error::Gpu {
message: format!(
"shader pipeline cache shard {shard_index} poisoned while storing '{label}' entry '{entry_point}': {source}. Fix: restart the process and inspect panics from other shader-compilation threads."
),
})?;
pipelines.insert(cache_key, pipeline.clone());
Ok(pipeline)
}
fn pipeline_cache_shard(key: &str) -> usize {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() as usize) % PIPELINE_CACHE_SHARDS
}
fn driver_pipeline_cache(device: &wgpu::Device, label: &str) -> Result<wgpu::PipelineCache> {
static DRIVER_CACHES: LazyLock<Mutex<HashMap<wgpu::Device, wgpu::PipelineCache>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
let mut caches = DRIVER_CACHES.lock().map_err(|source| Error::Gpu {
message: format!(
"wgpu pipeline cache mutex poisoned while compiling '{label}': {source}. Fix: restart the process and inspect panics from other shader-compilation threads."
),
})?;
if let Some(cache) = caches.get(device) {
return Ok(cache.clone());
}
let cache = {
#[allow(unsafe_code)]
unsafe {
device.create_pipeline_cache(&wgpu::PipelineCacheDescriptor {
label: Some("vyre wgpu pipeline cache"),
data: None,
fallback: true,
})
}
};
caches.insert(device.clone(), cache.clone());
Ok(cache)
}