use core::ptr;
use super::*;
use metadata::{BlockInfo, FreeList};
mod mutex;
pub use mutex::{LockingMechanism, Mutex, MutexGuard, SingleThreadedLockingMechanism};
const MAX_ALIGN_WITH_CLASS: usize = 4096;
const NUM_ALIGN_CLASSES: usize = MAX_ALIGN_WITH_CLASS.ilog2() as usize + 1;
const ACCEPTABLE_WASTE_DIVISOR: usize = 8;
#[derive(Default)]
pub struct ZeroAwareAllocator<A, L>
where
A: Allocator,
L: LockingMechanism,
{
inner: A,
zeroed: Mutex<Zeroed, L>,
}
#[derive(Default)]
struct Zeroed {
align_classes: [FreeList; NUM_ALIGN_CLASSES],
very_large_aligns: FreeList,
}
impl<A, L> ZeroAwareAllocator<A, L>
where
A: Allocator,
L: LockingMechanism,
{
#[inline]
pub const fn new(inner: A, lock: L) -> Self {
let zeroed = Zeroed {
align_classes: [const { FreeList::new() }; NUM_ALIGN_CLASSES],
very_large_aligns: FreeList::new(),
};
let zeroed = Mutex::new(zeroed, lock);
ZeroAwareAllocator { inner, zeroed }
}
#[inline]
pub fn inner(&self) -> &A {
&self.inner
}
#[inline]
pub fn inner_mut(&mut self) -> &mut A {
&mut self.inner
}
pub fn return_zeroed_memory_to_inner(&mut self) {
let mut zeroed = self.zeroed.lock();
let zeroed = &mut *zeroed;
for freelist in zeroed
.align_classes
.iter_mut()
.chain(Some(&mut zeroed.very_large_aligns))
{
while let Some(node) = freelist.pop_root() {
let block_ptr = node.ptr();
let block_layout = node.layout();
unsafe {
self.deallocate_block_info(node);
}
unsafe { self.inner.deallocate(block_ptr, block_layout) };
}
}
}
fn allocate_block_info(
&self,
ptr: NonNull<u8>,
layout: Layout,
) -> Result<&'static BlockInfo<'static>, AllocError> {
let node_ptr = self.inner.allocate(Layout::new::<BlockInfo<'_>>())?;
let node_ptr = node_ptr.cast::<BlockInfo<'_>>();
unsafe {
node_ptr.write(BlockInfo::new(ptr, layout));
Ok(node_ptr.as_ref())
}
}
unsafe fn deallocate_block_info(&self, node: &'static BlockInfo<'static>) {
self.inner
.deallocate(NonNull::from(node).cast(), Layout::new::<BlockInfo<'_>>());
}
#[inline]
fn allocate_already_zeroed(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
if layout.size() == 0 {
return Ok(NonNull::from(&[]));
}
let align_class = layout.align().ilog2() as usize;
{
let mut zeroed = self.zeroed.lock();
let zeroed = &mut *zeroed;
let freelists = zeroed.align_classes[align_class..NUM_ALIGN_CLASSES]
.iter_mut()
.chain(Some(&mut zeroed.very_large_aligns));
for freelist in freelists {
if let Some(node) = freelist.remove(&layout) {
let ret = node.non_null_slice_ptr();
unsafe {
self.deallocate_block_info(node);
}
return Ok(ret);
}
}
}
Err(AllocError)
}
#[inline]
unsafe fn deallocate_already_zeroed(&self, ptr: NonNull<u8>, layout: Layout) {
if layout.size() == 0 {
return;
}
match self.allocate_block_info(ptr, layout) {
Err(_) => self.inner.deallocate(ptr, layout),
Ok(node) => {
let mut zeroed = self.zeroed.lock();
let zeroed = &mut *zeroed;
let align_class = layout.align().ilog2() as usize;
let freelist = zeroed
.align_classes
.get_mut(align_class)
.unwrap_or_else(|| &mut zeroed.very_large_aligns);
freelist.insert(node);
}
}
}
}
impl<A, L> Drop for ZeroAwareAllocator<A, L>
where
A: Allocator,
L: LockingMechanism,
{
fn drop(&mut self) {
self.return_zeroed_memory_to_inner();
}
}
unsafe impl<A, L> Allocator for ZeroAwareAllocator<A, L>
where
A: Allocator,
L: LockingMechanism,
{
#[inline]
fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
self.inner
.allocate(layout)
.or_else(|_| self.allocate_already_zeroed(layout))
}
#[inline]
unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
self.inner.deallocate(ptr, layout);
}
#[inline]
fn allocate_zeroed(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
self.allocate_already_zeroed(layout)
.or_else(|_| self.inner.allocate_zeroed(layout))
}
#[inline]
unsafe fn grow(
&self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
self.inner.grow(ptr, old_layout, new_layout).or_else(|_| {
let new = self.allocate_already_zeroed(new_layout)?;
ptr::copy_nonoverlapping(
ptr.as_ptr().cast_const(),
new.cast().as_ptr(),
old_layout.size(),
);
Ok(new)
})
}
#[inline]
unsafe fn grow_zeroed(
&self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
let bytes_to_zero = new_layout.size() - old_layout.size();
let bytes_to_copy = new_layout.size() - bytes_to_zero;
if bytes_to_copy > 2usize.saturating_mul(bytes_to_zero) {
if let Ok(p) = self.inner.grow_zeroed(ptr, old_layout, new_layout) {
return Ok(p);
}
}
let new = self.allocate_zeroed(new_layout)?;
ptr::copy_nonoverlapping(
ptr.as_ptr().cast_const(),
new.cast().as_ptr(),
old_layout.size(),
);
Ok(new)
}
#[inline]
unsafe fn shrink(
&self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
self.inner.shrink(ptr, old_layout, new_layout)
}
}
impl<A, L> DeallocateZeroed for ZeroAwareAllocator<A, L>
where
A: Allocator,
L: LockingMechanism,
{
unsafe fn deallocate_zeroed(&self, pointer: NonNull<u8>, layout: Layout) {
self.deallocate_already_zeroed(pointer, layout);
}
}
mod metadata {
use super::*;
use core::cmp::Ordering;
use intrusive_splay_tree::{Node, SplayTree, TreeOrd};
#[derive(Debug)]
pub(super) struct BlockInfo<'a> {
ptr: NonNull<u8>,
layout: Layout,
node: Node<'a>,
}
impl<'a> BlockInfo<'a> {
pub(super) unsafe fn new(ptr: NonNull<u8>, layout: Layout) -> Self {
debug_assert!(
core::slice::from_raw_parts(ptr.as_ptr(), layout.size())
.iter()
.all(|b| *b == 0),
"supposedly already-zeroed block contains non-zero memory"
);
BlockInfo {
ptr,
layout,
node: Node::default(),
}
}
fn size(&self) -> usize {
self.layout.size()
}
fn align(&self) -> usize {
1 << (self.ptr.as_ptr() as usize).trailing_zeros()
}
pub(super) fn ptr(&self) -> NonNull<u8> {
self.ptr
}
pub(super) fn non_null_slice_ptr(&self) -> NonNull<[u8]> {
NonNull::slice_from_raw_parts(self.ptr, self.size())
}
pub(super) fn layout(&self) -> Layout {
self.layout
}
}
impl<'a> TreeOrd<'a, BlockInfo<'a>> for BlockInfo<'a> {
fn tree_cmp(&self, other: &'a BlockInfo<'a>) -> Ordering {
self.size()
.cmp(&other.size())
.then_with(|| {
let self_align = (self.ptr.as_ptr() as usize).trailing_zeros();
let other_align = (other.ptr.as_ptr() as usize).trailing_zeros();
self_align.cmp(&other_align)
})
.then_with(|| {
let self_addr = self.ptr.as_ptr() as usize;
let other_addr = other.ptr.as_ptr() as usize;
self_addr.cmp(&other_addr)
})
}
}
impl<'a> TreeOrd<'a, BlockInfo<'a>> for Layout {
fn tree_cmp(&self, block: &'a BlockInfo<'a>) -> Ordering {
let by_size = match self.size().cmp(&block.size()) {
Ordering::Greater => Ordering::Greater,
Ordering::Equal => Ordering::Equal,
Ordering::Less => {
let acceptable_waste = self.size() / ACCEPTABLE_WASTE_DIVISOR;
let potential_waste = block.size() - self.size();
if potential_waste <= acceptable_waste {
Ordering::Equal
} else {
Ordering::Less
}
}
};
by_size
.then_with(|| match self.align().cmp(&block.align()) {
Ordering::Less | Ordering::Equal => Ordering::Equal,
Ordering::Greater => Ordering::Greater,
})
}
}
intrusive_splay_tree::impl_intrusive_node! {
impl<'a> IntrusiveNode<'a> for BlockInfo<'a>
where
type Elem = BlockInfo<'a>,
node = node;
}
pub(super) type FreeList = SplayTree<'static, BlockInfo<'static>>;
}