use cudarc::cublas::CudaBlas;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::CudaRuntime;
use super::allocator::CudaAllocator;
use super::device::{CudaDevice, CudaError};
use super::sobol_cache::SobolDvCache;
use crate::runtime::RuntimeClient;
#[derive(Clone)]
pub struct CudaClient {
pub(crate) device: CudaDevice,
pub(crate) context: Arc<CudaContext>,
pub(crate) stream: Arc<CudaStream>,
pub(crate) copy_stream: Arc<CudaStream>,
pub(crate) cublas: Arc<CudaBlas>,
pub(crate) allocator: CudaAllocator,
pub(crate) raw_handle: CudaRawHandle,
pub(crate) sobol_dv_cache: Arc<SobolDvCache>,
}
impl std::fmt::Debug for CudaClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaClient")
.field("device", &self.device)
.finish_non_exhaustive()
}
}
impl CudaClient {
pub fn new(device: CudaDevice) -> Result<Self, CudaError> {
if let Some(cached) = super::cache::try_get_cached_client(device.index) {
return Ok(cached);
}
let client = Self::new_uncached(device)?;
Ok(super::cache::register_or_get_client(
client.device.index,
client,
))
}
pub(super) fn new_uncached(device: CudaDevice) -> Result<Self, CudaError> {
let context = CudaContext::new(device.index).map_err(|e| {
CudaError::ContextError(format!(
"Failed to create CUDA context for device {}: {:?}",
device.index, e
))
})?;
context.bind_to_thread().map_err(|e| {
CudaError::ContextError(format!("Failed to bind CUDA context to thread: {:?}", e))
})?;
let stream = context.new_stream().map_err(|e| {
CudaError::ContextError(format!("Failed to create CUDA stream: {:?}", e))
})?;
let copy_stream = context.new_stream().map_err(|e| {
CudaError::ContextError(format!("Failed to create CUDA copy stream: {:?}", e))
})?;
let cublas = CudaBlas::new(stream.clone())
.map_err(|e| CudaError::CublasError(format!("Failed to initialize cuBLAS: {:?}", e)))?;
let mut pool_handle: u64 = 0;
unsafe {
let mut pool: cudarc::driver::sys::CUmemoryPool = std::ptr::null_mut();
let result =
cudarc::driver::sys::cuDeviceGetDefaultMemPool(&mut pool, device.index as i32);
if result == cudarc::driver::sys::CUresult::CUDA_SUCCESS && !pool.is_null() {
let threshold: u64 = super::env_config::env_mib_to_bytes(
"NUMR_CUDA_POOL_RELEASE_THRESHOLD_MB",
512 * 1024 * 1024,
);
let _ = cudarc::driver::sys::cuMemPoolSetAttribute(
pool,
cudarc::driver::sys::CUmemPool_attribute::CU_MEMPOOL_ATTR_RELEASE_THRESHOLD,
&threshold as *const u64 as *mut std::ffi::c_void,
);
pool_handle = pool as u64;
}
}
let allocator = CudaAllocator::new(stream.clone(), pool_handle);
let raw_handle = CudaRawHandle {
context: context.clone(),
stream: stream.clone(),
};
Ok(Self {
device,
context,
stream,
copy_stream,
cublas: Arc::new(cublas),
allocator,
raw_handle,
sobol_dv_cache: SobolDvCache::new(),
})
}
#[inline]
pub fn stream(&self) -> &CudaStream {
&self.stream
}
#[inline]
pub fn stream_arc(&self) -> &Arc<CudaStream> {
&self.stream
}
#[inline]
pub fn context(&self) -> &Arc<CudaContext> {
&self.context
}
#[inline]
pub fn copy_stream(&self) -> &CudaStream {
&self.copy_stream
}
#[inline]
pub fn cublas(&self) -> &CudaBlas {
&self.cublas
}
pub fn record_event_on_compute(&self) -> Result<u64, CudaError> {
use cudarc::driver::sys::{CUevent_flags, cuEventCreate, cuEventRecord};
unsafe {
let mut event = std::ptr::null_mut();
let r = cuEventCreate(&mut event, CUevent_flags::CU_EVENT_DISABLE_TIMING as u32);
if r != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
return Err(CudaError::ContextError(format!(
"cuEventCreate failed: {:?}",
r
)));
}
let r = cuEventRecord(event, self.stream.cu_stream());
if r != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
cudarc::driver::sys::cuEventDestroy_v2(event);
return Err(CudaError::ContextError(format!(
"cuEventRecord failed: {:?}",
r
)));
}
Ok(event as u64)
}
}
pub fn copy_stream_wait_event(&self, event: u64) -> Result<(), CudaError> {
use cudarc::driver::sys::cuStreamWaitEvent;
unsafe {
let r = cuStreamWaitEvent(
self.copy_stream.cu_stream(),
event as cudarc::driver::sys::CUevent,
0,
);
if r != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
return Err(CudaError::ContextError(format!(
"cuStreamWaitEvent failed: {:?}",
r
)));
}
}
Ok(())
}
pub fn preload_modules(&self, module_names: &[&'static str]) -> crate::error::Result<()> {
crate::runtime::cuda::kernels::preload_modules(
&self.context,
self.device.index,
module_names,
)
}
pub fn warmup_sobol(&self, dimension: usize) -> crate::error::Result<()> {
use crate::ops::common::quasirandom::{SOBOL_BITS, SOBOL_MAX_DIMENSIONS};
if dimension == 0 {
return Err(crate::error::Error::InvalidArgument {
arg: "dimension",
reason: "Sobol dimension must be at least 1".into(),
});
}
if dimension > SOBOL_MAX_DIMENSIONS {
return Err(crate::error::Error::InvalidArgument {
arg: "dimension",
reason: format!(
"Sobol dimension {} exceeds maximum supported value {}",
dimension, SOBOL_MAX_DIMENSIONS
),
});
}
let dim_u32 = dimension as u32;
if self.sobol_dv_cache.get(dim_u32).is_some() {
return Ok(());
}
let direction_vectors =
crate::ops::common::quasirandom::compute_all_direction_vectors(dimension);
let num_u32s = direction_vectors.len();
debug_assert_eq!(num_u32s, dimension * SOBOL_BITS);
let dv_bytes = bytemuck::cast_slice::<u32, u8>(&direction_vectors);
let dv_ptr: u64 = unsafe {
let mut ptr: u64 = 0;
let r = cudarc::driver::sys::cuMemAllocAsync(
&mut ptr,
dv_bytes.len(),
self.stream.cu_stream(),
);
if r != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
return Err(crate::error::Error::OutOfMemory {
size: dv_bytes.len(),
});
}
ptr
};
unsafe {
let r = cudarc::driver::sys::cuMemcpyHtoDAsync_v2(
dv_ptr,
dv_bytes.as_ptr() as *const std::ffi::c_void,
dv_bytes.len(),
self.stream.cu_stream(),
);
if r != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
let _ = cudarc::driver::sys::cuMemFreeAsync(dv_ptr, self.stream.cu_stream());
return Err(crate::error::Error::Backend(format!(
"Sobol warmup H2D copy failed: {:?}",
r
)));
}
}
self.stream
.synchronize()
.map_err(|e| crate::error::Error::Internal(format!("stream sync failed: {:?}", e)))?;
unsafe { self.sobol_dv_cache.insert(dim_u32, dv_ptr, num_u32s) };
Ok(())
}
pub fn destroy_event(&self, event: u64) {
unsafe {
cudarc::driver::sys::cuEventDestroy_v2(event as cudarc::driver::sys::CUevent);
}
}
}
impl RuntimeClient<CudaRuntime> for CudaClient {
fn device(&self) -> &CudaDevice {
&self.device
}
fn synchronize(&self) {
if let Err(e) = self.stream.synchronize() {
eprintln!("[numr::cuda] Stream synchronization failed: {:?}", e);
}
}
fn allocator(&self) -> &CudaAllocator {
&self.allocator
}
fn compute_stream_handle(&self) -> Option<u64> {
Some(self.stream.cu_stream() as u64)
}
}
#[derive(Clone)]
pub struct CudaRawHandle {
pub context: Arc<CudaContext>,
pub stream: Arc<CudaStream>,
}