use crate::{Error, Job};
use async_trait::async_trait;
use chrono::Utc;
use futures::FutureExt;
use sea_orm::DatabaseConnection;
use std::collections::HashMap;
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
use tracing::{debug, error, info, warn};
#[async_trait]
pub trait TenantScopeProvider: Send + Sync {
async fn with_scope(
&self,
tenant_id: i64,
f: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
) -> Result<(), Error>;
}
#[derive(Debug, Clone)]
pub struct WorkerConfig {
pub queues: Vec<String>,
pub max_jobs: usize,
pub sleep_duration: Duration,
pub stop_on_error: bool,
pub visibility_timeout: Duration,
}
impl Default for WorkerConfig {
fn default() -> Self {
Self {
queues: vec!["default".to_string()],
max_jobs: 10,
sleep_duration: Duration::from_secs(1),
stop_on_error: false,
visibility_timeout: Duration::from_secs(300),
}
}
}
impl WorkerConfig {
pub fn new(queues: Vec<String>) -> Self {
Self {
queues,
..Default::default()
}
}
pub fn max_jobs(mut self, max: usize) -> Self {
self.max_jobs = max;
self
}
pub fn with_visibility_timeout(mut self, d: Duration) -> Self {
self.visibility_timeout = d;
self
}
}
type JobHandler = Arc<
dyn Fn(String, u32) -> Pin<Box<dyn Future<Output = (Result<(), Error>, Duration)> + Send>>
+ Send
+ Sync,
>;
pub struct WorkerLoop {
config: WorkerConfig,
handlers: HashMap<String, JobHandler>,
semaphore: Arc<Semaphore>,
shutdown: Arc<AtomicBool>,
tenant_scope: Option<Arc<dyn TenantScopeProvider>>,
worker_id: String,
}
impl WorkerLoop {
pub fn new(config: WorkerConfig) -> Self {
let semaphore = Arc::new(Semaphore::new(config.max_jobs));
Self {
config,
handlers: HashMap::new(),
semaphore,
shutdown: Arc::new(AtomicBool::new(false)),
tenant_scope: None,
worker_id: uuid::Uuid::new_v4().to_string(),
}
}
pub fn with_tenant_scope(mut self, provider: Arc<dyn TenantScopeProvider>) -> Self {
self.tenant_scope = Some(provider);
self
}
pub fn register<J>(&mut self)
where
J: Job + serde::de::DeserializeOwned + 'static,
{
let type_name = std::any::type_name::<J>().to_string();
let handler: JobHandler = Arc::new(move |data: String, attempt: u32| {
Box::pin(async move {
let job: J = match serde_json::from_str::<J>(&data) {
Ok(j) => j,
Err(e) => {
return (
Err(Error::DeserializationFailed(e.to_string())),
Duration::from_secs(5),
)
}
};
let delay = job.retry_delay(attempt);
let result = job.handle().await;
(result, delay)
})
});
self.handlers.insert(type_name, handler);
}
pub fn from_registry(config: WorkerConfig) -> Self {
let mut w = Self::new(config);
crate::db::Queue::apply_registrars(&mut w);
w
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::SeqCst);
}
pub async fn run(&self) -> Result<(), Error> {
let conn: &'static DatabaseConnection = crate::db::Queue::connection();
info!(
worker_id = %self.worker_id,
queues = ?self.config.queues,
max_jobs = self.config.max_jobs,
"WorkerLoop starting"
);
let signal_task = {
let shutdown = self.shutdown.clone();
tokio::spawn(async move {
let mut sigterm = match tokio::signal::unix::signal(
tokio::signal::unix::SignalKind::terminate(),
) {
Ok(s) => s,
Err(e) => {
error!(error = %e, "failed to install SIGTERM handler — requesting shutdown");
shutdown.store(true, Ordering::SeqCst);
return;
}
};
tokio::select! {
_ = sigterm.recv() => {
info!("SIGTERM received — shutting down WorkerLoop");
}
_ = tokio::signal::ctrl_c() => {
info!("Ctrl-C received — shutting down WorkerLoop");
}
}
shutdown.store(true, Ordering::SeqCst);
})
};
let _signal_guard = AbortOnDrop(signal_task);
'outer: loop {
if self.shutdown.load(Ordering::SeqCst) {
info!(worker_id = %self.worker_id, "Shutdown flag set — draining in-flight jobs");
let _drain_guard = self
.semaphore
.acquire_many(self.config.max_jobs as u32)
.await;
crate::db::requeue_claimed_by(conn, &self.worker_id)
.await
.map_err(|e| {
error!(error = %e, "requeue_claimed_by failed during shutdown");
e
})?;
info!(worker_id = %self.worker_id, "WorkerLoop shut down cleanly");
return Ok(());
}
for queue in &self.config.queues {
match crate::db::reaper(conn, queue, self.config.visibility_timeout).await {
Ok(()) => {}
Err(e) => {
error!(queue = %queue, error = %e, "reaper error");
if self.config.stop_on_error {
return Err(e);
}
}
}
match crate::db::claim(conn, queue, &self.worker_id).await {
Ok(Some(job_row)) => {
self.spawn_job(conn, job_row);
continue 'outer; }
Ok(None) => {} Err(e) => {
error!(queue = %queue, error = %e, "claim error");
if self.config.stop_on_error {
return Err(e);
}
}
}
}
tokio::time::sleep(self.config.sleep_duration).await;
}
}
fn spawn_job(&self, conn: &'static DatabaseConnection, job_row: crate::db::JobRow) {
let permit = self.semaphore.clone();
let handlers = self.handlers.clone();
let tenant_scope = self.tenant_scope.clone();
let worker_id = self.worker_id.clone();
let shutdown = self.shutdown.clone();
tokio::spawn(async move {
if shutdown.load(Ordering::SeqCst) {
return;
}
let _permit = permit.acquire_owned().await.expect("semaphore closed");
if shutdown.load(Ordering::SeqCst) {
return;
}
let job_id = job_row.id;
let job_type = job_row.job_type.clone();
let tenant_id = job_row.tenant_id;
let attempts = job_row.attempts;
let max_retries = job_row.max_retries;
debug!(
job_id = %job_id,
job_type = %job_type,
attempts = attempts,
tenant_id = ?tenant_id,
worker_id = %worker_id,
"Executing job"
);
let handler = match handlers.get(&job_type) {
Some(h) => h.clone(),
None => {
warn!(job_type = %job_type, "No handler registered — releasing job for retry");
let available_at = Utc::now()
+ chrono::Duration::from_std(Duration::from_secs(5)).unwrap_or_default();
crate::db::release_job(conn, job_id, attempts + 1, available_at)
.await
.ok();
return;
}
};
let result = AssertUnwindSafe(async move {
match (&tenant_scope, tenant_id) {
(Some(scope), Some(id)) => {
let fut = Box::pin(async move {
let (res, _delay) = handler(job_row.payload.clone(), attempts).await;
res
});
(scope.with_scope(id, fut).await, Duration::from_secs(5))
}
_ => handler(job_row.payload.clone(), attempts).await,
}
})
.catch_unwind()
.await;
match result {
Ok((Ok(()), _)) => {
debug!(job_id = %job_id, job_type = %job_type, "Job succeeded — deleting row");
crate::db::delete_job(conn, job_id).await.ok();
}
Ok((Err(e), retry_delay)) => {
error!(job_id = %job_id, job_type = %job_type, error = %e, "Job handler returned error");
handle_failure(
conn,
job_id,
attempts,
max_retries,
&e.to_string(),
retry_delay,
)
.await;
}
Err(_panic) => {
error!(job_id = %job_id, job_type = %job_type, "Job handler panicked — counting as failure");
let msg = "job handler panicked";
let delay = default_jitter_delay(attempts);
handle_failure(conn, job_id, attempts, max_retries, msg, delay).await;
}
}
});
}
}
async fn handle_failure(
conn: &'static DatabaseConnection,
job_id: i64,
attempts: u32,
max_retries: u32,
err_msg: &str,
retry_delay: Duration,
) {
if attempts + 1 >= max_retries {
warn!(job_id = %job_id, attempts = attempts, "Job exhausted retries — parking as failed");
crate::db::fail_job(conn, job_id, err_msg).await.ok();
} else {
let available_at = Utc::now() + chrono::Duration::from_std(retry_delay).unwrap_or_default();
debug!(
job_id = %job_id,
retry_at = %available_at,
"Scheduling job retry"
);
crate::db::release_job(conn, job_id, attempts + 1, available_at)
.await
.ok();
}
}
fn default_jitter_delay(attempt: u32) -> Duration {
use rand::Rng;
let base_secs: u64 = 5;
let cap_secs: u64 = 15 * 60;
let max_delay = cap_secs.min(base_secs.saturating_mul(2u64.saturating_pow(attempt)));
let jitter = rand::thread_rng().gen_range(0..=max_delay);
Duration::from_secs(jitter)
}
struct AbortOnDrop(tokio::task::JoinHandle<()>);
impl Drop for AbortOnDrop {
fn drop(&mut self) {
self.0.abort();
}
}
pub type Worker = WorkerLoop;
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
#[test]
fn test_tenant_scope_provider_is_object_safe() {
struct NoopProvider;
#[async_trait]
impl TenantScopeProvider for NoopProvider {
async fn with_scope(
&self,
_tenant_id: i64,
f: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
) -> Result<(), Error> {
f.await
}
}
let _provider: Arc<dyn TenantScopeProvider> = Arc::new(NoopProvider);
}
struct MockScopeProvider {
called_with: Arc<Mutex<Vec<i64>>>,
should_fail: bool,
}
impl MockScopeProvider {
fn new() -> Self {
Self {
called_with: Arc::new(Mutex::new(Vec::new())),
should_fail: false,
}
}
fn failing() -> Self {
Self {
called_with: Arc::new(Mutex::new(Vec::new())),
should_fail: true,
}
}
}
#[async_trait]
impl TenantScopeProvider for MockScopeProvider {
async fn with_scope(
&self,
tenant_id: i64,
f: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
) -> Result<(), Error> {
self.called_with.lock().unwrap().push(tenant_id);
if self.should_fail {
return Err(Error::tenant_not_found(tenant_id));
}
f.await
}
}
#[test]
fn test_worker_loop_new() {
let w = WorkerLoop::new(WorkerConfig::default());
assert!(w.tenant_scope.is_none());
assert!(!w.worker_id.is_empty());
}
#[test]
fn test_with_tenant_scope_stores_provider() {
let w = WorkerLoop::new(WorkerConfig::default());
let provider = Arc::new(MockScopeProvider::new());
let w = w.with_tenant_scope(provider);
assert!(w.tenant_scope.is_some());
}
#[test]
fn test_worker_without_scope_has_none_by_default() {
let w = WorkerLoop::new(WorkerConfig::default());
assert!(w.tenant_scope.is_none());
}
#[tokio::test]
async fn test_mock_scope_provider_calls_future() {
let provider = MockScopeProvider::new();
let calls = provider.called_with.clone();
let result = provider.with_scope(42, Box::pin(async { Ok(()) })).await;
assert!(result.is_ok());
assert_eq!(calls.lock().unwrap().as_slice(), &[42]);
}
#[tokio::test]
async fn test_mock_scope_provider_failure_returns_tenant_not_found() {
let provider = MockScopeProvider::failing();
let result = provider.with_scope(99, Box::pin(async { Ok(()) })).await;
assert!(matches!(
result,
Err(Error::TenantNotFound { tenant_id: 99 })
));
}
#[tokio::test]
async fn test_scope_dispatch_tenant_id_some_calls_with_scope() {
let mock = MockScopeProvider::new();
let calls = mock.called_with.clone();
let provider: Arc<dyn TenantScopeProvider> = Arc::new(mock);
let tenant_id: Option<i64> = Some(1);
let tenant_scope: Option<Arc<dyn TenantScopeProvider>> = Some(provider);
let job_ran = Arc::new(Mutex::new(false));
let job_ran_clone = job_ran.clone();
let job_fut = Box::pin(async move {
*job_ran_clone.lock().unwrap() = true;
Ok(())
});
let result = match (&tenant_scope, tenant_id) {
(Some(scope), Some(id)) => scope.with_scope(id, job_fut).await,
_ => job_fut.await,
};
assert!(result.is_ok());
assert_eq!(calls.lock().unwrap().as_slice(), &[1i64]);
assert!(*job_ran.lock().unwrap(), "job future must have been called");
}
#[tokio::test]
async fn test_scope_dispatch_tenant_id_none_skips_with_scope() {
let mock = MockScopeProvider::new();
let calls = mock.called_with.clone();
let provider: Arc<dyn TenantScopeProvider> = Arc::new(mock);
let tenant_id: Option<i64> = None;
let tenant_scope: Option<Arc<dyn TenantScopeProvider>> = Some(provider);
let job_ran = Arc::new(Mutex::new(false));
let job_ran_clone = job_ran.clone();
let job_fut = Box::pin(async move {
*job_ran_clone.lock().unwrap() = true;
Ok(())
});
let result = match (&tenant_scope, tenant_id) {
(Some(scope), Some(id)) => scope.with_scope(id, job_fut).await,
_ => job_fut.await,
};
assert!(result.is_ok());
assert!(
calls.lock().unwrap().is_empty(),
"with_scope must not be called when tenant_id is None"
);
assert!(
*job_ran.lock().unwrap(),
"job future must still run directly"
);
}
#[tokio::test]
async fn test_scope_dispatch_no_provider_runs_job_directly() {
let tenant_id: Option<i64> = Some(1);
let tenant_scope: Option<Arc<dyn TenantScopeProvider>> = None;
let job_ran = Arc::new(Mutex::new(false));
let job_ran_clone = job_ran.clone();
let job_fut = Box::pin(async move {
*job_ran_clone.lock().unwrap() = true;
Ok(())
});
let result = match (&tenant_scope, tenant_id) {
(Some(scope), Some(id)) => scope.with_scope(id, job_fut).await,
_ => job_fut.await,
};
assert!(result.is_ok());
assert!(
*job_ran.lock().unwrap(),
"job must run directly without a provider"
);
}
#[test]
fn test_shutdown_sets_flag() {
let w = WorkerLoop::new(WorkerConfig::default());
assert!(!w.shutdown.load(Ordering::SeqCst));
w.shutdown();
assert!(w.shutdown.load(Ordering::SeqCst));
}
#[test]
fn test_worker_config_visibility_timeout_default() {
let c = WorkerConfig::default();
assert_eq!(c.visibility_timeout, Duration::from_secs(300));
}
#[test]
fn test_default_jitter_delay_bounds() {
for _ in 0..50 {
assert!(default_jitter_delay(0).as_secs() <= 5);
assert!(default_jitter_delay(3).as_secs() <= 40);
assert!(default_jitter_delay(30).as_secs() <= 900);
}
}
}