use super::{BlockDataExt, BlockError, Storage};
use crate::block_manager::storage::StorageType;
pub trait Kind: std::marker::Sized + std::fmt::Debug + Clone + Copy + Send + Sync {}
#[derive(Debug, Clone, Copy)]
pub struct BlockKind;
impl Kind for BlockKind {}
#[derive(Debug, Clone, Copy)]
pub struct LayerKind;
impl Kind for LayerKind {}
pub type BlockView<'a, S> = MemoryView<'a, S, BlockKind>;
pub type BlockViewMut<'a, S> = MemoryViewMut<'a, S, BlockKind>;
pub type LayerView<'a, S> = MemoryView<'a, S, LayerKind>;
pub type LayerViewMut<'a, S> = MemoryViewMut<'a, S, LayerKind>;
#[derive(Debug)]
pub struct MemoryView<'a, S: Storage, K: Kind> {
_block_data: &'a dyn BlockDataExt<S>,
addr: usize,
size: usize,
storage_type: StorageType,
kind: std::marker::PhantomData<K>,
}
impl<'a, S, K> MemoryView<'a, S, K>
where
S: Storage,
K: Kind,
{
pub(crate) unsafe fn new(
_block_data: &'a dyn BlockDataExt<S>,
addr: usize,
size: usize,
storage_type: StorageType,
) -> Result<Self, BlockError> {
Ok(Self {
_block_data,
addr,
size,
storage_type,
kind: std::marker::PhantomData,
})
}
pub unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
pub fn size(&self) -> usize {
self.size
}
}
#[derive(Debug)]
pub struct MemoryViewMut<'a, S: Storage, K: Kind> {
_block_data: &'a mut dyn BlockDataExt<S>,
addr: usize,
size: usize,
storage_type: StorageType,
kind: std::marker::PhantomData<K>,
}
impl<'a, S: Storage, K: Kind> MemoryViewMut<'a, S, K> {
pub(crate) unsafe fn new(
_block_data: &'a mut dyn BlockDataExt<S>,
addr: usize,
size: usize,
storage_type: StorageType,
) -> Result<Self, BlockError> {
Ok(Self {
_block_data,
addr,
size,
storage_type,
kind: std::marker::PhantomData,
})
}
pub unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.addr as *mut u8
}
pub fn size(&self) -> usize {
self.size
}
}
mod nixl {
use super::*;
use super::super::nixl::*;
pub use crate::block_manager::storage::StorageType;
pub use nixl_sys::{MemType, MemoryRegion, NixlDescriptor};
impl<S: Storage, K: Kind> MemoryRegion for MemoryView<'_, S, K> {
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
fn size(&self) -> usize {
self.size()
}
}
impl<S, K> NixlDescriptor for MemoryView<'_, S, K>
where
S: Storage + NixlDescriptor,
K: Kind,
{
fn mem_type(&self) -> MemType {
self._block_data.storage_type().nixl_mem_type()
}
fn device_id(&self) -> u64 {
match self.storage_type {
StorageType::System | StorageType::Pinned => 0,
StorageType::Device(device_id) => device_id as u64,
StorageType::Disk(fd) => fd,
_ => panic!("Invalid storage type"),
}
}
}
impl<S: Storage, K: Kind> MemoryRegion for MemoryViewMut<'_, S, K> {
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
fn size(&self) -> usize {
self.size()
}
}
impl<S: Storage, K: Kind> NixlDescriptor for MemoryViewMut<'_, S, K>
where
S: Storage + NixlDescriptor,
K: Kind,
{
fn mem_type(&self) -> MemType {
self._block_data.storage_type().nixl_mem_type()
}
fn device_id(&self) -> u64 {
match self.storage_type {
StorageType::System | StorageType::Pinned => 0,
StorageType::Device(device_id) => device_id as u64,
StorageType::Disk(fd) => fd,
_ => panic!("Invalid storage type"),
}
}
}
impl<'a, S, K> MemoryView<'a, S, K>
where
S: Storage + NixlDescriptor, K: Kind,
{
pub fn as_nixl_descriptor(&self) -> NixlMemoryDescriptor<'a, K, IsImmutable> {
NixlMemoryDescriptor::new(
self.addr as u64, self.size(), self.mem_type(),
self.device_id(),
)
}
}
impl<'a, S, K> MemoryViewMut<'a, S, K>
where
S: Storage + NixlDescriptor,
K: Kind,
{
pub fn as_nixl_descriptor_mut(&mut self) -> NixlMemoryDescriptor<'a, K, IsMutable> {
NixlMemoryDescriptor::new(
self.addr as u64,
self.size(),
self.mem_type(),
self.device_id(),
)
}
}
}