use cudarc::cublas::CudaBlas;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::CudaRuntime;
use super::device::{CudaDevice, CudaError};
use crate::runtime::{Allocator, RuntimeClient};
#[inline]
unsafe fn is_cuda_context_valid() -> bool {
let mut ctx: cudarc::driver::sys::CUcontext = std::ptr::null_mut();
let result = unsafe { cudarc::driver::sys::cuCtxGetCurrent(&mut ctx) };
result == cudarc::driver::sys::CUresult::CUDA_SUCCESS && !ctx.is_null()
}
#[derive(Clone)]
pub struct CudaClient {
pub(crate) device: CudaDevice,
pub(crate) context: Arc<CudaContext>,
pub(crate) stream: Arc<CudaStream>,
pub(crate) copy_stream: Arc<CudaStream>,
pub(crate) cublas: Arc<CudaBlas>,
pub(crate) allocator: CudaAllocator,
pub(crate) raw_handle: CudaRawHandle,
}
impl std::fmt::Debug for CudaClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaClient")
.field("device", &self.device)
.finish_non_exhaustive()
}
}
#[derive(Clone)]
pub struct CudaAllocator {
stream: Arc<CudaStream>,
cache: Arc<std::sync::Mutex<std::collections::HashMap<usize, Vec<u64>>>>,
frozen: Arc<std::sync::atomic::AtomicBool>,
}
impl Allocator for CudaAllocator {
fn allocate(&self, size_bytes: usize) -> crate::error::Result<u64> {
if size_bytes == 0 {
return Ok(0);
}
if !self.frozen.load(std::sync::atomic::Ordering::Relaxed) {
let mut cache = self.cache.lock().unwrap();
if let Some(ptrs) = cache.get_mut(&size_bytes)
&& let Some(ptr) = ptrs.pop()
{
return Ok(ptr);
}
}
unsafe {
let mut ptr: u64 = 0;
let result =
cudarc::driver::sys::cuMemAllocAsync(&mut ptr, size_bytes, self.stream.cu_stream());
if result == cudarc::driver::sys::CUresult::CUDA_SUCCESS {
return Ok(ptr);
}
let _ = self.stream.synchronize();
let result =
cudarc::driver::sys::cuMemAllocAsync(&mut ptr, size_bytes, self.stream.cu_stream());
if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
return Err(crate::error::Error::OutOfMemory { size: size_bytes });
}
Ok(ptr)
}
}
fn deallocate(&self, ptr: u64, size_bytes: usize) {
if ptr == 0 {
return;
}
if self.frozen.load(std::sync::atomic::Ordering::Relaxed) {
unsafe {
let _ = cudarc::driver::sys::cuMemFreeAsync(ptr, self.stream.cu_stream());
}
return;
}
let mut cache = self.cache.lock().unwrap();
cache.entry(size_bytes).or_default().push(ptr);
}
fn is_frozen(&self) -> bool {
self.frozen.load(std::sync::atomic::Ordering::Relaxed)
}
fn freeze(&self) -> bool {
self.frozen
.store(true, std::sync::atomic::Ordering::Relaxed);
true
}
fn unfreeze(&self) {
self.frozen
.store(false, std::sync::atomic::Ordering::Relaxed);
}
fn reset(&self) -> crate::error::Result<()> {
let mut cache = self.cache.lock().unwrap();
for (_size, ptrs) in cache.drain() {
for ptr in ptrs {
unsafe {
if is_cuda_context_valid() {
let _ = cudarc::driver::sys::cuMemFreeAsync(ptr, self.stream.cu_stream());
}
}
}
}
Ok(())
}
}
impl CudaClient {
pub fn new(device: CudaDevice) -> Result<Self, CudaError> {
let context = CudaContext::new(device.index).map_err(|e| {
CudaError::ContextError(format!(
"Failed to create CUDA context for device {}: {:?}",
device.index, e
))
})?;
context.bind_to_thread().map_err(|e| {
CudaError::ContextError(format!("Failed to bind CUDA context to thread: {:?}", e))
})?;
let stream = context.new_stream().map_err(|e| {
CudaError::ContextError(format!("Failed to create CUDA stream: {:?}", e))
})?;
let copy_stream = context.new_stream().map_err(|e| {
CudaError::ContextError(format!("Failed to create CUDA copy stream: {:?}", e))
})?;
let cublas = CudaBlas::new(stream.clone())
.map_err(|e| CudaError::CublasError(format!("Failed to initialize cuBLAS: {:?}", e)))?;
unsafe {
let mut pool: cudarc::driver::sys::CUmemoryPool = std::ptr::null_mut();
let result =
cudarc::driver::sys::cuDeviceGetDefaultMemPool(&mut pool, device.index as i32);
if result == cudarc::driver::sys::CUresult::CUDA_SUCCESS && !pool.is_null() {
let threshold: u64 = u64::MAX; let _ = cudarc::driver::sys::cuMemPoolSetAttribute(
pool,
cudarc::driver::sys::CUmemPool_attribute::CU_MEMPOOL_ATTR_RELEASE_THRESHOLD,
&threshold as *const u64 as *mut std::ffi::c_void,
);
}
}
let allocator = CudaAllocator {
stream: stream.clone(),
cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
frozen: Arc::new(std::sync::atomic::AtomicBool::new(false)),
};
let raw_handle = CudaRawHandle {
context: context.clone(),
stream: stream.clone(),
};
Ok(Self {
device,
context,
stream,
copy_stream,
cublas: Arc::new(cublas),
allocator,
raw_handle,
})
}
#[inline]
pub fn stream(&self) -> &CudaStream {
&self.stream
}
#[inline]
pub fn stream_arc(&self) -> &Arc<CudaStream> {
&self.stream
}
#[inline]
pub fn context(&self) -> &Arc<CudaContext> {
&self.context
}
#[inline]
pub fn copy_stream(&self) -> &CudaStream {
&self.copy_stream
}
#[inline]
pub fn cublas(&self) -> &CudaBlas {
&self.cublas
}
pub fn record_event_on_compute(&self) -> Result<u64, CudaError> {
use cudarc::driver::sys::{CUevent_flags, cuEventCreate, cuEventRecord};
unsafe {
let mut event = std::ptr::null_mut();
let r = cuEventCreate(&mut event, CUevent_flags::CU_EVENT_DISABLE_TIMING as u32);
if r != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
return Err(CudaError::ContextError(format!(
"cuEventCreate failed: {:?}",
r
)));
}
let r = cuEventRecord(event, self.stream.cu_stream());
if r != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
cudarc::driver::sys::cuEventDestroy_v2(event);
return Err(CudaError::ContextError(format!(
"cuEventRecord failed: {:?}",
r
)));
}
Ok(event as u64)
}
}
pub fn copy_stream_wait_event(&self, event: u64) -> Result<(), CudaError> {
use cudarc::driver::sys::cuStreamWaitEvent;
unsafe {
let r = cuStreamWaitEvent(
self.copy_stream.cu_stream(),
event as cudarc::driver::sys::CUevent,
0,
);
if r != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
return Err(CudaError::ContextError(format!(
"cuStreamWaitEvent failed: {:?}",
r
)));
}
}
Ok(())
}
pub fn preload_modules(&self, module_names: &[&'static str]) -> crate::error::Result<()> {
crate::runtime::cuda::kernels::preload_modules(
&self.context,
self.device.index,
module_names,
)
}
pub fn destroy_event(&self, event: u64) {
unsafe {
cudarc::driver::sys::cuEventDestroy_v2(event as cudarc::driver::sys::CUevent);
}
}
}
impl RuntimeClient<CudaRuntime> for CudaClient {
fn device(&self) -> &CudaDevice {
&self.device
}
fn synchronize(&self) {
if let Err(e) = self.stream.synchronize() {
eprintln!("[numr::cuda] Stream synchronization failed: {:?}", e);
}
}
fn allocator(&self) -> &CudaAllocator {
&self.allocator
}
fn compute_stream_handle(&self) -> Option<u64> {
Some(self.stream.cu_stream() as u64)
}
}
#[derive(Clone)]
pub struct CudaRawHandle {
pub context: Arc<CudaContext>,
pub stream: Arc<CudaStream>,
}