use core::num::NonZeroUsize;
use crate::spsc::parking_shards::ParkingShardsPtr;
mod receiver;
mod sender;
pub use receiver::Receiver;
pub use sender::Sender;
pub fn channel<T>(
max_shards: NonZeroUsize,
capacity_per_shard: NonZeroUsize,
) -> (Sender<T>, Receiver<T>) {
debug_assert!(
max_shards.is_power_of_two(),
"number of shards must be a power of 2"
);
let shards = ParkingShardsPtr::new(max_shards, capacity_per_shard);
(
Sender::new(shards.clone(), max_shards.get()),
Receiver::new(shards, max_shards),
)
}
#[cfg(all(test, not(feature = "loom")))]
mod test {
use super::*;
use crate::thread;
use alloc_crate::vec::Vec;
#[test]
fn basic() {
const THREADS: u32 = 4;
const ITER: u32 = 100;
let (mut tx, rx) = channel(
NonZeroUsize::new(THREADS as usize).unwrap(),
NonZeroUsize::new(16).unwrap(),
);
thread::scope(move |scope| {
for _ in 0..THREADS - 1 {
let mut rx = rx.clone().unwrap();
scope.spawn(move || {
for _ in 0..ITER {
let _ = rx.recv();
}
});
}
let mut rx = rx;
scope.spawn(move || {
for _ in 0..ITER {
let _ = rx.recv();
}
});
for i in 0..THREADS * ITER {
tx.send(i);
}
});
}
#[test]
fn receiver_clone_reuses_dropped_shard() {
let (mut tx, mut rx0) =
channel::<usize>(NonZeroUsize::new(2).unwrap(), NonZeroUsize::new(4).unwrap());
let mut rx1 = rx0.clone().unwrap();
tx.send(0);
tx.send(1);
assert_eq!(rx0.recv(), 0);
assert_eq!(rx1.recv(), 1);
drop(rx0);
let mut rx2 = rx1.clone().unwrap();
assert!(rx1.clone().is_none());
tx.send(2);
assert_eq!(rx2.try_recv(), Some(2));
}
#[test]
fn test_try_ops() {
let (mut tx, mut rx) =
channel::<usize>(NonZeroUsize::new(1).unwrap(), NonZeroUsize::new(4).unwrap());
assert_eq!(rx.try_recv(), None);
for i in 0..4 {
tx.try_send(i).unwrap();
}
assert!(tx.try_send(99).is_err());
for i in 0..4 {
assert_eq!(rx.try_recv(), Some(i));
}
assert_eq!(rx.try_recv(), None);
}
#[test]
fn shard_futex_wakes_receiver_for_written_shard() {
use std::sync::mpsc::channel as std_channel;
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use std::time::Duration;
const SHARDS: usize = 8;
const ORDERS: [[usize; SHARDS]; 4] = [
[1, 2, 3, 0, 4, 5, 6, 7],
[7, 6, 5, 0, 4, 3, 2, 1],
[2, 4, 6, 0, 1, 3, 5, 7],
[7, 5, 3, 0, 6, 4, 2, 1],
];
for attempt in 0..10 {
let (mut tx, rx0) = channel::<usize>(
NonZeroUsize::new(SHARDS).unwrap(),
NonZeroUsize::new(1).unwrap(),
);
let mut receivers = Vec::new();
receivers.push(rx0);
for shard in 1..SHARDS {
let rx = receivers[shard - 1].clone().unwrap();
receivers.push(rx);
}
let mut receivers: Vec<_> = receivers.into_iter().map(Some).collect();
let (done_tx, done_rx) = std_channel();
let ready = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for shard in ORDERS[attempt % ORDERS.len()] {
let mut rx = receivers[shard].take().unwrap();
let done_tx = done_tx.clone();
let ready = ready.clone();
handles.push(thread::spawn(move || {
assert_eq!(rx.try_recv(), None);
ready.fetch_add(1, Ordering::Release);
let value = rx.recv();
done_tx.send((shard, value)).unwrap();
}));
std::thread::sleep(Duration::from_millis(5));
}
drop(done_tx);
while ready.load(Ordering::Acquire) != SHARDS {
std::thread::yield_now();
}
std::thread::sleep(Duration::from_millis(100));
tx.send(0);
let first_woken = done_rx.recv_timeout(Duration::from_millis(50));
let target_woke = first_woken == Ok((0, 0));
for shard in 1..SHARDS {
tx.send(shard * 10);
}
let mut completed = [false; SHARDS];
if let Ok((shard, _value)) = first_woken {
completed[shard] = true;
}
let cleanup_started = std::time::Instant::now();
while !completed.iter().all(|done| *done)
&& cleanup_started.elapsed() < Duration::from_secs(1)
{
while let Ok((shard, _value)) = done_rx.try_recv() {
completed[shard] = true;
}
std::thread::sleep(Duration::from_millis(1));
}
if completed.iter().all(|done| *done) {
for handle in handles {
handle.join().unwrap();
}
}
assert!(
target_woke,
"attempt {attempt}: writing to shard 0 did not wake the receiver bound to shard 0; first completion: {first_woken:?}, completed: {completed:?}"
);
}
}
#[test]
fn test_receiver_parks_and_wakes() {
let (mut tx, mut rx) = channel::<usize>(
NonZeroUsize::new(1).unwrap(),
NonZeroUsize::new(16).unwrap(),
);
let h = thread::spawn(move || {
let val = rx.recv();
assert_eq!(val, 42);
});
std::thread::sleep(std::time::Duration::from_millis(10));
tx.send(42);
h.join().unwrap();
}
#[test]
fn test_multi_receiver() {
let (mut tx, rx) = channel::<usize>(
NonZeroUsize::new(2).unwrap(),
NonZeroUsize::new(64).unwrap(),
);
let mut rx2 = rx.clone().unwrap();
assert!(rx.clone().is_none());
let mut rx = rx;
for i in 0..10 {
tx.send(i);
}
for i in 0..5 {
assert_eq!(rx.recv(), i * 2);
}
for i in 0..5 {
assert_eq!(rx2.recv(), i * 2 + 1);
}
}
#[test]
fn test_drop_remaining() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
struct DropCounter(Arc<AtomicUsize>);
impl Drop for DropCounter {
fn drop(&mut self) {
self.0.fetch_add(1, Ordering::SeqCst);
}
}
let dropped = Arc::new(AtomicUsize::new(0));
{
let (mut tx, _rx) = channel::<DropCounter>(
NonZeroUsize::new(1).unwrap(),
NonZeroUsize::new(8).unwrap(),
);
for _ in 0..5 {
tx.send(DropCounter(dropped.clone()));
}
}
assert_eq!(dropped.load(Ordering::SeqCst), 5);
}
}