claude_agent_sdk/orchestration/patterns/
parallel.rs

1//! # Parallel Orchestration Pattern
2//!
3//! Multiple agents execute in parallel, and their outputs are aggregated.
4//!
5//! ```text
6//!         → Agent A ─┐
7//! Input ─┼→ Agent B ─┼→ Aggregator → Output
8//!         → Agent C ─┘
9//! ```
10//!
11//! Use cases:
12//! - Multi-angle analysis
13//! - Parallel task processing
14//! - Performance optimization
15
16use crate::orchestration::{
17    Result,
18    agent::{Agent, AgentInput, AgentOutput},
19    context::{AgentExecution, ExecutionContext},
20    orchestrator::{BaseOrchestrator, Orchestrator, OrchestratorInput, OrchestratorOutput},
21};
22use futures::future::join_all;
23use std::sync::Arc;
24use tokio::sync::Semaphore;
25
26/// Parallel orchestrator that executes agents concurrently
27pub struct ParallelOrchestrator {
28    base: BaseOrchestrator,
29    max_retries: usize,
30    parallel_limit: usize,
31}
32
33impl ParallelOrchestrator {
34    /// Create a new parallel orchestrator
35    pub fn new() -> Self {
36        Self {
37            base: BaseOrchestrator::new(
38                "ParallelOrchestrator",
39                "Executes agents in parallel and aggregates their outputs",
40            ),
41            max_retries: 3,
42            parallel_limit: 10,
43        }
44    }
45
46    /// Set max retries per agent
47    pub fn with_max_retries(mut self, max_retries: usize) -> Self {
48        self.max_retries = max_retries;
49        self
50    }
51
52    /// Set parallel execution limit
53    pub fn with_parallel_limit(mut self, limit: usize) -> Self {
54        self.parallel_limit = limit;
55        self
56    }
57
58    /// Execute agents in parallel
59    async fn execute_parallel(
60        &self,
61        agents: Vec<Box<dyn Agent>>,
62        input: AgentInput,
63        ctx: &ExecutionContext,
64    ) -> Result<Vec<AgentOutput>> {
65        let semaphore = Arc::new(Semaphore::new(self.parallel_limit));
66        let agents_count = agents.len();
67        let mut futures = Vec::new();
68
69        for (index, agent) in agents.iter().enumerate() {
70            let agent_ref = agent.as_ref();
71            let input_clone = input.clone();
72            let semaphore_clone = semaphore.clone();
73            let ctx_clone = ctx.clone();
74            let base_name = self.base.name().to_string();
75
76            let future = async move {
77                // Acquire semaphore permit
78                let _permit = semaphore_clone.acquire().await.unwrap();
79
80                // Create execution record
81                let mut exec_record = AgentExecution::new(agent_ref.name(), input_clone.clone());
82
83                if ctx_clone.is_logging_enabled() {
84                    println!(
85                        "[{}] Executing agent {}/{}: {}",
86                        base_name,
87                        index + 1,
88                        agents_count,
89                        agent_ref.name()
90                    );
91                }
92
93                // Execute agent with retry
94                let output =
95                    Self::execute_agent_with_retry_static(agent_ref, input_clone, self.max_retries)
96                        .await;
97
98                let success = output.is_successful();
99
100                if success {
101                    exec_record.succeed(output.clone());
102                } else {
103                    exec_record.fail(output.content.clone());
104                }
105
106                // Add to trace if enabled
107                if ctx_clone.is_tracing_enabled() {
108                    ctx_clone.add_execution(exec_record).await;
109                }
110
111                (agent_ref.name().to_string(), output, success)
112            };
113
114            futures.push(future);
115        }
116
117        // Wait for all agents to complete
118        let results = join_all(futures).await;
119
120        // Check for failures and collect outputs
121        let mut outputs = Vec::new();
122        let mut failed_agents = Vec::new();
123
124        for (agent_name, output, success) in results {
125            if success {
126                outputs.push(output);
127            } else {
128                failed_agents.push(agent_name);
129            }
130        }
131
132        // If any agents failed, return error
133        if !failed_agents.is_empty() {
134            return Err(
135                crate::orchestration::errors::OrchestrationError::agent_failure(
136                    failed_agents.join(", "),
137                    "Execution failed",
138                ),
139            );
140        }
141
142        Ok(outputs)
143    }
144
145    // Static version for use in async block
146    async fn execute_agent_with_retry_static(
147        agent: &dyn Agent,
148        input: AgentInput,
149        max_retries: usize,
150    ) -> AgentOutput {
151        let mut last_error = None;
152
153        for attempt in 0..=max_retries {
154            match agent.execute(input.clone()).await {
155                Ok(output) => return output,
156                Err(e) => {
157                    last_error = Some(e.to_string());
158                    if attempt < max_retries {
159                        tokio::time::sleep(std::time::Duration::from_millis(
160                            100 * 2_u64.pow(attempt as u32),
161                        ))
162                        .await;
163                    }
164                },
165            }
166        }
167
168        // All retries failed
169        AgentOutput::new(format!(
170            "Agent {} failed after {} retries: {}",
171            agent.name(),
172            max_retries,
173            last_error.unwrap_or_else(|| "Unknown error".to_string())
174        ))
175        .with_confidence(0.0)
176    }
177}
178
179impl Default for ParallelOrchestrator {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185#[async_trait::async_trait]
186impl Orchestrator for ParallelOrchestrator {
187    fn name(&self) -> &str {
188        self.base.name()
189    }
190
191    fn description(&self) -> &str {
192        self.base.description()
193    }
194
195    async fn orchestrate(
196        &self,
197        agents: Vec<Box<dyn Agent>>,
198        input: OrchestratorInput,
199    ) -> Result<OrchestratorOutput> {
200        if agents.is_empty() {
201            return Err(
202                crate::orchestration::errors::OrchestrationError::invalid_config(
203                    "At least one agent is required",
204                ),
205            );
206        }
207
208        // Create execution context
209        let mut config = crate::orchestration::context::ExecutionConfig::new();
210        config.parallel_limit = self.parallel_limit;
211        let ctx = ExecutionContext::new(config);
212
213        let agent_input = self.base.input_to_agent_input(&input);
214
215        // Execute agents in parallel
216        let outputs = match self.execute_parallel(agents, agent_input, &ctx).await {
217            Ok(outputs) => outputs,
218            Err(e) => {
219                ctx.complete_trace().await;
220                let trace = ctx.get_trace().await;
221                return Ok(OrchestratorOutput::failure(e.to_string(), trace));
222            },
223        };
224
225        // Complete trace
226        ctx.complete_trace().await;
227        let trace = ctx.get_trace().await;
228
229        // Aggregate results
230        let aggregated = self.aggregate_results(&outputs);
231
232        Ok(OrchestratorOutput::success(aggregated, outputs, trace))
233    }
234}
235
236impl ParallelOrchestrator {
237    /// Aggregate multiple agent outputs into a single result
238    fn aggregate_results(&self, outputs: &[AgentOutput]) -> String {
239        if outputs.is_empty() {
240            return String::new();
241        }
242
243        if outputs.len() == 1 {
244            return outputs[0].content.clone();
245        }
246
247        // Combine all outputs
248        let mut result = String::from("Parallel execution results:\n\n");
249
250        for (index, output) in outputs.iter().enumerate() {
251            result.push_str(&format!("{}. {}\n", index + 1, output.content));
252        }
253
254        result
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use crate::orchestration::agent::SimpleAgent;
262    use std::sync::atomic::{AtomicUsize, Ordering};
263
264    #[tokio::test]
265    async fn test_parallel_orchestrator() {
266        let orchestrator = ParallelOrchestrator::new();
267
268        // Create three agents that execute independently
269        let agent1: Box<dyn Agent> = Box::new(SimpleAgent::new("Agent1", "First", |input| {
270            Ok(AgentOutput::new(format!(
271                "Result 1 from: {}",
272                input.content
273            )))
274        }));
275
276        let agent2: Box<dyn Agent> = Box::new(SimpleAgent::new("Agent2", "Second", |input| {
277            Ok(AgentOutput::new(format!(
278                "Result 2 from: {}",
279                input.content
280            )))
281        }));
282
283        let agent3: Box<dyn Agent> = Box::new(SimpleAgent::new("Agent3", "Third", |input| {
284            Ok(AgentOutput::new(format!(
285                "Result 3 from: {}",
286                input.content
287            )))
288        }));
289
290        let agents: Vec<Box<dyn Agent>> = vec![agent1, agent2, agent3];
291
292        let input = OrchestratorInput::new("Test input");
293
294        let output = orchestrator.orchestrate(agents, input).await.unwrap();
295
296        assert!(output.is_successful());
297        assert_eq!(output.agent_outputs.len(), 3);
298        assert!(output.result.contains("Parallel execution results"));
299        assert!(output.result.contains("Result 1 from: Test input"));
300        assert!(output.result.contains("Result 2 from: Test input"));
301        assert!(output.result.contains("Result 3 from: Test input"));
302    }
303
304    #[tokio::test]
305    async fn test_parallel_execution_is_parallel() {
306        let orchestrator = ParallelOrchestrator::new();
307
308        let counter = Arc::new(AtomicUsize::new(0));
309        let max_concurrent = Arc::new(AtomicUsize::new(0));
310
311        let mut agents: Vec<Box<dyn Agent>> = Vec::new();
312
313        for i in 0..5 {
314            let counter_clone = counter.clone();
315            let max_clone = max_concurrent.clone();
316
317            let agent: Box<dyn Agent> = Box::new(SimpleAgent::new(
318                format!("Agent{}", i),
319                format!("Agent number {}", i),
320                move |_input| {
321                    // Increment counter
322                    let current = counter_clone.fetch_add(1, Ordering::SeqCst);
323
324                    // Update max if needed
325                    loop {
326                        let current_max = max_clone.load(Ordering::SeqCst);
327                        if current + 1 <= current_max {
328                            break;
329                        }
330                        if max_clone
331                            .compare_exchange(
332                                current_max,
333                                current + 1,
334                                Ordering::SeqCst,
335                                Ordering::SeqCst,
336                            )
337                            .is_ok()
338                        {
339                            break;
340                        }
341                    }
342
343                    // Simulate work (using a simple computation instead of sleep)
344                    let mut sum = 0u64;
345                    for j in 0..1000 {
346                        sum = sum.wrapping_add(j);
347                    }
348
349                    // Decrement counter
350                    counter_clone.fetch_sub(1, Ordering::SeqCst);
351
352                    Ok(AgentOutput::new(format!("Agent {} done", i)))
353                },
354            ));
355
356            agents.push(agent);
357        }
358
359        let input = OrchestratorInput::new("Test");
360        let output = orchestrator.orchestrate(agents, input).await.unwrap();
361
362        assert!(output.is_successful());
363        assert_eq!(output.agent_outputs.len(), 5);
364
365        // Verify agents executed
366        let max_val = max_concurrent.load(Ordering::SeqCst);
367        assert!(
368            max_val >= 1,
369            "Expected at least 1 agent to execute (max concurrent: {})",
370            max_val
371        );
372    }
373
374    #[tokio::test]
375    async fn test_parallel_orchestrator_empty_agents() {
376        let orchestrator = ParallelOrchestrator::new();
377        let agents: Vec<Box<dyn Agent>> = vec![];
378        let input = OrchestratorInput::new("Test");
379
380        let result = orchestrator.orchestrate(agents, input).await;
381
382        assert!(result.is_err());
383        assert!(matches!(
384            result.unwrap_err(),
385            crate::orchestration::errors::OrchestrationError::InvalidConfig(_)
386        ));
387    }
388
389    #[tokio::test]
390    async fn test_parallel_with_limit() {
391        let orchestrator = ParallelOrchestrator::new().with_parallel_limit(2);
392
393        let counter = Arc::new(AtomicUsize::new(0));
394        let max_concurrent = Arc::new(AtomicUsize::new(0));
395
396        let mut agents: Vec<Box<dyn Agent>> = Vec::new();
397
398        for i in 0..5 {
399            let counter_clone = counter.clone();
400            let max_clone = max_concurrent.clone();
401
402            let agent: Box<dyn Agent> = Box::new(SimpleAgent::new(
403                format!("Agent{}", i),
404                format!("Agent {}", i),
405                move |_input| {
406                    let current = counter_clone.fetch_add(1, Ordering::SeqCst);
407
408                    loop {
409                        let current_max = max_clone.load(Ordering::SeqCst);
410                        if current + 1 <= current_max {
411                            break;
412                        }
413                        if max_clone
414                            .compare_exchange(
415                                current_max,
416                                current + 1,
417                                Ordering::SeqCst,
418                                Ordering::SeqCst,
419                            )
420                            .is_ok()
421                        {
422                            break;
423                        }
424                    }
425
426                    // Simulated work
427
428                    counter_clone.fetch_sub(1, Ordering::SeqCst);
429
430                    Ok(AgentOutput::new(format!("Agent {} done", i)))
431                },
432            ));
433
434            agents.push(agent);
435        }
436
437        let input = OrchestratorInput::new("Test");
438        let output = orchestrator.orchestrate(agents, input).await.unwrap();
439
440        assert!(output.is_successful());
441
442        // With limit of 2, we should never have more than 2 concurrent
443        let max_val = max_concurrent.load(Ordering::SeqCst);
444        assert!(max_val <= 2, "Expected max 2 concurrent, got {}", max_val);
445    }
446}