use std::cell::UnsafeCell;
use std::future::Future;
use std::mem::ManuallyDrop;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll, Waker};
use super::{RecvError, SendError, TryRecvError, TrySendError};
struct RingBuffer<T> {
buf: *mut T,
mask: usize,
head: usize,
tail: usize,
}
impl<T> RingBuffer<T> {
fn new(capacity: usize) -> Self {
assert!(capacity > 0, "channel capacity must be > 0");
let capacity = capacity.next_power_of_two();
let mask = capacity - 1;
let slots = ManuallyDrop::new(Vec::<T>::with_capacity(capacity));
let buf = slots.as_ptr().cast_mut();
Self {
buf,
mask,
head: 0,
tail: 0,
}
}
fn capacity(&self) -> usize {
self.mask + 1
}
fn len(&self) -> usize {
self.tail.wrapping_sub(self.head)
}
fn is_empty(&self) -> bool {
self.head == self.tail
}
fn is_full(&self) -> bool {
self.len() == self.capacity()
}
unsafe fn push(&mut self, value: T) {
debug_assert!(!self.is_full());
unsafe {
self.buf.add(self.tail & self.mask).write(value);
}
self.tail = self.tail.wrapping_add(1);
}
unsafe fn pop(&mut self) -> T {
debug_assert!(!self.is_empty());
let val = unsafe { self.buf.add(self.head & self.mask).read() };
self.head = self.head.wrapping_add(1);
val
}
}
impl<T> Drop for RingBuffer<T> {
fn drop(&mut self) {
while !self.is_empty() {
unsafe {
self.buf.add(self.head & self.mask).drop_in_place();
}
self.head = self.head.wrapping_add(1);
}
unsafe {
let capacity = self.mask + 1;
drop(Vec::from_raw_parts(self.buf, 0, capacity));
}
}
}
struct Waiter {
waker: Option<Waker>,
next: *mut Waiter,
prev: *mut Waiter,
queued: bool,
}
impl Waiter {
fn new() -> Self {
Self {
waker: None,
next: std::ptr::null_mut(),
prev: std::ptr::null_mut(),
queued: false,
}
}
}
struct WaiterList {
head: *mut Waiter,
tail: *mut Waiter,
}
impl WaiterList {
fn new() -> Self {
Self {
head: std::ptr::null_mut(),
tail: std::ptr::null_mut(),
}
}
unsafe fn push_back(&mut self, waiter: *mut Waiter) {
debug_assert!(unsafe { !(*waiter).queued });
unsafe {
(*waiter).queued = true;
(*waiter).next = std::ptr::null_mut();
(*waiter).prev = self.tail;
}
if self.tail.is_null() {
self.head = waiter;
} else {
unsafe { (*self.tail).next = waiter };
}
self.tail = waiter;
}
unsafe fn pop_front(&mut self) -> *mut Waiter {
let waiter = self.head;
if waiter.is_null() {
return std::ptr::null_mut();
}
self.head = unsafe { (*waiter).next };
if self.head.is_null() {
self.tail = std::ptr::null_mut();
} else {
unsafe { (*self.head).prev = std::ptr::null_mut() };
}
unsafe {
(*waiter).next = std::ptr::null_mut();
(*waiter).prev = std::ptr::null_mut();
(*waiter).queued = false;
}
waiter
}
unsafe fn remove(&mut self, waiter: *mut Waiter) {
if unsafe { !(*waiter).queued } {
return;
}
let prev = unsafe { (*waiter).prev };
let next = unsafe { (*waiter).next };
if prev.is_null() {
self.head = next;
} else {
unsafe { (*prev).next = next };
}
if next.is_null() {
self.tail = prev;
} else {
unsafe { (*next).prev = prev };
}
unsafe {
(*waiter).next = std::ptr::null_mut();
(*waiter).prev = std::ptr::null_mut();
(*waiter).queued = false;
}
}
unsafe fn wake_all(&mut self) {
let mut cursor = self.head;
while !cursor.is_null() {
let next = unsafe { (*cursor).next };
unsafe {
(*cursor).next = std::ptr::null_mut();
(*cursor).prev = std::ptr::null_mut();
(*cursor).queued = false;
}
if let Some(waker) = unsafe { (*cursor).waker.take() } {
waker.wake();
}
cursor = next;
}
self.head = std::ptr::null_mut();
self.tail = std::ptr::null_mut();
}
}
struct Inner<T> {
buffer: RingBuffer<T>,
rx_waker: Option<Waker>,
tx_waiters: WaiterList,
sender_count: u32,
closed: bool,
}
impl<T> Inner<T> {
fn new(capacity: usize) -> Self {
Self {
buffer: RingBuffer::new(capacity),
rx_waker: None,
tx_waiters: WaiterList::new(),
sender_count: 1,
closed: false,
}
}
}
type Shared<T> = Rc<UnsafeCell<Inner<T>>>;
#[inline]
#[allow(clippy::mut_from_ref)] unsafe fn inner<T>(shared: &Shared<T>) -> &mut Inner<T> {
unsafe { &mut *shared.get() }
}
pub fn channel<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
crate::context::assert_in_runtime("local::channel() called outside Runtime::block_on");
channel_inner(capacity)
}
fn channel_inner<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let shared: Shared<T> = Rc::new(UnsafeCell::new(Inner::new(capacity)));
let tx = Sender {
inner: shared.clone(),
};
let rx = Receiver { inner: shared };
(tx, rx)
}
pub struct Sender<T> {
inner: Shared<T>,
}
impl<T> Sender<T> {
pub fn send(&self, value: T) -> Send<'_, T> {
Send {
sender: self,
value: Some(value),
waiter: Waiter::new(),
}
}
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
let state = unsafe { inner(&self.inner) };
if state.closed {
return Err(TrySendError::Closed(value));
}
if state.buffer.is_full() {
return Err(TrySendError::Full(value));
}
unsafe { state.buffer.push(value) };
if let Some(waker) = state.rx_waker.take() {
waker.wake();
}
Ok(())
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
let state = unsafe { inner(&self.inner) };
state.sender_count += 1;
Self {
inner: self.inner.clone(),
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let state = unsafe { inner(&self.inner) };
state.sender_count -= 1;
if state.sender_count == 0 {
if let Some(waker) = state.rx_waker.take() {
waker.wake();
}
}
}
}
pub struct Send<'a, T> {
sender: &'a Sender<T>,
value: Option<T>,
waiter: Waiter,
}
impl<T> Future for Send<'_, T> {
type Output = Result<(), SendError<T>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
let state = unsafe { inner(&this.sender.inner) };
if state.closed {
let value = this.value.take().expect("polled after completion");
if this.waiter.queued {
unsafe { state.tx_waiters.remove(&raw mut this.waiter) };
}
return Poll::Ready(Err(SendError(value)));
}
if !state.buffer.is_full() {
let value = this.value.take().expect("polled after completion");
if this.waiter.queued {
unsafe { state.tx_waiters.remove(&raw mut this.waiter) };
}
unsafe { state.buffer.push(value) };
if let Some(waker) = state.rx_waker.take() {
waker.wake();
}
return Poll::Ready(Ok(()));
}
this.waiter.waker = Some(cx.waker().clone());
if !this.waiter.queued {
unsafe { state.tx_waiters.push_back(&raw mut this.waiter) };
}
Poll::Pending
}
}
impl<T> Drop for Send<'_, T> {
fn drop(&mut self) {
if self.waiter.queued {
let state = unsafe { inner(&self.sender.inner) };
unsafe { state.tx_waiters.remove(&raw mut self.waiter) };
}
}
}
pub struct Receiver<T> {
inner: Shared<T>,
}
impl<T> Receiver<T> {
pub fn recv(&self) -> Recv<'_, T> {
Recv { receiver: self }
}
pub fn try_recv(&self) -> Result<T, TryRecvError> {
let state = unsafe { inner(&self.inner) };
if !state.buffer.is_empty() {
let value = unsafe { state.buffer.pop() };
let waiter = unsafe { state.tx_waiters.pop_front() };
if !waiter.is_null() {
if let Some(waker) = unsafe { (*waiter).waker.take() } {
waker.wake();
}
}
return Ok(value);
}
if state.sender_count == 0 {
Err(TryRecvError::Closed)
} else {
Err(TryRecvError::Empty)
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
let state = unsafe { inner(&self.inner) };
state.closed = true;
unsafe { state.tx_waiters.wake_all() };
}
}
pub struct Recv<'a, T> {
receiver: &'a Receiver<T>,
}
impl<T> Future for Recv<'_, T> {
type Output = Result<T, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let state = unsafe { inner(&self.receiver.inner) };
if !state.buffer.is_empty() {
let value = unsafe { state.buffer.pop() };
let waiter = unsafe { state.tx_waiters.pop_front() };
if !waiter.is_null() {
if let Some(waker) = unsafe { (*waiter).waker.take() } {
waker.wake();
}
}
return Poll::Ready(Ok(value));
}
if state.sender_count == 0 {
return Poll::Ready(Err(RecvError));
}
state.rx_waker = Some(cx.waker().clone());
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::task::{RawWaker, RawWakerVTable};
fn noop_waker() -> Waker {
fn noop(_: *const ()) {}
fn noop_clone(p: *const ()) -> RawWaker {
RawWaker::new(p, &VTABLE)
}
const VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
}
fn poll_once<F: Future>(f: Pin<&mut F>) -> Poll<F::Output> {
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
f.poll(&mut cx)
}
#[test]
fn ring_buffer_push_pop() {
let mut rb = RingBuffer::<u32>::new(4);
assert_eq!(rb.capacity(), 4);
assert!(rb.is_empty());
unsafe { rb.push(1) };
unsafe { rb.push(2) };
unsafe { rb.push(3) };
assert_eq!(rb.len(), 3);
assert_eq!(unsafe { rb.pop() }, 1);
assert_eq!(unsafe { rb.pop() }, 2);
assert_eq!(unsafe { rb.pop() }, 3);
assert!(rb.is_empty());
}
#[test]
fn ring_buffer_full() {
let mut rb = RingBuffer::<u32>::new(2);
assert_eq!(rb.capacity(), 2);
unsafe { rb.push(1) };
unsafe { rb.push(2) };
assert!(rb.is_full());
assert_eq!(rb.len(), 2);
assert_eq!(unsafe { rb.pop() }, 1);
assert!(!rb.is_full());
}
#[test]
fn ring_buffer_wrap_around() {
let mut rb = RingBuffer::<u32>::new(4);
for cycle in 0..10u32 {
let base = cycle * 4;
unsafe { rb.push(base) };
unsafe { rb.push(base + 1) };
unsafe { rb.push(base + 2) };
unsafe { rb.push(base + 3) };
assert!(rb.is_full());
assert_eq!(unsafe { rb.pop() }, base);
assert_eq!(unsafe { rb.pop() }, base + 1);
assert_eq!(unsafe { rb.pop() }, base + 2);
assert_eq!(unsafe { rb.pop() }, base + 3);
assert!(rb.is_empty());
}
}
#[test]
fn ring_buffer_rounds_up_to_power_of_two() {
let rb = RingBuffer::<u8>::new(3);
assert_eq!(rb.capacity(), 4);
let rb = RingBuffer::<u8>::new(5);
assert_eq!(rb.capacity(), 8);
let rb = RingBuffer::<u8>::new(8);
assert_eq!(rb.capacity(), 8);
}
#[test]
#[should_panic(expected = "capacity must be > 0")]
fn ring_buffer_zero_capacity_panics() {
let _ = RingBuffer::<u8>::new(0);
}
#[test]
fn ring_buffer_drop_remaining() {
use std::cell::Cell;
use std::rc::Rc;
let dropped = Rc::new(Cell::new(0u32));
struct DropCounter(Rc<Cell<u32>>);
impl Drop for DropCounter {
fn drop(&mut self) {
self.0.set(self.0.get() + 1);
}
}
let mut rb = RingBuffer::new(4);
unsafe { rb.push(DropCounter(dropped.clone())) };
unsafe { rb.push(DropCounter(dropped.clone())) };
unsafe { rb.push(DropCounter(dropped.clone())) };
assert_eq!(dropped.get(), 0);
drop(rb);
assert_eq!(dropped.get(), 3);
}
#[test]
fn send_recv_single() {
let (tx, rx) = channel_inner::<u32>(4);
tx.try_send(1).unwrap();
tx.try_send(2).unwrap();
tx.try_send(3).unwrap();
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv().unwrap(), 2);
assert_eq!(rx.try_recv().unwrap(), 3);
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
}
#[test]
fn fifo_ordering() {
let (tx, rx) = channel_inner(8);
for i in 0..8u32 {
tx.try_send(i).unwrap();
}
for i in 0..8u32 {
assert_eq!(rx.try_recv().unwrap(), i);
}
}
#[test]
fn try_send_full() {
let (tx, rx) = channel_inner(2);
tx.try_send(1u32).unwrap();
tx.try_send(2).unwrap();
let err = tx.try_send(3).unwrap_err();
assert!(err.is_full());
assert_eq!(err.into_inner(), 3);
assert_eq!(rx.try_recv().unwrap(), 1);
tx.try_send(3).unwrap();
}
#[test]
fn try_recv_empty() {
let (tx, rx) = channel_inner::<u32>(4);
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
tx.try_send(1).unwrap();
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
}
#[test]
fn sender_drop_signals_closed() {
let (tx, rx) = channel_inner::<u32>(4);
tx.try_send(42).unwrap();
drop(tx);
assert_eq!(rx.try_recv().unwrap(), 42);
assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
}
#[test]
fn receiver_drop_signals_closed() {
let (tx, rx) = channel_inner::<u32>(4);
drop(rx);
let err = tx.try_send(1).unwrap_err();
assert!(err.is_closed());
}
#[test]
fn multiple_senders() {
let (tx1, rx) = channel_inner(8);
let tx2 = tx1.clone();
tx1.try_send(1u32).unwrap();
tx2.try_send(2).unwrap();
tx1.try_send(3).unwrap();
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv().unwrap(), 2);
assert_eq!(rx.try_recv().unwrap(), 3);
}
#[test]
fn last_sender_drop_wakes_receiver() {
let (tx, rx) = channel_inner::<u32>(4);
let mut recv_fut = std::pin::pin!(rx.recv());
assert!(poll_once(recv_fut.as_mut()).is_pending());
drop(tx);
assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
}
#[test]
fn recv_pending_then_ready() {
let (tx, rx) = channel_inner::<u32>(4);
let mut recv_fut = std::pin::pin!(rx.recv());
assert!(poll_once(recv_fut.as_mut()).is_pending());
tx.try_send(99).unwrap();
match poll_once(recv_fut.as_mut()) {
Poll::Ready(Ok(99)) => {}
other => panic!("expected Ready(Ok(99)), got {other:?}"),
}
}
#[test]
fn send_pending_then_ready() {
let (tx, rx) = channel_inner(2);
tx.try_send(1u32).unwrap();
tx.try_send(2).unwrap();
let mut send_fut = std::pin::pin!(tx.send(3));
assert!(poll_once(send_fut.as_mut()).is_pending());
assert_eq!(rx.try_recv().unwrap(), 1);
match poll_once(send_fut.as_mut()) {
Poll::Ready(Ok(())) => {}
other => panic!("expected Ready(Ok(())), got {other:?}"),
}
assert_eq!(rx.try_recv().unwrap(), 2);
assert_eq!(rx.try_recv().unwrap(), 3);
}
#[test]
fn send_cancelled_on_drop() {
let (tx, rx) = channel_inner(2);
tx.try_send(1u32).unwrap();
tx.try_send(2).unwrap();
{
let mut send_fut = std::pin::pin!(tx.send(3));
assert!(poll_once(send_fut.as_mut()).is_pending());
}
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv().unwrap(), 2);
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
}
#[test]
fn capacity_one() {
let (tx, rx) = channel_inner(1);
tx.try_send(42u32).unwrap();
assert!(tx.try_send(43).unwrap_err().is_full());
assert_eq!(rx.try_recv().unwrap(), 42);
tx.try_send(43).unwrap();
assert_eq!(rx.try_recv().unwrap(), 43);
}
#[test]
fn non_power_of_two_rounds_up() {
let (tx, rx) = channel_inner(3);
for i in 0..4u32 {
tx.try_send(i).unwrap();
}
assert!(tx.try_send(4).unwrap_err().is_full());
for i in 0..4u32 {
assert_eq!(rx.try_recv().unwrap(), i);
}
}
#[test]
fn clone_sender_increments_count() {
let (tx, rx) = channel_inner::<u32>(4);
let tx2 = tx.clone();
let tx3 = tx.clone();
drop(tx);
drop(tx2);
tx3.try_send(1).unwrap();
assert_eq!(rx.try_recv().unwrap(), 1);
drop(tx3);
assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
}
#[test]
fn recv_drains_buffer_after_all_senders_drop() {
let (tx, rx) = channel_inner(8);
tx.try_send(1u32).unwrap();
tx.try_send(2).unwrap();
tx.try_send(3).unwrap();
drop(tx);
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv().unwrap(), 2);
assert_eq!(rx.try_recv().unwrap(), 3);
assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
}
#[test]
fn send_after_receiver_drop_returns_closed() {
let (tx, rx) = channel_inner::<u32>(4);
drop(rx);
let mut send_fut = std::pin::pin!(tx.send(1));
match poll_once(send_fut.as_mut()) {
Poll::Ready(Err(SendError(1))) => {}
other => panic!("expected Ready(Err(SendError(1))), got {other:?}"),
}
}
#[test]
fn multiple_senders_blocked_then_unblocked() {
let (tx1, rx) = channel_inner(2);
let tx2 = tx1.clone();
tx1.try_send(1u32).unwrap();
tx2.try_send(2).unwrap();
let mut send1 = std::pin::pin!(tx1.send(3));
let mut send2 = std::pin::pin!(tx2.send(4));
assert!(poll_once(send1.as_mut()).is_pending());
assert!(poll_once(send2.as_mut()).is_pending());
assert_eq!(rx.try_recv().unwrap(), 1);
match poll_once(send1.as_mut()) {
Poll::Ready(Ok(())) => {}
other => panic!("expected Ready(Ok(())), got {other:?}"),
}
assert_eq!(rx.try_recv().unwrap(), 2);
match poll_once(send2.as_mut()) {
Poll::Ready(Ok(())) => {}
other => panic!("expected Ready(Ok(())), got {other:?}"),
}
assert_eq!(rx.try_recv().unwrap(), 3);
assert_eq!(rx.try_recv().unwrap(), 4);
}
#[test]
fn receiver_drop_wakes_blocked_senders() {
let (tx, rx) = channel_inner(1);
tx.try_send(1u32).unwrap();
let mut send_fut = std::pin::pin!(tx.send(2));
assert!(poll_once(send_fut.as_mut()).is_pending());
drop(rx);
match poll_once(send_fut.as_mut()) {
Poll::Ready(Err(SendError(2))) => {}
other => panic!("expected Ready(Err(SendError(2))), got {other:?}"),
}
}
#[test]
fn drop_values_on_channel_close() {
use std::cell::Cell;
use std::rc::Rc;
let dropped = Rc::new(Cell::new(0u32));
struct DropCounter(Rc<Cell<u32>>);
impl Drop for DropCounter {
fn drop(&mut self) {
self.0.set(self.0.get() + 1);
}
}
let (tx, rx) = channel_inner(4);
tx.try_send(DropCounter(dropped.clone())).unwrap();
tx.try_send(DropCounter(dropped.clone())).unwrap();
tx.try_send(DropCounter(dropped.clone())).unwrap();
assert_eq!(dropped.get(), 0);
drop(tx);
drop(rx);
assert_eq!(dropped.get(), 3);
}
#[test]
fn stress_sequential_send_recv() {
let (tx, rx) = channel_inner(64);
for i in 0..100_000u64 {
tx.try_send(i).unwrap();
assert_eq!(rx.try_recv().unwrap(), i);
}
}
#[test]
fn stress_fill_drain_cycles() {
let (tx, rx) = channel_inner(64);
for _ in 0..1_000 {
for i in 0..64u32 {
tx.try_send(i).unwrap();
}
assert!(tx.try_send(999).unwrap_err().is_full());
for i in 0..64u32 {
assert_eq!(rx.try_recv().unwrap(), i);
}
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
}
}
#[test]
fn stress_interleaved_small_buffer() {
let (tx, rx) = channel_inner(2);
for i in 0..50_000u64 {
tx.try_send(i).unwrap();
assert_eq!(rx.try_recv().unwrap(), i);
}
}
}