1#![allow(dead_code)]
2#![allow(unused_variables)]
3use crate::chat::{ChatMessage, ChatSession};
4use crate::config::Config;
5use crate::error::{HeliosError, Result};
6use crate::llm::{LLMClient, LLMProviderType};
7use crate::tools::{ToolRegistry, ToolResult};
8use serde_json::Value;
9
10const AGENT_MEMORY_PREFIX: &str = "agent:";
11
12pub struct Agent {
13 name: String,
14 llm_client: LLMClient,
15 tool_registry: ToolRegistry,
16 chat_session: ChatSession,
17 max_iterations: usize,
18}
19
20impl Agent {
21 async fn new(name: impl Into<String>, config: Config) -> Result<Self> {
22 let provider_type = if let Some(local_config) = config.local {
23 LLMProviderType::Local(local_config)
24 } else {
25 LLMProviderType::Remote(config.llm)
26 };
27
28 let llm_client = LLMClient::new(provider_type).await?;
29
30 Ok(Self {
31 name: name.into(),
32 llm_client,
33 tool_registry: ToolRegistry::new(),
34 chat_session: ChatSession::new(),
35 max_iterations: 10,
36 })
37 }
38
39 pub fn builder(name: impl Into<String>) -> AgentBuilder {
40 AgentBuilder::new(name)
41 }
42
43 pub fn name(&self) -> &str {
44 &self.name
45 }
46
47 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
48 self.chat_session = self.chat_session.clone().with_system_prompt(prompt);
49 }
50
51 pub fn register_tool(&mut self, tool: Box<dyn crate::tools::Tool>) {
52 self.tool_registry.register(tool);
53 }
54
55 pub fn tool_registry(&self) -> &ToolRegistry {
56 &self.tool_registry
57 }
58
59 pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry {
60 &mut self.tool_registry
61 }
62
63 pub fn chat_session(&self) -> &ChatSession {
64 &self.chat_session
65 }
66
67 pub fn chat_session_mut(&mut self) -> &mut ChatSession {
68 &mut self.chat_session
69 }
70
71 pub fn clear_history(&mut self) {
72 self.chat_session.clear();
73 }
74
75 pub async fn send_message(&mut self, message: impl Into<String>) -> Result<String> {
76 let user_message = message.into();
77 self.chat_session.add_user_message(user_message.clone());
78
79 let response = self.execute_with_tools().await?;
81
82 Ok(response)
83 }
84
85 async fn execute_with_tools(&mut self) -> Result<String> {
86 let mut iterations = 0;
87 let tool_definitions = self.tool_registry.get_definitions();
88
89 loop {
90 if iterations >= self.max_iterations {
91 return Err(HeliosError::AgentError(
92 "Maximum iterations reached".to_string(),
93 ));
94 }
95
96 let messages = self.chat_session.get_messages();
97 let tools_option = if tool_definitions.is_empty() {
98 None
99 } else {
100 Some(tool_definitions.clone())
101 };
102
103 let response = self.llm_client.chat(messages, tools_option).await?;
104
105 if let Some(ref tool_calls) = response.tool_calls {
107 self.chat_session.add_message(response.clone());
109
110 for tool_call in tool_calls {
112 let tool_name = &tool_call.function.name;
113 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
114 .unwrap_or(Value::Object(serde_json::Map::new()));
115
116 let tool_result = self
117 .tool_registry
118 .execute(tool_name, tool_args)
119 .await
120 .unwrap_or_else(|e| {
121 ToolResult::error(format!("Tool execution failed: {}", e))
122 });
123
124 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
126 self.chat_session.add_message(tool_message);
127 }
128
129 iterations += 1;
130 continue;
131 }
132
133 self.chat_session.add_message(response.clone());
135 return Ok(response.content);
136 }
137 }
138
139 pub async fn chat(&mut self, message: impl Into<String>) -> Result<String> {
140 self.send_message(message).await
141 }
142
143 pub fn set_max_iterations(&mut self, max: usize) {
144 self.max_iterations = max;
145 }
146
147 pub fn get_session_summary(&self) -> String {
148 self.chat_session.get_summary()
149 }
150
151 pub fn clear_memory(&mut self) {
152 self.chat_session
154 .metadata
155 .retain(|k, _| !k.starts_with(AGENT_MEMORY_PREFIX));
156 }
157
158 #[inline]
159 fn prefixed_key(key: &str) -> String {
160 format!("{}{}", AGENT_MEMORY_PREFIX, key)
161 }
162
163 pub fn set_memory(&mut self, key: impl Into<String>, value: impl Into<String>) {
165 let key = key.into();
166 self.chat_session
167 .set_metadata(Self::prefixed_key(&key), value);
168 }
169
170 pub fn get_memory(&self, key: &str) -> Option<&String> {
171 self.chat_session.get_metadata(&Self::prefixed_key(key))
172 }
173
174 pub fn remove_memory(&mut self, key: &str) -> Option<String> {
175 self.chat_session.remove_metadata(&Self::prefixed_key(key))
176 }
177
178 pub fn increment_counter(&mut self, key: &str) -> u32 {
180 let current = self
181 .get_memory(key)
182 .and_then(|v| v.parse::<u32>().ok())
183 .unwrap_or(0);
184 let next = current + 1;
185 self.set_memory(key, next.to_string());
186 next
187 }
188
189 pub fn increment_tasks_completed(&mut self) -> u32 {
190 self.increment_counter("tasks_completed")
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use crate::config::Config;
198 use crate::tools::{CalculatorTool, Tool, ToolParameter, ToolResult};
199 use serde_json::Value;
200 use std::collections::HashMap;
201
202 #[tokio::test]
203 async fn test_agent_creation_via_builder() {
204 let config = Config::new_default();
205 let agent = Agent::builder("test_agent")
206 .config(config)
207 .build()
208 .await;
209 assert!(agent.is_ok());
210 }
211
212 #[tokio::test]
213 async fn test_agent_memory_namespacing_set_get_remove() {
214 let config = Config::new_default();
215 let mut agent = Agent::builder("test_agent")
216 .config(config)
217 .build()
218 .await
219 .unwrap();
220
221 agent.set_memory("working_directory", "/tmp");
223 assert_eq!(
224 agent.get_memory("working_directory"),
225 Some(&"/tmp".to_string())
226 );
227
228 assert_eq!(
230 agent
231 .chat_session()
232 .get_metadata("agent:working_directory"),
233 Some(&"/tmp".to_string())
234 );
235 assert!(agent.chat_session().get_metadata("working_directory").is_none());
237
238 let removed = agent.remove_memory("working_directory");
240 assert_eq!(removed.as_deref(), Some("/tmp"));
241 assert!(agent.get_memory("working_directory").is_none());
242 }
243
244 #[tokio::test]
245 async fn test_agent_clear_memory_scoped() {
246 let config = Config::new_default();
247 let mut agent = Agent::builder("test_agent")
248 .config(config)
249 .build()
250 .await
251 .unwrap();
252
253 agent.set_memory("tasks_completed", "3");
255 agent
256 .chat_session_mut()
257 .set_metadata("session_start", "now");
258
259 agent.clear_memory();
261
262 assert!(agent.get_memory("tasks_completed").is_none());
264 assert_eq!(
266 agent.chat_session().get_metadata("session_start"),
267 Some(&"now".to_string())
268 );
269 }
270
271 #[tokio::test]
272 async fn test_agent_increment_helpers() {
273 let config = Config::new_default();
274 let mut agent = Agent::builder("test_agent")
275 .config(config)
276 .build()
277 .await
278 .unwrap();
279
280 let n1 = agent.increment_tasks_completed();
282 assert_eq!(n1, 1);
283 assert_eq!(agent.get_memory("tasks_completed"), Some(&"1".to_string()));
284
285 let n2 = agent.increment_tasks_completed();
286 assert_eq!(n2, 2);
287 assert_eq!(agent.get_memory("tasks_completed"), Some(&"2".to_string()));
288
289 let f1 = agent.increment_counter("files_accessed");
291 assert_eq!(f1, 1);
292 let f2 = agent.increment_counter("files_accessed");
293 assert_eq!(f2, 2);
294 assert_eq!(agent.get_memory("files_accessed"), Some(&"2".to_string()));
295 }
296
297 #[tokio::test]
298 async fn test_agent_builder() {
299 let config = Config::new_default();
300 let agent = Agent::builder("test_agent")
301 .config(config)
302 .system_prompt("You are a helpful assistant")
303 .max_iterations(5)
304 .tool(Box::new(CalculatorTool))
305 .build()
306 .await
307 .unwrap();
308
309 assert_eq!(agent.name(), "test_agent");
310 assert_eq!(agent.max_iterations, 5);
311 assert_eq!(
312 agent.tool_registry().list_tools(),
313 vec!["calculator".to_string()]
314 );
315 }
316
317 #[tokio::test]
318 async fn test_agent_system_prompt() {
319 let config = Config::new_default();
320 let mut agent = Agent::builder("test_agent")
321 .config(config)
322 .build()
323 .await
324 .unwrap();
325 agent.set_system_prompt("You are a test agent");
326
327 let session = agent.chat_session();
329 assert_eq!(
330 session.system_prompt,
331 Some("You are a test agent".to_string())
332 );
333 }
334
335 #[tokio::test]
336 async fn test_agent_tool_registry() {
337 let config = Config::new_default();
338 let mut agent = Agent::builder("test_agent")
339 .config(config)
340 .build()
341 .await
342 .unwrap();
343
344 assert!(agent.tool_registry().list_tools().is_empty());
346
347 agent.register_tool(Box::new(CalculatorTool));
349 assert_eq!(
350 agent.tool_registry().list_tools(),
351 vec!["calculator".to_string()]
352 );
353 }
354
355 #[tokio::test]
356 async fn test_agent_clear_history() {
357 let config = Config::new_default();
358 let mut agent = Agent::builder("test_agent")
359 .config(config)
360 .build()
361 .await
362 .unwrap();
363
364 agent.chat_session_mut().add_user_message("Hello");
366 assert!(!agent.chat_session().messages.is_empty());
367
368 agent.clear_history();
370 assert!(agent.chat_session().messages.is_empty());
371 }
372
373 struct MockTool;
375
376 #[async_trait::async_trait]
377 impl Tool for MockTool {
378 fn name(&self) -> &str {
379 "mock_tool"
380 }
381
382 fn description(&self) -> &str {
383 "A mock tool for testing"
384 }
385
386 fn parameters(&self) -> HashMap<String, ToolParameter> {
387 let mut params = HashMap::new();
388 params.insert(
389 "input".to_string(),
390 ToolParameter {
391 param_type: "string".to_string(),
392 description: "Input parameter".to_string(),
393 required: Some(true),
394 },
395 );
396 params
397 }
398
399 async fn execute(&self, args: Value) -> crate::Result<ToolResult> {
400 let input = args
401 .get("input")
402 .and_then(|v| v.as_str())
403 .unwrap_or("default");
404 Ok(ToolResult::success(format!("Mock tool output: {}", input)))
405 }
406 }
407}
408
409pub struct AgentBuilder {
410 name: String,
411 config: Option<Config>,
412 system_prompt: Option<String>,
413 tools: Vec<Box<dyn crate::tools::Tool>>,
414 max_iterations: usize,
415}
416
417impl AgentBuilder {
418 pub fn new(name: impl Into<String>) -> Self {
419 Self {
420 name: name.into(),
421 config: None,
422 system_prompt: None,
423 tools: Vec::new(),
424 max_iterations: 10,
425 }
426 }
427
428 pub fn config(mut self, config: Config) -> Self {
429 self.config = Some(config);
430 self
431 }
432
433 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
434 self.system_prompt = Some(prompt.into());
435 self
436 }
437
438 pub fn tool(mut self, tool: Box<dyn crate::tools::Tool>) -> Self {
439 self.tools.push(tool);
440 self
441 }
442
443 pub fn max_iterations(mut self, max: usize) -> Self {
444 self.max_iterations = max;
445 self
446 }
447
448 pub async fn build(self) -> Result<Agent> {
449 let config = self
450 .config
451 .ok_or_else(|| HeliosError::AgentError("Config is required".to_string()))?;
452
453 let mut agent = Agent::new(self.name, config).await?;
454
455 if let Some(prompt) = self.system_prompt {
456 agent.set_system_prompt(prompt);
457 }
458
459 for tool in self.tools {
460 agent.register_tool(tool);
461 }
462
463 agent.set_max_iterations(self.max_iterations);
464
465 Ok(agent)
466 }
467}