1#![allow(dead_code)]
9#![allow(unused_variables)]
10use crate::chat::{ChatMessage, ChatSession};
11use crate::config::Config;
12use crate::error::{HeliosError, Result};
13use crate::llm::{LLMClient, LLMProviderType};
14use crate::tools::{ToolRegistry, ToolResult};
15use serde_json::Value;
16
17const AGENT_MEMORY_PREFIX: &str = "agent:";
19
20pub struct Agent {
22 name: String,
24 llm_client: LLMClient,
26 tool_registry: ToolRegistry,
28 chat_session: ChatSession,
30 max_iterations: usize,
32}
33
34impl Agent {
35 async fn new(name: impl Into<String>, config: Config) -> Result<Self> {
46 let provider_type = if let Some(local_config) = config.local {
47 LLMProviderType::Local(local_config)
48 } else {
49 LLMProviderType::Remote(config.llm)
50 };
51
52 let llm_client = LLMClient::new(provider_type).await?;
53
54 Ok(Self {
55 name: name.into(),
56 llm_client,
57 tool_registry: ToolRegistry::new(),
58 chat_session: ChatSession::new(),
59 max_iterations: 10,
60 })
61 }
62
63 pub fn builder(name: impl Into<String>) -> AgentBuilder {
69 AgentBuilder::new(name)
70 }
71
72 pub fn name(&self) -> &str {
74 &self.name
75 }
76
77 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
83 self.chat_session = self.chat_session.clone().with_system_prompt(prompt);
84 }
85
86 pub fn register_tool(&mut self, tool: Box<dyn crate::tools::Tool>) {
92 self.tool_registry.register(tool);
93 }
94
95 pub fn tool_registry(&self) -> &ToolRegistry {
97 &self.tool_registry
98 }
99
100 pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry {
102 &mut self.tool_registry
103 }
104
105 pub fn chat_session(&self) -> &ChatSession {
107 &self.chat_session
108 }
109
110 pub fn chat_session_mut(&mut self) -> &mut ChatSession {
112 &mut self.chat_session
113 }
114
115 pub fn clear_history(&mut self) {
117 self.chat_session.clear();
118 }
119
120 pub async fn send_message(&mut self, message: impl Into<String>) -> Result<String> {
130 let user_message = message.into();
131 self.chat_session.add_user_message(user_message.clone());
132
133 let response = self.execute_with_tools().await?;
135
136 Ok(response)
137 }
138
139 async fn execute_with_tools(&mut self) -> Result<String> {
141 self.execute_with_tools_with_params(None, None, None).await
142 }
143
144 async fn execute_with_tools_with_params(
146 &mut self,
147 temperature: Option<f32>,
148 max_tokens: Option<u32>,
149 stop: Option<Vec<String>>,
150 ) -> Result<String> {
151 let mut iterations = 0;
152 let tool_definitions = self.tool_registry.get_definitions();
153
154 loop {
155 if iterations >= self.max_iterations {
156 return Err(HeliosError::AgentError(
157 "Maximum iterations reached".to_string(),
158 ));
159 }
160
161 let messages = self.chat_session.get_messages();
162 let tools_option = if tool_definitions.is_empty() {
163 None
164 } else {
165 Some(tool_definitions.clone())
166 };
167
168 let response = self
169 .llm_client
170 .chat(
171 messages,
172 tools_option,
173 temperature,
174 max_tokens,
175 stop.clone(),
176 )
177 .await?;
178
179 if let Some(ref tool_calls) = response.tool_calls {
181 self.chat_session.add_message(response.clone());
183
184 for tool_call in tool_calls {
186 let tool_name = &tool_call.function.name;
187 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
188 .unwrap_or(Value::Object(serde_json::Map::new()));
189
190 let tool_result = self
191 .tool_registry
192 .execute(tool_name, tool_args)
193 .await
194 .unwrap_or_else(|e| {
195 ToolResult::error(format!("Tool execution failed: {}", e))
196 });
197
198 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
200 self.chat_session.add_message(tool_message);
201 }
202
203 iterations += 1;
204 continue;
205 }
206
207 self.chat_session.add_message(response.clone());
209 return Ok(response.content);
210 }
211 }
212
213 pub async fn chat(&mut self, message: impl Into<String>) -> Result<String> {
215 self.send_message(message).await
216 }
217
218 pub fn set_max_iterations(&mut self, max: usize) {
224 self.max_iterations = max;
225 }
226
227 pub fn get_session_summary(&self) -> String {
229 self.chat_session.get_summary()
230 }
231
232 pub fn clear_memory(&mut self) {
234 self.chat_session
236 .metadata
237 .retain(|k, _| !k.starts_with(AGENT_MEMORY_PREFIX));
238 }
239
240 #[inline]
242 fn prefixed_key(key: &str) -> String {
243 format!("{}{}", AGENT_MEMORY_PREFIX, key)
244 }
245
246 pub fn set_memory(&mut self, key: impl Into<String>, value: impl Into<String>) {
249 let key = key.into();
250 self.chat_session
251 .set_metadata(Self::prefixed_key(&key), value);
252 }
253
254 pub fn get_memory(&self, key: &str) -> Option<&String> {
256 self.chat_session.get_metadata(&Self::prefixed_key(key))
257 }
258
259 pub fn remove_memory(&mut self, key: &str) -> Option<String> {
261 self.chat_session.remove_metadata(&Self::prefixed_key(key))
262 }
263
264 pub fn increment_counter(&mut self, key: &str) -> u32 {
267 let current = self
268 .get_memory(key)
269 .and_then(|v| v.parse::<u32>().ok())
270 .unwrap_or(0);
271 let next = current + 1;
272 self.set_memory(key, next.to_string());
273 next
274 }
275
276 pub fn increment_tasks_completed(&mut self) -> u32 {
278 self.increment_counter("tasks_completed")
279 }
280
281 pub async fn chat_with_history(
299 &mut self,
300 messages: Vec<ChatMessage>,
301 temperature: Option<f32>,
302 max_tokens: Option<u32>,
303 stop: Option<Vec<String>>,
304 ) -> Result<String> {
305 let mut temp_session = ChatSession::new();
307
308 for message in messages {
310 temp_session.add_message(message);
311 }
312
313 self.execute_with_tools_temp_session(temp_session, temperature, max_tokens, stop)
315 .await
316 }
317
318 async fn execute_with_tools_temp_session(
320 &mut self,
321 mut temp_session: ChatSession,
322 temperature: Option<f32>,
323 max_tokens: Option<u32>,
324 stop: Option<Vec<String>>,
325 ) -> Result<String> {
326 let mut iterations = 0;
327 let tool_definitions = self.tool_registry.get_definitions();
328
329 loop {
330 if iterations >= self.max_iterations {
331 return Err(HeliosError::AgentError(
332 "Maximum iterations reached".to_string(),
333 ));
334 }
335
336 let messages = temp_session.get_messages();
337 let tools_option = if tool_definitions.is_empty() {
338 None
339 } else {
340 Some(tool_definitions.clone())
341 };
342
343 let response = self
344 .llm_client
345 .chat(
346 messages,
347 tools_option,
348 temperature,
349 max_tokens,
350 stop.clone(),
351 )
352 .await?;
353
354 if let Some(ref tool_calls) = response.tool_calls {
356 temp_session.add_message(response.clone());
358
359 for tool_call in tool_calls {
361 let tool_name = &tool_call.function.name;
362 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
363 .unwrap_or(Value::Object(serde_json::Map::new()));
364
365 let tool_result = self
366 .tool_registry
367 .execute(tool_name, tool_args)
368 .await
369 .unwrap_or_else(|e| {
370 ToolResult::error(format!("Tool execution failed: {}", e))
371 });
372
373 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
375 temp_session.add_message(tool_message);
376 }
377
378 iterations += 1;
379 continue;
380 }
381
382 return Ok(response.content);
384 }
385 }
386
387 pub async fn chat_stream_with_history<F>(
406 &mut self,
407 messages: Vec<ChatMessage>,
408 temperature: Option<f32>,
409 max_tokens: Option<u32>,
410 stop: Option<Vec<String>>,
411 on_chunk: F,
412 ) -> Result<ChatMessage>
413 where
414 F: FnMut(&str) + Send,
415 {
416 let mut temp_session = ChatSession::new();
418
419 for message in messages {
421 temp_session.add_message(message);
422 }
423
424 self.execute_streaming_with_tools_temp_session(
427 temp_session,
428 temperature,
429 max_tokens,
430 stop,
431 on_chunk,
432 )
433 .await
434 }
435
436 async fn execute_streaming_with_tools_temp_session<F>(
438 &mut self,
439 mut temp_session: ChatSession,
440 temperature: Option<f32>,
441 max_tokens: Option<u32>,
442 stop: Option<Vec<String>>,
443 mut on_chunk: F,
444 ) -> Result<ChatMessage>
445 where
446 F: FnMut(&str) + Send,
447 {
448 let mut iterations = 0;
449 let tool_definitions = self.tool_registry.get_definitions();
450
451 loop {
452 if iterations >= self.max_iterations {
453 return Err(HeliosError::AgentError(
454 "Maximum iterations reached".to_string(),
455 ));
456 }
457
458 let messages = temp_session.get_messages();
459 let tools_option = if tool_definitions.is_empty() {
460 None
461 } else {
462 Some(tool_definitions.clone())
463 };
464
465 if iterations == 0 {
467 let mut streamed_content = String::new();
468
469 let stream_result = self
470 .llm_client
471 .chat_stream(
472 messages,
473 tools_option,
474 temperature,
475 max_tokens,
476 stop.clone(),
477 |chunk| {
478 on_chunk(chunk);
479 streamed_content.push_str(chunk);
480 },
481 )
482 .await;
483
484 match stream_result {
485 Ok(response) => {
486 if let Some(ref tool_calls) = response.tool_calls {
488 temp_session.add_message(response.clone());
490
491 for tool_call in tool_calls {
493 let tool_name = &tool_call.function.name;
494 let tool_args: Value =
495 serde_json::from_str(&tool_call.function.arguments)
496 .unwrap_or(Value::Object(serde_json::Map::new()));
497
498 let tool_result = self
499 .tool_registry
500 .execute(tool_name, tool_args)
501 .await
502 .unwrap_or_else(|e| {
503 ToolResult::error(format!("Tool execution failed: {}", e))
504 });
505
506 let tool_message =
508 ChatMessage::tool(tool_result.output, tool_call.id.clone());
509 temp_session.add_message(tool_message);
510 }
511
512 iterations += 1;
513 continue; } else {
515 let mut final_msg = response;
517 final_msg.content = streamed_content;
518 return Ok(final_msg);
519 }
520 }
521 Err(e) => return Err(e),
522 }
523 } else {
524 let response = self
527 .llm_client
528 .chat(
529 messages,
530 tools_option,
531 temperature,
532 max_tokens,
533 stop.clone(),
534 )
535 .await?;
536
537 if let Some(ref tool_calls) = response.tool_calls {
538 temp_session.add_message(response.clone());
540
541 for tool_call in tool_calls {
543 let tool_name = &tool_call.function.name;
544 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
545 .unwrap_or(Value::Object(serde_json::Map::new()));
546
547 let tool_result = self
548 .tool_registry
549 .execute(tool_name, tool_args)
550 .await
551 .unwrap_or_else(|e| {
552 ToolResult::error(format!("Tool execution failed: {}", e))
553 });
554
555 let tool_message =
557 ChatMessage::tool(tool_result.output, tool_call.id.clone());
558 temp_session.add_message(tool_message);
559 }
560
561 iterations += 1;
562 continue;
563 }
564
565 return Ok(response);
567 }
568 }
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use crate::config::Config;
576 use crate::tools::{CalculatorTool, Tool, ToolParameter, ToolResult};
577 use serde_json::Value;
578 use std::collections::HashMap;
579
580 #[tokio::test]
582 async fn test_agent_creation_via_builder() {
583 let config = Config::new_default();
584 let agent = Agent::builder("test_agent").config(config).build().await;
585 assert!(agent.is_ok());
586 }
587
588 #[tokio::test]
590 async fn test_agent_memory_namespacing_set_get_remove() {
591 let config = Config::new_default();
592 let mut agent = Agent::builder("test_agent")
593 .config(config)
594 .build()
595 .await
596 .unwrap();
597
598 agent.set_memory("working_directory", "/tmp");
600 assert_eq!(
601 agent.get_memory("working_directory"),
602 Some(&"/tmp".to_string())
603 );
604
605 assert_eq!(
607 agent.chat_session().get_metadata("agent:working_directory"),
608 Some(&"/tmp".to_string())
609 );
610 assert!(agent
612 .chat_session()
613 .get_metadata("working_directory")
614 .is_none());
615
616 let removed = agent.remove_memory("working_directory");
618 assert_eq!(removed.as_deref(), Some("/tmp"));
619 assert!(agent.get_memory("working_directory").is_none());
620 }
621
622 #[tokio::test]
624 async fn test_agent_clear_memory_scoped() {
625 let config = Config::new_default();
626 let mut agent = Agent::builder("test_agent")
627 .config(config)
628 .build()
629 .await
630 .unwrap();
631
632 agent.set_memory("tasks_completed", "3");
634 agent
635 .chat_session_mut()
636 .set_metadata("session_start", "now");
637
638 agent.clear_memory();
640
641 assert!(agent.get_memory("tasks_completed").is_none());
643 assert_eq!(
645 agent.chat_session().get_metadata("session_start"),
646 Some(&"now".to_string())
647 );
648 }
649
650 #[tokio::test]
652 async fn test_agent_increment_helpers() {
653 let config = Config::new_default();
654 let mut agent = Agent::builder("test_agent")
655 .config(config)
656 .build()
657 .await
658 .unwrap();
659
660 let n1 = agent.increment_tasks_completed();
662 assert_eq!(n1, 1);
663 assert_eq!(agent.get_memory("tasks_completed"), Some(&"1".to_string()));
664
665 let n2 = agent.increment_tasks_completed();
666 assert_eq!(n2, 2);
667 assert_eq!(agent.get_memory("tasks_completed"), Some(&"2".to_string()));
668
669 let f1 = agent.increment_counter("files_accessed");
671 assert_eq!(f1, 1);
672 let f2 = agent.increment_counter("files_accessed");
673 assert_eq!(f2, 2);
674 assert_eq!(agent.get_memory("files_accessed"), Some(&"2".to_string()));
675 }
676
677 #[tokio::test]
679 async fn test_agent_builder() {
680 let config = Config::new_default();
681 let agent = Agent::builder("test_agent")
682 .config(config)
683 .system_prompt("You are a helpful assistant")
684 .max_iterations(5)
685 .tool(Box::new(CalculatorTool))
686 .build()
687 .await
688 .unwrap();
689
690 assert_eq!(agent.name(), "test_agent");
691 assert_eq!(agent.max_iterations, 5);
692 assert_eq!(
693 agent.tool_registry().list_tools(),
694 vec!["calculator".to_string()]
695 );
696 }
697
698 #[tokio::test]
700 async fn test_agent_system_prompt() {
701 let config = Config::new_default();
702 let mut agent = Agent::builder("test_agent")
703 .config(config)
704 .build()
705 .await
706 .unwrap();
707 agent.set_system_prompt("You are a test agent");
708
709 let session = agent.chat_session();
711 assert_eq!(
712 session.system_prompt,
713 Some("You are a test agent".to_string())
714 );
715 }
716
717 #[tokio::test]
719 async fn test_agent_tool_registry() {
720 let config = Config::new_default();
721 let mut agent = Agent::builder("test_agent")
722 .config(config)
723 .build()
724 .await
725 .unwrap();
726
727 assert!(agent.tool_registry().list_tools().is_empty());
729
730 agent.register_tool(Box::new(CalculatorTool));
732 assert_eq!(
733 agent.tool_registry().list_tools(),
734 vec!["calculator".to_string()]
735 );
736 }
737
738 #[tokio::test]
740 async fn test_agent_clear_history() {
741 let config = Config::new_default();
742 let mut agent = Agent::builder("test_agent")
743 .config(config)
744 .build()
745 .await
746 .unwrap();
747
748 agent.chat_session_mut().add_user_message("Hello");
750 assert!(!agent.chat_session().messages.is_empty());
751
752 agent.clear_history();
754 assert!(agent.chat_session().messages.is_empty());
755 }
756
757 struct MockTool;
759
760 #[async_trait::async_trait]
761 impl Tool for MockTool {
762 fn name(&self) -> &str {
763 "mock_tool"
764 }
765
766 fn description(&self) -> &str {
767 "A mock tool for testing"
768 }
769
770 fn parameters(&self) -> HashMap<String, ToolParameter> {
771 let mut params = HashMap::new();
772 params.insert(
773 "input".to_string(),
774 ToolParameter {
775 param_type: "string".to_string(),
776 description: "Input parameter".to_string(),
777 required: Some(true),
778 },
779 );
780 params
781 }
782
783 async fn execute(&self, args: Value) -> crate::Result<ToolResult> {
784 let input = args
785 .get("input")
786 .and_then(|v| v.as_str())
787 .unwrap_or("default");
788 Ok(ToolResult::success(format!("Mock tool output: {}", input)))
789 }
790 }
791}
792
793pub struct AgentBuilder {
794 name: String,
795 config: Option<Config>,
796 system_prompt: Option<String>,
797 tools: Vec<Box<dyn crate::tools::Tool>>,
798 max_iterations: usize,
799}
800
801impl AgentBuilder {
802 pub fn new(name: impl Into<String>) -> Self {
803 Self {
804 name: name.into(),
805 config: None,
806 system_prompt: None,
807 tools: Vec::new(),
808 max_iterations: 10,
809 }
810 }
811
812 pub fn config(mut self, config: Config) -> Self {
813 self.config = Some(config);
814 self
815 }
816
817 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
818 self.system_prompt = Some(prompt.into());
819 self
820 }
821
822 pub fn tool(mut self, tool: Box<dyn crate::tools::Tool>) -> Self {
823 self.tools.push(tool);
824 self
825 }
826
827 pub fn max_iterations(mut self, max: usize) -> Self {
828 self.max_iterations = max;
829 self
830 }
831
832 pub async fn build(self) -> Result<Agent> {
833 let config = self
834 .config
835 .ok_or_else(|| HeliosError::AgentError("Config is required".to_string()))?;
836
837 let mut agent = Agent::new(self.name, config).await?;
838
839 if let Some(prompt) = self.system_prompt {
840 agent.set_system_prompt(prompt);
841 }
842
843 for tool in self.tools {
844 agent.register_tool(tool);
845 }
846
847 agent.set_max_iterations(self.max_iterations);
848
849 Ok(agent)
850 }
851}