use crate::PAGE_SIZE;
use core::{
alloc::{GlobalAlloc, Layout},
ptr::{self, null_mut},
};
unsafe impl Sync for FreeListAllocator {}
pub struct FreeListAllocator;
static mut FREE_LIST: *mut FreeListNode = EMPTY_FREE_LIST;
impl FreeListAllocator {
pub const fn new() -> Self {
FreeListAllocator
}
pub unsafe fn reset() {
unsafe {
FREE_LIST = EMPTY_FREE_LIST;
}
}
}
const EMPTY_FREE_LIST: *mut FreeListNode = usize::MAX as *mut FreeListNode;
struct FreeListNode {
next: *mut FreeListNode,
size: usize,
}
const NODE_SIZE: usize = core::mem::size_of::<FreeListNode>();
unsafe impl Send for FreeListAllocator {}
unsafe impl GlobalAlloc for FreeListAllocator {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
const MIN_ALIGN: usize = 16;
if layout.align() > MIN_ALIGN {
return null_mut();
}
let size = layout.size().max(NODE_SIZE);
let size = (size + 15) & !15;
let mut free_list: *mut *mut FreeListNode = ptr::addr_of_mut!(FREE_LIST);
loop {
if unsafe { *free_list == EMPTY_FREE_LIST } {
break;
}
let node = unsafe { *free_list };
let node_size = unsafe { (*node).size };
if size <= node_size {
let remaining = node_size - size;
if remaining >= NODE_SIZE {
unsafe {
(*node).size = remaining;
return (node as *mut u8).add(remaining);
}
} else {
unsafe {
*free_list = (*node).next;
return node as *mut u8;
}
}
}
unsafe {
free_list = ptr::addr_of_mut!((*node).next);
}
}
let requested_bytes = round_up(size, PAGE_SIZE);
let previous_page_count = unsafe { crate::grow_memory(requested_bytes / PAGE_SIZE) };
if previous_page_count == usize::MAX {
return null_mut();
}
let ptr = (previous_page_count * PAGE_SIZE) as *mut u8;
unsafe {
self.dealloc(
ptr,
Layout::from_size_align_unchecked(requested_bytes, PAGE_SIZE),
);
self.alloc(layout)
}
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
debug_assert!(ptr.align_offset(NODE_SIZE) == 0);
let ptr = ptr as *mut FreeListNode;
let size = full_size(layout);
let after_new = unsafe { offset_bytes(ptr, size) };
let mut free_list: *mut *mut FreeListNode = ptr::addr_of_mut!(FREE_LIST);
loop {
if unsafe { *free_list == EMPTY_FREE_LIST } {
unsafe {
(*ptr).next = EMPTY_FREE_LIST;
(*ptr).size = size;
*free_list = ptr;
}
return;
}
if unsafe { *free_list == after_new } {
let new_size = unsafe { size + (**free_list).size };
let next = unsafe { (**free_list).next };
if unsafe { next != EMPTY_FREE_LIST && offset_bytes(next, (*next).size) == ptr } {
unsafe {
(*next).size += new_size;
*free_list = next;
}
return;
}
unsafe {
*free_list = ptr;
(*ptr).size = new_size;
(*ptr).next = next;
}
return;
}
if unsafe { *free_list < ptr } {
if unsafe { offset_bytes(*free_list, (**free_list).size) == ptr } {
unsafe {
(**free_list).size += size;
}
return;
}
unsafe {
(*ptr).next = *free_list;
(*ptr).size = size;
*free_list = ptr;
}
return;
}
unsafe {
free_list = ptr::addr_of_mut!((**free_list).next);
}
}
}
#[cfg(feature = "realloc")]
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
let old_size = full_size(layout);
let new_full_size = (new_size.max(NODE_SIZE) + 15) & !15;
if new_full_size <= old_size {
let diff = old_size - new_full_size;
if diff >= NODE_SIZE {
unsafe {
let remainder = ptr.add(new_full_size);
let remainder_layout = Layout::from_size_align_unchecked(diff, 16);
self.dealloc(remainder, remainder_layout);
}
}
return ptr;
}
let needed = new_full_size - old_size;
let target_addr = unsafe { ptr.add(old_size) as *mut FreeListNode };
let mut prev = ptr::addr_of_mut!(FREE_LIST);
loop {
let curr = unsafe { *prev };
if curr == EMPTY_FREE_LIST {
break;
}
if curr < target_addr {
break;
}
if curr == target_addr {
let node_size = unsafe { (*curr).size };
if node_size >= needed {
unsafe {
*prev = (*curr).next;
}
let remaining_in_node = node_size - needed;
if remaining_in_node >= NODE_SIZE {
unsafe {
let remainder_addr = (curr as *mut u8).add(needed) as *mut FreeListNode;
(*remainder_addr).size = remaining_in_node;
(*remainder_addr).next = (*curr).next;
*prev = remainder_addr;
}
}
return ptr;
}
break;
}
unsafe {
prev = ptr::addr_of_mut!((*curr).next);
}
}
unsafe {
let new_ptr = self.alloc(Layout::from_size_align_unchecked(new_size, layout.align()));
if !new_ptr.is_null() {
ptr::copy_nonoverlapping(ptr, new_ptr, layout.size());
self.dealloc(ptr, layout);
}
new_ptr
}
}
}
fn full_size(layout: Layout) -> usize {
let grown = layout.size().max(NODE_SIZE);
(grown + 15) & !15
}
fn round_up(value: usize, increment: usize) -> usize {
debug_assert!(increment.is_power_of_two());
multiple_below(value + (increment - 1), increment)
}
fn multiple_below(value: usize, increment: usize) -> usize {
debug_assert!(increment.is_power_of_two());
value & increment.wrapping_neg()
}
unsafe fn offset_bytes(ptr: *mut FreeListNode, offset: usize) -> *mut FreeListNode {
unsafe { (ptr as *mut u8).add(offset) as *mut FreeListNode }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reset_heap;
use std::sync::{Mutex, MutexGuard};
static TEST_MUTEX: Mutex<()> = Mutex::new(());
struct SafeAllocator {
inner: FreeListAllocator,
_guard: MutexGuard<'static, ()>,
}
impl SafeAllocator {
fn new() -> Self {
let guard = TEST_MUTEX.lock().unwrap();
unsafe {
FreeListAllocator::reset();
reset_heap();
Self {
inner: FreeListAllocator::new(),
_guard: guard,
}
}
}
fn alloc(&self, layout: Layout) -> *mut u8 {
unsafe { self.inner.alloc(layout) }
}
fn dealloc(&self, ptr: *mut u8, layout: Layout) {
unsafe { self.inner.dealloc(ptr, layout) }
}
}
impl Drop for SafeAllocator {
fn drop(&mut self) {
unsafe {
FreeListAllocator::reset();
reset_heap();
}
}
}
#[test]
fn test_basic_allocation() {
let allocator = SafeAllocator::new();
let layout = Layout::new::<u64>();
let ptr = allocator.alloc(layout);
assert!(!ptr.is_null());
unsafe {
*ptr.cast::<u64>() = 0xDEADBEEF;
assert_eq!(*ptr.cast::<u64>(), 0xDEADBEEF);
}
allocator.dealloc(ptr, layout);
}
#[test]
fn test_allocation_order_descending() {
let allocator = SafeAllocator::new();
let layout = Layout::from_size_align(16, 16).unwrap();
let ptr1 = allocator.alloc(layout);
let ptr2 = allocator.alloc(layout);
assert!(!ptr1.is_null());
assert!(!ptr2.is_null());
assert!(ptr1 > ptr2);
assert_eq!(ptr1 as usize - ptr2 as usize, 16);
}
#[test]
fn test_reuse() {
let allocator = SafeAllocator::new();
let layout = Layout::from_size_align(16, 16).unwrap();
let ptr1 = allocator.alloc(layout);
allocator.dealloc(ptr1, layout);
let ptr2 = allocator.alloc(layout);
assert_eq!(ptr1, ptr2);
}
#[test]
fn test_coalescing_merge() {
let allocator = SafeAllocator::new();
let layout = Layout::from_size_align(128, 16).unwrap();
let ptr1 = allocator.alloc(layout);
let ptr2 = allocator.alloc(layout);
let ptr3 = allocator.alloc(layout);
assert_eq!(ptr1 as usize - ptr2 as usize, 128);
assert_eq!(ptr2 as usize - ptr3 as usize, 128);
allocator.dealloc(ptr2, layout);
allocator.dealloc(ptr1, layout);
allocator.dealloc(ptr3, layout);
let layout_large = Layout::from_size_align(384, 16).unwrap();
let ptr_large = allocator.alloc(layout_large);
assert!(!ptr_large.is_null());
assert_eq!(ptr_large, ptr3);
}
#[test]
fn test_memory_growth_multi_page() {
let allocator = SafeAllocator::new();
let layout = Layout::from_size_align(40 * 1024, 16).unwrap();
let ptr1 = allocator.alloc(layout);
assert!(!ptr1.is_null());
let ptr2 = allocator.alloc(layout);
assert!(!ptr2.is_null());
assert_ne!(ptr1, ptr2);
let dist = if ptr1 > ptr2 {
ptr1 as usize - ptr2 as usize
} else {
ptr2 as usize - ptr1 as usize
};
assert!(dist >= 40 * 1024);
}
#[test]
fn test_alignment_large() {
let allocator = SafeAllocator::new();
let layout_bad = Layout::from_size_align(32, 32).unwrap();
let ptr = allocator.alloc(layout_bad);
assert!(ptr.is_null());
}
#[test]
fn test_fragmentation_fill_hole() {
let allocator = SafeAllocator::new();
let layout = Layout::from_size_align(64, 16).unwrap();
let _p1 = allocator.alloc(layout);
let p2 = allocator.alloc(layout);
let p3 = allocator.alloc(layout);
allocator.dealloc(p2, layout);
let p4 = allocator.alloc(layout);
assert_eq!(p4, p2);
let _p5 = allocator.alloc(layout);
assert!(_p5 < p3);
}
#[test]
fn test_realloc_shrink_in_place() {
let allocator = SafeAllocator::new();
let layout = Layout::from_size_align(128, 16).unwrap();
let ptr = allocator.alloc(layout);
assert!(!ptr.is_null());
unsafe {
ptr.write_bytes(0xAA, layout.size());
}
let new_size = 64;
let new_layout = Layout::from_size_align(new_size, 16).unwrap();
let realloc_ptr = unsafe { allocator.inner.realloc(ptr, layout, new_size) };
#[cfg(feature = "realloc")]
assert_eq!(ptr, realloc_ptr);
#[cfg(not(feature = "realloc"))]
assert_ne!(ptr, realloc_ptr);
unsafe {
assert_eq!(*realloc_ptr.cast::<u8>(), 0xAA);
}
allocator.dealloc(realloc_ptr, new_layout);
}
#[test]
fn test_realloc_grow_in_place() {
let allocator = SafeAllocator::new();
let layout = Layout::from_size_align(64, 16).unwrap();
let ptr1 = allocator.alloc(layout);
let ptr2 = allocator.alloc(layout);
assert_eq!(ptr1 as usize - ptr2 as usize, 64);
allocator.dealloc(ptr1, layout);
let new_size = 128;
let ptr2_new = unsafe { allocator.inner.realloc(ptr2, layout, new_size) };
#[cfg(feature = "realloc")]
assert_eq!(ptr2, ptr2_new);
#[cfg(not(feature = "realloc"))]
assert_ne!(ptr2, ptr2_new);
allocator.dealloc(ptr2_new, Layout::from_size_align(new_size, 16).unwrap());
}
}