use super::{Storage, StorageError};
use offset_allocator::{Allocation, Allocator};
use std::sync::{Arc, Mutex};
#[derive(Debug, thiserror::Error)]
pub enum ArenaError {
#[error("Page size must be a power of 2")]
PageSizeNotAligned,
#[error("Allocation failed")]
AllocationFailed,
#[error("Failed to convert pages to u32")]
PagesNotConvertible,
#[error("Storage not registered with NIXL")]
NotRegisteredWithNixl,
#[error("Storage error: {0}")]
StorageError(#[from] StorageError),
}
#[derive(Clone)]
pub struct ArenaAllocator<S: Storage> {
storage: Arc<S>,
allocator: Arc<Mutex<Allocator>>,
page_size: u64,
}
impl<S: Storage> std::fmt::Debug for ArenaAllocator<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ArenaAllocator {{ storage: {:?}, page_size: {} }}",
self.storage, self.page_size
)
}
}
pub struct ArenaBuffer<S: Storage> {
offset: u64,
address: u64,
requested_size: usize,
storage: Arc<S>,
allocation: Allocation,
allocator: Arc<Mutex<Allocator>>,
}
impl<S: Storage> ArenaAllocator<S> {
pub fn new(storage: S, page_size: usize) -> Result<Self, ArenaError> {
let storage = Arc::new(storage);
if !page_size.is_power_of_two() {
return Err(ArenaError::PageSizeNotAligned);
}
let pages = storage.size() / page_size;
let allocator = Allocator::new(
pages
.try_into()
.map_err(|_| ArenaError::PagesNotConvertible)?,
);
let allocator = Arc::new(Mutex::new(allocator));
Ok(Self {
storage,
allocator,
page_size: page_size as u64,
})
}
pub fn allocate(&self, size: usize) -> Result<ArenaBuffer<S>, ArenaError> {
let size = size as u64;
let pages = size.div_ceil(self.page_size);
let allocation = self
.allocator
.lock()
.unwrap()
.allocate(pages.try_into().map_err(|_| ArenaError::AllocationFailed)?)
.ok_or(ArenaError::AllocationFailed)?;
let offset = allocation.offset as u64 * self.page_size;
let address = self.storage.addr() + offset;
debug_assert!(address + size <= self.storage.addr() + self.storage.size() as u64);
Ok(ArenaBuffer {
offset,
address,
requested_size: size as usize,
allocation,
storage: self.storage.clone(),
allocator: self.allocator.clone(),
})
}
}
impl<S: Storage> std::fmt::Debug for ArenaBuffer<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ArenaBuffer {{ addr {}, size: {}, kind: {:?}, allocator: {:p}}}",
self.address,
self.requested_size,
self.storage.storage_type(),
Arc::as_ptr(&self.storage)
)
}
}
impl<S: Storage> ArenaBuffer<S> {
pub fn address(&self) -> u64 {
self.address
}
pub fn size(&self) -> usize {
self.requested_size
}
}
mod nixl {
use super::super::nixl::*;
use super::super::*;
use super::*;
impl<S: Storage> ArenaBuffer<S>
where
S: NixlRegisterableStorage,
{
pub fn nixl_remote_descriptor(&self) -> Result<NixlRemoteDescriptor, ArenaError> {
let agent = self.storage.nixl_agent_name();
match agent {
Some(agent) => {
let storage = NixlStorage::from_storage_with_offset(
self.storage.as_ref(),
self.offset as usize,
self.requested_size,
)?;
Ok(NixlRemoteDescriptor::new(storage, agent))
}
_ => Err(ArenaError::NotRegisteredWithNixl),
}
}
}
impl<S: Storage> MemoryRegion for ArenaBuffer<S>
where
S: MemoryRegion,
{
unsafe fn as_ptr(&self) -> *const u8 {
unsafe { Storage::as_ptr(self.storage.as_ref()) }
}
fn size(&self) -> usize {
Storage::size(self.storage.as_ref())
}
}
impl<S: Storage> NixlDescriptor for ArenaBuffer<S>
where
S: NixlDescriptor,
{
fn mem_type(&self) -> MemType {
NixlDescriptor::mem_type(self.storage.as_ref())
}
fn device_id(&self) -> u64 {
NixlDescriptor::device_id(self.storage.as_ref())
}
}
}
impl<S: Storage> Drop for ArenaBuffer<S> {
fn drop(&mut self) {
self.allocator.lock().unwrap().free(self.allocation);
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use super::*;
use crate::block_manager::storage::SystemStorage;
const PAGE_SIZE: usize = 4096;
const PAGE_COUNT: usize = 10;
const TOTAL_STORAGE_SIZE: usize = PAGE_SIZE * PAGE_COUNT;
fn create_allocator() -> ArenaAllocator<SystemStorage> {
let storage = SystemStorage::new(TOTAL_STORAGE_SIZE).unwrap();
ArenaAllocator::new(storage, PAGE_SIZE).unwrap()
}
#[test]
fn test_arena_allocator_new_success() {
let storage = SystemStorage::new(TOTAL_STORAGE_SIZE).unwrap();
let allocator_result = ArenaAllocator::new(storage, PAGE_SIZE);
assert!(allocator_result.is_ok());
}
#[test]
fn test_arena_allocator_new_invalid_page_size() {
let storage = SystemStorage::new(TOTAL_STORAGE_SIZE).unwrap();
let allocator_result = ArenaAllocator::new(storage, PAGE_SIZE + 1);
assert!(allocator_result.is_err());
assert_matches!(allocator_result, Err(ArenaError::PageSizeNotAligned));
}
#[test]
fn test_allocate_single_buffer() {
let allocator = create_allocator();
let buffer_size = PAGE_SIZE * 2;
let buffer_result = allocator.allocate(buffer_size);
assert!(buffer_result.is_ok());
let buffer = buffer_result.unwrap();
assert_eq!(buffer.size(), buffer_size);
assert_eq!(buffer.address(), allocator.storage.addr()); }
#[test]
fn test_allocate_multiple_buffers() {
let allocator = create_allocator();
let buffer_size1 = PAGE_SIZE * 2;
let buffer1_result = allocator.allocate(buffer_size1);
assert!(buffer1_result.is_ok());
let buffer1 = buffer1_result.unwrap();
assert_eq!(buffer1.size(), buffer_size1);
assert_eq!(buffer1.address(), allocator.storage.addr());
let buffer_size2 = PAGE_SIZE * 3;
let buffer2_result = allocator.allocate(buffer_size2);
assert!(buffer2_result.is_ok());
let buffer2 = buffer2_result.unwrap();
assert_eq!(buffer2.size(), buffer_size2);
assert_eq!(
buffer2.address(),
allocator.storage.addr() + buffer_size1 as u64
);
}
#[test]
fn test_allocate_exact_size() {
let allocator = create_allocator();
let buffer_size = TOTAL_STORAGE_SIZE;
let buffer_result = allocator.allocate(buffer_size);
assert!(buffer_result.is_ok());
let buffer = buffer_result.unwrap();
assert_eq!(buffer.size(), buffer_size);
}
#[test]
fn test_allocate_too_large() {
let allocator = create_allocator();
let buffer_size = TOTAL_STORAGE_SIZE + PAGE_SIZE;
let buffer_result = allocator.allocate(buffer_size);
assert!(buffer_result.is_err());
assert_matches!(buffer_result, Err(ArenaError::AllocationFailed));
}
#[test]
fn test_buffer_drop_and_reallocate() {
let allocator = create_allocator();
let buffer_size = PAGE_SIZE * 6;
{
let buffer1 = allocator.allocate(buffer_size).unwrap();
assert_eq!(buffer1.size(), buffer_size);
assert_eq!(buffer1.address(), allocator.storage.addr());
}
let buffer2_result = allocator.allocate(buffer_size);
assert!(buffer2_result.is_ok());
let buffer2 = buffer2_result.unwrap();
assert_eq!(buffer2.size(), buffer_size);
assert_eq!(buffer2.address(), allocator.storage.addr()); }
#[test]
fn test_allocate_fill_and_fail() {
let allocator = create_allocator();
let buffer_size_half = TOTAL_STORAGE_SIZE / 2;
let buffer1 = allocator.allocate(buffer_size_half).unwrap();
assert_eq!(buffer1.size(), buffer_size_half);
let buffer2 = allocator.allocate(buffer_size_half).unwrap();
assert_eq!(buffer2.size(), buffer_size_half);
assert_eq!(
buffer2.address(),
allocator.storage.addr() + buffer_size_half as u64
);
let buffer3_result = allocator.allocate(PAGE_SIZE);
assert!(buffer3_result.is_err());
assert_matches!(buffer3_result, Err(ArenaError::AllocationFailed));
}
#[test]
fn test_allocate_non_page_aligned_single_byte() {
let allocator = create_allocator();
let buffer = allocator.allocate(1).unwrap();
assert_eq!(buffer.size(), 1);
}
#[test]
fn test_allocate_non_page_aligned_almost_full_page() {
let allocator = create_allocator();
let buffer = allocator.allocate(PAGE_SIZE - 1).unwrap();
assert_eq!(buffer.size(), PAGE_SIZE - 1);
}
#[test]
fn test_allocate_non_page_aligned_just_over_one_page() {
let allocator = create_allocator();
let buffer = allocator.allocate(PAGE_SIZE + 1).unwrap();
assert_eq!(buffer.size(), PAGE_SIZE + 1);
}
#[test]
fn test_allocate_half_plus_one_byte_twice_exhausts_arena() {
let allocator = create_allocator();
let allocation_size = (PAGE_COUNT / 2 * PAGE_SIZE) + 1;
let buffer1_result = allocator.allocate(allocation_size);
assert!(buffer1_result.is_ok(), "First allocation should succeed");
let buffer1 = buffer1_result.unwrap();
assert_eq!(buffer1.size(), allocation_size);
let pages_for_first_alloc = (allocation_size as u64).div_ceil(allocator.page_size);
assert_eq!(pages_for_first_alloc, (PAGE_COUNT / 2 + 1) as u64);
let buffer2_result = allocator.allocate(allocation_size);
assert!(
buffer2_result.is_err(),
"Second allocation should fail due to insufficient pages"
);
assert_matches!(buffer2_result, Err(ArenaError::AllocationFailed));
}
#[test]
fn test_fill_with_non_aligned_and_fail() {
let allocator = create_allocator();
let single_alloc_size = PAGE_SIZE + 1; let num_possible_allocs = PAGE_COUNT / 2;
let mut allocated_buffers = Vec::with_capacity(num_possible_allocs);
for i in 0..num_possible_allocs {
let buffer_result = allocator.allocate(single_alloc_size);
assert!(buffer_result.is_ok(), "Allocation {} should succeed", i + 1);
let buffer = buffer_result.unwrap();
assert_eq!(buffer.size(), single_alloc_size);
allocated_buffers.push(buffer);
}
let final_alloc_result = allocator.allocate(1);
assert!(
final_alloc_result.is_err(),
"Final allocation of 1 byte should fail as arena is full"
);
assert_matches!(final_alloc_result, Err(ArenaError::AllocationFailed));
}
}