use crate::get_field_ptr;
use crate::locks::LockTrait;
use crate::prelude::*;
use crate::slots::*;
use core::alloc::Layout;
use core::cell::UnsafeCell;
use core::ptr::NonNull;
use intrusive_doubly_list::DoublyLinkPointerRaw;
use intrusive_doubly_list::IntrusiveDLinkListRaw;
const MAGIC_HEADER: u64 = 0xFB4E_4C07_ACCE_45A2;
type FreeFunc<T, LOCK> =
fn(slab_ptr: NonNull<Slab<T, LOCK>>, slot_ptr: NonNull<T>, slab_lay: Layout) -> Result<()>;
type SlabPtr<T, LOCK> = UnsafeCell<Option<NonNull<Slab<T, LOCK>>>>;
#[repr(C)]
pub struct Slab<T, LOCK>
where
LOCK: LockTrait,
{
magic: UnsafeCell<u64>,
lock: LOCK,
slave_list: IntrusiveDLinkListRaw<Slab<T, LOCK>>,
master_ptr: Option<NonNull<Slab<T, LOCK>>>,
next_slab: SlabPtr<T, LOCK>,
last_slab: SlabPtr<T, LOCK>,
free_behaviour: FreeFunc<T, LOCK>,
slots: UnsafeCell<Slots<T>>,
ref_count: UnsafeCell<u32>,
is_linked: UnsafeCell<bool>,
}
pub fn free_master<T, LOCK>(
master_ptr: NonNull<Slab<T, LOCK>>,
slot_ptr: NonNull<T>,
slab_lay: Layout,
) -> Result<()>
where
LOCK: LockTrait,
{
let lock_ptr = get_field_ptr!(master_ptr, lock).as_ptr();
let _guard = unsafe { (*lock_ptr).lock() };
let slots_ptr = get_field_ptr!(master_ptr, slots);
let slots_ = slots_ptr.unsafe_ref().get();
unsafe {
(*slots_).try_push_slot(slot_ptr)?;
}
let can_drop = Slab::<T, LOCK>::can_drop(master_ptr);
if can_drop {
Slab::<T, LOCK>::write_magic(master_ptr, 0);
drop(_guard);
free_slab(master_ptr.cast(), slab_lay)?;
Ok(())
} else {
let slave_list_ptr = get_field_ptr!(master_ptr, slave_list);
IntrusiveDLinkListRaw::<Slab<T, LOCK>>::push(slave_list_ptr, master_ptr);
Ok(())
}
}
pub fn free_slave<T, LOCK>(
slab_ptr: NonNull<Slab<T, LOCK>>,
slot_ptr: NonNull<T>,
slab_lay: Layout,
) -> Result<()>
where
LOCK: LockTrait,
{
let master_ptr = Slab::<T, LOCK>::try_get_master_ptr(slab_ptr)?;
Slab::<T, LOCK>::validate_slab(master_ptr)?;
let lock_ptr = get_field_ptr!(master_ptr, lock).as_ptr();
let _guard = unsafe { (*lock_ptr).lock() };
let slave_list_ptr = get_field_ptr!(master_ptr, slave_list);
let slots_ptr = get_field_ptr!(slab_ptr, slots);
let slots_ = slots_ptr.unsafe_ref().get();
unsafe {
(*slots_).try_push_slot(slot_ptr)?;
}
let can_drop = Slab::<T, LOCK>::can_drop(slab_ptr);
if can_drop {
Slab::<T, LOCK>::ref_down(master_ptr);
IntrusiveDLinkListRaw::<Slab<T, LOCK>>::remove(slave_list_ptr, slab_ptr);
let can_drop_master = Slab::<T, LOCK>::can_drop(master_ptr);
Slab::<T, LOCK>::write_magic(slab_ptr, 0);
if can_drop_master {
Slab::<T, LOCK>::write_magic(master_ptr, 0);
drop(_guard);
free_slab(slab_ptr.cast(), slab_lay)?;
free_slab(master_ptr.cast(), slab_lay)?;
} else {
drop(_guard);
free_slab(slab_ptr.cast(), slab_lay)?;
}
Ok(())
} else {
IntrusiveDLinkListRaw::<Slab<T, LOCK>>::push(slave_list_ptr, slab_ptr);
Ok(())
}
}
impl<T, LOCK> Slab<T, LOCK>
where
LOCK: LockTrait,
{
const SLAB_ALIGNMENT: usize = align_of::<Slab<T, LOCK>>();
const SLAB_HEADE_SIZE: usize = size_of::<Slab<T, LOCK>>();
const _CHECK: u8 = const {
let lock_size = size_of::<LOCK>();
assert!(
lock_size <= size_of::<u128>(),
"The Lock Size Should Be Smaller Than 16 Bytes"
);
0
};
fn can_drop(slab_ptr: NonNull<Slab<T, LOCK>>) -> bool {
let slots_ptr = get_field_ptr!(slab_ptr, slots);
let slots_ = slots_ptr.unsafe_ref().get();
let state = unsafe { (*slots_).get_state() };
let ref_count_ptr = get_field_ptr!(slab_ptr, ref_count);
let ref_count = unsafe { *ref_count_ptr.unsafe_ref().get() };
ref_count == 0 && state == SlotsState::Free
}
pub fn alloc_slot(
slab_ptr_master: NonNull<Self>,
slab_memory_layout: Layout,
) -> Result<NonNull<T>> {
Self::validate_slab(slab_ptr_master)?;
let lock_ptr = get_field_ptr!(slab_ptr_master, lock);
let _guard = lock_ptr.unsafe_ref().lock();
let master_ptr = Self::try_get_master_ptr(slab_ptr_master)?;
(master_ptr == slab_ptr_master).on_err(SlabError::FatalError)?;
let slave_list_ptr = get_field_ptr!(slab_ptr_master, slave_list);
while let Some(slab_ptr) = IntrusiveDLinkListRaw::<Self>::pop(slave_list_ptr) {
let slots_ptr = get_field_ptr!(slab_ptr, slots);
let slots_ = slots_ptr.unsafe_ref().get();
match unsafe { (*slots_).try_pop_slot() } {
Ok(slot) => {
let state = unsafe { (*slots_).get_state() };
if state == SlotsState::Partial {
IntrusiveDLinkListRaw::<Self>::push(slave_list_ptr, slab_ptr);
}
return Ok(slot);
}
Err(err) => match err {
SlabError::OutOfMemory => {
let state = unsafe { (*slots_).get_state() };
if state == SlotsState::Full {
IntrusiveDLinkListRaw::<Self>::push(slave_list_ptr, slab_ptr);
} else {
return Err(SlabError::FatalError);
}
}
_ => {
return Err(err);
}
},
}
}
let slave_slab_ptr =
Self::alloc_slab_ptr(slab_memory_layout, Some(slab_ptr_master), free_slave)?;
let slots_ptr = get_field_ptr!(slave_slab_ptr, slots);
let slots_ = slots_ptr.unsafe_ref().get();
let ptr = unsafe {
(*slots_)
.try_pop_slot()
.map_err(|_| SlabError::FatalError)?
};
let state = unsafe { (*slots_).get_state() };
if state == SlotsState::Full {
IntrusiveDLinkListRaw::<Self>::remove(slave_list_ptr, slave_slab_ptr);
}
Ok(ptr.cast())
}
pub fn free_slot(slot_ptr: NonNull<T>, slab_memory_layout: Layout) -> Result<()> {
let address = slot_ptr.as_address();
let slab_address = address.align_down(slab_memory_layout.size());
let slab_ptr = NonNull::<Slab<T, LOCK>>::from_address(slab_address)
.ok_or(SlabError::InvalidPointer)?;
Self::validate_slab(slab_ptr)?;
let free_behaviour_ptr = get_field_ptr!(slab_ptr, free_behaviour);
let free_behaviour = unsafe { free_behaviour_ptr.read() };
(free_behaviour)(slab_ptr, slot_ptr, slab_memory_layout)
}
pub fn alloc_slab_ptr(
slab_layout: Layout,
master_ptr: Option<NonNull<Slab<T, LOCK>>>,
free_behaviour: FreeFunc<T, LOCK>,
) -> Result<NonNull<Slab<T, LOCK>>> {
let _ = Self::_CHECK;
let buffer_ptr = alloc_slab(slab_layout).ok_or(SlabError::OutOfMemory)?;
let buffer_size = slab_layout.size();
buffer_ptr
.is_aligned_on_pow2(Self::SLAB_ALIGNMENT)
.on_err(SlabError::AlignmentMismatch)?;
(Self::SLAB_HEADE_SIZE < buffer_size).on_err(SlabError::OutOfMemory)?;
let slot_area_ptr = buffer_ptr.unsafe_unsafe_add(Self::SLAB_HEADE_SIZE);
let slot_area_size = buffer_size - Self::SLAB_HEADE_SIZE;
let slots = Slots::<T>::new(slot_area_ptr, slot_area_size)?;
let slab_ptr = buffer_ptr.cast::<Slab<T, LOCK>>();
let master_fix_ptr = match master_ptr {
Some(ptr) => {
let ref_c_ptr = get_field_ptr!(ptr, ref_count);
unsafe { *ref_c_ptr.unsafe_ref().get() += 1 }
ptr
}
None => slab_ptr,
};
let slab = Self {
magic: UnsafeCell::new(MAGIC_HEADER),
slots: UnsafeCell::new(slots),
ref_count: UnsafeCell::new(0),
slave_list: IntrusiveDLinkListRaw::new(),
master_ptr: Some(master_fix_ptr),
next_slab: UnsafeCell::new(None),
last_slab: UnsafeCell::new(None),
is_linked: UnsafeCell::new(false),
lock: LOCK::init(),
free_behaviour,
};
unsafe { slab_ptr.write(slab) };
IntrusiveDLinkListRaw::init_node(slab_ptr);
let slave_list_ptr = get_field_ptr!(master_fix_ptr, slave_list);
IntrusiveDLinkListRaw::<Self>::push(slave_list_ptr, slab_ptr);
Ok(slab_ptr)
}
#[inline]
pub fn validate_slab(slab_ptr: NonNull<Slab<T, LOCK>>) -> Result<()> {
Self::check_magic_raw(slab_ptr)?;
let master_field_ptr = get_field_ptr!(slab_ptr, master_ptr);
let master_ptr_raw = unsafe { master_field_ptr.read() };
let Some(master_ptr) = master_ptr_raw else {
return Ok(());
};
Self::check_magic_raw(master_ptr)
}
#[inline]
fn check_magic_raw(slab_ptr: NonNull<Slab<T, LOCK>>) -> Result<()> {
let magic_ptr = get_field_ptr!(slab_ptr, magic);
let magic = unsafe { *magic_ptr.unsafe_ref().get() };
(magic == MAGIC_HEADER).as_result((), SlabError::InvalidPointer)
}
fn write_magic(slab_ptr: NonNull<Slab<T, LOCK>>, val: u64) {
let magic_ptr = get_field_ptr!(slab_ptr, magic);
unsafe { *magic_ptr.unsafe_ref().get() = val };
}
fn try_get_master_ptr(slab_ptr: NonNull<Slab<T, LOCK>>) -> Result<NonNull<Slab<T, LOCK>>> {
let master_field_ptr = get_field_ptr!(slab_ptr, master_ptr);
let master_ptr_raw = unsafe { master_field_ptr.read() };
if let Some(master_ptr) = master_ptr_raw {
Ok(master_ptr)
} else {
Err(SlabError::InvalidPointer)
}
}
#[inline]
pub fn ref_down(slab_ptr: NonNull<Slab<T, LOCK>>) {
let ref_count_ptr = get_field_ptr!(slab_ptr, ref_count);
unsafe { *ref_count_ptr.unsafe_ref().get() -= 1 };
}
#[inline]
#[allow(unused)]
pub fn atomic_ref_down(slab_ptr: NonNull<Slab<T, LOCK>>) -> Result<()> {
Self::validate_slab(slab_ptr)?;
let lock_ptr = get_field_ptr!(slab_ptr, lock);
let _guard = lock_ptr.unsafe_ref().lock();
let ref_count_ptr = get_field_ptr!(slab_ptr, ref_count);
unsafe { *ref_count_ptr.unsafe_ref().get() -= 1 };
Ok(())
}
#[inline]
pub fn atomic_ref_up(slab_ptr: NonNull<Slab<T, LOCK>>) -> Result<()> {
Self::validate_slab(slab_ptr)?;
let lock_ptr = get_field_ptr!(slab_ptr, lock);
let _guard = lock_ptr.unsafe_ref().lock();
let ref_count_ptr = get_field_ptr!(slab_ptr, ref_count);
unsafe { *ref_count_ptr.unsafe_ref().get() += 1 };
Ok(())
}
pub fn atomic_release_master(
master_ptr: NonNull<Slab<T, LOCK>>,
slab_lay: Layout,
) -> Result<()> {
Self::validate_slab(master_ptr)?;
let lock_ptr = get_field_ptr!(master_ptr, lock);
let _guard = lock_ptr.unsafe_ref().lock();
Self::ref_down(master_ptr);
let can_drop = Self::can_drop(master_ptr);
if can_drop {
Self::write_magic(master_ptr, 0);
drop(_guard);
free_slab(master_ptr.cast(), slab_lay)?;
}
Ok(())
}
}
impl<T, LOCK> DoublyLinkPointerRaw<Slab<T, LOCK>> for Slab<T, LOCK>
where
LOCK: LockTrait,
{
fn get_next(node_ptr: NonNull<Slab<T, LOCK>>) -> Option<NonNull<Slab<T, LOCK>>> {
unsafe { *node_ptr.unsafe_ref().next_slab.get() }
}
fn get_last(node_ptr: NonNull<Slab<T, LOCK>>) -> Option<NonNull<Slab<T, LOCK>>> {
unsafe { *node_ptr.unsafe_ref().last_slab.get() }
}
fn set_next(node_ptr: NonNull<Slab<T, LOCK>>, next_ptr: Option<NonNull<Slab<T, LOCK>>>) {
unsafe { *node_ptr.unsafe_ref().next_slab.get() = next_ptr }
}
fn set_last(node_ptr: NonNull<Slab<T, LOCK>>, last_ptr: Option<NonNull<Slab<T, LOCK>>>) {
unsafe { *node_ptr.unsafe_ref().last_slab.get() = last_ptr }
}
fn set_link_state(node_ptr: NonNull<Slab<T, LOCK>>, state: bool) {
unsafe { *node_ptr.unsafe_ref().is_linked.get() = state }
}
fn is_linked(node_ptr: NonNull<Slab<T, LOCK>>) -> bool {
unsafe { *node_ptr.unsafe_ref().is_linked.get() }
}
}
#[cfg(test)]
mod tests {
extern crate std;
use crate::{define_allocation_hooks, locks::SpinLock, prelude::*, slab::Slab};
use std::{alloc::Layout, ptr::NonNull};
pub fn alloc_i(_: Layout) -> Option<NonNull<u8>> {
panic!("fake allocation only use for make linker quit");
}
pub fn free_i(_: NonNull<u8>, _: Layout) -> Result<()> {
panic!("fake allocation only use for make linker quit");
}
define_allocation_hooks!(alloc_i, free_i);
#[test]
fn basic_test() {
let size = size_of::<Slab<u128, SpinLock>>();
assert!(size <= 96, "Slab Header Size Test failed");
}
}