1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use rsevents::{Awaitable, ManualResetEvent, State};
use std::sync::atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT};
use std::time::Duration;

pub struct CountdownEvent {
    count: AtomicUsize,
    event: ManualResetEvent,
}

/// A countdown event is a special type of [`ManualResetEvent`] that makes it easy to wait for a
/// given number of tasks to complete asynchronously, and then carry out some action. A countdown
/// event is first initialized with a count equal to the number of outstanding tasks, and each time
/// a task is completed, [`CountdownEvent::tick()`] is called. A call to
/// [`CountdownEvent::wait()`](Awaitable::wait()) will block until all outstanding tasks have
/// completed and the internal counter reaches 0.
///
/// Countdown events are thread-safe and may be wrapped in an [`Arc`](std::sync::Arc) to easily
/// share across threads.
impl CountdownEvent {
    /// Creates a new countdown event with the internal count initialized to `count`. If a count of
    /// zero is specified, the event is immediately set.
    pub fn new(count: usize) -> Self {
        let result = Self {
            count: ATOMIC_USIZE_INIT,
            event: ManualResetEvent::new(if count == 0 { State::Set } else { State::Unset }),
        };

        result.count.store(count, Ordering::Relaxed);

        result
    }

    /// Decrements the internal countdown. When the internal countdown reaches zero, the countdown
    /// event enters a [set](State::Set) state and any outstanding or future calls to
    /// [`Awaitable::wait()`] will be let through without blocking (until [the event is reset](CountdownEvent::reset())).
    pub fn tick(&self) {
        let old_ticks = self.count.fetch_sub(1, Ordering::Relaxed);
        if old_ticks == 1 {
            self.event.set();
        }
    }

    /// Resets a countdown event to the specified `count`. If a count of zero is specified, the
    /// event is immediately set.
    pub fn reset(&self, count: usize) {
        self.count.store(count, Ordering::Relaxed);
        if count == 0 {
            self.event.set();
        }
        else {
            self.event.reset();
        }
    }

    /// Get the current internal countdown value.
    pub fn count(&self) -> usize {
        self.count.load(Ordering::Relaxed)
    }
}

impl Awaitable for CountdownEvent {
    /// Waits for the internal countdown of the [`CountdownEvent`] to reach zero.
    fn wait(&self) {
        self.event.wait()
    }

    fn wait0(&self) -> bool {
        self.event.wait0()
    }

    fn wait_for(&self, limit: Duration) -> bool {
        self.event.wait_for(limit)
    }
}

#[test]
fn basic_countdown() {
    let countdown = CountdownEvent::new(1);
    assert_eq!(countdown.wait0(), false);
    countdown.tick();
    assert_eq!(countdown.wait0(), true);
}

#[test]
fn reset_countdown() {
    let countdown = CountdownEvent::new(1);
    assert_eq!(countdown.wait0(), false);
    countdown.tick();
    assert_eq!(countdown.wait0(), true);
    countdown.reset(1);
    assert_eq!(countdown.wait0(), false);
}

#[test]
fn start_at_zero() {
    let countdown = CountdownEvent::new(0);
    assert_eq!(countdown.wait0(), true);
}

#[test]
fn threaded_countdown() {
    use std::sync::Arc;
    use std::thread;

    let countdown = Arc::new(CountdownEvent::new(2));
    assert_eq!(countdown.wait0(), false);

    let thread1 = {
        let countdown = countdown.clone();
        thread::spawn(move || {
            assert_eq!(countdown.wait0(), false);
            countdown.tick();
        })
    };

    let thread2 = {
        let countdown = countdown.clone();
        thread::spawn(move || {
            assert_eq!(countdown.wait0(), false);
            countdown.tick();
        })
    };

    countdown.wait();

    // To catch any panics
    thread1.join().unwrap();
    thread2.join().unwrap();
}