Skip to main content

codetether_agent/swarm/
orchestrator.rs

1//! Orchestrator for decomposing tasks and coordinating sub-agents
2//!
3//! The orchestrator analyzes complex tasks and decomposes them into
4//! parallelizable subtasks, then coordinates their execution.
5
6use super::{
7    DecompositionStrategy, SubAgent, SubTask, SubTaskResult, SubTaskStatus, SwarmConfig, SwarmStats,
8};
9use crate::provider::{CompletionRequest, ContentPart, Message, ProviderRegistry, Role};
10use anyhow::Result;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// The orchestrator manages task decomposition and sub-agent coordination
15pub struct Orchestrator {
16    /// Configuration
17    config: SwarmConfig,
18
19    /// Provider registry for AI calls
20    providers: ProviderRegistry,
21
22    /// All subtasks
23    subtasks: HashMap<String, SubTask>,
24
25    /// All sub-agents
26    subagents: HashMap<String, SubAgent>,
27
28    /// Completed subtask IDs
29    completed: Vec<String>,
30
31    /// Current model for orchestration
32    model: String,
33
34    /// Current provider
35    provider: String,
36
37    /// Stats
38    stats: SwarmStats,
39}
40
41impl Orchestrator {
42    /// Create a new orchestrator
43    pub async fn new(config: SwarmConfig) -> Result<Self> {
44        use crate::provider::parse_model_string;
45
46        let providers = ProviderRegistry::from_vault().await?;
47        let provider_list = providers.list();
48
49        if provider_list.is_empty() {
50            anyhow::bail!("No providers available for orchestration");
51        }
52
53        // Parse model from config, env var, or use default
54        let model_str = config
55            .model
56            .clone()
57            .or_else(|| std::env::var("CODETETHER_DEFAULT_MODEL").ok());
58
59        let (provider, model) = if let Some(ref model_str) = model_str {
60            let (prov, mod_id) = parse_model_string(model_str);
61            let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
62            let provider = prov
63                .filter(|p| provider_list.contains(p))
64                .unwrap_or(provider_list[0])
65                .to_string();
66            let model = mod_id.to_string();
67            (provider, model)
68        } else {
69            // Default to GLM-5 via Z.AI for swarm.
70            let provider = if provider_list.contains(&"zai") {
71                "zai".to_string()
72            } else if provider_list.contains(&"openrouter") {
73                "openrouter".to_string()
74            } else {
75                provider_list[0].to_string()
76            };
77            let model = Self::default_model_for_provider(&provider);
78            (provider, model)
79        };
80
81        tracing::info!("Orchestrator using model {} via {}", model, provider);
82
83        Ok(Self {
84            config,
85            providers,
86            subtasks: HashMap::new(),
87            subagents: HashMap::new(),
88            completed: Vec::new(),
89            model,
90            provider,
91            stats: SwarmStats::default(),
92        })
93    }
94
95    /// Get default model for a provider
96    fn default_model_for_provider(provider: &str) -> String {
97        match provider {
98            "moonshotai" => "kimi-k2.5".to_string(),
99            "anthropic" => "claude-sonnet-4-20250514".to_string(),
100            "openai" => "gpt-4o".to_string(),
101            "google" => "gemini-2.5-pro".to_string(),
102            "zhipuai" | "zai" => "glm-5".to_string(),
103            "openrouter" => "z-ai/glm-5".to_string(),
104            "novita" => "qwen/qwen3-coder-next".to_string(),
105            "github-copilot" | "github-copilot-enterprise" => "gpt-5-mini".to_string(),
106            _ => "glm-5".to_string(),
107        }
108    }
109
110    fn prefers_temperature_one(model: &str) -> bool {
111        let normalized = model.to_ascii_lowercase();
112        normalized.contains("kimi-k2")
113            || normalized.contains("glm-")
114            || normalized.contains("minimax")
115    }
116
117    /// Decompose a complex task into subtasks
118    pub async fn decompose(
119        &mut self,
120        task: &str,
121        strategy: DecompositionStrategy,
122    ) -> Result<Vec<SubTask>> {
123        if strategy == DecompositionStrategy::None {
124            // Single task, no decomposition
125            let subtask = SubTask::new("Main Task", task);
126            self.subtasks.insert(subtask.id.clone(), subtask.clone());
127            return Ok(vec![subtask]);
128        }
129
130        // Use AI to decompose the task
131        let decomposition_prompt = self.build_decomposition_prompt(task, strategy);
132
133        let provider = self
134            .providers
135            .get(&self.provider)
136            .ok_or_else(|| anyhow::anyhow!("Provider {} not found", self.provider))?;
137
138        let temperature = if Self::prefers_temperature_one(&self.model) {
139            1.0
140        } else {
141            0.7
142        };
143
144        let request = CompletionRequest {
145            messages: vec![Message {
146                role: Role::User,
147                content: vec![ContentPart::Text {
148                    text: decomposition_prompt,
149                }],
150            }],
151            tools: Vec::new(),
152            model: self.model.clone(),
153            temperature: Some(temperature),
154            top_p: None,
155            max_tokens: Some(8192),
156            stop: Vec::new(),
157        };
158
159        let response = provider.complete(request).await?;
160
161        // Parse the decomposition response
162        let text = response
163            .message
164            .content
165            .iter()
166            .filter_map(|p| match p {
167                ContentPart::Text { text } => Some(text.clone()),
168                _ => None,
169            })
170            .collect::<Vec<_>>()
171            .join("\n");
172
173        tracing::debug!("Decomposition response: {}", text);
174
175        if text.trim().is_empty() {
176            // Fallback to single task if decomposition fails
177            tracing::warn!("Empty decomposition response, falling back to single task");
178            let subtask = SubTask::new("Main Task", task);
179            self.subtasks.insert(subtask.id.clone(), subtask.clone());
180            return Ok(vec![subtask]);
181        }
182
183        let subtasks = self.parse_decomposition(&text)?;
184
185        // Store subtasks
186        for subtask in &subtasks {
187            self.subtasks.insert(subtask.id.clone(), subtask.clone());
188        }
189
190        // Assign stages based on dependencies
191        self.assign_stages();
192
193        tracing::info!(
194            "Decomposed task into {} subtasks across {} stages",
195            subtasks.len(),
196            self.max_stage() + 1
197        );
198
199        Ok(subtasks)
200    }
201
202    /// Build the decomposition prompt
203    fn build_decomposition_prompt(&self, task: &str, strategy: DecompositionStrategy) -> String {
204        let strategy_instruction = match strategy {
205            DecompositionStrategy::Automatic => {
206                "Analyze the task and determine the optimal way to decompose it into parallel subtasks."
207            }
208            DecompositionStrategy::ByDomain => {
209                "Decompose the task by domain expertise (e.g., research, coding, analysis, verification)."
210            }
211            DecompositionStrategy::ByData => {
212                "Decompose the task by data partition (e.g., different files, sections, or datasets)."
213            }
214            DecompositionStrategy::ByStage => {
215                "Decompose the task by workflow stages (e.g., gather, process, synthesize)."
216            }
217            DecompositionStrategy::None => unreachable!(),
218        };
219
220        format!(
221            r#"You are a task orchestrator. Your job is to decompose complex tasks into parallelizable subtasks.
222
223TASK: {task}
224
225STRATEGY: {strategy_instruction}
226
227CONSTRAINTS:
228- Maximum {max_subtasks} subtasks
229- Each subtask should be independently executable
230- Identify dependencies between subtasks (which must complete before others can start)
231- Assign a specialty/role to each subtask
232
233OUTPUT FORMAT (JSON):
234```json
235{{
236  "subtasks": [
237    {{
238      "name": "Subtask Name",
239      "instruction": "Detailed instruction for this subtask",
240      "specialty": "Role/specialty (e.g., Researcher, Coder, Analyst)",
241      "dependencies": ["id-of-dependency-1"],
242      "priority": 1
243    }}
244  ]
245}}
246```
247
248Decompose the task now:"#,
249            task = task,
250            strategy_instruction = strategy_instruction,
251            max_subtasks = self.config.max_subagents,
252        )
253    }
254
255    /// Parse the decomposition response
256    fn parse_decomposition(&self, response: &str) -> Result<Vec<SubTask>> {
257        // Try to extract JSON from the response
258        let json_str = if let Some(start) = response.find("```json") {
259            let start = start + 7;
260            if let Some(end) = response[start..].find("```") {
261                &response[start..start + end]
262            } else {
263                response
264            }
265        } else if let Some(start) = response.find('{') {
266            if let Some(end) = response.rfind('}') {
267                &response[start..=end]
268            } else {
269                response
270            }
271        } else {
272            response
273        };
274
275        #[derive(Deserialize)]
276        struct DecompositionResponse {
277            subtasks: Vec<SubTaskDef>,
278        }
279
280        #[derive(Deserialize)]
281        struct SubTaskDef {
282            name: String,
283            instruction: String,
284            specialty: Option<String>,
285            #[serde(default)]
286            dependencies: Vec<String>,
287            #[serde(default)]
288            priority: i32,
289        }
290
291        let parsed: DecompositionResponse = serde_json::from_str(json_str.trim())
292            .map_err(|e| anyhow::anyhow!("Failed to parse decomposition: {}", e))?;
293
294        // Create SubTask objects with proper IDs
295        let mut subtasks = Vec::new();
296        let mut name_to_id: HashMap<String, String> = HashMap::new();
297
298        // First pass: create subtasks and map names to IDs
299        for def in &parsed.subtasks {
300            let subtask = SubTask::new(&def.name, &def.instruction).with_priority(def.priority);
301
302            let subtask = if let Some(ref specialty) = def.specialty {
303                subtask.with_specialty(specialty)
304            } else {
305                subtask
306            };
307
308            name_to_id.insert(def.name.clone(), subtask.id.clone());
309            subtasks.push((subtask, def.dependencies.clone()));
310        }
311
312        // Second pass: resolve dependencies
313        let result: Vec<SubTask> = subtasks
314            .into_iter()
315            .map(|(mut subtask, deps)| {
316                let resolved_deps: Vec<String> = deps
317                    .iter()
318                    .filter_map(|dep| name_to_id.get(dep).cloned())
319                    .collect();
320                subtask.dependencies = resolved_deps;
321                subtask
322            })
323            .collect();
324
325        Ok(result)
326    }
327
328    /// Assign stages to subtasks based on dependencies
329    fn assign_stages(&mut self) {
330        let mut changed = true;
331
332        while changed {
333            changed = false;
334
335            // First collect all updates needed
336            let updates: Vec<(String, usize)> = self
337                .subtasks
338                .iter()
339                .filter_map(|(id, subtask)| {
340                    if subtask.dependencies.is_empty() {
341                        if subtask.stage != 0 {
342                            Some((id.clone(), 0))
343                        } else {
344                            None
345                        }
346                    } else {
347                        let max_dep_stage = subtask
348                            .dependencies
349                            .iter()
350                            .filter_map(|dep_id| self.subtasks.get(dep_id))
351                            .map(|dep| dep.stage)
352                            .max()
353                            .unwrap_or(0);
354
355                        let new_stage = max_dep_stage + 1;
356                        if subtask.stage != new_stage {
357                            Some((id.clone(), new_stage))
358                        } else {
359                            None
360                        }
361                    }
362                })
363                .collect();
364
365            // Then apply updates
366            for (id, new_stage) in updates {
367                if let Some(subtask) = self.subtasks.get_mut(&id) {
368                    subtask.stage = new_stage;
369                    changed = true;
370                }
371            }
372        }
373    }
374
375    /// Get maximum stage number
376    fn max_stage(&self) -> usize {
377        self.subtasks.values().map(|s| s.stage).max().unwrap_or(0)
378    }
379
380    /// Get subtasks ready to execute (dependencies satisfied)
381    pub fn ready_subtasks(&self) -> Vec<&SubTask> {
382        self.subtasks
383            .values()
384            .filter(|s| s.status == SubTaskStatus::Pending && s.can_run(&self.completed))
385            .collect()
386    }
387
388    /// Get subtasks for a specific stage
389    pub fn subtasks_for_stage(&self, stage: usize) -> Vec<&SubTask> {
390        self.subtasks
391            .values()
392            .filter(|s| s.stage == stage)
393            .collect()
394    }
395
396    /// Create a sub-agent for a subtask
397    pub fn create_subagent(&mut self, subtask: &SubTask) -> SubAgent {
398        let specialty = subtask
399            .specialty
400            .clone()
401            .unwrap_or_else(|| "General".to_string());
402        let name = format!("{} Agent", specialty);
403
404        let subagent = SubAgent::new(name, specialty, &subtask.id, &self.model, &self.provider);
405
406        self.subagents.insert(subagent.id.clone(), subagent.clone());
407        self.stats.subagents_spawned += 1;
408
409        subagent
410    }
411
412    /// Mark a subtask as completed
413    pub fn complete_subtask(&mut self, subtask_id: &str, result: SubTaskResult) {
414        if let Some(subtask) = self.subtasks.get_mut(subtask_id) {
415            subtask.complete(result.success);
416
417            if result.success {
418                self.completed.push(subtask_id.to_string());
419                self.stats.subagents_completed += 1;
420            } else {
421                self.stats.subagents_failed += 1;
422            }
423
424            self.stats.total_tool_calls += result.tool_calls;
425        }
426    }
427
428    /// Get all subtasks
429    pub fn all_subtasks(&self) -> Vec<&SubTask> {
430        self.subtasks.values().collect()
431    }
432
433    /// Get statistics
434    pub fn stats(&self) -> &SwarmStats {
435        &self.stats
436    }
437
438    /// Get mutable statistics
439    pub fn stats_mut(&mut self) -> &mut SwarmStats {
440        &mut self.stats
441    }
442
443    /// Check if all subtasks are complete
444    pub fn is_complete(&self) -> bool {
445        self.subtasks.values().all(|s| {
446            matches!(
447                s.status,
448                SubTaskStatus::Completed | SubTaskStatus::Failed | SubTaskStatus::Cancelled
449            )
450        })
451    }
452
453    /// Get the provider registry
454    pub fn providers(&self) -> &ProviderRegistry {
455        &self.providers
456    }
457
458    /// Get current model
459    pub fn model(&self) -> &str {
460        &self.model
461    }
462
463    /// Get current provider
464    pub fn provider(&self) -> &str {
465        &self.provider
466    }
467}
468
469/// Message from sub-agent to orchestrator
470#[derive(Debug, Clone, Serialize, Deserialize)]
471pub enum SubAgentMessage {
472    /// Progress update
473    Progress {
474        subagent_id: String,
475        subtask_id: String,
476        steps: usize,
477        status: String,
478    },
479
480    /// Tool call made
481    ToolCall {
482        subagent_id: String,
483        tool_name: String,
484        success: bool,
485    },
486
487    /// Subtask completed
488    Completed {
489        subagent_id: String,
490        result: SubTaskResult,
491    },
492
493    /// Request for resources
494    ResourceRequest {
495        subagent_id: String,
496        resource_type: String,
497        resource_id: String,
498    },
499}
500
501/// Message from orchestrator to sub-agent
502#[derive(Debug, Clone, Serialize, Deserialize)]
503pub enum OrchestratorMessage {
504    /// Start execution
505    Start { subtask: Box<SubTask> },
506
507    /// Provide resource
508    Resource {
509        resource_type: String,
510        resource_id: String,
511        content: String,
512    },
513
514    /// Terminate execution
515    Terminate { reason: String },
516
517    /// Context update (from completed dependency)
518    ContextUpdate {
519        dependency_id: String,
520        result: String,
521    },
522}