Skip to main content

ares/agents/
orchestrator.rs

1use crate::{
2    agents::{Agent, AgentRegistry},
3    llm::LLMClient,
4    types::{AgentContext, AgentType, AppError, Result},
5    AppState,
6};
7use async_trait::async_trait;
8use std::sync::Arc;
9
10/// Orchestrator agent that coordinates multiple specialized agents.
11///
12/// This agent decomposes complex queries into subtasks and delegates
13/// them to appropriate specialized agents via the AgentRegistry.
14pub struct OrchestratorAgent {
15    llm: Box<dyn LLMClient>,
16    state: AppState,
17    agent_registry: Arc<AgentRegistry>,
18}
19
20impl OrchestratorAgent {
21    /// Creates a new OrchestratorAgent with the given dependencies.
22    pub fn new(
23        llm: Box<dyn LLMClient>,
24        state: AppState,
25        agent_registry: Arc<AgentRegistry>,
26    ) -> Self {
27        Self {
28            llm,
29            state,
30            agent_registry,
31        }
32    }
33
34    /// Decompose a complex task into subtasks for specialized agents
35    async fn decompose_task(&self, input: &str) -> Result<Vec<(String, String)>> {
36        // Get available agents from registry
37        let available_agents = self.agent_registry.agent_names();
38        let agent_list = available_agents
39            .iter()
40            .filter(|name| **name != "orchestrator" && **name != "router")
41            .cloned()
42            .collect::<Vec<_>>()
43            .join(", ");
44
45        let system_prompt = format!(
46            r#"You are a task decomposition agent. Break down complex queries into subtasks for specialized agents.
47
48Available agents: {}
49
50Return a JSON array of tasks:
51[
52    {{"agent": "sales", "task": "Get Q1 revenue"}},
53    {{"agent": "product", "task": "List top products"}}
54]
55
56Only respond with valid JSON."#,
57            agent_list
58        );
59
60        let response = self.llm.generate_with_system(&system_prompt, input).await?;
61
62        // Parse JSON response
63        let tasks: Vec<serde_json::Value> = serde_json::from_str(&response)
64            .map_err(|e| AppError::LLM(format!("Failed to parse tasks: {}", e)))?;
65
66        let mut result = Vec::new();
67        for task in tasks {
68            let agent_name = task["agent"].as_str().unwrap_or("product").to_string();
69            let task_str = task["task"].as_str().unwrap_or("").to_string();
70
71            // Validate agent exists in registry
72            if self.agent_registry.has_agent(&agent_name) {
73                result.push((agent_name, task_str));
74            } else {
75                // Fall back to product agent if unknown
76                result.push(("product".to_string(), task_str));
77            }
78        }
79
80        Ok(result)
81    }
82
83    /// Execute a subtask using the appropriate agent from the registry
84    async fn execute_subtask(
85        &self,
86        agent_name: &str,
87        task: &str,
88        context: &AgentContext,
89    ) -> Result<String> {
90        // Create agent from registry (handles model and tool configuration)
91        let agent = self.agent_registry.create_agent(agent_name).await?;
92        agent.execute(task, context).await
93    }
94}
95
96#[async_trait]
97impl Agent for OrchestratorAgent {
98    async fn execute(&self, input: &str, context: &AgentContext) -> Result<String> {
99        // Decompose the task into subtasks
100        let subtasks = self.decompose_task(input).await?;
101
102        if subtasks.is_empty() {
103            return self.llm.generate(input).await;
104        }
105
106        // Execute subtasks sequentially (could be parallelized in future)
107        let mut results = Vec::new();
108        for (agent_name, task) in subtasks {
109            let result = self.execute_subtask(&agent_name, &task, context).await?;
110            results.push(format!("[{}] {}", agent_name, result));
111        }
112
113        // Synthesize results into final response
114        let synthesis_prompt = format!(
115            "Original query: {}\n\nSubtask results:\n{}\n\nProvide a comprehensive answer:",
116            input,
117            results.join("\n\n")
118        );
119
120        self.llm.generate(&synthesis_prompt).await
121    }
122
123    fn system_prompt(&self) -> String {
124        // Get system prompt from config if available
125        let config = self.state.config_manager.config();
126        config
127            .get_agent("orchestrator")
128            .and_then(|a| a.system_prompt.clone())
129            .unwrap_or_else(|| {
130                "You are an orchestrator agent that coordinates multiple specialized agents to answer complex queries.".to_string()
131            })
132    }
133
134    fn agent_type(&self) -> AgentType {
135        AgentType::Orchestrator
136    }
137}