use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use crate::error::{RecvError, RecvTimeoutError, TryRecvError};
use futures::{task::AtomicWaker, Stream};
use std::future::Future;
use crate::{inner::Inner, util::async_recv};
pub struct Receiver<T> {
pub(crate) rx: crossbeam_channel::Receiver<T>,
inner: Arc<Inner>,
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
self.inner.inc_rx();
Self {
rx: self.rx.clone(),
inner: self.inner.clone(),
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
if self.inner.dec_rx() == 1 {
let mut signal_queues = self.inner.signal_queues();
while let Some(waker) = signal_queues.pop_recv() {
waker.as_ref().wake();
}
while let Some(waker) = signal_queues.pop_send() {
waker.as_ref().wake();
}
}
}
}
impl<T> Receiver<T> {
pub(crate) fn new(rx: crossbeam_channel::Receiver<T>, inner: Arc<Inner>) -> Self {
Self { rx, inner }
}
pub fn recv(&self) -> Result<T, RecvError> {
let res = self.rx.recv();
if res.is_ok() {
self.signal_send();
}
Ok(res?)
}
pub fn recv_timeout(&self, timeout: Duration) -> Result<T, RecvTimeoutError> {
let res = self.rx.recv_timeout(timeout);
if res.is_ok() {
self.signal_send();
}
Ok(res?)
}
pub fn try_recv(&self) -> Result<T, TryRecvError> {
let res = self.rx.try_recv();
if res.is_ok() {
self.signal_send();
}
Ok(res?)
}
#[inline(always)]
pub(crate) fn signal_send(&self) {
if let Some(waker) = { self.inner.signal_queues().pop_send() } {
waker.as_ref().wake();
}
}
pub fn recv_async(&self) -> RecvFut<'_, T> {
RecvFut {
rx: &self.rx,
inner: &self.inner,
poll_cnt: 0,
waker: AtomicWaker::new(),
}
}
pub fn into_stream(self) -> RecvStream<T> {
RecvStream {
inner: Box::new(RecvStreamInner {
rx: self,
poll_cnt: 0,
waker: AtomicWaker::new(),
}),
}
}
}
pub struct RecvFut<'a, T> {
rx: &'a crossbeam_channel::Receiver<T>,
inner: &'a Arc<Inner>,
poll_cnt: u32,
waker: AtomicWaker,
}
impl<'a, T> Unpin for RecvFut<'a, T> {}
impl<'a, T> RecvFut<'a, T> {
fn poll(
rx: &crossbeam_channel::Receiver<T>,
inner: &Arc<Inner>,
poll_cnt: &mut u32,
waker: &AtomicWaker,
cx: &mut Context<'_>,
) -> Poll<Result<T, RecvError>> {
for _ in 0..1 {
match async_recv(rx) {
Ok(value) => {
let mut signal_queues = inner.signal_queues();
if *poll_cnt > 0 {
signal_queues.remove_recv(waker as *const AtomicWaker as usize);
}
if let Some(waker_ptr) = signal_queues.pop_send() {
drop(signal_queues);
waker_ptr.as_ref().wake();
} else {
drop(signal_queues);
}
return Poll::Ready(Ok(value));
}
Err(TryRecvError::Empty) => {}
Err(TryRecvError::Disconnected) => {
let mut signal_queues = inner.signal_queues();
while let Some(waker_ptr) = signal_queues.pop_send() {
waker_ptr.as_ref().wake();
}
while let Some(waker_ptr) = signal_queues.pop_recv() {
waker_ptr.as_ref().wake();
}
drop(signal_queues);
return Poll::Ready(Err(RecvError));
}
}
}
waker.register(cx.waker());
let mut signal_queues = inner.signal_queues();
match rx.try_recv() {
Ok(value) => {
if *poll_cnt > 0 {
signal_queues.remove_recv(waker as *const AtomicWaker as usize);
}
if let Some(waker_ptr) = signal_queues.pop_send() {
drop(signal_queues);
waker_ptr.as_ref().wake();
}
return Poll::Ready(Ok(value));
}
Err(crossbeam_channel::TryRecvError::Empty) => {}
Err(crossbeam_channel::TryRecvError::Disconnected) => {
while let Some(waker_ptr) = signal_queues.pop_send() {
waker_ptr.as_ref().wake();
}
while let Some(waker_ptr) = signal_queues.pop_recv() {
waker_ptr.as_ref().wake();
}
drop(signal_queues);
return Poll::Ready(Err(RecvError));
}
}
let waker_ptr = waker as *const AtomicWaker as usize;
if *poll_cnt > 0 {
signal_queues.remove_recv(waker_ptr);
}
*poll_cnt += 1;
signal_queues.add_recv(waker_ptr);
if let Some(waker_ptr) = signal_queues.pop_send() {
drop(signal_queues);
waker_ptr.as_ref().wake();
} else {
drop(signal_queues);
}
Poll::Pending
}
}
impl<'a, T> Future for RecvFut<'a, T> {
type Output = Result<T, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
RecvFut::poll(this.rx, this.inner, &mut this.poll_cnt, &mut this.waker, cx)
}
}
impl<'a, T> Drop for RecvFut<'a, T> {
fn drop(&mut self) {
if self.poll_cnt >= 1 {
let mut signal_queues = self.inner.signal_queues();
signal_queues.remove_recv(&self.waker as *const AtomicWaker as usize);
}
}
}
pub struct RecvStreamInner<T> {
rx: Receiver<T>,
poll_cnt: u32,
waker: AtomicWaker,
}
pub struct RecvStream<T> {
inner: Box<RecvStreamInner<T>>,
}
impl<T> Unpin for RecvStream<T> {}
impl<T> Stream for RecvStream<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match RecvFut::poll(
&this.inner.rx.rx,
&this.inner.rx.inner,
&mut this.inner.poll_cnt,
&this.inner.waker,
cx,
) {
Poll::Ready(Ok(item)) => {
this.inner.poll_cnt = 0;
Poll::Ready(Some(item))
}
Poll::Ready(Err(_e)) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl<T> Drop for RecvStream<T> {
fn drop(&mut self) {
if self.inner.poll_cnt >= 1 {
let mut signal_queues = self.inner.rx.inner.signal_queues();
signal_queues.remove_recv(&self.inner.waker as *const AtomicWaker as usize);
}
}
}