use cudarc::driver::{
CudaContext, CudaStream,
sys::{CUcontext, CUstream, cuCtxPopCurrent_v2, cuCtxPushCurrent_v2, cudaError_enum},
};
use std::pin::Pin;
use std::{marker::PhantomData, sync::Arc};
pub trait DynamoCudaContextProvider {
unsafe fn cu_context(&self) -> cudarc::driver::sys::CUcontext;
fn bind_to_thread(&self) -> Pin<Box<DynamoCudaContextGuard>> {
unsafe { DynamoCudaContextGuard::new(self.cu_context()) }
}
}
pub trait DynamoCudaStreamProvider {
unsafe fn cu_stream(&self) -> cudarc::driver::sys::CUstream;
fn context(&self) -> Arc<dyn DynamoCudaContextProvider>;
}
pub struct DynamoCudaContextGuard {
context: cudarc::driver::sys::CUcontext,
_pin: std::marker::PhantomPinned,
_not_send_sync: PhantomData<*const ()>,
}
impl DynamoCudaContextGuard {
pub unsafe fn new(context: CUcontext) -> Pin<Box<Self>> {
let result = cuCtxPushCurrent_v2(context);
if result != cudaError_enum::CUDA_SUCCESS {
panic!("Failed to push CUDA context: {:?}", result);
}
let guard = Self {
context,
_pin: std::marker::PhantomPinned,
_not_send_sync: PhantomData,
};
Box::pin(guard)
}
pub fn context(&self) -> cudarc::driver::sys::CUcontext {
self.context
}
}
impl Drop for DynamoCudaContextGuard {
fn drop(&mut self) {
let mut popped_context: CUcontext = std::ptr::null_mut();
let result = unsafe { cuCtxPopCurrent_v2(&mut popped_context) };
if result != cudaError_enum::CUDA_SUCCESS {
eprintln!("Warning: Failed to pop CUDA context in drop: {:?}", result);
}
if popped_context != self.context {
eprintln!(
"Warning: Popped context {:?} does not match expected context {:?}",
popped_context, self.context
);
}
}
}
pub struct ExternalCudaContext {
context: CUcontext,
}
unsafe impl Send for ExternalCudaContext {}
unsafe impl Sync for ExternalCudaContext {}
impl ExternalCudaContext {
pub fn new(context: CUcontext) -> Arc<Self> {
Arc::new(Self { context })
}
pub fn cu_context(&self) -> CUcontext {
self.context
}
}
impl DynamoCudaContextProvider for ExternalCudaContext {
unsafe fn cu_context(&self) -> cudarc::driver::sys::CUcontext {
self.cu_context()
}
}
pub struct ExternalCudaStream {
stream: CUstream,
context: Arc<dyn DynamoCudaContextProvider>,
}
impl ExternalCudaStream {
pub fn new(stream: CUstream, context: Arc<dyn DynamoCudaContextProvider>) -> Self {
Self { stream, context }
}
}
impl DynamoCudaStreamProvider for ExternalCudaStream {
unsafe fn cu_stream(&self) -> cudarc::driver::sys::CUstream {
self.stream
}
fn context(&self) -> Arc<dyn DynamoCudaContextProvider> {
self.context.clone()
}
}
impl DynamoCudaContextProvider for CudaContext {
unsafe fn cu_context(&self) -> cudarc::driver::sys::CUcontext {
self.cu_ctx()
}
}
impl DynamoCudaContextProvider for CudaStream {
unsafe fn cu_context(&self) -> cudarc::driver::sys::CUcontext {
self.context().cu_context()
}
}
impl DynamoCudaStreamProvider for CudaStream {
unsafe fn cu_stream(&self) -> cudarc::driver::sys::CUstream {
self.cu_stream()
}
fn context(&self) -> Arc<dyn DynamoCudaContextProvider> {
self.context().clone()
}
}