runmat-accelerate 0.4.5

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
use std::collections::{HashMap, VecDeque};
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};

#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub enum BufferUsageClass {
    Generic,
    MatmulPartial,
    MatmulOut,
    SyrkOut,
    FusionOut,
}

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
struct ResidencyKey {
    usage: BufferUsageClass,
    len: usize,
}

impl ResidencyKey {
    fn new(usage: BufferUsageClass, len: usize) -> Self {
        Self { usage, len }
    }
}

impl Hash for ResidencyKey {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.usage.hash(state);
        self.len.hash(state);
    }
}

pub struct BufferResidency {
    pools: Mutex<HashMap<ResidencyKey, VecDeque<Arc<wgpu::Buffer>>>>,
    max_per_key: usize,
}

impl BufferResidency {
    pub fn new(max_per_key: usize) -> Self {
        Self {
            pools: Mutex::new(HashMap::new()),
            max_per_key,
        }
    }

    pub fn acquire(
        &self,
        device: &wgpu::Device,
        usage: BufferUsageClass,
        len: usize,
        element_size: usize,
        label: &str,
    ) -> (Arc<wgpu::Buffer>, bool) {
        if len == 0 {
            return (
                Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
                    label: Some(label),
                    size: element_size.max(1) as u64,
                    usage: wgpu::BufferUsages::STORAGE
                        | wgpu::BufferUsages::COPY_SRC
                        | wgpu::BufferUsages::COPY_DST,
                    mapped_at_creation: false,
                })),
                false,
            );
        }

        let key = ResidencyKey::new(usage, len);
        if let Ok(mut guard) = self.pools.lock() {
            if let Some(queue) = guard.get_mut(&key) {
                if let Some(buffer) = queue.pop_front() {
                    log::trace!(
                        "buffer_residency: reuse {:?} len={} ptr={:p}",
                        usage,
                        len,
                        Arc::as_ptr(&buffer)
                    );
                    return (buffer, true);
                }
            }
        }

        let size_bytes = (len as u64).max(1) * element_size as u64;
        let buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
            label: Some(label),
            size: size_bytes,
            usage: wgpu::BufferUsages::STORAGE
                | wgpu::BufferUsages::COPY_SRC
                | wgpu::BufferUsages::COPY_DST,
            mapped_at_creation: false,
        }));
        log::trace!(
            "buffer_residency: new {:?} len={} ptr={:p}",
            usage,
            len,
            Arc::as_ptr(&buffer)
        );
        (buffer, false)
    }

    pub fn release(&self, usage: BufferUsageClass, len: usize, buffer: Arc<wgpu::Buffer>) {
        if len == 0 {
            return;
        }

        let key = ResidencyKey::new(usage, len);
        if let Ok(mut guard) = self.pools.lock() {
            let queue = guard.entry(key).or_insert_with(VecDeque::new);
            if queue.len() < self.max_per_key {
                log::trace!(
                    "buffer_residency: release {:?} len={} ptr={:p}",
                    usage,
                    len,
                    Arc::as_ptr(&buffer)
                );
                queue.push_back(buffer);
            } else {
                log::trace!(
                    "buffer_residency: drop {:?} len={} ptr={:p} (pool full)",
                    usage,
                    len,
                    Arc::as_ptr(&buffer)
                );
            }
        }
    }
}