#[cfg(feature = "cuda")]
use std::sync::atomic::{AtomicU64, Ordering};
#[cfg(feature = "cuda")]
use std::collections::HashMap;
#[cfg(feature = "cuda")]
use std::sync::{Arc, Mutex, OnceLock};
#[cfg(feature = "cuda")]
use crate::driver::{CudaContext, CudaModule};
#[cfg(feature = "cuda")]
use crate::error::Result;
#[cfg(feature = "cuda")]
static KERNEL_CACHE: OnceLock<Mutex<HashMap<String, Arc<Mutex<CudaModule>>>>> = OnceLock::new();
#[cfg(feature = "cuda")]
static KERNEL_CACHE_HITS: AtomicU64 = AtomicU64::new(0);
#[cfg(feature = "cuda")]
static KERNEL_CACHE_MISSES: AtomicU64 = AtomicU64::new(0);
#[cfg(feature = "cuda")]
pub(crate) fn get_kernel_cache() -> &'static Mutex<HashMap<String, Arc<Mutex<CudaModule>>>> {
KERNEL_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
}
#[cfg(feature = "cuda")]
pub(crate) fn get_or_compile_kernel(
ctx: &CudaContext,
key: &str,
ptx: &str,
) -> Result<Arc<Mutex<CudaModule>>> {
let cache = get_kernel_cache();
{
let cache_guard = cache.lock().map_err(|e| {
crate::GpuError::KernelLaunch(format!("Cache lock poisoned: {}", e))
})?;
if let Some(module) = cache_guard.get(key) {
KERNEL_CACHE_HITS.fetch_add(1, Ordering::Relaxed);
return Ok(Arc::clone(module));
}
}
KERNEL_CACHE_MISSES.fetch_add(1, Ordering::Relaxed);
eprintln!("[KERNEL-CACHE] Compiling: {}", key);
let module = CudaModule::from_ptx(ctx, ptx)?;
let module_arc = Arc::new(Mutex::new(module));
{
let mut cache_guard = cache.lock().map_err(|e| {
crate::GpuError::KernelLaunch(format!("Cache lock poisoned: {}", e))
})?;
cache_guard.insert(key.to_string(), Arc::clone(&module_arc));
}
Ok(module_arc)
}
#[cfg(feature = "cuda")]
#[must_use]
pub fn kernel_cache_hits() -> u64 {
KERNEL_CACHE_HITS.load(Ordering::Relaxed)
}
#[cfg(feature = "cuda")]
#[must_use]
pub fn kernel_cache_misses() -> u64 {
KERNEL_CACHE_MISSES.load(Ordering::Relaxed)
}
#[cfg(feature = "cuda")]
pub fn reset_kernel_cache_stats() {
KERNEL_CACHE_HITS.store(0, Ordering::Relaxed);
KERNEL_CACHE_MISSES.store(0, Ordering::Relaxed);
}
#[cfg(feature = "cuda")]
pub fn clear_kernel_cache() {
if let Some(cache) = KERNEL_CACHE.get() {
if let Ok(mut guard) = cache.lock() {
guard.clear();
}
}
reset_kernel_cache_stats();
}
#[cfg(not(feature = "cuda"))]
#[must_use]
pub fn kernel_cache_hits() -> u64 {
0
}
#[cfg(not(feature = "cuda"))]
#[must_use]
pub fn kernel_cache_misses() -> u64 {
0
}
#[cfg(not(feature = "cuda"))]
pub fn reset_kernel_cache_stats() {}
#[cfg(not(feature = "cuda"))]
pub fn clear_kernel_cache() {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kernel_cache_stats_initial() {
reset_kernel_cache_stats();
assert_eq!(kernel_cache_hits(), 0);
assert_eq!(kernel_cache_misses(), 0);
}
#[test]
fn test_clear_kernel_cache() {
clear_kernel_cache();
assert_eq!(kernel_cache_hits(), 0);
assert_eq!(kernel_cache_misses(), 0);
}
#[test]
fn test_kernel_cache_hits_after_reset() {
reset_kernel_cache_stats();
let hits = kernel_cache_hits();
assert_eq!(hits, 0);
}
#[test]
fn test_kernel_cache_misses_after_reset() {
reset_kernel_cache_stats();
let misses = kernel_cache_misses();
assert_eq!(misses, 0);
}
#[test]
fn test_reset_kernel_cache_stats_idempotent() {
reset_kernel_cache_stats();
reset_kernel_cache_stats();
reset_kernel_cache_stats();
assert_eq!(kernel_cache_hits(), 0);
assert_eq!(kernel_cache_misses(), 0);
}
#[test]
fn test_clear_kernel_cache_idempotent() {
clear_kernel_cache();
clear_kernel_cache();
clear_kernel_cache();
assert_eq!(kernel_cache_hits(), 0);
assert_eq!(kernel_cache_misses(), 0);
}
#[test]
fn test_stats_consistency() {
reset_kernel_cache_stats();
let h1 = kernel_cache_hits();
let m1 = kernel_cache_misses();
assert_eq!(h1, 0);
assert_eq!(m1, 0);
clear_kernel_cache();
let h2 = kernel_cache_hits();
let m2 = kernel_cache_misses();
assert_eq!(h2, 0);
assert_eq!(m2, 0);
}
}
#[cfg(all(test, feature = "cuda"))]
mod cuda_tests {
use super::*;
#[test]
fn test_get_kernel_cache_returns_valid_cache() {
let cache = get_kernel_cache();
let guard = cache.lock().expect("Cache lock should not be poisoned");
drop(guard);
}
#[test]
fn test_get_kernel_cache_is_static() {
let cache1 = get_kernel_cache();
let cache2 = get_kernel_cache();
assert!(std::ptr::eq(cache1, cache2));
}
#[test]
fn test_cache_lock_reentrant() {
let cache = get_kernel_cache();
{
let _guard1 = cache.lock().expect("First lock should succeed");
}
{
let _guard2 = cache.lock().expect("Second lock should succeed");
}
{
let _guard3 = cache.lock().expect("Third lock should succeed");
}
}
#[test]
fn test_clear_kernel_cache_clears_hashmap() {
clear_kernel_cache();
let cache = get_kernel_cache();
let guard = cache.lock().expect("Lock should succeed");
assert!(guard.is_empty(), "Cache should be empty after clear");
}
#[test]
fn test_atomic_counter_operations() {
reset_kernel_cache_stats();
KERNEL_CACHE_HITS.fetch_add(5, std::sync::atomic::Ordering::Relaxed);
KERNEL_CACHE_MISSES.fetch_add(3, std::sync::atomic::Ordering::Relaxed);
assert_eq!(kernel_cache_hits(), 5);
assert_eq!(kernel_cache_misses(), 3);
reset_kernel_cache_stats();
assert_eq!(kernel_cache_hits(), 0);
assert_eq!(kernel_cache_misses(), 0);
}
#[test]
fn test_counter_concurrent_increments() {
reset_kernel_cache_stats();
for _ in 0..100 {
KERNEL_CACHE_HITS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
for _ in 0..50 {
KERNEL_CACHE_MISSES.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
assert_eq!(kernel_cache_hits(), 100);
assert_eq!(kernel_cache_misses(), 50);
reset_kernel_cache_stats();
}
#[test]
fn test_clear_uninitialized_cache() {
clear_kernel_cache();
assert_eq!(kernel_cache_hits(), 0);
assert_eq!(kernel_cache_misses(), 0);
}
#[test]
fn test_cache_hashmap_operations() {
clear_kernel_cache();
let cache = get_kernel_cache();
{
let guard = cache.lock().expect("Lock should succeed");
assert!(guard.is_empty());
}
}
}