use std::{
array,
fmt::Debug,
future::Future,
mem::MaybeUninit,
ops::{BitAnd, Deref},
};
#[cfg(not(loom))]
use std::{
cell::UnsafeCell,
sync::atomic::{AtomicUsize, Ordering},
thread::LocalKey,
};
use async_local::{AsContext, AsyncLocal, Context};
#[cfg(feature = "async-std-runtime")]
use async_std::task::{spawn, yield_now};
use async_t::async_trait;
use cache_padded::CachePadded;
#[cfg(loom)]
use loom::{
cell::UnsafeCell,
sync::atomic::{AtomicUsize, Ordering},
thread::LocalKey,
};
#[cfg(feature = "tokio-runtime")]
use tokio::task::{spawn, yield_now};
use crate::{
assignment::{CompletionReceipt, PendingAssignment, UnboundedSlice},
helpers::*,
task::{AutoBatchedTask, Receiver, TaskRef},
MAX_BUFFER_LEN, MIN_BUFFER_LEN,
};
#[cfg(target_pointer_width = "64")]
pub(crate) const INDEX_SHIFT: usize = 32;
#[cfg(target_pointer_width = "32")]
pub(crate) const INDEX_SHIFT: usize = 16;
#[doc(hidden)]
pub struct BufferCell<T: Send + Sync + Sized + 'static>(UnsafeCell<MaybeUninit<T>>);
impl<T> Deref for BufferCell<T>
where
T: Send + Sync + Sized + 'static,
{
type Target = UnsafeCell<MaybeUninit<T>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> BufferCell<T>
where
T: Send + Sync + Sized + 'static,
{
fn new_uninit() -> Self {
BufferCell(UnsafeCell::new(MaybeUninit::uninit()))
}
}
unsafe impl<T> Send for BufferCell<T> where T: Send + Sync + Sized + 'static {}
unsafe impl<T> Sync for BufferCell<T> where T: Send + Sync + Sized + 'static {}
#[async_trait]
pub trait TaskQueue: Send + Sync + Sized + 'static {
type Task: Send + Sync + Sized + 'static;
type Value: Send;
#[allow(clippy::needless_lifetimes)]
async fn batch_process<const N: usize>(
assignment: PendingAssignment<'async_trait, Self, N>,
) -> CompletionReceipt<Self>;
fn auto_batch<const N: usize>(task: Self::Task) -> AutoBatchedTask<Self, N>
where
Self: LocalQueue<N, BufferCell = TaskRef<Self>>,
{
AutoBatchedTask::new(task)
}
}
#[async_trait]
pub trait BackgroundQueue: Send + Sync + Sized + 'static {
type Task: Send + Sync + Sized + 'static;
#[allow(clippy::needless_lifetimes)]
async fn batch_process<const N: usize>(tasks: UnboundedSlice<'async_trait, Self::Task, N>);
fn auto_batch<const N: usize>(task: Self::Task)
where
Self: LocalQueue<N, BufferCell = BufferCell<Self::Task>>,
{
StackQueue::background_process::<Self>(task);
}
}
#[async_trait]
pub trait BatchReducer: Send + Sync + Sized + 'static {
type Task: Send + Sync + Sized + 'static;
async fn batch_reduce<const N: usize, F, R, Fut>(task: Self::Task, f: F) -> Option<R>
where
Self: LocalQueue<N, BufferCell = BufferCell<Self::Task>>,
F: FnOnce(UnboundedSlice<'async_trait, Self::Task, N>) -> Fut + Send,
Fut: Future<Output = R> + Send;
}
pub trait LocalQueue<const N: usize> {
type BufferCell: Send + Sync + Sized + 'static;
fn queue() -> &'static LocalKey<StackQueue<Self::BufferCell, N>>;
}
#[doc(hidden)]
pub struct Inner<T: Sync + Sized + 'static, const N: usize = 512> {
pub(crate) slot: CachePadded<AtomicUsize>,
pub(crate) occupancy: CachePadded<AtomicUsize>,
pub(crate) buffer: [T; N],
}
impl<T, const N: usize> Default for Inner<BufferCell<T>, N>
where
T: Send + Sync + Sized + 'static,
{
fn default() -> Self {
let buffer = array::from_fn(|_| BufferCell::new_uninit());
Inner {
slot: CachePadded::new(AtomicUsize::new(0)),
occupancy: CachePadded::new(AtomicUsize::new(0)),
buffer,
}
}
}
impl<T, const N: usize> Default for Inner<TaskRef<T>, N>
where
T: TaskQueue,
{
fn default() -> Self {
let buffer = array::from_fn(|_| TaskRef::new_uninit());
Inner {
slot: CachePadded::new(AtomicUsize::new(0)),
occupancy: CachePadded::new(AtomicUsize::new(0)),
buffer,
}
}
}
impl<T, const N: usize> Inner<BufferCell<T>, N>
where
T: Send + Sync + Sized + 'static,
{
#[cfg(not(loom))]
#[inline(always)]
pub(crate) unsafe fn with_buffer_cell<F, R>(&self, f: F, index: usize) -> R
where
F: FnOnce(*mut MaybeUninit<T>) -> R,
{
let cell = self.buffer.get_unchecked(index);
f(cell.get())
}
#[cfg(loom)]
#[inline(always)]
pub(crate) unsafe fn with_buffer_cell<F, R>(&self, f: F, index: usize) -> R
where
F: FnOnce(*mut MaybeUninit<T>) -> R,
{
let cell = self.buffer.get_unchecked(index);
cell.get_mut().with(f)
}
}
unsafe impl<T, const N: usize> Sync for Inner<T, N> where T: Sync + Sized + 'static {}
#[derive(Debug)]
pub(crate) struct QueueFull;
#[derive(AsContext)]
pub struct StackQueue<T: Sync + Sized + 'static, const N: usize = 512> {
slot: CachePadded<UnsafeCell<usize>>,
occupancy: CachePadded<UnsafeCell<usize>>,
inner: Context<Inner<T, N>>,
}
impl<T, const N: usize> Default for StackQueue<T, N>
where
T: Sync + Sized + 'static,
Inner<T, N>: Default,
{
fn default() -> Self {
assert_eq!(
N,
N.next_power_of_two(),
"StackQueue buffer size must be power of 2"
);
assert!(N >= MIN_BUFFER_LEN);
assert!(N <= MAX_BUFFER_LEN);
StackQueue {
slot: CachePadded::new(UnsafeCell::new(N << INDEX_SHIFT)),
occupancy: CachePadded::new(UnsafeCell::new(0)),
inner: Context::new(Inner::default()),
}
}
}
impl<T, const N: usize> StackQueue<T, N>
where
T: Sync + Sized + 'static,
{
#[cfg(not(loom))]
#[inline(always)]
unsafe fn with_slot<F, R>(&self, f: F) -> R
where
F: FnOnce(*const usize) -> R,
{
f(self.slot.get())
}
#[cfg(loom)]
#[inline(always)]
unsafe fn with_slot<F, R>(&self, f: F) -> R
where
F: FnOnce(*const usize) -> R,
{
self.slot.get().with(f)
}
#[cfg(not(loom))]
#[inline(always)]
unsafe fn with_slot_mut<F, R>(&self, f: F) -> R
where
F: FnOnce(*mut usize) -> R,
{
f(self.slot.get())
}
#[cfg(loom)]
#[inline(always)]
unsafe fn with_slot_mut<F, R>(&self, f: F) -> R
where
F: FnOnce(*mut usize) -> R,
{
self.slot.get_mut().with(f)
}
#[cfg(not(loom))]
#[inline(always)]
unsafe fn with_occupancy<F, R>(&self, f: F) -> R
where
F: FnOnce(*const usize) -> R,
{
f(self.occupancy.get())
}
#[cfg(loom)]
#[inline(always)]
unsafe fn with_occupancy<F, R>(&self, f: F) -> R
where
F: FnOnce(*const usize) -> R,
{
self.occupancy.get().with(f)
}
#[cfg(not(loom))]
#[inline(always)]
unsafe fn with_occupancy_mut<F, R>(&self, f: F) -> R
where
F: FnOnce(*mut usize) -> R,
{
f(self.occupancy.get())
}
#[cfg(loom)]
#[inline(always)]
unsafe fn with_occupancy_mut<F, R>(&self, f: F) -> R
where
F: FnOnce(*mut usize) -> R,
{
self.occupancy.get_mut().with(f)
}
#[inline(always)]
fn current_write_index(&self) -> usize {
unsafe { self.with_slot(|val| slot_index::<N>(&*val)) }
}
fn check_regional_occupancy(&self, index: &usize) -> Result<(), QueueFull> {
let region_mask = region_mask::<N>(index);
let regional_occupancy =
unsafe { self.with_occupancy(|occupancy| (*occupancy).bitand(region_mask)) };
if regional_occupancy.eq(&0) {
return Ok(());
}
let occupancy = self.inner.occupancy.load(Ordering::Relaxed);
let regional_occupancy = occupancy.bitand(region_mask);
unsafe {
self.with_occupancy_mut(move |val| *val = occupancy);
}
if regional_occupancy.eq(&0) {
Ok(())
} else {
Err(QueueFull)
}
}
fn occupy_region(&self, index: &usize) {
let shifted_add = one_shifted::<N>(index);
let occupancy = self
.inner
.occupancy
.fetch_add(shifted_add, Ordering::Relaxed)
.wrapping_add(shifted_add);
unsafe {
self.with_occupancy_mut(move |val| *val = occupancy);
}
}
#[inline(always)]
unsafe fn replace_slot(&self, slot: usize) -> usize {
self.with_slot_mut(move |val| std::mem::replace(&mut *val, slot))
}
}
impl<T, const N: usize> StackQueue<TaskRef<T>, N>
where
T: TaskQueue,
{
unsafe fn write_with<F>(&self, index: &usize, write_with: F)
where
F: FnOnce(*const AtomicUsize) -> (T::Task, *const Receiver<T>),
{
let task_ref = self.inner.buffer.get_unchecked(*index);
let (task, rx) = write_with(task_ref.state_ptr());
task_ref.set_task(task, rx);
}
pub(crate) fn enqueue<'a, F>(
&self,
write_with: F,
) -> Result<Option<PendingAssignment<'a, T, N>>, QueueFull>
where
T: LocalQueue<N, BufferCell = TaskRef<T>>,
F: FnOnce(*const AtomicUsize) -> (T::Task, *const Receiver<T>),
{
let write_index = self.current_write_index();
if write_index.bitand(region_size::<N>() - 1).eq(&0) {
self.check_regional_occupancy(&write_index)?;
}
unsafe {
self.write_with(&write_index, write_with);
}
let base_slot = self
.inner
.slot
.fetch_add(1 << INDEX_SHIFT, Ordering::Relaxed);
let prev_slot = unsafe { self.replace_slot(base_slot.wrapping_add(1 << INDEX_SHIFT)) };
if write_index.ne(&0) && ((base_slot ^ prev_slot) & active_phase_bit::<N>(&base_slot)).eq(&0) {
Ok(None)
} else {
self.occupy_region(&write_index);
let queue = unsafe { T::queue().guarded_ref() };
Ok(Some(PendingAssignment::new(base_slot, queue)))
}
}
}
impl<T, const N: usize> StackQueue<BufferCell<T>, N>
where
T: Send + Sync + Sized + 'static,
{
#[doc(hidden)]
pub unsafe fn push<'a, Q>(&self, task: T) -> Result<Option<UnboundedSlice<'a, T, N>>, T>
where
Q: LocalQueue<N, BufferCell = BufferCell<T>>,
{
let write_index = self.current_write_index();
if write_index.bitand(region_size::<N>() - 1).eq(&0)
&& self.check_regional_occupancy(&write_index).is_err()
{
return Err(task);
}
unsafe {
self
.inner
.with_buffer_cell(|cell| cell.write(MaybeUninit::new(task)), write_index);
}
let base_slot = self
.inner
.slot
.fetch_add(1 << INDEX_SHIFT, Ordering::Relaxed);
let prev_slot = unsafe { self.replace_slot(base_slot.wrapping_add(1 << INDEX_SHIFT)) };
if write_index.ne(&0) && ((base_slot ^ prev_slot) & active_phase_bit::<N>(&base_slot)).eq(&0) {
Ok(None)
} else {
self.occupy_region(&write_index);
let queue = unsafe { Q::queue().guarded_ref() };
Ok(Some(UnboundedSlice::new(base_slot, queue)))
}
}
fn background_process<Q>(task: T)
where
Q: BackgroundQueue<Task = T> + LocalQueue<N, BufferCell = BufferCell<T>>,
{
Q::queue().with(|queue| match unsafe { queue.push::<Q>(task) } {
Ok(Some(assignment)) => {
spawn(async move {
Q::batch_process::<N>(assignment).await;
});
}
Ok(None) => {}
Err(mut task) => {
spawn(async move {
loop {
task = match Q::queue().with(|queue| unsafe { queue.push::<Q>(task) }) {
Ok(Some(assignment)) => {
Q::batch_process::<N>(assignment).await;
return;
}
Ok(None) => {
return;
}
Err(task) => {
yield_now().await;
task
}
};
}
});
}
});
}
}
#[cfg(all(
test,
any(loom, feature = "tokio-runtime", feature = "async-std-runtime")
))]
mod test {
#[cfg(not(loom))]
use std::{thread, time::Duration};
#[cfg(feature = "async-std-runtime")]
use async_std::task::yield_now;
#[cfg(not(loom))]
use futures::{future::join_all, stream::FuturesUnordered, StreamExt};
#[cfg(all(not(loom), feature = "tokio-runtime"))]
use tokio::{
sync::{oneshot, Barrier},
task::{spawn, yield_now},
};
use crate::{
assignment::{CompletionReceipt, PendingAssignment},
local_queue, TaskQueue,
};
#[cfg(all(not(loom), feature = "tokio-runtime"))]
use crate::{queue::UnboundedSlice, BackgroundQueue};
struct EchoQueue;
#[local_queue]
impl TaskQueue for EchoQueue {
type Task = usize;
type Value = usize;
async fn batch_process<const N: usize>(
batch: PendingAssignment<'async_trait, Self, N>,
) -> CompletionReceipt<Self> {
batch.into_assignment().map(|val| val)
}
}
#[cfg(not(loom))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn it_process_tasks() {
let batch: Vec<usize> = join_all((0..100).map(EchoQueue::auto_batch)).await;
assert_eq!(batch, (0..100).collect::<Vec<usize>>());
}
#[cfg(not(loom))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn it_cycles() {
for i in 0..512 {
EchoQueue::auto_batch(i).await;
}
}
#[cfg(not(loom))]
struct SlowQueue;
#[cfg(not(loom))]
#[local_queue]
impl TaskQueue for SlowQueue {
type Task = usize;
type Value = usize;
async fn batch_process<const N: usize>(
batch: PendingAssignment<'async_trait, Self, N>,
) -> CompletionReceipt<Self> {
batch
.with_blocking(|batch| {
let assignment = batch.into_assignment();
thread::sleep(Duration::from_millis(50));
assignment.map(|task| task)
})
.await
}
}
#[cfg(all(not(loom), any(feature = "tokio-runtime")))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
async fn it_has_drop_safety() {
let handle = spawn(async {
SlowQueue::auto_batch(0).await;
});
yield_now().await;
handle.abort();
}
#[cfg(not(loom))]
struct YieldQueue;
#[cfg(not(loom))]
#[local_queue]
impl TaskQueue for YieldQueue {
type Task = usize;
type Value = usize;
async fn batch_process<const N: usize>(
batch: PendingAssignment<'async_trait, Self, N>,
) -> CompletionReceipt<Self> {
let assignment = batch.into_assignment();
yield_now().await;
assignment.map(|val| val)
}
}
#[cfg(all(not(loom), any(feature = "tokio-runtime")))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
async fn it_negotiates_receiver_drop() {
use std::sync::Arc;
let tasks: FuturesUnordered<_> = (0..256)
.map(|i| async move {
let barrier = Arc::new(Barrier::new(2));
let task_barrier = barrier.clone();
let handle = tokio::task::spawn(async move {
task_barrier.wait().await;
YieldQueue::auto_batch(i).await;
});
barrier.wait().await;
yield_now().await;
handle.abort()
})
.collect();
tasks.collect::<Vec<_>>().await;
}
#[cfg(loom)]
#[test]
fn it_negotiates_receiver_drop() {
use std::{hint::unreachable_unchecked, ptr::addr_of};
use futures::pin_mut;
use futures_test::task::noop_waker;
use loom::sync::{Arc, Condvar, Mutex};
use crate::task::{AutoBatchedTask, Receiver, State, TaskRef};
loom::model(|| {
let task: Arc<TaskRef<EchoQueue>> = Arc::new(TaskRef::new_uninit());
let barrier = Arc::new((Mutex::new(false), Condvar::new()));
let resolver_handle = {
let task = task.clone();
let barrier = barrier.clone();
loom::thread::spawn(move || {
let (lock, cvar) = &*barrier;
let mut task_initialized = lock.lock().unwrap();
while !*task_initialized {
task_initialized = cvar.wait(task_initialized).unwrap();
}
unsafe {
task.resolve_unchecked(9001);
}
})
};
let receiver_handle = {
loom::thread::spawn(move || {
let waker = noop_waker();
let rx: Receiver<EchoQueue> = Receiver::new(task.state_ptr(), waker);
let auto_batched_task: AutoBatchedTask<EchoQueue, 256> = AutoBatchedTask {
state: State::Batched(rx),
};
pin_mut!(auto_batched_task);
let rx = match &auto_batched_task.state {
State::Batched(rx) => {
addr_of!(*rx)
}
_ => unsafe { unreachable_unchecked() },
};
unsafe { task.set_task(9001, rx) };
let (lock, cvar) = &*barrier;
let mut task_initialized = lock.lock().unwrap();
*task_initialized = true;
cvar.notify_one();
#[allow(clippy::drop_non_drop)]
drop(auto_batched_task);
})
};
resolver_handle.join().unwrap();
receiver_handle.join().unwrap();
});
}
#[cfg(all(not(loom), any(feature = "tokio-runtime")))]
struct EchoBackgroundQueue;
#[cfg(all(not(loom), any(feature = "tokio-runtime")))]
#[local_queue]
impl BackgroundQueue for EchoBackgroundQueue {
type Task = (usize, oneshot::Sender<usize>);
async fn batch_process<const N: usize>(tasks: UnboundedSlice<'async_trait, Self::Task, N>) {
let tasks = tasks.into_bounded().to_vec();
for (val, tx) in tasks.into_iter() {
tx.send(val).ok();
}
}
}
#[cfg(all(not(loom), any(feature = "tokio-runtime")))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
async fn it_process_background_tasks() {
#[allow(clippy::needless_collect)]
let receivers: Vec<_> = (0..100_usize)
.map(|i| {
let (tx, rx) = oneshot::channel::<usize>();
EchoBackgroundQueue::auto_batch((i, tx));
rx
})
.collect();
for (i, rx) in receivers.into_iter().enumerate() {
assert_eq!(rx.await, Ok(i));
}
}
#[cfg(not(loom))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn it_batch_reduces() {
use crate::BatchReducer;
struct MultiSetter;
#[local_queue]
impl BatchReducer for MultiSetter {
type Task = usize;
}
let tasks: FuturesUnordered<_> = (0..10000)
.map(|i| {
MultiSetter::batch_reduce(i, |slice| async move {
slice.into_bounded().iter().sum::<usize>()
})
})
.collect();
let total = tasks
.fold(0_usize, |total, value| async move {
total + value.unwrap_or_default()
})
.await;
assert_eq!(total, (0..10000).sum());
}
}