use std::sync::{Arc, OnceLock};
use papaya::HashMap;
use svod_device::device::Program;
use svod_ir::UOp;
pub struct CachedKernel {
pub program: Box<dyn Program>,
pub device: String,
pub code: String,
pub entry_point: String,
pub var_names: Vec<String>,
pub globals: Vec<usize>,
pub outs: Vec<usize>,
pub ins: Vec<usize>,
pub host_parallel_safe: bool,
pub global_size: [Arc<UOp>; 3],
pub local_size: Option<[Arc<UOp>; 3]>,
}
type KernelKey = (u64, String);
static KERNELS: OnceLock<HashMap<KernelKey, Arc<CachedKernel>>> = OnceLock::new();
fn kernels() -> &'static HashMap<KernelKey, Arc<CachedKernel>> {
KERNELS.get_or_init(HashMap::new)
}
pub fn get_or_compile_kernel<F, E>(ast_id: u64, device: &str, compile_fn: F) -> Result<Arc<CachedKernel>, E>
where
F: FnOnce() -> Result<CachedKernel, E>,
{
let key = (ast_id, device.to_string());
let map = kernels();
let guard = map.guard();
if let Some(cached) = map.get(&key, &guard) {
return Ok(Arc::clone(cached));
}
let compiled = compile_fn()?;
let cached = Arc::new(compiled);
use papaya::{Compute, Operation};
match map.compute(
key,
|entry| match entry {
Some((_, existing)) => Operation::Abort(Arc::clone(existing)),
None => Operation::Insert(Arc::clone(&cached)),
},
&guard,
) {
Compute::Inserted(_, kernel) => Ok(Arc::clone(kernel)),
Compute::Aborted(kernel) => Ok(kernel),
_ => Ok(cached),
}
}
pub fn clear_all() {
let guard = kernels().guard();
kernels().clear(&guard);
}
pub fn gc_unused_kernels(live_ids: &std::collections::HashSet<u64>) {
let map = kernels();
let guard = map.guard();
let to_remove: Vec<KernelKey> =
map.iter(&guard).filter(|((ast_id, _), _)| !live_ids.contains(ast_id)).map(|(k, _)| k.clone()).collect();
for key in to_remove {
map.remove(&key, &guard);
}
}