use crate::codec::{Codec, sealed};
use crate::error::{BoxError, CodecError, WorkflowError};
use crate::loop_result::LoopResult;
use crate::priority::Priority;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct TaskMetadata {
pub display_name: Option<String>,
pub description: Option<String>,
pub timeout: Option<Duration>,
pub retries: Option<RetryPolicy>,
pub tags: Vec<String>,
pub version: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub priority: Option<Priority>,
}
impl TaskMetadata {
#[must_use]
pub fn from_node_fields(
timeout: Option<Duration>,
retries: Option<RetryPolicy>,
version: Option<String>,
priority: Option<u8>,
tags: Vec<String>,
) -> Self {
Self {
timeout,
retries,
version,
priority: priority.and_then(Priority::from_u8),
tags,
..Default::default()
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RetryPolicy {
#[serde(alias = "max_attempts")]
pub max_retries: u32,
pub initial_delay: Duration,
pub backoff_multiplier: f32,
#[serde(default)]
pub max_delay: Option<Duration>,
}
impl RetryPolicy {
#[must_use]
pub fn with_max_retries(max_retries: u32) -> Self {
Self {
max_retries,
initial_delay: Duration::from_secs(1),
backoff_multiplier: 2.0,
max_delay: None,
}
}
}
pub use crate::branch_results::NamedBranchResults;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub struct BranchEnvelope<T> {
pub branch: String,
pub result: T,
}
pub struct BranchOutputs<C> {
outputs: HashMap<String, Bytes>,
codec: Arc<C>,
}
impl<C> BranchOutputs<C> {
pub fn new(outputs: HashMap<String, Bytes>, codec: Arc<C>) -> Self {
Self { outputs, codec }
}
pub fn branch_names(&self) -> impl Iterator<Item = &str> {
self.outputs.keys().map(std::string::String::as_str)
}
#[must_use]
pub fn contains(&self, name: &str) -> bool {
self.outputs.contains_key(name)
}
#[must_use]
pub fn len(&self) -> usize {
self.outputs.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.outputs.is_empty()
}
}
impl<C: Codec> BranchOutputs<C> {
pub fn get<T>(&self) -> Result<T::Output, BoxError>
where
T: RegisterableTask,
T::Input: Send + 'static,
T::Output: Send + 'static,
T::Future: Send + 'static,
C: sealed::DecodeValue<T::Output>,
{
self.get_by_id(T::task_id())
}
pub fn get_by_id<T>(&self, name: &str) -> Result<T, BoxError>
where
C: sealed::DecodeValue<T>,
{
let bytes = self
.outputs
.get(name)
.ok_or_else(|| WorkflowError::BranchNotFound(name.to_string()))?;
self.codec.decode(bytes.clone())
}
}
pub trait TaskIdentifier {
fn task_id() -> &'static str;
}
pub trait RegisterableTask: CoreTask + Send + Sync + 'static
where
Self::Input: Send + 'static,
Self::Output: Send + 'static,
Self::Future: Send + 'static,
{
fn task_id() -> &'static str;
fn metadata() -> TaskMetadata;
}
impl<T> TaskIdentifier for T
where
T: RegisterableTask,
T::Input: Send + 'static,
T::Output: Send + 'static,
T::Future: Send + 'static,
{
fn task_id() -> &'static str {
<T as RegisterableTask>::task_id()
}
}
pub trait CoreTask: Send + Sync {
type Input;
type Output;
type Future: Future<Output = Result<Self::Output, BoxError>> + Send;
fn run(&self, input: Self::Input) -> Self::Future;
}
pub struct FnTask<F, I, O, Fut>(F, PhantomData<fn(I) -> (O, Fut)>);
impl<F, I, O, Fut> CoreTask for FnTask<F, I, O, Fut>
where
F: Fn(I) -> Fut + Send + Sync,
I: Send,
O: Send,
Fut: Future<Output = Result<O, BoxError>> + Send,
{
type Input = I;
type Output = O;
type Future = Fut;
fn run(&self, input: I) -> Self::Future {
(self.0)(input)
}
}
pub fn fn_task<F, I, O, Fut>(f: F) -> FnTask<F, I, O, Fut>
where
F: Fn(I) -> Fut,
{
FnTask(f, PhantomData)
}
pub struct BytesFuture(Pin<Box<dyn Future<Output = Result<Bytes, BoxError>> + Send>>);
impl Future for BytesFuture {
type Output = Result<Bytes, BoxError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.as_mut().poll(cx)
}
}
impl BytesFuture {
pub fn new<F>(fut: F) -> Self
where
F: Future<Output = Result<Bytes, BoxError>> + Send + 'static,
{
BytesFuture(Box::pin(fut))
}
}
pub type UntypedCoreTask =
Box<dyn CoreTask<Input = Bytes, Output = Bytes, Future = BytesFuture> + Send + Sync>;
macro_rules! impl_codec_task {
(
$wrapper:ident < $($gen:ident),+ >
where $func_type:ty : Fn($input:ty) -> $fut_type:ty,
$($bound:tt)+
) => {
impl< $($gen),+ > CoreTask for $wrapper < $($gen),+ >
where
$func_type : Fn($input) -> $fut_type + Send + Sync + 'static,
$($bound)+
{
type Input = Bytes;
type Output = Bytes;
type Future = BytesFuture;
fn run(&self, input: Bytes) -> Self::Future {
let func = Arc::clone(&self.func);
let codec = Arc::clone(&self.codec);
let task_id = self.task_id.clone();
BytesFuture::new(async move {
let decoded_input = codec.decode::<$input>(input)
.map_err(|e| -> BoxError { Box::new(CodecError::DecodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
expected_type: std::any::type_name::<$input>(),
source: e,
}) })?;
let output = func(decoded_input).await?;
codec.encode(&output)
.map_err(|e| -> BoxError { Box::new(CodecError::EncodeFailed {
task_id: $crate::TaskId::from(task_id.as_str()),
source: e,
}) })
})
}
}
};
}
struct UntypedTaskFnWrapper<F, I, O, Fut, C> {
func: Arc<F>,
codec: Arc<C>,
task_id: String,
_phantom: std::marker::PhantomData<fn(I) -> (O, Fut)>,
}
impl_codec_task!(
UntypedTaskFnWrapper<F, I, O, Fut, C>
where F: Fn(I) -> Fut,
I: Send + 'static,
O: Send + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
);
pub fn to_core_task<F, I, O, Fut, C>(id: &str, func: F, codec: Arc<C>) -> UntypedCoreTask
where
F: Fn(I) -> Fut + Send + Sync + 'static,
I: Send + 'static,
O: Send + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
{
to_core_task_arc(id, Arc::new(func), codec)
}
pub fn to_core_task_arc<F, I, O, Fut, C>(id: &str, func: Arc<F>, codec: Arc<C>) -> UntypedCoreTask
where
F: Fn(I) -> Fut + Send + Sync + 'static,
I: Send + 'static,
O: Send + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
{
Box::new(UntypedTaskFnWrapper {
func,
codec,
task_id: id.to_string(),
_phantom: std::marker::PhantomData,
})
}
pub fn to_core_loop_task_arc<F, I, O, Fut, C>(
id: &str,
func: Arc<F>,
codec: Arc<C>,
) -> UntypedCoreTask
where
F: Fn(I) -> Fut + Send + Sync + 'static,
I: Send + 'static,
O: Send + 'static,
Fut: Future<Output = Result<LoopResult<O>, BoxError>> + Send + 'static,
C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
{
struct LoopTaskFnWrapper<F, I, O, Fut, C> {
func: Arc<F>,
codec: Arc<C>,
task_id: String,
_phantom: PhantomData<fn(I) -> (O, Fut)>,
}
impl<F, I, O, Fut, C> CoreTask for LoopTaskFnWrapper<F, I, O, Fut, C>
where
F: Fn(I) -> Fut + Send + Sync + 'static,
I: Send + 'static,
O: Send + 'static,
Fut: Future<Output = Result<LoopResult<O>, BoxError>> + Send + 'static,
C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
{
type Input = Bytes;
type Output = Bytes;
type Future = BytesFuture;
fn run(&self, input: Bytes) -> Self::Future {
let func = Arc::clone(&self.func);
let codec = Arc::clone(&self.codec);
let task_id = self.task_id.clone();
BytesFuture::new(async move {
let decoded_input = codec.decode::<I>(input).map_err(|e| -> BoxError {
Box::new(CodecError::DecodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
expected_type: std::any::type_name::<I>(),
source: e,
})
})?;
let loop_result = func(decoded_input).await?;
let (decision, inner) = loop_result.into_decision();
let inner_bytes = codec.encode(&inner).map_err(|e| -> BoxError {
Box::new(CodecError::EncodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
source: e,
})
})?;
Ok(crate::codec::encode_loop_envelope(decision, &inner_bytes))
})
}
}
Box::new(LoopTaskFnWrapper {
func,
codec,
task_id: id.to_string(),
_phantom: PhantomData,
})
}
pub fn wrap_core_loop_task<T, O, C>(id: &str, task: Arc<T>, codec: Arc<C>) -> UntypedCoreTask
where
T: CoreTask<Output = LoopResult<O>> + 'static,
T::Input: Send + 'static,
O: Send + 'static,
T::Future: Send + 'static,
C: Codec + sealed::DecodeValue<T::Input> + sealed::EncodeValue<O> + 'static,
{
struct LoopCoreTaskWrapper<T, O, C> {
task: Arc<T>,
codec: Arc<C>,
task_id: String,
_phantom: PhantomData<fn() -> O>,
}
impl<T, O, C> CoreTask for LoopCoreTaskWrapper<T, O, C>
where
T: CoreTask<Output = LoopResult<O>> + Send + Sync + 'static,
T::Input: Send + 'static,
O: Send + 'static,
T::Future: Send + 'static,
C: Codec + sealed::DecodeValue<T::Input> + sealed::EncodeValue<O>,
{
type Input = Bytes;
type Output = Bytes;
type Future = BytesFuture;
fn run(&self, input: Bytes) -> Self::Future {
let task = Arc::clone(&self.task);
let codec = Arc::clone(&self.codec);
let task_id = self.task_id.clone();
BytesFuture::new(async move {
let decoded_input = codec.decode::<T::Input>(input).map_err(|e| -> BoxError {
Box::new(CodecError::DecodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
expected_type: std::any::type_name::<T::Input>(),
source: e,
})
})?;
let loop_result = task.run(decoded_input).await?;
let (decision, inner) = loop_result.into_decision();
let inner_bytes = codec.encode(&inner).map_err(|e| -> BoxError {
Box::new(CodecError::EncodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
source: e,
})
})?;
Ok(crate::codec::encode_loop_envelope(decision, &inner_bytes))
})
}
}
Box::new(LoopCoreTaskWrapper {
task,
codec,
task_id: id.to_string(),
_phantom: PhantomData,
})
}
type BoxedBranchFn<I, O> = Box<
dyn Fn(I) -> std::pin::Pin<Box<dyn Future<Output = Result<O, BoxError>> + Send>> + Send + Sync,
>;
pub(crate) struct Branch<I, O> {
id: String,
func: BoxedBranchFn<I, O>,
}
pub(crate) fn branch<F, Fut, I, O>(id: &str, f: F) -> Branch<I, O>
where
F: Fn(I) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
I: 'static,
O: 'static,
{
Branch {
id: id.to_string(),
func: Box::new(move |i| Box::pin(f(i))),
}
}
pub(crate) struct ErasedBranch {
pub(crate) id: String,
pub(crate) task: UntypedCoreTask,
}
impl<I, O> Branch<I, O> {
pub fn erase<C>(self, codec: Arc<C>) -> ErasedBranch
where
I: Send + 'static,
O: Send + 'static,
C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
{
ErasedBranch {
id: self.id.clone(),
task: branch_to_core_task(self, codec),
}
}
}
#[allow(clippy::items_after_statements)]
pub(crate) fn branch_to_core_task<I, O, C>(branch: Branch<I, O>, codec: Arc<C>) -> UntypedCoreTask
where
I: Send + 'static,
O: Send + 'static,
C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
{
let id = branch.id.clone();
let func = Arc::new(branch.func);
struct ArcBranchWrapper<I, O, C> {
func: Arc<BoxedBranchFn<I, O>>,
codec: Arc<C>,
task_id: String,
_phantom: PhantomData<fn(I) -> O>,
}
impl_codec_task!(
ArcBranchWrapper<I, O, C>
where BoxedBranchFn<I, O>: Fn(I) -> std::pin::Pin<Box<dyn Future<Output = Result<O, BoxError>> + Send>>,
I: Send + 'static,
O: Send + 'static,
C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
);
Box::new(ArcBranchWrapper {
func,
codec,
task_id: id,
_phantom: PhantomData,
})
}
#[allow(clippy::type_complexity)]
struct HeterogeneousJoinTaskWrapper<F, JoinOutput, Fut, C> {
func: Arc<F>,
codec: Arc<C>,
task_id: String,
_phantom: PhantomData<fn(BranchOutputs<C>) -> (JoinOutput, Fut)>,
}
impl<F, JoinOutput, Fut, C> CoreTask for HeterogeneousJoinTaskWrapper<F, JoinOutput, Fut, C>
where
F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
JoinOutput: Send + 'static,
Fut: Future<Output = Result<JoinOutput, BoxError>> + Send + 'static,
C: Codec
+ sealed::EncodeValue<JoinOutput>
+ sealed::DecodeValue<NamedBranchResults>
+ Send
+ Sync
+ 'static,
{
type Input = Bytes;
type Output = Bytes;
type Future = BytesFuture;
fn run(&self, input: Bytes) -> Self::Future {
let func = Arc::clone(&self.func);
let codec = Arc::clone(&self.codec);
let task_id = self.task_id.clone();
BytesFuture::new(async move {
let named_results: NamedBranchResults =
codec.decode(input).map_err(|e| -> BoxError {
Box::new(CodecError::DecodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
expected_type: std::any::type_name::<NamedBranchResults>(),
source: e,
})
})?;
let branch_outputs = BranchOutputs::new(named_results.into_map(), codec.clone());
let output = func(branch_outputs).await?;
codec.encode(&output).map_err(|e| -> BoxError {
Box::new(CodecError::EncodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
source: e,
})
})
})
}
}
pub fn to_heterogeneous_join_task<F, JoinOutput, Fut, C>(
id: &str,
func: F,
codec: Arc<C>,
) -> UntypedCoreTask
where
F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
JoinOutput: Send + 'static,
Fut: Future<Output = Result<JoinOutput, BoxError>> + Send + 'static,
C: Codec
+ sealed::EncodeValue<JoinOutput>
+ sealed::DecodeValue<NamedBranchResults>
+ Send
+ Sync
+ 'static,
{
to_heterogeneous_join_task_arc(id, Arc::new(func), codec)
}
pub fn to_heterogeneous_join_task_arc<F, JoinOutput, Fut, C>(
id: &str,
func: Arc<F>,
codec: Arc<C>,
) -> UntypedCoreTask
where
F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
JoinOutput: Send + 'static,
Fut: Future<Output = Result<JoinOutput, BoxError>> + Send + 'static,
C: Codec
+ sealed::EncodeValue<JoinOutput>
+ sealed::DecodeValue<NamedBranchResults>
+ Send
+ Sync
+ 'static,
{
Box::new(HeterogeneousJoinTaskWrapper {
func,
codec,
task_id: id.to_string(),
_phantom: PhantomData,
})
}
pub fn wrap_core_task<T, C>(id: &str, task: Arc<T>, codec: Arc<C>) -> UntypedCoreTask
where
T: CoreTask + 'static,
T::Input: Send + 'static,
T::Output: Send + 'static,
T::Future: Send + 'static,
C: Codec + sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output> + 'static,
{
struct CoreTaskWrapper<T, C> {
task: Arc<T>,
codec: Arc<C>,
task_id: String,
}
impl<T, C> CoreTask for CoreTaskWrapper<T, C>
where
T: CoreTask + Send + Sync + 'static,
T::Input: Send + 'static,
T::Output: Send + 'static,
T::Future: Send + 'static,
C: Codec + sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output>,
{
type Input = Bytes;
type Output = Bytes;
type Future = BytesFuture;
fn run(&self, input: Bytes) -> Self::Future {
let task = Arc::clone(&self.task);
let codec = Arc::clone(&self.codec);
let task_id = self.task_id.clone();
BytesFuture::new(async move {
let decoded_input = codec.decode::<T::Input>(input).map_err(|e| -> BoxError {
Box::new(CodecError::DecodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
expected_type: std::any::type_name::<T::Input>(),
source: e,
})
})?;
let output = task.run(decoded_input).await?;
codec.encode(&output).map_err(|e| -> BoxError {
Box::new(CodecError::EncodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
source: e,
})
})
})
}
}
Box::new(CoreTaskWrapper {
task,
codec,
task_id: id.to_string(),
})
}