use super::*;
use std::{
collections::HashMap,
sync::{Arc, Mutex, OnceLock},
};
use cudarc::driver::CudaContext;
use dynamo_memory::MemoryDescriptor as _;
pub trait CudaAccessible: Storage {}
pub trait CudaContextProivder {
fn cuda_context(&self) -> &Arc<CudaContext>;
}
pub struct Cuda {
contexts: HashMap<usize, Arc<CudaContext>>,
}
impl Cuda {
fn new() -> Self {
Self {
contexts: HashMap::new(),
}
}
pub fn device(device_id: usize) -> Option<Arc<CudaContext>> {
Cuda::instance()
.lock()
.unwrap()
.get_existing_context(device_id)
}
pub fn device_or_create(device_id: usize) -> Result<Arc<CudaContext>, StorageError> {
Cuda::instance().lock().unwrap().get_context(device_id)
}
pub fn is_initialized(device_id: usize) -> bool {
Cuda::instance().lock().unwrap().has_context(device_id)
}
fn instance() -> &'static Mutex<Cuda> {
static INSTANCE: OnceLock<Mutex<Cuda>> = OnceLock::new();
INSTANCE.get_or_init(|| Mutex::new(Cuda::new()))
}
fn get_context(&mut self, device_id: usize) -> Result<Arc<CudaContext>, StorageError> {
if let Some(ctx) = self.contexts.get(&device_id) {
return Ok(ctx.clone());
}
let ctx = CudaContext::new(device_id)?;
self.contexts.insert(device_id, ctx.clone());
Ok(ctx)
}
pub fn get_existing_context(&self, device_id: usize) -> Option<Arc<CudaContext>> {
self.contexts.get(&device_id).cloned()
}
pub fn has_context(&self, device_id: usize) -> bool {
self.contexts.contains_key(&device_id)
}
}
#[derive(Debug)]
pub struct PinnedStorage {
inner: dynamo_memory::PinnedStorage,
handles: RegistrationHandles,
}
impl Local for PinnedStorage {}
impl SystemAccessible for PinnedStorage {}
impl CudaAccessible for PinnedStorage {}
impl PinnedStorage {
#[deprecated(since = "1.0.0", note = "Use PinnedStorage::new_for_device instead")]
pub fn new(ctx: &Arc<CudaContext>, size: usize) -> Result<Self, StorageError> {
let inner =
dynamo_memory::PinnedStorage::new_for_device(size, Some(ctx.cu_device() as u32))?;
Ok(Self {
inner,
handles: RegistrationHandles::new(),
})
}
pub fn new_for_device(size: usize, device_id: Option<u32>) -> Result<Self, StorageError> {
static DEPRECATION_WARN: std::sync::Once = std::sync::Once::new();
if std::env::var("DYN_KVBM_ENABLE_NUMA").is_ok() {
DEPRECATION_WARN.call_once(|| {
tracing::warn!(
"DYN_KVBM_ENABLE_NUMA is deprecated for PinnedStorage::new_for_device; \
NUMA is now enabled by default. Use DYN_MEMORY_DISABLE_NUMA=1 to disable."
);
});
}
let inner = dynamo_memory::PinnedStorage::new_for_device(size, device_id)?;
Ok(Self {
inner,
handles: RegistrationHandles::new(),
})
}
}
impl Drop for PinnedStorage {
fn drop(&mut self) {
self.handles.release();
}
}
impl Storage for PinnedStorage {
fn storage_type(&self) -> StorageType {
StorageType::Pinned
}
fn addr(&self) -> u64 {
self.inner.addr() as u64
}
fn size(&self) -> usize {
self.inner.size()
}
unsafe fn as_ptr(&self) -> *const u8 {
unsafe { self.inner.as_ptr() }
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
unsafe { self.inner.as_mut_ptr() }
}
}
impl CudaContextProivder for PinnedStorage {
fn cuda_context(&self) -> &Arc<CudaContext> {
self.inner.ctx()
}
}
impl RegisterableStorage for PinnedStorage {
fn register(
&mut self,
key: &str,
handle: Box<dyn RegistationHandle>,
) -> Result<(), StorageError> {
self.handles.register(key, handle)
}
fn is_registered(&self, key: &str) -> bool {
self.handles.is_registered(key)
}
fn registration_handle(&self, key: &str) -> Option<&dyn RegistationHandle> {
self.handles.registration_handle(key)
}
}
impl StorageMemset for PinnedStorage {
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<(), StorageError> {
if offset + size > self.inner.size() {
return Err(StorageError::OperationFailed(
"memset: offset + size > storage size".into(),
));
}
unsafe {
let ptr = self.inner.as_mut_ptr().add(offset);
std::ptr::write_bytes(ptr, value, size);
}
Ok(())
}
}
pub struct PinnedAllocator {
ctx: Arc<CudaContext>,
}
impl Default for PinnedAllocator {
fn default() -> Self {
Self {
ctx: Cuda::device_or_create(0).expect("Failed to create CUDA context"),
}
}
}
impl PinnedAllocator {
pub fn new(device_id: usize) -> Result<Self, StorageError> {
Ok(Self {
ctx: Cuda::device_or_create(device_id)?,
})
}
}
impl StorageAllocator<PinnedStorage> for PinnedAllocator {
fn allocate(&self, size: usize) -> Result<PinnedStorage, StorageError> {
PinnedStorage::new_for_device(size, Some(self.ctx.cu_device() as u32))
}
}
#[derive(Debug)]
enum DeviceStorageType {
Owned, Torch { _tensor: Arc<dyn TorchTensor> }, }
#[derive(Debug)]
pub struct DeviceStorage {
ptr: u64,
size: usize,
ctx: Arc<CudaContext>,
handles: RegistrationHandles,
_storage_type: DeviceStorageType,
}
impl Local for DeviceStorage {}
impl CudaAccessible for DeviceStorage {}
impl DeviceStorage {
pub fn new(ctx: &Arc<CudaContext>, size: usize) -> Result<Self, StorageError> {
ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr = unsafe { cudarc::driver::result::malloc_sync(size).map_err(StorageError::Cuda)? };
Ok(Self {
ptr,
size,
ctx: ctx.clone(),
handles: RegistrationHandles::new(),
_storage_type: DeviceStorageType::Owned,
})
}
pub fn new_from_torch(
ctx: &Arc<CudaContext>,
tensor: Arc<dyn TorchTensor>,
) -> Result<Self, StorageError> {
let device = tensor.device();
let TorchDevice::Cuda(device_id) = device else {
return Err(StorageError::InvalidConfig("Tensor is not CUDA!".into()));
};
if device_id != ctx.cu_device() as usize {
return Err(StorageError::InvalidConfig(
"Tensor is not on the same device as the context!".into(),
));
}
let data_ptr = tensor.data_ptr();
let size = tensor.size_bytes();
Ok(Self {
ptr: data_ptr,
size,
ctx: ctx.clone(),
handles: RegistrationHandles::new(),
_storage_type: DeviceStorageType::Torch { _tensor: tensor },
})
}
pub fn context(&self) -> &Arc<CudaContext> {
&self.ctx
}
}
impl Storage for DeviceStorage {
fn storage_type(&self) -> StorageType {
StorageType::Device(self.ctx.cu_device() as u32)
}
fn addr(&self) -> u64 {
self.ptr
}
fn size(&self) -> usize {
self.size
}
unsafe fn as_ptr(&self) -> *const u8 {
self.ptr as *const u8
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr as *mut u8
}
}
impl CudaContextProivder for DeviceStorage {
fn cuda_context(&self) -> &Arc<CudaContext> {
&self.ctx
}
}
impl Drop for DeviceStorage {
fn drop(&mut self) {
self.handles.release();
match &self._storage_type {
DeviceStorageType::Owned => {
unsafe { cudarc::driver::result::free_sync(self.ptr as _) }.unwrap()
}
DeviceStorageType::Torch { _tensor } => {
}
}
}
}
impl RegisterableStorage for DeviceStorage {
fn register(
&mut self,
key: &str,
handle: Box<dyn RegistationHandle>,
) -> Result<(), StorageError> {
self.handles.register(key, handle)
}
fn is_registered(&self, key: &str) -> bool {
self.handles.is_registered(key)
}
fn registration_handle(&self, key: &str) -> Option<&dyn RegistationHandle> {
self.handles.registration_handle(key)
}
}
pub struct DeviceAllocator {
ctx: Arc<CudaContext>,
}
impl Default for DeviceAllocator {
fn default() -> Self {
Self {
ctx: CudaContext::new(0).expect("Failed to create CUDA context"),
}
}
}
impl DeviceAllocator {
pub fn new(device_id: usize) -> Result<Self, StorageError> {
Ok(Self {
ctx: Cuda::device_or_create(device_id)?,
})
}
pub fn ctx(&self) -> &Arc<CudaContext> {
&self.ctx
}
}
impl StorageAllocator<DeviceStorage> for DeviceAllocator {
fn allocate(&self, size: usize) -> Result<DeviceStorage, StorageError> {
DeviceStorage::new(&self.ctx, size)
}
}
#[cfg(all(test, feature = "testing-cuda"))]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct MockTensor {
device: TorchDevice,
data_ptr: u64,
size_bytes: usize,
}
impl MockTensor {
pub fn new(device: TorchDevice, data_ptr: u64, size_bytes: usize) -> Self {
Self {
device,
data_ptr,
size_bytes,
}
}
}
impl TorchTensor for MockTensor {
fn device(&self) -> TorchDevice {
self.device.clone()
}
fn data_ptr(&self) -> u64 {
self.data_ptr
}
fn size_bytes(&self) -> usize {
self.size_bytes
}
fn shape(&self) -> Vec<usize> {
vec![self.size_bytes]
}
fn stride(&self) -> Vec<usize> {
vec![1]
}
}
#[test]
fn test_device_storage_from_torch_valid_tensor() {
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size_bytes = 1024;
let actual_storage =
std::mem::ManuallyDrop::new(DeviceStorage::new(&ctx, size_bytes).unwrap());
let tensor = MockTensor::new(TorchDevice::Cuda(0), actual_storage.addr(), size_bytes);
let storage = DeviceStorage::new_from_torch(&ctx, Arc::new(tensor)).unwrap();
assert_eq!(storage.size(), size_bytes);
assert_eq!(storage.storage_type(), StorageType::Device(0));
assert_eq!(storage.addr(), actual_storage.addr());
}
#[test]
fn test_device_storage_from_torch_cpu_tensor_fails() {
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size_bytes = 1024;
let actual_storage = DeviceStorage::new(&ctx, size_bytes).unwrap();
let tensor = MockTensor::new(
TorchDevice::Other("cpu".to_string()),
actual_storage.addr(),
size_bytes,
);
let result = DeviceStorage::new_from_torch(&ctx, Arc::new(tensor));
assert!(result.is_err());
if let Err(StorageError::InvalidConfig(msg)) = result {
assert!(msg.contains("Tensor is not CUDA"));
} else {
panic!("Expected InvalidConfig error for CPU tensor");
}
}
#[test]
fn test_device_storage_wrong_device() {
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size_bytes = 1024;
let actual_storage = DeviceStorage::new(&ctx, size_bytes).unwrap();
let tensor = MockTensor::new(TorchDevice::Cuda(1), actual_storage.addr(), size_bytes);
let result = DeviceStorage::new_from_torch(&ctx, Arc::new(tensor));
assert!(result.is_err());
}
#[allow(deprecated)]
#[test]
fn test_pinned_storage_new_without_numa() {
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size = 8192;
let mut storage =
PinnedStorage::new(&ctx, size).expect("PinnedStorage::new should succeed");
assert_eq!(storage.size(), size);
assert_eq!(storage.storage_type(), StorageType::Pinned);
assert_ne!(storage.addr(), 0, "Address should be non-zero");
unsafe {
let ptr = storage.as_mut_ptr();
assert!(!ptr.is_null(), "Pointer should not be null");
for i in 0..size {
std::ptr::write_volatile(ptr.add(i), (i & 0xFF) as u8);
}
for i in 0..size {
let val = std::ptr::read_volatile(ptr.add(i));
assert_eq!(
val,
(i & 0xFF) as u8,
"Memory content mismatch at offset {}",
i
);
}
}
}
}