use crate::{control::*, error::*};
use derivative::Derivative;
use std::sync::atomic::*;
use std::{mem::ManuallyDrop, num::NonZeroUsize};
#[cfg(feature = "futures_api")]
use futures::{executor::*, sink::Sink, stream::*, task::*};
#[cfg(feature = "futures_api")]
use std::pin::Pin;
#[derive(Derivative, Eq, PartialEq)]
#[derivative(Debug(bound = ""))]
pub struct RingSender<T> {
#[derivative(Debug = "ignore")]
handle: ManuallyDrop<ControlBlockRef<T>>,
}
unsafe impl<T: Send> Send for RingSender<T> {}
impl<T> RingSender<T> {
fn new(handle: ManuallyDrop<ControlBlockRef<T>>) -> Self {
Self { handle }
}
pub fn send(&mut self, message: T) -> Result<(), SendError<T>> {
if self.handle.receivers.load(Ordering::Relaxed) > 0 {
self.handle.buffer.push(message);
#[cfg(feature = "futures_api")]
fence(Ordering::SeqCst);
#[cfg(feature = "futures_api")]
self.handle.waitlist.wake();
Ok(())
} else {
Err(SendError::Disconnected(message))
}
}
}
impl<T> Clone for RingSender<T> {
fn clone(&self) -> Self {
self.handle.senders.fetch_add(1, Ordering::Relaxed);
RingSender::new(self.handle.clone())
}
}
impl<T> Drop for RingSender<T> {
fn drop(&mut self) {
if self.handle.senders.fetch_sub(1, Ordering::AcqRel) == 1 {
#[cfg(feature = "futures_api")]
fence(Ordering::SeqCst);
#[cfg(feature = "futures_api")]
self.handle.waitlist.wake();
if !self.handle.connected.swap(false, Ordering::AcqRel) {
unsafe { ManuallyDrop::drop(&mut self.handle) }
}
}
}
}
#[cfg(feature = "futures_api")]
impl<T> Sink<T> for RingSender<T> {
type Error = SendError<T>;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
self.send(item)
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
#[derive(Derivative, Eq, PartialEq)]
#[derivative(Debug(bound = ""))]
pub struct RingReceiver<T> {
#[derivative(Debug = "ignore")]
handle: ManuallyDrop<ControlBlockRef<T>>,
}
unsafe impl<T: Send> Send for RingReceiver<T> {}
impl<T> RingReceiver<T> {
fn new(handle: ManuallyDrop<ControlBlockRef<T>>) -> Self {
Self { handle }
}
#[cfg(feature = "futures_api")]
pub fn recv(&mut self) -> Result<T, RecvError> {
block_on(self.next()).ok_or(RecvError::Disconnected)
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.handle.buffer.pop().ok_or_else(|| {
if self.handle.senders.load(Ordering::Relaxed) > 0 {
TryRecvError::Empty
} else {
TryRecvError::Disconnected
}
})
}
}
impl<T> Clone for RingReceiver<T> {
fn clone(&self) -> Self {
self.handle.receivers.fetch_add(1, Ordering::Relaxed);
RingReceiver::new(self.handle.clone())
}
}
impl<T> Drop for RingReceiver<T> {
fn drop(&mut self) {
if self.handle.receivers.fetch_sub(1, Ordering::AcqRel) == 1 {
if !self.handle.connected.swap(false, Ordering::AcqRel) {
unsafe { ManuallyDrop::drop(&mut self.handle) }
}
}
}
}
#[cfg(feature = "futures_api")]
impl<T> Stream for RingReceiver<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.try_recv() {
Ok(msg) => Poll::Ready(Some(msg)),
Err(TryRecvError::Disconnected) => Poll::Ready(None),
Err(TryRecvError::Empty) => {
self.handle.waitlist.wait(ctx.waker().clone());
fence(Ordering::SeqCst);
match self.try_recv() {
Ok(msg) => Poll::Ready(Some(msg)),
Err(TryRecvError::Disconnected) => Poll::Ready(None),
Err(TryRecvError::Empty) => Poll::Pending,
}
}
}
}
}
pub fn ring_channel<T>(capacity: NonZeroUsize) -> (RingSender<T>, RingReceiver<T>) {
let handle = ManuallyDrop::new(ControlBlockRef::new(capacity.get()));
(RingSender::new(handle.clone()), RingReceiver::new(handle))
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::{collection::*, prelude::*};
use rayon::{iter::repeatn, prelude::*};
use std::{cmp::min, iter};
#[cfg(feature = "futures_api")]
use futures::{prelude::*, stream};
#[test]
fn ring_channel_is_associated_with_a_single_control_block() {
let (s, r) = ring_channel::<()>(NonZeroUsize::new(1).unwrap());
assert_eq!(s.handle, r.handle);
}
#[test]
fn senders_are_equal_if_they_are_associated_with_the_same_ring_channel() {
let (s1, _) = ring_channel::<()>(NonZeroUsize::new(1).unwrap());
let (s2, _) = ring_channel::<()>(NonZeroUsize::new(1).unwrap());
assert_eq!(s1, s1.clone());
assert_eq!(s2, s2.clone());
assert_ne!(s1, s2);
}
#[test]
fn receivers_are_equal_if_they_are_associated_with_the_same_ring_channel() {
let (_, r1) = ring_channel::<()>(NonZeroUsize::new(1).unwrap());
let (_, r2) = ring_channel::<()>(NonZeroUsize::new(1).unwrap());
assert_eq!(r1, r1.clone());
assert_eq!(r2, r2.clone());
assert_ne!(r1, r2);
}
#[test]
fn cloning_sender_increments_senders() {
let (s, _r) = ring_channel::<()>(NonZeroUsize::new(1).unwrap());
#[allow(clippy::redundant_clone)]
let x = s.clone();
assert_eq!(x.handle.senders.load(Ordering::Relaxed), 2);
assert_eq!(x.handle.receivers.load(Ordering::Relaxed), 1);
}
#[test]
fn cloning_receiver_increments_receivers_counter() {
let (_s, r) = ring_channel::<()>(NonZeroUsize::new(1).unwrap());
#[allow(clippy::redundant_clone)]
let x = r.clone();
assert_eq!(x.handle.senders.load(Ordering::Relaxed), 1);
assert_eq!(x.handle.receivers.load(Ordering::Relaxed), 2);
}
#[test]
fn dropping_sender_decrements_senders_counter() {
let (_, r) = ring_channel::<()>(NonZeroUsize::new(1).unwrap());
assert_eq!(r.handle.senders.load(Ordering::Relaxed), 0);
assert_eq!(r.handle.receivers.load(Ordering::Relaxed), 1);
}
#[test]
fn dropping_receiver_decrements_receivers_counter() {
let (s, _) = ring_channel::<()>(NonZeroUsize::new(1).unwrap());
assert_eq!(s.handle.senders.load(Ordering::Relaxed), 1);
assert_eq!(s.handle.receivers.load(Ordering::Relaxed), 0);
}
#[test]
fn channel_is_disconnected_if_there_are_no_senders() {
let (_, r) = ring_channel::<()>(NonZeroUsize::new(1).unwrap());
assert_eq!(r.handle.senders.load(Ordering::Relaxed), 0);
assert_eq!(r.handle.connected.load(Ordering::Relaxed), false);
}
#[test]
fn channel_is_disconnected_if_there_are_no_receivers() {
let (s, _) = ring_channel::<()>(NonZeroUsize::new(1).unwrap());
assert_eq!(s.handle.receivers.load(Ordering::Relaxed), 0);
assert_eq!(s.handle.connected.load(Ordering::Relaxed), false);
}
proptest! {
#[test]
fn endpoints_are_safe_to_send_across_threads(m in 1..=100usize, n in 1..=100usize) {
#[derive(Clone)]
enum Endpoint<T> {
Sender(RingSender<T>),
Receiver(RingReceiver<T>),
}
let (s, r) = ring_channel::<()>(NonZeroUsize::new(1).unwrap());
let ls = repeatn(Endpoint::Sender(s), m);
let rs = repeatn(Endpoint::Receiver(r), n);
ls.chain(rs).for_each(drop);
}
#[test]
fn send_succeeds_on_connected_channel(cap in 1..=100usize, msgs in vec("[a-z]", 1..=100)) {
let (s, _r) = ring_channel(NonZeroUsize::new(cap).unwrap());
repeatn(s, msgs.len()).zip(msgs.par_iter().cloned()).for_each(|(mut c, msg)| {
assert_eq!(c.send(msg), Ok(()));
});
}
#[test]
fn send_fails_on_disconnected_channel(cap in 1..=100usize, msgs in vec("[a-z]", 1..=100)) {
let (s, _) = ring_channel(NonZeroUsize::new(cap).unwrap());
repeatn(s, msgs.len()).zip(msgs.par_iter().cloned()).for_each(|(mut c, msg)| {
assert_eq!(c.send(msg.clone()), Err(SendError::Disconnected(msg)));
});
}
#[test]
fn send_overwrites_old_messages(cap in 1..=100usize, mut msgs in vec("[a-z]", 1..=100)) {
let (mut s, r) = ring_channel(NonZeroUsize::new(cap).unwrap());
let overwritten = msgs.len() - min(msgs.len(), cap);
for msg in msgs.iter().cloned() {
assert_eq!(s.send(msg), Ok(()));
}
assert_eq!(
iter::from_fn(move || r.handle.buffer.pop()).collect::<Vec<_>>(),
msgs.drain(..).skip(overwritten).collect::<Vec<_>>()
);
}
#[cfg(feature = "futures_api")]
#[test]
fn recv_succeeds_on_non_empty_connected_channel(msgs in vec("[a-z]", 1..=100)) {
let (s, r) = ring_channel(NonZeroUsize::new(msgs.len()).unwrap());
for msg in msgs.iter().cloned().enumerate() {
s.handle.buffer.push(msg);
}
let mut received = vec![(0usize, Default::default()); msgs.len()];
repeatn(r, msgs.len()).zip(received.par_iter_mut()).for_each(|(mut c, slot)| {
*slot = c.recv().unwrap();
});
received.sort_by_key(|(k, _)| *k);
assert_eq!(received.drain(..).map(|(_, v)| v).collect::<Vec<_>>(), msgs);
}
#[cfg(feature = "futures_api")]
#[test]
fn recv_succeeds_on_non_empty_disconnected_channel(msgs in vec("[a-z]", 1..=100)) {
let (_, r) = ring_channel(NonZeroUsize::new(msgs.len()).unwrap());
for msg in msgs.iter().cloned().enumerate() {
r.handle.buffer.push(msg);
}
let mut received = vec![(0usize, Default::default()); msgs.len()];
repeatn(r, msgs.len()).zip(received.par_iter_mut()).for_each(|(mut c, slot)| {
*slot = c.recv().unwrap();
});
received.sort_by_key(|(k, _)| *k);
assert_eq!(received.drain(..).map(|(_, v)| v).collect::<Vec<_>>(), msgs);
}
#[cfg(feature = "futures_api")]
#[test]
fn recv_fails_on_empty_disconnected_channel(cap in 1..=100usize, n in 1..=100usize) {
let (_, r) = ring_channel::<()>(NonZeroUsize::new(cap).unwrap());
repeatn(r, n).for_each(move |mut r| {
assert_eq!(r.recv(), Err(RecvError::Disconnected));
});
}
#[test]
fn try_recv_succeeds_on_non_empty_connected_channel(msgs in vec("[a-z]", 1..=100)) {
let (s, r) = ring_channel(NonZeroUsize::new(msgs.len()).unwrap());
for msg in msgs.iter().cloned().enumerate() {
s.handle.buffer.push(msg);
}
let mut received = vec![(0usize, Default::default()); msgs.len()];
repeatn(r, msgs.len()).zip(received.par_iter_mut()).for_each(|(mut c, slot)| {
*slot = c.try_recv().unwrap();
});
received.sort_by_key(|(k, _)| *k);
assert_eq!(received.drain(..).map(|(_, v)| v).collect::<Vec<_>>(), msgs);
}
#[test]
fn try_recv_succeeds_on_non_empty_disconnected_channel(msgs in vec("[a-z]", 1..=100)) {
let (_, r) = ring_channel(NonZeroUsize::new(msgs.len()).unwrap());
for msg in msgs.iter().cloned().enumerate() {
r.handle.buffer.push(msg);
}
let mut received = vec![(0usize, Default::default()); msgs.len()];
repeatn(r, msgs.len()).zip(received.par_iter_mut()).for_each(|(mut c, slot)| {
*slot = c.try_recv().unwrap();
});
received.sort_by_key(|(k, _)| *k);
assert_eq!(received.drain(..).map(|(_, v)| v).collect::<Vec<_>>(), msgs);
}
#[test]
fn try_recv_fails_on_empty_connected_channel(cap in 1..=100usize, n in 1..=100usize) {
let (_s, r) = ring_channel::<()>(NonZeroUsize::new(cap).unwrap());
repeatn(r, n).for_each(|mut r| {
assert_eq!(r.try_recv(), Err(TryRecvError::Empty));
});
}
#[test]
fn try_recv_fails_on_empty_disconnected_channel(cap in 1..=100usize, n in 1..=100usize) {
let (_, r) = ring_channel::<()>(NonZeroUsize::new(cap).unwrap());
repeatn(r, n).for_each(move |mut r| {
assert_eq!(r.try_recv(), Err(TryRecvError::Disconnected));
});
}
}
#[cfg(feature = "futures_api")]
proptest! {
#[test]
fn sink(cap in 1..=100usize, mut msgs in vec_deque("[a-z]", 1..=100)) {
let (mut tx, mut rx) = ring_channel(NonZeroUsize::new(cap).unwrap());
let overwritten = msgs.len() - min(msgs.len(), cap);
assert_eq!(block_on(tx.send_all(&mut iter(msgs.clone()).map(Ok))), Ok(()));
drop(tx);
assert_eq!(
iter::from_fn(move || rx.try_recv().ok()).collect::<Vec<_>>(),
msgs.drain(..).skip(overwritten).collect::<Vec<_>>()
);
}
#[test]
fn stream(cap in 1..=100usize, mut msgs in vec_deque("[a-z]", 1..=100)) {
let (mut tx, rx) = ring_channel(NonZeroUsize::new(cap).unwrap());
let overwritten = msgs.len() - min(msgs.len(), cap);
for msg in msgs.iter().cloned() {
assert_eq!(tx.send(msg), Ok(()));
}
drop(tx);
assert_eq!(
block_on_stream(rx).collect::<Vec<_>>(),
msgs.drain(..).skip(overwritten).collect::<Vec<_>>()
);
}
#[test]
fn stream_wakes_on_disconnect(n in 1..=100usize) {
let (tx, rx) = ring_channel::<()>(NonZeroUsize::new(1).unwrap());
rayon::scope(move |s| {
for _ in 0..n {
let rx = rx.clone();
s.spawn(move |_| assert_eq!(block_on_stream(rx).collect::<Vec<_>>(), vec![]));
}
s.spawn(move |_| drop(tx));
});
}
#[test]
fn stream_wakes_on_send(n in 1..=100usize) {
let (tx, rx) = ring_channel(NonZeroUsize::new(n).unwrap());
rayon::scope(move |s| {
for _ in 0..n {
let tx = tx.clone();
let mut rx = rx.clone();
s.spawn(move |_| {
assert_eq!(block_on(rx.next()), Some(42));
drop(tx);
});
}
for _ in 0..n {
let mut tx = tx.clone();
s.spawn(move |_| assert_eq!(tx.send(42), Ok(())));
}
});
}
#[test]
fn stream_wakes_on_send_all(n in 1..=100usize) {
let (mut tx, rx) = ring_channel(NonZeroUsize::new(n).unwrap());
rayon::scope(move |s| {
for _ in 0..n {
let tx = tx.clone();
let mut rx = rx.clone();
s.spawn(move |_| {
assert_eq!(block_on(rx.next()), Some(42));
drop(tx);
});
}
let mut msgs = stream::iter(vec![Ok(42); n]);
s.spawn(move |_| assert_eq!(block_on(tx.send_all(&mut msgs)), Ok(())));
});
}
#[test]
fn recv_wakes_on_disconnect(n in 1..=100usize) {
let (tx, rx) = ring_channel::<()>(NonZeroUsize::new(1).unwrap());
rayon::scope(move |s| {
for _ in 0..n {
let mut rx = rx.clone();
s.spawn(move |_| assert_eq!(rx.recv(), Err(RecvError::Disconnected)));
}
s.spawn(move |_| drop(tx));
});
}
#[test]
fn recv_wakes_on_send(n in 1..=100usize) {
let (tx, rx) = ring_channel(NonZeroUsize::new(n).unwrap());
rayon::scope(move |s| {
for _ in 0..n {
let tx = tx.clone();
let mut rx = rx.clone();
s.spawn(move |_| {
assert_eq!(rx.recv(), Ok(42));
drop(tx);
});
}
for _ in 0..n {
let mut tx = tx.clone();
s.spawn(move |_| assert_eq!(tx.send(42), Ok(())));
}
});
}
#[test]
fn recv_wakes_on_send_all(n in 1..=100usize) {
let (mut tx, rx) = ring_channel(NonZeroUsize::new(n).unwrap());
rayon::scope(move |s| {
for _ in 0..n {
let tx = tx.clone();
let mut rx = rx.clone();
s.spawn(move |_| {
assert_eq!(rx.recv(), Ok(42));
drop(tx);
});
}
let mut msgs = stream::iter(vec![Ok(42); n]);
s.spawn(move |_| assert_eq!(block_on(tx.send_all(&mut msgs)), Ok(())));
});
}
}
}