use core::{
alloc::{GlobalAlloc, Layout},
cell::UnsafeCell,
ptr,
};
const PAGE_SIZE: usize = 65536;
#[derive(Eq, PartialEq, Copy, Clone)]
struct PageCount(usize);
impl PageCount {
fn size_in_bytes(&self) -> usize { self.0 * PAGE_SIZE }
}
const ERROR_PAGE_COUNT: PageCount = PageCount(usize::MAX);
extern "C" {
static __heap_base: u8;
}
unsafe impl Sync for BumpAllocator {}
pub struct BumpAllocator {
next: UnsafeCell<usize>,
heap_start: UnsafeCell<usize>,
heap_end: UnsafeCell<usize>,
allocations: UnsafeCell<usize>,
last_alloc: UnsafeCell<usize>,
}
impl BumpAllocator {
pub const unsafe fn new() -> Self {
Self {
next: UnsafeCell::new(0),
heap_start: UnsafeCell::new(0),
heap_end: UnsafeCell::new(0),
allocations: UnsafeCell::new(0),
last_alloc: UnsafeCell::new(0),
}
}
fn memory_grow(&self, delta: PageCount) -> PageCount {
PageCount(core::arch::wasm32::memory_grow(0, delta.0))
}
fn memory_size(&self) -> PageCount {
PageCount(core::arch::wasm32::memory_size(0))
}
}
unsafe impl GlobalAlloc for BumpAllocator {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
let heap_end = &mut *self.heap_end.get();
let next = &mut *self.next.get();
if *heap_end == 0 {
let heap_base = unsafe { &__heap_base as *const _ as usize };
let actual_size = self.memory_size().size_in_bytes();
*next = heap_base;
*self.heap_start.get() = heap_base;
*self.last_alloc.get() = heap_base;
*heap_end = actual_size;
}
let alloc_start = align_up(*next, layout.align());
let alloc_end = match alloc_start.checked_add(layout.size()) {
Some(end) => end,
None => return ptr::null_mut(),
};
if alloc_end > *heap_end {
let space_needed = alloc_end - *heap_end;
let pages_to_request = pages_to_request(space_needed);
let previous_page_count = self.memory_grow(pages_to_request);
if previous_page_count == ERROR_PAGE_COUNT {
return ptr::null_mut();
}
*heap_end += pages_to_request.size_in_bytes();
}
*self.allocations.get() += 1;
*self.last_alloc.get() = alloc_start;
*next = alloc_end;
alloc_start as *mut u8
}
unsafe fn dealloc(&self, ptr: *mut u8, _layout: Layout) {
let allocations = self.allocations.get();
let last_alloc = self.last_alloc.get();
let next = self.next.get();
*allocations -= 1;
if *allocations == 0 {
let heap_start = self.heap_start.get();
*next = *heap_start;
*last_alloc = *heap_start;
} else if *last_alloc as *mut u8 == ptr {
*next = *last_alloc;
}
}
}
fn align_up(addr: usize, align: usize) -> usize { (addr + align - 1) & !(align - 1) }
const fn pages_to_request(space_needed: usize) -> PageCount {
PageCount((space_needed >> 16) + ((space_needed & 0xffff != 0) as usize))
}