1use crate::agents::Agent;
8use crate::llm::LLMClient;
9use crate::tools::registry::ToolRegistry;
10use crate::types::{AgentContext, AgentType, Result, ToolDefinition};
11use crate::utils::toml_config::AgentConfig;
12use async_trait::async_trait;
13use std::sync::Arc;
14
15pub struct ConfigurableAgent {
17 name: String,
19 agent_type: AgentType,
21 llm: Box<dyn LLMClient>,
23 system_prompt: String,
25 tool_registry: Option<Arc<ToolRegistry>>,
27 allowed_tools: Vec<String>,
29 max_tool_iterations: usize,
31 parallel_tools: bool,
33}
34
35impl ConfigurableAgent {
36 pub fn new(
45 name: &str,
46 config: &AgentConfig,
47 llm: Box<dyn LLMClient>,
48 tool_registry: Option<Arc<ToolRegistry>>,
49 ) -> Self {
50 let agent_type = Self::name_to_type(name);
51 let system_prompt = config
52 .system_prompt
53 .clone()
54 .unwrap_or_else(|| Self::default_system_prompt(name));
55
56 Self {
57 name: name.to_string(),
58 agent_type,
59 llm,
60 system_prompt,
61 tool_registry,
62 allowed_tools: config.tools.clone(),
63 max_tool_iterations: config.max_tool_iterations,
64 parallel_tools: config.parallel_tools,
65 }
66 }
67
68 #[allow(clippy::too_many_arguments)]
70 pub fn with_params(
71 name: &str,
72 agent_type: AgentType,
73 llm: Box<dyn LLMClient>,
74 system_prompt: String,
75 tool_registry: Option<Arc<ToolRegistry>>,
76 allowed_tools: Vec<String>,
77 max_tool_iterations: usize,
78 parallel_tools: bool,
79 ) -> Self {
80 Self {
81 name: name.to_string(),
82 agent_type,
83 llm,
84 system_prompt,
85 tool_registry,
86 allowed_tools,
87 max_tool_iterations,
88 parallel_tools,
89 }
90 }
91
92 fn name_to_type(name: &str) -> AgentType {
94 match name.to_lowercase().as_str() {
95 "router" => AgentType::Router,
96 "orchestrator" => AgentType::Orchestrator,
97 "product" => AgentType::Product,
98 "invoice" => AgentType::Invoice,
99 "sales" => AgentType::Sales,
100 "finance" => AgentType::Finance,
101 "hr" => AgentType::HR,
102 _ => AgentType::Orchestrator, }
104 }
105
106 fn default_system_prompt(name: &str) -> String {
108 match name.to_lowercase().as_str() {
109 "router" => r#"You are a routing agent that classifies user queries.
110Available agents: product, invoice, sales, finance, hr, orchestrator.
111Respond with ONLY the agent name (one word, lowercase)."#
112 .to_string(),
113
114 "orchestrator" => r#"You are an orchestrator agent for complex queries.
115Break down requests, delegate to specialists, and synthesize results."#
116 .to_string(),
117
118 "product" => r#"You are a Product Agent for product-related queries.
119Handle catalog, specifications, inventory, and pricing questions."#
120 .to_string(),
121
122 "invoice" => r#"You are an Invoice Agent for billing queries.
123Handle invoices, payments, and billing history."#
124 .to_string(),
125
126 "sales" => r#"You are a Sales Agent for sales analytics.
127Handle performance metrics, revenue, and customer data."#
128 .to_string(),
129
130 "finance" => r#"You are a Finance Agent for financial analysis.
131Handle statements, budgets, and expense management."#
132 .to_string(),
133
134 "hr" => r#"You are an HR Agent for human resources.
135Handle employee info, policies, and benefits."#
136 .to_string(),
137
138 _ => format!("You are a {} agent.", name),
139 }
140 }
141
142 pub fn name(&self) -> &str {
144 &self.name
145 }
146
147 pub fn max_tool_iterations(&self) -> usize {
149 self.max_tool_iterations
150 }
151
152 pub fn parallel_tools(&self) -> bool {
154 self.parallel_tools
155 }
156
157 pub fn has_tools(&self) -> bool {
159 !self.allowed_tools.is_empty() && self.tool_registry.is_some()
160 }
161
162 pub fn tool_registry(&self) -> Option<&Arc<ToolRegistry>> {
164 self.tool_registry.as_ref()
165 }
166
167 pub fn allowed_tools(&self) -> &[String] {
169 &self.allowed_tools
170 }
171
172 pub fn get_filtered_tool_definitions(&self) -> Vec<ToolDefinition> {
178 match &self.tool_registry {
179 Some(registry) => {
180 let allowed: Vec<&str> = self.allowed_tools.iter().map(|s| s.as_str()).collect();
181 registry.get_tool_definitions_for(&allowed)
182 }
183 None => Vec::new(),
184 }
185 }
186
187 pub fn can_use_tool(&self, tool_name: &str) -> bool {
189 self.allowed_tools.contains(&tool_name.to_string())
190 && self
191 .tool_registry
192 .as_ref()
193 .map(|r| r.is_enabled(tool_name))
194 .unwrap_or(false)
195 }
196}
197
198#[async_trait]
199impl Agent for ConfigurableAgent {
200 async fn execute(&self, input: &str, context: &AgentContext) -> Result<String> {
201 let mut messages = vec![("system".to_string(), self.system_prompt.clone())];
203
204 if let Some(memory) = &context.user_memory {
206 let memory_context = format!(
207 "User preferences: {}",
208 memory
209 .preferences
210 .iter()
211 .map(|p| format!("{}: {}", p.key, p.value))
212 .collect::<Vec<_>>()
213 .join(", ")
214 );
215 messages.push(("system".to_string(), memory_context));
216 }
217
218 for msg in context.conversation_history.iter().rev().take(5).rev() {
220 let role = match msg.role {
221 crate::types::MessageRole::User => "user",
222 crate::types::MessageRole::Assistant => "assistant",
223 _ => "system",
224 };
225 messages.push((role.to_string(), msg.content.clone()));
226 }
227
228 messages.push(("user".to_string(), input.to_string()));
229
230 self.llm.generate_with_history(&messages).await
231 }
232
233 fn system_prompt(&self) -> String {
234 self.system_prompt.clone()
235 }
236
237 fn agent_type(&self) -> AgentType {
238 self.agent_type
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn test_name_to_type() {
248 assert!(matches!(
249 ConfigurableAgent::name_to_type("router"),
250 AgentType::Router
251 ));
252 assert!(matches!(
253 ConfigurableAgent::name_to_type("PRODUCT"),
254 AgentType::Product
255 ));
256 assert!(matches!(
257 ConfigurableAgent::name_to_type("unknown"),
258 AgentType::Orchestrator
259 ));
260 }
261
262 #[test]
263 fn test_default_system_prompt() {
264 let prompt = ConfigurableAgent::default_system_prompt("router");
265 assert!(prompt.contains("routing"));
266
267 let prompt = ConfigurableAgent::default_system_prompt("product");
268 assert!(prompt.contains("Product"));
269 }
270
271 #[test]
272 fn test_allowed_tools() {
273 use crate::llm::LLMResponse;
274 use crate::utils::toml_config::AgentConfig;
275 use std::collections::HashMap;
276
277 struct MockLLM;
279
280 #[async_trait]
281 impl LLMClient for MockLLM {
282 async fn generate(&self, _: &str) -> Result<String> {
283 Ok("mock".to_string())
284 }
285 async fn generate_with_system(&self, _: &str, _: &str) -> Result<String> {
286 Ok("mock".to_string())
287 }
288 async fn generate_with_history(&self, _: &[(String, String)]) -> Result<String> {
289 Ok("mock".to_string())
290 }
291 async fn generate_with_tools(
292 &self,
293 _: &str,
294 _: &[ToolDefinition],
295 ) -> Result<LLMResponse> {
296 Ok(LLMResponse {
297 content: "mock".to_string(),
298 tool_calls: vec![],
299 finish_reason: "stop".to_string(),
300 })
301 }
302 async fn stream(
303 &self,
304 _: &str,
305 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
306 {
307 Ok(Box::new(futures::stream::empty()))
308 }
309 fn model_name(&self) -> &str {
310 "mock"
311 }
312 }
313
314 let config = AgentConfig {
315 model: "default".to_string(),
316 system_prompt: None,
317 tools: vec!["calculator".to_string(), "web_search".to_string()],
318 max_tool_iterations: 5,
319 parallel_tools: false,
320 extra: HashMap::new(),
321 };
322
323 let agent = ConfigurableAgent::new(
324 "orchestrator",
325 &config,
326 Box::new(MockLLM),
327 None, );
329
330 assert_eq!(agent.allowed_tools().len(), 2);
331 assert!(agent.allowed_tools().contains(&"calculator".to_string()));
332 assert!(agent.allowed_tools().contains(&"web_search".to_string()));
333 }
334
335 #[test]
336 fn test_has_tools_requires_both_config_and_registry() {
337 use crate::llm::LLMResponse;
338 use crate::utils::toml_config::AgentConfig;
339 use std::collections::HashMap;
340
341 struct MockLLM;
342
343 #[async_trait]
344 impl LLMClient for MockLLM {
345 async fn generate(&self, _: &str) -> Result<String> {
346 Ok("mock".to_string())
347 }
348 async fn generate_with_system(&self, _: &str, _: &str) -> Result<String> {
349 Ok("mock".to_string())
350 }
351 async fn generate_with_history(&self, _: &[(String, String)]) -> Result<String> {
352 Ok("mock".to_string())
353 }
354 async fn generate_with_tools(
355 &self,
356 _: &str,
357 _: &[ToolDefinition],
358 ) -> Result<LLMResponse> {
359 Ok(LLMResponse {
360 content: "mock".to_string(),
361 tool_calls: vec![],
362 finish_reason: "stop".to_string(),
363 })
364 }
365 async fn stream(
366 &self,
367 _: &str,
368 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
369 {
370 Ok(Box::new(futures::stream::empty()))
371 }
372 fn model_name(&self) -> &str {
373 "mock"
374 }
375 }
376
377 let config = AgentConfig {
379 model: "default".to_string(),
380 system_prompt: None,
381 tools: vec!["calculator".to_string()],
382 max_tool_iterations: 5,
383 parallel_tools: false,
384 extra: HashMap::new(),
385 };
386
387 let agent = ConfigurableAgent::new("orchestrator", &config, Box::new(MockLLM), None);
388 assert!(!agent.has_tools()); let config_empty = AgentConfig {
392 model: "default".to_string(),
393 system_prompt: None,
394 tools: vec![],
395 max_tool_iterations: 5,
396 parallel_tools: false,
397 extra: HashMap::new(),
398 };
399
400 let agent_empty = ConfigurableAgent::new(
401 "product",
402 &config_empty,
403 Box::new(MockLLM),
404 Some(Arc::new(ToolRegistry::new())),
405 );
406 assert!(!agent_empty.has_tools()); }
408}