1use std::sync::{
15 atomic::{AtomicBool, Ordering::*},
16 Arc,
17};
18
19use tokio::sync::Semaphore;
20
21#[derive(Debug, Clone)]
22pub struct Signal {
23 shared: Arc<Inner>,
24}
25
26#[derive(Debug)]
27struct Inner {
28 semaphore: Semaphore,
29 triggered: AtomicBool,
30}
31
32impl Signal {
33 pub fn new() -> Self {
34 Signal {
35 shared: Arc::new(Inner {
36 semaphore: Semaphore::new(0),
37 triggered: AtomicBool::new(false),
38 }),
39 }
40 }
41
42 pub fn trigger(&self) {
43 let result = self
44 .shared
45 .triggered
46 .compare_exchange(false, true, AcqRel, Acquire);
47
48 if result.is_ok() {
49 self.shared.semaphore.add_permits(usize::MAX >> 3);
52 }
53 }
54
55 pub fn is_triggered(&self) -> bool {
56 self.shared.triggered.load(Acquire)
57 }
58
59 pub async fn wait(&self) {
60 if !self.is_triggered() {
61 let _ = self.shared.semaphore.acquire().await;
62 }
63 }
64}
65
66impl Default for Signal {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use std::time::Duration;
75
76 use super::*;
77
78 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
79 async fn signal_test() {
80 let signal = Signal::new();
81
82 let r#pub = tokio::task::spawn({
84 let signal = signal.clone();
85
86 async move {
87 tokio::time::sleep(Duration::from_millis(200)).await;
88 signal.trigger();
89 signal.trigger(); }
91 });
92
93 let fast_sub = tokio::task::spawn({
95 let signal = signal.clone();
96
97 async move {
98 signal.wait().await;
99 }
100 });
101
102 let slow_sub = tokio::task::spawn({
104 let signal = signal.clone();
105
106 async move {
107 tokio::time::sleep(Duration::from_millis(400)).await;
108 signal.wait().await;
109 }
110 });
111
112 let result = tokio::time::timeout(
114 Duration::from_millis(50000),
115 futures::future::join3(r#pub, fast_sub, slow_sub),
116 )
117 .await;
118 assert!(result.is_ok());
119
120 assert!(signal.is_triggered());
122 }
123}