tempest-rt 0.0.1

TempestDB Deterministic Async Runtime
Documentation
//! Single-value, single-producer, single-consumer channel.
//!
//! Create a channel with [`channel`]. The [`Sender`] sends exactly one value; the [`Receiver`]
//! is a future that resolves to that value. If either end is dropped before the send completes,
//! the other end observes a closed error.

use std::{
    cell::RefCell,
    future::poll_fn,
    rc::Rc,
    task::{Poll, Waker},
};

use derive_more::{Display, Error};

struct Inner<T> {
    value: Option<T>,
    waker: Option<Waker>,
}

/// Sending half of a oneshot channel.
#[derive(derive_more::Debug)]
pub struct Sender<T> {
    #[debug("{:p}", Rc::as_ptr(inner))]
    inner: Rc<RefCell<Inner<T>>>,
}

/// Receiving half of a oneshot channel. Implements [`Future`] and resolves to the sent value.
#[derive(derive_more::Debug)]
#[must_use = "futures do nothing unless awaited"]
pub struct Receiver<T> {
    #[debug("{:p}", Rc::as_ptr(inner))]
    inner: Rc<RefCell<Inner<T>>>,
}

/// Creates a oneshot channel, returning the sender and receiver halves.
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
    let inner = Rc::new(RefCell::new(Inner {
        value: None,
        waker: None,
    }));
    (
        Sender {
            inner: inner.clone(),
        },
        Receiver { inner },
    )
}

impl<T> Sender<T> {
    /// Tries to send `val` to the [`Receiver`].
    ///
    /// # Returns
    ///
    /// - `Ok(())` when the send has succeeded
    /// - `Err(val)` when the channel was closed (receiver was dropped)
    pub fn send(self, val: T) -> Result<(), T> {
        if Rc::strong_count(&self.inner) == 1 {
            return Err(val);
        }
        let mut borrow = self.inner.borrow_mut();
        borrow.value = Some(val);
        if let Some(waker) = borrow.waker.take() {
            waker.wake();
        }
        Ok(())
    }
}

/// Error returned by [`Receiver::recv`] when the sender has been dropped without sending a value.
#[derive(Debug, Display, Error, PartialEq, Eq)]
#[display("sender has been dropped")]
pub struct RecvError;

/// Error returned by [`Receiver::try_recv`].
#[derive(Debug, Display, Error, PartialEq, Eq)]
pub enum TryRecvError {
    /// The sender is still alive but has not sent a value yet.
    #[display("channel is empty")]
    Empty,
    /// The sender was dropped without sending a value.
    #[display("sender has been dropped")]
    Closed,
}

impl<T> Receiver<T> {
    pub(crate) fn poll_recv(
        &mut self,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Result<T, RecvError>> {
        let mut borrow = self.inner.borrow_mut();
        if let Some(val) = borrow.value.take() {
            return Poll::Ready(Ok(val));
        }
        if Rc::strong_count(&self.inner) == 1 {
            return Poll::Ready(Err(RecvError));
        }
        borrow.waker = Some(cx.waker().clone());
        Poll::Pending
    }

    /// Receives the value, parking until it arrives.
    ///
    /// Returns `Err` if the sender was dropped without sending a value.
    pub async fn recv(mut self) -> Result<T, RecvError> {
        poll_fn(|cx| self.poll_recv(cx)).await
    }

    /// Receives without waiting.
    ///
    /// Returns `Err(Empty)` if no value has been sent yet, or `Err(Closed)` if the sender was
    /// dropped without sending.
    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
        let mut borrow = self.inner.borrow_mut();
        if let Some(val) = borrow.value.take() {
            return Ok(val);
        }
        if Rc::strong_count(&self.inner) == 1 {
            Err(TryRecvError::Closed)
        } else {
            Err(TryRecvError::Empty)
        }
    }
}

#[cfg(test)]
mod tests {
    use tempest_io::VirtualIo;

    use crate::{block_on, spawn};

    use super::*;

    #[test]
    fn test_oneshot_send_recv() {
        block_on(VirtualIo::default(), async {
            let (tx, rx) = channel();
            tx.send(5).unwrap();
            assert_eq!(rx.recv().await.unwrap(), 5);
        })
    }

    #[test]
    fn test_oneshot_sender_dropped() {
        block_on(VirtualIo::default(), async {
            let (tx, rx) = channel::<i32>();
            drop(tx);
            assert_eq!(rx.recv().await, Err(RecvError));
        });
    }

    #[test]
    fn test_oneshot_receiver_dropped() {
        block_on(VirtualIo::default(), async {
            let (tx, rx) = channel::<i32>();
            drop(rx);
            assert_eq!(tx.send(99), Err(99));
        });
    }

    #[test]
    fn test_oneshot_from_task() {
        block_on(VirtualIo::default(), async {
            let (tx, rx) = channel();
            let handle = spawn(async {
                tx.send(42).unwrap();
            });

            handle.await.unwrap();
            let result = rx.recv().await.unwrap();
            assert_eq!(result, 42);
        });
    }

    #[test]
    fn test_oneshot_try_recv_empty() {
        block_on(VirtualIo::default(), async {
            let (_tx, mut rx) = channel::<i32>();
            assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
        });
    }

    #[test]
    fn test_oneshot_try_recv_closed() {
        block_on(VirtualIo::default(), async {
            let (tx, mut rx) = channel::<i32>();
            drop(tx);
            assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
        });
    }

    #[test]
    fn test_oneshot_try_recv_value() {
        block_on(VirtualIo::default(), async {
            let (tx, mut rx) = channel();
            tx.send(42).unwrap();
            assert_eq!(rx.try_recv(), Ok(42));
        });
    }
}