use core::alloc::Layout;
use core::cell::UnsafeCell;
#[cfg(feature = "unsize")]
use core::marker::Unsize;
use core::mem::MaybeUninit;
use core::ops::Range;
use core::ptr::{NonNull, Pointee};
use core::{mem, ptr};
use crate::base::{
ClonesafeStorage, ExactSizeStorage, FromLeakedStorage, LeaksafeStorage, MultiItemStorage,
Storage, StorageSafe,
};
use crate::error::{Result, StorageError};
use crate::handles::{Handle, OffsetMetaHandle};
use crate::utils;
fn blocks<S>(size: usize) -> usize {
size / mem::size_of::<S>()
}
fn blocks_for<S, T>(capacity: usize) -> usize {
(mem::size_of::<T>() * capacity) / mem::size_of::<S>()
}
fn lock_range<const N: usize>(lock: &mut spin::MutexGuard<'_, [bool; N]>, range: Range<usize>) {
lock[range].iter_mut().for_each(|i| {
debug_assert!(!*i);
*i = true;
});
}
fn unlock_range<const N: usize>(lock: &mut spin::MutexGuard<'_, [bool; N]>, range: Range<usize>) {
lock[range].iter_mut().for_each(|i| {
debug_assert!(*i);
*i = false;
});
}
fn find_open<S, const N: usize>(
lock: &spin::MutexGuard<'_, [bool; N]>,
size: usize,
) -> Result<Range<usize>> {
let blocks = blocks::<S>(size);
if blocks == 0 {
return Ok(0..0);
}
if blocks > N {
return Err(StorageError::InsufficientSpace {
expected: size,
available: Some(mem::size_of::<S>() * N),
});
}
lock.iter()
.scan(0, |n, &v| {
if v {
*n = 0;
} else {
*n += 1;
}
Some(*n)
})
.position(|count| count >= blocks)
.map(|end| {
let start = end - (blocks - 1);
start..(end + 1)
})
.ok_or(StorageError::NoSlots)
}
#[derive(Debug)]
pub struct VirtHeap<S, const N: usize> {
used: spin::Mutex<[bool; N]>,
storage: UnsafeCell<[MaybeUninit<S>; N]>,
}
impl<S, const N: usize> VirtHeap<S, N>
where
S: StorageSafe,
{
pub const fn new() -> VirtHeap<S, N> {
VirtHeap {
used: spin::Mutex::new([false; N]),
storage: UnsafeCell::new(unsafe {
MaybeUninit::<[MaybeUninit<S>; N]>::uninit().assume_init()
}),
}
}
}
impl<S, const N: usize> VirtHeap<S, N>
where
S: StorageSafe,
{
fn find_lock(&self, size: usize) -> Result<usize> {
let mut used = self.used.lock();
let open = find_open::<S, N>(&used, size)?;
let start = open.start;
lock_range(&mut used, open);
Ok(start)
}
fn grow_in_place<T>(
&self,
handle: OffsetMetaHandle<[T]>,
old_layout: Layout,
new_layout: Layout,
) -> bool {
let mut used = self.used.lock();
let old_blocks = blocks::<S>(old_layout.size());
let new_blocks = blocks::<S>(new_layout.size());
let after_old = (handle.offset() + old_blocks)..(handle.offset() + new_blocks);
let has_space = used[after_old.clone()].iter().all(|&i| !i);
if has_space {
lock_range(&mut used, after_old);
}
has_space
}
fn grow_move<T>(
&self,
handle: <&Self as Storage>::Handle<[T]>,
new_layout: Layout,
) -> Option<usize> {
let mut used = self.used.lock();
let old_range = handle.offset()..(handle.offset() + blocks_for::<S, T>(handle.metadata()));
if handle.metadata() != 0 {
unlock_range(&mut used, old_range.clone());
}
let new_range = match find_open::<S, N>(&used, new_layout.size()) {
Ok(open) => open,
Err(_) => {
if handle.metadata() != 0 {
lock_range(&mut used, old_range);
}
return None;
}
};
let new_start = new_range.start;
lock_range(&mut used, new_range);
unsafe { &mut *self.storage.get() }.copy_within(old_range, new_start);
Some(new_start)
}
}
impl<S, const N: usize> Default for VirtHeap<S, N>
where
S: StorageSafe,
{
fn default() -> Self {
VirtHeap::new()
}
}
unsafe impl<S, const N: usize> Storage for &VirtHeap<S, N>
where
S: StorageSafe,
{
type Handle<T: ?Sized> = OffsetMetaHandle<T>;
unsafe fn get<T: ?Sized>(&self, handle: Self::Handle<T>) -> NonNull<T> {
let slice_ptr = unsafe { ptr::addr_of_mut!((*self.storage.get())[handle.offset()]) };
let ptr = unsafe { NonNull::new_unchecked(slice_ptr).cast() };
NonNull::from_raw_parts(ptr, handle.metadata())
}
fn from_raw_parts<T: ?Sized + Pointee>(
handle: Self::Handle<()>,
meta: T::Metadata,
) -> Self::Handle<T> {
<Self::Handle<T>>::from_raw_parts(handle, meta)
}
fn cast<T: ?Sized + Pointee, U>(handle: Self::Handle<T>) -> Self::Handle<U> {
handle.cast()
}
fn cast_unsized<T: ?Sized + Pointee, U: ?Sized + Pointee<Metadata = T::Metadata>>(
handle: Self::Handle<T>,
) -> Self::Handle<U> {
handle.cast_unsized()
}
#[cfg(feature = "unsize")]
fn coerce<T: ?Sized + Pointee + Unsize<U>, U: ?Sized + Pointee>(
handle: Self::Handle<T>,
) -> Self::Handle<U> {
handle.coerce()
}
fn allocate_single<T: ?Sized + Pointee>(
&mut self,
meta: T::Metadata,
) -> Result<Self::Handle<T>> {
self.allocate(meta)
}
unsafe fn deallocate_single<T: ?Sized>(&mut self, handle: Self::Handle<T>) {
unsafe { self.deallocate(handle) }
}
unsafe fn try_grow<T>(
&mut self,
handle: Self::Handle<[T]>,
capacity: usize,
) -> Result<Self::Handle<[T]>> {
debug_assert!(capacity >= handle.metadata());
let old_layout = Layout::array::<T>(handle.metadata()).expect("Valid handle");
let new_layout = Layout::array::<T>(capacity).map_err(|_| StorageError::exceeds_max())?;
if self.grow_in_place(handle, old_layout, new_layout) {
Ok(OffsetMetaHandle::from_offset_meta(
handle.offset(),
capacity,
))
} else if let Some(new_start) = self.grow_move(handle, new_layout) {
Ok(OffsetMetaHandle::from_offset_meta(new_start, capacity))
} else {
Err(StorageError::InsufficientSpace {
expected: new_layout.size(),
available: None,
})
}
}
unsafe fn try_shrink<T>(
&mut self,
handle: Self::Handle<[T]>,
capacity: usize,
) -> Result<Self::Handle<[T]>> {
debug_assert!(capacity <= handle.metadata());
unlock_range(
&mut self.used.lock(),
(handle.offset() + capacity)..(handle.offset() + handle.metadata()),
);
Ok(OffsetMetaHandle::from_offset_meta(
handle.offset(),
capacity,
))
}
}
unsafe impl<S, const N: usize> MultiItemStorage for &VirtHeap<S, N>
where
S: StorageSafe,
{
fn allocate<T: ?Sized + Pointee>(&mut self, meta: T::Metadata) -> Result<Self::Handle<T>> {
let layout = utils::layout_of::<T>(meta);
utils::validate_layout_for::<[S; N]>(layout)?;
let start = self.find_lock(layout.size())?;
Ok(OffsetMetaHandle::from_offset_meta(start, meta))
}
unsafe fn deallocate<T: ?Sized + Pointee>(&mut self, handle: Self::Handle<T>) {
let ptr = unsafe { self.get(handle) };
let layout = unsafe { Layout::for_value_raw(ptr.as_ptr()) };
let mut used = self.used.lock();
unlock_range(
&mut used,
handle.offset()..(handle.offset() + blocks::<S>(layout.size())),
);
}
}
impl<S, const N: usize> ExactSizeStorage for &VirtHeap<S, N>
where
S: StorageSafe,
{
fn will_fit<T: ?Sized + Pointee>(&self, meta: T::Metadata) -> bool {
let layout = utils::layout_of::<T>(meta);
mem::size_of::<S>() >= layout.size()
}
fn max_range<T>(&self) -> usize {
let layout = Layout::new::<T>();
(mem::size_of::<S>() * N) / layout.size()
}
}
unsafe impl<S, const N: usize> ClonesafeStorage for &VirtHeap<S, N> where S: StorageSafe {}
unsafe impl<S, const N: usize> LeaksafeStorage for &VirtHeap<S, N> where S: StorageSafe {}
unsafe impl<S, const N: usize> FromLeakedStorage for &VirtHeap<S, N>
where
S: StorageSafe,
{
unsafe fn unleak_ptr<T: ?Sized>(&self, leaked: *mut T) -> Self::Handle<T> {
let meta = ptr::metadata(leaked);
let offset: usize = unsafe {
leaked
.cast::<S>()
.offset_from(self.storage.get() as *const S)
.try_into()
.unwrap()
};
OffsetMetaHandle::from_offset_meta(offset, meta)
}
}
unsafe impl<S: Send + StorageSafe, const N: usize> Send for VirtHeap<S, N> {}
unsafe impl<S: Sync + StorageSafe, const N: usize> Sync for VirtHeap<S, N> {}
#[cfg(test)]
mod tests {
use crate::boxed::Box;
use crate::collections::Vec;
use super::*;
#[test]
fn test_box() {
static HEAP: VirtHeap<usize, 4> = VirtHeap::new();
let b = Box::new_in([1, 2], &HEAP);
let b2 = b.coerce::<[i32]>();
assert_eq!(&*b2, &[1, 2]);
}
#[test]
fn test_multi_box() {
static HEAP: VirtHeap<usize, 16> = VirtHeap::new();
let b1 = Box::new_in([1, 2], &HEAP);
let b2 = Box::new_in([3, 4], &HEAP);
let b3 = Box::new_in([5, 6], &HEAP);
let b4 = Box::new_in([7, 8], &HEAP);
assert_eq!(*b1, [1, 2]);
assert_eq!(*b2, [3, 4]);
assert_eq!(*b3, [5, 6]);
assert_eq!(*b4, [7, 8]);
}
#[test]
fn test_vec() {
static HEAP: VirtHeap<usize, 16> = VirtHeap::new();
let mut v = Vec::new_in(&HEAP);
v.push(1);
v.push(2);
assert_eq!(&*v, &[1, 2]);
}
#[test]
fn test_multi_vec() {
static HEAP: VirtHeap<usize, 16> = VirtHeap::new();
let mut v1 = Vec::new_in(&HEAP);
let mut v2 = Vec::new_in(&HEAP);
let mut v3 = Vec::new_in(&HEAP);
let mut v4 = Vec::new_in(&HEAP);
v1.extend([1, 2]);
v2.extend([3, 4]);
v3.extend([5, 6]);
v4.extend([7, 8]);
v1.extend([9, 10, 11, 12, 13, 14, 15, 16]);
assert_eq!(&*v1, &[1, 2, 9, 10, 11, 12, 13, 14, 15, 16]);
assert_eq!(&*v2, &[3, 4]);
assert_eq!(&*v3, &[5, 6]);
assert_eq!(&*v4, &[7, 8]);
}
#[test]
fn test_size() {
static HEAP: VirtHeap<u8, 4> = VirtHeap::new();
type Box<T> = crate::boxed::Box<T, &'static VirtHeap<u8, 4>>;
Box::<[u8; 4]>::try_new_in([1, 2, 3, 4], &HEAP).unwrap();
Box::<[u8; 8]>::try_new_in([1, 2, 3, 4, 5, 6, 7, 8], &HEAP).unwrap_err();
}
#[test]
fn test_align() {
static FOO1: VirtHeap<u8, 4> = VirtHeap::new();
static FOO2: VirtHeap<u16, 4> = VirtHeap::new();
static FOO4: VirtHeap<u32, 4> = VirtHeap::new();
static FOO8: VirtHeap<u64, 4> = VirtHeap::new();
type Box<T, S> = crate::boxed::Box<T, &'static VirtHeap<S, 4>>;
#[derive(Debug)]
#[repr(align(1))]
struct Align1;
#[derive(Debug)]
#[repr(align(2))]
struct Align2;
#[derive(Debug)]
#[repr(align(4))]
struct Align4;
#[derive(Debug)]
#[repr(align(8))]
struct Align8;
Box::<_, u8>::try_new_in(Align1, &FOO1).unwrap();
Box::<_, u8>::try_new_in(Align2, &FOO1).unwrap_err();
Box::<_, u8>::try_new_in(Align4, &FOO1).unwrap_err();
Box::<_, u8>::try_new_in(Align8, &FOO1).unwrap_err();
Box::<_, u16>::try_new_in(Align1, &FOO2).unwrap();
Box::<_, u16>::try_new_in(Align2, &FOO2).unwrap();
Box::<_, u16>::try_new_in(Align4, &FOO2).unwrap_err();
Box::<_, u16>::try_new_in(Align8, &FOO2).unwrap_err();
Box::<_, u32>::try_new_in(Align1, &FOO4).unwrap();
Box::<_, u32>::try_new_in(Align2, &FOO4).unwrap();
Box::<_, u32>::try_new_in(Align4, &FOO4).unwrap();
Box::<_, u32>::try_new_in(Align8, &FOO4).unwrap_err();
Box::<_, u64>::try_new_in(Align1, &FOO8).unwrap();
Box::<_, u64>::try_new_in(Align2, &FOO8).unwrap();
Box::<_, u64>::try_new_in(Align4, &FOO8).unwrap();
Box::<_, u64>::try_new_in(Align8, &FOO8).unwrap();
}
#[test]
fn test_leak() {
static HEAP: VirtHeap<usize, 16> = VirtHeap::new();
let v1 = Box::new_in(1, &HEAP);
let i = Box::leak(v1);
assert_eq!(*i, 1);
*i = -1;
assert_eq!(*i, -1);
let v1 = unsafe { Box::from_raw_in(i, &HEAP) };
assert_eq!(*v1, -1);
}
#[test]
fn test_non_static() {
let heap: VirtHeap<u32, 4> = VirtHeap::new();
Box::new_in(1, &heap);
}
}