use std::cell::UnsafeCell;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicU64, Ordering};
#[repr(C)]
pub struct FreeBlock {
next: UnsafeCell<Option<NonNull<FreeBlock>>>,
}
impl FreeBlock {
#[inline]
pub unsafe fn read_next(&self) -> Option<NonNull<FreeBlock>> {
unsafe { *self.next.get() }
}
#[inline]
pub unsafe fn write_next(&self, val: Option<NonNull<FreeBlock>>) {
unsafe { self.next.get().write(val) }
}
}
unsafe impl Send for FreeBlock {}
pub struct TreiberStack {
head: AtomicU64,
}
impl TreiberStack {
pub const fn new() -> Self {
Self {
head: AtomicU64::new(0),
}
}
#[inline]
fn pack(ptr: Option<NonNull<FreeBlock>>, tag: u16) -> u64 {
let raw = match ptr {
Some(p) => p.as_ptr() as u64,
None => 0,
};
((tag as u64) << 48) | (raw & 0x0000_FFFF_FFFF_FFFF)
}
#[inline]
fn unpack(val: u64) -> (Option<NonNull<FreeBlock>>, u16) {
let raw = (val & 0x0000_FFFF_FFFF_FFFF) as *mut FreeBlock;
let ptr = NonNull::new(raw);
let tag = (val >> 48) as u16;
(ptr, tag)
}
pub fn push(&self, block: NonNull<FreeBlock>) {
loop {
let current = self.head.load(Ordering::Acquire);
let (head, tag) = Self::unpack(current);
unsafe { block.as_ref().write_next(head) };
let new = Self::pack(Some(block), tag.wrapping_add(1));
if self
.head
.compare_exchange_weak(current, new, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return;
}
}
}
pub fn pop(&self) -> Option<NonNull<FreeBlock>> {
loop {
let current = self.head.load(Ordering::Acquire);
let (head, tag) = Self::unpack(current);
let head = head?;
let next = unsafe { head.as_ref().read_next() };
let new = Self::pack(next, tag.wrapping_add(1));
if self
.head
.compare_exchange_weak(current, new, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return Some(head);
}
}
}
pub fn push_chain(&self, first: NonNull<FreeBlock>, last: NonNull<FreeBlock>) {
loop {
let current = self.head.load(Ordering::Acquire);
let (head, tag) = Self::unpack(current);
unsafe { last.as_ref().write_next(head) };
let new = Self::pack(Some(first), tag.wrapping_add(1));
if self
.head
.compare_exchange_weak(current, new, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return;
}
}
}
#[cfg(test)]
pub fn is_empty(&self) -> bool {
let (head, _) = Self::unpack(self.head.load(Ordering::Relaxed));
head.is_none()
}
}
unsafe impl Send for TreiberStack {}
unsafe impl Sync for TreiberStack {}
pub struct ThreadFreelist {
head: Option<NonNull<FreeBlock>>,
tail: Option<NonNull<FreeBlock>>,
count: usize,
}
impl ThreadFreelist {
pub const fn new() -> Self {
Self {
head: None,
tail: None,
count: 0,
}
}
#[inline]
pub fn push(&mut self, block: NonNull<FreeBlock>) {
unsafe { block.as_ref().write_next(self.head) };
if self.head.is_none() {
self.tail = Some(block);
}
self.head = Some(block);
self.count += 1;
}
#[inline]
pub fn pop(&mut self) -> Option<NonNull<FreeBlock>> {
let block = self.head?;
self.head = unsafe { block.as_ref().read_next() };
self.count -= 1;
if self.head.is_none() {
self.tail = None;
}
Some(block)
}
#[inline]
pub fn push_chain(
&mut self,
first: NonNull<FreeBlock>,
last: NonNull<FreeBlock>,
chain_count: usize,
) {
unsafe { last.as_ref().write_next(self.head) };
if self.head.is_none() {
self.tail = Some(last);
}
self.head = Some(first);
self.count += chain_count;
}
#[inline]
pub fn count(&self) -> usize {
self.count
}
#[cfg(test)]
#[inline]
pub fn is_empty(&self) -> bool {
self.head.is_none()
}
pub fn drain_all(&mut self) -> Option<(NonNull<FreeBlock>, NonNull<FreeBlock>, usize)> {
if self.count == 0 {
return None;
}
let head = self.head.unwrap();
let tail = self.tail.unwrap();
let count = self.count;
self.head = None;
self.tail = None;
self.count = 0;
Some((head, tail, count))
}
pub fn drain(
&mut self,
keep: usize,
) -> Option<(NonNull<FreeBlock>, NonNull<FreeBlock>, usize)> {
if self.count <= keep {
return None;
}
let drain_count = self.count - keep;
let mut split = self.head.unwrap();
for _ in 1..keep {
split = unsafe { split.as_ref().read_next().unwrap() };
}
let chain_head = unsafe { split.as_ref().read_next().unwrap() };
let chain_tail = self.tail.unwrap();
unsafe { split.as_ref().write_next(None) };
self.tail = Some(split);
self.count = keep;
Some((chain_head, chain_tail, drain_count))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn alloc_block() -> NonNull<FreeBlock> {
let block = Box::new(FreeBlock {
next: UnsafeCell::new(None),
});
NonNull::from(Box::leak(block))
}
unsafe fn free_block(b: NonNull<FreeBlock>) {
drop(unsafe { Box::from_raw(b.as_ptr()) });
}
#[test]
fn thread_freelist_push_pop() {
let mut fl = ThreadFreelist::new();
assert!(fl.is_empty());
let b1 = alloc_block();
let b2 = alloc_block();
fl.push(b1);
fl.push(b2);
assert_eq!(fl.count(), 2);
let p1 = fl.pop().unwrap();
assert_eq!(p1, b2); let p2 = fl.pop().unwrap();
assert_eq!(p2, b1);
assert!(fl.is_empty());
unsafe {
free_block(b1);
free_block(b2);
}
}
#[test]
fn thread_freelist_drain() {
let mut fl = ThreadFreelist::new();
let mut blocks = Vec::new();
for _ in 0..10 {
let b = alloc_block();
fl.push(b);
blocks.push(b);
}
assert_eq!(fl.count(), 10);
let (chain_head, chain_tail, drained) = fl.drain(4).unwrap();
assert_eq!(drained, 6);
assert_eq!(fl.count(), 4);
let mut n = chain_head;
let mut chain_len = 1usize;
while let Some(next) = unsafe { n.as_ref().read_next() } {
n = next;
chain_len += 1;
}
assert_eq!(chain_len, 6);
assert_eq!(n, chain_tail);
for b in blocks {
unsafe { free_block(b) };
}
}
#[test]
fn treiber_push_pop() {
let stack = TreiberStack::new();
assert!(stack.is_empty());
let b1 = alloc_block();
let b2 = alloc_block();
stack.push(b1);
stack.push(b2);
assert_eq!(stack.pop().unwrap(), b2);
assert_eq!(stack.pop().unwrap(), b1);
assert!(stack.pop().is_none());
unsafe {
free_block(b1);
free_block(b2);
}
}
#[test]
fn treiber_push_chain() {
let stack = TreiberStack::new();
let b1 = alloc_block();
let b2 = alloc_block();
let b3 = alloc_block();
unsafe {
b1.as_ref().write_next(Some(b2));
b2.as_ref().write_next(Some(b3));
b3.as_ref().write_next(None);
}
stack.push_chain(b1, b3);
assert_eq!(stack.pop().unwrap(), b1);
assert_eq!(stack.pop().unwrap(), b2);
assert_eq!(stack.pop().unwrap(), b3);
assert!(stack.pop().is_none());
unsafe {
free_block(b1);
free_block(b2);
free_block(b3);
}
}
#[test]
fn treiber_concurrent_push_pop() {
use std::sync::Arc;
use std::thread;
let stack = Arc::new(TreiberStack::new());
let num_threads = 8;
let ops_per_thread = 1000;
let handles: Vec<_> = (0..num_threads)
.map(|_| {
let s = Arc::clone(&stack);
thread::spawn(move || {
for _ in 0..ops_per_thread {
let b = alloc_block();
s.push(b);
}
for _ in 0..ops_per_thread {
let p = s.pop();
assert!(p.is_some());
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
let mut remaining = 0;
while stack.pop().is_some() {
remaining += 1;
}
assert_eq!(remaining, 0);
}
}