1use std::sync::Arc;
11
12use async_trait::async_trait;
13
14use crate::error::{GraphError, ObservedError};
15use crate::event::BarrierId;
16use crate::ids::SpanId;
17use crate::node_context::NodeContext;
18use crate::state::{State, StateMerge};
19use crate::workflow_state::{MergeStrategy, WorkflowState};
20
21pub use crate::barrier_node::{BarrierDefaultAction, BarrierNode};
24pub use crate::parallel_node::{
25 ParallelErrorStrategy, ParallelNode, ParallelNodeBuilder, ParallelNodeBuilderWithMerge,
26};
27
28#[derive(Debug, Clone, PartialEq, Eq)]
32pub enum NextStep {
33 Goto(String),
35 GoToNext,
37 End,
39}
40
41#[derive(Debug)]
45pub struct NodeOutput {
46 pub deltas: Vec<crate::delta::StateDelta>,
48 pub next: NextStep,
50 pub metadata: Option<crate::node_context::NodeMetadata>,
52}
53
54impl NodeOutput {
55 pub fn new(next: NextStep) -> Self {
56 Self {
57 deltas: Vec::new(),
58 next,
59 metadata: None,
60 }
61 }
62
63 pub fn with_delta(mut self, delta: crate::delta::StateDelta) -> Self {
64 self.deltas.push(delta);
65 self
66 }
67
68 pub fn with_deltas(mut self, deltas: Vec<crate::delta::StateDelta>) -> Self {
69 self.deltas.extend(deltas);
70 self
71 }
72
73 pub fn with_metadata(mut self, metadata: crate::node_context::NodeMetadata) -> Self {
74 self.metadata = Some(metadata);
75 self
76 }
77
78 pub fn with_token_cost(mut self, cost: f64) -> Self {
79 self.metadata
80 .get_or_insert_with(Default::default)
81 .token_cost = cost;
82 self
83 }
84
85 pub fn with_side_effects(mut self) -> Self {
86 self.metadata
87 .get_or_insert_with(Default::default)
88 .has_side_effects = true;
89 self
90 }
91}
92
93pub use crate::node_context::NodeMetadata;
95
96#[derive(Debug)]
100pub enum StreamNodeResult {
101 Continue {
102 deltas: Vec<crate::delta::StateDelta>,
103 next: NextStep,
104 span_id: SpanId,
105 observed: Option<ObservedError>,
106 metadata: Option<NodeMetadata>,
107 },
108 Pause {
109 deltas: Vec<crate::delta::StateDelta>,
110 barrier_id: BarrierId,
111 node_name: String,
112 span_id: SpanId,
113 timeout: Option<std::time::Duration>,
114 default_action: BarrierDefaultAction,
115 },
116 Fallback {
117 deltas: Vec<crate::delta::StateDelta>,
118 reason: String,
119 node_name: String,
120 },
121}
122
123#[async_trait]
138pub trait FlowNode<S: WorkflowState = State>: Send + Sync {
139 async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError>;
141}
142
143pub enum NodeKind<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
150 Task(TaskNode<S>),
152 Condition(ConditionNode<S>),
154 Barrier(BarrierNode<S>),
156 Parallel(ParallelNode<S, M>),
158 External(Arc<dyn FlowNode<S>>),
160}
161
162impl<S: WorkflowState, M: MergeStrategy<S>> Clone for NodeKind<S, M> {
163 fn clone(&self) -> Self {
164 match self {
165 Self::Task(n) => Self::Task(n.clone()),
166 Self::Condition(n) => Self::Condition(n.clone()),
167 Self::Barrier(n) => Self::Barrier(n.clone()),
168 Self::Parallel(n) => Self::Parallel(n.clone()),
169 Self::External(n) => Self::External(n.clone()),
170 }
171 }
172}
173
174pub type TaskFn<S> = Arc<dyn Fn(&mut NodeContext<'_, S>) -> Result<(), GraphError> + Send + Sync>;
178
179#[derive(Clone)]
181pub struct TaskNode<S: WorkflowState = State> {
182 pub name: String,
183 pub func: TaskFn<S>,
184}
185
186impl<S: WorkflowState> TaskNode<S> {
187 pub fn new(
188 name: impl Into<String>,
189 func: impl Fn(&mut NodeContext<'_, S>) -> Result<(), GraphError> + Send + Sync + 'static,
190 ) -> Self {
191 Self {
192 name: name.into(),
193 func: Arc::new(func),
194 }
195 }
196}
197
198#[async_trait]
199impl<S: WorkflowState> FlowNode<S> for TaskNode<S> {
200 async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
201 (self.func)(ctx)
202 }
203}
204
205pub type BranchCondition<S> = Arc<dyn Fn(&S) -> bool + Send + Sync>;
209
210#[derive(Clone)]
212pub struct ConditionNode<S: WorkflowState = State> {
213 pub name: String,
214 pub branches: Vec<(String, BranchCondition<S>)>,
215}
216
217impl<S: WorkflowState> ConditionNode<S> {
218 pub fn builder(name: impl Into<String>) -> ConditionNodeBuilder<S> {
219 ConditionNodeBuilder {
220 name: name.into(),
221 branches: Vec::new(),
222 }
223 }
224}
225
226pub struct ConditionNodeBuilder<S: WorkflowState = State> {
228 name: String,
229 branches: Vec<(String, BranchCondition<S>)>,
230}
231
232impl<S: WorkflowState> ConditionNodeBuilder<S> {
233 pub fn branch(
234 mut self,
235 target: impl Into<String>,
236 condition: impl Fn(&S) -> bool + Send + Sync + 'static,
237 ) -> Self {
238 self.branches.push((target.into(), Arc::new(condition)));
239 self
240 }
241
242 pub fn build(self) -> ConditionNode<S> {
243 ConditionNode {
244 name: self.name,
245 branches: self.branches,
246 }
247 }
248}
249
250#[async_trait]
251impl<S: WorkflowState> FlowNode<S> for ConditionNode<S> {
252 async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
253 let state = ctx.state();
254 for (target, condition) in &self.branches {
255 if condition(state) {
256 ctx.goto(target);
257 return Ok(());
258 }
259 }
260 Ok(())
261 }
262}
263
264#[async_trait]
267impl<S: WorkflowState, M: MergeStrategy<S>> FlowNode<S> for NodeKind<S, M> {
268 async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
269 match self {
270 Self::Task(n) => n.execute(ctx).await,
271 Self::Condition(n) => n.execute(ctx).await,
272 Self::Barrier(n) => n.execute(ctx).await,
273 Self::Parallel(n) => n.execute(ctx).await,
274 Self::External(n) => n.execute(ctx).await,
275 }
276 }
277}
278
279pub type GraphNode<S> = dyn FlowNode<S>;