#[cfg(test)]
mod tests;
use core::cell::{Cell, UnsafeCell};
use core::fmt::{self, Arguments};
use core::marker::PhantomData;
use core::mem::{align_of, forget, replace, size_of, MaybeUninit};
use core::num::NonZeroU8;
use core::ops::{Deref, DerefMut};
use core::ptr;
use core::slice;
use musli::buf::Error;
use musli::{Allocator, Buf};
use crate::DEFAULT_STACK_BUFFER;
const ALIGNMENT: usize = 8;
const HEADER_U32: u32 = size_of::<Header>() as u32;
const MAX_BYTES: u32 = i32::MAX as u32;
const _: () = {
if ALIGNMENT % align_of::<Header>() != 0 {
panic!("Header is not aligned by 8");
}
};
#[repr(align(8))]
pub struct StackBuffer<const N: usize = DEFAULT_STACK_BUFFER> {
data: [MaybeUninit<u8>; N],
}
impl<const C: usize> StackBuffer<C> {
pub const fn new() -> Self {
Self {
data: unsafe { MaybeUninit::uninit().assume_init() },
}
}
}
impl<const C: usize> Default for StackBuffer<C> {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl<const C: usize> Deref for StackBuffer<C> {
type Target = [MaybeUninit<u8>];
#[inline]
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl<const C: usize> DerefMut for StackBuffer<C> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}
pub struct Stack<'a> {
internal: UnsafeCell<Internal>,
_marker: PhantomData<&'a mut [MaybeUninit<u8>]>,
}
impl<'a> Stack<'a> {
pub fn new(buffer: &'a mut [MaybeUninit<u8>]) -> Self {
assert!(
buffer.len() <= MAX_BYTES as usize,
"Buffer too large 0-{}",
MAX_BYTES
);
assert!(
buffer.as_ptr() as usize % ALIGNMENT == 0,
"Provided buffer at {:08x} is not aligned by 8",
buffer.as_ptr() as usize
);
let size = buffer.len() as u32;
let size = size - size % (ALIGNMENT as u32);
Self {
internal: UnsafeCell::new(Internal {
free: None,
head: None,
tail: None,
bytes: 0,
headers: 0,
occupied: 0,
size,
data: buffer.as_mut_ptr(),
}),
_marker: PhantomData,
}
}
}
impl Allocator for Stack<'_> {
type Buf<'this> = StackBuf<'this> where Self: 'this;
#[inline(always)]
fn alloc(&self) -> Option<Self::Buf<'_>> {
let region = unsafe { (*self.internal.get()).alloc(0)? };
Some(StackBuf {
region: Cell::new(region.id),
internal: &self.internal,
})
}
}
pub struct StackBuf<'a> {
region: Cell<HeaderId>,
internal: &'a UnsafeCell<Internal>,
}
impl<'a> Buf for StackBuf<'a> {
#[inline]
fn write(&mut self, bytes: &[u8]) -> bool {
if bytes.is_empty() {
return true;
}
if bytes.len() > MAX_BYTES as usize {
return false;
}
let bytes_len = bytes.len() as u32;
unsafe {
let i = &mut *self.internal.get();
let region = i.region(self.region.get());
let len = region.len;
let mut region = 'out: {
if region.cap - len >= bytes_len {
break 'out region;
};
let requested = len + bytes_len;
let Some(region) = i.realloc(self.region.get(), len, requested) else {
return false;
};
self.region.set(region.id);
region
};
let dst = i.data.wrapping_add((region.start + len) as usize).cast();
ptr::copy_nonoverlapping(bytes.as_ptr(), dst, bytes.len());
region.len += bytes.len() as u32;
true
}
}
#[inline]
fn write_buffer<B>(&mut self, buf: B) -> bool
where
B: Buf,
{
'out: {
let other_ptr = buf.as_slice().as_ptr().cast();
unsafe {
let i = &mut *self.internal.get();
let mut this = i.region(self.region.get());
debug_assert!(this.cap >= this.len);
let data_cap_ptr = this.data_cap_ptr(i.data);
if !ptr::eq(data_cap_ptr.cast_const(), other_ptr) {
break 'out;
}
let Some(next) = this.next else {
break 'out;
};
forget(buf);
let next = i.region(next);
let diff = this.cap - this.len;
if diff > 0 {
let to_ptr = data_cap_ptr.wrapping_sub(diff as usize);
ptr::copy(data_cap_ptr, to_ptr, next.len as usize);
}
let old = i.free_region(next);
this.cap += old.cap;
this.len += old.len;
return true;
}
}
self.write(buf.as_slice())
}
#[inline(always)]
fn len(&self) -> usize {
unsafe {
let i = &*self.internal.get();
i.header(self.region.get()).len as usize
}
}
#[inline(always)]
fn as_slice(&self) -> &[u8] {
unsafe {
let i = &*self.internal.get();
let this = i.header(self.region.get());
let ptr = i.data.wrapping_add(this.start as usize).cast();
slice::from_raw_parts(ptr, this.len as usize)
}
}
#[inline(always)]
fn write_fmt(&mut self, arguments: Arguments<'_>) -> Result<(), Error> {
fmt::write(self, arguments).map_err(|_| Error)
}
}
impl fmt::Write for StackBuf<'_> {
#[inline]
fn write_str(&mut self, s: &str) -> fmt::Result {
if !self.write(s.as_bytes()) {
return Err(fmt::Error);
}
Ok(())
}
}
impl Drop for StackBuf<'_> {
fn drop(&mut self) {
unsafe {
(*self.internal.get()).free(self.region.get());
}
}
}
struct Region {
id: HeaderId,
ptr: *mut Header,
}
impl Region {
#[inline]
unsafe fn data_cap_ptr(&self, data: *mut MaybeUninit<u8>) -> *mut MaybeUninit<u8> {
data.wrapping_add((self.start + self.cap) as usize)
}
#[inline]
unsafe fn data_base_ptr(&self, data: *mut MaybeUninit<u8>) -> *mut MaybeUninit<u8> {
data.wrapping_add(self.start as usize)
}
}
impl Deref for Region {
type Target = Header;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { &*self.ptr }
}
}
impl DerefMut for Region {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.ptr }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(test, derive(PartialOrd, Ord, Hash))]
#[repr(transparent)]
struct HeaderId(NonZeroU8);
impl HeaderId {
#[inline]
const unsafe fn new_unchecked(value: u8) -> Self {
Self(NonZeroU8::new_unchecked(value))
}
#[inline]
fn get(self) -> u8 {
self.0.get()
}
}
struct Internal {
free: Option<HeaderId>,
head: Option<HeaderId>,
tail: Option<HeaderId>,
bytes: u32,
headers: u8,
occupied: u8,
size: u32,
data: *mut MaybeUninit<u8>,
}
impl Internal {
#[inline]
fn header(&self, at: HeaderId) -> &Header {
unsafe {
&*self
.data
.wrapping_add(self.region_to_addr(at))
.cast::<Header>()
}
}
#[inline]
fn header_mut(&mut self, at: HeaderId) -> *mut Header {
self.data
.wrapping_add(self.region_to_addr(at))
.cast::<Header>()
}
#[inline]
fn region(&mut self, id: HeaderId) -> Region {
Region {
id,
ptr: self.header_mut(id),
}
}
unsafe fn unlink(&mut self, header: &Header) {
if let Some(next) = header.next {
(*self.header_mut(next)).prev = header.prev;
} else {
self.tail = header.prev;
}
if let Some(prev) = header.prev {
(*self.header_mut(prev)).next = header.next;
} else {
self.head = header.next;
}
}
unsafe fn replace_back(&mut self, region: &mut Region) {
let prev = region.prev.take();
let next = region.next.take();
if let Some(prev) = prev {
(*self.header_mut(prev)).next = next;
}
if let Some(next) = next {
(*self.header_mut(next)).prev = prev;
}
if self.head == Some(region.id) {
self.head = next;
}
self.push_back(region);
}
unsafe fn push_back(&mut self, region: &mut Region) {
if self.head.is_none() {
self.head = Some(region.id);
}
if let Some(tail) = self.tail.replace(region.id) {
region.prev = Some(tail);
(*self.region(tail).ptr).next = Some(region.id);
}
}
unsafe fn free_region(&mut self, region: Region) -> Header {
let old = region.ptr.replace(Header {
start: 0,
len: 0,
cap: 0,
state: State::Free,
next_free: self.free.replace(region.id),
prev: None,
next: None,
});
self.unlink(&old);
old
}
unsafe fn alloc(&mut self, requested: u32) -> Option<Region> {
if self.occupied > 0 {
if let Some(mut region) =
self.find_region(|h| h.state == State::Occupy && h.cap >= requested)
{
self.occupied -= 1;
region.state = State::Used;
return Some(region);
}
}
let mut region = 'out: {
if let Some(mut region) = self.pop_free() {
let bytes = self.bytes + requested;
if bytes > self.size {
return None;
}
region.start = self.bytes;
region.state = State::Used;
region.cap = requested;
self.bytes = bytes;
break 'out region;
}
let bytes = self.bytes + requested;
let headers = self.headers.checked_add(1)?;
let size = self.size.checked_sub(HEADER_U32)?;
if bytes > size {
return None;
}
let start = replace(&mut self.bytes, bytes);
self.headers = headers;
self.size = size;
let region = self.region(HeaderId::new_unchecked(headers));
region.ptr.write(Header {
start,
len: 0,
cap: requested,
state: State::Used,
next_free: None,
prev: None,
next: None,
});
region
};
self.push_back(&mut region);
Some(region)
}
unsafe fn free(&mut self, region: HeaderId) {
let mut region = self.region(region);
debug_assert_eq!(region.state, State::Used);
debug_assert_eq!(region.next_free, None);
if region.next.is_none() {
self.free_tail(region);
return;
}
let Some(prev) = region.prev else {
self.occupied += 1;
region.state = State::Occupy;
region.len = 0;
return;
};
let mut prev = self.region(prev);
debug_assert!(matches!(prev.state, State::Occupy | State::Used));
let region = self.free_region(region);
prev.cap += region.cap;
if region.next.is_none() {
self.bytes = region.start;
}
}
unsafe fn free_tail(&mut self, current: Region) {
debug_assert_eq!(self.tail, Some(current.id));
let current = self.free_region(current);
debug_assert_eq!(current.next, None);
self.bytes -= current.cap;
let Some(prev) = current.prev else {
return;
};
let prev = self.region(prev);
if prev.state == State::Occupy {
let prev = self.free_region(prev);
self.bytes -= prev.cap;
self.occupied -= 1;
}
}
unsafe fn realloc(&mut self, from: HeaderId, len: u32, requested: u32) -> Option<Region> {
let mut from = self.region(from);
if from.next.is_none() {
let additional = requested - from.cap;
if self.bytes + additional > self.size {
return None;
}
from.cap += additional;
self.bytes += additional;
return Some(from);
}
'bail: {
let Some(prev) = from.prev else {
break 'bail;
};
let mut prev = self.region(prev);
if prev.state != State::Occupy || prev.cap + len < requested {
break 'bail;
}
let prev_ptr = prev.data_base_ptr(self.data);
let from_ptr = from.data_base_ptr(self.data);
let from = self.free_region(from);
ptr::copy(from_ptr, prev_ptr, from.len as usize);
prev.state = State::Used;
prev.cap += from.cap;
prev.len = from.len;
return Some(prev);
}
if from.cap == 0 {
let bytes = self.bytes + requested;
if bytes > self.size {
return None;
}
from.start = self.bytes;
from.cap = requested;
self.replace_back(&mut from);
self.bytes = bytes;
return Some(from);
}
let mut to = self.alloc(requested)?;
let from_data = self
.data
.wrapping_add(from.start as usize)
.cast::<u8>()
.cast_const();
let to_data = self.data.wrapping_add(to.start as usize).cast::<u8>();
ptr::copy_nonoverlapping(from_data, to_data, len as usize);
to.len = len;
self.free(from.id);
Some(to)
}
unsafe fn find_region<T>(&mut self, mut condition: T) -> Option<Region>
where
T: FnMut(&Header) -> bool,
{
let mut next = self.head;
while let Some(id) = next {
let ptr = self.header_mut(id);
if condition(&*ptr) {
return Some(Region { id, ptr });
}
next = (*ptr).next;
}
None
}
unsafe fn pop_free(&mut self) -> Option<Region> {
let id = self.free.take()?;
let ptr = self.header_mut(id);
self.free = (*ptr).next_free.take();
Some(Region { id, ptr })
}
#[inline]
fn region_to_addr(&self, at: HeaderId) -> usize {
region_to_addr(self.size, self.headers, at)
}
}
#[inline]
fn region_to_addr(size: u32, headers: u8, at: HeaderId) -> usize {
(size + u32::from(headers - at.get()) * HEADER_U32) as usize
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
enum State {
Free = 0,
Occupy,
Used,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(align(8))]
struct Header {
start: u32,
len: u32,
cap: u32,
state: State,
next_free: Option<HeaderId>,
prev: Option<HeaderId>,
next: Option<HeaderId>,
}