Skip to main content

ares/agents/
orchestrator.rs

1use crate::{
2    agents::{Agent, AgentRegistry, AgentResponse},
3    llm::LLMClient,
4    types::{AgentContext, AgentType, AppError, Result},
5    AppState,
6};
7use async_trait::async_trait;
8use std::sync::Arc;
9
10/// Orchestrator agent that coordinates multiple specialized agents.
11///
12/// This agent decomposes complex queries into subtasks and delegates
13/// them to appropriate specialized agents via the AgentRegistry.
14pub struct OrchestratorAgent {
15    llm: Box<dyn LLMClient>,
16    state: AppState,
17    agent_registry: Arc<AgentRegistry>,
18}
19
20impl OrchestratorAgent {
21    /// Creates a new OrchestratorAgent with the given dependencies.
22    pub fn new(
23        llm: Box<dyn LLMClient>,
24        state: AppState,
25        agent_registry: Arc<AgentRegistry>,
26    ) -> Self {
27        Self {
28            llm,
29            state,
30            agent_registry,
31        }
32    }
33
34    /// Decompose a complex task into subtasks for specialized agents
35    async fn decompose_task(&self, input: &str) -> Result<Vec<(String, String)>> {
36        // Get available agents from registry
37        let available_agents = self.agent_registry.agent_names();
38        let agent_list = available_agents
39            .iter()
40            .filter(|name| **name != "orchestrator" && **name != "router")
41            .cloned()
42            .collect::<Vec<_>>()
43            .join(", ");
44
45        let system_prompt = format!(
46            r#"You are a task decomposition agent. Break down complex queries into subtasks for specialized agents.
47
48Available agents: {}
49
50Return a JSON array of tasks:
51[
52    {{"agent": "sales", "task": "Get Q1 revenue"}},
53    {{"agent": "product", "task": "List top products"}}
54]
55
56Only respond with valid JSON."#,
57            agent_list
58        );
59
60        let response = self.llm.generate_with_system(&system_prompt, input).await?;
61
62        // Parse JSON response
63        let tasks: Vec<serde_json::Value> = serde_json::from_str(&response)
64            .map_err(|e| AppError::LLM(format!("Failed to parse tasks: {}", e)))?;
65
66        let mut result = Vec::new();
67        for task in tasks {
68            let agent_name = task["agent"].as_str().unwrap_or("product").to_string();
69            let task_str = task["task"].as_str().unwrap_or("").to_string();
70
71            // Validate agent exists in registry
72            if self.agent_registry.has_agent(&agent_name) {
73                result.push((agent_name, task_str));
74            } else {
75                // Fall back to product agent if unknown
76                result.push(("product".to_string(), task_str));
77            }
78        }
79
80        Ok(result)
81    }
82
83    /// Execute a subtask using the appropriate agent from the registry
84    async fn execute_subtask(
85        &self,
86        agent_name: &str,
87        task: &str,
88        context: &AgentContext,
89    ) -> Result<String> {
90        // Create agent from registry (handles model and tool configuration)
91        let agent = self.agent_registry.create_agent(agent_name).await?;
92        let resp = agent.execute(task, context).await?;
93        Ok(resp.content)
94    }
95}
96
97#[async_trait]
98impl Agent for OrchestratorAgent {
99    async fn execute(&self, input: &str, context: &AgentContext) -> Result<AgentResponse> {
100        // Decompose the task into subtasks
101        let subtasks = self.decompose_task(input).await?;
102
103        if subtasks.is_empty() {
104            let content = self.llm.generate(input).await?;
105            return Ok(AgentResponse { content, usage: None, metadata: None });
106        }
107
108        // Execute subtasks sequentially (could be parallelized in future)
109        let mut results = Vec::new();
110        for (agent_name, task) in subtasks {
111            let result = self.execute_subtask(&agent_name, &task, context).await?;
112            results.push(format!("[{}] {}", agent_name, result));
113        }
114
115        // Synthesize results into final response
116        let synthesis_prompt = format!(
117            "Original query: {}\n\nSubtask results:\n{}\n\nProvide a comprehensive answer:",
118            input,
119            results.join("\n\n")
120        );
121
122        let content = self.llm.generate(&synthesis_prompt).await?;
123        Ok(AgentResponse { content, usage: None, metadata: None })
124    }
125
126    fn system_prompt(&self) -> String {
127        // Get system prompt from config if available
128        let config = self.state.config_manager.config();
129        config
130            .get_agent("orchestrator")
131            .and_then(|a| a.system_prompt.clone())
132            .unwrap_or_else(|| {
133                "You are an orchestrator agent that coordinates multiple specialized agents to answer complex queries.".to_string()
134            })
135    }
136
137    fn agent_type(&self) -> AgentType {
138        AgentType::Orchestrator
139    }
140}