use crate::async_util::AtomicWaker;
use crate::error::{RecvError, SendError, TryRecvError};
use crate::internal::cache_padded::CachePadded;
use crate::sync_util;
use core::marker::PhantomPinned;
use std::cell::UnsafeCell;
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::ptr;
use std::sync::atomic::{self, AtomicBool, AtomicPtr, AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::thread::{self, Thread};
struct Node<T> {
next: AtomicPtr<Node<T>>,
value: UnsafeCell<Option<T>>,
}
pub(crate) struct MpscShared<T> {
head: CachePadded<AtomicPtr<Node<T>>>,
tail: CachePadded<UnsafeCell<*mut Node<T>>>,
consumer_parked: AtomicBool,
consumer_thread: UnsafeCell<Option<Thread>>,
consumer_waker: AtomicWaker,
sender_count: AtomicUsize,
pub(crate) current_len: AtomicUsize, }
impl<T> fmt::Debug for MpscShared<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MpscShared")
.field("head", &self.head.load(Ordering::Relaxed))
.field("tail", &"<UnsafeCell>")
.field(
"consumer_parked",
&self.consumer_parked.load(Ordering::Relaxed),
)
.field("consumer_waker", &self.consumer_waker) .field("sender_count", &self.sender_count.load(Ordering::Relaxed))
.field("current_len", &self.current_len.load(Ordering::Relaxed))
.finish_non_exhaustive()
}
}
unsafe impl<T: Send> Send for MpscShared<T> {}
unsafe impl<T: Send> Sync for MpscShared<T> {}
impl<T: Send> MpscShared<T> {
pub(crate) fn new() -> Self {
let stub = Box::new(Node {
next: AtomicPtr::new(ptr::null_mut()),
value: UnsafeCell::new(None),
});
let stub_ptr = Box::into_raw(stub);
MpscShared {
head: CachePadded::new(AtomicPtr::new(stub_ptr)),
tail: CachePadded::new(UnsafeCell::new(stub_ptr)),
consumer_parked: AtomicBool::new(false),
consumer_thread: UnsafeCell::new(None),
consumer_waker: AtomicWaker::new(),
sender_count: AtomicUsize::new(1),
current_len: AtomicUsize::new(0),
}
}
#[inline]
fn wake_consumer(&self) {
if self.consumer_parked.load(Ordering::Relaxed) {
atomic::fence(Ordering::Acquire);
if self
.consumer_parked
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
if let Some(thread_handle) = unsafe { (*self.consumer_thread.get()).take() } {
sync_util::unpark_thread(&thread_handle);
}
}
}
self.consumer_waker.wake();
}
pub(crate) fn try_recv_internal(&self) -> Result<T, TryRecvError> {
unsafe {
let tail_ptr = *self.tail.get();
let next_ptr = (*tail_ptr).next.load(Ordering::Acquire);
if next_ptr.is_null() {
if self.sender_count.load(Ordering::Acquire) == 0 {
Err(TryRecvError::Disconnected)
} else {
Err(TryRecvError::Empty)
}
} else {
let value = (*(*next_ptr).value.get()).take().unwrap();
*self.tail.get() = next_ptr;
self.current_len.fetch_sub(1, Ordering::Relaxed); drop(Box::from_raw(tail_ptr));
Ok(value)
}
}
}
}
impl<T> Drop for MpscShared<T> {
fn drop(&mut self) {
let mut current_node_ptr = *self.tail.get_mut();
while !current_node_ptr.is_null() {
let node_box = unsafe { Box::from_raw(current_node_ptr) };
current_node_ptr = node_box.next.load(Ordering::Relaxed);
}
}
}
#[derive(Debug)]
pub struct Sender<T: Send> {
pub(crate) shared: Arc<MpscShared<T>>,
}
#[derive(Debug)]
pub struct AsyncSender<T: Send> {
pub(crate) shared: Arc<MpscShared<T>>,
}
fn send_internal<T: Send>(shared: &Arc<MpscShared<T>>, value: T) -> Result<(), SendError> {
let new_node = Box::new(Node {
next: AtomicPtr::new(ptr::null_mut()),
value: UnsafeCell::new(Some(value)),
});
let new_node_ptr = Box::into_raw(new_node);
let old_head_ptr = shared.head.swap(new_node_ptr, Ordering::AcqRel);
unsafe {
(*old_head_ptr).next.store(new_node_ptr, Ordering::Release);
}
shared.current_len.fetch_add(1, Ordering::Relaxed); shared.wake_consumer();
Ok(())
}
impl<T: Send> Sender<T> {
pub fn send(&self, value: T) -> Result<(), SendError> {
send_internal(&self.shared, value)
}
pub fn len(&self) -> usize {
self.shared.current_len.load(Ordering::Relaxed)
}
}
impl<T: Send> AsyncSender<T> {
pub fn send(&self, value: T) -> SendFuture<'_, T> {
SendFuture {
producer: self,
value: Some(value),
_phantom: PhantomPinned,
}
}
pub fn len(&self) -> usize {
self.shared.current_len.load(Ordering::Relaxed)
}
}
impl<T: Send> Clone for Sender<T> {
fn clone(&self) -> Self {
self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
Sender {
shared: Arc::clone(&self.shared),
}
}
}
impl<T: Send> Drop for Sender<T> {
fn drop(&mut self) {
if self.shared.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self.shared.wake_consumer();
}
}
}
impl<T: Send> Clone for AsyncSender<T> {
fn clone(&self) -> Self {
self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
AsyncSender {
shared: Arc::clone(&self.shared),
}
}
}
impl<T: Send> Drop for AsyncSender<T> {
fn drop(&mut self) {
if self.shared.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self.shared.wake_consumer();
}
}
}
#[derive(Debug)]
pub struct Receiver<T: Send> {
pub(crate) shared: Arc<MpscShared<T>>,
pub(crate) _phantom: PhantomData<*mut ()>, }
unsafe impl<T: Send> Send for Receiver<T> {}
#[derive(Debug)]
pub struct AsyncReceiver<T: Send> {
pub(crate) shared: Arc<MpscShared<T>>,
pub(crate) _phantom: PhantomData<*mut ()>, }
unsafe impl<T: Send> Send for AsyncReceiver<T> {}
impl<T: Send> Receiver<T> {
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.shared.try_recv_internal()
}
pub fn recv(&mut self) -> Result<T, RecvError> {
loop {
match self.try_recv() {
Ok(value) => return Ok(value),
Err(TryRecvError::Disconnected) => return Err(RecvError::Disconnected),
Err(TryRecvError::Empty) => {
unsafe {
*self.shared.consumer_thread.get() = Some(thread::current());
}
self.shared.consumer_parked.store(true, Ordering::Release);
if let Ok(value) = self.try_recv() {
if self
.shared
.consumer_parked
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
unsafe {
*self.shared.consumer_thread.get() = None;
}
}
return Ok(value);
}
if self.shared.sender_count.load(Ordering::Acquire) == 0 {
let tail_ptr = unsafe { *self.shared.tail.get() };
let next_ptr = unsafe { (*tail_ptr).next.load(Ordering::Acquire) };
if next_ptr.is_null() {
if self
.shared
.consumer_parked
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
unsafe {
*self.shared.consumer_thread.get() = None;
}
}
return Err(RecvError::Disconnected);
}
}
sync_util::park_thread();
if self
.shared
.consumer_parked
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
unsafe {
*self.shared.consumer_thread.get() = None;
}
}
}
}
}
}
pub fn len(&self) -> usize {
self.shared.current_len.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
if self.shared.current_len.load(Ordering::Relaxed) == 0 {
let tail_node_ptr = unsafe { *self.shared.tail.get() };
let next_node_ptr = unsafe { (*tail_node_ptr).next.load(Ordering::Acquire) };
return next_node_ptr.is_null();
}
false
}
}
impl<T: Send> AsyncReceiver<T> {
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.shared.try_recv_internal()
}
pub fn recv(&mut self) -> RecvFuture<'_, T> {
RecvFuture { consumer: self }
}
pub fn len(&self) -> usize {
self.shared.current_len.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
if self.shared.current_len.load(Ordering::Relaxed) == 0 {
let tail_node_ptr = unsafe { *self.shared.tail.get() };
let next_node_ptr = unsafe { (*tail_node_ptr).next.load(Ordering::Acquire) };
return next_node_ptr.is_null();
}
false
}
}
impl<T: Send> Drop for Receiver<T> {
fn drop(&mut self) {
while self.shared.try_recv_internal().is_ok() {
}
}
}
impl<T: Send> Drop for AsyncReceiver<T> {
fn drop(&mut self) {
while self.shared.try_recv_internal().is_ok() {
}
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
pub struct SendFuture<'a, T: Send> {
producer: &'a AsyncSender<T>,
value: Option<T>,
_phantom: PhantomPinned,
}
impl<'a, T: Send> Future for SendFuture<'a, T> {
type Output = Result<(), SendError>;
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.as_mut().get_unchecked_mut() };
let value = this
.value
.take()
.expect("SendFuture polled after completion");
Poll::Ready(send_internal(&this.producer.shared, value))
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
pub struct RecvFuture<'a, T: Send> {
consumer: &'a mut AsyncReceiver<T>,
}
impl<'a, T: Send> Future for RecvFuture<'a, T> {
type Output = Result<T, RecvError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match self.consumer.shared.try_recv_internal() {
Ok(value) => return Poll::Ready(Ok(value)),
Err(TryRecvError::Disconnected) => return Poll::Ready(Err(RecvError::Disconnected)),
Err(TryRecvError::Empty) => {
self.consumer.shared.consumer_waker.register(cx.waker());
match self.consumer.shared.try_recv_internal() {
Ok(value) => {
return Poll::Ready(Ok(value));
}
Err(TryRecvError::Disconnected) => {
return Poll::Ready(Err(RecvError::Disconnected));
}
Err(TryRecvError::Empty) => {
if self.consumer.shared.sender_count.load(Ordering::Acquire) == 0 {
match self.consumer.shared.try_recv_internal() {
Ok(value) => return Poll::Ready(Ok(value)),
Err(TryRecvError::Disconnected) => {
return Poll::Ready(Err(RecvError::Disconnected))
}
Err(TryRecvError::Empty) => return Poll::Ready(Err(RecvError::Disconnected)), }
}
return Poll::Pending;
}
}
}
}
}
}
}