use crate::block_manager::storage::StorageType;
use super::{
BlockLayout, BlockLayoutConfig, GenericBlockLayout, LayoutConfig, LayoutError, LayoutType,
};
use super::super::storage::{
Storage, StorageAllocator,
nixl::{NixlAgent, NixlRegisterableStorage, NixlStorage, OptArgs},
};
use super::{FullyContiguous, FullyContiguousConfig, LayerSeparate, LayerSeparateConfig};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
pub trait NixlLayout: BlockLayout + ToSerializedNixlBlockLayout {
fn nixl_register(
&mut self,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> anyhow::Result<()>;
}
impl<T> NixlLayout for T
where
T: BlockLayout + ToSerializedNixlBlockLayout + ?Sized, T::StorageType: NixlRegisterableStorage, {
fn nixl_register(
&mut self,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> anyhow::Result<()> {
for storage in self.storage_mut() {
storage.nixl_register(agent, opt_args)?;
}
Ok(())
}
}
impl LayoutConfig {
pub fn create_layout<S: Storage + NixlRegisterableStorage>(
&self,
layout_type: LayoutType,
storage: Vec<S>,
) -> Result<Box<dyn NixlLayout<StorageType = S>>, LayoutError> {
Ok(match layout_type {
LayoutType::FullyContiguous => Box::new(FullyContiguous::new(self.clone(), storage)?),
LayoutType::LayerSeparate { outer_contiguous } => {
Box::new(LayerSeparate::new(self.clone(), storage, outer_contiguous)?)
}
})
}
pub fn allocate_layout<S: Storage + NixlRegisterableStorage>(
&self,
layout_type: LayoutType,
allocator: Arc<dyn StorageAllocator<S>>,
) -> Result<Box<dyn NixlLayout<StorageType = S>>, LayoutError> {
Ok(match layout_type {
LayoutType::FullyContiguous => {
Box::new(FullyContiguous::allocate(self.clone(), allocator.as_ref())?)
}
LayoutType::LayerSeparate { outer_contiguous } => Box::new(LayerSeparate::allocate(
self.clone(),
allocator.as_ref(),
outer_contiguous,
)?),
})
}
}
pub trait ToSerializedNixlBlockLayout: BlockLayout<StorageType: NixlRegisterableStorage> {
fn serialize(&self) -> Result<SerializedNixlBlockLayout, LayoutError>;
}
#[derive(Serialize, Deserialize, Clone)]
pub struct SerializedNixlBlockLayout(Vec<u8>);
#[derive(Serialize, Deserialize, Debug, Clone)]
enum NixlBlockLayoutKinds {
FullyContiguous(SerializableNixlLayout<FullyContiguousConfig>),
LayerSeparate(SerializableNixlLayout<LayerSeparateConfig>),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct SerializableNixlLayout<C: BlockLayoutConfig> {
config: C,
base_offsets: Vec<usize>,
storage_descriptors: Vec<NixlStorage>,
storage_type: StorageType,
}
impl<C> SerializableNixlLayout<C>
where
C: BlockLayoutConfig + Serialize + for<'de> Deserialize<'de> + Clone + std::fmt::Debug,
{
fn new(
config: C,
base_offsets: Vec<usize>,
storage_descriptors: Vec<NixlStorage>,
storage_type: StorageType,
) -> Self {
Self {
config,
base_offsets,
storage_descriptors,
storage_type,
}
}
}
fn serialize_storages<S: NixlRegisterableStorage>(
storages: Vec<&S>,
) -> Result<Vec<NixlStorage>, LayoutError> {
let mut storage_descriptors = Vec::new();
for storage in storages {
let descriptor = unsafe { storage.as_nixl_descriptor() }.ok_or_else(|| {
LayoutError::OperationFailed(
"Storage does not provide NIXL descriptors for serialization".to_string(),
)
})?;
storage_descriptors.push(descriptor);
}
Ok(storage_descriptors)
}
impl<S: NixlRegisterableStorage> ToSerializedNixlBlockLayout for FullyContiguous<S> {
fn serialize(&self) -> Result<SerializedNixlBlockLayout, LayoutError> {
let config = self.config.clone();
let base_offset = self.base_offset;
let storages = self.storage();
if storages.len() != 1 {
return Err(LayoutError::InvalidConfig(
"FullyContiguous reconstruction expects exactly one NixlStorage descriptor"
.to_string(),
));
}
let storage_descriptors = serialize_storages(storages)?;
let serializable_data = SerializableNixlLayout::new(
config,
vec![base_offset],
storage_descriptors,
*self.storage_type(),
);
let nixl_block_layout = NixlBlockLayoutKinds::FullyContiguous(serializable_data);
Ok(SerializedNixlBlockLayout(serde_json::to_vec(
&nixl_block_layout,
)?))
}
}
impl<S: NixlRegisterableStorage> ToSerializedNixlBlockLayout for LayerSeparate<S> {
fn serialize(&self) -> Result<SerializedNixlBlockLayout, LayoutError> {
let config = self.config.clone();
let base_offsets = self.base_offsets.clone();
let storages = self.storage();
let storage_descriptors = serialize_storages(storages)?;
let serializable_data = SerializableNixlLayout::new(
config,
base_offsets,
storage_descriptors,
*self.storage_type(),
);
let nixl_block_layout = NixlBlockLayoutKinds::LayerSeparate(serializable_data);
Ok(SerializedNixlBlockLayout(serde_json::to_vec(
&nixl_block_layout,
)?))
}
}
impl SerializedNixlBlockLayout {
pub fn deserialize(
&self,
) -> Result<Arc<dyn BlockLayout<StorageType = NixlStorage>>, LayoutError> {
let nixl_block_layout: NixlBlockLayoutKinds = serde_json::from_slice(&self.0)?;
match nixl_block_layout {
NixlBlockLayoutKinds::FullyContiguous(config) => {
if config.storage_descriptors.len() != 1 {
return Err(LayoutError::InvalidConfig(
"FullyContiguous reconstruction expects exactly one NixlStorage descriptor"
.to_string(),
));
}
let storage = config.storage_descriptors[0].clone();
let layout = FullyContiguous::new_internal(
config.config.clone(),
storage, config.storage_type,
config.base_offsets[0],
)?;
Ok(Arc::new(layout))
}
NixlBlockLayoutKinds::LayerSeparate(config) => {
if config.storage_descriptors.len() != config.config.num_layers() {
return Err(LayoutError::InvalidConfig(
"LayerSeparate reconstruction expects exactly one NixlStorage descriptor per layer"
.to_string(),
));
}
let storages = config.storage_descriptors.to_vec();
let layout = LayerSeparate::new_internal(
config.config.clone(),
storages,
config.storage_type,
config.base_offsets,
)?;
Ok(Arc::new(layout))
}
}
}
}
#[cfg(all(test, feature = "testing-nixl"))]
mod tests {
use super::super::*;
use super::*;
use crate::block_manager::storage::SystemAllocator;
use dynamo_runtime::logging::init as init_logging;
#[test]
fn test_nixl_layout() {
init_logging();
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(2)
.outer_dim(2)
.page_size(4)
.inner_dim(13)
.build()
.unwrap();
config.validate().unwrap();
let mut layout = FullyContiguous::allocate(config, &SystemAllocator).unwrap();
let agent = NixlAgent::new("test").unwrap();
tracing::info!("Registering layout");
layout.nixl_register(&agent, None).unwrap();
tracing::info!("Layout registered");
let local_storage_type = layout.storage_type();
let serialized = layout.serialize().unwrap();
let remote_layout = SerializedNixlBlockLayout::deserialize(&serialized).unwrap();
println!("Nixl layout: {:?}", remote_layout);
let remote_storage_type = remote_layout.storage_type();
assert_eq!(local_storage_type, remote_storage_type);
let _: Arc<dyn GenericBlockLayout> = remote_layout;
drop(layout);
tracing::info!("Layout dropped");
}
}