#![allow(missing_docs)]
pub mod arena;
pub mod cuda;
pub mod disk;
pub mod nixl;
pub mod object;
pub mod torch;
pub use cuda::*;
pub use disk::*;
pub use object::ObjectStorage;
use torch::*;
use std::{
alloc::{Layout, alloc_zeroed, dealloc},
collections::HashMap,
fmt::Debug,
ptr::NonNull,
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
pub type StorageResult<T> = std::result::Result<T, StorageError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub enum StorageType {
System,
Device(u32),
Pinned,
Disk(u64),
Nixl,
Null,
}
pub trait Local {}
pub trait Remote {}
pub trait SystemAccessible {}
pub trait CudaAccessible {}
#[derive(Debug, Error)]
#[allow(missing_docs)]
pub enum StorageError {
#[error("Storage allocation failed: {0}")]
AllocationFailed(String),
#[error("Storage not accessible: {0}")]
NotAccessible(String),
#[error("Invalid storage configuration: {0}")]
InvalidConfig(String),
#[error("Storage operation failed: {0}")]
OperationFailed(String),
#[error("CUDA error: {0}")]
Cuda(#[from] cudarc::driver::DriverError),
#[error("Registration key already exists: {0}")]
RegistrationKeyExists(String),
#[error("Handle not found for key: {0}")]
HandleNotFound(String),
#[error("NIXL error: {0}")]
NixlError(#[from] nixl_sys::NixlError),
#[error("Out of bounds: {0}")]
OutOfBounds(String),
}
pub trait Storage: Debug + Send + Sync + 'static {
fn storage_type(&self) -> StorageType;
fn addr(&self) -> u64;
fn size(&self) -> usize;
unsafe fn as_ptr(&self) -> *const u8;
unsafe fn as_mut_ptr(&mut self) -> *mut u8;
}
pub trait StorageTypeProvider {
type StorageType: Storage;
fn storage_type_id(&self) -> std::any::TypeId {
std::any::TypeId::of::<Self::StorageType>()
}
}
pub trait StorageMemset: Storage {
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<(), StorageError>;
}
pub trait RegisterableStorage: Storage + Send + Sync + 'static {
fn register(
&mut self,
key: &str,
handle: Box<dyn RegistationHandle>,
) -> Result<(), StorageError>;
fn is_registered(&self, key: &str) -> bool;
fn registration_handle(&self, key: &str) -> Option<&dyn RegistationHandle>;
}
pub trait RegistationHandle: std::any::Any + Send + Sync + 'static {
fn release(&mut self);
}
#[derive(Default)]
pub struct RegistrationHandles {
handles: HashMap<String, Box<dyn RegistationHandle>>,
}
impl RegistrationHandles {
pub fn new() -> Self {
Self {
handles: HashMap::new(),
}
}
pub fn register(
&mut self,
key: &str,
handle: Box<dyn RegistationHandle>,
) -> Result<(), StorageError> {
let key = key.to_string();
if self.handles.contains_key(&key) {
return Err(StorageError::RegistrationKeyExists(key));
}
self.handles.insert(key, handle);
Ok(())
}
fn release(&mut self) {
for handle in self.handles.values_mut() {
handle.release();
}
self.handles.clear();
}
fn is_registered(&self, key: &str) -> bool {
self.handles.contains_key(key)
}
fn registration_handle(&self, key: &str) -> Option<&dyn RegistationHandle> {
self.handles.get(key).map(|h| h.as_ref())
}
}
impl std::fmt::Debug for RegistrationHandles {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"RegistrationHandles {{ count: {:?} }}",
self.handles.len()
)
}
}
impl Drop for RegistrationHandles {
fn drop(&mut self) {
if !self.handles.is_empty() {
panic!(
"RegistrationHandles dropped with {} handles remaining; RegistrationHandles::release() needs to be explicitly called",
self.handles.len()
);
}
}
}
pub trait StorageAllocator<S: Storage>: Send + Sync {
fn allocate(&self, size: usize) -> Result<S, StorageError>;
}
#[derive(Debug)]
pub struct SystemStorage {
ptr: NonNull<u8>,
layout: Layout,
len: usize,
handles: RegistrationHandles,
}
unsafe impl Send for SystemStorage {}
unsafe impl Sync for SystemStorage {}
impl Local for SystemStorage {}
impl SystemAccessible for SystemStorage {}
impl SystemStorage {
pub fn new(size: usize) -> Result<Self, StorageError> {
let layout =
Layout::array::<u8>(size).map_err(|e| StorageError::AllocationFailed(e.to_string()))?;
let ptr = unsafe {
NonNull::new(alloc_zeroed(layout))
.ok_or_else(|| StorageError::AllocationFailed("memory allocation failed".into()))?
};
Ok(Self {
ptr,
layout,
len: size,
handles: RegistrationHandles::new(),
})
}
}
impl Drop for SystemStorage {
fn drop(&mut self) {
self.handles.release();
unsafe {
dealloc(self.ptr.as_ptr(), self.layout);
}
}
}
impl Storage for SystemStorage {
fn storage_type(&self) -> StorageType {
StorageType::System
}
fn addr(&self) -> u64 {
self.ptr.as_ptr() as u64
}
fn size(&self) -> usize {
self.len
}
unsafe fn as_ptr(&self) -> *const u8 {
self.ptr.as_ptr()
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr.as_ptr()
}
}
impl StorageMemset for SystemStorage {
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<(), StorageError> {
if offset + size > self.len {
return Err(StorageError::OperationFailed(
"memset: offset + size > storage size".into(),
));
}
unsafe {
let ptr = self.ptr.as_ptr().add(offset);
std::ptr::write_bytes(ptr, value, size);
}
Ok(())
}
}
impl RegisterableStorage for SystemStorage {
fn register(
&mut self,
key: &str,
handle: Box<dyn RegistationHandle>,
) -> Result<(), StorageError> {
self.handles.register(key, handle)
}
fn is_registered(&self, key: &str) -> bool {
self.handles.is_registered(key)
}
fn registration_handle(&self, key: &str) -> Option<&dyn RegistationHandle> {
self.handles.registration_handle(key)
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct SystemAllocator;
impl StorageAllocator<SystemStorage> for SystemAllocator {
fn allocate(&self, size: usize) -> Result<SystemStorage, StorageError> {
SystemStorage::new(size)
}
}
#[allow(missing_docs)]
pub mod tests {
use super::*;
#[derive(Debug)]
pub struct NullDeviceStorage {
size: u64,
}
impl NullDeviceStorage {
pub fn new(size: u64) -> Self {
Self { size }
}
}
impl Storage for NullDeviceStorage {
fn storage_type(&self) -> StorageType {
StorageType::Null
}
fn addr(&self) -> u64 {
0
}
fn size(&self) -> usize {
self.size as usize
}
unsafe fn as_ptr(&self) -> *const u8 {
std::ptr::null()
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
std::ptr::null_mut()
}
}
pub struct NullDeviceAllocator;
impl StorageAllocator<NullDeviceStorage> for NullDeviceAllocator {
fn allocate(&self, size: usize) -> Result<NullDeviceStorage, StorageError> {
Ok(NullDeviceStorage::new(size as u64))
}
}
#[derive(Debug)]
pub struct NullHostStorage {
size: u64,
}
impl NullHostStorage {
pub fn new(size: u64) -> Self {
Self { size }
}
}
impl Storage for NullHostStorage {
fn storage_type(&self) -> StorageType {
StorageType::Null
}
fn addr(&self) -> u64 {
0
}
fn size(&self) -> usize {
self.size as usize
}
unsafe fn as_ptr(&self) -> *const u8 {
std::ptr::null()
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
std::ptr::null_mut()
}
}
pub struct NullHostAllocator;
impl StorageAllocator<NullHostStorage> for NullHostAllocator {
fn allocate(&self, size: usize) -> Result<NullHostStorage, StorageError> {
Ok(NullHostStorage::new(size as u64))
}
}
}