Skip to main content

ares/agents/
configurable.rs

1//! Configurable Agent implementation
2//!
3//! This module provides a generic agent that can be configured via TOML.
4//! It replaces the hardcoded agent implementations with a flexible,
5//! configuration-driven approach.
6
7use crate::agents::Agent;
8use crate::llm::LLMClient;
9use crate::tools::registry::ToolRegistry;
10use crate::types::{AgentContext, AgentType, Result, ToolDefinition};
11use crate::utils::toml_config::AgentConfig;
12use async_trait::async_trait;
13use std::sync::Arc;
14
15/// A configurable agent that derives its behavior from TOML configuration
16pub struct ConfigurableAgent {
17    /// The agent's name/type identifier
18    name: String,
19    /// The agent type enum value
20    agent_type: AgentType,
21    /// The LLM client to use for generation
22    llm: Box<dyn LLMClient>,
23    /// The system prompt from configuration
24    system_prompt: String,
25    /// Tools available to this agent
26    tool_registry: Option<Arc<ToolRegistry>>,
27    /// List of tool names this agent is allowed to use
28    allowed_tools: Vec<String>,
29    /// Maximum tool calling iterations
30    max_tool_iterations: usize,
31    /// Whether to execute tools in parallel
32    parallel_tools: bool,
33}
34
35impl ConfigurableAgent {
36    /// Create a new configurable agent from TOML config
37    ///
38    /// # Arguments
39    ///
40    /// * `name` - The agent name (used to determine AgentType)
41    /// * `config` - The agent configuration from ares.toml
42    /// * `llm` - The LLM client (already created from the model config)
43    /// * `tool_registry` - Optional tool registry for tool calling
44    pub fn new(
45        name: &str,
46        config: &AgentConfig,
47        llm: Box<dyn LLMClient>,
48        tool_registry: Option<Arc<ToolRegistry>>,
49    ) -> Self {
50        let agent_type = Self::name_to_type(name);
51        let system_prompt = config
52            .system_prompt
53            .clone()
54            .unwrap_or_else(|| Self::default_system_prompt(name));
55
56        Self {
57            name: name.to_string(),
58            agent_type,
59            llm,
60            system_prompt,
61            tool_registry,
62            allowed_tools: config.tools.clone(),
63            max_tool_iterations: config.max_tool_iterations,
64            parallel_tools: config.parallel_tools,
65        }
66    }
67
68    /// Create a new configurable agent with explicit parameters
69    #[allow(clippy::too_many_arguments)]
70    pub fn with_params(
71        name: &str,
72        agent_type: AgentType,
73        llm: Box<dyn LLMClient>,
74        system_prompt: String,
75        tool_registry: Option<Arc<ToolRegistry>>,
76        allowed_tools: Vec<String>,
77        max_tool_iterations: usize,
78        parallel_tools: bool,
79    ) -> Self {
80        Self {
81            name: name.to_string(),
82            agent_type,
83            llm,
84            system_prompt,
85            tool_registry,
86            allowed_tools,
87            max_tool_iterations,
88            parallel_tools,
89        }
90    }
91
92    /// Convert agent name to AgentType
93    fn name_to_type(name: &str) -> AgentType {
94        AgentType::from_string(name)
95    }
96
97    /// Get default system prompt for an agent type
98    fn default_system_prompt(name: &str) -> String {
99        match name.to_lowercase().as_str() {
100            "router" => r#"You are a routing agent that classifies user queries.
101Available agents: product, invoice, sales, finance, hr, orchestrator.
102Respond with ONLY the agent name (one word, lowercase)."#
103                .to_string(),
104
105            "orchestrator" => r#"You are an orchestrator agent for complex queries.
106Break down requests, delegate to specialists, and synthesize results."#
107                .to_string(),
108
109            "product" => r#"You are a Product Agent for product-related queries.
110Handle catalog, specifications, inventory, and pricing questions."#
111                .to_string(),
112
113            "invoice" => r#"You are an Invoice Agent for billing queries.
114Handle invoices, payments, and billing history."#
115                .to_string(),
116
117            "sales" => r#"You are a Sales Agent for sales analytics.
118Handle performance metrics, revenue, and customer data."#
119                .to_string(),
120
121            "finance" => r#"You are a Finance Agent for financial analysis.
122Handle statements, budgets, and expense management."#
123                .to_string(),
124
125            "hr" => r#"You are an HR Agent for human resources.
126Handle employee info, policies, and benefits."#
127                .to_string(),
128
129            _ => format!("You are a {} agent.", name),
130        }
131    }
132
133    /// Get the agent name
134    pub fn name(&self) -> &str {
135        &self.name
136    }
137
138    /// Get the max tool iterations setting
139    pub fn max_tool_iterations(&self) -> usize {
140        self.max_tool_iterations
141    }
142
143    /// Get the parallel tools setting
144    pub fn parallel_tools(&self) -> bool {
145        self.parallel_tools
146    }
147
148    /// Check if this agent has tools configured
149    pub fn has_tools(&self) -> bool {
150        !self.allowed_tools.is_empty() && self.tool_registry.is_some()
151    }
152
153    /// Get the tool registry (if any)
154    pub fn tool_registry(&self) -> Option<&Arc<ToolRegistry>> {
155        self.tool_registry.as_ref()
156    }
157
158    /// Get the list of allowed tool names for this agent
159    pub fn allowed_tools(&self) -> &[String] {
160        &self.allowed_tools
161    }
162
163    /// Get tool definitions for only this agent's allowed tools
164    ///
165    /// This filters the tool registry to only return tools that:
166    /// 1. Are in this agent's allowed tools list
167    /// 2. Are enabled in the tool registry
168    pub fn get_filtered_tool_definitions(&self) -> Vec<ToolDefinition> {
169        match &self.tool_registry {
170            Some(registry) => {
171                let allowed: Vec<&str> = self.allowed_tools.iter().map(|s| s.as_str()).collect();
172                registry.get_tool_definitions_for(&allowed)
173            }
174            None => Vec::new(),
175        }
176    }
177
178    /// Check if a specific tool is allowed for this agent
179    pub fn can_use_tool(&self, tool_name: &str) -> bool {
180        self.allowed_tools.contains(&tool_name.to_string())
181            && self
182                .tool_registry
183                .as_ref()
184                .map(|r| r.is_enabled(tool_name))
185                .unwrap_or(false)
186    }
187}
188
189#[async_trait]
190impl Agent for ConfigurableAgent {
191    async fn execute(&self, input: &str, context: &AgentContext) -> Result<String> {
192        // Build context with conversation history if available
193        let mut messages = vec![("system".to_string(), self.system_prompt.clone())];
194
195        // Add user memory if available
196        if let Some(memory) = &context.user_memory {
197            let memory_context = format!(
198                "User preferences: {}",
199                memory
200                    .preferences
201                    .iter()
202                    .map(|p| format!("{}: {}", p.key, p.value))
203                    .collect::<Vec<_>>()
204                    .join(", ")
205            );
206            messages.push(("system".to_string(), memory_context));
207        }
208
209        // Add recent conversation history (last 5 messages)
210        for msg in context.conversation_history.iter().rev().take(5).rev() {
211            let role = match msg.role {
212                crate::types::MessageRole::User => "user",
213                crate::types::MessageRole::Assistant => "assistant",
214                _ => "system",
215            };
216            messages.push((role.to_string(), msg.content.clone()));
217        }
218
219        messages.push(("user".to_string(), input.to_string()));
220
221        self.llm.generate_with_history(&messages).await
222    }
223
224    fn system_prompt(&self) -> String {
225        self.system_prompt.clone()
226    }
227
228    fn agent_type(&self) -> AgentType {
229        self.agent_type.clone()
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_name_to_type() {
239        assert!(matches!(
240            ConfigurableAgent::name_to_type("router"),
241            AgentType::Router
242        ));
243        assert!(matches!(
244            ConfigurableAgent::name_to_type("PRODUCT"),
245            AgentType::Product
246        ));
247        // Unknown types now return Custom variant
248        assert!(matches!(
249            ConfigurableAgent::name_to_type("unknown"),
250            AgentType::Custom(_)
251        ));
252        // Verify the custom name is preserved
253        if let AgentType::Custom(name) = ConfigurableAgent::name_to_type("my-custom-agent") {
254            assert_eq!(name, "my-custom-agent");
255        } else {
256            panic!("Expected Custom variant");
257        }
258    }
259
260    #[test]
261    fn test_default_system_prompt() {
262        let prompt = ConfigurableAgent::default_system_prompt("router");
263        assert!(prompt.contains("routing"));
264
265        let prompt = ConfigurableAgent::default_system_prompt("product");
266        assert!(prompt.contains("Product"));
267    }
268
269    #[test]
270    fn test_allowed_tools() {
271        use crate::llm::LLMResponse;
272        use crate::utils::toml_config::AgentConfig;
273        use std::collections::HashMap;
274
275        // Create a mock LLM client (we'll use a simple mock)
276        struct MockLLM;
277
278        #[async_trait]
279        impl LLMClient for MockLLM {
280            async fn generate(&self, _: &str) -> Result<String> {
281                Ok("mock".to_string())
282            }
283            async fn generate_with_system(&self, _: &str, _: &str) -> Result<String> {
284                Ok("mock".to_string())
285            }
286            async fn generate_with_history(&self, _: &[(String, String)]) -> Result<String> {
287                Ok("mock".to_string())
288            }
289            async fn generate_with_tools(
290                &self,
291                _: &str,
292                _: &[ToolDefinition],
293            ) -> Result<LLMResponse> {
294                Ok(LLMResponse {
295                    content: "mock".to_string(),
296                    tool_calls: vec![],
297                    finish_reason: "stop".to_string(),
298                })
299            }
300            async fn stream(
301                &self,
302                _: &str,
303            ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
304            {
305                Ok(Box::new(futures::stream::empty()))
306            }
307            async fn stream_with_system(
308                &self,
309                _: &str,
310                _: &str,
311            ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
312            {
313                Ok(Box::new(futures::stream::empty()))
314            }
315            async fn stream_with_history(
316                &self,
317                _: &[(String, String)],
318            ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
319            {
320                Ok(Box::new(futures::stream::empty()))
321            }
322            fn model_name(&self) -> &str {
323                "mock"
324            }
325        }
326
327        let config = AgentConfig {
328            model: "default".to_string(),
329            system_prompt: None,
330            tools: vec!["calculator".to_string(), "web_search".to_string()],
331            max_tool_iterations: 5,
332            parallel_tools: false,
333            extra: HashMap::new(),
334        };
335
336        let agent = ConfigurableAgent::new(
337            "orchestrator",
338            &config,
339            Box::new(MockLLM),
340            None, // No registry for this test
341        );
342
343        assert_eq!(agent.allowed_tools().len(), 2);
344        assert!(agent.allowed_tools().contains(&"calculator".to_string()));
345        assert!(agent.allowed_tools().contains(&"web_search".to_string()));
346    }
347
348    #[test]
349    fn test_has_tools_requires_both_config_and_registry() {
350        use crate::llm::LLMResponse;
351        use crate::utils::toml_config::AgentConfig;
352        use std::collections::HashMap;
353
354        struct MockLLM;
355
356        #[async_trait]
357        impl LLMClient for MockLLM {
358            async fn generate(&self, _: &str) -> Result<String> {
359                Ok("mock".to_string())
360            }
361            async fn generate_with_system(&self, _: &str, _: &str) -> Result<String> {
362                Ok("mock".to_string())
363            }
364            async fn generate_with_history(&self, _: &[(String, String)]) -> Result<String> {
365                Ok("mock".to_string())
366            }
367            async fn generate_with_tools(
368                &self,
369                _: &str,
370                _: &[ToolDefinition],
371            ) -> Result<LLMResponse> {
372                Ok(LLMResponse {
373                    content: "mock".to_string(),
374                    tool_calls: vec![],
375                    finish_reason: "stop".to_string(),
376                })
377            }
378            async fn stream(
379                &self,
380                _: &str,
381            ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
382            {
383                Ok(Box::new(futures::stream::empty()))
384            }
385            async fn stream_with_system(
386                &self,
387                _: &str,
388                _: &str,
389            ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
390            {
391                Ok(Box::new(futures::stream::empty()))
392            }
393            async fn stream_with_history(
394                &self,
395                _: &[(String, String)],
396            ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
397            {
398                Ok(Box::new(futures::stream::empty()))
399            }
400            fn model_name(&self) -> &str {
401                "mock"
402            }
403        }
404
405        // Agent with tools config but no registry
406        let config = AgentConfig {
407            model: "default".to_string(),
408            system_prompt: None,
409            tools: vec!["calculator".to_string()],
410            max_tool_iterations: 5,
411            parallel_tools: false,
412            extra: HashMap::new(),
413        };
414
415        let agent = ConfigurableAgent::new("orchestrator", &config, Box::new(MockLLM), None);
416        assert!(!agent.has_tools()); // No registry
417
418        // Agent with empty tools
419        let config_empty = AgentConfig {
420            model: "default".to_string(),
421            system_prompt: None,
422            tools: vec![],
423            max_tool_iterations: 5,
424            parallel_tools: false,
425            extra: HashMap::new(),
426        };
427
428        let agent_empty = ConfigurableAgent::new(
429            "product",
430            &config_empty,
431            Box::new(MockLLM),
432            Some(Arc::new(ToolRegistry::new())),
433        );
434        assert!(!agent_empty.has_tools()); // Empty tools list
435    }
436}