atomic_try_update/
barrier.rs

1//! User-friendly barriers that use `atomic_try_update` to handle startup and teardown race conditions.
2use std::{error::Error, fmt::Display};
3
4use crate::{atomic_try_update, bits::FlagU64, Atom};
5
6pub struct ShutdownBarrierWaitResult {
7    cancelled: bool,
8}
9
10pub struct ShutdownBarrierDoneResult {
11    cancelled: bool,
12    shutdown_leader: bool,
13}
14
15impl ShutdownBarrierWaitResult {
16    /// This will return true for all waiters if at least one
17    /// waiter called cancel() before shutdown.
18    pub fn is_cancelled(&self) -> bool {
19        self.cancelled
20    }
21}
22
23impl ShutdownBarrierDoneResult {
24    pub fn is_cancelled(&self) -> bool {
25        self.cancelled
26    }
27    pub fn is_leader(&self) -> bool {
28        self.shutdown_leader
29    }
30}
31
32/// Similar to `tokio::sync::Barrier`, but you don't need to know how
33/// many waiters there will be up front.  It starts with one waiter,
34/// and more can be added dynamically with `spawn()` (which will return
35/// an error if invoked after shutdown).
36///
37/// This barrier's API handles two commonly-overlooked race conditions:
38///
39///  - It is OK to listen for completion before the first worker
40///    starts execution.  The `wait()` will not finish until all the
41///    workers that spawn complete.
42///  - It is OK to start listening for completion after the last
43///    worker exits.  In this case, the `wait()` will immediately
44///    complete.
45///
46/// You can also invoke `cancel()`, which causes the wait result's
47/// `is_cancelled()` method to return true for all waiters.
48pub struct ShutdownBarrier {
49    state: Atom<FlagU64, u64>,
50    /// We send false for normal shutdown; true for cancellation
51    broadcast: tokio::sync::broadcast::Sender<bool>,
52}
53
54enum WaitResult {
55    StillRunning,
56    Shutdown,
57    Cancelled,
58}
59
60#[derive(Debug)]
61enum DoneResult {
62    Cancelled,
63    AlreadyDone,
64    ShutdownLeader,
65    Running,
66}
67
68impl Default for ShutdownBarrier {
69    fn default() -> Self {
70        let this = Self {
71            state: Default::default(),
72            broadcast: tokio::sync::broadcast::channel(1).0,
73        };
74        unsafe {
75            atomic_try_update(&this.state, |s| {
76                s.set_val(1);
77                (true, ())
78            });
79        }
80        this
81    }
82}
83
84#[derive(Debug)]
85pub enum ShutdownBarrierError {
86    AlreadyShutdown,
87}
88
89impl Error for ShutdownBarrierError {}
90
91impl Display for ShutdownBarrierError {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        write!(f, "{self:?}")
94    }
95}
96
97impl ShutdownBarrier {
98    /// Register another worker with the barrier.
99    ///
100    /// Returns Error if the barrier has already been completed.  The barrier
101    ///         starts with one worker (the one that is the parent of all work),
102    ///         so this will never happen if you are careful not to invoke `spawn()`
103    ///         after the parent task invokes `done()`
104    pub fn spawn(&self) -> Result<(), ShutdownBarrierError> {
105        let already_shutdown = unsafe {
106            atomic_try_update(&self.state, |s| {
107                let count = s.get_val();
108                if s.get_flag() || count == 0 {
109                    (false, true) // already shutdown
110                } else {
111                    s.set_val(count + 1);
112                    (true, false)
113                }
114            })
115        };
116        if already_shutdown {
117            Err(ShutdownBarrierError::AlreadyShutdown)
118        } else {
119            Ok(())
120        }
121    }
122
123    /// Inform the barrier that whatever work all the workers are performing
124    /// has been cancelled.  This call causes `wait()` to return immediately
125    /// with `cancelled = true`.
126    pub fn cancel(&self) -> Result<(), ShutdownBarrierError> {
127        let already_shutdown = unsafe {
128            atomic_try_update(&self.state, |s| {
129                let count = s.get_val();
130                if s.get_flag() || count == 0 {
131                    (false, true)
132                } else {
133                    s.set_flag(true);
134                    (true, false)
135                }
136            })
137        };
138        if already_shutdown {
139            Err(ShutdownBarrierError::AlreadyShutdown)
140        } else {
141            // send true for cancellation; false on success
142            _ = self.broadcast.send(true);
143            Ok(())
144        }
145    }
146
147    /// Inform the barrier that a single worker has completed.
148    ///
149    /// Returns a `ShutdownBarrierDoneResult`, with `is_cancelled = true` if the
150    /// pool of work protected by the barrier was cancelled, and
151    /// `shutdown_leader = true` if this call to `done()` was the one that completed
152    /// the pool of work.  Workers can check for `shutdown_leader = true` to
153    /// perform clean up logic outside the thread of control that invokes `done()`.
154    pub fn done(&self) -> Result<ShutdownBarrierDoneResult, ShutdownBarrierError> {
155        let done_result = unsafe {
156            atomic_try_update(&self.state, |s| {
157                let count = s.get_val();
158                s.set_val(count - 1);
159                if s.get_flag() {
160                    (true, DoneResult::Cancelled)
161                } else if count == 0 {
162                    (false, DoneResult::AlreadyDone)
163                } else if count == 1 {
164                    (true, DoneResult::ShutdownLeader)
165                } else {
166                    (true, DoneResult::Running)
167                }
168            })
169        };
170        match done_result {
171            DoneResult::Cancelled => Ok(ShutdownBarrierDoneResult {
172                cancelled: true,
173                shutdown_leader: false,
174            }),
175            DoneResult::ShutdownLeader => {
176                _ = self.broadcast.send(false);
177                Ok(ShutdownBarrierDoneResult {
178                    cancelled: false,
179                    shutdown_leader: true,
180                })
181            }
182            DoneResult::Running => Ok(ShutdownBarrierDoneResult {
183                cancelled: false,
184                shutdown_leader: false,
185            }),
186            DoneResult::AlreadyDone => Err(ShutdownBarrierError::AlreadyShutdown),
187        }
188    }
189
190    /// Waits until the number of workers reaches zero.  This can be called at any time
191    /// and can be called multiple times.
192    pub async fn wait(&self) -> Result<ShutdownBarrierWaitResult, ShutdownBarrierError> {
193        // We have to subscribe before we check state.  Otherwise, some some thread
194        // could send the shutdown message after our subscription begins!
195        let mut rx = self.broadcast.subscribe();
196        let wait_result = unsafe {
197            atomic_try_update(&self.state, |s| {
198                let count = s.get_val();
199                if s.get_flag() {
200                    (false, WaitResult::Cancelled)
201                } else if count == 0 {
202                    (true, WaitResult::Shutdown)
203                } else {
204                    (false, WaitResult::StillRunning)
205                }
206            })
207        };
208        match wait_result {
209            WaitResult::StillRunning => {
210                let cancelled = rx
211                    .recv()
212                    .await
213                    .map_err(|_| ShutdownBarrierError::AlreadyShutdown)?;
214                Ok(ShutdownBarrierWaitResult { cancelled })
215            }
216            WaitResult::Shutdown => Ok(ShutdownBarrierWaitResult { cancelled: false }),
217            WaitResult::Cancelled => Ok(ShutdownBarrierWaitResult { cancelled: true }),
218        }
219    }
220    /// Returns a new shutdown barrier with a single worker.  The caller
221    /// should spawn() all the work that needs to be done, then invoke
222    /// done().  This makes sure the worker count doesn't spuriously
223    /// reach zero while work is being spawned.
224    pub fn new() -> Self {
225        Default::default()
226    }
227}