use std::{
sync::Arc,
sync::atomic::{AtomicUsize, Ordering},
thread,
time::{Duration, Instant},
};
use crate::{SelectableReceiver, SelectableSender, internals::UNSELECTED};
trait SelectOpTrait {
fn is_ready(&self) -> bool;
fn register(&self, case_id: usize, selected: Arc<AtomicUsize>);
fn abort(&self, selected: &Arc<AtomicUsize>);
}
trait SendOpTrait {
fn is_ready(&self) -> bool;
fn register(&self, case_id: usize, selected: Arc<AtomicUsize>);
fn abort(&self, selected: &Arc<AtomicUsize>);
}
enum AnyOp {
Recv(Box<dyn SelectOpTrait + Send + Sync>),
Send(Box<dyn SendOpTrait + Send + Sync>),
}
impl AnyOp {
fn is_ready(&self) -> bool {
match self {
AnyOp::Recv(op) => op.is_ready(),
AnyOp::Send(op) => op.is_ready(),
}
}
fn register(&self, case_id: usize, selected: Arc<AtomicUsize>) {
match self {
AnyOp::Recv(op) => op.register(case_id, selected),
AnyOp::Send(op) => op.register(case_id, selected),
}
}
fn abort(&self, selected: &Arc<AtomicUsize>) {
match self {
AnyOp::Recv(op) => op.abort(selected),
AnyOp::Send(op) => op.abort(selected),
}
}
}
struct SelectOp<R: SelectableReceiver> {
receiver: R,
}
impl<R: SelectableReceiver + Send + Sync + 'static> SelectOpTrait for SelectOp<R>
where
R::Output: Send,
{
fn is_ready(&self) -> bool {
self.receiver.is_ready()
}
fn register(&self, case_id: usize, selected: Arc<AtomicUsize>) {
self.receiver.register_select(case_id, selected);
}
fn abort(&self, selected: &Arc<AtomicUsize>) {
self.receiver.abort_select(selected);
}
}
struct SendOp<S: SelectableSender> {
sender: S,
}
impl<S: SelectableSender + Send + Sync + 'static> SendOpTrait for SendOp<S>
where
S::Input: Send,
{
fn is_ready(&self) -> bool {
self.sender.is_ready()
}
fn register(&self, case_id: usize, selected: Arc<AtomicUsize>) {
self.sender.register_select(case_id, selected);
}
fn abort(&self, selected: &Arc<AtomicUsize>) {
self.sender.abort_select(selected);
}
}
pub struct Select {
ops: Vec<AnyOp>,
}
pub struct SelectedOperation {
pub index: usize,
}
static FAIRNESS_CTR: AtomicUsize = AtomicUsize::new(0);
impl Select {
pub fn new() -> Self {
Select { ops: Vec::new() }
}
pub fn recv<R>(&mut self, rx: R) -> usize
where
R: SelectableReceiver + Clone + Send + Sync + 'static,
R::Output: Send + 'static,
{
let idx = self.ops.len();
let op = SelectOp { receiver: rx };
self.ops.push(AnyOp::Recv(Box::new(op)));
idx
}
pub fn send<S>(&mut self, tx: S) -> usize
where
S: SelectableSender + Send + Sync + 'static,
S::Input: Send + 'static,
{
let idx = self.ops.len();
let op = SendOp { sender: tx };
self.ops.push(AnyOp::Send(Box::new(op)));
idx
}
pub fn select(&mut self) -> SelectedOperation {
self.select_impl(None)
.expect("select() called with zero arms")
}
pub fn try_select(&mut self) -> Option<SelectedOperation> {
self.select_impl(Some(Instant::now()))
}
pub fn select_timeout(&mut self, timeout: Duration) -> Option<SelectedOperation> {
self.select_impl(Some(Instant::now() + timeout))
}
pub fn select_deadline(&mut self, deadline: Instant) -> Option<SelectedOperation> {
self.select_impl(Some(deadline))
}
fn select_impl(&mut self, deadline: Option<Instant>) -> Option<SelectedOperation> {
assert!(!self.ops.is_empty(), "Select with no registered operations");
let n = self.ops.len();
loop {
log_debug!("select::select_impl: arms={}, deadline={:?}", n, deadline);
let start = FAIRNESS_CTR.fetch_add(1, Ordering::Relaxed) % n;
log_debug!("select::try phase: start_index={}", start);
for i in 0..n {
let idx = (start + i) % n;
if self.ops[idx].is_ready() {
log_debug!("select::try phase: ready idx={}", idx);
return Some(SelectedOperation { index: idx });
}
}
if deadline.map(|d| Instant::now() >= d).unwrap_or(false) {
log_debug!("select::select_impl: deadline reached before park");
return None;
}
let selected = Arc::new(AtomicUsize::new(UNSELECTED));
for (idx, op) in self.ops.iter().enumerate() {
log_debug!("select::register phase: registering arm={}", idx);
op.register(idx, Arc::clone(&selected));
}
log_debug!("select::register phase: registered {} waiters", n);
for (idx, op) in self.ops.iter().enumerate() {
if op.is_ready() {
log_debug!("select::recheck: arm {} became ready during register", idx);
selected
.compare_exchange(UNSELECTED, idx, Ordering::SeqCst, Ordering::SeqCst)
.ok();
break;
}
}
if selected.load(Ordering::SeqCst) == UNSELECTED {
match deadline {
None => {
log_debug!("select::park phase: parking indefinitely");
thread::park()
}
Some(dl) => {
let wait = dl.saturating_duration_since(Instant::now());
log_debug!("select::park phase: parking with timeout={:?}", wait);
thread::park_timeout(wait)
}
}
}
for op in &self.ops {
op.abort(&selected);
}
let won = selected.load(Ordering::SeqCst);
if won != UNSELECTED {
log_debug!("select::winner: idx={}", won);
return Some(SelectedOperation { index: won });
}
if deadline.map(|d| Instant::now() >= d).unwrap_or(false) {
return None;
}
}
}
}
impl Default for Select {
fn default() -> Self {
Self::new()
}
}
#[macro_export]
macro_rules! select {
($(recv($rx:expr) -> $var:pat => $body:expr),+ $(,)?) => {{
let mut __sel = $crate::Select::new();
$( __sel.recv($rx.clone()); )+
let __oper = __sel.select();
let mut __n = 0usize;
$crate::select!(@arm __oper __n $(, recv($rx) -> $var => $body)+)
}};
($(recv($rx:expr) -> $var:pat => $body:expr,)+ default => $def:expr $(,)?) => {{
let mut __sel = $crate::Select::new();
$( __sel.recv($rx.clone()); )+
if let Some(__oper) = __sel.try_select() {
let mut __n = 0usize;
$crate::select!(@arm __oper __n $(, recv($rx) -> $var => $body)+)
} else {
$def
}
}};
($(recv($rx:expr) -> $var:pat => $body:expr,)+ default($dur:expr) => $def:expr $(,)?) => {{
let mut __sel = $crate::Select::new();
$( __sel.recv($rx.clone()); )+
if let Some(__oper) = __sel.select_timeout($dur) {
let mut __n = 0usize;
$crate::select!(@arm __oper __n $(, recv($rx) -> $var => $body)+)
} else {
$def
}
}};
(
send ($tx:expr, $val:expr) -> $var:pat => $body:expr,
default => $def:expr $(,)?
) => {{
let mut __sel = $crate::Select::new();
__sel.send(($tx).clone());
if let Some(__oper) = __sel.try_select() {
let mut __n = 0usize;
$crate::select!(@arm_new __oper __n, send ($tx, $val) -> $var => $body)
} else {
$def
}
}};
(
send ($tx:expr, $val:expr) -> $var:pat => $body:expr,
default($dur:expr) => $def:expr $(,)?
) => {{
let mut __sel = $crate::Select::new();
__sel.send(($tx).clone());
if let Some(__oper) = __sel.select_timeout($dur) {
let mut __n = 0usize;
$crate::select!(@arm_new __oper __n, send ($tx, $val) -> $var => $body)
} else {
$def
}
}};
(
$fk:ident ($($fa:tt)*) -> $fv:pat => $fb:expr
$(, $k:ident ($($a:tt)*) -> $v:pat => $b:expr)*
$(,)?
) => {{
let mut __sel = $crate::Select::new();
$crate::select!(@register __sel, $fk ($($fa)*));
$( $crate::select!(@register __sel, $k ($($a)*)); )*
let __oper = __sel.select();
let mut __n = 0usize;
$crate::select!(@arm_new __oper __n,
$fk ($($fa)*) -> $fv => $fb
$(, $k ($($a)*) -> $v => $b)*
)
}};
(
$fk:ident ($($fa:tt)*) -> $fv:pat => $fb:expr
$(, $k:ident ($($a:tt)*) -> $v:pat => $b:expr)*
, default => $def:expr
$(,)?
) => {{
let mut __sel = $crate::Select::new();
$crate::select!(@register __sel, $fk ($($fa)*));
$( $crate::select!(@register __sel, $k ($($a)*)); )*
if let Some(__oper) = __sel.try_select() {
let mut __n = 0usize;
$crate::select!(@arm_new __oper __n,
$fk ($($fa)*) -> $fv => $fb
$(, $k ($($a)*) -> $v => $b)*
)
} else {
$def
}
}};
(
$fk:ident ($($fa:tt)*) -> $fv:pat => $fb:expr
$(, $k:ident ($($a:tt)*) -> $v:pat => $b:expr)*
, default($dur:expr) => $def:expr
$(,)?
) => {{
let mut __sel = $crate::Select::new();
$crate::select!(@register __sel, $fk ($($fa)*));
$( $crate::select!(@register __sel, $k ($($a)*)); )*
if let Some(__oper) = __sel.select_timeout($dur) {
let mut __n = 0usize;
$crate::select!(@arm_new __oper __n,
$fk ($($fa)*) -> $fv => $fb
$(, $k ($($a)*) -> $v => $b)*
)
} else {
$def
}
}};
(@register $sel:ident, recv ($rx:expr)) => {
$sel.recv(($rx).clone());
};
(@register $sel:ident, send ($tx:expr, $val:expr)) => {
$sel.send(($tx).clone());
};
(@arm_new $oper:ident $n:ident,
recv ($rx:expr) -> $var:pat => $body:expr
$(, $k:ident ($($a:tt)*) -> $kv:pat => $kb:expr)+
) => {{
let __i = $n; $n += 1;
if $oper.index == __i {
let $var = $crate::SelectableReceiver::complete(&($rx));
$body
} else {
$crate::select!(@arm_new $oper $n $(, $k ($($a)*) -> $kv => $kb)+)
}
}};
(@arm_new $oper:ident $n:ident,
send ($tx:expr, $val:expr) -> $var:pat => $body:expr
$(, $k:ident ($($a:tt)*) -> $kv:pat => $kb:expr)+
) => {{
let __i = $n; $n += 1;
if $oper.index == __i {
let $var = $crate::SelectableSender::complete_send(&($tx), $val);
$body
} else {
$crate::select!(@arm_new $oper $n $(, $k ($($a)*) -> $kv => $kb)+)
}
}};
(@arm_new $oper:ident $n:ident, recv ($rx:expr) -> $var:pat => $body:expr) => {{
let __i = $n;
if $oper.index == __i {
let $var = $crate::SelectableReceiver::complete(&($rx));
$body
} else {
unreachable!(
"select!: winning index {} >= arm count {}",
$oper.index, __i + 1
)
}
}};
(@arm_new $oper:ident $n:ident, send ($tx:expr, $val:expr) -> $var:pat => $body:expr) => {{
let __i = $n;
if $oper.index == __i {
let $var = $crate::SelectableSender::complete_send(&($tx), $val);
$body
} else {
unreachable!(
"select!: winning index {} >= arm count {}",
$oper.index, __i + 1
)
}
}};
(@arm $oper:ident $n:ident,
$kind:ident ($rx:expr) -> $var:pat => $body:expr,
$($rest:tt)+
) => {{
let __i = $n; $n += 1;
if $oper.index == __i {
let $var = $crate::SelectableReceiver::complete(&($rx));
$body
} else {
$crate::select!(@arm $oper $n, $($rest)+)
}
}};
(@arm $oper:ident $n:ident, $kind:ident ($rx:expr) -> $var:pat => $body:expr) => {{
let __i = $n;
if $oper.index == __i {
let $var = $crate::SelectableReceiver::complete(&($rx));
$body
} else {
unreachable!(
"select!: winning index {} >= arm count {}",
$oper.index, __i + 1
)
}
}};
}
#[cfg(test)]
mod tests {
use std::{sync::mpsc, thread, time::Duration};
use crate::{bounded_mpmc, bounded_mpsc, rendezvous, unbounded_mpmc, unbounded_mpsc, watch};
#[test]
fn watch_with_instant_default() {
let (_tx, rx) = watch::channel::<&str>();
let fired = mpsc::channel();
let (done_tx, done_rx) = fired;
select! {
recv(rx) -> _version => panic!("watch arm should not be ready yet"),
default => done_tx.send(true).unwrap(),
}
assert!(done_rx.recv().unwrap());
}
#[test]
fn watch_with_timeout_default() {
let (_tx, rx) = watch::channel::<&str>();
let before = std::time::Instant::now();
select! {
recv(rx) -> _version => panic!("watch arm should not be ready yet"),
default(Duration::from_millis(20)) => {}
}
assert!(before.elapsed() >= Duration::from_millis(20));
}
#[test]
fn mixed_recv_and_watch_blocking_select() {
let (tx_msg, rx_msg) = unbounded_mpmc::channel::<i32>();
let (tx_watch, watch_rx) = watch::channel::<&str>();
thread::spawn(move || {
thread::sleep(Duration::from_millis(10));
tx_watch.send("ready").unwrap();
thread::sleep(Duration::from_millis(10));
tx_msg.send(42).unwrap();
});
select! {
recv(watch_rx) -> version => assert_eq!(version, Ok(1)),
recv(rx_msg) -> _msg => panic!("message arm should lose this race"),
}
}
#[test]
fn test_send_arm_unbounded_mpmc() {
let (tx, rx) = unbounded_mpmc::channel::<i32>();
select! {
send(tx, 99) -> res => assert!(res.is_ok()),
}
assert_eq!(rx.try_recv().unwrap(), 99);
}
#[test]
fn test_send_arm_unbounded_mpsc() {
let (tx, rx) = unbounded_mpsc::channel::<i32>();
select! {
send(tx, 77) -> res => assert!(res.is_ok()),
}
assert_eq!(rx.try_recv().unwrap(), 77);
}
#[test]
fn test_send_arm_bounded_mpmc() {
let (tx, rx) = bounded_mpmc::channel::<i32>(4);
select! {
send(tx, 42) -> res => assert!(res.is_ok()),
}
assert_eq!(rx.try_recv().unwrap(), 42);
}
#[test]
fn test_send_arm_watch() {
let (tx, rx) = watch::channel::<i32>();
select! {
send(tx, 10) -> res => assert!(res.is_ok()),
}
let snapshot = rx.borrow_arc();
assert!(snapshot.is_some());
assert_eq!(*snapshot.unwrap(), 10);
}
#[test]
fn test_send_arm_bounded_mpmc_blocks_when_full() {
let (tx, rx) = bounded_mpmc::channel::<i32>(1);
tx.send(1).unwrap();
let rx2 = rx.clone();
thread::spawn(move || {
thread::sleep(Duration::from_millis(20));
let _ = rx2.try_recv(); });
select! {
send(tx, 2) -> res => assert!(res.is_ok()),
}
}
#[test]
fn test_send_arm_disconnect_bounded() {
let (tx, rx) = bounded_mpmc::channel::<i32>(4);
drop(rx);
select! {
send(tx, 5) -> res => assert!(res.is_ok()),
}
}
#[test]
fn test_send_arm_disconnect_unbounded() {
let (tx, rx) = unbounded_mpmc::channel::<i32>();
drop(rx);
select! {
send(tx, 5) -> res => assert!(res.is_ok()),
}
}
#[test]
fn test_send_arm_rendezvous() {
let (tx, rx) = rendezvous::channel::<i32>();
let handle = thread::spawn(move || rx.recv().unwrap());
thread::sleep(Duration::from_millis(15));
select! {
send(tx, 55) -> res => assert!(res.is_ok()),
}
assert_eq!(handle.join().unwrap(), 55);
}
#[test]
fn test_mixed_recv_send_select() {
let (tx_msg, rx_msg) = unbounded_mpmc::channel::<i32>();
let (tx_out, _rx_out) = unbounded_mpmc::channel::<i32>();
tx_msg.send(7).unwrap();
select! {
recv(rx_msg) -> msg => assert_eq!(msg.unwrap(), 7),
send(tx_out, 99) -> res => assert!(res.is_ok()),
}
}
#[test]
fn test_send_arm_default_when_not_ready() {
let (tx, _rx) = bounded_mpsc::channel::<i32>(1);
tx.send(1).unwrap();
select! {
send(tx, 2) -> _res => panic!("send arm should not be ready"),
default => {}
}
}
#[test]
fn test_send_arm_timeout_default() {
let (tx, _rx) = bounded_mpsc::channel::<i32>(1);
tx.send(1).unwrap();
let before = std::time::Instant::now();
select! {
send(tx, 2) -> _res => panic!("send arm should not be ready"),
default(Duration::from_millis(30)) => {}
}
assert!(before.elapsed() >= Duration::from_millis(30));
}
#[test]
fn three_arm_select_fires_only_ready_arm() {
let (tx1, rx1) = unbounded_mpmc::channel::<i32>();
let (_tx2, rx2) = unbounded_mpmc::channel::<i32>();
let (_tx3, rx3) = unbounded_mpmc::channel::<i32>();
tx1.send(42).unwrap();
select! {
recv(rx1) -> msg => assert_eq!(msg.unwrap(), 42),
recv(rx2) -> _ => panic!("arm 2 must not fire"),
recv(rx3) -> _ => panic!("arm 3 must not fire"),
}
}
#[test]
fn fairness_distributes_roughly_evenly() {
const ITERS: usize = 200;
let (tx1, rx1) = unbounded_mpmc::channel::<()>();
let (tx2, rx2) = unbounded_mpmc::channel::<()>();
for _ in 0..ITERS {
tx1.send(()).unwrap();
tx2.send(()).unwrap();
}
let mut count = [0usize; 2];
for _ in 0..ITERS {
select! {
recv(rx1) -> _ => count[0] += 1,
recv(rx2) -> _ => count[1] += 1,
}
}
assert!(count[0] >= ITERS / 5, "arm 0 fired only {} times", count[0]);
assert!(count[1] >= ITERS / 5, "arm 1 fired only {} times", count[1]);
}
#[test]
fn blocking_select_waits_for_message() {
let (tx, rx) = unbounded_mpmc::channel::<i32>();
let never_rx = bounded_mpmc::never::<i32>();
thread::spawn(move || {
thread::sleep(Duration::from_millis(20));
tx.send(123).unwrap();
});
select! {
recv(rx) -> msg => assert_eq!(msg.unwrap(), 123),
recv(never_rx) -> _ => panic!("never arm must not fire"),
}
}
#[test]
fn recv_arm_wins_over_disconnected_send_arm() {
let (tx_msg, rx_msg) = unbounded_mpmc::channel::<i32>();
let (tx_out, _rx_out) = bounded_mpmc::channel::<i32>(4);
tx_msg.send(7).unwrap();
let mut got = 0i32;
select! {
recv(rx_msg) -> msg => got = msg.unwrap(),
send(tx_out, 99) -> _res => {},
}
let _ = rx_msg.try_recv(); let _ = got; }
#[test]
fn timeout_select_returns_within_deadline() {
let (tx, rx) = unbounded_mpmc::channel::<i32>();
drop(tx); let never_rx = bounded_mpmc::never::<i32>();
let before = std::time::Instant::now();
select! {
recv(never_rx) -> _ => panic!("should not fire"),
default(Duration::from_millis(40)) => {}
}
let _ = rx; let elapsed = before.elapsed();
assert!(elapsed >= Duration::from_millis(40));
assert!(elapsed < Duration::from_millis(500));
}
}