1#![allow(dead_code)]
9#![allow(unused_variables)]
10use crate::chat::{ChatMessage, ChatSession};
11use crate::config::Config;
12use crate::error::{HeliosError, Result};
13use crate::llm::{LLMClient, LLMProviderType};
14use crate::tools::{ToolRegistry, ToolResult};
15use serde_json::Value;
16
17const AGENT_MEMORY_PREFIX: &str = "agent:";
19
20pub struct Agent {
22 name: String,
24 llm_client: LLMClient,
26 tool_registry: ToolRegistry,
28 chat_session: ChatSession,
30 max_iterations: usize,
32 react_mode: bool,
34 react_prompt: Option<String>,
36}
37
38impl Agent {
39 async fn new(name: impl Into<String>, config: Config) -> Result<Self> {
50 #[cfg(feature = "local")]
51 let provider_type = if let Some(local_config) = config.local {
52 LLMProviderType::Local(local_config)
53 } else {
54 LLMProviderType::Remote(config.llm)
55 };
56
57 #[cfg(not(feature = "local"))]
58 let provider_type = LLMProviderType::Remote(config.llm);
59
60 let llm_client = LLMClient::new(provider_type).await?;
61
62 Ok(Self {
63 name: name.into(),
64 llm_client,
65 tool_registry: ToolRegistry::new(),
66 chat_session: ChatSession::new(),
67 max_iterations: 10,
68 react_mode: false,
69 react_prompt: None,
70 })
71 }
72
73 pub fn builder(name: impl Into<String>) -> AgentBuilder {
79 AgentBuilder::new(name)
80 }
81
82 pub fn name(&self) -> &str {
84 &self.name
85 }
86
87 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
93 self.chat_session = self.chat_session.clone().with_system_prompt(prompt);
94 }
95
96 pub fn register_tool(&mut self, tool: Box<dyn crate::tools::Tool>) {
102 self.tool_registry.register(tool);
103 }
104
105 pub fn tool_registry(&self) -> &ToolRegistry {
107 &self.tool_registry
108 }
109
110 pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry {
112 &mut self.tool_registry
113 }
114
115 pub fn chat_session(&self) -> &ChatSession {
117 &self.chat_session
118 }
119
120 pub fn chat_session_mut(&mut self) -> &mut ChatSession {
122 &mut self.chat_session
123 }
124
125 pub fn clear_history(&mut self) {
127 self.chat_session.clear();
128 }
129
130 pub async fn send_message(&mut self, message: impl Into<String>) -> Result<String> {
140 let user_message = message.into();
141 self.chat_session.add_user_message(user_message.clone());
142
143 let response = self.execute_with_tools().await?;
145
146 Ok(response)
147 }
148
149 const DEFAULT_REASONING_PROMPT: &'static str = r#"Before taking any action, think through this step by step:
151
1521. What is the user asking for?
1532. What information or tools do I need to answer this?
1543. What is my plan to solve this problem?
155
156Provide your reasoning in a clear, structured way."#;
157
158 async fn generate_reasoning(&self) -> Result<String> {
164 let reasoning_prompt = self
165 .react_prompt
166 .as_deref()
167 .unwrap_or(Self::DEFAULT_REASONING_PROMPT);
168
169 let mut reasoning_messages = self.chat_session.get_messages();
171 reasoning_messages.push(ChatMessage::user(reasoning_prompt));
172
173 let response = self
175 .llm_client
176 .chat(reasoning_messages, None, None, None, None)
177 .await?;
178
179 Ok(response.content)
180 }
181
182 async fn handle_react_reasoning(&mut self) -> Result<()> {
188 if self.react_mode && !self.tool_registry.get_definitions().is_empty() {
190 let reasoning = self.generate_reasoning().await?;
191
192 println!("\n💠ReAct Reasoning:\n{}\n", reasoning);
194
195 self.chat_session
198 .add_message(ChatMessage::assistant(format!(
199 "[Reasoning]: {}",
200 reasoning
201 )));
202 }
203 Ok(())
204 }
205
206 async fn execute_with_tools(&mut self) -> Result<String> {
208 self.execute_with_tools_streaming().await
209 }
210
211 async fn execute_with_tools_streaming(&mut self) -> Result<String> {
213 self.execute_with_tools_streaming_with_params(None, None, None)
214 .await
215 }
216
217 async fn execute_with_tools_with_params(
219 &mut self,
220 temperature: Option<f32>,
221 max_tokens: Option<u32>,
222 stop: Option<Vec<String>>,
223 ) -> Result<String> {
224 self.handle_react_reasoning().await?;
226
227 let mut iterations = 0;
228 let tool_definitions = self.tool_registry.get_definitions();
229
230 loop {
231 if iterations >= self.max_iterations {
232 return Err(HeliosError::AgentError(
233 "Maximum iterations reached".to_string(),
234 ));
235 }
236
237 let messages = self.chat_session.get_messages();
238 let tools_option = if tool_definitions.is_empty() {
239 None
240 } else {
241 Some(tool_definitions.clone())
242 };
243
244 let response = self
245 .llm_client
246 .chat(
247 messages,
248 tools_option,
249 temperature,
250 max_tokens,
251 stop.clone(),
252 )
253 .await?;
254
255 if let Some(ref tool_calls) = response.tool_calls {
257 self.chat_session.add_message(response.clone());
259
260 for tool_call in tool_calls {
262 let tool_name = &tool_call.function.name;
263 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
264 .unwrap_or(Value::Object(serde_json::Map::new()));
265
266 let tool_result = self
267 .tool_registry
268 .execute(tool_name, tool_args)
269 .await
270 .unwrap_or_else(|e| {
271 ToolResult::error(format!("Tool execution failed: {}", e))
272 });
273
274 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
276 self.chat_session.add_message(tool_message);
277 }
278
279 iterations += 1;
280 continue;
281 }
282
283 self.chat_session.add_message(response.clone());
285 return Ok(response.content);
286 }
287 }
288
289 async fn execute_with_tools_streaming_with_params(
291 &mut self,
292 temperature: Option<f32>,
293 max_tokens: Option<u32>,
294 stop: Option<Vec<String>>,
295 ) -> Result<String> {
296 self.handle_react_reasoning().await?;
298
299 let mut iterations = 0;
300 let tool_definitions = self.tool_registry.get_definitions();
301
302 loop {
303 if iterations >= self.max_iterations {
304 return Err(HeliosError::AgentError(
305 "Maximum iterations reached".to_string(),
306 ));
307 }
308
309 let messages = self.chat_session.get_messages();
310 let tools_option = if tool_definitions.is_empty() {
311 None
312 } else {
313 Some(tool_definitions.clone())
314 };
315
316 let mut streamed_content = String::new();
317
318 let stream_result = self
319 .llm_client
320 .chat_stream(
321 messages,
322 tools_option, temperature,
324 max_tokens,
325 stop.clone(),
326 |chunk| {
327 print!("{}", chunk);
329 let _ = std::io::Write::flush(&mut std::io::stdout());
330 streamed_content.push_str(chunk);
331 },
332 )
333 .await;
334
335 let response = stream_result?;
336
337 println!();
339
340 if let Some(ref tool_calls) = response.tool_calls {
342 let mut msg_with_content = response.clone();
344 msg_with_content.content = streamed_content.clone();
345 self.chat_session.add_message(msg_with_content);
346
347 for tool_call in tool_calls {
349 let tool_name = &tool_call.function.name;
350 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
351 .unwrap_or(Value::Object(serde_json::Map::new()));
352
353 let tool_result = self
354 .tool_registry
355 .execute(tool_name, tool_args)
356 .await
357 .unwrap_or_else(|e| {
358 ToolResult::error(format!("Tool execution failed: {}", e))
359 });
360
361 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
363 self.chat_session.add_message(tool_message);
364 }
365
366 iterations += 1;
367 continue;
368 }
369
370 let mut final_msg = response;
372 final_msg.content = streamed_content.clone();
373 self.chat_session.add_message(final_msg);
374 return Ok(streamed_content);
375 }
376 }
377
378 pub async fn chat(&mut self, message: impl Into<String>) -> Result<String> {
380 self.send_message(message).await
381 }
382
383 pub fn set_max_iterations(&mut self, max: usize) {
389 self.max_iterations = max;
390 }
391
392 pub fn get_session_summary(&self) -> String {
394 self.chat_session.get_summary()
395 }
396
397 pub fn clear_memory(&mut self) {
399 self.chat_session
401 .metadata
402 .retain(|k, _| !k.starts_with(AGENT_MEMORY_PREFIX));
403 }
404
405 #[inline]
407 fn prefixed_key(key: &str) -> String {
408 format!("{}{}", AGENT_MEMORY_PREFIX, key)
409 }
410
411 pub fn set_memory(&mut self, key: impl Into<String>, value: impl Into<String>) {
414 let key = key.into();
415 self.chat_session
416 .set_metadata(Self::prefixed_key(&key), value);
417 }
418
419 pub fn get_memory(&self, key: &str) -> Option<&String> {
421 self.chat_session.get_metadata(&Self::prefixed_key(key))
422 }
423
424 pub fn remove_memory(&mut self, key: &str) -> Option<String> {
426 self.chat_session.remove_metadata(&Self::prefixed_key(key))
427 }
428
429 pub fn increment_counter(&mut self, key: &str) -> u32 {
432 let current = self
433 .get_memory(key)
434 .and_then(|v| v.parse::<u32>().ok())
435 .unwrap_or(0);
436 let next = current + 1;
437 self.set_memory(key, next.to_string());
438 next
439 }
440
441 pub fn increment_tasks_completed(&mut self) -> u32 {
443 self.increment_counter("tasks_completed")
444 }
445
446 pub async fn chat_with_history(
464 &mut self,
465 messages: Vec<ChatMessage>,
466 temperature: Option<f32>,
467 max_tokens: Option<u32>,
468 stop: Option<Vec<String>>,
469 ) -> Result<String> {
470 let mut temp_session = ChatSession::new();
472
473 for message in messages {
475 temp_session.add_message(message);
476 }
477
478 self.execute_with_tools_temp_session(temp_session, temperature, max_tokens, stop)
480 .await
481 }
482
483 async fn execute_with_tools_temp_session(
485 &mut self,
486 mut temp_session: ChatSession,
487 temperature: Option<f32>,
488 max_tokens: Option<u32>,
489 stop: Option<Vec<String>>,
490 ) -> Result<String> {
491 let mut iterations = 0;
492 let tool_definitions = self.tool_registry.get_definitions();
493
494 loop {
495 if iterations >= self.max_iterations {
496 return Err(HeliosError::AgentError(
497 "Maximum iterations reached".to_string(),
498 ));
499 }
500
501 let messages = temp_session.get_messages();
502 let tools_option = if tool_definitions.is_empty() {
503 None
504 } else {
505 Some(tool_definitions.clone())
506 };
507
508 let response = self
509 .llm_client
510 .chat(
511 messages,
512 tools_option,
513 temperature,
514 max_tokens,
515 stop.clone(),
516 )
517 .await?;
518
519 if let Some(ref tool_calls) = response.tool_calls {
521 temp_session.add_message(response.clone());
523
524 for tool_call in tool_calls {
526 let tool_name = &tool_call.function.name;
527 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
528 .unwrap_or(Value::Object(serde_json::Map::new()));
529
530 let tool_result = self
531 .tool_registry
532 .execute(tool_name, tool_args)
533 .await
534 .unwrap_or_else(|e| {
535 ToolResult::error(format!("Tool execution failed: {}", e))
536 });
537
538 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
540 temp_session.add_message(tool_message);
541 }
542
543 iterations += 1;
544 continue;
545 }
546
547 return Ok(response.content);
549 }
550 }
551
552 pub async fn chat_stream_with_history<F>(
571 &mut self,
572 messages: Vec<ChatMessage>,
573 temperature: Option<f32>,
574 max_tokens: Option<u32>,
575 stop: Option<Vec<String>>,
576 on_chunk: F,
577 ) -> Result<ChatMessage>
578 where
579 F: FnMut(&str) + Send,
580 {
581 let mut temp_session = ChatSession::new();
583
584 for message in messages {
586 temp_session.add_message(message);
587 }
588
589 self.execute_streaming_with_tools_temp_session(
592 temp_session,
593 temperature,
594 max_tokens,
595 stop,
596 on_chunk,
597 )
598 .await
599 }
600
601 async fn execute_streaming_with_tools_temp_session<F>(
603 &mut self,
604 mut temp_session: ChatSession,
605 temperature: Option<f32>,
606 max_tokens: Option<u32>,
607 stop: Option<Vec<String>>,
608 mut on_chunk: F,
609 ) -> Result<ChatMessage>
610 where
611 F: FnMut(&str) + Send,
612 {
613 let mut iterations = 0;
614 let tool_definitions = self.tool_registry.get_definitions();
615
616 loop {
617 if iterations >= self.max_iterations {
618 return Err(HeliosError::AgentError(
619 "Maximum iterations reached".to_string(),
620 ));
621 }
622
623 let messages = temp_session.get_messages();
624 let tools_option = if tool_definitions.is_empty() {
625 None
626 } else {
627 Some(tool_definitions.clone())
628 };
629
630 let mut streamed_content = String::new();
632
633 let stream_result = self
634 .llm_client
635 .chat_stream(
636 messages,
637 tools_option,
638 temperature,
639 max_tokens,
640 stop.clone(),
641 |chunk| {
642 on_chunk(chunk);
643 streamed_content.push_str(chunk);
644 },
645 )
646 .await;
647
648 match stream_result {
649 Ok(response) => {
650 if let Some(ref tool_calls) = response.tool_calls {
652 let mut msg_with_content = response.clone();
654 msg_with_content.content = streamed_content.clone();
655 temp_session.add_message(msg_with_content);
656
657 for tool_call in tool_calls {
659 let tool_name = &tool_call.function.name;
660 let tool_args: Value =
661 serde_json::from_str(&tool_call.function.arguments)
662 .unwrap_or(Value::Object(serde_json::Map::new()));
663
664 let tool_result = self
665 .tool_registry
666 .execute(tool_name, tool_args)
667 .await
668 .unwrap_or_else(|e| {
669 ToolResult::error(format!("Tool execution failed: {}", e))
670 });
671
672 let tool_message =
674 ChatMessage::tool(tool_result.output, tool_call.id.clone());
675 temp_session.add_message(tool_message);
676 }
677
678 iterations += 1;
679 continue; } else {
681 let mut final_msg = response;
683 final_msg.content = streamed_content;
684 return Ok(final_msg);
685 }
686 }
687 Err(e) => return Err(e),
688 }
689 }
690 }
691}
692
693pub struct AgentBuilder {
694 name: String,
695 config: Option<Config>,
696 system_prompt: Option<String>,
697 tools: Vec<Box<dyn crate::tools::Tool>>,
698 max_iterations: usize,
699 react_mode: bool,
700 react_prompt: Option<String>,
701}
702
703impl AgentBuilder {
704 pub fn new(name: impl Into<String>) -> Self {
705 Self {
706 name: name.into(),
707 config: None,
708 system_prompt: None,
709 tools: Vec::new(),
710 max_iterations: 10,
711 react_mode: false,
712 react_prompt: None,
713 }
714 }
715
716 pub fn config(mut self, config: Config) -> Self {
717 self.config = Some(config);
718 self
719 }
720
721 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
722 self.system_prompt = Some(prompt.into());
723 self
724 }
725
726 pub fn tool(mut self, tool: Box<dyn crate::tools::Tool>) -> Self {
728 self.tools.push(tool);
729 self
730 }
731
732 pub fn tools(mut self, tools: Vec<Box<dyn crate::tools::Tool>>) -> Self {
752 self.tools.extend(tools);
753 self
754 }
755
756 pub fn max_iterations(mut self, max: usize) -> Self {
757 self.max_iterations = max;
758 self
759 }
760
761 pub fn react(mut self) -> Self {
782 self.react_mode = true;
783 self
784 }
785
786 pub fn react_with_prompt(mut self, prompt: impl Into<String>) -> Self {
818 self.react_mode = true;
819 self.react_prompt = Some(prompt.into());
820 self
821 }
822
823 pub async fn build(self) -> Result<Agent> {
824 let config = self
825 .config
826 .ok_or_else(|| HeliosError::AgentError("Config is required".to_string()))?;
827
828 let mut agent = Agent::new(self.name, config).await?;
829
830 if let Some(prompt) = self.system_prompt {
831 agent.set_system_prompt(prompt);
832 }
833
834 for tool in self.tools {
835 agent.register_tool(tool);
836 }
837
838 agent.set_max_iterations(self.max_iterations);
839 agent.react_mode = self.react_mode;
840 agent.react_prompt = self.react_prompt;
841
842 Ok(agent)
843 }
844}
845
846#[cfg(test)]
847mod tests {
848 use super::*;
849 use crate::config::Config;
850 use crate::tools::{CalculatorTool, Tool, ToolParameter, ToolResult};
851 use serde_json::Value;
852 use std::collections::HashMap;
853
854 #[tokio::test]
856 async fn test_agent_creation_via_builder() {
857 let config = Config::new_default();
858 let agent = Agent::builder("test_agent").config(config).build().await;
859 assert!(agent.is_ok());
860 }
861
862 #[tokio::test]
864 async fn test_agent_memory_namespacing_set_get_remove() {
865 let config = Config::new_default();
866 let mut agent = Agent::builder("test_agent")
867 .config(config)
868 .build()
869 .await
870 .unwrap();
871
872 agent.set_memory("working_directory", "/tmp");
874 assert_eq!(
875 agent.get_memory("working_directory"),
876 Some(&"/tmp".to_string())
877 );
878
879 assert_eq!(
881 agent.chat_session().get_metadata("agent:working_directory"),
882 Some(&"/tmp".to_string())
883 );
884 assert!(agent
886 .chat_session()
887 .get_metadata("working_directory")
888 .is_none());
889
890 let removed = agent.remove_memory("working_directory");
892 assert_eq!(removed.as_deref(), Some("/tmp"));
893 assert!(agent.get_memory("working_directory").is_none());
894 }
895
896 #[tokio::test]
898 async fn test_agent_clear_memory_scoped() {
899 let config = Config::new_default();
900 let mut agent = Agent::builder("test_agent")
901 .config(config)
902 .build()
903 .await
904 .unwrap();
905
906 agent.set_memory("tasks_completed", "3");
908 agent
909 .chat_session_mut()
910 .set_metadata("session_start", "now");
911
912 agent.clear_memory();
914
915 assert!(agent.get_memory("tasks_completed").is_none());
917 assert_eq!(
919 agent.chat_session().get_metadata("session_start"),
920 Some(&"now".to_string())
921 );
922 }
923
924 #[tokio::test]
926 async fn test_agent_increment_helpers() {
927 let config = Config::new_default();
928 let mut agent = Agent::builder("test_agent")
929 .config(config)
930 .build()
931 .await
932 .unwrap();
933
934 let n1 = agent.increment_tasks_completed();
936 assert_eq!(n1, 1);
937 assert_eq!(agent.get_memory("tasks_completed"), Some(&"1".to_string()));
938
939 let n2 = agent.increment_tasks_completed();
940 assert_eq!(n2, 2);
941 assert_eq!(agent.get_memory("tasks_completed"), Some(&"2".to_string()));
942
943 let f1 = agent.increment_counter("files_accessed");
945 assert_eq!(f1, 1);
946 let f2 = agent.increment_counter("files_accessed");
947 assert_eq!(f2, 2);
948 assert_eq!(agent.get_memory("files_accessed"), Some(&"2".to_string()));
949 }
950
951 #[tokio::test]
953 async fn test_agent_builder() {
954 let config = Config::new_default();
955 let agent = Agent::builder("test_agent")
956 .config(config)
957 .system_prompt("You are a helpful assistant")
958 .max_iterations(5)
959 .tool(Box::new(CalculatorTool))
960 .build()
961 .await
962 .unwrap();
963
964 assert_eq!(agent.name(), "test_agent");
965 assert_eq!(agent.max_iterations, 5);
966 assert_eq!(
967 agent.tool_registry().list_tools(),
968 vec!["calculator".to_string()]
969 );
970 }
971
972 #[tokio::test]
974 async fn test_agent_system_prompt() {
975 let config = Config::new_default();
976 let mut agent = Agent::builder("test_agent")
977 .config(config)
978 .build()
979 .await
980 .unwrap();
981 agent.set_system_prompt("You are a test agent");
982
983 let session = agent.chat_session();
985 assert_eq!(
986 session.system_prompt,
987 Some("You are a test agent".to_string())
988 );
989 }
990
991 #[tokio::test]
993 async fn test_agent_tool_registry() {
994 let config = Config::new_default();
995 let mut agent = Agent::builder("test_agent")
996 .config(config)
997 .build()
998 .await
999 .unwrap();
1000
1001 assert!(agent.tool_registry().list_tools().is_empty());
1003
1004 agent.register_tool(Box::new(CalculatorTool));
1006 assert_eq!(
1007 agent.tool_registry().list_tools(),
1008 vec!["calculator".to_string()]
1009 );
1010 }
1011
1012 #[tokio::test]
1014 async fn test_agent_clear_history() {
1015 let config = Config::new_default();
1016 let mut agent = Agent::builder("test_agent")
1017 .config(config)
1018 .build()
1019 .await
1020 .unwrap();
1021
1022 agent.chat_session_mut().add_user_message("Hello");
1024 assert!(!agent.chat_session().messages.is_empty());
1025
1026 agent.clear_history();
1028 assert!(agent.chat_session().messages.is_empty());
1029 }
1030
1031 struct MockTool;
1033
1034 #[async_trait::async_trait]
1035 impl Tool for MockTool {
1036 fn name(&self) -> &str {
1037 "mock_tool"
1038 }
1039
1040 fn description(&self) -> &str {
1041 "A mock tool for testing"
1042 }
1043
1044 fn parameters(&self) -> HashMap<String, ToolParameter> {
1045 let mut params = HashMap::new();
1046 params.insert(
1047 "input".to_string(),
1048 ToolParameter {
1049 param_type: "string".to_string(),
1050 description: "Input parameter".to_string(),
1051 required: Some(true),
1052 },
1053 );
1054 params
1055 }
1056
1057 async fn execute(&self, args: Value) -> crate::Result<ToolResult> {
1058 let input = args
1059 .get("input")
1060 .and_then(|v| v.as_str())
1061 .unwrap_or("default");
1062 Ok(ToolResult::success(format!("Mock tool output: {}", input)))
1063 }
1064 }
1065}