use crate::*;
use std::{any::Any, marker::PhantomData};
#[cfg(feature = "eventual-fairness")]
use nanorand::RNG;
type Token = usize;
struct SelectSignal(
thread::Thread,
Token,
AtomicBool,
Arc<Spinlock<VecDeque<Token>>>,
);
impl Signal for SelectSignal {
fn fire(&self) -> bool {
self.2.store(true, Ordering::SeqCst);
self.3.lock().push_back(self.1);
self.0.unpark();
false
}
fn as_any(&self) -> &(dyn Any + 'static) {
self
}
fn as_ptr(&self) -> *const () {
self as *const _ as *const ()
}
}
trait Selection<'a, T> {
fn init(&mut self) -> Option<T>;
fn poll(&mut self) -> Option<T>;
fn deinit(&mut self);
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum SelectError {
Timeout,
}
impl fmt::Display for SelectError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SelectError::Timeout => "timeout occurred".fmt(f),
}
}
}
impl std::error::Error for SelectError {}
pub struct Selector<'a, T: 'a> {
selections: Vec<Box<dyn Selection<'a, T> + 'a>>,
next_poll: usize,
signalled: Arc<Spinlock<VecDeque<Token>>>,
#[cfg(feature = "eventual-fairness")]
rng: nanorand::WyRand,
phantom: PhantomData<*const ()>,
}
impl<'a, T: 'a> Default for Selector<'a, T> {
fn default() -> Self {
Self::new()
}
}
impl<'a, T: 'a> fmt::Debug for Selector<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Selector").finish()
}
}
impl<'a, T> Selector<'a, T> {
pub fn new() -> Self {
Self {
selections: Vec::new(),
next_poll: 0,
signalled: Arc::default(),
phantom: PhantomData::default(),
#[cfg(feature = "eventual-fairness")]
rng: nanorand::WyRand::new(),
}
}
pub fn send<U, F: FnMut(Result<(), SendError<U>>) -> T + 'a>(
mut self,
sender: &'a Sender<U>,
msg: U,
mapper: F,
) -> Self {
struct SendSelection<'a, T, F, U> {
sender: &'a Sender<U>,
msg: Option<U>,
token: Token,
signalled: Arc<Spinlock<VecDeque<Token>>>,
hook: Option<Arc<Hook<U, SelectSignal>>>,
mapper: F,
phantom: PhantomData<T>,
}
impl<'a, T, F, U> Selection<'a, T> for SendSelection<'a, T, F, U>
where
F: FnMut(Result<(), SendError<U>>) -> T,
{
fn init(&mut self) -> Option<T> {
let token = self.token;
let signalled = self.signalled.clone();
let r = self.sender.shared.send(
self.msg.take().unwrap(),
true,
|msg| {
Hook::slot(
Some(msg),
SelectSignal(
thread::current(),
token,
AtomicBool::new(false),
signalled,
),
)
},
|h| {
self.hook = Some(h);
Ok(())
},
);
if self.hook.is_none() {
Some((self.mapper)(match r {
Ok(()) => Ok(()),
Err(TrySendTimeoutError::Disconnected(msg)) => Err(SendError(msg)),
_ => unreachable!(),
}))
} else {
None
}
}
fn poll(&mut self) -> Option<T> {
let res = if self.sender.shared.is_disconnected() {
if let Some(msg) = self.hook.as_ref()?.try_take() {
Err(SendError(msg))
} else {
Ok(())
}
} else if self.hook.as_ref().unwrap().is_empty() {
Ok(())
} else {
return None;
};
Some((&mut self.mapper)(res))
}
fn deinit(&mut self) {
if let Some(hook) = self.hook.take() {
let hook: Arc<Hook<U, dyn Signal>> = hook;
wait_lock(&self.sender.shared.chan)
.sending
.as_mut()
.unwrap()
.1
.retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
}
}
}
let token = self.selections.len();
self.selections.push(Box::new(SendSelection {
sender,
msg: Some(msg),
token,
signalled: self.signalled.clone(),
hook: None,
mapper,
phantom: Default::default(),
}));
self
}
pub fn recv<U, F: FnMut(Result<U, RecvError>) -> T + 'a>(
mut self,
receiver: &'a Receiver<U>,
mapper: F,
) -> Self {
struct RecvSelection<'a, T, F, U> {
receiver: &'a Receiver<U>,
token: Token,
signalled: Arc<Spinlock<VecDeque<Token>>>,
hook: Option<Arc<Hook<U, SelectSignal>>>,
mapper: F,
received: bool,
phantom: PhantomData<T>,
}
impl<'a, T, F, U> Selection<'a, T> for RecvSelection<'a, T, F, U>
where
F: FnMut(Result<U, RecvError>) -> T,
{
fn init(&mut self) -> Option<T> {
let token = self.token;
let signalled = self.signalled.clone();
let r = self.receiver.shared.recv(
true,
|| {
Hook::trigger(SelectSignal(
thread::current(),
token,
AtomicBool::new(false),
signalled,
))
},
|h| {
self.hook = Some(h);
Err(TryRecvTimeoutError::Timeout)
},
);
if self.hook.is_none() {
Some((self.mapper)(match r {
Ok(msg) => Ok(msg),
Err(TryRecvTimeoutError::Disconnected) => Err(RecvError::Disconnected),
_ => unreachable!(),
}))
} else {
None
}
}
fn poll(&mut self) -> Option<T> {
let res = if let Ok(msg) = self.receiver.try_recv() {
self.received = true;
Ok(msg)
} else if self.receiver.shared.is_disconnected() {
Err(RecvError::Disconnected)
} else {
return None;
};
Some((&mut self.mapper)(res))
}
fn deinit(&mut self) {
if let Some(hook) = self.hook.take() {
let hook: Arc<Hook<U, dyn Signal>> = hook;
wait_lock(&self.receiver.shared.chan)
.waiting
.retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
if !self.received
&& hook
.signal()
.as_any()
.downcast_ref::<SelectSignal>()
.unwrap()
.2
.load(Ordering::SeqCst)
{
wait_lock(&self.receiver.shared.chan).try_wake_receiver_if_pending();
}
}
}
}
let token = self.selections.len();
self.selections.push(Box::new(RecvSelection {
receiver,
token,
signalled: self.signalled.clone(),
hook: None,
mapper,
received: false,
phantom: Default::default(),
}));
self
}
fn wait_inner(mut self, deadline: Option<Instant>) -> Option<T> {
#[cfg(feature = "eventual-fairness")]
{
self.next_poll = self.rng.generate_range(0, self.selections.len());
}
let res = 'outer: loop {
for _ in 0..self.selections.len() {
if let Some(val) = self.selections[self.next_poll].init() {
break 'outer Some(val);
}
self.next_poll = (self.next_poll + 1) % self.selections.len();
}
if let Some(msg) = self.poll() {
break 'outer Some(msg);
}
loop {
if let Some(deadline) = deadline {
if let Some(dur) = deadline.checked_duration_since(Instant::now()) {
thread::park_timeout(dur);
}
} else {
thread::park();
}
if deadline.map(|d| Instant::now() >= d).unwrap_or(false) {
break 'outer self.poll();
}
let token = if let Some(token) = self.signalled.lock().pop_front() {
token
} else {
continue;
};
if let Some(msg) = self.selections[token].poll() {
break 'outer Some(msg);
}
}
};
for s in &mut self.selections {
s.deinit();
}
res
}
fn poll(&mut self) -> Option<T> {
for _ in 0..self.selections.len() {
if let Some(val) = self.selections[self.next_poll].poll() {
return Some(val);
}
self.next_poll = (self.next_poll + 1) % self.selections.len();
}
None
}
pub fn wait(self) -> T {
self.wait_inner(None).unwrap()
}
pub fn wait_timeout(self, dur: Duration) -> Result<T, SelectError> {
self.wait_inner(Some(Instant::now() + dur))
.ok_or(SelectError::Timeout)
}
pub fn wait_deadline(self, deadline: Instant) -> Result<T, SelectError> {
self.wait_inner(Some(deadline)).ok_or(SelectError::Timeout)
}
}