helios_engine/
agent.rs

1#![allow(dead_code)]
2#![allow(unused_variables)]
3use crate::chat::{ChatMessage, ChatSession};
4use crate::config::Config;
5use crate::error::{HeliosError, Result};
6use crate::llm::{LLMClient, LLMProviderType};
7use crate::tools::{ToolRegistry, ToolResult};
8use serde_json::Value;
9
10const AGENT_MEMORY_PREFIX: &str = "agent:";
11
12pub struct Agent {
13    name: String,
14    llm_client: LLMClient,
15    tool_registry: ToolRegistry,
16    chat_session: ChatSession,
17    max_iterations: usize,
18}
19
20impl Agent {
21    async fn new(name: impl Into<String>, config: Config) -> Result<Self> {
22        let provider_type = if let Some(local_config) = config.local {
23            LLMProviderType::Local(local_config)
24        } else {
25            LLMProviderType::Remote(config.llm)
26        };
27
28        let llm_client = LLMClient::new(provider_type).await?;
29
30        Ok(Self {
31            name: name.into(),
32            llm_client,
33            tool_registry: ToolRegistry::new(),
34            chat_session: ChatSession::new(),
35            max_iterations: 10,
36        })
37    }
38
39    pub fn builder(name: impl Into<String>) -> AgentBuilder {
40        AgentBuilder::new(name)
41    }
42
43    pub fn name(&self) -> &str {
44        &self.name
45    }
46
47    pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
48        self.chat_session = self.chat_session.clone().with_system_prompt(prompt);
49    }
50
51    pub fn register_tool(&mut self, tool: Box<dyn crate::tools::Tool>) {
52        self.tool_registry.register(tool);
53    }
54
55    pub fn tool_registry(&self) -> &ToolRegistry {
56        &self.tool_registry
57    }
58
59    pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry {
60        &mut self.tool_registry
61    }
62
63    pub fn chat_session(&self) -> &ChatSession {
64        &self.chat_session
65    }
66
67    pub fn chat_session_mut(&mut self) -> &mut ChatSession {
68        &mut self.chat_session
69    }
70
71    pub fn clear_history(&mut self) {
72        self.chat_session.clear();
73    }
74
75    pub async fn send_message(&mut self, message: impl Into<String>) -> Result<String> {
76        let user_message = message.into();
77        self.chat_session.add_user_message(user_message.clone());
78
79        // Execute agent loop with tool calling
80        let response = self.execute_with_tools().await?;
81
82        Ok(response)
83    }
84
85    async fn execute_with_tools(&mut self) -> Result<String> {
86        let mut iterations = 0;
87        let tool_definitions = self.tool_registry.get_definitions();
88
89        loop {
90            if iterations >= self.max_iterations {
91                return Err(HeliosError::AgentError(
92                    "Maximum iterations reached".to_string(),
93                ));
94            }
95
96            let messages = self.chat_session.get_messages();
97            let tools_option = if tool_definitions.is_empty() {
98                None
99            } else {
100                Some(tool_definitions.clone())
101            };
102
103            let response = self.llm_client.chat(messages, tools_option).await?;
104
105            // Check if the response includes tool calls
106            if let Some(ref tool_calls) = response.tool_calls {
107                // Add assistant message with tool calls
108                self.chat_session.add_message(response.clone());
109
110                // Execute each tool call
111                for tool_call in tool_calls {
112                    let tool_name = &tool_call.function.name;
113                    let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
114                        .unwrap_or(Value::Object(serde_json::Map::new()));
115
116                    let tool_result = self
117                        .tool_registry
118                        .execute(tool_name, tool_args)
119                        .await
120                        .unwrap_or_else(|e| {
121                            ToolResult::error(format!("Tool execution failed: {}", e))
122                        });
123
124                    // Add tool result message
125                    let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
126                    self.chat_session.add_message(tool_message);
127                }
128
129                iterations += 1;
130                continue;
131            }
132
133            // No tool calls, we have the final response
134            self.chat_session.add_message(response.clone());
135            return Ok(response.content);
136        }
137    }
138
139    pub async fn chat(&mut self, message: impl Into<String>) -> Result<String> {
140        self.send_message(message).await
141    }
142
143    pub fn set_max_iterations(&mut self, max: usize) {
144        self.max_iterations = max;
145    }
146
147    pub fn get_session_summary(&self) -> String {
148        self.chat_session.get_summary()
149    }
150
151    pub fn clear_memory(&mut self) {
152        // Only clear agent-scoped memory keys to avoid wiping general session metadata
153        self.chat_session
154            .metadata
155            .retain(|k, _| !k.starts_with(AGENT_MEMORY_PREFIX));
156    }
157
158    #[inline]
159    fn prefixed_key(key: &str) -> String {
160        format!("{}{}", AGENT_MEMORY_PREFIX, key)
161    }
162
163    // Agent-scoped memory API (namespaced under "agent:")
164    pub fn set_memory(&mut self, key: impl Into<String>, value: impl Into<String>) {
165        let key = key.into();
166        self.chat_session
167            .set_metadata(Self::prefixed_key(&key), value);
168    }
169
170    pub fn get_memory(&self, key: &str) -> Option<&String> {
171        self.chat_session.get_metadata(&Self::prefixed_key(key))
172    }
173
174    pub fn remove_memory(&mut self, key: &str) -> Option<String> {
175        self.chat_session.remove_metadata(&Self::prefixed_key(key))
176    }
177
178    // Convenience helpers to reduce duplication in examples
179    pub fn increment_counter(&mut self, key: &str) -> u32 {
180        let current = self
181            .get_memory(key)
182            .and_then(|v| v.parse::<u32>().ok())
183            .unwrap_or(0);
184        let next = current + 1;
185        self.set_memory(key, next.to_string());
186        next
187    }
188
189    pub fn increment_tasks_completed(&mut self) -> u32 {
190        self.increment_counter("tasks_completed")
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use crate::config::Config;
198    use crate::tools::{CalculatorTool, Tool, ToolParameter, ToolResult};
199    use serde_json::Value;
200    use std::collections::HashMap;
201
202    #[tokio::test]
203    async fn test_agent_creation_via_builder() {
204        let config = Config::new_default();
205        let agent = Agent::builder("test_agent")
206            .config(config)
207            .build()
208            .await;
209        assert!(agent.is_ok());
210    }
211
212    #[tokio::test]
213    async fn test_agent_memory_namespacing_set_get_remove() {
214        let config = Config::new_default();
215        let mut agent = Agent::builder("test_agent")
216            .config(config)
217            .build()
218            .await
219            .unwrap();
220
221        // Set and get namespaced memory
222        agent.set_memory("working_directory", "/tmp");
223        assert_eq!(
224            agent.get_memory("working_directory"),
225            Some(&"/tmp".to_string())
226        );
227
228        // Ensure underlying chat session stored the prefixed key
229        assert_eq!(
230            agent
231                .chat_session()
232                .get_metadata("agent:working_directory"),
233            Some(&"/tmp".to_string())
234        );
235        // Non-prefixed key should not exist
236        assert!(agent.chat_session().get_metadata("working_directory").is_none());
237
238        // Remove should also be namespaced
239        let removed = agent.remove_memory("working_directory");
240        assert_eq!(removed.as_deref(), Some("/tmp"));
241        assert!(agent.get_memory("working_directory").is_none());
242    }
243
244    #[tokio::test]
245    async fn test_agent_clear_memory_scoped() {
246        let config = Config::new_default();
247        let mut agent = Agent::builder("test_agent")
248            .config(config)
249            .build()
250            .await
251            .unwrap();
252
253        // Set an agent memory and a general (non-agent) session metadata key
254        agent.set_memory("tasks_completed", "3");
255        agent
256            .chat_session_mut()
257            .set_metadata("session_start", "now");
258
259        // Clear only agent-scoped memory
260        agent.clear_memory();
261
262        // Agent memory removed
263        assert!(agent.get_memory("tasks_completed").is_none());
264        // General session metadata preserved
265        assert_eq!(
266            agent.chat_session().get_metadata("session_start"),
267            Some(&"now".to_string())
268        );
269    }
270
271    #[tokio::test]
272    async fn test_agent_increment_helpers() {
273        let config = Config::new_default();
274        let mut agent = Agent::builder("test_agent")
275            .config(config)
276            .build()
277            .await
278            .unwrap();
279
280        // tasks_completed increments from 0
281        let n1 = agent.increment_tasks_completed();
282        assert_eq!(n1, 1);
283        assert_eq!(agent.get_memory("tasks_completed"), Some(&"1".to_string()));
284
285        let n2 = agent.increment_tasks_completed();
286        assert_eq!(n2, 2);
287        assert_eq!(agent.get_memory("tasks_completed"), Some(&"2".to_string()));
288
289        // generic counter
290        let f1 = agent.increment_counter("files_accessed");
291        assert_eq!(f1, 1);
292        let f2 = agent.increment_counter("files_accessed");
293        assert_eq!(f2, 2);
294        assert_eq!(agent.get_memory("files_accessed"), Some(&"2".to_string()));
295    }
296
297    #[tokio::test]
298    async fn test_agent_builder() {
299        let config = Config::new_default();
300        let agent = Agent::builder("test_agent")
301            .config(config)
302            .system_prompt("You are a helpful assistant")
303            .max_iterations(5)
304            .tool(Box::new(CalculatorTool))
305            .build()
306            .await
307            .unwrap();
308
309        assert_eq!(agent.name(), "test_agent");
310        assert_eq!(agent.max_iterations, 5);
311        assert_eq!(
312            agent.tool_registry().list_tools(),
313            vec!["calculator".to_string()]
314        );
315    }
316
317    #[tokio::test]
318    async fn test_agent_system_prompt() {
319        let config = Config::new_default();
320        let mut agent = Agent::builder("test_agent")
321            .config(config)
322            .build()
323            .await
324            .unwrap();
325        agent.set_system_prompt("You are a test agent");
326
327        // Check that the system prompt is set in chat session
328        let session = agent.chat_session();
329        assert_eq!(
330            session.system_prompt,
331            Some("You are a test agent".to_string())
332        );
333    }
334
335    #[tokio::test]
336    async fn test_agent_tool_registry() {
337        let config = Config::new_default();
338        let mut agent = Agent::builder("test_agent")
339            .config(config)
340            .build()
341            .await
342            .unwrap();
343
344        // Initially no tools
345        assert!(agent.tool_registry().list_tools().is_empty());
346
347        // Register a tool
348        agent.register_tool(Box::new(CalculatorTool));
349        assert_eq!(
350            agent.tool_registry().list_tools(),
351            vec!["calculator".to_string()]
352        );
353    }
354
355    #[tokio::test]
356    async fn test_agent_clear_history() {
357        let config = Config::new_default();
358        let mut agent = Agent::builder("test_agent")
359            .config(config)
360            .build()
361            .await
362            .unwrap();
363
364        // Add a message to the chat session
365        agent.chat_session_mut().add_user_message("Hello");
366        assert!(!agent.chat_session().messages.is_empty());
367
368        // Clear history
369        agent.clear_history();
370        assert!(agent.chat_session().messages.is_empty());
371    }
372
373    // Mock tool for testing
374    struct MockTool;
375
376    #[async_trait::async_trait]
377    impl Tool for MockTool {
378        fn name(&self) -> &str {
379            "mock_tool"
380        }
381
382        fn description(&self) -> &str {
383            "A mock tool for testing"
384        }
385
386        fn parameters(&self) -> HashMap<String, ToolParameter> {
387            let mut params = HashMap::new();
388            params.insert(
389                "input".to_string(),
390                ToolParameter {
391                    param_type: "string".to_string(),
392                    description: "Input parameter".to_string(),
393                    required: Some(true),
394                },
395            );
396            params
397        }
398
399        async fn execute(&self, args: Value) -> crate::Result<ToolResult> {
400            let input = args
401                .get("input")
402                .and_then(|v| v.as_str())
403                .unwrap_or("default");
404            Ok(ToolResult::success(format!("Mock tool output: {}", input)))
405        }
406    }
407}
408
409pub struct AgentBuilder {
410    name: String,
411    config: Option<Config>,
412    system_prompt: Option<String>,
413    tools: Vec<Box<dyn crate::tools::Tool>>,
414    max_iterations: usize,
415}
416
417impl AgentBuilder {
418    pub fn new(name: impl Into<String>) -> Self {
419        Self {
420            name: name.into(),
421            config: None,
422            system_prompt: None,
423            tools: Vec::new(),
424            max_iterations: 10,
425        }
426    }
427
428    pub fn config(mut self, config: Config) -> Self {
429        self.config = Some(config);
430        self
431    }
432
433    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
434        self.system_prompt = Some(prompt.into());
435        self
436    }
437
438    pub fn tool(mut self, tool: Box<dyn crate::tools::Tool>) -> Self {
439        self.tools.push(tool);
440        self
441    }
442
443    pub fn max_iterations(mut self, max: usize) -> Self {
444        self.max_iterations = max;
445        self
446    }
447
448    pub async fn build(self) -> Result<Agent> {
449        let config = self
450            .config
451            .ok_or_else(|| HeliosError::AgentError("Config is required".to_string()))?;
452
453        let mut agent = Agent::new(self.name, config).await?;
454
455        if let Some(prompt) = self.system_prompt {
456            agent.set_system_prompt(prompt);
457        }
458
459        for tool in self.tools {
460            agent.register_tool(tool);
461        }
462
463        agent.set_max_iterations(self.max_iterations);
464
465        Ok(agent)
466    }
467}