1use super::{
7 orchestrator::Orchestrator,
8 subtask::{SubTask, SubTaskResult},
9 DecompositionStrategy, StageStats, SwarmConfig, SwarmResult,
10};
11
12pub use super::{Actor, ActorStatus, Handler, SwarmMessage};
14use crate::{
15 agent::Agent,
16 provider::{CompletionRequest, ContentPart, FinishReason, Message, Provider, Role},
17 swarm::{SwarmArtifact, SwarmStats},
18 tool::ToolRegistry,
19};
20use anyhow::Result;
21use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::Instant;
24use tokio::sync::RwLock;
25use tokio::time::{timeout, Duration};
26
27pub struct SwarmExecutor {
29 config: SwarmConfig,
30 coordinator_agent: Option<Arc<tokio::sync::Mutex<Agent>>>,
32}
33
34impl SwarmExecutor {
35 pub fn new(config: SwarmConfig) -> Self {
37 Self {
38 config,
39 coordinator_agent: None,
40 }
41 }
42
43 pub fn with_coordinator_agent(mut self, agent: Arc<tokio::sync::Mutex<Agent>>) -> Self {
45 tracing::debug!("Setting coordinator agent for swarm execution");
46 self.coordinator_agent = Some(agent);
47 self
48 }
49
50 pub fn coordinator_agent(&self) -> Option<&Arc<tokio::sync::Mutex<Agent>>> {
52 self.coordinator_agent.as_ref()
53 }
54
55 pub async fn execute(
57 &self,
58 task: &str,
59 strategy: DecompositionStrategy,
60 ) -> Result<SwarmResult> {
61 let start_time = Instant::now();
62
63 let mut orchestrator = Orchestrator::new(self.config.clone()).await?;
65
66 tracing::info!(provider_name = %orchestrator.provider(), "Starting swarm execution for task");
67
68 let subtasks = orchestrator.decompose(task, strategy).await?;
70
71 if subtasks.is_empty() {
72 return Ok(SwarmResult {
73 success: false,
74 result: String::new(),
75 subtask_results: Vec::new(),
76 stats: SwarmStats::default(),
77 artifacts: Vec::new(),
78 error: Some("No subtasks generated".to_string()),
79 });
80 }
81
82 tracing::info!(provider_name = %orchestrator.provider(), "Task decomposed into {} subtasks", subtasks.len());
83
84 let max_stage = subtasks.iter().map(|s| s.stage).max().unwrap_or(0);
86 let mut all_results: Vec<SubTaskResult> = Vec::new();
87 let artifacts: Vec<SwarmArtifact> = Vec::new();
88
89 let completed_results: Arc<RwLock<HashMap<String, String>>> =
91 Arc::new(RwLock::new(HashMap::new()));
92
93 for stage in 0..=max_stage {
94 let stage_start = Instant::now();
95
96 let stage_subtasks: Vec<SubTask> = orchestrator
97 .subtasks_for_stage(stage)
98 .into_iter()
99 .cloned()
100 .collect();
101
102 tracing::debug!(
103 "Stage {} has {} subtasks (max_stage={})",
104 stage,
105 stage_subtasks.len(),
106 max_stage
107 );
108
109 if stage_subtasks.is_empty() {
110 continue;
111 }
112
113 tracing::info!(
114 provider_name = %orchestrator.provider(),
115 "Executing stage {} with {} subtasks",
116 stage,
117 stage_subtasks.len()
118 );
119
120 let stage_results = self
122 .execute_stage(
123 &orchestrator,
124 stage_subtasks,
125 completed_results.clone(),
126 )
127 .await?;
128
129 {
131 let mut completed = completed_results.write().await;
132 for result in &stage_results {
133 completed.insert(result.subtask_id.clone(), result.result.clone());
134 }
135 }
136
137 let stage_time = stage_start.elapsed().as_millis() as u64;
139 let max_steps = stage_results.iter().map(|r| r.steps).max().unwrap_or(0);
140 let total_steps: usize = stage_results.iter().map(|r| r.steps).sum();
141
142 orchestrator.stats_mut().stages.push(StageStats {
143 stage,
144 subagent_count: stage_results.len(),
145 max_steps,
146 total_steps,
147 execution_time_ms: stage_time,
148 });
149
150 for result in &stage_results {
152 orchestrator.complete_subtask(&result.subtask_id, result.clone());
153 }
154
155 all_results.extend(stage_results);
156 }
157
158 let provider_name = orchestrator.provider().to_string();
160
161 let stats = orchestrator.stats_mut();
163 stats.execution_time_ms = start_time.elapsed().as_millis() as u64;
164 stats.sequential_time_estimate_ms = all_results
165 .iter()
166 .map(|r| r.execution_time_ms)
167 .sum();
168 stats.calculate_critical_path();
169 stats.calculate_speedup();
170
171 let success = all_results.iter().all(|r| r.success);
173 let result = self.aggregate_results(&all_results).await?;
174
175 tracing::info!(
176 provider_name = %provider_name,
177 "Swarm execution complete: {} subtasks, {:.1}x speedup",
178 all_results.len(),
179 stats.speedup_factor
180 );
181
182 Ok(SwarmResult {
183 success,
184 result,
185 subtask_results: all_results,
186 stats: orchestrator.stats().clone(),
187 artifacts,
188 error: None,
189 })
190 }
191
192 async fn execute_stage(
194 &self,
195 orchestrator: &Orchestrator,
196 subtasks: Vec<SubTask>,
197 completed_results: Arc<RwLock<HashMap<String, String>>>,
198 ) -> Result<Vec<SubTaskResult>> {
199 let mut handles: Vec<tokio::task::JoinHandle<Result<SubTaskResult, anyhow::Error>>> = Vec::new();
200
201 let semaphore = Arc::new(tokio::sync::Semaphore::new(self.config.max_concurrent_requests));
203 let delay_ms = self.config.request_delay_ms;
204
205 let model = orchestrator.model().to_string();
207 let provider_name = orchestrator.provider().to_string();
208 let providers = orchestrator.providers();
209 let provider = providers.get(&provider_name)
210 .ok_or_else(|| anyhow::anyhow!("Provider {} not found", provider_name))?;
211
212 tracing::info!(provider_name = %provider_name, "Selected provider for subtask execution");
213
214 let tool_registry = ToolRegistry::with_provider_arc(Arc::clone(&provider), model.clone());
216 let tool_definitions = tool_registry.definitions();
217
218 for (idx, subtask) in subtasks.into_iter().enumerate() {
219 let model = model.clone();
220 let _provider_name = provider_name.clone();
221 let provider = Arc::clone(&provider);
222
223 let context = {
225 let completed = completed_results.read().await;
226 let mut dep_context = String::new();
227 for dep_id in &subtask.dependencies {
228 if let Some(result) = completed.get(dep_id) {
229 dep_context.push_str(&format!("\n--- Result from dependency {} ---\n{}\n", dep_id, result));
230 }
231 }
232 dep_context
233 };
234
235 let instruction = subtask.instruction.clone();
236 let specialty = subtask.specialty.clone().unwrap_or_default();
237 let subtask_id = subtask.id.clone();
238 let max_steps = self.config.max_steps_per_subagent;
239 let timeout_secs = self.config.subagent_timeout_secs;
240
241 let tools = tool_definitions.clone();
243 let registry = Arc::clone(&tool_registry);
244 let sem = Arc::clone(&semaphore);
245 let stagger_delay = delay_ms * idx as u64; let handle = tokio::spawn(async move {
249 if stagger_delay > 0 {
251 tokio::time::sleep(Duration::from_millis(stagger_delay)).await;
252 }
253 let _permit = sem.acquire().await.expect("semaphore closed");
254
255 let start = Instant::now();
256
257let prd_filename = format!("prd_{}.json", subtask_id.replace("-", "_"));
260 let system_prompt = format!(
261 "You are a {} specialist sub-agent (ID: {}). You have access to tools to complete your task.
262
263IMPORTANT: You MUST use tools to make changes. Do not just describe what to do - actually do it using the tools available.
264
265Available tools:
266- read: Read file contents
267- write: Write/create files
268- edit: Edit existing files (search and replace)
269- multiedit: Make multiple edits at once
270- glob: Find files by pattern
271- grep: Search file contents
272- bash: Run shell commands
273- webfetch: Fetch web pages
274- prd: Generate structured PRD for complex tasks
275- ralph: Run autonomous agent loop on a PRD
276
277COMPLEX TASKS:
278If your task is complex and involves multiple implementation steps, use the prd + ralph workflow:
2791. Call prd({{action: 'analyze', task_description: '...'}}) to understand what's needed
2802. Break down into user stories with acceptance criteria
2813. Call prd({{action: 'save', prd_path: '{}', project: '...', feature: '...', stories: [...]}})
2824. Call ralph({{action: 'run', prd_path: '{}'}}) to execute
283
284NOTE: Use your unique PRD file '{}' so parallel agents don't conflict.
285
286When done, provide a brief summary of what you accomplished.",
287 specialty,
288 subtask_id,
289 prd_filename,
290 prd_filename,
291 prd_filename
292 );
293
294 let user_prompt = if context.is_empty() {
295 format!("Complete this task:\n\n{}", instruction)
296 } else {
297 format!(
298 "Complete this task:\n\n{}\n\nContext from prior work:\n{}",
299 instruction, context
300 )
301 };
302
303 let result = run_agent_loop(
305 provider,
306 &model,
307 &system_prompt,
308 &user_prompt,
309 tools,
310 registry,
311 max_steps,
312 timeout_secs,
313 ).await;
314
315 match result {
316 Ok((output, steps, tool_calls)) => {
317 Ok(SubTaskResult {
318 subtask_id: subtask_id.clone(),
319 subagent_id: format!("agent-{}", subtask_id),
320 success: true,
321 result: output,
322 steps,
323 tool_calls,
324 execution_time_ms: start.elapsed().as_millis() as u64,
325 error: None,
326 artifacts: Vec::new(),
327 })
328 }
329 Err(e) => {
330 Ok(SubTaskResult {
331 subtask_id: subtask_id.clone(),
332 subagent_id: format!("agent-{}", subtask_id),
333 success: false,
334 result: String::new(),
335 steps: 0,
336 tool_calls: 0,
337 execution_time_ms: start.elapsed().as_millis() as u64,
338 error: Some(e.to_string()),
339 artifacts: Vec::new(),
340 })
341 }
342 }
343 });
344
345 handles.push(handle);
346 }
347
348 let mut results = Vec::new();
350 for handle in handles {
351 match handle.await {
352 Ok(Ok(result)) => results.push(result),
353 Ok(Err(e)) => {
354 tracing::error!(provider_name = %provider_name, "Subtask error: {}", e);
355 }
356 Err(e) => {
357 tracing::error!(provider_name = %provider_name, "Task join error: {}", e);
358 }
359 }
360 }
361
362 Ok(results)
363 }
364
365 async fn aggregate_results(&self, results: &[SubTaskResult]) -> Result<String> {
367 let mut aggregated = String::new();
368
369 for (i, result) in results.iter().enumerate() {
370 if result.success {
371 aggregated.push_str(&format!(
372 "=== Subtask {} ===\n{}\n\n",
373 i + 1,
374 result.result
375 ));
376 } else {
377 aggregated.push_str(&format!(
378 "=== Subtask {} (FAILED) ===\nError: {}\n\n",
379 i + 1,
380 result.error.as_deref().unwrap_or("Unknown error")
381 ));
382 }
383 }
384
385 Ok(aggregated)
386 }
387
388 pub async fn execute_single(&self, task: &str) -> Result<SwarmResult> {
390 self.execute(task, DecompositionStrategy::None).await
391 }
392}
393
394pub struct SwarmExecutorBuilder {
396 config: SwarmConfig,
397}
398
399impl SwarmExecutorBuilder {
400 pub fn new() -> Self {
401 Self {
402 config: SwarmConfig::default(),
403 }
404 }
405
406 pub fn max_subagents(mut self, max: usize) -> Self {
407 self.config.max_subagents = max;
408 self
409 }
410
411 pub fn max_steps_per_subagent(mut self, max: usize) -> Self {
412 self.config.max_steps_per_subagent = max;
413 self
414 }
415
416 pub fn max_total_steps(mut self, max: usize) -> Self {
417 self.config.max_total_steps = max;
418 self
419 }
420
421 pub fn timeout_secs(mut self, secs: u64) -> Self {
422 self.config.subagent_timeout_secs = secs;
423 self
424 }
425
426 pub fn parallel_enabled(mut self, enabled: bool) -> Self {
427 self.config.parallel_enabled = enabled;
428 self
429 }
430
431 pub fn build(self) -> SwarmExecutor {
432 SwarmExecutor::new(self.config)
433 }
434}
435
436impl Default for SwarmExecutorBuilder {
437 fn default() -> Self {
438 Self::new()
439 }
440}
441
442#[allow(clippy::too_many_arguments)]
444async fn run_agent_loop(
445 provider: Arc<dyn Provider>,
446 model: &str,
447 system_prompt: &str,
448 user_prompt: &str,
449 tools: Vec<crate::provider::ToolDefinition>,
450 registry: Arc<ToolRegistry>,
451 max_steps: usize,
452 timeout_secs: u64,
453) -> Result<(String, usize, usize)> {
454 let temperature = 0.7;
456
457 tracing::info!(
458 model = %model,
459 max_steps = max_steps,
460 timeout_secs = timeout_secs,
461 "Sub-agent starting agentic loop"
462 );
463 tracing::debug!(system_prompt = %system_prompt, "Sub-agent system prompt");
464 tracing::debug!(user_prompt = %user_prompt, "Sub-agent user prompt");
465
466 let mut messages = vec![
468 Message {
469 role: Role::System,
470 content: vec![ContentPart::Text { text: system_prompt.to_string() }],
471 },
472 Message {
473 role: Role::User,
474 content: vec![ContentPart::Text { text: user_prompt.to_string() }],
475 },
476 ];
477
478 let mut steps = 0;
479 let mut total_tool_calls = 0;
480 let mut final_output = String::new();
481
482 let deadline = Instant::now() + Duration::from_secs(timeout_secs);
483
484 loop {
485 if steps >= max_steps {
486 tracing::warn!(max_steps = max_steps, "Sub-agent reached max steps limit");
487 break;
488 }
489
490 if Instant::now() > deadline {
491 tracing::warn!(timeout_secs = timeout_secs, "Sub-agent timed out");
492 break;
493 }
494
495 steps += 1;
496 tracing::info!(step = steps, "Sub-agent step starting");
497
498 let request = CompletionRequest {
499 messages: messages.clone(),
500 tools: tools.clone(),
501 model: model.to_string(),
502 temperature: Some(temperature),
503 top_p: None,
504 max_tokens: Some(8192),
505 stop: Vec::new(),
506 };
507
508 let step_start = Instant::now();
509 let response = timeout(
510 Duration::from_secs(120),
511 provider.complete(request),
512 ).await??;
513 let step_duration = step_start.elapsed();
514
515 tracing::info!(
516 step = steps,
517 duration_ms = step_duration.as_millis() as u64,
518 finish_reason = ?response.finish_reason,
519 prompt_tokens = response.usage.prompt_tokens,
520 completion_tokens = response.usage.completion_tokens,
521 "Sub-agent step completed LLM call"
522 );
523
524 let mut text_parts = Vec::new();
526 let mut tool_calls = Vec::new();
527
528 for part in &response.message.content {
529 match part {
530 ContentPart::Text { text } => {
531 text_parts.push(text.clone());
532 }
533 ContentPart::ToolCall { id, name, arguments } => {
534 tool_calls.push((id.clone(), name.clone(), arguments.clone()));
535 }
536 _ => {}
537 }
538 }
539
540 if !text_parts.is_empty() {
542 final_output = text_parts.join("\n");
543 tracing::info!(
544 step = steps,
545 output_len = final_output.len(),
546 "Sub-agent text output"
547 );
548 tracing::debug!(step = steps, output = %final_output, "Sub-agent full output");
549 }
550
551 if !tool_calls.is_empty() {
553 tracing::info!(
554 step = steps,
555 num_tool_calls = tool_calls.len(),
556 tools = ?tool_calls.iter().map(|(_, name, _)| name.as_str()).collect::<Vec<_>>(),
557 "Sub-agent requesting tool calls"
558 );
559 }
560
561 messages.push(response.message.clone());
563
564 if response.finish_reason != FinishReason::ToolCalls || tool_calls.is_empty() {
566 tracing::info!(
567 steps = steps,
568 total_tool_calls = total_tool_calls,
569 "Sub-agent finished"
570 );
571 break;
572 }
573
574 let mut tool_results = Vec::new();
576
577 for (call_id, tool_name, arguments) in tool_calls {
578 total_tool_calls += 1;
579
580 tracing::info!(
581 step = steps,
582 tool_call_id = %call_id,
583 tool = %tool_name,
584 "Executing tool"
585 );
586 tracing::debug!(
587 tool = %tool_name,
588 arguments = %arguments,
589 "Tool call arguments"
590 );
591
592 let tool_start = Instant::now();
593 let result = if let Some(tool) = registry.get(&tool_name) {
594 let args: serde_json::Value = serde_json::from_str(&arguments)
596 .unwrap_or_else(|_| serde_json::json!({}));
597
598 match tool.execute(args).await {
599 Ok(r) => {
600 if r.success {
601 tracing::info!(
602 tool = %tool_name,
603 duration_ms = tool_start.elapsed().as_millis() as u64,
604 success = true,
605 "Tool execution completed"
606 );
607 r.output
608 } else {
609 tracing::warn!(
610 tool = %tool_name,
611 error = %r.output,
612 "Tool returned error"
613 );
614 format!("Tool error: {}", r.output)
615 }
616 }
617 Err(e) => {
618 tracing::error!(
619 tool = %tool_name,
620 error = %e,
621 "Tool execution failed"
622 );
623 format!("Tool execution failed: {}", e)
624 }
625 }
626 } else {
627 tracing::error!(tool = %tool_name, "Unknown tool requested");
628 format!("Unknown tool: {}", tool_name)
629 };
630
631 tracing::debug!(
632 tool = %tool_name,
633 result_len = result.len(),
634 "Tool result"
635 );
636
637 tool_results.push((call_id, result));
638 }
639
640 for (call_id, result) in tool_results {
642 messages.push(Message {
643 role: Role::Tool,
644 content: vec![ContentPart::ToolResult {
645 tool_call_id: call_id,
646 content: result,
647 }],
648 });
649 }
650 }
651
652 Ok((final_output, steps, total_tool_calls))
653}