1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use spin::Mutex;

use crate::Semaphore;

/// A condition variable.
///
/// This type allows multiple tasks to wait for an action to be completed.
/// This type is similar to [`std::sync::Condvar`], except that it is
/// async and cross-platform, and cannot be spuriously woken.
///
/// # Examples
/// ```rust
/// use no_std_async::{Condvar, Mutex};
/// # fn main() { std::thread::spawn(|| pollster::block_on(task1())); pollster::block_on(task2()); }
///
/// static CONDVAR: Condvar = Condvar::new();
/// static VALUE: Mutex<u8> = Mutex::new(0);
///
/// async fn task1() {
///     for i in 0..10 {
///         *VALUE.lock().await += 1; // some work
///     }
///     CONDVAR.notify_one(); // we're done!
/// }
///
/// async fn task2() {
///     CONDVAR.wait().await; // wait for task1 to finish
///     assert_eq!(10, *VALUE.lock().await);
/// }
/// ```
pub struct Condvar {
    semaphore: Semaphore,
    num_waiters: Mutex<usize>,
}
impl Condvar {
    /// Creates a new [`Condvar`].
    pub const fn new() -> Self {
        Self {
            semaphore: Semaphore::new(0),
            num_waiters: Mutex::new(0),
        }
    }

    /// Blocks the current task until this condition variable receives a notification.
    pub async fn wait(&self) {
        let mut num_waiters = self.num_waiters.lock();
        *num_waiters += 1;
        drop(num_waiters);
        self.semaphore.acquire(1).await;
    }

    /// Notifies a single task waiting on this condition variable.
    pub fn notify_one(&self) {
        let mut num_waiters = self.num_waiters.lock();
        if *num_waiters > 0 {
            *num_waiters -= 1;
            self.semaphore.release(1);
        }
    }

    /// Notifies all tasks waiting on this condition variable.
    pub fn notify_all(&self) {
        let mut num_waiters = self.num_waiters.lock();
        if *num_waiters > 0 {
            let total_waiters = *num_waiters;
            *num_waiters = 0;
            self.semaphore.release(total_waiters);
        }
    }
}

#[cfg(test)]
mod tests {
    use core::time::Duration;
    use std::thread;

    use super::*;

    // doctests test `notify_one` for us

    #[test]
    fn notify_all() {
        static CONDVAR: Condvar = Condvar::new();

        let task1 = thread::spawn(|| pollster::block_on(CONDVAR.wait()));
        let task2 = thread::spawn(|| pollster::block_on(CONDVAR.wait()));
        let task3 = thread::spawn(|| pollster::block_on(CONDVAR.wait()));

        thread::sleep(Duration::from_millis(100)); // make sure all tasks are waiting

        CONDVAR.notify_all();

        thread::sleep(Duration::from_millis(100)); // time for everything to synchronize

        assert!(task1.is_finished());
        assert!(task2.is_finished());
        assert!(task3.is_finished());
    }
}