Skip to main content

ares/agents/
configurable.rs

1//! Configurable Agent implementation
2//!
3//! This module provides a generic agent that can be configured via TOML.
4//! It replaces the hardcoded agent implementations with a flexible,
5//! configuration-driven approach.
6
7use crate::agents::Agent;
8use crate::llm::LLMClient;
9use crate::tools::registry::ToolRegistry;
10use crate::types::{AgentContext, AgentType, Result, ToolDefinition};
11use crate::utils::toml_config::AgentConfig;
12use async_trait::async_trait;
13use std::sync::Arc;
14
15/// A configurable agent that derives its behavior from TOML configuration
16pub struct ConfigurableAgent {
17    /// The agent's name/type identifier
18    name: String,
19    /// The agent type enum value
20    agent_type: AgentType,
21    /// The LLM client to use for generation
22    llm: Box<dyn LLMClient>,
23    /// The system prompt from configuration
24    system_prompt: String,
25    /// Tools available to this agent
26    tool_registry: Option<Arc<ToolRegistry>>,
27    /// List of tool names this agent is allowed to use
28    allowed_tools: Vec<String>,
29    /// Maximum tool calling iterations
30    max_tool_iterations: usize,
31    /// Whether to execute tools in parallel
32    parallel_tools: bool,
33}
34
35impl ConfigurableAgent {
36    /// Create a new configurable agent from TOML config
37    ///
38    /// # Arguments
39    ///
40    /// * `name` - The agent name (used to determine AgentType)
41    /// * `config` - The agent configuration from ares.toml
42    /// * `llm` - The LLM client (already created from the model config)
43    /// * `tool_registry` - Optional tool registry for tool calling
44    pub fn new(
45        name: &str,
46        config: &AgentConfig,
47        llm: Box<dyn LLMClient>,
48        tool_registry: Option<Arc<ToolRegistry>>,
49    ) -> Self {
50        let agent_type = Self::name_to_type(name);
51        let system_prompt = config
52            .system_prompt
53            .clone()
54            .unwrap_or_else(|| Self::default_system_prompt(name));
55
56        Self {
57            name: name.to_string(),
58            agent_type,
59            llm,
60            system_prompt,
61            tool_registry,
62            allowed_tools: config.tools.clone(),
63            max_tool_iterations: config.max_tool_iterations,
64            parallel_tools: config.parallel_tools,
65        }
66    }
67
68    /// Create a new configurable agent with explicit parameters
69    #[allow(clippy::too_many_arguments)]
70    pub fn with_params(
71        name: &str,
72        agent_type: AgentType,
73        llm: Box<dyn LLMClient>,
74        system_prompt: String,
75        tool_registry: Option<Arc<ToolRegistry>>,
76        allowed_tools: Vec<String>,
77        max_tool_iterations: usize,
78        parallel_tools: bool,
79    ) -> Self {
80        Self {
81            name: name.to_string(),
82            agent_type,
83            llm,
84            system_prompt,
85            tool_registry,
86            allowed_tools,
87            max_tool_iterations,
88            parallel_tools,
89        }
90    }
91
92    /// Convert agent name to AgentType
93    fn name_to_type(name: &str) -> AgentType {
94        AgentType::from_string(name)
95    }
96
97    /// Get default system prompt for an agent type
98    fn default_system_prompt(name: &str) -> String {
99        match name.to_lowercase().as_str() {
100            "router" => r#"You are a routing agent that classifies user queries.
101Available agents: product, invoice, sales, finance, hr, orchestrator.
102Respond with ONLY the agent name (one word, lowercase)."#
103                .to_string(),
104
105            "orchestrator" => r#"You are an orchestrator agent for complex queries.
106Break down requests, delegate to specialists, and synthesize results."#
107                .to_string(),
108
109            "product" => r#"You are a Product Agent for product-related queries.
110Handle catalog, specifications, inventory, and pricing questions."#
111                .to_string(),
112
113            "invoice" => r#"You are an Invoice Agent for billing queries.
114Handle invoices, payments, and billing history."#
115                .to_string(),
116
117            "sales" => r#"You are a Sales Agent for sales analytics.
118Handle performance metrics, revenue, and customer data."#
119                .to_string(),
120
121            "finance" => r#"You are a Finance Agent for financial analysis.
122Handle statements, budgets, and expense management."#
123                .to_string(),
124
125            "hr" => r#"You are an HR Agent for human resources.
126Handle employee info, policies, and benefits."#
127                .to_string(),
128
129            _ => format!("You are a {} agent.", name),
130        }
131    }
132
133    /// Get the agent name
134    pub fn name(&self) -> &str {
135        &self.name
136    }
137
138    /// Get the max tool iterations setting
139    pub fn max_tool_iterations(&self) -> usize {
140        self.max_tool_iterations
141    }
142
143    /// Get the parallel tools setting
144    pub fn parallel_tools(&self) -> bool {
145        self.parallel_tools
146    }
147
148    /// Check if this agent has tools configured
149    pub fn has_tools(&self) -> bool {
150        !self.allowed_tools.is_empty() && self.tool_registry.is_some()
151    }
152
153    /// Get the tool registry (if any)
154    pub fn tool_registry(&self) -> Option<&Arc<ToolRegistry>> {
155        self.tool_registry.as_ref()
156    }
157
158    /// Get the list of allowed tool names for this agent
159    pub fn allowed_tools(&self) -> &[String] {
160        &self.allowed_tools
161    }
162
163    /// Get tool definitions for only this agent's allowed tools
164    ///
165    /// This filters the tool registry to only return tools that:
166    /// 1. Are in this agent's allowed tools list
167    /// 2. Are enabled in the tool registry
168    pub fn get_filtered_tool_definitions(&self) -> Vec<ToolDefinition> {
169        match &self.tool_registry {
170            Some(registry) => {
171                let allowed: Vec<&str> = self.allowed_tools.iter().map(|s| s.as_str()).collect();
172                registry.get_tool_definitions_for(&allowed)
173            }
174            None => Vec::new(),
175        }
176    }
177
178    /// Check if a specific tool is allowed for this agent
179    pub fn can_use_tool(&self, tool_name: &str) -> bool {
180        self.allowed_tools.contains(&tool_name.to_string())
181            && self
182                .tool_registry
183                .as_ref()
184                .map(|r| r.is_enabled(tool_name))
185                .unwrap_or(false)
186    }
187}
188
189#[async_trait]
190impl Agent for ConfigurableAgent {
191    async fn execute(&self, input: &str, context: &AgentContext) -> Result<String> {
192        // Build context with conversation history if available
193        let mut messages = vec![("system".to_string(), self.system_prompt.clone())];
194
195        // Add user memory if available
196        if let Some(memory) = &context.user_memory {
197            let memory_context = format!(
198                "User preferences: {}",
199                memory
200                    .preferences
201                    .iter()
202                    .map(|p| format!("{}: {}", p.key, p.value))
203                    .collect::<Vec<_>>()
204                    .join(", ")
205            );
206            messages.push(("system".to_string(), memory_context));
207        }
208
209        // Add recent conversation history (last 5 messages)
210        for msg in context.conversation_history.iter().rev().take(5).rev() {
211            let role = match msg.role {
212                crate::types::MessageRole::User => "user",
213                crate::types::MessageRole::Assistant => "assistant",
214                _ => "system",
215            };
216            messages.push((role.to_string(), msg.content.clone()));
217        }
218
219        messages.push(("user".to_string(), input.to_string()));
220
221        self.llm.generate_with_history(&messages).await
222    }
223
224    fn system_prompt(&self) -> String {
225        self.system_prompt.clone()
226    }
227
228    fn agent_type(&self) -> AgentType {
229        self.agent_type.clone()
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_name_to_type() {
239        assert!(matches!(
240            ConfigurableAgent::name_to_type("router"),
241            AgentType::Router
242        ));
243        assert!(matches!(
244            ConfigurableAgent::name_to_type("PRODUCT"),
245            AgentType::Product
246        ));
247        // Unknown types now return Custom variant
248        assert!(matches!(
249            ConfigurableAgent::name_to_type("unknown"),
250            AgentType::Custom(_)
251        ));
252        // Verify the custom name is preserved
253        if let AgentType::Custom(name) = ConfigurableAgent::name_to_type("my-custom-agent") {
254            assert_eq!(name, "my-custom-agent");
255        } else {
256            panic!("Expected Custom variant");
257        }
258    }
259
260    #[test]
261    fn test_default_system_prompt() {
262        let prompt = ConfigurableAgent::default_system_prompt("router");
263        assert!(prompt.contains("routing"));
264
265        let prompt = ConfigurableAgent::default_system_prompt("product");
266        assert!(prompt.contains("Product"));
267    }
268
269    #[test]
270    fn test_allowed_tools() {
271        use crate::llm::LLMResponse;
272        use crate::utils::toml_config::AgentConfig;
273        use std::collections::HashMap;
274
275        // Create a mock LLM client (we'll use a simple mock)
276        struct MockLLM;
277
278        #[async_trait]
279        impl LLMClient for MockLLM {
280            async fn generate(&self, _: &str) -> Result<String> {
281                Ok("mock".to_string())
282            }
283            async fn generate_with_system(&self, _: &str, _: &str) -> Result<String> {
284                Ok("mock".to_string())
285            }
286            async fn generate_with_history(&self, _: &[(String, String)]) -> Result<String> {
287                Ok("mock".to_string())
288            }
289            async fn generate_with_tools(
290                &self,
291                _: &str,
292                _: &[ToolDefinition],
293            ) -> Result<LLMResponse> {
294                Ok(LLMResponse {
295                    content: "mock".to_string(),
296                    tool_calls: vec![],
297                    finish_reason: "stop".to_string(),
298                    usage: None,
299                })
300            }
301            async fn stream(
302                &self,
303                _: &str,
304            ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
305            {
306                Ok(Box::new(futures::stream::empty()))
307            }
308            async fn stream_with_system(
309                &self,
310                _: &str,
311                _: &str,
312            ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
313            {
314                Ok(Box::new(futures::stream::empty()))
315            }
316            async fn stream_with_history(
317                &self,
318                _: &[(String, String)],
319            ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
320            {
321                Ok(Box::new(futures::stream::empty()))
322            }
323            fn model_name(&self) -> &str {
324                "mock"
325            }
326            async fn generate_with_tools_and_history(
327                &self,
328                _: &[crate::llm::coordinator::ConversationMessage],
329                _: &[ToolDefinition],
330            ) -> Result<LLMResponse> {
331                Ok(LLMResponse {
332                    content: "mock".to_string(),
333                    tool_calls: vec![],
334                    finish_reason: "stop".to_string(),
335                    usage: None,
336                })
337            }
338        }
339
340        let config = AgentConfig {
341            model: "default".to_string(),
342            system_prompt: None,
343            tools: vec!["calculator".to_string(), "web_search".to_string()],
344            max_tool_iterations: 5,
345            parallel_tools: false,
346            extra: HashMap::new(),
347        };
348
349        let agent = ConfigurableAgent::new(
350            "orchestrator",
351            &config,
352            Box::new(MockLLM),
353            None, // No registry for this test
354        );
355
356        assert_eq!(agent.allowed_tools().len(), 2);
357        assert!(agent.allowed_tools().contains(&"calculator".to_string()));
358        assert!(agent.allowed_tools().contains(&"web_search".to_string()));
359    }
360
361    #[test]
362    fn test_has_tools_requires_both_config_and_registry() {
363        use crate::llm::LLMResponse;
364        use crate::utils::toml_config::AgentConfig;
365        use std::collections::HashMap;
366
367        struct MockLLM;
368
369        #[async_trait]
370        impl LLMClient for MockLLM {
371            async fn generate(&self, _: &str) -> Result<String> {
372                Ok("mock".to_string())
373            }
374            async fn generate_with_system(&self, _: &str, _: &str) -> Result<String> {
375                Ok("mock".to_string())
376            }
377            async fn generate_with_history(&self, _: &[(String, String)]) -> Result<String> {
378                Ok("mock".to_string())
379            }
380            async fn generate_with_tools(
381                &self,
382                _: &str,
383                _: &[ToolDefinition],
384            ) -> Result<LLMResponse> {
385                Ok(LLMResponse {
386                    content: "mock".to_string(),
387                    tool_calls: vec![],
388                    finish_reason: "stop".to_string(),
389                    usage: None,
390                })
391            }
392            async fn stream(
393                &self,
394                _: &str,
395            ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
396            {
397                Ok(Box::new(futures::stream::empty()))
398            }
399            async fn stream_with_system(
400                &self,
401                _: &str,
402                _: &str,
403            ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
404            {
405                Ok(Box::new(futures::stream::empty()))
406            }
407            async fn stream_with_history(
408                &self,
409                _: &[(String, String)],
410            ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
411            {
412                Ok(Box::new(futures::stream::empty()))
413            }
414            fn model_name(&self) -> &str {
415                "mock"
416            }
417            async fn generate_with_tools_and_history(
418                &self,
419                _: &[crate::llm::coordinator::ConversationMessage],
420                _: &[ToolDefinition],
421            ) -> Result<LLMResponse> {
422                Ok(LLMResponse {
423                    content: "mock".to_string(),
424                    tool_calls: vec![],
425                    finish_reason: "stop".to_string(),
426                    usage: None,
427                })
428            }
429        }
430
431        // Agent with tools config but no registry
432        let config = AgentConfig {
433            model: "default".to_string(),
434            system_prompt: None,
435            tools: vec!["calculator".to_string()],
436            max_tool_iterations: 5,
437            parallel_tools: false,
438            extra: HashMap::new(),
439        };
440
441        let agent = ConfigurableAgent::new("orchestrator", &config, Box::new(MockLLM), None);
442        assert!(!agent.has_tools()); // No registry
443
444        // Agent with empty tools
445        let config_empty = AgentConfig {
446            model: "default".to_string(),
447            system_prompt: None,
448            tools: vec![],
449            max_tool_iterations: 5,
450            parallel_tools: false,
451            extra: HashMap::new(),
452        };
453
454        let agent_empty = ConfigurableAgent::new(
455            "product",
456            &config_empty,
457            Box::new(MockLLM),
458            Some(Arc::new(ToolRegistry::new())),
459        );
460        assert!(!agent_empty.has_tools()); // Empty tools list
461    }
462}