1use std::collections::{HashMap, HashSet, VecDeque};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use async_trait::async_trait;
12use futures::StreamExt;
13
14use super::base::{
15 InterruptState, InvocationState, MultiAgentBase, MultiAgentEvent,
16 MultiAgentEventStream, MultiAgentInput, MultiAgentResult, NodeResult, NodeResultValue, Status,
17};
18use crate::agent::Agent;
19use crate::hooks::{
20 AfterInvocationEvent, AfterToolCallEvent, BeforeInvocationEvent, BeforeToolCallEvent,
21 HookEvent, HookRegistry,
22};
23use crate::types::tools::{ToolResult as ToolResultType, ToolUse};
24use crate::types::errors::{Result, StrandsError};
25use crate::types::streaming::{Metrics, Usage};
26
27pub type EdgeCondition = Arc<dyn Fn(&GraphState) -> bool + Send + Sync>;
29
30pub struct GraphEdge {
32 pub from_node: String,
33 pub to_node: String,
34 pub condition: Option<EdgeCondition>,
35}
36
37impl GraphEdge {
38 pub fn new(from: impl Into<String>, to: impl Into<String>) -> Self {
40 Self {
41 from_node: from.into(),
42 to_node: to.into(),
43 condition: None,
44 }
45 }
46
47 pub fn conditional(
49 from: impl Into<String>,
50 to: impl Into<String>,
51 condition: impl Fn(&GraphState) -> bool + Send + Sync + 'static,
52 ) -> Self {
53 Self {
54 from_node: from.into(),
55 to_node: to.into(),
56 condition: Some(Arc::new(condition)),
57 }
58 }
59
60 pub fn should_traverse(&self, state: &GraphState) -> bool {
62 match &self.condition {
63 Some(cond) => cond(state),
64 None => true,
65 }
66 }
67}
68
69pub struct GraphNode {
71 pub node_id: String,
72 pub agent: Agent,
73 pub dependencies: HashSet<String>,
74 pub status: Status,
75 pub result: Option<NodeResult>,
76 pub execution_time_ms: u64,
77}
78
79impl GraphNode {
80 pub fn new(node_id: impl Into<String>, agent: Agent) -> Self {
81 Self {
82 node_id: node_id.into(),
83 agent,
84 dependencies: HashSet::new(),
85 status: Status::Pending,
86 result: None,
87 execution_time_ms: 0,
88 }
89 }
90
91 pub fn reset(&mut self) {
93 self.status = Status::Pending;
94 self.result = None;
95 self.execution_time_ms = 0;
96 }
97}
98
99#[derive(Debug, Clone, Default)]
101pub struct GraphState {
102 pub status: Status,
103 pub task: String,
104 pub completed_nodes: HashSet<String>,
105 pub failed_nodes: HashSet<String>,
106 pub execution_order: Vec<String>,
107 pub results: HashMap<String, NodeResult>,
108 pub accumulated_usage: Usage,
109 pub accumulated_metrics: Metrics,
110 pub execution_count: u32,
111 pub execution_time_ms: u64,
112 pub start_time: Option<Instant>,
113 pub total_nodes: usize,
114}
115
116impl GraphState {
117 pub fn should_continue(
119 &self,
120 max_node_executions: Option<usize>,
121 execution_timeout: Option<Duration>,
122 ) -> (bool, &'static str) {
123 if let Some(max) = max_node_executions {
124 if self.execution_order.len() >= max {
125 return (false, "Max node executions reached");
126 }
127 }
128
129 if let (Some(timeout), Some(start)) = (execution_timeout, self.start_time) {
130 if start.elapsed() > timeout {
131 return (false, "Execution timed out");
132 }
133 }
134
135 (true, "Continuing")
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct GraphResult {
142 pub status: Status,
143 pub results: HashMap<String, NodeResult>,
144 pub execution_order: Vec<String>,
145 pub accumulated_usage: Usage,
146 pub accumulated_metrics: Metrics,
147 pub execution_time_ms: u64,
148 pub total_nodes: usize,
149 pub completed_nodes: usize,
150 pub failed_nodes: usize,
151 pub entry_points: Vec<String>,
152}
153
154impl From<GraphResult> for MultiAgentResult {
155 fn from(gr: GraphResult) -> Self {
156 MultiAgentResult {
157 status: gr.status,
158 results: gr.results,
159 accumulated_usage: gr.accumulated_usage,
160 accumulated_metrics: gr.accumulated_metrics,
161 execution_count: gr.execution_order.len() as u32,
162 execution_time_ms: gr.execution_time_ms,
163 interrupts: Vec::new(),
164 }
165 }
166}
167
168#[derive(Debug, Clone)]
170pub struct GraphConfig {
171 pub max_node_executions: Option<usize>,
172 pub execution_timeout: Option<Duration>,
173 pub node_timeout: Option<Duration>,
174 pub reset_on_revisit: bool,
175}
176
177impl Default for GraphConfig {
178 fn default() -> Self {
179 Self {
180 max_node_executions: Some(100),
181 execution_timeout: Some(Duration::from_secs(900)),
182 node_timeout: Some(Duration::from_secs(300)),
183 reset_on_revisit: false,
184 }
185 }
186}
187
188pub struct GraphBuilder {
190 nodes: HashMap<String, GraphNode>,
191 edges: Vec<GraphEdge>,
192 entry_points: HashSet<String>,
193 config: GraphConfig,
194 id: String,
195 hooks: HookRegistry,
196}
197
198impl Default for GraphBuilder {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204impl GraphBuilder {
205 pub fn new() -> Self {
206 Self {
207 nodes: HashMap::new(),
208 edges: Vec::new(),
209 entry_points: HashSet::new(),
210 config: GraphConfig::default(),
211 id: "default_graph".to_string(),
212 hooks: HookRegistry::new(),
213 }
214 }
215
216 pub fn id(mut self, id: impl Into<String>) -> Self {
218 self.id = id.into();
219 self
220 }
221
222 pub fn add_node(mut self, node_id: impl Into<String>, agent: Agent) -> Self {
224 let node_id = node_id.into();
225 self.nodes.insert(node_id.clone(), GraphNode::new(node_id, agent));
226 self
227 }
228
229 pub fn add_edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
231 let from = from.into();
232 let to = to.into();
233
234 if let Some(node) = self.nodes.get_mut(&to) {
235 node.dependencies.insert(from.clone());
236 }
237
238 self.edges.push(GraphEdge::new(from, to));
239 self
240 }
241
242 pub fn add_conditional_edge<F>(
244 mut self,
245 from: impl Into<String>,
246 to: impl Into<String>,
247 condition: F,
248 ) -> Self
249 where
250 F: Fn(&GraphState) -> bool + Send + Sync + 'static,
251 {
252 let from = from.into();
253 let to = to.into();
254
255 if let Some(node) = self.nodes.get_mut(&to) {
256 node.dependencies.insert(from.clone());
257 }
258
259 self.edges.push(GraphEdge::conditional(from, to, condition));
260 self
261 }
262
263 pub fn set_entry_points(mut self, entry_points: impl IntoIterator<Item = impl Into<String>>) -> Self {
265 self.entry_points = entry_points.into_iter().map(Into::into).collect();
266 self
267 }
268
269 pub fn set_entry_point(mut self, node_id: impl Into<String>) -> Self {
271 self.entry_points.insert(node_id.into());
272 self
273 }
274
275 pub fn config(mut self, config: GraphConfig) -> Self {
277 self.config = config;
278 self
279 }
280
281 pub fn max_node_executions(mut self, max: usize) -> Self {
283 self.config.max_node_executions = Some(max);
284 self
285 }
286
287 pub fn execution_timeout(mut self, timeout: Duration) -> Self {
289 self.config.execution_timeout = Some(timeout);
290 self
291 }
292
293 pub fn node_timeout(mut self, timeout: Duration) -> Self {
295 self.config.node_timeout = Some(timeout);
296 self
297 }
298
299 pub fn reset_on_revisit(mut self, enabled: bool) -> Self {
301 self.config.reset_on_revisit = enabled;
302 self
303 }
304
305 pub fn hooks(mut self, hooks: HookRegistry) -> Self {
307 self.hooks = hooks;
308 self
309 }
310
311 pub fn build(self) -> Result<Graph> {
313 if self.nodes.is_empty() {
314 return Err(StrandsError::ConfigurationError {
315 message: "Graph must have at least one node".to_string(),
316 });
317 }
318
319 let entry_points = if self.entry_points.is_empty() {
320 self.nodes
321 .values()
322 .filter(|n| n.dependencies.is_empty())
323 .map(|n| n.node_id.clone())
324 .collect()
325 } else {
326 self.entry_points
327 };
328
329 if entry_points.is_empty() {
330 return Err(StrandsError::ConfigurationError {
331 message: "Graph has no entry points (all nodes have dependencies)".to_string(),
332 });
333 }
334
335 Ok(Graph {
336 id: self.id,
337 nodes: self.nodes,
338 edges: self.edges,
339 entry_points,
340 config: self.config,
341 state: GraphState::default(),
342 hooks: self.hooks,
343 interrupt_state: InterruptState::new(),
344 })
345 }
346}
347
348pub struct Graph {
350 id: String,
351 nodes: HashMap<String, GraphNode>,
352 edges: Vec<GraphEdge>,
353 entry_points: HashSet<String>,
354 config: GraphConfig,
355 state: GraphState,
356 hooks: HookRegistry,
357 interrupt_state: InterruptState,
358}
359
360impl Graph {
361 pub fn builder() -> GraphBuilder {
363 GraphBuilder::new()
364 }
365
366 pub fn graph_id(&self) -> &str {
368 &self.id
369 }
370
371 pub fn state(&self) -> &GraphState {
373 &self.state
374 }
375
376 pub fn node_ids(&self) -> impl Iterator<Item = &str> {
378 self.nodes.keys().map(|s| s.as_str())
379 }
380
381 pub fn entry_points(&self) -> &HashSet<String> {
383 &self.entry_points
384 }
385
386 pub fn interrupt_state(&self) -> &InterruptState {
388 &self.interrupt_state
389 }
390
391 pub fn interrupt_state_mut(&mut self) -> &mut InterruptState {
393 &mut self.interrupt_state
394 }
395
396
397 pub fn call(&mut self, task: impl Into<MultiAgentInput>) -> Result<GraphResult> {
399 tokio::task::block_in_place(|| {
400 tokio::runtime::Handle::current().block_on(self.invoke_async(task.into(), None))
401 })
402 }
403
404 pub async fn invoke_async(
406 &mut self,
407 task: MultiAgentInput,
408 invocation_state: Option<&InvocationState>,
409 ) -> Result<GraphResult> {
410 let total_nodes = self.nodes.len();
411 let entry_points_vec: Vec<String> = self.entry_points.iter().cloned().collect();
412
413 let mut stream = self.stream_async(task, invocation_state);
414 let mut final_result = None;
415
416 while let Some(event) = stream.next().await {
417 if let MultiAgentEvent::Result(result) = event {
418 final_result = Some(result);
419 }
420 }
421
422 drop(stream);
423
424 final_result
425 .map(|r| GraphResult {
426 status: r.status,
427 results: r.results,
428 execution_order: self.state.execution_order.clone(),
429 accumulated_usage: r.accumulated_usage,
430 accumulated_metrics: r.accumulated_metrics,
431 execution_time_ms: r.execution_time_ms,
432 total_nodes,
433 completed_nodes: self.state.completed_nodes.len(),
434 failed_nodes: self.state.failed_nodes.len(),
435 entry_points: entry_points_vec,
436 })
437 .ok_or_else(|| StrandsError::MultiAgentError {
438 message: "Graph execution completed without result".to_string(),
439 })
440 }
441
442 pub fn stream_async<'a>(
444 &'a mut self,
445 task: MultiAgentInput,
446 _invocation_state: Option<&'a InvocationState>,
447 ) -> MultiAgentEventStream<'a> {
448 let task_str = task.to_string_lossy();
449
450 Box::pin(async_stream::stream! {
451 self.hooks.invoke(&HookEvent::BeforeInvocation(BeforeInvocationEvent)).await;
452
453 self.state = GraphState {
454 status: Status::Executing,
455 task: task_str.clone(),
456 start_time: Some(Instant::now()),
457 total_nodes: self.nodes.len(),
458 ..Default::default()
459 };
460
461 let mut queue: VecDeque<String> = self.entry_points.iter().cloned().collect();
462 let mut processed: HashSet<String> = HashSet::new();
463
464 while let Some(node_id) = queue.pop_front() {
465 if processed.contains(&node_id) {
466 continue;
467 }
468
469 let (should_continue, reason) = self.state.should_continue(
470 self.config.max_node_executions,
471 self.config.execution_timeout,
472 );
473 if !should_continue {
474 tracing::warn!("Graph execution stopped: {reason}");
475 self.state.status = Status::Failed;
476 break;
477 }
478
479 let deps_met = {
480 if let Some(node) = self.nodes.get(&node_id) {
481 node.dependencies.iter().all(|dep| self.state.completed_nodes.contains(dep))
482 } else {
483 false
484 }
485 };
486
487 if !deps_met {
488 queue.push_back(node_id);
489 continue;
490 }
491
492 if self.config.reset_on_revisit && self.state.completed_nodes.contains(&node_id) {
493 if let Some(node) = self.nodes.get_mut(&node_id) {
494 node.reset();
495 }
496 self.state.completed_nodes.remove(&node_id);
497 }
498
499 yield MultiAgentEvent::node_start(&node_id, "agent");
500
501 self.hooks.invoke(&HookEvent::BeforeToolCall(BeforeToolCallEvent::new(
502 ToolUse::new(&node_id, &node_id, serde_json::json!({}))
503 ))).await;
504
505 let result = self.execute_node(&node_id, &task_str).await;
506
507 match result {
508 Ok(node_result) => {
509
510 if node_result.status == Status::Interrupted {
511 self.interrupt_state.deactivate();
512 tracing::error!("user raised interrupt from agent | interrupts are not yet supported in graphs");
513 self.state.status = Status::Failed;
514 yield MultiAgentEvent::node_stop(&node_id, node_result);
515 break;
516 }
517
518 self.state.completed_nodes.insert(node_id.clone());
519 self.state.execution_order.push(node_id.clone());
520 self.state.accumulated_usage.add(&node_result.accumulated_usage);
521 self.state.accumulated_metrics.latency_ms += node_result.accumulated_metrics.latency_ms;
522 self.state.execution_count += 1;
523
524 if let Some(node) = self.nodes.get_mut(&node_id) {
525 node.status = Status::Completed;
526 node.execution_time_ms = node_result.execution_time_ms;
527 }
528
529 yield MultiAgentEvent::node_stop(&node_id, node_result.clone());
530
531 self.state.results.insert(node_id.clone(), node_result);
532
533 let mut next_nodes = Vec::new();
534 for edge in &self.edges {
535 if edge.from_node == node_id && edge.should_traverse(&self.state) {
536 if !processed.contains(&edge.to_node) {
537 next_nodes.push(edge.to_node.clone());
538 }
539 }
540 }
541
542 if !next_nodes.is_empty() {
543 yield MultiAgentEvent::handoff(
544 vec![node_id.clone()],
545 next_nodes.clone(),
546 None,
547 );
548 for next in next_nodes {
549 queue.push_back(next);
550 }
551 }
552 }
553 Err(e) => {
554 tracing::error!("Node {node_id} failed: {e}");
555 self.state.failed_nodes.insert(node_id.clone());
556 if let Some(node) = self.nodes.get_mut(&node_id) {
557 node.status = Status::Failed;
558 }
559
560 let error_result = NodeResult::from_error(e.to_string(), 0);
561 yield MultiAgentEvent::node_stop(&node_id, error_result);
562 }
563 }
564
565 self.hooks.invoke(&HookEvent::AfterToolCall(AfterToolCallEvent::new(
566 ToolUse::new(&node_id, &node_id, serde_json::json!({})),
567 ToolResultType::success(&node_id, "completed")
568 ))).await;
569 processed.insert(node_id);
570 }
571
572 if self.state.failed_nodes.is_empty() && self.state.status == Status::Executing {
573 self.state.status = Status::Completed;
574 } else if !self.state.failed_nodes.is_empty() {
575 self.state.status = Status::Failed;
576 }
577
578 self.state.execution_time_ms = self.state.start_time
579 .map(|s| s.elapsed().as_millis() as u64)
580 .unwrap_or(0);
581
582 self.hooks.invoke(&HookEvent::AfterInvocation(AfterInvocationEvent::new(None))).await;
583
584 let result = MultiAgentResult {
585 status: self.state.status,
586 results: self.state.results.clone(),
587 accumulated_usage: self.state.accumulated_usage.clone(),
588 accumulated_metrics: self.state.accumulated_metrics.clone(),
589 execution_count: self.state.execution_count,
590 execution_time_ms: self.state.execution_time_ms,
591 interrupts: Vec::new(),
592 };
593
594 yield MultiAgentEvent::result(result);
595 })
596 }
597
598 async fn execute_node(&mut self, node_id: &str, task: &str) -> Result<NodeResult> {
599 let start = Instant::now();
600
601 let input = self.build_node_input(node_id, task);
602
603 let node = self.nodes.get_mut(node_id).ok_or_else(|| StrandsError::InternalError {
604 message: format!("Node '{node_id}' not found"),
605 })?;
606
607 node.status = Status::Executing;
608
609 let agent_result = node.agent.invoke_async(input.as_str()).await?;
610 let execution_time_ms = start.elapsed().as_millis() as u64;
611
612 let usage = agent_result.usage.clone();
613
614 Ok(NodeResult {
615 result: NodeResultValue::Agent(agent_result),
616 execution_time_ms,
617 status: Status::Completed,
618 accumulated_usage: usage,
619 accumulated_metrics: Metrics { latency_ms: execution_time_ms, time_to_first_byte_ms: 0 },
620 execution_count: 1,
621 interrupts: Vec::new(),
622 })
623 }
624
625 fn build_node_input(&self, node_id: &str, task: &str) -> String {
626 let mut input = String::new();
627
628 let node = match self.nodes.get(node_id) {
629 Some(n) => n,
630 None => {
631 input.push_str(&format!("Task: {task}"));
632 return input;
633 }
634 };
635
636 if node.dependencies.is_empty() {
637 input.push_str(&format!("Task: {task}"));
638 } else {
639 input.push_str(&format!("Original Task: {task}\n\n"));
640 input.push_str("Inputs from previous nodes:\n\n");
641
642 for dep in &node.dependencies {
643 if let Some(result) = self.state.results.get(dep) {
644 input.push_str(&format!("From {dep}:\n"));
645 for agent_result in result.get_agent_results() {
646 let text = agent_result.text();
647 input.push_str(&format!(" - Agent: {text}\n"));
648 }
649 }
650 }
651 }
652
653 input
654 }
655}
656
657#[async_trait]
658impl MultiAgentBase for Graph {
659 fn id(&self) -> &str {
660 &self.id
661 }
662
663 async fn invoke_async(
664 &mut self,
665 task: MultiAgentInput,
666 invocation_state: Option<&InvocationState>,
667 ) -> Result<MultiAgentResult> {
668 self.invoke_async(task, invocation_state).await.map(Into::into)
669 }
670
671 fn stream_async<'a>(
672 &'a mut self,
673 task: MultiAgentInput,
674 invocation_state: Option<&'a InvocationState>,
675 ) -> MultiAgentEventStream<'a> {
676 self.stream_async(task, invocation_state)
677 }
678
679 fn serialize_state(&self) -> serde_json::Value {
680 serde_json::json!({
681 "type": "graph",
682 "id": self.id,
683 "status": format!("{:?}", self.state.status).to_lowercase(),
684 "completed_nodes": self.state.completed_nodes.iter().collect::<Vec<_>>(),
685 "failed_nodes": self.state.failed_nodes.iter().collect::<Vec<_>>(),
686 "execution_order": self.state.execution_order,
687 "current_task": self.state.task,
688 })
689 }
690
691 fn deserialize_state(&mut self, payload: &serde_json::Value) -> Result<()> {
692 if let Some(status_str) = payload.get("status").and_then(|v| v.as_str()) {
693 self.state.status = match status_str {
694 "pending" => Status::Pending,
695 "executing" => Status::Executing,
696 "completed" => Status::Completed,
697 "failed" => Status::Failed,
698 "interrupted" => Status::Interrupted,
699 _ => Status::Pending,
700 };
701 }
702
703 if let Some(completed) = payload.get("completed_nodes").and_then(|v| v.as_array()) {
704 self.state.completed_nodes = completed
705 .iter()
706 .filter_map(|v| v.as_str().map(|s| s.to_string()))
707 .collect();
708 }
709
710 if let Some(task) = payload.get("current_task").and_then(|v| v.as_str()) {
711 self.state.task = task.to_string();
712 }
713
714 Ok(())
715 }
716}
717
718#[cfg(test)]
719mod tests {
720 use super::*;
721
722 #[test]
723 fn test_graph_no_nodes() {
724 let result = Graph::builder().build();
725 assert!(result.is_err());
726 }
727
728 #[test]
729 fn test_graph_state_should_continue() {
730 let state = GraphState::default();
731 let (should_continue, _) = state.should_continue(Some(10), None);
732 assert!(should_continue);
733
734 let mut state = GraphState::default();
735 state.execution_order = vec!["a".to_string(); 10];
736 let (should_continue, reason) = state.should_continue(Some(10), None);
737 assert!(!should_continue);
738 assert_eq!(reason, "Max node executions reached");
739 }
740
741 #[test]
742 fn test_node_result() {
743 let result = NodeResult::from_error("test error", 100);
744 assert!(result.is_error());
745 assert_eq!(result.execution_time_ms, 100);
746 }
747
748 #[test]
749 fn test_status_default() {
750 let status = Status::default();
751 assert_eq!(status, Status::Pending);
752 }
753}