use super::*;
use std::{
collections::HashMap,
sync::{Arc, Mutex, OnceLock},
};
use cudarc::driver::CudaContext;
use crate::block_manager::numa_allocator;
unsafe fn malloc_host_prefer_writecombined(size: usize) -> Result<*mut u8, StorageError> {
match unsafe {
cudarc::driver::result::malloc_host(
size,
cudarc::driver::sys::CU_MEMHOSTALLOC_WRITECOMBINED,
)
} {
Ok(ptr) => Ok(ptr as *mut u8),
Err(_) => {
tracing::debug!("Write-combined memory not supported, using regular pinned memory");
unsafe { cudarc::driver::result::malloc_host(size, 0) }
.map(|ptr| ptr as *mut u8)
.map_err(StorageError::Cuda)
}
}
}
pub trait CudaAccessible: Storage {}
pub trait CudaContextProivder {
fn cuda_context(&self) -> &Arc<CudaContext>;
}
pub struct Cuda {
contexts: HashMap<usize, Arc<CudaContext>>,
}
impl Cuda {
fn new() -> Self {
Self {
contexts: HashMap::new(),
}
}
pub fn device(device_id: usize) -> Option<Arc<CudaContext>> {
Cuda::instance()
.lock()
.unwrap()
.get_existing_context(device_id)
}
pub fn device_or_create(device_id: usize) -> Result<Arc<CudaContext>, StorageError> {
Cuda::instance().lock().unwrap().get_context(device_id)
}
pub fn is_initialized(device_id: usize) -> bool {
Cuda::instance().lock().unwrap().has_context(device_id)
}
fn instance() -> &'static Mutex<Cuda> {
static INSTANCE: OnceLock<Mutex<Cuda>> = OnceLock::new();
INSTANCE.get_or_init(|| Mutex::new(Cuda::new()))
}
fn get_context(&mut self, device_id: usize) -> Result<Arc<CudaContext>, StorageError> {
if let Some(ctx) = self.contexts.get(&device_id) {
return Ok(ctx.clone());
}
let ctx = CudaContext::new(device_id)?;
self.contexts.insert(device_id, ctx.clone());
Ok(ctx)
}
pub fn get_existing_context(&self, device_id: usize) -> Option<Arc<CudaContext>> {
self.contexts.get(&device_id).cloned()
}
pub fn has_context(&self, device_id: usize) -> bool {
self.contexts.contains_key(&device_id)
}
}
#[derive(Debug)]
pub struct PinnedStorage {
ptr: u64,
size: usize,
handles: RegistrationHandles,
ctx: Arc<CudaContext>,
}
impl Local for PinnedStorage {}
impl SystemAccessible for PinnedStorage {}
impl CudaAccessible for PinnedStorage {}
impl PinnedStorage {
pub fn new(ctx: &Arc<CudaContext>, size: usize) -> Result<Self, StorageError> {
unsafe {
ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr = if numa_allocator::is_numa_enabled() {
let device_id = ctx.cu_device() as u32;
match numa_allocator::worker_pool::NumaWorkerPool::global()
.allocate_pinned_for_gpu(size, device_id)
{
Ok(ptr) => ptr,
Err(e) => {
tracing::warn!("NUMA allocation failed: {}, using direct allocation", e);
malloc_host_prefer_writecombined(size)?
}
}
} else {
malloc_host_prefer_writecombined(size)?
};
assert!(!ptr.is_null(), "Failed to allocate pinned memory");
assert!(ptr.is_aligned(), "Pinned memory is not aligned");
assert!(size < isize::MAX as usize);
let ptr = ptr as u64;
Ok(Self {
ptr,
size,
handles: RegistrationHandles::new(),
ctx: ctx.clone(),
})
}
}
}
impl Drop for PinnedStorage {
fn drop(&mut self) {
self.handles.release();
unsafe {
if let Err(e) = cudarc::driver::result::free_host(self.ptr as _) {
tracing::error!(
"Failed to free pinned storage at 0x{:x} (size={}): {}",
self.ptr,
self.size,
e
);
}
}
}
}
impl Storage for PinnedStorage {
fn storage_type(&self) -> StorageType {
StorageType::Pinned
}
fn addr(&self) -> u64 {
self.ptr
}
fn size(&self) -> usize {
self.size
}
unsafe fn as_ptr(&self) -> *const u8 {
self.ptr as *const u8
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr as *mut u8
}
}
impl CudaContextProivder for PinnedStorage {
fn cuda_context(&self) -> &Arc<CudaContext> {
&self.ctx
}
}
impl RegisterableStorage for PinnedStorage {
fn register(
&mut self,
key: &str,
handle: Box<dyn RegistationHandle>,
) -> Result<(), StorageError> {
self.handles.register(key, handle)
}
fn is_registered(&self, key: &str) -> bool {
self.handles.is_registered(key)
}
fn registration_handle(&self, key: &str) -> Option<&dyn RegistationHandle> {
self.handles.registration_handle(key)
}
}
impl StorageMemset for PinnedStorage {
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<(), StorageError> {
if offset + size > self.size {
return Err(StorageError::OperationFailed(
"memset: offset + size > storage size".into(),
));
}
unsafe {
let ptr = (self.ptr as *mut u8).add(offset);
std::ptr::write_bytes(ptr, value, size);
}
Ok(())
}
}
pub struct PinnedAllocator {
ctx: Arc<CudaContext>,
}
impl Default for PinnedAllocator {
fn default() -> Self {
Self {
ctx: Cuda::device_or_create(0).expect("Failed to create CUDA context"),
}
}
}
impl PinnedAllocator {
pub fn new(device_id: usize) -> Result<Self, StorageError> {
Ok(Self {
ctx: Cuda::device_or_create(device_id)?,
})
}
}
impl StorageAllocator<PinnedStorage> for PinnedAllocator {
fn allocate(&self, size: usize) -> Result<PinnedStorage, StorageError> {
PinnedStorage::new(&self.ctx, size)
}
}
#[derive(Debug)]
enum DeviceStorageType {
Owned, Torch { _tensor: Arc<dyn TorchTensor> }, }
#[derive(Debug)]
pub struct DeviceStorage {
ptr: u64,
size: usize,
ctx: Arc<CudaContext>,
handles: RegistrationHandles,
_storage_type: DeviceStorageType,
}
impl Local for DeviceStorage {}
impl CudaAccessible for DeviceStorage {}
impl DeviceStorage {
pub fn new(ctx: &Arc<CudaContext>, size: usize) -> Result<Self, StorageError> {
ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr = unsafe { cudarc::driver::result::malloc_sync(size).map_err(StorageError::Cuda)? };
Ok(Self {
ptr,
size,
ctx: ctx.clone(),
handles: RegistrationHandles::new(),
_storage_type: DeviceStorageType::Owned,
})
}
pub fn new_from_torch(
ctx: &Arc<CudaContext>,
tensor: Arc<dyn TorchTensor>,
) -> Result<Self, StorageError> {
let device = tensor.device();
let TorchDevice::Cuda(device_id) = device else {
return Err(StorageError::InvalidConfig("Tensor is not CUDA!".into()));
};
if device_id != ctx.cu_device() as usize {
return Err(StorageError::InvalidConfig(
"Tensor is not on the same device as the context!".into(),
));
}
let data_ptr = tensor.data_ptr();
let size = tensor.size_bytes();
Ok(Self {
ptr: data_ptr,
size,
ctx: ctx.clone(),
handles: RegistrationHandles::new(),
_storage_type: DeviceStorageType::Torch { _tensor: tensor },
})
}
pub fn context(&self) -> &Arc<CudaContext> {
&self.ctx
}
}
impl Storage for DeviceStorage {
fn storage_type(&self) -> StorageType {
StorageType::Device(self.ctx.cu_device() as u32)
}
fn addr(&self) -> u64 {
self.ptr
}
fn size(&self) -> usize {
self.size
}
unsafe fn as_ptr(&self) -> *const u8 {
self.ptr as *const u8
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr as *mut u8
}
}
impl CudaContextProivder for DeviceStorage {
fn cuda_context(&self) -> &Arc<CudaContext> {
&self.ctx
}
}
impl Drop for DeviceStorage {
fn drop(&mut self) {
self.handles.release();
match &self._storage_type {
DeviceStorageType::Owned => {
unsafe { cudarc::driver::result::free_sync(self.ptr as _) }.unwrap()
}
DeviceStorageType::Torch { _tensor } => {
}
}
}
}
impl RegisterableStorage for DeviceStorage {
fn register(
&mut self,
key: &str,
handle: Box<dyn RegistationHandle>,
) -> Result<(), StorageError> {
self.handles.register(key, handle)
}
fn is_registered(&self, key: &str) -> bool {
self.handles.is_registered(key)
}
fn registration_handle(&self, key: &str) -> Option<&dyn RegistationHandle> {
self.handles.registration_handle(key)
}
}
pub struct DeviceAllocator {
ctx: Arc<CudaContext>,
}
impl Default for DeviceAllocator {
fn default() -> Self {
Self {
ctx: CudaContext::new(0).expect("Failed to create CUDA context"),
}
}
}
impl DeviceAllocator {
pub fn new(device_id: usize) -> Result<Self, StorageError> {
Ok(Self {
ctx: Cuda::device_or_create(device_id)?,
})
}
pub fn ctx(&self) -> &Arc<CudaContext> {
&self.ctx
}
}
impl StorageAllocator<DeviceStorage> for DeviceAllocator {
fn allocate(&self, size: usize) -> Result<DeviceStorage, StorageError> {
DeviceStorage::new(&self.ctx, size)
}
}
#[cfg(all(test, feature = "testing-cuda"))]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct MockTensor {
device: TorchDevice,
data_ptr: u64,
size_bytes: usize,
}
impl MockTensor {
pub fn new(device: TorchDevice, data_ptr: u64, size_bytes: usize) -> Self {
Self {
device,
data_ptr,
size_bytes,
}
}
}
impl TorchTensor for MockTensor {
fn device(&self) -> TorchDevice {
self.device.clone()
}
fn data_ptr(&self) -> u64 {
self.data_ptr
}
fn size_bytes(&self) -> usize {
self.size_bytes
}
fn shape(&self) -> Vec<usize> {
vec![self.size_bytes]
}
fn stride(&self) -> Vec<usize> {
vec![1]
}
}
#[test]
fn test_device_storage_from_torch_valid_tensor() {
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size_bytes = 1024;
let actual_storage =
std::mem::ManuallyDrop::new(DeviceStorage::new(&ctx, size_bytes).unwrap());
let tensor = MockTensor::new(TorchDevice::Cuda(0), actual_storage.addr(), size_bytes);
let storage = DeviceStorage::new_from_torch(&ctx, Arc::new(tensor)).unwrap();
assert_eq!(storage.size(), size_bytes);
assert_eq!(storage.storage_type(), StorageType::Device(0));
assert_eq!(storage.addr(), actual_storage.addr());
}
#[test]
fn test_device_storage_from_torch_cpu_tensor_fails() {
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size_bytes = 1024;
let actual_storage = DeviceStorage::new(&ctx, size_bytes).unwrap();
let tensor = MockTensor::new(
TorchDevice::Other("cpu".to_string()),
actual_storage.addr(),
size_bytes,
);
let result = DeviceStorage::new_from_torch(&ctx, Arc::new(tensor));
assert!(result.is_err());
if let Err(StorageError::InvalidConfig(msg)) = result {
assert!(msg.contains("Tensor is not CUDA"));
} else {
panic!("Expected InvalidConfig error for CPU tensor");
}
}
#[test]
fn test_device_storage_wrong_device() {
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size_bytes = 1024;
let actual_storage = DeviceStorage::new(&ctx, size_bytes).unwrap();
let tensor = MockTensor::new(TorchDevice::Cuda(1), actual_storage.addr(), size_bytes);
let result = DeviceStorage::new_from_torch(&ctx, Arc::new(tensor));
assert!(result.is_err());
}
#[test]
fn test_malloc_host_prefer_writecombined_allocates_memory() {
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size = 4096;
unsafe {
ctx.bind_to_thread().expect("Failed to bind CUDA context");
let ptr = malloc_host_prefer_writecombined(size)
.expect("malloc_host_prefer_writecombined should succeed");
assert!(!ptr.is_null(), "Allocated pointer should not be null");
std::ptr::write_volatile(ptr, 0xAB);
let val = std::ptr::read_volatile(ptr);
assert_eq!(val, 0xAB, "Should be able to write and read pinned memory");
cudarc::driver::result::free_host(ptr as _).expect("Failed to free pinned memory");
}
}
#[test]
fn test_pinned_storage_new_without_numa() {
assert!(
!numa_allocator::is_numa_enabled(),
"NUMA should be disabled for this test"
);
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size = 8192;
let mut storage =
PinnedStorage::new(&ctx, size).expect("PinnedStorage::new should succeed");
assert_eq!(storage.size(), size);
assert_eq!(storage.storage_type(), StorageType::Pinned);
assert_ne!(storage.addr(), 0, "Address should be non-zero");
unsafe {
let ptr = storage.as_mut_ptr();
assert!(!ptr.is_null(), "Pointer should not be null");
for i in 0..size {
std::ptr::write_volatile(ptr.add(i), (i & 0xFF) as u8);
}
for i in 0..size {
let val = std::ptr::read_volatile(ptr.add(i));
assert_eq!(
val,
(i & 0xFF) as u8,
"Memory content mismatch at offset {}",
i
);
}
}
}
}