runmat-accelerate 0.4.5

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
use serde::{de::DeserializeOwned, Serialize};
use std::collections::HashMap;
use std::hash::Hash;
use std::path::PathBuf;
use std::sync::Mutex;

pub struct AutotuneController<K, V> {
    enabled: bool,
    cache: Mutex<HashMap<K, V>>,
    json_path: Option<PathBuf>,
}

impl<K, V> AutotuneController<K, V>
where
    K: Eq + Hash + Clone + Serialize + DeserializeOwned,
    V: Copy + Serialize + DeserializeOwned + PartialEq,
{
    pub fn new_from_env(
        var: &str,
        kernel: &str,
        base_dir: Option<PathBuf>,
        device_tag: &str,
    ) -> Self {
        let enabled = if cfg!(target_arch = "wasm32") {
            false
        } else {
            std::env::var(var)
                .map(|v| {
                    matches!(
                        v.trim().to_ascii_lowercase().as_str(),
                        "1" | "true" | "yes" | "on"
                    )
                })
                .unwrap_or(false)
        };
        let json_path = if enabled {
            base_dir.map(|mut dir| {
                dir.push("autotune");
                dir.push(kernel);
                dir.push(format!("{device_tag}.json"));
                dir
            })
        } else {
            None
        };
        let controller = Self {
            enabled,
            cache: Mutex::new(HashMap::new()),
            json_path,
        };
        if controller.enabled {
            controller.load_from_disk().ok();
        }
        controller
    }

    pub fn is_enabled(&self) -> bool {
        self.enabled
    }

    pub fn get(&self, key: &K) -> Option<V> {
        self.cache.lock().ok()?.get(key).copied()
    }

    pub fn insert(&self, key: K, value: V) {
        if let Ok(mut guard) = self.cache.lock() {
            let needs_flush = !matches!(guard.get(&key), Some(existing) if *existing == value);
            if needs_flush {
                guard.insert(key, value);
                self.save_to_disk(&guard).ok();
            }
        }
    }

    fn load_from_disk(&self) -> std::io::Result<()> {
        let path = match (&self.json_path, self.enabled) {
            (Some(p), true) => p,
            _ => return Ok(()),
        };
        if !path.exists() {
            return Ok(());
        }
        let data = std::fs::read_to_string(path)?;
        let entries: Vec<(K, V)> = serde_json::from_str(&data).unwrap_or_default();
        if let Ok(mut guard) = self.cache.lock() {
            guard.clear();
            for (k, v) in entries {
                guard.insert(k, v);
            }
        }
        Ok(())
    }

    fn save_to_disk(&self, guard: &HashMap<K, V>) -> std::io::Result<()> {
        let path = match (&self.json_path, self.enabled) {
            (Some(p), true) => p,
            _ => return Ok(()),
        };
        if let Some(parent) = path.parent() {
            std::fs::create_dir_all(parent)?;
        }
        let entries: Vec<(&K, &V)> = guard.iter().collect();
        let payload = serde_json::to_string_pretty(&entries).unwrap_or_else(|_| "[]".to_string());
        std::fs::write(path, payload)?;
        log::info!("autotune cache saved to {:?}", path);
        Ok(())
    }
}