use std::{
sync::Arc,
sync::atomic::{AtomicPtr, AtomicUsize, Ordering::*},
thread,
};
use crossbeam_queue::SegQueue;
use crate::{
error::{RecvError, SendError, TryRecvError},
waiter::{
RecvWaiter, RecvWaiterGuard, RecvWaiterList, SelectWaiter, UNSELECTED,
abort_select_waiters, new_recv_waiter_list, push_select_waiter, wake_one_recv_waiter,
wake_select_all, wake_select_one,
},
};
pub(crate) struct Chan<T> {
queue: SegQueue<T>,
recv_waiters: RecvWaiterList,
select_waiters: Arc<AtomicPtr<SelectWaiter>>,
sender_count: AtomicUsize,
receiver_count: AtomicUsize,
}
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let chan = Arc::new(Chan {
queue: SegQueue::new(),
recv_waiters: new_recv_waiter_list(),
select_waiters: Arc::new(AtomicPtr::new(std::ptr::null_mut())),
sender_count: AtomicUsize::new(1),
receiver_count: AtomicUsize::new(1),
});
log_debug!("unbounded_mpmc::unbounded: chan={:p}", Arc::as_ptr(&chan));
(Sender(Arc::clone(&chan)), Receiver(chan))
}
pub struct Sender<T>(pub(crate) Arc<Chan<T>>);
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.0.sender_count.fetch_add(1, Relaxed);
Sender(Arc::clone(&self.0))
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let prev = self.0.sender_count.fetch_sub(1, AcqRel);
log_debug!(
"unbounded_mpmc::sender_drop: chan={:p}, remaining_senders={}",
Arc::as_ptr(&self.0),
prev - 1
);
if prev == 1 {
log_debug!(
"unbounded_mpmc::sender_drop: chan={:p}, disconnecting",
Arc::as_ptr(&self.0)
);
wake_one_recv_waiter(&self.0.recv_waiters, UNSELECTED);
wake_select_all(&self.0.select_waiters);
}
}
}
impl<T> Sender<T> {
pub fn send(&self, val: T) -> Result<(), SendError<T>> {
self.0.queue.push(val);
log_debug!(
"unbounded_mpmc::send: chan={:p}, queue_len={}",
Arc::as_ptr(&self.0),
self.0.queue.len()
);
wake_one_recv_waiter(&self.0.recv_waiters, UNSELECTED);
wake_select_one(&self.0.select_waiters);
Ok(())
}
}
pub struct Receiver<T>(pub(crate) Arc<Chan<T>>);
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
self.0.receiver_count.fetch_add(1, Relaxed);
Receiver(Arc::clone(&self.0))
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
self.0.receiver_count.fetch_sub(1, AcqRel);
}
}
impl<T> Receiver<T> {
pub fn try_recv(&self) -> Result<T, TryRecvError> {
if let Some(v) = self.0.queue.pop() {
log_debug!(
"unbounded_mpmc::try_recv: chan={:p}, got value",
Arc::as_ptr(&self.0)
);
return Ok(v);
}
if self.0.sender_count.load(Acquire) == 0 {
log_debug!(
"unbounded_mpmc::try_recv: chan={:p}, Disconnected",
Arc::as_ptr(&self.0)
);
Err(TryRecvError::Disconnected)
} else {
log_debug!(
"unbounded_mpmc::try_recv: chan={:p}, Empty",
Arc::as_ptr(&self.0)
);
Err(TryRecvError::Empty)
}
}
pub fn recv(&self) -> Result<T, RecvError> {
let marker = Arc::new(AtomicUsize::new(UNSELECTED));
loop {
if let Some(v) = self.0.queue.pop() {
return Ok(v);
}
if self.0.sender_count.load(Acquire) == 0 {
return Err(RecvError::Disconnected);
}
if let Some(v) = self.0.queue.pop() {
return Ok(v);
}
if self.0.sender_count.load(Acquire) == 0 {
return Err(RecvError::Disconnected);
}
let waiter = RecvWaiter::new(usize::MAX, Arc::clone(&marker));
let _guard = RecvWaiterGuard::register(waiter, &self.0.recv_waiters);
if let Some(v) = self.0.queue.pop() {
return Ok(v);
}
if self.0.sender_count.load(Acquire) == 0 {
return Err(RecvError::Disconnected);
}
if marker.load(Acquire) != UNSELECTED {
if let Some(v) = self.0.queue.pop() {
return Ok(v);
}
return Err(RecvError::Disconnected);
}
thread::park();
}
}
pub(crate) fn is_ready(&self) -> bool {
!self.0.queue.is_empty() || self.0.sender_count.load(Acquire) == 0
}
pub(crate) fn register_select(&self, case_id: usize, selected: Arc<AtomicUsize>) {
log_trace!(
"unbounded_mpmc::register_select: chan={:p}, case_id={}",
Arc::as_ptr(&self.0),
case_id
);
let ptr = SelectWaiter::alloc(case_id, selected);
push_select_waiter(ptr, &self.0.select_waiters);
}
pub(crate) fn abort_select(&self, selected: &Arc<AtomicUsize>) {
log_trace!(
"unbounded_mpmc::abort_select: chan={:p}",
Arc::as_ptr(&self.0)
);
abort_select_waiters(&self.0.select_waiters, selected);
}
pub fn complete_recv(&self) -> Result<T, RecvError> {
self.recv()
}
}
impl<T> crate::SelectableReceiver for Receiver<T> {
type Output = T;
fn is_ready(&self) -> bool {
self.is_ready()
}
fn register_select(
&self,
case_id: usize,
selected: std::sync::Arc<std::sync::atomic::AtomicUsize>,
) {
self.register_select(case_id, selected)
}
fn abort_select(&self, selected: &std::sync::Arc<std::sync::atomic::AtomicUsize>) {
self.abort_select(selected)
}
fn complete(&self) -> Result<Self::Output, crate::RecvError> {
self.complete_recv()
}
}
impl<T: Send + 'static> crate::SelectableSender for Sender<T> {
type Input = T;
fn is_ready(&self) -> bool {
true
}
fn register_select(&self, _case_id: usize, _selected: Arc<AtomicUsize>) {}
fn abort_select(&self, _selected: &Arc<AtomicUsize>) {}
fn complete_send(&self, value: T) -> Result<(), crate::SendError<T>> {
self.send(value)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use std::time::Duration;
use crate::Select;
use super::*;
#[test]
fn basic_send_recv() {
let (tx, rx) = channel();
tx.send(42u32).unwrap();
assert_eq!(rx.try_recv(), Ok(42));
}
#[test]
fn fifo_ordering() {
let (tx, rx) = channel();
for i in 0..5u32 {
tx.send(i).unwrap();
}
for i in 0..5u32 {
assert_eq!(rx.recv().unwrap(), i);
}
}
#[test]
fn empty_then_disconnected() {
let (tx, rx) = channel::<i32>();
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
drop(tx);
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected));
}
#[test]
fn blocking_recv_wakes_on_send() {
let (tx, rx) = channel::<u32>();
let handle = thread::spawn(move || rx.recv().unwrap());
thread::sleep(Duration::from_millis(20));
tx.send(7).unwrap();
assert_eq!(handle.join().unwrap(), 7);
}
#[test]
fn blocking_recv_wakes_on_disconnect() {
let (tx, rx) = channel::<i32>();
let handle = thread::spawn(move || rx.recv());
thread::sleep(Duration::from_millis(20));
drop(tx);
assert_eq!(handle.join().unwrap(), Err(RecvError::Disconnected));
}
#[test]
fn multiple_senders() {
let (tx, rx) = channel::<u32>();
let mut handles = Vec::new();
for i in 0..4u32 {
let tx_clone = tx.clone();
handles.push(thread::spawn(move || tx_clone.send(i).unwrap()));
}
drop(tx);
for h in handles {
h.join().unwrap();
}
let mut collected: Vec<u32> = (0..4).map(|_| rx.recv().unwrap()).collect();
collected.sort();
assert_eq!(collected, vec![0, 1, 2, 3]);
}
#[test]
fn multiple_receivers_share_messages() {
let (tx, rx) = channel::<u32>();
let rx2 = rx.clone();
let h1 = thread::spawn(move || rx.recv().unwrap());
let h2 = thread::spawn(move || rx2.recv().unwrap());
tx.send(1).unwrap();
tx.send(2).unwrap();
let mut results = vec![h1.join().unwrap(), h2.join().unwrap()];
results.sort();
assert_eq!(results, vec![1, 2]);
}
#[test]
fn select_recv_ready() {
let (tx, rx) = channel::<i32>();
tx.send(99).unwrap();
let mut sel = Select::new();
let idx = sel.recv(rx.clone());
let op = sel.try_select().expect("should be ready");
assert_eq!(op.index, idx);
let val = rx.complete_recv().unwrap();
assert_eq!(val, 99);
}
#[test]
fn select_not_ready_when_empty() {
let (_tx, rx) = channel::<i32>();
let mut sel = Select::new();
sel.recv(rx.clone());
assert!(sel.try_select().is_none());
}
#[test]
fn complete_recv_waits_when_race_steals_message() {
let (tx, rx) = channel();
let rx_other = rx.clone();
tx.send(1).unwrap();
let started = Arc::new(AtomicBool::new(false));
let consumed = Arc::new(AtomicBool::new(false));
let start_flag = started.clone();
let consumed_flag = consumed.clone();
let tx_clone = tx.clone();
let handle = thread::spawn(move || {
let mut sel = Select::new();
sel.recv(rx_other.clone());
let _ = sel.select();
start_flag.store(true, Ordering::SeqCst);
while !consumed_flag.load(Ordering::SeqCst) {
thread::yield_now();
}
tx_clone.send(2).unwrap();
rx_other.complete_recv().unwrap()
});
while !started.load(Ordering::SeqCst) {
thread::yield_now();
}
let value = rx.recv().unwrap();
consumed.store(true, Ordering::SeqCst);
assert_eq!(value, 1);
let result = handle.join().unwrap();
assert_eq!(result, 2);
}
#[test]
fn stress_concurrent_senders_receivers() {
const SENDERS: usize = 4;
const RECEIVERS: usize = 4;
const MSGS_PER_SENDER: u32 = 500;
const TOTAL: u32 = SENDERS as u32 * MSGS_PER_SENDER;
let (tx, rx) = channel::<u32>();
let mut handles = Vec::new();
for s in 0..SENDERS {
let tx_clone = tx.clone();
handles.push(thread::spawn(move || {
for i in 0..MSGS_PER_SENDER {
tx_clone.send(s as u32 * MSGS_PER_SENDER + i).unwrap();
}
}));
}
let received = Arc::new(AtomicUsize::new(0));
for _ in 0..RECEIVERS {
let rx_clone = rx.clone();
let counter = Arc::clone(&received);
handles.push(thread::spawn(move || {
loop {
match rx_clone.try_recv() {
Ok(_) => {
counter.fetch_add(1, Relaxed);
}
Err(TryRecvError::Empty) => thread::yield_now(),
Err(TryRecvError::Disconnected) => break,
Err(TryRecvError::Lagged { .. }) => {
unreachable!("unbounded_mpmc cannot lag")
}
}
}
}));
}
drop(tx); for h in handles {
h.join().unwrap();
}
assert_eq!(received.load(Relaxed) as u32, TOTAL);
}
}