use graphile_worker_ctx::WorkerContext;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
pub type TaskHandlerFn = Arc<
dyn Fn(WorkerContext) -> Pin<Box<dyn Future<Output = TaskHandlerOutcome> + Send>> + Send + Sync,
>;
#[derive(Debug, Clone, PartialEq)]
pub enum TaskHandlerOutcome {
Complete,
Failed {
error: String,
replacement_payload: Option<Value>,
},
}
impl TaskHandlerOutcome {
pub fn failed(error: impl Into<String>) -> Self {
Self::Failed {
error: error.into(),
replacement_payload: None,
}
}
pub fn failed_with_replacement(
error: impl Into<String>,
replacement_payload: impl Into<Value>,
) -> Self {
Self::Failed {
error: error.into(),
replacement_payload: Some(replacement_payload.into()),
}
}
}
pub trait IntoTaskHandlerResult {
fn into_task_handler_result(self) -> Result<(), impl Debug>;
}
#[derive(Clone)]
pub struct JobDefinition {
identifier: &'static str,
handler: TaskHandlerFn,
}
impl JobDefinition {
pub fn of<T: TaskHandler>() -> Self {
let handler = move |ctx: WorkerContext| {
let ctx = ctx.clone();
Box::pin(run_task_from_worker_ctx_outcome::<T>(ctx))
as Pin<Box<dyn Future<Output = TaskHandlerOutcome> + Send>>
};
Self {
identifier: T::IDENTIFIER,
handler: Arc::new(handler),
}
}
pub fn of_batch<T: BatchTaskHandler>() -> Self {
let handler = move |ctx: WorkerContext| {
let ctx = ctx.clone();
Box::pin(run_batch_task_from_worker_ctx::<T>(ctx))
as Pin<Box<dyn Future<Output = TaskHandlerOutcome> + Send>>
};
Self {
identifier: T::IDENTIFIER,
handler: Arc::new(handler),
}
}
pub fn identifier(&self) -> &'static str {
self.identifier
}
pub fn handler(&self) -> TaskHandlerFn {
self.handler.clone()
}
pub fn into_parts(self) -> (&'static str, TaskHandlerFn) {
(self.identifier, self.handler)
}
}
impl IntoTaskHandlerResult for () {
fn into_task_handler_result(self) -> Result<(), impl Debug> {
Ok::<_, ()>(())
}
}
impl<D: Debug> IntoTaskHandlerResult for Result<(), D> {
fn into_task_handler_result(self) -> Result<(), impl Debug> {
self
}
}
pub trait TaskHandler: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static {
const IDENTIFIER: &'static str;
fn definition() -> JobDefinition
where
Self: Sized,
{
JobDefinition::of::<Self>()
}
fn run(
self,
ctx: WorkerContext,
) -> impl Future<Output = impl IntoTaskHandlerResult> + Send + 'static;
}
#[derive(Debug, Clone, PartialEq)]
pub enum BatchTaskResult<E> {
Complete,
FailAll(E),
ItemResults(Vec<Result<(), E>>),
}
pub trait IntoBatchTaskHandlerResult {
fn into_batch_task_handler_result(self) -> BatchTaskResult<impl Debug>;
}
impl IntoBatchTaskHandlerResult for () {
fn into_batch_task_handler_result(self) -> BatchTaskResult<impl Debug> {
BatchTaskResult::<()>::Complete
}
}
impl<D: Debug> IntoBatchTaskHandlerResult for Result<(), D> {
fn into_batch_task_handler_result(self) -> BatchTaskResult<impl Debug> {
match self {
Ok(()) => BatchTaskResult::Complete,
Err(error) => BatchTaskResult::FailAll(error),
}
}
}
impl<D: Debug> IntoBatchTaskHandlerResult for Vec<Result<(), D>> {
fn into_batch_task_handler_result(self) -> BatchTaskResult<impl Debug> {
BatchTaskResult::ItemResults(self)
}
}
impl<D: Debug> IntoBatchTaskHandlerResult for BatchTaskResult<D> {
fn into_batch_task_handler_result(self) -> BatchTaskResult<impl Debug> {
self
}
}
pub trait BatchTaskHandler: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static {
const IDENTIFIER: &'static str;
fn definition() -> JobDefinition
where
Self: Sized,
{
JobDefinition::of_batch::<Self>()
}
fn run_batch(
items: Vec<Self>,
ctx: WorkerContext,
) -> impl Future<Output = impl IntoBatchTaskHandlerResult> + Send + 'static;
}
async fn run_task_from_worker_ctx_outcome<T: TaskHandler>(
worker_context: WorkerContext,
) -> TaskHandlerOutcome {
match run_task_from_worker_ctx::<T>(worker_context).await {
Ok(()) => TaskHandlerOutcome::Complete,
Err(error) => TaskHandlerOutcome::failed(error),
}
}
pub async fn run_batch_task_from_worker_ctx<T: BatchTaskHandler>(
worker_context: WorkerContext,
) -> TaskHandlerOutcome {
let original_payload = worker_context.payload().clone();
let item_payloads = match original_payload.as_array() {
Some(items) => items.clone(),
None => {
return TaskHandlerOutcome::failed("batch job payload must be a JSON array");
}
};
let items = match Vec::<T>::deserialize(&original_payload) {
Ok(items) => items,
Err(error) => return TaskHandlerOutcome::failed(format!("{error:?}")),
};
let item_count = items.len();
let result = T::run_batch(items, worker_context)
.await
.into_batch_task_handler_result();
batch_result_to_task_outcome(result, item_count, item_payloads)
}
fn batch_result_to_task_outcome<E: Debug>(
result: BatchTaskResult<E>,
item_count: usize,
item_payloads: Vec<Value>,
) -> TaskHandlerOutcome {
match result {
BatchTaskResult::Complete => TaskHandlerOutcome::Complete,
BatchTaskResult::FailAll(error) => TaskHandlerOutcome::failed(format!("{error:?}")),
BatchTaskResult::ItemResults(results) => {
if results.len() != item_count {
return TaskHandlerOutcome::failed(format!(
"batch handler returned {} results for {item_count} payload items",
results.len()
));
}
let mut failed_items = Vec::new();
let mut errors = Vec::new();
for (index, result) in results.into_iter().enumerate() {
let Err(error) = result else {
continue;
};
failed_items.push(item_payloads[index].clone());
errors.push(format!("{index}: {error:?}"));
}
if failed_items.is_empty() {
return TaskHandlerOutcome::Complete;
}
TaskHandlerOutcome::failed_with_replacement(
format!(
"{} batch item(s) failed: {}",
failed_items.len(),
errors.join(", ")
),
Value::Array(failed_items),
)
}
}
}
pub async fn run_task_from_worker_ctx<T: TaskHandler>(
worker_context: WorkerContext,
) -> Result<(), String> {
let job = T::deserialize(worker_context.payload());
let Ok(job) = job else {
let e = job.err().unwrap();
return Err(format!("{e:?}"));
};
job.run(worker_context)
.await
.into_task_handler_result()
.map_err(|e| format!("{e:?}"))
}