use crate::batcher::{CompletionBatcher, FailureBatcher};
use crate::cron::CronBuilder;
use crate::local_queue::LocalQueueConfig;
use crate::runner::WorkerFn;
use crate::sql::task_identifiers::{get_tasks_details, SharedTaskDetails};
use crate::utils::escape_identifier;
use crate::Worker;
use futures::FutureExt;
use graphile_worker_crontab_parser::{parse_crontab, CrontabParseError};
use graphile_worker_crontab_types::Crontab;
use graphile_worker_database::{Database, DbError};
use graphile_worker_extensions::Extensions;
use graphile_worker_lifecycle_hooks::{Event, HookRegistry, Plugin};
use graphile_worker_migrations::migrate;
use graphile_worker_runtime::Notify;
use graphile_worker_shutdown_signal::{shutdown_signal, ShutdownSignal};
use graphile_worker_task_handler::{BatchTaskHandler, JobDefinition, TaskHandler};
use rand::Rng;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
pub trait CronInput {
type Output;
fn append_to(self, options: WorkerOptions) -> Self::Output;
}
impl CronInput for Crontab {
type Output = WorkerOptions;
fn append_to(self, mut options: WorkerOptions) -> Self::Output {
options.append_crontabs(vec![self]);
options
}
}
impl<T: TaskHandler> CronInput for CronBuilder<T> {
type Output = WorkerOptions;
fn append_to(self, options: WorkerOptions) -> Self::Output {
self.build().append_to(options)
}
}
impl CronInput for &str {
type Output = Result<WorkerOptions, CrontabParseError>;
fn append_to(self, mut options: WorkerOptions) -> Self::Output {
let crontabs = parse_crontab(self)?;
options.append_crontabs(crontabs);
Ok(options)
}
}
impl CronInput for String {
type Output = Result<WorkerOptions, CrontabParseError>;
fn append_to(self, options: WorkerOptions) -> Self::Output {
self.as_str().append_to(options)
}
}
impl CronInput for &String {
type Output = Result<WorkerOptions, CrontabParseError>;
fn append_to(self, options: WorkerOptions) -> Self::Output {
self.as_str().append_to(options)
}
}
fn manual_shutdown_signal_pair() -> (ShutdownSignal, Arc<Notify>) {
let notify = Arc::new(Notify::new());
let notify_for_signal = notify.clone();
let signal = async move {
notify_for_signal.notified().await;
}
.boxed()
.shared();
(signal, notify)
}
fn combine_shutdown_signals(left: ShutdownSignal, right: ShutdownSignal) -> ShutdownSignal {
async move {
let left = left.fuse();
let right = right.fuse();
futures::pin_mut!(left, right);
futures::select_biased! {
_ = left => (),
_ = right => (),
};
}
.boxed()
.shared()
}
#[cfg(feature = "driver-sqlx")]
async fn connect_default_database(db_url: &str, max_connections: u32) -> Result<Database, DbError> {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(max_connections)
.connect(db_url)
.await
.map_err(DbError::from)?;
Ok(pool.into())
}
#[cfg(all(not(feature = "driver-sqlx"), feature = "driver-tokio-postgres"))]
async fn connect_default_database(db_url: &str, max_connections: u32) -> Result<Database, DbError> {
let database = graphile_worker_database::tokio_postgres::TokioPostgresDatabase::from_url(
db_url,
max_connections as usize,
)?;
Ok(database.into())
}
#[cfg(not(any(feature = "driver-sqlx", feature = "driver-tokio-postgres")))]
async fn connect_default_database(
_db_url: &str,
_max_connections: u32,
) -> Result<Database, DbError> {
Err(DbError::new(
"database_url requires enabling a database driver feature",
))
}
#[derive(Default)]
pub struct WorkerOptions {
concurrency: Option<usize>,
poll_interval: Option<Duration>,
jobs: HashMap<String, WorkerFn>,
database: Option<Database>,
database_url: Option<String>,
max_pg_conn: Option<u32>,
schema: Option<String>,
forbidden_flags: Vec<String>,
crontabs: Option<Vec<Crontab>>,
use_local_time: bool,
extensions: Extensions,
hooks: HookRegistry,
listen_os_shutdown_signals: Option<bool>,
local_queue_config: Option<LocalQueueConfig>,
complete_job_batch_delay: Option<Duration>,
fail_job_batch_delay: Option<Duration>,
}
#[derive(Error, Debug)]
pub enum WorkerBuildError {
#[error("Error occurred while connecting to the PostgreSQL database: {0}")]
ConnectError(#[from] DbError),
#[error("Error occurred while executing a query: {0}")]
QueryError(#[from] crate::errors::GraphileWorkerError),
#[error("Missing database configuration - must provide either database_url or database")]
MissingDatabaseUrl,
#[error("Error occurred while migrating the database schema: {0}")]
MigrationError(#[from] graphile_worker_migrations::MigrateError),
}
impl WorkerOptions {
fn append_crontabs(&mut self, mut crontabs: Vec<Crontab>) {
match self.crontabs.as_mut() {
Some(existing) => existing.append(&mut crontabs),
None => {
self.crontabs = Some(crontabs);
}
}
}
pub async fn init(self) -> Result<Worker, WorkerBuildError> {
let listen_os_shutdown_signals = self.listen_os_shutdown_signals.unwrap_or(true);
let database = match self.database {
Some(database) => database,
None => {
let db_url = self
.database_url
.ok_or(WorkerBuildError::MissingDatabaseUrl)?;
connect_default_database(&db_url, self.max_pg_conn.unwrap_or(20)).await?
}
};
let schema = self
.schema
.unwrap_or_else(|| String::from("graphile_worker"));
let escaped_schema = escape_identifier(&database, &schema).await?;
migrate(&database, &escaped_schema).await?;
let task_details: SharedTaskDetails = get_tasks_details(
&database,
&escaped_schema,
self.jobs.keys().cloned().collect(),
)
.await?
.into();
let mut random_bytes = [0u8; 9];
rand::rng().fill_bytes(&mut random_bytes);
let (manual_signal, shutdown_notifier) = manual_shutdown_signal_pair();
let shutdown_signal = if listen_os_shutdown_signals {
combine_shutdown_signals(manual_signal, shutdown_signal())
} else {
manual_signal
};
let worker_id = format!("graphile_worker_{}", hex::encode(random_bytes));
let poll_interval = self.poll_interval.unwrap_or(Duration::from_millis(1000));
let hooks = Arc::new(self.hooks);
let concurrency = self.concurrency.unwrap_or_else(num_cpus::get);
let local_queue_config = if self.forbidden_flags.is_empty() {
self.local_queue_config
} else {
None
};
let completion_batcher = self.complete_job_batch_delay.map(|delay| {
Arc::new(CompletionBatcher::new(
delay,
database.clone(),
escaped_schema.clone(),
worker_id.clone(),
hooks.clone(),
shutdown_signal.clone(),
))
});
let failure_batcher = self.fail_job_batch_delay.map(|delay| {
Arc::new(FailureBatcher::new(
delay,
database.clone(),
escaped_schema.clone(),
worker_id.clone(),
hooks.clone(),
shutdown_signal.clone(),
))
});
let worker = Worker {
worker_id,
concurrency,
poll_interval,
jobs: self.jobs,
database,
escaped_schema,
task_details,
forbidden_flags: self.forbidden_flags,
crontabs: self.crontabs.unwrap_or_default(),
use_local_time: self.use_local_time,
shutdown_signal,
shutdown_notifier,
extensions: self.extensions.into(),
hooks,
local_queue_config,
completion_batcher,
failure_batcher,
};
Ok(worker)
}
pub fn schema(mut self, value: &str) -> Self {
self.schema = Some(value.into());
self
}
pub fn concurrency(mut self, value: usize) -> Self {
assert!(value > 0, "Concurrency must be greater than 0");
self.concurrency = Some(value);
self
}
pub fn poll_interval(mut self, value: Duration) -> Self {
self.poll_interval = Some(value);
self
}
pub fn database(mut self, value: impl Into<Database>) -> Self {
self.database = Some(value.into());
self
}
#[cfg(feature = "driver-sqlx")]
pub fn pg_pool(mut self, value: sqlx::PgPool) -> Self {
self.database = Some(value.into());
self
}
pub fn database_url(mut self, value: &str) -> Self {
self.database_url = Some(value.into());
self
}
pub fn max_pg_conn(mut self, value: u32) -> Self {
self.max_pg_conn = Some(value);
self
}
pub fn define_job<T: TaskHandler>(self) -> Self {
self.define_jobs([T::definition()])
}
pub fn define_batch_job<T: BatchTaskHandler>(self) -> Self {
self.define_jobs([T::definition()])
}
pub fn define_jobs<I>(mut self, jobs: I) -> Self
where
I: IntoIterator<Item = JobDefinition>,
{
for job in jobs {
let (identifier, worker_fn) = job.into_parts();
self.jobs.insert(identifier.to_string(), worker_fn);
}
self
}
pub fn add_forbidden_flag(mut self, flag: &str) -> Self {
self.forbidden_flags.push(flag.into());
self
}
pub fn with_cron<C: CronInput>(self, cron: C) -> C::Output {
cron.append_to(self)
}
pub fn with_crons<I, C>(mut self, crontabs: I) -> Self
where
I: IntoIterator<Item = C>,
C: Into<Crontab>,
{
self.append_crontabs(crontabs.into_iter().map(Into::into).collect());
self
}
#[deprecated(note = "use WorkerOptions::with_cron(...) instead")]
pub fn with_crontab(self, input: &str) -> Result<Self, CrontabParseError> {
self.with_cron(input)
}
pub fn use_local_time(mut self, value: bool) -> Self {
self.use_local_time = value;
self
}
pub fn listen_os_shutdown_signals(mut self, value: bool) -> Self {
self.listen_os_shutdown_signals = Some(value);
self
}
pub fn add_extension<T: Clone + Send + Sync + Debug + 'static>(mut self, value: T) -> Self {
self.extensions.insert(value);
self
}
pub fn on<E, F, Fut>(mut self, event: E, handler: F) -> Self
where
E: Event,
F: Fn(E::Context) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future<Output = E::Output> + Send + 'static,
{
self.hooks.on(event, handler);
self
}
pub fn add_plugin<P: Plugin>(mut self, plugin: P) -> Self {
plugin.register(&mut self.hooks);
self
}
pub fn local_queue(mut self, config: LocalQueueConfig) -> Self {
self.local_queue_config = Some(config);
self
}
pub fn complete_job_batch_delay(mut self, delay: Duration) -> Self {
self.complete_job_batch_delay = Some(delay);
self
}
pub fn fail_job_batch_delay(mut self, delay: Duration) -> Self {
self.fail_job_batch_delay = Some(delay);
self
}
}