use core::alloc::{GlobalAlloc, Layout};
pub unsafe trait ChainableAlloc {
fn claims(&self, ptr: *mut u8, layout: Layout) -> bool;
}
pub struct AllocChain<'a, A, B>(A, &'a B);
impl<'a, A, B> AllocChain<'a, A, B> {
pub const fn new(a: A, b: &'a B) -> Self {
Self(a, b)
}
pub const fn chain<T>(self, next: &T) -> AllocChain<'_, Self, T>
where
Self: Sized,
{
AllocChain::new(self, next)
}
}
unsafe impl<A: GlobalAlloc + ChainableAlloc, B: GlobalAlloc> GlobalAlloc for AllocChain<'_, A, B> {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
let ptr_a = unsafe { self.0.alloc(layout) };
if ptr_a.is_null() {
unsafe { self.1.alloc(layout) }
} else {
ptr_a
}
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
if self.0.claims(ptr, layout) {
unsafe { self.0.dealloc(ptr, layout) };
} else {
unsafe { self.1.dealloc(ptr, layout) };
}
}
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
if self.0.claims(ptr, layout) {
let ptr_a = unsafe { self.0.realloc(ptr, layout, new_size) };
if !ptr_a.is_null() {
return ptr_a;
}
let layout_b = unsafe { Layout::from_size_align_unchecked(new_size, layout.align()) };
let ptr_b = unsafe { self.1.alloc(layout_b) };
if !ptr_b.is_null() {
unsafe {
ptr.copy_to_nonoverlapping(ptr_b, layout.size());
self.0.dealloc(ptr, layout);
}
}
ptr_b
} else {
unsafe { self.1.realloc(ptr, layout, new_size) }
}
}
}
#[cfg(any(feature = "allocator-api", feature = "allocator-api2"))]
use {
crate::{AllocError, Allocator},
core::ptr::NonNull,
};
#[cfg(any(feature = "allocator-api", feature = "allocator-api2"))]
unsafe impl<A: ChainableAlloc, B> Allocator for &AllocChain<'_, A, B>
where
for<'a> &'a A: Allocator,
for<'a> &'a B: Allocator,
{
fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
(&self.0)
.allocate(layout)
.or_else(|_| self.1.allocate(layout))
}
unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
if self.0.claims(ptr.as_ptr(), layout) {
unsafe { (&self.0).deallocate(ptr, layout) };
} else {
unsafe { self.1.deallocate(ptr, layout) }
}
}
unsafe fn grow(
&self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
if self.0.claims(ptr.as_ptr(), old_layout) {
let res_a = unsafe { (&self.0).grow(ptr, old_layout, new_layout) };
if res_a.is_ok() {
return res_a;
}
let res_b = self.1.allocate(new_layout);
if let Ok(ptr_b) = res_b {
unsafe {
ptr.copy_to_nonoverlapping(ptr_b.cast(), old_layout.size());
(&self.0).deallocate(ptr, old_layout);
}
}
res_b
} else {
unsafe { self.1.grow(ptr, old_layout, new_layout) }
}
}
unsafe fn grow_zeroed(
&self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
unsafe {
let new_ptr = self.grow(ptr, old_layout, new_layout)?;
let count = new_ptr.len() - old_layout.size();
new_ptr
.cast::<u8>()
.add(old_layout.size())
.write_bytes(0, count);
Ok(new_ptr)
}
}
unsafe fn shrink(
&self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
if self.0.claims(ptr.as_ptr(), old_layout) {
let res_a = unsafe { (&self.0).shrink(ptr, old_layout, new_layout) };
if res_a.is_ok() {
return res_a;
}
let res_b = self.1.allocate(new_layout);
if let Ok(ptr_b) = res_b {
unsafe {
ptr.copy_to_nonoverlapping(ptr_b.cast(), old_layout.size());
(&self.0).deallocate(ptr, old_layout);
}
}
res_b
} else {
unsafe { self.1.shrink(ptr, old_layout, new_layout) }
}
}
fn by_ref(&self) -> &Self
where
Self: Sized,
{
self
}
}