1use super::{
7 DecompositionStrategy, SubAgent, SubTask, SubTaskResult, SubTaskStatus, SwarmConfig, SwarmStats,
8};
9use crate::provider::{CompletionRequest, ContentPart, Message, ProviderRegistry, Role};
10use anyhow::Result;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14pub struct Orchestrator {
16 config: SwarmConfig,
18
19 providers: ProviderRegistry,
21
22 subtasks: HashMap<String, SubTask>,
24
25 subagents: HashMap<String, SubAgent>,
27
28 completed: Vec<String>,
30
31 model: String,
33
34 provider: String,
36
37 stats: SwarmStats,
39}
40
41impl Orchestrator {
42 pub async fn new(config: SwarmConfig) -> Result<Self> {
44 use crate::provider::parse_model_string;
45
46 let providers = ProviderRegistry::from_vault().await?;
47 let provider_list = providers.list();
48
49 if provider_list.is_empty() {
50 anyhow::bail!("No providers available for orchestration");
51 }
52
53 let model_str = config
55 .model
56 .clone()
57 .or_else(|| std::env::var("CODETETHER_DEFAULT_MODEL").ok());
58
59 let (provider, model) = if let Some(ref model_str) = model_str {
60 let (prov, mod_id) = parse_model_string(model_str);
61 let provider = prov
62 .filter(|p| provider_list.contains(p))
63 .unwrap_or(provider_list[0])
64 .to_string();
65 let model = mod_id.to_string();
66 (provider, model)
67 } else {
68 let provider = if provider_list.contains(&"novita") {
70 "novita".to_string()
71 } else if provider_list.contains(&"moonshotai") {
72 "moonshotai".to_string()
73 } else {
74 provider_list[0].to_string()
75 };
76 let model = Self::default_model_for_provider(&provider);
77 (provider, model)
78 };
79
80 tracing::info!("Orchestrator using model {} via {}", model, provider);
81
82 Ok(Self {
83 config,
84 providers,
85 subtasks: HashMap::new(),
86 subagents: HashMap::new(),
87 completed: Vec::new(),
88 model,
89 provider,
90 stats: SwarmStats::default(),
91 })
92 }
93
94 fn default_model_for_provider(provider: &str) -> String {
96 match provider {
97 "moonshotai" => "kimi-k2.5".to_string(),
98 "anthropic" => "claude-sonnet-4-20250514".to_string(),
99 "openai" => "gpt-4o".to_string(),
100 "google" => "gemini-2.5-pro".to_string(),
101 "openrouter" => "z-ai/glm-4.7".to_string(),
102 "novita" => "qwen/qwen3-coder-next".to_string(),
103 "github-copilot" | "github-copilot-enterprise" => "gpt-5-mini".to_string(),
104 _ => "kimi-k2.5".to_string(),
105 }
106 }
107
108 pub async fn decompose(
110 &mut self,
111 task: &str,
112 strategy: DecompositionStrategy,
113 ) -> Result<Vec<SubTask>> {
114 if strategy == DecompositionStrategy::None {
115 let subtask = SubTask::new("Main Task", task);
117 self.subtasks.insert(subtask.id.clone(), subtask.clone());
118 return Ok(vec![subtask]);
119 }
120
121 let decomposition_prompt = self.build_decomposition_prompt(task, strategy);
123
124 let provider = self
125 .providers
126 .get(&self.provider)
127 .ok_or_else(|| anyhow::anyhow!("Provider {} not found", self.provider))?;
128
129 let temperature = if self.model.starts_with("kimi-k2") {
130 1.0
131 } else {
132 0.7
133 };
134
135 let request = CompletionRequest {
136 messages: vec![Message {
137 role: Role::User,
138 content: vec![ContentPart::Text {
139 text: decomposition_prompt,
140 }],
141 }],
142 tools: Vec::new(),
143 model: self.model.clone(),
144 temperature: Some(temperature),
145 top_p: None,
146 max_tokens: Some(8192),
147 stop: Vec::new(),
148 };
149
150 let response = provider.complete(request).await?;
151
152 let text = response
154 .message
155 .content
156 .iter()
157 .filter_map(|p| match p {
158 ContentPart::Text { text } => Some(text.clone()),
159 _ => None,
160 })
161 .collect::<Vec<_>>()
162 .join("\n");
163
164 tracing::debug!("Decomposition response: {}", text);
165
166 if text.trim().is_empty() {
167 tracing::warn!("Empty decomposition response, falling back to single task");
169 let subtask = SubTask::new("Main Task", task);
170 self.subtasks.insert(subtask.id.clone(), subtask.clone());
171 return Ok(vec![subtask]);
172 }
173
174 let subtasks = self.parse_decomposition(&text)?;
175
176 for subtask in &subtasks {
178 self.subtasks.insert(subtask.id.clone(), subtask.clone());
179 }
180
181 self.assign_stages();
183
184 tracing::info!(
185 "Decomposed task into {} subtasks across {} stages",
186 subtasks.len(),
187 self.max_stage() + 1
188 );
189
190 Ok(subtasks)
191 }
192
193 fn build_decomposition_prompt(&self, task: &str, strategy: DecompositionStrategy) -> String {
195 let strategy_instruction = match strategy {
196 DecompositionStrategy::Automatic => {
197 "Analyze the task and determine the optimal way to decompose it into parallel subtasks."
198 }
199 DecompositionStrategy::ByDomain => {
200 "Decompose the task by domain expertise (e.g., research, coding, analysis, verification)."
201 }
202 DecompositionStrategy::ByData => {
203 "Decompose the task by data partition (e.g., different files, sections, or datasets)."
204 }
205 DecompositionStrategy::ByStage => {
206 "Decompose the task by workflow stages (e.g., gather, process, synthesize)."
207 }
208 DecompositionStrategy::None => unreachable!(),
209 };
210
211 format!(
212 r#"You are a task orchestrator. Your job is to decompose complex tasks into parallelizable subtasks.
213
214TASK: {task}
215
216STRATEGY: {strategy_instruction}
217
218CONSTRAINTS:
219- Maximum {max_subtasks} subtasks
220- Each subtask should be independently executable
221- Identify dependencies between subtasks (which must complete before others can start)
222- Assign a specialty/role to each subtask
223
224OUTPUT FORMAT (JSON):
225```json
226{{
227 "subtasks": [
228 {{
229 "name": "Subtask Name",
230 "instruction": "Detailed instruction for this subtask",
231 "specialty": "Role/specialty (e.g., Researcher, Coder, Analyst)",
232 "dependencies": ["id-of-dependency-1"],
233 "priority": 1
234 }}
235 ]
236}}
237```
238
239Decompose the task now:"#,
240 task = task,
241 strategy_instruction = strategy_instruction,
242 max_subtasks = self.config.max_subagents,
243 )
244 }
245
246 fn parse_decomposition(&self, response: &str) -> Result<Vec<SubTask>> {
248 let json_str = if let Some(start) = response.find("```json") {
250 let start = start + 7;
251 if let Some(end) = response[start..].find("```") {
252 &response[start..start + end]
253 } else {
254 response
255 }
256 } else if let Some(start) = response.find('{') {
257 if let Some(end) = response.rfind('}') {
258 &response[start..=end]
259 } else {
260 response
261 }
262 } else {
263 response
264 };
265
266 #[derive(Deserialize)]
267 struct DecompositionResponse {
268 subtasks: Vec<SubTaskDef>,
269 }
270
271 #[derive(Deserialize)]
272 struct SubTaskDef {
273 name: String,
274 instruction: String,
275 specialty: Option<String>,
276 #[serde(default)]
277 dependencies: Vec<String>,
278 #[serde(default)]
279 priority: i32,
280 }
281
282 let parsed: DecompositionResponse = serde_json::from_str(json_str.trim())
283 .map_err(|e| anyhow::anyhow!("Failed to parse decomposition: {}", e))?;
284
285 let mut subtasks = Vec::new();
287 let mut name_to_id: HashMap<String, String> = HashMap::new();
288
289 for def in &parsed.subtasks {
291 let subtask = SubTask::new(&def.name, &def.instruction).with_priority(def.priority);
292
293 let subtask = if let Some(ref specialty) = def.specialty {
294 subtask.with_specialty(specialty)
295 } else {
296 subtask
297 };
298
299 name_to_id.insert(def.name.clone(), subtask.id.clone());
300 subtasks.push((subtask, def.dependencies.clone()));
301 }
302
303 let result: Vec<SubTask> = subtasks
305 .into_iter()
306 .map(|(mut subtask, deps)| {
307 let resolved_deps: Vec<String> = deps
308 .iter()
309 .filter_map(|dep| name_to_id.get(dep).cloned())
310 .collect();
311 subtask.dependencies = resolved_deps;
312 subtask
313 })
314 .collect();
315
316 Ok(result)
317 }
318
319 fn assign_stages(&mut self) {
321 let mut changed = true;
322
323 while changed {
324 changed = false;
325
326 let updates: Vec<(String, usize)> = self
328 .subtasks
329 .iter()
330 .filter_map(|(id, subtask)| {
331 if subtask.dependencies.is_empty() {
332 if subtask.stage != 0 {
333 Some((id.clone(), 0))
334 } else {
335 None
336 }
337 } else {
338 let max_dep_stage = subtask
339 .dependencies
340 .iter()
341 .filter_map(|dep_id| self.subtasks.get(dep_id))
342 .map(|dep| dep.stage)
343 .max()
344 .unwrap_or(0);
345
346 let new_stage = max_dep_stage + 1;
347 if subtask.stage != new_stage {
348 Some((id.clone(), new_stage))
349 } else {
350 None
351 }
352 }
353 })
354 .collect();
355
356 for (id, new_stage) in updates {
358 if let Some(subtask) = self.subtasks.get_mut(&id) {
359 subtask.stage = new_stage;
360 changed = true;
361 }
362 }
363 }
364 }
365
366 fn max_stage(&self) -> usize {
368 self.subtasks.values().map(|s| s.stage).max().unwrap_or(0)
369 }
370
371 pub fn ready_subtasks(&self) -> Vec<&SubTask> {
373 self.subtasks
374 .values()
375 .filter(|s| s.status == SubTaskStatus::Pending && s.can_run(&self.completed))
376 .collect()
377 }
378
379 pub fn subtasks_for_stage(&self, stage: usize) -> Vec<&SubTask> {
381 self.subtasks
382 .values()
383 .filter(|s| s.stage == stage)
384 .collect()
385 }
386
387 pub fn create_subagent(&mut self, subtask: &SubTask) -> SubAgent {
389 let specialty = subtask
390 .specialty
391 .clone()
392 .unwrap_or_else(|| "General".to_string());
393 let name = format!("{} Agent", specialty);
394
395 let subagent = SubAgent::new(name, specialty, &subtask.id, &self.model, &self.provider);
396
397 self.subagents.insert(subagent.id.clone(), subagent.clone());
398 self.stats.subagents_spawned += 1;
399
400 subagent
401 }
402
403 pub fn complete_subtask(&mut self, subtask_id: &str, result: SubTaskResult) {
405 if let Some(subtask) = self.subtasks.get_mut(subtask_id) {
406 subtask.complete(result.success);
407
408 if result.success {
409 self.completed.push(subtask_id.to_string());
410 self.stats.subagents_completed += 1;
411 } else {
412 self.stats.subagents_failed += 1;
413 }
414
415 self.stats.total_tool_calls += result.tool_calls;
416 }
417 }
418
419 pub fn all_subtasks(&self) -> Vec<&SubTask> {
421 self.subtasks.values().collect()
422 }
423
424 pub fn stats(&self) -> &SwarmStats {
426 &self.stats
427 }
428
429 pub fn stats_mut(&mut self) -> &mut SwarmStats {
431 &mut self.stats
432 }
433
434 pub fn is_complete(&self) -> bool {
436 self.subtasks.values().all(|s| {
437 matches!(
438 s.status,
439 SubTaskStatus::Completed | SubTaskStatus::Failed | SubTaskStatus::Cancelled
440 )
441 })
442 }
443
444 pub fn providers(&self) -> &ProviderRegistry {
446 &self.providers
447 }
448
449 pub fn model(&self) -> &str {
451 &self.model
452 }
453
454 pub fn provider(&self) -> &str {
456 &self.provider
457 }
458}
459
460#[derive(Debug, Clone, Serialize, Deserialize)]
462pub enum SubAgentMessage {
463 Progress {
465 subagent_id: String,
466 subtask_id: String,
467 steps: usize,
468 status: String,
469 },
470
471 ToolCall {
473 subagent_id: String,
474 tool_name: String,
475 success: bool,
476 },
477
478 Completed {
480 subagent_id: String,
481 result: SubTaskResult,
482 },
483
484 ResourceRequest {
486 subagent_id: String,
487 resource_type: String,
488 resource_id: String,
489 },
490}
491
492#[derive(Debug, Clone, Serialize, Deserialize)]
494pub enum OrchestratorMessage {
495 Start { subtask: Box<SubTask> },
497
498 Resource {
500 resource_type: String,
501 resource_id: String,
502 content: String,
503 },
504
505 Terminate { reason: String },
507
508 ContextUpdate {
510 dependency_id: String,
511 result: String,
512 },
513}