use crossbeam_epoch::{pin as epoch_pin, unprotected, Atomic, Guard, Owned, Shared};
use crossbeam_utils::Backoff;
use std::{
cell::UnsafeCell,
cmp::min,
mem::{self, MaybeUninit},
slice,
sync::atomic::{AtomicUsize, Ordering},
};
#[cfg(target_pointer_width = "16")]
const BLOCK_SIZE: usize = 16;
#[cfg(target_pointer_width = "32")]
const BLOCK_SIZE: usize = 32;
#[cfg(target_pointer_width = "64")]
const BLOCK_SIZE: usize = 64;
const DEFERRED_BLOCK_BATCH_SIZE: usize = 32;
struct Block<T> {
write: AtomicUsize,
read: AtomicUsize,
slots: [MaybeUninit<UnsafeCell<T>>; BLOCK_SIZE],
next: Atomic<Block<T>>,
}
impl<T> Block<T> {
pub fn new() -> Self {
unsafe { MaybeUninit::zeroed().assume_init() }
}
pub(crate) fn next_len(&self, guard: &Guard) -> usize {
let tail = self.next.load(Ordering::Acquire, guard);
if tail.is_null() {
return 0;
}
let tail_block = unsafe { tail.deref() };
tail_block.len()
}
pub fn len(&self) -> usize {
self.read.load(Ordering::Acquire).trailing_ones() as usize
}
pub fn is_quiesced(&self) -> bool {
let len = self.len();
if len == BLOCK_SIZE {
return true;
}
min(self.write.load(Ordering::Acquire), BLOCK_SIZE) == len
}
pub fn data(&self) -> &[T] {
let len = self.len();
unsafe {
let head = self.slots.get_unchecked(0).as_ptr();
slice::from_raw_parts(head as *const T, len)
}
}
pub fn push(&self, value: T) -> Result<(), T> {
let index = self.write.fetch_add(1, Ordering::AcqRel);
if index >= BLOCK_SIZE {
return Err(value);
}
unsafe {
self.slots.get_unchecked(index).assume_init_ref().get().write(value);
}
self.read.fetch_or(1 << index, Ordering::AcqRel);
Ok(())
}
}
unsafe impl<T: Send> Send for Block<T> {}
unsafe impl<T: Sync> Sync for Block<T> {}
impl<T> Drop for Block<T> {
fn drop(&mut self) {
while !self.is_quiesced() {}
unsafe {
let len = self.len();
for i in 0..len {
self.slots.get_unchecked(i).assume_init_ref().get().drop_in_place();
}
}
}
}
impl<T> std::fmt::Debug for Block<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let has_next = unsafe { !self.next.load(Ordering::Acquire, unprotected()).is_null() };
f.debug_struct("Block")
.field("type", &std::any::type_name::<T>())
.field("block_size", &BLOCK_SIZE)
.field("write", &self.write.load(Ordering::Acquire))
.field("read", &self.read.load(Ordering::Acquire))
.field("len", &self.len())
.field("has_next", &has_next)
.finish()
}
}
#[derive(Debug)]
pub struct AtomicBucket<T> {
tail: Atomic<Block<T>>,
}
impl<T> AtomicBucket<T> {
pub fn new() -> Self {
AtomicBucket { tail: Atomic::null() }
}
pub fn is_empty(&self) -> bool {
let guard = &epoch_pin();
let tail = self.tail.load(Ordering::Acquire, guard);
if tail.is_null() {
return true;
}
let tail_block = unsafe { tail.deref() };
tail_block.len() == 0 && tail_block.next_len(guard) == 0
}
pub fn push(&self, value: T) {
let mut original = value;
let guard = &epoch_pin();
loop {
let mut tail = self.tail.load(Ordering::Acquire, guard);
if tail.is_null() {
match self.tail.compare_exchange(
Shared::null(),
Owned::new(Block::new()),
Ordering::AcqRel,
Ordering::Acquire,
guard,
) {
Ok(ptr) => tail = ptr,
Err(e) => tail = e.current,
}
}
let tail_block = unsafe { tail.deref() };
match tail_block.push(original) {
Ok(_) => return,
Err(value) => {
match self.tail.compare_exchange(
tail,
Owned::new(Block::new()),
Ordering::AcqRel,
Ordering::Acquire,
guard,
) {
Ok(ptr) => {
let new_tail = unsafe { ptr.deref() };
new_tail.next.store(tail, Ordering::Release);
match new_tail.push(value) {
Ok(_) => return,
Err(value) => {
original = value;
continue;
}
}
}
Err(_) => original = value,
}
}
}
}
}
pub fn data(&self) -> Vec<T>
where
T: Clone,
{
let mut values = Vec::new();
self.data_with(|block| values.extend_from_slice(block));
values
}
pub fn data_with<F>(&self, mut f: F)
where
F: FnMut(&[T]),
{
let guard = &epoch_pin();
let backoff = Backoff::new();
let mut block_ptr = self.tail.load(Ordering::Acquire, guard);
while !block_ptr.is_null() {
let block = unsafe { block_ptr.deref() };
while !block.is_quiesced() {
backoff.snooze();
}
let data = block.data();
f(data);
block_ptr = block.next.load(Ordering::Acquire, guard);
}
}
pub fn clear(&self) {
self.clear_with(|_: &[T]| {})
}
pub fn clear_with<F>(&self, mut f: F)
where
F: FnMut(&[T]),
{
let guard = &epoch_pin();
let mut block_ptr = self.tail.load(Ordering::Acquire, guard);
if !block_ptr.is_null()
&& self
.tail
.compare_exchange(
block_ptr,
Shared::null(),
Ordering::SeqCst,
Ordering::SeqCst,
guard,
)
.is_ok()
{
let backoff = Backoff::new();
let mut freeable_blocks = Vec::new();
while !block_ptr.is_null() {
let block = unsafe { block_ptr.deref() };
while !block.is_quiesced() {
backoff.snooze();
}
let data = block.data();
f(data);
let old_block_ptr =
mem::replace(&mut block_ptr, block.next.load(Ordering::Acquire, guard));
freeable_blocks.push(old_block_ptr);
if freeable_blocks.len() >= DEFERRED_BLOCK_BATCH_SIZE {
let blocks = mem::take(&mut freeable_blocks);
unsafe {
guard.defer_unchecked(move || {
for block in blocks {
drop(block.into_owned());
}
});
}
}
}
if !freeable_blocks.is_empty() {
unsafe {
guard.defer_unchecked(move || {
for block in freeable_blocks {
drop(block.into_owned());
}
});
}
}
guard.flush();
}
}
}
impl<T> Default for AtomicBucket<T> {
fn default() -> Self {
Self { tail: Atomic::null() }
}
}
#[cfg(test)]
mod tests {
use super::{AtomicBucket, Block, BLOCK_SIZE};
use crossbeam_utils::thread::scope;
#[test]
fn test_create_new_block() {
let block: Block<u64> = Block::new();
assert_eq!(block.len(), 0);
let data = block.data();
assert_eq!(data.len(), 0);
}
#[test]
fn test_block_write_then_read() {
let block = Block::new();
assert_eq!(block.len(), 0);
let data = block.data();
assert_eq!(data.len(), 0);
let result = block.push(42);
assert!(result.is_ok());
assert_eq!(block.len(), 1);
let data = block.data();
assert_eq!(data.len(), 1);
assert_eq!(data[0], 42);
}
#[test]
fn test_block_write_until_full_then_read() {
let block = Block::new();
assert_eq!(block.len(), 0);
let data = block.data();
assert_eq!(data.len(), 0);
let mut i = 0;
let mut total = 0;
while i < BLOCK_SIZE as u64 {
assert!(block.push(i).is_ok());
total += i;
i += 1;
}
let data = block.data();
assert_eq!(data.len(), BLOCK_SIZE);
let sum: u64 = data.iter().sum();
assert_eq!(sum, total);
let result = block.push(42);
assert!(result.is_err());
}
#[test]
fn test_block_write_until_full_then_read_mt() {
let block = Block::new();
assert_eq!(block.len(), 0);
let data = block.data();
assert_eq!(data.len(), 0);
let res = scope(|s| {
let t1 = s.spawn(|_| {
let mut i = 0;
let mut total = 0;
while i < BLOCK_SIZE as u64 / 2 {
assert!(block.push(i).is_ok());
total += i;
i += 1;
}
total
});
let t2 = s.spawn(|_| {
let mut i = 0;
let mut total = 0;
while i < BLOCK_SIZE as u64 / 2 {
assert!(block.push(i).is_ok());
total += i;
i += 1;
}
total
});
let t1_total = t1.join().unwrap();
let t2_total = t2.join().unwrap();
t1_total + t2_total
});
let total = res.unwrap();
let data = block.data();
assert_eq!(data.len(), BLOCK_SIZE);
let sum: u64 = data.iter().sum();
assert_eq!(sum, total);
let result = block.push(42);
assert!(result.is_err());
}
#[test]
fn test_bucket_write_then_read() {
let bucket = AtomicBucket::new();
bucket.push(42);
let snapshot = bucket.data();
assert_eq!(snapshot.len(), 1);
assert_eq!(snapshot[0], 42);
}
#[test]
fn test_bucket_multiple_blocks_write_then_read() {
let bucket = AtomicBucket::new();
let snapshot = bucket.data();
assert_eq!(snapshot.len(), 0);
let target = (BLOCK_SIZE * 3 + BLOCK_SIZE / 2) as u64;
let mut i = 0;
let mut total = 0;
while i < target {
bucket.push(i);
total += i;
i += 1;
}
let snapshot = bucket.data();
assert_eq!(snapshot.len(), target as usize);
let sum: u64 = snapshot.iter().sum();
assert_eq!(sum, total);
}
#[test]
fn test_bucket_write_then_read_mt() {
let bucket = AtomicBucket::new();
let snapshot = bucket.data();
assert_eq!(snapshot.len(), 0);
let res = scope(|s| {
let t1 = s.spawn(|_| {
let mut i = 0;
let mut total = 0;
while i < BLOCK_SIZE as u64 * 100_000 {
bucket.push(i);
total += i;
i += 1;
}
total
});
let t2 = s.spawn(|_| {
let mut i = 0;
let mut total = 0;
while i < BLOCK_SIZE as u64 * 100_000 {
bucket.push(i);
total += i;
i += 1;
}
total
});
let t1_total = t1.join().unwrap();
let t2_total = t2.join().unwrap();
t1_total + t2_total
});
let total = res.unwrap();
let snapshot = bucket.data();
assert_eq!(snapshot.len(), BLOCK_SIZE * 200_000);
let sum = snapshot.iter().sum::<u64>();
assert_eq!(sum, total);
}
#[test]
fn test_clear_and_clear_with() {
let bucket = AtomicBucket::new();
let snapshot = bucket.data();
assert_eq!(snapshot.len(), 0);
let mut i = 0;
let mut total_pushed = 0;
while i < BLOCK_SIZE * 4 {
bucket.push(i);
total_pushed += i;
i += 1;
}
let snapshot = bucket.data();
assert_eq!(snapshot.len(), i);
let mut total_accumulated = 0;
bucket.clear_with(|xs| total_accumulated += xs.iter().sum::<usize>());
assert_eq!(total_pushed, total_accumulated);
let snapshot = bucket.data();
assert_eq!(snapshot.len(), 0);
}
#[test]
fn test_bucket_len_and_next_len() {
let bucket = AtomicBucket::new();
assert!(bucket.is_empty());
let snapshot = bucket.data();
assert_eq!(snapshot.len(), 0);
let mut i = 0;
while i < BLOCK_SIZE * 2 {
bucket.push(i);
assert!(!bucket.is_empty());
i += 1;
}
}
}