use crate::async_util::AtomicWaker;
use crate::error::{RecvError, TryRecvError};
use crate::mpsc::block_queue::UnboundedBlockQueue;
use crate::sync_util;
use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Mutex;
use std::task::{Context, Poll};
use std::thread::Thread;
pub(crate) struct MpscShared<T> {
pub(crate) queue: UnboundedBlockQueue<T>,
pub(crate) consumer_parked: AtomicBool,
pub(crate) consumer_thread: Mutex<Option<Thread>>,
pub(crate) consumer_waker: AtomicWaker,
pub(crate) receiver_dropped: AtomicBool,
pub(crate) sender_count: AtomicUsize,
pub(crate) current_len: AtomicUsize,
}
impl<T> fmt::Debug for MpscShared<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MpscShared")
.field(
"consumer_parked",
&self.consumer_parked.load(Ordering::Relaxed),
)
.field("consumer_waker", &self.consumer_waker)
.field("sender_count", &self.sender_count.load(Ordering::Relaxed))
.field("current_len", &self.current_len.load(Ordering::Relaxed))
.finish_non_exhaustive()
}
}
unsafe impl<T: Send> Send for MpscShared<T> {}
unsafe impl<T: Send> Sync for MpscShared<T> {}
impl<T> MpscShared<T> {
pub(crate) fn new() -> Self {
MpscShared {
queue: UnboundedBlockQueue::new(),
consumer_parked: AtomicBool::new(false),
consumer_thread: Mutex::new(None),
consumer_waker: AtomicWaker::new(),
receiver_dropped: AtomicBool::new(false),
sender_count: AtomicUsize::new(1),
current_len: AtomicUsize::new(0),
}
}
#[inline]
pub(crate) fn wake_consumer(&self) {
self.consumer_waker.wake();
if self
.consumer_parked
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
if let Some(thread_handle) = self.consumer_thread.lock().unwrap().take() {
sync_util::unpark_thread(&thread_handle);
}
}
}
pub(crate) fn try_recv_internal(&self) -> Result<T, TryRecvError> {
match self.queue.pop() {
Some(val) => {
self.current_len.fetch_sub(1, Ordering::Relaxed);
Ok(val)
}
None => {
if self.sender_count.load(Ordering::Acquire) == 0 {
Err(TryRecvError::Disconnected)
} else {
Err(TryRecvError::Empty)
}
}
}
}
pub(crate) fn try_recv_batch_internal(
&self,
out: &mut Vec<T>,
max: usize,
) -> Result<usize, TryRecvError> {
let k = self.queue.pop_batch(out, max);
if k > 0 {
self.current_len.fetch_sub(k, Ordering::Relaxed);
return Ok(k);
}
if self.sender_count.load(Ordering::Acquire) == 0 {
let k = self.queue.pop_batch(out, max);
if k > 0 {
self.current_len.fetch_sub(k, Ordering::Relaxed);
return Ok(k);
}
Err(TryRecvError::Disconnected)
} else {
Err(TryRecvError::Empty)
}
}
pub(crate) fn poll_recv_batch_internal(
&self,
cx: &mut Context<'_>,
out: &mut Vec<T>,
max: usize,
) -> Poll<Result<usize, RecvError>> {
if max == 0 {
return Poll::Ready(Ok(0));
}
match self.try_recv_batch_internal(out, max) {
Ok(k) => Poll::Ready(Ok(k)),
Err(TryRecvError::Disconnected) => Poll::Ready(Err(RecvError::Disconnected)),
Err(TryRecvError::Empty) => {
self.consumer_waker.register(cx.waker());
match self.try_recv_batch_internal(out, max) {
Ok(k) => Poll::Ready(Ok(k)),
Err(TryRecvError::Disconnected) => Poll::Ready(Err(RecvError::Disconnected)),
Err(TryRecvError::Empty) => Poll::Pending,
}
}
}
}
pub(crate) fn poll_recv_internal(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
loop {
match self.try_recv_internal() {
Ok(value) => return Poll::Ready(Ok(value)),
Err(TryRecvError::Disconnected) => return Poll::Ready(Err(RecvError::Disconnected)),
Err(TryRecvError::Empty) => {
self.consumer_waker.register(cx.waker());
match self.try_recv_internal() {
Ok(value) => return Poll::Ready(Ok(value)),
Err(TryRecvError::Disconnected) => return Poll::Ready(Err(RecvError::Disconnected)),
Err(TryRecvError::Empty) => {
if self.sender_count.load(Ordering::Acquire) == 0 {
match self.try_recv_internal() {
Ok(value) => return Poll::Ready(Ok(value)),
_ => return Poll::Ready(Err(RecvError::Disconnected)),
}
}
return Poll::Pending;
}
}
}
}
}
}
}