Skip to main content

unbounded_spsc/
select.rs

1#![allow(dead_code)]
2
3use std;
4use crate::{blocking, Receiver, RecvError, SelectionResult};
5
6/// A "receiver set" structure used to manage a set of receivers being selected
7/// over.
8pub struct Select {
9  inner   : std::cell::UnsafeCell <Inner>,
10  next_id : std::cell::Cell <usize>
11}
12impl !Send for Select {}
13
14/// Handle to a receiver which is currently a member of a `Select` set of
15/// receivers, used to keep the receiver in the set as well as to interact
16/// with the underlying receiver.
17pub struct Handle <'rx, T : Send + 'rx> {
18  /// The ID of this handle, used to compare against the return value of
19  /// `Select:::wait()`
20  id       : usize,
21  selector : *mut Inner,
22  next     : *mut Handle <'static, ()>,
23  prev     : *mut Handle <'static, ()>,
24  added    : bool,
25  packet   : &'rx (dyn Packet + 'rx),
26  // due to our fun transmutes, be sure to place this at the end. (nothing
27  // previous relies on T)
28  rx       : &'rx Receiver <T>
29}
30
31struct Inner {
32  head : *mut Handle <'static, ()>,
33  tail : *mut Handle <'static, ()>
34}
35
36struct HandleIter {
37  cur : *mut Handle <'static, ()>
38}
39
40#[derive(PartialEq, Eq)]
41pub enum StartResult {
42  Installed,
43  Abort
44}
45
46pub trait Packet {
47  fn can_recv        (&self) -> bool;
48  fn start_selection (&self, token : blocking::SignalToken) -> StartResult;
49  fn abort_selection (&self) -> bool;
50}
51
52impl Select {
53  /// New empty selection structure.
54  pub const fn new () -> Select {
55    Select {
56      inner: std::cell::UnsafeCell::new (Inner {
57        head: std::ptr::null_mut(),
58        tail: std::ptr::null_mut()
59      }),
60      next_id: std::cell::Cell::new (1)
61    }
62  }
63
64  /// New handle into this receiver set for a new receiver; does *not* add the
65  /// receiver to the receiver set, for that call `add` on the handle itself.
66  pub fn handle <'a, T> (&'a self, rx : &'a Receiver <T>) -> Handle <'a, T>
67    where T : Send
68  {
69    let id = self.next_id.get();
70    self.next_id.set (id + 1);
71    Handle {
72      id,
73      selector: self.inner.get(),
74      next:     std::ptr::null_mut(),
75      prev:     std::ptr::null_mut(),
76      added:    false,
77      rx,
78      packet:   rx
79    }
80  }
81
82  /// Wait for an "event" on this receiver set. Returns an ID that can be
83  /// queried against any active `Handle` structures `id` method. The handle
84  /// with the matching `id` will have some sort of "event" available on it:
85  /// either that data is available or the corresponding channel has been
86  /// closed.
87  pub fn wait (&self) -> usize {
88    self.wait2 (true)
89  }
90
91  /// Helper method for skipping the "preflight checks" during testing
92  fn wait2 (&self, do_preflight_checks : bool) -> usize {
93    unsafe {
94      // Stage 1: preflight checks
95      if do_preflight_checks {
96        for handle in self.iter() {
97          if (*handle).packet.can_recv() {
98            return (*handle).id();
99          }
100        }
101      }
102      // Stage 2: begin blocking process
103      let (wait_token, signal_token) = blocking::tokens();
104      for (i, handle) in self.iter().enumerate() {
105        match (*handle).packet.start_selection (signal_token.clone()) {
106          StartResult::Installed => {}
107          StartResult::Abort     => {
108            for handle in self.iter().take (i) {
109              (*handle).packet.abort_selection();
110            }
111            return (*handle).id;
112          }
113        }
114      }
115      // Stage 3: no message availble, actually block
116      wait_token.wait();
117      // Stage 4: must be a message; find it
118      let mut ready_id = usize::MAX;
119      for handle in self.iter() {
120        if (*handle).packet.abort_selection() {
121          ready_id = (*handle).id;
122        }
123      }
124
125      // must have found a ready receiver
126      assert_ne!(ready_id, usize::MAX);
127      ready_id
128    }
129  }
130
131  fn iter (&self) -> HandleIter {
132    HandleIter {
133      cur: unsafe { &*self.inner.get() }.head
134    }
135  }
136}
137
138impl std::fmt::Debug for Select {
139  fn fmt (&self, f : &mut std::fmt::Formatter) -> std::fmt::Result {
140    write!(f, "Select {{ .. }}")
141  }
142}
143
144impl Drop for Select {
145  fn drop (&mut self) {
146    unsafe {
147      assert!((&*self.inner.get()).head.is_null());
148      assert!((&*self.inner.get()).tail.is_null());
149    }
150  }
151}
152
153impl <'rx, T> Handle <'rx, T> where T : Send {
154  #[inline]
155  pub const fn id (&self) -> usize {
156    self.id
157  }
158
159  pub fn recv (&self) -> Result <T, RecvError> {
160    self.rx.recv()
161  }
162
163  /// Add this handle to the receiver set that the handle was created from.
164  pub unsafe fn add (&mut self) {
165    if self.added {
166      return
167    }
168
169    let selector = unsafe { &mut *self.selector };
170    let me = std::ptr::from_mut::<Handle <'rx, T>> (self) as *mut Handle <'static, ()>;
171    if selector.head.is_null() {
172      selector.head = me;
173    } else {
174      unsafe {
175        (*me).prev = selector.tail;
176        assert!((*me).next.is_null());
177        (*selector.tail).next = me;
178      }
179    }
180    selector.tail = me;
181
182    self.added = true;
183  }
184
185  /// Remove this handle from the receiver set.
186  pub unsafe fn remove (&mut self) {
187    if !self.added {
188      return
189    }
190
191    let selector = unsafe { &mut *self.selector };
192    let me = std::ptr::from_mut::<Handle <'rx, T>>(self) as *mut Handle <'static, ()>;
193    if self.prev.is_null() {
194      assert_eq!(selector.head, me);
195      selector.head = self.next;
196    } else {
197      unsafe { (*self.prev).next = self.next; }
198    }
199    if self.next.is_null() {
200      assert_eq!(selector.tail, me);
201      selector.tail = self.prev;
202    } else {
203      unsafe { (*self.next).prev = self.prev; }
204    }
205
206    self.next = std::ptr::null_mut();
207    self.prev = std::ptr::null_mut();
208    self.added = false;
209  }
210}
211
212impl <'rx, T> std::fmt::Debug for Handle <'rx, T> where T : Send + 'rx {
213  fn fmt (&self, f : &mut std::fmt::Formatter) -> std::fmt::Result {
214    write!(f, "Handle {{ .. }}")
215  }
216}
217
218impl <T> Drop for Handle <'_, T> where T : Send {
219  fn drop (&mut self) {
220    unsafe { self.remove() }
221  }
222}
223
224impl Iterator for HandleIter {
225  type Item = *mut Handle <'static, ()>;
226  fn next (&mut self) -> Option <*mut Handle <'static, ()>> {
227    if self.cur.is_null() {
228      None
229    } else {
230      let ret = Some (self.cur);
231      unsafe {
232        self.cur = (*self.cur).next;
233      }
234      ret
235    }
236  }
237}
238
239impl <T> Packet for Receiver <T> {
240  #[inline]
241  fn can_recv (&self) -> bool {
242    self.can_recv_()
243  }
244  #[inline]
245  fn start_selection (&self, token : blocking::SignalToken) -> StartResult {
246    match self.start_selection_ (token) {
247      SelectionResult::SelSuccess  => StartResult::Installed,
248      SelectionResult::SelCanceled => StartResult::Abort
249    }
250  }
251  #[inline]
252  fn abort_selection (&self) -> bool {
253    self.abort_selection_()
254  }
255}
256
257#[macro_export]
258macro_rules! select {
259  (
260    $($name:pat = $rx:ident.$meth:ident() => $code:expr),+
261  ) => {{
262    let sel = Select::new();
263    $(
264    let mut $rx = sel.handle (&$rx);
265    )+
266    unsafe {
267      $($rx.add();)+
268    }
269    let ret = sel.wait();
270    $(
271    if ret == $rx.id() {
272      let $name = $rx.$meth(); $code
273    } else
274    )+
275    { unreachable!() }
276  }}
277}
278
279#[cfg(test)]
280mod tests {
281  use super::*;
282  use super::super::*;
283
284  #[test]
285  fn smoke() {
286    let (tx1, rx1) = channel::<i32>();
287    let (tx2, rx2) = channel::<i32>();
288    tx1.send (1).unwrap();
289    select! {
290      foo = rx1.recv() => { assert_eq!(foo.unwrap(), 1); },
291      _bar = rx2.recv() => panic!()
292    }
293    tx2.send (2).unwrap();
294    select! {
295      _foo = rx1.recv() => panic!(),
296      bar = rx2.recv() => assert_eq!(bar.unwrap(), 2)
297    }
298    drop(tx1);
299    select! {
300      foo = rx1.recv() => { foo.unwrap_err(); },
301      _bar = rx2.recv() => panic!()
302    }
303    drop(tx2);
304    select! {
305      bar = rx2.recv() => { bar.unwrap_err(); }
306    }
307  }
308
309  #[test]
310  fn smoke2() {
311    let (_tx1, rx1) = channel::<i32>();
312    let (_tx2, rx2) = channel::<i32>();
313    let (_tx3, rx3) = channel::<i32>();
314    let (_tx4, rx4) = channel::<i32>();
315    let (tx5, rx5) = channel::<i32>();
316    tx5.send (4).unwrap();
317    select! {
318      _foo = rx1.recv() => panic!("1"),
319      _foo = rx2.recv() => panic!("2"),
320      _foo = rx3.recv() => panic!("3"),
321      _foo = rx4.recv() => panic!("4"),
322      foo = rx5.recv() => { assert_eq!(foo.unwrap(), 4); }
323    }
324  }
325
326  #[test]
327  fn closed() {
328    let (_tx1, rx1) = channel::<i32>();
329    let (tx2, rx2) = channel::<i32>();
330    drop(tx2);
331
332    select! {
333      _a1 = rx1.recv() => panic!(),
334      a2 = rx2.recv() => { a2.unwrap_err(); }
335    }
336  }
337
338  #[test]
339  fn unblocks() {
340    let (tx1, rx1) = channel::<i32>();
341    let (_tx2, rx2) = channel::<i32>();
342    let (tx3, rx3) = channel::<i32>();
343
344    let _t = std::thread::spawn(move|| {
345      for _ in 0..20 { std::thread::yield_now(); }
346      tx1.send (1).unwrap();
347      rx3.recv().unwrap();
348      for _ in 0..20 { std::thread::yield_now(); }
349    });
350
351    select! {
352      a = rx1.recv() => { assert_eq!(a.unwrap(), 1); },
353      _b = rx2.recv() => panic!()
354    }
355    tx3.send (1).unwrap();
356    select! {
357      a = rx1.recv() => assert!(a.is_err()),
358      _b = rx2.recv() => panic!()
359    }
360  }
361
362  #[test]
363  fn both_ready() {
364    let (tx1, rx1) = channel::<i32>();
365    let (tx2, rx2) = channel::<i32>();
366    let (tx3, rx3) = channel::<bool>();
367
368    let _t = std::thread::spawn(move|| {
369      for _ in 0..20 { std::thread::yield_now(); }
370      tx1.send (1).unwrap();
371      tx2.send (2).unwrap();
372      rx3.recv().unwrap();
373    });
374
375    select! {
376      a = rx1.recv() => { assert_eq!(a.unwrap(), 1); },
377      a = rx2.recv() => { assert_eq!(a.unwrap(), 2); }
378    }
379    select! {
380      a = rx1.recv() => { assert_eq!(a.unwrap(), 1); },
381      a = rx2.recv() => { assert_eq!(a.unwrap(), 2); }
382    }
383    assert_eq!(rx1.try_recv(), Err (TryRecvError::Empty));
384    assert_eq!(rx2.try_recv(), Err (TryRecvError::Empty));
385    tx3.send (true).unwrap();
386  }
387
388  #[test]
389  fn stress() {
390    const AMT: i32 = 10000;
391    let (tx1, rx1) = channel::<i32>();
392    let (tx2, rx2) = channel::<i32>();
393    let (tx3, rx3) = channel::<bool>();
394
395    let _t = std::thread::spawn(move|| {
396      for i in 0..AMT {
397        if i % 2 == 0 {
398          tx1.send (i).unwrap();
399        } else {
400          tx2.send (i).unwrap();
401        }
402        rx3.recv().unwrap();
403      }
404    });
405
406    for i in 0..AMT {
407      select! {
408        i1 = rx1.recv() => { assert!(i % 2 == 0 && i == i1.unwrap()); },
409        i2 = rx2.recv() => { assert!(i % 2 == 1 && i == i2.unwrap()); }
410      }
411      tx3.send (true).unwrap();
412    }
413  }
414
415  #[test]
416  fn preflight1() {
417    let (tx, rx) = channel();
418    tx.send (true).unwrap();
419    select! {
420      _n = rx.recv() => {}
421    }
422  }
423
424  #[test]
425  fn preflight2() {
426    let (tx, rx) = channel();
427    tx.send (true).unwrap();
428    tx.send (true).unwrap();
429    select! {
430      _n = rx.recv() => {}
431    }
432  }
433
434  #[test]
435  fn preflight4() {
436    let (tx, rx) = channel();
437    tx.send (true).unwrap();
438    let s = Select::new();
439    let mut h = s.handle (&rx);
440    unsafe { h.add(); }
441    assert_eq!(s.wait2 (false), h.id);
442  }
443
444  #[test]
445  fn preflight5() {
446    let (tx, rx) = channel();
447    tx.send (true).unwrap();
448    tx.send (true).unwrap();
449    let s = Select::new();
450    let mut h = s.handle(&rx);
451    unsafe { h.add(); }
452    assert_eq!(s.wait2 (false), h.id);
453  }
454
455  #[test]
456  fn preflight7() {
457    let (tx, rx) = channel::<bool>();
458    drop(tx);
459    let s = Select::new();
460    let mut h = s.handle(&rx);
461    unsafe { h.add(); }
462    assert_eq!(s.wait2 (false), h.id);
463  }
464
465  #[test]
466  fn preflight8() {
467    let (tx, rx) = channel();
468    tx.send (true).unwrap();
469    drop(tx);
470    rx.recv().unwrap();
471    let s = Select::new();
472    let mut h = s.handle(&rx);
473    unsafe { h.add(); }
474    assert_eq!(s.wait2 (false), h.id);
475  }
476
477  #[test]
478  fn oneshot_data_waiting() {
479    let (tx1, rx1) = channel();
480    let (tx2, rx2) = channel();
481    let _t = std::thread::spawn(move|| {
482      select! {
483        _n = rx1.recv() => {}
484      }
485      tx2.send (true).unwrap();
486    });
487
488    for _ in 0..100 { std::thread::yield_now() }
489    tx1.send (true).unwrap();
490    rx2.recv().unwrap();
491  }
492
493  #[test]
494  fn stream_data_waiting() {
495    let (tx1, rx1) = channel();
496    let (tx2, rx2) = channel();
497    tx1.send (true).unwrap();
498    tx1.send (true).unwrap();
499    rx1.recv().unwrap();
500    rx1.recv().unwrap();
501    let _t = std::thread::spawn(move|| {
502      select! {
503        _n = rx1.recv() => {}
504      }
505      tx2.send (true).unwrap();
506    });
507
508    for _ in 0..100 { std::thread::yield_now() }
509    tx1.send (true).unwrap();
510    rx2.recv().unwrap();
511  }
512
513  #[test]
514  fn fmt_debug_select() {
515    let sel = Select::new();
516    assert_eq!(format!("{sel:?}"), "Select { .. }");
517  }
518
519  #[test]
520  fn fmt_debug_handle() {
521    let (_, rx) = channel::<i32>();
522    let sel = Select::new();
523    let handle = sel.handle(&rx);
524    assert_eq!(format!("{handle:?}"), "Handle { .. }");
525  }
526}