1#![allow(dead_code)]
9#![allow(unused_variables)]
10use crate::chat::{ChatMessage, ChatSession};
11use crate::config::Config;
12use crate::error::{HeliosError, Result};
13use crate::llm::{LLMClient, LLMProviderType};
14use crate::tools::{ToolRegistry, ToolResult};
15use serde_json::Value;
16
17const AGENT_MEMORY_PREFIX: &str = "agent:";
19
20pub struct Agent {
22 name: String,
24 llm_client: LLMClient,
26 tool_registry: ToolRegistry,
28 chat_session: ChatSession,
30 max_iterations: usize,
32 react_mode: bool,
34 react_prompt: Option<String>,
36}
37
38impl Agent {
39 async fn new(name: impl Into<String>, config: Config) -> Result<Self> {
50 #[cfg(feature = "local")]
51 let provider_type = if let Some(local_config) = config.local {
52 LLMProviderType::Local(local_config)
53 } else {
54 LLMProviderType::Remote(config.llm)
55 };
56
57 #[cfg(not(feature = "local"))]
58 let provider_type = LLMProviderType::Remote(config.llm);
59
60 let llm_client = LLMClient::new(provider_type).await?;
61
62 Ok(Self {
63 name: name.into(),
64 llm_client,
65 tool_registry: ToolRegistry::new(),
66 chat_session: ChatSession::new(),
67 max_iterations: 10,
68 react_mode: false,
69 react_prompt: None,
70 })
71 }
72
73 pub fn builder(name: impl Into<String>) -> AgentBuilder {
79 AgentBuilder::new(name)
80 }
81
82 pub async fn quick(name: impl Into<String>) -> Result<Self> {
103 let config = Config::load_or_default("config.toml");
104 Agent::builder(name).config(config).build().await
105 }
106
107 pub fn name(&self) -> &str {
109 &self.name
110 }
111
112 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
118 self.chat_session = self.chat_session.clone().with_system_prompt(prompt);
119 }
120
121 pub fn register_tool(&mut self, tool: Box<dyn crate::tools::Tool>) {
127 self.tool_registry.register(tool);
128 }
129
130 pub fn tool_registry(&self) -> &ToolRegistry {
132 &self.tool_registry
133 }
134
135 pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry {
137 &mut self.tool_registry
138 }
139
140 pub fn chat_session(&self) -> &ChatSession {
142 &self.chat_session
143 }
144
145 pub fn chat_session_mut(&mut self) -> &mut ChatSession {
147 &mut self.chat_session
148 }
149
150 pub fn clear_history(&mut self) {
152 self.chat_session.clear();
153 }
154
155 pub async fn send_message(&mut self, message: impl Into<String>) -> Result<String> {
165 let user_message = message.into();
166 self.chat_session.add_user_message(user_message.clone());
167
168 let response = self.execute_with_tools().await?;
170
171 Ok(response)
172 }
173
174 const DEFAULT_REASONING_PROMPT: &'static str = r#"Before taking any action, think through this step by step:
176
1771. What is the user asking for?
1782. What information or tools do I need to answer this?
1793. What is my plan to solve this problem?
180
181Provide your reasoning in a clear, structured way."#;
182
183 async fn generate_reasoning(&self) -> Result<String> {
189 let reasoning_prompt = self
190 .react_prompt
191 .as_deref()
192 .unwrap_or(Self::DEFAULT_REASONING_PROMPT);
193
194 let mut reasoning_messages = self.chat_session.get_messages();
196 reasoning_messages.push(ChatMessage::user(reasoning_prompt));
197
198 let response = self
200 .llm_client
201 .chat(reasoning_messages, None, None, None, None)
202 .await?;
203
204 Ok(response.content)
205 }
206
207 async fn handle_react_reasoning(&mut self) -> Result<()> {
213 if self.react_mode && !self.tool_registry.get_definitions().is_empty() {
215 let reasoning = self.generate_reasoning().await?;
216
217 println!("\n💠ReAct Reasoning:\n{}\n", reasoning);
219
220 self.chat_session
223 .add_message(ChatMessage::assistant(format!(
224 "[Reasoning]: {}",
225 reasoning
226 )));
227 }
228 Ok(())
229 }
230
231 async fn execute_with_tools(&mut self) -> Result<String> {
233 self.execute_with_tools_streaming().await
234 }
235
236 async fn execute_with_tools_streaming(&mut self) -> Result<String> {
238 self.execute_with_tools_streaming_with_params(None, None, None)
239 .await
240 }
241
242 async fn execute_with_tools_with_params(
244 &mut self,
245 temperature: Option<f32>,
246 max_tokens: Option<u32>,
247 stop: Option<Vec<String>>,
248 ) -> Result<String> {
249 self.handle_react_reasoning().await?;
251
252 let mut iterations = 0;
253 let tool_definitions = self.tool_registry.get_definitions();
254
255 loop {
256 if iterations >= self.max_iterations {
257 return Err(HeliosError::AgentError(
258 "Maximum iterations reached".to_string(),
259 ));
260 }
261
262 let messages = self.chat_session.get_messages();
263 let tools_option = if tool_definitions.is_empty() {
264 None
265 } else {
266 Some(tool_definitions.clone())
267 };
268
269 let response = self
270 .llm_client
271 .chat(
272 messages,
273 tools_option,
274 temperature,
275 max_tokens,
276 stop.clone(),
277 )
278 .await?;
279
280 if let Some(ref tool_calls) = response.tool_calls {
282 self.chat_session.add_message(response.clone());
284
285 for tool_call in tool_calls {
287 let tool_name = &tool_call.function.name;
288 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
289 .unwrap_or(Value::Object(serde_json::Map::new()));
290
291 let tool_result = self
292 .tool_registry
293 .execute(tool_name, tool_args)
294 .await
295 .unwrap_or_else(|e| {
296 ToolResult::error(format!("Tool execution failed: {}", e))
297 });
298
299 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
301 self.chat_session.add_message(tool_message);
302 }
303
304 iterations += 1;
305 continue;
306 }
307
308 self.chat_session.add_message(response.clone());
310 return Ok(response.content);
311 }
312 }
313
314 async fn execute_with_tools_streaming_with_params(
316 &mut self,
317 temperature: Option<f32>,
318 max_tokens: Option<u32>,
319 stop: Option<Vec<String>>,
320 ) -> Result<String> {
321 self.handle_react_reasoning().await?;
323
324 let mut iterations = 0;
325 let tool_definitions = self.tool_registry.get_definitions();
326
327 loop {
328 if iterations >= self.max_iterations {
329 return Err(HeliosError::AgentError(
330 "Maximum iterations reached".to_string(),
331 ));
332 }
333
334 let messages = self.chat_session.get_messages();
335 let tools_option = if tool_definitions.is_empty() {
336 None
337 } else {
338 Some(tool_definitions.clone())
339 };
340
341 let mut streamed_content = String::new();
342
343 let stream_result = self
344 .llm_client
345 .chat_stream(
346 messages,
347 tools_option, temperature,
349 max_tokens,
350 stop.clone(),
351 |chunk| {
352 print!("{}", chunk);
354 let _ = std::io::Write::flush(&mut std::io::stdout());
355 streamed_content.push_str(chunk);
356 },
357 )
358 .await;
359
360 let response = stream_result?;
361
362 println!();
364
365 if let Some(ref tool_calls) = response.tool_calls {
367 let mut msg_with_content = response.clone();
369 msg_with_content.content = streamed_content.clone();
370 self.chat_session.add_message(msg_with_content);
371
372 for tool_call in tool_calls {
374 let tool_name = &tool_call.function.name;
375 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
376 .unwrap_or(Value::Object(serde_json::Map::new()));
377
378 let tool_result = self
379 .tool_registry
380 .execute(tool_name, tool_args)
381 .await
382 .unwrap_or_else(|e| {
383 ToolResult::error(format!("Tool execution failed: {}", e))
384 });
385
386 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
388 self.chat_session.add_message(tool_message);
389 }
390
391 iterations += 1;
392 continue;
393 }
394
395 let mut final_msg = response;
397 final_msg.content = streamed_content.clone();
398 self.chat_session.add_message(final_msg);
399 return Ok(streamed_content);
400 }
401 }
402
403 pub async fn chat(&mut self, message: impl Into<String>) -> Result<String> {
405 self.send_message(message).await
406 }
407
408 pub async fn ask(&mut self, question: impl Into<String>) -> Result<String> {
410 self.chat(question).await
411 }
412
413 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
415 self.set_system_prompt(prompt);
416 self
417 }
418
419 pub fn with_tool(mut self, tool: Box<dyn crate::tools::Tool>) -> Self {
421 self.register_tool(tool);
422 self
423 }
424
425 pub fn with_tools(mut self, tools: Vec<Box<dyn crate::tools::Tool>>) -> Self {
427 for tool in tools {
428 self.register_tool(tool);
429 }
430 self
431 }
432
433 pub fn set_max_iterations(&mut self, max: usize) {
439 self.max_iterations = max;
440 }
441
442 pub fn get_session_summary(&self) -> String {
444 self.chat_session.get_summary()
445 }
446
447 pub fn clear_memory(&mut self) {
449 self.chat_session
451 .metadata
452 .retain(|k, _| !k.starts_with(AGENT_MEMORY_PREFIX));
453 }
454
455 #[inline]
457 fn prefixed_key(key: &str) -> String {
458 format!("{}{}", AGENT_MEMORY_PREFIX, key)
459 }
460
461 pub fn set_memory(&mut self, key: impl Into<String>, value: impl Into<String>) {
464 let key = key.into();
465 self.chat_session
466 .set_metadata(Self::prefixed_key(&key), value);
467 }
468
469 pub fn get_memory(&self, key: &str) -> Option<&String> {
471 self.chat_session.get_metadata(&Self::prefixed_key(key))
472 }
473
474 pub fn remove_memory(&mut self, key: &str) -> Option<String> {
476 self.chat_session.remove_metadata(&Self::prefixed_key(key))
477 }
478
479 pub fn increment_counter(&mut self, key: &str) -> u32 {
482 let current = self
483 .get_memory(key)
484 .and_then(|v| v.parse::<u32>().ok())
485 .unwrap_or(0);
486 let next = current + 1;
487 self.set_memory(key, next.to_string());
488 next
489 }
490
491 pub fn increment_tasks_completed(&mut self) -> u32 {
493 self.increment_counter("tasks_completed")
494 }
495
496 pub async fn chat_with_history(
514 &mut self,
515 messages: Vec<ChatMessage>,
516 temperature: Option<f32>,
517 max_tokens: Option<u32>,
518 stop: Option<Vec<String>>,
519 ) -> Result<String> {
520 let mut temp_session = ChatSession::new();
522
523 for message in messages {
525 temp_session.add_message(message);
526 }
527
528 self.execute_with_tools_temp_session(temp_session, temperature, max_tokens, stop)
530 .await
531 }
532
533 async fn execute_with_tools_temp_session(
535 &mut self,
536 mut temp_session: ChatSession,
537 temperature: Option<f32>,
538 max_tokens: Option<u32>,
539 stop: Option<Vec<String>>,
540 ) -> Result<String> {
541 let mut iterations = 0;
542 let tool_definitions = self.tool_registry.get_definitions();
543
544 loop {
545 if iterations >= self.max_iterations {
546 return Err(HeliosError::AgentError(
547 "Maximum iterations reached".to_string(),
548 ));
549 }
550
551 let messages = temp_session.get_messages();
552 let tools_option = if tool_definitions.is_empty() {
553 None
554 } else {
555 Some(tool_definitions.clone())
556 };
557
558 let response = self
559 .llm_client
560 .chat(
561 messages,
562 tools_option,
563 temperature,
564 max_tokens,
565 stop.clone(),
566 )
567 .await?;
568
569 if let Some(ref tool_calls) = response.tool_calls {
571 temp_session.add_message(response.clone());
573
574 for tool_call in tool_calls {
576 let tool_name = &tool_call.function.name;
577 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
578 .unwrap_or(Value::Object(serde_json::Map::new()));
579
580 let tool_result = self
581 .tool_registry
582 .execute(tool_name, tool_args)
583 .await
584 .unwrap_or_else(|e| {
585 ToolResult::error(format!("Tool execution failed: {}", e))
586 });
587
588 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
590 temp_session.add_message(tool_message);
591 }
592
593 iterations += 1;
594 continue;
595 }
596
597 return Ok(response.content);
599 }
600 }
601
602 pub async fn chat_stream_with_history<F>(
621 &mut self,
622 messages: Vec<ChatMessage>,
623 temperature: Option<f32>,
624 max_tokens: Option<u32>,
625 stop: Option<Vec<String>>,
626 on_chunk: F,
627 ) -> Result<ChatMessage>
628 where
629 F: FnMut(&str) + Send,
630 {
631 let mut temp_session = ChatSession::new();
633
634 for message in messages {
636 temp_session.add_message(message);
637 }
638
639 self.execute_streaming_with_tools_temp_session(
642 temp_session,
643 temperature,
644 max_tokens,
645 stop,
646 on_chunk,
647 )
648 .await
649 }
650
651 async fn execute_streaming_with_tools_temp_session<F>(
653 &mut self,
654 mut temp_session: ChatSession,
655 temperature: Option<f32>,
656 max_tokens: Option<u32>,
657 stop: Option<Vec<String>>,
658 mut on_chunk: F,
659 ) -> Result<ChatMessage>
660 where
661 F: FnMut(&str) + Send,
662 {
663 let mut iterations = 0;
664 let tool_definitions = self.tool_registry.get_definitions();
665
666 loop {
667 if iterations >= self.max_iterations {
668 return Err(HeliosError::AgentError(
669 "Maximum iterations reached".to_string(),
670 ));
671 }
672
673 let messages = temp_session.get_messages();
674 let tools_option = if tool_definitions.is_empty() {
675 None
676 } else {
677 Some(tool_definitions.clone())
678 };
679
680 let mut streamed_content = String::new();
682
683 let stream_result = self
684 .llm_client
685 .chat_stream(
686 messages,
687 tools_option,
688 temperature,
689 max_tokens,
690 stop.clone(),
691 |chunk| {
692 on_chunk(chunk);
693 streamed_content.push_str(chunk);
694 },
695 )
696 .await;
697
698 match stream_result {
699 Ok(response) => {
700 if let Some(ref tool_calls) = response.tool_calls {
702 let mut msg_with_content = response.clone();
704 msg_with_content.content = streamed_content.clone();
705 temp_session.add_message(msg_with_content);
706
707 for tool_call in tool_calls {
709 let tool_name = &tool_call.function.name;
710 let tool_args: Value =
711 serde_json::from_str(&tool_call.function.arguments)
712 .unwrap_or(Value::Object(serde_json::Map::new()));
713
714 let tool_result = self
715 .tool_registry
716 .execute(tool_name, tool_args)
717 .await
718 .unwrap_or_else(|e| {
719 ToolResult::error(format!("Tool execution failed: {}", e))
720 });
721
722 let tool_message =
724 ChatMessage::tool(tool_result.output, tool_call.id.clone());
725 temp_session.add_message(tool_message);
726 }
727
728 iterations += 1;
729 continue; } else {
731 let mut final_msg = response;
733 final_msg.content = streamed_content;
734 return Ok(final_msg);
735 }
736 }
737 Err(e) => return Err(e),
738 }
739 }
740 }
741}
742
743pub struct AgentBuilder {
744 name: String,
745 config: Option<Config>,
746 system_prompt: Option<String>,
747 tools: Vec<Box<dyn crate::tools::Tool>>,
748 max_iterations: usize,
749 react_mode: bool,
750 react_prompt: Option<String>,
751}
752
753impl AgentBuilder {
754 pub fn new(name: impl Into<String>) -> Self {
755 Self {
756 name: name.into(),
757 config: None,
758 system_prompt: None,
759 tools: Vec::new(),
760 max_iterations: 10,
761 react_mode: false,
762 react_prompt: None,
763 }
764 }
765
766 pub fn config(mut self, config: Config) -> Self {
767 self.config = Some(config);
768 self
769 }
770
771 pub fn auto_config(mut self) -> Self {
773 self.config = Some(Config::load_or_default("config.toml"));
774 self
775 }
776
777 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
778 self.system_prompt = Some(prompt.into());
779 self
780 }
781
782 pub fn prompt(self, prompt: impl Into<String>) -> Self {
784 self.system_prompt(prompt)
785 }
786
787 pub fn tool(mut self, tool: Box<dyn crate::tools::Tool>) -> Self {
789 self.tools.push(tool);
790 self
791 }
792
793 pub fn with_tool(mut self, tool: Box<dyn crate::tools::Tool>) -> Self {
795 self.tools.push(tool);
796 self
797 }
798
799 pub fn tools(mut self, tools: Vec<Box<dyn crate::tools::Tool>>) -> Self {
819 self.tools.extend(tools);
820 self
821 }
822
823 pub fn with_tools(mut self, tools: Vec<Box<dyn crate::tools::Tool>>) -> Self {
825 self.tools.extend(tools);
826 self
827 }
828
829 pub fn max_iterations(mut self, max: usize) -> Self {
830 self.max_iterations = max;
831 self
832 }
833
834 pub fn react(mut self) -> Self {
855 self.react_mode = true;
856 self
857 }
858
859 pub fn react_with_prompt(mut self, prompt: impl Into<String>) -> Self {
891 self.react_mode = true;
892 self.react_prompt = Some(prompt.into());
893 self
894 }
895
896 pub async fn build(self) -> Result<Agent> {
897 let config = self
898 .config
899 .ok_or_else(|| HeliosError::AgentError("Config is required".to_string()))?;
900
901 let mut agent = Agent::new(self.name, config).await?;
902
903 if let Some(prompt) = self.system_prompt {
904 agent.set_system_prompt(prompt);
905 }
906
907 for tool in self.tools {
908 agent.register_tool(tool);
909 }
910
911 agent.set_max_iterations(self.max_iterations);
912 agent.react_mode = self.react_mode;
913 agent.react_prompt = self.react_prompt;
914
915 Ok(agent)
916 }
917}
918
919#[cfg(test)]
920mod tests {
921 use super::*;
922 use crate::config::Config;
923 use crate::tools::{CalculatorTool, Tool, ToolParameter, ToolResult};
924 use serde_json::Value;
925 use std::collections::HashMap;
926
927 #[tokio::test]
929 async fn test_agent_creation_via_builder() {
930 let config = Config::new_default();
931 let agent = Agent::builder("test_agent").config(config).build().await;
932 assert!(agent.is_ok());
933 }
934
935 #[tokio::test]
937 async fn test_agent_memory_namespacing_set_get_remove() {
938 let config = Config::new_default();
939 let mut agent = Agent::builder("test_agent")
940 .config(config)
941 .build()
942 .await
943 .unwrap();
944
945 agent.set_memory("working_directory", "/tmp");
947 assert_eq!(
948 agent.get_memory("working_directory"),
949 Some(&"/tmp".to_string())
950 );
951
952 assert_eq!(
954 agent.chat_session().get_metadata("agent:working_directory"),
955 Some(&"/tmp".to_string())
956 );
957 assert!(agent
959 .chat_session()
960 .get_metadata("working_directory")
961 .is_none());
962
963 let removed = agent.remove_memory("working_directory");
965 assert_eq!(removed.as_deref(), Some("/tmp"));
966 assert!(agent.get_memory("working_directory").is_none());
967 }
968
969 #[tokio::test]
971 async fn test_agent_clear_memory_scoped() {
972 let config = Config::new_default();
973 let mut agent = Agent::builder("test_agent")
974 .config(config)
975 .build()
976 .await
977 .unwrap();
978
979 agent.set_memory("tasks_completed", "3");
981 agent
982 .chat_session_mut()
983 .set_metadata("session_start", "now");
984
985 agent.clear_memory();
987
988 assert!(agent.get_memory("tasks_completed").is_none());
990 assert_eq!(
992 agent.chat_session().get_metadata("session_start"),
993 Some(&"now".to_string())
994 );
995 }
996
997 #[tokio::test]
999 async fn test_agent_increment_helpers() {
1000 let config = Config::new_default();
1001 let mut agent = Agent::builder("test_agent")
1002 .config(config)
1003 .build()
1004 .await
1005 .unwrap();
1006
1007 let n1 = agent.increment_tasks_completed();
1009 assert_eq!(n1, 1);
1010 assert_eq!(agent.get_memory("tasks_completed"), Some(&"1".to_string()));
1011
1012 let n2 = agent.increment_tasks_completed();
1013 assert_eq!(n2, 2);
1014 assert_eq!(agent.get_memory("tasks_completed"), Some(&"2".to_string()));
1015
1016 let f1 = agent.increment_counter("files_accessed");
1018 assert_eq!(f1, 1);
1019 let f2 = agent.increment_counter("files_accessed");
1020 assert_eq!(f2, 2);
1021 assert_eq!(agent.get_memory("files_accessed"), Some(&"2".to_string()));
1022 }
1023
1024 #[tokio::test]
1026 async fn test_agent_builder() {
1027 let config = Config::new_default();
1028 let agent = Agent::builder("test_agent")
1029 .config(config)
1030 .system_prompt("You are a helpful assistant")
1031 .max_iterations(5)
1032 .tool(Box::new(CalculatorTool))
1033 .build()
1034 .await
1035 .unwrap();
1036
1037 assert_eq!(agent.name(), "test_agent");
1038 assert_eq!(agent.max_iterations, 5);
1039 assert_eq!(
1040 agent.tool_registry().list_tools(),
1041 vec!["calculator".to_string()]
1042 );
1043 }
1044
1045 #[tokio::test]
1047 async fn test_agent_system_prompt() {
1048 let config = Config::new_default();
1049 let mut agent = Agent::builder("test_agent")
1050 .config(config)
1051 .build()
1052 .await
1053 .unwrap();
1054 agent.set_system_prompt("You are a test agent");
1055
1056 let session = agent.chat_session();
1058 assert_eq!(
1059 session.system_prompt,
1060 Some("You are a test agent".to_string())
1061 );
1062 }
1063
1064 #[tokio::test]
1066 async fn test_agent_tool_registry() {
1067 let config = Config::new_default();
1068 let mut agent = Agent::builder("test_agent")
1069 .config(config)
1070 .build()
1071 .await
1072 .unwrap();
1073
1074 assert!(agent.tool_registry().list_tools().is_empty());
1076
1077 agent.register_tool(Box::new(CalculatorTool));
1079 assert_eq!(
1080 agent.tool_registry().list_tools(),
1081 vec!["calculator".to_string()]
1082 );
1083 }
1084
1085 #[tokio::test]
1087 async fn test_agent_clear_history() {
1088 let config = Config::new_default();
1089 let mut agent = Agent::builder("test_agent")
1090 .config(config)
1091 .build()
1092 .await
1093 .unwrap();
1094
1095 agent.chat_session_mut().add_user_message("Hello");
1097 assert!(!agent.chat_session().messages.is_empty());
1098
1099 agent.clear_history();
1101 assert!(agent.chat_session().messages.is_empty());
1102 }
1103
1104 struct MockTool;
1106
1107 #[async_trait::async_trait]
1108 impl Tool for MockTool {
1109 fn name(&self) -> &str {
1110 "mock_tool"
1111 }
1112
1113 fn description(&self) -> &str {
1114 "A mock tool for testing"
1115 }
1116
1117 fn parameters(&self) -> HashMap<String, ToolParameter> {
1118 let mut params = HashMap::new();
1119 params.insert(
1120 "input".to_string(),
1121 ToolParameter {
1122 param_type: "string".to_string(),
1123 description: "Input parameter".to_string(),
1124 required: Some(true),
1125 },
1126 );
1127 params
1128 }
1129
1130 async fn execute(&self, args: Value) -> crate::Result<ToolResult> {
1131 let input = args
1132 .get("input")
1133 .and_then(|v| v.as_str())
1134 .unwrap_or("default");
1135 Ok(ToolResult::success(format!("Mock tool output: {}", input)))
1136 }
1137 }
1138}