1use std::sync::Arc;
29
30use async_trait::async_trait;
31
32pub use super::node_context::LeafContext;
33use super::node_context::NodeContext;
34use crate::error::GraphError;
35use crate::exec::execution_engine::ExecutionEngine;
36use crate::state::workflow_state::{MergeStrategy, WorkflowState};
37use crate::state::{State, StateMerge};
38
39pub use super::barrier_node::{BarrierDefaultAction, BarrierNode};
42pub use super::parallel_node::{ParallelErrorStrategy, ParallelNode, ParallelNodeBuilder};
43
44#[async_trait]
61pub trait LeafNode<S: WorkflowState = State>: Send + Sync {
62 async fn execute(&self, ctx: &mut LeafContext<'_, S>) -> Result<(), GraphError>;
64}
65
66#[async_trait]
76pub trait ExecutorOperation<S: WorkflowState = State>: Send + Sync {
77 async fn execute(&self, engine: &mut ExecutionEngine<'_, S>) -> Result<(), GraphError>;
79}
80
81#[async_trait]
88pub trait FlowNode<S: WorkflowState = State>: Send + Sync {
89 async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError>;
91}
92
93pub enum NodeKind<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
105 Task(TaskNode<S>),
107 Condition(ConditionNode<S>),
109 Barrier(BarrierNode<S>),
111 Parallel(ParallelNode<S, M>),
113 External(Arc<dyn FlowNode<S>>),
115 ExternalLeaf(Arc<dyn LeafNode<S>>),
117 Subgraph(super::compiled_subgraph::CompiledSubgraph<S>),
122}
123
124impl<S: WorkflowState, M: MergeStrategy<S>> Clone for NodeKind<S, M> {
125 fn clone(&self) -> Self {
126 match self {
127 Self::Task(n) => Self::Task(n.clone()),
128 Self::Condition(n) => Self::Condition(n.clone()),
129 Self::Barrier(n) => Self::Barrier(n.clone()),
130 Self::Parallel(n) => Self::Parallel(n.clone()),
131 Self::External(n) => Self::External(n.clone()),
132 Self::ExternalLeaf(n) => Self::ExternalLeaf(n.clone()),
133 Self::Subgraph(n) => Self::Subgraph(n.clone()),
134 }
135 }
136}
137
138pub type TaskFn<S> = Arc<dyn Fn(&mut NodeContext<'_, S>) -> Result<(), GraphError> + Send + Sync>;
142
143#[derive(Clone)]
145pub struct TaskNode<S: WorkflowState = State> {
146 pub name: String,
147 pub func: TaskFn<S>,
148}
149
150impl<S: WorkflowState> TaskNode<S> {
151 pub fn new(
152 name: impl Into<String>,
153 func: impl Fn(&mut NodeContext<'_, S>) -> Result<(), GraphError> + Send + Sync + 'static,
154 ) -> Self {
155 Self {
156 name: name.into(),
157 func: Arc::new(func),
158 }
159 }
160}
161
162#[async_trait]
166impl<S: WorkflowState> FlowNode<S> for TaskNode<S> {
167 async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
168 (self.func)(ctx)
169 }
170}
171
172pub type BranchCondition<S> = Arc<dyn Fn(&S) -> bool + Send + Sync>;
176
177#[derive(Clone)]
179pub struct ConditionNode<S: WorkflowState = State> {
180 pub name: String,
181 pub branches: Vec<(String, BranchCondition<S>)>,
182}
183
184impl<S: WorkflowState> ConditionNode<S> {
185 pub fn builder(name: impl Into<String>) -> ConditionNodeBuilder<S> {
186 ConditionNodeBuilder {
187 name: name.into(),
188 branches: Vec::new(),
189 }
190 }
191}
192
193pub struct ConditionNodeBuilder<S: WorkflowState = State> {
195 name: String,
196 branches: Vec<(String, BranchCondition<S>)>,
197}
198
199impl<S: WorkflowState> ConditionNodeBuilder<S> {
200 pub fn branch(
201 mut self,
202 target: impl Into<String>,
203 condition: impl Fn(&S) -> bool + Send + Sync + 'static,
204 ) -> Self {
205 self.branches.push((target.into(), Arc::new(condition)));
206 self
207 }
208
209 pub fn build(self) -> ConditionNode<S> {
210 ConditionNode {
211 name: self.name,
212 branches: self.branches,
213 }
214 }
215}
216
217#[async_trait]
219impl<S: WorkflowState> LeafNode<S> for ConditionNode<S> {
220 async fn execute(&self, ctx: &mut LeafContext<'_, S>) -> Result<(), GraphError> {
221 let state = ctx.state();
222 for (target, condition) in &self.branches {
223 if condition(state) {
224 ctx.goto(target);
225 return Ok(());
226 }
227 }
228 Ok(())
229 }
230}
231
232#[async_trait]
234impl<S: WorkflowState> FlowNode<S> for ConditionNode<S> {
235 async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
236 let state = ctx.state();
237 for (target, condition) in &self.branches {
238 if condition(state) {
239 ctx.goto(target);
240 return Ok(());
241 }
242 }
243 Ok(())
244 }
245}
246
247pub type GraphNode<S> = dyn FlowNode<S>;