use super::bounded_sync::{BoundedMessage, BoundedMpscShared, Permit, Receiver, Sender};
use crate::error::{RecvError, SendError, TrySendError};
use crate::mpsc::unbounded;
use crate::TryRecvError;
use futures_core::Stream;
use std::future::Future;
use std::marker::PhantomPinned;
use std::mem;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
#[derive(Debug)]
pub struct AsyncSender<T: Send> {
pub(crate) shared: Arc<BoundedMpscShared<T>>,
pub(crate) closed: AtomicBool,
}
#[derive(Debug)]
pub struct AsyncReceiver<T: Send> {
pub(crate) shared: Arc<BoundedMpscShared<T>>,
pub(crate) closed: AtomicBool,
}
impl<T: Send> AsyncSender<T> {
pub fn send(&self, value: T) -> SendFuture<'_, T> {
SendFuture {
acquire: self.shared.gate.acquire_async(),
sender: self,
value: Some(value),
is_rendezvous: self.capacity() == 0, _phantom: PhantomPinned,
}
}
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
if self.is_closed() {
return Err(TrySendError::Closed(value));
}
if !self.shared.gate.try_acquire() {
return Err(TrySendError::Full(value));
}
let permit = Permit {
gate: self.shared.gate.clone(),
is_rendezvous: self.capacity() == 0,
};
let message = BoundedMessage {
value,
_permit: permit,
};
if let Err(msg) = unbounded::send_internal(&self.shared.channel, message) {
return Err(TrySendError::Closed(msg.value));
}
Ok(())
}
pub fn clone(&self) -> Self {
self
.shared
.channel
.sender_count
.fetch_add(1, Ordering::Relaxed);
Self {
shared: self.shared.clone(),
closed: AtomicBool::new(false),
}
}
pub fn is_closed(&self) -> bool {
self.shared.channel.receiver_dropped.load(Ordering::Acquire)
}
pub fn len(&self) -> usize {
self.shared.channel.current_len.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn capacity(&self) -> usize {
self.shared.gate.capacity()
}
pub fn is_full(&self) -> bool {
self.len() == self.capacity()
}
pub fn to_sync(self) -> Sender<T> {
let shared = unsafe { std::ptr::read(&self.shared) };
mem::forget(self);
Sender {
shared,
closed: AtomicBool::new(false),
}
}
}
impl<T: Send> Drop for AsyncSender<T> {
fn drop(&mut self) {
if !self.closed.swap(true, Ordering::AcqRel) {
if self
.shared
.channel
.sender_count
.fetch_sub(1, Ordering::AcqRel)
== 1
{
self.shared.channel.wake_consumer();
self.shared.gate.release();
}
}
}
}
impl<T: Send> AsyncReceiver<T> {
pub fn recv(&self) -> RecvFuture<'_, T> {
RecvFuture {
receiver: self,
rendezvous_permit_released: false,
}
}
pub fn try_recv(&self) -> Result<T, TryRecvError> {
if self.closed.load(Ordering::Relaxed) {
return Err(TryRecvError::Disconnected);
}
if self.capacity() == 0 {
self.shared.gate.release();
}
self.shared.channel.try_recv_internal().map(|msg| msg.value)
}
pub fn is_closed(&self) -> bool {
let chan = &self.shared.channel;
chan.sender_count.load(Ordering::Acquire) == 0 && self.is_empty()
}
pub fn len(&self) -> usize {
self.shared.channel.current_len.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn capacity(&self) -> usize {
self.shared.gate.capacity()
}
pub fn is_full(&self) -> bool {
self.len() == self.capacity()
}
pub fn to_sync(self) -> Receiver<T> {
let shared = unsafe { std::ptr::read(&self.shared) };
mem::forget(self);
Receiver {
shared,
closed: AtomicBool::new(false),
}
}
}
impl<T: Send> Drop for AsyncReceiver<T> {
fn drop(&mut self) {
if !self.closed.swap(true, Ordering::AcqRel) {
self
.shared
.channel
.receiver_dropped
.store(true, Ordering::Release);
while self.shared.channel.try_recv_internal().is_ok() {}
self.shared.gate.release();
}
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
pub struct SendFuture<'a, T: Send> {
acquire: crate::coord::AcquireFuture<'a>,
sender: &'a AsyncSender<T>,
value: Option<T>,
is_rendezvous: bool,
_phantom: PhantomPinned,
}
impl<'a, T: Send> Future for SendFuture<'a, T> {
type Output = Result<(), SendError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.as_mut().get_unchecked_mut() };
if this.sender.is_closed() {
this.value = None;
return Poll::Ready(Err(SendError::Closed));
}
match Pin::new(&mut this.acquire).poll(cx) {
Poll::Ready(()) => {
let value = this
.value
.take()
.expect("SendFuture polled after completion");
let permit = Permit {
gate: this.sender.shared.gate.clone(),
is_rendezvous: this.is_rendezvous,
};
let message = BoundedMessage {
value,
_permit: permit,
};
match unbounded::send_internal(&this.sender.shared.channel, message) {
Ok(()) => Poll::Ready(Ok(())),
Err(_) => Poll::Ready(Err(SendError::Closed)),
}
}
Poll::Pending => Poll::Pending,
}
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
#[derive(Debug)]
pub struct RecvFuture<'a, T: Send> {
receiver: &'a AsyncReceiver<T>,
rendezvous_permit_released: bool,
}
impl<'a, T: Send> Future for RecvFuture<'a, T> {
type Output = Result<T, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if this.receiver.closed.load(Ordering::Relaxed) {
return Poll::Ready(Err(RecvError::Disconnected));
}
if this.receiver.capacity() == 0 && !this.rendezvous_permit_released {
this.receiver.shared.gate.release();
this.rendezvous_permit_released = true;
}
match this.receiver.shared.channel.poll_recv_internal(cx) {
Poll::Ready(Ok(msg)) => Poll::Ready(Ok(msg.value)),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
impl<T: Send> Stream for AsyncReceiver<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.closed.load(Ordering::Relaxed) {
return Poll::Ready(None);
}
if this.capacity() == 0 {
this.shared.gate.release();
}
match this.shared.channel.poll_recv_internal(cx) {
Poll::Ready(Ok(msg)) => Poll::Ready(Some(msg.value)),
Poll::Ready(Err(_)) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}