use std::{
sync::Arc,
sync::atomic::{AtomicPtr, AtomicUsize, Ordering::*},
thread,
time::{Duration, Instant},
};
use crate::{
error::{RecvError, SendError, TryRecvError},
internals::LockFreeBoundedRing,
waiter::{
RecvWaiter, RecvWaiterGuard, RecvWaiterList, SelectWaiter, UNSELECTED,
abort_select_waiters, new_recv_waiter_list, push_select_waiter, wake_all_recv_waiters,
wake_one_recv_waiter, wake_select_all, wake_select_one,
},
};
pub(crate) struct Chan<T> {
ring: LockFreeBoundedRing<T>,
recv_waiters: RecvWaiterList,
select_waiters: Arc<AtomicPtr<SelectWaiter>>,
pub(crate) sender_count: AtomicUsize,
pub(crate) receiver_count: AtomicUsize,
send_select_waiters: Arc<AtomicPtr<SelectWaiter>>,
}
pub struct Sender<T>(pub(crate) Arc<Chan<T>>);
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.0.sender_count.fetch_add(1, Relaxed);
Sender(Arc::clone(&self.0))
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let prev = self.0.sender_count.fetch_sub(1, AcqRel);
if prev == 1 {
wake_all_recv_waiters(&self.0.recv_waiters, UNSELECTED);
wake_select_all(&self.0.select_waiters);
}
}
}
impl<T> Sender<T> {
pub fn send(&self, val: T) -> Result<(), SendError<T>> {
self.0.ring.try_push(val).map_err(SendError)?;
wake_one_recv_waiter(&self.0.recv_waiters, UNSELECTED);
wake_select_one(&self.0.select_waiters);
Ok(())
}
}
pub struct Receiver<T>(pub(crate) Arc<Chan<T>>);
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
self.0.receiver_count.fetch_add(1, Relaxed);
Receiver(Arc::clone(&self.0))
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
let prev = self.0.receiver_count.fetch_sub(1, AcqRel);
if prev == 1 {
wake_select_all(&self.0.send_select_waiters);
}
}
}
impl<T> Receiver<T> {
pub fn try_recv(&self) -> Result<T, TryRecvError> {
if let Some(msg) = self.0.ring.try_pop() {
wake_select_one(&self.0.send_select_waiters);
return Ok(msg);
}
if self.0.sender_count.load(Acquire) == 0 {
Err(TryRecvError::Disconnected)
} else {
Err(TryRecvError::Empty)
}
}
pub fn recv(&self) -> Result<T, RecvError> {
loop {
if let Some(msg) = self.0.ring.try_pop() {
wake_select_one(&self.0.send_select_waiters);
return Ok(msg);
}
if self.0.sender_count.load(Acquire) == 0 {
return Err(RecvError::Disconnected);
}
let marker = Arc::new(AtomicUsize::new(UNSELECTED));
let waiter = RecvWaiter::new(usize::MAX, Arc::clone(&marker));
let _guard = RecvWaiterGuard::register(waiter, &self.0.recv_waiters);
if let Some(msg) = self.0.ring.try_pop() {
wake_select_one(&self.0.send_select_waiters);
return Ok(msg);
}
if self.0.sender_count.load(Acquire) == 0 {
return Err(RecvError::Disconnected);
}
thread::park();
}
}
pub(crate) fn is_ready(&self) -> bool {
!self.0.ring.is_empty() || self.0.sender_count.load(Acquire) == 0
}
pub(crate) fn register_select(&self, case_id: usize, selected: Arc<AtomicUsize>) {
log_trace!(
"bounded_mpmc::register_select: chan={:p}, case_id={}",
Arc::as_ptr(&self.0),
case_id
);
let ptr = SelectWaiter::alloc(case_id, selected);
push_select_waiter(ptr, &self.0.select_waiters);
}
pub(crate) fn abort_select(&self, selected: &Arc<AtomicUsize>) {
log_trace!(
"bounded_mpmc::abort_select: chan={:p}",
Arc::as_ptr(&self.0)
);
abort_select_waiters(&self.0.select_waiters, selected);
}
pub fn complete_recv(&self) -> Result<T, RecvError> {
self.recv()
}
}
impl<T> crate::SelectableReceiver for Receiver<T> {
type Output = T;
fn is_ready(&self) -> bool {
self.is_ready()
}
fn register_select(
&self,
case_id: usize,
selected: std::sync::Arc<std::sync::atomic::AtomicUsize>,
) {
self.register_select(case_id, selected)
}
fn abort_select(&self, selected: &std::sync::Arc<std::sync::atomic::AtomicUsize>) {
self.abort_select(selected)
}
fn complete(&self) -> Result<Self::Output, crate::RecvError> {
self.complete_recv()
}
}
impl<T: Send + 'static> crate::SelectableSender for Sender<T> {
type Input = T;
fn is_ready(&self) -> bool {
!self.0.ring.is_full() || self.0.receiver_count.load(Acquire) == 0
}
fn register_select(&self, case_id: usize, selected: Arc<AtomicUsize>) {
let ptr = SelectWaiter::alloc(case_id, selected);
push_select_waiter(ptr, &self.0.send_select_waiters);
}
fn abort_select(&self, selected: &Arc<AtomicUsize>) {
abort_select_waiters(&self.0.send_select_waiters, selected);
}
fn complete_send(&self, value: T) -> Result<(), crate::SendError<T>> {
self.send(value)
}
}
pub fn channel<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let chan = Arc::new(Chan {
ring: LockFreeBoundedRing::new(capacity),
recv_waiters: new_recv_waiter_list(),
select_waiters: Arc::new(AtomicPtr::new(std::ptr::null_mut())),
sender_count: AtomicUsize::new(1),
receiver_count: AtomicUsize::new(1),
send_select_waiters: Arc::new(AtomicPtr::new(std::ptr::null_mut())),
});
(Sender(Arc::clone(&chan)), Receiver(chan))
}
pub fn after(duration: Duration) -> Receiver<Instant> {
let chan = Arc::new(Chan {
ring: LockFreeBoundedRing::new(1),
recv_waiters: new_recv_waiter_list(),
select_waiters: Arc::new(AtomicPtr::new(std::ptr::null_mut())),
sender_count: AtomicUsize::new(1),
receiver_count: AtomicUsize::new(1),
send_select_waiters: Arc::new(AtomicPtr::new(std::ptr::null_mut())),
});
let rx = Receiver(Arc::clone(&chan));
if let Err(error) = thread::Builder::new()
.name("selectables::mpmc_after".to_owned())
.spawn(move || {
thread::sleep(duration);
let _ = chan.ring.try_push(Instant::now());
wake_one_recv_waiter(&chan.recv_waiters, UNSELECTED);
wake_select_one(&chan.select_waiters);
})
{
eprintln!("failed to spawn 'mpmc_after' thread: {}", error);
}
rx
}
pub fn never<T>() -> Receiver<T> {
Receiver(Arc::new(Chan {
ring: LockFreeBoundedRing::new(0),
recv_waiters: new_recv_waiter_list(),
select_waiters: Arc::new(AtomicPtr::new(std::ptr::null_mut())),
sender_count: AtomicUsize::new(1),
receiver_count: AtomicUsize::new(1),
send_select_waiters: Arc::new(AtomicPtr::new(std::ptr::null_mut())),
}))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::thread;
use std::time::Duration;
use crate::{Select, select};
#[test]
fn test_basic_send_recv() {
let (tx, rx) = channel(4);
tx.send(1).unwrap();
tx.send(2).unwrap();
assert_eq!(rx.recv().unwrap(), 1);
assert_eq!(rx.recv().unwrap(), 2);
}
#[test]
fn test_try_recv_empty_and_full() {
let (tx, rx) = channel(2);
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
tx.send(1).unwrap();
tx.send(2).unwrap();
assert!(tx.send(3).is_err(), "buffer full");
assert_eq!(rx.try_recv(), Ok(1));
tx.send(3).unwrap();
assert_eq!(rx.try_recv(), Ok(2));
assert_eq!(rx.try_recv(), Ok(3));
}
#[test]
fn test_zero_capacity_always_full() {
let (tx, _rx) = channel::<i32>(0);
assert!(tx.send(42).is_err());
}
#[test]
fn test_never_channel_is_alive_but_empty() {
let rx: Receiver<i32> = never();
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
let mut sel = Select::new();
sel.recv(rx.clone());
assert!(sel.try_select().is_none());
}
#[test]
fn test_after_fires_once() {
let rx = after(Duration::from_millis(30));
let t = rx.recv().unwrap();
assert!(t.elapsed() < Duration::from_secs(5));
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
}
#[test]
fn test_sender_drop_disconnects() {
let (tx, rx) = channel::<i32>(4);
tx.send(1).unwrap();
drop(tx);
assert_eq!(rx.try_recv(), Ok(1));
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected));
assert!(rx.recv().is_err());
}
#[test]
fn test_clone_sender_keeps_alive() {
let (tx1, rx) = channel(4);
let tx2 = tx1.clone();
drop(tx1);
tx2.send(42).unwrap();
drop(tx2);
assert_eq!(rx.recv().unwrap(), 42);
assert!(rx.recv().is_err()); }
#[test]
fn test_multiple_senders() {
let (tx1, rx) = channel(16);
let tx2 = tx1.clone();
let tx3 = tx1.clone();
tx1.send(1).unwrap();
tx2.send(2).unwrap();
tx3.send(3).unwrap();
let mut vals: Vec<i32> = (0..3).map(|_| rx.recv().unwrap()).collect();
vals.sort();
assert_eq!(vals, vec![1, 2, 3]);
}
#[test]
fn test_multiple_receivers() {
let (tx, rx1) = channel(16);
let rx2 = rx1.clone();
tx.send(10).unwrap();
tx.send(20).unwrap();
let a = rx1.try_recv().unwrap_or_else(|_| rx2.try_recv().unwrap());
let b = rx1.try_recv().unwrap_or_else(|_| rx2.try_recv().unwrap());
let mut vals = vec![a, b];
vals.sort();
assert_eq!(vals, vec![10, 20]);
}
#[test]
fn test_blocking_recv_woken_by_send() {
let (tx, rx) = channel(4);
let handle = thread::spawn(move || rx.recv().unwrap());
thread::sleep(std::time::Duration::from_millis(20));
tx.send(42).unwrap();
assert_eq!(handle.join().unwrap(), 42);
}
#[test]
fn test_blocking_recv_woken_by_disconnect() {
let (tx, rx) = channel::<i32>(4);
let handle = thread::spawn(move || rx.recv());
thread::sleep(std::time::Duration::from_millis(20));
drop(tx);
assert!(handle.join().unwrap().is_err());
}
#[test]
fn test_is_ready() {
let (tx, rx) = channel(4);
assert!(!rx.is_ready());
tx.send(1).unwrap();
assert!(rx.is_ready());
rx.try_recv().unwrap();
assert!(!rx.is_ready());
drop(tx);
assert!(rx.is_ready()); }
#[test]
fn test_concurrent_producers_consumers() {
const PRODUCERS: usize = 4;
const PER_PRODUCER: usize = 256;
const TOTAL: usize = PRODUCERS * PER_PRODUCER;
let (tx, rx) = channel(32);
let received = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..PRODUCERS {
let tx = tx.clone();
handles.push(thread::spawn(move || {
for i in 0..PER_PRODUCER {
loop {
match tx.send(i) {
Ok(()) => break,
Err(_) => thread::yield_now(),
}
}
}
}));
}
for _ in 0..PRODUCERS {
let rx = rx.clone();
let received = Arc::clone(&received);
handles.push(thread::spawn(move || {
loop {
match rx.try_recv() {
Ok(_) => {
received.fetch_add(1, Ordering::Relaxed);
}
Err(TryRecvError::Empty) => thread::yield_now(),
Err(TryRecvError::Disconnected) => return,
Err(TryRecvError::Lagged { .. }) => unreachable!("bounded_mpmc cannot lag"),
}
}
}));
}
drop(tx); for h in handles {
h.join().unwrap();
}
assert_eq!(received.load(Ordering::Relaxed), TOTAL);
}
#[test]
fn test_drop_values_in_buffer() {
use std::sync::atomic::AtomicBool;
static DROPPED: AtomicBool = AtomicBool::new(false);
#[derive(Debug)]
struct Guard;
impl Drop for Guard {
fn drop(&mut self) {
DROPPED.store(true, Ordering::Relaxed);
}
}
let (tx, rx) = channel(4);
tx.send(Guard).unwrap();
drop(tx);
drop(rx); assert!(DROPPED.load(Ordering::Relaxed));
}
#[test]
fn test_fifo_ordering() {
let (tx, rx) = channel(16);
for i in 0..8u32 {
tx.send(i).unwrap();
}
for i in 0..8u32 {
assert_eq!(rx.recv().unwrap(), i);
}
}
#[test]
fn test_recv_drains_before_disconnect() {
let (tx, rx) = channel(4);
tx.send(1).unwrap();
tx.send(2).unwrap();
drop(tx);
assert_eq!(rx.recv().unwrap(), 1);
assert_eq!(rx.recv().unwrap(), 2);
assert!(rx.recv().is_err());
}
#[test]
fn test_send_select_wakes_when_receiver_pops() {
let (tx, rx) = channel::<i32>(1);
tx.send(1).unwrap();
let tx2 = tx.clone();
let rx2 = rx.clone();
let handle = std::thread::spawn(move || {
select! {
send(tx2, 99) -> res => assert!(res.is_ok()),
}
});
thread::sleep(std::time::Duration::from_millis(20));
assert_eq!(rx.recv().unwrap(), 1);
handle.join().unwrap();
assert_eq!(rx2.recv().unwrap(), 99);
}
#[test]
fn test_select_ready_when_disconnected_and_empty() {
let (tx, rx) = channel::<i32>(4);
drop(tx);
assert!(rx.is_ready());
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected));
}
}