Skip to main content

codetether_agent/swarm/
orchestrator.rs

1//! Orchestrator for decomposing tasks and coordinating sub-agents
2//!
3//! The orchestrator analyzes complex tasks and decomposes them into
4//! parallelizable subtasks, then coordinates their execution.
5
6use super::{
7    DecompositionStrategy, SubAgent, SubTask, SubTaskResult, SubTaskStatus,
8    SwarmConfig, SwarmStats,
9};
10use crate::provider::{CompletionRequest, ContentPart, Message, ProviderRegistry, Role};
11use anyhow::Result;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15/// The orchestrator manages task decomposition and sub-agent coordination
16pub struct Orchestrator {
17    /// Configuration
18    config: SwarmConfig,
19    
20    /// Provider registry for AI calls
21    providers: ProviderRegistry,
22    
23    /// All subtasks
24    subtasks: HashMap<String, SubTask>,
25    
26    /// All sub-agents
27    subagents: HashMap<String, SubAgent>,
28    
29    /// Completed subtask IDs
30    completed: Vec<String>,
31    
32    /// Current model for orchestration
33    model: String,
34    
35    /// Current provider
36    provider: String,
37    
38    /// Stats
39    stats: SwarmStats,
40}
41
42impl Orchestrator {
43    /// Create a new orchestrator
44    pub async fn new(config: SwarmConfig) -> Result<Self> {
45        use crate::provider::parse_model_string;
46        
47        let providers = ProviderRegistry::from_vault().await?;
48        let provider_list = providers.list();
49        
50        if provider_list.is_empty() {
51            anyhow::bail!("No providers available for orchestration");
52        }
53        
54        // Parse model from config or use default
55        let (provider, model) = if let Some(ref model_str) = config.model {
56            let (prov, mod_id) = parse_model_string(model_str);
57            let provider = prov
58                .filter(|p| provider_list.contains(p))
59                .unwrap_or(provider_list[0])
60                .to_string();
61            let model = mod_id.to_string();
62            (provider, model)
63        } else {
64            let provider = provider_list[0].to_string();
65            let model = Self::default_model_for_provider(&provider);
66            (provider, model)
67        };
68        
69        tracing::info!("Orchestrator using model {} via {}", model, provider);
70        
71        Ok(Self {
72            config,
73            providers,
74            subtasks: HashMap::new(),
75            subagents: HashMap::new(),
76            completed: Vec::new(),
77            model,
78            provider,
79            stats: SwarmStats::default(),
80        })
81    }
82    
83    /// Get default model for a provider
84    fn default_model_for_provider(provider: &str) -> String {
85        match provider {
86            "moonshotai" => "kimi-k2.5".to_string(),
87            "anthropic" => "claude-sonnet-4-20250514".to_string(),
88            "openai" => "gpt-4o".to_string(),
89            "google" => "gemini-2.5-pro".to_string(),
90            "openrouter" => "stepfun/step-3.5-flash:free".to_string(),
91            _ => "kimi-k2.5".to_string(),
92        }
93    }
94    
95    /// Decompose a complex task into subtasks
96    pub async fn decompose(
97        &mut self,
98        task: &str,
99        strategy: DecompositionStrategy,
100    ) -> Result<Vec<SubTask>> {
101        if strategy == DecompositionStrategy::None {
102            // Single task, no decomposition
103            let subtask = SubTask::new("Main Task", task);
104            self.subtasks.insert(subtask.id.clone(), subtask.clone());
105            return Ok(vec![subtask]);
106        }
107        
108        // Use AI to decompose the task
109        let decomposition_prompt = self.build_decomposition_prompt(task, strategy);
110        
111        let provider = self.providers.get(&self.provider)
112            .ok_or_else(|| anyhow::anyhow!("Provider {} not found", self.provider))?;
113        
114        let temperature = if self.model.starts_with("kimi-k2") { 1.0 } else { 0.7 };
115        
116        let request = CompletionRequest {
117            messages: vec![Message {
118                role: Role::User,
119                content: vec![ContentPart::Text { text: decomposition_prompt }],
120            }],
121            tools: Vec::new(),
122            model: self.model.clone(),
123            temperature: Some(temperature),
124            top_p: None,
125            max_tokens: Some(8192),
126            stop: Vec::new(),
127        };
128        
129        let response = provider.complete(request).await?;
130        
131        // Parse the decomposition response
132        let text = response.message.content
133            .iter()
134            .filter_map(|p| match p {
135                ContentPart::Text { text } => Some(text.clone()),
136                _ => None,
137            })
138            .collect::<Vec<_>>()
139            .join("\n");
140        
141        tracing::debug!("Decomposition response: {}", text);
142        
143        if text.trim().is_empty() {
144            // Fallback to single task if decomposition fails
145            tracing::warn!("Empty decomposition response, falling back to single task");
146            let subtask = SubTask::new("Main Task", task);
147            self.subtasks.insert(subtask.id.clone(), subtask.clone());
148            return Ok(vec![subtask]);
149        }
150        
151        let subtasks = self.parse_decomposition(&text)?;
152        
153        // Store subtasks
154        for subtask in &subtasks {
155            self.subtasks.insert(subtask.id.clone(), subtask.clone());
156        }
157        
158        // Assign stages based on dependencies
159        self.assign_stages();
160        
161        tracing::info!(
162            "Decomposed task into {} subtasks across {} stages",
163            subtasks.len(),
164            self.max_stage() + 1
165        );
166        
167        Ok(subtasks)
168    }
169    
170    /// Build the decomposition prompt
171    fn build_decomposition_prompt(&self, task: &str, strategy: DecompositionStrategy) -> String {
172        let strategy_instruction = match strategy {
173            DecompositionStrategy::Automatic => {
174                "Analyze the task and determine the optimal way to decompose it into parallel subtasks."
175            }
176            DecompositionStrategy::ByDomain => {
177                "Decompose the task by domain expertise (e.g., research, coding, analysis, verification)."
178            }
179            DecompositionStrategy::ByData => {
180                "Decompose the task by data partition (e.g., different files, sections, or datasets)."
181            }
182            DecompositionStrategy::ByStage => {
183                "Decompose the task by workflow stages (e.g., gather, process, synthesize)."
184            }
185            DecompositionStrategy::None => unreachable!(),
186        };
187        
188        format!(
189            r#"You are a task orchestrator. Your job is to decompose complex tasks into parallelizable subtasks.
190
191TASK: {task}
192
193STRATEGY: {strategy_instruction}
194
195CONSTRAINTS:
196- Maximum {max_subtasks} subtasks
197- Each subtask should be independently executable
198- Identify dependencies between subtasks (which must complete before others can start)
199- Assign a specialty/role to each subtask
200
201OUTPUT FORMAT (JSON):
202```json
203{{
204  "subtasks": [
205    {{
206      "name": "Subtask Name",
207      "instruction": "Detailed instruction for this subtask",
208      "specialty": "Role/specialty (e.g., Researcher, Coder, Analyst)",
209      "dependencies": ["id-of-dependency-1"],
210      "priority": 1
211    }}
212  ]
213}}
214```
215
216Decompose the task now:"#,
217            task = task,
218            strategy_instruction = strategy_instruction,
219            max_subtasks = self.config.max_subagents,
220        )
221    }
222    
223    /// Parse the decomposition response
224    fn parse_decomposition(&self, response: &str) -> Result<Vec<SubTask>> {
225        // Try to extract JSON from the response
226        let json_str = if let Some(start) = response.find("```json") {
227            let start = start + 7;
228            if let Some(end) = response[start..].find("```") {
229                &response[start..start + end]
230            } else {
231                response
232            }
233        } else if let Some(start) = response.find('{') {
234            if let Some(end) = response.rfind('}') {
235                &response[start..=end]
236            } else {
237                response
238            }
239        } else {
240            response
241        };
242        
243        #[derive(Deserialize)]
244        struct DecompositionResponse {
245            subtasks: Vec<SubTaskDef>,
246        }
247        
248        #[derive(Deserialize)]
249        struct SubTaskDef {
250            name: String,
251            instruction: String,
252            specialty: Option<String>,
253            #[serde(default)]
254            dependencies: Vec<String>,
255            #[serde(default)]
256            priority: i32,
257        }
258        
259        let parsed: DecompositionResponse = serde_json::from_str(json_str.trim())
260            .map_err(|e| anyhow::anyhow!("Failed to parse decomposition: {}", e))?;
261        
262        // Create SubTask objects with proper IDs
263        let mut subtasks = Vec::new();
264        let mut name_to_id: HashMap<String, String> = HashMap::new();
265        
266        // First pass: create subtasks and map names to IDs
267        for def in &parsed.subtasks {
268            let subtask = SubTask::new(&def.name, &def.instruction)
269                .with_priority(def.priority);
270            
271            let subtask = if let Some(ref specialty) = def.specialty {
272                subtask.with_specialty(specialty)
273            } else {
274                subtask
275            };
276            
277            name_to_id.insert(def.name.clone(), subtask.id.clone());
278            subtasks.push((subtask, def.dependencies.clone()));
279        }
280        
281        // Second pass: resolve dependencies
282        let result: Vec<SubTask> = subtasks
283            .into_iter()
284            .map(|(mut subtask, deps)| {
285                let resolved_deps: Vec<String> = deps
286                    .iter()
287                    .filter_map(|dep| name_to_id.get(dep).cloned())
288                    .collect();
289                subtask.dependencies = resolved_deps;
290                subtask
291            })
292            .collect();
293        
294        Ok(result)
295    }
296    
297    /// Assign stages to subtasks based on dependencies
298    fn assign_stages(&mut self) {
299        let mut changed = true;
300        
301        while changed {
302            changed = false;
303            
304            // First collect all updates needed
305            let updates: Vec<(String, usize)> = self.subtasks.iter().filter_map(|(id, subtask)| {
306                if subtask.dependencies.is_empty() {
307                    if subtask.stage != 0 {
308                        Some((id.clone(), 0))
309                    } else {
310                        None
311                    }
312                } else {
313                    let max_dep_stage = subtask
314                        .dependencies
315                        .iter()
316                        .filter_map(|dep_id| self.subtasks.get(dep_id))
317                        .map(|dep| dep.stage)
318                        .max()
319                        .unwrap_or(0);
320                    
321                    let new_stage = max_dep_stage + 1;
322                    if subtask.stage != new_stage {
323                        Some((id.clone(), new_stage))
324                    } else {
325                        None
326                    }
327                }
328            }).collect();
329            
330            // Then apply updates
331            for (id, new_stage) in updates {
332                if let Some(subtask) = self.subtasks.get_mut(&id) {
333                    subtask.stage = new_stage;
334                    changed = true;
335                }
336            }
337        }
338    }
339    
340    /// Get maximum stage number
341    fn max_stage(&self) -> usize {
342        self.subtasks.values().map(|s| s.stage).max().unwrap_or(0)
343    }
344    
345    /// Get subtasks ready to execute (dependencies satisfied)
346    pub fn ready_subtasks(&self) -> Vec<&SubTask> {
347        self.subtasks
348            .values()
349            .filter(|s| s.status == SubTaskStatus::Pending && s.can_run(&self.completed))
350            .collect()
351    }
352    
353    /// Get subtasks for a specific stage
354    pub fn subtasks_for_stage(&self, stage: usize) -> Vec<&SubTask> {
355        self.subtasks
356            .values()
357            .filter(|s| s.stage == stage)
358            .collect()
359    }
360    
361    /// Create a sub-agent for a subtask
362    pub fn create_subagent(&mut self, subtask: &SubTask) -> SubAgent {
363        let specialty = subtask.specialty.clone().unwrap_or_else(|| "General".to_string());
364        let name = format!("{} Agent", specialty);
365        
366        let subagent = SubAgent::new(
367            name,
368            specialty,
369            &subtask.id,
370            &self.model,
371            &self.provider,
372        );
373        
374        self.subagents.insert(subagent.id.clone(), subagent.clone());
375        self.stats.subagents_spawned += 1;
376        
377        subagent
378    }
379    
380    /// Mark a subtask as completed
381    pub fn complete_subtask(&mut self, subtask_id: &str, result: SubTaskResult) {
382        if let Some(subtask) = self.subtasks.get_mut(subtask_id) {
383            subtask.complete(result.success);
384            
385            if result.success {
386                self.completed.push(subtask_id.to_string());
387                self.stats.subagents_completed += 1;
388            } else {
389                self.stats.subagents_failed += 1;
390            }
391            
392            self.stats.total_tool_calls += result.tool_calls;
393        }
394    }
395    
396    /// Get all subtasks
397    pub fn all_subtasks(&self) -> Vec<&SubTask> {
398        self.subtasks.values().collect()
399    }
400    
401    /// Get statistics
402    pub fn stats(&self) -> &SwarmStats {
403        &self.stats
404    }
405    
406    /// Get mutable statistics
407    pub fn stats_mut(&mut self) -> &mut SwarmStats {
408        &mut self.stats
409    }
410    
411    /// Check if all subtasks are complete
412    pub fn is_complete(&self) -> bool {
413        self.subtasks.values().all(|s| {
414            matches!(s.status, SubTaskStatus::Completed | SubTaskStatus::Failed | SubTaskStatus::Cancelled)
415        })
416    }
417    
418    /// Get the provider registry
419    pub fn providers(&self) -> &ProviderRegistry {
420        &self.providers
421    }
422    
423    /// Get current model
424    pub fn model(&self) -> &str {
425        &self.model
426    }
427    
428    /// Get current provider
429    pub fn provider(&self) -> &str {
430        &self.provider
431    }
432}
433
434/// Message from sub-agent to orchestrator
435#[derive(Debug, Clone, Serialize, Deserialize)]
436pub enum SubAgentMessage {
437    /// Progress update
438    Progress {
439        subagent_id: String,
440        subtask_id: String,
441        steps: usize,
442        status: String,
443    },
444    
445    /// Tool call made
446    ToolCall {
447        subagent_id: String,
448        tool_name: String,
449        success: bool,
450    },
451    
452    /// Subtask completed
453    Completed {
454        subagent_id: String,
455        result: SubTaskResult,
456    },
457    
458    /// Request for resources
459    ResourceRequest {
460        subagent_id: String,
461        resource_type: String,
462        resource_id: String,
463    },
464}
465
466/// Message from orchestrator to sub-agent
467#[derive(Debug, Clone, Serialize, Deserialize)]
468pub enum OrchestratorMessage {
469    /// Start execution
470    Start {
471        subtask: Box<SubTask>,
472    },
473    
474    /// Provide resource
475    Resource {
476        resource_type: String,
477        resource_id: String,
478        content: String,
479    },
480    
481    /// Terminate execution
482    Terminate {
483        reason: String,
484    },
485    
486    /// Context update (from completed dependency)
487    ContextUpdate {
488        dependency_id: String,
489        result: String,
490    },
491}