use std::sync::Arc;
use async_trait::async_trait;
use crate::error::GraphError;
use crate::execution_engine::ExecutionEngine;
pub use crate::node_context::LeafContext;
use crate::node_context::NodeContext;
use crate::state::{State, StateMerge};
use crate::workflow_state::{MergeStrategy, WorkflowState};
pub use crate::barrier_node::{BarrierDefaultAction, BarrierNode};
pub use crate::parallel_node::{ParallelErrorStrategy, ParallelNode, ParallelNodeBuilder};
#[async_trait]
pub trait LeafNode<S: WorkflowState = State>: Send + Sync {
async fn execute(&self, ctx: &mut LeafContext<'_, S>) -> Result<(), GraphError>;
}
#[async_trait]
pub trait ExecutorOperation<S: WorkflowState = State>: Send + Sync {
async fn execute(&self, engine: &mut ExecutionEngine<S>) -> Result<(), GraphError>;
}
#[async_trait]
pub trait FlowNode<S: WorkflowState = State>: Send + Sync {
async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError>;
}
pub enum NodeKind<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
Task(TaskNode<S>),
Condition(ConditionNode<S>),
Barrier(BarrierNode<S>),
Parallel(ParallelNode<S, M>),
External(Arc<dyn FlowNode<S>>),
ExternalLeaf(Arc<dyn LeafNode<S>>),
}
impl<S: WorkflowState, M: MergeStrategy<S>> Clone for NodeKind<S, M> {
fn clone(&self) -> Self {
match self {
Self::Task(n) => Self::Task(n.clone()),
Self::Condition(n) => Self::Condition(n.clone()),
Self::Barrier(n) => Self::Barrier(n.clone()),
Self::Parallel(n) => Self::Parallel(n.clone()),
Self::External(n) => Self::External(n.clone()),
Self::ExternalLeaf(n) => Self::ExternalLeaf(n.clone()),
}
}
}
pub type TaskFn<S> = Arc<dyn Fn(&mut NodeContext<'_, S>) -> Result<(), GraphError> + Send + Sync>;
#[derive(Clone)]
pub struct TaskNode<S: WorkflowState = State> {
pub name: String,
pub func: TaskFn<S>,
}
impl<S: WorkflowState> TaskNode<S> {
pub fn new(
name: impl Into<String>,
func: impl Fn(&mut NodeContext<'_, S>) -> Result<(), GraphError> + Send + Sync + 'static,
) -> Self {
Self {
name: name.into(),
func: Arc::new(func),
}
}
}
#[async_trait]
impl<S: WorkflowState> FlowNode<S> for TaskNode<S> {
async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
(self.func)(ctx)
}
}
pub type BranchCondition<S> = Arc<dyn Fn(&S) -> bool + Send + Sync>;
#[derive(Clone)]
pub struct ConditionNode<S: WorkflowState = State> {
pub name: String,
pub branches: Vec<(String, BranchCondition<S>)>,
}
impl<S: WorkflowState> ConditionNode<S> {
pub fn builder(name: impl Into<String>) -> ConditionNodeBuilder<S> {
ConditionNodeBuilder {
name: name.into(),
branches: Vec::new(),
}
}
}
pub struct ConditionNodeBuilder<S: WorkflowState = State> {
name: String,
branches: Vec<(String, BranchCondition<S>)>,
}
impl<S: WorkflowState> ConditionNodeBuilder<S> {
pub fn branch(
mut self,
target: impl Into<String>,
condition: impl Fn(&S) -> bool + Send + Sync + 'static,
) -> Self {
self.branches.push((target.into(), Arc::new(condition)));
self
}
pub fn build(self) -> ConditionNode<S> {
ConditionNode {
name: self.name,
branches: self.branches,
}
}
}
#[async_trait]
impl<S: WorkflowState> LeafNode<S> for ConditionNode<S> {
async fn execute(&self, ctx: &mut LeafContext<'_, S>) -> Result<(), GraphError> {
let state = ctx.state();
for (target, condition) in &self.branches {
if condition(state) {
ctx.goto(target);
return Ok(());
}
}
Ok(())
}
}
#[async_trait]
impl<S: WorkflowState> FlowNode<S> for ConditionNode<S> {
async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
let state = ctx.state();
for (target, condition) in &self.branches {
if condition(state) {
ctx.goto(target);
return Ok(());
}
}
Ok(())
}
}
pub type GraphNode<S> = dyn FlowNode<S>;