#[cfg(feature = "jobs-postgres")]
pub mod pg;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::{mpsc, Mutex};
use tokio::task::JoinHandle;
#[derive(Debug, thiserror::Error)]
pub enum JobError {
#[error("retryable: {0}")]
Retryable(String),
#[error("fatal: {0}")]
Fatal(String),
#[error("queue error: {0}")]
Queue(String),
}
#[async_trait::async_trait]
pub trait Job: Send + Sync + Sized + Serialize + DeserializeOwned + 'static {
const NAME: &'static str;
const MAX_ATTEMPTS: u32 = 5;
async fn run(&self) -> Result<(), JobError>;
}
#[async_trait::async_trait]
pub trait JobQueue: Send + Sync + 'static {
async fn register<T: Job>(&self);
async fn dispatch<T: Job>(&self, payload: &T) -> Result<(), JobError>;
async fn start(&self);
async fn shutdown(&self);
async fn pending_count(&self) -> usize;
}
#[derive(Debug, Clone)]
struct JobEnvelope {
name: &'static str,
payload: serde_json::Value,
attempt: u32,
max_attempts: u32,
}
type HandlerFn = Arc<
dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<(), JobError>> + Send>>
+ Send
+ Sync,
>;
#[derive(Default)]
struct HandlerRegistry {
handlers: HashMap<&'static str, (HandlerFn, u32)>, }
impl HandlerRegistry {
fn register<T: Job>(&mut self) {
let handler: HandlerFn = Arc::new(move |payload| {
Box::pin(async move {
let job: T =
serde_json::from_value(payload).map_err(|e| JobError::Queue(e.to_string()))?;
job.run().await
})
});
self.handlers.insert(T::NAME, (handler, T::MAX_ATTEMPTS));
}
fn lookup(&self, name: &str) -> Option<(HandlerFn, u32)> {
self.handlers.get(name).cloned()
}
}
pub type DeadLetterFn =
Arc<dyn Fn(JobDeadLetter) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
#[derive(Debug, Clone)]
pub struct JobDeadLetter {
pub name: &'static str,
pub payload: serde_json::Value,
pub attempts: u32,
pub error: String,
}
pub struct InMemoryJobQueue {
tx: mpsc::UnboundedSender<JobEnvelope>,
rx: Mutex<Option<mpsc::UnboundedReceiver<JobEnvelope>>>,
registry: Arc<Mutex<HandlerRegistry>>,
workers: Mutex<Vec<JoinHandle<()>>>,
worker_count: usize,
dead_letter: Arc<Mutex<Option<DeadLetterFn>>>,
pending: Arc<std::sync::atomic::AtomicUsize>,
}
impl InMemoryJobQueue {
#[must_use]
pub fn with_workers(worker_count: usize) -> Self {
let (tx, rx) = mpsc::unbounded_channel();
Self {
tx,
rx: Mutex::new(Some(rx)),
registry: Arc::new(Mutex::new(HandlerRegistry::default())),
workers: Mutex::new(Vec::new()),
worker_count,
dead_letter: Arc::new(Mutex::new(None)),
pending: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
}
}
#[must_use]
pub fn new() -> Self {
Self::with_workers(4)
}
pub async fn on_dead_letter<F, Fut>(&self, callback: F)
where
F: Fn(JobDeadLetter) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let boxed: DeadLetterFn = Arc::new(move |dl| Box::pin(callback(dl)));
*self.dead_letter.lock().await = Some(boxed);
}
}
impl Default for InMemoryJobQueue {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "config")]
#[must_use]
pub fn inmemory_from_settings(s: &crate::config::JobsSettings) -> Arc<InMemoryJobQueue> {
let workers = s.concurrency.map_or(4, |c| c as usize);
if let Some(backend) = s.backend.as_deref() {
if backend != "memory" {
tracing::warn!(
target: "rustango::jobs",
backend = %backend,
"jobs.backend = `{backend}` but inmemory_from_settings only builds InMemoryJobQueue. \
Wire the desired backend directly via Arc::new(...). See the docstring."
);
}
}
Arc::new(InMemoryJobQueue::with_workers(workers))
}
#[async_trait::async_trait]
impl JobQueue for InMemoryJobQueue {
async fn register<T: Job>(&self) {
self.registry.lock().await.register::<T>();
}
async fn dispatch<T: Job>(&self, payload: &T) -> Result<(), JobError> {
let value = serde_json::to_value(payload).map_err(|e| JobError::Queue(e.to_string()))?;
let envelope = JobEnvelope {
name: T::NAME,
payload: value,
attempt: 0,
max_attempts: T::MAX_ATTEMPTS,
};
self.pending
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.tx
.send(envelope)
.map_err(|e| JobError::Queue(e.to_string()))?;
Ok(())
}
async fn start(&self) {
let mut workers = self.workers.lock().await;
if !workers.is_empty() {
return; }
let mut rx_guard = self.rx.lock().await;
let rx = rx_guard
.take()
.expect("queue already started without workers");
let rx = Arc::new(Mutex::new(rx));
for _ in 0..self.worker_count {
let rx = rx.clone();
let registry = self.registry.clone();
let dead_letter = self.dead_letter.clone();
let pending = self.pending.clone();
let tx_for_retries = self.tx.clone();
let handle = tokio::spawn(async move {
worker_loop(rx, registry, dead_letter, pending, tx_for_retries).await;
});
workers.push(handle);
}
}
async fn shutdown(&self) {
let mut workers = self.workers.lock().await;
for h in workers.drain(..) {
h.abort();
let _ = h.await;
}
}
async fn pending_count(&self) -> usize {
self.pending.load(std::sync::atomic::Ordering::SeqCst)
}
}
async fn worker_loop(
rx: Arc<Mutex<mpsc::UnboundedReceiver<JobEnvelope>>>,
registry: Arc<Mutex<HandlerRegistry>>,
dead_letter: Arc<Mutex<Option<DeadLetterFn>>>,
pending: Arc<std::sync::atomic::AtomicUsize>,
tx: mpsc::UnboundedSender<JobEnvelope>,
) {
loop {
let envelope = {
let mut rx_guard = rx.lock().await;
match rx_guard.recv().await {
Some(e) => e,
None => return, }
};
let handler = {
let reg = registry.lock().await;
reg.lookup(envelope.name)
};
let Some((handler, _max_attempts)) = handler else {
tracing::error!(job = envelope.name, "no handler registered");
pending.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
continue;
};
let payload = envelope.payload.clone();
let result = handler(payload).await;
match result {
Ok(()) => {
pending.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
}
Err(JobError::Retryable(msg)) => {
let next_attempt = envelope.attempt + 1;
if next_attempt >= envelope.max_attempts {
let dl_callback = dead_letter.lock().await.clone();
if let Some(cb) = dl_callback {
cb(JobDeadLetter {
name: envelope.name,
payload: envelope.payload.clone(),
attempts: next_attempt,
error: msg,
})
.await;
} else {
tracing::error!(job = envelope.name, attempts = next_attempt, error = %msg, "job exhausted retries");
}
pending.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
} else {
let backoff_ms = 1000u64.saturating_mul(1u64 << next_attempt.min(10));
let mut retry = envelope.clone();
retry.attempt = next_attempt;
let tx = tx.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
let _ = tx.send(retry);
});
}
}
Err(e @ (JobError::Fatal(_) | JobError::Queue(_))) => {
let msg = e.to_string();
let dl_callback = dead_letter.lock().await.clone();
if let Some(cb) = dl_callback {
cb(JobDeadLetter {
name: envelope.name,
payload: envelope.payload.clone(),
attempts: envelope.attempt + 1,
error: msg,
})
.await;
} else {
tracing::error!(job = envelope.name, error = %msg, "job fatal");
}
pending.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Serialize, Deserialize, Debug)]
struct Increment;
static COUNTER: AtomicUsize = AtomicUsize::new(0);
#[async_trait::async_trait]
impl Job for Increment {
const NAME: &'static str = "test:increment";
async fn run(&self) -> Result<(), JobError> {
COUNTER.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
#[derive(Serialize, Deserialize, Debug)]
struct AlwaysFail {
fatal: bool,
}
#[async_trait::async_trait]
impl Job for AlwaysFail {
const NAME: &'static str = "test:always_fail";
const MAX_ATTEMPTS: u32 = 2;
async fn run(&self) -> Result<(), JobError> {
if self.fatal {
Err(JobError::Fatal("dead now".into()))
} else {
Err(JobError::Retryable("transient".into()))
}
}
}
#[derive(Serialize, Deserialize, Debug)]
struct EventuallyOk {
fail_n: u32,
success_marker_id: u64,
}
static SUCCESSES: std::sync::Mutex<Vec<u64>> = std::sync::Mutex::new(Vec::new());
static ATTEMPTS: AtomicUsize = AtomicUsize::new(0);
#[async_trait::async_trait]
impl Job for EventuallyOk {
const NAME: &'static str = "test:eventually_ok";
const MAX_ATTEMPTS: u32 = 5;
async fn run(&self) -> Result<(), JobError> {
let n = ATTEMPTS.fetch_add(1, Ordering::SeqCst);
if (n as u32) < self.fail_n {
Err(JobError::Retryable(format!("attempt {n}")))
} else {
SUCCESSES.lock().unwrap().push(self.success_marker_id);
Ok(())
}
}
}
#[tokio::test]
async fn dispatch_runs_handler() {
COUNTER.store(0, Ordering::SeqCst);
let q = InMemoryJobQueue::with_workers(2);
q.register::<Increment>().await;
q.start().await;
q.dispatch(&Increment).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
q.shutdown().await;
}
#[tokio::test]
async fn fatal_error_goes_to_dead_letter() {
let q = InMemoryJobQueue::with_workers(1);
q.register::<AlwaysFail>().await;
let captured: Arc<Mutex<Vec<JobDeadLetter>>> = Arc::new(Mutex::new(Vec::new()));
let cap = captured.clone();
q.on_dead_letter(move |dl| {
let cap = cap.clone();
async move {
cap.lock().await.push(dl);
}
})
.await;
q.start().await;
q.dispatch(&AlwaysFail { fatal: true }).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(captured.lock().await.len(), 1);
assert!(captured.lock().await[0].error.contains("dead now"));
q.shutdown().await;
}
#[tokio::test]
async fn retryable_succeeds_eventually() {
ATTEMPTS.store(0, Ordering::SeqCst);
let q = InMemoryJobQueue::with_workers(1);
q.register::<EventuallyOk>().await;
q.start().await;
let marker = 12345;
q.dispatch(&EventuallyOk {
fail_n: 2,
success_marker_id: marker,
})
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(7000)).await;
let succ = SUCCESSES.lock().unwrap();
assert!(succ.contains(&marker), "expected marker, got {succ:?}");
drop(succ);
q.shutdown().await;
}
#[tokio::test]
async fn unknown_job_is_logged_not_panic() {
#[derive(Serialize, Deserialize)]
struct UnregisteredJob;
#[async_trait::async_trait]
impl Job for UnregisteredJob {
const NAME: &'static str = "test:unregistered";
async fn run(&self) -> Result<(), JobError> {
Ok(())
}
}
let q = InMemoryJobQueue::with_workers(1);
q.start().await;
q.dispatch(&UnregisteredJob).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
q.shutdown().await;
}
#[tokio::test]
async fn pending_count_tracks_in_flight() {
let q = InMemoryJobQueue::with_workers(0); q.register::<Increment>().await;
for _ in 0..3 {
q.dispatch(&Increment).await.unwrap();
}
assert_eq!(q.pending_count().await, 3);
}
#[cfg(feature = "config")]
#[test]
fn inmemory_from_settings_uses_configured_concurrency() {
let mut s = crate::config::JobsSettings::default();
s.concurrency = Some(8);
let q = inmemory_from_settings(&s);
assert_eq!(q.worker_count, 8);
}
#[cfg(feature = "config")]
#[test]
fn inmemory_from_settings_defaults_to_four_workers() {
let s = crate::config::JobsSettings::default();
let q = inmemory_from_settings(&s);
assert_eq!(q.worker_count, 4);
}
}