use super::bounded_sync::{
unwrap_batch_messages, BoundedMessage, BoundedMpscShared, Permit, Receiver, Sender,
};
use crate::error::{
BatchSendErrorReason, RecvError, SendBatchError, SendError, TrySendBatchError, TrySendError,
};
use crate::mpsc::unbounded_v2;
use crate::{CloseError, TryRecvError};
use futures_core::Stream;
use std::future::Future;
use std::marker::PhantomPinned;
use std::mem;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
#[derive(Debug)]
pub struct AsyncSender<T: Send> {
pub(crate) shared: Arc<BoundedMpscShared<T>>,
pub(crate) closed: AtomicBool,
}
#[derive(Debug)]
pub struct AsyncReceiver<T: Send> {
pub(crate) shared: Arc<BoundedMpscShared<T>>,
pub(crate) closed: AtomicBool,
}
impl<T: Send> AsyncSender<T> {
pub fn send(&self, value: T) -> SendFuture<'_, T> {
SendFuture {
acquire: self.shared.gate.acquire_async(),
sender: self,
value: Some(value),
is_rendezvous: self.capacity() == 0, _phantom: PhantomPinned,
}
}
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
if self.is_closed() {
return Err(TrySendError::Closed(value));
}
if !self.shared.gate.try_acquire() {
return Err(TrySendError::Full(value));
}
let permit = Permit {
gate: self.shared.gate.clone(),
is_rendezvous: self.capacity() == 0,
};
let message = BoundedMessage {
value,
_permit: permit,
};
let mut cache = None;
if let Err(msg) = unbounded_v2::send_internal(&self.shared.channel, message, &mut cache) {
return Err(TrySendError::Closed(msg.value));
}
Ok(())
}
pub fn send_batch(&self, items: Vec<T>) -> SendBatchFuture<'_, T> {
let total = items.len();
SendBatchFuture {
acquire: self.shared.gate.acquire_many_async(total),
sender: self,
iter: items.into_iter(),
total,
sent: 0,
is_rendezvous: self.capacity() == 0,
_phantom: PhantomPinned,
}
}
pub fn send_batch_mut<'a>(&'a self, items: &'a mut Vec<T>) -> SendBatchMutFuture<'a, T> {
let len = items.len();
SendBatchMutFuture {
acquire: self.shared.gate.acquire_many_async(len),
sender: self,
items,
sent: 0,
is_rendezvous: self.capacity() == 0,
_phantom: PhantomPinned,
}
}
pub fn try_send_batch(&self, items: Vec<T>) -> Result<usize, TrySendBatchError<T>> {
let total = items.len();
if total == 0 {
return Ok(0);
}
if self.closed.load(Ordering::Relaxed) || self.is_closed() {
return Err(TrySendBatchError {
sent: 0,
unsent: items,
reason: BatchSendErrorReason::Closed,
});
}
let k = self.shared.gate.try_acquire_many(total);
if k == 0 {
return Err(TrySendBatchError {
sent: 0,
unsent: items,
reason: BatchSendErrorReason::Full,
});
}
if self.is_closed() {
return Err(TrySendBatchError {
sent: 0,
unsent: items,
reason: BatchSendErrorReason::Closed,
});
}
let mut iter = items.into_iter();
push_batch_messages_async(self, &mut iter, k);
if k == total {
Ok(total)
} else {
Err(TrySendBatchError {
sent: k,
unsent: iter.collect(),
reason: BatchSendErrorReason::Full,
})
}
}
pub fn try_send_batch_mut(&self, items: &mut Vec<T>) -> Result<usize, SendError> {
if items.is_empty() {
return Ok(0);
}
if self.closed.load(Ordering::Relaxed) || self.is_closed() {
return Err(SendError::Closed);
}
let k = self.shared.gate.try_acquire_many(items.len());
if k == 0 {
return Ok(0);
}
if self.is_closed() {
return Err(SendError::Closed);
}
let mut drain = items.drain(..k);
push_batch_messages_async(self, &mut drain, k);
drop(drain);
Ok(k)
}
pub fn close(&self) -> Result<(), CloseError> {
if self
.closed
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.close_internal();
Ok(())
} else {
Err(CloseError)
}
}
fn close_internal(&self) {
if self
.shared
.channel
.sender_count
.fetch_sub(1, Ordering::AcqRel)
== 1
{
self.shared.channel.wake_consumer();
self.shared.gate.release();
}
}
pub fn is_closed(&self) -> bool {
self.shared.channel.receiver_dropped.load(Ordering::Acquire)
}
pub fn len(&self) -> usize {
self.shared.channel.current_len.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn capacity(&self) -> usize {
self.shared.gate.capacity()
}
pub fn is_full(&self) -> bool {
self.len() == self.capacity()
}
pub fn to_sync(self) -> Sender<T> {
let shared = unsafe { std::ptr::read(&self.shared) };
mem::forget(self);
Sender {
shared,
closed: AtomicBool::new(false),
}
}
}
impl<T: Send> Clone for AsyncSender<T> {
fn clone(&self) -> Self {
self
.shared
.channel
.sender_count
.fetch_add(1, Ordering::Relaxed);
Self {
shared: self.shared.clone(),
closed: AtomicBool::new(false),
}
}
}
impl<T: Send> Drop for AsyncSender<T> {
fn drop(&mut self) {
if !self.closed.swap(true, Ordering::AcqRel) {
self.close_internal();
}
}
}
fn push_batch_messages_async<T: Send>(
sender: &AsyncSender<T>,
iter: &mut impl Iterator<Item = T>,
k: usize,
) {
let is_rendezvous = sender.capacity() == 0;
let shared = &sender.shared;
let mut msg_iter = iter.by_ref().map(|value| BoundedMessage {
value,
_permit: Permit {
gate: shared.gate.clone(),
is_rendezvous,
},
});
let mut cache = None;
unbounded_v2::send_batch_internal(&shared.channel, &mut msg_iter, k, &mut cache);
}
impl<T: Send> AsyncReceiver<T> {
pub fn recv(&self) -> RecvFuture<'_, T> {
RecvFuture {
receiver: self,
rendezvous_permit_released: false,
}
}
pub fn recv_batch(&self, max: usize) -> RecvBatchFuture<'_, T> {
RecvBatchFuture {
receiver: self,
max,
rendezvous_permit_released: false,
}
}
pub fn recv_batch_mut<'a>(&'a self, out: &'a mut Vec<T>, max: usize) -> RecvBatchMutFuture<'a, T> {
RecvBatchMutFuture {
receiver: self,
out,
max,
rendezvous_permit_released: false,
}
}
pub fn try_recv_batch(&self, max: usize) -> Result<Vec<T>, TryRecvError> {
let mut out = Vec::new();
self.try_recv_batch_mut(&mut out, max)?;
Ok(out)
}
pub fn try_recv_batch_mut(&self, out: &mut Vec<T>, max: usize) -> Result<usize, TryRecvError> {
if max == 0 {
return Ok(0);
}
if self.closed.load(Ordering::Relaxed) {
return Err(TryRecvError::Disconnected);
}
if self.capacity() == 0 {
self.shared.gate.release();
}
let mut msgs = Vec::new();
let k = self.shared.channel.try_recv_batch_internal(&mut msgs, max)?;
debug_assert_eq!(k, msgs.len());
Ok(unwrap_batch_messages(&self.shared.gate, msgs, out))
}
pub fn try_recv(&self) -> Result<T, TryRecvError> {
if self.closed.load(Ordering::Relaxed) {
return Err(TryRecvError::Disconnected);
}
if self.capacity() == 0 {
self.shared.gate.release();
}
self.shared.channel.try_recv_internal().map(|msg| msg.value)
}
pub fn close(&self) -> Result<(), CloseError> {
if self
.closed
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.close_internal();
Ok(())
} else {
Err(CloseError)
}
}
fn close_internal(&self) {
self
.shared
.channel
.receiver_dropped
.store(true, Ordering::Release);
while self.shared.channel.try_recv_internal().is_ok() {}
self.shared.gate.close();
}
pub fn is_closed(&self) -> bool {
let chan = &self.shared.channel;
chan.sender_count.load(Ordering::Acquire) == 0 && self.is_empty()
}
pub fn len(&self) -> usize {
self.shared.channel.current_len.load(Ordering::Relaxed)
}
pub fn sender_count(&self) -> usize {
self.shared.channel.sender_count.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn capacity(&self) -> usize {
self.shared.gate.capacity()
}
pub fn is_full(&self) -> bool {
self.len() == self.capacity()
}
pub fn to_sync(self) -> Receiver<T> {
let shared = unsafe { std::ptr::read(&self.shared) };
mem::forget(self);
Receiver {
shared,
closed: AtomicBool::new(false),
}
}
}
impl<T: Send> Drop for AsyncReceiver<T> {
fn drop(&mut self) {
if !self.closed.swap(true, Ordering::AcqRel) {
self.close_internal();
}
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
pub struct SendFuture<'a, T: Send> {
acquire: crate::coord::AcquireFuture<'a>,
sender: &'a AsyncSender<T>,
value: Option<T>,
is_rendezvous: bool,
_phantom: PhantomPinned,
}
impl<'a, T: Send> Future for SendFuture<'a, T> {
type Output = Result<(), SendError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.as_mut().get_unchecked_mut() };
if this.sender.is_closed() {
this.value = None;
return Poll::Ready(Err(SendError::Closed));
}
match unsafe { Pin::new_unchecked(&mut this.acquire) }.poll(cx) {
Poll::Ready(()) => {
let value = this
.value
.take()
.expect("SendFuture polled after completion");
let permit = Permit {
gate: this.sender.shared.gate.clone(),
is_rendezvous: this.is_rendezvous,
};
let message = BoundedMessage {
value,
_permit: permit,
};
let mut cache = None;
match unbounded_v2::send_internal(&this.sender.shared.channel, message, &mut cache) {
Ok(()) => Poll::Ready(Ok(())),
Err(_) => Poll::Ready(Err(SendError::Closed)),
}
}
Poll::Pending => Poll::Pending,
}
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
pub struct SendBatchFuture<'a, T: Send> {
acquire: crate::coord::AcquireManyFuture<'a>,
sender: &'a AsyncSender<T>,
iter: std::vec::IntoIter<T>,
total: usize,
sent: usize,
is_rendezvous: bool,
_phantom: PhantomPinned,
}
impl<'a, T: Send> Future for SendBatchFuture<'a, T> {
type Output = Result<usize, SendBatchError<T>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.as_mut().get_unchecked_mut() };
loop {
if this.sent == this.total {
return Poll::Ready(Ok(this.total));
}
if this.sender.is_closed() {
return Poll::Ready(Err(SendBatchError {
sent: this.sent,
unsent: this.iter.by_ref().collect(),
}));
}
match unsafe { Pin::new_unchecked(&mut this.acquire) }.poll(cx) {
Poll::Ready(k) => {
if this.sender.is_closed() {
return Poll::Ready(Err(SendBatchError {
sent: this.sent,
unsent: this.iter.by_ref().collect(),
}));
}
let k = k.min(this.total - this.sent);
push_batch_messages_async_rendezvous(this.sender, &mut this.iter, k, this.is_rendezvous);
this.sent += k;
if this.sent == this.total {
return Poll::Ready(Ok(this.total));
}
this.acquire = this.sender.shared.gate.acquire_many_async(this.total - this.sent);
}
Poll::Pending => return Poll::Pending,
}
}
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
pub struct SendBatchMutFuture<'a, T: Send> {
acquire: crate::coord::AcquireManyFuture<'a>,
sender: &'a AsyncSender<T>,
items: &'a mut Vec<T>,
sent: usize,
is_rendezvous: bool,
_phantom: PhantomPinned,
}
impl<'a, T: Send> Future for SendBatchMutFuture<'a, T> {
type Output = Result<usize, SendError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.as_mut().get_unchecked_mut() };
loop {
if this.items.is_empty() {
return Poll::Ready(Ok(this.sent));
}
if this.sender.is_closed() {
return Poll::Ready(Err(SendError::Closed));
}
match unsafe { Pin::new_unchecked(&mut this.acquire) }.poll(cx) {
Poll::Ready(k) => {
if this.sender.is_closed() {
return Poll::Ready(Err(SendError::Closed));
}
let k = k.min(this.items.len());
{
let mut drain = this.items.drain(..k);
push_batch_messages_async_rendezvous(this.sender, &mut drain, k, this.is_rendezvous);
}
this.sent += k;
if this.items.is_empty() {
return Poll::Ready(Ok(this.sent));
}
this.acquire = this.sender.shared.gate.acquire_many_async(this.items.len());
}
Poll::Pending => return Poll::Pending,
}
}
}
}
fn push_batch_messages_async_rendezvous<T: Send>(
sender: &AsyncSender<T>,
iter: &mut impl Iterator<Item = T>,
k: usize,
is_rendezvous: bool,
) {
let shared = &sender.shared;
let mut msg_iter = iter.by_ref().map(|value| BoundedMessage {
value,
_permit: Permit {
gate: shared.gate.clone(),
is_rendezvous,
},
});
let mut cache = None;
unbounded_v2::send_batch_internal(&shared.channel, &mut msg_iter, k, &mut cache);
}
#[must_use = "futures do nothing unless you .await or poll them"]
pub struct RecvBatchFuture<'a, T: Send> {
receiver: &'a AsyncReceiver<T>,
max: usize,
rendezvous_permit_released: bool,
}
impl<'a, T: Send> Future for RecvBatchFuture<'a, T> {
type Output = Result<Vec<T>, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if this.max == 0 {
return Poll::Ready(Ok(Vec::new()));
}
let mut out = Vec::new();
match poll_recv_batch_bounded(
this.receiver,
cx,
&mut out,
this.max,
&mut this.rendezvous_permit_released,
) {
Poll::Ready(Ok(_)) => Poll::Ready(Ok(out)),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
pub struct RecvBatchMutFuture<'a, T: Send> {
receiver: &'a AsyncReceiver<T>,
out: &'a mut Vec<T>,
max: usize,
rendezvous_permit_released: bool,
}
impl<'a, T: Send> Future for RecvBatchMutFuture<'a, T> {
type Output = Result<usize, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if this.max == 0 {
return Poll::Ready(Ok(0));
}
let max = this.max;
poll_recv_batch_bounded(
this.receiver,
cx,
this.out,
max,
&mut this.rendezvous_permit_released,
)
}
}
fn poll_recv_batch_bounded<T: Send>(
receiver: &AsyncReceiver<T>,
cx: &mut Context<'_>,
out: &mut Vec<T>,
max: usize,
rendezvous_permit_released: &mut bool,
) -> Poll<Result<usize, RecvError>> {
if receiver.closed.load(Ordering::Relaxed) {
return Poll::Ready(Err(RecvError::Disconnected));
}
if receiver.capacity() == 0 && !*rendezvous_permit_released {
receiver.shared.gate.release();
*rendezvous_permit_released = true;
}
let mut msgs = Vec::new();
match receiver.shared.channel.poll_recv_batch_internal(cx, &mut msgs, max) {
Poll::Ready(Ok(k)) => {
debug_assert_eq!(k, msgs.len());
Poll::Ready(Ok(unwrap_batch_messages(&receiver.shared.gate, msgs, out)))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
#[derive(Debug)]
pub struct RecvFuture<'a, T: Send> {
receiver: &'a AsyncReceiver<T>,
rendezvous_permit_released: bool,
}
impl<'a, T: Send> Future for RecvFuture<'a, T> {
type Output = Result<T, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if this.receiver.closed.load(Ordering::Relaxed) {
return Poll::Ready(Err(RecvError::Disconnected));
}
if this.receiver.capacity() == 0 && !this.rendezvous_permit_released {
this.receiver.shared.gate.release();
this.rendezvous_permit_released = true;
}
match this.receiver.shared.channel.poll_recv_internal(cx) {
Poll::Ready(Ok(msg)) => Poll::Ready(Ok(msg.value)),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
impl<T: Send> Stream for AsyncReceiver<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.closed.load(Ordering::Relaxed) {
return Poll::Ready(None);
}
if this.capacity() == 0 {
this.shared.gate.release();
}
match this.shared.channel.poll_recv_internal(cx) {
Poll::Ready(Ok(msg)) => Poll::Ready(Some(msg.value)),
Poll::Ready(Err(_)) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}