use std::future::Future;
use std::marker::PhantomData;
use std::rc::Rc;
use std::task::{Poll, Waker};
use thiserror::Error;
use crate::cell::UnsafeCell;
#[derive(Debug, Error)]
#[error("channel has been closed.")]
pub struct RecvError {
_marker: PhantomData<()>,
}
#[derive(Debug)]
struct Inner<T> {
rx_waker: Option<Waker>,
closed: bool,
item: Option<T>,
_marker: PhantomData<Rc<()>>,
}
impl<T> Inner<T> {
#[inline]
fn poll_impl(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<T, RecvError>> {
if let Some(m) = self.item.take() {
return Poll::Ready(Ok(m));
}
if self.closed {
return Poll::Ready(Err(RecvError {
_marker: PhantomData,
}));
}
self.rx_waker = Some(cx.waker().clone());
Poll::Pending
}
#[inline]
fn send_impl(&mut self, item: T) -> Result<(), T> {
if self.closed {
return Err(item);
}
self.item = Some(item);
if let Some(ref m) = self.rx_waker {
m.wake_by_ref();
}
Ok(())
}
fn close_impl(&mut self, wake_recv: bool) {
self.closed = true;
if wake_recv && self.item.is_none() {
if let Some(ref m) = self.rx_waker {
m.wake_by_ref();
}
}
}
}
#[derive(Debug)]
pub struct Receiver<T> {
inner: Rc<UnsafeCell<Inner<T>>>,
}
impl<T> Future for Receiver<T> {
type Output = Result<T, RecvError>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
unsafe { self.inner.with_mut(|inner| inner.poll_impl(cx)) }
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
unsafe { self.inner.with_mut(|inner| inner.close_impl(false)) }
}
}
#[derive(Debug)]
pub struct Sender<T> {
inner: Rc<UnsafeCell<Inner<T>>>,
}
impl<T> Sender<T> {
pub fn send(self, item: T) -> Result<(), T> {
unsafe { self.inner.with_mut(move |inner| inner.send_impl(item)) }
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
unsafe { self.inner.with_mut(|inner| inner.close_impl(true)) }
}
}
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let inner = Rc::new(UnsafeCell::new(Inner {
rx_waker: None,
closed: false,
item: None,
_marker: PhantomData,
}));
(
Sender {
inner: inner.clone(),
},
Receiver { inner },
)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Barrier;
use tokio::task::{spawn_local, LocalSet};
use tokio::test;
use tokio::time::sleep;
use super::*;
#[test]
async fn oneshot_works() {
let (tx, rx) = channel();
tx.send(0).expect("failed to send.");
assert_eq!(rx.await.expect("failed to receive."), 0);
}
#[test]
async fn oneshot_drops_sender() {
let local_set = LocalSet::new();
local_set
.run_until(async {
let (tx, rx) = channel::<usize>();
spawn_local(async move {
sleep(Duration::from_millis(1)).await;
drop(tx);
});
rx.await.expect_err("successful to receive.");
})
.await;
}
#[test]
async fn oneshot_drops_receiver() {
let local_set = LocalSet::new();
local_set
.run_until(async {
let (tx, rx) = channel::<usize>();
let bar = Arc::new(Barrier::new(2));
{
let bar = bar.clone();
spawn_local(async move {
sleep(Duration::from_millis(1)).await;
drop(rx);
bar.wait().await;
});
}
bar.wait().await;
tx.send(0).expect_err("successful to send.");
})
.await;
}
}