atomic_try_update/barrier.rs
1//! User-friendly barriers that use `atomic_try_update` to handle startup and teardown race conditions.
2use std::{error::Error, fmt::Display};
3
4use crate::{atomic_try_update, bits::FlagU64, Atom};
5
6pub struct ShutdownBarrierWaitResult {
7 cancelled: bool,
8}
9
10pub struct ShutdownBarrierDoneResult {
11 cancelled: bool,
12 shutdown_leader: bool,
13}
14
15impl ShutdownBarrierWaitResult {
16 /// This will return true for all waiters if at least one
17 /// waiter called cancel() before shutdown.
18 pub fn is_cancelled(&self) -> bool {
19 self.cancelled
20 }
21}
22
23impl ShutdownBarrierDoneResult {
24 pub fn is_cancelled(&self) -> bool {
25 self.cancelled
26 }
27 pub fn is_leader(&self) -> bool {
28 self.shutdown_leader
29 }
30}
31
32/// Similar to `tokio::sync::Barrier`, but you don't need to know how
33/// many waiters there will be up front. It starts with one waiter,
34/// and more can be added dynamically with `spawn()` (which will return
35/// an error if invoked after shutdown).
36///
37/// This barrier's API handles two commonly-overlooked race conditions:
38///
39/// - It is OK to listen for completion before the first worker
40/// starts execution. The `wait()` will not finish until all the
41/// workers that spawn complete.
42/// - It is OK to start listening for completion after the last
43/// worker exits. In this case, the `wait()` will immediately
44/// complete.
45///
46/// You can also invoke `cancel()`, which causes the wait result's
47/// `is_cancelled()` method to return true for all waiters.
48pub struct ShutdownBarrier {
49 state: Atom<FlagU64, u64>,
50 /// We send false for normal shutdown; true for cancellation
51 broadcast: tokio::sync::broadcast::Sender<bool>,
52}
53
54enum WaitResult {
55 StillRunning,
56 Shutdown,
57 Cancelled,
58}
59
60#[derive(Debug)]
61enum DoneResult {
62 Cancelled,
63 AlreadyDone,
64 ShutdownLeader,
65 Running,
66}
67
68impl Default for ShutdownBarrier {
69 fn default() -> Self {
70 let this = Self {
71 state: Default::default(),
72 broadcast: tokio::sync::broadcast::channel(1).0,
73 };
74 unsafe {
75 atomic_try_update(&this.state, |s| {
76 s.set_val(1);
77 (true, ())
78 });
79 }
80 this
81 }
82}
83
84#[derive(Debug)]
85pub enum ShutdownBarrierError {
86 AlreadyShutdown,
87}
88
89impl Error for ShutdownBarrierError {}
90
91impl Display for ShutdownBarrierError {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 write!(f, "{self:?}")
94 }
95}
96
97impl ShutdownBarrier {
98 /// Register another worker with the barrier.
99 ///
100 /// Returns Error if the barrier has already been completed. The barrier
101 /// starts with one worker (the one that is the parent of all work),
102 /// so this will never happen if you are careful not to invoke `spawn()`
103 /// after the parent task invokes `done()`
104 pub fn spawn(&self) -> Result<(), ShutdownBarrierError> {
105 let already_shutdown = unsafe {
106 atomic_try_update(&self.state, |s| {
107 let count = s.get_val();
108 if s.get_flag() || count == 0 {
109 (false, true) // already shutdown
110 } else {
111 s.set_val(count + 1);
112 (true, false)
113 }
114 })
115 };
116 if already_shutdown {
117 Err(ShutdownBarrierError::AlreadyShutdown)
118 } else {
119 Ok(())
120 }
121 }
122
123 /// Inform the barrier that whatever work all the workers are performing
124 /// has been cancelled. This call causes `wait()` to return immediately
125 /// with `cancelled = true`.
126 pub fn cancel(&self) -> Result<(), ShutdownBarrierError> {
127 let already_shutdown = unsafe {
128 atomic_try_update(&self.state, |s| {
129 let count = s.get_val();
130 if s.get_flag() || count == 0 {
131 (false, true)
132 } else {
133 s.set_flag(true);
134 (true, false)
135 }
136 })
137 };
138 if already_shutdown {
139 Err(ShutdownBarrierError::AlreadyShutdown)
140 } else {
141 // send true for cancellation; false on success
142 _ = self.broadcast.send(true);
143 Ok(())
144 }
145 }
146
147 /// Inform the barrier that a single worker has completed.
148 ///
149 /// Returns a `ShutdownBarrierDoneResult`, with `is_cancelled = true` if the
150 /// pool of work protected by the barrier was cancelled, and
151 /// `shutdown_leader = true` if this call to `done()` was the one that completed
152 /// the pool of work. Workers can check for `shutdown_leader = true` to
153 /// perform clean up logic outside the thread of control that invokes `done()`.
154 pub fn done(&self) -> Result<ShutdownBarrierDoneResult, ShutdownBarrierError> {
155 let done_result = unsafe {
156 atomic_try_update(&self.state, |s| {
157 let count = s.get_val();
158 s.set_val(count - 1);
159 if s.get_flag() {
160 (true, DoneResult::Cancelled)
161 } else if count == 0 {
162 (false, DoneResult::AlreadyDone)
163 } else if count == 1 {
164 (true, DoneResult::ShutdownLeader)
165 } else {
166 (true, DoneResult::Running)
167 }
168 })
169 };
170 match done_result {
171 DoneResult::Cancelled => Ok(ShutdownBarrierDoneResult {
172 cancelled: true,
173 shutdown_leader: false,
174 }),
175 DoneResult::ShutdownLeader => {
176 _ = self.broadcast.send(false);
177 Ok(ShutdownBarrierDoneResult {
178 cancelled: false,
179 shutdown_leader: true,
180 })
181 }
182 DoneResult::Running => Ok(ShutdownBarrierDoneResult {
183 cancelled: false,
184 shutdown_leader: false,
185 }),
186 DoneResult::AlreadyDone => Err(ShutdownBarrierError::AlreadyShutdown),
187 }
188 }
189
190 /// Waits until the number of workers reaches zero. This can be called at any time
191 /// and can be called multiple times.
192 pub async fn wait(&self) -> Result<ShutdownBarrierWaitResult, ShutdownBarrierError> {
193 // We have to subscribe before we check state. Otherwise, some some thread
194 // could send the shutdown message after our subscription begins!
195 let mut rx = self.broadcast.subscribe();
196 let wait_result = unsafe {
197 atomic_try_update(&self.state, |s| {
198 let count = s.get_val();
199 if s.get_flag() {
200 (false, WaitResult::Cancelled)
201 } else if count == 0 {
202 (true, WaitResult::Shutdown)
203 } else {
204 (false, WaitResult::StillRunning)
205 }
206 })
207 };
208 match wait_result {
209 WaitResult::StillRunning => {
210 let cancelled = rx
211 .recv()
212 .await
213 .map_err(|_| ShutdownBarrierError::AlreadyShutdown)?;
214 Ok(ShutdownBarrierWaitResult { cancelled })
215 }
216 WaitResult::Shutdown => Ok(ShutdownBarrierWaitResult { cancelled: false }),
217 WaitResult::Cancelled => Ok(ShutdownBarrierWaitResult { cancelled: true }),
218 }
219 }
220 /// Returns a new shutdown barrier with a single worker. The caller
221 /// should spawn() all the work that needs to be done, then invoke
222 /// done(). This makes sure the worker count doesn't spuriously
223 /// reach zero while work is being spawned.
224 pub fn new() -> Self {
225 Default::default()
226 }
227}