Skip to main content

potato_agent/agents/orchestration/
parallel.rs

1use crate::agents::{
2    error::AgentError,
3    run_context::ResumeContext,
4    runner::{AgentRunOutcome, AgentRunResult, AgentRunner},
5    session::SessionState,
6};
7use async_trait::async_trait;
8use potato_util::create_uuid7;
9use serde_json::Value;
10use std::fmt::Debug;
11use std::sync::Arc;
12
13/// How to combine results from parallel agents.
14#[derive(Debug, Clone, Default)]
15pub enum MergeStrategy {
16    /// Collect every agent's response text into a JSON array.
17    #[default]
18    CollectAll,
19    /// Return the first agent that completes.
20    First,
21}
22
23/// Runs all agents concurrently; merges results according to `strategy`.
24#[derive(Debug)]
25pub struct ParallelAgent {
26    id: String,
27    agents: Vec<Arc<dyn AgentRunner>>,
28    strategy: MergeStrategy,
29}
30
31#[async_trait]
32impl AgentRunner for ParallelAgent {
33    fn id(&self) -> &str {
34        &self.id
35    }
36
37    async fn run(
38        &self,
39        input: &str,
40        session: &mut SessionState,
41    ) -> Result<AgentRunOutcome, AgentError> {
42        // Each child gets a snapshot clone of the session; we merge back after join.
43        let mut handles = Vec::with_capacity(self.agents.len());
44
45        for agent in &self.agents {
46            let agent_clone = Arc::clone(agent);
47            let input_owned = input.to_string();
48            // Give each child a fresh session clone
49            let mut child_session = SessionState::new();
50            child_session.merge(session.snapshot());
51
52            let handle = tokio::spawn(async move {
53                let result = agent_clone.run(&input_owned, &mut child_session).await;
54                (result, child_session.snapshot())
55            });
56            handles.push(handle);
57        }
58
59        let mut outcomes: Vec<AgentRunResult> = Vec::new();
60
61        for handle in handles {
62            let (outcome, child_snapshot) = handle
63                .await
64                .map_err(|e| AgentError::Error(format!("Parallel join error: {}", e)))?;
65
66            // Merge child session back into parent, skipping system keys like __ancestor_ids
67            session.merge_user_data(child_snapshot);
68
69            match outcome? {
70                AgentRunOutcome::Complete(result) => {
71                    outcomes.push(*result);
72                }
73                AgentRunOutcome::NeedsInput {
74                    question,
75                    resume_context,
76                } => {
77                    // Abort on first NeedsInput — cannot continue without user input
78                    return Ok(AgentRunOutcome::NeedsInput {
79                        question,
80                        resume_context,
81                    });
82                }
83            }
84        }
85
86        if outcomes.is_empty() {
87            return Err(AgentError::Error(
88                "ParallelAgent: no agents produced results".to_string(),
89            ));
90        }
91
92        match self.strategy {
93            MergeStrategy::First => {
94                let mut result = outcomes.into_iter().next().unwrap();
95                result.combined_text = None;
96                Ok(AgentRunOutcome::complete(result))
97            }
98            MergeStrategy::CollectAll => {
99                // Combine all text responses into a JSON array
100                let texts: Vec<Value> = outcomes
101                    .iter()
102                    .map(|r| Value::String(r.final_response.response_text()))
103                    .collect();
104                let combined = Value::Array(texts).to_string();
105                let last = outcomes.into_iter().last().unwrap();
106                Ok(AgentRunOutcome::complete(AgentRunResult {
107                    final_response: last.final_response,
108                    iterations: last.iterations,
109                    completion_reason: "all parallel agents completed".into(),
110                    combined_text: Some(combined),
111                }))
112            }
113        }
114    }
115
116    async fn resume(
117        &self,
118        user_answer: &str,
119        ctx: ResumeContext,
120        session: &mut SessionState,
121    ) -> Result<AgentRunOutcome, AgentError> {
122        for agent in &self.agents {
123            if agent.id() == ctx.agent_id {
124                return agent.resume(user_answer, ctx, session).await;
125            }
126        }
127        Err(AgentError::Error(format!(
128            "No agent with id '{}' found in parallel group",
129            ctx.agent_id
130        )))
131    }
132}
133
134/// Builder for `ParallelAgent`.
135#[derive(Default)]
136pub struct ParallelAgentBuilder {
137    agents: Vec<Arc<dyn AgentRunner>>,
138    strategy: MergeStrategy,
139}
140
141impl ParallelAgentBuilder {
142    pub fn new() -> Self {
143        Self::default()
144    }
145
146    pub fn with_agent(mut self, agent: Arc<dyn AgentRunner>) -> Self {
147        self.agents.push(agent);
148        self
149    }
150
151    pub fn merge_strategy(mut self, strategy: MergeStrategy) -> Self {
152        self.strategy = strategy;
153        self
154    }
155
156    pub fn build(self) -> Arc<ParallelAgent> {
157        Arc::new(ParallelAgent {
158            id: create_uuid7(),
159            agents: self.agents,
160            strategy: self.strategy,
161        })
162    }
163}