use std::{
cell::UnsafeCell,
mem::ManuallyDrop,
ptr,
sync::{
Arc,
atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering::*},
},
thread,
};
use crate::{
error::{RecvError, SendError, TryRecvError},
internals::UNSELECTED,
waiter::{
RecvWaiter, RecvWaiterGuard, RecvWaiterList, SelectWaiter, abort_select_waiters,
drain_select_waiters, new_recv_waiter_list, push_select_waiter, wake_all_recv_waiters,
wake_select_all, wake_select_one,
},
};
struct SenderWaiter<T> {
value: UnsafeCell<ManuallyDrop<T>>,
taken: AtomicBool,
thread: thread::Thread,
next: AtomicPtr<SenderWaiter<T>>,
}
unsafe impl<T: Send> Send for SenderWaiter<T> {}
unsafe impl<T: Send> Sync for SenderWaiter<T> {}
impl<T> SenderWaiter<T> {
fn new(value: T) -> Self {
SenderWaiter {
value: UnsafeCell::new(ManuallyDrop::new(value)),
taken: AtomicBool::new(false),
thread: thread::current(),
next: AtomicPtr::new(ptr::null_mut()),
}
}
}
struct Chan<T> {
sender_waiters: AtomicPtr<SenderWaiter<T>>,
recv_waiters: RecvWaiterList,
select_waiters: Arc<AtomicPtr<SelectWaiter>>,
send_select_waiters: Arc<AtomicPtr<SelectWaiter>>,
sender_count: AtomicUsize,
receiver_count: AtomicUsize,
}
pub struct Sender<T>(Arc<Chan<T>>);
impl<T: Send> Sender<T> {
pub fn send(&self, value: T) -> Result<(), SendError<T>> {
if self.0.receiver_count.load(Acquire) == 0 {
return Err(SendError(value));
}
let waiter = SenderWaiter::new(value);
let waiter_ptr = &waiter as *const SenderWaiter<T> as *mut SenderWaiter<T>;
loop {
let head = self.0.sender_waiters.load(Acquire);
waiter.next.store(head, Relaxed);
if self
.0
.sender_waiters
.compare_exchange(head, waiter_ptr, AcqRel, Acquire)
.is_ok()
{
break;
}
}
wake_all_recv_waiters(&self.0.recv_waiters, UNSELECTED);
wake_select_one(&self.0.select_waiters);
loop {
if waiter.taken.load(Acquire) {
return Ok(());
}
if self.0.receiver_count.load(Acquire) == 0 {
self.remove_sender_waiter(waiter_ptr);
let val = unsafe { ManuallyDrop::into_inner(ptr::read(waiter.value.get())) };
return Err(SendError(val));
}
thread::park();
}
}
fn remove_sender_waiter(&self, ptr: *mut SenderWaiter<T>) {
loop {
let head = self.0.sender_waiters.load(Acquire);
if head.is_null() {
return;
}
if head == ptr {
let next = unsafe { (*ptr).next.load(Acquire) };
if self
.0
.sender_waiters
.compare_exchange(head, next, AcqRel, Acquire)
.is_ok()
{
return;
}
continue; }
let mut current = head;
loop {
let next_ptr = unsafe { (*current).next.load(Acquire) };
if next_ptr == ptr {
let my_next = unsafe { (*ptr).next.load(Acquire) };
if unsafe {
(*current)
.next
.compare_exchange(next_ptr, my_next, AcqRel, Acquire)
.is_ok()
} {
return;
}
break; }
if next_ptr.is_null() {
return; }
current = next_ptr;
}
}
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.0.sender_count.fetch_add(1, Relaxed);
Sender(Arc::clone(&self.0))
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let prev = self.0.sender_count.fetch_sub(1, AcqRel);
if prev == 1 {
wake_all_recv_waiters(&self.0.recv_waiters, UNSELECTED);
wake_select_all(&self.0.select_waiters);
}
}
}
pub struct Receiver<T>(Arc<Chan<T>>);
impl<T: Send> Receiver<T> {
pub fn try_recv(&self) -> Result<T, TryRecvError> {
if let Some(val) = self.pop_sender() {
return Ok(val);
}
if self.0.sender_count.load(Acquire) == 0 {
Err(TryRecvError::Disconnected)
} else {
Err(TryRecvError::Empty)
}
}
pub fn recv(&self) -> Result<T, RecvError> {
loop {
if let Some(val) = self.pop_sender() {
return Ok(val);
}
if self.0.sender_count.load(Acquire) == 0 {
return Err(RecvError::Disconnected);
}
let marker = Arc::new(AtomicUsize::new(UNSELECTED));
let waiter = RecvWaiter::new(usize::MAX, Arc::clone(&marker));
let _guard = RecvWaiterGuard::register(waiter, &self.0.recv_waiters);
wake_select_one(&self.0.send_select_waiters);
if let Some(val) = self.pop_sender() {
return Ok(val);
}
if self.0.sender_count.load(Acquire) == 0 {
return Err(RecvError::Disconnected);
}
thread::park();
}
}
fn pop_sender(&self) -> Option<T> {
loop {
let head = self.0.sender_waiters.load(Acquire);
if head.is_null() {
return None;
}
let next = unsafe { (*head).next.load(Acquire) };
if self
.0
.sender_waiters
.compare_exchange(head, next, AcqRel, Acquire)
.is_ok()
{
let val = unsafe { ManuallyDrop::into_inner(ptr::read((*head).value.get())) };
unsafe { (*head).taken.store(true, Release) };
unsafe { (*head).thread.unpark() };
return Some(val);
}
}
}
pub(crate) fn is_ready(&self) -> bool {
!self.0.sender_waiters.load(Acquire).is_null() || self.0.sender_count.load(Acquire) == 0
}
pub(crate) fn register_select(&self, case_id: usize, selected: Arc<AtomicUsize>) {
let ptr = SelectWaiter::alloc(case_id, selected);
push_select_waiter(ptr, &self.0.select_waiters);
}
pub(crate) fn abort_select(&self, selected: &Arc<AtomicUsize>) {
abort_select_waiters(&self.0.select_waiters, selected);
}
pub fn complete_recv(&self) -> Result<T, RecvError> {
self.recv()
}
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
self.0.receiver_count.fetch_add(1, Relaxed);
Receiver(Arc::clone(&self.0))
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
let prev = self.0.receiver_count.fetch_sub(1, AcqRel);
if prev == 1 {
let mut current = self.0.sender_waiters.load(Acquire);
while !current.is_null() {
let next = unsafe { (*current).next.load(Acquire) };
unsafe { (*current).thread.unpark() };
current = next;
}
drain_select_waiters(&self.0.select_waiters);
wake_select_all(&self.0.send_select_waiters);
}
}
}
impl<T: Send + 'static> crate::SelectableSender for Sender<T> {
type Input = T;
fn is_ready(&self) -> bool {
!self.0.recv_waiters.lock().unwrap().is_empty() || self.0.receiver_count.load(Acquire) == 0
}
fn register_select(&self, case_id: usize, selected: Arc<AtomicUsize>) {
let ptr = SelectWaiter::alloc(case_id, selected);
push_select_waiter(ptr, &self.0.send_select_waiters);
}
fn abort_select(&self, selected: &Arc<AtomicUsize>) {
abort_select_waiters(&self.0.send_select_waiters, selected);
}
fn complete_send(&self, value: T) -> Result<(), crate::SendError<T>> {
self.send(value)
}
}
impl<T: Send> crate::SelectableReceiver for Receiver<T> {
type Output = T;
fn is_ready(&self) -> bool {
self.is_ready()
}
fn register_select(&self, case_id: usize, selected: Arc<AtomicUsize>) {
self.register_select(case_id, selected)
}
fn abort_select(&self, selected: &Arc<AtomicUsize>) {
self.abort_select(selected)
}
fn complete(&self) -> Result<Self::Output, RecvError> {
self.complete_recv()
}
}
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let chan = Arc::new(Chan {
sender_waiters: AtomicPtr::new(ptr::null_mut()),
recv_waiters: new_recv_waiter_list(),
select_waiters: Arc::new(AtomicPtr::new(ptr::null_mut())),
send_select_waiters: Arc::new(AtomicPtr::new(ptr::null_mut())),
sender_count: AtomicUsize::new(1),
receiver_count: AtomicUsize::new(1),
});
(Sender(Arc::clone(&chan)), Receiver(chan))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::time::Duration;
#[test]
fn test_basic_rendezvous() {
let (tx, rx) = channel::<i32>();
let handle = thread::spawn(move || rx.recv().unwrap());
tx.send(42).unwrap();
assert_eq!(handle.join().unwrap(), 42);
}
#[test]
fn test_try_recv_empty() {
let (tx, rx) = channel::<i32>();
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
drop(tx);
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected));
}
#[test]
fn test_try_recv_wins() {
let (tx, rx) = channel::<i32>();
let handle = thread::spawn(move || tx.send(99));
thread::sleep(Duration::from_millis(20)); assert_eq!(rx.try_recv(), Ok(99));
handle.join().unwrap().unwrap();
}
#[test]
fn test_sender_disconnect_wakes_recv() {
let (tx, rx) = channel::<i32>();
let handle = thread::spawn(move || rx.recv());
thread::sleep(Duration::from_millis(10));
drop(tx);
assert_eq!(handle.join().unwrap(), Err(RecvError::Disconnected));
}
#[test]
fn test_receiver_disconnect_wakes_sender() {
let (tx, rx) = channel::<i32>();
let handle = thread::spawn(move || tx.send(7));
thread::sleep(Duration::from_millis(10));
drop(rx);
assert_eq!(handle.join().unwrap(), Err(SendError(7)));
}
#[test]
fn test_multiple_senders_one_receiver() {
let (tx, rx) = channel::<i32>();
let tx2 = tx.clone();
let tx3 = tx.clone();
let h1 = thread::spawn(move || tx.send(1).unwrap());
let h2 = thread::spawn(move || tx2.send(2).unwrap());
let h3 = thread::spawn(move || tx3.send(3).unwrap());
let mut results = vec![rx.recv().unwrap(), rx.recv().unwrap(), rx.recv().unwrap()];
results.sort();
assert_eq!(results, vec![1, 2, 3]);
h1.join().unwrap();
h2.join().unwrap();
h3.join().unwrap();
}
#[test]
fn test_select_arm_rendezvous() {
use crate::{Select, bounded_mpmc};
let (tx_rdv, rx_rdv) = channel::<i32>();
let rx_never = bounded_mpmc::never::<i32>();
let h = thread::spawn(move || tx_rdv.send(55).unwrap());
thread::sleep(Duration::from_millis(20));
let mut sel = Select::new();
let i_rdv = sel.recv(rx_rdv.clone());
let _i_never = sel.recv(rx_never);
let op = sel.select();
assert_eq!(op.index, i_rdv);
assert_eq!(rx_rdv.complete_recv(), Ok(55));
h.join().unwrap();
}
#[test]
fn test_one_sender_multiple_receivers_each_gets_one() {
let (tx, rx) = channel::<i32>();
let rx2 = rx.clone();
let rx3 = rx.clone();
let h1 = thread::spawn(move || rx.recv().unwrap());
let h2 = thread::spawn(move || rx2.recv().unwrap());
let h3 = thread::spawn(move || rx3.recv().unwrap());
thread::sleep(Duration::from_millis(20));
tx.send(1).unwrap();
tx.send(2).unwrap();
tx.send(3).unwrap();
let mut results = vec![h1.join().unwrap(), h2.join().unwrap(), h3.join().unwrap()];
results.sort();
assert_eq!(results, vec![1, 2, 3]);
}
#[test]
fn test_mpmc_stress() {
const SENDERS: usize = 4;
const PER_SENDER: usize = 64;
const TOTAL: usize = SENDERS * PER_SENDER;
let (tx, rx) = channel::<usize>();
let counter = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for s in 0..SENDERS {
let txc = tx.clone();
handles.push(thread::spawn(move || {
for i in 0..PER_SENDER {
let _ = txc.send(s * PER_SENDER + i);
}
}));
}
drop(tx);
let ctr = Arc::clone(&counter);
handles.push(thread::spawn(move || {
while rx.recv().is_ok() {
ctr.fetch_add(1, Relaxed);
}
}));
for h in handles {
h.join().unwrap();
}
assert_eq!(counter.load(Relaxed), TOTAL);
}
#[test]
fn test_try_recv_disconnected_immediately() {
let (tx, rx) = channel::<i32>();
drop(tx);
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected));
}
}