use graphile_worker::{
HookRegistry, IntoTaskHandlerResult, JobComplete, JobFail, JobPermanentlyFail, JobSpec, Plugin,
TaskHandler, Worker, WorkerContext,
};
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::task::spawn_local;
use tokio::time::{sleep, Instant};
use crate::helpers::with_test_db;
mod helpers;
#[derive(Clone, Debug)]
struct CompletedCounter(Arc<AtomicU32>);
impl CompletedCounter {
fn new() -> Self {
Self(Arc::new(AtomicU32::new(0)))
}
fn increment(&self) {
self.0.fetch_add(1, Ordering::SeqCst);
}
fn get(&self) -> u32 {
self.0.load(Ordering::SeqCst)
}
}
#[derive(Debug, Default)]
struct BatcherHookCounters {
complete: AtomicU32,
fail: AtomicU32,
permanent: AtomicU32,
}
#[derive(Clone, Debug)]
struct BatcherHooksPlugin {
counters: Arc<BatcherHookCounters>,
}
impl BatcherHooksPlugin {
fn new() -> Self {
Self {
counters: Arc::new(BatcherHookCounters::default()),
}
}
fn counters(&self) -> Arc<BatcherHookCounters> {
self.counters.clone()
}
}
impl Plugin for BatcherHooksPlugin {
fn register(self, hooks: &mut HookRegistry) {
let counters = self.counters.clone();
hooks.on(JobComplete, move |_ctx| {
let counters = counters.clone();
async move {
counters.complete.fetch_add(1, Ordering::SeqCst);
}
});
let counters = self.counters.clone();
hooks.on(JobFail, move |_ctx| {
let counters = counters.clone();
async move {
counters.fail.fetch_add(1, Ordering::SeqCst);
}
});
let counters = self.counters.clone();
hooks.on(JobPermanentlyFail, move |_ctx| {
let counters = counters.clone();
async move {
counters.permanent.fetch_add(1, Ordering::SeqCst);
}
});
}
}
#[derive(Serialize, Deserialize)]
struct SuccessJob {
id: u32,
}
impl TaskHandler for SuccessJob {
const IDENTIFIER: &'static str = "success_job";
async fn run(self, ctx: WorkerContext) -> impl IntoTaskHandlerResult {
if let Some(counter) = ctx.get_ext::<CompletedCounter>() {
counter.increment();
}
Ok::<(), String>(())
}
}
#[derive(Serialize, Deserialize)]
struct FailJob {
id: u32,
}
impl TaskHandler for FailJob {
const IDENTIFIER: &'static str = "fail_job";
async fn run(self, ctx: WorkerContext) -> impl IntoTaskHandlerResult {
if let Some(counter) = ctx.get_ext::<CompletedCounter>() {
counter.increment();
}
Err::<(), String>(format!("Job {} failed", self.id))
}
}
#[tokio::test]
async fn test_batchers_emit_lifecycle_hooks() {
with_test_db(|test_db| async move {
let utils = test_db.worker_utils();
utils.migrate().await.expect("Failed to migrate");
let attempts = CompletedCounter::new();
let plugin = BatcherHooksPlugin::new();
let hook_counters = plugin.counters();
let worker = Arc::new(
Worker::options()
.database(test_db.database.clone())
.concurrency(4)
.poll_interval(Duration::from_millis(50))
.complete_job_batch_delay(Duration::from_millis(10))
.fail_job_batch_delay(Duration::from_millis(10))
.add_extension(attempts.clone())
.add_plugin(plugin)
.define_job::<SuccessJob>()
.define_job::<FailJob>()
.init()
.await
.expect("Failed to create worker"),
);
let worker_clone = worker.clone();
let worker_handle = spawn_local(async move {
worker_clone.run().await.expect("Failed to run worker");
});
utils
.add_job(SuccessJob { id: 1 }, JobSpec::default())
.await
.expect("Failed to add success job");
utils
.add_job(
SuccessJob { id: 2 },
JobSpec::builder().queue_name("hooked_complete_queue").build(),
)
.await
.expect("Failed to add queued success job");
utils
.add_job(FailJob { id: 1 }, JobSpec::builder().max_attempts(3).build())
.await
.expect("Failed to add retryable fail job");
utils
.add_job(
FailJob { id: 2 },
JobSpec::builder()
.queue_name("hooked_fail_queue")
.max_attempts(1)
.build(),
)
.await
.expect("Failed to add permanent fail job");
let start = Instant::now();
while hook_counters.complete.load(Ordering::SeqCst) < 2
|| hook_counters.fail.load(Ordering::SeqCst) < 1
|| hook_counters.permanent.load(Ordering::SeqCst) < 1
{
if start.elapsed() > Duration::from_secs(5) {
panic!(
"Batch hooks should have fired by now, complete: {}, fail: {}, permanent: {}, attempts: {}",
hook_counters.complete.load(Ordering::SeqCst),
hook_counters.fail.load(Ordering::SeqCst),
hook_counters.permanent.load(Ordering::SeqCst),
attempts.get()
);
}
sleep(Duration::from_millis(50)).await;
}
assert_eq!(hook_counters.complete.load(Ordering::SeqCst), 2);
assert_eq!(hook_counters.fail.load(Ordering::SeqCst), 1);
assert_eq!(hook_counters.permanent.load(Ordering::SeqCst), 1);
worker.request_shutdown();
let _ = worker_handle.await;
})
.await;
}
#[tokio::test]
async fn test_complete_job_batch_delay() {
with_test_db(|test_db| async move {
let utils = test_db.worker_utils();
utils.migrate().await.expect("Failed to migrate");
let counter = CompletedCounter::new();
let worker = Arc::new(
Worker::options()
.database(test_db.database.clone())
.concurrency(4)
.poll_interval(Duration::from_millis(50))
.complete_job_batch_delay(Duration::from_millis(10))
.add_extension(counter.clone())
.define_job::<SuccessJob>()
.init()
.await
.expect("Failed to create worker"),
);
let worker_clone = worker.clone();
let worker_handle = spawn_local(async move {
worker_clone.run().await.expect("Failed to run worker");
});
for i in 0..10 {
utils
.add_job(SuccessJob { id: i }, JobSpec::default())
.await
.expect("Failed to add job");
}
let start = Instant::now();
while counter.get() < 10 {
if start.elapsed() > Duration::from_secs(5) {
panic!(
"Jobs should have completed by now, only {} completed",
counter.get()
);
}
sleep(Duration::from_millis(50)).await;
}
assert_eq!(counter.get(), 10, "All 10 jobs should have completed");
worker.request_shutdown();
let _ = worker_handle.await;
})
.await;
}
#[tokio::test]
async fn test_fail_job_batch_delay() {
with_test_db(|test_db| async move {
let utils = test_db.worker_utils();
utils.migrate().await.expect("Failed to migrate");
let counter = CompletedCounter::new();
let worker = Arc::new(
Worker::options()
.database(test_db.database.clone())
.concurrency(4)
.poll_interval(Duration::from_millis(50))
.fail_job_batch_delay(Duration::from_millis(10))
.add_extension(counter.clone())
.define_job::<FailJob>()
.init()
.await
.expect("Failed to create worker"),
);
let worker_clone = worker.clone();
let worker_handle = spawn_local(async move {
worker_clone.run().await.expect("Failed to run worker");
});
for i in 0..5 {
utils
.add_job(
FailJob { id: i },
JobSpec::builder().max_attempts(1).build(), )
.await
.expect("Failed to add job");
}
let start = Instant::now();
while counter.get() < 5 {
if start.elapsed() > Duration::from_secs(5) {
panic!(
"Jobs should have been attempted by now, only {} attempted",
counter.get()
);
}
sleep(Duration::from_millis(50)).await;
}
assert_eq!(counter.get(), 5, "All 5 jobs should have been attempted");
worker.request_shutdown();
let _ = worker_handle.await;
})
.await;
}
#[tokio::test]
async fn test_both_batchers_together() {
with_test_db(|test_db| async move {
let utils = test_db.worker_utils();
utils.migrate().await.expect("Failed to migrate");
let success_counter = CompletedCounter::new();
let fail_counter = CompletedCounter::new();
#[derive(Clone, Debug)]
struct SuccessCounter(CompletedCounter);
#[derive(Clone, Debug)]
struct FailCounter(CompletedCounter);
#[derive(Serialize, Deserialize)]
struct MixedSuccessJob {
id: u32,
}
impl TaskHandler for MixedSuccessJob {
const IDENTIFIER: &'static str = "mixed_success_job";
async fn run(self, ctx: WorkerContext) -> impl IntoTaskHandlerResult {
if let Some(counter) = ctx.get_ext::<SuccessCounter>() {
counter.0.increment();
}
Ok::<(), String>(())
}
}
#[derive(Serialize, Deserialize)]
struct MixedFailJob {
id: u32,
}
impl TaskHandler for MixedFailJob {
const IDENTIFIER: &'static str = "mixed_fail_job";
async fn run(self, ctx: WorkerContext) -> impl IntoTaskHandlerResult {
if let Some(counter) = ctx.get_ext::<FailCounter>() {
counter.0.increment();
}
Err::<(), String>(format!("Job {} failed", self.id))
}
}
let worker = Arc::new(
Worker::options()
.database(test_db.database.clone())
.concurrency(8)
.poll_interval(Duration::from_millis(50))
.complete_job_batch_delay(Duration::from_millis(10))
.fail_job_batch_delay(Duration::from_millis(10))
.add_extension(SuccessCounter(success_counter.clone()))
.add_extension(FailCounter(fail_counter.clone()))
.define_job::<MixedSuccessJob>()
.define_job::<MixedFailJob>()
.init()
.await
.expect("Failed to create worker"),
);
let worker_clone = worker.clone();
let worker_handle = spawn_local(async move {
worker_clone.run().await.expect("Failed to run worker");
});
for i in 0..5 {
utils
.add_job(MixedSuccessJob { id: i }, JobSpec::default())
.await
.expect("Failed to add job");
utils
.add_job(
MixedFailJob { id: i },
JobSpec::builder().max_attempts(1).build(),
)
.await
.expect("Failed to add job");
}
let start = Instant::now();
while success_counter.get() < 5 || fail_counter.get() < 5 {
if start.elapsed() > Duration::from_secs(5) {
panic!(
"Jobs should have completed by now, success: {}, fail: {}",
success_counter.get(),
fail_counter.get()
);
}
sleep(Duration::from_millis(50)).await;
}
assert_eq!(success_counter.get(), 5);
assert_eq!(fail_counter.get(), 5);
worker.request_shutdown();
let _ = worker_handle.await;
})
.await;
}
#[tokio::test]
async fn test_shutdown_flushes_pending_completions() {
with_test_db(|test_db| async move {
let utils = test_db.worker_utils();
utils.migrate().await.expect("Failed to migrate");
let counter = CompletedCounter::new();
let worker = Arc::new(
Worker::options()
.database(test_db.database.clone())
.concurrency(4)
.poll_interval(Duration::from_millis(50))
.complete_job_batch_delay(Duration::from_millis(100))
.add_extension(counter.clone())
.define_job::<SuccessJob>()
.init()
.await
.expect("Failed to create worker"),
);
let worker_clone = worker.clone();
let worker_handle = spawn_local(async move {
worker_clone.run().await.expect("Failed to run worker");
});
for i in 0..3 {
utils
.add_job(SuccessJob { id: i }, JobSpec::default())
.await
.expect("Failed to add job");
}
let start = Instant::now();
while counter.get() < 3 {
if start.elapsed() > Duration::from_secs(5) {
panic!("Jobs should have run by now");
}
sleep(Duration::from_millis(50)).await;
}
worker.request_shutdown();
let _ = worker_handle.await;
let remaining_jobs: (i64,) =
sqlx::query_as("SELECT COUNT(*) FROM graphile_worker._private_jobs")
.fetch_one(&test_db.test_pool)
.await
.expect("Failed to count jobs");
assert_eq!(
remaining_jobs.0, 0,
"All jobs should have been completed and removed from the database"
);
})
.await;
}
#[tokio::test]
async fn test_retryable_failures_processed_individually() {
with_test_db(|test_db| async move {
let utils = test_db.worker_utils();
utils.migrate().await.expect("Failed to migrate");
let counter = CompletedCounter::new();
let worker = Arc::new(
Worker::options()
.database(test_db.database.clone())
.concurrency(4)
.poll_interval(Duration::from_millis(50))
.fail_job_batch_delay(Duration::from_millis(10))
.add_extension(counter.clone())
.define_job::<FailJob>()
.init()
.await
.expect("Failed to create worker"),
);
let worker_clone = worker.clone();
let worker_handle = spawn_local(async move {
worker_clone.run().await.expect("Failed to run worker");
});
utils
.add_job(
FailJob { id: 1 },
JobSpec::builder().max_attempts(3).build(), )
.await
.expect("Failed to add job");
let start = Instant::now();
while counter.get() < 1 {
if start.elapsed() > Duration::from_secs(5) {
panic!("Job should have been attempted by now");
}
sleep(Duration::from_millis(50)).await;
}
sleep(Duration::from_millis(200)).await;
let job: Option<(i64, i16, i16)> = sqlx::query_as(
"SELECT id, attempts, max_attempts FROM graphile_worker._private_jobs LIMIT 1",
)
.fetch_optional(&test_db.test_pool)
.await
.expect("Failed to query job");
assert!(job.is_some(), "Job should still exist for retry");
let (_, attempts, max_attempts) = job.unwrap();
assert_eq!(attempts, 1, "Job should have 1 attempt");
assert_eq!(max_attempts, 3, "Job should have max_attempts=3");
worker.request_shutdown();
let _ = worker_handle.await;
})
.await;
}
#[tokio::test]
async fn test_permanent_failure_unlocks_job_and_queue() {
with_test_db(|test_db| async move {
let utils = test_db.worker_utils();
utils.migrate().await.expect("Failed to migrate");
let counter = CompletedCounter::new();
let worker = Arc::new(
Worker::options()
.database(test_db.database.clone())
.concurrency(4)
.poll_interval(Duration::from_millis(50))
.fail_job_batch_delay(Duration::from_millis(10))
.add_extension(counter.clone())
.define_job::<FailJob>()
.init()
.await
.expect("Failed to create worker"),
);
let worker_clone = worker.clone();
let worker_handle = spawn_local(async move {
worker_clone.run().await.expect("Failed to run worker");
});
utils
.add_job(
FailJob { id: 1 },
JobSpec::builder()
.max_attempts(1)
.queue_name("test_queue")
.build(),
)
.await
.expect("Failed to add job");
let start = Instant::now();
while counter.get() < 1 {
if start.elapsed() > Duration::from_secs(5) {
panic!("Job should have been attempted by now");
}
sleep(Duration::from_millis(50)).await;
}
let start = Instant::now();
let (job, queue) = loop {
let job: Option<(i64, Option<String>, Option<chrono::DateTime<chrono::Utc>>)> =
sqlx::query_as(
"SELECT id, locked_by, locked_at FROM graphile_worker._private_jobs LIMIT 1",
)
.fetch_optional(&test_db.test_pool)
.await
.expect("Failed to query job");
let queue: Option<(i32, Option<String>, Option<chrono::DateTime<chrono::Utc>>)> =
sqlx::query_as(
"SELECT id, locked_by, locked_at FROM graphile_worker._private_job_queues WHERE queue_name = 'test_queue' LIMIT 1",
)
.fetch_optional(&test_db.test_pool)
.await
.expect("Failed to query queue");
let job_unlocked = job
.as_ref()
.is_some_and(|(_, locked_by, locked_at)| locked_by.is_none() && locked_at.is_none());
let queue_unlocked = queue.as_ref().is_some_and(
|(_, queue_locked_by, queue_locked_at)| {
queue_locked_by.is_none() && queue_locked_at.is_none()
},
);
if job_unlocked && queue_unlocked {
break (job.unwrap(), queue.unwrap());
}
if start.elapsed() > Duration::from_secs(5) {
panic!(
"Job and queue should have been unlocked after permanent failure. job={job:?}, queue={queue:?}"
);
}
sleep(Duration::from_millis(50)).await;
};
let (_, locked_by, locked_at) = job;
assert!(
locked_by.is_none(),
"Job locked_by should be NULL after permanent failure"
);
assert!(
locked_at.is_none(),
"Job locked_at should be NULL after permanent failure"
);
let (_, queue_locked_by, queue_locked_at) = queue;
assert!(
queue_locked_by.is_none(),
"Queue locked_by should be NULL after permanent failure"
);
assert!(
queue_locked_at.is_none(),
"Queue locked_at should be NULL after permanent failure"
);
worker.request_shutdown();
let _ = worker_handle.await;
})
.await;
}