async_barrier/lib.rs
1//! An async barrier.
2//!
3//! This crate is an async version of [`std::sync::Barrier`].
4
5#![forbid(unsafe_code)]
6#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
7
8use async_mutex::Mutex;
9use event_listener::Event;
10
11/// A counter to synchronize multiple tasks at the same time.
12#[derive(Debug)]
13pub struct Barrier {
14 n: usize,
15 state: Mutex<State>,
16 event: Event,
17}
18
19#[derive(Debug)]
20struct State {
21 count: usize,
22 generation_id: u64,
23}
24
25impl Barrier {
26 /// Creates a barrier that can block the given number of tasks.
27 ///
28 /// A barrier will block `n`-1 tasks which call [`wait()`] and then wake up all tasks
29 /// at once when the `n`th task calls [`wait()`].
30 ///
31 /// [`wait()`]: `Barrier::wait()`
32 ///
33 /// # Examples
34 ///
35 /// ```
36 /// use async_barrier::Barrier;
37 ///
38 /// let barrier = Barrier::new(5);
39 /// ```
40 pub const fn new(n: usize) -> Barrier {
41 Barrier {
42 n,
43 state: Mutex::new(State {
44 count: 0,
45 generation_id: 0,
46 }),
47 event: Event::new(),
48 }
49 }
50
51 /// Blocks the current task until all tasks reach this point.
52 ///
53 /// Barriers are reusable after all tasks have synchronized, and can be used continuously.
54 ///
55 /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the
56 /// last task to call this method.
57 ///
58 /// # Examples
59 ///
60 /// ```
61 /// use async_barrier::Barrier;
62 /// use futures_lite::future;
63 /// use std::sync::Arc;
64 /// use std::thread;
65 ///
66 /// let barrier = Arc::new(Barrier::new(5));
67 ///
68 /// for _ in 0..5 {
69 /// let b = barrier.clone();
70 /// thread::spawn(move || {
71 /// future::block_on(async {
72 /// // The same messages will be printed together.
73 /// // There will NOT be interleaving of "before" and "after".
74 /// println!("before wait");
75 /// b.wait().await;
76 /// println!("after wait");
77 /// });
78 /// });
79 /// }
80 /// ```
81 pub async fn wait(&self) -> BarrierWaitResult {
82 let mut state = self.state.lock().await;
83 let local_gen = state.generation_id;
84 state.count += 1;
85
86 if state.count < self.n {
87 while local_gen == state.generation_id && state.count < self.n {
88 let listener = self.event.listen();
89 drop(state);
90 listener.await;
91 state = self.state.lock().await;
92 }
93 BarrierWaitResult { is_leader: false }
94 } else {
95 state.count = 0;
96 state.generation_id = state.generation_id.wrapping_add(1);
97 self.event.notify(std::usize::MAX);
98 BarrierWaitResult { is_leader: true }
99 }
100 }
101}
102
103/// Returned by [`Barrier::wait()`] when all tasks have called it.
104///
105/// # Examples
106///
107/// ```
108/// # futures_lite::future::block_on(async {
109/// use async_barrier::Barrier;
110///
111/// let barrier = Barrier::new(1);
112/// let barrier_wait_result = barrier.wait().await;
113/// # });
114/// ```
115#[derive(Debug, Clone)]
116pub struct BarrierWaitResult {
117 is_leader: bool,
118}
119
120impl BarrierWaitResult {
121 /// Returns `true` if this task was the last to call to [`Barrier::wait()`].
122 ///
123 /// # Examples
124 ///
125 /// ```
126 /// # futures_lite::future::block_on(async {
127 /// use async_barrier::Barrier;
128 /// use futures_lite::future;
129 ///
130 /// let barrier = Barrier::new(2);
131 /// let (a, b) = future::zip(barrier.wait(), barrier.wait()).await;
132 /// assert_eq!(a.is_leader(), false);
133 /// assert_eq!(b.is_leader(), true);
134 /// # });
135 /// ```
136 pub fn is_leader(&self) -> bool {
137 self.is_leader
138 }
139}