use crate::flavor::FlavorMP;
use crate::sink::AsyncSink;
#[cfg(feature = "trace_log")]
use crate::tokio_task_id;
use crate::weak::WeakTx;
use crate::{shared::*, trace_log, MTx, NotCloneable, SenderType, Tx};
use std::cell::Cell;
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::mem::{needs_drop, MaybeUninit};
use std::ops::Deref;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
pub struct AsyncTx<F: Flavor> {
pub(crate) shared: Arc<ChannelShared<F>>,
_phan: PhantomData<Cell<()>>,
}
impl<F: Flavor> fmt::Debug for AsyncTx<F> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "AsyncTx{:p}", self)
}
}
impl<F: Flavor> fmt::Display for AsyncTx<F> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "AsyncTx{:p}", self)
}
}
unsafe impl<F: Flavor> Send for AsyncTx<F> {}
impl<F: Flavor> Drop for AsyncTx<F> {
#[inline(always)]
fn drop(&mut self) {
self.shared.close_tx();
}
}
impl<F: Flavor> From<Tx<F>> for AsyncTx<F> {
fn from(value: Tx<F>) -> Self {
value.add_tx();
Self::new(value.shared.clone())
}
}
impl<F: Flavor> AsyncTx<F> {
#[inline]
pub(crate) fn new(shared: Arc<ChannelShared<F>>) -> Self {
Self { shared, _phan: Default::default() }
}
#[inline]
pub fn into_sink(self) -> AsyncSink<F> {
AsyncSink::new(self)
}
#[inline]
pub fn into_blocking(self) -> Tx<F> {
self.into()
}
#[inline(always)]
pub fn is_disconnected(&self) -> bool {
self.shared.is_rx_closed()
}
}
impl<F: Flavor> AsyncTx<F> {
#[inline(always)]
pub fn send<'a>(&'a self, item: F::Item) -> SendFuture<'a, F> {
SendFuture { tx: self, item: MaybeUninit::new(item), waker: None }
}
#[inline]
pub fn try_send(&self, item: F::Item) -> Result<(), TrySendError<F::Item>> {
if self.shared.is_rx_closed() {
return Err(TrySendError::Disconnected(item));
}
let _item = MaybeUninit::new(item);
if self.shared.inner.try_send(&_item) {
self.shared.on_send();
Ok(())
} else {
unsafe { Err(TrySendError::Full(_item.assume_init())) }
}
}
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
#[inline]
pub fn send_timeout(
&self, item: F::Item, duration: std::time::Duration,
) -> SendTimeoutFuture<'_, F, tokio::time::Sleep, ()> {
let sleep = tokio::time::sleep(duration);
self.send_with_timer(item, sleep)
}
#[cfg(feature = "async_std")]
#[cfg_attr(docsrs, doc(cfg(feature = "async_std")))]
#[inline]
pub fn send_timeout(
&self, item: F::Item, duration: std::time::Duration,
) -> SendTimeoutFuture<'_, F, impl Future<Output = ()>, ()> {
let sleep = async_std::task::sleep(duration);
self.send_with_timer(item, sleep)
}
#[inline]
pub fn send_with_timer<FR, R>(&self, item: F::Item, fut: FR) -> SendTimeoutFuture<'_, F, FR, R>
where
FR: Future<Output = R>,
{
SendTimeoutFuture { tx: self, item: MaybeUninit::new(item), waker: None, sleep: fut }
}
#[inline(always)]
pub(crate) fn poll_send<'a, const SINK: bool>(
&self, ctx: &'a mut Context, item: &MaybeUninit<F::Item>,
o_waker: &'a mut Option<<F::Send as Registry>::Waker>,
) -> Poll<Result<(), ()>> {
let shared = &self.shared;
if shared.is_rx_closed() {
trace_log!("tx{:?}: closed {:?}", tokio_task_id!(), o_waker);
return Poll::Ready(Err(()));
}
loop {
if shared.inner.try_send(item) {
shared.on_send();
if let Some(_waker) = o_waker.take() {
trace_log!("tx{:?}: send {:?}", tokio_task_id!(), _waker);
} else {
trace_log!("tx{:?}: send", tokio_task_id!());
}
return Poll::Ready(Ok(()));
}
if o_waker.is_none() {
if let Some(mut backoff) = shared.get_async_backoff() {
loop {
backoff.spin();
if shared.inner.try_send(item) {
shared.on_send();
trace_log!("tx{:?}: send", tokio_task_id!());
return Poll::Ready(Ok(()));
}
if backoff.is_completed() {
break;
}
}
}
}
match shared.senders.reg_waker_async(ctx, o_waker) {
Some(Poll::Pending) => return Poll::Pending,
Some(Poll::Ready(())) => return Poll::Ready(Err(())),
_ => {}
}
let state = shared.sender_double_check::<SINK>(item, o_waker);
trace_log!("tx{:?}: sender_double_check {:?} {}", tokio_task_id!(), o_waker, state);
if state < WakerState::Woken as u8 {
return Poll::Pending;
} else if state > WakerState::Woken as u8 {
if state == WakerState::Done as u8 {
trace_log!("tx{:?}: send {:?} done", o_waker, tokio_task_id!());
let _ = o_waker.take();
return Poll::Ready(Ok(()));
} else {
debug_assert_eq!(state, WakerState::Closed as u8);
trace_log!("tx{:?}: closed {:?}", o_waker, tokio_task_id!());
let _ = o_waker.take();
return Poll::Ready(Err(()));
}
}
debug_assert_eq!(state, WakerState::Woken as u8);
continue;
}
}
}
#[must_use]
pub struct SendFuture<'a, F: Flavor> {
tx: &'a AsyncTx<F>,
item: MaybeUninit<F::Item>,
waker: Option<<F::Send as Registry>::Waker>,
}
unsafe impl<F: Flavor> Send for SendFuture<'_, F> where F::Item: Send {}
impl<F: Flavor> Drop for SendFuture<'_, F> {
#[inline]
fn drop(&mut self) {
if let Some(waker) = self.waker.as_ref() {
if self.tx.shared.abandon_send_waker(waker) && needs_drop::<F::Item>() {
unsafe { self.item.assume_init_drop() };
}
}
}
}
impl<F: Flavor> Future for SendFuture<'_, F>
where
F::Item: Unpin,
{
type Output = Result<(), SendError<F::Item>>;
#[inline]
fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
let mut _self = self.get_mut();
match _self.tx.poll_send::<false>(ctx, &_self.item, &mut _self.waker) {
Poll::Ready(Ok(())) => {
debug_assert!(_self.waker.is_none());
Poll::Ready(Ok(()))
}
Poll::Ready(Err(())) => {
let _ = _self.waker.take();
Poll::Ready(Err(SendError(unsafe { _self.item.assume_init_read() })))
}
Poll::Pending => Poll::Pending,
}
}
}
#[must_use]
pub struct SendTimeoutFuture<'a, F, FR, R>
where
F: Flavor,
FR: Future<Output = R>,
{
tx: &'a AsyncTx<F>,
sleep: FR,
item: MaybeUninit<F::Item>,
waker: Option<<F::Send as Registry>::Waker>,
}
unsafe impl<F, FR, R> Send for SendTimeoutFuture<'_, F, FR, R>
where
F: Flavor,
FR: Future<Output = R>,
{
}
impl<F, FR, R> Drop for SendTimeoutFuture<'_, F, FR, R>
where
F: Flavor,
FR: Future<Output = R>,
{
#[inline]
fn drop(&mut self) {
if let Some(waker) = self.waker.as_ref() {
if self.tx.shared.abandon_send_waker(waker) && needs_drop::<F::Item>() {
unsafe { self.item.assume_init_drop() };
}
}
}
}
impl<F, FR, R> Future for SendTimeoutFuture<'_, F, FR, R>
where
F: Flavor,
FR: Future<Output = R>,
F::Item: Send + 'static + Unpin,
{
type Output = Result<(), SendTimeoutError<F::Item>>;
#[inline]
fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
let mut _self = unsafe { self.get_unchecked_mut() };
match _self.tx.poll_send::<false>(ctx, &_self.item, &mut _self.waker) {
Poll::Ready(Ok(())) => {
debug_assert!(_self.waker.is_none());
Poll::Ready(Ok(()))
}
Poll::Ready(Err(())) => {
let _ = _self.waker.take();
Poll::Ready(Err(SendTimeoutError::Disconnected(unsafe {
_self.item.assume_init_read()
})))
}
Poll::Pending => {
let sleep = unsafe { Pin::new_unchecked(&mut _self.sleep) };
if sleep.poll(ctx).is_ready() {
if _self.tx.shared.abandon_send_waker(&_self.waker.take().unwrap()) {
return Poll::Ready(Err(SendTimeoutError::Timeout(unsafe {
_self.item.assume_init_read()
})));
} else {
return Poll::Ready(Ok(()));
}
}
Poll::Pending
}
}
}
}
pub trait AsyncTxTrait<T>: Send + 'static + fmt::Debug + fmt::Display {
fn try_send(&self, item: T) -> Result<(), TrySendError<T>>;
fn len(&self) -> usize;
fn capacity(&self) -> Option<usize>;
fn is_empty(&self) -> bool;
fn is_full(&self) -> bool;
fn is_disconnected(&self) -> bool;
fn get_tx_count(&self) -> usize;
fn get_rx_count(&self) -> usize;
fn clone_to_vec(self, count: usize) -> Vec<Self>
where
Self: Sized;
fn get_wakers_count(&self) -> (usize, usize);
fn send(&self, item: T) -> impl Future<Output = Result<(), SendError<T>>> + Send
where
T: Send + 'static + Unpin;
#[cfg(any(feature = "tokio", feature = "async_std"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "tokio", feature = "async_std"))))]
fn send_timeout<'a>(
&'a self, item: T, duration: std::time::Duration,
) -> impl Future<Output = Result<(), SendTimeoutError<T>>> + Send
where
T: Send + 'static + Unpin;
fn send_with_timer<FR, R>(
&self, item: T, fut: FR,
) -> impl Future<Output = Result<(), SendTimeoutError<T>>> + Send
where
FR: Future<Output = R>,
T: Send + 'static + Unpin;
}
impl<F: Flavor> AsyncTxTrait<F::Item> for AsyncTx<F> {
#[inline(always)]
fn clone_to_vec(self, count: usize) -> Vec<Self> {
assert_eq!(count, 1);
vec![self]
}
#[inline(always)]
fn try_send(&self, item: F::Item) -> Result<(), TrySendError<F::Item>> {
AsyncTx::try_send(self, item)
}
#[inline(always)]
fn send(&self, item: F::Item) -> impl Future<Output = Result<(), SendError<F::Item>>> + Send
where
F::Item: Send + 'static + Unpin,
{
AsyncTx::send(self, item)
}
#[cfg(any(feature = "tokio", feature = "async_std"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "tokio", feature = "async_std"))))]
#[inline(always)]
fn send_timeout<'a>(
&'a self, item: F::Item, duration: std::time::Duration,
) -> impl Future<Output = Result<(), SendTimeoutError<F::Item>>> + Send
where
F::Item: Send + 'static + Unpin,
{
AsyncTx::send_timeout(self, item, duration)
}
#[inline(always)]
fn send_with_timer<FR, R>(
&self, item: F::Item, fut: FR,
) -> impl Future<Output = Result<(), SendTimeoutError<F::Item>>> + Send
where
FR: Future<Output = R>,
F::Item: Send + 'static + Unpin,
{
AsyncTx::send_with_timer(self, item, fut)
}
#[inline(always)]
fn len(&self) -> usize {
self.as_ref().len()
}
#[inline(always)]
fn capacity(&self) -> Option<usize> {
self.as_ref().capacity()
}
#[inline(always)]
fn is_empty(&self) -> bool {
self.as_ref().is_empty()
}
#[inline(always)]
fn is_full(&self) -> bool {
self.as_ref().is_full()
}
#[inline(always)]
fn is_disconnected(&self) -> bool {
self.as_ref().get_rx_count() == 0
}
#[inline(always)]
fn get_tx_count(&self) -> usize {
self.as_ref().get_tx_count()
}
#[inline(always)]
fn get_rx_count(&self) -> usize {
self.as_ref().get_rx_count()
}
fn get_wakers_count(&self) -> (usize, usize) {
self.as_ref().get_wakers_count()
}
}
pub struct MAsyncTx<F: Flavor>(pub(crate) AsyncTx<F>);
impl<F: Flavor> fmt::Debug for MAsyncTx<F> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "MAsyncTx{:p}", self)
}
}
impl<F: Flavor> fmt::Display for MAsyncTx<F> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "MAsyncTx{:p}", self)
}
}
unsafe impl<F: Flavor> Sync for MAsyncTx<F> {}
impl<F: Flavor> Clone for MAsyncTx<F> {
#[inline]
fn clone(&self) -> Self {
let inner = &self.0;
inner.shared.add_tx();
Self(AsyncTx::new(inner.shared.clone()))
}
}
impl<F: Flavor> From<MAsyncTx<F>> for AsyncTx<F> {
fn from(tx: MAsyncTx<F>) -> Self {
tx.0
}
}
impl<F: Flavor> MAsyncTx<F> {
#[inline]
pub(crate) fn new(shared: Arc<ChannelShared<F>>) -> Self {
Self(AsyncTx::new(shared))
}
#[inline]
pub fn into_sink(self) -> AsyncSink<F> {
AsyncSink::new(self.0)
}
#[inline]
pub fn into_blocking(self) -> MTx<F> {
self.into()
}
#[inline]
pub fn downgrade(&self) -> WeakTx<F>
where
F: FlavorMP,
{
WeakTx(self.shared.clone())
}
}
impl<F: Flavor> Deref for MAsyncTx<F> {
type Target = AsyncTx<F>;
#[inline(always)]
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<F: Flavor> From<MTx<F>> for MAsyncTx<F> {
fn from(value: MTx<F>) -> Self {
value.add_tx();
Self(AsyncTx::new(value.shared.clone()))
}
}
impl<F: Flavor + FlavorMP> AsyncTxTrait<F::Item> for MAsyncTx<F> {
#[inline(always)]
fn clone_to_vec(self, count: usize) -> Vec<Self> {
let mut v = Vec::with_capacity(count);
for _ in 0..count - 1 {
v.push(self.clone());
}
v.push(self);
v
}
#[inline(always)]
fn try_send(&self, item: F::Item) -> Result<(), TrySendError<F::Item>> {
self.0.try_send(item)
}
#[inline(always)]
fn send(&self, item: F::Item) -> impl Future<Output = Result<(), SendError<F::Item>>> + Send
where
F::Item: Send + 'static + Unpin,
{
self.0.send(item)
}
#[cfg(any(feature = "tokio", feature = "async_std"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "tokio", feature = "async_std"))))]
#[inline(always)]
fn send_timeout<'a>(
&'a self, item: F::Item, duration: std::time::Duration,
) -> impl Future<Output = Result<(), SendTimeoutError<F::Item>>> + Send
where
F::Item: Send + 'static + Unpin,
{
self.0.send_timeout(item, duration)
}
#[inline(always)]
fn send_with_timer<FR, R>(
&self, item: F::Item, fut: FR,
) -> impl Future<Output = Result<(), SendTimeoutError<F::Item>>> + Send
where
FR: Future<Output = R>,
F::Item: Send + 'static + Unpin,
{
self.0.send_with_timer::<FR, R>(item, fut)
}
#[inline(always)]
fn len(&self) -> usize {
self.as_ref().len()
}
#[inline(always)]
fn capacity(&self) -> Option<usize> {
self.as_ref().capacity()
}
#[inline(always)]
fn is_empty(&self) -> bool {
self.as_ref().is_empty()
}
#[inline(always)]
fn is_full(&self) -> bool {
self.as_ref().is_full()
}
#[inline(always)]
fn is_disconnected(&self) -> bool {
self.as_ref().get_rx_count() == 0
}
#[inline(always)]
fn get_tx_count(&self) -> usize {
self.as_ref().get_tx_count()
}
#[inline(always)]
fn get_rx_count(&self) -> usize {
self.as_ref().get_rx_count()
}
fn get_wakers_count(&self) -> (usize, usize) {
self.as_ref().get_wakers_count()
}
}
impl<F: Flavor> Deref for AsyncTx<F> {
type Target = ChannelShared<F>;
#[inline(always)]
fn deref(&self) -> &ChannelShared<F> {
&self.shared
}
}
impl<F: Flavor> AsRef<ChannelShared<F>> for AsyncTx<F> {
#[inline(always)]
fn as_ref(&self) -> &ChannelShared<F> {
&self.shared
}
}
impl<F: Flavor> AsRef<ChannelShared<F>> for MAsyncTx<F> {
#[inline(always)]
fn as_ref(&self) -> &ChannelShared<F> {
&self.0.shared
}
}
impl<T, F: Flavor<Item = T>> SenderType for AsyncTx<F> {
type Flavor = F;
#[inline(always)]
fn new(shared: Arc<ChannelShared<F>>) -> Self {
AsyncTx::new(shared)
}
}
impl<F: Flavor> NotCloneable for AsyncTx<F> {}
impl<T, F: Flavor<Item = T> + FlavorMP> SenderType for MAsyncTx<F> {
type Flavor = F;
#[inline(always)]
fn new(shared: Arc<ChannelShared<F>>) -> Self {
MAsyncTx::new(shared)
}
}