use crate::error::{TryRecvError, TrySendError};
use crate::RecvError;
use core::task::{Context, Poll};
use parking_lot::Mutex;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicU8, Ordering};
use std::task::Waker;
use std::thread::Thread;
pub(crate) const STATE_WAITING: u8 = 0;
pub(crate) const STATE_SUCCESS: u8 = 1;
pub(crate) const STATE_CANCELLED: u8 = 2;
#[derive(Debug)]
pub(crate) enum WaiterData<T> {
SenderItem(Option<T>),
}
#[derive(Debug)]
pub(crate) struct SyncWaiter<T> {
pub(crate) thread: Thread,
pub(crate) data: Option<WaiterData<T>>,
pub(crate) state: *const AtomicU8,
}
impl<T> SyncWaiter<T> {
pub(crate) fn take_item_from_sender_slot(&mut self) -> Option<T> {
if let Some(WaiterData::SenderItem(item_opt)) = self.data.as_mut() {
item_opt.take()
} else {
None
}
}
}
unsafe impl<T: Send> Send for SyncWaiter<T> {}
unsafe impl<T: Send> Sync for SyncWaiter<T> {}
#[derive(Debug)]
pub(crate) struct AsyncWaiter<T> {
pub(crate) waker: Waker,
pub(crate) data: Option<WaiterData<T>>,
pub(crate) state: *const AtomicU8,
}
impl<T> AsyncWaiter<T> {
pub(crate) fn take_item_from_sender_slot(&mut self) -> Option<T> {
if let Some(WaiterData::SenderItem(item_opt)) = self.data.as_mut() {
item_opt.take()
} else {
None
}
}
}
unsafe impl<T: Send> Send for AsyncWaiter<T> {}
unsafe impl<T: Send> Sync for AsyncWaiter<T> {}
#[derive(Debug)]
pub(crate) struct MpmcChannelInternal<T> {
pub(crate) queue: VecDeque<T>,
pub(crate) waiting_sync_senders: VecDeque<SyncWaiter<T>>,
pub(crate) waiting_async_senders: VecDeque<AsyncWaiter<T>>,
pub(crate) waiting_sync_receivers: VecDeque<SyncWaiter<T>>,
pub(crate) waiting_async_receivers: VecDeque<AsyncWaiter<T>>,
pub(crate) sender_count: usize,
pub(crate) receiver_count: usize,
}
#[derive(Debug)]
pub(crate) struct MpmcShared<T> {
pub(crate) internal: Mutex<MpmcChannelInternal<T>>,
pub(crate) capacity: usize,
}
unsafe impl<T: Send> Send for MpmcShared<T> {}
unsafe impl<T: Send> Sync for MpmcShared<T> {}
impl<T: Send> MpmcShared<T> {
pub(crate) fn new(capacity: usize) -> Self {
MpmcShared {
internal: Mutex::new(MpmcChannelInternal {
queue: VecDeque::with_capacity(if capacity == usize::MAX { 32 } else { capacity }),
waiting_sync_senders: VecDeque::new(),
waiting_async_senders: VecDeque::new(),
waiting_sync_receivers: VecDeque::new(),
waiting_async_receivers: VecDeque::new(),
sender_count: 1,
receiver_count: 1,
}),
capacity,
}
}
pub(crate) fn try_send_core(&self, item: T) -> Result<(), TrySendError<T>> {
let mut guard = self.internal.lock();
if guard.receiver_count == 0 {
return Err(TrySendError::Closed(item));
}
loop {
match guard.waiting_async_receivers.pop_front() {
None => break,
Some(waiter) => {
let waiter_state = unsafe { &*waiter.state };
if waiter_state
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
guard.queue.push_back(item);
waiter.waker.wake();
return Ok(());
}
}
}
}
loop {
match guard.waiting_sync_receivers.pop_front() {
None => break,
Some(waiter) => {
let waiter_state = unsafe { &*waiter.state };
if waiter_state
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
guard.queue.push_back(item);
waiter.thread.unpark();
return Ok(());
}
}
}
}
if self.capacity == 0 {
return Err(TrySendError::Full(item));
}
if self.capacity == usize::MAX || guard.queue.len() < self.capacity {
guard.queue.push_back(item);
return Ok(());
}
Err(TrySendError::Full(item))
}
pub(crate) fn try_recv_core(&self) -> Result<T, TryRecvError> {
let mut guard = self.internal.lock();
loop {
if guard
.waiting_async_senders
.front()
.map(|w| w.data.is_some())
.unwrap_or(false)
{
let mut waiter = guard.waiting_async_senders.pop_front().unwrap();
let waiter_state = unsafe { &*waiter.state };
match waiter_state.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => {
let item = waiter.take_item_from_sender_slot().unwrap();
waiter.waker.wake();
return Ok(item);
}
Err(_) => {
drop(waiter.data.take()); }
}
} else {
break;
}
}
loop {
if guard
.waiting_sync_senders
.front()
.map(|w| w.data.is_some())
.unwrap_or(false)
{
let mut waiter = guard.waiting_sync_senders.pop_front().unwrap();
let waiter_state = unsafe { &*waiter.state };
match waiter_state.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => {
let item = waiter.take_item_from_sender_slot().unwrap();
waiter.thread.unpark();
return Ok(item);
}
Err(_) => {
drop(waiter.data.take());
}
}
} else {
break;
}
}
if let Some(item) = guard.queue.pop_front() {
if self.capacity > 0 {
let mut woke = false;
loop {
match guard.waiting_async_senders.pop_front() {
None => break,
Some(waiter) => {
let waiter_state = unsafe { &*waiter.state };
if waiter_state
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
waiter.waker.wake();
woke = true;
break;
}
}
}
}
if !woke {
loop {
match guard.waiting_sync_senders.pop_front() {
None => break,
Some(waiter) => {
let waiter_state = unsafe { &*waiter.state };
if waiter_state
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
waiter.thread.unpark();
break;
}
}
}
}
}
}
return Ok(item);
}
if guard.sender_count == 0 {
return Err(TryRecvError::Disconnected);
}
Err(TryRecvError::Empty)
}
pub(crate) fn poll_recv_internal(
&self,
cx: &mut Context<'_>,
state_ptr: *const AtomicU8,
) -> Poll<Result<T, RecvError>> {
'poll_loop: loop {
match self.try_recv_core() {
Ok(item) => {
return Poll::Ready(Ok(item));
}
Err(TryRecvError::Disconnected) => return Poll::Ready(Err(RecvError::Disconnected)),
Err(TryRecvError::Empty) => { }
}
{
let mut guard = self.internal.lock();
if !guard.queue.is_empty()
|| (self.capacity == 0
&& (!guard.waiting_sync_senders.is_empty() || !guard.waiting_async_senders.is_empty()))
{
drop(guard);
continue 'poll_loop;
}
if guard.sender_count == 0 {
return Poll::Ready(Err(RecvError::Disconnected));
}
let new_waker = cx.waker();
if let Some(existing_waiter) = guard
.waiting_async_receivers
.iter_mut()
.find(|w| w.state == state_ptr)
{
existing_waiter.waker = new_waker.clone();
} else {
unsafe { (*state_ptr).store(STATE_WAITING, Ordering::SeqCst) };
let waiter = AsyncWaiter {
waker: new_waker.clone(),
data: None,
state: state_ptr,
};
guard.waiting_async_receivers.push_back(waiter);
}
return Poll::Pending;
}
}
}
}
impl<T> Drop for MpmcShared<T> {
fn drop(&mut self) {
let guard = self.internal.get_mut();
guard.queue.clear();
for mut waiter in guard.waiting_sync_senders.drain(..) {
let _ = waiter.take_item_from_sender_slot();
}
for mut waiter in guard.waiting_async_senders.drain(..) {
let _ = waiter.take_item_from_sender_slot();
}
}
}