1use crate::chat::{ChatMessage, ChatSession};
2use crate::config::Config;
3use crate::error::{HeliosError, Result};
4use crate::llm::{LLMClient, LLMProviderType};
5use crate::tools::{ToolRegistry, ToolResult};
6use serde_json::Value;
7
8pub struct Agent {
9 name: String,
10 llm_client: LLMClient,
11 tool_registry: ToolRegistry,
12 chat_session: ChatSession,
13 max_iterations: usize,
14}
15
16impl Agent {
17 pub async fn new(name: impl Into<String>, config: Config) -> Result<Self> {
18 let provider_type = if let Some(local_config) = config.local {
19 LLMProviderType::Local(local_config)
20 } else {
21 LLMProviderType::Remote(config.llm)
22 };
23
24 let llm_client = LLMClient::new(provider_type).await?;
25
26 Ok(Self {
27 name: name.into(),
28 llm_client,
29 tool_registry: ToolRegistry::new(),
30 chat_session: ChatSession::new(),
31 max_iterations: 10,
32 })
33 }
34
35 pub fn builder(name: impl Into<String>) -> AgentBuilder {
36 AgentBuilder::new(name)
37 }
38
39 pub fn name(&self) -> &str {
40 &self.name
41 }
42
43 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
44 self.chat_session = self.chat_session.clone().with_system_prompt(prompt);
45 }
46
47 pub fn register_tool(&mut self, tool: Box<dyn crate::tools::Tool>) {
48 self.tool_registry.register(tool);
49 }
50
51 pub fn tool_registry(&self) -> &ToolRegistry {
52 &self.tool_registry
53 }
54
55 pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry {
56 &mut self.tool_registry
57 }
58
59 pub fn chat_session(&self) -> &ChatSession {
60 &self.chat_session
61 }
62
63 pub fn chat_session_mut(&mut self) -> &mut ChatSession {
64 &mut self.chat_session
65 }
66
67 pub fn clear_history(&mut self) {
68 self.chat_session.clear();
69 }
70
71 pub async fn send_message(&mut self, message: impl Into<String>) -> Result<String> {
72 let user_message = message.into();
73 self.chat_session.add_user_message(user_message.clone());
74
75 let response = self.execute_with_tools().await?;
77
78 Ok(response)
79 }
80
81 async fn execute_with_tools(&mut self) -> Result<String> {
82 let mut iterations = 0;
83 let tool_definitions = self.tool_registry.get_definitions();
84
85 loop {
86 if iterations >= self.max_iterations {
87 return Err(HeliosError::AgentError(
88 "Maximum iterations reached".to_string(),
89 ));
90 }
91
92 let messages = self.chat_session.get_messages();
93 let tools_option = if tool_definitions.is_empty() {
94 None
95 } else {
96 Some(tool_definitions.clone())
97 };
98
99 let response = self.llm_client.chat(messages, tools_option).await?;
100
101 if let Some(ref tool_calls) = response.tool_calls {
103 self.chat_session.add_message(response.clone());
105
106 for tool_call in tool_calls {
108 let tool_name = &tool_call.function.name;
109 let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
110 .unwrap_or(Value::Object(serde_json::Map::new()));
111
112 let tool_result = self
113 .tool_registry
114 .execute(tool_name, tool_args)
115 .await
116 .unwrap_or_else(|e| {
117 ToolResult::error(format!("Tool execution failed: {}", e))
118 });
119
120 let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
122 self.chat_session.add_message(tool_message);
123 }
124
125 iterations += 1;
126 continue;
127 }
128
129 self.chat_session.add_message(response.clone());
131 return Ok(response.content);
132 }
133 }
134
135 pub async fn chat(&mut self, message: impl Into<String>) -> Result<String> {
136 self.send_message(message).await
137 }
138
139 pub fn set_max_iterations(&mut self, max: usize) {
140 self.max_iterations = max;
141 }
142
143 pub fn set_memory(&mut self, key: impl Into<String>, value: impl Into<String>) {
145 self.chat_session.set_metadata(key, value);
146 }
147
148 pub fn get_memory(&self, key: &str) -> Option<&String> {
149 self.chat_session.get_metadata(key)
150 }
151
152 pub fn remove_memory(&mut self, key: &str) -> Option<String> {
153 self.chat_session.remove_metadata(key)
154 }
155
156 pub fn get_session_summary(&self) -> String {
157 self.chat_session.get_summary()
158 }
159
160 pub fn clear_memory(&mut self) {
161 self.chat_session.metadata.clear();
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::config::Config;
169 use crate::tools::{CalculatorTool, Tool, ToolParameter, ToolResult};
170 use serde_json::Value;
171 use std::collections::HashMap;
172
173 #[tokio::test]
174 async fn test_agent_new() {
175 let config = Config::new_default();
176 let agent = Agent::new("test_agent", config).await;
177 assert!(agent.is_ok());
178 }
179
180 #[tokio::test]
181 async fn test_agent_builder() {
182 let config = Config::new_default();
183 let agent = Agent::builder("test_agent")
184 .config(config)
185 .system_prompt("You are a helpful assistant")
186 .max_iterations(5)
187 .tool(Box::new(CalculatorTool))
188 .build()
189 .await
190 .unwrap();
191
192 assert_eq!(agent.name(), "test_agent");
193 assert_eq!(agent.max_iterations, 5);
194 assert_eq!(
195 agent.tool_registry().list_tools(),
196 vec!["calculator".to_string()]
197 );
198 }
199
200 #[tokio::test]
201 async fn test_agent_system_prompt() {
202 let config = Config::new_default();
203 let mut agent = Agent::new("test_agent", config).await.unwrap();
204 agent.set_system_prompt("You are a test agent");
205
206 let session = agent.chat_session();
208 assert_eq!(
209 session.system_prompt,
210 Some("You are a test agent".to_string())
211 );
212 }
213
214 #[tokio::test]
215 async fn test_agent_tool_registry() {
216 let config = Config::new_default();
217 let mut agent = Agent::new("test_agent", config).await.unwrap();
218
219 assert!(agent.tool_registry().list_tools().is_empty());
221
222 agent.register_tool(Box::new(CalculatorTool));
224 assert_eq!(
225 agent.tool_registry().list_tools(),
226 vec!["calculator".to_string()]
227 );
228 }
229
230 #[tokio::test]
231 async fn test_agent_clear_history() {
232 let config = Config::new_default();
233 let mut agent = Agent::new("test_agent", config).await.unwrap();
234
235 agent.chat_session_mut().add_user_message("Hello");
237 assert!(!agent.chat_session().messages.is_empty());
238
239 agent.clear_history();
241 assert!(agent.chat_session().messages.is_empty());
242 }
243
244 struct MockTool;
246
247 #[async_trait::async_trait]
248 impl Tool for MockTool {
249 fn name(&self) -> &str {
250 "mock_tool"
251 }
252
253 fn description(&self) -> &str {
254 "A mock tool for testing"
255 }
256
257 fn parameters(&self) -> HashMap<String, ToolParameter> {
258 let mut params = HashMap::new();
259 params.insert(
260 "input".to_string(),
261 ToolParameter {
262 param_type: "string".to_string(),
263 description: "Input parameter".to_string(),
264 required: Some(true),
265 },
266 );
267 params
268 }
269
270 async fn execute(&self, args: Value) -> crate::Result<ToolResult> {
271 let input = args
272 .get("input")
273 .and_then(|v| v.as_str())
274 .unwrap_or("default");
275 Ok(ToolResult::success(format!("Mock tool output: {}", input)))
276 }
277 }
278}
279
280pub struct AgentBuilder {
281 name: String,
282 config: Option<Config>,
283 system_prompt: Option<String>,
284 tools: Vec<Box<dyn crate::tools::Tool>>,
285 max_iterations: usize,
286}
287
288impl AgentBuilder {
289 pub fn new(name: impl Into<String>) -> Self {
290 Self {
291 name: name.into(),
292 config: None,
293 system_prompt: None,
294 tools: Vec::new(),
295 max_iterations: 10,
296 }
297 }
298
299 pub fn config(mut self, config: Config) -> Self {
300 self.config = Some(config);
301 self
302 }
303
304 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
305 self.system_prompt = Some(prompt.into());
306 self
307 }
308
309 pub fn tool(mut self, tool: Box<dyn crate::tools::Tool>) -> Self {
310 self.tools.push(tool);
311 self
312 }
313
314 pub fn max_iterations(mut self, max: usize) -> Self {
315 self.max_iterations = max;
316 self
317 }
318
319 pub async fn build(self) -> Result<Agent> {
320 let config = self
321 .config
322 .ok_or_else(|| HeliosError::AgentError("Config is required".to_string()))?;
323
324 let mut agent = Agent::new(self.name, config).await?;
325
326 if let Some(prompt) = self.system_prompt {
327 agent.set_system_prompt(prompt);
328 }
329
330 for tool in self.tools {
331 agent.register_tool(tool);
332 }
333
334 agent.set_max_iterations(self.max_iterations);
335
336 Ok(agent)
337 }
338}