async_barrier/
lib.rs

1//! An async barrier.
2//!
3//! This crate is an async version of [`std::sync::Barrier`].
4
5#![forbid(unsafe_code)]
6#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
7
8use async_mutex::Mutex;
9use event_listener::Event;
10
11/// A counter to synchronize multiple tasks at the same time.
12#[derive(Debug)]
13pub struct Barrier {
14    n: usize,
15    state: Mutex<State>,
16    event: Event,
17}
18
19#[derive(Debug)]
20struct State {
21    count: usize,
22    generation_id: u64,
23}
24
25impl Barrier {
26    /// Creates a barrier that can block the given number of tasks.
27    ///
28    /// A barrier will block `n`-1 tasks which call [`wait()`] and then wake up all tasks
29    /// at once when the `n`th task calls [`wait()`].
30    ///
31    /// [`wait()`]: `Barrier::wait()`
32    ///
33    /// # Examples
34    ///
35    /// ```
36    /// use async_barrier::Barrier;
37    ///
38    /// let barrier = Barrier::new(5);
39    /// ```
40    pub const fn new(n: usize) -> Barrier {
41        Barrier {
42            n,
43            state: Mutex::new(State {
44                count: 0,
45                generation_id: 0,
46            }),
47            event: Event::new(),
48        }
49    }
50
51    /// Blocks the current task until all tasks reach this point.
52    ///
53    /// Barriers are reusable after all tasks have synchronized, and can be used continuously.
54    ///
55    /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the
56    /// last task to call this method.
57    ///
58    /// # Examples
59    ///
60    /// ```
61    /// use async_barrier::Barrier;
62    /// use futures_lite::future;
63    /// use std::sync::Arc;
64    /// use std::thread;
65    ///
66    /// let barrier = Arc::new(Barrier::new(5));
67    ///
68    /// for _ in 0..5 {
69    ///     let b = barrier.clone();
70    ///     thread::spawn(move || {
71    ///         future::block_on(async {
72    ///             // The same messages will be printed together.
73    ///             // There will NOT be interleaving of "before" and "after".
74    ///             println!("before wait");
75    ///             b.wait().await;
76    ///             println!("after wait");
77    ///         });
78    ///     });
79    /// }
80    /// ```
81    pub async fn wait(&self) -> BarrierWaitResult {
82        let mut state = self.state.lock().await;
83        let local_gen = state.generation_id;
84        state.count += 1;
85
86        if state.count < self.n {
87            while local_gen == state.generation_id && state.count < self.n {
88                let listener = self.event.listen();
89                drop(state);
90                listener.await;
91                state = self.state.lock().await;
92            }
93            BarrierWaitResult { is_leader: false }
94        } else {
95            state.count = 0;
96            state.generation_id = state.generation_id.wrapping_add(1);
97            self.event.notify(std::usize::MAX);
98            BarrierWaitResult { is_leader: true }
99        }
100    }
101}
102
103/// Returned by [`Barrier::wait()`] when all tasks have called it.
104///
105/// # Examples
106///
107/// ```
108/// # futures_lite::future::block_on(async {
109/// use async_barrier::Barrier;
110///
111/// let barrier = Barrier::new(1);
112/// let barrier_wait_result = barrier.wait().await;
113/// # });
114/// ```
115#[derive(Debug, Clone)]
116pub struct BarrierWaitResult {
117    is_leader: bool,
118}
119
120impl BarrierWaitResult {
121    /// Returns `true` if this task was the last to call to [`Barrier::wait()`].
122    ///
123    /// # Examples
124    ///
125    /// ```
126    /// # futures_lite::future::block_on(async {
127    /// use async_barrier::Barrier;
128    /// use futures_lite::future;
129    ///
130    /// let barrier = Barrier::new(2);
131    /// let (a, b) = future::zip(barrier.wait(), barrier.wait()).await;
132    /// assert_eq!(a.is_leader(), false);
133    /// assert_eq!(b.is_leader(), true);
134    /// # });
135    /// ```
136    pub fn is_leader(&self) -> bool {
137        self.is_leader
138    }
139}