use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use std::{collections::HashMap, time::Instant};
use chrono::Utc;
use crate::batcher::{CompletionRequest, FailureRequest};
use crate::errors::GraphileWorkerError;
use crate::local_queue::LocalQueue;
use crate::sql::{get_job::get_job, task_identifiers::SharedTaskDetails};
use crate::streams::{job_signal_stream, job_signal_stream_with_receiver, job_stream};
use crate::worker_utils::WorkerUtils;
use futures::{try_join, StreamExt, TryStreamExt};
use getset::Getters;
use graphile_worker_crontab_runner::{cron_main, ScheduleCronJobError};
use graphile_worker_crontab_types::Crontab;
use graphile_worker_ctx::WorkerContext;
use graphile_worker_extensions::ReadOnlyExtensions;
use graphile_worker_job::Job;
use graphile_worker_lifecycle_hooks::{
AfterJobRunContext, BeforeJobRunContext, HookRegistry, HookResult, JobCompleteContext,
JobFailContext, JobFetchContext, JobPermanentlyFailContext, JobStartContext, ShutdownReason,
WorkerShutdownContext, WorkerStartContext,
};
use graphile_worker_shutdown_signal::ShutdownSignal;
use thiserror::Error;
use tokio::sync::Notify;
use tracing::{debug, error, info, trace, warn, Instrument, Span};
use crate::builder::WorkerOptions;
use crate::sql::complete_job::complete_job;
use crate::tracing::link_to_job_create_span;
use crate::{sql::fail_job::fail_job, streams::StreamSource};
pub type WorkerFn = Box<
dyn Fn(WorkerContext) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send>> + Send + Sync,
>;
#[derive(Getters)]
#[getset(get = "pub")]
pub struct Worker {
pub(crate) worker_id: String,
pub(crate) concurrency: usize,
pub(crate) poll_interval: Duration,
pub(crate) jobs: HashMap<String, WorkerFn>,
pub(crate) pg_pool: sqlx::PgPool,
pub(crate) escaped_schema: String,
pub(crate) task_details: SharedTaskDetails,
pub(crate) forbidden_flags: Vec<String>,
pub(crate) crontabs: Vec<Crontab>,
pub(crate) use_local_time: bool,
pub(crate) shutdown_signal: ShutdownSignal,
#[getset(skip)]
pub(crate) shutdown_notifier: Arc<Notify>,
pub(crate) extensions: ReadOnlyExtensions,
pub(crate) hooks: Arc<HookRegistry>,
#[getset(skip)]
pub(crate) local_queue_config: Option<crate::local_queue::LocalQueueConfig>,
#[getset(skip)]
pub(crate) completion_batcher: Option<crate::batcher::CompletionBatcher>,
#[getset(skip)]
pub(crate) failure_batcher: Option<crate::batcher::FailureBatcher>,
}
#[derive(Error, Debug)]
pub enum WorkerRuntimeError {
#[error("Unexpected error occured while processing job : '{0}'")]
ProcessJob(#[from] ProcessJobError),
#[error("Failed to listen to postgres notifications : '{0}'")]
PgListen(#[from] GraphileWorkerError),
#[error("Error occured while trying to schedule cron job : {0}")]
Crontab(#[from] ScheduleCronJobError),
}
impl Worker {
pub fn options() -> WorkerOptions {
WorkerOptions::default()
}
pub async fn run(&self) -> Result<(), WorkerRuntimeError> {
self.hooks
.emit(WorkerStartContext {
pool: self.pg_pool.clone(),
worker_id: self.worker_id.clone(),
extensions: self.extensions.clone(),
})
.await;
let local_queue = self.create_local_queue();
let job_runner = self.job_runner_internal(local_queue);
let crontab_scheduler = self.crontab_scheduler();
let result = try_join!(crontab_scheduler, job_runner);
if let Some(batcher) = &self.completion_batcher {
batcher.await_shutdown().await;
}
if let Some(batcher) = &self.failure_batcher {
batcher.await_shutdown().await;
}
let reason = match &result {
Ok(_) => ShutdownReason::Graceful,
Err(_) => ShutdownReason::Error,
};
self.hooks
.emit(WorkerShutdownContext {
pool: self.pg_pool.clone(),
worker_id: self.worker_id.clone(),
reason,
})
.await;
result?;
Ok(())
}
pub async fn run_once(&self) -> Result<(), WorkerRuntimeError> {
let job_stream = job_stream(
self.pg_pool.clone(),
self.shutdown_signal.clone(),
self.task_details.clone(),
self.escaped_schema.clone(),
self.worker_id.clone(),
self.forbidden_flags.clone(),
self.use_local_time,
);
job_stream
.for_each_concurrent(self.concurrency, |mut job| async move {
loop {
let job_id = *job.id();
let has_queue = job.job_queue_id().is_some();
let result =
run_and_release_job(Arc::new(job), self, &StreamSource::RunOnce).await;
match result {
Ok(_) => {
info!(job_id, "Job processed");
}
Err(e) => {
error!("Error while processing job : {:?}", e);
}
};
if !has_queue {
break;
}
info!(job_id, "Job has queue, fetching another job");
let now = self.use_local_time.then(Utc::now);
let task_details_guard = self.task_details.read().await;
let new_job = get_job(
self.pg_pool(),
&task_details_guard,
self.escaped_schema(),
self.worker_id(),
self.forbidden_flags(),
now,
)
.await
.unwrap_or(None);
drop(task_details_guard);
let Some(new_job) = new_job else {
break;
};
job = new_job;
}
})
.await;
Ok(())
}
fn create_local_queue(&self) -> Option<(LocalQueue, crate::streams::JobSignalReceiver)> {
if let Some(ref config) = self.local_queue_config {
let (tx, rx) = tokio::sync::mpsc::channel(self.concurrency * 2);
let queue = LocalQueue::new(crate::local_queue::LocalQueueParams {
config: config.clone(),
pg_pool: self.pg_pool.clone(),
escaped_schema: self.escaped_schema.clone(),
worker_id: self.worker_id.clone(),
task_details: self.task_details.clone(),
poll_interval: self.poll_interval,
continuous: true,
shutdown_signal: Some(self.shutdown_signal.clone()),
hooks: self.hooks.clone(),
job_signal_sender: tx,
use_local_time: self.use_local_time,
});
Some((queue, rx))
} else {
None
}
}
async fn job_runner_internal(
&self,
local_queue: Option<(LocalQueue, crate::streams::JobSignalReceiver)>,
) -> Result<(), WorkerRuntimeError> {
match local_queue {
Some((local_queue, rx)) => self.job_runner_with_local_queue(local_queue, rx).await,
None => self.job_runner_direct().await,
}
}
async fn job_runner_with_local_queue(
&self,
local_queue: LocalQueue,
job_signal_rx: crate::streams::JobSignalReceiver,
) -> Result<(), WorkerRuntimeError> {
let job_signal = job_signal_stream_with_receiver(
self.pg_pool.clone(),
self.poll_interval,
self.shutdown_signal.clone(),
self.concurrency,
job_signal_rx,
)
.await?;
debug!("Listening for jobs with LocalQueue...");
job_signal
.map(Ok::<_, ProcessJobError>)
.try_for_each_concurrent(self.concurrency, |source| {
let local_queue = local_queue.clone();
async move {
if matches!(source, StreamSource::PgListener) {
local_queue.pulse(1).await;
}
let job = local_queue.get_job(&self.forbidden_flags).await;
if let Some(job) = job {
let job = Arc::new(job);
self.hooks
.emit(JobFetchContext {
job: job.clone(),
worker_id: self.worker_id().clone(),
})
.await;
run_and_release_job(job.clone(), self, &source).await?;
}
Ok(())
}
})
.await?;
if let Err(e) = local_queue.release().await {
warn!(error = %e, "Error releasing LocalQueue");
}
Ok(())
}
async fn job_runner_direct(&self) -> Result<(), WorkerRuntimeError> {
let job_signal = job_signal_stream(
self.pg_pool.clone(),
self.poll_interval,
self.shutdown_signal.clone(),
self.concurrency,
)
.await?;
debug!("Listening for jobs...");
job_signal
.map(Ok::<_, ProcessJobError>)
.try_for_each_concurrent(self.concurrency, |source| async move {
let res = process_one_job(self, source).await?;
if let Some(job) = res {
debug!(job_id = job.id(), "Job processed");
}
Ok(())
})
.await?;
Ok(())
}
async fn crontab_scheduler(&self) -> Result<(), WorkerRuntimeError> {
if self.crontabs().is_empty() {
return Ok(());
}
cron_main(
self.pg_pool(),
self.escaped_schema(),
self.crontabs(),
*self.use_local_time(),
self.shutdown_signal.clone(),
&self.hooks,
)
.await?;
Ok(())
}
pub fn create_utils(&self) -> WorkerUtils {
WorkerUtils::new(self.pg_pool.clone(), self.escaped_schema.clone())
.with_hooks(self.hooks.clone())
.with_task_details(self.task_details.clone())
}
pub fn request_shutdown(&self) {
self.shutdown_notifier.notify_waiters();
}
}
#[derive(Error, Debug)]
pub enum ProcessJobError {
#[error("An error occured while releasing a job : '{0}'")]
ReleaseJobError(#[from] ReleaseJobError),
#[error("An error occured while fetching a job to run : '{0}'")]
GetJobError(#[from] GraphileWorkerError),
}
async fn process_one_job(
worker: &Worker,
source: StreamSource,
) -> Result<Option<Job>, ProcessJobError> {
let now = worker.use_local_time.then(Utc::now);
let task_details_guard = worker.task_details.read().await;
let job = get_job(
worker.pg_pool(),
&task_details_guard,
worker.escaped_schema(),
worker.worker_id(),
worker.forbidden_flags(),
now,
)
.await
.map_err(|e| {
error!("Could not get job : {:?}", e);
e
})?;
drop(task_details_guard);
match job {
Some(job) => {
let job = Arc::new(job);
worker
.hooks
.emit(JobFetchContext {
job: job.clone(),
worker_id: worker.worker_id().clone(),
})
.await;
run_and_release_job(job.clone(), worker, &source).await?;
Ok(Some(
Arc::try_unwrap(job).unwrap_or_else(|arc| (*arc).clone()),
))
}
None => {
trace!(source = ?source, "No job found");
Ok(None)
}
}
}
async fn run_and_release_job(
job: Arc<Job>,
worker: &Worker,
source: &StreamSource,
) -> Result<(), ProcessJobError> {
let before_result = worker
.hooks
.intercept(BeforeJobRunContext {
job: job.clone(),
worker_id: worker.worker_id().clone(),
payload: job.payload().clone(),
})
.await;
let (job_result, duration) = match before_result {
HookResult::Continue => {
worker
.hooks
.emit(JobStartContext {
job: job.clone(),
worker_id: worker.worker_id().clone(),
})
.await;
let start = Instant::now();
let job_result = run_job(&job, worker, source).await;
let duration = start.elapsed();
let result_for_hook = job_result
.as_ref()
.map(|_| ())
.map_err(|e| format!("{e:?}"));
let after_result = worker
.hooks
.intercept(AfterJobRunContext {
job: job.clone(),
worker_id: worker.worker_id().clone(),
result: result_for_hook,
duration,
})
.await;
match after_result {
HookResult::Continue => (job_result, duration),
HookResult::Skip => (Ok(()), duration),
HookResult::Fail(msg) => (Err(RunJobError::TaskError(msg)), duration),
}
}
HookResult::Skip => {
debug!(job_id = job.id(), "Job skipped by before_job_run hook");
(Ok(()), Duration::ZERO)
}
HookResult::Fail(msg) => {
debug!(
job_id = job.id(),
"Job failed by before_job_run hook: {}", msg
);
(Err(RunJobError::TaskError(msg)), Duration::ZERO)
}
};
release_job(job_result, job.clone(), worker, duration)
.await
.map_err(|e| {
error!("Release job error : {:?}", e);
e
})?;
Ok(())
}
#[derive(Error, Debug)]
enum RunJobError {
#[error("Cannot find any task identifier for given task id '{0}'. This is probably a bug !")]
IdentifierNotFound(i32),
#[error("Cannot find any task fn for given task identifier '{0}'. This is probably a bug !")]
FnNotFound(String),
#[error("Task failed execution to complete : {0}")]
TaskPanic(#[from] tokio::task::JoinError),
#[error("Task returned the following error : {0}")]
TaskError(String),
#[error("Task was aborted by shutdown signal")]
TaskAborted,
}
#[tracing::instrument(
"run_job",
skip(job, worker, source),
fields(
job_id = job.id(),
messaging.system = "graphile-worker",
messaging.operation.name = "run_job",
messaging.destination.name = tracing::field::Empty,
otel.name = tracing::field::Empty
)
)]
async fn run_job(job: &Job, worker: &Worker, source: &StreamSource) -> Result<(), RunJobError> {
link_to_job_create_span(job.payload().clone());
let task_id = job.task_id();
let task_details_guard = worker.task_details.read().await;
let task_identifier = task_details_guard
.get(task_id)
.ok_or_else(|| RunJobError::IdentifierNotFound(*task_id))?
.clone();
drop(task_details_guard);
let span = Span::current();
span.record("otel.name", task_identifier.as_str());
span.record("messaging.destination.name", task_identifier.as_str());
let task_fn = worker
.jobs()
.get(&task_identifier)
.ok_or_else(|| RunJobError::FnNotFound(task_identifier.clone()))?;
debug!(source = ?source, job_id = job.id(), task_identifier, task_id, "Found task");
let payload = job.payload().to_string();
let worker_ctx = WorkerContext::builder()
.payload(job.payload().clone())
.pg_pool(worker.pg_pool().clone())
.escaped_schema(worker.escaped_schema().clone())
.job(job.clone())
.worker_id(worker.worker_id().clone())
.extensions(worker.extensions().clone())
.task_details(worker.task_details().clone())
.use_local_time(worker.use_local_time)
.build();
let task_fut = task_fn(worker_ctx);
let start = Instant::now();
let job_task = tokio::spawn(task_fut.instrument(span));
let abort_handle = job_task.abort_handle();
let mut shutdown_signal = worker.shutdown_signal().clone();
let shutdown_timeout = async {
(&mut shutdown_signal).await;
tokio::time::sleep(Duration::from_secs(5)).await;
};
tokio::select! {
res = job_task => {
match res {
Err(e) => Err(RunJobError::TaskPanic(e)),
Ok(Err(e)) => Err(RunJobError::TaskError(e)),
Ok(Ok(_)) => Ok(()),
}
}
_ = shutdown_timeout => {
abort_handle.abort();
warn!(task_identifier, payload, job_id = job.id(), "Job interrupted by shutdown signal after 5 seconds timeout");
Err(RunJobError::TaskAborted)
}
}?;
let duration = start.elapsed();
info!(
task_identifier,
payload,
job_id = job.id(),
duration = duration.as_millis(),
"Completed task with success"
);
Ok(())
}
#[derive(Error, Debug)]
#[error("Failed to release job '{job_id}'. {source}")]
pub struct ReleaseJobError {
job_id: i64,
#[source]
source: GraphileWorkerError,
}
async fn release_job(
job_result: Result<(), RunJobError>,
job: Arc<Job>,
worker: &Worker,
duration: Duration,
) -> Result<(), ReleaseJobError> {
match job_result {
Ok(_) => {
if let Some(batcher) = &worker.completion_batcher {
batcher
.complete(CompletionRequest {
job_id: *job.id(),
has_queue: job.job_queue_id().is_some(),
job,
duration,
})
.await;
} else {
complete_job(
worker.pg_pool(),
&job,
worker.worker_id(),
worker.escaped_schema(),
)
.await
.map_err(|e| ReleaseJobError {
job_id: *job.id(),
source: e,
})?;
worker
.hooks
.emit(JobCompleteContext {
job,
worker_id: worker.worker_id().clone(),
duration,
})
.await;
}
}
Err(e) => {
let error_str = format!("{e:?}");
let will_retry = job.attempts() < job.max_attempts();
if let Some(batcher) = &worker.failure_batcher {
if !will_retry {
error!(
error = ?e,
task_id = job.task_id(),
payload = ?job.payload(),
job_id = job.id(),
"Job max attempts reached"
);
} else {
warn!(
error = ?e,
task_id = job.task_id(),
payload = ?job.payload(),
job_id = job.id(),
"Failed task"
);
}
batcher
.fail(FailureRequest {
job,
error: error_str,
will_retry,
})
.await;
} else {
if !will_retry {
error!(
error = ?e,
task_id = job.task_id(),
payload = ?job.payload(),
job_id = job.id(),
"Job max attempts reached"
);
worker
.hooks
.emit(JobPermanentlyFailContext {
job: job.clone(),
worker_id: worker.worker_id().clone(),
error: error_str.clone(),
})
.await;
} else {
warn!(
error = ?e,
task_id = job.task_id(),
payload = ?job.payload(),
job_id = job.id(),
"Failed task"
);
worker
.hooks
.emit(JobFailContext {
job: job.clone(),
worker_id: worker.worker_id().clone(),
error: error_str.clone(),
will_retry,
})
.await;
}
fail_job(
worker.pg_pool(),
&job,
worker.escaped_schema(),
worker.worker_id(),
&error_str,
None,
)
.await
.map_err(|e| ReleaseJobError {
job_id: *job.id(),
source: e,
})?;
}
}
}
Ok(())
}