#![warn(clippy::pedantic, missing_docs)]
#![allow(clippy::module_name_repetitions)]
use std::{any::type_name, collections::HashMap, marker::PhantomData, sync::Arc};
use async_trait::async_trait;
use eyre::Context;
use ora_common::{
task::{TaskDataFormat, TaskDefinition, WorkerSelector},
timeout::TimeoutPolicy,
UnixNanos,
};
use ora_worker::{RawHandler, TaskContext};
use serde::{de::DeserializeOwned, Serialize};
pub mod client;
pub trait Task: Serialize + DeserializeOwned + Send + Sync + 'static {
type Output: Serialize + DeserializeOwned + Send + Sync + 'static;
fn worker_selector() -> WorkerSelector {
let name = type_name::<Self>();
let name = name.split_once('<').map_or(name, |n| n.0);
let name = name.split("::").last().unwrap_or(name);
WorkerSelector {
kind: std::borrow::Cow::Borrowed(name),
}
}
#[must_use]
fn format() -> TaskDataFormat {
TaskDataFormat::Json
}
#[must_use]
fn timeout() -> TimeoutPolicy {
TimeoutPolicy::Never
}
fn task(&self) -> TaskDefinition<Self>
where
Self: Serialize,
{
TaskDefinition {
target: UnixNanos(0),
worker_selector: Self::worker_selector(),
data: match Self::format() {
TaskDataFormat::Unknown => panic!("invalid data format"),
TaskDataFormat::MessagePack => rmp_serde::to_vec_named(self).unwrap(),
TaskDataFormat::Json => serde_json::to_vec(self).unwrap(),
},
data_format: Self::format(),
labels: HashMap::default(),
timeout: Self::timeout(),
_task_type: PhantomData,
}
}
}
#[async_trait]
pub trait Handler<T>
where
Self: Sized + Send + Sync + 'static,
T: Task,
{
async fn run(&self, ctx: TaskContext, task: T) -> eyre::Result<T::Output>;
#[doc(hidden)]
fn raw_handler(self) -> Arc<dyn RawHandler + Send + Sync> {
self.raw_handler_with_selector(T::worker_selector())
}
#[doc(hidden)]
fn raw_handler_with_selector(
self,
selector: WorkerSelector,
) -> Arc<dyn RawHandler + Send + Sync> {
Arc::new(WorkerAdapter {
selector,
worker: self,
_task: PhantomData,
})
}
}
pub trait IntoHandler {
fn handler<T>(self) -> Arc<dyn RawHandler + Send + Sync>
where
Self: Handler<T>,
T: Task,
{
<Self as Handler<T>>::raw_handler(self)
}
fn handler_with_selector<T>(self, selector: WorkerSelector) -> Arc<dyn RawHandler + Send + Sync>
where
Self: Handler<T>,
T: Task,
{
<Self as Handler<T>>::raw_handler_with_selector(self, selector)
}
}
impl<W> IntoHandler for W where W: Sized + Send + Sync + 'static {}
struct WorkerAdapter<W, T>
where
T: Task,
W: Handler<T>,
{
worker: W,
selector: WorkerSelector,
_task: PhantomData<&'static T>,
}
#[async_trait]
impl<W, T> RawHandler for WorkerAdapter<W, T>
where
T: Task,
W: Handler<T> + Send + Sync + 'static,
{
fn selector(&self) -> &WorkerSelector {
&self.selector
}
fn output_format(&self) -> TaskDataFormat {
T::format()
}
async fn run(&self, context: TaskContext, task: TaskDefinition) -> eyre::Result<Vec<u8>> {
let task: T = match task.data_format {
TaskDataFormat::Unknown => eyre::bail!("failed to deserialize unknown data format"),
TaskDataFormat::MessagePack => {
rmp_serde::from_slice(&task.data).wrap_err("failed to deserialize MessagePack")?
}
TaskDataFormat::Json => {
serde_json::from_slice(&task.data).wrap_err("failed to deserialize JSON")?
}
};
let out = self.worker.run(context, task).await?;
match T::format() {
TaskDataFormat::Unknown => panic!("invalid task format"),
TaskDataFormat::MessagePack => {
Ok(rmp_serde::to_vec_named(&out).wrap_err("failed to deserialize output")?)
}
TaskDataFormat::Json => {
Ok(serde_json::to_vec(&out).wrap_err("failed to deserialize output")?)
}
}
}
}