claude_agent_sdk/orchestration/patterns/
parallel.rs1use crate::orchestration::{
17 Result,
18 agent::{Agent, AgentInput, AgentOutput},
19 context::{AgentExecution, ExecutionContext},
20 orchestrator::{BaseOrchestrator, Orchestrator, OrchestratorInput, OrchestratorOutput},
21};
22use futures::future::join_all;
23use std::sync::Arc;
24use tokio::sync::Semaphore;
25
26pub struct ParallelOrchestrator {
28 base: BaseOrchestrator,
29 max_retries: usize,
30 parallel_limit: usize,
31}
32
33impl ParallelOrchestrator {
34 pub fn new() -> Self {
36 Self {
37 base: BaseOrchestrator::new(
38 "ParallelOrchestrator",
39 "Executes agents in parallel and aggregates their outputs",
40 ),
41 max_retries: 3,
42 parallel_limit: 10,
43 }
44 }
45
46 pub fn with_max_retries(mut self, max_retries: usize) -> Self {
48 self.max_retries = max_retries;
49 self
50 }
51
52 pub fn with_parallel_limit(mut self, limit: usize) -> Self {
54 self.parallel_limit = limit;
55 self
56 }
57
58 async fn execute_parallel(
60 &self,
61 agents: Vec<Box<dyn Agent>>,
62 input: AgentInput,
63 ctx: &ExecutionContext,
64 ) -> Result<Vec<AgentOutput>> {
65 let semaphore = Arc::new(Semaphore::new(self.parallel_limit));
66 let agents_count = agents.len();
67 let mut futures = Vec::new();
68
69 for (index, agent) in agents.iter().enumerate() {
70 let agent_ref = agent.as_ref();
71 let input_clone = input.clone();
72 let semaphore_clone = semaphore.clone();
73 let ctx_clone = ctx.clone();
74 let base_name = self.base.name().to_string();
75
76 let future = async move {
77 let _permit = semaphore_clone.acquire().await.unwrap();
79
80 let mut exec_record = AgentExecution::new(agent_ref.name(), input_clone.clone());
82
83 if ctx_clone.is_logging_enabled() {
84 println!(
85 "[{}] Executing agent {}/{}: {}",
86 base_name,
87 index + 1,
88 agents_count,
89 agent_ref.name()
90 );
91 }
92
93 let output =
95 Self::execute_agent_with_retry_static(agent_ref, input_clone, self.max_retries)
96 .await;
97
98 let success = output.is_successful();
99
100 if success {
101 exec_record.succeed(output.clone());
102 } else {
103 exec_record.fail(output.content.clone());
104 }
105
106 if ctx_clone.is_tracing_enabled() {
108 ctx_clone.add_execution(exec_record).await;
109 }
110
111 (agent_ref.name().to_string(), output, success)
112 };
113
114 futures.push(future);
115 }
116
117 let results = join_all(futures).await;
119
120 let mut outputs = Vec::new();
122 let mut failed_agents = Vec::new();
123
124 for (agent_name, output, success) in results {
125 if success {
126 outputs.push(output);
127 } else {
128 failed_agents.push(agent_name);
129 }
130 }
131
132 if !failed_agents.is_empty() {
134 return Err(
135 crate::orchestration::errors::OrchestrationError::agent_failure(
136 failed_agents.join(", "),
137 "Execution failed",
138 ),
139 );
140 }
141
142 Ok(outputs)
143 }
144
145 async fn execute_agent_with_retry_static(
147 agent: &dyn Agent,
148 input: AgentInput,
149 max_retries: usize,
150 ) -> AgentOutput {
151 let mut last_error = None;
152
153 for attempt in 0..=max_retries {
154 match agent.execute(input.clone()).await {
155 Ok(output) => return output,
156 Err(e) => {
157 last_error = Some(e.to_string());
158 if attempt < max_retries {
159 tokio::time::sleep(std::time::Duration::from_millis(
160 100 * 2_u64.pow(attempt as u32),
161 ))
162 .await;
163 }
164 },
165 }
166 }
167
168 AgentOutput::new(format!(
170 "Agent {} failed after {} retries: {}",
171 agent.name(),
172 max_retries,
173 last_error.unwrap_or_else(|| "Unknown error".to_string())
174 ))
175 .with_confidence(0.0)
176 }
177}
178
179impl Default for ParallelOrchestrator {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185#[async_trait::async_trait]
186impl Orchestrator for ParallelOrchestrator {
187 fn name(&self) -> &str {
188 self.base.name()
189 }
190
191 fn description(&self) -> &str {
192 self.base.description()
193 }
194
195 async fn orchestrate(
196 &self,
197 agents: Vec<Box<dyn Agent>>,
198 input: OrchestratorInput,
199 ) -> Result<OrchestratorOutput> {
200 if agents.is_empty() {
201 return Err(
202 crate::orchestration::errors::OrchestrationError::invalid_config(
203 "At least one agent is required",
204 ),
205 );
206 }
207
208 let mut config = crate::orchestration::context::ExecutionConfig::new();
210 config.parallel_limit = self.parallel_limit;
211 let ctx = ExecutionContext::new(config);
212
213 let agent_input = self.base.input_to_agent_input(&input);
214
215 let outputs = match self.execute_parallel(agents, agent_input, &ctx).await {
217 Ok(outputs) => outputs,
218 Err(e) => {
219 ctx.complete_trace().await;
220 let trace = ctx.get_trace().await;
221 return Ok(OrchestratorOutput::failure(e.to_string(), trace));
222 },
223 };
224
225 ctx.complete_trace().await;
227 let trace = ctx.get_trace().await;
228
229 let aggregated = self.aggregate_results(&outputs);
231
232 Ok(OrchestratorOutput::success(aggregated, outputs, trace))
233 }
234}
235
236impl ParallelOrchestrator {
237 fn aggregate_results(&self, outputs: &[AgentOutput]) -> String {
239 if outputs.is_empty() {
240 return String::new();
241 }
242
243 if outputs.len() == 1 {
244 return outputs[0].content.clone();
245 }
246
247 let mut result = String::from("Parallel execution results:\n\n");
249
250 for (index, output) in outputs.iter().enumerate() {
251 result.push_str(&format!("{}. {}\n", index + 1, output.content));
252 }
253
254 result
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use crate::orchestration::agent::SimpleAgent;
262 use std::sync::atomic::{AtomicUsize, Ordering};
263
264 #[tokio::test]
265 async fn test_parallel_orchestrator() {
266 let orchestrator = ParallelOrchestrator::new();
267
268 let agent1: Box<dyn Agent> = Box::new(SimpleAgent::new("Agent1", "First", |input| {
270 Ok(AgentOutput::new(format!(
271 "Result 1 from: {}",
272 input.content
273 )))
274 }));
275
276 let agent2: Box<dyn Agent> = Box::new(SimpleAgent::new("Agent2", "Second", |input| {
277 Ok(AgentOutput::new(format!(
278 "Result 2 from: {}",
279 input.content
280 )))
281 }));
282
283 let agent3: Box<dyn Agent> = Box::new(SimpleAgent::new("Agent3", "Third", |input| {
284 Ok(AgentOutput::new(format!(
285 "Result 3 from: {}",
286 input.content
287 )))
288 }));
289
290 let agents: Vec<Box<dyn Agent>> = vec![agent1, agent2, agent3];
291
292 let input = OrchestratorInput::new("Test input");
293
294 let output = orchestrator.orchestrate(agents, input).await.unwrap();
295
296 assert!(output.is_successful());
297 assert_eq!(output.agent_outputs.len(), 3);
298 assert!(output.result.contains("Parallel execution results"));
299 assert!(output.result.contains("Result 1 from: Test input"));
300 assert!(output.result.contains("Result 2 from: Test input"));
301 assert!(output.result.contains("Result 3 from: Test input"));
302 }
303
304 #[tokio::test]
305 async fn test_parallel_execution_is_parallel() {
306 let orchestrator = ParallelOrchestrator::new();
307
308 let counter = Arc::new(AtomicUsize::new(0));
309 let max_concurrent = Arc::new(AtomicUsize::new(0));
310
311 let mut agents: Vec<Box<dyn Agent>> = Vec::new();
312
313 for i in 0..5 {
314 let counter_clone = counter.clone();
315 let max_clone = max_concurrent.clone();
316
317 let agent: Box<dyn Agent> = Box::new(SimpleAgent::new(
318 format!("Agent{}", i),
319 format!("Agent number {}", i),
320 move |_input| {
321 let current = counter_clone.fetch_add(1, Ordering::SeqCst);
323
324 loop {
326 let current_max = max_clone.load(Ordering::SeqCst);
327 if current + 1 <= current_max {
328 break;
329 }
330 if max_clone
331 .compare_exchange(
332 current_max,
333 current + 1,
334 Ordering::SeqCst,
335 Ordering::SeqCst,
336 )
337 .is_ok()
338 {
339 break;
340 }
341 }
342
343 let mut sum = 0u64;
345 for j in 0..1000 {
346 sum = sum.wrapping_add(j);
347 }
348
349 counter_clone.fetch_sub(1, Ordering::SeqCst);
351
352 Ok(AgentOutput::new(format!("Agent {} done", i)))
353 },
354 ));
355
356 agents.push(agent);
357 }
358
359 let input = OrchestratorInput::new("Test");
360 let output = orchestrator.orchestrate(agents, input).await.unwrap();
361
362 assert!(output.is_successful());
363 assert_eq!(output.agent_outputs.len(), 5);
364
365 let max_val = max_concurrent.load(Ordering::SeqCst);
367 assert!(
368 max_val >= 1,
369 "Expected at least 1 agent to execute (max concurrent: {})",
370 max_val
371 );
372 }
373
374 #[tokio::test]
375 async fn test_parallel_orchestrator_empty_agents() {
376 let orchestrator = ParallelOrchestrator::new();
377 let agents: Vec<Box<dyn Agent>> = vec![];
378 let input = OrchestratorInput::new("Test");
379
380 let result = orchestrator.orchestrate(agents, input).await;
381
382 assert!(result.is_err());
383 assert!(matches!(
384 result.unwrap_err(),
385 crate::orchestration::errors::OrchestrationError::InvalidConfig(_)
386 ));
387 }
388
389 #[tokio::test]
390 async fn test_parallel_with_limit() {
391 let orchestrator = ParallelOrchestrator::new().with_parallel_limit(2);
392
393 let counter = Arc::new(AtomicUsize::new(0));
394 let max_concurrent = Arc::new(AtomicUsize::new(0));
395
396 let mut agents: Vec<Box<dyn Agent>> = Vec::new();
397
398 for i in 0..5 {
399 let counter_clone = counter.clone();
400 let max_clone = max_concurrent.clone();
401
402 let agent: Box<dyn Agent> = Box::new(SimpleAgent::new(
403 format!("Agent{}", i),
404 format!("Agent {}", i),
405 move |_input| {
406 let current = counter_clone.fetch_add(1, Ordering::SeqCst);
407
408 loop {
409 let current_max = max_clone.load(Ordering::SeqCst);
410 if current + 1 <= current_max {
411 break;
412 }
413 if max_clone
414 .compare_exchange(
415 current_max,
416 current + 1,
417 Ordering::SeqCst,
418 Ordering::SeqCst,
419 )
420 .is_ok()
421 {
422 break;
423 }
424 }
425
426 counter_clone.fetch_sub(1, Ordering::SeqCst);
429
430 Ok(AgentOutput::new(format!("Agent {} done", i)))
431 },
432 ));
433
434 agents.push(agent);
435 }
436
437 let input = OrchestratorInput::new("Test");
438 let output = orchestrator.orchestrate(agents, input).await.unwrap();
439
440 assert!(output.is_successful());
441
442 let max_val = max_concurrent.load(Ordering::SeqCst);
444 assert!(max_val <= 2, "Expected max 2 concurrent, got {}", max_val);
445 }
446}