tentacle 0.4.2

Minimal implementation for a multiplexed p2p network framework.
Documentation
use crate::channel::{
    decode_state, encode_state, queue::Queue, Priority, SendError, SendErrorKind, TryRecvError,
    TrySendError, INIT_STATE, MAX_BUFFER, MAX_CAPACITY, OPEN_MASK,
};
use futures::{
    future::poll_fn,
    stream::{FusedStream, Stream},
    task::AtomicWaker,
};
use std::{
    pin::Pin,
    sync::{
        atomic::{AtomicUsize, Ordering::SeqCst},
        Arc,
    },
    task::{Context, Poll},
};

/// Creates an unbounded mpsc channel for communicating between asynchronous
/// tasks.
///
/// A `send` on this channel will always succeed as long as the receive half has
/// not been closed. If the receiver falls behind, messages will be arbitrarily
/// buffered.
///
/// **Note** that the amount of available system memory is an implicit bound to
/// the channel. Using an `unbounded` channel has the ability of causing the
/// process to run out of memory. In this case, the process will be aborted.
pub fn unbounded<T>() -> (UnboundedSender<T>, UnboundedReceiver<T>) {
    let inner = Arc::new(UnboundedInner {
        state: AtomicUsize::new(INIT_STATE),
        message_queue: Queue::new(),
        quick_message_queue: Queue::new(),
        num_senders: AtomicUsize::new(1),
        recv_task: AtomicWaker::new(),
    });

    let tx = UnboundedSenderInner {
        inner: inner.clone(),
    };

    let rx = UnboundedReceiver { inner: Some(inner) };

    (UnboundedSender(Some(tx)), rx)
}

#[derive(Debug)]
struct UnboundedInner<T> {
    // Internal channel state. Consists of the number of messages stored in the
    // channel as well as a flag signalling that the channel is closed.
    state: AtomicUsize,

    // Atomic, FIFO queue used to send messages to the receiver
    message_queue: Queue<T>,

    // Atomic, FIFO queue used to send messages to the receiver, quick queue
    quick_message_queue: Queue<T>,

    // Number of senders in existence
    num_senders: AtomicUsize,

    // Handle to the receiver's task.
    recv_task: AtomicWaker,
}

impl<T> UnboundedInner<T> {
    // Clear `open` flag in the state, keep `num_messages` intact.
    fn set_closed(&self) {
        let curr = self.state.load(SeqCst);
        if !decode_state(curr).is_open {
            return;
        }

        self.state.fetch_and(!OPEN_MASK, SeqCst);
    }
}

unsafe impl<T: Send> Send for UnboundedInner<T> {}
unsafe impl<T: Send> Sync for UnboundedInner<T> {}

#[derive(Debug)]
struct UnboundedSenderInner<T> {
    // Channel state shared between the sender and receiver.
    inner: Arc<UnboundedInner<T>>,
}

// We never project Pin<&mut SenderInner> to `Pin<&mut T>`
impl<T> Unpin for UnboundedSenderInner<T> {}

impl<T> UnboundedSenderInner<T> {
    fn poll_ready_nb(&self) -> Poll<Result<(), SendError>> {
        let state = decode_state(self.inner.state.load(SeqCst));
        if state.is_open {
            Poll::Ready(Ok(()))
        } else {
            Poll::Ready(Err(SendError {
                kind: SendErrorKind::Disconnected,
            }))
        }
    }

    // Push message to the queue and signal to the receiver
    fn queue_push_and_signal(&self, msg: T, priority: Priority) {
        // Push the message onto the message queue
        match priority {
            Priority::High => self.inner.quick_message_queue.push(msg),
            Priority::Normal => self.inner.message_queue.push(msg),
        }

        // Signal to the receiver that a message has been enqueued. If the
        // receiver is parked, this will unpark the task.
        self.inner.recv_task.wake();
    }

