workers 0.1.0

A robust async PostgreSQL-backed background job processing system
Documentation
#![allow(missing_docs)]
#![allow(clippy::expect_used)]
#![allow(clippy::panic)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::indexing_slicing)]

use claims::{assert_none, assert_some};
use insta::assert_compact_json_snapshot;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sqlx::{PgPool, Row};
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, Ordering};
use testcontainers::ContainerAsync;
use testcontainers_modules::postgres::Postgres;
use tokio::sync::Barrier;
use workers::{BackgroundJob, Runner};

/// Test utilities and common setup
mod test_utils {
    use super::*;
    use testcontainers::runners::AsyncRunner;

    /// Set up a test database with TestContainers and return the pool and container  
    pub(super) async fn setup_test_db() -> anyhow::Result<(PgPool, ContainerAsync<Postgres>)> {
        let postgres_image = Postgres::default();
        let container = postgres_image.start().await?;

        // Get the connection parameters from the container
        let host = container.get_host().await?;
        let port = container.get_host_port_ipv4(5432).await?;

        // Use the standard postgres/postgres credentials for testcontainers
        let connection_string =
            format!("postgresql://postgres:postgres@{}:{}/postgres", host, port);

        let pool = PgPool::connect(&connection_string).await?;

        // Run migrations
        sqlx::migrate!("./migrations").run(&pool).await?;

        Ok((pool, container))
    }

    /// Create a test runner with common configuration
    pub(super) fn create_test_runner<Context: Clone + Send + Sync + 'static>(
        pool: PgPool,
        context: Context,
    ) -> Runner<Context> {
        Runner::new(pool, context)
            .configure_default_queue(|queue| queue.num_workers(2))
            .shutdown_when_queue_empty()
    }
}

async fn all_jobs(pool: &PgPool) -> anyhow::Result<Vec<(String, Value)>> {
    let jobs = sqlx::query("SELECT job_type, data FROM background_jobs")
        .fetch_all(pool)
        .await?;

    Ok(jobs
        .into_iter()
        .map(|row| {
            let job_type: String = row.get("job_type");
            let data: Value = row.get("data");
            (job_type, data)
        })
        .collect())
}

async fn job_exists(id: i64, pool: &PgPool) -> anyhow::Result<bool> {
    let result =
        sqlx::query_scalar::<_, Option<i64>>("SELECT id FROM background_jobs WHERE id = $1")
            .bind(id)
            .fetch_optional(pool)
            .await?;

    Ok(result.is_some())
}

async fn job_is_locked(id: i64, pool: &PgPool) -> anyhow::Result<bool> {
    let result = sqlx::query_scalar::<_, Option<i64>>(
        "SELECT id FROM background_jobs WHERE id = $1 FOR UPDATE SKIP LOCKED",
    )
    .bind(id)
    .fetch_optional(pool)
    .await?;

    Ok(result.is_none())
}

#[tokio::test]
async fn jobs_are_locked_when_fetched() -> anyhow::Result<()> {
    #[derive(Clone)]
    struct TestContext {
        job_started_barrier: Arc<Barrier>,
        assertions_finished_barrier: Arc<Barrier>,
    }

    #[derive(Serialize, Deserialize)]
    struct TestJob;

    impl BackgroundJob for TestJob {
        const JOB_NAME: &'static str = "test";
        type Context = TestContext;

        async fn run(&self, ctx: Self::Context) -> anyhow::Result<()> {
            ctx.job_started_barrier.wait().await;
            ctx.assertions_finished_barrier.wait().await;
            Ok(())
        }
    }

    let test_context = TestContext {
        job_started_barrier: Arc::new(Barrier::new(2)),
        assertions_finished_barrier: Arc::new(Barrier::new(2)),
    };

    let (pool, _container) = test_utils::setup_test_db().await?;

    let runner = test_utils::create_test_runner(pool.clone(), test_context.clone())
        .register_job_type::<TestJob>();

    let job_id = assert_some!(TestJob.enqueue(&pool).await?);

    assert!(job_exists(job_id, &pool).await?);
    assert!(!job_is_locked(job_id, &pool).await?);

    let runner = runner.start();
    test_context.job_started_barrier.wait().await;

    assert!(job_exists(job_id, &pool).await?);
    assert!(job_is_locked(job_id, &pool).await?);

    test_context.assertions_finished_barrier.wait().await;
    runner.wait_for_shutdown().await;

    assert!(!job_exists(job_id, &pool).await?);

    Ok(())
}

#[tokio::test]
async fn jobs_are_deleted_when_successfully_run() -> anyhow::Result<()> {
    #[derive(Serialize, Deserialize)]
    struct TestJob;

    impl BackgroundJob for TestJob {
        const JOB_NAME: &'static str = "test";
        type Context = ();

        async fn run(&self, _ctx: Self::Context) -> anyhow::Result<()> {
            Ok(())
        }
    }

    async fn remaining_jobs(pool: &PgPool) -> anyhow::Result<i64> {
        let count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM background_jobs")
            .fetch_one(pool)
            .await?;
        Ok(count)
    }

    let (pool, _container) = test_utils::setup_test_db().await?;

    let runner = test_utils::create_test_runner(pool.clone(), ()).register_job_type::<TestJob>();

    assert_eq!(remaining_jobs(&pool).await?, 0);

    TestJob.enqueue(&pool).await?;
    assert_eq!(remaining_jobs(&pool).await?, 1);

    let runner = runner.start();
    runner.wait_for_shutdown().await;
    assert_eq!(remaining_jobs(&pool).await?, 0);

    Ok(())
}

