use super::{
DecompositionStrategy, SubAgent, SubTask, SubTaskResult, SubTaskStatus, SwarmConfig, SwarmStats,
};
use crate::provider::{CompletionRequest, ContentPart, Message, ProviderRegistry, Role};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub struct Orchestrator {
config: SwarmConfig,
providers: ProviderRegistry,
subtasks: HashMap<String, SubTask>,
subagents: HashMap<String, SubAgent>,
completed: Vec<String>,
model: String,
provider: String,
stats: SwarmStats,
}
impl Orchestrator {
pub async fn new(config: SwarmConfig) -> Result<Self> {
use crate::provider::parse_model_string;
let providers = ProviderRegistry::from_vault().await?;
let provider_list = providers.list();
if provider_list.is_empty() {
anyhow::bail!("No providers available for orchestration");
}
let model_str = config
.model
.clone()
.or_else(|| std::env::var("CODETETHER_DEFAULT_MODEL").ok());
let (provider, model) = if let Some(ref model_str) = model_str {
let (prov, mod_id) = parse_model_string(model_str);
let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
let provider = if let Some(explicit_provider) = prov {
if provider_list.contains(&explicit_provider) {
explicit_provider.to_string()
} else {
anyhow::bail!(
"Provider '{}' selected explicitly but is unavailable. Available providers: {}",
explicit_provider,
provider_list.join(", ")
);
}
} else {
choose_default_provider(provider_list.as_slice())
.ok_or_else(|| anyhow::anyhow!("No providers available for orchestration"))?
.to_string()
};
let model = if mod_id.trim().is_empty() {
default_model_for_provider(&provider)
} else {
mod_id.to_string()
};
(provider, model)
} else {
let provider = choose_default_provider(provider_list.as_slice())
.ok_or_else(|| anyhow::anyhow!("No providers available for orchestration"))?
.to_string();
let model = default_model_for_provider(&provider);
(provider, model)
};
tracing::info!("Orchestrator using model {} via {}", model, provider);
Ok(Self {
config,
providers,
subtasks: HashMap::new(),
subagents: HashMap::new(),
completed: Vec::new(),
model,
provider,
stats: SwarmStats::default(),
})
}
fn prefers_temperature_one(model: &str) -> bool {
let normalized = model.to_ascii_lowercase();
normalized.contains("kimi-k2")
|| normalized.contains("glm-")
|| normalized.contains("minimax")
}
pub async fn decompose(
&mut self,
task: &str,
strategy: DecompositionStrategy,
) -> Result<Vec<SubTask>> {
if strategy == DecompositionStrategy::None {
let subtask = SubTask::new("Main Task", task);
self.subtasks.insert(subtask.id.clone(), subtask.clone());
return Ok(vec![subtask]);
}
let decomposition_prompt = self.build_decomposition_prompt(task, strategy);
let provider = self
.providers
.get(&self.provider)
.ok_or_else(|| anyhow::anyhow!("Provider {} not found", self.provider))?;
let temperature = if Self::prefers_temperature_one(&self.model) {
1.0
} else {
0.7
};
let request = CompletionRequest {
messages: vec![Message {
role: Role::User,
content: vec![ContentPart::Text {
text: decomposition_prompt,
}],
}],
tools: Vec::new(),
model: self.model.clone(),
temperature: Some(temperature),
top_p: None,
max_tokens: Some(8192),
stop: Vec::new(),
};
let response = provider.complete(request).await?;
let text = response
.message
.content
.iter()
.filter_map(|p| match p {
ContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
tracing::debug!("Decomposition response: {}", text);
if text.trim().is_empty() {
tracing::warn!("Empty decomposition response, falling back to single task");
let subtask = SubTask::new("Main Task", task);
self.subtasks.insert(subtask.id.clone(), subtask.clone());
return Ok(vec![subtask]);
}
let subtasks = self.parse_decomposition(&text)?;
for subtask in &subtasks {
self.subtasks.insert(subtask.id.clone(), subtask.clone());
}
self.assign_stages();
tracing::info!(
"Decomposed task into {} subtasks across {} stages",
subtasks.len(),
self.max_stage() + 1
);
Ok(subtasks)
}
fn build_decomposition_prompt(&self, task: &str, strategy: DecompositionStrategy) -> String {
let strategy_instruction = match strategy {
DecompositionStrategy::Automatic => {
"Analyze the task and determine the optimal way to decompose it into parallel subtasks."
}
DecompositionStrategy::ByDomain => {
"Decompose the task by domain expertise (e.g., research, coding, analysis, verification)."
}
DecompositionStrategy::ByData => {
"Decompose the task by data partition (e.g., different files, sections, or datasets)."
}
DecompositionStrategy::ByStage => {
"Decompose the task by workflow stages (e.g., gather, process, synthesize)."
}
DecompositionStrategy::None => unreachable!(),
};
format!(
r#"You are a task orchestrator. Your job is to decompose complex tasks into parallelizable subtasks.
TASK: {task}
STRATEGY: {strategy_instruction}
CONSTRAINTS:
- Maximum {max_subtasks} subtasks
- Each subtask should be independently executable
- Identify dependencies between subtasks (which must complete before others can start)
- Assign a specialty/role to each subtask
OUTPUT FORMAT (JSON):
```json
{{
"subtasks": [
{{
"name": "Subtask Name",
"instruction": "Detailed instruction for this subtask",
"specialty": "Role/specialty (e.g., Researcher, Coder, Analyst)",
"dependencies": ["id-of-dependency-1"],
"priority": 1
}}
]
}}
```
Decompose the task now:"#,
task = task,
strategy_instruction = strategy_instruction,
max_subtasks = self.config.max_subagents,
)
}
fn parse_decomposition(&self, response: &str) -> Result<Vec<SubTask>> {
let json_str = if let Some(start) = response.find("```json") {
let start = start + 7;
if let Some(end) = response[start..].find("```") {
&response[start..start + end]
} else {
response
}
} else if let Some(start) = response.find('{') {
if let Some(end) = response.rfind('}') {
&response[start..=end]
} else {
response
}
} else {
response
};
#[derive(Deserialize)]
struct DecompositionResponse {
subtasks: Vec<SubTaskDef>,
}
#[derive(Deserialize)]
struct SubTaskDef {
name: String,
instruction: String,
specialty: Option<String>,
#[serde(default)]
dependencies: Vec<String>,
#[serde(default)]
priority: i32,
}
let parsed: DecompositionResponse = serde_json::from_str(json_str.trim())
.map_err(|e| anyhow::anyhow!("Failed to parse decomposition: {}", e))?;
let mut subtasks = Vec::new();
let mut name_to_id: HashMap<String, String> = HashMap::new();
for def in &parsed.subtasks {
let subtask = SubTask::new(&def.name, &def.instruction).with_priority(def.priority);
let subtask = if let Some(ref specialty) = def.specialty {
subtask.with_specialty(specialty)
} else {
subtask
};
name_to_id.insert(def.name.clone(), subtask.id.clone());
subtasks.push((subtask, def.dependencies.clone()));
}
let result: Vec<SubTask> = subtasks
.into_iter()
.map(|(mut subtask, deps)| {
let resolved_deps: Vec<String> = deps
.iter()
.filter_map(|dep| name_to_id.get(dep).cloned())
.collect();
subtask.dependencies = resolved_deps;
subtask
})
.collect();
Ok(result)
}
fn assign_stages(&mut self) {
let mut changed = true;
while changed {
changed = false;
let updates: Vec<(String, usize)> = self
.subtasks
.iter()
.filter_map(|(id, subtask)| {
if subtask.dependencies.is_empty() {
if subtask.stage != 0 {
Some((id.clone(), 0))
} else {
None
}
} else {
let max_dep_stage = subtask
.dependencies
.iter()
.filter_map(|dep_id| self.subtasks.get(dep_id))
.map(|dep| dep.stage)
.max()
.unwrap_or(0);
let new_stage = max_dep_stage + 1;
if subtask.stage != new_stage {
Some((id.clone(), new_stage))
} else {
None
}
}
})
.collect();
for (id, new_stage) in updates {
if let Some(subtask) = self.subtasks.get_mut(&id) {
subtask.stage = new_stage;
changed = true;
}
}
}
}
fn max_stage(&self) -> usize {
self.subtasks.values().map(|s| s.stage).max().unwrap_or(0)
}
pub fn ready_subtasks(&self) -> Vec<&SubTask> {
self.subtasks
.values()
.filter(|s| s.status == SubTaskStatus::Pending && s.can_run(&self.completed))
.collect()
}
pub fn subtasks_for_stage(&self, stage: usize) -> Vec<&SubTask> {
self.subtasks
.values()
.filter(|s| s.stage == stage)
.collect()
}
pub fn create_subagent(&mut self, subtask: &SubTask) -> SubAgent {
let specialty = subtask
.specialty
.clone()
.unwrap_or_else(|| "General".to_string());
let name = format!("{} Agent", specialty);
let subagent = SubAgent::new(name, specialty, &subtask.id, &self.model, &self.provider);
self.subagents.insert(subagent.id.clone(), subagent.clone());
self.stats.subagents_spawned += 1;
subagent
}
pub fn complete_subtask(&mut self, subtask_id: &str, result: SubTaskResult) {
if let Some(subtask) = self.subtasks.get_mut(subtask_id) {
subtask.complete(result.success);
if result.success {
self.completed.push(subtask_id.to_string());
self.stats.subagents_completed += 1;
} else {
self.stats.subagents_failed += 1;
}
self.stats.total_tool_calls += result.tool_calls;
}
}
pub fn all_subtasks(&self) -> Vec<&SubTask> {
self.subtasks.values().collect()
}
pub fn stats(&self) -> &SwarmStats {
&self.stats
}
pub fn stats_mut(&mut self) -> &mut SwarmStats {
&mut self.stats
}
pub fn is_complete(&self) -> bool {
self.subtasks.values().all(|s| {
matches!(
s.status,
SubTaskStatus::Completed | SubTaskStatus::Failed | SubTaskStatus::Cancelled
)
})
}
pub fn providers(&self) -> &ProviderRegistry {
&self.providers
}
pub fn model(&self) -> &str {
&self.model
}
pub fn provider(&self) -> &str {
&self.provider
}
}
pub(crate) fn choose_default_provider<'a>(providers: &'a [&'a str]) -> Option<&'a str> {
let preferred = [
"openai",
"anthropic",
"github-copilot",
"github-copilot-enterprise",
"openai-codex",
"zai",
"minimax",
"moonshotai",
"openrouter",
"novita",
"google",
"bedrock",
];
for name in preferred {
if let Some(found) = providers.iter().copied().find(|p| *p == name) {
return Some(found);
}
}
providers.first().copied()
}
pub(crate) fn default_model_for_provider(provider: &str) -> String {
match provider {
"moonshotai" => "kimi-k2.5".to_string(),
"anthropic" => "claude-sonnet-4-20250514".to_string(),
"bedrock" => "us.anthropic.claude-opus-4-6-v1".to_string(),
"openai" => "gpt-4o".to_string(),
"google" => "gemini-2.5-pro".to_string(),
"zhipuai" | "zai" => "glm-5".to_string(),
"openrouter" => "z-ai/glm-5".to_string(),
"novita" => "Qwen/Qwen3.5-35B-A3B".to_string(),
"github-copilot" | "github-copilot-enterprise" | "openai-codex" => "gpt-5-mini".to_string(),
_ => "gpt-4o".to_string(),
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SubAgentMessage {
Progress {
subagent_id: String,
subtask_id: String,
steps: usize,
status: String,
},
ToolCall {
subagent_id: String,
tool_name: String,
success: bool,
},
Completed {
subagent_id: String,
result: SubTaskResult,
},
ResourceRequest {
subagent_id: String,
resource_type: String,
resource_id: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum OrchestratorMessage {
Start { subtask: Box<SubTask> },
Resource {
resource_type: String,
resource_id: String,
content: String,
},
Terminate { reason: String },
ContextUpdate {
dependency_id: String,
result: String,
},
}