use crate::shim::atomic::{AtomicBool, Ordering};
use crate::shim::cell::UnsafeCell;
use crate::shim::notify::SingleWaiterNotify;
use crate::shim::sync::Arc;
use core::num::NonZeroUsize;
use smallring::spsc::{Consumer, PopError, Producer, PushError, new};
pub fn channel<T, const N: usize>(capacity: NonZeroUsize) -> (Sender<T, N>, Receiver<T, N>) {
let (producer, consumer) = new::<T, N>(capacity);
let inner = Arc::new(Inner::<T, N> {
producer: UnsafeCell::new(producer),
consumer: UnsafeCell::new(consumer),
closed: AtomicBool::new(false),
recv_notify: SingleWaiterNotify::new(),
send_notify: SingleWaiterNotify::new(),
});
let sender = Sender {
inner: inner.clone(),
};
let receiver = Receiver { inner };
(sender, receiver)
}
struct Inner<T, const N: usize = 32> {
producer: UnsafeCell<Producer<T, N>>,
consumer: UnsafeCell<Consumer<T, N>>,
closed: AtomicBool,
recv_notify: SingleWaiterNotify,
send_notify: SingleWaiterNotify,
}
unsafe impl<T: Send, const N: usize> Sync for Inner<T, N> {}
impl<T, const N: usize> core::fmt::Debug for Inner<T, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Inner")
.field("closed", &self.closed.load(Ordering::Acquire))
.finish()
}
}
pub struct Sender<T, const N: usize> {
inner: Arc<Inner<T, N>>,
}
impl<T, const N: usize> core::fmt::Debug for Sender<T, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Sender")
.field("closed", &self.is_closed())
.field("len", &self.len())
.field("capacity", &self.capacity())
.finish()
}
}
pub struct Receiver<T, const N: usize> {
inner: Arc<Inner<T, N>>,
}
impl<T, const N: usize> core::fmt::Debug for Receiver<T, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Receiver")
.field("is_empty", &self.is_empty())
.field("len", &self.len())
.field("capacity", &self.capacity())
.finish()
}
}
pub struct Drain<'a, T, const N: usize> {
receiver: &'a mut Receiver<T, N>,
}
impl<'a, T, const N: usize> core::fmt::Debug for Drain<'a, T, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Drain")
.field("len", &self.receiver.len())
.field("is_empty", &self.receiver.is_empty())
.finish()
}
}
impl<'a, T, const N: usize> Iterator for Drain<'a, T, N> {
type Item = T;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.receiver.try_recv().ok()
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.receiver.len();
(len, Some(len))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SendError<T> {
Closed(T),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TryRecvError {
Empty,
Closed,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrySendError<T> {
Full(T),
Closed(T),
}
impl<T, const N: usize> Sender<T, N> {
pub async fn send(&self, mut value: T) -> Result<(), SendError<T>> {
loop {
match self.try_send(value) {
Ok(()) => return Ok(()),
Err(TrySendError::Closed(v)) => return Err(SendError::Closed(v)),
Err(TrySendError::Full(v)) => {
value = v;
self.inner.send_notify.notified().await;
if self.inner.closed.load(Ordering::Acquire) {
return Err(SendError::Closed(value));
}
}
}
}
}
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
if self.inner.closed.load(Ordering::Acquire) {
return Err(TrySendError::Closed(value));
}
self.inner.producer.with_mut(|producer| unsafe {
match (*producer).push(value) {
Ok(()) => {
self.inner.recv_notify.notify_one();
Ok(())
}
Err(PushError::Full(v)) => Err(TrySendError::Full(v)),
}
})
}
#[inline]
pub fn is_closed(&self) -> bool {
self.inner.closed.load(Ordering::Acquire)
}
#[inline]
pub fn capacity(&self) -> usize {
self.inner
.producer
.with(|producer| unsafe { (*producer).capacity() })
}
#[inline]
pub fn len(&self) -> usize {
self.inner
.producer
.with(|producer| unsafe { (*producer).slots() })
}
#[inline]
pub fn is_empty(&self) -> bool {
self.inner
.producer
.with(|producer| unsafe { (*producer).is_empty() })
}
#[inline]
pub fn free_slots(&self) -> usize {
self.inner
.producer
.with(|producer| unsafe { (*producer).free_slots() })
}
#[inline]
pub fn is_full(&self) -> bool {
self.inner
.producer
.with(|producer| unsafe { (*producer).is_full() })
}
}
impl<T: Copy, const N: usize> Sender<T, N> {
pub fn try_send_slice(&self, values: &[T]) -> usize {
if self.inner.closed.load(Ordering::Acquire) {
return 0;
}
self.inner.producer.with_mut(|producer| unsafe {
let sent = (*producer).push_slice(values);
if sent > 0 {
self.inner.recv_notify.notify_one();
}
sent
})
}
pub async fn send_slice(&self, values: &[T]) -> Result<usize, SendError<()>> {
let mut total_sent = 0;
while total_sent < values.len() {
if self.inner.closed.load(Ordering::Acquire) {
return Err(SendError::Closed(()));
}
let sent = self.try_send_slice(&values[total_sent..]);
total_sent += sent;
if total_sent < values.len() {
self.inner.send_notify.notified().await;
if self.inner.closed.load(Ordering::Acquire) {
return Err(SendError::Closed(()));
}
}
}
Ok(total_sent)
}
}
impl<T, const N: usize> Receiver<T, N> {
pub async fn recv(&self) -> Option<T> {
loop {
match self.try_recv() {
Ok(value) => return Some(value),
Err(TryRecvError::Closed) => return None,
Err(TryRecvError::Empty) => {
if self.inner.closed.load(Ordering::Acquire) {
if let Ok(value) = self.try_recv() {
return Some(value);
}
return None;
}
self.inner.recv_notify.notified().await;
}
}
}
}
pub fn try_recv(&self) -> Result<T, TryRecvError> {
self.inner.consumer.with_mut(|consumer| unsafe {
match (*consumer).pop() {
Ok(value) => {
self.inner.send_notify.notify_one();
Ok(value)
}
Err(PopError::Empty) => {
if self.inner.closed.load(Ordering::Acquire) {
match (*consumer).pop() {
Ok(value) => {
self.inner.send_notify.notify_one();
Ok(value)
}
Err(PopError::Empty) => Err(TryRecvError::Closed),
}
} else {
Err(TryRecvError::Empty)
}
}
}
})
}
#[inline]
pub fn is_empty(&self) -> bool {
self.inner
.consumer
.with(|consumer| unsafe { (*consumer).is_empty() })
}
#[inline]
pub fn len(&self) -> usize {
self.inner
.consumer
.with(|consumer| unsafe { (*consumer).slots() })
}
#[inline]
pub fn capacity(&self) -> usize {
self.inner
.consumer
.with(|consumer| unsafe { (*consumer).buffer().capacity() })
}
#[inline]
pub fn peek(&self) -> Option<&T> {
self.inner
.consumer
.with(|consumer| unsafe { core::mem::transmute((*consumer).peek()) })
}
pub fn clear(&mut self) {
self.inner
.consumer
.with_mut(|consumer| unsafe { (*consumer).clear() });
self.inner.send_notify.notify_one();
}
#[inline]
pub fn drain(&mut self) -> Drain<'_, T, N> {
Drain { receiver: self }
}
}
impl<T: Copy, const N: usize> Receiver<T, N> {
pub fn try_recv_slice(&mut self, dest: &mut [T]) -> usize {
self.inner.consumer.with_mut(|consumer| unsafe {
let received = (*consumer).pop_slice(dest);
if received > 0 {
self.inner.send_notify.notify_one();
}
received
})
}
pub async fn recv_slice(&mut self, dest: &mut [T]) -> usize {
let mut total_received = 0;
while total_received < dest.len() {
let received = self.try_recv_slice(&mut dest[total_received..]);
total_received += received;
if total_received < dest.len() {
if self.inner.closed.load(Ordering::Acquire) {
return total_received;
}
self.inner.recv_notify.notified().await;
if self.inner.closed.load(Ordering::Acquire) {
let final_received = self.try_recv_slice(&mut dest[total_received..]);
total_received += final_received;
return total_received;
}
}
}
total_received
}
}
impl<T, const N: usize> Drop for Receiver<T, N> {
fn drop(&mut self) {
self.inner.closed.store(true, Ordering::Release);
self.inner.send_notify.notify_one();
}
}
impl<T, const N: usize> Drop for Sender<T, N> {
fn drop(&mut self) {
self.inner.closed.store(true, Ordering::Release);
self.inner.recv_notify.notify_one();
}
}
#[cfg(all(test, not(feature = "loom")))]
mod tests {
use super::*;
#[tokio::test]
async fn test_basic_send_recv() {
let (tx, rx) = channel::<i32, 32>(NonZeroUsize::new(4).unwrap());
tx.send(1).await.unwrap();
tx.send(2).await.unwrap();
tx.send(3).await.unwrap();
assert_eq!(rx.recv().await, Some(1));
assert_eq!(rx.recv().await, Some(2));
assert_eq!(rx.recv().await, Some(3));
}
#[tokio::test]
async fn test_try_send_recv() {
let (tx, rx) = channel::<i32, 32>(NonZeroUsize::new(4).unwrap());
tx.try_send(1).unwrap();
tx.try_send(2).unwrap();
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv().unwrap(), 2);
assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
}
#[tokio::test]
async fn test_channel_closed_on_sender_drop() {
let (tx, rx) = channel::<i32, 32>(NonZeroUsize::new(4).unwrap());
tx.send(1).await.unwrap();
drop(tx);
assert_eq!(rx.recv().await, Some(1));
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
async fn test_channel_closed_on_receiver_drop() {
let (tx, rx) = channel::<i32, 32>(NonZeroUsize::new(4).unwrap());
drop(rx);
assert!(matches!(tx.send(1).await, Err(SendError::Closed(1))));
}
#[tokio::test]
async fn test_cross_task_communication() {
let (tx, rx) = channel::<i32, 32>(NonZeroUsize::new(4).unwrap());
let sender_handle = tokio::spawn(async move {
for i in 0..10 {
tx.send(i).await.unwrap();
}
});
let receiver_handle = tokio::spawn(async move {
let mut sum = 0;
while let Some(value) = rx.recv().await {
sum += value;
}
sum
});
sender_handle.await.unwrap();
let sum = receiver_handle.await.unwrap();
assert_eq!(sum, 45); }
#[tokio::test]
async fn test_backpressure() {
let (tx, rx) = channel::<i32, 32>(NonZeroUsize::new(4).unwrap());
tx.try_send(1).unwrap();
tx.try_send(2).unwrap();
tx.try_send(3).unwrap();
tx.try_send(4).unwrap();
assert!(matches!(tx.try_send(5), Err(TrySendError::Full(5))));
let send_handle = tokio::spawn(async move {
tx.send(5).await.unwrap();
tx.send(6).await.unwrap();
});
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
assert_eq!(rx.recv().await, Some(1));
assert_eq!(rx.recv().await, Some(2));
assert_eq!(rx.recv().await, Some(3));
assert_eq!(rx.recv().await, Some(4));
assert_eq!(rx.recv().await, Some(5));
assert_eq!(rx.recv().await, Some(6));
send_handle.await.unwrap();
}
#[tokio::test]
async fn test_capacity_and_len() {
let (tx, rx) = channel::<i32, 32>(NonZeroUsize::new(8).unwrap());
assert_eq!(rx.capacity(), 8);
assert_eq!(rx.len(), 0);
assert!(rx.is_empty());
tx.try_send(1).unwrap();
tx.try_send(2).unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
assert_eq!(rx.len(), 2);
assert!(!rx.is_empty());
}
#[tokio::test]
async fn test_sender_capacity_queries() {
let (tx, rx) = channel::<i32, 32>(NonZeroUsize::new(8).unwrap());
assert_eq!(tx.capacity(), 8);
assert_eq!(tx.len(), 0);
assert_eq!(tx.free_slots(), 8);
assert!(!tx.is_full());
tx.try_send(1).unwrap();
tx.try_send(2).unwrap();
tx.try_send(3).unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
assert_eq!(tx.len(), 3);
assert_eq!(tx.free_slots(), 5);
assert!(!tx.is_full());
tx.try_send(4).unwrap();
tx.try_send(5).unwrap();
tx.try_send(6).unwrap();
tx.try_send(7).unwrap();
tx.try_send(8).unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
assert_eq!(tx.len(), 8);
assert_eq!(tx.free_slots(), 0);
assert!(tx.is_full());
rx.recv().await;
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
assert_eq!(tx.len(), 7);
assert_eq!(tx.free_slots(), 1);
assert!(!tx.is_full());
}
#[tokio::test]
async fn test_try_send_slice() {
let (tx, rx) = channel::<u32, 32>(NonZeroUsize::new(16).unwrap());
let data = [1, 2, 3, 4, 5];
let sent = tx.try_send_slice(&data);
assert_eq!(sent, 5);
assert_eq!(rx.len(), 5);
for i in 0..5 {
assert_eq!(rx.recv().await.unwrap(), data[i]);
}
}
#[tokio::test]
async fn test_try_send_slice_partial() {
let (tx, rx) = channel::<u32, 32>(NonZeroUsize::new(8).unwrap());
let initial = [1, 2, 3, 4, 5];
tx.try_send_slice(&initial);
let more = [6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
let sent = tx.try_send_slice(&more);
assert_eq!(sent, 3);
assert_eq!(rx.len(), 8);
assert!(tx.is_full());
for i in 1..=8 {
assert_eq!(rx.recv().await.unwrap(), i);
}
}
#[tokio::test]
async fn test_send_slice() {
let (tx, rx) = channel::<u32, 32>(NonZeroUsize::new(16).unwrap());
let data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let result = tx.send_slice(&data).await;
assert_eq!(result.unwrap(), 10);
assert_eq!(rx.len(), 10);
for i in 0..10 {
assert_eq!(rx.recv().await.unwrap(), data[i]);
}
}
#[tokio::test]
async fn test_send_slice_with_backpressure() {
let (tx, rx) = channel::<u32, 32>(NonZeroUsize::new(4).unwrap());
let data = [1, 2, 3, 4, 5, 6, 7, 8];
let send_handle = tokio::spawn(async move { tx.send_slice(&data).await.unwrap() });
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
for i in 1..=4 {
assert_eq!(rx.recv().await.unwrap(), i);
}
let sent = send_handle.await.unwrap();
assert_eq!(sent, 8);
for i in 5..=8 {
assert_eq!(rx.recv().await.unwrap(), i);
}
}
#[tokio::test]
async fn test_peek() {
let (tx, rx) = channel::<i32, 32>(NonZeroUsize::new(8).unwrap());
assert!(rx.peek().is_none());
tx.try_send(42).unwrap();
tx.try_send(100).unwrap();
tx.try_send(200).unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
assert_eq!(rx.peek(), Some(&42));
assert_eq!(rx.peek(), Some(&42)); assert_eq!(rx.len(), 3); }
#[tokio::test]
async fn test_peek_after_recv() {
let (tx, rx) = channel::<String, 32>(NonZeroUsize::new(8).unwrap());
tx.try_send("first".to_string()).unwrap();
tx.try_send("second".to_string()).unwrap();
tx.try_send("third".to_string()).unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
assert_eq!(rx.peek(), Some(&"first".to_string()));
rx.recv().await.unwrap();
assert_eq!(rx.peek(), Some(&"second".to_string()));
rx.recv().await.unwrap();
assert_eq!(rx.peek(), Some(&"third".to_string()));
rx.recv().await.unwrap();
assert!(rx.peek().is_none());
}
#[tokio::test]
async fn test_clear() {
let (tx, mut rx) = channel::<i32, 32>(NonZeroUsize::new(16).unwrap());
for i in 0..10 {
tx.try_send(i).unwrap();
}
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
assert_eq!(rx.len(), 10);
rx.clear();
assert_eq!(rx.len(), 0);
assert!(rx.is_empty());
}
#[tokio::test]
async fn test_clear_with_drop() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
#[derive(Debug)]
struct DropCounter {
counter: Arc<AtomicUsize>,
}
impl Drop for DropCounter {
fn drop(&mut self) {
self.counter.fetch_add(1, AtomicOrdering::SeqCst);
}
}
let counter = Arc::new(AtomicUsize::new(0));
{
let (tx, mut rx) = channel::<DropCounter, 32>(NonZeroUsize::new(16).unwrap());
for _ in 0..8 {
tx.try_send(DropCounter {
counter: counter.clone(),
})
.unwrap();
}
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
assert_eq!(counter.load(AtomicOrdering::SeqCst), 0);
rx.clear();
assert_eq!(counter.load(AtomicOrdering::SeqCst), 8);
}
}
#[tokio::test]
async fn test_drain() {
let (tx, mut rx) = channel::<i32, 32>(NonZeroUsize::new(16).unwrap());
for i in 0..10 {
tx.try_send(i).unwrap();
}
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let collected: Vec<i32> = rx.drain().collect();
assert_eq!(collected, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
assert!(rx.is_empty());
}
#[tokio::test]
async fn test_drain_empty() {
let (_tx, mut rx) = channel::<i32, 32>(NonZeroUsize::new(8).unwrap());
let collected: Vec<i32> = rx.drain().collect();
assert!(collected.is_empty());
}
#[tokio::test]
async fn test_drain_size_hint() {
let (tx, mut rx) = channel::<i32, 32>(NonZeroUsize::new(16).unwrap());
for i in 0..5 {
tx.try_send(i).unwrap();
}
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let mut drain = rx.drain();
assert_eq!(drain.size_hint(), (5, Some(5)));
drain.next();
assert_eq!(drain.size_hint(), (4, Some(4)));
drain.next();
assert_eq!(drain.size_hint(), (3, Some(3)));
}
#[tokio::test]
async fn test_try_recv_slice() {
let (tx, mut rx) = channel::<u32, 32>(NonZeroUsize::new(16).unwrap());
for i in 0..10 {
tx.try_send(i).unwrap();
}
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let mut dest = [0u32; 5];
let received = rx.try_recv_slice(&mut dest);
assert_eq!(received, 5);
assert_eq!(dest, [0, 1, 2, 3, 4]);
assert_eq!(rx.len(), 5);
}
#[tokio::test]
async fn test_try_recv_slice_partial() {
let (tx, mut rx) = channel::<u32, 32>(NonZeroUsize::new(16).unwrap());
tx.try_send(1).unwrap();
tx.try_send(2).unwrap();
tx.try_send(3).unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let mut dest = [0u32; 10];
let received = rx.try_recv_slice(&mut dest);
assert_eq!(received, 3);
assert_eq!(&dest[0..3], &[1, 2, 3]);
assert!(rx.is_empty());
}
#[tokio::test]
async fn test_recv_slice() {
let (tx, mut rx) = channel::<u32, 32>(NonZeroUsize::new(16).unwrap());
for i in 1..=10 {
tx.try_send(i).unwrap();
}
let mut dest = [0u32; 10];
let received = rx.recv_slice(&mut dest).await;
assert_eq!(received, 10);
assert_eq!(dest, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
assert!(rx.is_empty());
}
#[tokio::test]
async fn test_recv_slice_with_wait() {
let (tx, mut rx) = channel::<u32, 32>(NonZeroUsize::new(4).unwrap());
let recv_handle = tokio::spawn(async move {
let mut dest = [0u32; 8];
let received = rx.recv_slice(&mut dest).await;
(received, dest)
});
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
for i in 1..=8 {
tx.send(i).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
}
let (received, dest) = recv_handle.await.unwrap();
assert_eq!(received, 8);
assert_eq!(dest, [1, 2, 3, 4, 5, 6, 7, 8]);
}
#[tokio::test]
async fn test_recv_slice_channel_closed() {
let (tx, mut rx) = channel::<u32, 32>(NonZeroUsize::new(8).unwrap());
tx.try_send(1).unwrap();
tx.try_send(2).unwrap();
tx.try_send(3).unwrap();
drop(tx);
let mut dest = [0u32; 10];
let received = rx.recv_slice(&mut dest).await;
assert_eq!(received, 3);
assert_eq!(&dest[0..3], &[1, 2, 3]);
}
#[tokio::test]
async fn test_combined_new_apis() {
let (tx, mut rx) = channel::<u32, 32>(NonZeroUsize::new(16).unwrap());
let data = [1, 2, 3, 4, 5];
tx.try_send_slice(&data);
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
assert_eq!(tx.len(), 5);
assert_eq!(rx.len(), 5);
assert_eq!(rx.capacity(), 16);
assert_eq!(rx.peek(), Some(&1));
let mut dest = [0u32; 3];
rx.try_recv_slice(&mut dest);
assert_eq!(dest, [1, 2, 3]);
assert_eq!(rx.len(), 2);
assert_eq!(tx.free_slots(), 14);
let remaining: Vec<u32> = rx.drain().collect();
assert_eq!(remaining, vec![4, 5]);
assert!(rx.is_empty());
assert!(!tx.is_full());
}
}