#[cfg(feature = "cuda")]
use std::sync::atomic::{AtomicU64, Ordering};
#[cfg(feature = "cuda")]
use std::collections::HashMap;
#[cfg(feature = "cuda")]
use std::sync::{Arc, Mutex, OnceLock};
#[cfg(feature = "cuda")]
use crate::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig};
#[cfg(feature = "cuda")]
use crate::error::Result;
#[cfg(feature = "cuda")]
static KERNEL_CACHE: OnceLock<Mutex<HashMap<String, Arc<Mutex<CudaModule>>>>> = OnceLock::new();
#[cfg(feature = "cuda")]
static KERNEL_CACHE_HITS: AtomicU64 = AtomicU64::new(0);
#[cfg(feature = "cuda")]
static KERNEL_CACHE_MISSES: AtomicU64 = AtomicU64::new(0);
#[cfg(feature = "cuda")]
pub(crate) fn get_kernel_cache() -> &'static Mutex<HashMap<String, Arc<Mutex<CudaModule>>>> {
KERNEL_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
}
#[cfg(feature = "cuda")]
fn lock_cache(
cache: &Mutex<HashMap<String, Arc<Mutex<CudaModule>>>>,
) -> Result<std::sync::MutexGuard<'_, HashMap<String, Arc<Mutex<CudaModule>>>>> {
cache.lock().map_err(|e| crate::GpuError::KernelLaunch(format!("Cache lock poisoned: {e}")))
}
#[cfg(feature = "cuda")]
fn lock_module(module: &Mutex<CudaModule>) -> Result<std::sync::MutexGuard<'_, CudaModule>> {
module.lock().map_err(|e| crate::GpuError::KernelLaunch(format!("Module lock poisoned: {e}")))
}
#[cfg(feature = "cuda")]
pub(crate) fn get_or_compile_kernel(
ctx: &CudaContext,
key: &str,
ptx: &str,
) -> Result<Arc<Mutex<CudaModule>>> {
let cache = get_kernel_cache();
{
let cache_guard = lock_cache(cache)?;
if let Some(module) = cache_guard.get(key) {
KERNEL_CACHE_HITS.fetch_add(1, Ordering::Relaxed);
return Ok(Arc::clone(module));
}
}
KERNEL_CACHE_MISSES.fetch_add(1, Ordering::Relaxed);
eprintln!("[KERNEL-CACHE] Compiling: {key}");
let module = CudaModule::from_ptx(ctx, ptx)?;
let module_arc = Arc::new(Mutex::new(module));
lock_cache(cache)?.insert(key.to_string(), Arc::clone(&module_arc));
Ok(module_arc)
}
#[cfg(feature = "cuda")]
pub(crate) fn compile_lock_launch(
ctx: &CudaContext,
stream: &CudaStream,
cache_key: &str,
ptx: &str,
kernel_name: &str,
config: &LaunchConfig,
args: &mut [*mut std::ffi::c_void],
) -> Result<()> {
let module_arc = get_or_compile_kernel(ctx, cache_key, ptx)?;
let mut module = lock_module(&module_arc)?;
unsafe {
stream.launch_kernel(&mut module, kernel_name, config, args)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
#[must_use]
pub fn kernel_cache_hits() -> u64 {
KERNEL_CACHE_HITS.load(Ordering::Relaxed)
}
#[cfg(feature = "cuda")]
#[must_use]
pub fn kernel_cache_misses() -> u64 {
KERNEL_CACHE_MISSES.load(Ordering::Relaxed)
}
#[cfg(feature = "cuda")]
pub fn reset_kernel_cache_stats() {
KERNEL_CACHE_HITS.store(0, Ordering::Relaxed);
KERNEL_CACHE_MISSES.store(0, Ordering::Relaxed);
}
#[cfg(feature = "cuda")]
pub fn clear_kernel_cache() {
if let Some(cache) = KERNEL_CACHE.get() {
if let Ok(mut guard) = lock_cache(cache) {
guard.clear();
}
}
reset_kernel_cache_stats();
}
#[cfg(not(feature = "cuda"))]
#[must_use]
pub fn kernel_cache_hits() -> u64 {
0
}
#[cfg(not(feature = "cuda"))]
#[must_use]
pub fn kernel_cache_misses() -> u64 {
0
}
#[cfg(not(feature = "cuda"))]
pub fn reset_kernel_cache_stats() {}
#[cfg(not(feature = "cuda"))]
pub fn clear_kernel_cache() {}
#[cfg(test)]
mod tests {
use super::*;
fn assert_clean_stats() {
reset_kernel_cache_stats();
assert_eq!(kernel_cache_hits(), 0);
assert_eq!(kernel_cache_misses(), 0);
}
#[test]
fn test_kernel_cache_stats_initial() {
assert_clean_stats();
}
#[test]
fn test_clear_kernel_cache() {
clear_kernel_cache();
assert_clean_stats();
}
#[test]
fn test_idempotent_operations() {
for _ in 0..3 {
reset_kernel_cache_stats();
clear_kernel_cache();
}
assert_clean_stats();
}
}
#[cfg(all(test, feature = "cuda"))]
mod cuda_tests {
use super::*;
fn assert_clean_stats() {
reset_kernel_cache_stats();
assert_eq!(kernel_cache_hits(), 0);
assert_eq!(kernel_cache_misses(), 0);
}
fn assert_counter_round_trip(hits: u64, misses: u64) {
assert_clean_stats();
for _ in 0..hits {
KERNEL_CACHE_HITS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
for _ in 0..misses {
KERNEL_CACHE_MISSES.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
assert_eq!(kernel_cache_hits(), hits);
assert_eq!(kernel_cache_misses(), misses);
assert_clean_stats();
}
fn clear_and_assert_empty() {
clear_kernel_cache();
let guard = lock_cache(get_kernel_cache()).expect("Cache lock should not be poisoned");
assert!(guard.is_empty(), "Cache should be empty");
}
#[test]
fn test_get_kernel_cache_static_and_reentrant() {
let cache1 = get_kernel_cache();
let cache2 = get_kernel_cache();
assert!(std::ptr::eq(cache1, cache2));
for _ in 0..3 {
let _guard = lock_cache(cache1).expect("lock");
}
}
#[test]
fn test_clear_kernel_cache_clears_hashmap() {
clear_and_assert_empty();
}
#[test]
fn test_atomic_counter_operations() {
assert_counter_round_trip(5, 3);
assert_counter_round_trip(100, 50);
}
#[test]
fn test_clear_uninitialized_cache() {
clear_kernel_cache();
assert_clean_stats();
}
}