use crate::bytes::{AllocationController, AllocationError};
use alloc::alloc::Layout;
use alloc::vec::Vec;
use bytemuck::Contiguous;
use core::{alloc::LayoutError, marker::PhantomData, mem::MaybeUninit, ptr::NonNull};
pub const MAX_ALIGN: usize = core::mem::align_of::<u128>();
struct Allocation<'a> {
ptr: NonNull<u8>,
size: usize,
align: usize,
_lifetime: PhantomData<&'a ()>,
}
impl<'a> Allocation<'a> {
pub unsafe fn new_init(ptr: NonNull<u8>, size: usize, align: usize) -> Self {
debug_assert!(
align <= MAX_ALIGN,
"alignment exceeds maximum supported alignment"
);
debug_assert!(
ptr.as_ptr().align_offset(align.into_integer()) == 0,
"pointer is not properly aligned"
);
Self {
ptr,
size,
align,
_lifetime: PhantomData,
}
}
fn dangling(align: usize) -> Allocation<'a> {
let ptr = core::ptr::null_mut::<u8>().wrapping_add(align);
Self {
ptr: NonNull::new(ptr).unwrap(),
size: 0,
align,
_lifetime: PhantomData,
}
}
}
pub(crate) struct NativeAllocationController<'a> {
allocation: Allocation<'a>,
}
impl<'a> NativeAllocationController<'a> {
pub(crate) fn alloc_with_data(data: &[u8], align: usize) -> Result<Self, LayoutError> {
debug_assert!(
align <= MAX_ALIGN,
"alignment exceeds maximum supported alignment"
);
let capacity = data.len().next_multiple_of(align.into_integer());
let mut controller = Self::alloc_with_capacity(capacity, align)?;
unsafe {
controller.memory_mut()[..data.len()].copy_from_slice(core::slice::from_raw_parts(
data.as_ptr().cast(),
data.len(),
));
}
Ok(controller)
}
pub(crate) fn alloc_with_capacity(capacity: usize, align: usize) -> Result<Self, LayoutError> {
debug_assert!(
align <= MAX_ALIGN,
"alignment exceeds maximum supported alignment"
);
debug_assert!(
capacity.is_multiple_of(align.into_integer()),
"capacity must be a multiple of alignment"
);
let layout = Layout::from_size_align(capacity, align.into_integer())?;
let ptr = buffer_alloc(layout);
let allocation = unsafe { Allocation::new_init(ptr, layout.size(), layout.align()) };
Ok(Self { allocation })
}
pub(crate) fn from_elems<E>(elems: Vec<E>) -> Self
where
E: bytemuck::NoUninit + Send + Sync,
{
let mut elems = core::mem::ManuallyDrop::new(elems);
unsafe { elems.set_len(0) };
let data = elems.spare_capacity_mut();
let layout = Layout::for_value(data);
let ptr = NonNull::new(elems.as_mut_ptr() as *mut u8).unwrap();
let alloc = unsafe { Allocation::new_init(ptr, layout.size(), layout.align()) };
Self { allocation: alloc }
}
}
impl AllocationController for NativeAllocationController<'_> {
fn grow(&mut self, size: usize, align: usize) -> Result<(), AllocationError> {
debug_assert!(
align <= MAX_ALIGN,
"alignment exceeds maximum supported alignment"
);
debug_assert!(
size > self.allocation.size,
"new size must be larger than current size"
);
let Ok(new_layout) = Layout::from_size_align(size, align) else {
return Err(AllocationError::OutOfMemory);
};
let old_layout = unsafe {
Layout::from_size_align_unchecked(self.allocation.size, self.allocation.align)
};
let (layout, ptr) = buffer_grow(old_layout, self.allocation.ptr, new_layout);
self.allocation = unsafe { Allocation::new_init(ptr, layout.size(), layout.align()) };
Ok(())
}
fn try_detach(&mut self) -> Option<NonNull<u8>> {
let ptr = self.allocation.ptr;
self.allocation = Allocation::dangling(self.allocation.align);
Some(ptr)
}
fn alloc_align(&self) -> usize {
self.allocation.align
}
fn memory(&self) -> &[MaybeUninit<u8>] {
unsafe {
core::slice::from_raw_parts(self.allocation.ptr.as_ptr().cast(), self.allocation.size)
}
}
unsafe fn memory_mut(&mut self) -> &mut [MaybeUninit<u8>] {
unsafe {
core::slice::from_raw_parts_mut(
self.allocation.ptr.as_ptr().cast(),
self.allocation.size,
)
}
}
}
impl Drop for NativeAllocationController<'_> {
fn drop(&mut self) {
let layout = unsafe {
Layout::from_size_align_unchecked(self.allocation.size, self.allocation.align)
};
buffer_dealloc(layout, self.allocation.ptr.cast());
}
}
fn buffer_alloc(layout: Layout) -> NonNull<u8> {
if layout.size() == 0 {
let ptr = core::ptr::null_mut::<u8>().wrapping_add(layout.align());
NonNull::new(ptr).unwrap()
} else {
let ptr = unsafe { alloc::alloc::alloc(layout) };
NonNull::new(ptr.cast()).unwrap_or_else(|| alloc::alloc::handle_alloc_error(layout))
}
}
fn buffer_dealloc(layout: Layout, buffer: NonNull<u8>) {
if layout.size() != 0 {
unsafe {
alloc::alloc::dealloc(buffer.as_ptr().cast(), layout);
}
} else {
expect_dangling(layout.align(), buffer.cast());
}
}
fn buffer_grow(
old_layout: Layout,
buffer: NonNull<u8>,
min_layout: Layout,
) -> (Layout, NonNull<u8>) {
let new_align = min_layout.align().max(old_layout.align()); let new_size = min_layout.size().next_multiple_of(new_align);
if new_size > isize::MAX as usize {
alloc_overflow();
}
assert!(new_size > old_layout.size(), "size must actually grow");
if old_layout.size() == 0 {
expect_dangling(old_layout.align(), buffer);
let new_layout = Layout::from_size_align(new_size, new_align).unwrap();
let buffer = buffer_alloc(new_layout);
return (new_layout, buffer);
};
let realloc = || {
let new_layout = Layout::from_size_align(new_size, old_layout.align()).unwrap();
let ptr = unsafe { alloc::alloc::realloc(buffer.as_ptr(), old_layout, new_layout.size()) };
(new_layout, ptr)
};
if new_align == old_layout.align() {
let (new_layout, ptr) = realloc();
let buffer = NonNull::new(ptr);
let buffer = buffer.unwrap_or_else(|| alloc::alloc::handle_alloc_error(new_layout));
return (new_layout, buffer);
}
#[cfg(target_has_atomic = "8")]
mod alignment_assumption {
use core::sync::atomic::{AtomicBool, Ordering};
static SPECULATE: AtomicBool = AtomicBool::new(true);
pub fn speculate() -> bool {
SPECULATE.load(Ordering::Relaxed)
}
pub fn report_violation() {
SPECULATE.store(false, Ordering::Relaxed)
}
}
#[cfg(not(target_has_atomic = "8"))]
mod alignment_assumption {
pub fn speculate() -> bool {
false
}
pub fn report_violation() {}
}
let mut old_buffer = buffer;
let mut old_layout = old_layout;
if alignment_assumption::speculate() {
let (realloc_layout, ptr) = realloc();
if let Some(buffer) = NonNull::new(ptr) {
if buffer.align_offset(new_align) == 0 {
return (realloc_layout, buffer);
}
alignment_assumption::report_violation();
old_buffer = buffer.cast();
old_layout = realloc_layout;
} else {
}
}
let new_layout = Layout::from_size_align(new_size, new_align).unwrap();
let new_buffer = buffer_alloc(new_layout);
unsafe {
core::ptr::copy_nonoverlapping(
old_buffer.as_ptr(),
new_buffer.as_ptr().cast(),
old_layout.size(),
);
}
buffer_dealloc(old_layout, old_buffer);
(new_layout, new_buffer)
}
fn expect_dangling(align: usize, buffer: NonNull<u8>) {
debug_assert!(
buffer.as_ptr().wrapping_sub(align).is_null(),
"expected a nullptr for size 0"
);
}
#[cold]
pub fn alloc_overflow() -> ! {
panic!("Overflow, too many elements")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bytes::AllocationController;
#[test]
fn test_core_allocation_controller_alloc_with_capacity() {
let controller = NativeAllocationController::alloc_with_capacity(64, 8).unwrap();
assert_eq!(controller.alloc_align(), 8);
assert_eq!(controller.memory().len(), 64);
}
#[test]
fn test_core_allocation_controller_alloc_with_data() {
let data = b"hello world test"; let controller = NativeAllocationController::alloc_with_data(data, 8).unwrap();
assert_eq!(controller.alloc_align(), 8);
assert!(controller.memory().len() >= data.len());
assert_eq!(controller.memory().len() % 8, 0);
let memory = controller.memory();
let memory_slice =
unsafe { core::slice::from_raw_parts(memory.as_ptr() as *const u8, data.len()) };
assert_eq!(memory_slice, data);
}
#[test]
fn test_core_allocation_controller_from_elems() {
let elems = vec![1u32, 2, 3, 4];
let expected_bytes = elems.len() * core::mem::size_of::<u32>();
let controller = NativeAllocationController::from_elems(elems);
assert_eq!(controller.alloc_align(), core::mem::align_of::<u32>());
assert_eq!(controller.memory().len(), expected_bytes);
}
#[test]
fn test_core_allocation_controller_grow() {
let mut controller = NativeAllocationController::alloc_with_capacity(32, 8).unwrap();
let old_memory_len = controller.memory().len();
controller.grow(64, 8).unwrap();
assert_eq!(controller.alloc_align(), 8);
assert!(controller.memory().len() >= 64);
assert!(controller.memory().len() > old_memory_len);
}
#[test]
fn test_buffer_alloc_zero_size() {
let layout = Layout::from_size_align(0, 8).unwrap();
let ptr = buffer_alloc(layout);
assert_eq!(ptr.as_ptr().align_offset(8), 0);
buffer_dealloc(layout, ptr);
}
#[test]
fn test_buffer_grow_from_zero() {
let old_layout = Layout::from_size_align(0, 8).unwrap();
let buffer = buffer_alloc(old_layout);
let min_layout = Layout::from_size_align(64, 8).unwrap();
let (new_layout, new_buffer) = buffer_grow(old_layout, buffer, min_layout);
assert!(new_layout.size() >= 64);
assert_eq!(new_layout.align(), 8);
buffer_dealloc(new_layout, new_buffer);
}
#[test]
fn test_memory_access() {
let data = b"test data"; let controller = NativeAllocationController::alloc_with_data(data, 8).unwrap();
let memory = controller.memory();
assert!(memory.len() >= data.len());
assert_eq!(memory.len() % 8, 0); let memory_slice =
unsafe { core::slice::from_raw_parts(memory.as_ptr() as *const u8, data.len()) };
assert_eq!(memory_slice, data);
}
#[test]
fn test_memory_mut_access() {
let mut controller = NativeAllocationController::alloc_with_capacity(16, 8).unwrap();
unsafe {
let memory = controller.memory_mut();
assert_eq!(memory.len(), 16);
memory[0].write(42);
memory[1].write(84);
}
let memory = controller.memory();
unsafe {
assert_eq!(memory[0].assume_init(), 42);
assert_eq!(memory[1].assume_init(), 84);
}
}
#[test]
#[should_panic(expected = "capacity must be a multiple of alignment")]
fn test_debug_assert_capacity_alignment_mismatch() {
let _ = NativeAllocationController::alloc_with_capacity(33, 8);
}
}