    // Increment the number of queued messages. Returns the resulting number.
    fn inc_num_messages(&self) -> Option<usize> {
        let mut curr = self.inner.state.load(SeqCst);

        loop {
            let mut state = decode_state(curr);

            // The receiver end closed the channel.
            if !state.is_open {
                return None;
            }

            // This probably is never hit? Odds are the process will run out of
            // memory first. It may be worth to return something else in this
            // case?
            assert!(
                state.num_messages < MAX_CAPACITY,
                "buffer space \
                    exhausted; sending this messages would overflow the state"
            );

            state.num_messages += 1;

            let next = encode_state(&state);
            match self
                .inner
                .state
                .compare_exchange(curr, next, SeqCst, SeqCst)
            {
                Ok(_) => return Some(state.num_messages),
                Err(actual) => curr = actual,
            }
        }
    }

    /// Returns whether the senders send to the same receiver.
    fn same_receiver(&self, other: &Self) -> bool {
        Arc::ptr_eq(&self.inner, &other.inner)
    }

    /// Returns pointer to the Arc containing sender
    ///
    /// The returned pointer is not referenced and should be only used for hashing!
    fn ptr(&self) -> *const UnboundedInner<T> {
        &*self.inner
    }

    /// Returns whether this channel is closed without needing a context.
    fn is_closed(&self) -> bool {
        !decode_state(self.inner.state.load(SeqCst)).is_open
    }

    /// Closes this channel from the sender side, preventing any new messages.
    fn close_channel(&self) {
        // There's no need to park this sender, its dropping,
        // and we don't want to check for capacity, so skip
        // that stuff from `do_send`.

        self.inner.set_closed();
        self.inner.recv_task.wake();
    }
}

impl<T> Clone for UnboundedSenderInner<T> {
    fn clone(&self) -> UnboundedSenderInner<T> {
        // Since this atomic op isn't actually guarding any memory and we don't
        // care about any orderings besides the ordering on the single atomic
        // variable, a relaxed ordering is acceptable.
        let mut curr = self.inner.num_senders.load(SeqCst);

        loop {
            // If the maximum number of senders has been reached, then fail
            if curr == MAX_BUFFER {
                panic!("cannot clone `Sender` -- too many outstanding senders");
            }

            debug_assert!(curr < MAX_BUFFER);

            let next = curr + 1;
            match self
                .inner
                .num_senders
                .compare_exchange(curr, next, SeqCst, SeqCst)
            {
                Ok(actual) => {
                    // The ABA problem doesn't matter here. We only care that the
                    // number of senders never exceeds the maximum.
                    if actual == curr {
                        return UnboundedSenderInner {
                            inner: self.inner.clone(),
                        };
                    }
                }
                Err(actual) => curr = actual,
            }
        }
    }
}

impl<T> Drop for UnboundedSenderInner<T> {
    fn drop(&mut self) {
        // Ordering between variables don't matter here
        let prev = self.inner.num_senders.fetch_sub(1, SeqCst);

        if prev == 1 {
            self.close_channel();
        }
    }
}

/// The transmission end of an unbounded mpsc channel.
///
/// This value is created by the [`unbounded`](unbounded) function.
#[derive(Debug)]
pub struct UnboundedSender<T>(Option<UnboundedSenderInner<T>>);

