use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::OnceLock;
use cudarc::cublas::CudaBlas;
use cudarc::cublaslt::sys as cublaslt_sys;
use cudarc::cudnn::sys as cudnn_sys;
use cudarc::driver::CudaContext;
static CTX: OnceLock<Option<Arc<CudaContext>>> = OnceLock::new();
static BLAS: OnceLock<Option<Arc<Mutex<CudaBlas>>>> = OnceLock::new();
static BLAS_LT_HANDLE: OnceLock<Option<usize>> = OnceLock::new();
static DNN_HANDLE: OnceLock<Option<usize>> = OnceLock::new();
pub fn cuda_context() -> Option<Arc<CudaContext>> {
CTX.get_or_init(|| {
let prev = std::panic::take_hook();
std::panic::set_hook(Box::new(|_| {}));
let result = std::panic::catch_unwind(AssertUnwindSafe(|| CudaContext::new(0)));
std::panic::set_hook(prev);
match result {
Ok(Ok(ctx)) => Some(ctx),
_ => None,
}
})
.clone()
}
pub fn cuda_blas() -> Option<Arc<Mutex<CudaBlas>>> {
BLAS.get_or_init(|| {
let ctx = cuda_context()?;
let stream = ctx.default_stream();
CudaBlas::new(stream).ok().map(|b| Arc::new(Mutex::new(b)))
})
.clone()
}
pub fn cuda_blas_lt_handle() -> Option<cublaslt_sys::cublasLtHandle_t> {
BLAS_LT_HANDLE
.get_or_init(|| {
let _ctx = cuda_context()?;
let handle = cudarc::cublaslt::result::create_handle().ok()?;
Some(handle as usize)
})
.map(|h| h as cublaslt_sys::cublasLtHandle_t)
}
pub fn cuda_dnn_handle() -> Option<cudnn_sys::cudnnHandle_t> {
DNN_HANDLE
.get_or_init(|| {
let ctx = cuda_context()?;
let prev = std::panic::take_hook();
std::panic::set_hook(Box::new(|_| {}));
let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
let handle = cudarc::cudnn::result::create_handle().ok()?;
unsafe {
let stream = ctx.default_stream();
cudarc::cudnn::result::set_stream(
handle,
stream.cu_stream() as cudnn_sys::cudaStream_t,
)
.ok()?;
}
Some(handle as usize)
}));
std::panic::set_hook(prev);
result.ok().flatten()
})
.map(|h| h as cudnn_sys::cudnnHandle_t)
}