mod allocation;
use allocation::Allocation;
use std::{
self,
cell::UnsafeCell,
mem::{size_of, MaybeUninit},
ptr, slice,
};
thread_local!(
static THREAD_LOCAL: Stack = Stack::new()
);
pub struct Stack(UnsafeCell<Allocation>);
impl Drop for Stack {
fn drop(&mut self) {
let stack = self.0.get_mut();
unsafe {
stack.force_dealloc();
}
}
}
impl Stack {
pub fn new() -> Self {
Self(UnsafeCell::new(Allocation::null()))
}
pub fn uninit<T, R, F>(&self, f: F) -> R
where
F: FnOnce(&mut MaybeUninit<T>) -> R,
{
self.uninit_slice(1, |slice| f(&mut slice[0]))
}
pub fn uninit_slice<T, F, R>(&self, len: usize, f: F) -> R
where
F: FnOnce(&mut [MaybeUninit<T>]) -> R,
{
if std::mem::size_of::<T>() == 0 {
let mut tmp = Vec::<T>::with_capacity(len);
let slice = &mut tmp.spare_capacity_mut()[..len];
return f(slice);
}
if len == 0 {
return f(&mut []);
}
let (_restore, (ptr, len)) = unsafe {
let stack = &mut *self.0.get();
stack.get_slice(&self.0, len)
};
let slice = unsafe { slice::from_raw_parts_mut(ptr as *mut MaybeUninit<T>, len) };
f(slice)
}
pub fn buffer<T, F, R, I>(&self, i: I, f: F) -> R
where
I: Iterator<Item = T>,
F: FnOnce(&mut [T]) -> R,
{
if size_of::<T>() == 0 {
let mut v: Vec<_> = i.collect();
return f(&mut v);
}
struct Writer<'a, T> {
restore: Option<DropStack<'a>>,
base: *mut T,
len: usize,
capacity: usize,
}
impl<T> Writer<'_, T> {
unsafe fn write(&mut self, item: T) {
self.base.add(self.len).write(item);
self.len += 1;
}
fn try_reuse(&mut self, stack: &mut Allocation) -> bool {
if let Some(prev) = &self.restore {
if prev.restore.ref_eq(stack) {
let required_bytes = size_of::<T>() * self.capacity;
if stack.remaining_bytes() >= required_bytes {
stack.len += required_bytes;
self.capacity *= 2;
return true;
}
}
}
false
}
}
impl<T> Drop for Writer<'_, T> {
fn drop(&mut self) {
unsafe {
for i in 0..self.len {
self.base.add(i).drop_in_place()
}
}
}
}
unsafe {
let mut writer = Writer {
restore: None,
base: ptr::null_mut(),
capacity: 0,
len: 0,
};
for next in i {
if writer.capacity == writer.len {
let stack = &mut *self.0.get();
if !writer.try_reuse(stack) {
let (restore, (base, capacity)) =
stack.get_slice(&self.0, (writer.len * 2).max(1));
if writer.len != 0 {
ptr::copy_nonoverlapping(writer.base, base, writer.len);
}
writer.restore = Some(restore);
writer.capacity = capacity;
writer.base = base;
}
}
writer.write(next);
}
let buffer = slice::from_raw_parts_mut(writer.base, writer.len);
f(buffer)
}
}
}
pub fn uninit_slice<T, F, R>(len: usize, f: F) -> R
where
F: FnOnce(&mut [MaybeUninit<T>]) -> R,
{
THREAD_LOCAL.with(|stack| stack.uninit_slice(len, f))
}
pub fn uninit<T, F, R>(f: F) -> R
where
F: FnOnce(&mut MaybeUninit<T>) -> R,
{
THREAD_LOCAL.with(|stack| stack.uninit(f))
}
pub fn buffer<T, F, R, I>(i: I, f: F) -> R
where
I: Iterator<Item = T>,
F: FnOnce(&mut [T]) -> R,
{
THREAD_LOCAL.with(|stack| stack.buffer(i, f))
}
pub(crate) struct DropStack<'a> {
pub restore: Allocation,
pub location: &'a UnsafeCell<Allocation>,
}
impl Drop for DropStack<'_> {
fn drop(&mut self) {
unsafe {
let mut current = &mut *self.location.get();
if current.ref_eq(&self.restore) {
current.len = self.restore.len;
} else {
self.restore.try_dealloc();
}
}
}
}