#![deny(missing_docs)]
#![warn(rust_2018_idioms)]
use crossbeam_channel as mpsc;
use parking_lot_core::SpinWait;
use std::cell::UnsafeCell;
use std::fmt;
use std::marker::PhantomData;
use std::ops::Deref;
use std::ptr;
use std::sync::atomic;
use std::sync::mpsc as std_mpsc;
use std::sync::Arc;
use std::thread;
use std::time;
const SPINTIME: u32 = 100_000;
struct SeatState<T> {
max: usize,
val: Option<T>,
}
struct MutSeatState<T>(UnsafeCell<SeatState<T>>);
unsafe impl<T> Sync for MutSeatState<T> {}
impl<T> Deref for MutSeatState<T> {
type Target = UnsafeCell<SeatState<T>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> fmt::Debug for MutSeatState<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("MutSeatState").field(&self.0).finish()
}
}
struct Seat<T> {
read: atomic::AtomicUsize,
state: MutSeatState<T>,
waiting: AtomicOption<thread::Thread>,
}
impl<T> fmt::Debug for Seat<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Seat")
.field("read", &self.read)
.field("state", &self.state)
.field("waiting", &self.waiting)
.finish()
}
}
impl<T: Clone + Sync> Seat<T> {
fn take(&self) -> T {
let read = self.read.load(atomic::Ordering::Acquire);
let state = unsafe { &*self.state.get() };
assert!(
read < state.max,
"reader hit seat with exhausted reader count"
);
let mut waiting = None;
let v = if read + 1 == state.max {
waiting = self.waiting.take();
unsafe { &mut *self.state.get() }.val.take().unwrap()
} else {
let v = state
.val
.clone()
.expect("seat that should be occupied was empty");
#[allow(clippy::drop_ref)]
drop(state);
v
};
self.read.fetch_add(1, atomic::Ordering::AcqRel);
if let Some(t) = waiting {
t.unpark();
}
v
}
}
impl<T> Default for Seat<T> {
fn default() -> Self {
Seat {
read: atomic::AtomicUsize::new(0),
waiting: AtomicOption::empty(),
state: MutSeatState(UnsafeCell::new(SeatState { max: 0, val: None })),
}
}
}
struct BusInner<T> {
ring: Vec<Seat<T>>,
len: usize,
tail: atomic::AtomicUsize,
closed: atomic::AtomicBool,
}
impl<T> fmt::Debug for BusInner<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BusInner")
.field("ring", &self.ring)
.field("len", &self.len)
.field("tail", &self.tail)
.field("closed", &self.closed)
.finish()
}
}
pub struct Bus<T> {
state: Arc<BusInner<T>>,
readers: usize,
rleft: Vec<usize>,
leaving: (mpsc::Sender<usize>, mpsc::Receiver<usize>),
#[allow(clippy::type_complexity)]
waiting: (
mpsc::Sender<(thread::Thread, usize)>,
mpsc::Receiver<(thread::Thread, usize)>,
),
unpark: mpsc::Sender<thread::Thread>,
cache: Vec<(thread::Thread, usize)>,
}
impl<T> fmt::Debug for Bus<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Bus")
.field("state", &self.state)
.field("readers", &self.readers)
.field("rleft", &self.rleft)
.field("leaving", &self.leaving)
.field("waiting", &self.waiting)
.field("unpark", &self.unpark)
.field("cache", &self.cache)
.finish()
}
}
impl<T> Bus<T> {
pub fn new(mut len: usize) -> Bus<T> {
use std::iter;
len += 1;
let inner = Arc::new(BusInner {
ring: (0..len).map(|_| Seat::default()).collect(),
tail: atomic::AtomicUsize::new(0),
closed: atomic::AtomicBool::new(false),
len,
});
if !cfg!(miri) && cfg!(target = "macos") {
let _ = time::Instant::now().elapsed();
}
let (unpark_tx, unpark_rx) = mpsc::unbounded::<thread::Thread>();
thread::spawn(move || {
for t in unpark_rx.iter() {
t.unpark();
}
});
Bus {
state: inner,
readers: 0,
rleft: iter::repeat(0).take(len).collect(),
leaving: mpsc::unbounded(),
waiting: mpsc::unbounded(),
unpark: unpark_tx,
cache: Vec::new(),
}
}
#[inline]
fn expected(&mut self, at: usize) -> usize {
unsafe { &*self.state.ring[at].state.get() }.max - self.rleft[at]
}
fn broadcast_inner(&mut self, val: T, block: bool) -> Result<(), T> {
let tail = self.state.tail.load(atomic::Ordering::Relaxed);
let fence = (tail + 1) % self.state.len;
let spintime = time::Duration::new(0, SPINTIME);
let mut sw = SpinWait::new();
loop {
let fence_read = self.state.ring[fence].read.load(atomic::Ordering::Acquire);
if fence_read == self.expected(fence) {
break;
}
while let Ok(mut left) = self.leaving.1.try_recv() {
self.readers -= 1;
while left != tail {
self.rleft[left] += 1;
left = (left + 1) % self.state.len
}
}
if fence_read == self.expected(fence) {
break;
} else if block {
self.state.ring[fence]
.waiting
.swap(Some(Box::new(thread::current())));
self.state.ring[fence]
.read
.fetch_add(0, atomic::Ordering::Release);
if !sw.spin() {
thread::park_timeout(spintime);
}
continue;
} else {
return Err(val);
}
}
let readers = self.readers;
{
let next = &self.state.ring[tail];
let state = unsafe { &mut *next.state.get() };
state.max = readers;
state.val = Some(val);
next.waiting.take();
next.read.store(0, atomic::Ordering::Release);
}
self.rleft[tail] = 0;
let tail = (tail + 1) % self.state.len;
self.state.tail.store(tail, atomic::Ordering::Release);
while let Ok((t, at)) = self.waiting.1.try_recv() {
if at == tail {
self.cache.push((t, at))
} else {
self.unpark.send(t).unwrap();
}
}
for w in self.cache.drain(..) {
self.waiting.0.send(w).unwrap();
}
Ok(())
}
pub fn try_broadcast(&mut self, val: T) -> Result<(), T> {
self.broadcast_inner(val, false)
}
pub fn broadcast(&mut self, val: T) {
if let Err(..) = self.broadcast_inner(val, true) {
unreachable!("blocking broadcast_inner can't fail");
}
}
pub fn add_rx(&mut self) -> BusReader<T> {
self.readers += 1;
BusReader {
bus: Arc::clone(&self.state),
head: self.state.tail.load(atomic::Ordering::Relaxed),
leaving: self.leaving.0.clone(),
waiting: self.waiting.0.clone(),
closed: false,
}
}
}
impl<T> Drop for Bus<T> {
fn drop(&mut self) {
self.state.closed.store(true, atomic::Ordering::Relaxed);
self.state.tail.fetch_add(0, atomic::Ordering::AcqRel);
}
}
#[derive(Clone, Copy)]
enum RecvCondition {
Try,
Block,
Timeout(time::Duration),
}
pub struct BusReader<T> {
bus: Arc<BusInner<T>>,
head: usize,
leaving: mpsc::Sender<usize>,
waiting: mpsc::Sender<(thread::Thread, usize)>,
closed: bool,
}
impl<T> fmt::Debug for BusReader<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BusReader")
.field("bus", &self.bus)
.field("head", &self.head)
.field("leaving", &self.leaving)
.field("waiting", &self.waiting)
.field("closed", &self.closed)
.finish()
}
}
impl<T: Clone + Sync> BusReader<T> {
fn recv_inner(&mut self, block: RecvCondition) -> Result<T, std_mpsc::RecvTimeoutError> {
if self.closed {
return Err(std_mpsc::RecvTimeoutError::Disconnected);
}
let start = match block {
RecvCondition::Timeout(_) => Some(time::Instant::now()),
_ => None,
};
let spintime = time::Duration::new(0, SPINTIME);
let mut was_closed = false;
let mut sw = SpinWait::new();
let mut first = true;
loop {
let tail = self.bus.tail.load(atomic::Ordering::Acquire);
if tail != self.head {
break;
}
if self.bus.closed.load(atomic::Ordering::Relaxed) {
if !was_closed {
was_closed = true;
continue;
}
self.closed = true;
return Err(std_mpsc::RecvTimeoutError::Disconnected);
}
if let RecvCondition::Try = block {
return Err(std_mpsc::RecvTimeoutError::Timeout);
}
if first {
if let Err(..) = self.waiting.send((thread::current(), self.head)) {
atomic::fence(atomic::Ordering::SeqCst);
continue;
}
first = false;
}
if !sw.spin() {
match block {
RecvCondition::Timeout(t) => {
match t.checked_sub(start.as_ref().unwrap().elapsed()) {
Some(left) => {
if left < spintime {
thread::park_timeout(left);
} else {
thread::park_timeout(spintime);
}
}
None => {
return Err(std_mpsc::RecvTimeoutError::Timeout);
}
}
}
RecvCondition::Block => {
thread::park_timeout(spintime);
}
RecvCondition::Try => unreachable!(),
}
}
}
let head = self.head;
let ret = self.bus.ring[head].take();
self.head = (head + 1) % self.bus.len;
Ok(ret)
}
pub fn try_recv(&mut self) -> Result<T, std_mpsc::TryRecvError> {
self.recv_inner(RecvCondition::Try).map_err(|e| match e {
std_mpsc::RecvTimeoutError::Disconnected => std_mpsc::TryRecvError::Disconnected,
std_mpsc::RecvTimeoutError::Timeout => std_mpsc::TryRecvError::Empty,
})
}
pub fn recv(&mut self) -> Result<T, std_mpsc::RecvError> {
match self.recv_inner(RecvCondition::Block) {
Ok(val) => Ok(val),
Err(std_mpsc::RecvTimeoutError::Disconnected) => Err(std_mpsc::RecvError),
_ => unreachable!("blocking recv_inner can't fail"),
}
}
pub fn recv_timeout(
&mut self,
timeout: time::Duration,
) -> Result<T, std_mpsc::RecvTimeoutError> {
self.recv_inner(RecvCondition::Timeout(timeout))
}
}
impl<T> BusReader<T> {
pub fn iter(&mut self) -> BusIter<'_, T> {
BusIter(self)
}
}
impl<T> Drop for BusReader<T> {
#[allow(unused_must_use)]
fn drop(&mut self) {
self.leaving.send(self.head);
}
}
pub struct BusIter<'a, T>(&'a mut BusReader<T>);
pub struct BusIntoIter<T>(BusReader<T>);
impl<'a, T: Clone + Sync> IntoIterator for &'a mut BusReader<T> {
type Item = T;
type IntoIter = BusIter<'a, T>;
fn into_iter(self) -> BusIter<'a, T> {
BusIter(self)
}
}
impl<T: Clone + Sync> IntoIterator for BusReader<T> {
type Item = T;
type IntoIter = BusIntoIter<T>;
fn into_iter(self) -> BusIntoIter<T> {
BusIntoIter(self)
}
}
impl<'a, T: Clone + Sync> Iterator for BusIter<'a, T> {
type Item = T;
fn next(&mut self) -> Option<T> {
self.0.recv().ok()
}
}
impl<T: Clone + Sync> Iterator for BusIntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
self.0.recv().ok()
}
}
struct AtomicOption<T> {
ptr: atomic::AtomicPtr<T>,
_marker: PhantomData<Option<Box<T>>>,
}
impl<T> fmt::Debug for AtomicOption<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AtomicOption")
.field("ptr", &self.ptr)
.finish()
}
}
unsafe impl<T: Send> Send for AtomicOption<T> {}
unsafe impl<T: Send> Sync for AtomicOption<T> {}
impl<T> AtomicOption<T> {
fn empty() -> Self {
Self {
ptr: atomic::AtomicPtr::new(ptr::null_mut()),
_marker: PhantomData,
}
}
fn swap(&self, val: Option<Box<T>>) -> Option<Box<T>> {
let old = match val {
Some(val) => self.ptr.swap(Box::into_raw(val), atomic::Ordering::AcqRel),
None => self.ptr.swap(ptr::null_mut(), atomic::Ordering::Acquire),
};
if old.is_null() {
None
} else {
Some(unsafe { Box::from_raw(old) })
}
}
fn take(&self) -> Option<Box<T>> {
self.swap(None)
}
}
impl<T> Drop for AtomicOption<T> {
fn drop(&mut self) {
drop(self.take());
}
}