1use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use crate::error::{GraphError, ObservedError, TerminalError};
13use crate::event::{BarrierId, GraphEvent, SpanId};
14use crate::graph::Edge;
15use crate::state::State;
16
17pub use crate::barrier_node::{BarrierDefaultAction, BarrierNode};
20pub use crate::llm_node::{AgentNode, LLMNode};
21pub use crate::tool_node::ToolNode;
22
23#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum NextStep {
28 Goto(String),
30 GoToNext,
32 End,
34}
35
36#[derive(Debug)]
38pub enum StreamNodeResult {
39 Done {
41 next: NextStep,
43 span_id: SpanId,
45 },
46 BarrierPaused {
48 barrier_id: BarrierId,
50 node_name: String,
52 span_id: SpanId,
54 timeout: Option<std::time::Duration>,
56 default_action: crate::barrier_node::BarrierDefaultAction,
58 },
59 Observed {
65 error: ObservedError,
67 next: NextStep,
69 span_id: SpanId,
71 },
72}
73
74#[async_trait]
76pub trait GraphNode: Send + Sync {
77 async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError>;
79
80 async fn execute_stream(
89 &self,
90 state: &mut State,
91 _sink: &tokio::sync::mpsc::Sender<GraphEvent>,
92 span_id: SpanId,
93 ) -> Result<StreamNodeResult, GraphError> {
94 let next = self.execute(state).await?;
95 Ok(StreamNodeResult::Done { next, span_id })
96 }
97}
98
99pub enum NodeKind {
101 Task(TaskNode),
103 Agent(Box<AgentNode>),
105 Tool(ToolNode),
107 Condition(ConditionNode),
109 Loop(Box<LoopNode>),
111 Barrier(BarrierNode),
113}
114
115pub type TaskFn = Arc<dyn Fn(&mut State) -> Result<(), GraphError> + Send + Sync>;
120
121pub type BranchCondition = Arc<dyn Fn(&State) -> bool + Send + Sync>;
124
125pub struct TaskNode {
127 pub name: String,
128 pub func: TaskFn,
129}
130
131impl TaskNode {
132 pub fn new(
133 name: impl Into<String>,
134 func: impl Fn(&mut State) -> Result<(), GraphError> + Send + Sync + 'static,
135 ) -> Self {
136 Self {
137 name: name.into(),
138 func: Arc::new(func),
139 }
140 }
141}
142
143#[async_trait]
144impl GraphNode for TaskNode {
145 async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
146 (self.func)(state)?;
147 Ok(NextStep::GoToNext)
148 }
149}
150
151pub struct ConditionNode {
155 pub name: String,
156 pub branches: Vec<(String, BranchCondition)>,
157 pub otherwise_target: Option<String>,
160}
161
162impl ConditionNode {
163 pub fn builder(name: impl Into<String>) -> ConditionNodeBuilder {
164 ConditionNodeBuilder {
165 name: name.into(),
166 branches: Vec::new(),
167 otherwise_target: None,
168 }
169 }
170}
171
172pub struct ConditionNodeBuilder {
174 name: String,
175 branches: Vec<(String, BranchCondition)>,
176 otherwise_target: Option<String>,
177}
178
179impl ConditionNodeBuilder {
180 pub fn branch(
181 mut self,
182 target: impl Into<String>,
183 condition: impl Fn(&State) -> bool + Send + Sync + 'static,
184 ) -> Self {
185 self.branches.push((target.into(), Arc::new(condition)));
186 self
187 }
188
189 pub fn otherwise(mut self, target: impl Into<String>) -> Self {
201 self.otherwise_target = Some(target.into());
202 self
203 }
204
205 pub fn build(self) -> ConditionNode {
206 ConditionNode {
207 name: self.name,
208 branches: self.branches,
209 otherwise_target: self.otherwise_target,
210 }
211 }
212}
213
214#[async_trait]
215impl GraphNode for ConditionNode {
216 async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
217 for (target, condition) in &self.branches {
218 if condition(state) {
219 return Ok(NextStep::Goto(target.clone()));
220 }
221 }
222 if let Some(ref target) = self.otherwise_target {
224 return Ok(NextStep::Goto(target.clone()));
225 }
226 Err(GraphError::Terminal(TerminalError::NodeExecutionFailed {
227 node: self.name.clone(),
228 source: "no matching branch and no otherwise target".into(),
229 }))
230 }
231}
232
233#[derive(Default)]
240pub struct SubGraph {
241 pub nodes: Vec<Arc<dyn GraphNode>>,
242 pub edges: Vec<Edge>,
243}
244
245impl SubGraph {
246 pub fn new() -> Self {
247 Self::default()
248 }
249
250 pub async fn execute(&self, state: &mut State) -> Result<(), GraphError> {
256 for node in &self.nodes {
257 match node.execute(state).await? {
258 NextStep::GoToNext => {
259 }
261 NextStep::End => {
262 break;
264 }
265 NextStep::Goto(target) => {
266 return Err(GraphError::Terminal(TerminalError::InvalidGraph(format!(
267 "SubGraph does not support Goto(\"{}\"). Use Graph::edge_if for conditional jumps.",
268 target
269 ))));
270 }
271 }
272 }
273 Ok(())
274 }
275}
276
277pub struct LoopNode {
294 pub name: String,
295 pub body: SubGraph,
296 pub continue_condition: Arc<dyn Fn(&State) -> bool + Send + Sync>,
297 pub max_iterations: usize,
298}
299
300impl LoopNode {
301 pub fn new(
302 name: impl Into<String>,
303 body: SubGraph,
304 continue_condition: impl Fn(&State) -> bool + Send + Sync + 'static,
305 max_iterations: usize,
306 ) -> Self {
307 Self {
308 name: name.into(),
309 body,
310 continue_condition: Arc::new(continue_condition),
311 max_iterations,
312 }
313 }
314}
315
316#[async_trait]
317impl GraphNode for LoopNode {
318 async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
319 for i in 0..self.max_iterations {
320 tracing::debug!(
321 loop_name = %self.name,
322 iteration = i + 1,
323 max = self.max_iterations,
324 "executing loop body"
325 );
326
327 self.body.execute(state).await?;
328
329 if !(self.continue_condition)(state) {
330 tracing::debug!(
331 loop_name = %self.name,
332 iterations = i + 1,
333 "loop condition met, exiting"
334 );
335 return Ok(NextStep::GoToNext);
336 }
337 }
338
339 Err(GraphError::Terminal(TerminalError::LoopLimitExceeded {
340 limit: self.max_iterations,
341 }))
342 }
343}
344
345#[async_trait]
348impl GraphNode for NodeKind {
349 async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
350 match self {
351 Self::Task(n) => n.execute(state).await,
352 Self::Agent(n) => n.execute(state).await,
353 Self::Tool(n) => n.execute(state).await,
354 Self::Condition(n) => n.execute(state).await,
355 Self::Loop(n) => n.execute(state).await,
356 Self::Barrier(n) => n.execute(state).await,
357 }
358 }
359
360 async fn execute_stream(
361 &self,
362 state: &mut State,
363 sink: &tokio::sync::mpsc::Sender<GraphEvent>,
364 span_id: SpanId,
365 ) -> Result<StreamNodeResult, GraphError> {
366 match self {
367 Self::Task(n) => n.execute_stream(state, sink, span_id).await,
368 Self::Agent(n) => n.execute_stream(state, sink, span_id).await,
369 Self::Tool(n) => n.execute_stream(state, sink, span_id).await,
370 Self::Condition(n) => n.execute_stream(state, sink, span_id).await,
371 Self::Loop(n) => n.execute_stream(state, sink, span_id).await,
372 Self::Barrier(n) => n.execute_stream(state, sink, span_id).await,
373 }
374 }
375}