use core::{any::TypeId, mem, sync::atomic::Ordering};
use core::{
cell::UnsafeCell,
cmp, fmt,
hash::{Hash, Hasher},
marker::PhantomData,
mem::MaybeUninit,
ops::{Deref, DerefMut},
ptr::{self, NonNull},
sync::atomic::AtomicPtr,
};
use as_slice::{AsMutSlice, AsSlice};
pub mod singleton;
pub struct Pool<T> {
head: AtomicPtr<Node<T>>,
_not_send_or_sync: PhantomData<*const ()>,
}
#[cfg(any(armv7a, armv7r, armv7m, armv8m_main, test))]
unsafe impl<T> Sync for Pool<T> {}
unsafe impl<T> Send for Pool<T> {}
impl<T> Pool<T> {
pub const fn new() -> Self {
Pool {
head: AtomicPtr::new(ptr::null_mut()),
_not_send_or_sync: PhantomData,
}
}
pub fn alloc(&self) -> Option<Box<T, Uninit>> {
if let Some(node) = self.pop() {
Some(Box {
node,
_state: PhantomData,
})
} else {
None
}
}
pub fn free<S>(&self, value: Box<T, S>)
where
S: 'static,
{
if TypeId::of::<S>() == TypeId::of::<Init>() {
unsafe {
ptr::drop_in_place(value.node.as_ref().data.get());
}
}
self.push(value.node)
}
pub fn grow(&self, memory: &'static mut [u8]) -> usize {
let mut p = memory.as_mut_ptr();
let mut len = memory.len();
let align = mem::align_of::<Node<T>>();
let sz = mem::size_of::<Node<T>>();
let rem = (p as usize) % align;
if rem != 0 {
let offset = align - rem;
if offset >= len {
return 0;
}
p = unsafe { p.add(offset) };
len -= offset;
}
let mut n = 0;
while len >= sz {
self.push(unsafe { NonNull::new_unchecked(p as *mut _) });
n += 1;
p = unsafe { p.add(sz) };
len -= sz;
}
n
}
pub fn grow_exact<A>(&self, memory: &'static mut MaybeUninit<A>) -> usize
where
A: AsMutSlice<Element = Node<T>>,
{
let nodes = unsafe { (*memory.as_mut_ptr()).as_mut_slice() };
let cap = nodes.len();
for p in nodes {
self.push(NonNull::from(p))
}
cap
}
fn pop(&self) -> Option<NonNull<Node<T>>> {
loop {
let head = self.head.load(Ordering::Acquire);
if let Some(nn_head) = NonNull::new(head) {
let next = unsafe { (*head).next };
match self.head.compare_exchange_weak(
head,
next,
Ordering::Release, Ordering::Relaxed, ) {
Ok(_) => break Some(nn_head),
Err(_) => continue,
}
} else {
break None;
}
}
}
fn push(&self, mut new_head: NonNull<Node<T>>) {
let mut head = self.head.load(Ordering::Relaxed);
loop {
unsafe { new_head.as_mut().next = head }
match self.head.compare_exchange_weak(
head,
new_head.as_ptr(),
Ordering::Release, Ordering::Relaxed, ) {
Ok(_) => return,
Err(p) => head = p,
}
}
}
}
pub struct Node<T> {
data: UnsafeCell<T>,
next: *mut Node<T>,
}
pub struct Box<T, STATE = Init> {
_state: PhantomData<STATE>,
node: NonNull<Node<T>>,
}
impl<T> Box<T, Uninit> {
pub fn init(self, val: T) -> Box<T, Init> {
unsafe {
ptr::write(self.node.as_ref().data.get(), val);
}
Box {
node: self.node,
_state: PhantomData,
}
}
}
pub enum Uninit {}
pub enum Init {}
unsafe impl<T, S> Send for Box<T, S> where T: Send {}
unsafe impl<T, S> Sync for Box<T, S> where T: Sync {}
impl<A> AsSlice for Box<A>
where
A: AsSlice,
{
type Element = A::Element;
fn as_slice(&self) -> &[A::Element] {
self.deref().as_slice()
}
}
impl<A> AsMutSlice for Box<A>
where
A: AsMutSlice,
{
fn as_mut_slice(&mut self) -> &mut [A::Element] {
self.deref_mut().as_mut_slice()
}
}
impl<T> Deref for Box<T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.node.as_ref().data.get() }
}
}
impl<T> DerefMut for Box<T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.node.as_ref().data.get() }
}
}
impl<T> fmt::Debug for Box<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<T as fmt::Debug>::fmt(self, f)
}
}
impl<T> fmt::Display for Box<T>
where
T: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<T as fmt::Display>::fmt(self, f)
}
}
impl<T> PartialEq for Box<T>
where
T: PartialEq,
{
fn eq(&self, rhs: &Box<T>) -> bool {
<T as PartialEq>::eq(self, rhs)
}
}
impl<T> Eq for Box<T> where T: Eq {}
impl<T> PartialOrd for Box<T>
where
T: PartialOrd,
{
fn partial_cmp(&self, rhs: &Box<T>) -> Option<cmp::Ordering> {
<T as PartialOrd>::partial_cmp(self, rhs)
}
}
impl<T> Ord for Box<T>
where
T: Ord,
{
fn cmp(&self, rhs: &Box<T>) -> cmp::Ordering {
<T as Ord>::cmp(self, rhs)
}
}
impl<T> Hash for Box<T>
where
T: Hash,
{
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
<T as Hash>::hash(self, state)
}
}
#[cfg(test)]
mod tests {
use core::{
mem::{self, MaybeUninit},
sync::atomic::{AtomicUsize, Ordering},
};
use super::{Node, Pool};
#[test]
fn grow() {
static mut MEMORY: [u8; 1024] = [0; 1024];
static POOL: Pool<[u8; 128]> = Pool::new();
unsafe {
POOL.grow(&mut MEMORY);
}
for _ in 0..7 {
assert!(POOL.alloc().is_some());
}
}
#[test]
fn grow_exact() {
const SZ: usize = 8;
static mut MEMORY: MaybeUninit<[Node<[u8; 128]>; SZ]> = MaybeUninit::uninit();
static POOL: Pool<[u8; 128]> = Pool::new();
unsafe {
POOL.grow_exact(&mut MEMORY);
}
for _ in 0..SZ {
assert!(POOL.alloc().is_some());
}
assert!(POOL.alloc().is_none());
}
#[test]
fn sanity() {
static mut MEMORY: [u8; 31] = [0; 31];
static POOL: Pool<u8> = Pool::new();
assert!(POOL.alloc().is_none());
POOL.grow(unsafe { &mut MEMORY });
let x = POOL.alloc().unwrap().init(0);
assert_eq!(*x, 0);
assert!(POOL.alloc().is_none());
POOL.free(x);
assert_eq!(*POOL.alloc().unwrap().init(1), 1);
}
#[test]
fn destructors() {
static COUNT: AtomicUsize = AtomicUsize::new(0);
struct X;
impl X {
fn new() -> X {
COUNT.fetch_add(1, Ordering::Relaxed);
X
}
}
impl Drop for X {
fn drop(&mut self) {
COUNT.fetch_sub(1, Ordering::Relaxed);
}
}
static mut MEMORY: [u8; 31] = [0; 31];
static POOL: Pool<X> = Pool::new();
POOL.grow(unsafe { &mut MEMORY });
let x = POOL.alloc().unwrap().init(X::new());
let y = POOL.alloc().unwrap().init(X::new());
let z = POOL.alloc().unwrap().init(X::new());
assert_eq!(COUNT.load(Ordering::Relaxed), 3);
drop(x);
assert_eq!(COUNT.load(Ordering::Relaxed), 3);
mem::forget(y);
assert_eq!(COUNT.load(Ordering::Relaxed), 3);
POOL.free(z);
assert_eq!(COUNT.load(Ordering::Relaxed), 2);
}
}