use std::alloc::{alloc, dealloc, Layout};
use std::ptr::NonNull;
const SIZE_CLASSES: [usize; 9] = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096];
const OBJECTS_PER_PAGE: usize = 64;
pub struct SlabAllocator {
free_lists: [FreeList; 9],
}
struct FreeList {
head: Option<NonNull<FreeNode>>,
size_class: usize,
free_count: usize,
alloc_count: usize,
}
#[repr(C)]
struct FreeNode {
next: Option<NonNull<FreeNode>>,
}
impl SlabAllocator {
pub fn new() -> Self {
Self {
free_lists: [
FreeList::new(SIZE_CLASSES[0]),
FreeList::new(SIZE_CLASSES[1]),
FreeList::new(SIZE_CLASSES[2]),
FreeList::new(SIZE_CLASSES[3]),
FreeList::new(SIZE_CLASSES[4]),
FreeList::new(SIZE_CLASSES[5]),
FreeList::new(SIZE_CLASSES[6]),
FreeList::new(SIZE_CLASSES[7]),
FreeList::new(SIZE_CLASSES[8]),
],
}
}
fn size_class_index(size: usize) -> Option<usize> {
SIZE_CLASSES.iter().position(|&s| size <= s)
}
pub fn alloc<T>(&mut self) -> *mut T {
let size = std::mem::size_of::<T>();
let align = std::mem::align_of::<T>();
let idx = match Self::size_class_index(size.max(align)) {
Some(idx) => idx,
None => {
return self.alloc_large::<T>();
}
};
self.free_lists[idx].alloc() as *mut T
}
pub fn free<T>(&mut self, ptr: *mut T) {
if ptr.is_null() {
return;
}
let size = std::mem::size_of::<T>();
let align = std::mem::align_of::<T>();
let idx = match Self::size_class_index(size.max(align)) {
Some(idx) => idx,
None => {
self.free_large(ptr);
return;
}
};
self.free_lists[idx].free(ptr as *mut u8);
}
fn alloc_large<T>(&self) -> *mut T {
let layout = Layout::new::<T>();
unsafe { alloc(layout) as *mut T }
}
fn free_large<T>(&self, ptr: *mut T) {
let layout = Layout::new::<T>();
unsafe { dealloc(ptr as *mut u8, layout) };
}
pub fn stats(&self) -> SlabStats {
let mut total_free = 0;
let mut total_alloc = 0;
for list in &self.free_lists {
total_free += list.free_count;
total_alloc += list.alloc_count;
}
SlabStats {
free_count: total_free,
alloc_count: total_alloc,
}
}
}
impl Default for SlabAllocator {
fn default() -> Self {
Self::new()
}
}
impl FreeList {
const fn new(size_class: usize) -> Self {
Self {
head: None,
size_class,
free_count: 0,
alloc_count: 0,
}
}
fn alloc(&mut self) -> *mut u8 {
if let Some(node) = self.head {
unsafe {
self.head = node.as_ref().next;
self.free_count -= 1;
self.alloc_count += 1;
return node.as_ptr() as *mut u8;
}
}
self.grow();
if let Some(node) = self.head {
unsafe {
self.head = node.as_ref().next;
self.free_count -= 1;
self.alloc_count += 1;
return node.as_ptr() as *mut u8;
}
}
std::ptr::null_mut()
}
fn free(&mut self, ptr: *mut u8) {
if ptr.is_null() {
return;
}
unsafe {
let node = ptr as *mut FreeNode;
(*node).next = self.head;
self.head = NonNull::new(node);
self.free_count += 1;
self.alloc_count = self.alloc_count.saturating_sub(1);
}
}
fn add_to_free_list(&mut self, ptr: *mut u8) {
unsafe {
let node = ptr as *mut FreeNode;
(*node).next = self.head;
self.head = NonNull::new(node);
self.free_count += 1;
}
}
fn grow(&mut self) {
let page_size = self.size_class * OBJECTS_PER_PAGE;
let layout = Layout::from_size_align(page_size, self.size_class).unwrap();
let page = unsafe { alloc(layout) };
if page.is_null() {
return;
}
for i in 0..OBJECTS_PER_PAGE {
let obj = unsafe { page.add(i * self.size_class) };
self.add_to_free_list(obj);
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct SlabStats {
pub free_count: usize,
pub alloc_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_slab_alloc_free() {
let mut slab = SlabAllocator::new();
let ptr1: *mut u32 = slab.alloc();
assert!(!ptr1.is_null());
unsafe { *ptr1 = 42 };
let ptr2: *mut u64 = slab.alloc();
assert!(!ptr2.is_null());
unsafe { *ptr2 = 123 };
assert_eq!(unsafe { *ptr1 }, 42);
assert_eq!(unsafe { *ptr2 }, 123);
slab.free(ptr1);
slab.free(ptr2);
}
#[test]
fn test_slab_reuse() {
let mut slab = SlabAllocator::new();
let ptr1: *mut u32 = slab.alloc();
slab.free(ptr1);
let ptr2: *mut u32 = slab.alloc();
assert_eq!(ptr1, ptr2);
}
#[test]
fn test_slab_large() {
let mut slab = SlabAllocator::new();
let ptr: *mut [u8; 8192] = slab.alloc();
assert!(!ptr.is_null());
slab.free(ptr);
}
}