use std::{
collections::VecDeque,
sync::{Arc, Mutex},
};
use futures::{stream, Stream};
use tokio::sync::Notify;
#[derive(Debug)]
pub struct WorkQueue<T> {
inner: Mutex<QueueInner<T>>,
notify: Notify,
}
#[derive(Debug)]
struct QueueInner<T> {
jobs: VecDeque<T>,
in_progress: usize,
}
impl<T> Default for WorkQueue<T> {
fn default() -> Self {
Self {
inner: Default::default(),
notify: Default::default(),
}
}
}
impl<T> Default for QueueInner<T> {
fn default() -> Self {
Self {
jobs: Default::default(),
in_progress: Default::default(),
}
}
}
impl<T> WorkQueue<T> {
pub async fn next_job(self: &Arc<Self>) -> Option<JobHandle<T>> {
loop {
let waiting;
{
let mut inner = self.inner.lock().expect("lock poisoned");
match inner.jobs.pop_front() {
Some(job) => {
inner.in_progress += 1;
return Some(JobHandle {
job,
queue: self.clone(),
});
}
None => {
if inner.in_progress == 0 {
return None;
}
waiting = self.notify.notified();
}
}
}
waiting.await;
}
}
pub fn push_job(&self, job: T) {
let mut inner = self.inner.lock().expect("lock poisoned");
inner.jobs.push_back(job);
self.notify.notify_waiters();
}
pub fn num_jobs(&self) -> usize {
self.inner.lock().expect("lock poisoned").jobs.len()
}
#[inline]
pub fn to_stream(self: Arc<Self>) -> impl Stream<Item = JobHandle<T>> {
stream::unfold(self, |work_queue| async move {
let next = work_queue.next_job().await;
next.map(|handle| (handle, work_queue))
})
}
fn complete_job(&self) {
let mut inner = self.inner.lock().expect("lock poisoned");
inner.in_progress -= 1;
self.notify.notify_waiters();
}
}
#[derive(Debug)]
pub struct JobHandle<T> {
job: T,
queue: Arc<WorkQueue<T>>,
}
impl<T> JobHandle<T> {
pub fn inner(&self) -> &T {
&self.job
}
}
impl<T> Drop for JobHandle<T> {
fn drop(&mut self) {
self.queue.complete_job();
}
}
#[cfg(test)]
mod tests {
use std::{
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
time::Duration,
};
use futures::{FutureExt, StreamExt};
use tokio::sync::Notify;
use super::WorkQueue;
#[derive(Debug)]
struct TestJob(u32);
#[test]
fn notification_assumption_holds() {
let not = Notify::new();
assert!(not.notified().now_or_never().is_none());
not.notify_waiters();
assert!(not.notified().now_or_never().is_none());
let waiter = not.notified();
not.notify_waiters();
assert!(waiter.now_or_never().is_some());
}
async fn job_worker_simple(queue: Arc<WorkQueue<TestJob>>, sum: Arc<AtomicU32>) {
while let Some(job) = queue.next_job().await {
if job.inner().0 % 5 == 0 {
tokio::time::sleep(Duration::from_millis(50)).await;
}
sum.fetch_add(job.inner().0, Ordering::SeqCst);
}
}
async fn job_worker_binary(queue: Arc<WorkQueue<TestJob>>, sum: Arc<AtomicU32>) {
while let Some(job) = queue.next_job().await {
tokio::time::sleep(Duration::from_millis(10)).await;
sum.fetch_add(job.inner().0, Ordering::SeqCst);
if job.inner().0 > 0 {
queue.push_job(TestJob(job.inner().0 - 1));
queue.push_job(TestJob(job.inner().0 - 1));
}
}
}
#[tokio::test]
async fn empty_queue_exits_immediately() {
let q: Arc<WorkQueue<TestJob>> = Arc::new(Default::default());
assert!(q.next_job().await.is_none());
}
#[tokio::test]
async fn large_front_loaded_queue_terminates() {
let num_jobs = 1_000;
let q: Arc<WorkQueue<TestJob>> = Arc::new(Default::default());
for job in (0..num_jobs).map(TestJob) {
q.push_job(job);
}
let mut workers = Vec::new();
let output = Arc::new(AtomicU32::new(0));
for _ in 0..3 {
workers.push(tokio::spawn(job_worker_simple(q.clone(), output.clone())));
}
for worker in workers {
worker.await.expect("task panicked");
}
let expected_total = (num_jobs * (num_jobs - 1)) / 2;
assert_eq!(output.load(Ordering::SeqCst), expected_total);
}
#[tokio::test]
async fn stream_interface_works() {
let num_jobs = 1_000;
let q: Arc<WorkQueue<TestJob>> = Arc::new(Default::default());
for job in (0..num_jobs).map(TestJob) {
q.push_job(job);
}
let mut current = 0;
let mut stream = Box::pin(q.to_stream());
while let Some(job) = stream.next().await {
assert_eq!(job.inner().0, current);
current += 1;
}
}
#[tokio::test]
async fn complex_queue_terminates() {
let num_jobs = 5;
let q: Arc<WorkQueue<TestJob>> = Arc::new(Default::default());
for _ in 0..num_jobs {
q.push_job(TestJob(num_jobs));
}
let mut workers = Vec::new();
let output = Arc::new(AtomicU32::new(0));
for _ in 0..3 {
workers.push(tokio::spawn(job_worker_binary(q.clone(), output.clone())));
}
for worker in workers {
worker.await.expect("task panicked");
}
let expected_total = 285;
assert_eq!(output.load(Ordering::SeqCst), expected_total);
}
}