use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::OnceLock;
use cudarc::cublas::CudaBlas;
use cudarc::cublaslt::sys as cublaslt_sys;
use cudarc::cudnn::sys as cudnn_sys;
use cudarc::driver::{CudaContext, CudaSlice};
static CTX: OnceLock<Option<Arc<CudaContext>>> = OnceLock::new();
static BLAS: OnceLock<Option<Arc<Mutex<CudaBlas>>>> = OnceLock::new();
static BLAS_LT_HANDLE: OnceLock<Option<usize>> = OnceLock::new();
static DNN_HANDLE: OnceLock<Option<usize>> = OnceLock::new();
pub fn cuda_context() -> Option<Arc<CudaContext>> {
CTX.get_or_init(|| {
let prev = std::panic::take_hook();
std::panic::set_hook(Box::new(|_| {}));
let result = std::panic::catch_unwind(AssertUnwindSafe(|| CudaContext::new(0)));
std::panic::set_hook(prev);
match result {
Ok(Ok(ctx)) => Some(ctx),
_ => None,
}
})
.clone()
}
pub fn cuda_blas() -> Option<Arc<Mutex<CudaBlas>>> {
BLAS.get_or_init(|| {
let ctx = cuda_context()?;
let stream = ctx.default_stream();
CudaBlas::new(stream).ok().map(|b| Arc::new(Mutex::new(b)))
})
.clone()
}
pub fn cuda_blas_lt_handle() -> Option<cublaslt_sys::cublasLtHandle_t> {
BLAS_LT_HANDLE
.get_or_init(|| {
let _ctx = cuda_context()?;
let handle = cudarc::cublaslt::result::create_handle().ok()?;
Some(handle as usize)
})
.map(|h| h as cublaslt_sys::cublasLtHandle_t)
}
pub fn cuda_dnn_handle() -> Option<cudnn_sys::cudnnHandle_t> {
DNN_HANDLE
.get_or_init(|| {
let ctx = cuda_context()?;
let prev = std::panic::take_hook();
std::panic::set_hook(Box::new(|_| {}));
let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
let handle = cudarc::cudnn::result::create_handle().ok()?;
unsafe {
let stream = ctx.default_stream();
cudarc::cudnn::result::set_stream(
handle,
stream.cu_stream() as cudnn_sys::cudaStream_t,
)
.ok()?;
}
Some(handle as usize)
}));
std::panic::set_hook(prev);
result.ok().flatten()
})
.map(|h| h as cudnn_sys::cudnnHandle_t)
}
pub const CUBLASLT_WORKSPACE_BYTES: usize = 4 * 1024 * 1024;
pub const CUDNN_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
static BLAS_LT_WORKSPACE: OnceLock<Option<Arc<Mutex<CudaSlice<u8>>>>> = OnceLock::new();
static DNN_WORKSPACE: OnceLock<Option<Arc<Mutex<CudaSlice<u8>>>>> = OnceLock::new();
pub fn cuda_blas_lt_workspace() -> Option<Arc<Mutex<CudaSlice<u8>>>> {
BLAS_LT_WORKSPACE
.get_or_init(|| {
cuda_blas_lt_handle()?;
let ctx = cuda_context()?;
ctx.default_stream()
.alloc_zeros::<u8>(CUBLASLT_WORKSPACE_BYTES)
.ok()
.map(|buf| Arc::new(Mutex::new(buf)))
})
.clone()
}
pub fn cuda_dnn_workspace() -> Option<Arc<Mutex<CudaSlice<u8>>>> {
DNN_WORKSPACE
.get_or_init(|| {
cuda_dnn_handle()?;
let ctx = cuda_context()?;
ctx.default_stream()
.alloc_zeros::<u8>(CUDNN_WORKSPACE_BYTES)
.ok()
.map(|buf| Arc::new(Mutex::new(buf)))
})
.clone()
}
pub fn device_name() -> Option<String> {
cuda_context().map(|_| "cuda-0".to_string())
}