use alloc::boxed::Box;
use alloc::sync::Arc;
use core::iter::FusedIterator;
use core::mem::{drop, MaybeUninit};
use core::panic::{RefUnwindSafe, UnwindSafe};
use core::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
use cache_padded::CachePadded;
use crate::config::{AtomicUnsignedLong, AtomicUnsignedShort, UnsignedShort};
use crate::loom_exports::cell::UnsafeCell;
use crate::loom_exports::{debug_or_loom_assert, debug_or_loom_assert_eq};
use crate::{allocate_buffer, pack, unpack, StealError};
#[derive(Debug)]
struct Queue<T> {
heads: CachePadded<AtomicUnsignedLong>,
tail: CachePadded<AtomicUnsignedShort>,
buffer: Box<[UnsafeCell<MaybeUninit<T>>]>,
mask: UnsignedShort,
}
impl<T> Queue<T> {
#[inline]
unsafe fn read_at(&self, position: UnsignedShort) -> T {
let index = (position & self.mask) as usize;
(*self.buffer).as_ref()[index].with(|slot| slot.read().assume_init())
}
#[inline]
unsafe fn write_at(&self, position: UnsignedShort, item: T) {
let index = (position & self.mask) as usize;
(*self.buffer).as_ref()[index].with_mut(|slot| slot.write(MaybeUninit::new(item)));
}
fn book_items<C>(
&self,
mut count_fn: C,
max_count: UnsignedShort,
) -> Result<(UnsignedShort, UnsignedShort), StealError>
where
C: FnMut(usize) -> usize,
{
let mut heads = self.heads.load(Acquire);
loop {
let (worker_head, stealer_head) = unpack(heads);
if stealer_head != worker_head {
return Err(StealError::Busy);
}
let tail = self.tail.load(Acquire);
let item_count = tail.wrapping_sub(worker_head);
if item_count == 0 {
return Err(StealError::Empty);
}
let count = (count_fn(item_count as usize).min(max_count as usize) as UnsignedShort)
.min(item_count);
if count == 0 {
return Err(StealError::Empty);
}
let new_heads = pack(worker_head.wrapping_add(count), stealer_head);
match self
.heads
.compare_exchange_weak(heads, new_heads, Acquire, Acquire)
{
Ok(_) => return Ok((stealer_head, count)),
Err(h) => heads = h,
}
}
}
#[inline]
fn capacity(&self) -> UnsignedShort {
self.mask.wrapping_add(1)
}
}
impl<T> Drop for Queue<T> {
fn drop(&mut self) {
let worker_head = unpack(self.heads.load(Relaxed)).0;
let tail = self.tail.load(Relaxed);
let count = tail.wrapping_sub(worker_head);
for offset in 0..count {
drop(unsafe { self.read_at(worker_head.wrapping_add(offset)) })
}
}
}
#[derive(Debug)]
pub struct Worker<T> {
queue: Arc<Queue<T>>,
}
impl<T> Worker<T> {
pub fn new(min_capacity: usize) -> Self {
const MAX_CAPACITY: usize = 1 << (UnsignedShort::BITS - 1);
assert!(
min_capacity <= MAX_CAPACITY,
"the capacity of the queue cannot exceed {}",
MAX_CAPACITY
);
let capacity = min_capacity.next_power_of_two();
let buffer = allocate_buffer(capacity);
let mask = capacity as UnsignedShort - 1;
let queue = Arc::new(Queue {
heads: CachePadded::new(AtomicUnsignedLong::new(0)),
tail: CachePadded::new(AtomicUnsignedShort::new(0)),
buffer,
mask,
});
Worker { queue }
}
pub fn stealer(&self) -> Stealer<T> {
Stealer {
queue: self.queue.clone(),
}
}
pub fn capacity(&self) -> usize {
self.queue.capacity() as usize
}
pub fn spare_capacity(&self) -> usize {
let stealer_head = unpack(self.queue.heads.load(Relaxed)).1;
let tail = self.queue.tail.load(Relaxed);
let len = tail.wrapping_sub(stealer_head);
(self.queue.capacity() - len) as usize
}
pub fn is_empty(&self) -> bool {
let worker_head = unpack(self.queue.heads.load(Relaxed)).0;
let tail = self.queue.tail.load(Relaxed);
tail == worker_head
}
pub fn push(&self, item: T) -> Result<(), T> {
let stealer_head = unpack(self.queue.heads.load(Acquire)).1;
let tail = self.queue.tail.load(Relaxed);
if tail.wrapping_sub(stealer_head) > self.queue.mask {
return Err(item);
}
unsafe { self.queue.write_at(tail, item) };
self.queue.tail.store(tail.wrapping_add(1), Release);
Ok(())
}
pub fn extend<I: IntoIterator<Item = T>>(&self, iter: I) {
let stealer_head = unpack(self.queue.heads.load(Acquire)).1;
let mut tail = self.queue.tail.load(Relaxed);
let max_tail = stealer_head.wrapping_add(self.queue.capacity());
for item in iter {
if tail == max_tail {
break;
}
unsafe { self.queue.write_at(tail, item) };
tail = tail.wrapping_add(1);
}
self.queue.tail.store(tail, Release);
}
pub fn pop(&self) -> Option<T> {
let mut heads = self.queue.heads.load(Acquire);
let prev_worker_head = loop {
let (worker_head, stealer_head) = unpack(heads);
let tail = self.queue.tail.load(Relaxed);
if tail == worker_head {
return None;
}
let next_heads = pack(
worker_head.wrapping_add(1),
stealer_head.wrapping_add((stealer_head == worker_head) as UnsignedShort),
);
let res = self
.queue
.heads
.compare_exchange_weak(heads, next_heads, AcqRel, Acquire);
match res {
Ok(_) => break worker_head,
Err(h) => heads = h,
}
};
unsafe { Some(self.queue.read_at(prev_worker_head)) }
}
pub fn drain<C>(&self, count_fn: C) -> Result<Drain<'_, T>, StealError>
where
C: FnMut(usize) -> usize,
{
let (head, count) = self.queue.book_items(count_fn, UnsignedShort::MAX)?;
Ok(Drain {
queue: &self.queue,
head,
from_head: head,
to_head: head.wrapping_add(count),
})
}
}
impl<T> UnwindSafe for Worker<T> {}
impl<T> RefUnwindSafe for Worker<T> {}
unsafe impl<T: Send> Send for Worker<T> {}
#[derive(Debug)]
pub struct Drain<'a, T> {
queue: &'a Queue<T>,
head: UnsignedShort,
from_head: UnsignedShort,
to_head: UnsignedShort,
}
impl<'a, T> Iterator for Drain<'a, T> {
type Item = T;
fn next(&mut self) -> Option<T> {
if self.head == self.to_head {
return None;
}
let item = Some(unsafe { self.queue.read_at(self.head) });
self.head = self.head.wrapping_add(1);
if self.head == self.to_head {
let mut heads = self.queue.heads.load(Relaxed);
loop {
let (worker_head, stealer_head) = unpack(heads);
debug_or_loom_assert_eq!(stealer_head, self.from_head);
let res = self.queue.heads.compare_exchange_weak(
heads,
pack(worker_head, worker_head),
AcqRel,
Acquire,
);
match res {
Ok(_) => break,
Err(h) => {
heads = h;
}
}
}
}
item
}
fn size_hint(&self) -> (usize, Option<usize>) {
let sz = self.to_head.wrapping_sub(self.head) as usize;
(sz, Some(sz))
}
}
impl<'a, T> ExactSizeIterator for Drain<'a, T> {}
impl<'a, T> FusedIterator for Drain<'a, T> {}
impl<'a, T> Drop for Drain<'a, T> {
fn drop(&mut self) {
for _item in self {}
}
}
impl<'a, T> UnwindSafe for Drain<'a, T> {}
impl<'a, T> RefUnwindSafe for Drain<'a, T> {}
unsafe impl<'a, T: Send> Send for Drain<'a, T> {}
unsafe impl<'a, T: Send> Sync for Drain<'a, T> {}
#[derive(Debug)]
pub struct Stealer<T> {
queue: Arc<Queue<T>>,
}
impl<T> Stealer<T> {
pub fn steal<C>(&self, dest: &Worker<T>, count_fn: C) -> Result<usize, StealError>
where
C: FnMut(usize) -> usize,
{
let dest_tail = dest.queue.tail.load(Relaxed);
let dest_stealer_head = unpack(dest.queue.heads.load(Acquire)).1;
let dest_free_capacity = dest.queue.capacity() - dest_tail.wrapping_sub(dest_stealer_head);
debug_or_loom_assert!(dest_free_capacity <= dest.queue.capacity());
let (stealer_head, transfer_count) = self.queue.book_items(count_fn, dest_free_capacity)?;
debug_or_loom_assert!(transfer_count <= dest_free_capacity);
for offset in 0..transfer_count {
unsafe {
let item = self.queue.read_at(stealer_head.wrapping_add(offset));
dest.queue.write_at(dest_tail.wrapping_add(offset), item);
}
}
dest.queue
.tail
.store(dest_tail.wrapping_add(transfer_count), Release);
let mut heads = self.queue.heads.load(Relaxed);
loop {
let (worker_head, sh) = unpack(heads);
debug_or_loom_assert_eq!(stealer_head, sh);
let res = self.queue.heads.compare_exchange_weak(
heads,
pack(worker_head, worker_head),
AcqRel,
Acquire,
);
match res {
Ok(_) => return Ok(transfer_count as usize),
Err(h) => {
heads = h;
}
}
}
}
pub fn steal_and_pop<C>(&self, dest: &Worker<T>, count_fn: C) -> Result<(T, usize), StealError>
where
C: FnMut(usize) -> usize,
{
let dest_tail = dest.queue.tail.load(Relaxed);
let dest_stealer_head = unpack(dest.queue.heads.load(Acquire)).1;
let dest_free_capacity = dest.queue.capacity() - dest_tail.wrapping_sub(dest_stealer_head);
debug_or_loom_assert!(dest_free_capacity <= dest.queue.capacity());
let (stealer_head, count) = self.queue.book_items(count_fn, dest_free_capacity + 1)?;
let transfer_count = count - 1;
debug_or_loom_assert!(transfer_count <= dest_free_capacity);
for offset in 0..transfer_count {
unsafe {
let item = self.queue.read_at(stealer_head.wrapping_add(offset));
dest.queue.write_at(dest_tail.wrapping_add(offset), item);
}
}
let last_item = unsafe {
self.queue
.read_at(stealer_head.wrapping_add(transfer_count))
};
dest.queue
.tail
.store(dest_tail.wrapping_add(transfer_count), Release);
let mut heads = self.queue.heads.load(Relaxed);
loop {
let (worker_head, sh) = unpack(heads);
debug_or_loom_assert_eq!(stealer_head, sh);
let res = self.queue.heads.compare_exchange_weak(
heads,
pack(worker_head, worker_head),
AcqRel,
Acquire,
);
match res {
Ok(_) => return Ok((last_item, transfer_count as usize)),
Err(h) => {
heads = h;
}
}
}
}
}
impl<T> Clone for Stealer<T> {
fn clone(&self) -> Self {
Stealer {
queue: self.queue.clone(),
}
}
}
impl<T> UnwindSafe for Stealer<T> {}
impl<T> RefUnwindSafe for Stealer<T> {}
unsafe impl<T: Send> Send for Stealer<T> {}
unsafe impl<T: Send> Sync for Stealer<T> {}