use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use crate::error::{SendError, SendTimeoutError, TrySendError};
use futures::task::AtomicWaker;
use std::future::Future;
use crate::{inner::Inner, util::async_send};
pub struct Sender<T> {
pub(crate) tx: crossbeam_channel::Sender<T>,
inner: Arc<Inner>,
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.inner.inc_tx();
Self {
tx: self.tx.clone(),
inner: self.inner.clone(),
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
if self.inner.dec_tx() == 1 {
let mut signal_queues = self.inner.signal_queues();
while let Some(waker) = signal_queues.pop_recv() {
waker.as_ref().wake();
}
while let Some(waker) = signal_queues.pop_send() {
waker.as_ref().wake();
}
}
}
}
impl<T> Sender<T> {
pub(crate) fn new(tx: crossbeam_channel::Sender<T>, inner: Arc<Inner>) -> Self {
Self { tx, inner }
}
pub fn send(&self, value: T) -> Result<(), SendError<T>> {
let res = self.tx.send(value);
if res.is_ok() {
self.signal_recv();
}
Ok(res?)
}
pub fn send_timeout(&self, value: T, timeout: Duration) -> Result<(), SendTimeoutError<T>> {
let res = self.tx.send_timeout(value, timeout);
if res.is_ok() {
self.signal_recv();
}
Ok(res?)
}
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
let res = self.tx.try_send(value);
if res.is_ok() {
self.signal_recv();
}
Ok(res?)
}
#[inline(always)]
pub(crate) fn signal_recv(&self) {
if let Some(waker) = { self.inner.signal_queues().pop_recv() } {
waker.as_ref().wake();
}
}
pub fn send_async(&self, value: T) -> SendFut<'_, T> {
SendFut {
tx: &self.tx,
inner: &self.inner,
value: Some(value),
poll_cnt: 0,
waker: AtomicWaker::new(),
}
}
}
pub struct SendFut<'a, T> {
tx: &'a crossbeam_channel::Sender<T>,
inner: &'a Arc<Inner>,
value: Option<T>,
poll_cnt: u32,
waker: AtomicWaker,
}
impl<'a, T> Unpin for SendFut<'a, T> {}
impl<'a, T> Future for SendFut<'a, T> {
type Output = Result<(), SendError<T>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
for _ in 0..1 {
match async_send(this.tx, this.value.take().unwrap()) {
Ok(()) => {
let mut signal_queues = this.inner.signal_queues();
if this.poll_cnt > 0 {
signal_queues.remove_send(&this.waker as *const AtomicWaker as usize);
}
if let Some(waker) = signal_queues.pop_recv() {
drop(signal_queues);
waker.as_ref().wake();
} else {
drop(signal_queues);
}
return Poll::Ready(Ok(()));
}
Err(TrySendError::Full(value)) => {
this.value = Some(value);
}
Err(TrySendError::Disconnected(value)) => {
let mut signal_queues = this.inner.signal_queues();
while let Some(waker) = signal_queues.pop_send() {
waker.as_ref().wake();
}
while let Some(waker) = signal_queues.pop_recv() {
waker.as_ref().wake();
}
drop(signal_queues);
return Poll::Ready(Err(SendError(value)));
}
}
}
this.waker.register(cx.waker());
let mut signal_queues = this.inner.signal_queues();
match this.tx.try_send(this.value.take().unwrap()) {
Ok(()) => {
if this.poll_cnt > 0 {
signal_queues.remove_send(&this.waker as *const AtomicWaker as usize);
}
if let Some(waker) = signal_queues.pop_recv() {
drop(signal_queues);
waker.as_ref().wake();
} else {
drop(signal_queues);
}
return Poll::Ready(Ok(()));
}
Err(crossbeam_channel::TrySendError::Full(value)) => {
this.value = Some(value);
}
Err(crossbeam_channel::TrySendError::Disconnected(value)) => {
while let Some(waker) = signal_queues.pop_send() {
waker.as_ref().wake();
}
while let Some(waker) = signal_queues.pop_recv() {
waker.as_ref().wake();
}
drop(signal_queues);
return Poll::Ready(Err(SendError(value)));
}
}
let waker_ptr = &this.waker as *const AtomicWaker as usize;
if this.poll_cnt > 0 {
signal_queues.remove_send(waker_ptr);
}
this.poll_cnt += 1;
signal_queues.add_send(waker_ptr);
if let Some(waker) = signal_queues.pop_recv() {
drop(signal_queues);
waker.as_ref().wake();
} else {
drop(signal_queues);
}
Poll::Pending
}
}
impl<'a, T> Drop for SendFut<'a, T> {
fn drop(&mut self) {
if self.poll_cnt >= 1 {
let mut signal_queues = self.inner.signal_queues();
signal_queues.remove_send(&self.waker as *const AtomicWaker as usize);
}
}
}