use core::mem;
use crate::{is_aligned, HeapStats, PageAllocatorProvider};
use super::align_up;
const HEAP_MAGIC: u32 = 0xF0B0CAFE;
#[repr(C, align(16))]
struct AllocatedHeapBlockInfo {
magic: u32,
size: usize,
pre_padding: usize,
}
const KERNEL_HEAP_BLOCK_INFO_SIZE: usize = mem::size_of::<AllocatedHeapBlockInfo>();
#[derive(Debug)]
struct HeapFreeBlock {
prev: *mut HeapFreeBlock,
next: *mut HeapFreeBlock,
size: usize,
}
pub struct HeapAllocator<const PAGE_SIZE: usize, T: PageAllocatorProvider<PAGE_SIZE>> {
heap_start: usize,
total_heap_size: usize,
free_list_addr: *mut HeapFreeBlock,
free_size: usize,
used_size: usize,
page_allocator: T,
}
unsafe impl<const PAGE_SIZE: usize, T: PageAllocatorProvider<PAGE_SIZE>> Send
for HeapAllocator<PAGE_SIZE, T>
{
}
impl<const PAGE_SIZE: usize, T> HeapAllocator<PAGE_SIZE, T>
where
T: PageAllocatorProvider<PAGE_SIZE>,
{
fn is_free_blocks_in_cycle(&self) -> bool {
let mut slow = self.free_list_addr;
let mut fast = self.free_list_addr;
if fast.is_null() {
return false;
} else {
fast = unsafe { (*fast).next };
}
while fast != slow {
if fast.is_null() {
return false;
} else {
fast = unsafe { (*fast).next };
}
if fast.is_null() {
return false;
} else {
fast = unsafe { (*fast).next };
}
if slow.is_null() {
return false;
} else {
slow = unsafe { (*slow).next };
}
}
true
}
fn check_free_blocks(&self) -> bool {
let mut forward_count = 0;
let mut last: *mut HeapFreeBlock = core::ptr::null_mut();
for block in self.iter_free_blocks() {
forward_count += 1;
last = block as _;
}
let mut backward_count = 0;
if !last.is_null() {
while !last.is_null() {
backward_count += 1;
last = unsafe { (*last).prev };
}
}
forward_count != backward_count
}
fn check_issues(&self) -> bool {
self.is_free_blocks_in_cycle() || self.check_free_blocks()
}
fn get_free_block(&mut self, size: usize) -> *mut HeapFreeBlock {
if self.total_heap_size == 0 {
let size = align_up(size, PAGE_SIZE);
self.allocate_more_pages(size / PAGE_SIZE);
return self.get_free_block(size);
}
let mut best_block: *mut HeapFreeBlock = core::ptr::null_mut();
for block in self.iter_free_blocks() {
if block.size >= size
&& (best_block.is_null() || block.size < unsafe { (*best_block).size })
{
best_block = block as _;
}
}
if best_block.is_null() {
let size = align_up(size, PAGE_SIZE);
self.allocate_more_pages(size / PAGE_SIZE);
return self.get_free_block(size);
}
best_block
}
fn iter_free_blocks(&self) -> impl Iterator<Item = &mut HeapFreeBlock> {
let mut current_block = self.free_list_addr;
core::iter::from_fn(move || {
if current_block.is_null() {
None
} else {
let block = current_block;
current_block = unsafe { (*current_block).next };
Some(unsafe { &mut *block })
}
})
}
fn allocate_more_pages(&mut self, pages: usize) {
assert!(pages > 0);
let new_heap_start = if self.total_heap_size == 0 {
self.heap_start = self.page_allocator.allocate_pages(pages).unwrap() as usize;
self.heap_start
} else {
self.page_allocator.allocate_pages(pages).unwrap() as usize
};
self.total_heap_size += pages * PAGE_SIZE;
if self.free_list_addr.is_null() {
let free_block = new_heap_start as *mut HeapFreeBlock;
unsafe {
(*free_block).prev = core::ptr::null_mut();
(*free_block).next = core::ptr::null_mut();
(*free_block).size = pages * PAGE_SIZE;
}
self.free_list_addr = free_block;
} else {
unsafe {
self.free_block(new_heap_start as _, pages * PAGE_SIZE);
}
}
self.free_size += pages * PAGE_SIZE;
}
unsafe fn free_block(&mut self, freeing_block: usize, size: usize) {
assert!(freeing_block <= self.heap_start + self.total_heap_size);
assert!(freeing_block + size <= self.heap_start + self.total_heap_size);
let freeing_block = freeing_block as *mut HeapFreeBlock;
let freeing_block_start = freeing_block as usize;
let freeing_block_end = freeing_block_start + size;
let mut prev_block: *mut HeapFreeBlock = core::ptr::null_mut();
let mut next_block: *mut HeapFreeBlock = core::ptr::null_mut();
let mut closest_prev_block: *mut HeapFreeBlock = core::ptr::null_mut();
for block in self.iter_free_blocks() {
let block_addr = block as *mut _ as usize;
let block_end = block_addr + block.size;
if block_addr == freeing_block_start {
panic!("double free");
}
assert!(
(freeing_block_end <= block_addr) || (freeing_block_start >= block_end),
"Free block at {:x}..{:x} is in the middle of another block at {:x}..{:x}",
freeing_block_start,
freeing_block_end,
block_addr,
block_end
);
if block_end == freeing_block_start {
prev_block = block as _;
} else if freeing_block_end == block_addr {
next_block = block as _;
}
if block_addr < freeing_block_start {
if closest_prev_block.is_null() || block_addr > (closest_prev_block as usize) {
closest_prev_block = block as _;
}
}
}
if !prev_block.is_null() && !next_block.is_null() {
let new_block = prev_block;
(*new_block).size += size + (*next_block).size;
if !(*next_block).next.is_null() {
(*(*next_block).next).prev = new_block;
}
if !(*next_block).prev.is_null() {
(*(*next_block).prev).next = new_block;
} else {
self.free_list_addr = new_block;
}
(*new_block).next = (*next_block).next;
} else if !prev_block.is_null() {
(*prev_block).size += size;
} else if !next_block.is_null() {
let new_block = freeing_block;
(*new_block).size = (*next_block).size + size;
(*new_block).prev = (*next_block).prev;
(*new_block).next = (*next_block).next;
if !(*next_block).next.is_null() {
(*(*next_block).next).prev = new_block;
}
if !(*next_block).prev.is_null() {
(*(*next_block).prev).next = new_block;
} else {
self.free_list_addr = new_block;
}
} else {
if closest_prev_block.is_null() {
(*freeing_block).prev = core::ptr::null_mut();
(*freeing_block).next = self.free_list_addr;
(*freeing_block).size = size;
if !(*freeing_block).next.is_null() {
(*(*freeing_block).next).prev = freeing_block;
}
self.free_list_addr = freeing_block;
} else {
let closest_next_block = (*closest_prev_block).next;
(*freeing_block).prev = closest_prev_block;
(*freeing_block).next = closest_next_block;
(*freeing_block).size = size;
(*closest_prev_block).next = freeing_block;
if !closest_next_block.is_null() {
(*closest_next_block).prev = freeing_block;
}
}
}
}
}
impl<const PAGE_SIZE: usize, T> HeapAllocator<PAGE_SIZE, T>
where
T: PageAllocatorProvider<PAGE_SIZE>,
{
pub fn new(page_allocator: T) -> Self {
Self {
heap_start: 0,
free_list_addr: core::ptr::null_mut(),
total_heap_size: 0,
free_size: 0,
used_size: 0,
page_allocator,
}
}
pub fn stats(&self) -> HeapStats {
HeapStats {
allocated: self.used_size,
free_size: self.free_size,
heap_size: self.total_heap_size,
}
}
pub fn debug_free_blocks(&self) -> impl Iterator<Item = (usize, usize)> + '_ {
self.iter_free_blocks()
.map(|block| (block as *mut _ as usize, block.size))
}
pub unsafe fn alloc(&mut self, layout: core::alloc::Layout) -> *mut u8 {
let block_info_layout = core::alloc::Layout::new::<AllocatedHeapBlockInfo>();
let (whole_layout, whole_block_offset) = block_info_layout
.extend(layout.align_to(16).unwrap())
.unwrap();
let size_to_allocate = whole_layout.pad_to_align().size();
let free_block = self.get_free_block(size_to_allocate);
if free_block.is_null() {
return core::ptr::null_mut();
}
let free_block_size = (*free_block).size;
let free_block_end = free_block as usize + size_to_allocate;
let new_free_block = free_block_end as *mut HeapFreeBlock;
let whole_size = size_to_allocate + mem::size_of::<HeapFreeBlock>();
let mut this_allocation_size = size_to_allocate;
if free_block_size > whole_size {
(*new_free_block).prev = (*free_block).prev;
(*new_free_block).next = (*free_block).next;
(*new_free_block).size = free_block_size - size_to_allocate;
if !(*new_free_block).next.is_null() {
(*(*new_free_block).next).prev = new_free_block;
}
if !(*new_free_block).prev.is_null() {
(*(*new_free_block).prev).next = new_free_block;
} else {
self.free_list_addr = new_free_block;
}
} else {
this_allocation_size = free_block_size;
if !(*free_block).prev.is_null() {
(*(*free_block).prev).next = (*free_block).next;
} else {
self.free_list_addr = (*free_block).next;
}
if !(*free_block).next.is_null() {
(*(*free_block).next).prev = (*free_block).prev;
}
}
self.free_size -= this_allocation_size;
self.used_size += this_allocation_size;
if self.check_issues() {
panic!("Found issues in `alloc`");
}
let base = free_block as usize;
let possible_next_offset = align_up(base, layout.align()) - base;
let allocated_block_offset = if possible_next_offset < KERNEL_HEAP_BLOCK_INFO_SIZE {
possible_next_offset + KERNEL_HEAP_BLOCK_INFO_SIZE.max(layout.align())
} else {
possible_next_offset
};
assert!(allocated_block_offset >= KERNEL_HEAP_BLOCK_INFO_SIZE);
assert!(allocated_block_offset <= whole_block_offset);
let allocated_ptr = (free_block as *mut u8).add(allocated_block_offset);
let allocated_block_info =
allocated_ptr.sub(KERNEL_HEAP_BLOCK_INFO_SIZE) as *mut AllocatedHeapBlockInfo;
assert!(is_aligned(allocated_ptr as _, layout.align()),
"base_block={allocated_block_info:p}, offset={allocated_block_offset}, ptr={allocated_ptr:?}, layout={layout:?}, should_be_addr={:x}",
align_up(allocated_block_info as usize, layout.align()));
(*allocated_block_info).magic = HEAP_MAGIC;
(*allocated_block_info).size = this_allocation_size;
(*allocated_block_info).pre_padding = allocated_block_offset;
allocated_ptr
}
pub unsafe fn dealloc(&mut self, ptr: *mut u8, layout: core::alloc::Layout) {
assert!(!ptr.is_null());
let base_layout = core::alloc::Layout::new::<AllocatedHeapBlockInfo>();
let (whole_layout, _) = base_layout.extend(layout.align_to(16).unwrap()).unwrap();
let size_to_free_from_layout = whole_layout.pad_to_align().size();
let allocated_block_info =
ptr.sub(KERNEL_HEAP_BLOCK_INFO_SIZE) as *mut AllocatedHeapBlockInfo;
assert_eq!((*allocated_block_info).magic, HEAP_MAGIC);
assert!((*allocated_block_info).size >= size_to_free_from_layout);
assert!((*allocated_block_info).pre_padding >= KERNEL_HEAP_BLOCK_INFO_SIZE);
let this_allocation_size = (*allocated_block_info).size;
let freeing_block = ptr.sub((*allocated_block_info).pre_padding) as usize;
self.free_block(freeing_block, this_allocation_size);
self.used_size -= this_allocation_size;
self.free_size += this_allocation_size;
if self.check_issues() {
panic!("Found issues in `dealloc`");
}
}
}