#![allow(missing_docs)]
#![allow(clippy::expect_used)]
#![allow(clippy::panic)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::indexing_slicing)]
use claims::{assert_none, assert_some};
use insta::assert_compact_json_snapshot;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sqlx::{PgPool, Row};
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, Ordering};
use testcontainers::ContainerAsync;
use testcontainers_modules::postgres::Postgres;
use tokio::sync::Barrier;
use workers::{BackgroundJob, Runner};
mod test_utils {
use super::*;
use testcontainers::runners::AsyncRunner;
pub(super) async fn setup_test_db() -> anyhow::Result<(PgPool, ContainerAsync<Postgres>)> {
let postgres_image = Postgres::default();
let container = postgres_image.start().await?;
let host = container.get_host().await?;
let port = container.get_host_port_ipv4(5432).await?;
let connection_string =
format!("postgresql://postgres:postgres@{}:{}/postgres", host, port);
let pool = PgPool::connect(&connection_string).await?;
sqlx::migrate!("./migrations").run(&pool).await?;
Ok((pool, container))
}
pub(super) fn create_test_runner<Context: Clone + Send + Sync + 'static>(
pool: PgPool,
context: Context,
) -> Runner<Context> {
Runner::new(pool, context)
.configure_default_queue(|queue| queue.num_workers(2))
.shutdown_when_queue_empty()
}
}
async fn all_jobs(pool: &PgPool) -> anyhow::Result<Vec<(String, Value)>> {
let jobs = sqlx::query("SELECT job_type, data FROM background_jobs")
.fetch_all(pool)
.await?;
Ok(jobs
.into_iter()
.map(|row| {
let job_type: String = row.get("job_type");
let data: Value = row.get("data");
(job_type, data)
})
.collect())
}
async fn job_exists(id: i64, pool: &PgPool) -> anyhow::Result<bool> {
let result =
sqlx::query_scalar::<_, Option<i64>>("SELECT id FROM background_jobs WHERE id = $1")
.bind(id)
.fetch_optional(pool)
.await?;
Ok(result.is_some())
}
async fn job_is_locked(id: i64, pool: &PgPool) -> anyhow::Result<bool> {
let result = sqlx::query_scalar::<_, Option<i64>>(
"SELECT id FROM background_jobs WHERE id = $1 FOR UPDATE SKIP LOCKED",
)
.bind(id)
.fetch_optional(pool)
.await?;
Ok(result.is_none())
}
#[tokio::test]
async fn jobs_are_locked_when_fetched() -> anyhow::Result<()> {
#[derive(Clone)]
struct TestContext {
job_started_barrier: Arc<Barrier>,
assertions_finished_barrier: Arc<Barrier>,
}
#[derive(Serialize, Deserialize)]
struct TestJob;
impl BackgroundJob for TestJob {
const JOB_NAME: &'static str = "test";
type Context = TestContext;
async fn run(&self, ctx: Self::Context) -> anyhow::Result<()> {
ctx.job_started_barrier.wait().await;
ctx.assertions_finished_barrier.wait().await;
Ok(())
}
}
let test_context = TestContext {
job_started_barrier: Arc::new(Barrier::new(2)),
assertions_finished_barrier: Arc::new(Barrier::new(2)),
};
let (pool, _container) = test_utils::setup_test_db().await?;
let runner = test_utils::create_test_runner(pool.clone(), test_context.clone())
.register_job_type::<TestJob>();
let job_id = assert_some!(TestJob.enqueue(&pool).await?);
assert!(job_exists(job_id, &pool).await?);
assert!(!job_is_locked(job_id, &pool).await?);
let runner = runner.start();
test_context.job_started_barrier.wait().await;
assert!(job_exists(job_id, &pool).await?);
assert!(job_is_locked(job_id, &pool).await?);
test_context.assertions_finished_barrier.wait().await;
runner.wait_for_shutdown().await;
assert!(!job_exists(job_id, &pool).await?);
Ok(())
}
#[tokio::test]
async fn jobs_are_deleted_when_successfully_run() -> anyhow::Result<()> {
#[derive(Serialize, Deserialize)]
struct TestJob;
impl BackgroundJob for TestJob {
const JOB_NAME: &'static str = "test";
type Context = ();
async fn run(&self, _ctx: Self::Context) -> anyhow::Result<()> {
Ok(())
}
}
async fn remaining_jobs(pool: &PgPool) -> anyhow::Result<i64> {
let count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM background_jobs")
.fetch_one(pool)
.await?;
Ok(count)
}
let (pool, _container) = test_utils::setup_test_db().await?;
let runner = test_utils::create_test_runner(pool.clone(), ()).register_job_type::<TestJob>();
assert_eq!(remaining_jobs(&pool).await?, 0);
TestJob.enqueue(&pool).await?;
assert_eq!(remaining_jobs(&pool).await?, 1);
let runner = runner.start();
runner.wait_for_shutdown().await;
assert_eq!(remaining_jobs(&pool).await?, 0);
Ok(())
}
#[tokio::test]
async fn failed_jobs_do_not_release_lock_before_updating_retry_time() -> anyhow::Result<()> {
#[derive(Clone)]
struct TestContext {
job_started_barrier: Arc<Barrier>,
}
#[derive(Serialize, Deserialize)]
struct TestJob;
impl BackgroundJob for TestJob {
const JOB_NAME: &'static str = "test";
type Context = TestContext;
async fn run(&self, ctx: Self::Context) -> anyhow::Result<()> {
ctx.job_started_barrier.wait().await;
panic!();
}
}
let test_context = TestContext {
job_started_barrier: Arc::new(Barrier::new(2)),
};
let (pool, _container) = test_utils::setup_test_db().await?;
let runner = test_utils::create_test_runner(pool.clone(), test_context.clone())
.register_job_type::<TestJob>();
TestJob.enqueue(&pool).await?;
let runner = runner.start();
test_context.job_started_barrier.wait().await;
let available_jobs =
sqlx::query_scalar::<_, i64>("SELECT id FROM background_jobs WHERE retries = 0 FOR UPDATE")
.fetch_all(&pool)
.await?;
assert_eq!(available_jobs.len(), 0);
let total_jobs_including_failed =
sqlx::query_scalar::<_, i64>("SELECT id FROM background_jobs FOR UPDATE")
.fetch_all(&pool)
.await?;
assert_eq!(total_jobs_including_failed.len(), 1);
runner.wait_for_shutdown().await;
Ok(())
}
#[tokio::test]
async fn panicking_in_jobs_updates_retry_counter() -> anyhow::Result<()> {
#[derive(Serialize, Deserialize)]
struct TestJob;
impl BackgroundJob for TestJob {
const JOB_NAME: &'static str = "test";
type Context = ();
async fn run(&self, _ctx: Self::Context) -> anyhow::Result<()> {
panic!()
}
}
let (pool, _container) = test_utils::setup_test_db().await?;
let runner = test_utils::create_test_runner(pool.clone(), ()).register_job_type::<TestJob>();
let job_id = assert_some!(TestJob.enqueue(&pool).await?);
let runner = runner.start();
runner.wait_for_shutdown().await;
let tries = sqlx::query_scalar::<_, i32>(
"SELECT retries FROM background_jobs WHERE id = $1 FOR UPDATE",
)
.bind(job_id)
.fetch_one(&pool)
.await?;
assert_eq!(tries, 1);
Ok(())
}
#[tokio::test]
async fn jobs_can_be_deduplicated() -> anyhow::Result<()> {
#[derive(Clone)]
struct TestContext {
runs: Arc<AtomicU8>,
job_started_barrier: Arc<Barrier>,
assertions_finished_barrier: Arc<Barrier>,
}
#[derive(Serialize, Deserialize)]
struct TestJob {
value: String,
}
impl TestJob {
fn new(value: impl Into<String>) -> Self {
let value = value.into();
Self { value }
}
}
impl BackgroundJob for TestJob {
const JOB_NAME: &'static str = "test";
const DEDUPLICATED: bool = true;
type Context = TestContext;
async fn run(&self, ctx: Self::Context) -> anyhow::Result<()> {
let runs = ctx.runs.fetch_add(1, Ordering::SeqCst);
if runs == 0 {
ctx.job_started_barrier.wait().await;
ctx.assertions_finished_barrier.wait().await;
}
Ok(())
}
}
let test_context = TestContext {
runs: Arc::new(AtomicU8::new(0)),
job_started_barrier: Arc::new(Barrier::new(2)),
assertions_finished_barrier: Arc::new(Barrier::new(2)),
};
let (pool, _container) = test_utils::setup_test_db().await?;
let runner = Runner::new(pool.clone(), test_context.clone())
.register_job_type::<TestJob>()
.shutdown_when_queue_empty();
assert_some!(TestJob::new("foo").enqueue(&pool).await?);
assert_compact_json_snapshot!(all_jobs(&pool).await?, @r#"[["test", {"value": "foo"}]]"#);
assert_none!(TestJob::new("foo").enqueue(&pool).await?);
assert_compact_json_snapshot!(all_jobs(&pool).await?, @r#"[["test", {"value": "foo"}]]"#);
let runner = runner.start();
test_context.job_started_barrier.wait().await;
assert_some!(TestJob::new("foo").enqueue(&pool).await?);
assert_compact_json_snapshot!(all_jobs(&pool).await?, @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}]]"#);
assert_none!(TestJob::new("foo").enqueue(&pool).await?);
assert_compact_json_snapshot!(all_jobs(&pool).await?, @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}]]"#);
assert_some!(TestJob::new("bar").enqueue(&pool).await?);
assert_compact_json_snapshot!(all_jobs(&pool).await?, @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}], ["test", {"value": "bar"}]]"#);
test_context.assertions_finished_barrier.wait().await;
runner.wait_for_shutdown().await;
Ok(())
}