latches/task/mod.rs
1use core::{
2 fmt,
3 future::Future,
4 hint,
5 pin::Pin,
6 sync::atomic::{
7 AtomicUsize,
8 Ordering::{Acquire, Relaxed, Release},
9 },
10 task::{Context, Poll},
11};
12
13use crate::{lock::Mutex, macros, WaitTimeoutResult};
14
15use self::waiters::Waiters;
16
17#[cfg(test)]
18mod tests;
19
20mod waiters;
21
22/// A latch is a downward counter which can be used to coordinate tasks. The
23/// value of the counter is initialized on creation. Tasks may suspend on the
24/// latch until the counter is decremented to 0.
25///
26/// In contrast to [`Barrier`], it is a one-shot phenomenon, that mean the
27/// counter will not be reset after reaching 0. However, it has a useful
28/// property in that it does not make tasks wait for the counter to reach 0 by
29/// calling [`count_down()`] or [`arrive()`].
30///
31/// It spins on every polling of waiting futures.
32///
33/// # Examples
34///
35/// Created by `1` can be used as a simple gate, all tasks calling [`wait()`]
36/// will be suspended until a task calls [`count_down()`].
37///
38/// Created by `N` can be used to make one or more tasks wait until `N`
39/// operations have completed, or an operation has completed 'N' times.
40///
41/// [`Barrier`]: std::sync::Barrier
42/// [`Future`]: std::future::Future
43/// [`arrive()`]: Latch::arrive
44/// [`count_down()`]: Latch::count_down
45/// [`wait()`]: Latch::wait
46///
47/// ```
48/// # use tokio::{runtime::Builder, task};
49/// use std::sync::{
50/// atomic::{AtomicU32, Ordering},
51/// Arc, RwLock,
52/// };
53///
54/// use latches::task::Latch;
55///
56/// # Builder::new_multi_thread().build().unwrap().block_on(async move {
57/// let init_gate = Arc::new(Latch::new(1));
58/// let operation = Arc::new(Latch::new(30));
59/// let results = Arc::new(RwLock::new(Vec::<AtomicU32>::new()));
60///
61/// for i in 0..10 {
62/// let gate = init_gate.clone();
63/// let part = operation.clone();
64/// let res = results.clone();
65///
66/// // Each task need to process 3 operations
67/// task::spawn(async move {
68/// gate.wait().await;
69///
70/// let db = res.read().unwrap();
71/// for j in 0..3 {
72/// db[i * 3 + j].store((i * 3 + j) as u32, Ordering::Relaxed);
73/// part.count_down();
74/// }
75/// });
76/// }
77///
78/// let res = results.clone();
79/// task::spawn(async move {
80/// // Init some statuses, e.g. DB, File System, etc.
81/// let mut db = res.write().unwrap();
82/// for _ in 0..30 {
83/// db.push(AtomicU32::new(0));
84/// }
85/// init_gate.count_down();
86/// });
87///
88/// // All 30 operations will be done after this line
89/// // Or use operation.watch(T) to set the timeout
90/// operation.wait().await;
91///
92/// let res: Vec<_> = results.read()
93/// .unwrap()
94/// .iter()
95/// .map(|i| i.load(Ordering::Relaxed))
96/// .collect();
97/// assert_eq!(res, Vec::from_iter(0..30));
98/// # });
99/// ```
100pub struct Latch {
101 stat: AtomicUsize,
102 lock: Mutex<Waiters>,
103}
104
105impl Latch {
106 /// Creates a new latch initialized with the given count.
107 ///
108 /// # Examples
109 ///
110 /// ```
111 /// use latches::task::Latch;
112 ///
113 /// let latch = Latch::new(10);
114 /// # drop(latch);
115 /// ```
116 #[must_use]
117 #[inline]
118 pub const fn new(count: usize) -> Self {
119 Self {
120 stat: AtomicUsize::new(count),
121 lock: Mutex::new(Waiters::new()),
122 }
123 }
124
125 /// Decrements the latch count, wake up all pending tasks if the counter
126 /// reaches 0 after decrement.
127 ///
128 /// - If the counter has reached 0 then do nothing.
129 /// - If the current count is greater than 0 then it is decremented.
130 /// - If the new count is 0 then all pending tasks are waked up.
131 ///
132 /// # Examples
133 ///
134 /// ```
135 /// use latches::task::Latch;
136 ///
137 /// let latch = Latch::new(1);
138 /// latch.count_down();
139 /// ```
140 pub fn count_down(&self) {
141 macros::decrement!(self, 1);
142 }
143
144 /// Decrements the latch count by `n`, wake up all pending tasks if the
145 /// counter reaches 0 after decrement.
146 ///
147 /// It will not cause an overflow by decrement the counter.
148 ///
149 /// - If the `n` is 0 or the counter has reached 0 then do nothing.
150 /// - If the current count is greater than `n` then decremented by `n`.
151 /// - If the current count is greater than 0 and less than or equal to `n`,
152 /// then the new count will be 0, and all pending tasks are waked up.
153 ///
154 /// # Examples
155 ///
156 /// ```
157 /// use latches::task::Latch;
158 ///
159 /// let latch = Latch::new(10);
160 ///
161 /// // Do a batch upsert SQL and return `updatedRows` = 10 in runtime.
162 /// # let updatedRows = 10;
163 /// latch.arrive(updatedRows);
164 /// assert_eq!(latch.count(), 0);
165 /// ```
166 pub fn arrive(&self, n: usize) {
167 if n == 0 {
168 return;
169 }
170
171 macros::decrement!(self, n);
172 }
173
174 /// Acquires the current count.
175 ///
176 /// It is typically used for debugging and testing.
177 ///
178 /// # Examples
179 ///
180 /// ```
181 /// use latches::task::Latch;
182 ///
183 /// let latch = Latch::new(3);
184 /// assert_eq!(latch.count(), 3);
185 /// ```
186 #[must_use]
187 #[inline]
188 pub fn count(&self) -> usize {
189 self.stat.load(Acquire)
190 }
191
192 /// Checks that the counter has reached 0.
193 ///
194 /// # Errors
195 ///
196 /// This function will return an error with the current count if the
197 /// counter has not reached 0.
198 ///
199 /// # Examples
200 ///
201 /// ```
202 /// use latches::task::Latch;
203 ///
204 /// let latch = Latch::new(1);
205 /// assert_eq!(latch.try_wait(), Err(1));
206 /// latch.count_down();
207 /// assert_eq!(latch.try_wait(), Ok(()));
208 /// ```
209 #[inline]
210 pub fn try_wait(&self) -> Result<(), usize> {
211 macros::once_try_wait!(self)
212 }
213
214 /// Returns a future that suspends the current task to wait until the
215 /// counter reaches 0.
216 ///
217 /// When the future is polled:
218 ///
219 /// - If the current count is 0 then ready immediately.
220 /// - If the current count is greater than 0 then pending with a waker that
221 /// will be awakened by a [`count_down()`]/[`arrive()`] invocation which
222 /// causes the counter reaches 0.
223 ///
224 /// [`count_down()`]: Latch::count_down
225 /// [`arrive()`]: Latch::arrive
226 ///
227 /// # Examples
228 ///
229 /// ```
230 /// # use tokio::runtime::Builder;
231 /// use std::{sync::Arc, thread};
232 ///
233 /// use latches::task::Latch;
234 ///
235 /// # Builder::new_multi_thread().build().unwrap().block_on(async move {
236 /// let latch = Arc::new(Latch::new(1));
237 /// let l1 = latch.clone();
238 ///
239 /// thread::spawn(move || l1.count_down());
240 /// latch.wait().await;
241 /// # });
242 /// ```
243 #[inline]
244 pub const fn wait(&self) -> LatchWait<'_> {
245 LatchWait {
246 id: None,
247 latch: self,
248 }
249 }
250
251 /// Returns a future that suspends the current task to wait until the
252 /// counter reaches 0 or the timer done.
253 ///
254 /// It requires an asynchronous timer, which provides greater flexibility
255 /// for optimization. For example, some implementations provide higher
256 /// precision timers, while other implementations sacrifice timing accuracy
257 /// for performance. Some async libraries provide a global timer pool, if
258 /// your project is using these libraries you should consider using their
259 /// built-in timers first.
260 ///
261 /// When the future is polled:
262 ///
263 /// - If the current count is 0 then [`Reached`] ready immediately.
264 /// - If the timer is done then [`TimedOut(timer_res)`] ready immediately.
265 /// - If the current count is greater than 0 then pending with a waker that
266 /// will be awakened by a [`count_down()`]/[`arrive()`] invocation which
267 /// causes the counter reaches 0, or awakened by the timer.
268 ///
269 /// [`Reached`]: WaitTimeoutResult::Reached
270 /// [`TimedOut(timer_res)`]: WaitTimeoutResult::TimedOut
271 /// [`count_down()`]: Latch::count_down
272 /// [`arrive()`]: Latch::arrive
273 ///
274 /// # Examples
275 ///
276 /// This example shows how to extend your own `wait_timeout`.
277 ///
278 /// It is based on tokio, you can use other implementations that your
279 /// prefers, like async-std, futures-timer, async-io, gloo-timers, etc.
280 ///
281 /// ```
282 /// # use tokio::runtime::Builder;
283 /// use std::ops::Deref;
284 ///
285 /// use tokio::time::{sleep, Duration};
286 /// use latches::{task::Latch as Inner, WaitTimeoutResult as Res};
287 ///
288 /// #[repr(transparent)]
289 /// struct Latch(Inner);
290 ///
291 /// impl Latch {
292 /// const fn new(count: usize) -> Latch {
293 /// Latch(Inner::new(count))
294 /// }
295 /// }
296 ///
297 /// impl Latch {
298 /// async fn wait_timeout(&self, dur: Duration) -> Res<()> {
299 /// self.0.watch(sleep(dur)).await
300 /// }
301 /// }
302 ///
303 /// impl Deref for Latch {
304 /// type Target = Inner;
305 ///
306 /// fn deref(&self) -> &Self::Target {
307 /// &self.0
308 /// }
309 /// }
310 ///
311 /// # Builder::new_multi_thread().enable_time().build().unwrap()
312 /// # .block_on(async move {
313 /// let latch = Latch::new(3);
314 /// let dur = Duration::from_millis(10);
315 ///
316 /// latch.count_down();
317 /// assert!(latch.wait_timeout(dur).await.is_timed_out());
318 /// latch.arrive(2);
319 /// assert!(latch.wait_timeout(dur).await.is_reached());
320 /// # });
321 /// ```
322 #[inline]
323 pub const fn watch<T>(&self, timer: T) -> LatchWatch<'_, T> {
324 LatchWatch {
325 id: None,
326 latch: self,
327 timer,
328 }
329 }
330
331 fn spin(&self) -> bool {
332 macros::spin_try_wait!(self, s, true, s == 0);
333 }
334
335 #[cold]
336 fn done(&self) {
337 Waiters::wake_all(&self.lock);
338 }
339}
340
341/// Future returned by [`Latch::wait`].
342#[must_use = "futures do nothing unless you `.await` or poll them"]
343pub struct LatchWait<'a> {
344 id: Option<usize>,
345 latch: &'a Latch,
346}
347
348impl Future for LatchWait<'_> {
349 type Output = ();
350
351 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
352 let Self { latch, id } = self.get_mut();
353
354 if latch.spin() {
355 Poll::Ready(())
356 } else {
357 let mut lock = latch.lock.lock();
358
359 if latch.stat.load(Acquire) == 0 {
360 Poll::Ready(())
361 } else {
362 lock.upsert(id, cx.waker());
363
364 Poll::Pending
365 }
366 }
367 }
368}
369
370/// Future returned by [`Latch::watch`].
371#[must_use = "futures do nothing unless you `.await` or poll them"]
372pub struct LatchWatch<'a, T> {
373 id: Option<usize>,
374 latch: &'a Latch,
375 timer: T,
376}
377
378impl<T> LatchWatch<'_, T> {
379 /// Gets the pinned timer.
380 ///
381 /// It is typically used to reset, cancel or pre-boot the timer, this
382 /// depends on the timer implementation.
383 ///
384 /// # Examples
385 ///
386 /// This example shows how to reset a tokio timer, other libraries may or
387 /// may not have other ways to resetting timers.
388 ///
389 /// ```
390 /// # use tokio::runtime::Builder;
391 /// use std::pin::Pin;
392 ///
393 /// use tokio::time::{sleep, Duration, Instant};
394 /// use latches::task::Latch;
395 ///
396 /// # Builder::new_multi_thread().enable_time().build().unwrap()
397 /// # .block_on(async move {
398 /// let init_dur = Duration::from_millis(100);
399 /// let reset_dur = Duration::from_millis(10);
400 /// let latch = Latch::new(1);
401 /// let start = Instant::now();
402 /// let mut result = latch.watch(sleep(init_dur));
403 /// let mut result = unsafe { Pin::new_unchecked(&mut result) };
404 ///
405 /// result.as_mut()
406 /// .timer() // Get `Pin<&mut tokio::time::Sleep>` here
407 /// .reset(start + reset_dur);
408 /// result.await;
409 /// assert!((reset_dur..init_dur).contains(&start.elapsed()));
410 /// # });
411 /// ```
412 #[must_use]
413 #[inline]
414 pub fn timer(self: Pin<&mut Self>) -> Pin<&mut T> {
415 // SAFETY: LatchWatch does not implement Drop, not repr(packed),
416 // auto implement Unpin if T is Unpin cuz other fields are Unpin.
417 unsafe {
418 let Self { timer, .. } = self.get_unchecked_mut();
419 Pin::new_unchecked(timer)
420 }
421 }
422}
423
424impl<T: Future> Future for LatchWatch<'_, T> {
425 type Output = WaitTimeoutResult<T::Output>;
426
427 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
428 // SAFETY: LatchWatch does not implement Drop, not repr(packed),
429 // auto implement Unpin if T is Unpin cuz other fields are Unpin.
430 let Self { id, latch, timer } = unsafe { self.get_unchecked_mut() };
431 let timer = unsafe { Pin::new_unchecked(timer) };
432
433 if latch.spin() {
434 Poll::Ready(WaitTimeoutResult::Reached)
435 } else {
436 // Acquire lock after pulling timer, minimizing lock-in effects.
437 let out = timer.poll(cx);
438 let mut lock = latch.lock.lock();
439
440 if latch.stat.load(Acquire) == 0 {
441 Poll::Ready(WaitTimeoutResult::Reached)
442 } else {
443 match out {
444 Poll::Ready(t) => {
445 lock.remove(id);
446
447 Poll::Ready(WaitTimeoutResult::TimedOut(t))
448 }
449 Poll::Pending => {
450 lock.upsert(id, cx.waker());
451
452 Poll::Pending
453 }
454 }
455 }
456 }
457 }
458}
459
460impl fmt::Debug for Latch {
461 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
462 f.debug_struct("Latch")
463 .field("count", &self.stat)
464 .finish_non_exhaustive()
465 }
466}
467
468impl fmt::Debug for LatchWait<'_> {
469 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
470 f.debug_struct("LatchWait").finish_non_exhaustive()
471 }
472}
473
474impl<T> fmt::Debug for LatchWatch<'_, T> {
475 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
476 f.debug_struct("LatchWatch").finish_non_exhaustive()
477 }
478}