1use std::sync::Arc;
11
12use async_trait::async_trait;
13
14use crate::delta::StateDelta;
15use crate::error::{GraphError, ObservedError};
16use crate::event::{BarrierId, GraphEvent};
17use crate::ids::SpanId;
18use crate::state::State;
19
20pub use crate::barrier_node::{BarrierDefaultAction, BarrierNode};
23pub use crate::parallel_node::{ParallelErrorStrategy, ParallelNode, ParallelNodeBuilder};
24
25#[derive(Debug, Clone, PartialEq, Eq)]
29pub enum NextStep {
30 Goto(String),
32 GoToNext,
34 End,
36}
37
38#[derive(Debug)]
43pub struct NodeOutput {
44 pub deltas: Vec<StateDelta>,
46 pub next: NextStep,
48 pub metadata: Option<NodeMetadata>,
50}
51
52#[derive(Debug, Clone, Default)]
54pub struct NodeMetadata {
55 pub token_cost: f64,
57 pub has_side_effects: bool,
59}
60
61impl NodeOutput {
62 pub fn new(next: NextStep) -> Self {
64 Self {
65 deltas: Vec::new(),
66 next,
67 metadata: None,
68 }
69 }
70
71 pub fn with_delta(mut self, delta: StateDelta) -> Self {
73 self.deltas.push(delta);
74 self
75 }
76
77 pub fn with_deltas(mut self, deltas: Vec<StateDelta>) -> Self {
79 self.deltas.extend(deltas);
80 self
81 }
82
83 pub fn with_metadata(mut self, metadata: NodeMetadata) -> Self {
85 self.metadata = Some(metadata);
86 self
87 }
88
89 pub fn with_token_cost(mut self, cost: f64) -> Self {
91 self.metadata
92 .get_or_insert_with(Default::default)
93 .token_cost = cost;
94 self
95 }
96
97 pub fn with_side_effects(mut self) -> Self {
99 self.metadata
100 .get_or_insert_with(Default::default)
101 .has_side_effects = true;
102 self
103 }
104}
105
106#[derive(Debug)]
108pub enum StreamNodeResult {
109 Continue {
111 deltas: Vec<StateDelta>,
113 next: NextStep,
115 span_id: SpanId,
117 observed: Option<ObservedError>,
119 metadata: Option<NodeMetadata>,
121 },
122 Pause {
124 deltas: Vec<StateDelta>,
126 barrier_id: BarrierId,
128 node_name: String,
130 span_id: SpanId,
132 timeout: Option<std::time::Duration>,
134 default_action: BarrierDefaultAction,
136 },
137 Fallback {
142 deltas: Vec<StateDelta>,
144 reason: String,
146 node_name: String,
148 },
149}
150
151#[async_trait]
159pub trait FlowNode: Send + Sync {
160 async fn execute(&self, state: &State) -> Result<NodeOutput, GraphError>;
165
166 async fn execute_stream(
175 &self,
176 state: &State,
177 _sink: &tokio::sync::mpsc::Sender<GraphEvent>,
178 span_id: SpanId,
179 ) -> Result<StreamNodeResult, GraphError> {
180 let output = self.execute(state).await?;
181 Ok(StreamNodeResult::Continue {
182 deltas: output.deltas,
183 next: output.next,
184 span_id,
185 observed: None,
186 metadata: output.metadata,
187 })
188 }
189
190 fn metadata_hint(&self) -> NodeMetadata {
201 NodeMetadata::default()
202 }
203}
204
205#[derive(Clone)]
211pub enum NodeKind {
212 Task(TaskNode),
214 Condition(ConditionNode),
216 Barrier(BarrierNode),
218 Parallel(ParallelNode),
220 External(std::sync::Arc<dyn FlowNode>),
224}
225
226pub type TaskFn = Arc<dyn Fn(&State) -> Result<Vec<StateDelta>, GraphError> + Send + Sync>;
233
234pub type BranchCondition = Arc<dyn Fn(&State) -> bool + Send + Sync>;
237
238#[derive(Clone)]
240pub struct TaskNode {
241 pub name: String,
242 pub func: TaskFn,
243}
244
245impl TaskNode {
246 pub fn new(
247 name: impl Into<String>,
248 func: impl Fn(&State) -> Result<Vec<StateDelta>, GraphError> + Send + Sync + 'static,
249 ) -> Self {
250 Self {
251 name: name.into(),
252 func: Arc::new(func),
253 }
254 }
255}
256
257#[async_trait]
258impl FlowNode for TaskNode {
259 async fn execute(&self, state: &State) -> Result<NodeOutput, GraphError> {
260 let deltas = (self.func)(state)?;
261 Ok(NodeOutput {
262 deltas,
263 next: NextStep::GoToNext,
264 metadata: None,
265 })
266 }
267
268 fn metadata_hint(&self) -> NodeMetadata {
269 NodeMetadata {
271 token_cost: 0.0,
272 has_side_effects: false,
273 }
274 }
275}
276
277#[derive(Clone)]
284pub struct ConditionNode {
285 pub name: String,
286 pub branches: Vec<(String, BranchCondition)>,
287}
288
289impl ConditionNode {
290 pub fn builder(name: impl Into<String>) -> ConditionNodeBuilder {
291 ConditionNodeBuilder {
292 name: name.into(),
293 branches: Vec::new(),
294 }
295 }
296}
297
298pub struct ConditionNodeBuilder {
300 name: String,
301 branches: Vec<(String, BranchCondition)>,
302}
303
304impl ConditionNodeBuilder {
305 pub fn branch(
306 mut self,
307 target: impl Into<String>,
308 condition: impl Fn(&State) -> bool + Send + Sync + 'static,
309 ) -> Self {
310 self.branches.push((target.into(), Arc::new(condition)));
311 self
312 }
313
314 pub fn build(self) -> ConditionNode {
315 ConditionNode {
316 name: self.name,
317 branches: self.branches,
318 }
319 }
320}
321
322#[async_trait]
323impl FlowNode for ConditionNode {
324 async fn execute(&self, state: &State) -> Result<NodeOutput, GraphError> {
325 for (target, condition) in &self.branches {
326 if condition(state) {
327 return Ok(NodeOutput::new(NextStep::Goto(target.clone())));
328 }
329 }
330 Ok(NodeOutput::new(NextStep::GoToNext))
332 }
333
334 fn metadata_hint(&self) -> NodeMetadata {
335 NodeMetadata {
337 token_cost: 0.0,
338 has_side_effects: false,
339 }
340 }
341}
342
343#[async_trait]
346impl FlowNode for NodeKind {
347 async fn execute(&self, state: &State) -> Result<NodeOutput, GraphError> {
348 match self {
349 Self::Task(n) => n.execute(state).await,
350 Self::Condition(n) => n.execute(state).await,
351 Self::Barrier(n) => n.execute(state).await,
352 Self::Parallel(n) => n.execute_sequential(state).await,
353 Self::External(n) => n.execute(state).await,
354 }
355 }
356
357 async fn execute_stream(
358 &self,
359 state: &State,
360 sink: &tokio::sync::mpsc::Sender<GraphEvent>,
361 span_id: SpanId,
362 ) -> Result<StreamNodeResult, GraphError> {
363 match self {
364 Self::Task(n) => n.execute_stream(state, sink, span_id).await,
365 Self::Condition(n) => n.execute_stream(state, sink, span_id).await,
366 Self::Barrier(n) => n.execute_stream(state, sink, span_id).await,
367 Self::Parallel(_) => {
368 let output = self.execute(state).await?;
371 Ok(StreamNodeResult::Continue {
372 deltas: output.deltas,
373 next: output.next,
374 span_id,
375 observed: None,
376 metadata: output.metadata,
377 })
378 }
379 Self::External(n) => n.execute_stream(state, sink, span_id).await,
380 }
381 }
382}
383
384pub type GraphNode = dyn FlowNode;