use crate::block_manager::storage::StorageType;
use super::{BlockLayout, BlockLayoutConfig, LayoutConfig, LayoutError, LayoutType};
use super::super::storage::{
nixl::{MemType, NixlAgent, NixlRegisterableStorage, NixlStorage, OptArgs},
Storage, StorageAllocator,
};
use super::{FullyContiguous, FullyContiguousConfig};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
pub trait NixlLayout: BlockLayout + BlockLayoutNixlStorage + ToSerializedNixlBlockLayout {
fn nixl_register(
&mut self,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> anyhow::Result<()>;
}
pub trait BlockLayoutNixlStorage {
fn mem_type(&self) -> MemType;
fn device_id(&self) -> u64;
}
impl<T> NixlLayout for T
where
T: BlockLayout + BlockLayoutNixlStorage + ToSerializedNixlBlockLayout + ?Sized, T::StorageType: NixlRegisterableStorage, {
fn nixl_register(
&mut self,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> anyhow::Result<()> {
for storage in self.storage_mut() {
storage.nixl_register(agent, opt_args)?;
}
Ok(())
}
}
impl LayoutConfig {
pub fn create_layout<S: Storage + NixlRegisterableStorage>(
&self,
layout_type: LayoutType,
storage: Vec<S>,
) -> Result<impl NixlLayout<StorageType = S>, LayoutError> {
match layout_type {
LayoutType::FullyContiguous => FullyContiguous::new(self.clone(), storage),
}
}
pub fn allocate_layout<S: Storage + NixlRegisterableStorage>(
&self,
layout_type: LayoutType,
allocator: Arc<dyn StorageAllocator<S>>,
) -> Result<impl NixlLayout<StorageType = S>, LayoutError> {
match layout_type {
LayoutType::FullyContiguous => {
FullyContiguous::allocate(self.clone(), allocator.as_ref())
}
}
}
}
pub trait ToSerializedNixlBlockLayout: BlockLayout<StorageType: NixlRegisterableStorage> {
fn serialize(&self) -> Result<SerializedNixlBlockLayout, LayoutError>;
}
#[derive(Serialize, Deserialize, Clone)]
pub struct SerializedNixlBlockLayout(Vec<u8>);
#[derive(Serialize, Deserialize, Debug, Clone)]
enum NixlBlockLayoutKinds {
FullyContiguous(SerializableNixlLayout<FullyContiguousConfig>),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct SerializableNixlLayout<C: BlockLayoutConfig> {
config: C,
base_offset: usize,
storage_descriptors: Vec<NixlStorage>,
storage_type: StorageType,
}
impl<C> SerializableNixlLayout<C>
where
C: BlockLayoutConfig + Serialize + for<'de> Deserialize<'de> + Clone + std::fmt::Debug,
{
fn new(
config: C,
base_offset: usize,
storage_descriptors: Vec<NixlStorage>,
storage_type: StorageType,
) -> Self {
Self {
config,
base_offset,
storage_descriptors,
storage_type,
}
}
}
impl<S: NixlRegisterableStorage> ToSerializedNixlBlockLayout for FullyContiguous<S> {
fn serialize(&self) -> Result<SerializedNixlBlockLayout, LayoutError> {
let config = self.config.clone();
let base_offset = self.base_offset;
let storages = self.storage();
if storages.len() != 1 {
return Err(LayoutError::InvalidConfig(
"FullyContiguous reconstruction expects exactly one NixlStorage descriptor"
.to_string(),
));
}
let storage_instance = storages.first().ok_or_else(|| {
LayoutError::OperationFailed("FullyContiguous requires one storage element".to_string())
})?;
let storage_descriptors =
unsafe { storage_instance.as_nixl_descriptor() }.ok_or_else(|| {
LayoutError::OperationFailed(
"Storage does not provide NIXL descriptors for serialization".to_string(),
)
})?;
let serializable_data = SerializableNixlLayout::new(
config,
base_offset,
vec![storage_descriptors],
self.storage_type(),
);
let nixl_block_layout = NixlBlockLayoutKinds::FullyContiguous(serializable_data);
Ok(SerializedNixlBlockLayout(serde_json::to_vec(
&nixl_block_layout,
)?))
}
}
impl SerializedNixlBlockLayout {
pub fn deserialize(
&self,
) -> Result<Arc<dyn BlockLayout<StorageType = NixlStorage>>, LayoutError> {
let nixl_block_layout: NixlBlockLayoutKinds = serde_json::from_slice(&self.0)?;
match nixl_block_layout {
NixlBlockLayoutKinds::FullyContiguous(config) => {
if config.storage_descriptors.len() != 1 {
return Err(LayoutError::InvalidConfig(
"FullyContiguous reconstruction expects exactly one NixlStorage descriptor"
.to_string(),
));
}
let storage = config.storage_descriptors[0].clone();
let layout = FullyContiguous::new_internal(
config.config.clone(),
storage, config.base_offset,
config.storage_type,
)?;
Ok(Arc::new(layout))
} }
}
}
impl<S> BlockLayoutNixlStorage for FullyContiguous<S>
where
S: Storage + NixlRegisterableStorage,
{
fn mem_type(&self) -> MemType {
self.storage.mem_type()
}
fn device_id(&self) -> u64 {
self.storage.device_id()
}
}
#[cfg(test)]
mod tests {
use super::super::*;
use super::*;
use crate::block_manager::storage::SystemAllocator;
use dynamo_runtime::logging::init as init_logging;
#[test]
fn test_nixl_layout() {
init_logging();
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(2)
.page_size(4)
.inner_dim(13)
.build()
.unwrap();
config.validate().unwrap();
let mut layout = FullyContiguous::allocate(config, &SystemAllocator).unwrap();
let agent = NixlAgent::new("test").unwrap();
tracing::info!("Registering layout");
layout.nixl_register(&agent, None).unwrap();
tracing::info!("Layout registered");
let local_storage_type = layout.storage_type();
let serialized = layout.serialize().unwrap();
let remote_layout = SerializedNixlBlockLayout::deserialize(&serialized).unwrap();
println!("Nixl layout: {:?}", remote_layout);
let remote_storage_type = remote_layout.storage_type();
assert_eq!(local_storage_type, remote_storage_type);
drop(layout);
tracing::info!("Layout dropped");
}
}