zero_pool/
future.rs

1use std::sync::atomic::Ordering;
2use std::sync::{Arc, Condvar, Mutex};
3use std::time::Duration;
4
5use crate::padded_type::PaddedAtomicUsize;
6
7// public work future with arc wrapped fields
8#[derive(Clone)]
9pub struct WorkFuture {
10    remaining: Arc<PaddedAtomicUsize>,
11    state: Arc<(Mutex<()>, Condvar)>,
12}
13
14impl WorkFuture {
15    // create a new work future for the given number of tasks
16    pub(crate) fn new(task_count: usize) -> Self {
17        WorkFuture {
18            remaining: Arc::new(PaddedAtomicUsize::new(task_count)),
19            state: Arc::new((Mutex::new(()), Condvar::new())),
20        }
21    }
22
23    // check if all tasks are complete
24    pub fn is_complete(&self) -> bool {
25        self.remaining.load(Ordering::Acquire) == 0
26    }
27
28    // wait for all tasks to complete
29    pub fn wait(self) {
30        if self.is_complete() {
31            return;
32        }
33
34        let (lock, cvar) = &*self.state;
35        let mut guard = lock.lock().unwrap();
36
37        while !self.is_complete() {
38            guard = cvar.wait(guard).unwrap();
39        }
40    }
41
42    // wait for all tasks with timeout
43    pub fn wait_timeout(self, timeout: Duration) -> bool {
44        if self.is_complete() {
45            return true;
46        }
47
48        let (lock, cvar) = &*self.state;
49        let mut guard = lock.lock().unwrap();
50
51        while !self.is_complete() {
52            let (new_guard, timeout_result) = cvar.wait_timeout(guard, timeout).unwrap();
53            guard = new_guard;
54            if timeout_result.timed_out() {
55                return self.is_complete();
56            }
57        }
58        true
59    }
60
61    // get remaining task count
62    pub fn remaining_count(&self) -> usize {
63        self.remaining.load(Ordering::Relaxed)
64    }
65
66    // complets multiple tasks, decrements counter and notifies if all done
67    #[inline]
68    pub(crate) fn complete_many(&self, count: usize) {
69        let remaining_count = self.remaining.fetch_sub(count, Ordering::Release);
70
71        // if this completed the last tasks, notify waiters
72        if remaining_count == count {
73            let (lock, cvar) = &*self.state;
74            let _guard = lock.lock().unwrap();
75            cvar.notify_all();
76        }
77    }
78}