use std::collections::{HashMap, HashSet};
use std::sync::Mutex;
pub fn buf_id(b: &wgpu::Buffer) -> u64 {
let mut h = std::collections::hash_map::DefaultHasher::new();
std::hash::Hash::hash(b, &mut h);
std::hash::Hasher::finish(&h)
}
fn pipeline_id(p: &wgpu::ComputePipeline) -> u64 {
let mut h = std::collections::hash_map::DefaultHasher::new();
std::hash::Hash::hash(p, &mut h);
std::hash::Hasher::finish(&h)
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct CacheKey {
pub pipeline: u64,
pub b0: u64,
pub b1: u64,
pub b2: u64,
pub b3: Option<u64>,
pub b4: Option<u64>,
pub b5: Option<u64>,
pub b6: Option<u64>,
}
impl CacheKey {
pub fn one(p: &wgpu::ComputePipeline, b0: &wgpu::Buffer) -> Self {
Self {
pipeline: pipeline_id(p),
b0: buf_id(b0),
b1: 0,
b2: 0,
b3: None,
b4: None,
b5: None,
b6: None,
}
}
pub fn two(p: &wgpu::ComputePipeline, b0: &wgpu::Buffer, b1: &wgpu::Buffer) -> Self {
Self {
pipeline: pipeline_id(p),
b0: buf_id(b0),
b1: buf_id(b1),
b2: 0,
b3: None,
b4: None,
b5: None,
b6: None,
}
}
pub fn three(
p: &wgpu::ComputePipeline,
b0: &wgpu::Buffer,
b1: &wgpu::Buffer,
b2: &wgpu::Buffer,
) -> Self {
Self {
pipeline: pipeline_id(p),
b0: buf_id(b0),
b1: buf_id(b1),
b2: buf_id(b2),
b3: None,
b4: None,
b5: None,
b6: None,
}
}
pub fn four(
p: &wgpu::ComputePipeline,
b0: &wgpu::Buffer,
b1: &wgpu::Buffer,
b2: &wgpu::Buffer,
b3: &wgpu::Buffer,
) -> Self {
Self {
pipeline: pipeline_id(p),
b0: buf_id(b0),
b1: buf_id(b1),
b2: buf_id(b2),
b3: Some(buf_id(b3)),
b4: None,
b5: None,
b6: None,
}
}
pub fn five(
p: &wgpu::ComputePipeline,
b0: &wgpu::Buffer,
b1: &wgpu::Buffer,
b2: &wgpu::Buffer,
b3: &wgpu::Buffer,
b4: &wgpu::Buffer,
) -> Self {
Self {
pipeline: pipeline_id(p),
b0: buf_id(b0),
b1: buf_id(b1),
b2: buf_id(b2),
b3: Some(buf_id(b3)),
b4: Some(buf_id(b4)),
b5: None,
b6: None,
}
}
pub fn six(
p: &wgpu::ComputePipeline,
b0: &wgpu::Buffer,
b1: &wgpu::Buffer,
b2: &wgpu::Buffer,
b3: &wgpu::Buffer,
b4: &wgpu::Buffer,
b5: &wgpu::Buffer,
) -> Self {
Self {
pipeline: pipeline_id(p),
b0: buf_id(b0),
b1: buf_id(b1),
b2: buf_id(b2),
b3: Some(buf_id(b3)),
b4: Some(buf_id(b4)),
b5: Some(buf_id(b5)),
b6: None,
}
}
#[allow(clippy::too_many_arguments)]
pub fn seven(
p: &wgpu::ComputePipeline,
b0: &wgpu::Buffer,
b1: &wgpu::Buffer,
b2: &wgpu::Buffer,
b3: &wgpu::Buffer,
b4: &wgpu::Buffer,
b5: &wgpu::Buffer,
b6: &wgpu::Buffer,
) -> Self {
Self {
pipeline: pipeline_id(p),
b0: buf_id(b0),
b1: buf_id(b1),
b2: buf_id(b2),
b3: Some(buf_id(b3)),
b4: Some(buf_id(b4)),
b5: Some(buf_id(b5)),
b6: Some(buf_id(b6)),
}
}
fn touches(&self, ids: &HashSet<u64>) -> bool {
ids.contains(&self.b0)
|| ids.contains(&self.b1)
|| ids.contains(&self.b2)
|| self.b3.is_some_and(|id| ids.contains(&id))
|| self.b4.is_some_and(|id| ids.contains(&id))
|| self.b5.is_some_and(|id| ids.contains(&id))
|| self.b6.is_some_and(|id| ids.contains(&id))
}
}
#[derive(Clone)]
pub struct CachedDispatch {
pub uniform: wgpu::Buffer,
pub bind_group: wgpu::BindGroup,
}
pub struct BindGroupCache {
inner: Mutex<HashMap<CacheKey, CachedDispatch>>,
}
impl BindGroupCache {
pub fn new() -> Self {
Self {
inner: Mutex::new(HashMap::new()),
}
}
pub fn get_or_create<F>(&self, key: CacheKey, build: F) -> CachedDispatch
where
F: FnOnce() -> CachedDispatch,
{
let mut guard = self.inner.lock().unwrap();
guard.entry(key).or_insert_with(build).clone()
}
pub fn clear(&self) {
self.inner.lock().unwrap().clear();
}
pub fn invalidate_buffers(&self, ids: &[u64]) {
if ids.is_empty() {
return;
}
let id_set: HashSet<u64> = ids.iter().copied().collect();
let mut guard = self.inner.lock().unwrap();
#[cfg(not(target_arch = "wasm32"))]
let before = guard.len();
guard.retain(|k, _| !k.touches(&id_set));
#[cfg(not(target_arch = "wasm32"))]
if std::env::var("RULLAMA_TRACE_BINDCACHE").is_ok() {
eprintln!(
"[bindcache] invalidate_buffers: {} ids, removed {} entries ({} -> {})",
ids.len(),
before - guard.len(),
before,
guard.len()
);
}
}
pub fn len(&self) -> usize {
self.inner.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.inner.lock().unwrap().is_empty()
}
}
impl Default for BindGroupCache {
fn default() -> Self {
Self::new()
}
}