use std::collections::VecDeque;
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll, Waker};
use futures::sink::Sink;
use futures::stream::{FusedStream, Stream};
use thiserror::Error;
use crate::cell::UnsafeCell;
#[derive(Error, Debug)]
#[error("queue is empty")]
pub struct TryRecvError {
_marker: PhantomData<()>,
}
#[derive(Error, Debug)]
#[error("failed to send")]
pub struct SendError<T> {
pub inner: T,
}
#[derive(Error, Debug)]
#[error("failed to send")]
pub struct TrySendError {
_marker: PhantomData<()>,
}
#[derive(Debug)]
struct Inner<T> {
rx_waker: Option<Waker>,
closed: bool,
sender_ctr: usize,
items: VecDeque<T>,
_marker: PhantomData<Rc<()>>,
}
impl<T> Inner<T> {
fn close_impl(&mut self) {
self.closed = true;
if let Some(ref m) = self.rx_waker {
m.wake_by_ref();
}
}
#[inline]
fn try_next_impl(&mut self) -> Result<Option<T>, TryRecvError> {
match (self.items.pop_front(), self.closed) {
(Some(m), _) => Ok(Some(m)),
(None, true) => Ok(None),
(None, false) => Err(TryRecvError {
_marker: PhantomData,
}),
}
}
#[inline]
fn poll_next_impl(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
match (self.items.pop_front(), self.closed) {
(Some(m), _) => Poll::Ready(Some(m)),
(None, false) => {
self.rx_waker = Some(cx.waker().clone());
Poll::Pending
}
(None, true) => Poll::Ready(None),
}
}
#[inline]
fn is_terminated_impl(&self) -> bool {
self.items.is_empty() && self.closed
}
#[inline]
fn send_impl(&mut self, item: T) -> Result<(), SendError<T>> {
if self.closed {
return Err(SendError { inner: item });
}
self.items.push_back(item);
if let Some(ref m) = self.rx_waker {
m.wake_by_ref();
}
Ok(())
}
#[inline]
fn pre_clone_sender_impl(&mut self) {
self.sender_ctr += 1;
}
#[inline]
fn drop_sender_impl(&mut self) {
let sender_ctr = {
self.sender_ctr -= 1;
self.sender_ctr
};
if sender_ctr == 0 {
self.close_impl();
}
}
}
#[derive(Debug)]
pub struct UnboundedReceiver<T> {
inner: Rc<UnsafeCell<Inner<T>>>,
}
impl<T> UnboundedReceiver<T> {
pub fn try_next(&self) -> Result<Option<T>, TryRecvError> {
unsafe { self.inner.with_mut(|inner| inner.try_next_impl()) }
}
pub fn close(&self) {
unsafe { self.inner.with_mut(|inner| inner.close_impl()) }
}
}
impl<T> Stream for UnboundedReceiver<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
unsafe { self.inner.with_mut(|inner| inner.poll_next_impl(cx)) }
}
}
impl<T> FusedStream for UnboundedReceiver<T> {
fn is_terminated(&self) -> bool {
unsafe { self.inner.with(|inner| inner.is_terminated_impl()) }
}
}
impl<T> Drop for UnboundedReceiver<T> {
fn drop(&mut self) {
unsafe { self.inner.with_mut(|inner| inner.close_impl()) }
}
}
#[derive(Debug)]
pub struct UnboundedSender<T> {
inner: Rc<UnsafeCell<Inner<T>>>,
}
impl<T> UnboundedSender<T> {
pub fn send_now(&self, item: T) -> Result<(), SendError<T>> {
unsafe { self.inner.with_mut(move |inner| inner.send_impl(item)) }
}
pub fn close_now(&self) {
unsafe { self.inner.with_mut(|inner| inner.close_impl()) }
}
}
impl<T> Clone for UnboundedSender<T> {
fn clone(&self) -> Self {
unsafe { self.inner.with_mut(|inner| inner.pre_clone_sender_impl()) }
Self {
inner: self.inner.clone(),
}
}
}
impl<T> Drop for UnboundedSender<T> {
fn drop(&mut self) {
unsafe { self.inner.with_mut(|inner| inner.drop_sender_impl()) }
}
}
impl<T> Sink<T> for &'_ UnboundedSender<T> {
type Error = TrySendError;
fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
self.send_now(item).map_err(|_| TrySendError {
_marker: PhantomData,
})
}
fn poll_ready(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let closed = unsafe { self.inner.with(|inner| inner.closed) };
match closed {
false => Poll::Ready(Ok(())),
true => Poll::Ready(Err(TrySendError {
_marker: PhantomData,
})),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.close_now();
Poll::Ready(Ok(()))
}
}
pub fn unbounded<T>() -> (UnboundedSender<T>, UnboundedReceiver<T>) {
let inner = Rc::new(UnsafeCell::new(Inner {
rx_waker: None,
closed: false,
sender_ctr: 1,
items: VecDeque::new(),
_marker: PhantomData,
}));
(
UnboundedSender {
inner: inner.clone(),
},
UnboundedReceiver { inner },
)
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use tokio::task::{spawn_local, LocalSet};
use tokio::test;
use tokio::time::sleep;
use super::*;
#[test]
async fn mpsc_works() {
let local_set = LocalSet::new();
local_set
.run_until(async {
let (tx, mut rx) = unbounded::<usize>();
spawn_local(async move {
for i in 0..10 {
(&tx).send(i).await.expect("failed to send.");
sleep(Duration::from_millis(1)).await;
}
});
for i in 0..10 {
let received = rx.next().await.expect("failed to receive");
assert_eq!(i, received);
}
assert_eq!(rx.next().await, None);
})
.await;
}
#[test]
async fn mpsc_drops_receiver() {
let (tx, rx) = unbounded::<usize>();
drop(rx);
(&tx).send(0).await.expect_err("should fail to send.");
}
#[test]
async fn mpsc_multi_sender() {
let local_set = LocalSet::new();
local_set
.run_until(async {
let (tx, mut rx) = unbounded::<usize>();
spawn_local(async move {
let tx2 = tx.clone();
for i in 0..10 {
if i % 2 == 0 {
(&tx).send(i).await.expect("failed to send.");
} else {
(&tx2).send(i).await.expect("failed to send.");
}
sleep(Duration::from_millis(1)).await;
}
drop(tx2);
for i in 10..20 {
(&tx).send(i).await.expect("failed to send.");
sleep(Duration::from_millis(1)).await;
}
});
for i in 0..20 {
let received = rx.next().await.expect("failed to receive");
assert_eq!(i, received);
}
assert_eq!(rx.next().await, None);
})
.await;
}
#[test]
async fn mpsc_drops_sender() {
let (tx, mut rx) = unbounded::<usize>();
drop(tx);
assert_eq!(rx.next().await, None);
}
}