use crate::refs::RefBox;
use crate::stack::StackVec;
use core::fmt;
use core::fmt::Write;
use core::mem::MaybeUninit;
use core::ops::Index;
use core::slice;
pub struct StackAlloc<'a>(StackVec<'a, u8>);
impl<'lex> StackAlloc<'lex> {
#[inline]
pub const fn from_slice(raw: &'lex mut [MaybeUninit<u8>]) -> Self {
Self(StackVec::from_slice(raw))
}
#[inline]
pub fn alloc<T>(&mut self) -> Option<&'lex mut MaybeUninit<T>> {
Some(&mut self.alloc_many(1)?[0])
}
#[inline]
pub fn alloc_many<T>(&mut self,n:usize) -> Option<&'lex mut [MaybeUninit<T>]> {
let curr_len = self.0.len();
let curr_ptr = unsafe { self.0.get_base().add(curr_len) };
let pad = curr_ptr.align_offset(align_of::<T>());
let total = pad + n*size_of::<T>();
unsafe {
self.0.alloc(total)?;
let slot = curr_ptr.add(pad) as *mut MaybeUninit<T>;
Some(slice::from_raw_parts_mut(slot,n))
}
}
#[inline]
pub fn save<T>(&mut self, t: T) -> Option<RefBox<'lex, T>> {
self.alloc().map(|x| unsafe { RefBox::new(x.write(t)) })
}
#[inline]
pub fn save_refboxed<T>(&mut self, t: RefBox<T>) -> Option<RefBox<'lex, T>> {
unsafe{
let slot = self.alloc::<T>()?;
let slot = slot.write(t.into_inner());
Some(RefBox::new(slot))
}
}
#[inline]
pub fn save_slice<T: Clone>(&mut self, t: &[T]) -> Option<RefBox<'lex, [T]>> {
unsafe {
let slot: &mut [MaybeUninit<T>] = self.alloc_many(t.len())?;
for (s, x) in slot.into_iter().zip(t.iter()) {
s.write(x.clone());
}
let ans = core::slice::from_raw_parts_mut(slot.as_mut_ptr() as *mut T, t.len());
Some(RefBox::new(ans))
}
}
#[inline]
pub fn check_point(&self) -> usize {
self.0.len()
}
#[inline]
pub unsafe fn goto_checkpoint(&mut self, cp: usize) {
let to_free = self.0.len() - cp;
self.0.free(to_free).expect("checkpoint math is wrong");
}
#[inline]
pub unsafe fn free(&mut self,n:usize){
self.0.free(n);
}
#[inline]
pub fn len(&self) -> usize {
self.0.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
#[inline(always)]
pub fn with_addr(&self, addr: usize) -> *mut u8 {
self.0.get_base().with_addr(addr)
}
}
#[must_use]
pub struct StackWriter<'me, 'lex> {
alloc: &'me mut StackAlloc<'lex>,
start: usize,
}
impl Write for StackWriter<'_, '_> {
#[inline]
fn write_str(&mut self, s: &str) -> fmt::Result {
self.alloc.0.push_slice(s.as_bytes()).ok_or(fmt::Error)
}
}
impl<'me, 'lex> StackWriter<'me, 'lex> {
#[inline]
pub fn new(alloc: &'me mut StackAlloc<'lex>) -> Self {
let start = alloc.0.len();
Self { alloc, start }
}
#[inline]
pub fn finish(self) -> &'lex mut str {
unsafe {
let start = self.alloc.0.get_base().add(self.start);
let len = self.alloc.0.len() - self.start;
let body = core::slice::from_raw_parts_mut(start, len);
core::str::from_utf8_unchecked_mut(body)
}
}
#[inline]
pub fn discard(self) {
unsafe { self.alloc.goto_checkpoint(self.start) }
}
}
pub struct StackAllocator<'a, T>(StackVec<'a, T>);
impl<'a, T> StackAllocator<'a, T> {
#[inline]
pub const fn new(buf: &'a mut [MaybeUninit<T>]) -> Self {
Self(StackVec::from_slice(buf))
}
#[inline]
pub fn save(&mut self, elem: T) -> Result<&'a mut T, T> {
if size_of::<T>() == 0 {
return Ok(unsafe { &mut *core::ptr::dangling_mut() });
}
unsafe {
match self.0.alloc(1) {
None => Err(elem),
Some(_) => {
let slot = self.0.peek_raw().unwrap_unchecked();
slot.write(elem);
Ok(&mut *slot)
}
}
}
}
#[inline]
pub fn check_point(&self) -> usize {
self.0.len()
}
#[inline]
pub unsafe fn try_index_checkpoint(&self, cp: usize) -> Option<&'a [T]> {
let live = self.0.len() - cp;
let addr = self.0.peek_many(live)?.as_ptr().addr();
let p = self.0.peek_raw()?.with_addr(addr);
unsafe { Some(slice::from_raw_parts(p, live)) }
}
#[inline]
pub unsafe fn index_checkpoint(&self, cp: usize) -> &'a [T] {
unsafe {
self.try_index_checkpoint(cp)
.expect("checkpoint math is wrong")
}
}
#[inline]
pub unsafe fn goto_checkpoint(&mut self, cp: usize) {
let live = self.0.len() - cp;
self.0.flush(live).expect("checkpoint math is wrong"); }
#[inline(always)]
pub unsafe fn get_inner(&mut self) -> &mut StackVec<'a, T> {
&mut self.0
}
#[inline(always)]
pub fn with_addr(&self, addr: usize) -> *mut T {
self.0.get_base().with_addr(addr)
}
#[inline]
pub fn len(&self) -> usize {
self.0.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stack::make_storage;
use core::mem::ManuallyDrop;
use core::mem::{MaybeUninit, align_of};
#[inline]
fn addr_of<T>(slot: &mut MaybeUninit<T>) -> usize {
slot as *mut _ as usize
}
#[repr(align(32))]
#[derive(Copy, Clone, Debug, PartialEq)]
struct OverAligned(u8);
struct Zst;
#[test]
fn stack_alloc_aligment() {
let mut backing: [_; 1024] = make_storage();
let mut arena = StackAlloc::from_slice(&mut backing);
let s1 = arena.alloc::<u16>().expect("u16 should fit");
let a1 = addr_of(s1);
assert_eq!(a1 % align_of::<u16>(), 0, "u16 not aligned");
let s2 = arena.alloc::<OverAligned>().expect("OverAligned");
let a2 = addr_of(s2);
assert_eq!(a2 % align_of::<OverAligned>(), 0, "OverAligned mis-aligned");
*s2 = MaybeUninit::new(OverAligned(2));
unsafe { assert_eq!(s2.assume_init(), OverAligned(2)) }
let s3 = arena.alloc::<Zst>().expect("ZST");
*s3 = MaybeUninit::new(Zst);
let a3 = addr_of(s3);
assert_eq!(a3 % align_of::<()>(), 0);
let s4 = arena.alloc::<[u64; 3]>().expect("[u64;3]");
let a4 = addr_of(s4);
assert_eq!(a4 % align_of::<[u64; 3]>(), 0, "array mis-aligned");
while let Some(_) = arena.alloc::<u64>() {}
assert!(arena.alloc::<u64>().is_none(), "OOM must remain OOM");
}
#[test]
fn stack_writer_write_and_finish() {
let mut backing: [_; 1024] = make_storage(); let mut arena = StackAlloc::from_slice(&mut backing);
let mut writer = StackWriter::new(&mut arena);
write!(writer, "hello").unwrap();
write!(writer, " world {}", 42).unwrap();
let result = writer.finish();
assert_eq!(result, "hello world 42");
let mut writer = StackWriter::new(&mut arena);
write!(writer, "junk").unwrap();
writer.discard();
let mut writer = StackWriter::new(&mut arena);
write!(writer, "finish").unwrap();
let result2 = writer.finish();
assert_eq!(result2, "finish");
assert_eq!(result, "hello world 42");
let remaining_space = arena.len();
let used_bytes = 1024 - remaining_space;
assert!(
used_bytes >= result.len() + result2.len(),
"allocator should have used at least result length"
);
}
#[test]
fn test_stack_allocator_basic() {
use alloc::boxed::Box;
let mut storage = [const { MaybeUninit::<Box<i32>>::uninit() }; 8];
let mut alloc = StackAllocator::new(&mut storage);
let a = alloc.save(Box::new(10)).unwrap();
let b = alloc.save(Box::new(20)).unwrap();
assert_eq!(**a, 10);
assert_eq!(**b, 20);
let cp = alloc.check_point();
let c = alloc.save(Box::new(30)).unwrap();
assert_eq!(*c, Box::new(30));
unsafe {
alloc.goto_checkpoint(cp);
}
let d = alloc.save(Box::new(99)).unwrap();
assert_eq!(*d, Box::new(99));
}
#[test]
fn test_stack_alloc_boxes() {
use alloc::boxed::Box;
let mut storage = [const { MaybeUninit::uninit() }; 1024];
let mut alloc = StackAlloc::from_slice(&mut storage);
let a = alloc.save_slice(&[10, 2]).unwrap();
let mut mem = ManuallyDrop::new(Box::new(20));
let b = alloc
.save_refboxed(unsafe { RefBox::drop_this(&mut mem) })
.unwrap();
assert_eq!(**b, 20);
assert_eq!(&*a, &[10, 2]);
let cp = alloc.check_point();
let c = alloc.save(Box::new(30)).unwrap();
assert_eq!(*c, Box::new(30));
core::mem::drop(c);
unsafe {
alloc.goto_checkpoint(cp);
}
let d = alloc.save(Box::new(99)).unwrap();
assert_eq!(*d, Box::new(99));
}
}
pub struct Registery<'a, T>(StackVec<'a, T>);
impl<'a, T> Registery<'a, T> {
#[inline]
pub const fn new(buf: &'a mut [MaybeUninit<T>]) -> Self {
Self(StackVec::from_slice(buf))
}
#[inline]
pub fn save(&mut self, elem: T) -> Result<(), T> {
if size_of::<T>() == 0 {
return Ok(());
}
unsafe {
match self.0.alloc(1) {
None => Err(elem),
Some(_) => {
self.0.peek_raw().unwrap_unchecked().write(elem);
Ok(())
}
}
}
}
#[inline]
pub fn get(&self, cp: usize) -> Option<&'a T> {
unsafe { Some(&*self.0.get_raw(cp)?) }
}
#[inline]
pub fn check_point(&self) -> usize {
self.0.len()
}
#[inline]
pub fn try_index_checkpoint(&self, cp: usize) -> Option<&'a [T]> {
let live = self.0.len() - cp;
let addr = self.0.peek_many(live)?.as_ptr().addr();
let p = self.0.peek_raw()?.with_addr(addr);
unsafe { Some(slice::from_raw_parts(p, live)) }
}
#[inline]
pub fn index_checkpoint(&self, cp: usize) -> &'a [T] {
self.try_index_checkpoint(cp)
.expect("checkpoint math is wrong")
}
#[inline]
pub unsafe fn goto_checkpoint(&mut self, cp: usize) {
let live = self.0.len() - cp;
self.0.flush(live).expect("checkpoint math is wrong"); }
#[inline(always)]
pub unsafe fn get_inner(&mut self) -> &mut StackVec<'a, T> {
&mut self.0
}
#[inline(always)]
pub fn with_addr(&self, addr: usize) -> *mut T {
self.0.get_base().with_addr(addr)
}
#[inline]
pub fn len(&self) -> usize {
self.0.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl<T> Index<usize> for Registery<'_, T> {
type Output = T;
fn index(&self, id: usize) -> &T {
&self.0[id]
}
}
#[test]
fn test_registert_basic() {
use alloc::boxed::Box;
let mut storage = [const { MaybeUninit::<Box<i32>>::uninit() }; 8];
let mut alloc = Registery::new(&mut storage);
alloc.save(Box::new(10)).unwrap();
let a = alloc.get(0).unwrap();
alloc.save(Box::new(20)).unwrap();
let b = alloc.get(1).unwrap();
assert_eq!(**a, 10);
assert_eq!(**b, 20);
let cp = alloc.check_point();
alloc.save(Box::new(30)).unwrap();
let c = alloc.get(2).unwrap();
assert_eq!(*c, Box::new(30));
unsafe {
alloc.goto_checkpoint(cp);
}
alloc.save(Box::new(99)).unwrap();
let d = alloc.get(2).unwrap();
assert_eq!(*d, Box::new(99));
}