helios_engine/
agent.rs

1use crate::chat::{ChatMessage, ChatSession};
2use crate::config::Config;
3use crate::error::{HeliosError, Result};
4use crate::llm::{LLMClient, LLMProviderType};
5use crate::tools::{ToolRegistry, ToolResult};
6use serde_json::Value;
7
8pub struct Agent {
9    name: String,
10    llm_client: LLMClient,
11    tool_registry: ToolRegistry,
12    chat_session: ChatSession,
13    max_iterations: usize,
14}
15
16impl Agent {
17    pub async fn new(name: impl Into<String>, config: Config) -> Result<Self> {
18        let provider_type = if let Some(local_config) = config.local {
19            LLMProviderType::Local(local_config)
20        } else {
21            LLMProviderType::Remote(config.llm)
22        };
23
24        let llm_client = LLMClient::new(provider_type).await?;
25
26        Ok(Self {
27            name: name.into(),
28            llm_client,
29            tool_registry: ToolRegistry::new(),
30            chat_session: ChatSession::new(),
31            max_iterations: 10,
32        })
33    }
34
35    pub fn builder(name: impl Into<String>) -> AgentBuilder {
36        AgentBuilder::new(name)
37    }
38
39    pub fn name(&self) -> &str {
40        &self.name
41    }
42
43    pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
44        self.chat_session = self.chat_session.clone().with_system_prompt(prompt);
45    }
46
47    pub fn register_tool(&mut self, tool: Box<dyn crate::tools::Tool>) {
48        self.tool_registry.register(tool);
49    }
50
51    pub fn tool_registry(&self) -> &ToolRegistry {
52        &self.tool_registry
53    }
54
55    pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry {
56        &mut self.tool_registry
57    }
58
59    pub fn chat_session(&self) -> &ChatSession {
60        &self.chat_session
61    }
62
63    pub fn chat_session_mut(&mut self) -> &mut ChatSession {
64        &mut self.chat_session
65    }
66
67    pub fn clear_history(&mut self) {
68        self.chat_session.clear();
69    }
70
71    pub async fn send_message(&mut self, message: impl Into<String>) -> Result<String> {
72        let user_message = message.into();
73        self.chat_session.add_user_message(user_message.clone());
74
75        // Execute agent loop with tool calling
76        let response = self.execute_with_tools().await?;
77
78        Ok(response)
79    }
80
81    async fn execute_with_tools(&mut self) -> Result<String> {
82        let mut iterations = 0;
83        let tool_definitions = self.tool_registry.get_definitions();
84
85        loop {
86            if iterations >= self.max_iterations {
87                return Err(HeliosError::AgentError(
88                    "Maximum iterations reached".to_string(),
89                ));
90            }
91
92            let messages = self.chat_session.get_messages();
93            let tools_option = if tool_definitions.is_empty() {
94                None
95            } else {
96                Some(tool_definitions.clone())
97            };
98
99            let response = self.llm_client.chat(messages, tools_option).await?;
100
101            // Check if the response includes tool calls
102            if let Some(ref tool_calls) = response.tool_calls {
103                // Add assistant message with tool calls
104                self.chat_session.add_message(response.clone());
105
106                // Execute each tool call
107                for tool_call in tool_calls {
108                    let tool_name = &tool_call.function.name;
109                    let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
110                        .unwrap_or(Value::Object(serde_json::Map::new()));
111
112                    let tool_result = self
113                        .tool_registry
114                        .execute(tool_name, tool_args)
115                        .await
116                        .unwrap_or_else(|e| {
117                            ToolResult::error(format!("Tool execution failed: {}", e))
118                        });
119
120                    // Add tool result message
121                    let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
122                    self.chat_session.add_message(tool_message);
123                }
124
125                iterations += 1;
126                continue;
127            }
128
129            // No tool calls, we have the final response
130            self.chat_session.add_message(response.clone());
131            return Ok(response.content);
132        }
133    }
134
135    pub async fn chat(&mut self, message: impl Into<String>) -> Result<String> {
136        self.send_message(message).await
137    }
138
139    pub fn set_max_iterations(&mut self, max: usize) {
140        self.max_iterations = max;
141    }
142    
143    // Session memory methods
144    pub fn set_memory(&mut self, key: impl Into<String>, value: impl Into<String>) {
145        self.chat_session.set_metadata(key, value);
146    }
147
148    pub fn get_memory(&self, key: &str) -> Option<&String> {
149        self.chat_session.get_metadata(key)
150    }
151
152    pub fn remove_memory(&mut self, key: &str) -> Option<String> {
153        self.chat_session.remove_metadata(key)
154    }
155
156    pub fn get_session_summary(&self) -> String {
157        self.chat_session.get_summary()
158    }
159
160    pub fn clear_memory(&mut self) {
161        self.chat_session.metadata.clear();
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::config::Config;
169    use crate::tools::{CalculatorTool, Tool, ToolParameter, ToolResult};
170    use serde_json::Value;
171    use std::collections::HashMap;
172
173    #[tokio::test]
174    async fn test_agent_new() {
175        let config = Config::new_default();
176        let agent = Agent::new("test_agent", config).await;
177        assert!(agent.is_ok());
178    }
179
180    #[tokio::test]
181    async fn test_agent_builder() {
182        let config = Config::new_default();
183        let agent = Agent::builder("test_agent")
184            .config(config)
185            .system_prompt("You are a helpful assistant")
186            .max_iterations(5)
187            .tool(Box::new(CalculatorTool))
188            .build()
189            .await
190            .unwrap();
191
192        assert_eq!(agent.name(), "test_agent");
193        assert_eq!(agent.max_iterations, 5);
194        assert_eq!(
195            agent.tool_registry().list_tools(),
196            vec!["calculator".to_string()]
197        );
198    }
199
200    #[tokio::test]
201    async fn test_agent_system_prompt() {
202        let config = Config::new_default();
203        let mut agent = Agent::new("test_agent", config).await.unwrap();
204        agent.set_system_prompt("You are a test agent");
205
206        // Check that the system prompt is set in chat session
207        let session = agent.chat_session();
208        assert_eq!(
209            session.system_prompt,
210            Some("You are a test agent".to_string())
211        );
212    }
213
214    #[tokio::test]
215    async fn test_agent_tool_registry() {
216        let config = Config::new_default();
217        let mut agent = Agent::new("test_agent", config).await.unwrap();
218
219        // Initially no tools
220        assert!(agent.tool_registry().list_tools().is_empty());
221
222        // Register a tool
223        agent.register_tool(Box::new(CalculatorTool));
224        assert_eq!(
225            agent.tool_registry().list_tools(),
226            vec!["calculator".to_string()]
227        );
228    }
229
230    #[tokio::test]
231    async fn test_agent_clear_history() {
232        let config = Config::new_default();
233        let mut agent = Agent::new("test_agent", config).await.unwrap();
234
235        // Add a message to the chat session
236        agent.chat_session_mut().add_user_message("Hello");
237        assert!(!agent.chat_session().messages.is_empty());
238
239        // Clear history
240        agent.clear_history();
241        assert!(agent.chat_session().messages.is_empty());
242    }
243
244    // Mock tool for testing
245    struct MockTool;
246
247    #[async_trait::async_trait]
248    impl Tool for MockTool {
249        fn name(&self) -> &str {
250            "mock_tool"
251        }
252
253        fn description(&self) -> &str {
254            "A mock tool for testing"
255        }
256
257        fn parameters(&self) -> HashMap<String, ToolParameter> {
258            let mut params = HashMap::new();
259            params.insert(
260                "input".to_string(),
261                ToolParameter {
262                    param_type: "string".to_string(),
263                    description: "Input parameter".to_string(),
264                    required: Some(true),
265                },
266            );
267            params
268        }
269
270        async fn execute(&self, args: Value) -> crate::Result<ToolResult> {
271            let input = args
272                .get("input")
273                .and_then(|v| v.as_str())
274                .unwrap_or("default");
275            Ok(ToolResult::success(format!("Mock tool output: {}", input)))
276        }
277    }
278}
279
280pub struct AgentBuilder {
281    name: String,
282    config: Option<Config>,
283    system_prompt: Option<String>,
284    tools: Vec<Box<dyn crate::tools::Tool>>,
285    max_iterations: usize,
286}
287
288impl AgentBuilder {
289    pub fn new(name: impl Into<String>) -> Self {
290        Self {
291            name: name.into(),
292            config: None,
293            system_prompt: None,
294            tools: Vec::new(),
295            max_iterations: 10,
296        }
297    }
298
299    pub fn config(mut self, config: Config) -> Self {
300        self.config = Some(config);
301        self
302    }
303
304    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
305        self.system_prompt = Some(prompt.into());
306        self
307    }
308
309    pub fn tool(mut self, tool: Box<dyn crate::tools::Tool>) -> Self {
310        self.tools.push(tool);
311        self
312    }
313
314    pub fn max_iterations(mut self, max: usize) -> Self {
315        self.max_iterations = max;
316        self
317    }
318
319    pub async fn build(self) -> Result<Agent> {
320        let config = self
321            .config
322            .ok_or_else(|| HeliosError::AgentError("Config is required".to_string()))?;
323
324        let mut agent = Agent::new(self.name, config).await?;
325
326        if let Some(prompt) = self.system_prompt {
327            agent.set_system_prompt(prompt);
328        }
329
330        for tool in self.tools {
331            agent.register_tool(tool);
332        }
333
334        agent.set_max_iterations(self.max_iterations);
335
336        Ok(agent)
337    }
338}