use std::future::Future;
use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, FromQueryResult, Statement, Value};
use tracing::Instrument;
use crate::jobs::retry::{apply_failure, apply_success};
use crate::jobs::{JobDescriptor, JobResult, JobRow, RetryPolicy};
use crate::state::AppState;
#[derive(Debug, Clone)]
pub struct JobConfig {
pub poll_interval: Duration,
pub batch_size: i32,
pub queues: Vec<String>,
pub job_timeout: Duration,
}
impl Default for JobConfig {
fn default() -> Self {
Self {
poll_interval: Duration::from_secs(5),
batch_size: 10,
queues: vec!["default".to_string()],
job_timeout: Duration::from_secs(30),
}
}
}
impl JobConfig {
pub fn poll_interval(mut self, interval: Duration) -> Self {
self.poll_interval = interval;
self
}
pub fn batch_size(mut self, size: i32) -> Self {
self.batch_size = size;
self
}
pub fn queues(mut self, queues: impl IntoIterator<Item = impl Into<String>>) -> Self {
let q: Vec<String> = queues.into_iter().map(Into::into).collect();
assert!(!q.is_empty(), "queues must not be empty");
self.queues = q;
self
}
pub fn job_timeout(mut self, timeout: Duration) -> Self {
self.job_timeout = timeout;
self
}
}
pub(crate) struct Worker {
state: Arc<AppState>,
config: JobConfig,
}
impl Worker {
pub(crate) fn new(state: Arc<AppState>, config: JobConfig) -> Self {
Self { state, config }
}
pub(crate) async fn run(self) {
let mut ctrl_c = pin!(tokio::signal::ctrl_c());
let mut sigterm: std::pin::Pin<Box<dyn Future<Output = ()> + Send>> = Box::pin(async {
#[cfg(unix)]
{
use tokio::signal::unix::SignalKind;
tokio::signal::unix::signal(SignalKind::terminate())
.expect("failed to install SIGTERM handler")
.recv()
.await;
}
#[cfg(not(unix))]
{
std::future::pending::<()>().await;
}
});
tracing::info!(
queues = ?self.config.queues,
poll_interval_secs = self.config.poll_interval.as_secs(),
"Job worker started"
);
let Some(db) = self.state.get::<DatabaseConnection>() else {
tracing::error!(
"Job worker: no DatabaseConnection in AppState — worker will not start. \
Call .with_database() before .jobs()."
);
return;
};
loop {
match claim_batch(db, &self.config).await {
Ok(jobs) => {
let n = jobs.len();
if n > 0 {
tracing::debug!(claimed = n, "Claimed job batch");
}
for job in jobs {
self.dispatch(db, job).await;
}
}
Err(e) => {
tracing::error!(error = %e, "Failed to claim jobs from database");
}
}
tokio::select! {
_ = tokio::time::sleep(self.config.poll_interval) => {}
_ = ctrl_c.as_mut() => {
tracing::info!("Job worker received shutdown signal, stopping.");
break;
}
_ = sigterm.as_mut() => {
tracing::info!("Job worker received shutdown signal, stopping.");
break;
}
}
}
}
async fn dispatch(&self, db: &DatabaseConnection, job: JobRow) {
let handler = inventory::iter::<JobDescriptor>
.into_iter()
.find(|d| d.job_type == job.job_type);
let Some(descriptor) = handler else {
tracing::warn!(
job_id = %job.id,
job_type = %job.job_type,
"No handler registered for job type — permanently failing job"
);
let _ = apply_failure(
db,
job.id,
&format!("no handler registered for job type: {}", job.job_type),
job.attempts,
0,
&RetryPolicy::None,
)
.await;
return;
};
let span = tracing::info_span!(
"job",
job_type = %job.job_type,
job_id = %job.id,
trace_id = job.trace_id.as_deref().unwrap_or(""),
);
let result: JobResult = (descriptor.handle)(job.payload.clone(), self.state.clone())
.instrument(span)
.await;
let policy = build_policy(
descriptor.retry_policy,
job.max_retries,
descriptor.retry_delay_secs,
);
match result {
Ok(()) => {
tracing::debug!(job_id = %job.id, job_type = %job.job_type, "Job completed");
if let Err(e) = apply_success(db, job.id).await {
tracing::error!(job_id = %job.id, error = %e, "Failed to mark job as completed");
}
}
Err(e) => {
tracing::warn!(job_id = %job.id, job_type = %job.job_type, error = %e, "Job failed");
if let Err(db_err) = apply_failure(
db,
job.id,
&e.to_string(),
job.attempts,
job.max_retries,
&policy,
)
.await
{
tracing::error!(job_id = %job.id, error = %db_err, "Failed to record job failure");
}
}
}
}
}
fn build_policy(retry_policy: &str, max_retries: i32, delay_secs: f64) -> RetryPolicy {
let delay = Duration::from_secs_f64(delay_secs);
match retry_policy {
"fixed" => RetryPolicy::fixed(max_retries, delay),
"none" => RetryPolicy::none(),
_ => RetryPolicy::exponential(max_retries, delay),
}
}
async fn claim_batch(
db: &DatabaseConnection,
config: &JobConfig,
) -> Result<Vec<JobRow>, sea_orm::DbErr> {
let stmt = build_claim_stmt(config);
let rows = db.query_all(stmt).await?;
rows.iter()
.map(|row| JobRow::from_query_result(row, ""))
.collect()
}
fn build_claim_stmt(config: &JobConfig) -> Statement {
let placeholders = (1..=config.queues.len())
.map(|i| format!("${i}"))
.collect::<Vec<_>>()
.join(", ");
let batch_param = config.queues.len() + 1;
let timeout_param = config.queues.len() + 2;
let sql = format!(
r#"WITH claimed AS (
SELECT id FROM rapina_jobs
WHERE status = 'pending'
AND queue IN ({placeholders})
AND run_at <= NOW()
ORDER BY run_at ASC
LIMIT ${batch_param}
FOR UPDATE SKIP LOCKED
)
UPDATE rapina_jobs
SET status = 'running',
started_at = NOW(),
locked_until = NOW() + make_interval(secs => ${timeout_param})
FROM claimed
WHERE rapina_jobs.id = claimed.id
RETURNING rapina_jobs.*"#
);
let mut values: Vec<Value> = config
.queues
.iter()
.map(|q| Value::String(Some(Box::new(q.clone()))))
.collect();
values.push(Value::Int(Some(config.batch_size)));
values.push(Value::Double(Some(config.job_timeout.as_secs_f64())));
Statement::from_sql_and_values(DbBackend::Postgres, &sql, values)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use crate::state::AppState;
#[test]
fn descriptor_exponential_produces_exponential_policy() {
let policy = build_policy("exponential", 5, 2.0);
assert!(matches!(
policy,
RetryPolicy::Exponential { max_retries: 5, .. }
));
}
#[test]
fn descriptor_fixed_produces_fixed_policy() {
let policy = build_policy("fixed", 3, 30.0);
assert!(matches!(policy, RetryPolicy::Fixed { max_retries: 3, .. }));
}
#[test]
fn descriptor_none_produces_none_policy() {
let policy = build_policy("none", 0, 0.0);
assert!(matches!(policy, RetryPolicy::None));
}
#[test]
fn descriptor_unknown_policy_falls_back_to_exponential() {
let policy = build_policy("bogus", 3, 1.0);
assert!(matches!(policy, RetryPolicy::Exponential { .. }));
}
#[test]
fn descriptor_base_delay_is_forwarded() {
let policy = build_policy("exponential", 3, 5.0);
match policy {
RetryPolicy::Exponential { base_delay, .. } => {
assert_eq!(base_delay, Duration::from_secs(5));
}
_ => panic!("expected Exponential"),
}
}
#[test]
fn descriptor_fixed_delay_is_forwarded() {
let policy = build_policy("fixed", 3, 20.0);
match policy {
RetryPolicy::Fixed { delay, .. } => {
assert_eq!(delay, Duration::from_secs(20));
}
_ => panic!("expected Fixed"),
}
}
#[tokio::test]
async fn worker_exits_immediately_without_database() {
let state = Arc::new(AppState::new()); let worker = Worker::new(state, JobConfig::default());
let handle = tokio::spawn(worker.run());
let result = tokio::time::timeout(Duration::from_millis(500), handle).await;
assert!(
result.is_ok(),
"worker should return quickly when no DB is in AppState"
);
assert!(result.unwrap().is_ok(), "worker task should not panic");
}
#[test]
fn job_config_defaults() {
let config = JobConfig::default();
assert_eq!(config.poll_interval, Duration::from_secs(5));
assert_eq!(config.batch_size, 10);
assert_eq!(config.queues, vec!["default"]);
assert_eq!(config.job_timeout, Duration::from_secs(30));
}
#[test]
fn job_config_builder_methods() {
let config = JobConfig::default()
.poll_interval(Duration::from_secs(2))
.batch_size(5)
.queues(["emails", "default"])
.job_timeout(Duration::from_secs(60));
assert_eq!(config.poll_interval, Duration::from_secs(2));
assert_eq!(config.batch_size, 5);
assert_eq!(config.queues, vec!["emails", "default"]);
assert_eq!(config.job_timeout, Duration::from_secs(60));
}
#[test]
fn build_claim_stmt_sql_shape() {
let config = JobConfig::default(); let stmt = build_claim_stmt(&config);
let sql = &stmt.sql;
assert!(
sql.contains("FOR UPDATE SKIP LOCKED"),
"should lock claimed rows"
);
assert!(sql.contains("'running'"), "should transition to running");
assert!(sql.contains("locked_until"), "should set lock expiry");
assert!(
sql.contains("RETURNING rapina_jobs.*"),
"should return the claimed rows"
);
assert!(sql.contains("run_at <= NOW()"), "should filter by run_at");
}
#[test]
fn build_claim_stmt_param_count_single_queue() {
let config = JobConfig::default();
let stmt = build_claim_stmt(&config);
let params = stmt.values.as_ref().map(|v| v.0.len()).unwrap_or(0);
assert_eq!(params, 3);
}
#[test]
fn build_claim_stmt_param_count_multiple_queues() {
let config = JobConfig::default().queues(["default", "emails", "heavy"]);
let stmt = build_claim_stmt(&config);
let params = stmt.values.as_ref().map(|v| v.0.len()).unwrap_or(0);
assert_eq!(params, 5);
}
#[test]
fn build_claim_stmt_uses_postgres_backend() {
let stmt = build_claim_stmt(&JobConfig::default());
assert_eq!(stmt.db_backend, DbBackend::Postgres);
}
#[test]
fn build_claim_stmt_queue_values() {
let config = JobConfig::default().queues(["emails"]);
let stmt = build_claim_stmt(&config);
let values = &stmt.values.as_ref().unwrap().0;
assert_eq!(
values[0],
Value::String(Some(Box::new("emails".to_string())))
);
}
#[test]
fn build_claim_stmt_batch_and_timeout_values() {
let config = JobConfig::default()
.batch_size(7)
.job_timeout(Duration::from_secs(45));
let stmt = build_claim_stmt(&config);
let values = &stmt.values.as_ref().unwrap().0;
assert_eq!(values[1], Value::Int(Some(7)));
assert_eq!(values[2], Value::Double(Some(45.0)));
}
}