use core::alloc::{GlobalAlloc, Layout};
use core::mem::MaybeUninit;
use core::ptr::{null_mut, NonNull};
use core::sync::atomic::{AtomicU8, Ordering};
#[cfg(feature = "use_libc")]
use errno::Errno;
use spin::{Mutex, MutexGuard};
use crate::blocklist::{BlockList, Stats, Validity};
#[cfg(not(feature = "use_libc"))]
use crate::unix::{self, mmap, MmapError};
fn round_up(value: usize, increment: usize) -> usize {
if value == 0 {
return 0;
}
increment * ((value - 1) / increment + 1)
}
pub trait HeapGrower {
type Err;
unsafe fn grow_heap(&mut self, size: usize) -> Result<(*mut u8, usize), Self::Err>;
}
#[cfg(feature = "use_libc")]
#[derive(Default)]
pub struct LibcHeapGrower {
pages: usize,
growths: usize,
}
#[cfg(feature = "use_libc")]
impl HeapGrower for LibcHeapGrower {
type Err = Errno;
unsafe fn grow_heap(&mut self, size: usize) -> Result<(*mut u8, usize), Self::Err> {
if size == 0 {
return Ok((null_mut(), 0));
}
let pagesize = sysconf::page::pagesize();
let to_allocate = round_up(size, pagesize);
let ptr = libc::mmap(
null_mut(),
to_allocate,
libc::PROT_WRITE | libc::PROT_READ,
libc::MAP_ANON | libc::MAP_PRIVATE,
0,
0,
);
if (ptr as i64) == -1 {
return Err(errno::errno());
}
if ptr.is_null() {
return Ok((ptr as *mut u8, 0));
}
self.pages += to_allocate / pagesize;
self.growths += 1;
Ok((ptr as *mut u8, to_allocate))
}
}
#[cfg(not(feature = "use_libc"))]
#[derive(Default)]
pub struct SyscallHeapGrower {
pages: usize,
growths: usize,
}
#[cfg(not(feature = "use_libc"))]
impl HeapGrower for SyscallHeapGrower {
type Err = MmapError;
unsafe fn grow_heap(&mut self, size: usize) -> Result<(*mut u8, usize), MmapError> {
if size == 0 {
return Ok((null_mut(), 0));
}
let pagesize = 4096;
let to_allocate = round_up(size, pagesize);
let ptr = mmap(
null_mut(),
to_allocate,
unix::PROT_WRITE | unix::PROT_READ,
unix::MAP_ANON | unix::MAP_PRIVATE,
0,
0,
)?;
if ptr.is_null() {
return Ok((ptr as *mut u8, 0));
}
self.pages += to_allocate / pagesize;
self.growths += 1;
Ok((ptr as *mut u8, to_allocate))
}
}
pub struct RawAlloc<G> {
pub grower: G,
pub blocks: BlockList,
}
impl<G> Drop for RawAlloc<G> {
fn drop(&mut self) {
let blocks = core::mem::take(&mut self.blocks);
core::mem::forget(blocks);
}
}
impl<G: HeapGrower + Default> Default for RawAlloc<G> {
fn default() -> Self {
RawAlloc {
grower: G::default(),
blocks: BlockList::default(),
}
}
}
impl<G: HeapGrower> RawAlloc<G> {
#[allow(dead_code)]
pub fn new(grower: G) -> Self {
RawAlloc {
grower,
blocks: BlockList::default(),
}
}
pub fn stats(&self) -> (Validity, Stats) {
self.blocks.stats()
}
pub fn block_size(layout: Layout) -> usize {
let aligned_layout = layout
.align_to(16)
.expect("Whoa, serious memory issues")
.pad_to_align();
aligned_layout.size()
}
pub unsafe fn alloc(&mut self, layout: Layout) -> *mut u8 {
let needed_size = RawAlloc::<G>::block_size(layout);
if let Some(range) = self.blocks.pop_size(needed_size) {
return range.start.as_ptr();
}
let growth = self.grower.grow_heap(needed_size);
let (ptr, size) = match growth {
Err(_) => return null_mut(),
Ok(res) => res,
};
if size == needed_size {
return ptr;
}
let free_ptr = NonNull::new_unchecked(ptr.add(needed_size));
if size >= needed_size + BlockList::header_size() {
self.blocks.add_block(free_ptr, size - needed_size);
} else {
debug_assert!(
false,
"Unexpected memory left over. Is page_size a multiple of header size?"
);
}
ptr
}
pub unsafe fn dealloc(&mut self, ptr: *mut u8, layout: Layout) {
let size = RawAlloc::<G>::block_size(layout);
self.blocks.add_block(NonNull::new_unchecked(ptr), size);
}
}
pub struct GenericAllocator<G> {
init: AtomicU8,
raw: MaybeUninit<Mutex<RawAlloc<G>>>,
}
impl<G: HeapGrower + Default> Default for GenericAllocator<G> {
fn default() -> Self {
Self::new()
}
}
impl<G> GenericAllocator<G> {
pub const fn new() -> Self {
GenericAllocator {
init: AtomicU8::new(0),
raw: MaybeUninit::uninit(),
}
}
}
impl<G: HeapGrower + Default> GenericAllocator<G> {
pub unsafe fn get_raw(&self) -> MutexGuard<RawAlloc<G>> {
let mut state = self.init.compare_and_swap(0, 1, Ordering::SeqCst);
if state == 0 {
let raw_loc: *const Mutex<RawAlloc<G>> = self.raw.as_ptr();
let raw_mut: *mut Mutex<RawAlloc<G>> = raw_loc as *mut Mutex<RawAlloc<G>>;
raw_mut.write(Mutex::new(RawAlloc::default()));
let mx: &mut Mutex<RawAlloc<G>> = raw_mut.as_mut().unwrap();
self.init.store(2, Ordering::SeqCst);
return mx.lock();
}
while state == 1 {
core::sync::atomic::spin_loop_hint();
state = self.init.load(Ordering::SeqCst);
}
let ptr = self.raw.as_ptr().as_ref().unwrap();
ptr.lock()
}
pub fn stats(&self) -> (Validity, Stats) {
unsafe { self.get_raw().stats() }
}
}
#[derive(Default)]
pub struct UnixAllocator {
#[cfg(not(feature = "use_libc"))]
alloc: GenericAllocator<SyscallHeapGrower>,
#[cfg(feature = "use_libc")]
alloc: GenericAllocator<LibcHeapGrower>,
}
impl UnixAllocator {
pub const fn new() -> Self {
UnixAllocator {
alloc: GenericAllocator::new(),
}
}
pub fn stats(&self) -> (Validity, Stats) {
self.alloc.stats()
}
}
unsafe impl GlobalAlloc for UnixAllocator {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
self.alloc.get_raw().alloc(layout)
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
self.alloc.get_raw().dealloc(ptr, layout)
}
}
pub struct ToyHeap {
pub page_size: usize,
pub size: usize,
pub heap: [u8; 256 * 1024],
}
impl Default for ToyHeap {
fn default() -> Self {
ToyHeap {
page_size: 64,
size: 0,
heap: [0; 256 * 1024],
}
}
}
pub struct ToyHeapOverflowError();
impl HeapGrower for ToyHeap {
type Err = ToyHeapOverflowError;
unsafe fn grow_heap(&mut self, size: usize) -> Result<(*mut u8, usize), Self::Err> {
if self.size + size > self.heap.len() {
return Err(ToyHeapOverflowError());
}
let allocating = round_up(size, self.page_size);
let ptr = self.heap.as_mut_ptr().add(self.size);
self.size += allocating;
Ok((ptr, allocating))
}
}
#[cfg(test)]
mod tests {
use super::*;
use test_env_log::test;
#[test]
fn test_basic() {
let toy_heap = ToyHeap::default();
let mut allocator = RawAlloc::new(toy_heap);
const BLOCKS: usize = 3;
let layouts: [Layout; BLOCKS] = [
Layout::from_size_align(64, 16).unwrap(),
Layout::from_size_align(64, 16).unwrap(),
Layout::from_size_align(224, 16).unwrap(),
];
let pointers: [*mut u8; BLOCKS] = unsafe {
let mut pointers = [null_mut(); BLOCKS];
for (i, &l) in layouts.iter().enumerate() {
pointers[i] = allocator.alloc(l);
let (validity, _stats) = allocator.stats();
assert!(validity.is_valid());
}
pointers
};
for i in 0..BLOCKS - 1 {
let l = layouts[i];
let expected = unsafe { pointers[i].add(l.size()) };
let found = pointers[i + 1];
assert_eq!(expected, found);
}
let toy_heap = &allocator.grower;
let page_size = toy_heap.page_size;
let total_allocated: usize = layouts.iter().map(|l| l.size()).sum();
let page_space = round_up(total_allocated, page_size);
assert_eq!(toy_heap.size, page_space);
unsafe { allocator.dealloc(pointers[1], layouts[1]) };
let (validity, _stats) = allocator.stats();
assert!(validity.is_valid());
let mut iter = allocator.blocks.iter();
let first = iter.next();
assert!(first.is_some());
let first = first.expect("This should not be null");
assert_eq!(first.size(), layouts[1].size());
let next_exists = iter.next().is_some();
log::info!("dealloc: {}", allocator.blocks);
assert!(next_exists);
log::info!("post-alloc: {}", allocator.blocks);
unsafe {
let newp = allocator.alloc(Layout::from_size_align(112, 16).unwrap());
let (validity, _stats) = allocator.stats();
assert!(validity.is_valid());
assert_eq!(
newp,
pointers[2].add(round_up(layouts[2].size(), page_size))
);
log::info!("p112: {}", allocator.blocks);
let p32 = allocator.alloc(Layout::from_size_align(32, 16).unwrap());
let (validity, _stats) = allocator.stats();
assert!(validity.is_valid());
assert_eq!(p32, pointers[1].add(32));
log::info!("p32: {}", allocator.blocks);
let p8 = allocator.alloc(Layout::from_size_align(16, 4).unwrap());
let (validity, _stats) = allocator.stats();
assert!(validity.is_valid());
log::info!("p8: {}", allocator.blocks);
let p16 = allocator.alloc(Layout::from_size_align(8, 1).unwrap());
let (validity, _stats) = allocator.stats();
assert!(validity.is_valid());
log::info!("p16: {}", allocator.blocks);
assert_eq!(p8, pointers[1].add(16));
assert_eq!(p16, pointers[1]);
log::info!("done: {}", allocator.blocks);
};
}
}