pub use nixl_sys::{
Agent as NixlAgent, MemType, MemoryRegion, NixlDescriptor, OptArgs,
RegistrationHandle as NixlRegistrationHandle,
};
use derive_getters::Getters;
use serde::{Deserialize, Serialize};
use super::{
CudaContextProivder, DeviceStorage, DiskStorage, PinnedStorage, RegistationHandle,
RegisterableStorage, Remote, Storage, StorageError, StorageType, SystemStorage,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NixlRemoteDescriptor {
storage: NixlStorage,
agent: String,
notif: Option<String>,
}
impl NixlRemoteDescriptor {
pub(crate) fn new(storage: NixlStorage, agent: String) -> Self {
Self {
storage,
agent,
notif: None,
}
}
pub fn size(&self) -> usize {
*self.storage.size()
}
pub fn set_notif(&mut self, notif: String) {
self.notif = Some(notif);
}
pub fn clear_notif(&mut self) {
self.notif = None;
}
pub fn get_notif(&self) -> Option<String> {
self.notif.clone()
}
pub fn agent_name(&self) -> &str {
self.agent.as_str()
}
}
pub trait NixlAccessible {}
impl StorageType {
pub fn nixl_mem_type(&self) -> MemType {
match self {
StorageType::System => MemType::Dram,
StorageType::Pinned => MemType::Dram,
StorageType::Device(_) => MemType::Vram,
StorageType::Nixl => MemType::Unknown,
StorageType::Null => MemType::Unknown,
StorageType::Disk(_) => MemType::File,
}
}
}
impl RegistationHandle for NixlRegistrationHandle {
fn release(&mut self) {
if let Err(e) = self.deregister() {
tracing::error!("Failed to deregister Nixl storage: {}", e);
}
}
}
fn handle_nixl_register<S: NixlRegisterableStorage>(
storage: &mut S,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> Result<(), StorageError> {
let handle = Box::new(agent.register_memory(storage, opt_args)?);
storage.register("nixl", handle)
}
pub trait NixlRegisterableStorage: RegisterableStorage + NixlDescriptor + Sized {
fn nixl_register(
&mut self,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> Result<(), StorageError> {
handle_nixl_register(self, agent, opt_args)
}
fn is_nixl_registered(&self) -> bool {
self.is_registered("nixl")
}
fn nixl_agent_name(&self) -> Option<String> {
self.registration_handle("nixl")
.and_then(|handle_box| {
(handle_box as &dyn std::any::Any)
.downcast_ref::<NixlRegistrationHandle>()
.map(|nixl_handle| nixl_handle.agent_name())
})?
}
unsafe fn as_nixl_descriptor(&self) -> Option<NixlStorage> {
if self.is_nixl_registered() {
Some(NixlStorage {
addr: self.addr(),
size: MemoryRegion::size(self),
mem_type: self.mem_type(),
device_id: self.device_id(),
})
} else {
None
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Getters)]
pub struct NixlStorage {
addr: u64,
size: usize,
mem_type: MemType,
device_id: u64,
}
impl Remote for NixlStorage {}
impl NixlAccessible for NixlStorage {}
impl NixlStorage {
pub(crate) fn from_storage_with_offset<S: Storage + NixlDescriptor>(
storage: &S,
offset: usize,
size: usize,
) -> Result<Self, StorageError> {
if offset + size > Storage::size(storage) {
return Err(StorageError::OutOfBounds(format!(
"Offset: {}, Size: {}, Total Size: {}",
offset,
size,
Storage::size(storage)
)));
}
Ok(Self {
addr: storage.addr() + offset as u64,
size,
mem_type: storage.mem_type(),
device_id: storage.device_id(),
})
}
}
impl Storage for NixlStorage {
fn storage_type(&self) -> StorageType {
StorageType::Nixl
}
fn addr(&self) -> u64 {
self.addr
}
fn size(&self) -> usize {
self.size
}
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.addr as *mut u8
}
}
impl MemoryRegion for NixlStorage {
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
fn size(&self) -> usize {
self.size
}
}
impl NixlDescriptor for NixlStorage {
fn mem_type(&self) -> MemType {
self.mem_type
}
fn device_id(&self) -> u64 {
self.device_id
}
}
impl NixlRegisterableStorage for SystemStorage {}
impl MemoryRegion for SystemStorage {
unsafe fn as_ptr(&self) -> *const u8 {
self.ptr.as_ptr()
}
fn size(&self) -> usize {
self.len
}
}
impl NixlDescriptor for SystemStorage {
fn mem_type(&self) -> MemType {
MemType::Dram
}
fn device_id(&self) -> u64 {
0
}
}
impl NixlAccessible for PinnedStorage {}
impl NixlRegisterableStorage for PinnedStorage {}
impl MemoryRegion for PinnedStorage {
unsafe fn as_ptr(&self) -> *const u8 {
unsafe { Storage::as_ptr(self) }
}
fn size(&self) -> usize {
Storage::size(self)
}
}
impl NixlDescriptor for PinnedStorage {
fn mem_type(&self) -> MemType {
MemType::Dram
}
fn device_id(&self) -> u64 {
0
}
}
impl NixlAccessible for DeviceStorage {}
impl NixlRegisterableStorage for DeviceStorage {}
impl MemoryRegion for DeviceStorage {
unsafe fn as_ptr(&self) -> *const u8 {
unsafe { Storage::as_ptr(self) }
}
fn size(&self) -> usize {
Storage::size(self)
}
}
impl NixlDescriptor for DeviceStorage {
fn mem_type(&self) -> MemType {
MemType::Vram
}
fn device_id(&self) -> u64 {
CudaContextProivder::cuda_context(self).cu_device() as u64
}
}
impl NixlAccessible for DiskStorage {}
impl NixlRegisterableStorage for DiskStorage {
fn nixl_register(
&mut self,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> Result<(), StorageError> {
if self.unlinked() {
return Err(StorageError::AllocationFailed(
"Disk storage has already been unlinked. GDS registration will fail.".to_string(),
));
}
handle_nixl_register(self, agent, opt_args)?;
self.unlink()?;
Ok(())
}
}
impl MemoryRegion for DiskStorage {
unsafe fn as_ptr(&self) -> *const u8 {
unsafe { Storage::as_ptr(self) }
}
fn size(&self) -> usize {
Storage::size(self)
}
}
impl NixlDescriptor for DiskStorage {
fn mem_type(&self) -> MemType {
MemType::File
}
fn device_id(&self) -> u64 {
self.fd()
}
}