runmat-accelerate 0.4.5

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
use std::panic::{self, AssertUnwindSafe};
use std::path::Path;
use std::sync::Arc;

use super::bindings::build_bgl_for_layout_tag;
use super::cache::persist::PipelineMeta;
use super::cache::persist::PIPELINE_CACHE_VERSION;
use super::types::NumericPrecision;

pub fn warmup_from_disk<FHash, FCreate, FNoop>(
    device: &wgpu::Device,
    cache_dir: Option<&Path>,
    target_precision: NumericPrecision,
    compute_hash: FHash,
    get_or_create: FCreate,
    after_create_noop: FNoop,
) where
    FHash: Fn(&[u8], &str, Option<u32>) -> u64,
    FCreate: Fn(
        u64,
        &wgpu::PipelineLayout,
        &wgpu::ShaderModule,
        &str,
        Option<&[u8]>,
        Option<&str>,
        Option<u32>,
    ) -> Arc<wgpu::ComputePipeline>,
    FNoop: Fn(&wgpu::ComputePipeline),
{
    let Some(dir) = cache_dir else {
        return;
    };
    let Ok(rd) = std::fs::read_dir(dir) else {
        return;
    };
    let mut compiled = 0usize;
    for entry in rd.flatten() {
        let path = entry.path();
        if path.extension().and_then(|e| e.to_str()) != Some("json") {
            continue;
        }
        let stem = match path.file_stem().and_then(|s| s.to_str()) {
            Some(s) => s,
            None => continue,
        };
        let meta_bytes = match std::fs::read(&path) {
            Ok(b) => b,
            Err(_) => continue,
        };
        let meta: PipelineMeta = match serde_json::from_slice(&meta_bytes) {
            Ok(m) => m,
            Err(_) => continue,
        };
        // Skip stale or incompatible cache entries silently
        if meta.version.unwrap_or(0) != PIPELINE_CACHE_VERSION {
            continue;
        }
        match meta.precision.as_deref() {
            Some(stored) if stored == target_precision.as_str() => {}
            Some(_) => {
                continue;
            }
            None => {
                // Missing precision metadata (likely stale entry); skip
                continue;
            }
        }
        let layout_tag = match meta.layout_tag.as_deref() {
            Some(t) => t,
            None => continue,
        };
        let wgsl_path = dir.join(format!("{stem}.wgsl"));
        let wgsl_bytes = match std::fs::read(&wgsl_path) {
            Ok(b) => b,
            Err(_) => continue,
        };
        let wgsl_str = match std::str::from_utf8(&wgsl_bytes) {
            Ok(s) => s,
            Err(_) => continue,
        };
        let bgl = match build_bgl_for_layout_tag(device, layout_tag) {
            Some(b) => b,
            None => continue,
        };
        let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
            label: Some("warmup-pipeline-layout"),
            bind_group_layouts: &[&bgl],
            push_constant_ranges: &[],
        });
        // Apply the @WG@ substitution used by regular pipeline creation
        let module = crate::backend::wgpu::pipelines::create_shader_module(
            device,
            "warmup-shader-module",
            wgsl_str,
        );
        let key = compute_hash(&wgsl_bytes, layout_tag, meta.workgroup_size);
        let compiled_pipeline = panic::catch_unwind(AssertUnwindSafe(|| {
            let pipeline = get_or_create(
                key,
                &pl,
                &module,
                "warmup-precompiled-pipeline",
                Some(&wgsl_bytes),
                Some(layout_tag),
                meta.workgroup_size,
            );
            after_create_noop(&pipeline);
        }));
        match compiled_pipeline {
            Ok(_) => {
                compiled += 1;
            }
            Err(_) => {
                log::warn!(
                    "warmup: failed to precompile pipeline {}; removing incompatible cache entry",
                    stem
                );
                let _ = std::fs::remove_file(&path);
                let _ = std::fs::remove_file(&wgsl_path);
                continue;
            }
        }
    }
    if compiled > 0 {
        log::info!(
            "warmup: precompiled {} pipelines from on-disk cache",
            compiled
        );
    }
}

pub fn noop_after_create(
    device: &wgpu::Device,
    queue: &wgpu::Queue,
    pipeline: &wgpu::ComputePipeline,
) {
    let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
        label: Some("warmup-noop-precompiled-enc"),
    });
    {
        let mut pass = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
            label: Some("warmup-noop-precompiled-pass"),
            timestamp_writes: None,
        });
        pass.set_pipeline(pipeline);
    }
    queue.submit(Some(enc.finish()));
    device.poll(wgpu::Maintain::Wait);
}