use crate::domain::{DomainError, DomainResult};
use std::{
alloc::{Layout, alloc, dealloc, realloc},
ptr::NonNull,
};
#[derive(Debug, Default, Clone, Copy)]
pub struct AlignedAllocator;
impl AlignedAllocator {
pub const fn new() -> Self {
Self
}
pub unsafe fn alloc_aligned(&self, size: usize, alignment: usize) -> DomainResult<NonNull<u8>> {
if !alignment.is_power_of_two() {
return Err(DomainError::InvalidInput(format!(
"Alignment {} is not a power of 2",
alignment
)));
}
let layout = Layout::from_size_align(size, alignment)
.map_err(|e| DomainError::InvalidInput(format!("Invalid layout: {}", e)))?;
let ptr = unsafe { alloc(layout) };
if ptr.is_null() {
return Err(DomainError::ResourceExhausted(format!(
"Failed to allocate {} bytes with alignment {}",
size, alignment
)));
}
Ok(unsafe { NonNull::new_unchecked(ptr) })
}
pub unsafe fn realloc_aligned(
&self,
ptr: NonNull<u8>,
old_layout: Layout,
new_size: usize,
) -> DomainResult<NonNull<u8>> {
let new_ptr = unsafe { realloc(ptr.as_ptr(), old_layout, new_size) };
if new_ptr.is_null() {
return Err(DomainError::ResourceExhausted(format!(
"Failed to reallocate to {} bytes",
new_size
)));
}
Ok(unsafe { NonNull::new_unchecked(new_ptr) })
}
pub unsafe fn dealloc_aligned(&self, ptr: NonNull<u8>, layout: Layout) {
unsafe { dealloc(ptr.as_ptr(), layout) };
}
}
static ALIGNED_ALLOCATOR: AlignedAllocator = AlignedAllocator::new();
pub fn aligned_allocator() -> &'static AlignedAllocator {
&ALIGNED_ALLOCATOR
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aligned_allocation() {
let allocator = AlignedAllocator::new();
unsafe {
for alignment in [16, 32, 64, 128, 256] {
let ptr = allocator.alloc_aligned(1024, alignment).unwrap();
assert_eq!(
ptr.as_ptr() as usize % alignment,
0,
"Pointer not aligned to {} bytes",
alignment
);
let layout = Layout::from_size_align(1024, alignment).unwrap();
allocator.dealloc_aligned(ptr, layout);
}
}
}
#[test]
fn test_reallocation() {
let allocator = AlignedAllocator::new();
unsafe {
let alignment = 64;
let initial_size = 1024;
let new_size = 2048;
let ptr = allocator.alloc_aligned(initial_size, alignment).unwrap();
let layout = Layout::from_size_align(initial_size, alignment).unwrap();
std::ptr::write_bytes(ptr.as_ptr(), 0xAB, initial_size);
let new_ptr = allocator.realloc_aligned(ptr, layout, new_size).unwrap();
assert_eq!(
new_ptr.as_ptr() as usize % alignment,
0,
"Reallocated pointer not aligned"
);
let first_byte = std::ptr::read(new_ptr.as_ptr());
assert_eq!(first_byte, 0xAB, "Data not preserved during reallocation");
let new_layout = Layout::from_size_align(new_size, alignment).unwrap();
allocator.dealloc_aligned(new_ptr, new_layout);
}
}
#[test]
fn test_invalid_alignment() {
let allocator = AlignedAllocator::new();
unsafe {
assert!(allocator.alloc_aligned(1024, 0).is_err());
assert!(allocator.alloc_aligned(1024, 3).is_err());
assert!(allocator.alloc_aligned(1024, 17).is_err());
}
}
#[test]
fn test_aligned_allocator_singleton() {
let a = aligned_allocator();
let b = aligned_allocator();
assert!(std::ptr::eq(a, b));
}
}