1use super::{
7 DecompositionStrategy, SubAgent, SubTask, SubTaskResult, SubTaskStatus, SwarmConfig, SwarmStats,
8};
9use crate::provider::{CompletionRequest, ContentPart, Message, ProviderRegistry, Role};
10use anyhow::Result;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14pub struct Orchestrator {
16 config: SwarmConfig,
18
19 providers: ProviderRegistry,
21
22 subtasks: HashMap<String, SubTask>,
24
25 subagents: HashMap<String, SubAgent>,
27
28 completed: Vec<String>,
30
31 model: String,
33
34 provider: String,
36
37 stats: SwarmStats,
39}
40
41impl Orchestrator {
42 pub async fn new(config: SwarmConfig) -> Result<Self> {
44 use crate::provider::parse_model_string;
45
46 let providers = ProviderRegistry::from_vault().await?;
47 let provider_list = providers.list();
48
49 if provider_list.is_empty() {
50 anyhow::bail!("No providers available for orchestration");
51 }
52
53 let model_str = config
55 .model
56 .clone()
57 .or_else(|| std::env::var("CODETETHER_DEFAULT_MODEL").ok());
58
59 let (provider, model) = if let Some(ref model_str) = model_str {
60 let (prov, mod_id) = parse_model_string(model_str);
61 let provider = prov
62 .filter(|p| provider_list.contains(p))
63 .unwrap_or(provider_list[0])
64 .to_string();
65 let model = mod_id.to_string();
66 (provider, model)
67 } else {
68 let provider = if provider_list.contains(&"novita") {
70 "novita".to_string()
71 } else if provider_list.contains(&"moonshotai") {
72 "moonshotai".to_string()
73 } else {
74 provider_list[0].to_string()
75 };
76 let model = Self::default_model_for_provider(&provider);
77 (provider, model)
78 };
79
80 tracing::info!("Orchestrator using model {} via {}", model, provider);
81
82 Ok(Self {
83 config,
84 providers,
85 subtasks: HashMap::new(),
86 subagents: HashMap::new(),
87 completed: Vec::new(),
88 model,
89 provider,
90 stats: SwarmStats::default(),
91 })
92 }
93
94 fn default_model_for_provider(provider: &str) -> String {
96 match provider {
97 "moonshotai" => "kimi-k2.5".to_string(),
98 "anthropic" => "claude-sonnet-4-20250514".to_string(),
99 "openai" => "gpt-4o".to_string(),
100 "google" => "gemini-2.5-pro".to_string(),
101 "openrouter" => "stepfun/step-3.5-flash:free".to_string(),
102 "novita" => "qwen/qwen3-coder-next".to_string(),
103 _ => "kimi-k2.5".to_string(),
104 }
105 }
106
107 pub async fn decompose(
109 &mut self,
110 task: &str,
111 strategy: DecompositionStrategy,
112 ) -> Result<Vec<SubTask>> {
113 if strategy == DecompositionStrategy::None {
114 let subtask = SubTask::new("Main Task", task);
116 self.subtasks.insert(subtask.id.clone(), subtask.clone());
117 return Ok(vec![subtask]);
118 }
119
120 let decomposition_prompt = self.build_decomposition_prompt(task, strategy);
122
123 let provider = self
124 .providers
125 .get(&self.provider)
126 .ok_or_else(|| anyhow::anyhow!("Provider {} not found", self.provider))?;
127
128 let temperature = if self.model.starts_with("kimi-k2") {
129 1.0
130 } else {
131 0.7
132 };
133
134 let request = CompletionRequest {
135 messages: vec![Message {
136 role: Role::User,
137 content: vec![ContentPart::Text {
138 text: decomposition_prompt,
139 }],
140 }],
141 tools: Vec::new(),
142 model: self.model.clone(),
143 temperature: Some(temperature),
144 top_p: None,
145 max_tokens: Some(8192),
146 stop: Vec::new(),
147 };
148
149 let response = provider.complete(request).await?;
150
151 let text = response
153 .message
154 .content
155 .iter()
156 .filter_map(|p| match p {
157 ContentPart::Text { text } => Some(text.clone()),
158 _ => None,
159 })
160 .collect::<Vec<_>>()
161 .join("\n");
162
163 tracing::debug!("Decomposition response: {}", text);
164
165 if text.trim().is_empty() {
166 tracing::warn!("Empty decomposition response, falling back to single task");
168 let subtask = SubTask::new("Main Task", task);
169 self.subtasks.insert(subtask.id.clone(), subtask.clone());
170 return Ok(vec![subtask]);
171 }
172
173 let subtasks = self.parse_decomposition(&text)?;
174
175 for subtask in &subtasks {
177 self.subtasks.insert(subtask.id.clone(), subtask.clone());
178 }
179
180 self.assign_stages();
182
183 tracing::info!(
184 "Decomposed task into {} subtasks across {} stages",
185 subtasks.len(),
186 self.max_stage() + 1
187 );
188
189 Ok(subtasks)
190 }
191
192 fn build_decomposition_prompt(&self, task: &str, strategy: DecompositionStrategy) -> String {
194 let strategy_instruction = match strategy {
195 DecompositionStrategy::Automatic => {
196 "Analyze the task and determine the optimal way to decompose it into parallel subtasks."
197 }
198 DecompositionStrategy::ByDomain => {
199 "Decompose the task by domain expertise (e.g., research, coding, analysis, verification)."
200 }
201 DecompositionStrategy::ByData => {
202 "Decompose the task by data partition (e.g., different files, sections, or datasets)."
203 }
204 DecompositionStrategy::ByStage => {
205 "Decompose the task by workflow stages (e.g., gather, process, synthesize)."
206 }
207 DecompositionStrategy::None => unreachable!(),
208 };
209
210 format!(
211 r#"You are a task orchestrator. Your job is to decompose complex tasks into parallelizable subtasks.
212
213TASK: {task}
214
215STRATEGY: {strategy_instruction}
216
217CONSTRAINTS:
218- Maximum {max_subtasks} subtasks
219- Each subtask should be independently executable
220- Identify dependencies between subtasks (which must complete before others can start)
221- Assign a specialty/role to each subtask
222
223OUTPUT FORMAT (JSON):
224```json
225{{
226 "subtasks": [
227 {{
228 "name": "Subtask Name",
229 "instruction": "Detailed instruction for this subtask",
230 "specialty": "Role/specialty (e.g., Researcher, Coder, Analyst)",
231 "dependencies": ["id-of-dependency-1"],
232 "priority": 1
233 }}
234 ]
235}}
236```
237
238Decompose the task now:"#,
239 task = task,
240 strategy_instruction = strategy_instruction,
241 max_subtasks = self.config.max_subagents,
242 )
243 }
244
245 fn parse_decomposition(&self, response: &str) -> Result<Vec<SubTask>> {
247 let json_str = if let Some(start) = response.find("```json") {
249 let start = start + 7;
250 if let Some(end) = response[start..].find("```") {
251 &response[start..start + end]
252 } else {
253 response
254 }
255 } else if let Some(start) = response.find('{') {
256 if let Some(end) = response.rfind('}') {
257 &response[start..=end]
258 } else {
259 response
260 }
261 } else {
262 response
263 };
264
265 #[derive(Deserialize)]
266 struct DecompositionResponse {
267 subtasks: Vec<SubTaskDef>,
268 }
269
270 #[derive(Deserialize)]
271 struct SubTaskDef {
272 name: String,
273 instruction: String,
274 specialty: Option<String>,
275 #[serde(default)]
276 dependencies: Vec<String>,
277 #[serde(default)]
278 priority: i32,
279 }
280
281 let parsed: DecompositionResponse = serde_json::from_str(json_str.trim())
282 .map_err(|e| anyhow::anyhow!("Failed to parse decomposition: {}", e))?;
283
284 let mut subtasks = Vec::new();
286 let mut name_to_id: HashMap<String, String> = HashMap::new();
287
288 for def in &parsed.subtasks {
290 let subtask = SubTask::new(&def.name, &def.instruction).with_priority(def.priority);
291
292 let subtask = if let Some(ref specialty) = def.specialty {
293 subtask.with_specialty(specialty)
294 } else {
295 subtask
296 };
297
298 name_to_id.insert(def.name.clone(), subtask.id.clone());
299 subtasks.push((subtask, def.dependencies.clone()));
300 }
301
302 let result: Vec<SubTask> = subtasks
304 .into_iter()
305 .map(|(mut subtask, deps)| {
306 let resolved_deps: Vec<String> = deps
307 .iter()
308 .filter_map(|dep| name_to_id.get(dep).cloned())
309 .collect();
310 subtask.dependencies = resolved_deps;
311 subtask
312 })
313 .collect();
314
315 Ok(result)
316 }
317
318 fn assign_stages(&mut self) {
320 let mut changed = true;
321
322 while changed {
323 changed = false;
324
325 let updates: Vec<(String, usize)> = self
327 .subtasks
328 .iter()
329 .filter_map(|(id, subtask)| {
330 if subtask.dependencies.is_empty() {
331 if subtask.stage != 0 {
332 Some((id.clone(), 0))
333 } else {
334 None
335 }
336 } else {
337 let max_dep_stage = subtask
338 .dependencies
339 .iter()
340 .filter_map(|dep_id| self.subtasks.get(dep_id))
341 .map(|dep| dep.stage)
342 .max()
343 .unwrap_or(0);
344
345 let new_stage = max_dep_stage + 1;
346 if subtask.stage != new_stage {
347 Some((id.clone(), new_stage))
348 } else {
349 None
350 }
351 }
352 })
353 .collect();
354
355 for (id, new_stage) in updates {
357 if let Some(subtask) = self.subtasks.get_mut(&id) {
358 subtask.stage = new_stage;
359 changed = true;
360 }
361 }
362 }
363 }
364
365 fn max_stage(&self) -> usize {
367 self.subtasks.values().map(|s| s.stage).max().unwrap_or(0)
368 }
369
370 pub fn ready_subtasks(&self) -> Vec<&SubTask> {
372 self.subtasks
373 .values()
374 .filter(|s| s.status == SubTaskStatus::Pending && s.can_run(&self.completed))
375 .collect()
376 }
377
378 pub fn subtasks_for_stage(&self, stage: usize) -> Vec<&SubTask> {
380 self.subtasks
381 .values()
382 .filter(|s| s.stage == stage)
383 .collect()
384 }
385
386 pub fn create_subagent(&mut self, subtask: &SubTask) -> SubAgent {
388 let specialty = subtask
389 .specialty
390 .clone()
391 .unwrap_or_else(|| "General".to_string());
392 let name = format!("{} Agent", specialty);
393
394 let subagent = SubAgent::new(name, specialty, &subtask.id, &self.model, &self.provider);
395
396 self.subagents.insert(subagent.id.clone(), subagent.clone());
397 self.stats.subagents_spawned += 1;
398
399 subagent
400 }
401
402 pub fn complete_subtask(&mut self, subtask_id: &str, result: SubTaskResult) {
404 if let Some(subtask) = self.subtasks.get_mut(subtask_id) {
405 subtask.complete(result.success);
406
407 if result.success {
408 self.completed.push(subtask_id.to_string());
409 self.stats.subagents_completed += 1;
410 } else {
411 self.stats.subagents_failed += 1;
412 }
413
414 self.stats.total_tool_calls += result.tool_calls;
415 }
416 }
417
418 pub fn all_subtasks(&self) -> Vec<&SubTask> {
420 self.subtasks.values().collect()
421 }
422
423 pub fn stats(&self) -> &SwarmStats {
425 &self.stats
426 }
427
428 pub fn stats_mut(&mut self) -> &mut SwarmStats {
430 &mut self.stats
431 }
432
433 pub fn is_complete(&self) -> bool {
435 self.subtasks.values().all(|s| {
436 matches!(
437 s.status,
438 SubTaskStatus::Completed | SubTaskStatus::Failed | SubTaskStatus::Cancelled
439 )
440 })
441 }
442
443 pub fn providers(&self) -> &ProviderRegistry {
445 &self.providers
446 }
447
448 pub fn model(&self) -> &str {
450 &self.model
451 }
452
453 pub fn provider(&self) -> &str {
455 &self.provider
456 }
457}
458
459#[derive(Debug, Clone, Serialize, Deserialize)]
461pub enum SubAgentMessage {
462 Progress {
464 subagent_id: String,
465 subtask_id: String,
466 steps: usize,
467 status: String,
468 },
469
470 ToolCall {
472 subagent_id: String,
473 tool_name: String,
474 success: bool,
475 },
476
477 Completed {
479 subagent_id: String,
480 result: SubTaskResult,
481 },
482
483 ResourceRequest {
485 subagent_id: String,
486 resource_type: String,
487 resource_id: String,
488 },
489}
490
491#[derive(Debug, Clone, Serialize, Deserialize)]
493pub enum OrchestratorMessage {
494 Start { subtask: Box<SubTask> },
496
497 Resource {
499 resource_type: String,
500 resource_id: String,
501 content: String,
502 },
503
504 Terminate { reason: String },
506
507 ContextUpdate {
509 dependency_id: String,
510 result: String,
511 },
512}