use core::ptr;
use super::*;
use metadata::{BlockInfo, FreeList, LiveSet};
mod mutex;
pub use mutex::{LockingMechanism, Mutex, MutexGuard, SingleThreadedLockingMechanism};
const MAX_ALIGN_WITH_CLASS: usize = 4096;
const NUM_ALIGN_CLASSES: usize = MAX_ALIGN_WITH_CLASS.ilog2() as usize + 1;
const ACCEPTABLE_WASTE_DIVISOR: usize = 8;
#[derive(Default)]
pub struct ZeroAwareAllocator<A, L>
where
A: Allocator,
L: LockingMechanism,
{
inner: A,
zeroed: Mutex<Zeroed, L>,
}
#[derive(Default)]
struct Zeroed {
align_classes: [FreeList; NUM_ALIGN_CLASSES],
very_large_aligns: FreeList,
live_set: LiveSet,
}
impl<A, L> ZeroAwareAllocator<A, L>
where
A: Allocator,
L: LockingMechanism,
{
#[inline]
pub const fn new(inner: A, lock: L) -> Self {
let zeroed = Zeroed {
align_classes: [const { FreeList::new() }; NUM_ALIGN_CLASSES],
very_large_aligns: FreeList::new(),
live_set: LiveSet::new(),
};
let zeroed = Mutex::new(zeroed, lock);
ZeroAwareAllocator { inner, zeroed }
}
#[inline]
pub fn inner(&self) -> &A {
&self.inner
}
#[inline]
pub fn inner_mut(&mut self) -> &mut A {
&mut self.inner
}
pub fn return_zeroed_memory_to_inner(&mut self) {
let mut zeroed = self.zeroed.lock();
let zeroed = &mut *zeroed;
for freelist in zeroed
.align_classes
.iter_mut()
.chain(Some(&mut zeroed.very_large_aligns))
{
while let Some(node) = freelist.pop_root() {
let block_ptr = node.ptr();
let block_layout = node.layout();
unsafe {
self.deallocate_block_info(node);
}
unsafe { self.inner.deallocate(block_ptr, block_layout) };
}
}
}
fn allocate_block_info(
&self,
ptr: NonNull<u8>,
layout: Layout,
) -> Result<&'static BlockInfo<'static>, AllocError> {
let node_ptr = self.inner.allocate(Layout::new::<BlockInfo<'_>>())?;
let node_ptr = node_ptr.cast::<BlockInfo<'_>>();
unsafe {
node_ptr.write(BlockInfo::new(ptr, layout));
Ok(node_ptr.as_ref())
}
}
unsafe fn deallocate_block_info(&self, node: &'static BlockInfo<'static>) {
self.inner
.deallocate(NonNull::from(node).cast(), Layout::new::<BlockInfo<'_>>());
}
#[inline]
fn allocate_already_zeroed(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
debug_assert_ne!(layout.size(), 0);
let align_class = layout.align().ilog2() as usize;
{
let mut zeroed = self.zeroed.lock();
let zeroed = &mut *zeroed;
let freelists = zeroed
.align_classes
.get_mut(align_class..NUM_ALIGN_CLASSES)
.into_iter()
.flat_map(|classes| classes.iter_mut())
.chain(Some(&mut zeroed.very_large_aligns));
for freelist in freelists {
if let Some(node) = freelist.remove(&layout) {
let ret = node.non_null_slice_ptr();
debug_assert!(
ret.len() >= layout.size(),
"{ret:#p}'s size should be greater than or equal to user layout's size\n\
actual size = {}\n\
user layout's size = {}",
ret.len(),
layout.size()
);
debug_assert!(
ret.len() >= node.layout().size(),
"{ret:#p}'s size should be greater than or equal to its original layout's size\n\
actual size = {}\n\
original layout's size = {}",
ret.len(),
layout.size()
);
debug_assert_eq!(
ret.cast::<u8>().as_ptr() as usize % layout.align(),
0,
"{ret:#p} should be aligned to user layout's alignment of {:#x}",
layout.align()
);
debug_assert_eq!(
ret.cast::<u8>().as_ptr() as usize % node.layout().align(),
0,
"{ret:#p} should be aligned to user layout's alignment of {:#x}",
node.layout().align()
);
debug_assert!({
let slice = unsafe {
core::slice::from_raw_parts(
node.ptr().as_ptr().cast_const(),
node.layout().size(),
)
};
slice.iter().all(|b| *b == 0)
});
zeroed.live_set.insert(node);
return Ok(ret);
}
}
}
Err(AllocError)
}
#[inline]
unsafe fn deallocate_already_zeroed(&self, ptr: NonNull<u8>, layout: Layout) {
debug_assert_ne!(layout.size(), 0);
debug_assert_eq!(ptr.as_ptr() as usize % layout.align(), 0);
debug_assert!({
let slice = core::slice::from_raw_parts(ptr.as_ptr().cast_const(), layout.size());
slice.iter().all(|b| *b == 0)
});
let mut zeroed = self.zeroed.lock();
let zeroed = &mut *zeroed;
let node = match zeroed.live_set.remove(&ptr) {
Some(node) => {
debug_assert!(
node.layout().size() >= layout.size(),
"actual size should be greater than or equal to user's size\n\
actual size = {}\n\
user's size = {}",
node.layout().size(),
layout.size(),
);
debug_assert_eq!(ptr.as_ptr() as usize % node.layout().align(), 0);
debug_assert!({
let slice = core::slice::from_raw_parts(
ptr.add(layout.size()).as_ptr().cast_const(),
node.layout().size() - layout.size(),
);
slice.iter().all(|b| *b == 0)
});
node
}
None => match self.allocate_block_info(ptr, layout) {
Ok(node) => node,
Err(_) => {
self.inner.deallocate(ptr, layout);
return;
}
},
};
let align_class = node.layout().align().ilog2() as usize;
let freelist = zeroed
.align_classes
.get_mut(align_class)
.unwrap_or_else(|| &mut zeroed.very_large_aligns);
freelist.insert(node);
}
unsafe fn pre_grow(
&self,
ptr: NonNull<u8>,
user_old_layout: Layout,
new_layout: Layout,
) -> PreGrow {
debug_assert_ne!(user_old_layout.size(), 0);
debug_assert_eq!(
ptr.as_ptr() as usize % user_old_layout.align(),
0,
"{ptr:#p} should be aligned to user's layout's alignment of {:#x}",
user_old_layout.align()
);
let mut zeroed = self.zeroed.lock();
let zeroed = &mut *zeroed;
let actual_old_layout = if let Some(node) = zeroed.live_set.find(&ptr) {
debug_assert_eq!(node.ptr(), ptr);
debug_assert!(node.layout().size() >= user_old_layout.size());
debug_assert_eq!(
ptr.as_ptr() as usize % node.layout().align(),
0,
"{ptr:#p} should be aligned to original layout's alignment of {:#x}",
node.layout().align()
);
if node.layout().size() >= new_layout.size()
&& node.layout().align() >= new_layout.align()
{
return PreGrow::Reuse {
ptr: NonNull::slice_from_raw_parts(ptr, node.layout().size()),
actual_layout: node.layout(),
};
}
let actual_old_layout = node.layout();
zeroed
.live_set
.remove(node)
.expect("just found the node in the live-set, should still be there");
self.deallocate_block_info(node);
actual_old_layout
} else {
user_old_layout
};
if new_layout.size() >= actual_old_layout.size() {
PreGrow::DoGrow { actual_old_layout }
} else {
PreGrow::DoShrink { actual_old_layout }
}
}
}
enum PreGrow {
Reuse {
ptr: NonNull<[u8]>,
actual_layout: Layout,
},
DoGrow { actual_old_layout: Layout },
DoShrink { actual_old_layout: Layout },
}
impl<A, L> Drop for ZeroAwareAllocator<A, L>
where
A: Allocator,
L: LockingMechanism,
{
fn drop(&mut self) {
self.return_zeroed_memory_to_inner();
}
}
unsafe impl<A, L> Allocator for ZeroAwareAllocator<A, L>
where
A: Allocator,
L: LockingMechanism,
{
#[inline]
fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
self.inner
.allocate(layout)
.or_else(|_| self.allocate_already_zeroed(layout))
}
#[inline]
unsafe fn deallocate(&self, ptr: NonNull<u8>, user_layout: Layout) {
if user_layout.size() == 0 {
self.inner.deallocate(ptr, user_layout);
return;
}
let mut zeroed = self.zeroed.lock();
let zeroed = &mut *zeroed;
let actual_layout = if let Some(node) = zeroed.live_set.remove(&ptr) {
let layout = node.layout();
self.deallocate_block_info(node);
layout
} else {
user_layout
};
self.inner.deallocate(ptr, actual_layout);
}
#[inline]
fn allocate_zeroed(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
if layout.size() == 0 {
return self.inner.allocate_zeroed(layout);
}
self.allocate_already_zeroed(layout)
.or_else(|_| self.inner.allocate_zeroed(layout))
}
#[inline]
unsafe fn grow(
&self,
ptr: NonNull<u8>,
user_old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
if user_old_layout.size() == 0 {
return self.inner.grow(ptr, user_old_layout, new_layout);
}
let actual_old_layout = match self.pre_grow(ptr, user_old_layout, new_layout) {
PreGrow::Reuse {
ptr,
actual_layout: _,
} => {
return Ok(ptr);
}
PreGrow::DoShrink { actual_old_layout } => {
return self.inner.shrink(ptr, actual_old_layout, new_layout);
}
PreGrow::DoGrow { actual_old_layout } => actual_old_layout,
};
match self.inner.grow(ptr, actual_old_layout, new_layout) {
Ok(ptr) => Ok(ptr),
Err(_) => {
let new_ptr = self.allocate_already_zeroed(new_layout)?;
debug_assert_ne!(new_ptr.cast::<u8>(), ptr);
ptr::copy_nonoverlapping(
ptr.as_ptr().cast_const(),
new_ptr.cast::<u8>().as_ptr(),
user_old_layout.size(),
);
self.inner.deallocate(ptr, actual_old_layout);
Ok(new_ptr)
}
}
}
#[inline]
unsafe fn grow_zeroed(
&self,
ptr: NonNull<u8>,
user_old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
if user_old_layout.size() == 0 {
return self.inner.allocate_zeroed(new_layout);
}
let actual_old_layout = match self.pre_grow(ptr, user_old_layout, new_layout) {
PreGrow::Reuse { ptr, actual_layout } => {
let slop = ptr.cast::<u8>().add(user_old_layout.size());
let slop_len = actual_layout.size() - user_old_layout.size();
slop.write_bytes(0, slop_len);
return Ok(ptr);
}
PreGrow::DoShrink { actual_old_layout } => {
let new_ptr = self.inner.shrink(ptr, actual_old_layout, new_layout)?;
debug_assert!(new_layout.size() < actual_old_layout.size());
debug_assert!(new_layout.size() >= user_old_layout.size());
let to_zero = new_ptr.cast::<u8>().add(user_old_layout.size());
let to_zero_len = new_layout.size() - user_old_layout.size();
to_zero.write_bytes(0, to_zero_len);
return Ok(new_ptr);
}
PreGrow::DoGrow { actual_old_layout } => actual_old_layout,
};
let bytes_to_zero = new_layout.size().saturating_sub(actual_old_layout.size());
let bytes_to_copy = new_layout.size() - bytes_to_zero;
if bytes_to_copy > 2usize.saturating_mul(bytes_to_zero) {
if let Ok(p) = self.inner.grow_zeroed(ptr, user_old_layout, new_layout) {
return Ok(p);
}
}
let new = self.allocate_zeroed(new_layout)?;
ptr::copy_nonoverlapping(
ptr.as_ptr().cast_const(),
new.cast().as_ptr(),
user_old_layout.size(),
);
self.inner.deallocate(ptr, actual_old_layout);
Ok(new)
}
#[inline]
unsafe fn shrink(
&self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
if old_layout.size() == 0 {
return self.inner.shrink(ptr, old_layout, new_layout);
}
let mut zeroed = self.zeroed.lock();
let zeroed = &mut *zeroed;
let old_layout = if let Some(node) = zeroed.live_set.remove(&ptr) {
let actual_old_layout = node.layout();
self.deallocate_block_info(node);
actual_old_layout
} else {
old_layout
};
self.inner.shrink(ptr, old_layout, new_layout)
}
}
impl<A, L> DeallocateZeroed for ZeroAwareAllocator<A, L>
where
A: Allocator,
L: LockingMechanism,
{
unsafe fn deallocate_zeroed(&self, ptr: NonNull<u8>, layout: Layout) {
if layout.size() == 0 {
self.inner.deallocate(ptr, layout);
return;
}
self.deallocate_already_zeroed(ptr, layout);
}
}
mod metadata {
use super::*;
use core::cmp::Ordering;
use intrusive_splay_tree::{Node, SplayTree, TreeOrd};
#[derive(Debug)]
pub(super) struct BlockInfo<'a> {
ptr: NonNull<u8>,
layout: Layout,
node: Node<'a>,
}
impl<'a> BlockInfo<'a> {
pub(super) unsafe fn new(ptr: NonNull<u8>, layout: Layout) -> Self {
debug_assert!(
core::slice::from_raw_parts(ptr.as_ptr(), layout.size())
.iter()
.all(|b| *b == 0),
"supposedly already-zeroed block contains non-zero memory"
);
BlockInfo {
ptr,
layout,
node: Node::default(),
}
}
fn size(&self) -> usize {
self.layout.size()
}
fn align(&self) -> usize {
1 << (self.ptr.as_ptr() as usize).trailing_zeros()
}
pub(super) fn ptr(&self) -> NonNull<u8> {
self.ptr
}
pub(super) fn non_null_slice_ptr(&self) -> NonNull<[u8]> {
NonNull::slice_from_raw_parts(self.ptr, self.size())
}
pub(super) fn layout(&self) -> Layout {
self.layout
}
}
pub(super) struct ByAlignClass;
pub(super) type FreeList = SplayTree<'static, ByAlignClass>;
impl<'a> TreeOrd<'a, ByAlignClass> for BlockInfo<'a> {
fn tree_cmp(&self, other: &'a BlockInfo<'a>) -> Ordering {
self.size()
.cmp(&other.size())
.then_with(|| {
let self_align = (self.ptr.as_ptr() as usize).trailing_zeros();
let other_align = (other.ptr.as_ptr() as usize).trailing_zeros();
self_align.cmp(&other_align)
})
.then_with(|| {
let self_addr = self.ptr.as_ptr() as usize;
let other_addr = other.ptr.as_ptr() as usize;
self_addr.cmp(&other_addr)
})
}
}
impl<'a> TreeOrd<'a, ByAlignClass> for Layout {
fn tree_cmp(&self, block: &'a BlockInfo<'a>) -> Ordering {
let by_size = match self.size().cmp(&block.size()) {
Ordering::Greater => Ordering::Greater,
Ordering::Equal => Ordering::Equal,
Ordering::Less => {
let acceptable_waste = self.size() / ACCEPTABLE_WASTE_DIVISOR;
let potential_waste = block.size() - self.size();
if potential_waste <= acceptable_waste {
Ordering::Equal
} else {
Ordering::Less
}
}
};
by_size
.then_with(|| match self.align().cmp(&block.align()) {
Ordering::Less | Ordering::Equal => Ordering::Equal,
Ordering::Greater => Ordering::Greater,
})
}
}
intrusive_splay_tree::impl_intrusive_node! {
impl<'a> IntrusiveNode<'a> for ByAlignClass
where
type Elem = BlockInfo<'a>,
node = node;
}
pub(super) struct ByPointer;
pub(super) type LiveSet = SplayTree<'static, ByPointer>;
impl<'a> TreeOrd<'a, ByPointer> for BlockInfo<'a> {
fn tree_cmp(&self, other: &'a BlockInfo<'a>) -> Ordering {
Ord::cmp(&self.ptr, &other.ptr)
}
}
impl<'a> TreeOrd<'a, ByPointer> for NonNull<u8> {
fn tree_cmp(&self, other: &'a BlockInfo<'a>) -> Ordering {
Ord::cmp(self, &other.ptr)
}
}
intrusive_splay_tree::impl_intrusive_node! {
impl<'a> IntrusiveNode<'a> for ByPointer
where
type Elem = BlockInfo<'a>,
node = node;
}
}