casper_node/utils/work_queue.rs
1//! Work queue for finite work.
2//!
3//! A queue that allows for processing a variable amount of work that may spawn more jobs, but is
4//! expected to finish eventually.
5
6use std::{
7 collections::VecDeque,
8 sync::{Arc, Mutex},
9};
10
11use futures::{stream, Stream};
12use tokio::sync::Notify;
13
14/// Multi-producer, multi-consumer async job queue with end conditions.
15///
16/// Keeps track of in-progress jobs and can indicate to workers that all work has been finished.
17/// Intended to be used for jobs that will spawn other jobs during processing, but stop once all
18/// jobs have finished.
19///
20/// # Example use
21///
22/// ```rust
23/// #![allow(non_snake_case)]
24/// # use std::{sync::Arc, time::Duration};
25/// #
26/// # use futures::stream::{futures_unordered::FuturesUnordered, StreamExt};
27/// #
28/// # use casper_node::utils::work_queue::WorkQueue;
29/// #
30/// type DemoJob = (&'static str, usize);
31///
32/// /// Job processing function.
33/// ///
34/// /// For a given job `(name, n)`, returns two jobs with `n = n - 1`, unless `n == 0`.
35/// async fn process_job(job: DemoJob) -> Vec<DemoJob> {
36/// tokio::time::sleep(Duration::from_millis(25)).await;
37///
38/// let (tag, n) = job;
39///
40/// if n == 0 {
41/// Vec::new()
42/// } else {
43/// vec![(tag, n - 1), (tag, n - 1)]
44/// }
45/// }
46///
47/// /// Job-processing worker.
48/// ///
49/// /// `id` is the worker ID for logging.
50/// async fn worker(id: usize, q: Arc<WorkQueue<DemoJob>>) {
51/// println!("worker {}: init", id);
52///
53/// while let Some(job) = q.next_job().await {
54/// println!("worker {}: start job {:?}", id, job.inner());
55/// for new_job in process_job(job.inner().clone()).await {
56/// q.push_job(new_job);
57/// }
58/// println!("worker {}: finish job {:?}", id, job.inner());
59/// }
60///
61/// println!("worker {}: shutting down", id);
62/// }
63///
64/// const WORKER_COUNT: usize = 3;
65/// #
66/// # async fn test_func() {
67/// let q = Arc::new(WorkQueue::default());
68/// q.push_job(("A", 3));
69///
70/// let workers: FuturesUnordered<_> = (0..WORKER_COUNT).map(|id| worker(id, q.clone())).collect();
71///
72/// // Wait for all workers to finish.
73/// workers.for_each(|_| async move {}).await;
74/// # }
75/// # let rt = tokio::runtime::Runtime::new().unwrap();
76/// # let handle = rt.handle();
77/// # handle.block_on(test_func());
78/// ```
79#[derive(Debug)]
80pub struct WorkQueue<T> {
81 /// Inner workings of the queue.
82 inner: Mutex<QueueInner<T>>,
83 /// Notifier for waiting tasks.
84 notify: Notify,
85}
86
87/// Queue inner state.
88#[derive(Debug)]
89struct QueueInner<T> {
90 /// Jobs currently in the queue.
91 jobs: VecDeque<T>,
92 /// Number of jobs that have been popped from the queue using `next_job` but not finished.
93 in_progress: usize,
94}
95
96// Manual default implementation, since the derivation would require a `T: Default` trait bound.
97impl<T> Default for WorkQueue<T> {
98 fn default() -> Self {
99 Self {
100 inner: Default::default(),
101 notify: Default::default(),
102 }
103 }
104}
105
106impl<T> Default for QueueInner<T> {
107 fn default() -> Self {
108 Self {
109 jobs: Default::default(),
110 in_progress: Default::default(),
111 }
112 }
113}
114
115impl<T> WorkQueue<T> {
116 /// Pop a job from the queue.
117 ///
118 /// If there is a job in the queue, returns the job and increases the internal in progress
119 /// counter by one.
120 ///
121 /// If there are still jobs in progress, but none queued, waits until either of these conditions
122 /// changes, then retries.
123 ///
124 /// If there are no jobs available and no jobs in progress, returns `None`.
125 pub async fn next_job(self: &Arc<Self>) -> Option<JobHandle<T>> {
126 loop {
127 let waiting;
128 {
129 let mut inner = self.inner.lock().expect("lock poisoned");
130 match inner.jobs.pop_front() {
131 Some(job) => {
132 // We got a job, increase the `in_progress` count and return.
133 inner.in_progress += 1;
134 return Some(JobHandle {
135 job,
136 queue: self.clone(),
137 });
138 }
139 None => {
140 // No job found. Check if we are completely done.
141 if inner.in_progress == 0 {
142 // No more jobs, no jobs in progress. We are done!
143 return None;
144 }
145
146 // Otherwise, we have to wait.
147 waiting = self.notify.notified();
148 }
149 }
150 }
151
152 // Note: Any notification sent while executing this segment (after the guard has been
153 // dropped, but before `waiting.await` has been entered) will still be picked up by
154 // `waiting.await`, as the call to `notified()` marks the beginning of the waiting
155 // period, not `waiting.await`. See `tests::notification_assumption_holds`.
156
157 // After freeing the lock, wait for a new job to arrive or be finished.
158 waiting.await;
159 }
160 }
161
162 /// Pushes a job onto the queue.
163 ///
164 /// If there are any worker waiting on `next_job`, one of them will receive the job.
165 pub fn push_job(&self, job: T) {
166 let mut inner = self.inner.lock().expect("lock poisoned");
167
168 inner.jobs.push_back(job);
169 self.notify.notify_waiters();
170 }
171
172 /// Returns the number of jobs in the queue.
173 pub fn num_jobs(&self) -> usize {
174 self.inner.lock().expect("lock poisoned").jobs.len()
175 }
176
177 /// Creates a streaming consumer of the work queue.
178 #[inline]
179 pub fn to_stream(self: Arc<Self>) -> impl Stream<Item = JobHandle<T>> {
180 stream::unfold(self, |work_queue| async move {
181 let next = work_queue.next_job().await;
182 next.map(|handle| (handle, work_queue))
183 })
184 }
185
186 /// Mark job completion.
187 ///
188 /// This is an internal function to be used by `JobHandle`, which locks the internal queue and
189 /// decreases the in-progress count by one.
190 fn complete_job(&self) {
191 let mut inner = self.inner.lock().expect("lock poisoned");
192
193 inner.in_progress -= 1;
194 self.notify.notify_waiters();
195 }
196}
197
198/// Handle containing a job.
199///
200/// Holds a job popped from the job queue.
201///
202/// The job will be considered completed once `JobHandle` has been dropped.
203#[derive(Debug)]
204pub struct JobHandle<T> {
205 /// The protected job.
206 job: T,
207 /// Queue job was removed from.
208 queue: Arc<WorkQueue<T>>,
209}
210
211impl<T> JobHandle<T> {
212 /// Returns a reference to the inner job.
213 pub fn inner(&self) -> &T {
214 &self.job
215 }
216}
217
218impl<T> Drop for JobHandle<T> {
219 fn drop(&mut self) {
220 self.queue.complete_job();
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use std::{
227 sync::{
228 atomic::{AtomicU32, Ordering},
229 Arc,
230 },
231 time::Duration,
232 };
233
234 use futures::{FutureExt, StreamExt};
235 use tokio::sync::Notify;
236
237 use super::WorkQueue;
238
239 #[derive(Debug)]
240 struct TestJob(u32);
241
242 // Verify that the assumption made about `Notification` -- namely that a call to `notified()` is
243 // enough to "register" the waiter -- holds.
244 #[test]
245 fn notification_assumption_holds() {
246 let not = Notify::new();
247
248 // First attempt to await a notification, should return pending.
249 assert!(not.notified().now_or_never().is_none());
250
251 // Second, we notify, then try notification again. Should also return pending, as we were
252 // "not around" when the notification happened.
253 not.notify_waiters();
254 assert!(not.notified().now_or_never().is_none());
255
256 // Finally, we "register" for notification beforehand.
257 let waiter = not.notified();
258 not.notify_waiters();
259 assert!(waiter.now_or_never().is_some());
260 }
261
262 /// Process a job, sleeping a short amout of time on every 5th job.
263 async fn job_worker_simple(queue: Arc<WorkQueue<TestJob>>, sum: Arc<AtomicU32>) {
264 while let Some(job) = queue.next_job().await {
265 if job.inner().0 % 5 == 0 {
266 tokio::time::sleep(Duration::from_millis(50)).await;
267 }
268
269 sum.fetch_add(job.inner().0, Ordering::SeqCst);
270 }
271 }
272
273 /// Process a job, sleeping a short amount of time on every job.
274 ///
275 /// Spawns two additional jobs for every job processed, decreasing the job number until reaching
276 /// zero.
277 async fn job_worker_binary(queue: Arc<WorkQueue<TestJob>>, sum: Arc<AtomicU32>) {
278 while let Some(job) = queue.next_job().await {
279 tokio::time::sleep(Duration::from_millis(10)).await;
280
281 sum.fetch_add(job.inner().0, Ordering::SeqCst);
282
283 if job.inner().0 > 0 {
284 queue.push_job(TestJob(job.inner().0 - 1));
285 queue.push_job(TestJob(job.inner().0 - 1));
286 }
287 }
288 }
289
290 #[tokio::test]
291 async fn empty_queue_exits_immediately() {
292 let q: Arc<WorkQueue<TestJob>> = Arc::new(Default::default());
293 assert!(q.next_job().await.is_none());
294 }
295
296 #[tokio::test]
297 async fn large_front_loaded_queue_terminates() {
298 let num_jobs = 1_000;
299 let q: Arc<WorkQueue<TestJob>> = Arc::new(Default::default());
300 for job in (0..num_jobs).map(TestJob) {
301 q.push_job(job);
302 }
303
304 let mut workers = Vec::new();
305 let output = Arc::new(AtomicU32::new(0));
306 for _ in 0..3 {
307 workers.push(tokio::spawn(job_worker_simple(q.clone(), output.clone())));
308 }
309
310 // We use a different pattern for waiting here, see the doctest for a solution that does not
311 // spawn.
312 for worker in workers {
313 worker.await.expect("task panicked");
314 }
315
316 let expected_total = (num_jobs * (num_jobs - 1)) / 2;
317 assert_eq!(output.load(Ordering::SeqCst), expected_total);
318 }
319
320 #[tokio::test]
321 async fn stream_interface_works() {
322 let num_jobs = 1_000;
323 let q: Arc<WorkQueue<TestJob>> = Arc::new(Default::default());
324 for job in (0..num_jobs).map(TestJob) {
325 q.push_job(job);
326 }
327
328 let mut current = 0;
329 let mut stream = Box::pin(q.to_stream());
330 while let Some(job) = stream.next().await {
331 assert_eq!(job.inner().0, current);
332 current += 1;
333 }
334 }
335
336 #[tokio::test]
337 async fn complex_queue_terminates() {
338 let num_jobs = 5;
339 let q: Arc<WorkQueue<TestJob>> = Arc::new(Default::default());
340 for _ in 0..num_jobs {
341 q.push_job(TestJob(num_jobs));
342 }
343
344 let mut workers = Vec::new();
345 let output = Arc::new(AtomicU32::new(0));
346 for _ in 0..3 {
347 workers.push(tokio::spawn(job_worker_binary(q.clone(), output.clone())));
348 }
349
350 // We use a different pattern for waiting here, see the doctest for a solution that does not
351 // spawn.
352 for worker in workers {
353 worker.await.expect("task panicked");
354 }
355
356 // A single job starting at `k` will add `SUM_{n=0}^{k} (k-n) * 2^n`, which is
357 // 57 for `k=5`. We start 5 jobs, so we expect `5 * 57 = 285` to be the result.
358 let expected_total = 285;
359 assert_eq!(output.load(Ordering::SeqCst), expected_total);
360 }
361}