use std::sync::Arc;
use std::time::Duration;
use sqlx_postgres::PgPool;
use tokio::task::JoinSet;
use tokio::time::sleep;
use tokio_util::sync::CancellationToken;
use tracing::{info, info_span, warn, Instrument};
use crate::queue;
use crate::worker::executor::{ExecutionContext, ExecutionOutcome, Executor};
pub struct WorkerRuntime {
pool: PgPool,
executor: Arc<dyn Executor>,
concurrency: usize,
poll_interval: Duration,
max_poll_interval: Duration,
worker_id_prefix: String,
shutdown_grace: Duration,
}
impl WorkerRuntime {
pub fn new(pool: PgPool, executor: Arc<dyn Executor>) -> Self {
Self {
pool,
executor,
concurrency: 4,
poll_interval: Duration::from_millis(500),
max_poll_interval: Duration::from_secs(2),
worker_id_prefix: "worker".into(),
shutdown_grace: Duration::from_secs(30),
}
}
pub fn with_concurrency(mut self, n: usize) -> Self {
self.concurrency = n.max(1);
self
}
pub fn with_poll_interval(mut self, base: Duration, max: Duration) -> Self {
self.poll_interval = base;
self.max_poll_interval = max.max(base);
self
}
pub fn with_worker_id_prefix(mut self, prefix: impl Into<String>) -> Self {
self.worker_id_prefix = prefix.into();
self
}
pub fn with_shutdown_grace(mut self, grace: Duration) -> Self {
self.shutdown_grace = grace;
self
}
pub async fn run(self, cancel: CancellationToken) {
let mut set = JoinSet::new();
for n in 0..self.concurrency {
let worker_id = format!("{}-{n}", self.worker_id_prefix);
let pool = self.pool.clone();
let executor = self.executor.clone();
let token = cancel.clone();
let base = self.poll_interval;
let max_iv = self.max_poll_interval;
set.spawn(
async move { worker_loop(worker_id, pool, executor, token, base, max_iv).await },
);
}
cancel.cancelled().await;
info!(
grace_seconds = self.shutdown_grace.as_secs(),
"shutdown requested; waiting for in-flight jobs"
);
let drain = async { while set.join_next().await.is_some() {} };
match tokio::time::timeout(self.shutdown_grace, drain).await {
Ok(()) => info!("all workers exited within grace period"),
Err(_) => {
warn!("shutdown grace period expired; aborting remaining workers");
set.abort_all();
while set.join_next().await.is_some() {}
}
}
}
}
async fn worker_loop(
worker_id: String,
pool: PgPool,
executor: Arc<dyn Executor>,
cancel: CancellationToken,
base_interval: Duration,
max_interval: Duration,
) {
let mut idle_backoff = base_interval;
loop {
if cancel.is_cancelled() {
break;
}
match queue::fetch_next(&pool, &worker_id).await {
Ok(Some(job)) => {
idle_backoff = base_interval;
let span = info_span!(
"job",
job_id = %job.id,
kind = job.kind.as_str(),
attempt = job.attempts,
worker_id = %worker_id
);
let ctx = ExecutionContext {
pool: pool.clone(),
shutdown: cancel.clone(),
worker_id: worker_id.clone(),
};
let job_id = job.id;
let kind_label = job.kind.as_str();
metrics::counter!("worker_jobs_started_total", "kind" => kind_label).increment(1);
let started = std::time::Instant::now();
let outcome = executor.execute(&ctx, &job).instrument(span).await;
let elapsed = started.elapsed().as_secs_f64();
let outcome_label: &'static str = match outcome {
ExecutionOutcome::Succeeded => {
if let Err(e) = queue::mark_succeeded(&pool, job_id).await {
warn!(error = %e, %job_id, "failed to mark succeeded");
}
"succeeded"
}
ExecutionOutcome::Failed(msg) => {
match queue::mark_failed_or_retry(&pool, job_id, &msg).await {
Ok(updated) => {
if updated.status == crate::JobStatus::FailedPermanent {
"failed_permanent"
} else {
"retrying"
}
}
Err(e) => {
warn!(error = %e, %job_id, "failed to mark failed_or_retry");
"error"
}
}
}
ExecutionOutcome::Cancelled => {
if let Err(e) = queue::finalize_cancelled(&pool, job_id).await {
warn!(error = %e, %job_id, "failed to finalize cancelled");
}
"cancelled"
}
};
metrics::counter!(
"worker_jobs_completed_total",
"kind" => kind_label,
"outcome" => outcome_label,
)
.increment(1);
metrics::histogram!(
"worker_job_duration_seconds",
"kind" => kind_label,
"outcome" => outcome_label,
)
.record(elapsed);
}
Ok(None) => {
tokio::select! {
_ = cancel.cancelled() => break,
_ = sleep(idle_backoff) => {}
}
idle_backoff = (idle_backoff * 2).min(max_interval);
}
Err(e) => {
warn!(error = %e, "fetch_next error; backing off");
tokio::select! {
_ = cancel.cancelled() => break,
_ = sleep(max_interval) => {}
}
}
}
}
info!(worker_id, "worker loop exited");
}