use crate::batcher::{CompletionBatcher, FailureBatcher};
use crate::local_queue::LocalQueueConfig;
use crate::runner::WorkerFn;
use crate::sql::task_identifiers::{get_tasks_details, SharedTaskDetails};
use crate::utils::escape_identifier;
use crate::Worker;
use futures::FutureExt;
use graphile_worker_crontab_parser::{parse_crontab, CrontabParseError};
use graphile_worker_crontab_types::Crontab;
use graphile_worker_ctx::WorkerContext;
use graphile_worker_extensions::Extensions;
use graphile_worker_lifecycle_hooks::{Event, HookRegistry, Plugin};
use graphile_worker_migrations::migrate;
use graphile_worker_shutdown_signal::{shutdown_signal, ShutdownSignal};
use graphile_worker_task_handler::{run_task_from_worker_ctx, TaskHandler};
use rand::RngCore;
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::sync::Notify;
fn manual_shutdown_signal_pair() -> (ShutdownSignal, Arc<Notify>) {
let notify = Arc::new(Notify::new());
let notify_for_signal = notify.clone();
let signal = async move {
notify_for_signal.notified().await;
}
.boxed()
.shared();
(signal, notify)
}
fn combine_shutdown_signals(left: ShutdownSignal, right: ShutdownSignal) -> ShutdownSignal {
async move {
tokio::select! {
_ = left => (),
_ = right => (),
}
}
.boxed()
.shared()
}
#[derive(Default)]
pub struct WorkerOptions {
concurrency: Option<usize>,
poll_interval: Option<Duration>,
jobs: HashMap<String, WorkerFn>,
pg_pool: Option<PgPool>,
database_url: Option<String>,
max_pg_conn: Option<u32>,
schema: Option<String>,
forbidden_flags: Vec<String>,
crontabs: Option<Vec<Crontab>>,
use_local_time: bool,
extensions: Extensions,
hooks: HookRegistry,
listen_os_shutdown_signals: Option<bool>,
local_queue_config: Option<LocalQueueConfig>,
complete_job_batch_delay: Option<Duration>,
fail_job_batch_delay: Option<Duration>,
}
#[derive(Error, Debug)]
pub enum WorkerBuildError {
#[error("Error occurred while connecting to the PostgreSQL database: {0}")]
ConnectError(#[from] sqlx::Error),
#[error("Error occurred while executing a query: {0}")]
QueryError(#[from] crate::errors::GraphileWorkerError),
#[error("Missing database_url configuration - must provide either database_url or pg_pool")]
MissingDatabaseUrl,
#[error("Error occurred while migrating the database schema: {0}")]
MigrationError(#[from] graphile_worker_migrations::MigrateError),
}
impl WorkerOptions {
pub async fn init(self) -> Result<Worker, WorkerBuildError> {
let listen_os_shutdown_signals = self.listen_os_shutdown_signals.unwrap_or(true);
let pg_pool = match self.pg_pool {
Some(pg_pool) => pg_pool,
None => {
let db_url = self
.database_url
.ok_or(WorkerBuildError::MissingDatabaseUrl)?;
PgPoolOptions::new()
.max_connections(self.max_pg_conn.unwrap_or(20))
.connect(&db_url)
.await?
}
};
let schema = self
.schema
.unwrap_or_else(|| String::from("graphile_worker"));
let escaped_schema = escape_identifier(&pg_pool, &schema).await?;
migrate(&pg_pool, &escaped_schema).await?;
let task_details: SharedTaskDetails = get_tasks_details(
&pg_pool,
&escaped_schema,
self.jobs.keys().cloned().collect(),
)
.await?
.into();
let mut random_bytes = [0u8; 9];
rand::rng().fill_bytes(&mut random_bytes);
let (manual_signal, shutdown_notifier) = manual_shutdown_signal_pair();
let shutdown_signal = if listen_os_shutdown_signals {
combine_shutdown_signals(manual_signal, shutdown_signal())
} else {
manual_signal
};
let worker_id = format!("graphile_worker_{}", hex::encode(random_bytes));
let poll_interval = self.poll_interval.unwrap_or(Duration::from_millis(1000));
let hooks = Arc::new(self.hooks);
let concurrency = self.concurrency.unwrap_or_else(num_cpus::get);
let local_queue_config = if self.forbidden_flags.is_empty() {
self.local_queue_config
} else {
None
};
let completion_batcher = self.complete_job_batch_delay.map(|delay| {
CompletionBatcher::new(
delay,
pg_pool.clone(),
escaped_schema.clone(),
worker_id.clone(),
hooks.clone(),
shutdown_signal.clone(),
)
});
let failure_batcher = self.fail_job_batch_delay.map(|delay| {
FailureBatcher::new(
delay,
pg_pool.clone(),
escaped_schema.clone(),
worker_id.clone(),
hooks.clone(),
shutdown_signal.clone(),
)
});
let worker = Worker {
worker_id,
concurrency,
poll_interval,
jobs: self.jobs,
pg_pool,
escaped_schema,
task_details,
forbidden_flags: self.forbidden_flags,
crontabs: self.crontabs.unwrap_or_default(),
use_local_time: self.use_local_time,
shutdown_signal,
shutdown_notifier,
extensions: self.extensions.into(),
hooks,
local_queue_config,
completion_batcher,
failure_batcher,
};
Ok(worker)
}
pub fn schema(mut self, value: &str) -> Self {
self.schema = Some(value.into());
self
}
pub fn concurrency(mut self, value: usize) -> Self {
assert!(value > 0, "Concurrency must be greater than 0");
self.concurrency = Some(value);
self
}
pub fn poll_interval(mut self, value: Duration) -> Self {
self.poll_interval = Some(value);
self
}
pub fn pg_pool(mut self, value: PgPool) -> Self {
self.pg_pool = Some(value);
self
}
pub fn database_url(mut self, value: &str) -> Self {
self.database_url = Some(value.into());
self
}
pub fn max_pg_conn(mut self, value: u32) -> Self {
self.max_pg_conn = Some(value);
self
}
pub fn define_job<T: TaskHandler>(mut self) -> Self {
let identifier = T::IDENTIFIER;
let worker_fn = move |ctx: WorkerContext| {
let ctx = ctx.clone();
run_task_from_worker_ctx::<T>(ctx).boxed()
};
self.jobs
.insert(identifier.to_string(), Box::new(worker_fn));
self
}
pub fn add_forbidden_flag(mut self, flag: &str) -> Self {
self.forbidden_flags.push(flag.into());
self
}
pub fn with_crontab(mut self, input: &str) -> Result<Self, CrontabParseError> {
let mut crontabs = parse_crontab(input)?;
match self.crontabs.as_mut() {
Some(c) => c.append(&mut crontabs),
None => {
self.crontabs = Some(crontabs);
}
}
Ok(self)
}
pub fn use_local_time(mut self, value: bool) -> Self {
self.use_local_time = value;
self
}
pub fn listen_os_shutdown_signals(mut self, value: bool) -> Self {
self.listen_os_shutdown_signals = Some(value);
self
}
pub fn add_extension<T: Clone + Send + Sync + Debug + 'static>(mut self, value: T) -> Self {
self.extensions.insert(value);
self
}
pub fn on<E, F, Fut>(mut self, event: E, handler: F) -> Self
where
E: Event,
F: Fn(E::Context) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future<Output = E::Output> + Send + 'static,
{
self.hooks.on(event, handler);
self
}
pub fn add_plugin<P: Plugin>(mut self, plugin: P) -> Self {
plugin.register(&mut self.hooks);
self
}
pub fn local_queue(mut self, config: LocalQueueConfig) -> Self {
self.local_queue_config = Some(config);
self
}
pub fn complete_job_batch_delay(mut self, delay: Duration) -> Self {
self.complete_job_batch_delay = Some(delay);
self
}
pub fn fail_job_batch_delay(mut self, delay: Duration) -> Self {
self.fail_job_batch_delay = Some(delay);
self
}
}