1use crate::agents::Agent;
8use crate::llm::LLMClient;
9use crate::tools::registry::ToolRegistry;
10use crate::types::{AgentContext, AgentType, Result, ToolDefinition};
11use crate::utils::toml_config::AgentConfig;
12use async_trait::async_trait;
13use std::sync::Arc;
14
15pub struct ConfigurableAgent {
17 name: String,
19 agent_type: AgentType,
21 llm: Box<dyn LLMClient>,
23 system_prompt: String,
25 tool_registry: Option<Arc<ToolRegistry>>,
27 allowed_tools: Vec<String>,
29 max_tool_iterations: usize,
31 parallel_tools: bool,
33}
34
35impl ConfigurableAgent {
36 pub fn new(
45 name: &str,
46 config: &AgentConfig,
47 llm: Box<dyn LLMClient>,
48 tool_registry: Option<Arc<ToolRegistry>>,
49 ) -> Self {
50 let agent_type = Self::name_to_type(name);
51 let system_prompt = config
52 .system_prompt
53 .clone()
54 .unwrap_or_else(|| Self::default_system_prompt(name));
55
56 Self {
57 name: name.to_string(),
58 agent_type,
59 llm,
60 system_prompt,
61 tool_registry,
62 allowed_tools: config.tools.clone(),
63 max_tool_iterations: config.max_tool_iterations,
64 parallel_tools: config.parallel_tools,
65 }
66 }
67
68 #[allow(clippy::too_many_arguments)]
70 pub fn with_params(
71 name: &str,
72 agent_type: AgentType,
73 llm: Box<dyn LLMClient>,
74 system_prompt: String,
75 tool_registry: Option<Arc<ToolRegistry>>,
76 allowed_tools: Vec<String>,
77 max_tool_iterations: usize,
78 parallel_tools: bool,
79 ) -> Self {
80 Self {
81 name: name.to_string(),
82 agent_type,
83 llm,
84 system_prompt,
85 tool_registry,
86 allowed_tools,
87 max_tool_iterations,
88 parallel_tools,
89 }
90 }
91
92 fn name_to_type(name: &str) -> AgentType {
94 AgentType::from_string(name)
95 }
96
97 fn default_system_prompt(name: &str) -> String {
99 match name.to_lowercase().as_str() {
100 "router" => r#"You are a routing agent that classifies user queries.
101Available agents: product, invoice, sales, finance, hr, orchestrator.
102Respond with ONLY the agent name (one word, lowercase)."#
103 .to_string(),
104
105 "orchestrator" => r#"You are an orchestrator agent for complex queries.
106Break down requests, delegate to specialists, and synthesize results."#
107 .to_string(),
108
109 "product" => r#"You are a Product Agent for product-related queries.
110Handle catalog, specifications, inventory, and pricing questions."#
111 .to_string(),
112
113 "invoice" => r#"You are an Invoice Agent for billing queries.
114Handle invoices, payments, and billing history."#
115 .to_string(),
116
117 "sales" => r#"You are a Sales Agent for sales analytics.
118Handle performance metrics, revenue, and customer data."#
119 .to_string(),
120
121 "finance" => r#"You are a Finance Agent for financial analysis.
122Handle statements, budgets, and expense management."#
123 .to_string(),
124
125 "hr" => r#"You are an HR Agent for human resources.
126Handle employee info, policies, and benefits."#
127 .to_string(),
128
129 _ => format!("You are a {} agent.", name),
130 }
131 }
132
133 pub fn name(&self) -> &str {
135 &self.name
136 }
137
138 pub fn max_tool_iterations(&self) -> usize {
140 self.max_tool_iterations
141 }
142
143 pub fn parallel_tools(&self) -> bool {
145 self.parallel_tools
146 }
147
148 pub fn has_tools(&self) -> bool {
150 !self.allowed_tools.is_empty() && self.tool_registry.is_some()
151 }
152
153 pub fn tool_registry(&self) -> Option<&Arc<ToolRegistry>> {
155 self.tool_registry.as_ref()
156 }
157
158 pub fn allowed_tools(&self) -> &[String] {
160 &self.allowed_tools
161 }
162
163 pub fn get_filtered_tool_definitions(&self) -> Vec<ToolDefinition> {
169 match &self.tool_registry {
170 Some(registry) => {
171 let allowed: Vec<&str> = self.allowed_tools.iter().map(|s| s.as_str()).collect();
172 registry.get_tool_definitions_for(&allowed)
173 }
174 None => Vec::new(),
175 }
176 }
177
178 pub fn can_use_tool(&self, tool_name: &str) -> bool {
180 self.allowed_tools.contains(&tool_name.to_string())
181 && self
182 .tool_registry
183 .as_ref()
184 .map(|r| r.is_enabled(tool_name))
185 .unwrap_or(false)
186 }
187}
188
189#[async_trait]
190impl Agent for ConfigurableAgent {
191 async fn execute(&self, input: &str, context: &AgentContext) -> Result<String> {
192 let mut messages = vec![("system".to_string(), self.system_prompt.clone())];
194
195 if let Some(memory) = &context.user_memory {
197 let memory_context = format!(
198 "User preferences: {}",
199 memory
200 .preferences
201 .iter()
202 .map(|p| format!("{}: {}", p.key, p.value))
203 .collect::<Vec<_>>()
204 .join(", ")
205 );
206 messages.push(("system".to_string(), memory_context));
207 }
208
209 for msg in context.conversation_history.iter().rev().take(5).rev() {
211 let role = match msg.role {
212 crate::types::MessageRole::User => "user",
213 crate::types::MessageRole::Assistant => "assistant",
214 _ => "system",
215 };
216 messages.push((role.to_string(), msg.content.clone()));
217 }
218
219 messages.push(("user".to_string(), input.to_string()));
220
221 self.llm.generate_with_history(&messages).await
222 }
223
224 fn system_prompt(&self) -> String {
225 self.system_prompt.clone()
226 }
227
228 fn agent_type(&self) -> AgentType {
229 self.agent_type.clone()
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_name_to_type() {
239 assert!(matches!(
240 ConfigurableAgent::name_to_type("router"),
241 AgentType::Router
242 ));
243 assert!(matches!(
244 ConfigurableAgent::name_to_type("PRODUCT"),
245 AgentType::Product
246 ));
247 assert!(matches!(
249 ConfigurableAgent::name_to_type("unknown"),
250 AgentType::Custom(_)
251 ));
252 if let AgentType::Custom(name) = ConfigurableAgent::name_to_type("my-custom-agent") {
254 assert_eq!(name, "my-custom-agent");
255 } else {
256 panic!("Expected Custom variant");
257 }
258 }
259
260 #[test]
261 fn test_default_system_prompt() {
262 let prompt = ConfigurableAgent::default_system_prompt("router");
263 assert!(prompt.contains("routing"));
264
265 let prompt = ConfigurableAgent::default_system_prompt("product");
266 assert!(prompt.contains("Product"));
267 }
268
269 #[test]
270 fn test_allowed_tools() {
271 use crate::llm::LLMResponse;
272 use crate::utils::toml_config::AgentConfig;
273 use std::collections::HashMap;
274
275 struct MockLLM;
277
278 #[async_trait]
279 impl LLMClient for MockLLM {
280 async fn generate(&self, _: &str) -> Result<String> {
281 Ok("mock".to_string())
282 }
283 async fn generate_with_system(&self, _: &str, _: &str) -> Result<String> {
284 Ok("mock".to_string())
285 }
286 async fn generate_with_history(&self, _: &[(String, String)]) -> Result<String> {
287 Ok("mock".to_string())
288 }
289 async fn generate_with_tools(
290 &self,
291 _: &str,
292 _: &[ToolDefinition],
293 ) -> Result<LLMResponse> {
294 Ok(LLMResponse {
295 content: "mock".to_string(),
296 tool_calls: vec![],
297 finish_reason: "stop".to_string(),
298 usage: None,
299 })
300 }
301 async fn stream(
302 &self,
303 _: &str,
304 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
305 {
306 Ok(Box::new(futures::stream::empty()))
307 }
308 async fn stream_with_system(
309 &self,
310 _: &str,
311 _: &str,
312 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
313 {
314 Ok(Box::new(futures::stream::empty()))
315 }
316 async fn stream_with_history(
317 &self,
318 _: &[(String, String)],
319 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
320 {
321 Ok(Box::new(futures::stream::empty()))
322 }
323 fn model_name(&self) -> &str {
324 "mock"
325 }
326 async fn generate_with_tools_and_history(
327 &self,
328 _: &[crate::llm::coordinator::ConversationMessage],
329 _: &[ToolDefinition],
330 ) -> Result<LLMResponse> {
331 Ok(LLMResponse {
332 content: "mock".to_string(),
333 tool_calls: vec![],
334 finish_reason: "stop".to_string(),
335 usage: None,
336 })
337 }
338 }
339
340 let config = AgentConfig {
341 model: "default".to_string(),
342 system_prompt: None,
343 tools: vec!["calculator".to_string(), "web_search".to_string()],
344 max_tool_iterations: 5,
345 parallel_tools: false,
346 extra: HashMap::new(),
347 };
348
349 let agent = ConfigurableAgent::new(
350 "orchestrator",
351 &config,
352 Box::new(MockLLM),
353 None, );
355
356 assert_eq!(agent.allowed_tools().len(), 2);
357 assert!(agent.allowed_tools().contains(&"calculator".to_string()));
358 assert!(agent.allowed_tools().contains(&"web_search".to_string()));
359 }
360
361 #[test]
362 fn test_has_tools_requires_both_config_and_registry() {
363 use crate::llm::LLMResponse;
364 use crate::utils::toml_config::AgentConfig;
365 use std::collections::HashMap;
366
367 struct MockLLM;
368
369 #[async_trait]
370 impl LLMClient for MockLLM {
371 async fn generate(&self, _: &str) -> Result<String> {
372 Ok("mock".to_string())
373 }
374 async fn generate_with_system(&self, _: &str, _: &str) -> Result<String> {
375 Ok("mock".to_string())
376 }
377 async fn generate_with_history(&self, _: &[(String, String)]) -> Result<String> {
378 Ok("mock".to_string())
379 }
380 async fn generate_with_tools(
381 &self,
382 _: &str,
383 _: &[ToolDefinition],
384 ) -> Result<LLMResponse> {
385 Ok(LLMResponse {
386 content: "mock".to_string(),
387 tool_calls: vec![],
388 finish_reason: "stop".to_string(),
389 usage: None,
390 })
391 }
392 async fn stream(
393 &self,
394 _: &str,
395 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
396 {
397 Ok(Box::new(futures::stream::empty()))
398 }
399 async fn stream_with_system(
400 &self,
401 _: &str,
402 _: &str,
403 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
404 {
405 Ok(Box::new(futures::stream::empty()))
406 }
407 async fn stream_with_history(
408 &self,
409 _: &[(String, String)],
410 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
411 {
412 Ok(Box::new(futures::stream::empty()))
413 }
414 fn model_name(&self) -> &str {
415 "mock"
416 }
417 async fn generate_with_tools_and_history(
418 &self,
419 _: &[crate::llm::coordinator::ConversationMessage],
420 _: &[ToolDefinition],
421 ) -> Result<LLMResponse> {
422 Ok(LLMResponse {
423 content: "mock".to_string(),
424 tool_calls: vec![],
425 finish_reason: "stop".to_string(),
426 usage: None,
427 })
428 }
429 }
430
431 let config = AgentConfig {
433 model: "default".to_string(),
434 system_prompt: None,
435 tools: vec!["calculator".to_string()],
436 max_tool_iterations: 5,
437 parallel_tools: false,
438 extra: HashMap::new(),
439 };
440
441 let agent = ConfigurableAgent::new("orchestrator", &config, Box::new(MockLLM), None);
442 assert!(!agent.has_tools()); let config_empty = AgentConfig {
446 model: "default".to_string(),
447 system_prompt: None,
448 tools: vec![],
449 max_tool_iterations: 5,
450 parallel_tools: false,
451 extra: HashMap::new(),
452 };
453
454 let agent_empty = ConfigurableAgent::new(
455 "product",
456 &config_empty,
457 Box::new(MockLLM),
458 Some(Arc::new(ToolRegistry::new())),
459 );
460 assert!(!agent_empty.has_tools()); }
462}