use crate::{PAGE_SIZE, grow_memory};
use core::{
alloc::{GlobalAlloc, Layout},
ptr::{self, null_mut},
};
unsafe impl Sync for BumpFreeListAllocator {}
pub struct BumpFreeListAllocator;
impl BumpFreeListAllocator {
pub const fn new() -> Self {
Self
}
}
struct Node {
next: *mut Node,
size: usize,
}
static mut FREE_LIST: *mut Node = null_mut();
static mut HEAP_TOP: usize = 0;
static mut HEAP_END: usize = 0;
unsafe impl GlobalAlloc for BumpFreeListAllocator {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
let align_req = layout.align().max(16);
let size = layout.size().max(16);
let size = (size + 15) & !15;
unsafe {
let mut prev = ptr::addr_of_mut!(FREE_LIST);
let mut curr = *prev;
while !curr.is_null() {
if (*curr).size >= size {
*prev = (*curr).next;
return curr as *mut u8;
}
prev = ptr::addr_of_mut!((*curr).next);
curr = *prev;
}
}
unsafe { self.bump_alloc(size, align_req) }
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
let size = layout.size().max(16);
let size = (size + 15) & !15;
unsafe {
let node = ptr as *mut Node;
(*node).size = size;
(*node).next = FREE_LIST;
FREE_LIST = node;
}
}
#[cfg(feature = "realloc")]
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
let old_size = (layout.size().max(16) + 15) & !15;
let req_new_size = (new_size.max(16) + 15) & !15;
let heap_top = unsafe { HEAP_TOP };
if ptr as usize + old_size == heap_top {
let diff = req_new_size.saturating_sub(old_size);
if diff == 0 {
return ptr;
}
unsafe {
if HEAP_TOP + diff <= HEAP_END {
HEAP_TOP += diff;
return ptr;
}
let pages_needed =
((HEAP_TOP + diff - HEAP_END + PAGE_SIZE - 1) / PAGE_SIZE).max(1);
if grow_memory(pages_needed) != usize::MAX {
HEAP_END += pages_needed * PAGE_SIZE;
HEAP_TOP += diff;
return ptr;
}
}
}
unsafe {
let new_ptr = self.alloc(Layout::from_size_align_unchecked(new_size, layout.align()));
if !new_ptr.is_null() {
ptr::copy_nonoverlapping(ptr, new_ptr, layout.size());
self.dealloc(ptr, layout);
}
new_ptr
}
}
}
impl BumpFreeListAllocator {
unsafe fn bump_alloc(&self, size: usize, align: usize) -> *mut u8 {
unsafe {
let mut ptr = HEAP_TOP;
ptr = (ptr + align - 1) & !(align - 1);
if ptr + size > HEAP_END || ptr < HEAP_TOP {
let bytes_needed = (ptr + size).saturating_sub(HEAP_END);
let pages_needed = ((bytes_needed + PAGE_SIZE - 1) / PAGE_SIZE).max(1);
let prev_page = grow_memory(pages_needed);
if prev_page == usize::MAX {
return null_mut();
}
if HEAP_END == 0 {
let memory_start = prev_page * PAGE_SIZE;
ptr = memory_start;
ptr = (ptr + align - 1) & !(align - 1);
HEAP_END = memory_start + pages_needed * PAGE_SIZE;
} else {
HEAP_END += pages_needed * PAGE_SIZE;
}
}
HEAP_TOP = ptr + size;
ptr as *mut u8
}
}
pub unsafe fn reset() {
unsafe {
FREE_LIST = null_mut();
HEAP_TOP = 0;
HEAP_END = 0;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reset_heap;
use core::alloc::Layout;
use std::sync::{Mutex, MutexGuard};
static TEST_MUTEX: Mutex<()> = Mutex::new(());
struct SafeAllocator {
inner: BumpFreeListAllocator,
_guard: MutexGuard<'static, ()>,
}
impl SafeAllocator {
fn new() -> Self {
let guard = TEST_MUTEX.lock().unwrap();
unsafe {
BumpFreeListAllocator::reset(); reset_heap(); Self {
inner: BumpFreeListAllocator::new(),
_guard: guard,
}
}
}
fn alloc(&self, layout: Layout) -> *mut u8 {
unsafe { self.inner.alloc(layout) }
}
fn dealloc(&self, ptr: *mut u8, layout: Layout) {
unsafe { self.inner.dealloc(ptr, layout) }
}
}
impl Drop for SafeAllocator {
fn drop(&mut self) {
unsafe {
BumpFreeListAllocator::reset();
reset_heap();
}
}
}
#[test]
fn test_basic_allocation() {
let allocator = SafeAllocator::new();
let layout = Layout::new::<u64>();
let ptr = allocator.alloc(layout);
assert!(!ptr.is_null());
unsafe {
*ptr.cast::<u64>() = 42;
assert_eq!(*ptr.cast::<u64>(), 42);
}
allocator.dealloc(ptr, layout);
}
#[test]
fn test_multiple_allocations() {
let allocator = SafeAllocator::new();
let layout = Layout::new::<u32>();
let ptr1 = allocator.alloc(layout);
let ptr2 = allocator.alloc(layout);
assert!(!ptr1.is_null());
assert!(!ptr2.is_null());
assert_ne!(ptr1, ptr2);
let diff = (ptr2 as usize).wrapping_sub(ptr1 as usize);
assert!(diff >= 16);
}
#[test]
fn test_memory_grow() {
let allocator = SafeAllocator::new();
let layout = Layout::from_size_align(40 * 1024, 16).unwrap();
let ptr1 = allocator.alloc(layout);
let ptr2 = allocator.alloc(layout);
assert!(!ptr1.is_null());
assert!(!ptr2.is_null());
assert_ne!(ptr1, ptr2);
unsafe {
ptr1.write_bytes(1, layout.size());
ptr2.write_bytes(2, layout.size());
}
}
#[test]
fn test_freelist_reuse() {
let allocator = SafeAllocator::new();
let layout = Layout::from_size_align(128, 16).unwrap();
let ptr1 = allocator.alloc(layout);
allocator.dealloc(ptr1, layout);
let ptr2 = allocator.alloc(layout);
assert_eq!(ptr1, ptr2);
}
#[test]
fn test_realloc_in_place() {
let allocator = SafeAllocator::new();
let layout = Layout::from_size_align(32, 16).unwrap();
let ptr = allocator.alloc(layout);
assert!(!ptr.is_null());
unsafe {
ptr.write_bytes(0xAA, layout.size());
}
let new_size = 64;
let new_layout = Layout::from_size_align(new_size, 16).unwrap();
let realloc_ptr = unsafe { allocator.inner.realloc(ptr, layout, new_size) };
#[cfg(feature = "realloc")]
assert_eq!(ptr, realloc_ptr); #[cfg(not(feature = "realloc"))]
assert_ne!(ptr, realloc_ptr); unsafe {
assert_eq!(*realloc_ptr.cast::<u8>(), 0xAA); assert_eq!(*realloc_ptr.add(31).cast::<u8>(), 0xAA); realloc_ptr
.add(32)
.write_bytes(0xBB, new_size - layout.size()); assert_eq!(*realloc_ptr.add(32).cast::<u8>(), 0xBB);
}
allocator.dealloc(realloc_ptr, new_layout);
}
#[test]
fn test_realloc_not_in_place() {
let allocator = SafeAllocator::new();
let layout1 = Layout::from_size_align(32, 16).unwrap();
let ptr1 = allocator.alloc(layout1); assert!(!ptr1.is_null());
let layout2 = Layout::from_size_align(32, 16).unwrap();
let ptr2 = allocator.alloc(layout2); assert!(!ptr2.is_null());
assert_ne!(ptr1, ptr2);
unsafe {
ptr1.write_bytes(0xAA, layout1.size());
}
let new_size = 64;
let new_layout = Layout::from_size_align(new_size, 16).unwrap();
let realloc_ptr = unsafe { allocator.inner.realloc(ptr1, layout1, new_size) };
assert!(!realloc_ptr.is_null());
assert_ne!(ptr1, realloc_ptr); unsafe {
assert_eq!(*realloc_ptr.cast::<u8>(), 0xAA); assert_eq!(*realloc_ptr.add(31).cast::<u8>(), 0xAA); }
allocator.dealloc(realloc_ptr, new_layout);
allocator.dealloc(ptr2, layout2);
}
#[test]
fn test_realloc_shrink() {
let allocator = SafeAllocator::new();
let layout = Layout::from_size_align(64, 16).unwrap();
let ptr = allocator.alloc(layout);
assert!(!ptr.is_null());
unsafe {
ptr.write_bytes(0xCC, layout.size());
}
let new_size = 32;
let new_layout = Layout::from_size_align(new_size, 16).unwrap();
let realloc_ptr = unsafe { allocator.inner.realloc(ptr, layout, new_size) };
#[cfg(feature = "realloc")]
assert_eq!(ptr, realloc_ptr);
#[cfg(not(feature = "realloc"))]
assert_ne!(ptr, realloc_ptr);
unsafe {
assert_eq!(*realloc_ptr.cast::<u8>(), 0xCC); assert_eq!(*realloc_ptr.add(31).cast::<u8>(), 0xCC); }
allocator.dealloc(realloc_ptr, new_layout);
}
}