1#![allow(dead_code)]
9#![allow(unused_variables)]
10use crate::chat::{ChatMessage, ChatSession};
11use crate::config::Config;
12use crate::error::{HeliosError, Result};
13use crate::llm::{LLMClient, LLMProviderType};
14use crate::tools::{ToolRegistry, ToolResult};
15use serde_json::Value;
16
17const AGENT_MEMORY_PREFIX: &str = "agent:";
19
20pub struct Agent {
22 name: String,
24 llm_client: LLMClient,
26 tool_registry: ToolRegistry,
28 chat_session: ChatSession,
30 max_iterations: usize,
32}
33
34impl Agent {
35 async fn new(name: impl Into<String>, config: Config) -> Result<Self> {
46 let provider_type = if let Some(local_config) = config.local {
47 LLMProviderType::Local(local_config)
48 } else {
49 LLMProviderType::Remote(config.llm)
50 };
51
52 let llm_client = LLMClient::new(provider_type).await?;
53
54 Ok(Self {
55 name: name.into(),
56 llm_client,
57 tool_registry: ToolRegistry::new(),
58 chat_session: ChatSession::new(),
59 max_iterations: 10,
60 })
61 }
62
63 pub fn builder(name: impl Into<String>) -> AgentBuilder {
69 AgentBuilder::new(name)
70 }
71
72 pub fn name(&self) -> &str {
74 &self.name
75 }
76
77 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
83 self.chat_session = self.chat_session.clone().with_system_prompt(prompt);
84 }
85
86 pub fn register_tool(&mut self, tool: Box<dyn crate::tools::Tool>) {
92 self.tool_registry.register(tool);
93 }
94
95 pub fn tool_registry(&self) -> &ToolRegistry {
97 &self.tool_registry
98 }
99
100 pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry {
102 &mut self.tool_registry
103 }
104
105 pub fn chat_session(&self) -> &ChatSession {
107 &self.chat_session
108 }
109
110 pub fn chat_session_mut(&mut self) -> &mut ChatSession {
112 &mut self.chat_session
113 }
114
115 pub fn clear_history(&mut self) {
117 self.chat_session.clear();
118 }
119
120 pub async fn send_message(&mut self, message: impl Into<String>) -> Result<String> {
130 let user_message = message.into();
131 self.chat_session.add_user_message(user_message.clone());
132
133 let response = self.execute_with_tools().await?;
135
136 Ok(response)
137 }
138
139 async fn execute_with_tools(&mut self) -> Result<String> {
141 let mut iterations = 0;
142 let tool_definitions = self.tool_registry.get_definitions();
143
144 loop {
145 if iterations >= self.max_iterations {
146 return Err(HeliosError::AgentError(
147 "Maximum iterations reached".to_string(),
148 ));
149 }
150
151 let messages = self.chat_session.get_messages();
152 let tools_option = if tool_definitions.is_empty() {
153 None
154 } else {
155 Some(tool_definitions.clone())
156 };
157
158 let response = self.llm_client.chat(messages, tools_option).await?;
159
160 if let Some(ref tool_calls) = response.tool_calls {
162 self.chat_session.add_message(response.clone());
164
165 for tool_call in tool_calls {
167 let tool_name = &tool_call.function.name;
168 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
169 .unwrap_or(Value::Object(serde_json::Map::new()));
170
171 let tool_result = self
172 .tool_registry
173 .execute(tool_name, tool_args)
174 .await
175 .unwrap_or_else(|e| {
176 ToolResult::error(format!("Tool execution failed: {}", e))
177 });
178
179 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
181 self.chat_session.add_message(tool_message);
182 }
183
184 iterations += 1;
185 continue;
186 }
187
188 self.chat_session.add_message(response.clone());
190 return Ok(response.content);
191 }
192 }
193
194 pub async fn chat(&mut self, message: impl Into<String>) -> Result<String> {
196 self.send_message(message).await
197 }
198
199 pub fn set_max_iterations(&mut self, max: usize) {
205 self.max_iterations = max;
206 }
207
208 pub fn get_session_summary(&self) -> String {
210 self.chat_session.get_summary()
211 }
212
213 pub fn clear_memory(&mut self) {
215 self.chat_session
217 .metadata
218 .retain(|k, _| !k.starts_with(AGENT_MEMORY_PREFIX));
219 }
220
221 #[inline]
223 fn prefixed_key(key: &str) -> String {
224 format!("{}{}", AGENT_MEMORY_PREFIX, key)
225 }
226
227 pub fn set_memory(&mut self, key: impl Into<String>, value: impl Into<String>) {
230 let key = key.into();
231 self.chat_session
232 .set_metadata(Self::prefixed_key(&key), value);
233 }
234
235 pub fn get_memory(&self, key: &str) -> Option<&String> {
237 self.chat_session.get_metadata(&Self::prefixed_key(key))
238 }
239
240 pub fn remove_memory(&mut self, key: &str) -> Option<String> {
242 self.chat_session.remove_metadata(&Self::prefixed_key(key))
243 }
244
245 pub fn increment_counter(&mut self, key: &str) -> u32 {
248 let current = self
249 .get_memory(key)
250 .and_then(|v| v.parse::<u32>().ok())
251 .unwrap_or(0);
252 let next = current + 1;
253 self.set_memory(key, next.to_string());
254 next
255 }
256
257 pub fn increment_tasks_completed(&mut self) -> u32 {
259 self.increment_counter("tasks_completed")
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use crate::config::Config;
267 use crate::tools::{CalculatorTool, Tool, ToolParameter, ToolResult};
268 use serde_json::Value;
269 use std::collections::HashMap;
270
271 #[tokio::test]
273 async fn test_agent_creation_via_builder() {
274 let config = Config::new_default();
275 let agent = Agent::builder("test_agent").config(config).build().await;
276 assert!(agent.is_ok());
277 }
278
279 #[tokio::test]
281 async fn test_agent_memory_namespacing_set_get_remove() {
282 let config = Config::new_default();
283 let mut agent = Agent::builder("test_agent")
284 .config(config)
285 .build()
286 .await
287 .unwrap();
288
289 agent.set_memory("working_directory", "/tmp");
291 assert_eq!(
292 agent.get_memory("working_directory"),
293 Some(&"/tmp".to_string())
294 );
295
296 assert_eq!(
298 agent.chat_session().get_metadata("agent:working_directory"),
299 Some(&"/tmp".to_string())
300 );
301 assert!(agent
303 .chat_session()
304 .get_metadata("working_directory")
305 .is_none());
306
307 let removed = agent.remove_memory("working_directory");
309 assert_eq!(removed.as_deref(), Some("/tmp"));
310 assert!(agent.get_memory("working_directory").is_none());
311 }
312
313 #[tokio::test]
315 async fn test_agent_clear_memory_scoped() {
316 let config = Config::new_default();
317 let mut agent = Agent::builder("test_agent")
318 .config(config)
319 .build()
320 .await
321 .unwrap();
322
323 agent.set_memory("tasks_completed", "3");
325 agent
326 .chat_session_mut()
327 .set_metadata("session_start", "now");
328
329 agent.clear_memory();
331
332 assert!(agent.get_memory("tasks_completed").is_none());
334 assert_eq!(
336 agent.chat_session().get_metadata("session_start"),
337 Some(&"now".to_string())
338 );
339 }
340
341 #[tokio::test]
343 async fn test_agent_increment_helpers() {
344 let config = Config::new_default();
345 let mut agent = Agent::builder("test_agent")
346 .config(config)
347 .build()
348 .await
349 .unwrap();
350
351 let n1 = agent.increment_tasks_completed();
353 assert_eq!(n1, 1);
354 assert_eq!(agent.get_memory("tasks_completed"), Some(&"1".to_string()));
355
356 let n2 = agent.increment_tasks_completed();
357 assert_eq!(n2, 2);
358 assert_eq!(agent.get_memory("tasks_completed"), Some(&"2".to_string()));
359
360 let f1 = agent.increment_counter("files_accessed");
362 assert_eq!(f1, 1);
363 let f2 = agent.increment_counter("files_accessed");
364 assert_eq!(f2, 2);
365 assert_eq!(agent.get_memory("files_accessed"), Some(&"2".to_string()));
366 }
367
368 #[tokio::test]
370 async fn test_agent_builder() {
371 let config = Config::new_default();
372 let agent = Agent::builder("test_agent")
373 .config(config)
374 .system_prompt("You are a helpful assistant")
375 .max_iterations(5)
376 .tool(Box::new(CalculatorTool))
377 .build()
378 .await
379 .unwrap();
380
381 assert_eq!(agent.name(), "test_agent");
382 assert_eq!(agent.max_iterations, 5);
383 assert_eq!(
384 agent.tool_registry().list_tools(),
385 vec!["calculator".to_string()]
386 );
387 }
388
389 #[tokio::test]
391 async fn test_agent_system_prompt() {
392 let config = Config::new_default();
393 let mut agent = Agent::builder("test_agent")
394 .config(config)
395 .build()
396 .await
397 .unwrap();
398 agent.set_system_prompt("You are a test agent");
399
400 let session = agent.chat_session();
402 assert_eq!(
403 session.system_prompt,
404 Some("You are a test agent".to_string())
405 );
406 }
407
408 #[tokio::test]
410 async fn test_agent_tool_registry() {
411 let config = Config::new_default();
412 let mut agent = Agent::builder("test_agent")
413 .config(config)
414 .build()
415 .await
416 .unwrap();
417
418 assert!(agent.tool_registry().list_tools().is_empty());
420
421 agent.register_tool(Box::new(CalculatorTool));
423 assert_eq!(
424 agent.tool_registry().list_tools(),
425 vec!["calculator".to_string()]
426 );
427 }
428
429 #[tokio::test]
431 async fn test_agent_clear_history() {
432 let config = Config::new_default();
433 let mut agent = Agent::builder("test_agent")
434 .config(config)
435 .build()
436 .await
437 .unwrap();
438
439 agent.chat_session_mut().add_user_message("Hello");
441 assert!(!agent.chat_session().messages.is_empty());
442
443 agent.clear_history();
445 assert!(agent.chat_session().messages.is_empty());
446 }
447
448 struct MockTool;
450
451 #[async_trait::async_trait]
452 impl Tool for MockTool {
453 fn name(&self) -> &str {
454 "mock_tool"
455 }
456
457 fn description(&self) -> &str {
458 "A mock tool for testing"
459 }
460
461 fn parameters(&self) -> HashMap<String, ToolParameter> {
462 let mut params = HashMap::new();
463 params.insert(
464 "input".to_string(),
465 ToolParameter {
466 param_type: "string".to_string(),
467 description: "Input parameter".to_string(),
468 required: Some(true),
469 },
470 );
471 params
472 }
473
474 async fn execute(&self, args: Value) -> crate::Result<ToolResult> {
475 let input = args
476 .get("input")
477 .and_then(|v| v.as_str())
478 .unwrap_or("default");
479 Ok(ToolResult::success(format!("Mock tool output: {}", input)))
480 }
481 }
482}
483
484pub struct AgentBuilder {
485 name: String,
486 config: Option<Config>,
487 system_prompt: Option<String>,
488 tools: Vec<Box<dyn crate::tools::Tool>>,
489 max_iterations: usize,
490}
491
492impl AgentBuilder {
493 pub fn new(name: impl Into<String>) -> Self {
494 Self {
495 name: name.into(),
496 config: None,
497 system_prompt: None,
498 tools: Vec::new(),
499 max_iterations: 10,
500 }
501 }
502
503 pub fn config(mut self, config: Config) -> Self {
504 self.config = Some(config);
505 self
506 }
507
508 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
509 self.system_prompt = Some(prompt.into());
510 self
511 }
512
513 pub fn tool(mut self, tool: Box<dyn crate::tools::Tool>) -> Self {
514 self.tools.push(tool);
515 self
516 }
517
518 pub fn max_iterations(mut self, max: usize) -> Self {
519 self.max_iterations = max;
520 self
521 }
522
523 pub async fn build(self) -> Result<Agent> {
524 let config = self
525 .config
526 .ok_or_else(|| HeliosError::AgentError("Config is required".to_string()))?;
527
528 let mut agent = Agent::new(self.name, config).await?;
529
530 if let Some(prompt) = self.system_prompt {
531 agent.set_system_prompt(prompt);
532 }
533
534 for tool in self.tools {
535 agent.register_tool(tool);
536 }
537
538 agent.set_max_iterations(self.max_iterations);
539
540 Ok(agent)
541 }
542}