potato_agent/agents/orchestration/
parallel.rs1use crate::agents::{
2 error::AgentError,
3 run_context::ResumeContext,
4 runner::{AgentRunOutcome, AgentRunResult, AgentRunner},
5 session::SessionState,
6};
7use async_trait::async_trait;
8use potato_util::create_uuid7;
9use serde_json::Value;
10use std::fmt::Debug;
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Default)]
15pub enum MergeStrategy {
16 #[default]
18 CollectAll,
19 First,
21}
22
23#[derive(Debug)]
25pub struct ParallelAgent {
26 id: String,
27 agents: Vec<Arc<dyn AgentRunner>>,
28 strategy: MergeStrategy,
29}
30
31#[async_trait]
32impl AgentRunner for ParallelAgent {
33 fn id(&self) -> &str {
34 &self.id
35 }
36
37 async fn run(
38 &self,
39 input: &str,
40 session: &mut SessionState,
41 ) -> Result<AgentRunOutcome, AgentError> {
42 let mut handles = Vec::with_capacity(self.agents.len());
44
45 for agent in &self.agents {
46 let agent_clone = Arc::clone(agent);
47 let input_owned = input.to_string();
48 let mut child_session = SessionState::new();
50 child_session.merge(session.snapshot());
51
52 let handle = tokio::spawn(async move {
53 let result = agent_clone.run(&input_owned, &mut child_session).await;
54 (result, child_session.snapshot())
55 });
56 handles.push(handle);
57 }
58
59 let mut outcomes: Vec<AgentRunResult> = Vec::new();
60
61 for handle in handles {
62 let (outcome, child_snapshot) = handle
63 .await
64 .map_err(|e| AgentError::Error(format!("Parallel join error: {}", e)))?;
65
66 session.merge_user_data(child_snapshot);
68
69 match outcome? {
70 AgentRunOutcome::Complete(result) => {
71 outcomes.push(*result);
72 }
73 AgentRunOutcome::NeedsInput {
74 question,
75 resume_context,
76 } => {
77 return Ok(AgentRunOutcome::NeedsInput {
79 question,
80 resume_context,
81 });
82 }
83 }
84 }
85
86 if outcomes.is_empty() {
87 return Err(AgentError::Error(
88 "ParallelAgent: no agents produced results".to_string(),
89 ));
90 }
91
92 match self.strategy {
93 MergeStrategy::First => {
94 let mut result = outcomes.into_iter().next().unwrap();
95 result.combined_text = None;
96 Ok(AgentRunOutcome::complete(result))
97 }
98 MergeStrategy::CollectAll => {
99 let texts: Vec<Value> = outcomes
101 .iter()
102 .map(|r| Value::String(r.final_response.response_text()))
103 .collect();
104 let combined = Value::Array(texts).to_string();
105 let last = outcomes.into_iter().last().unwrap();
106 Ok(AgentRunOutcome::complete(AgentRunResult {
107 final_response: last.final_response,
108 iterations: last.iterations,
109 completion_reason: "all parallel agents completed".into(),
110 combined_text: Some(combined),
111 }))
112 }
113 }
114 }
115
116 async fn resume(
117 &self,
118 user_answer: &str,
119 ctx: ResumeContext,
120 session: &mut SessionState,
121 ) -> Result<AgentRunOutcome, AgentError> {
122 for agent in &self.agents {
123 if agent.id() == ctx.agent_id {
124 return agent.resume(user_answer, ctx, session).await;
125 }
126 }
127 Err(AgentError::Error(format!(
128 "No agent with id '{}' found in parallel group",
129 ctx.agent_id
130 )))
131 }
132}
133
134#[derive(Default)]
136pub struct ParallelAgentBuilder {
137 agents: Vec<Arc<dyn AgentRunner>>,
138 strategy: MergeStrategy,
139}
140
141impl ParallelAgentBuilder {
142 pub fn new() -> Self {
143 Self::default()
144 }
145
146 pub fn with_agent(mut self, agent: Arc<dyn AgentRunner>) -> Self {
147 self.agents.push(agent);
148 self
149 }
150
151 pub fn merge_strategy(mut self, strategy: MergeStrategy) -> Self {
152 self.strategy = strategy;
153 self
154 }
155
156 pub fn build(self) -> Arc<ParallelAgent> {
157 Arc::new(ParallelAgent {
158 id: create_uuid7(),
159 agents: self.agents,
160 strategy: self.strategy,
161 })
162 }
163}