use crate::error::{
CloseError, RecvError, RecvErrorTimeout, SendError, TryRecvError, TrySendError,
};
pub use async_impl::{RecvFuture, SendFuture};
mod async_impl;
mod backoff;
mod core;
mod sync_impl;
use self::core::{MpmcShared, STATE_CANCELLED, STATE_SUCCESS, STATE_WAITING};
use ::core::mem;
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use std::sync::Arc;
#[derive(Debug)]
pub struct Sender<T: Send> {
shared: Arc<MpmcShared<T>>,
closed: AtomicBool,
}
#[derive(Debug)]
pub struct Receiver<T: Send> {
shared: Arc<MpmcShared<T>>,
closed: AtomicBool,
}
#[derive(Debug)]
pub struct AsyncSender<T: Send> {
shared: Arc<MpmcShared<T>>,
closed: AtomicBool,
}
#[derive(Debug)]
pub struct AsyncReceiver<T: Send> {
shared: Arc<MpmcShared<T>>,
closed: AtomicBool,
pub(super) state: AtomicU8,
pub(super) is_registered: bool,
}
pub fn bounded<T: Send>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(MpmcShared::new(capacity));
(
Sender {
shared: Arc::clone(&shared),
closed: AtomicBool::new(false),
},
Receiver {
shared,
closed: AtomicBool::new(false),
},
)
}
pub fn unbounded<T: Send>() -> (Sender<T>, Receiver<T>) {
bounded(usize::MAX)
}
pub fn bounded_async<T: Send>(capacity: usize) -> (AsyncSender<T>, AsyncReceiver<T>) {
let shared = Arc::new(MpmcShared::new(capacity));
(
AsyncSender {
shared: Arc::clone(&shared),
closed: AtomicBool::new(false),
},
AsyncReceiver {
shared,
closed: AtomicBool::new(false),
state: AtomicU8::new(STATE_WAITING),
is_registered: false,
},
)
}
pub fn unbounded_async<T: Send>() -> (AsyncSender<T>, AsyncReceiver<T>) {
bounded_async(usize::MAX)
}
impl<T: Send> Clone for Sender<T> {
fn clone(&self) -> Self {
self.shared.internal.lock().sender_count += 1;
Sender {
shared: Arc::clone(&self.shared),
closed: AtomicBool::new(false),
}
}
}
impl<T: Send> Clone for Receiver<T> {
fn clone(&self) -> Self {
self.shared.internal.lock().receiver_count += 1;
Receiver {
shared: Arc::clone(&self.shared),
closed: AtomicBool::new(false),
}
}
}
impl<T: Send> Clone for AsyncSender<T> {
fn clone(&self) -> Self {
self.shared.internal.lock().sender_count += 1;
AsyncSender {
shared: Arc::clone(&self.shared),
closed: AtomicBool::new(false),
}
}
}
impl<T: Send> Clone for AsyncReceiver<T> {
fn clone(&self) -> Self {
self.shared.internal.lock().receiver_count += 1;
AsyncReceiver {
shared: Arc::clone(&self.shared),
closed: AtomicBool::new(false),
state: AtomicU8::new(STATE_WAITING),
is_registered: false,
}
}
}
impl<T: Send> Sender<T> {
pub fn send(&self, item: T) -> Result<(), SendError> {
if self.closed.load(Ordering::Relaxed) {
return Err(SendError::Closed);
}
sync_impl::send_sync(self, item)
}
pub fn try_send(&self, item: T) -> Result<(), TrySendError<T>> {
if self.closed.load(Ordering::Relaxed) {
return Err(TrySendError::Closed(item));
}
self.shared.try_send_core(item)
}
pub fn close(&self) -> Result<(), CloseError> {
if self
.closed
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.close_internal();
Ok(())
} else {
Err(CloseError)
}
}
fn close_internal(&self) {
let sync_waiters;
let async_waiters;
{
let mut guard = self.shared.internal.lock();
guard.sender_count -= 1;
if guard.sender_count == 0 {
sync_waiters = std::mem::take(&mut guard.waiting_sync_receivers);
async_waiters = std::mem::take(&mut guard.waiting_async_receivers);
} else {
return;
}
}
for waiter in sync_waiters {
if unsafe { &*waiter.state }
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
waiter.thread.unpark();
}
}
for waiter in async_waiters {
if unsafe { &*waiter.state }
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
waiter.waker.wake();
}
}
}
pub fn is_closed(&self) -> bool {
self.shared.internal.lock().receiver_count == 0
}
pub fn capacity(&self) -> Option<usize> {
if self.shared.capacity == usize::MAX {
None
} else {
Some(self.shared.capacity)
}
}
pub fn to_async(self) -> AsyncSender<T> {
let shared = unsafe { std::ptr::read(&self.shared) };
mem::forget(self);
AsyncSender {
shared,
closed: AtomicBool::new(false),
}
}
#[inline]
pub fn len(&self) -> usize {
self.shared.internal.lock().queue.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn is_full(&self) -> bool {
if self.shared.capacity == usize::MAX {
false
} else {
self.len() == self.shared.capacity
}
}
}
impl<T: Send> Drop for Sender<T> {
fn drop(&mut self) {
let _ = self.close();
}
}
impl<T: Send> Receiver<T> {
pub fn recv(&self) -> Result<T, RecvError> {
if self.closed.load(Ordering::Relaxed) {
return Err(RecvError::Disconnected);
}
sync_impl::recv_sync(self)
}
pub fn try_recv(&self) -> Result<T, TryRecvError> {
if self.closed.load(Ordering::Relaxed) {
return Err(TryRecvError::Disconnected);
}
self.shared.try_recv_core()
}
pub fn recv_timeout(&self, timeout: std::time::Duration) -> Result<T, RecvErrorTimeout> {
sync_impl::recv_timeout_sync(self, timeout)
}
pub fn close(&self) -> Result<(), CloseError> {
if self
.closed
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.close_internal();
Ok(())
} else {
Err(CloseError)
}
}
fn close_internal(&self) {
let sync_waiters;
let async_waiters;
{
let mut guard = self.shared.internal.lock();
guard.receiver_count -= 1;
if guard.receiver_count == 0 {
sync_waiters = std::mem::take(&mut guard.waiting_sync_senders);
async_waiters = std::mem::take(&mut guard.waiting_async_senders);
} else {
sync_waiters = guard.waiting_sync_senders.pop_front().into_iter().collect();
async_waiters = guard
.waiting_async_senders
.pop_front()
.into_iter()
.collect();
}
}
for waiter in sync_waiters {
if unsafe { &*waiter.state }
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
waiter.thread.unpark();
}
}
for waiter in async_waiters {
if unsafe { &*waiter.state }
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
waiter.waker.wake();
}
}
}
pub fn is_closed(&self) -> bool {
let guard = self.shared.internal.lock();
guard.sender_count == 0
&& guard.queue.is_empty()
&& guard.waiting_sync_senders.is_empty()
&& guard.waiting_async_senders.is_empty()
}
pub fn capacity(&self) -> Option<usize> {
if self.shared.capacity == usize::MAX {
None
} else {
Some(self.shared.capacity)
}
}
pub fn to_async(self) -> AsyncReceiver<T> {
let shared = unsafe { std::ptr::read(&self.shared) };
mem::forget(self);
AsyncReceiver {
shared,
closed: AtomicBool::new(false),
state: AtomicU8::new(STATE_WAITING),
is_registered: false,
}
}
#[inline]
pub fn len(&self) -> usize {
self.shared.internal.lock().queue.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn is_full(&self) -> bool {
if self.shared.capacity == usize::MAX {
false
} else {
self.len() == self.shared.capacity
}
}
}
impl<T: Send> Drop for Receiver<T> {
fn drop(&mut self) {
let _ = self.close();
}
}
impl<T: Send> AsyncSender<T> {
pub fn send(&self, item: T) -> SendFuture<'_, T> {
async_impl::SendFuture::new(self, item)
}
pub fn try_send(&self, item: T) -> Result<(), TrySendError<T>> {
if self.closed.load(Ordering::Relaxed) {
return Err(TrySendError::Closed(item));
}
self.shared.try_send_core(item)
}
pub fn close(&self) -> Result<(), CloseError> {
if self
.closed
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.close_internal();
Ok(())
} else {
Err(CloseError)
}
}
fn close_internal(&self) {
let sync_waiters;
let async_waiters;
{
let mut guard = self.shared.internal.lock();
guard.sender_count -= 1;
if guard.sender_count == 0 {
sync_waiters = std::mem::take(&mut guard.waiting_sync_receivers);
async_waiters = std::mem::take(&mut guard.waiting_async_receivers);
} else {
return;
}
}
for waiter in sync_waiters {
if unsafe { &*waiter.state }
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
waiter.thread.unpark();
}
}
for waiter in async_waiters {
if unsafe { &*waiter.state }
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
waiter.waker.wake();
}
}
}
pub fn is_closed(&self) -> bool {
self.shared.internal.lock().receiver_count == 0
}
pub fn capacity(&self) -> Option<usize> {
if self.shared.capacity == usize::MAX {
None
} else {
Some(self.shared.capacity)
}
}
pub fn to_sync(self) -> Sender<T> {
let shared = unsafe { std::ptr::read(&self.shared) };
mem::forget(self);
Sender {
shared,
closed: AtomicBool::new(false),
}
}
#[inline]
pub fn len(&self) -> usize {
self.shared.internal.lock().queue.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn is_full(&self) -> bool {
if self.shared.capacity == usize::MAX {
false
} else {
self.len() == self.shared.capacity
}
}
}
impl<T: Send> Drop for AsyncSender<T> {
fn drop(&mut self) {
let _ = self.close();
}
}
impl<T: Send> AsyncReceiver<T> {
pub fn recv(&self) -> RecvFuture<'_, T> {
async_impl::RecvFuture::new(self)
}
pub fn try_recv(&self) -> Result<T, TryRecvError> {
if self.closed.load(Ordering::Relaxed) {
return Err(TryRecvError::Disconnected);
}
self.shared.try_recv_core()
}
pub fn close(&self) -> Result<(), CloseError> {
if self
.closed
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.close_internal();
Ok(())
} else {
Err(CloseError)
}
}
fn close_internal(&self) {
let sync_waiters;
let async_waiters;
{
let mut guard = self.shared.internal.lock();
guard.receiver_count -= 1;
if guard.receiver_count == 0 {
sync_waiters = std::mem::take(&mut guard.waiting_sync_senders);
async_waiters = std::mem::take(&mut guard.waiting_async_senders);
} else {
sync_waiters = guard.waiting_sync_senders.pop_front().into_iter().collect();
async_waiters = guard
.waiting_async_senders
.pop_front()
.into_iter()
.collect();
}
}
for waiter in sync_waiters {
if unsafe { &*waiter.state }
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
waiter.thread.unpark();
}
}
for waiter in async_waiters {
if unsafe { &*waiter.state }
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
waiter.waker.wake();
}
}
}
pub fn is_closed(&self) -> bool {
let guard = self.shared.internal.lock();
guard.sender_count == 0
&& guard.queue.is_empty()
&& guard.waiting_sync_senders.is_empty()
&& guard.waiting_async_senders.is_empty()
}
pub fn capacity(&self) -> Option<usize> {
if self.shared.capacity == usize::MAX {
None
} else {
Some(self.shared.capacity)
}
}
pub fn to_sync(self) -> Receiver<T> {
if self.is_registered {
let state_ptr = &self.state as *const AtomicU8;
if self
.state
.compare_exchange(
STATE_WAITING,
STATE_CANCELLED,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
let mut guard = self.shared.internal.lock();
guard
.waiting_async_receivers
.retain(|w| w.state != state_ptr);
}
}
let shared = unsafe { std::ptr::read(&self.shared) };
mem::forget(self); Receiver {
shared,
closed: AtomicBool::new(false),
}
}
#[inline]
pub fn len(&self) -> usize {
self.shared.internal.lock().queue.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn is_full(&self) -> bool {
if self.shared.capacity == usize::MAX {
false
} else {
self.len() == self.shared.capacity
}
}
}
impl<T: Send> Drop for AsyncReceiver<T> {
fn drop(&mut self) {
let _ = self.close();
if self.is_registered {
let state_ptr = &self.state as *const AtomicU8;
if self
.state
.compare_exchange(
STATE_WAITING,
STATE_CANCELLED,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
let mut guard = self.shared.internal.lock();
guard
.waiting_async_receivers
.retain(|w| w.state != state_ptr);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use std::thread;
use std::time::Duration;
#[test]
fn test_mpmc_v2_recv_timeout_spurious_wakeup_leak() {
let (tx, rx) = bounded::<i32>(5);
let rx_shared = Arc::clone(&rx.shared);
let receiver_handle = thread::spawn(move || {
rx.recv_timeout(Duration::from_secs(5))
});
thread::sleep(Duration::from_millis(50));
{
let guard = rx_shared.internal.lock();
assert_eq!(guard.waiting_sync_receivers.len(), 1);
}
receiver_handle.thread().unpark();
thread::sleep(Duration::from_millis(50));
let leaked_count = {
let guard = rx_shared.internal.lock();
guard.waiting_sync_receivers.len()
};
assert_eq!(
leaked_count, 1,
"Waiter was leaked! Queue contains {} waiters from a single thread.",
leaked_count
);
let _ = tx.send(42);
let _ = receiver_handle.join().unwrap();
}
#[test]
fn test_mpmc_v2_async_waker_collision_deadlock() {
use super::*;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
fn dummy_waker() -> Waker {
unsafe fn clone(_: *const ()) -> RawWaker {
RawWaker::new(std::ptr::null(), &VTABLE)
}
unsafe fn wake(_: *const ()) {}
unsafe fn wake_by_ref(_: *const ()) {}
unsafe fn drop_raw(_: *const ()) {}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_raw);
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
}
let (tx, rx) = bounded_async::<i32>(0);
let rx_clone = rx.clone();
let mut fut1 = Box::pin(rx.recv());
let mut fut2 = Box::pin(rx_clone.recv());
let waker = dummy_waker();
let mut cx = Context::from_waker(&waker);
let _ = fut1.as_mut().poll(&mut cx);
let _ = fut2.as_mut().poll(&mut cx);
{
let guard = rx.shared.internal.lock();
assert_eq!(
guard.waiting_async_receivers.len(),
2,
"Async waker collision occurred! Only 1 waiter was registered for 2 futures."
);
}
drop(fut1);
{
let guard = rx.shared.internal.lock();
assert_eq!(
guard.waiting_async_receivers.len(),
1,
"fut2's waiter registration was silently lost when fut1 was dropped!"
);
}
drop(fut2);
{
let guard = rx.shared.internal.lock();
assert_eq!(
guard.waiting_async_receivers.len(),
0,
"Queue is not empty after dropping both futures!"
);
}
}
}