#[tokio::test]
async fn failed_jobs_do_not_release_lock_before_updating_retry_time() -> anyhow::Result<()> {
    #[derive(Clone)]
    struct TestContext {
        job_started_barrier: Arc<Barrier>,
    }

    #[derive(Serialize, Deserialize)]
    struct TestJob;

    impl BackgroundJob for TestJob {
        const JOB_NAME: &'static str = "test";
        type Context = TestContext;

        async fn run(&self, ctx: Self::Context) -> anyhow::Result<()> {
            ctx.job_started_barrier.wait().await;
            panic!();
        }
    }

    let test_context = TestContext {
        job_started_barrier: Arc::new(Barrier::new(2)),
    };

    let (pool, _container) = test_utils::setup_test_db().await?;

    let runner = test_utils::create_test_runner(pool.clone(), test_context.clone())
        .register_job_type::<TestJob>();

    TestJob.enqueue(&pool).await?;

    let runner = runner.start();
    test_context.job_started_barrier.wait().await;

    // `SKIP LOCKED` is intentionally omitted here, so we block until
    // the lock on the first job is released.
    // If there is any point where the row is unlocked, but the retry
    // count is not updated, we will get a row here.
    let available_jobs =
        sqlx::query_scalar::<_, i64>("SELECT id FROM background_jobs WHERE retries = 0 FOR UPDATE")
            .fetch_all(&pool)
            .await?;
    assert_eq!(available_jobs.len(), 0);

    // Sanity check to make sure the job actually is there
    let total_jobs_including_failed =
        sqlx::query_scalar::<_, i64>("SELECT id FROM background_jobs FOR UPDATE")
            .fetch_all(&pool)
            .await?;
    assert_eq!(total_jobs_including_failed.len(), 1);

    runner.wait_for_shutdown().await;

    Ok(())
}

#[tokio::test]
async fn panicking_in_jobs_updates_retry_counter() -> anyhow::Result<()> {
    #[derive(Serialize, Deserialize)]
    struct TestJob;

    impl BackgroundJob for TestJob {
        const JOB_NAME: &'static str = "test";
        type Context = ();

        async fn run(&self, _ctx: Self::Context) -> anyhow::Result<()> {
            panic!()
        }
    }

    let (pool, _container) = test_utils::setup_test_db().await?;

    let runner = test_utils::create_test_runner(pool.clone(), ()).register_job_type::<TestJob>();

    let job_id = assert_some!(TestJob.enqueue(&pool).await?);

    let runner = runner.start();
    runner.wait_for_shutdown().await;

    let tries = sqlx::query_scalar::<_, i32>(
        "SELECT retries FROM background_jobs WHERE id = $1 FOR UPDATE",
    )
    .bind(job_id)
    .fetch_one(&pool)
    .await?;
    assert_eq!(tries, 1);

    Ok(())
}

#[tokio::test]
async fn jobs_can_be_deduplicated() -> anyhow::Result<()> {
    #[derive(Clone)]
    struct TestContext {
        runs: Arc<AtomicU8>,
        job_started_barrier: Arc<Barrier>,
        assertions_finished_barrier: Arc<Barrier>,
    }

    #[derive(Serialize, Deserialize)]
    struct TestJob {
        value: String,
    }

    impl TestJob {
        fn new(value: impl Into<String>) -> Self {
            let value = value.into();
            Self { value }
        }
    }

    impl BackgroundJob for TestJob {
        const JOB_NAME: &'static str = "test";
        const DEDUPLICATED: bool = true;
        type Context = TestContext;

        async fn run(&self, ctx: Self::Context) -> anyhow::Result<()> {
            let runs = ctx.runs.fetch_add(1, Ordering::SeqCst);
            if runs == 0 {
                ctx.job_started_barrier.wait().await;
                ctx.assertions_finished_barrier.wait().await;
            }
            Ok(())
        }
    }

    let test_context = TestContext {
        runs: Arc::new(AtomicU8::new(0)),
        job_started_barrier: Arc::new(Barrier::new(2)),
        assertions_finished_barrier: Arc::new(Barrier::new(2)),
    };

    let (pool, _container) = test_utils::setup_test_db().await?;

    let runner = Runner::new(pool.clone(), test_context.clone())
        .register_job_type::<TestJob>()
        .shutdown_when_queue_empty();

    // Enqueue first job
    assert_some!(TestJob::new("foo").enqueue(&pool).await?);
    assert_compact_json_snapshot!(all_jobs(&pool).await?, @r#"[["test", {"value": "foo"}]]"#);

    // Try to enqueue the same job again, which should be deduplicated
    assert_none!(TestJob::new("foo").enqueue(&pool).await?);
    assert_compact_json_snapshot!(all_jobs(&pool).await?, @r#"[["test", {"value": "foo"}]]"#);

    // Start processing the first job
    let runner = runner.start();
    test_context.job_started_barrier.wait().await;

    // Enqueue the same job again, which should NOT be deduplicated,
    // since the first job already still running
    assert_some!(TestJob::new("foo").enqueue(&pool).await?);
    assert_compact_json_snapshot!(all_jobs(&pool).await?, @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}]]"#);

    // Try to enqueue the same job again, which should be deduplicated again
    assert_none!(TestJob::new("foo").enqueue(&pool).await?);
    assert_compact_json_snapshot!(all_jobs(&pool).await?, @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}]]"#);

    // Enqueue the same job but with different data, which should
    // NOT be deduplicated
    assert_some!(TestJob::new("bar").enqueue(&pool).await?);
    assert_compact_json_snapshot!(all_jobs(&pool).await?, @r#"[["test", {"value": "foo"}], ["test", {"value": "foo"}], ["test", {"value": "bar"}]]"#);

    // Resolve the final barrier to finish the test
    test_context.assertions_finished_barrier.wait().await;
    runner.wait_for_shutdown().await;

    Ok(())
}