use crate::loom::cell::CausalCell;
use crate::loom::future::AtomicWaker;
use crate::loom::sync::atomic::{spin_loop_hint, AtomicBool, AtomicPtr, AtomicUsize};
use crate::loom::sync::{Arc, Condvar, Mutex};
use std::fmt;
use std::ptr;
use std::sync::atomic::Ordering::SeqCst;
use std::task::{Context, Poll, Waker};
use std::usize;
pub struct Sender<T> {
shared: Arc<Shared<T>>,
}
pub struct Receiver<T> {
shared: Arc<Shared<T>>,
next: u64,
wait: Arc<WaitNode>,
}
#[derive(Debug)]
pub struct SendError<T>(pub T);
#[derive(Debug, PartialEq)]
pub enum RecvError {
Closed,
Lagged(u64),
}
#[derive(Debug, PartialEq)]
pub enum TryRecvError {
Empty,
Closed,
Lagged(u64),
}
struct Shared<T> {
buffer: Box<[Slot<T>]>,
mask: usize,
tail: Mutex<Tail>,
condvar: Condvar,
wait_stack: AtomicPtr<WaitNode>,
num_tx: AtomicUsize,
}
struct Tail {
pos: u64,
rx_cnt: usize,
}
struct Slot<T> {
rem: AtomicUsize,
lock: AtomicUsize,
write: Write<T>,
}
struct Write<T> {
pos: CausalCell<u64>,
val: CausalCell<Option<T>>,
}
#[derive(Debug)]
struct WaitNode {
queued: AtomicBool,
waker: AtomicWaker,
next: CausalCell<*const WaitNode>,
}
struct RecvGuard<'a, T> {
slot: &'a Slot<T>,
tail: &'a Mutex<Tail>,
condvar: &'a Condvar,
}
const MAX_RECEIVERS: usize = usize::MAX >> 1;
pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
assert!(capacity > 0, "capacity is empty");
assert!(capacity <= usize::MAX >> 1, "requested capacity too large");
capacity = capacity.next_power_of_two();
let mut buffer = Vec::with_capacity(capacity);
for i in 0..capacity {
buffer.push(Slot {
rem: AtomicUsize::new(0),
lock: AtomicUsize::new(0),
write: Write {
pos: CausalCell::new((i as u64).wrapping_sub(capacity as u64)),
val: CausalCell::new(None),
},
});
}
let shared = Arc::new(Shared {
buffer: buffer.into_boxed_slice(),
mask: capacity - 1,
tail: Mutex::new(Tail { pos: 0, rx_cnt: 1 }),
condvar: Condvar::new(),
wait_stack: AtomicPtr::new(ptr::null_mut()),
num_tx: AtomicUsize::new(1),
});
let rx = Receiver {
shared: shared.clone(),
next: 0,
wait: Arc::new(WaitNode {
queued: AtomicBool::new(false),
waker: AtomicWaker::new(),
next: CausalCell::new(ptr::null()),
}),
};
let tx = Sender { shared };
(tx, rx)
}
unsafe impl<T: Send> Send for Sender<T> {}
unsafe impl<T: Send> Sync for Sender<T> {}
unsafe impl<T: Send> Send for Receiver<T> {}
unsafe impl<T: Send> Sync for Receiver<T> {}
impl<T> Sender<T> {
pub fn send(&self, value: T) -> Result<usize, SendError<T>> {
self.send2(Some(value))
.map_err(|SendError(maybe_v)| SendError(maybe_v.unwrap()))
}
pub fn subscribe(&self) -> Receiver<T> {
let shared = self.shared.clone();
let mut tail = shared.tail.lock().unwrap();
if tail.rx_cnt == MAX_RECEIVERS {
panic!("max receivers");
}
tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow");
let next = tail.pos;
drop(tail);
Receiver {
shared,
next,
wait: Arc::new(WaitNode {
queued: AtomicBool::new(false),
waker: AtomicWaker::new(),
next: CausalCell::new(ptr::null()),
}),
}
}
pub fn receiver_count(&self) -> usize {
let tail = self.shared.tail.lock().unwrap();
tail.rx_cnt
}
fn send2(&self, value: Option<T>) -> Result<usize, SendError<Option<T>>> {
let mut tail = self.shared.tail.lock().unwrap();
if tail.rx_cnt == 0 {
return Err(SendError(value));
}
let pos = tail.pos;
let rem = tail.rx_cnt;
let idx = (pos & self.shared.mask as u64) as usize;
tail.pos = tail.pos.wrapping_add(1);
let slot = &self.shared.buffer[idx];
let mut prev = slot.lock.fetch_or(1, SeqCst);
while prev & !1 != 0 {
tail = self.shared.condvar.wait(tail).unwrap();
prev = slot.lock.load(SeqCst);
if prev & 1 == 0 {
return Ok(rem);
}
}
if tail.pos.wrapping_sub(pos) > self.shared.buffer.len() as u64 {
return Ok(rem);
}
slot.write.pos.with_mut(|ptr| unsafe { *ptr = pos });
slot.write.val.with_mut(|ptr| unsafe { *ptr = value });
slot.rem.store(rem, SeqCst);
slot.lock.store(0, SeqCst);
drop(tail);
self.notify_rx();
Ok(rem)
}
fn notify_rx(&self) {
let mut curr = self.shared.wait_stack.swap(ptr::null_mut(), SeqCst) as *const WaitNode;
while !curr.is_null() {
let waiter = unsafe { Arc::from_raw(curr) };
curr = waiter.next.with(|ptr| unsafe { *ptr });
waiter.queued.store(false, SeqCst);
waiter.waker.wake();
}
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Sender<T> {
let shared = self.shared.clone();
shared.num_tx.fetch_add(1, SeqCst);
Sender { shared }
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
if 1 == self.shared.num_tx.fetch_sub(1, SeqCst) {
let _ = self.send2(None);
}
}
}
impl<T> Receiver<T> {
fn recv_ref(&mut self, spin: bool) -> Result<RecvGuard<'_, T>, TryRecvError> {
let idx = (self.next & self.shared.mask as u64) as usize;
let slot = &self.shared.buffer[idx];
if !slot.try_rx_lock() {
if spin {
while !slot.try_rx_lock() {
spin_loop_hint();
}
} else {
return Err(TryRecvError::Empty);
}
}
let guard = RecvGuard {
slot,
tail: &self.shared.tail,
condvar: &self.shared.condvar,
};
if guard.pos() != self.next {
let pos = guard.pos();
guard.drop_no_rem_dec();
if pos.wrapping_add(self.shared.buffer.len() as u64) == self.next {
return Err(TryRecvError::Empty);
} else {
let tail = self.shared.tail.lock().unwrap();
let next = tail.pos.wrapping_sub(self.shared.buffer.len() as u64);
let missed = next.wrapping_sub(self.next);
self.next = next;
return Err(TryRecvError::Lagged(missed));
}
}
self.next = self.next.wrapping_add(1);
Ok(guard)
}
}
impl<T> Receiver<T>
where
T: Clone,
{
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
let guard = self.recv_ref(false)?;
guard.clone_value().ok_or(TryRecvError::Closed)
}
#[doc(hidden)] pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
if let Some(value) = ok_empty(self.try_recv())? {
return Poll::Ready(Ok(value));
}
self.register_waker(cx.waker());
if let Some(value) = ok_empty(self.try_recv())? {
Poll::Ready(Ok(value))
} else {
Poll::Pending
}
}
pub async fn recv(&mut self) -> Result<T, RecvError> {
use crate::future::poll_fn;
poll_fn(|cx| self.poll_recv(cx)).await
}
fn register_waker(&self, cx: &Waker) {
self.wait.waker.register_by_ref(cx);
if !self.wait.queued.load(SeqCst) {
self.wait.queued.store(true, SeqCst);
let mut curr = self.shared.wait_stack.load(SeqCst);
let node = Arc::into_raw(self.wait.clone()) as *mut _;
loop {
self.wait.next.with_mut(|ptr| unsafe { *ptr = curr });
let res = self
.shared
.wait_stack
.compare_exchange(curr, node, SeqCst, SeqCst);
match res {
Ok(_) => return,
Err(actual) => curr = actual,
}
}
}
}
}
#[cfg(feature = "stream")]
impl<T> crate::stream::Stream for Receiver<T>
where
T: Clone,
{
type Item = Result<T, RecvError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<T, RecvError>>> {
self.poll_recv(cx).map(|v| match v {
Ok(v) => Some(Ok(v)),
lag @ Err(RecvError::Lagged(_)) => Some(lag),
Err(RecvError::Closed) => None,
})
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
let mut tail = self.shared.tail.lock().unwrap();
tail.rx_cnt -= 1;
let until = tail.pos;
drop(tail);
while self.next != until {
match self.recv_ref(true) {
Ok(_) => {}
Err(TryRecvError::Closed) => break,
Err(TryRecvError::Lagged(..)) => {}
Err(TryRecvError::Empty) => panic!("unexpected empty broadcast channel"),
}
}
}
}
impl<T> Drop for Shared<T> {
fn drop(&mut self) {
let mut curr = *self.wait_stack.get_mut() as *const WaitNode;
while !curr.is_null() {
let waiter = unsafe { Arc::from_raw(curr) };
curr = waiter.next.with(|ptr| unsafe { *ptr });
}
}
}
impl<T> fmt::Debug for Sender<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "broadcast::Sender")
}
}
impl<T> fmt::Debug for Receiver<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "broadcast::Receiver")
}
}
impl<T> Slot<T> {
fn try_rx_lock(&self) -> bool {
let mut curr = self.lock.load(SeqCst);
loop {
if curr & 1 == 1 {
return false;
}
let res = self.lock.compare_exchange(curr, curr + 2, SeqCst, SeqCst);
match res {
Ok(_) => return true,
Err(actual) => curr = actual,
}
}
}
fn rx_unlock(&self, tail: &Mutex<Tail>, condvar: &Condvar, rem_dec: bool) {
if rem_dec {
if 1 == self.rem.fetch_sub(1, SeqCst) {
self.write.val.with_mut(|ptr| unsafe { *ptr = None });
}
}
if 1 == self.lock.fetch_sub(2, SeqCst) - 2 {
let _ = tail.lock().unwrap();
condvar.notify_all();
}
}
}
impl<'a, T> RecvGuard<'a, T> {
fn pos(&self) -> u64 {
self.slot.write.pos.with(|ptr| unsafe { *ptr })
}
fn clone_value(&self) -> Option<T>
where
T: Clone,
{
self.slot.write.val.with(|ptr| unsafe { (*ptr).clone() })
}
fn drop_no_rem_dec(self) {
use std::mem;
self.slot.rx_unlock(self.tail, self.condvar, false);
mem::forget(self);
}
}
impl<'a, T> Drop for RecvGuard<'a, T> {
fn drop(&mut self) {
self.slot.rx_unlock(self.tail, self.condvar, true)
}
}
fn ok_empty<T>(res: Result<T, TryRecvError>) -> Result<Option<T>, RecvError> {
match res {
Ok(value) => Ok(Some(value)),
Err(TryRecvError::Empty) => Ok(None),
Err(TryRecvError::Lagged(n)) => Err(RecvError::Lagged(n)),
Err(TryRecvError::Closed) => Err(RecvError::Closed),
}
}
impl fmt::Display for RecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RecvError::Closed => write!(f, "channel closed"),
RecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt),
}
}
}
impl std::error::Error for RecvError {}
impl fmt::Display for TryRecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TryRecvError::Empty => write!(f, "channel empty"),
TryRecvError::Closed => write!(f, "channel closed"),
TryRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt),
}
}
}
impl std::error::Error for TryRecvError {}