1use std::{
15 fmt,
16 sync::{
17 atomic::{AtomicU16, AtomicU8, Ordering},
18 Arc,
19 },
20 time::{Duration, Instant},
21};
22
23use event_listener::{Event as EventLib, Listener};
24
25const WAIT_ERR_STR: &str = "No notifier available";
27pub struct WaitError;
28
29impl fmt::Display for WaitError {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 write!(f, "{self:?}")
32 }
33}
34
35impl fmt::Debug for WaitError {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 f.write_str(WAIT_ERR_STR)
38 }
39}
40
41impl std::error::Error for WaitError {}
42
43#[repr(u8)]
44pub enum WaitDeadlineError {
45 Deadline,
46 WaitError,
47}
48
49impl fmt::Display for WaitDeadlineError {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 write!(f, "{self:?}")
52 }
53}
54
55impl fmt::Debug for WaitDeadlineError {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 match self {
58 Self::Deadline => f.write_str("Deadline reached"),
59 Self::WaitError => f.write_str(WAIT_ERR_STR),
60 }
61 }
62}
63
64impl std::error::Error for WaitDeadlineError {}
65
66#[repr(u8)]
67pub enum WaitTimeoutError {
68 Timeout,
69 WaitError,
70}
71
72impl fmt::Display for WaitTimeoutError {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 write!(f, "{self:?}")
75 }
76}
77
78impl fmt::Debug for WaitTimeoutError {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 match self {
81 Self::Timeout => f.write_str("Timeout expired"),
82 Self::WaitError => f.write_str(WAIT_ERR_STR),
83 }
84 }
85}
86
87impl std::error::Error for WaitTimeoutError {}
88
89const NOTIFY_ERR_STR: &str = "No waiter available";
90pub struct NotifyError;
91
92impl fmt::Display for NotifyError {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 write!(f, "{self:?}")
95 }
96}
97
98impl fmt::Debug for NotifyError {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 f.write_str(NOTIFY_ERR_STR)
101 }
102}
103
104impl std::error::Error for NotifyError {}
105
106struct EventInner {
108 event: EventLib,
109 flag: AtomicU8,
110 notifiers: AtomicU16,
111 waiters: AtomicU16,
112}
113
114const UNSET: u8 = 0;
115const OK: u8 = 1;
116const ERR: u8 = 1 << 1;
117
118#[repr(u8)]
119enum EventCheck {
120 Unset = UNSET,
121 Ok = OK,
122 Err = ERR,
123}
124
125#[repr(u8)]
126enum EventSet {
127 Ok = OK,
128 Err = ERR,
129}
130
131impl EventInner {
132 fn check(&self) -> EventCheck {
133 let f = self.flag.fetch_and(!OK, Ordering::SeqCst);
134 if f & ERR != 0 {
135 return EventCheck::Err;
136 }
137 if f == OK {
138 return EventCheck::Ok;
139 }
140 EventCheck::Unset
141 }
142
143 fn set(&self) -> EventSet {
144 let f = self.flag.fetch_or(OK, Ordering::SeqCst);
145 if f & ERR != 0 {
146 return EventSet::Err;
147 }
148 EventSet::Ok
149 }
150
151 fn err(&self) {
152 self.flag.store(ERR, Ordering::SeqCst);
153 }
154}
155
156pub fn new() -> (Notifier, Waiter) {
160 let inner = Arc::new(EventInner {
161 event: EventLib::new(),
162 flag: AtomicU8::new(UNSET),
163 notifiers: AtomicU16::new(1),
164 waiters: AtomicU16::new(1),
165 });
166 (Notifier(inner.clone()), Waiter(inner))
167}
168
169#[repr(transparent)]
171pub struct Notifier(Arc<EventInner>);
172
173impl Notifier {
174 #[inline]
176 pub fn notify(&self) -> Result<(), NotifyError> {
177 match self.0.set() {
179 EventSet::Ok => {
180 self.0.event.notify_additional_relaxed(1);
181 Ok(())
182 }
183 EventSet::Err => Err(NotifyError),
184 }
185 }
186}
187
188impl Clone for Notifier {
189 fn clone(&self) -> Self {
190 let n = self.0.notifiers.fetch_add(1, Ordering::SeqCst);
191 assert!(n != 0);
193 Self(self.0.clone())
194 }
195}
196
197impl Drop for Notifier {
198 fn drop(&mut self) {
199 let n = self.0.notifiers.fetch_sub(1, Ordering::SeqCst);
200 if n == 1 {
201 self.0.err();
203 self.0.event.notify(usize::MAX);
204 }
205 }
206}
207
208#[repr(transparent)]
209pub struct Waiter(Arc<EventInner>);
210
211impl Waiter {
212 #[inline]
214 pub async fn wait_async(&self) -> Result<(), WaitError> {
215 loop {
217 match self.0.check() {
219 EventCheck::Ok => break,
220 EventCheck::Unset => {}
221 EventCheck::Err => return Err(WaitError),
222 }
223
224 let listener = self.0.event.listen();
226
227 match self.0.check() {
229 EventCheck::Ok => break,
230 EventCheck::Unset => {}
231 EventCheck::Err => return Err(WaitError),
232 }
233
234 listener.await;
236 }
237
238 Ok(())
239 }
240
241 #[inline]
243 pub fn wait(&self) -> Result<(), WaitError> {
244 loop {
246 match self.0.check() {
248 EventCheck::Ok => break,
249 EventCheck::Unset => {}
250 EventCheck::Err => return Err(WaitError),
251 }
252
253 let listener = self.0.event.listen();
255
256 match self.0.check() {
258 EventCheck::Ok => break,
259 EventCheck::Unset => {}
260 EventCheck::Err => return Err(WaitError),
261 }
262
263 listener.wait();
265 }
266
267 Ok(())
268 }
269
270 #[inline]
272 pub fn wait_deadline(&self, deadline: Instant) -> Result<(), WaitDeadlineError> {
273 loop {
275 match self.0.check() {
277 EventCheck::Ok => break,
278 EventCheck::Unset => {}
279 EventCheck::Err => return Err(WaitDeadlineError::WaitError),
280 }
281
282 let listener = self.0.event.listen();
284
285 match self.0.check() {
287 EventCheck::Ok => break,
288 EventCheck::Unset => {}
289 EventCheck::Err => return Err(WaitDeadlineError::WaitError),
290 }
291
292 if listener.wait_deadline(deadline).is_none() {
294 return Err(WaitDeadlineError::Deadline);
295 }
296 }
297
298 Ok(())
299 }
300
301 #[inline]
303 pub fn wait_timeout(&self, timeout: Duration) -> Result<(), WaitTimeoutError> {
304 loop {
306 match self.0.check() {
308 EventCheck::Ok => break,
309 EventCheck::Unset => {}
310 EventCheck::Err => return Err(WaitTimeoutError::WaitError),
311 }
312
313 let listener = self.0.event.listen();
315
316 match self.0.check() {
318 EventCheck::Ok => break,
319 EventCheck::Unset => {}
320 EventCheck::Err => return Err(WaitTimeoutError::WaitError),
321 }
322
323 if listener.wait_timeout(timeout).is_none() {
325 return Err(WaitTimeoutError::Timeout);
326 }
327 }
328
329 Ok(())
330 }
331}
332
333impl Clone for Waiter {
334 fn clone(&self) -> Self {
335 let n = self.0.waiters.fetch_add(1, Ordering::Relaxed);
336 assert!(n != 0);
338 Self(self.0.clone())
339 }
340}
341
342impl Drop for Waiter {
343 fn drop(&mut self) {
344 let n = self.0.waiters.fetch_sub(1, Ordering::SeqCst);
345 if n == 1 {
346 self.0.err();
348 }
349 }
350}
351
352mod tests {
353 #[test]
354 fn event_timeout() {
355 use std::{
356 sync::{Arc, Barrier},
357 time::Duration,
358 };
359
360 use crate::WaitTimeoutError;
361
362 let barrier = Arc::new(Barrier::new(2));
363 let (notifier, waiter) = super::new();
364 let tslot = Duration::from_secs(1);
365
366 let bs = barrier.clone();
367 let s = std::thread::spawn(move || {
368 match waiter.wait_timeout(tslot) {
370 Ok(()) => {}
371 Err(WaitTimeoutError::Timeout) => panic!("Timeout {tslot:#?}"),
372 Err(WaitTimeoutError::WaitError) => panic!("Event closed"),
373 }
374
375 bs.wait();
376
377 bs.wait();
379
380 match waiter.wait_timeout(tslot) {
381 Ok(()) => {}
382 Err(WaitTimeoutError::Timeout) => panic!("Timeout {tslot:#?}"),
383 Err(WaitTimeoutError::WaitError) => panic!("Event closed"),
384 }
385
386 match waiter.wait_timeout(tslot) {
387 Ok(()) => panic!("Event Ok but it should be Timeout"),
388 Err(WaitTimeoutError::Timeout) => {}
389 Err(WaitTimeoutError::WaitError) => panic!("Event closed"),
390 }
391
392 bs.wait();
393
394 bs.wait();
396
397 waiter.wait().unwrap_err();
398
399 bs.wait();
400 });
401
402 let bp = barrier.clone();
403 let p = std::thread::spawn(move || {
404 notifier.notify().unwrap();
406
407 bp.wait();
408
409 notifier.notify().unwrap();
411 notifier.notify().unwrap();
412
413 bp.wait();
414 bp.wait();
415
416 drop(notifier);
418
419 bp.wait();
420 bp.wait();
421 });
422
423 s.join().unwrap();
424 p.join().unwrap();
425 }
426
427 #[test]
428 fn event_deadline() {
429 use std::{
430 sync::{Arc, Barrier},
431 time::{Duration, Instant},
432 };
433
434 use crate::WaitDeadlineError;
435
436 let barrier = Arc::new(Barrier::new(2));
437 let (notifier, waiter) = super::new();
438 let tslot = Duration::from_secs(1);
439
440 let bs = barrier.clone();
441 let s = std::thread::spawn(move || {
442 match waiter.wait_deadline(Instant::now() + tslot) {
444 Ok(()) => {}
445 Err(WaitDeadlineError::Deadline) => panic!("Timeout {tslot:#?}"),
446 Err(WaitDeadlineError::WaitError) => panic!("Event closed"),
447 }
448
449 bs.wait();
450
451 bs.wait();
453
454 match waiter.wait_deadline(Instant::now() + tslot) {
455 Ok(()) => {}
456 Err(WaitDeadlineError::Deadline) => panic!("Timeout {tslot:#?}"),
457 Err(WaitDeadlineError::WaitError) => panic!("Event closed"),
458 }
459
460 match waiter.wait_deadline(Instant::now() + tslot) {
461 Ok(()) => panic!("Event Ok but it should be Timeout"),
462 Err(WaitDeadlineError::Deadline) => {}
463 Err(WaitDeadlineError::WaitError) => panic!("Event closed"),
464 }
465
466 bs.wait();
467
468 bs.wait();
470
471 waiter.wait().unwrap_err();
472
473 bs.wait();
474 });
475
476 let bp = barrier.clone();
477 let p = std::thread::spawn(move || {
478 notifier.notify().unwrap();
480
481 bp.wait();
482
483 notifier.notify().unwrap();
485 notifier.notify().unwrap();
486
487 bp.wait();
488 bp.wait();
489
490 drop(notifier);
492
493 bp.wait();
494 bp.wait();
495 });
496
497 s.join().unwrap();
498 p.join().unwrap();
499 }
500
501 #[test]
502 fn event_loop() {
503 use std::{
504 sync::{
505 atomic::{AtomicUsize, Ordering},
506 Arc, Barrier,
507 },
508 time::{Duration, Instant},
509 };
510
511 const N: usize = 1_000;
512 static COUNTER: AtomicUsize = AtomicUsize::new(0);
513
514 let (notifier, waiter) = super::new();
515 let barrier = Arc::new(Barrier::new(2));
516
517 let bs = barrier.clone();
518 let s = std::thread::spawn(move || {
519 for _ in 0..N {
520 waiter.wait().unwrap();
521 COUNTER.fetch_add(1, Ordering::Relaxed);
522 bs.wait();
523 }
524 });
525 let p = std::thread::spawn(move || {
526 for _ in 0..N {
527 notifier.notify().unwrap();
528 barrier.wait();
529 }
530 });
531
532 let start = Instant::now();
533 let tout = Duration::from_secs(60);
534 loop {
535 let n = COUNTER.load(Ordering::Relaxed);
536 if n == N {
537 break;
538 }
539 if start.elapsed() > tout {
540 panic!("Timeout {tout:#?}. Counter: {n}/{N}");
541 }
542
543 std::thread::sleep(Duration::from_millis(100));
544 }
545
546 s.join().unwrap();
547 p.join().unwrap();
548 }
549
550 #[test]
551 fn event_multiple() {
552 use std::{
553 sync::atomic::{AtomicUsize, Ordering},
554 time::{Duration, Instant},
555 };
556
557 const N: usize = 1_000;
558 static COUNTER: AtomicUsize = AtomicUsize::new(0);
559
560 let (notifier, waiter) = super::new();
561
562 let w1 = waiter.clone();
563 let s1 = std::thread::spawn(move || {
564 let mut n = 0;
565 while COUNTER.fetch_add(1, Ordering::Relaxed) < N - 2 {
566 w1.wait().unwrap();
567 n += 1;
568 }
569 println!("S1: {n}");
570 });
571 let s2 = std::thread::spawn(move || {
572 let mut n = 0;
573 while COUNTER.fetch_add(1, Ordering::Relaxed) < N - 2 {
574 waiter.wait().unwrap();
575 n += 1;
576 }
577 println!("S2: {n}");
578 });
579
580 let n1 = notifier.clone();
581 let p1 = std::thread::spawn(move || {
582 let mut n = 0;
583 while COUNTER.load(Ordering::Relaxed) < N {
584 n1.notify().unwrap();
585 n += 1;
586 std::thread::sleep(Duration::from_millis(1));
587 }
588 println!("P1: {n}");
589 });
590 let p2 = std::thread::spawn(move || {
591 let mut n = 0;
592 while COUNTER.load(Ordering::Relaxed) < N {
593 notifier.notify().unwrap();
594 n += 1;
595 std::thread::sleep(Duration::from_millis(1));
596 }
597 println!("P2: {n}");
598 });
599
600 std::thread::spawn(move || {
601 let start = Instant::now();
602 let tout = Duration::from_secs(60);
603 loop {
604 let n = COUNTER.load(Ordering::Relaxed);
605 if n == N {
606 break;
607 }
608 if start.elapsed() > tout {
609 panic!("Timeout {tout:#?}. Counter: {n}/{N}");
610 }
611
612 std::thread::sleep(Duration::from_millis(100));
613 }
614 });
615
616 p1.join().unwrap();
617 p2.join().unwrap();
618
619 s1.join().unwrap();
620 s2.join().unwrap();
621 }
622}