casper_node/utils/
work_queue.rs

1//! Work queue for finite work.
2//!
3//! A queue that allows for processing a variable amount of work that may spawn more jobs, but is
4//! expected to finish eventually.
5
6use std::{
7    collections::VecDeque,
8    sync::{Arc, Mutex},
9};
10
11use futures::{stream, Stream};
12use tokio::sync::Notify;
13
14/// Multi-producer, multi-consumer async job queue with end conditions.
15///
16/// Keeps track of in-progress jobs and can indicate to workers that all work has been finished.
17/// Intended to be used for jobs that will spawn other jobs during processing, but stop once all
18/// jobs have finished.
19///
20/// # Example use
21///
22/// ```rust
23/// #![allow(non_snake_case)]
24/// # use std::{sync::Arc, time::Duration};
25/// #
26/// # use futures::stream::{futures_unordered::FuturesUnordered, StreamExt};
27/// #
28/// # use casper_node::utils::work_queue::WorkQueue;
29/// #
30/// type DemoJob = (&'static str, usize);
31///
32/// /// Job processing function.
33/// ///
34/// /// For a given job `(name, n)`, returns two jobs with `n = n - 1`, unless `n == 0`.
35/// async fn process_job(job: DemoJob) -> Vec<DemoJob> {
36///     tokio::time::sleep(Duration::from_millis(25)).await;
37///
38///     let (tag, n) = job;
39///
40///     if n == 0 {
41///         Vec::new()
42///     } else {
43///         vec![(tag, n - 1), (tag, n - 1)]
44///     }
45/// }
46///
47/// /// Job-processing worker.
48/// ///
49/// /// `id` is the worker ID for logging.
50/// async fn worker(id: usize, q: Arc<WorkQueue<DemoJob>>) {
51///     println!("worker {}: init", id);
52///
53///     while let Some(job) = q.next_job().await {
54///         println!("worker {}: start job {:?}", id, job.inner());
55///         for new_job in process_job(job.inner().clone()).await {
56///             q.push_job(new_job);
57///         }
58///         println!("worker {}: finish job {:?}", id, job.inner());
59///     }
60///
61///     println!("worker {}: shutting down", id);
62/// }
63///
64/// const WORKER_COUNT: usize = 3;
65/// #
66/// # async fn test_func() {
67/// let q = Arc::new(WorkQueue::default());
68/// q.push_job(("A", 3));
69///
70/// let workers: FuturesUnordered<_> = (0..WORKER_COUNT).map(|id| worker(id, q.clone())).collect();
71///
72/// // Wait for all workers to finish.
73/// workers.for_each(|_| async move {}).await;
74/// # }
75/// # let rt = tokio::runtime::Runtime::new().unwrap();
76/// # let handle = rt.handle();
77/// # handle.block_on(test_func());
78/// ```
79#[derive(Debug)]
80pub struct WorkQueue<T> {
81    /// Inner workings of the queue.
82    inner: Mutex<QueueInner<T>>,
83    /// Notifier for waiting tasks.
84    notify: Notify,
85}
86
87/// Queue inner state.
88#[derive(Debug)]
89struct QueueInner<T> {
90    /// Jobs currently in the queue.
91    jobs: VecDeque<T>,
92    /// Number of jobs that have been popped from the queue using `next_job` but not finished.
93    in_progress: usize,
94}
95
96// Manual default implementation, since the derivation would require a `T: Default` trait bound.
97impl<T> Default for WorkQueue<T> {
98    fn default() -> Self {
99        Self {
100            inner: Default::default(),
101            notify: Default::default(),
102        }
103    }
104}
105
106impl<T> Default for QueueInner<T> {
107    fn default() -> Self {
108        Self {
109            jobs: Default::default(),
110            in_progress: Default::default(),
111        }
112    }
113}
114
115impl<T> WorkQueue<T> {
116    /// Pop a job from the queue.
117    ///
118    /// If there is a job in the queue, returns the job and increases the internal in progress
119    /// counter by one.
120    ///
121    /// If there are still jobs in progress, but none queued, waits until either of these conditions
122    /// changes, then retries.
123    ///
124    /// If there are no jobs available and no jobs in progress, returns `None`.
125    pub async fn next_job(self: &Arc<Self>) -> Option<JobHandle<T>> {
126        loop {
127            let waiting;
128            {
129                let mut inner = self.inner.lock().expect("lock poisoned");
130                match inner.jobs.pop_front() {
131                    Some(job) => {
132                        // We got a job, increase the `in_progress` count and return.
133                        inner.in_progress += 1;
134                        return Some(JobHandle {
135                            job,
136                            queue: self.clone(),
137                        });
138                    }
139                    None => {
140                        // No job found. Check if we are completely done.
141                        if inner.in_progress == 0 {
142                            // No more jobs, no jobs in progress. We are done!
143                            return None;
144                        }
145
146                        // Otherwise, we have to wait.
147                        waiting = self.notify.notified();
148                    }
149                }
150            }
151
152            // Note: Any notification sent while executing this segment (after the guard has been
153            // dropped, but before `waiting.await` has been entered) will still be picked up by
154            // `waiting.await`, as the call to `notified()` marks the beginning of the waiting
155            // period, not `waiting.await`. See `tests::notification_assumption_holds`.
156
157            // After freeing the lock, wait for a new job to arrive or be finished.
158            waiting.await;
159        }
160    }
161
162    /// Pushes a job onto the queue.
163    ///
164    /// If there are any worker waiting on `next_job`, one of them will receive the job.
165    pub fn push_job(&self, job: T) {
166        let mut inner = self.inner.lock().expect("lock poisoned");
167
168        inner.jobs.push_back(job);
169        self.notify.notify_waiters();
170    }
171
172    /// Returns the number of jobs in the queue.
173    pub fn num_jobs(&self) -> usize {
174        self.inner.lock().expect("lock poisoned").jobs.len()
175    }
176
177    /// Creates a streaming consumer of the work queue.
178    #[inline]
179    pub fn to_stream(self: Arc<Self>) -> impl Stream<Item = JobHandle<T>> {
180        stream::unfold(self, |work_queue| async move {
181            let next = work_queue.next_job().await;
182            next.map(|handle| (handle, work_queue))
183        })
184    }
185
186    /// Mark job completion.
187    ///
188    /// This is an internal function to be used by `JobHandle`, which locks the internal queue and
189    /// decreases the in-progress count by one.
190    fn complete_job(&self) {
191        let mut inner = self.inner.lock().expect("lock poisoned");
192
193        inner.in_progress -= 1;
194        self.notify.notify_waiters();
195    }
196}
197
198/// Handle containing a job.
199///
200/// Holds a job popped from the job queue.
201///
202/// The job will be considered completed once `JobHandle` has been dropped.
203#[derive(Debug)]
204pub struct JobHandle<T> {
205    /// The protected job.
206    job: T,
207    /// Queue job was removed from.
208    queue: Arc<WorkQueue<T>>,
209}
210
211impl<T> JobHandle<T> {
212    /// Returns a reference to the inner job.
213    pub fn inner(&self) -> &T {
214        &self.job
215    }
216}
217
218impl<T> Drop for JobHandle<T> {
219    fn drop(&mut self) {
220        self.queue.complete_job();
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use std::{
227        sync::{
228            atomic::{AtomicU32, Ordering},
229            Arc,
230        },
231        time::Duration,
232    };
233
234    use futures::{FutureExt, StreamExt};
235    use tokio::sync::Notify;
236
237    use super::WorkQueue;
238
239    #[derive(Debug)]
240    struct TestJob(u32);
241
242    // Verify that the assumption made about `Notification` -- namely that a call to `notified()` is
243    // enough to "register" the waiter -- holds.
244    #[test]
245    fn notification_assumption_holds() {
246        let not = Notify::new();
247
248        // First attempt to await a notification, should return pending.
249        assert!(not.notified().now_or_never().is_none());
250
251        // Second, we notify, then try notification again. Should also return pending, as we were
252        // "not around" when the notification happened.
253        not.notify_waiters();
254        assert!(not.notified().now_or_never().is_none());
255
256        // Finally, we "register" for notification beforehand.
257        let waiter = not.notified();
258        not.notify_waiters();
259        assert!(waiter.now_or_never().is_some());
260    }
261
262    /// Process a job, sleeping a short amout of time on every 5th job.
263    async fn job_worker_simple(queue: Arc<WorkQueue<TestJob>>, sum: Arc<AtomicU32>) {
264        while let Some(job) = queue.next_job().await {
265            if job.inner().0 % 5 == 0 {
266                tokio::time::sleep(Duration::from_millis(50)).await;
267            }
268
269            sum.fetch_add(job.inner().0, Ordering::SeqCst);
270        }
271    }
272
273    /// Process a job, sleeping a short amount of time on every job.
274    ///
275    /// Spawns two additional jobs for every job processed, decreasing the job number until reaching
276    /// zero.
277    async fn job_worker_binary(queue: Arc<WorkQueue<TestJob>>, sum: Arc<AtomicU32>) {
278        while let Some(job) = queue.next_job().await {
279            tokio::time::sleep(Duration::from_millis(10)).await;
280
281            sum.fetch_add(job.inner().0, Ordering::SeqCst);
282
283            if job.inner().0 > 0 {
284                queue.push_job(TestJob(job.inner().0 - 1));
285                queue.push_job(TestJob(job.inner().0 - 1));
286            }
287        }
288    }
289
290    #[tokio::test]
291    async fn empty_queue_exits_immediately() {
292        let q: Arc<WorkQueue<TestJob>> = Arc::new(Default::default());
293        assert!(q.next_job().await.is_none());
294    }
295
296    #[tokio::test]
297    async fn large_front_loaded_queue_terminates() {
298        let num_jobs = 1_000;
299        let q: Arc<WorkQueue<TestJob>> = Arc::new(Default::default());
300        for job in (0..num_jobs).map(TestJob) {
301            q.push_job(job);
302        }
303
304        let mut workers = Vec::new();
305        let output = Arc::new(AtomicU32::new(0));
306        for _ in 0..3 {
307            workers.push(tokio::spawn(job_worker_simple(q.clone(), output.clone())));
308        }
309
310        // We use a different pattern for waiting here, see the doctest for a solution that does not
311        // spawn.
312        for worker in workers {
313            worker.await.expect("task panicked");
314        }
315
316        let expected_total = (num_jobs * (num_jobs - 1)) / 2;
317        assert_eq!(output.load(Ordering::SeqCst), expected_total);
318    }
319
320    #[tokio::test]
321    async fn stream_interface_works() {
322        let num_jobs = 1_000;
323        let q: Arc<WorkQueue<TestJob>> = Arc::new(Default::default());
324        for job in (0..num_jobs).map(TestJob) {
325            q.push_job(job);
326        }
327
328        let mut current = 0;
329        let mut stream = Box::pin(q.to_stream());
330        while let Some(job) = stream.next().await {
331            assert_eq!(job.inner().0, current);
332            current += 1;
333        }
334    }
335
336    #[tokio::test]
337    async fn complex_queue_terminates() {
338        let num_jobs = 5;
339        let q: Arc<WorkQueue<TestJob>> = Arc::new(Default::default());
340        for _ in 0..num_jobs {
341            q.push_job(TestJob(num_jobs));
342        }
343
344        let mut workers = Vec::new();
345        let output = Arc::new(AtomicU32::new(0));
346        for _ in 0..3 {
347            workers.push(tokio::spawn(job_worker_binary(q.clone(), output.clone())));
348        }
349
350        // We use a different pattern for waiting here, see the doctest for a solution that does not
351        // spawn.
352        for worker in workers {
353            worker.await.expect("task panicked");
354        }
355
356        // A single job starting at `k` will add `SUM_{n=0}^{k} (k-n) * 2^n`, which is
357        // 57 for `k=5`. We start 5 jobs, so we expect `5 * 57 = 285` to be the result.
358        let expected_total = 285;
359        assert_eq!(output.load(Ordering::SeqCst), expected_total);
360    }
361}