ares/agents/
configurable.rs

1//! Configurable Agent implementation
2//!
3//! This module provides a generic agent that can be configured via TOML.
4//! It replaces the hardcoded agent implementations with a flexible,
5//! configuration-driven approach.
6
7use crate::agents::Agent;
8use crate::llm::LLMClient;
9use crate::tools::registry::ToolRegistry;
10use crate::types::{AgentContext, AgentType, Result, ToolDefinition};
11use crate::utils::toml_config::AgentConfig;
12use async_trait::async_trait;
13use std::sync::Arc;
14
15/// A configurable agent that derives its behavior from TOML configuration
16pub struct ConfigurableAgent {
17    /// The agent's name/type identifier
18    name: String,
19    /// The agent type enum value
20    agent_type: AgentType,
21    /// The LLM client to use for generation
22    llm: Box<dyn LLMClient>,
23    /// The system prompt from configuration
24    system_prompt: String,
25    /// Tools available to this agent
26    tool_registry: Option<Arc<ToolRegistry>>,
27    /// List of tool names this agent is allowed to use
28    allowed_tools: Vec<String>,
29    /// Maximum tool calling iterations
30    max_tool_iterations: usize,
31    /// Whether to execute tools in parallel
32    parallel_tools: bool,
33}
34
35impl ConfigurableAgent {
36    /// Create a new configurable agent from TOML config
37    ///
38    /// # Arguments
39    ///
40    /// * `name` - The agent name (used to determine AgentType)
41    /// * `config` - The agent configuration from ares.toml
42    /// * `llm` - The LLM client (already created from the model config)
43    /// * `tool_registry` - Optional tool registry for tool calling
44    pub fn new(
45        name: &str,
46        config: &AgentConfig,
47        llm: Box<dyn LLMClient>,
48        tool_registry: Option<Arc<ToolRegistry>>,
49    ) -> Self {
50        let agent_type = Self::name_to_type(name);
51        let system_prompt = config
52            .system_prompt
53            .clone()
54            .unwrap_or_else(|| Self::default_system_prompt(name));
55
56        Self {
57            name: name.to_string(),
58            agent_type,
59            llm,
60            system_prompt,
61            tool_registry,
62            allowed_tools: config.tools.clone(),
63            max_tool_iterations: config.max_tool_iterations,
64            parallel_tools: config.parallel_tools,
65        }
66    }
67
68    /// Create a new configurable agent with explicit parameters
69    #[allow(clippy::too_many_arguments)]
70    pub fn with_params(
71        name: &str,
72        agent_type: AgentType,
73        llm: Box<dyn LLMClient>,
74        system_prompt: String,
75        tool_registry: Option<Arc<ToolRegistry>>,
76        allowed_tools: Vec<String>,
77        max_tool_iterations: usize,
78        parallel_tools: bool,
79    ) -> Self {
80        Self {
81            name: name.to_string(),
82            agent_type,
83            llm,
84            system_prompt,
85            tool_registry,
86            allowed_tools,
87            max_tool_iterations,
88            parallel_tools,
89        }
90    }
91
92    /// Convert agent name to AgentType
93    fn name_to_type(name: &str) -> AgentType {
94        match name.to_lowercase().as_str() {
95            "router" => AgentType::Router,
96            "orchestrator" => AgentType::Orchestrator,
97            "product" => AgentType::Product,
98            "invoice" => AgentType::Invoice,
99            "sales" => AgentType::Sales,
100            "finance" => AgentType::Finance,
101            "hr" => AgentType::HR,
102            _ => AgentType::Orchestrator, // Default to orchestrator for unknown types
103        }
104    }
105
106    /// Get default system prompt for an agent type
107    fn default_system_prompt(name: &str) -> String {
108        match name.to_lowercase().as_str() {
109            "router" => r#"You are a routing agent that classifies user queries.
110Available agents: product, invoice, sales, finance, hr, orchestrator.
111Respond with ONLY the agent name (one word, lowercase)."#
112                .to_string(),
113
114            "orchestrator" => r#"You are an orchestrator agent for complex queries.
115Break down requests, delegate to specialists, and synthesize results."#
116                .to_string(),
117
118            "product" => r#"You are a Product Agent for product-related queries.
119Handle catalog, specifications, inventory, and pricing questions."#
120                .to_string(),
121
122            "invoice" => r#"You are an Invoice Agent for billing queries.
123Handle invoices, payments, and billing history."#
124                .to_string(),
125
126            "sales" => r#"You are a Sales Agent for sales analytics.
127Handle performance metrics, revenue, and customer data."#
128                .to_string(),
129
130            "finance" => r#"You are a Finance Agent for financial analysis.
131Handle statements, budgets, and expense management."#
132                .to_string(),
133
134            "hr" => r#"You are an HR Agent for human resources.
135Handle employee info, policies, and benefits."#
136                .to_string(),
137
138            _ => format!("You are a {} agent.", name),
139        }
140    }
141
142    /// Get the agent name
143    pub fn name(&self) -> &str {
144        &self.name
145    }
146
147    /// Get the max tool iterations setting
148    pub fn max_tool_iterations(&self) -> usize {
149        self.max_tool_iterations
150    }
151
152    /// Get the parallel tools setting
153    pub fn parallel_tools(&self) -> bool {
154        self.parallel_tools
155    }
156
157    /// Check if this agent has tools configured
158    pub fn has_tools(&self) -> bool {
159        !self.allowed_tools.is_empty() && self.tool_registry.is_some()
160    }
161
162    /// Get the tool registry (if any)
163    pub fn tool_registry(&self) -> Option<&Arc<ToolRegistry>> {
164        self.tool_registry.as_ref()
165    }
166
167    /// Get the list of allowed tool names for this agent
168    pub fn allowed_tools(&self) -> &[String] {
169        &self.allowed_tools
170    }
171
172    /// Get tool definitions for only this agent's allowed tools
173    ///
174    /// This filters the tool registry to only return tools that:
175    /// 1. Are in this agent's allowed tools list
176    /// 2. Are enabled in the tool registry
177    pub fn get_filtered_tool_definitions(&self) -> Vec<ToolDefinition> {
178        match &self.tool_registry {
179            Some(registry) => {
180                let allowed: Vec<&str> = self.allowed_tools.iter().map(|s| s.as_str()).collect();
181                registry.get_tool_definitions_for(&allowed)
182            }
183            None => Vec::new(),
184        }
185    }
186
187    /// Check if a specific tool is allowed for this agent
188    pub fn can_use_tool(&self, tool_name: &str) -> bool {
189        self.allowed_tools.contains(&tool_name.to_string())
190            && self
191                .tool_registry
192                .as_ref()
193                .map(|r| r.is_enabled(tool_name))
194                .unwrap_or(false)
195    }
196}
197
198#[async_trait]
199impl Agent for ConfigurableAgent {
200    async fn execute(&self, input: &str, context: &AgentContext) -> Result<String> {
201        // Build context with conversation history if available
202        let mut messages = vec![("system".to_string(), self.system_prompt.clone())];
203
204        // Add user memory if available
205        if let Some(memory) = &context.user_memory {
206            let memory_context = format!(
207                "User preferences: {}",
208                memory
209                    .preferences
210                    .iter()
211                    .map(|p| format!("{}: {}", p.key, p.value))
212                    .collect::<Vec<_>>()
213                    .join(", ")
214            );
215            messages.push(("system".to_string(), memory_context));
216        }
217
218        // Add recent conversation history (last 5 messages)
219        for msg in context.conversation_history.iter().rev().take(5).rev() {
220            let role = match msg.role {
221                crate::types::MessageRole::User => "user",
222                crate::types::MessageRole::Assistant => "assistant",
223                _ => "system",
224            };
225            messages.push((role.to_string(), msg.content.clone()));
226        }
227
228        messages.push(("user".to_string(), input.to_string()));
229
230        self.llm.generate_with_history(&messages).await
231    }
232
233    fn system_prompt(&self) -> String {
234        self.system_prompt.clone()
235    }
236
237    fn agent_type(&self) -> AgentType {
238        self.agent_type
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_name_to_type() {
248        assert!(matches!(
249            ConfigurableAgent::name_to_type("router"),
250            AgentType::Router
251        ));
252        assert!(matches!(
253            ConfigurableAgent::name_to_type("PRODUCT"),
254            AgentType::Product
255        ));
256        assert!(matches!(
257            ConfigurableAgent::name_to_type("unknown"),
258            AgentType::Orchestrator
259        ));
260    }
261
262    #[test]
263    fn test_default_system_prompt() {
264        let prompt = ConfigurableAgent::default_system_prompt("router");
265        assert!(prompt.contains("routing"));
266
267        let prompt = ConfigurableAgent::default_system_prompt("product");
268        assert!(prompt.contains("Product"));
269    }
270
271    #[test]
272    fn test_allowed_tools() {
273        use crate::llm::LLMResponse;
274        use crate::utils::toml_config::AgentConfig;
275        use std::collections::HashMap;
276
277        // Create a mock LLM client (we'll use a simple mock)
278        struct MockLLM;
279
280        #[async_trait]
281        impl LLMClient for MockLLM {
282            async fn generate(&self, _: &str) -> Result<String> {
283                Ok("mock".to_string())
284            }
285            async fn generate_with_system(&self, _: &str, _: &str) -> Result<String> {
286                Ok("mock".to_string())
287            }
288            async fn generate_with_history(&self, _: &[(String, String)]) -> Result<String> {
289                Ok("mock".to_string())
290            }
291            async fn generate_with_tools(
292                &self,
293                _: &str,
294                _: &[ToolDefinition],
295            ) -> Result<LLMResponse> {
296                Ok(LLMResponse {
297                    content: "mock".to_string(),
298                    tool_calls: vec![],
299                    finish_reason: "stop".to_string(),
300                })
301            }
302            async fn stream(
303                &self,
304                _: &str,
305            ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
306            {
307                Ok(Box::new(futures::stream::empty()))
308            }
309            fn model_name(&self) -> &str {
310                "mock"
311            }
312        }
313
314        let config = AgentConfig {
315            model: "default".to_string(),
316            system_prompt: None,
317            tools: vec!["calculator".to_string(), "web_search".to_string()],
318            max_tool_iterations: 5,
319            parallel_tools: false,
320            extra: HashMap::new(),
321        };
322
323        let agent = ConfigurableAgent::new(
324            "orchestrator",
325            &config,
326            Box::new(MockLLM),
327            None, // No registry for this test
328        );
329
330        assert_eq!(agent.allowed_tools().len(), 2);
331        assert!(agent.allowed_tools().contains(&"calculator".to_string()));
332        assert!(agent.allowed_tools().contains(&"web_search".to_string()));
333    }
334
335    #[test]
336    fn test_has_tools_requires_both_config_and_registry() {
337        use crate::llm::LLMResponse;
338        use crate::utils::toml_config::AgentConfig;
339        use std::collections::HashMap;
340
341        struct MockLLM;
342
343        #[async_trait]
344        impl LLMClient for MockLLM {
345            async fn generate(&self, _: &str) -> Result<String> {
346                Ok("mock".to_string())
347            }
348            async fn generate_with_system(&self, _: &str, _: &str) -> Result<String> {
349                Ok("mock".to_string())
350            }
351            async fn generate_with_history(&self, _: &[(String, String)]) -> Result<String> {
352                Ok("mock".to_string())
353            }
354            async fn generate_with_tools(
355                &self,
356                _: &str,
357                _: &[ToolDefinition],
358            ) -> Result<LLMResponse> {
359                Ok(LLMResponse {
360                    content: "mock".to_string(),
361                    tool_calls: vec![],
362                    finish_reason: "stop".to_string(),
363                })
364            }
365            async fn stream(
366                &self,
367                _: &str,
368            ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
369            {
370                Ok(Box::new(futures::stream::empty()))
371            }
372            fn model_name(&self) -> &str {
373                "mock"
374            }
375        }
376
377        // Agent with tools config but no registry
378        let config = AgentConfig {
379            model: "default".to_string(),
380            system_prompt: None,
381            tools: vec!["calculator".to_string()],
382            max_tool_iterations: 5,
383            parallel_tools: false,
384            extra: HashMap::new(),
385        };
386
387        let agent = ConfigurableAgent::new("orchestrator", &config, Box::new(MockLLM), None);
388        assert!(!agent.has_tools()); // No registry
389
390        // Agent with empty tools
391        let config_empty = AgentConfig {
392            model: "default".to_string(),
393            system_prompt: None,
394            tools: vec![],
395            max_tool_iterations: 5,
396            parallel_tools: false,
397            extra: HashMap::new(),
398        };
399
400        let agent_empty = ConfigurableAgent::new(
401            "product",
402            &config_empty,
403            Box::new(MockLLM),
404            Some(Arc::new(ToolRegistry::new())),
405        );
406        assert!(!agent_empty.has_tools()); // Empty tools list
407    }
408}