no_std_async/
condvar.rs

1use spin::Mutex;
2
3use crate::Semaphore;
4
5/// A condition variable.
6///
7/// This type allows multiple tasks to wait for an action to be completed.
8/// This type is similar to [`std::sync::Condvar`], except that it is
9/// async and cross-platform, and cannot be spuriously woken.
10///
11/// # Examples
12/// ```rust
13/// use no_std_async::{Condvar, Mutex};
14/// # fn main() { std::thread::spawn(|| pollster::block_on(task1())); pollster::block_on(task2()); }
15///
16/// static CONDVAR: Condvar = Condvar::new();
17/// static VALUE: Mutex<u8> = Mutex::new(0);
18///
19/// async fn task1() {
20///     for i in 0..10 {
21///         *VALUE.lock().await += 1; // some work
22///     }
23///     CONDVAR.notify_one(); // we're done!
24/// }
25///
26/// async fn task2() {
27///     CONDVAR.wait().await; // wait for task1 to finish
28///     assert_eq!(10, *VALUE.lock().await);
29/// }
30/// ```
31pub struct Condvar {
32    semaphore: Semaphore,
33    num_waiters: Mutex<usize>,
34}
35impl Condvar {
36    /// Creates a new [`Condvar`].
37    pub const fn new() -> Self {
38        Self {
39            semaphore: Semaphore::new(0),
40            num_waiters: Mutex::new(0),
41        }
42    }
43
44    /// Blocks the current task until this condition variable receives a notification.
45    pub async fn wait(&self) {
46        let mut num_waiters = self.num_waiters.lock();
47        *num_waiters += 1;
48        drop(num_waiters);
49        self.semaphore.acquire(1).await;
50    }
51
52    /// Notifies a single task waiting on this condition variable.
53    pub fn notify_one(&self) {
54        let mut num_waiters = self.num_waiters.lock();
55        if *num_waiters > 0 {
56            *num_waiters -= 1;
57            self.semaphore.release(1);
58        }
59    }
60
61    /// Notifies all tasks waiting on this condition variable.
62    pub fn notify_all(&self) {
63        let mut num_waiters = self.num_waiters.lock();
64        if *num_waiters > 0 {
65            let total_waiters = *num_waiters;
66            *num_waiters = 0;
67            self.semaphore.release(total_waiters);
68        }
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use core::time::Duration;
75    use std::thread;
76
77    use super::*;
78
79    // doctests test `notify_one` for us
80
81    #[test]
82    fn notify_all() {
83        static CONDVAR: Condvar = Condvar::new();
84
85        let task1 = thread::spawn(|| pollster::block_on(CONDVAR.wait()));
86        let task2 = thread::spawn(|| pollster::block_on(CONDVAR.wait()));
87        let task3 = thread::spawn(|| pollster::block_on(CONDVAR.wait()));
88
89        thread::sleep(Duration::from_millis(100)); // make sure all tasks are waiting
90
91        CONDVAR.notify_all();
92
93        thread::sleep(Duration::from_millis(100)); // time for everything to synchronize
94
95        assert!(task1.is_finished());
96        assert!(task2.is_finished());
97        assert!(task3.is_finished());
98    }
99}