use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use serde_json::Value;
use sqlx::types::chrono::{DateTime, Utc};
use sqlx::{PgPool, Row};
use tokio::sync::{Mutex, Notify};
use tokio::task::JoinHandle;
use super::{DeadLetterFn, HandlerRegistry, Job, JobDeadLetter, JobError, JobQueue};
pub struct PgJobQueue {
pool: PgPool,
registry: Arc<Mutex<HandlerRegistry>>,
workers: Mutex<Vec<JoinHandle<()>>>,
worker_count: usize,
dead_letter: Arc<Mutex<Option<DeadLetterFn>>>,
poll_interval: Duration,
shutdown: Arc<AtomicBool>,
notify: Arc<Notify>,
worker_id_prefix: String,
}
impl PgJobQueue {
#[must_use]
pub fn with_workers(pool: PgPool, worker_count: usize) -> Self {
let id_prefix = format!(
"host:{}:pid:{}",
hostname().unwrap_or_else(|| "unknown".into()),
std::process::id()
);
Self {
pool,
registry: Arc::new(Mutex::new(HandlerRegistry::default())),
workers: Mutex::new(Vec::new()),
worker_count,
dead_letter: Arc::new(Mutex::new(None)),
poll_interval: Duration::from_secs(1),
shutdown: Arc::new(AtomicBool::new(false)),
notify: Arc::new(Notify::new()),
worker_id_prefix: id_prefix,
}
}
#[must_use]
pub fn new(pool: PgPool) -> Self {
Self::with_workers(pool, 4)
}
#[must_use]
pub fn poll_interval(mut self, d: Duration) -> Self {
self.poll_interval = d;
self
}
pub async fn on_dead_letter<F, Fut>(&self, callback: F)
where
F: Fn(JobDeadLetter) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let boxed: DeadLetterFn = Arc::new(move |dl| Box::pin(callback(dl)));
*self.dead_letter.lock().await = Some(boxed);
}
pub async fn ensure_table(pool: &PgPool) -> Result<(), sqlx::Error> {
sqlx::query(
"CREATE TABLE IF NOT EXISTS rustango_jobs (
id BIGSERIAL PRIMARY KEY,
name TEXT NOT NULL,
payload JSONB NOT NULL,
attempt INTEGER NOT NULL DEFAULT 0,
max_attempts INTEGER NOT NULL,
run_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
locked_at TIMESTAMPTZ,
locked_by TEXT,
last_error TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)",
)
.execute(pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS rustango_jobs_pickup_idx
ON rustango_jobs (run_at)
WHERE locked_at IS NULL",
)
.execute(pool)
.await?;
Ok(())
}
pub async fn reclaim_stuck_jobs(
pool: &PgPool,
older_than: Duration,
) -> Result<u64, sqlx::Error> {
let secs = older_than.as_secs() as f64;
let res = sqlx::query(
"UPDATE rustango_jobs
SET locked_at = NULL, locked_by = NULL
WHERE locked_at IS NOT NULL
AND locked_at < NOW() - ($1 || ' seconds')::INTERVAL",
)
.bind(format!("{secs:.3}"))
.execute(pool)
.await?;
Ok(res.rows_affected())
}
}
#[async_trait::async_trait]
impl JobQueue for PgJobQueue {
async fn register<T: Job>(&self) {
self.registry.lock().await.register::<T>();
}
async fn dispatch<T: Job>(&self, payload: &T) -> Result<(), JobError> {
let value = serde_json::to_value(payload).map_err(|e| JobError::Queue(e.to_string()))?;
sqlx::query(
"INSERT INTO rustango_jobs (name, payload, max_attempts)
VALUES ($1, $2, $3)",
)
.bind(T::NAME)
.bind(&value)
.bind(i32::try_from(T::MAX_ATTEMPTS).unwrap_or(i32::MAX))
.execute(&self.pool)
.await
.map_err(|e| JobError::Queue(e.to_string()))?;
self.notify.notify_one();
Ok(())
}
async fn start(&self) {
let mut workers = self.workers.lock().await;
if !workers.is_empty() {
return;
}
for n in 0..self.worker_count {
let pool = self.pool.clone();
let registry = self.registry.clone();
let dead_letter = self.dead_letter.clone();
let shutdown = self.shutdown.clone();
let notify = self.notify.clone();
let poll = self.poll_interval;
let worker_id = format!("{}:w{}", self.worker_id_prefix, n);
let h = tokio::spawn(async move {
worker_loop(
pool,
registry,
dead_letter,
shutdown,
notify,
poll,
worker_id,
)
.await;
});
workers.push(h);
}
}
async fn shutdown(&self) {
self.shutdown.store(true, Ordering::SeqCst);
self.notify.notify_waiters();
let mut workers = self.workers.lock().await;
for h in workers.drain(..) {
let _ = tokio::time::timeout(Duration::from_secs(5), h).await;
}
}
async fn pending_count(&self) -> usize {
sqlx::query("SELECT COUNT(*)::BIGINT AS n FROM rustango_jobs WHERE locked_at IS NULL")
.fetch_one(&self.pool)
.await
.ok()
.and_then(|row| row.try_get::<i64, _>("n").ok())
.map_or(0, |n| usize::try_from(n).unwrap_or(0))
}
}
#[allow(clippy::too_many_arguments)]
async fn worker_loop(
pool: PgPool,
registry: Arc<Mutex<HandlerRegistry>>,
dead_letter: Arc<Mutex<Option<DeadLetterFn>>>,
shutdown: Arc<AtomicBool>,
notify: Arc<Notify>,
poll_interval: Duration,
worker_id: String,
) {
while !shutdown.load(Ordering::SeqCst) {
match pick_one(&pool, &worker_id).await {
Ok(Some(row)) => {
run_one(&pool, ®istry, &dead_letter, row).await;
}
Ok(None) => {
tokio::select! {
() = tokio::time::sleep(poll_interval) => {}
() = notify.notified() => {}
}
}
Err(e) => {
tracing::error!(error = %e, "PgJobQueue pickup failed");
tokio::time::sleep(poll_interval).await;
}
}
}
}
#[derive(Debug)]
struct PickedJob {
id: i64,
name: String,
payload: Value,
attempt: i32,
max_attempts: i32,
}
async fn pick_one(pool: &PgPool, worker_id: &str) -> Result<Option<PickedJob>, sqlx::Error> {
let row = sqlx::query(
"WITH next AS (
SELECT id FROM rustango_jobs
WHERE locked_at IS NULL AND run_at <= NOW()
ORDER BY run_at, id
FOR UPDATE SKIP LOCKED
LIMIT 1
)
UPDATE rustango_jobs
SET locked_at = NOW(), locked_by = $1
WHERE id IN (SELECT id FROM next)
RETURNING id, name, payload, attempt, max_attempts",
)
.bind(worker_id)
.fetch_optional(pool)
.await?;
let Some(row) = row else { return Ok(None) };
Ok(Some(PickedJob {
id: row.try_get("id")?,
name: row.try_get("name")?,
payload: row.try_get("payload")?,
attempt: row.try_get("attempt")?,
max_attempts: row.try_get("max_attempts")?,
}))
}
async fn run_one(
pool: &PgPool,
registry: &Arc<Mutex<HandlerRegistry>>,
dead_letter: &Arc<Mutex<Option<DeadLetterFn>>>,
job: PickedJob,
) {
let handler = registry.lock().await.lookup_owned(&job.name);
let Some((handler, static_name)) = handler else {
tracing::warn!(job = %job.name, id = job.id, "no handler registered — leaving locked");
return;
};
let result = handler(job.payload.clone()).await;
match result {
Ok(()) => {
let _ = sqlx::query("DELETE FROM rustango_jobs WHERE id = $1")
.bind(job.id)
.execute(pool)
.await;
}
Err(JobError::Retryable(msg)) => {
let next_attempt = job.attempt + 1;
if next_attempt >= job.max_attempts {
handle_dead_letter(pool, dead_letter, &job, static_name, &msg).await;
} else {
let backoff_ms = 1000u64.saturating_mul(1u64 << (next_attempt as u32).min(10));
let _ = sqlx::query(
"UPDATE rustango_jobs
SET attempt = $1,
run_at = NOW() + ($2 || ' milliseconds')::INTERVAL,
locked_at = NULL,
locked_by = NULL,
last_error = $3
WHERE id = $4",
)
.bind(next_attempt)
.bind(backoff_ms.to_string())
.bind(&msg)
.bind(job.id)
.execute(pool)
.await;
}
}
Err(e @ (JobError::Fatal(_) | JobError::Queue(_))) => {
let msg = e.to_string();
handle_dead_letter(pool, dead_letter, &job, static_name, &msg).await;
}
}
}
async fn handle_dead_letter(
pool: &PgPool,
dead_letter: &Arc<Mutex<Option<DeadLetterFn>>>,
job: &PickedJob,
static_name: &'static str,
error: &str,
) {
let cb = dead_letter.lock().await.clone();
if let Some(cb) = cb {
cb(JobDeadLetter {
name: static_name,
payload: job.payload.clone(),
attempts: u32::try_from(job.attempt + 1).unwrap_or(0),
error: error.to_owned(),
})
.await;
} else {
tracing::error!(
job = static_name,
attempts = job.attempt + 1,
error,
"PgJobQueue dead-letter (no callback configured)"
);
}
let _ = sqlx::query("DELETE FROM rustango_jobs WHERE id = $1")
.bind(job.id)
.execute(pool)
.await;
}
impl HandlerRegistry {
fn lookup_owned(&self, name: &str) -> Option<(super::HandlerFn, &'static str)> {
let (handler, _) = self.handlers.get(name)?;
let static_name = self.handlers.keys().find(|k| **k == name).copied()?;
Some((handler.clone(), static_name))
}
}
fn hostname() -> Option<String> {
std::env::var("HOSTNAME").ok().filter(|s| !s.is_empty())
}
const _: fn() = || {
let _: Option<DateTime<Utc>> = None;
};
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn worker_id_prefix_includes_pid() {
let q = PgJobQueue::with_workers(dummy_pool(), 0);
assert!(q.worker_id_prefix.contains("pid:"));
}
#[tokio::test]
async fn poll_interval_is_overridable() {
let q = PgJobQueue::with_workers(dummy_pool(), 0).poll_interval(Duration::from_millis(250));
assert_eq!(q.poll_interval, Duration::from_millis(250));
}
#[tokio::test]
async fn dead_letter_callback_can_be_set() {
let q = PgJobQueue::with_workers(dummy_pool(), 0);
q.on_dead_letter(|_| async {}).await;
assert!(q.dead_letter.lock().await.is_some());
}
#[tokio::test]
async fn register_lookups_handler_under_static_name() {
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
struct Demo;
#[async_trait::async_trait]
impl Job for Demo {
const NAME: &'static str = "demo:job";
async fn run(&self) -> Result<(), JobError> {
Ok(())
}
}
let q = PgJobQueue::with_workers(dummy_pool(), 0);
q.register::<Demo>().await;
let r = q.registry.lock().await.lookup_owned("demo:job");
assert!(r.is_some());
let (_, name) = r.unwrap();
assert_eq!(name, "demo:job");
}
#[tokio::test]
async fn register_lookup_returns_none_for_unknown_name() {
let q = PgJobQueue::with_workers(dummy_pool(), 0);
assert!(q.registry.lock().await.lookup_owned("unknown").is_none());
}
fn dummy_pool() -> PgPool {
sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost:1/none")
.expect("lazy pool")
}
}