impl<T> UnboundedSender<T> {
    /// Check if the channel is ready to receive a message.
    pub fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), SendError>> {
        let inner = self.0.as_ref().ok_or(SendError {
            kind: SendErrorKind::Disconnected,
        })?;
        inner.poll_ready_nb()
    }

    /// Returns whether this channel is closed without needing a context.
    pub fn is_closed(&self) -> bool {
        self.0
            .as_ref()
            .map(UnboundedSenderInner::is_closed)
            .unwrap_or(true)
    }

    /// Closes this channel from the sender side, preventing any new messages.
    pub fn close_channel(&self) {
        if let Some(inner) = &self.0 {
            inner.close_channel();
        }
    }

    /// Disconnects this sender from the channel, closing it if there are no more senders left.
    pub fn disconnect(&mut self) {
        self.0 = None;
    }

    /// Try get inner queue len, if close or disconnect, it will be none
    pub fn len(&self) -> Option<usize> {
        self.0.as_ref().and_then(|inner| {
            let state = decode_state(inner.inner.state.load(SeqCst));
            if state.is_open {
                Some(state.num_messages)
            } else {
                None
            }
        })
    }

    // Do the send without parking current task.
    fn do_send_nb(&self, msg: T, priority: Priority) -> Result<(), TrySendError<T>> {
        if let Some(inner) = &self.0 {
            if inner.inc_num_messages().is_some() {
                inner.queue_push_and_signal(msg, priority);
                return Ok(());
            }
        }

        Err(TrySendError {
            err: SendError {
                kind: SendErrorKind::Disconnected,
            },
            val: msg,
        })
    }

    // Send a message on the channel with async fn, doesn't use Sink trait
    pub async fn async_send(&self, msg: T) -> Result<(), SendError> {
        let mut msg = Some(msg);
        poll_fn(|cx| {
            let item = msg.take().unwrap();
            match self.poll_ready(cx)? {
                Poll::Ready(()) => Poll::Ready(self.start_send(item)),
                Poll::Pending => {
                    msg = Some(item);
                    Poll::Pending
                }
            }
        })
        .await
    }

    // Send a message on the channel with async fn, doesn't use Sink trait
    pub async fn async_quick_send(&self, msg: T) -> Result<(), SendError> {
        let mut msg = Some(msg);
        poll_fn(|cx| {
            let item = msg.take().unwrap();
            match self.poll_ready(cx)? {
                Poll::Ready(()) => Poll::Ready(self.start_quick_send(item)),
                Poll::Pending => {
                    msg = Some(item);
                    Poll::Pending
                }
            }
        })
        .await
    }

    /// Send a message on the channel.
    ///
    /// This method should only be called after `poll_ready` has been used to
    /// verify that the channel is ready to receive a message.
    pub fn start_send(&self, msg: T) -> Result<(), SendError> {
        self.do_send_nb(msg, Priority::Normal).map_err(|e| e.err)
    }

    /// Send a message on the channel.
    ///
    /// This method should only be called after `poll_ready` has been used to
    /// verify that the channel is ready to receive a message.
    pub fn start_quick_send(&self, msg: T) -> Result<(), SendError> {
        self.do_send_nb(msg, Priority::High).map_err(|e| e.err)
    }

    /// Sends a message along this channel.
    ///
    /// This is an unbounded sender, so this function differs from `Sink::send`
    /// by ensuring the return type reflects that the channel is always ready to
    /// receive messages.
    pub fn unbounded_send(&self, msg: T) -> Result<(), TrySendError<T>> {
        self.do_send_nb(msg, Priority::Normal)
    }

    /// Sends a message along this channel.
    ///
    /// This is an unbounded sender, so this function differs from `Sink::send`
    /// by ensuring the return type reflects that the channel is always ready to
    /// receive messages.
    pub fn unbounded_quick_send(&self, msg: T) -> Result<(), TrySendError<T>> {
        self.do_send_nb(msg, Priority::High)
    }

    /// Returns whether the senders send to the same receiver.
    pub fn same_receiver(&self, other: &Self) -> bool {
        match (&self.0, &other.0) {
            (Some(inner), Some(other)) => inner.same_receiver(other),
            _ => false,
        }
    }

    /// Hashes the receiver into the provided hasher
    pub fn hash_receiver<H>(&self, hasher: &mut H)
    where
        H: std::hash::Hasher,
    {
        use std::hash::Hash;

        let ptr = self.0.as_ref().map(|inner| inner.ptr());
        ptr.hash(hasher);
    }
}

impl<T> Clone for UnboundedSender<T> {
    fn clone(&self) -> UnboundedSender<T> {
        UnboundedSender(self.0.clone())
    }
}

/// The receiving end of an unbounded mpsc channel.
///
/// This value is created by the [`unbounded`](unbounded) function.
#[derive(Debug)]
pub struct UnboundedReceiver<T> {
    inner: Option<Arc<UnboundedInner<T>>>,
}

// `Pin<&mut UnboundedReceiver<T>>` is never projected to `Pin<&mut T>`
impl<T> Unpin for UnboundedReceiver<T> {}

impl<T> UnboundedReceiver<T> {
    /// Closes the receiving half of a channel, without dropping it.
    ///
    /// This prevents any further messages from being sent on the channel while
    /// still enabling the receiver to drain messages that are buffered.
    pub fn close(&mut self) {
        if let Some(inner) = &mut self.inner {
            inner.set_closed();
        }
    }

