1#![allow(dead_code)]
9#![allow(unused_variables)]
10use crate::chat::{ChatMessage, ChatSession};
11use crate::config::Config;
12use crate::error::{HeliosError, Result};
13use crate::llm::{LLMClient, LLMProviderType};
14use crate::tools::{ToolRegistry, ToolResult};
15use serde_json::Value;
16
17const AGENT_MEMORY_PREFIX: &str = "agent:";
19
20pub struct Agent {
22 name: String,
24 llm_client: LLMClient,
26 tool_registry: ToolRegistry,
28 chat_session: ChatSession,
30 max_iterations: usize,
32}
33
34impl Agent {
35 async fn new(name: impl Into<String>, config: Config) -> Result<Self> {
46 #[cfg(feature = "local")]
47 let provider_type = if let Some(local_config) = config.local {
48 LLMProviderType::Local(local_config)
49 } else {
50 LLMProviderType::Remote(config.llm)
51 };
52
53 #[cfg(not(feature = "local"))]
54 let provider_type = LLMProviderType::Remote(config.llm);
55
56 let llm_client = LLMClient::new(provider_type).await?;
57
58 Ok(Self {
59 name: name.into(),
60 llm_client,
61 tool_registry: ToolRegistry::new(),
62 chat_session: ChatSession::new(),
63 max_iterations: 10,
64 })
65 }
66
67 pub fn builder(name: impl Into<String>) -> AgentBuilder {
73 AgentBuilder::new(name)
74 }
75
76 pub fn name(&self) -> &str {
78 &self.name
79 }
80
81 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
87 self.chat_session = self.chat_session.clone().with_system_prompt(prompt);
88 }
89
90 pub fn register_tool(&mut self, tool: Box<dyn crate::tools::Tool>) {
96 self.tool_registry.register(tool);
97 }
98
99 pub fn tool_registry(&self) -> &ToolRegistry {
101 &self.tool_registry
102 }
103
104 pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry {
106 &mut self.tool_registry
107 }
108
109 pub fn chat_session(&self) -> &ChatSession {
111 &self.chat_session
112 }
113
114 pub fn chat_session_mut(&mut self) -> &mut ChatSession {
116 &mut self.chat_session
117 }
118
119 pub fn clear_history(&mut self) {
121 self.chat_session.clear();
122 }
123
124 pub async fn send_message(&mut self, message: impl Into<String>) -> Result<String> {
134 let user_message = message.into();
135 self.chat_session.add_user_message(user_message.clone());
136
137 let response = self.execute_with_tools().await?;
139
140 Ok(response)
141 }
142
143 async fn execute_with_tools(&mut self) -> Result<String> {
145 self.execute_with_tools_streaming().await
146 }
147
148 async fn execute_with_tools_streaming(&mut self) -> Result<String> {
150 self.execute_with_tools_streaming_with_params(None, None, None)
151 .await
152 }
153
154 async fn execute_with_tools_with_params(
156 &mut self,
157 temperature: Option<f32>,
158 max_tokens: Option<u32>,
159 stop: Option<Vec<String>>,
160 ) -> Result<String> {
161 let mut iterations = 0;
162 let tool_definitions = self.tool_registry.get_definitions();
163
164 loop {
165 if iterations >= self.max_iterations {
166 return Err(HeliosError::AgentError(
167 "Maximum iterations reached".to_string(),
168 ));
169 }
170
171 let messages = self.chat_session.get_messages();
172 let tools_option = if tool_definitions.is_empty() {
173 None
174 } else {
175 Some(tool_definitions.clone())
176 };
177
178 let response = self
179 .llm_client
180 .chat(
181 messages,
182 tools_option,
183 temperature,
184 max_tokens,
185 stop.clone(),
186 )
187 .await?;
188
189 if let Some(ref tool_calls) = response.tool_calls {
191 self.chat_session.add_message(response.clone());
193
194 for tool_call in tool_calls {
196 let tool_name = &tool_call.function.name;
197 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
198 .unwrap_or(Value::Object(serde_json::Map::new()));
199
200 let tool_result = self
201 .tool_registry
202 .execute(tool_name, tool_args)
203 .await
204 .unwrap_or_else(|e| {
205 ToolResult::error(format!("Tool execution failed: {}", e))
206 });
207
208 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
210 self.chat_session.add_message(tool_message);
211 }
212
213 iterations += 1;
214 continue;
215 }
216
217 self.chat_session.add_message(response.clone());
219 return Ok(response.content);
220 }
221 }
222
223 async fn execute_with_tools_streaming_with_params(
225 &mut self,
226 temperature: Option<f32>,
227 max_tokens: Option<u32>,
228 stop: Option<Vec<String>>,
229 ) -> Result<String> {
230 let mut iterations = 0;
231 let tool_definitions = self.tool_registry.get_definitions();
232
233 loop {
234 if iterations >= self.max_iterations {
235 return Err(HeliosError::AgentError(
236 "Maximum iterations reached".to_string(),
237 ));
238 }
239
240 let messages = self.chat_session.get_messages();
241 let tools_option = if tool_definitions.is_empty() {
242 None
243 } else {
244 Some(tool_definitions.clone())
245 };
246
247 let mut streamed_content = String::new();
248
249 let stream_result = self
250 .llm_client
251 .chat_stream(
252 messages,
253 tools_option, temperature,
255 max_tokens,
256 stop.clone(),
257 |chunk| {
258 print!("{}", chunk);
260 let _ = std::io::Write::flush(&mut std::io::stdout());
261 streamed_content.push_str(chunk);
262 },
263 )
264 .await;
265
266 let response = stream_result?;
267
268 println!();
270
271 if let Some(ref tool_calls) = response.tool_calls {
273 let mut msg_with_content = response.clone();
275 msg_with_content.content = streamed_content.clone();
276 self.chat_session.add_message(msg_with_content);
277
278 for tool_call in tool_calls {
280 let tool_name = &tool_call.function.name;
281 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
282 .unwrap_or(Value::Object(serde_json::Map::new()));
283
284 let tool_result = self
285 .tool_registry
286 .execute(tool_name, tool_args)
287 .await
288 .unwrap_or_else(|e| {
289 ToolResult::error(format!("Tool execution failed: {}", e))
290 });
291
292 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
294 self.chat_session.add_message(tool_message);
295 }
296
297 iterations += 1;
298 continue;
299 }
300
301 let mut final_msg = response;
303 final_msg.content = streamed_content.clone();
304 self.chat_session.add_message(final_msg);
305 return Ok(streamed_content);
306 }
307 }
308
309 pub async fn chat(&mut self, message: impl Into<String>) -> Result<String> {
311 self.send_message(message).await
312 }
313
314 pub fn set_max_iterations(&mut self, max: usize) {
320 self.max_iterations = max;
321 }
322
323 pub fn get_session_summary(&self) -> String {
325 self.chat_session.get_summary()
326 }
327
328 pub fn clear_memory(&mut self) {
330 self.chat_session
332 .metadata
333 .retain(|k, _| !k.starts_with(AGENT_MEMORY_PREFIX));
334 }
335
336 #[inline]
338 fn prefixed_key(key: &str) -> String {
339 format!("{}{}", AGENT_MEMORY_PREFIX, key)
340 }
341
342 pub fn set_memory(&mut self, key: impl Into<String>, value: impl Into<String>) {
345 let key = key.into();
346 self.chat_session
347 .set_metadata(Self::prefixed_key(&key), value);
348 }
349
350 pub fn get_memory(&self, key: &str) -> Option<&String> {
352 self.chat_session.get_metadata(&Self::prefixed_key(key))
353 }
354
355 pub fn remove_memory(&mut self, key: &str) -> Option<String> {
357 self.chat_session.remove_metadata(&Self::prefixed_key(key))
358 }
359
360 pub fn increment_counter(&mut self, key: &str) -> u32 {
363 let current = self
364 .get_memory(key)
365 .and_then(|v| v.parse::<u32>().ok())
366 .unwrap_or(0);
367 let next = current + 1;
368 self.set_memory(key, next.to_string());
369 next
370 }
371
372 pub fn increment_tasks_completed(&mut self) -> u32 {
374 self.increment_counter("tasks_completed")
375 }
376
377 pub async fn chat_with_history(
395 &mut self,
396 messages: Vec<ChatMessage>,
397 temperature: Option<f32>,
398 max_tokens: Option<u32>,
399 stop: Option<Vec<String>>,
400 ) -> Result<String> {
401 let mut temp_session = ChatSession::new();
403
404 for message in messages {
406 temp_session.add_message(message);
407 }
408
409 self.execute_with_tools_temp_session(temp_session, temperature, max_tokens, stop)
411 .await
412 }
413
414 async fn execute_with_tools_temp_session(
416 &mut self,
417 mut temp_session: ChatSession,
418 temperature: Option<f32>,
419 max_tokens: Option<u32>,
420 stop: Option<Vec<String>>,
421 ) -> Result<String> {
422 let mut iterations = 0;
423 let tool_definitions = self.tool_registry.get_definitions();
424
425 loop {
426 if iterations >= self.max_iterations {
427 return Err(HeliosError::AgentError(
428 "Maximum iterations reached".to_string(),
429 ));
430 }
431
432 let messages = temp_session.get_messages();
433 let tools_option = if tool_definitions.is_empty() {
434 None
435 } else {
436 Some(tool_definitions.clone())
437 };
438
439 let response = self
440 .llm_client
441 .chat(
442 messages,
443 tools_option,
444 temperature,
445 max_tokens,
446 stop.clone(),
447 )
448 .await?;
449
450 if let Some(ref tool_calls) = response.tool_calls {
452 temp_session.add_message(response.clone());
454
455 for tool_call in tool_calls {
457 let tool_name = &tool_call.function.name;
458 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
459 .unwrap_or(Value::Object(serde_json::Map::new()));
460
461 let tool_result = self
462 .tool_registry
463 .execute(tool_name, tool_args)
464 .await
465 .unwrap_or_else(|e| {
466 ToolResult::error(format!("Tool execution failed: {}", e))
467 });
468
469 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
471 temp_session.add_message(tool_message);
472 }
473
474 iterations += 1;
475 continue;
476 }
477
478 return Ok(response.content);
480 }
481 }
482
483 pub async fn chat_stream_with_history<F>(
502 &mut self,
503 messages: Vec<ChatMessage>,
504 temperature: Option<f32>,
505 max_tokens: Option<u32>,
506 stop: Option<Vec<String>>,
507 on_chunk: F,
508 ) -> Result<ChatMessage>
509 where
510 F: FnMut(&str) + Send,
511 {
512 let mut temp_session = ChatSession::new();
514
515 for message in messages {
517 temp_session.add_message(message);
518 }
519
520 self.execute_streaming_with_tools_temp_session(
523 temp_session,
524 temperature,
525 max_tokens,
526 stop,
527 on_chunk,
528 )
529 .await
530 }
531
532 async fn execute_streaming_with_tools_temp_session<F>(
534 &mut self,
535 mut temp_session: ChatSession,
536 temperature: Option<f32>,
537 max_tokens: Option<u32>,
538 stop: Option<Vec<String>>,
539 mut on_chunk: F,
540 ) -> Result<ChatMessage>
541 where
542 F: FnMut(&str) + Send,
543 {
544 let mut iterations = 0;
545 let tool_definitions = self.tool_registry.get_definitions();
546
547 loop {
548 if iterations >= self.max_iterations {
549 return Err(HeliosError::AgentError(
550 "Maximum iterations reached".to_string(),
551 ));
552 }
553
554 let messages = temp_session.get_messages();
555 let tools_option = if tool_definitions.is_empty() {
556 None
557 } else {
558 Some(tool_definitions.clone())
559 };
560
561 let mut streamed_content = String::new();
563
564 let stream_result = self
565 .llm_client
566 .chat_stream(
567 messages,
568 tools_option,
569 temperature,
570 max_tokens,
571 stop.clone(),
572 |chunk| {
573 on_chunk(chunk);
574 streamed_content.push_str(chunk);
575 },
576 )
577 .await;
578
579 match stream_result {
580 Ok(response) => {
581 if let Some(ref tool_calls) = response.tool_calls {
583 let mut msg_with_content = response.clone();
585 msg_with_content.content = streamed_content.clone();
586 temp_session.add_message(msg_with_content);
587
588 for tool_call in tool_calls {
590 let tool_name = &tool_call.function.name;
591 let tool_args: Value =
592 serde_json::from_str(&tool_call.function.arguments)
593 .unwrap_or(Value::Object(serde_json::Map::new()));
594
595 let tool_result = self
596 .tool_registry
597 .execute(tool_name, tool_args)
598 .await
599 .unwrap_or_else(|e| {
600 ToolResult::error(format!("Tool execution failed: {}", e))
601 });
602
603 let tool_message =
605 ChatMessage::tool(tool_result.output, tool_call.id.clone());
606 temp_session.add_message(tool_message);
607 }
608
609 iterations += 1;
610 continue; } else {
612 let mut final_msg = response;
614 final_msg.content = streamed_content;
615 return Ok(final_msg);
616 }
617 }
618 Err(e) => return Err(e),
619 }
620 }
621 }
622}
623
624pub struct AgentBuilder {
625 name: String,
626 config: Option<Config>,
627 system_prompt: Option<String>,
628 tools: Vec<Box<dyn crate::tools::Tool>>,
629 max_iterations: usize,
630}
631
632impl AgentBuilder {
633 pub fn new(name: impl Into<String>) -> Self {
634 Self {
635 name: name.into(),
636 config: None,
637 system_prompt: None,
638 tools: Vec::new(),
639 max_iterations: 10,
640 }
641 }
642
643 pub fn config(mut self, config: Config) -> Self {
644 self.config = Some(config);
645 self
646 }
647
648 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
649 self.system_prompt = Some(prompt.into());
650 self
651 }
652
653 pub fn tool(mut self, tool: Box<dyn crate::tools::Tool>) -> Self {
655 self.tools.push(tool);
656 self
657 }
658
659 pub fn tools(mut self, tools: Vec<Box<dyn crate::tools::Tool>>) -> Self {
679 self.tools.extend(tools);
680 self
681 }
682
683 pub fn max_iterations(mut self, max: usize) -> Self {
684 self.max_iterations = max;
685 self
686 }
687
688 pub async fn build(self) -> Result<Agent> {
689 let config = self
690 .config
691 .ok_or_else(|| HeliosError::AgentError("Config is required".to_string()))?;
692
693 let mut agent = Agent::new(self.name, config).await?;
694
695 if let Some(prompt) = self.system_prompt {
696 agent.set_system_prompt(prompt);
697 }
698
699 for tool in self.tools {
700 agent.register_tool(tool);
701 }
702
703 agent.set_max_iterations(self.max_iterations);
704
705 Ok(agent)
706 }
707}
708
709#[cfg(test)]
710mod tests {
711 use super::*;
712 use crate::config::Config;
713 use crate::tools::{CalculatorTool, Tool, ToolParameter, ToolResult};
714 use serde_json::Value;
715 use std::collections::HashMap;
716
717 #[tokio::test]
719 async fn test_agent_creation_via_builder() {
720 let config = Config::new_default();
721 let agent = Agent::builder("test_agent").config(config).build().await;
722 assert!(agent.is_ok());
723 }
724
725 #[tokio::test]
727 async fn test_agent_memory_namespacing_set_get_remove() {
728 let config = Config::new_default();
729 let mut agent = Agent::builder("test_agent")
730 .config(config)
731 .build()
732 .await
733 .unwrap();
734
735 agent.set_memory("working_directory", "/tmp");
737 assert_eq!(
738 agent.get_memory("working_directory"),
739 Some(&"/tmp".to_string())
740 );
741
742 assert_eq!(
744 agent.chat_session().get_metadata("agent:working_directory"),
745 Some(&"/tmp".to_string())
746 );
747 assert!(agent
749 .chat_session()
750 .get_metadata("working_directory")
751 .is_none());
752
753 let removed = agent.remove_memory("working_directory");
755 assert_eq!(removed.as_deref(), Some("/tmp"));
756 assert!(agent.get_memory("working_directory").is_none());
757 }
758
759 #[tokio::test]
761 async fn test_agent_clear_memory_scoped() {
762 let config = Config::new_default();
763 let mut agent = Agent::builder("test_agent")
764 .config(config)
765 .build()
766 .await
767 .unwrap();
768
769 agent.set_memory("tasks_completed", "3");
771 agent
772 .chat_session_mut()
773 .set_metadata("session_start", "now");
774
775 agent.clear_memory();
777
778 assert!(agent.get_memory("tasks_completed").is_none());
780 assert_eq!(
782 agent.chat_session().get_metadata("session_start"),
783 Some(&"now".to_string())
784 );
785 }
786
787 #[tokio::test]
789 async fn test_agent_increment_helpers() {
790 let config = Config::new_default();
791 let mut agent = Agent::builder("test_agent")
792 .config(config)
793 .build()
794 .await
795 .unwrap();
796
797 let n1 = agent.increment_tasks_completed();
799 assert_eq!(n1, 1);
800 assert_eq!(agent.get_memory("tasks_completed"), Some(&"1".to_string()));
801
802 let n2 = agent.increment_tasks_completed();
803 assert_eq!(n2, 2);
804 assert_eq!(agent.get_memory("tasks_completed"), Some(&"2".to_string()));
805
806 let f1 = agent.increment_counter("files_accessed");
808 assert_eq!(f1, 1);
809 let f2 = agent.increment_counter("files_accessed");
810 assert_eq!(f2, 2);
811 assert_eq!(agent.get_memory("files_accessed"), Some(&"2".to_string()));
812 }
813
814 #[tokio::test]
816 async fn test_agent_builder() {
817 let config = Config::new_default();
818 let agent = Agent::builder("test_agent")
819 .config(config)
820 .system_prompt("You are a helpful assistant")
821 .max_iterations(5)
822 .tool(Box::new(CalculatorTool))
823 .build()
824 .await
825 .unwrap();
826
827 assert_eq!(agent.name(), "test_agent");
828 assert_eq!(agent.max_iterations, 5);
829 assert_eq!(
830 agent.tool_registry().list_tools(),
831 vec!["calculator".to_string()]
832 );
833 }
834
835 #[tokio::test]
837 async fn test_agent_system_prompt() {
838 let config = Config::new_default();
839 let mut agent = Agent::builder("test_agent")
840 .config(config)
841 .build()
842 .await
843 .unwrap();
844 agent.set_system_prompt("You are a test agent");
845
846 let session = agent.chat_session();
848 assert_eq!(
849 session.system_prompt,
850 Some("You are a test agent".to_string())
851 );
852 }
853
854 #[tokio::test]
856 async fn test_agent_tool_registry() {
857 let config = Config::new_default();
858 let mut agent = Agent::builder("test_agent")
859 .config(config)
860 .build()
861 .await
862 .unwrap();
863
864 assert!(agent.tool_registry().list_tools().is_empty());
866
867 agent.register_tool(Box::new(CalculatorTool));
869 assert_eq!(
870 agent.tool_registry().list_tools(),
871 vec!["calculator".to_string()]
872 );
873 }
874
875 #[tokio::test]
877 async fn test_agent_clear_history() {
878 let config = Config::new_default();
879 let mut agent = Agent::builder("test_agent")
880 .config(config)
881 .build()
882 .await
883 .unwrap();
884
885 agent.chat_session_mut().add_user_message("Hello");
887 assert!(!agent.chat_session().messages.is_empty());
888
889 agent.clear_history();
891 assert!(agent.chat_session().messages.is_empty());
892 }
893
894 struct MockTool;
896
897 #[async_trait::async_trait]
898 impl Tool for MockTool {
899 fn name(&self) -> &str {
900 "mock_tool"
901 }
902
903 fn description(&self) -> &str {
904 "A mock tool for testing"
905 }
906
907 fn parameters(&self) -> HashMap<String, ToolParameter> {
908 let mut params = HashMap::new();
909 params.insert(
910 "input".to_string(),
911 ToolParameter {
912 param_type: "string".to_string(),
913 description: "Input parameter".to_string(),
914 required: Some(true),
915 },
916 );
917 params
918 }
919
920 async fn execute(&self, args: Value) -> crate::Result<ToolResult> {
921 let input = args
922 .get("input")
923 .and_then(|v| v.as_str())
924 .unwrap_or("default");
925 Ok(ToolResult::success(format!("Mock tool output: {}", input)))
926 }
927 }
928}