use anyhow::{Result, anyhow};
use std::sync::Arc;
use validator::Validate;
use super::serialize::{LayerSeparateDetails, LayoutTypeDetails};
use super::{
BlockDimension, Layout, LayoutConfig, MemoryDescriptor, MemoryRegion, OwnedMemoryRegion,
};
#[derive(Debug)]
pub struct LayerSeparateLayout {
config: LayoutConfig,
layer_base_addrs: Vec<usize>,
block_dim: BlockDimension,
block_stride: usize,
outer_stride: usize,
region_size: usize,
memory_regions: Vec<Arc<dyn MemoryRegion>>,
}
impl LayerSeparateLayout {
pub fn new(
config: LayoutConfig,
memory: Vec<Arc<dyn MemoryRegion>>,
block_dim: BlockDimension,
) -> Result<Self> {
config.validate()?;
if memory.len() != config.num_layers {
return Err(anyhow!(
"Memory region count ({}) must match num_layers ({})",
memory.len(),
config.num_layers
));
}
let region_size = config.page_size * config.inner_dim * config.dtype_width_bytes;
let (block_stride, outer_stride) = if block_dim == BlockDimension::BlockIsSecondDim {
let block_stride = region_size;
let outer_stride = block_stride * config.num_blocks;
(block_stride, outer_stride)
} else {
let outer_stride = region_size;
let block_stride = outer_stride * config.outer_dim;
(block_stride, outer_stride)
};
let mut layer_base_addrs = Vec::with_capacity(config.num_layers);
let required_size = config.num_blocks * config.outer_dim * region_size;
for (i, mem) in memory.iter().enumerate() {
if mem.size() < required_size {
return Err(anyhow!(
"Memory region {} too small for layout. Required: {} bytes, got: {} bytes",
i,
required_size,
mem.size()
));
}
layer_base_addrs.push(mem.addr());
}
Ok(Self {
config,
layer_base_addrs,
block_dim,
block_stride,
outer_stride,
region_size,
memory_regions: memory,
})
}
fn calculate_address(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<usize> {
if block_id >= self.config.num_blocks {
return Err(anyhow!(
"Block ID {} out of range (max: {})",
block_id,
self.config.num_blocks
));
}
if layer_id >= self.config.num_layers {
return Err(anyhow!(
"Layer ID {} out of range (max: {})",
layer_id,
self.config.num_layers
));
}
if outer_id >= self.config.outer_dim {
return Err(anyhow!(
"Outer ID {} out of range (max: {})",
outer_id,
self.config.outer_dim
));
}
let base_addr = self.layer_base_addrs[layer_id];
let offset = block_id * self.block_stride + outer_id * self.outer_stride;
Ok(base_addr + offset)
}
pub fn block_dim(&self) -> BlockDimension {
self.block_dim
}
pub fn memory_regions_mut(&mut self) -> &mut [Arc<dyn MemoryRegion>] {
&mut self.memory_regions
}
}
impl Layout for LayerSeparateLayout {
fn config(&self) -> &LayoutConfig {
&self.config
}
fn memory_regions(&self) -> &[OwnedMemoryRegion] {
&self.memory_regions
}
fn memory_region(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<MemoryDescriptor> {
let addr = self.calculate_address(block_id, layer_id, outer_id)?;
Ok(MemoryDescriptor::new(addr, self.region_size))
}
fn required_allocations(&self) -> Vec<usize> {
let per_layer_size = self.config.num_blocks * self.config.outer_dim * self.region_size;
vec![per_layer_size; self.config.num_layers]
}
fn is_fully_contiguous(&self) -> bool {
false
}
fn num_blocks(&self) -> usize {
self.config.num_blocks
}
fn num_layers(&self) -> usize {
self.config.num_layers
}
fn outer_dim(&self) -> usize {
self.config.outer_dim
}
fn page_size(&self) -> usize {
self.config.page_size
}
fn inner_dim(&self) -> usize {
self.config.inner_dim
}
fn dtype_width_bytes(&self) -> usize {
self.config.dtype_width_bytes
}
fn serialization_details(&self) -> LayoutTypeDetails {
LayoutTypeDetails::LayerSeparate(LayerSeparateDetails {
block_dim: self.block_dim,
})
}
}
#[cfg(all(test, feature = "testing-nixl"))]
mod tests {
use super::super::tests::*;
use super::*;
#[test]
fn test_layer_separate_block_contiguous() {
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(4)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let per_layer_size = 10 * 2 * 16 * 128 * 2;
let memory: Vec<Arc<dyn MemoryRegion>> = (0..4)
.map(|i| {
MockMemory::new(0x1000 + i * per_layer_size, per_layer_size)
as Arc<dyn MemoryRegion>
})
.collect();
let layout =
LayerSeparateLayout::new(config, memory, BlockDimension::BlockIsFirstDim).unwrap();
assert_eq!(layout.num_blocks(), 10);
assert!(!layout.is_fully_contiguous());
assert_eq!(layout.required_allocations().len(), 4);
}
#[test]
fn test_layer_separate_outer_contiguous() {
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(4)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let per_layer_size = 10 * 2 * 16 * 128 * 2;
let memory: Vec<Arc<dyn MemoryRegion>> = (0..4)
.map(|i| {
MockMemory::new(0x1000 + i * per_layer_size, per_layer_size)
as Arc<dyn MemoryRegion>
})
.collect();
let layout =
LayerSeparateLayout::new(config, memory, BlockDimension::BlockIsSecondDim).unwrap();
assert_eq!(layout.num_blocks(), 10);
assert!(!layout.is_fully_contiguous());
}
#[test]
fn test_memory_region() {
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let per_layer_size = 2 * 2 * 16 * 128 * 2;
let memory: Vec<Arc<dyn MemoryRegion>> = (0..2)
.map(|i| {
MockMemory::new(0x1000 + i * per_layer_size, per_layer_size)
as Arc<dyn MemoryRegion>
})
.collect();
let layout =
LayerSeparateLayout::new(config, memory, BlockDimension::BlockIsFirstDim).unwrap();
let region_size = 16 * 128 * 2;
let region = layout.memory_region(0, 0, 0).unwrap();
assert_eq!(region.addr, 0x1000);
assert_eq!(region.size, region_size);
let region = layout.memory_region(0, 1, 0).unwrap();
assert_eq!(region.addr, 0x1000 + per_layer_size);
assert_eq!(region.size, region_size);
let region = layout.memory_region(0, 0, 1).unwrap();
assert_eq!(region.addr, 0x1000 + region_size);
assert_eq!(region.size, region_size);
}
}