1use super::{
7 DecompositionStrategy, SubAgent, SubTask, SubTaskResult, SubTaskStatus, SwarmConfig, SwarmStats,
8};
9use crate::provider::{CompletionRequest, ContentPart, Message, ProviderRegistry, Role};
10use anyhow::Result;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14pub struct Orchestrator {
16 config: SwarmConfig,
18
19 providers: ProviderRegistry,
21
22 subtasks: HashMap<String, SubTask>,
24
25 subagents: HashMap<String, SubAgent>,
27
28 completed: Vec<String>,
30
31 model: String,
33
34 provider: String,
36
37 stats: SwarmStats,
39}
40
41impl Orchestrator {
42 pub async fn new(config: SwarmConfig) -> Result<Self> {
44 use crate::provider::parse_model_string;
45
46 let providers = ProviderRegistry::from_vault().await?;
47 let provider_list = providers.list();
48
49 if provider_list.is_empty() {
50 anyhow::bail!("No providers available for orchestration");
51 }
52
53 let model_str = config
55 .model
56 .clone()
57 .or_else(|| std::env::var("CODETETHER_DEFAULT_MODEL").ok());
58
59 let (provider, model) = if let Some(ref model_str) = model_str {
60 let (prov, mod_id) = parse_model_string(model_str);
61 let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
62 let provider = prov
63 .filter(|p| provider_list.contains(p))
64 .unwrap_or(provider_list[0])
65 .to_string();
66 let model = mod_id.to_string();
67 (provider, model)
68 } else {
69 let provider = if provider_list.contains(&"zai") {
71 "zai".to_string()
72 } else if provider_list.contains(&"openrouter") {
73 "openrouter".to_string()
74 } else {
75 provider_list[0].to_string()
76 };
77 let model = Self::default_model_for_provider(&provider);
78 (provider, model)
79 };
80
81 tracing::info!("Orchestrator using model {} via {}", model, provider);
82
83 Ok(Self {
84 config,
85 providers,
86 subtasks: HashMap::new(),
87 subagents: HashMap::new(),
88 completed: Vec::new(),
89 model,
90 provider,
91 stats: SwarmStats::default(),
92 })
93 }
94
95 fn default_model_for_provider(provider: &str) -> String {
97 match provider {
98 "moonshotai" => "kimi-k2.5".to_string(),
99 "anthropic" => "claude-sonnet-4-20250514".to_string(),
100 "openai" => "gpt-4o".to_string(),
101 "google" => "gemini-2.5-pro".to_string(),
102 "zhipuai" | "zai" => "glm-5".to_string(),
103 "openrouter" => "z-ai/glm-5".to_string(),
104 "novita" => "qwen/qwen3-coder-next".to_string(),
105 "github-copilot" | "github-copilot-enterprise" => "gpt-5-mini".to_string(),
106 _ => "glm-5".to_string(),
107 }
108 }
109
110 fn prefers_temperature_one(model: &str) -> bool {
111 let normalized = model.to_ascii_lowercase();
112 normalized.contains("kimi-k2")
113 || normalized.contains("glm-")
114 || normalized.contains("minimax")
115 }
116
117 pub async fn decompose(
119 &mut self,
120 task: &str,
121 strategy: DecompositionStrategy,
122 ) -> Result<Vec<SubTask>> {
123 if strategy == DecompositionStrategy::None {
124 let subtask = SubTask::new("Main Task", task);
126 self.subtasks.insert(subtask.id.clone(), subtask.clone());
127 return Ok(vec![subtask]);
128 }
129
130 let decomposition_prompt = self.build_decomposition_prompt(task, strategy);
132
133 let provider = self
134 .providers
135 .get(&self.provider)
136 .ok_or_else(|| anyhow::anyhow!("Provider {} not found", self.provider))?;
137
138 let temperature = if Self::prefers_temperature_one(&self.model) {
139 1.0
140 } else {
141 0.7
142 };
143
144 let request = CompletionRequest {
145 messages: vec![Message {
146 role: Role::User,
147 content: vec![ContentPart::Text {
148 text: decomposition_prompt,
149 }],
150 }],
151 tools: Vec::new(),
152 model: self.model.clone(),
153 temperature: Some(temperature),
154 top_p: None,
155 max_tokens: Some(8192),
156 stop: Vec::new(),
157 };
158
159 let response = provider.complete(request).await?;
160
161 let text = response
163 .message
164 .content
165 .iter()
166 .filter_map(|p| match p {
167 ContentPart::Text { text } => Some(text.clone()),
168 _ => None,
169 })
170 .collect::<Vec<_>>()
171 .join("\n");
172
173 tracing::debug!("Decomposition response: {}", text);
174
175 if text.trim().is_empty() {
176 tracing::warn!("Empty decomposition response, falling back to single task");
178 let subtask = SubTask::new("Main Task", task);
179 self.subtasks.insert(subtask.id.clone(), subtask.clone());
180 return Ok(vec![subtask]);
181 }
182
183 let subtasks = self.parse_decomposition(&text)?;
184
185 for subtask in &subtasks {
187 self.subtasks.insert(subtask.id.clone(), subtask.clone());
188 }
189
190 self.assign_stages();
192
193 tracing::info!(
194 "Decomposed task into {} subtasks across {} stages",
195 subtasks.len(),
196 self.max_stage() + 1
197 );
198
199 Ok(subtasks)
200 }
201
202 fn build_decomposition_prompt(&self, task: &str, strategy: DecompositionStrategy) -> String {
204 let strategy_instruction = match strategy {
205 DecompositionStrategy::Automatic => {
206 "Analyze the task and determine the optimal way to decompose it into parallel subtasks."
207 }
208 DecompositionStrategy::ByDomain => {
209 "Decompose the task by domain expertise (e.g., research, coding, analysis, verification)."
210 }
211 DecompositionStrategy::ByData => {
212 "Decompose the task by data partition (e.g., different files, sections, or datasets)."
213 }
214 DecompositionStrategy::ByStage => {
215 "Decompose the task by workflow stages (e.g., gather, process, synthesize)."
216 }
217 DecompositionStrategy::None => unreachable!(),
218 };
219
220 format!(
221 r#"You are a task orchestrator. Your job is to decompose complex tasks into parallelizable subtasks.
222
223TASK: {task}
224
225STRATEGY: {strategy_instruction}
226
227CONSTRAINTS:
228- Maximum {max_subtasks} subtasks
229- Each subtask should be independently executable
230- Identify dependencies between subtasks (which must complete before others can start)
231- Assign a specialty/role to each subtask
232
233OUTPUT FORMAT (JSON):
234```json
235{{
236 "subtasks": [
237 {{
238 "name": "Subtask Name",
239 "instruction": "Detailed instruction for this subtask",
240 "specialty": "Role/specialty (e.g., Researcher, Coder, Analyst)",
241 "dependencies": ["id-of-dependency-1"],
242 "priority": 1
243 }}
244 ]
245}}
246```
247
248Decompose the task now:"#,
249 task = task,
250 strategy_instruction = strategy_instruction,
251 max_subtasks = self.config.max_subagents,
252 )
253 }
254
255 fn parse_decomposition(&self, response: &str) -> Result<Vec<SubTask>> {
257 let json_str = if let Some(start) = response.find("```json") {
259 let start = start + 7;
260 if let Some(end) = response[start..].find("```") {
261 &response[start..start + end]
262 } else {
263 response
264 }
265 } else if let Some(start) = response.find('{') {
266 if let Some(end) = response.rfind('}') {
267 &response[start..=end]
268 } else {
269 response
270 }
271 } else {
272 response
273 };
274
275 #[derive(Deserialize)]
276 struct DecompositionResponse {
277 subtasks: Vec<SubTaskDef>,
278 }
279
280 #[derive(Deserialize)]
281 struct SubTaskDef {
282 name: String,
283 instruction: String,
284 specialty: Option<String>,
285 #[serde(default)]
286 dependencies: Vec<String>,
287 #[serde(default)]
288 priority: i32,
289 }
290
291 let parsed: DecompositionResponse = serde_json::from_str(json_str.trim())
292 .map_err(|e| anyhow::anyhow!("Failed to parse decomposition: {}", e))?;
293
294 let mut subtasks = Vec::new();
296 let mut name_to_id: HashMap<String, String> = HashMap::new();
297
298 for def in &parsed.subtasks {
300 let subtask = SubTask::new(&def.name, &def.instruction).with_priority(def.priority);
301
302 let subtask = if let Some(ref specialty) = def.specialty {
303 subtask.with_specialty(specialty)
304 } else {
305 subtask
306 };
307
308 name_to_id.insert(def.name.clone(), subtask.id.clone());
309 subtasks.push((subtask, def.dependencies.clone()));
310 }
311
312 let result: Vec<SubTask> = subtasks
314 .into_iter()
315 .map(|(mut subtask, deps)| {
316 let resolved_deps: Vec<String> = deps
317 .iter()
318 .filter_map(|dep| name_to_id.get(dep).cloned())
319 .collect();
320 subtask.dependencies = resolved_deps;
321 subtask
322 })
323 .collect();
324
325 Ok(result)
326 }
327
328 fn assign_stages(&mut self) {
330 let mut changed = true;
331
332 while changed {
333 changed = false;
334
335 let updates: Vec<(String, usize)> = self
337 .subtasks
338 .iter()
339 .filter_map(|(id, subtask)| {
340 if subtask.dependencies.is_empty() {
341 if subtask.stage != 0 {
342 Some((id.clone(), 0))
343 } else {
344 None
345 }
346 } else {
347 let max_dep_stage = subtask
348 .dependencies
349 .iter()
350 .filter_map(|dep_id| self.subtasks.get(dep_id))
351 .map(|dep| dep.stage)
352 .max()
353 .unwrap_or(0);
354
355 let new_stage = max_dep_stage + 1;
356 if subtask.stage != new_stage {
357 Some((id.clone(), new_stage))
358 } else {
359 None
360 }
361 }
362 })
363 .collect();
364
365 for (id, new_stage) in updates {
367 if let Some(subtask) = self.subtasks.get_mut(&id) {
368 subtask.stage = new_stage;
369 changed = true;
370 }
371 }
372 }
373 }
374
375 fn max_stage(&self) -> usize {
377 self.subtasks.values().map(|s| s.stage).max().unwrap_or(0)
378 }
379
380 pub fn ready_subtasks(&self) -> Vec<&SubTask> {
382 self.subtasks
383 .values()
384 .filter(|s| s.status == SubTaskStatus::Pending && s.can_run(&self.completed))
385 .collect()
386 }
387
388 pub fn subtasks_for_stage(&self, stage: usize) -> Vec<&SubTask> {
390 self.subtasks
391 .values()
392 .filter(|s| s.stage == stage)
393 .collect()
394 }
395
396 pub fn create_subagent(&mut self, subtask: &SubTask) -> SubAgent {
398 let specialty = subtask
399 .specialty
400 .clone()
401 .unwrap_or_else(|| "General".to_string());
402 let name = format!("{} Agent", specialty);
403
404 let subagent = SubAgent::new(name, specialty, &subtask.id, &self.model, &self.provider);
405
406 self.subagents.insert(subagent.id.clone(), subagent.clone());
407 self.stats.subagents_spawned += 1;
408
409 subagent
410 }
411
412 pub fn complete_subtask(&mut self, subtask_id: &str, result: SubTaskResult) {
414 if let Some(subtask) = self.subtasks.get_mut(subtask_id) {
415 subtask.complete(result.success);
416
417 if result.success {
418 self.completed.push(subtask_id.to_string());
419 self.stats.subagents_completed += 1;
420 } else {
421 self.stats.subagents_failed += 1;
422 }
423
424 self.stats.total_tool_calls += result.tool_calls;
425 }
426 }
427
428 pub fn all_subtasks(&self) -> Vec<&SubTask> {
430 self.subtasks.values().collect()
431 }
432
433 pub fn stats(&self) -> &SwarmStats {
435 &self.stats
436 }
437
438 pub fn stats_mut(&mut self) -> &mut SwarmStats {
440 &mut self.stats
441 }
442
443 pub fn is_complete(&self) -> bool {
445 self.subtasks.values().all(|s| {
446 matches!(
447 s.status,
448 SubTaskStatus::Completed | SubTaskStatus::Failed | SubTaskStatus::Cancelled
449 )
450 })
451 }
452
453 pub fn providers(&self) -> &ProviderRegistry {
455 &self.providers
456 }
457
458 pub fn model(&self) -> &str {
460 &self.model
461 }
462
463 pub fn provider(&self) -> &str {
465 &self.provider
466 }
467}
468
469#[derive(Debug, Clone, Serialize, Deserialize)]
471pub enum SubAgentMessage {
472 Progress {
474 subagent_id: String,
475 subtask_id: String,
476 steps: usize,
477 status: String,
478 },
479
480 ToolCall {
482 subagent_id: String,
483 tool_name: String,
484 success: bool,
485 },
486
487 Completed {
489 subagent_id: String,
490 result: SubTaskResult,
491 },
492
493 ResourceRequest {
495 subagent_id: String,
496 resource_type: String,
497 resource_id: String,
498 },
499}
500
501#[derive(Debug, Clone, Serialize, Deserialize)]
503pub enum OrchestratorMessage {
504 Start { subtask: Box<SubTask> },
506
507 Resource {
509 resource_type: String,
510 resource_id: String,
511 content: String,
512 },
513
514 Terminate { reason: String },
516
517 ContextUpdate {
519 dependency_id: String,
520 result: String,
521 },
522}