may/sync/
sync_flag.rs

1use std::fmt;
2use std::sync::atomic::{AtomicIsize, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5
6use super::blocking::SyncBlocker;
7use crate::cancel::trigger_cancel_panic;
8use crate::park::ParkError;
9use crossbeam::queue::SegQueue;
10
11/// SyncFlag primitive
12///
13/// SyncFlag allow threads and coroutines to synchronize their actions
14/// like barrier.
15///
16/// A SyncFlag is an boolean value
17/// When the SyncFlag is false, any thread or coroutine wait on it would
18/// block until it's value becomes true
19/// When the SyncFlag is true, any thread or coroutine wait on it would
20/// return immediately.
21///
22/// After the SyncFlag becomes true, it will never becomes false again.
23///
24/// # Examples
25///
26/// ```rust
27/// use std::sync::Arc;
28/// use may::coroutine;
29/// use may::sync::SyncFlag;
30///
31/// let flag = Arc::new(SyncFlag::new());
32/// let flag2 = flag.clone();
33///
34/// // spawn a coroutine, and then wait for it to start
35/// unsafe {
36///     coroutine::spawn(move || {
37///         flag2.fire();
38///         flag2.wait();
39///     });
40/// }
41///
42/// // wait for the coroutine to start up
43/// flag.wait();
44/// ```
45pub struct SyncFlag {
46    // track how many resources available for the SyncFlag
47    // if it's negative means how many threads are waiting for
48    cnt: AtomicIsize,
49    // the waiting blocker list, must be mpmc
50    to_wake: SegQueue<Arc<SyncBlocker>>,
51}
52
53impl Default for SyncFlag {
54    fn default() -> Self {
55        SyncFlag {
56            to_wake: SegQueue::new(),
57            cnt: AtomicIsize::new(0),
58        }
59    }
60}
61
62impl SyncFlag {
63    /// create a SyncFlag with the initial value
64    pub fn new() -> Self {
65        Default::default()
66    }
67
68    #[inline]
69    fn wakeup_all(&self) {
70        while let Some(w) = self.to_wake.pop() {
71            w.unpark();
72            if w.take_release() {
73                self.fire();
74            }
75        }
76    }
77
78    // return false if timeout
79    fn wait_timeout_impl(&self, dur: Option<Duration>) -> bool {
80        // try wait first
81        if self.is_fired() {
82            return true;
83        }
84
85        let cur = SyncBlocker::current();
86        // register blocker first
87        self.to_wake.push(cur.clone());
88        // dec the cnt, if it's positive, unpark one waiter
89        if self.cnt.fetch_sub(1, Ordering::SeqCst) > 0 {
90            self.wakeup_all();
91        }
92
93        match cur.park(dur) {
94            Ok(_) => true,
95            Err(err) => {
96                // check the unpark status
97                if cur.is_unparked() {
98                    self.fire();
99                } else {
100                    // register
101                    cur.set_release();
102                    // re-check unpark status
103                    if cur.is_unparked() && cur.take_release() {
104                        self.fire();
105                    }
106                }
107
108                // now we can safely go with the cancel panic
109                if err == ParkError::Canceled {
110                    trigger_cancel_panic();
111                }
112                false
113            }
114        }
115    }
116
117    /// wait for a SyncFlag
118    /// if the SyncFlag value is bigger than zero the function returns immediately
119    /// otherwise it would block the until a `fire` is executed
120    pub fn wait(&self) {
121        self.wait_timeout_impl(None);
122    }
123
124    /// same as `wait` except that with an extra timeout value
125    /// return false if timeout happened
126    pub fn wait_timeout(&self, dur: Duration) -> bool {
127        self.wait_timeout_impl(Some(dur))
128    }
129
130    /// set the SyncFlag to true
131    /// and would wakeup all threads/coroutines that are calling `wait`
132    pub fn fire(&self) {
133        self.cnt.store(isize::MAX, Ordering::SeqCst);
134
135        // try to wakeup all waiters
136        self.wakeup_all();
137    }
138
139    /// return the current SyncFlag value
140    pub fn is_fired(&self) -> bool {
141        let cnt = self.cnt.load(Ordering::SeqCst);
142        cnt > 0
143    }
144}
145
146impl fmt::Debug for SyncFlag {
147    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
148        write!(f, "SyncFlag {{ is_fired: {} }}", self.is_fired())
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use std::thread;
156
157    #[test]
158    fn sanity_test() {
159        let flag = Arc::new(SyncFlag::new());
160        let flag2 = flag.clone();
161
162        // spawn a new thread, and then wait for it to start
163        thread::spawn(move || {
164            flag2.fire();
165            flag2.wait();
166        });
167
168        // wait for the thread to start up
169        flag.wait();
170    }
171
172    #[test]
173    fn test_syncflag_canceled() {
174        use crate::sleep::sleep;
175
176        let flag1 = Arc::new(SyncFlag::new());
177        let flag2 = flag1.clone();
178        let flag3 = flag1.clone();
179
180        let h1 = go!(move || {
181            flag2.wait();
182        });
183
184        let h2 = go!(move || {
185            // let h1 enqueue
186            sleep(Duration::from_millis(50));
187            flag3.wait();
188        });
189
190        // wait h1 and h2 enqueue
191        sleep(Duration::from_millis(100));
192        println!("flag1={flag1:?}");
193        // cancel h1
194        unsafe { h1.coroutine().cancel() };
195        h1.join().unwrap_err();
196        // release the SyncFlag
197        flag1.fire();
198        h2.join().unwrap();
199    }
200
201    #[test]
202    fn test_syncflag_co_timeout() {
203        use crate::sleep::sleep;
204
205        let flag1 = Arc::new(SyncFlag::new());
206        let flag2 = flag1.clone();
207        let flag3 = flag1.clone();
208
209        let h1 = go!(move || {
210            let r = flag2.wait_timeout(Duration::from_millis(10));
211            assert!(!r);
212        });
213
214        let h2 = go!(move || {
215            // let h1 enqueue
216            sleep(Duration::from_millis(50));
217            flag3.wait();
218        });
219
220        // wait h1 timeout
221        h1.join().unwrap();
222        // release the SyncFlag
223        flag1.fire();
224        h2.join().unwrap();
225    }
226
227    #[test]
228    fn test_syncflag_thread_timeout() {
229        use crate::sleep::sleep;
230
231        let flag1 = Arc::new(SyncFlag::new());
232        let flag2 = flag1.clone();
233        let flag3 = flag1.clone();
234
235        let h1 = thread::spawn(move || {
236            let r = flag2.wait_timeout(Duration::from_millis(10));
237            assert!(!r);
238        });
239
240        let h2 = thread::spawn(move || {
241            // let h1 enqueue
242            sleep(Duration::from_millis(50));
243            flag3.wait();
244        });
245
246        // wait h1 timeout
247        h1.join().unwrap();
248        // release the SyncFlag
249        flag1.fire();
250        h2.join().unwrap();
251    }
252}