1#![allow(dead_code)]
9#![allow(unused_variables)]
10use crate::chat::{ChatMessage, ChatSession};
11use crate::config::Config;
12use crate::error::{HeliosError, Result};
13use crate::llm::{LLMClient, LLMProviderType};
14use crate::tools::{ToolRegistry, ToolResult};
15use serde_json::Value;
16
17const AGENT_MEMORY_PREFIX: &str = "agent:";
19
20pub struct Agent {
22 name: String,
24 llm_client: LLMClient,
26 tool_registry: ToolRegistry,
28 chat_session: ChatSession,
30 max_iterations: usize,
32}
33
34impl Agent {
35 async fn new(name: impl Into<String>, config: Config) -> Result<Self> {
46 #[cfg(feature = "local")]
47 let provider_type = if let Some(local_config) = config.local {
48 LLMProviderType::Local(local_config)
49 } else {
50 LLMProviderType::Remote(config.llm)
51 };
52
53 #[cfg(not(feature = "local"))]
54 let provider_type = LLMProviderType::Remote(config.llm);
55
56 let llm_client = LLMClient::new(provider_type).await?;
57
58 Ok(Self {
59 name: name.into(),
60 llm_client,
61 tool_registry: ToolRegistry::new(),
62 chat_session: ChatSession::new(),
63 max_iterations: 10,
64 })
65 }
66
67 pub fn builder(name: impl Into<String>) -> AgentBuilder {
73 AgentBuilder::new(name)
74 }
75
76 pub fn name(&self) -> &str {
78 &self.name
79 }
80
81 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
87 self.chat_session = self.chat_session.clone().with_system_prompt(prompt);
88 }
89
90 pub fn register_tool(&mut self, tool: Box<dyn crate::tools::Tool>) {
96 self.tool_registry.register(tool);
97 }
98
99 pub fn tool_registry(&self) -> &ToolRegistry {
101 &self.tool_registry
102 }
103
104 pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry {
106 &mut self.tool_registry
107 }
108
109 pub fn chat_session(&self) -> &ChatSession {
111 &self.chat_session
112 }
113
114 pub fn chat_session_mut(&mut self) -> &mut ChatSession {
116 &mut self.chat_session
117 }
118
119 pub fn clear_history(&mut self) {
121 self.chat_session.clear();
122 }
123
124 pub async fn send_message(&mut self, message: impl Into<String>) -> Result<String> {
134 let user_message = message.into();
135 self.chat_session.add_user_message(user_message.clone());
136
137 let response = self.execute_with_tools().await?;
139
140 Ok(response)
141 }
142
143 async fn execute_with_tools(&mut self) -> Result<String> {
145 self.execute_with_tools_with_params(None, None, None).await
146 }
147
148 async fn execute_with_tools_with_params(
150 &mut self,
151 temperature: Option<f32>,
152 max_tokens: Option<u32>,
153 stop: Option<Vec<String>>,
154 ) -> Result<String> {
155 let mut iterations = 0;
156 let tool_definitions = self.tool_registry.get_definitions();
157
158 loop {
159 if iterations >= self.max_iterations {
160 return Err(HeliosError::AgentError(
161 "Maximum iterations reached".to_string(),
162 ));
163 }
164
165 let messages = self.chat_session.get_messages();
166 let tools_option = if tool_definitions.is_empty() {
167 None
168 } else {
169 Some(tool_definitions.clone())
170 };
171
172 let response = self
173 .llm_client
174 .chat(
175 messages,
176 tools_option,
177 temperature,
178 max_tokens,
179 stop.clone(),
180 )
181 .await?;
182
183 if let Some(ref tool_calls) = response.tool_calls {
185 self.chat_session.add_message(response.clone());
187
188 for tool_call in tool_calls {
190 let tool_name = &tool_call.function.name;
191 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
192 .unwrap_or(Value::Object(serde_json::Map::new()));
193
194 let tool_result = self
195 .tool_registry
196 .execute(tool_name, tool_args)
197 .await
198 .unwrap_or_else(|e| {
199 ToolResult::error(format!("Tool execution failed: {}", e))
200 });
201
202 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
204 self.chat_session.add_message(tool_message);
205 }
206
207 iterations += 1;
208 continue;
209 }
210
211 self.chat_session.add_message(response.clone());
213 return Ok(response.content);
214 }
215 }
216
217 pub async fn chat(&mut self, message: impl Into<String>) -> Result<String> {
219 self.send_message(message).await
220 }
221
222 pub fn set_max_iterations(&mut self, max: usize) {
228 self.max_iterations = max;
229 }
230
231 pub fn get_session_summary(&self) -> String {
233 self.chat_session.get_summary()
234 }
235
236 pub fn clear_memory(&mut self) {
238 self.chat_session
240 .metadata
241 .retain(|k, _| !k.starts_with(AGENT_MEMORY_PREFIX));
242 }
243
244 #[inline]
246 fn prefixed_key(key: &str) -> String {
247 format!("{}{}", AGENT_MEMORY_PREFIX, key)
248 }
249
250 pub fn set_memory(&mut self, key: impl Into<String>, value: impl Into<String>) {
253 let key = key.into();
254 self.chat_session
255 .set_metadata(Self::prefixed_key(&key), value);
256 }
257
258 pub fn get_memory(&self, key: &str) -> Option<&String> {
260 self.chat_session.get_metadata(&Self::prefixed_key(key))
261 }
262
263 pub fn remove_memory(&mut self, key: &str) -> Option<String> {
265 self.chat_session.remove_metadata(&Self::prefixed_key(key))
266 }
267
268 pub fn increment_counter(&mut self, key: &str) -> u32 {
271 let current = self
272 .get_memory(key)
273 .and_then(|v| v.parse::<u32>().ok())
274 .unwrap_or(0);
275 let next = current + 1;
276 self.set_memory(key, next.to_string());
277 next
278 }
279
280 pub fn increment_tasks_completed(&mut self) -> u32 {
282 self.increment_counter("tasks_completed")
283 }
284
285 pub async fn chat_with_history(
303 &mut self,
304 messages: Vec<ChatMessage>,
305 temperature: Option<f32>,
306 max_tokens: Option<u32>,
307 stop: Option<Vec<String>>,
308 ) -> Result<String> {
309 let mut temp_session = ChatSession::new();
311
312 for message in messages {
314 temp_session.add_message(message);
315 }
316
317 self.execute_with_tools_temp_session(temp_session, temperature, max_tokens, stop)
319 .await
320 }
321
322 async fn execute_with_tools_temp_session(
324 &mut self,
325 mut temp_session: ChatSession,
326 temperature: Option<f32>,
327 max_tokens: Option<u32>,
328 stop: Option<Vec<String>>,
329 ) -> Result<String> {
330 let mut iterations = 0;
331 let tool_definitions = self.tool_registry.get_definitions();
332
333 loop {
334 if iterations >= self.max_iterations {
335 return Err(HeliosError::AgentError(
336 "Maximum iterations reached".to_string(),
337 ));
338 }
339
340 let messages = temp_session.get_messages();
341 let tools_option = if tool_definitions.is_empty() {
342 None
343 } else {
344 Some(tool_definitions.clone())
345 };
346
347 let response = self
348 .llm_client
349 .chat(
350 messages,
351 tools_option,
352 temperature,
353 max_tokens,
354 stop.clone(),
355 )
356 .await?;
357
358 if let Some(ref tool_calls) = response.tool_calls {
360 temp_session.add_message(response.clone());
362
363 for tool_call in tool_calls {
365 let tool_name = &tool_call.function.name;
366 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
367 .unwrap_or(Value::Object(serde_json::Map::new()));
368
369 let tool_result = self
370 .tool_registry
371 .execute(tool_name, tool_args)
372 .await
373 .unwrap_or_else(|e| {
374 ToolResult::error(format!("Tool execution failed: {}", e))
375 });
376
377 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
379 temp_session.add_message(tool_message);
380 }
381
382 iterations += 1;
383 continue;
384 }
385
386 return Ok(response.content);
388 }
389 }
390
391 pub async fn chat_stream_with_history<F>(
410 &mut self,
411 messages: Vec<ChatMessage>,
412 temperature: Option<f32>,
413 max_tokens: Option<u32>,
414 stop: Option<Vec<String>>,
415 on_chunk: F,
416 ) -> Result<ChatMessage>
417 where
418 F: FnMut(&str) + Send,
419 {
420 let mut temp_session = ChatSession::new();
422
423 for message in messages {
425 temp_session.add_message(message);
426 }
427
428 self.execute_streaming_with_tools_temp_session(
431 temp_session,
432 temperature,
433 max_tokens,
434 stop,
435 on_chunk,
436 )
437 .await
438 }
439
440 async fn execute_streaming_with_tools_temp_session<F>(
442 &mut self,
443 mut temp_session: ChatSession,
444 temperature: Option<f32>,
445 max_tokens: Option<u32>,
446 stop: Option<Vec<String>>,
447 mut on_chunk: F,
448 ) -> Result<ChatMessage>
449 where
450 F: FnMut(&str) + Send,
451 {
452 let mut iterations = 0;
453 let tool_definitions = self.tool_registry.get_definitions();
454
455 loop {
456 if iterations >= self.max_iterations {
457 return Err(HeliosError::AgentError(
458 "Maximum iterations reached".to_string(),
459 ));
460 }
461
462 let messages = temp_session.get_messages();
463 let tools_option = if tool_definitions.is_empty() {
464 None
465 } else {
466 Some(tool_definitions.clone())
467 };
468
469 if iterations == 0 {
471 let mut streamed_content = String::new();
472
473 let stream_result = self
474 .llm_client
475 .chat_stream(
476 messages,
477 tools_option,
478 temperature,
479 max_tokens,
480 stop.clone(),
481 |chunk| {
482 on_chunk(chunk);
483 streamed_content.push_str(chunk);
484 },
485 )
486 .await;
487
488 match stream_result {
489 Ok(response) => {
490 if let Some(ref tool_calls) = response.tool_calls {
492 temp_session.add_message(response.clone());
494
495 for tool_call in tool_calls {
497 let tool_name = &tool_call.function.name;
498 let tool_args: Value =
499 serde_json::from_str(&tool_call.function.arguments)
500 .unwrap_or(Value::Object(serde_json::Map::new()));
501
502 let tool_result = self
503 .tool_registry
504 .execute(tool_name, tool_args)
505 .await
506 .unwrap_or_else(|e| {
507 ToolResult::error(format!("Tool execution failed: {}", e))
508 });
509
510 let tool_message =
512 ChatMessage::tool(tool_result.output, tool_call.id.clone());
513 temp_session.add_message(tool_message);
514 }
515
516 iterations += 1;
517 continue; } else {
519 let mut final_msg = response;
521 final_msg.content = streamed_content;
522 return Ok(final_msg);
523 }
524 }
525 Err(e) => return Err(e),
526 }
527 } else {
528 let response = self
531 .llm_client
532 .chat(
533 messages,
534 tools_option,
535 temperature,
536 max_tokens,
537 stop.clone(),
538 )
539 .await?;
540
541 if let Some(ref tool_calls) = response.tool_calls {
542 temp_session.add_message(response.clone());
544
545 for tool_call in tool_calls {
547 let tool_name = &tool_call.function.name;
548 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
549 .unwrap_or(Value::Object(serde_json::Map::new()));
550
551 let tool_result = self
552 .tool_registry
553 .execute(tool_name, tool_args)
554 .await
555 .unwrap_or_else(|e| {
556 ToolResult::error(format!("Tool execution failed: {}", e))
557 });
558
559 let tool_message =
561 ChatMessage::tool(tool_result.output, tool_call.id.clone());
562 temp_session.add_message(tool_message);
563 }
564
565 iterations += 1;
566 continue;
567 }
568
569 return Ok(response);
571 }
572 }
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use crate::config::Config;
580 use crate::tools::{CalculatorTool, Tool, ToolParameter, ToolResult};
581 use serde_json::Value;
582 use std::collections::HashMap;
583
584 #[tokio::test]
586 async fn test_agent_creation_via_builder() {
587 let config = Config::new_default();
588 let agent = Agent::builder("test_agent").config(config).build().await;
589 assert!(agent.is_ok());
590 }
591
592 #[tokio::test]
594 async fn test_agent_memory_namespacing_set_get_remove() {
595 let config = Config::new_default();
596 let mut agent = Agent::builder("test_agent")
597 .config(config)
598 .build()
599 .await
600 .unwrap();
601
602 agent.set_memory("working_directory", "/tmp");
604 assert_eq!(
605 agent.get_memory("working_directory"),
606 Some(&"/tmp".to_string())
607 );
608
609 assert_eq!(
611 agent.chat_session().get_metadata("agent:working_directory"),
612 Some(&"/tmp".to_string())
613 );
614 assert!(agent
616 .chat_session()
617 .get_metadata("working_directory")
618 .is_none());
619
620 let removed = agent.remove_memory("working_directory");
622 assert_eq!(removed.as_deref(), Some("/tmp"));
623 assert!(agent.get_memory("working_directory").is_none());
624 }
625
626 #[tokio::test]
628 async fn test_agent_clear_memory_scoped() {
629 let config = Config::new_default();
630 let mut agent = Agent::builder("test_agent")
631 .config(config)
632 .build()
633 .await
634 .unwrap();
635
636 agent.set_memory("tasks_completed", "3");
638 agent
639 .chat_session_mut()
640 .set_metadata("session_start", "now");
641
642 agent.clear_memory();
644
645 assert!(agent.get_memory("tasks_completed").is_none());
647 assert_eq!(
649 agent.chat_session().get_metadata("session_start"),
650 Some(&"now".to_string())
651 );
652 }
653
654 #[tokio::test]
656 async fn test_agent_increment_helpers() {
657 let config = Config::new_default();
658 let mut agent = Agent::builder("test_agent")
659 .config(config)
660 .build()
661 .await
662 .unwrap();
663
664 let n1 = agent.increment_tasks_completed();
666 assert_eq!(n1, 1);
667 assert_eq!(agent.get_memory("tasks_completed"), Some(&"1".to_string()));
668
669 let n2 = agent.increment_tasks_completed();
670 assert_eq!(n2, 2);
671 assert_eq!(agent.get_memory("tasks_completed"), Some(&"2".to_string()));
672
673 let f1 = agent.increment_counter("files_accessed");
675 assert_eq!(f1, 1);
676 let f2 = agent.increment_counter("files_accessed");
677 assert_eq!(f2, 2);
678 assert_eq!(agent.get_memory("files_accessed"), Some(&"2".to_string()));
679 }
680
681 #[tokio::test]
683 async fn test_agent_builder() {
684 let config = Config::new_default();
685 let agent = Agent::builder("test_agent")
686 .config(config)
687 .system_prompt("You are a helpful assistant")
688 .max_iterations(5)
689 .tool(Box::new(CalculatorTool))
690 .build()
691 .await
692 .unwrap();
693
694 assert_eq!(agent.name(), "test_agent");
695 assert_eq!(agent.max_iterations, 5);
696 assert_eq!(
697 agent.tool_registry().list_tools(),
698 vec!["calculator".to_string()]
699 );
700 }
701
702 #[tokio::test]
704 async fn test_agent_system_prompt() {
705 let config = Config::new_default();
706 let mut agent = Agent::builder("test_agent")
707 .config(config)
708 .build()
709 .await
710 .unwrap();
711 agent.set_system_prompt("You are a test agent");
712
713 let session = agent.chat_session();
715 assert_eq!(
716 session.system_prompt,
717 Some("You are a test agent".to_string())
718 );
719 }
720
721 #[tokio::test]
723 async fn test_agent_tool_registry() {
724 let config = Config::new_default();
725 let mut agent = Agent::builder("test_agent")
726 .config(config)
727 .build()
728 .await
729 .unwrap();
730
731 assert!(agent.tool_registry().list_tools().is_empty());
733
734 agent.register_tool(Box::new(CalculatorTool));
736 assert_eq!(
737 agent.tool_registry().list_tools(),
738 vec!["calculator".to_string()]
739 );
740 }
741
742 #[tokio::test]
744 async fn test_agent_clear_history() {
745 let config = Config::new_default();
746 let mut agent = Agent::builder("test_agent")
747 .config(config)
748 .build()
749 .await
750 .unwrap();
751
752 agent.chat_session_mut().add_user_message("Hello");
754 assert!(!agent.chat_session().messages.is_empty());
755
756 agent.clear_history();
758 assert!(agent.chat_session().messages.is_empty());
759 }
760
761 struct MockTool;
763
764 #[async_trait::async_trait]
765 impl Tool for MockTool {
766 fn name(&self) -> &str {
767 "mock_tool"
768 }
769
770 fn description(&self) -> &str {
771 "A mock tool for testing"
772 }
773
774 fn parameters(&self) -> HashMap<String, ToolParameter> {
775 let mut params = HashMap::new();
776 params.insert(
777 "input".to_string(),
778 ToolParameter {
779 param_type: "string".to_string(),
780 description: "Input parameter".to_string(),
781 required: Some(true),
782 },
783 );
784 params
785 }
786
787 async fn execute(&self, args: Value) -> crate::Result<ToolResult> {
788 let input = args
789 .get("input")
790 .and_then(|v| v.as_str())
791 .unwrap_or("default");
792 Ok(ToolResult::success(format!("Mock tool output: {}", input)))
793 }
794 }
795}
796
797pub struct AgentBuilder {
798 name: String,
799 config: Option<Config>,
800 system_prompt: Option<String>,
801 tools: Vec<Box<dyn crate::tools::Tool>>,
802 max_iterations: usize,
803}
804
805impl AgentBuilder {
806 pub fn new(name: impl Into<String>) -> Self {
807 Self {
808 name: name.into(),
809 config: None,
810 system_prompt: None,
811 tools: Vec::new(),
812 max_iterations: 10,
813 }
814 }
815
816 pub fn config(mut self, config: Config) -> Self {
817 self.config = Some(config);
818 self
819 }
820
821 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
822 self.system_prompt = Some(prompt.into());
823 self
824 }
825
826 pub fn tool(mut self, tool: Box<dyn crate::tools::Tool>) -> Self {
827 self.tools.push(tool);
828 self
829 }
830
831 pub fn max_iterations(mut self, max: usize) -> Self {
832 self.max_iterations = max;
833 self
834 }
835
836 pub async fn build(self) -> Result<Agent> {
837 let config = self
838 .config
839 .ok_or_else(|| HeliosError::AgentError("Config is required".to_string()))?;
840
841 let mut agent = Agent::new(self.name, config).await?;
842
843 if let Some(prompt) = self.system_prompt {
844 agent.set_system_prompt(prompt);
845 }
846
847 for tool in self.tools {
848 agent.register_tool(tool);
849 }
850
851 agent.set_max_iterations(self.max_iterations);
852
853 Ok(agent)
854 }
855}