Skip to main content

svod_runtime/
kernel_cache.rs

1//! Global kernel deduplication cache.
2//!
3//! This module provides a global concurrent cache that maps (UOp ID, device) pairs to compiled kernels.
4//! Uses papaya's lock-free HashMap for thread-safe access across parallel tensor operations.
5//!
6//! # Thread Safety
7//!
8//! All operations are thread-safe. Multiple threads can look up and compile kernels
9//! concurrently without explicit synchronization.
10//!
11//! # Deduplication
12//!
13//! Thanks to hash consing in `ir/src/uop/hash_consing.rs`, identical ASTs automatically
14//! have identical IDs, making kernel deduplication trivial. The key includes both the
15//! AST ID and the device string to support multi-GPU systems where the same kernel
16//! might be compiled differently for different devices.
17
18use std::sync::{Arc, OnceLock};
19
20use papaya::HashMap;
21use svod_device::device::Program;
22use svod_ir::UOp;
23
24/// Cached kernel that can be reused across tensors.
25///
26/// Note: This struct does not implement Clone because `Box<dyn Program>` is not Clone.
27/// Use `Arc<CachedKernel>` for sharing.
28pub struct CachedKernel {
29    /// The compiled, executable program.
30    pub program: Box<dyn Program>,
31    /// Device string (e.g., "CPU", "CUDA:0").
32    pub device: String,
33    /// Generated source code (for debugging/profiling).
34    pub code: String,
35    /// Entry point name.
36    pub entry_point: String,
37    /// Variable names in order for converting HashMap to positional vals.
38    /// Matches the order expected by the compiled program.
39    pub var_names: Vec<String>,
40    /// Global buffer slots in kernel argument order.
41    /// Matches Tinygrad's ProgramSpec.globals semantics.
42    pub globals: Vec<usize>,
43    /// Output buffer slots written by STORE operations.
44    /// Matches Tinygrad's ProgramSpec.outs semantics.
45    pub outs: Vec<usize>,
46    /// Input buffer slots read by LOAD operations.
47    /// Matches Tinygrad's ProgramSpec.ins semantics.
48    pub ins: Vec<usize>,
49    /// Whether host-level scheduling may overlap this program with other kernels.
50    ///
51    /// Thread-safety is required by the `Program` trait. This flag is about
52    /// backend/kernel semantics, not Rust synchronization safety.
53    pub host_parallel_safe: bool,
54    /// Symbolic global work size evaluated with runtime vars before dispatch.
55    pub global_size: [Arc<UOp>; 3],
56    /// Symbolic local work size. None means direct global-id execution.
57    pub local_size: Option<[Arc<UOp>; 3]>,
58}
59
60/// Cache key: (AST ID, device string).
61///
62/// Using both AST ID and device allows the same logical kernel to be compiled
63/// differently for different devices (e.g., CPU vs CUDA, or CUDA:0 vs CUDA:1).
64type KernelKey = (u64, String);
65
66// Global kernel dedup cache using lock-free concurrent HashMap.
67//
68// Maps (UOp ID, device) -> Arc<CachedKernel>.
69// Kernels live until explicitly cleared via clear_all().
70static KERNELS: OnceLock<HashMap<KernelKey, Arc<CachedKernel>>> = OnceLock::new();
71
72fn kernels() -> &'static HashMap<KernelKey, Arc<CachedKernel>> {
73    KERNELS.get_or_init(HashMap::new)
74}
75
76/// Get or compile a kernel by UOp ID and device.
77///
78/// Thread-safe: if multiple threads call this with the same key concurrently,
79/// exactly one will compile the kernel, and all others will receive a clone
80/// of the Arc to that kernel.
81///
82/// # Arguments
83///
84/// * `ast_id` - The UOp ID of the kernel AST (from hash consing)
85/// * `device` - Device string (e.g., "CPU", "CUDA:0")
86/// * `compile_fn` - Function to compile the kernel if not cached
87///
88/// # Returns
89///
90/// Arc to the cached kernel (either from cache or freshly compiled).
91///
92/// # Errors
93///
94/// Returns error if compilation fails
95pub fn get_or_compile_kernel<F, E>(ast_id: u64, device: &str, compile_fn: F) -> Result<Arc<CachedKernel>, E>
96where
97    F: FnOnce() -> Result<CachedKernel, E>,
98{
99    let key = (ast_id, device.to_string());
100    let map = kernels();
101    let guard = map.guard();
102
103    // Fast path: kernel already cached
104    if let Some(cached) = map.get(&key, &guard) {
105        return Ok(Arc::clone(cached));
106    }
107
108    // Slow path: compile kernel (expensive)
109    let compiled = compile_fn()?;
110    let cached = Arc::new(compiled);
111
112    // Atomic insert - if another thread beat us, use their kernel
113    use papaya::{Compute, Operation};
114    match map.compute(
115        key,
116        |entry| match entry {
117            Some((_, existing)) => Operation::Abort(Arc::clone(existing)),
118            None => Operation::Insert(Arc::clone(&cached)),
119        },
120        &guard,
121    ) {
122        Compute::Inserted(_, kernel) => Ok(Arc::clone(kernel)),
123        Compute::Aborted(kernel) => Ok(kernel),
124        _ => Ok(cached),
125    }
126}
127
128/// Clear all cached kernels.
129///
130/// This is primarily useful for testing to ensure test isolation.
131/// Thread-safe.
132pub fn clear_all() {
133    let guard = kernels().guard();
134    kernels().clear(&guard);
135}
136
137/// Remove kernels whose AST IDs are no longer in the live UOp set.
138///
139/// Call this after `gc_dead_refs()` to clean up compiled kernels for
140/// discarded UOps. This prevents kernel cache memory accumulation during
141/// beam search and other optimization passes.
142///
143/// # Arguments
144///
145/// * `live_ids` - Set of UOp IDs that are still alive in the UOp cache
146///
147/// # Example
148///
149/// ```ignore
150/// svod_ir::uop::gc_dead_refs();
151/// let live_ids = svod_ir::uop::live_uop_ids();
152/// svod_runtime::kernel_cache::gc_unused_kernels(&live_ids);
153/// ```
154pub fn gc_unused_kernels(live_ids: &std::collections::HashSet<u64>) {
155    let map = kernels();
156    let guard = map.guard();
157
158    // Collect keys to remove (can't mutate while iterating)
159    let to_remove: Vec<KernelKey> =
160        map.iter(&guard).filter(|((ast_id, _), _)| !live_ids.contains(ast_id)).map(|(k, _)| k.clone()).collect();
161
162    // Remove dead entries
163    for key in to_remove {
164        map.remove(&key, &guard);
165    }
166}