Skip to main content

codetether_agent/swarm/
orchestrator.rs

1//! Orchestrator for decomposing tasks and coordinating sub-agents
2//!
3//! The orchestrator analyzes complex tasks and decomposes them into
4//! parallelizable subtasks, then coordinates their execution.
5
6use super::{
7    DecompositionStrategy, SubAgent, SubTask, SubTaskResult, SubTaskStatus, SwarmConfig, SwarmStats,
8};
9use crate::provider::{CompletionRequest, ContentPart, Message, ProviderRegistry, Role};
10use anyhow::Result;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// The orchestrator manages task decomposition and sub-agent coordination
15pub struct Orchestrator {
16    /// Configuration
17    config: SwarmConfig,
18
19    /// Provider registry for AI calls
20    providers: ProviderRegistry,
21
22    /// All subtasks
23    subtasks: HashMap<String, SubTask>,
24
25    /// All sub-agents
26    subagents: HashMap<String, SubAgent>,
27
28    /// Completed subtask IDs
29    completed: Vec<String>,
30
31    /// Current model for orchestration
32    model: String,
33
34    /// Current provider
35    provider: String,
36
37    /// Stats
38    stats: SwarmStats,
39}
40
41impl Orchestrator {
42    /// Create a new orchestrator
43    pub async fn new(config: SwarmConfig) -> Result<Self> {
44        use crate::provider::parse_model_string;
45
46        let providers = ProviderRegistry::from_vault().await?;
47        let provider_list = providers.list();
48
49        if provider_list.is_empty() {
50            anyhow::bail!("No providers available for orchestration");
51        }
52
53        // Parse model from config, env var, or use default
54        let model_str = config
55            .model
56            .clone()
57            .or_else(|| std::env::var("CODETETHER_DEFAULT_MODEL").ok());
58
59        let (provider, model) = if let Some(ref model_str) = model_str {
60            let (prov, mod_id) = parse_model_string(model_str);
61            let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
62            let provider = prov
63                .filter(|p| provider_list.contains(p))
64                .unwrap_or(provider_list[0])
65                .to_string();
66            let model = mod_id.to_string();
67            (provider, model)
68        } else {
69            // Default to GLM-5 via Z.AI for swarm.
70            let provider = if provider_list.contains(&"zai") {
71                "zai".to_string()
72            } else if provider_list.contains(&"openrouter") {
73                "openrouter".to_string()
74            } else {
75                provider_list[0].to_string()
76            };
77            let model = Self::default_model_for_provider(&provider);
78            (provider, model)
79        };
80
81        tracing::info!("Orchestrator using model {} via {}", model, provider);
82
83        Ok(Self {
84            config,
85            providers,
86            subtasks: HashMap::new(),
87            subagents: HashMap::new(),
88            completed: Vec::new(),
89            model,
90            provider,
91            stats: SwarmStats::default(),
92        })
93    }
94
95    /// Get default model for a provider
96    fn default_model_for_provider(provider: &str) -> String {
97        match provider {
98            "moonshotai" => "kimi-k2.5".to_string(),
99            "anthropic" => "claude-sonnet-4-20250514".to_string(),
100            "openai" => "gpt-4o".to_string(),
101            "google" => "gemini-2.5-pro".to_string(),
102            "zhipuai" | "zai" => "glm-5".to_string(),
103            "openrouter" => "z-ai/glm-5".to_string(),
104            "novita" => "qwen/qwen3-coder-next".to_string(),
105            "github-copilot" | "github-copilot-enterprise" => "gpt-5-mini".to_string(),
106            _ => "glm-5".to_string(),
107        }
108    }
109
110    /// Decompose a complex task into subtasks
111    pub async fn decompose(
112        &mut self,
113        task: &str,
114        strategy: DecompositionStrategy,
115    ) -> Result<Vec<SubTask>> {
116        if strategy == DecompositionStrategy::None {
117            // Single task, no decomposition
118            let subtask = SubTask::new("Main Task", task);
119            self.subtasks.insert(subtask.id.clone(), subtask.clone());
120            return Ok(vec![subtask]);
121        }
122
123        // Use AI to decompose the task
124        let decomposition_prompt = self.build_decomposition_prompt(task, strategy);
125
126        let provider = self
127            .providers
128            .get(&self.provider)
129            .ok_or_else(|| anyhow::anyhow!("Provider {} not found", self.provider))?;
130
131        let temperature = if self.model.contains("kimi-k2") || self.model.contains("glm-") {
132            1.0
133        } else {
134            0.7
135        };
136
137        let request = CompletionRequest {
138            messages: vec![Message {
139                role: Role::User,
140                content: vec![ContentPart::Text {
141                    text: decomposition_prompt,
142                }],
143            }],
144            tools: Vec::new(),
145            model: self.model.clone(),
146            temperature: Some(temperature),
147            top_p: None,
148            max_tokens: Some(8192),
149            stop: Vec::new(),
150        };
151
152        let response = provider.complete(request).await?;
153
154        // Parse the decomposition response
155        let text = response
156            .message
157            .content
158            .iter()
159            .filter_map(|p| match p {
160                ContentPart::Text { text } => Some(text.clone()),
161                _ => None,
162            })
163            .collect::<Vec<_>>()
164            .join("\n");
165
166        tracing::debug!("Decomposition response: {}", text);
167
168        if text.trim().is_empty() {
169            // Fallback to single task if decomposition fails
170            tracing::warn!("Empty decomposition response, falling back to single task");
171            let subtask = SubTask::new("Main Task", task);
172            self.subtasks.insert(subtask.id.clone(), subtask.clone());
173            return Ok(vec![subtask]);
174        }
175
176        let subtasks = self.parse_decomposition(&text)?;
177
178        // Store subtasks
179        for subtask in &subtasks {
180            self.subtasks.insert(subtask.id.clone(), subtask.clone());
181        }
182
183        // Assign stages based on dependencies
184        self.assign_stages();
185
186        tracing::info!(
187            "Decomposed task into {} subtasks across {} stages",
188            subtasks.len(),
189            self.max_stage() + 1
190        );
191
192        Ok(subtasks)
193    }
194
195    /// Build the decomposition prompt
196    fn build_decomposition_prompt(&self, task: &str, strategy: DecompositionStrategy) -> String {
197        let strategy_instruction = match strategy {
198            DecompositionStrategy::Automatic => {
199                "Analyze the task and determine the optimal way to decompose it into parallel subtasks."
200            }
201            DecompositionStrategy::ByDomain => {
202                "Decompose the task by domain expertise (e.g., research, coding, analysis, verification)."
203            }
204            DecompositionStrategy::ByData => {
205                "Decompose the task by data partition (e.g., different files, sections, or datasets)."
206            }
207            DecompositionStrategy::ByStage => {
208                "Decompose the task by workflow stages (e.g., gather, process, synthesize)."
209            }
210            DecompositionStrategy::None => unreachable!(),
211        };
212
213        format!(
214            r#"You are a task orchestrator. Your job is to decompose complex tasks into parallelizable subtasks.
215
216TASK: {task}
217
218STRATEGY: {strategy_instruction}
219
220CONSTRAINTS:
221- Maximum {max_subtasks} subtasks
222- Each subtask should be independently executable
223- Identify dependencies between subtasks (which must complete before others can start)
224- Assign a specialty/role to each subtask
225
226OUTPUT FORMAT (JSON):
227```json
228{{
229  "subtasks": [
230    {{
231      "name": "Subtask Name",
232      "instruction": "Detailed instruction for this subtask",
233      "specialty": "Role/specialty (e.g., Researcher, Coder, Analyst)",
234      "dependencies": ["id-of-dependency-1"],
235      "priority": 1
236    }}
237  ]
238}}
239```
240
241Decompose the task now:"#,
242            task = task,
243            strategy_instruction = strategy_instruction,
244            max_subtasks = self.config.max_subagents,
245        )
246    }
247
248    /// Parse the decomposition response
249    fn parse_decomposition(&self, response: &str) -> Result<Vec<SubTask>> {
250        // Try to extract JSON from the response
251        let json_str = if let Some(start) = response.find("```json") {
252            let start = start + 7;
253            if let Some(end) = response[start..].find("```") {
254                &response[start..start + end]
255            } else {
256                response
257            }
258        } else if let Some(start) = response.find('{') {
259            if let Some(end) = response.rfind('}') {
260                &response[start..=end]
261            } else {
262                response
263            }
264        } else {
265            response
266        };
267
268        #[derive(Deserialize)]
269        struct DecompositionResponse {
270            subtasks: Vec<SubTaskDef>,
271        }
272
273        #[derive(Deserialize)]
274        struct SubTaskDef {
275            name: String,
276            instruction: String,
277            specialty: Option<String>,
278            #[serde(default)]
279            dependencies: Vec<String>,
280            #[serde(default)]
281            priority: i32,
282        }
283
284        let parsed: DecompositionResponse = serde_json::from_str(json_str.trim())
285            .map_err(|e| anyhow::anyhow!("Failed to parse decomposition: {}", e))?;
286
287        // Create SubTask objects with proper IDs
288        let mut subtasks = Vec::new();
289        let mut name_to_id: HashMap<String, String> = HashMap::new();
290
291        // First pass: create subtasks and map names to IDs
292        for def in &parsed.subtasks {
293            let subtask = SubTask::new(&def.name, &def.instruction).with_priority(def.priority);
294
295            let subtask = if let Some(ref specialty) = def.specialty {
296                subtask.with_specialty(specialty)
297            } else {
298                subtask
299            };
300
301            name_to_id.insert(def.name.clone(), subtask.id.clone());
302            subtasks.push((subtask, def.dependencies.clone()));
303        }
304
305        // Second pass: resolve dependencies
306        let result: Vec<SubTask> = subtasks
307            .into_iter()
308            .map(|(mut subtask, deps)| {
309                let resolved_deps: Vec<String> = deps
310                    .iter()
311                    .filter_map(|dep| name_to_id.get(dep).cloned())
312                    .collect();
313                subtask.dependencies = resolved_deps;
314                subtask
315            })
316            .collect();
317
318        Ok(result)
319    }
320
321    /// Assign stages to subtasks based on dependencies
322    fn assign_stages(&mut self) {
323        let mut changed = true;
324
325        while changed {
326            changed = false;
327
328            // First collect all updates needed
329            let updates: Vec<(String, usize)> = self
330                .subtasks
331                .iter()
332                .filter_map(|(id, subtask)| {
333                    if subtask.dependencies.is_empty() {
334                        if subtask.stage != 0 {
335                            Some((id.clone(), 0))
336                        } else {
337                            None
338                        }
339                    } else {
340                        let max_dep_stage = subtask
341                            .dependencies
342                            .iter()
343                            .filter_map(|dep_id| self.subtasks.get(dep_id))
344                            .map(|dep| dep.stage)
345                            .max()
346                            .unwrap_or(0);
347
348                        let new_stage = max_dep_stage + 1;
349                        if subtask.stage != new_stage {
350                            Some((id.clone(), new_stage))
351                        } else {
352                            None
353                        }
354                    }
355                })
356                .collect();
357
358            // Then apply updates
359            for (id, new_stage) in updates {
360                if let Some(subtask) = self.subtasks.get_mut(&id) {
361                    subtask.stage = new_stage;
362                    changed = true;
363                }
364            }
365        }
366    }
367
368    /// Get maximum stage number
369    fn max_stage(&self) -> usize {
370        self.subtasks.values().map(|s| s.stage).max().unwrap_or(0)
371    }
372
373    /// Get subtasks ready to execute (dependencies satisfied)
374    pub fn ready_subtasks(&self) -> Vec<&SubTask> {
375        self.subtasks
376            .values()
377            .filter(|s| s.status == SubTaskStatus::Pending && s.can_run(&self.completed))
378            .collect()
379    }
380
381    /// Get subtasks for a specific stage
382    pub fn subtasks_for_stage(&self, stage: usize) -> Vec<&SubTask> {
383        self.subtasks
384            .values()
385            .filter(|s| s.stage == stage)
386            .collect()
387    }
388
389    /// Create a sub-agent for a subtask
390    pub fn create_subagent(&mut self, subtask: &SubTask) -> SubAgent {
391        let specialty = subtask
392            .specialty
393            .clone()
394            .unwrap_or_else(|| "General".to_string());
395        let name = format!("{} Agent", specialty);
396
397        let subagent = SubAgent::new(name, specialty, &subtask.id, &self.model, &self.provider);
398
399        self.subagents.insert(subagent.id.clone(), subagent.clone());
400        self.stats.subagents_spawned += 1;
401
402        subagent
403    }
404
405    /// Mark a subtask as completed
406    pub fn complete_subtask(&mut self, subtask_id: &str, result: SubTaskResult) {
407        if let Some(subtask) = self.subtasks.get_mut(subtask_id) {
408            subtask.complete(result.success);
409
410            if result.success {
411                self.completed.push(subtask_id.to_string());
412                self.stats.subagents_completed += 1;
413            } else {
414                self.stats.subagents_failed += 1;
415            }
416
417            self.stats.total_tool_calls += result.tool_calls;
418        }
419    }
420
421    /// Get all subtasks
422    pub fn all_subtasks(&self) -> Vec<&SubTask> {
423        self.subtasks.values().collect()
424    }
425
426    /// Get statistics
427    pub fn stats(&self) -> &SwarmStats {
428        &self.stats
429    }
430
431    /// Get mutable statistics
432    pub fn stats_mut(&mut self) -> &mut SwarmStats {
433        &mut self.stats
434    }
435
436    /// Check if all subtasks are complete
437    pub fn is_complete(&self) -> bool {
438        self.subtasks.values().all(|s| {
439            matches!(
440                s.status,
441                SubTaskStatus::Completed | SubTaskStatus::Failed | SubTaskStatus::Cancelled
442            )
443        })
444    }
445
446    /// Get the provider registry
447    pub fn providers(&self) -> &ProviderRegistry {
448        &self.providers
449    }
450
451    /// Get current model
452    pub fn model(&self) -> &str {
453        &self.model
454    }
455
456    /// Get current provider
457    pub fn provider(&self) -> &str {
458        &self.provider
459    }
460}
461
462/// Message from sub-agent to orchestrator
463#[derive(Debug, Clone, Serialize, Deserialize)]
464pub enum SubAgentMessage {
465    /// Progress update
466    Progress {
467        subagent_id: String,
468        subtask_id: String,
469        steps: usize,
470        status: String,
471    },
472
473    /// Tool call made
474    ToolCall {
475        subagent_id: String,
476        tool_name: String,
477        success: bool,
478    },
479
480    /// Subtask completed
481    Completed {
482        subagent_id: String,
483        result: SubTaskResult,
484    },
485
486    /// Request for resources
487    ResourceRequest {
488        subagent_id: String,
489        resource_type: String,
490        resource_id: String,
491    },
492}
493
494/// Message from orchestrator to sub-agent
495#[derive(Debug, Clone, Serialize, Deserialize)]
496pub enum OrchestratorMessage {
497    /// Start execution
498    Start { subtask: Box<SubTask> },
499
500    /// Provide resource
501    Resource {
502        resource_type: String,
503        resource_id: String,
504        content: String,
505    },
506
507    /// Terminate execution
508    Terminate { reason: String },
509
510    /// Context update (from completed dependency)
511    ContextUpdate {
512        dependency_id: String,
513        result: String,
514    },
515}