use super::client::CudaClient;
use super::device::CudaDevice;
use std::collections::HashMap;
use std::sync::{Mutex, MutexGuard, OnceLock, PoisonError};
static CLIENT_CACHE: OnceLock<Mutex<HashMap<usize, CudaClient>>> = OnceLock::new();
#[inline]
pub(super) unsafe fn is_cuda_context_valid() -> bool {
let mut ctx: cudarc::driver::sys::CUcontext = std::ptr::null_mut();
let result = unsafe { cudarc::driver::sys::cuCtxGetCurrent(&mut ctx) };
result == cudarc::driver::sys::CUresult::CUDA_SUCCESS && !ctx.is_null()
}
#[inline]
fn lock_client_cache(
cache: &Mutex<HashMap<usize, CudaClient>>,
) -> MutexGuard<'_, HashMap<usize, CudaClient>> {
cache.lock().unwrap_or_else(PoisonError::into_inner)
}
pub(super) fn get_or_create_client(device: &CudaDevice) -> CudaClient {
let cache = CLIENT_CACHE.get_or_init(|| Mutex::new(HashMap::new()));
let mut cache_guard = lock_client_cache(cache);
if let Some(client) = cache_guard.get(&device.index) {
return client.clone();
}
let client = CudaClient::new(device.clone()).expect("Failed to create CUDA client");
cache_guard.insert(device.index, client.clone());
client
}
#[inline]
pub(super) fn try_get_cached_client(device_index: usize) -> Option<CudaClient> {
let cache = CLIENT_CACHE.get()?;
let guard = lock_client_cache(cache);
guard.get(&device_index).cloned()
}
#[inline]
pub(super) fn try_get_cached_stream(device_index: usize) -> Option<cudarc::driver::sys::CUstream> {
let cache = CLIENT_CACHE.get()?;
let guard = lock_client_cache(cache);
guard
.get(&device_index)
.map(|client| client.stream.cu_stream())
}
#[cold]
#[inline(never)]
pub(super) fn log_cuda_memory_error(
operation: &str,
ptr: u64,
result: cudarc::driver::sys::CUresult,
) {
eprintln!(
"[numr::cuda] {} failed for ptr 0x{:x}: {:?}",
operation, ptr, result
);
}