pub use nixl_sys::{
Agent as NixlAgent, MemType, MemoryRegion, NixlDescriptor, OptArgs,
RegistrationHandle as NixlRegistrationHandle,
};
use derive_getters::Getters;
use serde::{Deserialize, Serialize};
use super::{
CudaContextProivder, DeviceStorage, PinnedStorage, RegistationHandle, RegisterableStorage,
Remote, Storage, StorageError, StorageType, SystemStorage,
};
pub trait NixlAccessible {}
impl StorageType {
pub fn nixl_mem_type(&self) -> MemType {
match self {
StorageType::System => MemType::Dram,
StorageType::Pinned => MemType::Dram,
StorageType::Device(_) => MemType::Vram,
StorageType::Nixl => MemType::Unknown,
StorageType::Null => MemType::Unknown,
}
}
pub fn nixl_device_id(&self) -> u64 {
match self {
StorageType::System => 0,
StorageType::Pinned => 0,
StorageType::Device(id) => *id as u64,
StorageType::Nixl => 0,
StorageType::Null => 0,
}
}
}
impl RegistationHandle for NixlRegistrationHandle {
fn release(&mut self) {
if let Err(e) = self.deregister() {
tracing::error!("Failed to deregister Nixl storage: {}", e);
}
}
}
pub trait NixlRegisterableStorage: RegisterableStorage + NixlDescriptor + Sized {
fn nixl_register(
&mut self,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> Result<(), StorageError> {
let handle = Box::new(agent.register_memory(self, opt_args)?);
self.register("nixl", handle)
}
fn is_nixl_registered(&self) -> bool {
self.is_registered("nixl")
}
fn nixl_agent_name(&self) -> Option<String> {
self.registration_handle("nixl")
.and_then(|handle_box| {
(handle_box as &dyn std::any::Any)
.downcast_ref::<NixlRegistrationHandle>()
.map(|nixl_handle| nixl_handle.agent_name())
})?
}
unsafe fn as_nixl_descriptor(&self) -> Option<NixlStorage> {
if self.is_nixl_registered() {
Some(NixlStorage {
addr: self.addr(),
size: MemoryRegion::size(self),
mem_type: self.mem_type(),
device_id: self.device_id(),
})
} else {
None
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Getters)]
pub struct NixlStorage {
addr: u64,
size: usize,
mem_type: MemType,
device_id: u64,
}
impl Remote for NixlStorage {}
impl NixlAccessible for NixlStorage {}
impl Storage for NixlStorage {
fn storage_type(&self) -> StorageType {
StorageType::Nixl
}
fn addr(&self) -> u64 {
self.addr
}
fn size(&self) -> usize {
self.size
}
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.addr as *mut u8
}
}
impl MemoryRegion for NixlStorage {
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
fn size(&self) -> usize {
self.size
}
}
impl NixlDescriptor for NixlStorage {
fn mem_type(&self) -> MemType {
self.mem_type
}
fn device_id(&self) -> u64 {
self.device_id
}
}
impl NixlRegisterableStorage for SystemStorage {}
impl MemoryRegion for SystemStorage {
unsafe fn as_ptr(&self) -> *const u8 {
self.ptr.as_ptr()
}
fn size(&self) -> usize {
self.len
}
}
impl NixlDescriptor for SystemStorage {
fn mem_type(&self) -> MemType {
MemType::Dram
}
fn device_id(&self) -> u64 {
0
}
}
impl NixlAccessible for PinnedStorage {}
impl NixlRegisterableStorage for PinnedStorage {}
impl MemoryRegion for PinnedStorage {
unsafe fn as_ptr(&self) -> *const u8 {
Storage::as_ptr(self)
}
fn size(&self) -> usize {
Storage::size(self)
}
}
impl NixlDescriptor for PinnedStorage {
fn mem_type(&self) -> MemType {
MemType::Dram
}
fn device_id(&self) -> u64 {
0
}
}
impl NixlAccessible for DeviceStorage {}
impl NixlRegisterableStorage for DeviceStorage {}
impl MemoryRegion for DeviceStorage {
unsafe fn as_ptr(&self) -> *const u8 {
Storage::as_ptr(self)
}
fn size(&self) -> usize {
Storage::size(self)
}
}
impl NixlDescriptor for DeviceStorage {
fn mem_type(&self) -> MemType {
MemType::Vram
}
fn device_id(&self) -> u64 {
CudaContextProivder::cuda_context(self).cu_device() as u64
}
}