use crate::sync::primitive::{Arc, AtomicUsize, AtomicWaker, Ordering};
use core::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use crossbeam_utils::CachePadded;
pub fn channel() -> (Sender, Receiver) {
let state = Arc::new(State::default());
let sender = Sender(state.clone());
let receiver = Receiver { state, credits: 0 };
(sender, receiver)
}
pub struct Receiver {
state: Arc<State>,
credits: usize,
}
impl Receiver {
#[inline]
pub async fn acquire(&mut self) -> Option<usize> {
Acquire(self).await
}
#[inline]
pub fn poll_acquire(&mut self, cx: &mut Context) -> Poll<Option<usize>> {
let state = &*self.state;
macro_rules! acquire {
() => {{
self.credits += state.remaining.swap(0, Ordering::Acquire);
if self.credits > 0 {
return Poll::Ready(Some(self.credits));
}
}};
}
acquire!();
state.receiver.register(cx.waker());
acquire!();
if state.senders.load(Ordering::Acquire) == 0 {
acquire!();
return Poll::Ready(None);
}
Poll::Pending
}
#[inline]
pub fn finish(&mut self, count: usize) {
debug_assert!(self.credits >= count);
self.credits -= count;
}
}
#[derive(Clone)]
pub struct Sender(Arc<State>);
impl Sender {
#[inline]
pub fn submit(&self, count: usize) {
let state = &*self.0;
state.remaining.fetch_add(count, Ordering::Release);
state.receiver.wake();
}
}
impl Drop for Sender {
#[inline]
fn drop(&mut self) {
let state = &*self.0;
state.senders.fetch_sub(1, Ordering::Release);
state.receiver.wake();
}
}
struct State {
remaining: CachePadded<AtomicUsize>,
receiver: AtomicWaker,
senders: CachePadded<AtomicUsize>,
}
impl Default for State {
fn default() -> Self {
Self {
remaining: Default::default(),
receiver: Default::default(),
senders: AtomicUsize::new(1).into(),
}
}
}
struct Acquire<'a>(&'a mut Receiver);
impl Future for Acquire<'_> {
type Output = Option<usize>;
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.0.poll_acquire(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testing::loom;
fn loom_scenario(iterations: usize, send_batch_size: usize, recv_batch_size: usize) {
assert_ne!(send_batch_size, 0);
assert_ne!(recv_batch_size, 0);
loom::model(move || {
let (send, mut recv) = channel();
let sender = loom::thread::spawn(move || {
for _ in 0..iterations {
send.submit(send_batch_size);
loom::hint::spin_loop();
}
});
let receiver = loom::thread::spawn(move || {
loom::future::block_on(async move {
let mut total = 0;
while let Some(mut count) = recv.acquire().await {
assert_ne!(count, 0);
while count > 0 {
let to_finish = count.min(recv_batch_size);
recv.finish(to_finish);
total += to_finish;
count -= to_finish;
}
}
assert_eq!(total, iterations * send_batch_size);
})
});
if cfg!(not(loom)) {
sender.join().unwrap();
receiver.join().unwrap();
}
});
}
const ITERATIONS: usize = if cfg!(loom) { 1 } else { 100 };
const SEND_BATCH_SIZE: usize = if cfg!(loom) { 2 } else { 8 };
const RECV_BATCH_SIZE: usize = if cfg!(loom) { 2 } else { 8 };
#[test]
fn loom_no_items() {
loom_scenario(0, 1, 1);
}
#[test]
fn loom_single_item() {
loom_scenario(ITERATIONS, 1, 1);
}
#[test]
fn loom_send_batch() {
loom_scenario(ITERATIONS, SEND_BATCH_SIZE, 1);
}
#[test]
fn loom_recv_batch() {
loom_scenario(ITERATIONS, 1, RECV_BATCH_SIZE);
}
#[test]
fn loom_both_batch() {
loom_scenario(ITERATIONS, SEND_BATCH_SIZE, RECV_BATCH_SIZE);
}
}