use std::{
alloc::{AllocError, Allocator, GlobalAlloc, Layout},
ptr::{self, NonNull},
sync::Mutex,
};
use crate::{bucket::Bucket, realloc::Realloc, AllocResult};
struct InternalAllocator<const N: usize> {
sizes: [usize; N],
buckets: [Bucket; N],
dyn_bucket: Bucket,
}
impl<const N: usize> InternalAllocator<N> {
pub const fn with_bucket_sizes(sizes: [usize; N]) -> Self {
const BUCKET: Bucket = Bucket::new();
InternalAllocator::<N> {
sizes,
buckets: [BUCKET; N],
dyn_bucket: Bucket::new(),
}
}
fn bucket_index_of(&self, layout: Layout) -> usize {
for (i, size) in self.sizes.iter().enumerate() {
if layout.size() <= *size {
return i;
}
}
self.buckets.len()
}
fn bucket_mut(&mut self, index: usize) -> &mut Bucket {
if index == self.buckets.len() {
&mut self.dyn_bucket
} else {
&mut self.buckets[index]
}
}
#[inline]
fn dispatch(&mut self, layout: Layout) -> &mut Bucket {
self.bucket_mut(self.bucket_index_of(layout))
}
#[inline]
pub unsafe fn allocate(&mut self, layout: Layout) -> AllocResult {
self.dispatch(layout).allocate(layout)
}
#[inline]
pub unsafe fn deallocate(&mut self, address: NonNull<u8>, layout: Layout) {
self.dispatch(layout).deallocate(address, layout)
}
pub unsafe fn reallocate(&mut self, realloc: &Realloc) -> AllocResult {
let current_bucket = self.bucket_index_of(realloc.old_layout);
let ideal_bucket = self.bucket_index_of(realloc.new_layout);
if current_bucket == ideal_bucket {
return self.bucket_mut(current_bucket).reallocate(realloc);
}
let new_address = self.bucket_mut(ideal_bucket).allocate(realloc.new_layout)?;
ptr::copy_nonoverlapping(
realloc.address.as_ptr(),
new_address.as_mut_ptr(),
realloc.count(),
);
self.bucket_mut(current_bucket)
.deallocate(realloc.address, realloc.old_layout);
Ok(new_address)
}
}
pub struct Rulloc<const N: usize = 3> {
allocator: Mutex<InternalAllocator<N>>,
}
unsafe impl<const N: usize> Sync for Rulloc<N> {}
impl Rulloc {
pub const fn with_default_config() -> Self {
Self {
allocator: Mutex::new(InternalAllocator::with_bucket_sizes([128, 1024, 8192])),
}
}
}
impl<const N: usize> Rulloc<N> {
pub fn with_bucket_sizes(sizes: [usize; N]) -> Self {
Self {
allocator: Mutex::new(InternalAllocator::with_bucket_sizes(sizes)),
}
}
}
impl Default for Rulloc {
fn default() -> Self {
Rulloc::with_default_config()
}
}
unsafe impl<const N: usize> Allocator for Rulloc<N> {
fn allocate(&self, layout: Layout) -> AllocResult {
unsafe {
match self.allocator.lock() {
Ok(mut allocator) => allocator.allocate(layout),
Err(_) => Err(AllocError),
}
}
}
unsafe fn deallocate(&self, address: NonNull<u8>, layout: Layout) {
if let Ok(mut allocator) = self.allocator.lock() {
allocator.deallocate(address, layout)
}
}
unsafe fn shrink(
&self,
address: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> AllocResult {
match self.allocator.lock() {
Ok(mut allocator) => {
allocator.reallocate(&Realloc::shrink(address, old_layout, new_layout))
}
Err(_) => Err(AllocError),
}
}
unsafe fn grow(
&self,
address: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> AllocResult {
match self.allocator.lock() {
Ok(mut allocator) => {
allocator.reallocate(&Realloc::grow(address, old_layout, new_layout))
}
Err(_) => Err(AllocError),
}
}
unsafe fn grow_zeroed(
&self,
address: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> AllocResult {
let new_address = self.grow(address, old_layout, new_layout)?;
let zero_from = new_address
.as_mut_ptr()
.map_addr(|addr| addr + old_layout.size());
zero_from.write_bytes(0, new_layout.size() - old_layout.size());
Ok(new_address)
}
}
unsafe impl GlobalAlloc for Rulloc {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
match self.allocate(layout) {
Ok(address) => address.cast().as_ptr(),
Err(_) => ptr::null_mut(),
}
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
self.deallocate(NonNull::new_unchecked(ptr), layout)
}
unsafe fn realloc(&self, address: *mut u8, old_layout: Layout, new_size: usize) -> *mut u8 {
let new_layout = Layout::from_size_align(new_size, old_layout.align()).unwrap();
let address = NonNull::new_unchecked(address);
let result = if old_layout.size() <= new_size {
self.shrink(address, old_layout, new_layout)
} else {
self.grow(address, old_layout, new_layout)
};
match result {
Ok(new_address) => new_address.as_mut_ptr(),
Err(_) => ptr::null_mut(),
}
}
}
#[cfg(test)]
mod tests {
use std::{
sync,
thread::{self, ThreadId},
};
use super::*;
use crate::platform::PAGE_SIZE;
#[test]
fn internal_allocator_wrapper() {
let allocator = Rulloc::with_default_config();
unsafe {
let layout1 = Layout::array::<u8>(8).unwrap();
let mut addr1 = allocator.allocate(layout1).unwrap();
addr1.as_mut().fill(69);
let layout2 = Layout::array::<u8>(PAGE_SIZE * 2).unwrap();
let mut addr2 = allocator.allocate(layout2).unwrap();
addr2.as_mut().fill(42);
for value in addr1.as_ref() {
assert_eq!(value, &69);
}
allocator.deallocate(addr1.cast(), layout1);
for value in addr2.as_ref() {
assert_eq!(value, &42);
}
allocator.deallocate(addr2.cast(), layout2);
}
}
#[test]
fn buckets() {
unsafe {
let sizes = [8, 16, 24];
let mut allocator = InternalAllocator::<3>::with_bucket_sizes(sizes);
macro_rules! verify_number_of_regions_per_bucket {
($expected:expr) => {
for i in 0..sizes.len() {
assert_eq!(allocator.buckets[i].regions().len(), $expected[i]);
}
};
}
let layout1 = Layout::array::<u8>(sizes[0]).unwrap();
let addr1 = allocator.allocate(layout1).unwrap().cast();
verify_number_of_regions_per_bucket!([1, 0, 0]);
let layout2 = Layout::array::<u8>(sizes[1]).unwrap();
let addr2 = allocator.allocate(layout2).unwrap().cast();
verify_number_of_regions_per_bucket!([1, 1, 0]);
let layout3 = Layout::array::<u8>(sizes[2]).unwrap();
let addr3 = allocator.allocate(layout3).unwrap().cast();
verify_number_of_regions_per_bucket!([1, 1, 1]);
allocator.deallocate(addr1, layout1);
verify_number_of_regions_per_bucket!([0, 1, 1]);
allocator.deallocate(addr2, layout2);
verify_number_of_regions_per_bucket!([0, 0, 1]);
allocator.deallocate(addr3, layout3);
verify_number_of_regions_per_bucket!([0, 0, 0]);
let layout4 = Layout::array::<u8>(sizes[2] + 128).unwrap();
let addr4 = allocator.allocate(layout4).unwrap().cast();
verify_number_of_regions_per_bucket!([0, 0, 0]);
assert_eq!(allocator.dyn_bucket.regions().len(), 1);
allocator.deallocate(addr4, layout4);
assert_eq!(allocator.dyn_bucket.regions().len(), 0);
let mut realloc_addr = allocator.allocate(layout1).unwrap();
let corruption_check = 213;
realloc_addr.as_mut().fill(corruption_check);
realloc_addr = allocator
.reallocate(&Realloc::grow(realloc_addr.cast(), layout1, layout2))
.unwrap();
verify_number_of_regions_per_bucket!([0, 1, 0]);
realloc_addr = allocator
.reallocate(&Realloc::grow(realloc_addr.cast(), layout2, layout3))
.unwrap();
verify_number_of_regions_per_bucket!([0, 0, 1]);
for value in &realloc_addr.as_ref()[..layout1.size()] {
assert_eq!(*value, corruption_check);
}
}
}
fn verify_buckets_are_empty(allocator: Rulloc) {
let internal = allocator.allocator.lock().unwrap();
for bucket in &internal.buckets {
assert_eq!(bucket.regions().len(), 0);
}
assert_eq!(internal.dyn_bucket.regions().len(), 0);
}
#[test]
fn multiple_threads_synchronized_allocs_and_deallocs() {
let allocator = Rulloc::with_default_config();
let num_threads = 8;
let barrier = sync::Barrier::new(num_threads);
thread::scope(|scope| {
for _ in 0..num_threads {
scope.spawn(|| unsafe {
let num_elements = 1024;
let layout = Layout::array::<ThreadId>(num_elements).unwrap();
let addr = allocator.allocate(layout).unwrap().cast::<ThreadId>();
let id = thread::current().id();
for i in 0..num_elements {
*addr.as_ptr().add(i) = id;
}
barrier.wait();
for i in 0..num_elements {
assert_eq!(*addr.as_ptr().add(i), id);
}
allocator.deallocate(addr.cast(), layout);
});
}
});
verify_buckets_are_empty(allocator);
}
#[test]
fn multiple_threads_unsynchronized_allocs_and_deallocs() {
let allocator = Rulloc::with_default_config();
let num_threads = 8;
let barrier = sync::Barrier::new(num_threads);
thread::scope(|scope| {
for _ in 0..num_threads {
scope.spawn(|| unsafe {
let layouts = [16, 256, 1024, 2048, 4096, 8192]
.map(|size| Layout::array::<u8>(size).unwrap());
let num_allocs = if cfg!(miri) { 20 } else { 1000 };
for layout in layouts {
barrier.wait();
for _ in 0..num_allocs {
let addr = allocator.allocate(layout).unwrap().cast::<u8>();
if cfg!(miri) {
let offsets = [0, layout.size() / 2, layout.size() - 1];
let values = [1, 5, 10];
for (offset, value) in offsets.iter().zip(values) {
*addr.as_ptr().add(*offset) = value;
}
for (offset, value) in offsets.iter().zip(values) {
assert_eq!(*addr.as_ptr().add(*offset), value);
}
} else {
for i in 0..layout.size() {
*addr.as_ptr().add(i) = (i % 256) as u8;
}
for i in 0..layout.size() {
assert_eq!(*addr.as_ptr().add(i), (i % 256) as u8);
}
}
allocator.deallocate(addr, layout);
}
}
});
}
});
verify_buckets_are_empty(allocator);
}
}