1use crate::agents::Agent;
8use crate::llm::LLMClient;
9use crate::tools::registry::ToolRegistry;
10use crate::types::{AgentContext, AgentType, Result, ToolDefinition};
11use crate::utils::toml_config::AgentConfig;
12use async_trait::async_trait;
13use std::sync::Arc;
14
15pub struct ConfigurableAgent {
17 name: String,
19 agent_type: AgentType,
21 llm: Box<dyn LLMClient>,
23 system_prompt: String,
25 tool_registry: Option<Arc<ToolRegistry>>,
27 allowed_tools: Vec<String>,
29 max_tool_iterations: usize,
31 parallel_tools: bool,
33}
34
35impl ConfigurableAgent {
36 pub fn new(
45 name: &str,
46 config: &AgentConfig,
47 llm: Box<dyn LLMClient>,
48 tool_registry: Option<Arc<ToolRegistry>>,
49 ) -> Self {
50 let agent_type = Self::name_to_type(name);
51 let system_prompt = config
52 .system_prompt
53 .clone()
54 .unwrap_or_else(|| Self::default_system_prompt(name));
55
56 Self {
57 name: name.to_string(),
58 agent_type,
59 llm,
60 system_prompt,
61 tool_registry,
62 allowed_tools: config.tools.clone(),
63 max_tool_iterations: config.max_tool_iterations,
64 parallel_tools: config.parallel_tools,
65 }
66 }
67
68 #[allow(clippy::too_many_arguments)]
70 pub fn with_params(
71 name: &str,
72 agent_type: AgentType,
73 llm: Box<dyn LLMClient>,
74 system_prompt: String,
75 tool_registry: Option<Arc<ToolRegistry>>,
76 allowed_tools: Vec<String>,
77 max_tool_iterations: usize,
78 parallel_tools: bool,
79 ) -> Self {
80 Self {
81 name: name.to_string(),
82 agent_type,
83 llm,
84 system_prompt,
85 tool_registry,
86 allowed_tools,
87 max_tool_iterations,
88 parallel_tools,
89 }
90 }
91
92 fn name_to_type(name: &str) -> AgentType {
94 AgentType::from_string(name)
95 }
96
97 fn default_system_prompt(name: &str) -> String {
99 match name.to_lowercase().as_str() {
100 "router" => r#"You are a routing agent that classifies user queries.
101Available agents: product, invoice, sales, finance, hr, orchestrator.
102Respond with ONLY the agent name (one word, lowercase)."#
103 .to_string(),
104
105 "orchestrator" => r#"You are an orchestrator agent for complex queries.
106Break down requests, delegate to specialists, and synthesize results."#
107 .to_string(),
108
109 "product" => r#"You are a Product Agent for product-related queries.
110Handle catalog, specifications, inventory, and pricing questions."#
111 .to_string(),
112
113 "invoice" => r#"You are an Invoice Agent for billing queries.
114Handle invoices, payments, and billing history."#
115 .to_string(),
116
117 "sales" => r#"You are a Sales Agent for sales analytics.
118Handle performance metrics, revenue, and customer data."#
119 .to_string(),
120
121 "finance" => r#"You are a Finance Agent for financial analysis.
122Handle statements, budgets, and expense management."#
123 .to_string(),
124
125 "hr" => r#"You are an HR Agent for human resources.
126Handle employee info, policies, and benefits."#
127 .to_string(),
128
129 _ => format!("You are a {} agent.", name),
130 }
131 }
132
133 pub fn name(&self) -> &str {
135 &self.name
136 }
137
138 pub fn max_tool_iterations(&self) -> usize {
140 self.max_tool_iterations
141 }
142
143 pub fn parallel_tools(&self) -> bool {
145 self.parallel_tools
146 }
147
148 pub fn has_tools(&self) -> bool {
150 !self.allowed_tools.is_empty() && self.tool_registry.is_some()
151 }
152
153 pub fn tool_registry(&self) -> Option<&Arc<ToolRegistry>> {
155 self.tool_registry.as_ref()
156 }
157
158 pub fn allowed_tools(&self) -> &[String] {
160 &self.allowed_tools
161 }
162
163 pub fn get_filtered_tool_definitions(&self) -> Vec<ToolDefinition> {
169 match &self.tool_registry {
170 Some(registry) => {
171 let allowed: Vec<&str> = self.allowed_tools.iter().map(|s| s.as_str()).collect();
172 registry.get_tool_definitions_for(&allowed)
173 }
174 None => Vec::new(),
175 }
176 }
177
178 pub fn can_use_tool(&self, tool_name: &str) -> bool {
180 self.allowed_tools.contains(&tool_name.to_string())
181 && self
182 .tool_registry
183 .as_ref()
184 .map(|r| r.is_enabled(tool_name))
185 .unwrap_or(false)
186 }
187}
188
189#[async_trait]
190impl Agent for ConfigurableAgent {
191 async fn execute(&self, input: &str, context: &AgentContext) -> Result<String> {
192 let mut messages = vec![("system".to_string(), self.system_prompt.clone())];
194
195 if let Some(memory) = &context.user_memory {
197 let memory_context = format!(
198 "User preferences: {}",
199 memory
200 .preferences
201 .iter()
202 .map(|p| format!("{}: {}", p.key, p.value))
203 .collect::<Vec<_>>()
204 .join(", ")
205 );
206 messages.push(("system".to_string(), memory_context));
207 }
208
209 for msg in context.conversation_history.iter().rev().take(5).rev() {
211 let role = match msg.role {
212 crate::types::MessageRole::User => "user",
213 crate::types::MessageRole::Assistant => "assistant",
214 _ => "system",
215 };
216 messages.push((role.to_string(), msg.content.clone()));
217 }
218
219 messages.push(("user".to_string(), input.to_string()));
220
221 self.llm.generate_with_history(&messages).await
222 }
223
224 fn system_prompt(&self) -> String {
225 self.system_prompt.clone()
226 }
227
228 fn agent_type(&self) -> AgentType {
229 self.agent_type.clone()
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_name_to_type() {
239 assert!(matches!(
240 ConfigurableAgent::name_to_type("router"),
241 AgentType::Router
242 ));
243 assert!(matches!(
244 ConfigurableAgent::name_to_type("PRODUCT"),
245 AgentType::Product
246 ));
247 assert!(matches!(
249 ConfigurableAgent::name_to_type("unknown"),
250 AgentType::Custom(_)
251 ));
252 if let AgentType::Custom(name) = ConfigurableAgent::name_to_type("my-custom-agent") {
254 assert_eq!(name, "my-custom-agent");
255 } else {
256 panic!("Expected Custom variant");
257 }
258 }
259
260 #[test]
261 fn test_default_system_prompt() {
262 let prompt = ConfigurableAgent::default_system_prompt("router");
263 assert!(prompt.contains("routing"));
264
265 let prompt = ConfigurableAgent::default_system_prompt("product");
266 assert!(prompt.contains("Product"));
267 }
268
269 #[test]
270 fn test_allowed_tools() {
271 use crate::llm::LLMResponse;
272 use crate::utils::toml_config::AgentConfig;
273 use std::collections::HashMap;
274
275 struct MockLLM;
277
278 #[async_trait]
279 impl LLMClient for MockLLM {
280 async fn generate(&self, _: &str) -> Result<String> {
281 Ok("mock".to_string())
282 }
283 async fn generate_with_system(&self, _: &str, _: &str) -> Result<String> {
284 Ok("mock".to_string())
285 }
286 async fn generate_with_history(&self, _: &[(String, String)]) -> Result<String> {
287 Ok("mock".to_string())
288 }
289 async fn generate_with_tools(
290 &self,
291 _: &str,
292 _: &[ToolDefinition],
293 ) -> Result<LLMResponse> {
294 Ok(LLMResponse {
295 content: "mock".to_string(),
296 tool_calls: vec![],
297 finish_reason: "stop".to_string(),
298 })
299 }
300 async fn stream(
301 &self,
302 _: &str,
303 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
304 {
305 Ok(Box::new(futures::stream::empty()))
306 }
307 async fn stream_with_system(
308 &self,
309 _: &str,
310 _: &str,
311 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
312 {
313 Ok(Box::new(futures::stream::empty()))
314 }
315 async fn stream_with_history(
316 &self,
317 _: &[(String, String)],
318 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
319 {
320 Ok(Box::new(futures::stream::empty()))
321 }
322 fn model_name(&self) -> &str {
323 "mock"
324 }
325 }
326
327 let config = AgentConfig {
328 model: "default".to_string(),
329 system_prompt: None,
330 tools: vec!["calculator".to_string(), "web_search".to_string()],
331 max_tool_iterations: 5,
332 parallel_tools: false,
333 extra: HashMap::new(),
334 };
335
336 let agent = ConfigurableAgent::new(
337 "orchestrator",
338 &config,
339 Box::new(MockLLM),
340 None, );
342
343 assert_eq!(agent.allowed_tools().len(), 2);
344 assert!(agent.allowed_tools().contains(&"calculator".to_string()));
345 assert!(agent.allowed_tools().contains(&"web_search".to_string()));
346 }
347
348 #[test]
349 fn test_has_tools_requires_both_config_and_registry() {
350 use crate::llm::LLMResponse;
351 use crate::utils::toml_config::AgentConfig;
352 use std::collections::HashMap;
353
354 struct MockLLM;
355
356 #[async_trait]
357 impl LLMClient for MockLLM {
358 async fn generate(&self, _: &str) -> Result<String> {
359 Ok("mock".to_string())
360 }
361 async fn generate_with_system(&self, _: &str, _: &str) -> Result<String> {
362 Ok("mock".to_string())
363 }
364 async fn generate_with_history(&self, _: &[(String, String)]) -> Result<String> {
365 Ok("mock".to_string())
366 }
367 async fn generate_with_tools(
368 &self,
369 _: &str,
370 _: &[ToolDefinition],
371 ) -> Result<LLMResponse> {
372 Ok(LLMResponse {
373 content: "mock".to_string(),
374 tool_calls: vec![],
375 finish_reason: "stop".to_string(),
376 })
377 }
378 async fn stream(
379 &self,
380 _: &str,
381 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
382 {
383 Ok(Box::new(futures::stream::empty()))
384 }
385 async fn stream_with_system(
386 &self,
387 _: &str,
388 _: &str,
389 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
390 {
391 Ok(Box::new(futures::stream::empty()))
392 }
393 async fn stream_with_history(
394 &self,
395 _: &[(String, String)],
396 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>
397 {
398 Ok(Box::new(futures::stream::empty()))
399 }
400 fn model_name(&self) -> &str {
401 "mock"
402 }
403 }
404
405 let config = AgentConfig {
407 model: "default".to_string(),
408 system_prompt: None,
409 tools: vec!["calculator".to_string()],
410 max_tool_iterations: 5,
411 parallel_tools: false,
412 extra: HashMap::new(),
413 };
414
415 let agent = ConfigurableAgent::new("orchestrator", &config, Box::new(MockLLM), None);
416 assert!(!agent.has_tools()); let config_empty = AgentConfig {
420 model: "default".to_string(),
421 system_prompt: None,
422 tools: vec![],
423 max_tool_iterations: 5,
424 parallel_tools: false,
425 extra: HashMap::new(),
426 };
427
428 let agent_empty = ConfigurableAgent::new(
429 "product",
430 &config_empty,
431 Box::new(MockLLM),
432 Some(Arc::new(ToolRegistry::new())),
433 );
434 assert!(!agent_empty.has_tools()); }
436}