use std::fmt::Debug;
use std::sync::Arc;
use graphile_worker_database::Database;
use graphile_worker_extensions::ReadOnlyExtensions;
use graphile_worker_job::Job;
pub use graphile_worker_task_details::{SharedTaskDetails, TaskDetails};
use serde_json::Value;
#[derive(Clone, Debug)]
pub struct WorkerContext {
payload: Option<Value>,
database: Database,
escaped_schema: String,
job: Arc<Job>,
worker_id: String,
extensions: ReadOnlyExtensions,
task_details: SharedTaskDetails,
use_local_time: bool,
}
impl WorkerContext {
pub fn builder() -> WorkerContextBuilder {
WorkerContextBuilder::default()
}
pub fn from_shared_job(
job: Arc<Job>,
database: impl Into<Database>,
escaped_schema: String,
worker_id: String,
extensions: ReadOnlyExtensions,
task_details: SharedTaskDetails,
use_local_time: bool,
) -> Self {
Self {
payload: None,
database: database.into(),
escaped_schema,
job,
worker_id,
extensions,
task_details,
use_local_time,
}
}
pub fn payload(&self) -> &Value {
self.payload.as_ref().unwrap_or_else(|| self.job.payload())
}
pub fn database(&self) -> &Database {
&self.database
}
#[cfg(feature = "driver-sqlx")]
pub fn try_pg_pool(&self) -> Option<&sqlx::PgPool> {
self.database
.downcast_ref::<graphile_worker_database::sqlx::SqlxDatabase>()
.map(|database| database.pool())
}
#[cfg(feature = "driver-sqlx")]
pub fn pg_pool(&self) -> &sqlx::PgPool {
self.try_pg_pool()
.expect("WorkerContext does not use the SQLx database driver")
}
pub fn escaped_schema(&self) -> &str {
&self.escaped_schema
}
pub fn job(&self) -> &Job {
self.job.as_ref()
}
pub fn worker_id(&self) -> &str {
&self.worker_id
}
pub fn extensions(&self) -> &ReadOnlyExtensions {
&self.extensions
}
pub fn task_details(&self) -> &SharedTaskDetails {
&self.task_details
}
pub fn use_local_time(&self) -> bool {
self.use_local_time
}
pub fn get_ext<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.extensions.get()
}
}
#[derive(Clone, Default, Debug)]
pub struct WorkerContextBuilder {
payload: Option<Value>,
database: Option<Database>,
escaped_schema: Option<String>,
job: Option<Job>,
worker_id: Option<String>,
extensions: Option<ReadOnlyExtensions>,
task_details: Option<SharedTaskDetails>,
use_local_time: Option<bool>,
}
impl WorkerContextBuilder {
pub fn payload(mut self, payload: Value) -> Self {
self.payload = Some(payload);
self
}
pub fn database(mut self, database: impl Into<Database>) -> Self {
self.database = Some(database.into());
self
}
#[cfg(feature = "driver-sqlx")]
pub fn pg_pool(mut self, pg_pool: sqlx::PgPool) -> Self {
self.database = Some(pg_pool.into());
self
}
pub fn escaped_schema(mut self, escaped_schema: impl Into<String>) -> Self {
self.escaped_schema = Some(escaped_schema.into());
self
}
pub fn job(mut self, job: Job) -> Self {
self.job = Some(job);
self
}
pub fn worker_id(mut self, worker_id: impl Into<String>) -> Self {
self.worker_id = Some(worker_id.into());
self
}
pub fn extensions(mut self, extensions: ReadOnlyExtensions) -> Self {
self.extensions = Some(extensions);
self
}
pub fn task_details(mut self, task_details: SharedTaskDetails) -> Self {
self.task_details = Some(task_details);
self
}
pub fn use_local_time(mut self, use_local_time: bool) -> Self {
self.use_local_time = Some(use_local_time);
self
}
pub fn build(self) -> WorkerContext {
WorkerContext {
payload: self.payload,
database: self.database.unwrap_or_else(|| missing_field("database")),
escaped_schema: self
.escaped_schema
.unwrap_or_else(|| missing_field("escaped_schema")),
job: Arc::new(self.job.unwrap_or_else(|| missing_field("job"))),
worker_id: self.worker_id.unwrap_or_else(|| missing_field("worker_id")),
extensions: self
.extensions
.unwrap_or_else(|| missing_field("extensions")),
task_details: self
.task_details
.unwrap_or_else(|| missing_field("task_details")),
use_local_time: self.use_local_time.unwrap_or_default(),
}
}
}
fn missing_field<T>(field: &str) -> T {
panic!("UninitializedField(\"{field}\")")
}
#[cfg(all(test, feature = "driver-sqlx"))]
mod tests {
use super::*;
use graphile_worker_extensions::Extensions;
use graphile_worker_job::Job;
use sqlx::{postgres::PgPoolOptions, PgPool};
fn create_test_job() -> Job {
Job::builder()
.id(1)
.payload(serde_json::json!({"test": "data"}))
.task_identifier("test_task".to_string())
.build()
}
fn create_test_pool() -> PgPool {
PgPoolOptions::new()
.connect_lazy("postgres://test:test@localhost/test")
.expect("Failed to create lazy pool")
}
fn create_extensions() -> ReadOnlyExtensions {
ReadOnlyExtensions::new(Extensions::default())
}
#[derive(Clone, Debug)]
struct TestExtension {
value: &'static str,
}
#[tokio::test]
async fn test_worker_context_builder() {
let job = create_test_job();
let pool = create_test_pool();
let extensions = create_extensions();
let task_details = SharedTaskDetails::default();
let ctx = WorkerContext::builder()
.payload(serde_json::json!({"key": "value"}))
.pg_pool(pool)
.escaped_schema("graphile_worker".to_string())
.job(job)
.worker_id("worker-1".to_string())
.extensions(extensions)
.task_details(task_details)
.use_local_time(true)
.build();
assert_eq!(ctx.payload(), &serde_json::json!({"key": "value"}));
assert_eq!(ctx.escaped_schema(), "graphile_worker");
assert_eq!(ctx.worker_id(), "worker-1");
assert!(ctx.use_local_time());
}
#[tokio::test]
async fn test_worker_context_builder_use_local_time_default() {
let job = create_test_job();
let pool = create_test_pool();
let extensions = create_extensions();
let task_details = SharedTaskDetails::default();
let ctx = WorkerContext::builder()
.payload(serde_json::json!({}))
.pg_pool(pool)
.escaped_schema("schema".to_string())
.job(job)
.worker_id("worker".to_string())
.extensions(extensions)
.task_details(task_details)
.build();
assert!(!ctx.use_local_time());
}
#[tokio::test]
async fn test_worker_context_from_shared_job_uses_job_payload() {
let job = std::sync::Arc::new(create_test_job());
let pool = create_test_pool();
let mut extensions = Extensions::default();
extensions.insert(TestExtension { value: "present" });
let extensions = ReadOnlyExtensions::new(extensions);
let task_details = SharedTaskDetails::default();
let ctx = WorkerContext::from_shared_job(
job.clone(),
pool,
"graphile_worker".to_string(),
"worker-1".to_string(),
extensions,
task_details,
true,
);
assert_eq!(ctx.payload(), job.payload());
assert_eq!(ctx.job().id(), job.id());
assert_eq!(ctx.escaped_schema(), "graphile_worker");
assert_eq!(ctx.worker_id(), "worker-1");
assert!(ctx.use_local_time());
assert_eq!(
ctx.extensions().get::<TestExtension>().unwrap().value,
"present"
);
assert_eq!(ctx.get_ext::<TestExtension>().unwrap().value, "present");
}
#[tokio::test]
async fn test_worker_context_builder_uses_job_payload_when_payload_missing() {
let job = create_test_job();
let expected_payload = job.payload().clone();
let pool = create_test_pool();
let extensions = create_extensions();
let task_details = SharedTaskDetails::default();
let ctx = WorkerContext::builder()
.pg_pool(pool)
.escaped_schema("schema".to_string())
.job(job)
.worker_id("worker".to_string())
.extensions(extensions)
.task_details(task_details)
.build();
assert_eq!(ctx.payload(), &expected_payload);
}
}