use core::cell::UnsafeCell;
use core::marker::PhantomData;
use core::mem::{MaybeUninit, align_of, forget, replace, size_of};
use core::num::NonZeroU16;
use core::ops::{Deref, DerefMut};
use core::ptr;
use super::{Alloc, AllocError, Allocator, SliceBuffer};
#[cfg(test)]
mod tests;
const MAX_BYTES: usize = i32::MAX as usize;
pub struct Slice<'a> {
internal: UnsafeCell<Internal>,
_marker: PhantomData<&'a mut [MaybeUninit<u8>]>,
}
impl<'a> Slice<'a> {
pub fn new<B>(buffer: &'a mut B) -> Self
where
B: ?Sized + SliceBuffer,
{
let buffer = buffer.as_uninit_bytes();
let size = buffer.len();
assert!(
size <= MAX_BYTES,
"Buffer of {size} bytes is larger than the maximum {MAX_BYTES}"
);
let mut data = Range::new(buffer.as_mut_ptr_range());
let align = data.end.align_offset(align_of::<Header>());
if align != 0 {
let sub = align_of::<Header>() - align;
if sub <= size {
unsafe {
data.end = data.end.sub(sub);
}
} else {
data.end = data.start;
}
}
Self {
internal: UnsafeCell::new(Internal {
free_head: None,
tail: None,
occupied: None,
full: data,
free: data,
}),
_marker: PhantomData,
}
}
}
unsafe impl<'a> Allocator for &'a Slice<'_> {
#[inline]
fn __do_not_implement() {}
const IS_GLOBAL: bool = false;
type Alloc<T> = SliceAlloc<'a, T>;
#[inline]
fn alloc<T>(self, value: T) -> Result<Self::Alloc<T>, AllocError> {
if size_of::<T>() == 0 {
return Ok(SliceAlloc::ZST);
}
let region = unsafe {
let i = &mut *self.internal.get();
let region = i.alloc(size_of::<T>(), align_of::<T>());
let Some(region) = region else {
return Err(AllocError);
};
region.range.start.cast::<T>().write(value);
Some(region.id)
};
Ok(SliceAlloc {
region,
internal: Some(&self.internal),
cap: 1,
_marker: PhantomData,
})
}
#[inline]
fn alloc_empty<T>(self) -> Self::Alloc<T> {
if size_of::<T>() == 0 {
return SliceAlloc::ZST;
}
let region = unsafe {
(*self.internal.get())
.alloc(0, align_of::<T>())
.map(|r| r.id)
};
SliceAlloc {
region,
internal: Some(&self.internal),
cap: 0,
_marker: PhantomData,
}
}
}
pub struct SliceAlloc<'a, T> {
region: Option<HeaderId>,
internal: Option<&'a UnsafeCell<Internal>>,
cap: usize,
_marker: PhantomData<T>,
}
impl<T> SliceAlloc<'_, T> {
const ZST: Self = Self {
region: None,
internal: None,
cap: usize::MAX,
_marker: PhantomData,
};
}
impl<T> SliceAlloc<'_, T> {
#[inline]
fn free(&mut self) {
let (Some(region), Some(internal)) = (self.region.take(), self.internal) else {
return;
};
unsafe {
(*internal.get()).free(region);
}
}
}
impl<T> Alloc<T> for SliceAlloc<'_, T> {
#[inline]
fn as_ptr(&self) -> *const T {
let (Some(region), Some(internal)) = (self.region, self.internal) else {
return ptr::NonNull::dangling().as_ptr();
};
unsafe {
let i = &*internal.get();
let this = i.header(region);
this.range.start.cast_const().cast()
}
}
#[inline]
fn as_mut_ptr(&mut self) -> *mut T {
let (Some(region), Some(internal)) = (self.region, self.internal) else {
return ptr::NonNull::dangling().as_ptr();
};
unsafe {
let i = &*internal.get();
let this = i.header(region);
this.range.start.cast()
}
}
#[inline]
fn capacity(&self) -> usize {
let Some(internal) = self.internal else {
return usize::MAX;
};
let Some(region_id) = self.region else {
return 0;
};
unsafe {
let i = &mut *internal.get();
i.region(region_id).capacity()
}
}
#[inline]
fn resize(&mut self, len: usize, additional: usize) -> Result<(), AllocError> {
if len + additional <= self.cap {
return Ok(());
}
let Some(internal) = self.internal else {
debug_assert_eq!(
size_of::<T>(),
0,
"Only ZSTs should lack an internal pointer"
);
return Ok(());
};
let Some(region_id) = self.region else {
return Err(AllocError);
};
let Some(len) = len.checked_mul(size_of::<T>()) else {
return Err(AllocError);
};
let Some(additional) = additional.checked_mul(size_of::<T>()) else {
return Err(AllocError);
};
let Some(requested) = len.checked_add(additional) else {
return Err(AllocError);
};
if requested > MAX_BYTES {
return Err(AllocError);
}
unsafe {
let i = &mut *internal.get();
let region = i.region(region_id);
let actual = region.capacity();
if actual >= requested {
self.cap = actual / size_of::<T>();
return Ok(());
};
let Some(region) = i.realloc(region_id, len, requested, align_of::<T>()) else {
return Err(AllocError);
};
self.region = Some(region.id);
self.cap = region.capacity() / size_of::<T>();
Ok(())
}
}
#[inline]
fn try_merge<B>(&mut self, this_len: usize, buf: B, other_len: usize) -> Result<(), B>
where
B: Alloc<T>,
{
let Some(internal) = self.internal else {
return Ok(());
};
let Some(region) = self.region else {
return Err(buf);
};
let this_len = this_len * size_of::<T>();
let other_len = other_len * size_of::<T>();
let other_ptr = buf.as_ptr().cast();
unsafe {
let i = &mut *internal.get();
let mut this = i.region(region);
debug_assert!(this.capacity() >= this_len);
if !ptr::eq(this.range.end.cast_const(), other_ptr) {
return Err(buf);
}
let Some(next) = this.next else {
return Err(buf);
};
forget(buf);
let next = i.region(next);
let to = this.range.start.wrapping_add(this_len);
if this.range.end != to {
this.range.end.copy_to(to, other_len);
}
let old = i.free_region(next);
this.range.end = old.range.end;
Ok(())
}
}
}
impl<T> Drop for SliceAlloc<'_, T> {
#[inline]
fn drop(&mut self) {
self.free()
}
}
struct Region {
id: HeaderId,
ptr: *mut Header,
}
impl Deref for Region {
type Target = Header;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { &*self.ptr }
}
}
impl DerefMut for Region {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.ptr }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(test, derive(PartialOrd, Ord, Hash))]
#[repr(transparent)]
struct HeaderId(NonZeroU16);
impl HeaderId {
#[cfg(test)]
const unsafe fn new_unchecked(value: u16) -> Self {
Self(unsafe { NonZeroU16::new_unchecked(value) })
}
#[inline]
fn new(value: isize) -> Option<Self> {
Some(Self(NonZeroU16::new(u16::try_from(value).ok()?)?))
}
#[inline]
fn get(self) -> usize {
self.0.get() as usize
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct Range {
start: *mut MaybeUninit<u8>,
end: *mut MaybeUninit<u8>,
}
impl Range {
fn new(range: core::ops::Range<*mut MaybeUninit<u8>>) -> Self {
Self {
start: range.start,
end: range.end,
}
}
fn head(self) -> Range {
Self {
start: self.start,
end: self.start,
}
}
}
struct Internal {
free_head: Option<HeaderId>,
tail: Option<HeaderId>,
occupied: Option<HeaderId>,
full: Range,
free: Range,
}
impl Internal {
#[cfg(test)]
#[inline]
fn bytes(&self) -> usize {
unsafe { self.free.start.byte_offset_from(self.full.start) as usize }
}
#[cfg(test)]
#[inline]
fn headers(&self) -> usize {
unsafe {
self.full
.end
.cast::<Header>()
.offset_from(self.free.end.cast()) as usize
}
}
#[inline]
fn remaining(&self) -> usize {
unsafe { self.free.end.byte_offset_from(self.free.start) as usize }
}
#[inline]
fn header(&self, at: HeaderId) -> &Header {
unsafe { &*self.full.end.cast::<Header>().wrapping_sub(at.get()) }
}
#[inline]
fn header_mut(&mut self, at: HeaderId) -> *mut Header {
self.full.end.cast::<Header>().wrapping_sub(at.get())
}
#[inline]
fn region(&mut self, id: HeaderId) -> Region {
Region {
id,
ptr: self.header_mut(id),
}
}
unsafe fn unlink(&mut self, header: &Header) {
unsafe {
if let Some(next) = header.next {
(*self.header_mut(next)).prev = header.prev;
} else {
self.tail = header.prev;
}
if let Some(prev) = header.prev {
(*self.header_mut(prev)).next = header.next;
}
}
}
unsafe fn replace_back(&mut self, region: &mut Region) {
unsafe {
let prev = region.prev.take();
let next = region.next.take();
if let Some(prev) = prev {
(*self.header_mut(prev)).next = next;
}
if let Some(next) = next {
(*self.header_mut(next)).prev = prev;
}
self.push_back(region);
}
}
unsafe fn push_back(&mut self, region: &mut Region) {
unsafe {
if let Some(tail) = self.tail.replace(region.id) {
region.prev = Some(tail);
(*self.header_mut(tail)).next = Some(region.id);
}
}
}
unsafe fn free_region(&mut self, region: Region) -> Header {
unsafe {
self.unlink(®ion);
region.ptr.replace(Header {
range: self.full.head(),
next: self.free_head.replace(region.id),
prev: None,
})
}
}
unsafe fn alloc_header(&mut self, end: *mut MaybeUninit<u8>) -> Option<Region> {
if let Some(region) = self.free_head.take() {
let mut region = self.region(region);
region.range.start = self.free.start;
region.range.end = end;
return Some(region);
}
debug_assert_eq!(
self.free.end.align_offset(align_of::<Header>()),
0,
"End pointer should be aligned to header"
);
let ptr = self.free.end.cast::<Header>().wrapping_sub(1);
if ptr < self.free.start.cast() || ptr >= self.free.end.cast() {
return None;
}
unsafe {
let id = HeaderId::new(self.full.end.cast::<Header>().offset_from(ptr))?;
ptr.write(Header {
range: Range::new(self.free.start..end),
prev: None,
next: None,
});
self.free.end = ptr.cast();
Some(Region { id, ptr })
}
}
unsafe fn alloc(&mut self, requested: usize, align: usize) -> Option<Region> {
if let Some(occupied) = self.occupied {
let region = self.region(occupied);
if region.capacity() >= requested && region.is_aligned(align) {
self.occupied = None;
return Some(region);
}
}
unsafe {
self.align(align)?;
if self.remaining() < requested {
return None;
}
let end = self.free.start.wrapping_add(requested);
let mut region = self.alloc_header(end)?;
self.free.start = end;
debug_assert!(self.free.start <= self.free.end);
self.push_back(&mut region);
Some(region)
}
}
unsafe fn align(&mut self, align: usize) -> Option<()> {
let align = self.free.start.align_offset(align);
if align == 0 {
return Some(());
}
if self.remaining() < align {
return None;
}
let aligned_start = self.free.start.wrapping_add(align);
if let Some(tail) = self.tail {
self.region(tail).range.end = aligned_start;
} else {
unsafe {
let mut region = self.alloc_header(aligned_start)?;
self.push_back(&mut region);
}
}
self.free.start = aligned_start;
Some(())
}
unsafe fn free(&mut self, region: HeaderId) {
let region = self.region(region);
if region.next.is_none() {
unsafe {
self.free_tail(region);
}
return;
}
let Some(prev) = region.prev else {
debug_assert!(
self.occupied.is_none(),
"There can only be one occupied region"
);
self.occupied = Some(region.id);
return;
};
let mut prev = self.region(prev);
let region = unsafe { self.free_region(region) };
prev.range.end = region.range.end;
if region.next.is_none() {
self.free.start = region.range.start;
}
}
unsafe fn free_tail(&mut self, current: Region) {
debug_assert_eq!(self.tail, Some(current.id));
unsafe {
let current = self.free_region(current);
debug_assert_eq!(current.next, None);
self.free.start = match current.prev {
Some(prev) if self.occupied == Some(prev) => {
self.occupied = None;
let prev = self.region(prev);
self.free_region(prev).range.start
}
_ => current.range.start,
};
}
}
unsafe fn reserve(&mut self, additional: usize, align: usize) -> Option<*mut MaybeUninit<u8>> {
unsafe {
self.align(align)?;
}
let free_start = self.free.start.wrapping_add(additional);
if free_start > self.free.end || free_start < self.free.start {
return None;
}
Some(free_start)
}
unsafe fn realloc(
&mut self,
from: HeaderId,
len: usize,
requested: usize,
align: usize,
) -> Option<Region> {
let mut from = self.region(from);
if from.next.is_none() {
unsafe {
let additional = requested - from.capacity();
self.free.start = self.reserve(additional, align)?;
from.range.end = from.range.end.add(additional);
}
return Some(from);
}
if from.range.start == from.range.end {
unsafe {
let free_start = self.reserve(requested, align)?;
from.range.start = replace(&mut self.free.start, free_start);
from.range.end = free_start;
self.replace_back(&mut from);
}
return Some(from);
}
'bail: {
let Some(prev) = from.prev else {
break 'bail;
};
if self.occupied != Some(prev) {
break 'bail;
}
let mut prev = self.region(prev);
if prev.capacity() + from.capacity() < requested {
break 'bail;
}
if !prev.is_aligned(align) {
break 'bail;
}
unsafe {
let from = self.free_region(from);
from.range.start.copy_to(prev.range.start, len);
prev.range.end = from.range.end;
self.occupied = None;
return Some(prev);
}
}
unsafe {
let to = self.alloc(requested, align)?;
from.range.start.copy_to_nonoverlapping(to.range.start, len);
self.free(from.id);
Some(to)
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct Header {
range: Range,
prev: Option<HeaderId>,
next: Option<HeaderId>,
}
impl Header {
#[inline]
fn capacity(&self) -> usize {
unsafe { self.range.end.byte_offset_from(self.range.start) as usize }
}
#[inline]
fn is_aligned(&self, align: usize) -> bool {
self.range.start.align_offset(align) == 0
}
}