latches/task/
mod.rs

1use core::{
2    fmt,
3    future::Future,
4    hint,
5    pin::Pin,
6    sync::atomic::{
7        AtomicUsize,
8        Ordering::{Acquire, Relaxed, Release},
9    },
10    task::{Context, Poll},
11};
12
13use crate::{lock::Mutex, macros, WaitTimeoutResult};
14
15use self::waiters::Waiters;
16
17#[cfg(test)]
18mod tests;
19
20mod waiters;
21
22/// A latch is a downward counter which can be used to coordinate tasks. The
23/// value of the counter is initialized on creation. Tasks may suspend on the
24/// latch until the counter is decremented to 0.
25///
26/// In contrast to [`Barrier`], it is a one-shot phenomenon, that mean the
27/// counter will not be reset after reaching 0. However, it has a useful
28/// property in that it does not make tasks wait for the counter to reach 0 by
29/// calling [`count_down()`] or [`arrive()`].
30///
31/// It spins on every polling of waiting futures.
32///
33/// # Examples
34///
35/// Created by `1` can be used as a simple gate, all tasks calling [`wait()`]
36/// will be suspended until a task calls [`count_down()`].
37///
38/// Created by `N` can be used to make one or more tasks wait until `N`
39/// operations have completed, or an operation has completed 'N' times.
40///
41/// [`Barrier`]: std::sync::Barrier
42/// [`Future`]: std::future::Future
43/// [`arrive()`]: Latch::arrive
44/// [`count_down()`]: Latch::count_down
45/// [`wait()`]: Latch::wait
46///
47/// ```
48/// # use tokio::{runtime::Builder, task};
49/// use std::sync::{
50///     atomic::{AtomicU32, Ordering},
51///     Arc, RwLock,
52/// };
53///
54/// use latches::task::Latch;
55///
56/// # Builder::new_multi_thread().build().unwrap().block_on(async move {
57/// let init_gate = Arc::new(Latch::new(1));
58/// let operation = Arc::new(Latch::new(30));
59/// let results = Arc::new(RwLock::new(Vec::<AtomicU32>::new()));
60///
61/// for i in 0..10 {
62///     let gate = init_gate.clone();
63///     let part = operation.clone();
64///     let res = results.clone();
65///
66///     // Each task need to process 3 operations
67///     task::spawn(async move {
68///         gate.wait().await;
69///
70///         let db = res.read().unwrap();
71///         for j in 0..3 {
72///             db[i * 3 + j].store((i * 3 + j) as u32, Ordering::Relaxed);
73///             part.count_down();
74///         }
75///     });
76/// }
77///
78/// let res = results.clone();
79/// task::spawn(async move {
80///     // Init some statuses, e.g. DB, File System, etc.
81///     let mut db = res.write().unwrap();
82///     for _ in 0..30 {
83///         db.push(AtomicU32::new(0));
84///     }
85///     init_gate.count_down();
86/// });
87///
88/// // All 30 operations will be done after this line
89/// // Or use operation.watch(T) to set the timeout
90/// operation.wait().await;
91///
92/// let res: Vec<_> = results.read()
93///     .unwrap()
94///     .iter()
95///     .map(|i| i.load(Ordering::Relaxed))
96///     .collect();
97/// assert_eq!(res, Vec::from_iter(0..30));
98/// # });
99/// ```
100pub struct Latch {
101    stat: AtomicUsize,
102    lock: Mutex<Waiters>,
103}
104
105impl Latch {
106    /// Creates a new latch initialized with the given count.
107    ///
108    /// # Examples
109    ///
110    /// ```
111    /// use latches::task::Latch;
112    ///
113    /// let latch = Latch::new(10);
114    /// # drop(latch);
115    /// ```
116    #[must_use]
117    #[inline]
118    pub const fn new(count: usize) -> Self {
119        Self {
120            stat: AtomicUsize::new(count),
121            lock: Mutex::new(Waiters::new()),
122        }
123    }
124
125    /// Decrements the latch count, wake up all pending tasks if the counter
126    /// reaches 0 after decrement.
127    ///
128    /// - If the counter has reached 0 then do nothing.
129    /// - If the current count is greater than 0 then it is decremented.
130    /// - If the new count is 0 then all pending tasks are waked up.
131    ///
132    /// # Examples
133    ///
134    /// ```
135    /// use latches::task::Latch;
136    ///
137    /// let latch = Latch::new(1);
138    /// latch.count_down();
139    /// ```
140    pub fn count_down(&self) {
141        macros::decrement!(self, 1);
142    }
143
144    /// Decrements the latch count by `n`, wake up all pending tasks if the
145    /// counter reaches 0 after decrement.
146    ///
147    /// It will not cause an overflow by decrement the counter.
148    ///
149    /// - If the `n` is 0 or the counter has reached 0 then do nothing.
150    /// - If the current count is greater than `n` then decremented by `n`.
151    /// - If the current count is greater than 0 and less than or equal to `n`,
152    ///   then the new count will be 0, and all pending tasks are waked up.
153    ///
154    /// # Examples
155    ///
156    /// ```
157    /// use latches::task::Latch;
158    ///
159    /// let latch = Latch::new(10);
160    ///
161    /// // Do a batch upsert SQL and return `updatedRows` = 10 in runtime.
162    /// # let updatedRows = 10;
163    /// latch.arrive(updatedRows);
164    /// assert_eq!(latch.count(), 0);
165    /// ```
166    pub fn arrive(&self, n: usize) {
167        if n == 0 {
168            return;
169        }
170
171        macros::decrement!(self, n);
172    }
173
174    /// Acquires the current count.
175    ///
176    /// It is typically used for debugging and testing.
177    ///
178    /// # Examples
179    ///
180    /// ```
181    /// use latches::task::Latch;
182    ///
183    /// let latch = Latch::new(3);
184    /// assert_eq!(latch.count(), 3);
185    /// ```
186    #[must_use]
187    #[inline]
188    pub fn count(&self) -> usize {
189        self.stat.load(Acquire)
190    }
191
192    /// Checks that the counter has reached 0.
193    ///
194    /// # Errors
195    ///
196    /// This function will return an error with the current count if the
197    /// counter has not reached 0.
198    ///
199    /// # Examples
200    ///
201    /// ```
202    /// use latches::task::Latch;
203    ///
204    /// let latch = Latch::new(1);
205    /// assert_eq!(latch.try_wait(), Err(1));
206    /// latch.count_down();
207    /// assert_eq!(latch.try_wait(), Ok(()));
208    /// ```
209    #[inline]
210    pub fn try_wait(&self) -> Result<(), usize> {
211        macros::once_try_wait!(self)
212    }
213
214    /// Returns a future that suspends the current task to wait until the
215    /// counter reaches 0.
216    ///
217    /// When the future is polled:
218    ///
219    /// - If the current count is 0 then ready immediately.
220    /// - If the current count is greater than 0 then pending with a waker that
221    ///   will be awakened by a [`count_down()`]/[`arrive()`] invocation which
222    ///   causes the counter reaches 0.
223    ///
224    /// [`count_down()`]: Latch::count_down
225    /// [`arrive()`]: Latch::arrive
226    ///
227    /// # Examples
228    ///
229    /// ```
230    /// # use tokio::runtime::Builder;
231    /// use std::{sync::Arc, thread};
232    ///
233    /// use latches::task::Latch;
234    ///
235    /// # Builder::new_multi_thread().build().unwrap().block_on(async move {
236    /// let latch = Arc::new(Latch::new(1));
237    /// let l1 = latch.clone();
238    ///
239    /// thread::spawn(move || l1.count_down());
240    /// latch.wait().await;
241    /// # });
242    /// ```
243    #[inline]
244    pub const fn wait(&self) -> LatchWait<'_> {
245        LatchWait {
246            id: None,
247            latch: self,
248        }
249    }
250
251    /// Returns a future that suspends the current task to wait until the
252    /// counter reaches 0 or the timer done.
253    ///
254    /// It requires an asynchronous timer, which provides greater flexibility
255    /// for optimization. For example, some implementations provide higher
256    /// precision timers, while other implementations sacrifice timing accuracy
257    /// for performance. Some async libraries provide a global timer pool, if
258    /// your project is using these libraries you should consider using their
259    /// built-in timers first.
260    ///
261    /// When the future is polled:
262    ///
263    /// - If the current count is 0 then [`Reached`] ready immediately.
264    /// - If the timer is done then [`TimedOut(timer_res)`] ready immediately.
265    /// - If the current count is greater than 0 then pending with a waker that
266    ///   will be awakened by a [`count_down()`]/[`arrive()`] invocation which
267    ///   causes the counter reaches 0, or awakened by the timer.
268    ///
269    /// [`Reached`]: WaitTimeoutResult::Reached
270    /// [`TimedOut(timer_res)`]: WaitTimeoutResult::TimedOut
271    /// [`count_down()`]: Latch::count_down
272    /// [`arrive()`]: Latch::arrive
273    ///
274    /// # Examples
275    ///
276    /// This example shows how to extend your own `wait_timeout`.
277    ///
278    /// It is based on tokio, you can use other implementations that your
279    /// prefers, like async-std, futures-timer, async-io, gloo-timers, etc.
280    ///
281    /// ```
282    /// # use tokio::runtime::Builder;
283    /// use std::ops::Deref;
284    ///
285    /// use tokio::time::{sleep, Duration};
286    /// use latches::{task::Latch as Inner, WaitTimeoutResult as Res};
287    ///
288    /// #[repr(transparent)]
289    /// struct Latch(Inner);
290    ///
291    /// impl Latch {
292    ///     const fn new(count: usize) -> Latch {
293    ///         Latch(Inner::new(count))
294    ///     }
295    /// }
296    ///
297    /// impl Latch {
298    ///     async fn wait_timeout(&self, dur: Duration) -> Res<()> {
299    ///         self.0.watch(sleep(dur)).await
300    ///     }
301    /// }
302    ///
303    /// impl Deref for Latch {
304    ///     type Target = Inner;
305    ///
306    ///     fn deref(&self) -> &Self::Target {
307    ///         &self.0
308    ///     }
309    /// }
310    ///
311    /// # Builder::new_multi_thread().enable_time().build().unwrap()
312    /// # .block_on(async move {
313    /// let latch = Latch::new(3);
314    /// let dur = Duration::from_millis(10);
315    ///
316    /// latch.count_down();
317    /// assert!(latch.wait_timeout(dur).await.is_timed_out());
318    /// latch.arrive(2);
319    /// assert!(latch.wait_timeout(dur).await.is_reached());
320    /// # });
321    /// ```
322    #[inline]
323    pub const fn watch<T>(&self, timer: T) -> LatchWatch<'_, T> {
324        LatchWatch {
325            id: None,
326            latch: self,
327            timer,
328        }
329    }
330
331    fn spin(&self) -> bool {
332        macros::spin_try_wait!(self, s, true, s == 0);
333    }
334
335    #[cold]
336    fn done(&self) {
337        Waiters::wake_all(&self.lock);
338    }
339}
340
341/// Future returned by [`Latch::wait`].
342#[must_use = "futures do nothing unless you `.await` or poll them"]
343pub struct LatchWait<'a> {
344    id: Option<usize>,
345    latch: &'a Latch,
346}
347
348impl Future for LatchWait<'_> {
349    type Output = ();
350
351    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
352        let Self { latch, id } = self.get_mut();
353
354        if latch.spin() {
355            Poll::Ready(())
356        } else {
357            let mut lock = latch.lock.lock();
358
359            if latch.stat.load(Acquire) == 0 {
360                Poll::Ready(())
361            } else {
362                lock.upsert(id, cx.waker());
363
364                Poll::Pending
365            }
366        }
367    }
368}
369
370/// Future returned by [`Latch::watch`].
371#[must_use = "futures do nothing unless you `.await` or poll them"]
372pub struct LatchWatch<'a, T> {
373    id: Option<usize>,
374    latch: &'a Latch,
375    timer: T,
376}
377
378impl<T> LatchWatch<'_, T> {
379    /// Gets the pinned timer.
380    ///
381    /// It is typically used to reset, cancel or pre-boot the timer, this
382    /// depends on the timer implementation.
383    ///
384    /// # Examples
385    ///
386    /// This example shows how to reset a tokio timer, other libraries may or
387    /// may not have other ways to resetting timers.
388    ///
389    /// ```
390    /// # use tokio::runtime::Builder;
391    /// use std::pin::Pin;
392    ///
393    /// use tokio::time::{sleep, Duration, Instant};
394    /// use latches::task::Latch;
395    ///
396    /// # Builder::new_multi_thread().enable_time().build().unwrap()
397    /// # .block_on(async move {
398    /// let init_dur = Duration::from_millis(100);
399    /// let reset_dur = Duration::from_millis(10);
400    /// let latch = Latch::new(1);
401    /// let start = Instant::now();
402    /// let mut result = latch.watch(sleep(init_dur));
403    /// let mut result = unsafe { Pin::new_unchecked(&mut result) };
404    ///
405    /// result.as_mut()
406    ///     .timer() // Get `Pin<&mut tokio::time::Sleep>` here
407    ///     .reset(start + reset_dur);
408    /// result.await;
409    /// assert!((reset_dur..init_dur).contains(&start.elapsed()));
410    /// # });
411    /// ```
412    #[must_use]
413    #[inline]
414    pub fn timer(self: Pin<&mut Self>) -> Pin<&mut T> {
415        // SAFETY: LatchWatch does not implement Drop, not repr(packed),
416        // auto implement Unpin if T is Unpin cuz other fields are Unpin.
417        unsafe {
418            let Self { timer, .. } = self.get_unchecked_mut();
419            Pin::new_unchecked(timer)
420        }
421    }
422}
423
424impl<T: Future> Future for LatchWatch<'_, T> {
425    type Output = WaitTimeoutResult<T::Output>;
426
427    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
428        // SAFETY: LatchWatch does not implement Drop, not repr(packed),
429        // auto implement Unpin if T is Unpin cuz other fields are Unpin.
430        let Self { id, latch, timer } = unsafe { self.get_unchecked_mut() };
431        let timer = unsafe { Pin::new_unchecked(timer) };
432
433        if latch.spin() {
434            Poll::Ready(WaitTimeoutResult::Reached)
435        } else {
436            // Acquire lock after pulling timer, minimizing lock-in effects.
437            let out = timer.poll(cx);
438            let mut lock = latch.lock.lock();
439
440            if latch.stat.load(Acquire) == 0 {
441                Poll::Ready(WaitTimeoutResult::Reached)
442            } else {
443                match out {
444                    Poll::Ready(t) => {
445                        lock.remove(id);
446
447                        Poll::Ready(WaitTimeoutResult::TimedOut(t))
448                    }
449                    Poll::Pending => {
450                        lock.upsert(id, cx.waker());
451
452                        Poll::Pending
453                    }
454                }
455            }
456        }
457    }
458}
459
460impl fmt::Debug for Latch {
461    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
462        f.debug_struct("Latch")
463            .field("count", &self.stat)
464            .finish_non_exhaustive()
465    }
466}
467
468impl fmt::Debug for LatchWait<'_> {
469    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
470        f.debug_struct("LatchWait").finish_non_exhaustive()
471    }
472}
473
474impl<T> fmt::Debug for LatchWatch<'_, T> {
475    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
476        f.debug_struct("LatchWatch").finish_non_exhaustive()
477    }
478}