1use super::{
7 DecompositionStrategy, SubAgent, SubTask, SubTaskResult, SubTaskStatus, SwarmConfig, SwarmStats,
8};
9use crate::provider::{CompletionRequest, ContentPart, Message, ProviderRegistry, Role};
10use anyhow::Result;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14pub struct Orchestrator {
16 config: SwarmConfig,
18
19 providers: ProviderRegistry,
21
22 subtasks: HashMap<String, SubTask>,
24
25 subagents: HashMap<String, SubAgent>,
27
28 completed: Vec<String>,
30
31 model: String,
33
34 provider: String,
36
37 stats: SwarmStats,
39}
40
41impl Orchestrator {
42 pub async fn new(config: SwarmConfig) -> Result<Self> {
44 use crate::provider::parse_model_string;
45
46 let providers = ProviderRegistry::from_vault().await?;
47 let provider_list = providers.list();
48
49 if provider_list.is_empty() {
50 anyhow::bail!("No providers available for orchestration");
51 }
52
53 let model_str = config
55 .model
56 .clone()
57 .or_else(|| std::env::var("CODETETHER_DEFAULT_MODEL").ok());
58
59 let (provider, model) = if let Some(ref model_str) = model_str {
60 let (prov, mod_id) = parse_model_string(model_str);
61 let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
62 let provider = prov
63 .filter(|p| provider_list.contains(p))
64 .unwrap_or(provider_list[0])
65 .to_string();
66 let model = mod_id.to_string();
67 (provider, model)
68 } else {
69 let provider = if provider_list.contains(&"zai") {
71 "zai".to_string()
72 } else if provider_list.contains(&"openrouter") {
73 "openrouter".to_string()
74 } else {
75 provider_list[0].to_string()
76 };
77 let model = Self::default_model_for_provider(&provider);
78 (provider, model)
79 };
80
81 tracing::info!("Orchestrator using model {} via {}", model, provider);
82
83 Ok(Self {
84 config,
85 providers,
86 subtasks: HashMap::new(),
87 subagents: HashMap::new(),
88 completed: Vec::new(),
89 model,
90 provider,
91 stats: SwarmStats::default(),
92 })
93 }
94
95 fn default_model_for_provider(provider: &str) -> String {
97 match provider {
98 "moonshotai" => "kimi-k2.5".to_string(),
99 "anthropic" => "claude-sonnet-4-20250514".to_string(),
100 "openai" => "gpt-4o".to_string(),
101 "google" => "gemini-2.5-pro".to_string(),
102 "zhipuai" | "zai" => "glm-5".to_string(),
103 "openrouter" => "z-ai/glm-5".to_string(),
104 "novita" => "qwen/qwen3-coder-next".to_string(),
105 "github-copilot" | "github-copilot-enterprise" => "gpt-5-mini".to_string(),
106 _ => "glm-5".to_string(),
107 }
108 }
109
110 pub async fn decompose(
112 &mut self,
113 task: &str,
114 strategy: DecompositionStrategy,
115 ) -> Result<Vec<SubTask>> {
116 if strategy == DecompositionStrategy::None {
117 let subtask = SubTask::new("Main Task", task);
119 self.subtasks.insert(subtask.id.clone(), subtask.clone());
120 return Ok(vec![subtask]);
121 }
122
123 let decomposition_prompt = self.build_decomposition_prompt(task, strategy);
125
126 let provider = self
127 .providers
128 .get(&self.provider)
129 .ok_or_else(|| anyhow::anyhow!("Provider {} not found", self.provider))?;
130
131 let temperature = if self.model.contains("kimi-k2") || self.model.contains("glm-") {
132 1.0
133 } else {
134 0.7
135 };
136
137 let request = CompletionRequest {
138 messages: vec![Message {
139 role: Role::User,
140 content: vec![ContentPart::Text {
141 text: decomposition_prompt,
142 }],
143 }],
144 tools: Vec::new(),
145 model: self.model.clone(),
146 temperature: Some(temperature),
147 top_p: None,
148 max_tokens: Some(8192),
149 stop: Vec::new(),
150 };
151
152 let response = provider.complete(request).await?;
153
154 let text = response
156 .message
157 .content
158 .iter()
159 .filter_map(|p| match p {
160 ContentPart::Text { text } => Some(text.clone()),
161 _ => None,
162 })
163 .collect::<Vec<_>>()
164 .join("\n");
165
166 tracing::debug!("Decomposition response: {}", text);
167
168 if text.trim().is_empty() {
169 tracing::warn!("Empty decomposition response, falling back to single task");
171 let subtask = SubTask::new("Main Task", task);
172 self.subtasks.insert(subtask.id.clone(), subtask.clone());
173 return Ok(vec![subtask]);
174 }
175
176 let subtasks = self.parse_decomposition(&text)?;
177
178 for subtask in &subtasks {
180 self.subtasks.insert(subtask.id.clone(), subtask.clone());
181 }
182
183 self.assign_stages();
185
186 tracing::info!(
187 "Decomposed task into {} subtasks across {} stages",
188 subtasks.len(),
189 self.max_stage() + 1
190 );
191
192 Ok(subtasks)
193 }
194
195 fn build_decomposition_prompt(&self, task: &str, strategy: DecompositionStrategy) -> String {
197 let strategy_instruction = match strategy {
198 DecompositionStrategy::Automatic => {
199 "Analyze the task and determine the optimal way to decompose it into parallel subtasks."
200 }
201 DecompositionStrategy::ByDomain => {
202 "Decompose the task by domain expertise (e.g., research, coding, analysis, verification)."
203 }
204 DecompositionStrategy::ByData => {
205 "Decompose the task by data partition (e.g., different files, sections, or datasets)."
206 }
207 DecompositionStrategy::ByStage => {
208 "Decompose the task by workflow stages (e.g., gather, process, synthesize)."
209 }
210 DecompositionStrategy::None => unreachable!(),
211 };
212
213 format!(
214 r#"You are a task orchestrator. Your job is to decompose complex tasks into parallelizable subtasks.
215
216TASK: {task}
217
218STRATEGY: {strategy_instruction}
219
220CONSTRAINTS:
221- Maximum {max_subtasks} subtasks
222- Each subtask should be independently executable
223- Identify dependencies between subtasks (which must complete before others can start)
224- Assign a specialty/role to each subtask
225
226OUTPUT FORMAT (JSON):
227```json
228{{
229 "subtasks": [
230 {{
231 "name": "Subtask Name",
232 "instruction": "Detailed instruction for this subtask",
233 "specialty": "Role/specialty (e.g., Researcher, Coder, Analyst)",
234 "dependencies": ["id-of-dependency-1"],
235 "priority": 1
236 }}
237 ]
238}}
239```
240
241Decompose the task now:"#,
242 task = task,
243 strategy_instruction = strategy_instruction,
244 max_subtasks = self.config.max_subagents,
245 )
246 }
247
248 fn parse_decomposition(&self, response: &str) -> Result<Vec<SubTask>> {
250 let json_str = if let Some(start) = response.find("```json") {
252 let start = start + 7;
253 if let Some(end) = response[start..].find("```") {
254 &response[start..start + end]
255 } else {
256 response
257 }
258 } else if let Some(start) = response.find('{') {
259 if let Some(end) = response.rfind('}') {
260 &response[start..=end]
261 } else {
262 response
263 }
264 } else {
265 response
266 };
267
268 #[derive(Deserialize)]
269 struct DecompositionResponse {
270 subtasks: Vec<SubTaskDef>,
271 }
272
273 #[derive(Deserialize)]
274 struct SubTaskDef {
275 name: String,
276 instruction: String,
277 specialty: Option<String>,
278 #[serde(default)]
279 dependencies: Vec<String>,
280 #[serde(default)]
281 priority: i32,
282 }
283
284 let parsed: DecompositionResponse = serde_json::from_str(json_str.trim())
285 .map_err(|e| anyhow::anyhow!("Failed to parse decomposition: {}", e))?;
286
287 let mut subtasks = Vec::new();
289 let mut name_to_id: HashMap<String, String> = HashMap::new();
290
291 for def in &parsed.subtasks {
293 let subtask = SubTask::new(&def.name, &def.instruction).with_priority(def.priority);
294
295 let subtask = if let Some(ref specialty) = def.specialty {
296 subtask.with_specialty(specialty)
297 } else {
298 subtask
299 };
300
301 name_to_id.insert(def.name.clone(), subtask.id.clone());
302 subtasks.push((subtask, def.dependencies.clone()));
303 }
304
305 let result: Vec<SubTask> = subtasks
307 .into_iter()
308 .map(|(mut subtask, deps)| {
309 let resolved_deps: Vec<String> = deps
310 .iter()
311 .filter_map(|dep| name_to_id.get(dep).cloned())
312 .collect();
313 subtask.dependencies = resolved_deps;
314 subtask
315 })
316 .collect();
317
318 Ok(result)
319 }
320
321 fn assign_stages(&mut self) {
323 let mut changed = true;
324
325 while changed {
326 changed = false;
327
328 let updates: Vec<(String, usize)> = self
330 .subtasks
331 .iter()
332 .filter_map(|(id, subtask)| {
333 if subtask.dependencies.is_empty() {
334 if subtask.stage != 0 {
335 Some((id.clone(), 0))
336 } else {
337 None
338 }
339 } else {
340 let max_dep_stage = subtask
341 .dependencies
342 .iter()
343 .filter_map(|dep_id| self.subtasks.get(dep_id))
344 .map(|dep| dep.stage)
345 .max()
346 .unwrap_or(0);
347
348 let new_stage = max_dep_stage + 1;
349 if subtask.stage != new_stage {
350 Some((id.clone(), new_stage))
351 } else {
352 None
353 }
354 }
355 })
356 .collect();
357
358 for (id, new_stage) in updates {
360 if let Some(subtask) = self.subtasks.get_mut(&id) {
361 subtask.stage = new_stage;
362 changed = true;
363 }
364 }
365 }
366 }
367
368 fn max_stage(&self) -> usize {
370 self.subtasks.values().map(|s| s.stage).max().unwrap_or(0)
371 }
372
373 pub fn ready_subtasks(&self) -> Vec<&SubTask> {
375 self.subtasks
376 .values()
377 .filter(|s| s.status == SubTaskStatus::Pending && s.can_run(&self.completed))
378 .collect()
379 }
380
381 pub fn subtasks_for_stage(&self, stage: usize) -> Vec<&SubTask> {
383 self.subtasks
384 .values()
385 .filter(|s| s.stage == stage)
386 .collect()
387 }
388
389 pub fn create_subagent(&mut self, subtask: &SubTask) -> SubAgent {
391 let specialty = subtask
392 .specialty
393 .clone()
394 .unwrap_or_else(|| "General".to_string());
395 let name = format!("{} Agent", specialty);
396
397 let subagent = SubAgent::new(name, specialty, &subtask.id, &self.model, &self.provider);
398
399 self.subagents.insert(subagent.id.clone(), subagent.clone());
400 self.stats.subagents_spawned += 1;
401
402 subagent
403 }
404
405 pub fn complete_subtask(&mut self, subtask_id: &str, result: SubTaskResult) {
407 if let Some(subtask) = self.subtasks.get_mut(subtask_id) {
408 subtask.complete(result.success);
409
410 if result.success {
411 self.completed.push(subtask_id.to_string());
412 self.stats.subagents_completed += 1;
413 } else {
414 self.stats.subagents_failed += 1;
415 }
416
417 self.stats.total_tool_calls += result.tool_calls;
418 }
419 }
420
421 pub fn all_subtasks(&self) -> Vec<&SubTask> {
423 self.subtasks.values().collect()
424 }
425
426 pub fn stats(&self) -> &SwarmStats {
428 &self.stats
429 }
430
431 pub fn stats_mut(&mut self) -> &mut SwarmStats {
433 &mut self.stats
434 }
435
436 pub fn is_complete(&self) -> bool {
438 self.subtasks.values().all(|s| {
439 matches!(
440 s.status,
441 SubTaskStatus::Completed | SubTaskStatus::Failed | SubTaskStatus::Cancelled
442 )
443 })
444 }
445
446 pub fn providers(&self) -> &ProviderRegistry {
448 &self.providers
449 }
450
451 pub fn model(&self) -> &str {
453 &self.model
454 }
455
456 pub fn provider(&self) -> &str {
458 &self.provider
459 }
460}
461
462#[derive(Debug, Clone, Serialize, Deserialize)]
464pub enum SubAgentMessage {
465 Progress {
467 subagent_id: String,
468 subtask_id: String,
469 steps: usize,
470 status: String,
471 },
472
473 ToolCall {
475 subagent_id: String,
476 tool_name: String,
477 success: bool,
478 },
479
480 Completed {
482 subagent_id: String,
483 result: SubTaskResult,
484 },
485
486 ResourceRequest {
488 subagent_id: String,
489 resource_type: String,
490 resource_id: String,
491 },
492}
493
494#[derive(Debug, Clone, Serialize, Deserialize)]
496pub enum OrchestratorMessage {
497 Start { subtask: Box<SubTask> },
499
500 Resource {
502 resource_type: String,
503 resource_id: String,
504 content: String,
505 },
506
507 Terminate { reason: String },
509
510 ContextUpdate {
512 dependency_id: String,
513 result: String,
514 },
515}