    /// Tries to receive the next message without notifying a context if empty.
    ///
    /// It is not recommended to call this function from inside of a future,
    /// only when you've otherwise arranged to be notified when the channel is
    /// no longer empty.
    ///
    /// This function will panic if called after `try_next` or `poll_next` has
    /// returned `None`.
    pub fn try_next(&mut self) -> Result<Option<(Priority, T)>, TryRecvError> {
        match self.next_message() {
            Poll::Ready(msg) => Ok(msg),
            Poll::Pending => Err(TryRecvError { _priv: () }),
        }
    }

    fn next_message(&mut self) -> Poll<Option<(Priority, T)>> {
        let inner = self
            .inner
            .as_mut()
            .expect("Receiver::next_message called after `None`");

        match unsafe { inner.quick_message_queue.pop_spin() } {
            Some(msg) => {
                // Decrement number of messages
                self.dec_num_messages();

                Poll::Ready(Some((Priority::High, msg)))
            }
            None => {
                match unsafe { inner.message_queue.pop_spin() } {
                    Some(msg) => {
                        // Decrement number of messages
                        self.dec_num_messages();

                        Poll::Ready(Some((Priority::Normal, msg)))
                    }
                    None => {
                        let state = decode_state(inner.state.load(SeqCst));
                        if state.is_closed() {
                            // If closed flag is set AND there are no pending messages
                            // it means end of stream
                            self.inner = None;
                            Poll::Ready(None)
                        } else {
                            // If queue is open, we need to return Pending
                            // to be woken up when new messages arrive.
                            // If queue is closed but num_messages is non-zero,
                            // it means that senders updated the state,
                            // but didn't put message to queue yet,
                            // so we need to park until sender unparks the task
                            // after queueing the message.
                            Poll::Pending
                        }
                    }
                }
            }
        }
    }

    fn dec_num_messages(&self) {
        if let Some(inner) = &self.inner {
            // OPEN_MASK is highest bit, so it's unaffected by subtraction
            // unless there's underflow, and we know there's no underflow
            // because number of messages at this point is always > 0.
            inner.state.fetch_sub(1, SeqCst);
        }
    }
}

impl<T> FusedStream for UnboundedReceiver<T> {
    fn is_terminated(&self) -> bool {
        self.inner.is_none()
    }
}

impl<T> Stream for UnboundedReceiver<T> {
    type Item = (Priority, T);

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        // Try to read a message off of the message queue.
        match self.next_message() {
            Poll::Ready(msg) => {
                if msg.is_none() {
                    self.inner = None;
                }
                Poll::Ready(msg)
            }
            Poll::Pending => {
                // There are no messages to read, in this case, park.
                self.inner.as_ref().unwrap().recv_task.register(cx.waker());
                // Check queue again after parking to prevent race condition:
                // a message could be added to the queue after previous `next_message`
                // before `register` call.
                self.next_message()
            }
        }
    }
}

impl<T> Drop for UnboundedReceiver<T> {
    fn drop(&mut self) {
        // Drain the channel of all pending messages
        self.close();
        if self.inner.is_some() {
            loop {
                match self.next_message() {
                    Poll::Ready(Some(_)) => {}
                    Poll::Ready(None) => break,
                    Poll::Pending => {
                        let state = decode_state(self.inner.as_ref().unwrap().state.load(SeqCst));

                        // If the channel is closed, then there is no need to park.
                        if state.is_closed() {
                            break;
                        }

                        // TODO: Spinning isn't ideal, it might be worth
                        // investigating using a condvar or some other strategy
                        // here. That said, if this case is hit, then another thread
                        // is about to push the value into the queue and this isn't
                        // the only spinlock in the impl right now.
                        #[cfg(any(
                            target_arch = "x86",
                            target_arch = "x86_64",
                            target_arch = "aarch64",
                            target_arch = "arm"
                        ))]
                        std::hint::spin_loop();
                        #[cfg(not(any(
                            target_arch = "x86",
                            target_arch = "x86_64",
                            target_arch = "aarch64",
                            target_arch = "arm"
                        )))]
                        std::thread::yield_now();
                    }
                }
            }
        }
    }
}