1#![allow(dead_code)]
2
3use std;
4use crate::{blocking, Receiver, RecvError, SelectionResult};
5
6pub struct Select {
9 inner : std::cell::UnsafeCell <Inner>,
10 next_id : std::cell::Cell <usize>
11}
12impl !Send for Select {}
13
14pub struct Handle <'rx, T : Send + 'rx> {
18 id : usize,
21 selector : *mut Inner,
22 next : *mut Handle <'static, ()>,
23 prev : *mut Handle <'static, ()>,
24 added : bool,
25 packet : &'rx (dyn Packet + 'rx),
26 rx : &'rx Receiver <T>
29}
30
31struct Inner {
32 head : *mut Handle <'static, ()>,
33 tail : *mut Handle <'static, ()>
34}
35
36struct HandleIter {
37 cur : *mut Handle <'static, ()>
38}
39
40#[derive(PartialEq, Eq)]
41pub enum StartResult {
42 Installed,
43 Abort
44}
45
46pub trait Packet {
47 fn can_recv (&self) -> bool;
48 fn start_selection (&self, token : blocking::SignalToken) -> StartResult;
49 fn abort_selection (&self) -> bool;
50}
51
52impl Select {
53 pub const fn new () -> Select {
55 Select {
56 inner: std::cell::UnsafeCell::new (Inner {
57 head: std::ptr::null_mut(),
58 tail: std::ptr::null_mut()
59 }),
60 next_id: std::cell::Cell::new (1)
61 }
62 }
63
64 pub fn handle <'a, T> (&'a self, rx : &'a Receiver <T>) -> Handle <'a, T>
67 where T : Send
68 {
69 let id = self.next_id.get();
70 self.next_id.set (id + 1);
71 Handle {
72 id,
73 selector: self.inner.get(),
74 next: std::ptr::null_mut(),
75 prev: std::ptr::null_mut(),
76 added: false,
77 rx,
78 packet: rx
79 }
80 }
81
82 pub fn wait (&self) -> usize {
88 self.wait2 (true)
89 }
90
91 fn wait2 (&self, do_preflight_checks : bool) -> usize {
93 unsafe {
94 if do_preflight_checks {
96 for handle in self.iter() {
97 if (*handle).packet.can_recv() {
98 return (*handle).id();
99 }
100 }
101 }
102 let (wait_token, signal_token) = blocking::tokens();
104 for (i, handle) in self.iter().enumerate() {
105 match (*handle).packet.start_selection (signal_token.clone()) {
106 StartResult::Installed => {}
107 StartResult::Abort => {
108 for handle in self.iter().take (i) {
109 (*handle).packet.abort_selection();
110 }
111 return (*handle).id;
112 }
113 }
114 }
115 wait_token.wait();
117 let mut ready_id = usize::MAX;
119 for handle in self.iter() {
120 if (*handle).packet.abort_selection() {
121 ready_id = (*handle).id;
122 }
123 }
124
125 assert_ne!(ready_id, usize::MAX);
127 ready_id
128 }
129 }
130
131 fn iter (&self) -> HandleIter {
132 HandleIter {
133 cur: unsafe { &*self.inner.get() }.head
134 }
135 }
136}
137
138impl std::fmt::Debug for Select {
139 fn fmt (&self, f : &mut std::fmt::Formatter) -> std::fmt::Result {
140 write!(f, "Select {{ .. }}")
141 }
142}
143
144impl Drop for Select {
145 fn drop (&mut self) {
146 unsafe {
147 assert!((&*self.inner.get()).head.is_null());
148 assert!((&*self.inner.get()).tail.is_null());
149 }
150 }
151}
152
153impl <'rx, T> Handle <'rx, T> where T : Send {
154 #[inline]
155 pub const fn id (&self) -> usize {
156 self.id
157 }
158
159 pub fn recv (&self) -> Result <T, RecvError> {
160 self.rx.recv()
161 }
162
163 pub unsafe fn add (&mut self) {
165 if self.added {
166 return
167 }
168
169 let selector = unsafe { &mut *self.selector };
170 let me = std::ptr::from_mut::<Handle <'rx, T>> (self) as *mut Handle <'static, ()>;
171 if selector.head.is_null() {
172 selector.head = me;
173 } else {
174 unsafe {
175 (*me).prev = selector.tail;
176 assert!((*me).next.is_null());
177 (*selector.tail).next = me;
178 }
179 }
180 selector.tail = me;
181
182 self.added = true;
183 }
184
185 pub unsafe fn remove (&mut self) {
187 if !self.added {
188 return
189 }
190
191 let selector = unsafe { &mut *self.selector };
192 let me = std::ptr::from_mut::<Handle <'rx, T>>(self) as *mut Handle <'static, ()>;
193 if self.prev.is_null() {
194 assert_eq!(selector.head, me);
195 selector.head = self.next;
196 } else {
197 unsafe { (*self.prev).next = self.next; }
198 }
199 if self.next.is_null() {
200 assert_eq!(selector.tail, me);
201 selector.tail = self.prev;
202 } else {
203 unsafe { (*self.next).prev = self.prev; }
204 }
205
206 self.next = std::ptr::null_mut();
207 self.prev = std::ptr::null_mut();
208 self.added = false;
209 }
210}
211
212impl <'rx, T> std::fmt::Debug for Handle <'rx, T> where T : Send + 'rx {
213 fn fmt (&self, f : &mut std::fmt::Formatter) -> std::fmt::Result {
214 write!(f, "Handle {{ .. }}")
215 }
216}
217
218impl <T> Drop for Handle <'_, T> where T : Send {
219 fn drop (&mut self) {
220 unsafe { self.remove() }
221 }
222}
223
224impl Iterator for HandleIter {
225 type Item = *mut Handle <'static, ()>;
226 fn next (&mut self) -> Option <*mut Handle <'static, ()>> {
227 if self.cur.is_null() {
228 None
229 } else {
230 let ret = Some (self.cur);
231 unsafe {
232 self.cur = (*self.cur).next;
233 }
234 ret
235 }
236 }
237}
238
239impl <T> Packet for Receiver <T> {
240 #[inline]
241 fn can_recv (&self) -> bool {
242 self.can_recv_()
243 }
244 #[inline]
245 fn start_selection (&self, token : blocking::SignalToken) -> StartResult {
246 match self.start_selection_ (token) {
247 SelectionResult::SelSuccess => StartResult::Installed,
248 SelectionResult::SelCanceled => StartResult::Abort
249 }
250 }
251 #[inline]
252 fn abort_selection (&self) -> bool {
253 self.abort_selection_()
254 }
255}
256
257#[macro_export]
258macro_rules! select {
259 (
260 $($name:pat = $rx:ident.$meth:ident() => $code:expr),+
261 ) => {{
262 let sel = Select::new();
263 $(
264 let mut $rx = sel.handle (&$rx);
265 )+
266 unsafe {
267 $($rx.add();)+
268 }
269 let ret = sel.wait();
270 $(
271 if ret == $rx.id() {
272 let $name = $rx.$meth(); $code
273 } else
274 )+
275 { unreachable!() }
276 }}
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use super::super::*;
283
284 #[test]
285 fn smoke() {
286 let (tx1, rx1) = channel::<i32>();
287 let (tx2, rx2) = channel::<i32>();
288 tx1.send (1).unwrap();
289 select! {
290 foo = rx1.recv() => { assert_eq!(foo.unwrap(), 1); },
291 _bar = rx2.recv() => panic!()
292 }
293 tx2.send (2).unwrap();
294 select! {
295 _foo = rx1.recv() => panic!(),
296 bar = rx2.recv() => assert_eq!(bar.unwrap(), 2)
297 }
298 drop(tx1);
299 select! {
300 foo = rx1.recv() => { foo.unwrap_err(); },
301 _bar = rx2.recv() => panic!()
302 }
303 drop(tx2);
304 select! {
305 bar = rx2.recv() => { bar.unwrap_err(); }
306 }
307 }
308
309 #[test]
310 fn smoke2() {
311 let (_tx1, rx1) = channel::<i32>();
312 let (_tx2, rx2) = channel::<i32>();
313 let (_tx3, rx3) = channel::<i32>();
314 let (_tx4, rx4) = channel::<i32>();
315 let (tx5, rx5) = channel::<i32>();
316 tx5.send (4).unwrap();
317 select! {
318 _foo = rx1.recv() => panic!("1"),
319 _foo = rx2.recv() => panic!("2"),
320 _foo = rx3.recv() => panic!("3"),
321 _foo = rx4.recv() => panic!("4"),
322 foo = rx5.recv() => { assert_eq!(foo.unwrap(), 4); }
323 }
324 }
325
326 #[test]
327 fn closed() {
328 let (_tx1, rx1) = channel::<i32>();
329 let (tx2, rx2) = channel::<i32>();
330 drop(tx2);
331
332 select! {
333 _a1 = rx1.recv() => panic!(),
334 a2 = rx2.recv() => { a2.unwrap_err(); }
335 }
336 }
337
338 #[test]
339 fn unblocks() {
340 let (tx1, rx1) = channel::<i32>();
341 let (_tx2, rx2) = channel::<i32>();
342 let (tx3, rx3) = channel::<i32>();
343
344 let _t = std::thread::spawn(move|| {
345 for _ in 0..20 { std::thread::yield_now(); }
346 tx1.send (1).unwrap();
347 rx3.recv().unwrap();
348 for _ in 0..20 { std::thread::yield_now(); }
349 });
350
351 select! {
352 a = rx1.recv() => { assert_eq!(a.unwrap(), 1); },
353 _b = rx2.recv() => panic!()
354 }
355 tx3.send (1).unwrap();
356 select! {
357 a = rx1.recv() => assert!(a.is_err()),
358 _b = rx2.recv() => panic!()
359 }
360 }
361
362 #[test]
363 fn both_ready() {
364 let (tx1, rx1) = channel::<i32>();
365 let (tx2, rx2) = channel::<i32>();
366 let (tx3, rx3) = channel::<bool>();
367
368 let _t = std::thread::spawn(move|| {
369 for _ in 0..20 { std::thread::yield_now(); }
370 tx1.send (1).unwrap();
371 tx2.send (2).unwrap();
372 rx3.recv().unwrap();
373 });
374
375 select! {
376 a = rx1.recv() => { assert_eq!(a.unwrap(), 1); },
377 a = rx2.recv() => { assert_eq!(a.unwrap(), 2); }
378 }
379 select! {
380 a = rx1.recv() => { assert_eq!(a.unwrap(), 1); },
381 a = rx2.recv() => { assert_eq!(a.unwrap(), 2); }
382 }
383 assert_eq!(rx1.try_recv(), Err (TryRecvError::Empty));
384 assert_eq!(rx2.try_recv(), Err (TryRecvError::Empty));
385 tx3.send (true).unwrap();
386 }
387
388 #[test]
389 fn stress() {
390 const AMT: i32 = 10000;
391 let (tx1, rx1) = channel::<i32>();
392 let (tx2, rx2) = channel::<i32>();
393 let (tx3, rx3) = channel::<bool>();
394
395 let _t = std::thread::spawn(move|| {
396 for i in 0..AMT {
397 if i % 2 == 0 {
398 tx1.send (i).unwrap();
399 } else {
400 tx2.send (i).unwrap();
401 }
402 rx3.recv().unwrap();
403 }
404 });
405
406 for i in 0..AMT {
407 select! {
408 i1 = rx1.recv() => { assert!(i % 2 == 0 && i == i1.unwrap()); },
409 i2 = rx2.recv() => { assert!(i % 2 == 1 && i == i2.unwrap()); }
410 }
411 tx3.send (true).unwrap();
412 }
413 }
414
415 #[test]
416 fn preflight1() {
417 let (tx, rx) = channel();
418 tx.send (true).unwrap();
419 select! {
420 _n = rx.recv() => {}
421 }
422 }
423
424 #[test]
425 fn preflight2() {
426 let (tx, rx) = channel();
427 tx.send (true).unwrap();
428 tx.send (true).unwrap();
429 select! {
430 _n = rx.recv() => {}
431 }
432 }
433
434 #[test]
435 fn preflight4() {
436 let (tx, rx) = channel();
437 tx.send (true).unwrap();
438 let s = Select::new();
439 let mut h = s.handle (&rx);
440 unsafe { h.add(); }
441 assert_eq!(s.wait2 (false), h.id);
442 }
443
444 #[test]
445 fn preflight5() {
446 let (tx, rx) = channel();
447 tx.send (true).unwrap();
448 tx.send (true).unwrap();
449 let s = Select::new();
450 let mut h = s.handle(&rx);
451 unsafe { h.add(); }
452 assert_eq!(s.wait2 (false), h.id);
453 }
454
455 #[test]
456 fn preflight7() {
457 let (tx, rx) = channel::<bool>();
458 drop(tx);
459 let s = Select::new();
460 let mut h = s.handle(&rx);
461 unsafe { h.add(); }
462 assert_eq!(s.wait2 (false), h.id);
463 }
464
465 #[test]
466 fn preflight8() {
467 let (tx, rx) = channel();
468 tx.send (true).unwrap();
469 drop(tx);
470 rx.recv().unwrap();
471 let s = Select::new();
472 let mut h = s.handle(&rx);
473 unsafe { h.add(); }
474 assert_eq!(s.wait2 (false), h.id);
475 }
476
477 #[test]
478 fn oneshot_data_waiting() {
479 let (tx1, rx1) = channel();
480 let (tx2, rx2) = channel();
481 let _t = std::thread::spawn(move|| {
482 select! {
483 _n = rx1.recv() => {}
484 }
485 tx2.send (true).unwrap();
486 });
487
488 for _ in 0..100 { std::thread::yield_now() }
489 tx1.send (true).unwrap();
490 rx2.recv().unwrap();
491 }
492
493 #[test]
494 fn stream_data_waiting() {
495 let (tx1, rx1) = channel();
496 let (tx2, rx2) = channel();
497 tx1.send (true).unwrap();
498 tx1.send (true).unwrap();
499 rx1.recv().unwrap();
500 rx1.recv().unwrap();
501 let _t = std::thread::spawn(move|| {
502 select! {
503 _n = rx1.recv() => {}
504 }
505 tx2.send (true).unwrap();
506 });
507
508 for _ in 0..100 { std::thread::yield_now() }
509 tx1.send (true).unwrap();
510 rx2.recv().unwrap();
511 }
512
513 #[test]
514 fn fmt_debug_select() {
515 let sel = Select::new();
516 assert_eq!(format!("{sel:?}"), "Select { .. }");
517 }
518
519 #[test]
520 fn fmt_debug_handle() {
521 let (_, rx) = channel::<i32>();
522 let sel = Select::new();
523 let handle = sel.handle(&rx);
524 assert_eq!(format!("{handle:?}"), "Handle { .. }");
525 }
526}