use anyhow::{Result, anyhow};
use std::sync::Arc;
use validator::Validate;
use super::serialize::{BlockFormat, FullyContiguousDetails, LayoutTypeDetails};
use super::{Layout, LayoutConfig, MemoryDescriptor, MemoryRegion, OwnedMemoryRegion};
#[derive(Debug)]
pub struct FullyContiguousLayout {
config: LayoutConfig,
base_addr: usize,
block_stride: usize,
layer_stride: usize,
outer_stride: usize,
region_size: usize,
memory: Arc<dyn MemoryRegion>,
block_format: BlockFormat,
}
impl FullyContiguousLayout {
pub fn new(config: LayoutConfig, memory: Arc<dyn MemoryRegion>) -> Result<Self> {
config.validate()?;
let base_addr = memory.addr();
let region_size = config.page_size * config.inner_dim * config.dtype_width_bytes;
let outer_stride = region_size;
let layer_stride = outer_stride * config.outer_dim;
let block_stride = layer_stride * config.num_layers;
let required_size = block_stride * config.num_blocks;
if memory.size() < required_size {
return Err(anyhow!(
"Memory region too small for layout. Required: {} bytes, got: {} bytes",
required_size,
memory.size()
));
}
Ok(Self {
config,
base_addr,
block_stride,
layer_stride,
outer_stride,
region_size,
memory,
block_format: BlockFormat::default(),
})
}
pub(crate) fn new_with_format(
config: LayoutConfig,
memory: Arc<dyn MemoryRegion>,
block_format: BlockFormat,
) -> Result<Self> {
let mut layout = Self::new(config, memory)?;
layout.block_format = block_format;
Ok(layout)
}
pub fn block_format(&self) -> BlockFormat {
self.block_format
}
fn calculate_address(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<usize> {
if block_id >= self.config.num_blocks {
return Err(anyhow!(
"Block ID {} out of range (max: {})",
block_id,
self.config.num_blocks
));
}
if layer_id >= self.config.num_layers {
return Err(anyhow!(
"Layer ID {} out of range (max: {})",
layer_id,
self.config.num_layers
));
}
if outer_id >= self.config.outer_dim {
return Err(anyhow!(
"Outer ID {} out of range (max: {})",
outer_id,
self.config.outer_dim
));
}
Ok(self.base_addr
+ block_id * self.block_stride
+ layer_id * self.layer_stride
+ outer_id * self.outer_stride)
}
pub fn memory_arc_mut(&mut self) -> &mut Arc<dyn MemoryRegion> {
&mut self.memory
}
}
impl Layout for FullyContiguousLayout {
fn config(&self) -> &LayoutConfig {
&self.config
}
fn memory_regions(&self) -> &[OwnedMemoryRegion] {
std::slice::from_ref(&self.memory)
}
fn memory_region(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<MemoryDescriptor> {
let addr = self.calculate_address(block_id, layer_id, outer_id)?;
Ok(MemoryDescriptor::new(addr, self.region_size))
}
fn required_allocations(&self) -> Vec<usize> {
vec![self.block_stride * self.config.num_blocks]
}
fn is_fully_contiguous(&self) -> bool {
true
}
fn num_blocks(&self) -> usize {
self.config.num_blocks
}
fn num_layers(&self) -> usize {
self.config.num_layers
}
fn outer_dim(&self) -> usize {
self.config.outer_dim
}
fn page_size(&self) -> usize {
self.config.page_size
}
fn inner_dim(&self) -> usize {
self.config.inner_dim
}
fn dtype_width_bytes(&self) -> usize {
self.config.dtype_width_bytes
}
fn serialization_details(&self) -> LayoutTypeDetails {
LayoutTypeDetails::FullyContiguous(FullyContiguousDetails {
block_format: self.block_format,
})
}
}
#[cfg(all(test, feature = "testing-nixl"))]
mod tests {
use super::super::tests::*;
use super::*;
#[test]
fn test_fully_contiguous_layout_creation() {
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(4)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let required_bytes = config.required_bytes();
assert_eq!(required_bytes, 10 * 4 * 2 * 16 * 128 * 2);
let memory = MockMemory::new(0x1000, required_bytes);
let layout = FullyContiguousLayout::new(config, memory).unwrap();
assert_eq!(layout.num_blocks(), 10);
assert!(layout.is_fully_contiguous());
}
#[test]
fn test_memory_region() {
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let required_size = config.required_bytes();
let memory = MockMemory::new(0x1000, required_size);
let layout = FullyContiguousLayout::new(config.clone(), memory).unwrap();
let region_size = config.page_size * config.inner_dim * config.dtype_width_bytes;
let region = layout.memory_region(0, 0, 0).unwrap();
assert_eq!(region.addr, 0x1000);
assert_eq!(region.size, region_size);
let region = layout.memory_region(0, 0, 1).unwrap();
assert_eq!(region.addr, 0x1000 + region_size);
assert_eq!(region.size, region_size);
let region = layout.memory_region(0, 1, 0).unwrap();
assert_eq!(region.addr, 0x1000 + 2 * region_size);
assert_eq!(region.size, region_size);
let region = layout.memory_region(1, 0, 0).unwrap();
assert_eq!(
region.addr,
0x1000 + (config.outer_dim * config.num_layers * region_size)
);
assert_eq!(region.size, region_size);
}
}