1use super::{
7 DecompositionStrategy, SubAgent, SubTask, SubTaskResult, SubTaskStatus,
8 SwarmConfig, SwarmStats,
9};
10use crate::provider::{CompletionRequest, ContentPart, Message, ProviderRegistry, Role};
11use anyhow::Result;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15pub struct Orchestrator {
17 config: SwarmConfig,
19
20 providers: ProviderRegistry,
22
23 subtasks: HashMap<String, SubTask>,
25
26 subagents: HashMap<String, SubAgent>,
28
29 completed: Vec<String>,
31
32 model: String,
34
35 provider: String,
37
38 stats: SwarmStats,
40}
41
42impl Orchestrator {
43 pub async fn new(config: SwarmConfig) -> Result<Self> {
45 use crate::provider::parse_model_string;
46
47 let providers = ProviderRegistry::from_vault().await?;
48 let provider_list = providers.list();
49
50 if provider_list.is_empty() {
51 anyhow::bail!("No providers available for orchestration");
52 }
53
54 let (provider, model) = if let Some(ref model_str) = config.model {
56 let (prov, mod_id) = parse_model_string(model_str);
57 let provider = prov
58 .filter(|p| provider_list.contains(p))
59 .unwrap_or(provider_list[0])
60 .to_string();
61 let model = mod_id.to_string();
62 (provider, model)
63 } else {
64 let provider = provider_list[0].to_string();
65 let model = Self::default_model_for_provider(&provider);
66 (provider, model)
67 };
68
69 tracing::info!("Orchestrator using model {} via {}", model, provider);
70
71 Ok(Self {
72 config,
73 providers,
74 subtasks: HashMap::new(),
75 subagents: HashMap::new(),
76 completed: Vec::new(),
77 model,
78 provider,
79 stats: SwarmStats::default(),
80 })
81 }
82
83 fn default_model_for_provider(provider: &str) -> String {
85 match provider {
86 "moonshotai" => "kimi-k2.5".to_string(),
87 "anthropic" => "claude-sonnet-4-20250514".to_string(),
88 "openai" => "gpt-4o".to_string(),
89 "google" => "gemini-2.5-pro".to_string(),
90 "openrouter" => "stepfun/step-3.5-flash:free".to_string(),
91 _ => "kimi-k2.5".to_string(),
92 }
93 }
94
95 pub async fn decompose(
97 &mut self,
98 task: &str,
99 strategy: DecompositionStrategy,
100 ) -> Result<Vec<SubTask>> {
101 if strategy == DecompositionStrategy::None {
102 let subtask = SubTask::new("Main Task", task);
104 self.subtasks.insert(subtask.id.clone(), subtask.clone());
105 return Ok(vec![subtask]);
106 }
107
108 let decomposition_prompt = self.build_decomposition_prompt(task, strategy);
110
111 let provider = self.providers.get(&self.provider)
112 .ok_or_else(|| anyhow::anyhow!("Provider {} not found", self.provider))?;
113
114 let temperature = if self.model.starts_with("kimi-k2") { 1.0 } else { 0.7 };
115
116 let request = CompletionRequest {
117 messages: vec![Message {
118 role: Role::User,
119 content: vec![ContentPart::Text { text: decomposition_prompt }],
120 }],
121 tools: Vec::new(),
122 model: self.model.clone(),
123 temperature: Some(temperature),
124 top_p: None,
125 max_tokens: Some(8192),
126 stop: Vec::new(),
127 };
128
129 let response = provider.complete(request).await?;
130
131 let text = response.message.content
133 .iter()
134 .filter_map(|p| match p {
135 ContentPart::Text { text } => Some(text.clone()),
136 _ => None,
137 })
138 .collect::<Vec<_>>()
139 .join("\n");
140
141 tracing::debug!("Decomposition response: {}", text);
142
143 if text.trim().is_empty() {
144 tracing::warn!("Empty decomposition response, falling back to single task");
146 let subtask = SubTask::new("Main Task", task);
147 self.subtasks.insert(subtask.id.clone(), subtask.clone());
148 return Ok(vec![subtask]);
149 }
150
151 let subtasks = self.parse_decomposition(&text)?;
152
153 for subtask in &subtasks {
155 self.subtasks.insert(subtask.id.clone(), subtask.clone());
156 }
157
158 self.assign_stages();
160
161 tracing::info!(
162 "Decomposed task into {} subtasks across {} stages",
163 subtasks.len(),
164 self.max_stage() + 1
165 );
166
167 Ok(subtasks)
168 }
169
170 fn build_decomposition_prompt(&self, task: &str, strategy: DecompositionStrategy) -> String {
172 let strategy_instruction = match strategy {
173 DecompositionStrategy::Automatic => {
174 "Analyze the task and determine the optimal way to decompose it into parallel subtasks."
175 }
176 DecompositionStrategy::ByDomain => {
177 "Decompose the task by domain expertise (e.g., research, coding, analysis, verification)."
178 }
179 DecompositionStrategy::ByData => {
180 "Decompose the task by data partition (e.g., different files, sections, or datasets)."
181 }
182 DecompositionStrategy::ByStage => {
183 "Decompose the task by workflow stages (e.g., gather, process, synthesize)."
184 }
185 DecompositionStrategy::None => unreachable!(),
186 };
187
188 format!(
189 r#"You are a task orchestrator. Your job is to decompose complex tasks into parallelizable subtasks.
190
191TASK: {task}
192
193STRATEGY: {strategy_instruction}
194
195CONSTRAINTS:
196- Maximum {max_subtasks} subtasks
197- Each subtask should be independently executable
198- Identify dependencies between subtasks (which must complete before others can start)
199- Assign a specialty/role to each subtask
200
201OUTPUT FORMAT (JSON):
202```json
203{{
204 "subtasks": [
205 {{
206 "name": "Subtask Name",
207 "instruction": "Detailed instruction for this subtask",
208 "specialty": "Role/specialty (e.g., Researcher, Coder, Analyst)",
209 "dependencies": ["id-of-dependency-1"],
210 "priority": 1
211 }}
212 ]
213}}
214```
215
216Decompose the task now:"#,
217 task = task,
218 strategy_instruction = strategy_instruction,
219 max_subtasks = self.config.max_subagents,
220 )
221 }
222
223 fn parse_decomposition(&self, response: &str) -> Result<Vec<SubTask>> {
225 let json_str = if let Some(start) = response.find("```json") {
227 let start = start + 7;
228 if let Some(end) = response[start..].find("```") {
229 &response[start..start + end]
230 } else {
231 response
232 }
233 } else if let Some(start) = response.find('{') {
234 if let Some(end) = response.rfind('}') {
235 &response[start..=end]
236 } else {
237 response
238 }
239 } else {
240 response
241 };
242
243 #[derive(Deserialize)]
244 struct DecompositionResponse {
245 subtasks: Vec<SubTaskDef>,
246 }
247
248 #[derive(Deserialize)]
249 struct SubTaskDef {
250 name: String,
251 instruction: String,
252 specialty: Option<String>,
253 #[serde(default)]
254 dependencies: Vec<String>,
255 #[serde(default)]
256 priority: i32,
257 }
258
259 let parsed: DecompositionResponse = serde_json::from_str(json_str.trim())
260 .map_err(|e| anyhow::anyhow!("Failed to parse decomposition: {}", e))?;
261
262 let mut subtasks = Vec::new();
264 let mut name_to_id: HashMap<String, String> = HashMap::new();
265
266 for def in &parsed.subtasks {
268 let subtask = SubTask::new(&def.name, &def.instruction)
269 .with_priority(def.priority);
270
271 let subtask = if let Some(ref specialty) = def.specialty {
272 subtask.with_specialty(specialty)
273 } else {
274 subtask
275 };
276
277 name_to_id.insert(def.name.clone(), subtask.id.clone());
278 subtasks.push((subtask, def.dependencies.clone()));
279 }
280
281 let result: Vec<SubTask> = subtasks
283 .into_iter()
284 .map(|(mut subtask, deps)| {
285 let resolved_deps: Vec<String> = deps
286 .iter()
287 .filter_map(|dep| name_to_id.get(dep).cloned())
288 .collect();
289 subtask.dependencies = resolved_deps;
290 subtask
291 })
292 .collect();
293
294 Ok(result)
295 }
296
297 fn assign_stages(&mut self) {
299 let mut changed = true;
300
301 while changed {
302 changed = false;
303
304 let updates: Vec<(String, usize)> = self.subtasks.iter().filter_map(|(id, subtask)| {
306 if subtask.dependencies.is_empty() {
307 if subtask.stage != 0 {
308 Some((id.clone(), 0))
309 } else {
310 None
311 }
312 } else {
313 let max_dep_stage = subtask
314 .dependencies
315 .iter()
316 .filter_map(|dep_id| self.subtasks.get(dep_id))
317 .map(|dep| dep.stage)
318 .max()
319 .unwrap_or(0);
320
321 let new_stage = max_dep_stage + 1;
322 if subtask.stage != new_stage {
323 Some((id.clone(), new_stage))
324 } else {
325 None
326 }
327 }
328 }).collect();
329
330 for (id, new_stage) in updates {
332 if let Some(subtask) = self.subtasks.get_mut(&id) {
333 subtask.stage = new_stage;
334 changed = true;
335 }
336 }
337 }
338 }
339
340 fn max_stage(&self) -> usize {
342 self.subtasks.values().map(|s| s.stage).max().unwrap_or(0)
343 }
344
345 pub fn ready_subtasks(&self) -> Vec<&SubTask> {
347 self.subtasks
348 .values()
349 .filter(|s| s.status == SubTaskStatus::Pending && s.can_run(&self.completed))
350 .collect()
351 }
352
353 pub fn subtasks_for_stage(&self, stage: usize) -> Vec<&SubTask> {
355 self.subtasks
356 .values()
357 .filter(|s| s.stage == stage)
358 .collect()
359 }
360
361 pub fn create_subagent(&mut self, subtask: &SubTask) -> SubAgent {
363 let specialty = subtask.specialty.clone().unwrap_or_else(|| "General".to_string());
364 let name = format!("{} Agent", specialty);
365
366 let subagent = SubAgent::new(
367 name,
368 specialty,
369 &subtask.id,
370 &self.model,
371 &self.provider,
372 );
373
374 self.subagents.insert(subagent.id.clone(), subagent.clone());
375 self.stats.subagents_spawned += 1;
376
377 subagent
378 }
379
380 pub fn complete_subtask(&mut self, subtask_id: &str, result: SubTaskResult) {
382 if let Some(subtask) = self.subtasks.get_mut(subtask_id) {
383 subtask.complete(result.success);
384
385 if result.success {
386 self.completed.push(subtask_id.to_string());
387 self.stats.subagents_completed += 1;
388 } else {
389 self.stats.subagents_failed += 1;
390 }
391
392 self.stats.total_tool_calls += result.tool_calls;
393 }
394 }
395
396 pub fn all_subtasks(&self) -> Vec<&SubTask> {
398 self.subtasks.values().collect()
399 }
400
401 pub fn stats(&self) -> &SwarmStats {
403 &self.stats
404 }
405
406 pub fn stats_mut(&mut self) -> &mut SwarmStats {
408 &mut self.stats
409 }
410
411 pub fn is_complete(&self) -> bool {
413 self.subtasks.values().all(|s| {
414 matches!(s.status, SubTaskStatus::Completed | SubTaskStatus::Failed | SubTaskStatus::Cancelled)
415 })
416 }
417
418 pub fn providers(&self) -> &ProviderRegistry {
420 &self.providers
421 }
422
423 pub fn model(&self) -> &str {
425 &self.model
426 }
427
428 pub fn provider(&self) -> &str {
430 &self.provider
431 }
432}
433
434#[derive(Debug, Clone, Serialize, Deserialize)]
436pub enum SubAgentMessage {
437 Progress {
439 subagent_id: String,
440 subtask_id: String,
441 steps: usize,
442 status: String,
443 },
444
445 ToolCall {
447 subagent_id: String,
448 tool_name: String,
449 success: bool,
450 },
451
452 Completed {
454 subagent_id: String,
455 result: SubTaskResult,
456 },
457
458 ResourceRequest {
460 subagent_id: String,
461 resource_type: String,
462 resource_id: String,
463 },
464}
465
466#[derive(Debug, Clone, Serialize, Deserialize)]
468pub enum OrchestratorMessage {
469 Start {
471 subtask: Box<SubTask>,
472 },
473
474 Resource {
476 resource_type: String,
477 resource_id: String,
478 content: String,
479 },
480
481 Terminate {
483 reason: String,
484 },
485
486 ContextUpdate {
488 dependency_id: String,
489 result: String,
490 },
491}