use core::alloc::Layout;
use core::cmp::{max, min};
use core::mem::size_of;
use core::ptr::{self, NonNull};
use core::result::Result;
use crate::math::log2;
const MIN_HEAP_ALIGN: usize = 4096;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum AllocationSizeError {
BadAlignment,
TooLarge,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum AllocationError {
HeapExhausted,
InvalidSize(AllocationSizeError),
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum HeapError {
BadBaseAlignment,
BadSizeAlignment,
BadHeapSize,
MinBlockTooSmall,
}
struct FreeBlock {
next: *mut FreeBlock,
}
impl FreeBlock {
const fn new(next: *mut FreeBlock) -> FreeBlock {
FreeBlock { next }
}
}
#[derive(Debug)]
pub struct Heap<const N: usize> {
heap_base: *mut u8,
heap_size: usize,
free_lists: [*mut FreeBlock; N],
min_block_size: usize,
min_block_size_log2: u8,
}
unsafe impl<const N: usize> Send for Heap<N> {}
impl<const N: usize> Heap<N> {
pub unsafe fn new(heap_base: NonNull<u8>, heap_size: usize) -> Result<Self, HeapError> {
let min_block_size = heap_size >> (N - 1);
if heap_base.as_ptr() as usize & (MIN_HEAP_ALIGN - 1) != 0 {
return Err(HeapError::BadBaseAlignment);
}
if heap_size < min_block_size {
return Err(HeapError::BadHeapSize);
}
if min_block_size < size_of::<FreeBlock>() {
return Err(HeapError::MinBlockTooSmall);
}
if !heap_size.is_power_of_two() {
return Err(HeapError::BadSizeAlignment);
}
Ok(Self::new_unchecked(heap_base.as_ptr(), heap_size))
}
pub const unsafe fn new_unchecked(heap_base: *mut u8, heap_size: usize) -> Self {
let min_block_size = heap_size >> (N - 1);
let mut free_lists: [*mut FreeBlock; N] = [core::ptr::null_mut(); N];
free_lists[N - 1] = heap_base as *mut FreeBlock;
Self {
heap_base: heap_base,
heap_size,
free_lists,
min_block_size,
min_block_size_log2: log2(min_block_size),
}
}
fn allocation_size(&self, mut size: usize, align: usize) -> Result<usize, AllocationSizeError> {
if !align.is_power_of_two() {
return Err(AllocationSizeError::BadAlignment);
}
if align > MIN_HEAP_ALIGN {
return Err(AllocationSizeError::BadAlignment);
}
if align > size {
size = align;
}
size = max(size, self.min_block_size);
size = size.next_power_of_two();
if size > self.heap_size {
return Err(AllocationSizeError::TooLarge);
}
Ok(size)
}
fn allocation_order(&self, size: usize, align: usize) -> Result<usize, AllocationSizeError> {
self.allocation_size(size, align)
.map(|s| (log2(s) - self.min_block_size_log2) as usize)
}
const fn order_size(&self, order: usize) -> usize {
1 << (self.min_block_size_log2 as usize + order)
}
fn free_list_pop(&mut self, order: usize) -> Option<*mut u8> {
let candidate = self.free_lists[order];
if !candidate.is_null() {
if order != self.free_lists.len() - 1 {
self.free_lists[order] = unsafe { (*candidate).next };
} else {
self.free_lists[order] = ptr::null_mut();
}
Some(candidate as *mut u8)
} else {
None
}
}
unsafe fn free_list_insert(&mut self, order: usize, block: *mut u8) {
let free_block_ptr = block as *mut FreeBlock;
*free_block_ptr = FreeBlock::new(self.free_lists[order]);
self.free_lists[order] = free_block_ptr;
}
fn free_list_remove(&mut self, order: usize, block: *mut u8) -> bool {
let block_ptr = block as *mut FreeBlock;
let mut checking: &mut *mut FreeBlock = &mut self.free_lists[order];
while !(*checking).is_null() {
if *checking == block_ptr {
*checking = unsafe { (*(*checking)).next };
return true;
}
checking = unsafe { &mut ((*(*checking)).next) };
}
false
}
unsafe fn split_free_block(&mut self, block: *mut u8, mut order: usize, order_needed: usize) {
let mut size_to_split = self.order_size(order);
while order > order_needed {
size_to_split >>= 1;
order -= 1;
let split = block.add(size_to_split);
self.free_list_insert(order, split);
}
}
fn buddy(&self, order: usize, block: *mut u8) -> Option<*mut u8> {
assert!(block >= self.heap_base);
let relative = unsafe { block.offset_from(self.heap_base) } as usize;
let size = self.order_size(order);
if size >= self.heap_size {
None
} else {
Some(unsafe { self.heap_base.add(relative ^ size) })
}
}
pub fn allocate(&mut self, layout: Layout) -> Result<*mut u8, AllocationError> {
match self.allocation_order(layout.size(), layout.align()) {
Ok(order_needed) => {
for order in order_needed..self.free_lists.len() {
if let Some(block) = self.free_list_pop(order) {
if order > order_needed {
unsafe { self.split_free_block(block, order, order_needed) };
}
return Ok(block);
}
}
Err(AllocationError::HeapExhausted)
}
Err(e) => Err(AllocationError::InvalidSize(e)),
}
}
pub unsafe fn deallocate(&mut self, ptr: *mut u8, layout: Layout) {
let initial_order = self
.allocation_order(layout.size(), layout.align())
.expect("Tried to dispose of invalid block");
let mut block = ptr;
for order in initial_order..self.free_lists.len() {
if let Some(buddy) = self.buddy(order, block) {
if self.free_list_remove(order, buddy) {
block = min(block, buddy);
continue;
}
}
self.free_list_insert(order, block);
return;
}
}
}
#[cfg(test)]
mod test {
extern crate std;
use super::*;
#[test]
fn test_allocation_size_and_order() {
unsafe {
let heap_size = 256;
let layout = std::alloc::Layout::from_size_align(heap_size, 4096).unwrap();
let mem = std::alloc::alloc(layout);
let heap: Heap<5> = Heap::new(NonNull::new(mem).unwrap(), heap_size).unwrap();
assert_eq!(
Err(AllocationSizeError::BadAlignment),
heap.allocation_size(256, 8192)
);
assert_eq!(
Err(AllocationSizeError::TooLarge),
heap.allocation_size(256, 256 * 2)
);
assert_eq!(Ok(16), heap.allocation_size(0, 1));
assert_eq!(Ok(16), heap.allocation_size(1, 1));
assert_eq!(Ok(16), heap.allocation_size(16, 1));
assert_eq!(Ok(32), heap.allocation_size(17, 1));
assert_eq!(Ok(32), heap.allocation_size(32, 32));
assert_eq!(Ok(256), heap.allocation_size(256, 256));
assert_eq!(Ok(64), heap.allocation_size(16, 64));
assert_eq!(Ok(0), heap.allocation_order(0, 1));
assert_eq!(Ok(0), heap.allocation_order(1, 1));
assert_eq!(Ok(0), heap.allocation_order(16, 16));
assert_eq!(Ok(1), heap.allocation_order(32, 32));
assert_eq!(Ok(2), heap.allocation_order(64, 64));
assert_eq!(Ok(3), heap.allocation_order(128, 128));
assert_eq!(Ok(4), heap.allocation_order(256, 256));
assert_eq!(
Err(AllocationSizeError::TooLarge),
heap.allocation_order(512, 512)
);
std::alloc::dealloc(mem, layout);
}
}
#[test]
fn test_buddy() {
unsafe {
let heap_size = 256;
let layout = std::alloc::Layout::from_size_align(heap_size, 4096).unwrap();
let mem = std::alloc::alloc(layout);
let heap: Heap<5> = Heap::new(NonNull::new(mem).unwrap(), heap_size).unwrap();
let block_16_0 = mem;
let block_16_1 = mem.offset(16);
assert_eq!(Some(block_16_1), heap.buddy(0, block_16_0));
assert_eq!(Some(block_16_0), heap.buddy(0, block_16_1));
let block_32_0 = mem;
let block_32_1 = mem.offset(32);
assert_eq!(Some(block_32_1), heap.buddy(1, block_32_0));
assert_eq!(Some(block_32_0), heap.buddy(1, block_32_1));
let block_32_2 = mem.offset(64);
let block_32_3 = mem.offset(96);
assert_eq!(Some(block_32_3), heap.buddy(1, block_32_2));
assert_eq!(Some(block_32_2), heap.buddy(1, block_32_3));
let block_256_0 = mem;
assert_eq!(None, heap.buddy(4, block_256_0));
std::alloc::dealloc(mem, layout);
}
}
#[test]
fn test_alloc_and_dealloc() {
unsafe {
let heap_size = 256;
let layout = std::alloc::Layout::from_size_align(heap_size, 4096).unwrap();
let mem = std::alloc::alloc(layout);
let mut heap: Heap<5> = Heap::new(NonNull::new(mem).unwrap(), heap_size).unwrap();
let block_16_0 = heap
.allocate(Layout::from_size_align(8, 8).unwrap())
.unwrap();
assert_eq!(mem, block_16_0);
let bigger_than_heap = heap.allocate(Layout::from_size_align(heap_size, 4096).unwrap());
assert_eq!(
Err(AllocationError::InvalidSize(AllocationSizeError::TooLarge)),
bigger_than_heap
);
let bigger_than_free =
heap.allocate(Layout::from_size_align(heap_size, heap_size).unwrap());
assert_eq!(Err(AllocationError::HeapExhausted), bigger_than_free);
let block_16_1 = heap
.allocate(Layout::from_size_align(8, 8).unwrap())
.unwrap();
assert_eq!(mem.offset(16), block_16_1);
let block_16_2 = heap
.allocate(Layout::from_size_align(8, 8).unwrap())
.unwrap();
assert_eq!(mem.offset(32), block_16_2);
let block_32_2 = heap
.allocate(Layout::from_size_align(32, 32).unwrap())
.unwrap();
assert_eq!(mem.offset(64), block_32_2);
let block_16_3 = heap
.allocate(Layout::from_size_align(8, 8).unwrap())
.unwrap();
assert_eq!(mem.offset(48), block_16_3);
let block_128_1 = heap
.allocate(Layout::from_size_align(128, 128).unwrap())
.unwrap();
assert_eq!(mem.offset(128), block_128_1);
let too_fragmented = heap.allocate(Layout::from_size_align(64, 64).unwrap());
assert_eq!(Err(AllocationError::HeapExhausted), too_fragmented);
heap.deallocate(block_32_2, Layout::from_size_align(32, 32).unwrap());
heap.deallocate(block_16_0, Layout::from_size_align(8, 8).unwrap());
heap.deallocate(block_16_3, Layout::from_size_align(8, 8).unwrap());
heap.deallocate(block_16_1, Layout::from_size_align(8, 8).unwrap());
heap.deallocate(block_16_2, Layout::from_size_align(8, 8).unwrap());
let block_128_0 = heap
.allocate(Layout::from_size_align(128, 128).unwrap())
.unwrap();
assert_eq!(mem.offset(0), block_128_0);
heap.deallocate(block_128_1, Layout::from_size_align(128, 128).unwrap());
heap.deallocate(block_128_0, Layout::from_size_align(128, 128).unwrap());
let block_256_0 = heap
.allocate(Layout::from_size_align(256, 256).unwrap())
.unwrap();
assert_eq!(mem.offset(0), block_256_0);
std::alloc::dealloc(mem, layout);
}
}
}