runmat-accelerate 0.4.5

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
use smallvec::SmallVec;
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};

#[derive(Clone, Debug, Eq)]
struct BindingKey {
    binding: u32,
    buffer_ptr: usize,
    offset: u64,
    size: Option<u64>,
}

impl PartialEq for BindingKey {
    fn eq(&self, other: &Self) -> bool {
        self.binding == other.binding
            && self.buffer_ptr == other.buffer_ptr
            && self.offset == other.offset
            && self.size == other.size
    }
}

impl Hash for BindingKey {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.binding.hash(state);
        self.buffer_ptr.hash(state);
        self.offset.hash(state);
        self.size.hash(state);
    }
}

#[derive(Clone, Debug, Eq)]
struct BindGroupKey {
    layout_ptr: usize,
    bindings: SmallVec<[BindingKey; 8]>,
}

impl PartialEq for BindGroupKey {
    fn eq(&self, other: &Self) -> bool {
        self.layout_ptr == other.layout_ptr && self.bindings == other.bindings
    }
}

impl Hash for BindGroupKey {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.layout_ptr.hash(state);
        self.bindings.hash(state);
    }
}

pub struct BindGroupCache {
    inner: Mutex<HashMap<BindGroupKey, Arc<wgpu::BindGroup>>>,
    index: Mutex<HashMap<usize, HashSet<BindGroupKey>>>,
    hits: AtomicU64,
    misses: AtomicU64,
    per_layout: Mutex<HashMap<usize, (u64, u64)>>,
}

impl Default for BindGroupCache {
    fn default() -> Self {
        Self::new()
    }
}

impl BindGroupCache {
    pub fn new() -> Self {
        Self {
            inner: Mutex::new(HashMap::new()),
            index: Mutex::new(HashMap::new()),
            hits: AtomicU64::new(0),
            misses: AtomicU64::new(0),
            per_layout: Mutex::new(HashMap::new()),
        }
    }

    pub fn get_or_create<F>(
        &self,
        layout: &wgpu::BindGroupLayout,
        entries: &[wgpu::BindGroupEntry],
        create: F,
    ) -> Arc<wgpu::BindGroup>
    where
        F: FnOnce() -> Arc<wgpu::BindGroup>,
    {
        match Self::make_key(layout, entries) {
            Some(key) => {
                if let Some(existing) = self
                    .inner
                    .lock()
                    .ok()
                    .and_then(|map| map.get(&key).cloned())
                {
                    self.hits.fetch_add(1, Ordering::Relaxed);
                    log::debug!("bind_group_cache hit layout_ptr={:#x}", key.layout_ptr);
                    if let Ok(mut per) = self.per_layout.lock() {
                        let entry = per.entry(key.layout_ptr).or_insert((0, 0));
                        entry.0 = entry.0.saturating_add(1);
                    }
                    return existing;
                }
                let bind_group = create();
                let layout_ptr = key.layout_ptr;
                let key_clone = key.clone();
                if let Ok(mut map) = self.inner.lock() {
                    map.insert(key, bind_group.clone());
                }
                if let Ok(mut index) = self.index.lock() {
                    for binding in &key_clone.bindings {
                        index
                            .entry(binding.buffer_ptr)
                            .or_insert_with(HashSet::new)
                            .insert(key_clone.clone());
                    }
                }
                self.misses.fetch_add(1, Ordering::Relaxed);
                log::debug!("bind_group_cache miss layout_ptr={:#x}", layout_ptr);
                if let Ok(mut per) = self.per_layout.lock() {
                    let entry = per.entry(layout_ptr).or_insert((0, 0));
                    entry.1 = entry.1.saturating_add(1);
                }
                bind_group
            }
            None => {
                self.misses.fetch_add(1, Ordering::Relaxed);
                create()
            }
        }
    }

    pub fn counters(&self) -> (u64, u64) {
        (
            self.hits.load(Ordering::Relaxed),
            self.misses.load(Ordering::Relaxed),
        )
    }

    pub fn reset_counters(&self) {
        self.hits.store(0, Ordering::Relaxed);
        self.misses.store(0, Ordering::Relaxed);
        if let Ok(mut per) = self.per_layout.lock() {
            per.clear();
        }
    }

    pub fn per_layout_counters(&self) -> HashMap<usize, (u64, u64)> {
        self.per_layout
            .lock()
            .map(|m| m.clone())
            .unwrap_or_default()
    }

    pub fn invalidate_buffer(&self, buffer_ptr: usize) {
        let removed_keys = {
            let mut keys_to_remove: Vec<BindGroupKey> = Vec::new();
            if let Ok(mut index) = self.index.lock() {
                if let Some(keys_set) = index.remove(&buffer_ptr) {
                    keys_to_remove.extend(keys_set);
                    let mut empty_ptrs: HashSet<usize> = HashSet::new();
                    for key in &keys_to_remove {
                        for binding in &key.bindings {
                            if binding.buffer_ptr == buffer_ptr {
                                continue;
                            }
                            if let Some(set) = index.get_mut(&binding.buffer_ptr) {
                                set.remove(key);
                                if set.is_empty() {
                                    empty_ptrs.insert(binding.buffer_ptr);
                                }
                            }
                        }
                    }
                    for ptr in empty_ptrs {
                        index.remove(&ptr);
                    }
                } else {
                    return;
                }
            } else {
                return;
            }
            keys_to_remove
        };

        if let Ok(mut map) = self.inner.lock() {
            for key in removed_keys {
                map.remove(&key);
            }
        }
    }

    fn make_key(
        layout: &wgpu::BindGroupLayout,
        entries: &[wgpu::BindGroupEntry],
    ) -> Option<BindGroupKey> {
        let mut bindings: SmallVec<[BindingKey; 8]> = SmallVec::with_capacity(entries.len());
        for entry in entries {
            match &entry.resource {
                wgpu::BindingResource::Buffer(buffer_binding) => {
                    let buffer_ptr = buffer_binding.buffer as *const wgpu::Buffer as usize;
                    let size = buffer_binding.size.map(|v| v.get());
                    log::debug!(
                        "bind_group_cache key binding={} ptr={:#x} offset={} size={:?}",
                        entry.binding,
                        buffer_ptr,
                        buffer_binding.offset,
                        size
                    );
                    bindings.push(BindingKey {
                        binding: entry.binding,
                        buffer_ptr,
                        offset: buffer_binding.offset,
                        size,
                    });
                }
                other => {
                    log::debug!(
                        "bind_group_cache unsupported resource at binding {}: {:?}",
                        entry.binding,
                        other
                    );
                    return None;
                }
            }
        }
        Some(BindGroupKey {
            layout_ptr: layout as *const _ as usize,
            bindings,
        })
    }
}