use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
pub struct ShaderCache {
modules: HashMap<u64, wgpu::ShaderModule>,
}
impl ShaderCache {
#[must_use]
pub fn new() -> Self {
Self {
modules: HashMap::new(),
}
}
pub fn get_or_compile(
&mut self,
device: &wgpu::Device,
wgsl_source: &str,
label: &str,
) -> &wgpu::ShaderModule {
let hash = Self::hash_source(wgsl_source);
self.modules.entry(hash).or_insert_with(|| {
tracing::debug!(label, hash, "compiling and caching shader module");
device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
})
})
}
#[must_use]
pub fn contains(&self, wgsl_source: &str) -> bool {
self.modules.contains_key(&Self::hash_source(wgsl_source))
}
pub fn invalidate(&mut self, wgsl_source: &str) -> bool {
let hash = Self::hash_source(wgsl_source);
let removed = self.modules.remove(&hash).is_some();
if removed {
tracing::debug!(hash, "shader module invalidated");
}
removed
}
pub fn clear(&mut self) {
tracing::debug!(count = self.modules.len(), "clearing shader cache");
self.modules.clear();
}
#[must_use]
#[inline]
pub fn len(&self) -> usize {
self.modules.len()
}
#[must_use]
#[inline]
pub fn is_empty(&self) -> bool {
self.modules.is_empty()
}
fn hash_source(source: &str) -> u64 {
let mut hasher = DefaultHasher::new();
source.hash(&mut hasher);
hasher.finish()
}
}
impl Default for ShaderCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cache_empty() {
let cache = ShaderCache::new();
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
}
#[test]
fn cache_default() {
let cache = ShaderCache::default();
assert!(cache.is_empty());
}
#[test]
fn hash_deterministic() {
let source = "@compute @workgroup_size(64) fn main() {}";
let h1 = ShaderCache::hash_source(source);
let h2 = ShaderCache::hash_source(source);
assert_eq!(h1, h2);
}
#[test]
fn hash_different_sources() {
let h1 = ShaderCache::hash_source("fn a() {}");
let h2 = ShaderCache::hash_source("fn b() {}");
assert_ne!(h1, h2);
}
#[test]
fn contains_before_compile() {
let cache = ShaderCache::new();
assert!(!cache.contains("fn main() {}"));
}
fn try_gpu() -> Option<wgpu::Device> {
let ctx = pollster::block_on(crate::context::GpuContext::new()).ok()?;
Some(ctx.device)
}
const TEST_SHADER: &str = "@compute @workgroup_size(64) fn main() {}";
#[test]
fn gpu_compile_and_cache() {
let Some(device) = try_gpu() else { return };
let mut cache = ShaderCache::new();
let _module = cache.get_or_compile(&device, TEST_SHADER, "test");
assert!(cache.contains(TEST_SHADER));
assert_eq!(cache.len(), 1);
}
#[test]
fn gpu_deduplication() {
let Some(device) = try_gpu() else { return };
let mut cache = ShaderCache::new();
let _m1 = cache.get_or_compile(&device, TEST_SHADER, "first");
let _m2 = cache.get_or_compile(&device, TEST_SHADER, "second");
assert_eq!(cache.len(), 1); }
#[test]
fn gpu_invalidate() {
let Some(device) = try_gpu() else { return };
let mut cache = ShaderCache::new();
let _m = cache.get_or_compile(&device, TEST_SHADER, "test");
assert!(cache.invalidate(TEST_SHADER));
assert!(!cache.contains(TEST_SHADER));
assert!(!cache.invalidate(TEST_SHADER)); }
#[test]
fn gpu_clear() {
let Some(device) = try_gpu() else { return };
let mut cache = ShaderCache::new();
let _m = cache.get_or_compile(&device, TEST_SHADER, "test");
cache.clear();
assert!(cache.is_empty());
}
}