Skip to main content

codetether_agent/swarm/
executor.rs

1//! Parallel execution engine for swarm operations
2//!
3//! Executes subtasks in parallel across multiple sub-agents,
4//! respecting dependencies and optimizing for critical path.
5
6use super::{
7    orchestrator::Orchestrator,
8    subtask::{SubTask, SubTaskResult},
9    DecompositionStrategy, StageStats, SwarmConfig, SwarmResult,
10};
11
12// Re-export swarm types for convenience
13pub use super::{Actor, ActorStatus, Handler, SwarmMessage};
14use crate::{
15    agent::Agent,
16    provider::{CompletionRequest, ContentPart, FinishReason, Message, Provider, Role},
17    swarm::{SwarmArtifact, SwarmStats},
18    tool::ToolRegistry,
19};
20use anyhow::Result;
21use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::Instant;
24use tokio::sync::RwLock;
25use tokio::time::{timeout, Duration};
26
27/// The swarm executor runs subtasks in parallel
28pub struct SwarmExecutor {
29    config: SwarmConfig,
30    /// Optional agent for handling swarm-level coordination (reserved for future use)
31    coordinator_agent: Option<Arc<tokio::sync::Mutex<Agent>>>,
32}
33
34impl SwarmExecutor {
35    /// Create a new executor
36    pub fn new(config: SwarmConfig) -> Self {
37        Self { 
38            config,
39            coordinator_agent: None,
40        }
41    }
42    
43    /// Set a coordinator agent for swarm-level coordination
44    pub fn with_coordinator_agent(mut self, agent: Arc<tokio::sync::Mutex<Agent>>) -> Self {
45        tracing::debug!("Setting coordinator agent for swarm execution");
46        self.coordinator_agent = Some(agent);
47        self
48    }
49    
50    /// Get the coordinator agent if set
51    pub fn coordinator_agent(&self) -> Option<&Arc<tokio::sync::Mutex<Agent>>> {
52        self.coordinator_agent.as_ref()
53    }
54    
55    /// Execute a task using the swarm
56    pub async fn execute(
57        &self,
58        task: &str,
59        strategy: DecompositionStrategy,
60    ) -> Result<SwarmResult> {
61        let start_time = Instant::now();
62        
63        // Create orchestrator
64        let mut orchestrator = Orchestrator::new(self.config.clone()).await?;
65        
66        tracing::info!(provider_name = %orchestrator.provider(), "Starting swarm execution for task");
67        
68        // Decompose the task
69        let subtasks = orchestrator.decompose(task, strategy).await?;
70        
71        if subtasks.is_empty() {
72            return Ok(SwarmResult {
73                success: false,
74                result: String::new(),
75                subtask_results: Vec::new(),
76                stats: SwarmStats::default(),
77                artifacts: Vec::new(),
78                error: Some("No subtasks generated".to_string()),
79            });
80        }
81        
82        tracing::info!(provider_name = %orchestrator.provider(), "Task decomposed into {} subtasks", subtasks.len());
83        
84        // Execute stages in order
85        let max_stage = subtasks.iter().map(|s| s.stage).max().unwrap_or(0);
86        let mut all_results: Vec<SubTaskResult> = Vec::new();
87        let artifacts: Vec<SwarmArtifact> = Vec::new();
88        
89        // Shared state for completed results
90        let completed_results: Arc<RwLock<HashMap<String, String>>> = 
91            Arc::new(RwLock::new(HashMap::new()));
92        
93        for stage in 0..=max_stage {
94            let stage_start = Instant::now();
95            
96            let stage_subtasks: Vec<SubTask> = orchestrator
97                .subtasks_for_stage(stage)
98                .into_iter()
99                .cloned()
100                .collect();
101            
102            tracing::debug!(
103                "Stage {} has {} subtasks (max_stage={})",
104                stage,
105                stage_subtasks.len(),
106                max_stage
107            );
108            
109            if stage_subtasks.is_empty() {
110                continue;
111            }
112            
113            tracing::info!(
114                provider_name = %orchestrator.provider(),
115                "Executing stage {} with {} subtasks",
116                stage,
117                stage_subtasks.len()
118            );
119            
120            // Execute all subtasks in this stage in parallel
121            let stage_results = self
122                .execute_stage(
123                    &orchestrator,
124                    stage_subtasks,
125                    completed_results.clone(),
126                )
127                .await?;
128            
129            // Update completed results for next stage
130            {
131                let mut completed = completed_results.write().await;
132                for result in &stage_results {
133                    completed.insert(result.subtask_id.clone(), result.result.clone());
134                }
135            }
136            
137            // Update orchestrator stats
138            let stage_time = stage_start.elapsed().as_millis() as u64;
139            let max_steps = stage_results.iter().map(|r| r.steps).max().unwrap_or(0);
140            let total_steps: usize = stage_results.iter().map(|r| r.steps).sum();
141            
142            orchestrator.stats_mut().stages.push(StageStats {
143                stage,
144                subagent_count: stage_results.len(),
145                max_steps,
146                total_steps,
147                execution_time_ms: stage_time,
148            });
149            
150            // Mark subtasks as completed
151            for result in &stage_results {
152                orchestrator.complete_subtask(&result.subtask_id, result.clone());
153            }
154            
155            all_results.extend(stage_results);
156        }
157        
158        // Get provider name before mutable borrow
159        let provider_name = orchestrator.provider().to_string();
160        
161        // Calculate final stats
162        let stats = orchestrator.stats_mut();
163        stats.execution_time_ms = start_time.elapsed().as_millis() as u64;
164        stats.sequential_time_estimate_ms = all_results
165            .iter()
166            .map(|r| r.execution_time_ms)
167            .sum();
168        stats.calculate_critical_path();
169        stats.calculate_speedup();
170        
171        // Aggregate results
172        let success = all_results.iter().all(|r| r.success);
173        let result = self.aggregate_results(&all_results).await?;
174        
175        tracing::info!(
176            provider_name = %provider_name,
177            "Swarm execution complete: {} subtasks, {:.1}x speedup",
178            all_results.len(),
179            stats.speedup_factor
180        );
181        
182        Ok(SwarmResult {
183            success,
184            result,
185            subtask_results: all_results,
186            stats: orchestrator.stats().clone(),
187            artifacts,
188            error: None,
189        })
190    }
191    
192    /// Execute a single stage of subtasks in parallel (with rate limiting)
193    async fn execute_stage(
194        &self,
195        orchestrator: &Orchestrator,
196        subtasks: Vec<SubTask>,
197        completed_results: Arc<RwLock<HashMap<String, String>>>,
198    ) -> Result<Vec<SubTaskResult>> {
199        let mut handles: Vec<tokio::task::JoinHandle<Result<SubTaskResult, anyhow::Error>>> = Vec::new();
200        
201        // Rate limiting: semaphore for max concurrent requests
202        let semaphore = Arc::new(tokio::sync::Semaphore::new(self.config.max_concurrent_requests));
203        let delay_ms = self.config.request_delay_ms;
204        
205        // Get provider info for tool registry
206        let model = orchestrator.model().to_string();
207        let provider_name = orchestrator.provider().to_string();
208        let providers = orchestrator.providers();
209        let provider = providers.get(&provider_name)
210            .ok_or_else(|| anyhow::anyhow!("Provider {} not found", provider_name))?;
211        
212        tracing::info!(provider_name = %provider_name, "Selected provider for subtask execution");
213        
214        // Create shared tool registry with provider for ralph and batch tool
215        let tool_registry = ToolRegistry::with_provider_arc(Arc::clone(&provider), model.clone());
216        let tool_definitions = tool_registry.definitions();
217        
218        for (idx, subtask) in subtasks.into_iter().enumerate() {
219            let model = model.clone();
220            let _provider_name = provider_name.clone();
221            let provider = Arc::clone(&provider);
222            
223            // Get context from dependencies
224            let context = {
225                let completed = completed_results.read().await;
226                let mut dep_context = String::new();
227                for dep_id in &subtask.dependencies {
228                    if let Some(result) = completed.get(dep_id) {
229                        dep_context.push_str(&format!("\n--- Result from dependency {} ---\n{}\n", dep_id, result));
230                    }
231                }
232                dep_context
233            };
234            
235            let instruction = subtask.instruction.clone();
236            let specialty = subtask.specialty.clone().unwrap_or_default();
237            let subtask_id = subtask.id.clone();
238            let max_steps = self.config.max_steps_per_subagent;
239            let timeout_secs = self.config.subagent_timeout_secs;
240            
241            // Clone for the async block
242            let tools = tool_definitions.clone();
243            let registry = Arc::clone(&tool_registry);
244            let sem = Arc::clone(&semaphore);
245            let stagger_delay = delay_ms * idx as u64;  // Stagger start times
246            
247            // Spawn the subtask execution with agentic tool loop
248            let handle = tokio::spawn(async move {
249                // Rate limiting: stagger start and acquire semaphore
250                if stagger_delay > 0 {
251                    tokio::time::sleep(Duration::from_millis(stagger_delay)).await;
252                }
253                let _permit = sem.acquire().await.expect("semaphore closed");
254                
255                let start = Instant::now();
256                
257// Build the system prompt for this sub-agent
258                // Use subtask_id to create unique PRD names for parallel execution
259                let prd_filename = format!("prd_{}.json", subtask_id.replace("-", "_"));
260                let system_prompt = format!(
261                    "You are a {} specialist sub-agent (ID: {}). You have access to tools to complete your task.
262
263IMPORTANT: You MUST use tools to make changes. Do not just describe what to do - actually do it using the tools available.
264
265Available tools:
266- read: Read file contents
267- write: Write/create files  
268- edit: Edit existing files (search and replace)
269- multiedit: Make multiple edits at once
270- glob: Find files by pattern
271- grep: Search file contents
272- bash: Run shell commands
273- webfetch: Fetch web pages
274- prd: Generate structured PRD for complex tasks
275- ralph: Run autonomous agent loop on a PRD
276
277COMPLEX TASKS:
278If your task is complex and involves multiple implementation steps, use the prd + ralph workflow:
2791. Call prd({{action: 'analyze', task_description: '...'}}) to understand what's needed
2802. Break down into user stories with acceptance criteria
2813. Call prd({{action: 'save', prd_path: '{}', project: '...', feature: '...', stories: [...]}})
2824. Call ralph({{action: 'run', prd_path: '{}'}}) to execute
283
284NOTE: Use your unique PRD file '{}' so parallel agents don't conflict.
285
286When done, provide a brief summary of what you accomplished.",
287                    specialty,
288                    subtask_id,
289                    prd_filename,
290                    prd_filename,
291                    prd_filename
292                );
293                
294                let user_prompt = if context.is_empty() {
295                    format!("Complete this task:\n\n{}", instruction)
296                } else {
297                    format!(
298                        "Complete this task:\n\n{}\n\nContext from prior work:\n{}",
299                        instruction, context
300                    )
301                };
302                
303                // Run the agentic loop
304                let result = run_agent_loop(
305                    provider,
306                    &model,
307                    &system_prompt,
308                    &user_prompt,
309                    tools,
310                    registry,
311                    max_steps,
312                    timeout_secs,
313                ).await;
314                
315                match result {
316                    Ok((output, steps, tool_calls)) => {
317                        Ok(SubTaskResult {
318                            subtask_id: subtask_id.clone(),
319                            subagent_id: format!("agent-{}", subtask_id),
320                            success: true,
321                            result: output,
322                            steps,
323                            tool_calls,
324                            execution_time_ms: start.elapsed().as_millis() as u64,
325                            error: None,
326                            artifacts: Vec::new(),
327                        })
328                    }
329                    Err(e) => {
330                        Ok(SubTaskResult {
331                            subtask_id: subtask_id.clone(),
332                            subagent_id: format!("agent-{}", subtask_id),
333                            success: false,
334                            result: String::new(),
335                            steps: 0,
336                            tool_calls: 0,
337                            execution_time_ms: start.elapsed().as_millis() as u64,
338                            error: Some(e.to_string()),
339                            artifacts: Vec::new(),
340                        })
341                    }
342                }
343            });
344            
345            handles.push(handle);
346        }
347        
348        // Wait for all handles
349        let mut results = Vec::new();
350        for handle in handles {
351            match handle.await {
352                Ok(Ok(result)) => results.push(result),
353                Ok(Err(e)) => {
354                    tracing::error!(provider_name = %provider_name, "Subtask error: {}", e);
355                }
356                Err(e) => {
357                    tracing::error!(provider_name = %provider_name, "Task join error: {}", e);
358                }
359            }
360        }
361        
362        Ok(results)
363    }
364    
365    /// Aggregate results from all subtasks into a final result
366    async fn aggregate_results(&self, results: &[SubTaskResult]) -> Result<String> {
367        let mut aggregated = String::new();
368        
369        for (i, result) in results.iter().enumerate() {
370            if result.success {
371                aggregated.push_str(&format!(
372                    "=== Subtask {} ===\n{}\n\n",
373                    i + 1,
374                    result.result
375                ));
376            } else {
377                aggregated.push_str(&format!(
378                    "=== Subtask {} (FAILED) ===\nError: {}\n\n",
379                    i + 1,
380                    result.error.as_deref().unwrap_or("Unknown error")
381                ));
382            }
383        }
384        
385        Ok(aggregated)
386    }
387    
388    /// Execute a single task without decomposition (for simple cases)
389    pub async fn execute_single(&self, task: &str) -> Result<SwarmResult> {
390        self.execute(task, DecompositionStrategy::None).await
391    }
392}
393
394/// Builder for swarm execution
395pub struct SwarmExecutorBuilder {
396    config: SwarmConfig,
397}
398
399impl SwarmExecutorBuilder {
400    pub fn new() -> Self {
401        Self {
402            config: SwarmConfig::default(),
403        }
404    }
405    
406    pub fn max_subagents(mut self, max: usize) -> Self {
407        self.config.max_subagents = max;
408        self
409    }
410    
411    pub fn max_steps_per_subagent(mut self, max: usize) -> Self {
412        self.config.max_steps_per_subagent = max;
413        self
414    }
415    
416    pub fn max_total_steps(mut self, max: usize) -> Self {
417        self.config.max_total_steps = max;
418        self
419    }
420    
421    pub fn timeout_secs(mut self, secs: u64) -> Self {
422        self.config.subagent_timeout_secs = secs;
423        self
424    }
425    
426    pub fn parallel_enabled(mut self, enabled: bool) -> Self {
427        self.config.parallel_enabled = enabled;
428        self
429    }
430    
431    pub fn build(self) -> SwarmExecutor {
432        SwarmExecutor::new(self.config)
433    }
434}
435
436impl Default for SwarmExecutorBuilder {
437    fn default() -> Self {
438        Self::new()
439    }
440}
441
442/// Run the agentic loop for a sub-agent with tool execution
443#[allow(clippy::too_many_arguments)]
444async fn run_agent_loop(
445    provider: Arc<dyn Provider>,
446    model: &str,
447    system_prompt: &str,
448    user_prompt: &str,
449    tools: Vec<crate::provider::ToolDefinition>,
450    registry: Arc<ToolRegistry>,
451    max_steps: usize,
452    timeout_secs: u64,
453) -> Result<(String, usize, usize)> {
454    // Let the provider handle temperature - K2 models need 0.6 when thinking is disabled
455    let temperature = 0.7;
456    
457    tracing::info!(
458        model = %model,
459        max_steps = max_steps,
460        timeout_secs = timeout_secs,
461        "Sub-agent starting agentic loop"
462    );
463    tracing::debug!(system_prompt = %system_prompt, "Sub-agent system prompt");
464    tracing::debug!(user_prompt = %user_prompt, "Sub-agent user prompt");
465    
466    // Initialize conversation with system and user messages
467    let mut messages = vec![
468        Message {
469            role: Role::System,
470            content: vec![ContentPart::Text { text: system_prompt.to_string() }],
471        },
472        Message {
473            role: Role::User,
474            content: vec![ContentPart::Text { text: user_prompt.to_string() }],
475        },
476    ];
477    
478    let mut steps = 0;
479    let mut total_tool_calls = 0;
480    let mut final_output = String::new();
481    
482    let deadline = Instant::now() + Duration::from_secs(timeout_secs);
483    
484    loop {
485        if steps >= max_steps {
486            tracing::warn!(max_steps = max_steps, "Sub-agent reached max steps limit");
487            break;
488        }
489        
490        if Instant::now() > deadline {
491            tracing::warn!(timeout_secs = timeout_secs, "Sub-agent timed out");
492            break;
493        }
494        
495        steps += 1;
496        tracing::info!(step = steps, "Sub-agent step starting");
497        
498        let request = CompletionRequest {
499            messages: messages.clone(),
500            tools: tools.clone(),
501            model: model.to_string(),
502            temperature: Some(temperature),
503            top_p: None,
504            max_tokens: Some(8192),
505            stop: Vec::new(),
506        };
507        
508        let step_start = Instant::now();
509        let response = timeout(
510            Duration::from_secs(120),
511            provider.complete(request),
512        ).await??;
513        let step_duration = step_start.elapsed();
514        
515        tracing::info!(
516            step = steps,
517            duration_ms = step_duration.as_millis() as u64,
518            finish_reason = ?response.finish_reason,
519            prompt_tokens = response.usage.prompt_tokens,
520            completion_tokens = response.usage.completion_tokens,
521            "Sub-agent step completed LLM call"
522        );
523        
524        // Extract text from response
525        let mut text_parts = Vec::new();
526        let mut tool_calls = Vec::new();
527        
528        for part in &response.message.content {
529            match part {
530                ContentPart::Text { text } => {
531                    text_parts.push(text.clone());
532                }
533                ContentPart::ToolCall { id, name, arguments } => {
534                    tool_calls.push((id.clone(), name.clone(), arguments.clone()));
535                }
536                _ => {}
537            }
538        }
539        
540        // Log assistant output
541        if !text_parts.is_empty() {
542            final_output = text_parts.join("\n");
543            tracing::info!(
544                step = steps,
545                output_len = final_output.len(),
546                "Sub-agent text output"
547            );
548            tracing::debug!(step = steps, output = %final_output, "Sub-agent full output");
549        }
550        
551        // Log tool calls
552        if !tool_calls.is_empty() {
553            tracing::info!(
554                step = steps,
555                num_tool_calls = tool_calls.len(),
556                tools = ?tool_calls.iter().map(|(_, name, _)| name.as_str()).collect::<Vec<_>>(),
557                "Sub-agent requesting tool calls"
558            );
559        }
560        
561        // Add assistant message to history
562        messages.push(response.message.clone());
563        
564        // If no tool calls or stop, we're done
565        if response.finish_reason != FinishReason::ToolCalls || tool_calls.is_empty() {
566            tracing::info!(
567                steps = steps, 
568                total_tool_calls = total_tool_calls,
569                "Sub-agent finished"
570            );
571            break;
572        }
573        
574        // Execute tool calls
575        let mut tool_results = Vec::new();
576        
577        for (call_id, tool_name, arguments) in tool_calls {
578            total_tool_calls += 1;
579            
580            tracing::info!(
581                step = steps,
582                tool_call_id = %call_id,
583                tool = %tool_name,
584                "Executing tool"
585            );
586            tracing::debug!(
587                tool = %tool_name,
588                arguments = %arguments,
589                "Tool call arguments"
590            );
591            
592            let tool_start = Instant::now();
593            let result = if let Some(tool) = registry.get(&tool_name) {
594                // Parse arguments as JSON
595                let args: serde_json::Value = serde_json::from_str(&arguments)
596                    .unwrap_or_else(|_| serde_json::json!({}));
597                
598                match tool.execute(args).await {
599                    Ok(r) => {
600                        if r.success {
601                            tracing::info!(
602                                tool = %tool_name,
603                                duration_ms = tool_start.elapsed().as_millis() as u64,
604                                success = true,
605                                "Tool execution completed"
606                            );
607                            r.output
608                        } else {
609                            tracing::warn!(
610                                tool = %tool_name,
611                                error = %r.output,
612                                "Tool returned error"
613                            );
614                            format!("Tool error: {}", r.output)
615                        }
616                    }
617                    Err(e) => {
618                        tracing::error!(
619                            tool = %tool_name,
620                            error = %e,
621                            "Tool execution failed"
622                        );
623                        format!("Tool execution failed: {}", e)
624                    }
625                }
626            } else {
627                tracing::error!(tool = %tool_name, "Unknown tool requested");
628                format!("Unknown tool: {}", tool_name)
629            };
630            
631            tracing::debug!(
632                tool = %tool_name,
633                result_len = result.len(),
634                "Tool result"
635            );
636            
637            tool_results.push((call_id, result));
638        }
639        
640        // Add tool results to conversation
641        for (call_id, result) in tool_results {
642            messages.push(Message {
643                role: Role::Tool,
644                content: vec![ContentPart::ToolResult {
645                    tool_call_id: call_id,
646                    content: result,
647                }],
648            });
649        }
650    }
651    
652    Ok((final_output, steps, total_tool_calls))
653}