use core::{
alloc::{Layout, LayoutError},
cell::{Cell, UnsafeCell},
mem::{self, MaybeUninit},
ops,
ptr::{self, NonNull},
};
use alloc_traits::AllocTime;
use crate::bump::{Allocation, Failure, Level};
use crate::leaked::LeakBox;
#[cfg_attr(feature = "alloc", doc = "```")]
#[cfg_attr(not(feature = "alloc"), doc = "```ignore")]
#[repr(C)]
pub struct Bump<T> {
_index: Cell<usize>,
_data: UnsafeCell<MaybeUninit<T>>,
}
#[derive(Debug)]
pub struct FromMemError {
_inner: (),
}
#[repr(C)]
pub struct MemBump {
index: Cell<usize>,
data: UnsafeCell<[MaybeUninit<u8>]>,
}
impl<T> Bump<T> {
pub fn uninit() -> Self {
Bump {
_index: Cell::new(0),
_data: UnsafeCell::new(MaybeUninit::uninit()),
}
}
pub fn zeroed() -> Self {
Bump {
_index: Cell::new(0),
_data: UnsafeCell::new(MaybeUninit::zeroed()),
}
}
}
#[cfg(feature = "alloc")]
impl MemBump {
pub fn new(capacity: usize) -> alloc::boxed::Box<Self> {
let layout = Self::layout_from_size(capacity).expect("Bad layout");
let ptr = NonNull::new(unsafe { alloc::alloc::alloc(layout) })
.unwrap_or_else(|| alloc::alloc::handle_alloc_error(layout));
let ptr = ptr::slice_from_raw_parts_mut(ptr.as_ptr(), capacity);
unsafe { ptr::write(ptr as *mut Cell<usize>, Cell::new(0)) };
unsafe { alloc::boxed::Box::from_raw(ptr as *mut MemBump) }
}
}
impl MemBump {
pub fn from_mem(mem: &mut [MaybeUninit<u8>]) -> Result<LeakBox<'_, Self>, FromMemError> {
let header = Self::header_layout();
let offset = mem.as_ptr().align_offset(header.align());
let mem = mem.get_mut(offset..).ok_or(FromMemError { _inner: () })?;
mem.get_mut(..header.size())
.ok_or(FromMemError { _inner: () })?
.fill(MaybeUninit::new(0));
Ok(unsafe { Self::from_mem_unchecked(mem) })
}
pub unsafe fn from_mem_unchecked(mem: &mut [MaybeUninit<u8>]) -> LeakBox<'_, Self> {
let raw = Self::from_aligned_mem(mem);
LeakBox::from_mut_unchecked(raw)
}
#[allow(unused_unsafe)]
unsafe fn from_aligned_mem(mem: &mut [MaybeUninit<u8>]) -> &mut Self {
let header = Self::header_layout();
let datasize = mem.len() - header.size();
let datasize = datasize - datasize % header.align();
debug_assert!(Self::layout_from_size(datasize).map_or(false, |l| l.size() <= mem.len()));
let raw = mem.as_mut_ptr() as *mut u8;
unsafe { &mut *(ptr::slice_from_raw_parts_mut(raw, datasize) as *mut MemBump) }
}
pub fn into_mem<'lt>(this: LeakBox<'lt, Self>) -> &'lt mut [MaybeUninit<u8>] {
let layout = Layout::for_value(&*this);
let mem_pointer = LeakBox::into_raw(this) as *mut MaybeUninit<u8>;
unsafe { &mut *ptr::slice_from_raw_parts_mut(mem_pointer, layout.size()) }
}
fn header_layout() -> Layout {
Layout::new::<Cell<usize>>()
}
fn data_layout(size: usize) -> Result<Layout, LayoutError> {
Layout::array::<UnsafeCell<MaybeUninit<u8>>>(size)
}
pub(crate) fn layout_from_size(size: usize) -> Result<Layout, LayoutError> {
let data_tail = Self::data_layout(size)?;
let (layout, _) = Self::header_layout().extend(data_tail)?;
Ok(layout.pad_to_align())
}
pub const fn capacity(&self) -> usize {
unsafe { (*(self.data.get() as *const [UnsafeCell<u8>])).len() }
}
pub fn data_ptr(&self) -> NonNull<u8> {
NonNull::new(self.data.get() as *mut u8).expect("from a reference")
}
pub fn alloc(&self, layout: Layout) -> Option<NonNull<u8>> {
Some(self.try_alloc(layout)?.ptr)
}
pub fn alloc_at(&self, layout: Layout, level: Level) -> Result<NonNull<u8>, Failure> {
let Allocation { ptr, .. } = self.try_alloc_at(layout, level.0)?;
Ok(ptr)
}
pub fn get<V>(&self) -> Option<Allocation<V>> {
let alloc = self.try_alloc(Layout::new::<V>())?;
Some(Allocation {
lifetime: alloc.lifetime,
level: alloc.level,
ptr: alloc.ptr.cast(),
})
}
pub fn get_at<V>(&self, level: Level) -> Result<Allocation<V>, Failure> {
let alloc = self.try_alloc_at(Layout::new::<V>(), level.0)?;
Ok(Allocation {
lifetime: alloc.lifetime,
level: alloc.level,
ptr: alloc.ptr.cast(),
})
}
pub unsafe fn get_unchecked<V>(&self, level: Level) -> Allocation<V> {
debug_assert!(level.0 < self.capacity());
let ptr = self.data_ptr().as_ptr();
let alloc = ptr.offset(level.0 as isize) as *mut V;
Allocation {
level,
lifetime: AllocTime::default(),
ptr: NonNull::new_unchecked(alloc),
}
}
pub fn bump_box<'bump, T: 'bump>(
&'bump self,
) -> Result<LeakBox<'bump, MaybeUninit<T>>, Failure> {
let allocation = self.get_at(self.level())?;
Ok(unsafe { allocation.uninit() }.into())
}
pub fn bump_array<'bump, T: 'bump>(
&'bump self,
n: usize,
) -> Result<LeakBox<'bump, [MaybeUninit<T>]>, Failure> {
let layout = Layout::array::<T>(n).map_err(|_| Failure::Exhausted)?;
let raw = self.alloc(layout).ok_or(Failure::Exhausted)?;
let slice = ptr::slice_from_raw_parts_mut(raw.cast().as_ptr(), n);
let uninit = unsafe { &mut *slice };
Ok(uninit.into())
}
pub fn level(&self) -> Level {
Level(self.index.get())
}
pub fn reset(&mut self) {
self.index.set(0)
}
fn try_alloc(&self, layout: Layout) -> Option<Allocation<'_>> {
let consumed = self.index.get();
match self.try_alloc_at(layout, consumed) {
Ok(alloc) => return Some(alloc),
Err(Failure::Exhausted) => return None,
Err(Failure::Mismatch { observed: _ }) => {
unreachable!("Count in Cell concurrently modified, this UB")
}
}
}
fn try_alloc_at(
&self,
layout: Layout,
expect_consumed: usize,
) -> Result<Allocation<'_>, Failure> {
assert!(layout.size() > 0);
let length = mem::size_of_val(&self.data);
let data: &UnsafeCell<[MaybeUninit<u8>]> =
unsafe { &*(&self.data as *const _ as *const UnsafeCell<_>) };
let base_ptr = data.get() as *mut u8;
let alignment = layout.align();
let requested = layout.size();
assert!(expect_consumed <= length, "{}/{}", expect_consumed, length);
let available = length.checked_sub(expect_consumed).unwrap();
let ptr_to = base_ptr.wrapping_add(expect_consumed);
let offset = ptr_to.align_offset(alignment);
if Some(requested) > available.checked_sub(offset) {
return Err(Failure::Exhausted); }
assert!(offset < available);
let at_aligned = expect_consumed.checked_add(offset).unwrap();
let new_consumed = at_aligned.checked_add(requested).unwrap();
assert!(new_consumed <= length);
assert!(at_aligned < length);
match self.bump(expect_consumed, new_consumed) {
Ok(()) => (),
Err(observed) => {
return Err(Failure::Mismatch {
observed: Level(observed),
});
}
}
let aligned = unsafe {
(base_ptr as *mut u8).add(at_aligned)
};
Ok(Allocation {
ptr: NonNull::new(aligned).unwrap(),
lifetime: AllocTime::default(),
level: Level(new_consumed),
})
}
fn bump(&self, expect: usize, consume: usize) -> Result<(), usize> {
debug_assert!(consume <= self.capacity());
debug_assert!(expect <= consume);
let prev = self.index.get();
if prev != expect {
Err(prev)
} else {
self.index.set(consume);
Ok(())
}
}
}
impl<T> ops::Deref for Bump<T> {
type Target = MemBump;
fn deref(&self) -> &MemBump {
let from_layout = Layout::for_value(self);
let data_layout = Layout::new::<MaybeUninit<T>>();
let ptr = self as *const Self as *const MaybeUninit<u8>;
let mem: *const [MaybeUninit<u8>] = ptr::slice_from_raw_parts(ptr, data_layout.size());
let bump = unsafe { &*(mem as *const MemBump) };
debug_assert_eq!(from_layout, Layout::for_value(bump));
bump
}
}
impl<T> ops::DerefMut for Bump<T> {
fn deref_mut(&mut self) -> &mut MemBump {
let from_layout = Layout::for_value(self);
let data_layout = Layout::new::<MaybeUninit<T>>();
let ptr = self as *mut Self as *mut MaybeUninit<u8>;
let mem: *mut [MaybeUninit<u8>] = ptr::slice_from_raw_parts_mut(ptr, data_layout.size());
let bump = unsafe { &mut *(mem as *mut MemBump) };
debug_assert_eq!(from_layout, Layout::for_value(bump));
bump
}
}
#[test]
fn mem_bump_derefs_correctly() {
let bump = Bump::<usize>::zeroed();
let mem: &MemBump = ≎
assert_eq!(mem::size_of_val(&bump), mem::size_of_val(mem));
}