vyre-wgpu 0.1.0

wgpu backend for vyre IR — implements VyreBackend, owns GPU runtime, buffer pool, pipeline cache
Documentation
use super::{cache_key, PipelineCache as ComputePipelineCache};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{LazyLock, Mutex, RwLock};
use vyre::error::{Error, Result};

const PIPELINE_CACHE_SHARDS: usize = 32;

/// Compile a WGSL compute shader into a `wgpu` compute pipeline.
///
/// This helper manages a static process-wide cache of compiled pipelines.
///
/// # Errors
///
/// Returns an error if the shader module cannot be created or if the
/// pipeline compilation fails on the GPU.
///
/// # Examples
///
/// ```ignore
/// use vyre::runtime::shader::compile_compute_pipeline;
///
/// // Requires a live wgpu device.
/// // let pipeline = compile_compute_pipeline(&device, "my_shader", "@compute fn main() {}", "main")?;
/// ```
#[inline]
pub fn compile_compute_pipeline(
    device: &wgpu::Device,
    label: &str,
    wgsl_source: &str,
    entry_point: &str,
) -> Result<wgpu::ComputePipeline> {
    compile_compute_pipeline_with_layout(device, label, wgsl_source, entry_point, None)
}

/// Compile a WGSL compute shader with an explicit pipeline layout.
///
/// # Errors
///
/// Returns an actionable GPU error if the shader or pipeline cache locks are
/// poisoned.
#[inline]
pub fn compile_compute_pipeline_with_layout(
    device: &wgpu::Device,
    label: &str,
    wgsl_source: &str,
    entry_point: &str,
    layout: Option<&wgpu::PipelineLayout>,
) -> Result<wgpu::ComputePipeline> {
    static PIPELINES: LazyLock<[RwLock<ComputePipelineCache>; PIPELINE_CACHE_SHARDS]> =
        LazyLock::new(|| std::array::from_fn(|_| RwLock::new(ComputePipelineCache::new())));

    let cache_key = cache_key(wgsl_source, entry_point);
    let shard_index = pipeline_cache_shard(&cache_key);
    {
        let mut pipelines = PIPELINES[shard_index].write().map_err(|source| Error::Gpu {
            message: format!(
                "shader pipeline cache shard {shard_index} poisoned while reading '{label}' entry '{entry_point}': {source}. Fix: restart the process and inspect panics from other shader-compilation threads."
            ),
        })?;
        if let Some(pipeline) = pipelines.get(&cache_key) {
            return Ok(pipeline);
        }
    }

    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
        label: Some(label),
        source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
    });
    let driver_cache = driver_pipeline_cache(device, label)?;
    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
        label: Some(label),
        layout,
        module: &module,
        entry_point: Some(entry_point),
        compilation_options: wgpu::PipelineCompilationOptions::default(),
        cache: Some(&driver_cache),
    });
    let mut pipelines = PIPELINES[shard_index].write().map_err(|source| Error::Gpu {
        message: format!(
            "shader pipeline cache shard {shard_index} poisoned while storing '{label}' entry '{entry_point}': {source}. Fix: restart the process and inspect panics from other shader-compilation threads."
        ),
    })?;
    pipelines.insert(cache_key, pipeline.clone());
    Ok(pipeline)
}

fn pipeline_cache_shard(key: &str) -> usize {
    let mut hasher = std::collections::hash_map::DefaultHasher::new();
    key.hash(&mut hasher);
    (hasher.finish() as usize) % PIPELINE_CACHE_SHARDS
}

fn driver_pipeline_cache(device: &wgpu::Device, label: &str) -> Result<wgpu::PipelineCache> {
    static DRIVER_CACHES: LazyLock<Mutex<HashMap<wgpu::Device, wgpu::PipelineCache>>> =
        LazyLock::new(|| Mutex::new(HashMap::new()));

    let mut caches = DRIVER_CACHES.lock().map_err(|source| Error::Gpu {
        message: format!(
            "wgpu pipeline cache mutex poisoned while compiling '{label}': {source}. Fix: restart the process and inspect panics from other shader-compilation threads."
        ),
    })?;
    if let Some(cache) = caches.get(device) {
        return Ok(cache.clone());
    }
    // wgpu::Device::create_pipeline_cache is an unsafe API because a caller can
    // feed driver-cache bytes from an untrusted source, which drivers may
    // deserialize without validation. We tightly bound the unsafety:
    //
    //   * `data: None`        — no external bytes are supplied.
    //   * `fallback: true`    — the driver creates an empty cache if the
    //                           platform cannot use persisted data.
    //
    // Under those two constraints the call is sound. The allow is scoped to
    // the single expression; the function body otherwise obeys the workspace
    // `unsafe_code = "deny"` lint.
    let cache = {
        #[allow(unsafe_code)]
        // SAFETY: data=None forbids untrusted bytes; fallback=true lets wgpu
        // substitute an empty cache on backends without pipeline-cache support.
        unsafe {
            device.create_pipeline_cache(&wgpu::PipelineCacheDescriptor {
                label: Some("vyre wgpu pipeline cache"),
                data: None,
                fallback: true,
            })
        }
    };
    caches.insert(device.clone(), cache.clone());
    Ok(cache)
}