helios_engine/
agent.rs

1//! # Agent Module
2//!
3//! This module defines the `Agent` struct, the core of the Helios Engine. Agents are
4//! autonomous entities that can interact with users, use tools, and manage their
5//! own chat history. The `AgentBuilder` provides a convenient way to construct
6//! and configure agents.
7
8#![allow(dead_code)]
9#![allow(unused_variables)]
10use crate::chat::{ChatMessage, ChatSession};
11use crate::config::Config;
12use crate::error::{HeliosError, Result};
13use crate::llm::{LLMClient, LLMProviderType};
14use crate::tools::{ToolRegistry, ToolResult};
15use serde_json::Value;
16
17/// Prefix for agent-specific keys in the chat session metadata.
18const AGENT_MEMORY_PREFIX: &str = "agent:";
19
20/// Represents an LLM-powered agent that can chat, use tools, and manage a conversation.
21pub struct Agent {
22    /// The name of the agent.
23    name: String,
24    /// The client for interacting with the Large Language Model.
25    llm_client: LLMClient,
26    /// The registry of tools available to the agent.
27    tool_registry: ToolRegistry,
28    /// The chat session, which stores the conversation history.
29    chat_session: ChatSession,
30    /// The maximum number of iterations for tool execution in a single turn.
31    max_iterations: usize,
32}
33
34impl Agent {
35    /// Creates a new agent with the given name and configuration.
36    ///
37    /// # Arguments
38    ///
39    /// * `name` - The name of the agent.
40    /// * `config` - The configuration for the agent.
41    ///
42    /// # Returns
43    ///
44    /// A `Result` containing the new `Agent` instance.
45    async fn new(name: impl Into<String>, config: Config) -> Result<Self> {
46        let provider_type = if let Some(local_config) = config.local {
47            LLMProviderType::Local(local_config)
48        } else {
49            LLMProviderType::Remote(config.llm)
50        };
51
52        let llm_client = LLMClient::new(provider_type).await?;
53
54        Ok(Self {
55            name: name.into(),
56            llm_client,
57            tool_registry: ToolRegistry::new(),
58            chat_session: ChatSession::new(),
59            max_iterations: 10,
60        })
61    }
62
63    /// Returns a new `AgentBuilder` for constructing an agent.
64    ///
65    /// # Arguments
66    ///
67    /// * `name` - The name of the agent.
68    pub fn builder(name: impl Into<String>) -> AgentBuilder {
69        AgentBuilder::new(name)
70    }
71
72    /// Returns the name of the agent.
73    pub fn name(&self) -> &str {
74        &self.name
75    }
76
77    /// Sets the system prompt for the agent.
78    ///
79    /// # Arguments
80    ///
81    /// * `prompt` - The system prompt to set.
82    pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
83        self.chat_session = self.chat_session.clone().with_system_prompt(prompt);
84    }
85
86    /// Registers a tool with the agent.
87    ///
88    /// # Arguments
89    ///
90    /// * `tool` - The tool to register.
91    pub fn register_tool(&mut self, tool: Box<dyn crate::tools::Tool>) {
92        self.tool_registry.register(tool);
93    }
94
95    /// Returns a reference to the agent's tool registry.
96    pub fn tool_registry(&self) -> &ToolRegistry {
97        &self.tool_registry
98    }
99
100    /// Returns a mutable reference to the agent's tool registry.
101    pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry {
102        &mut self.tool_registry
103    }
104
105    /// Returns a reference to the agent's chat session.
106    pub fn chat_session(&self) -> &ChatSession {
107        &self.chat_session
108    }
109
110    /// Returns a mutable reference to the agent's chat session.
111    pub fn chat_session_mut(&mut self) -> &mut ChatSession {
112        &mut self.chat_session
113    }
114
115    /// Clears the agent's chat history.
116    pub fn clear_history(&mut self) {
117        self.chat_session.clear();
118    }
119
120    /// Sends a message to the agent and gets a response.
121    ///
122    /// # Arguments
123    ///
124    /// * `message` - The message to send.
125    ///
126    /// # Returns
127    ///
128    /// A `Result` containing the agent's response.
129    pub async fn send_message(&mut self, message: impl Into<String>) -> Result<String> {
130        let user_message = message.into();
131        self.chat_session.add_user_message(user_message.clone());
132
133        // Execute agent loop with tool calling
134        let response = self.execute_with_tools().await?;
135
136        Ok(response)
137    }
138
139    /// Executes the agent's main loop, including tool calls.
140    async fn execute_with_tools(&mut self) -> Result<String> {
141        let mut iterations = 0;
142        let tool_definitions = self.tool_registry.get_definitions();
143
144        loop {
145            if iterations >= self.max_iterations {
146                return Err(HeliosError::AgentError(
147                    "Maximum iterations reached".to_string(),
148                ));
149            }
150
151            let messages = self.chat_session.get_messages();
152            let tools_option = if tool_definitions.is_empty() {
153                None
154            } else {
155                Some(tool_definitions.clone())
156            };
157
158            let response = self.llm_client.chat(messages, tools_option).await?;
159
160            // Check if the response includes tool calls
161            if let Some(ref tool_calls) = response.tool_calls {
162                // Add assistant message with tool calls
163                self.chat_session.add_message(response.clone());
164
165                // Execute each tool call
166                for tool_call in tool_calls {
167                    let tool_name = &tool_call.function.name;
168                    let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
169                        .unwrap_or(Value::Object(serde_json::Map::new()));
170
171                    let tool_result = self
172                        .tool_registry
173                        .execute(tool_name, tool_args)
174                        .await
175                        .unwrap_or_else(|e| {
176                            ToolResult::error(format!("Tool execution failed: {}", e))
177                        });
178
179                    // Add tool result message
180                    let tool_message = ChatMessage::tool(tool_result.output, tool_call.id.clone());
181                    self.chat_session.add_message(tool_message);
182                }
183
184                iterations += 1;
185                continue;
186            }
187
188            // No tool calls, we have the final response
189            self.chat_session.add_message(response.clone());
190            return Ok(response.content);
191        }
192    }
193
194    /// A convenience method for sending a message to the agent.
195    pub async fn chat(&mut self, message: impl Into<String>) -> Result<String> {
196        self.send_message(message).await
197    }
198
199    /// Sets the maximum number of iterations for tool execution.
200    ///
201    /// # Arguments
202    ///
203    /// * `max` - The maximum number of iterations.
204    pub fn set_max_iterations(&mut self, max: usize) {
205        self.max_iterations = max;
206    }
207
208    /// Returns a summary of the current chat session.
209    pub fn get_session_summary(&self) -> String {
210        self.chat_session.get_summary()
211    }
212
213    /// Clears the agent's memory (agent-scoped metadata).
214    pub fn clear_memory(&mut self) {
215        // Only clear agent-scoped memory keys to avoid wiping general session metadata
216        self.chat_session
217            .metadata
218            .retain(|k, _| !k.starts_with(AGENT_MEMORY_PREFIX));
219    }
220
221    /// Prefixes a key with the agent memory prefix.
222    #[inline]
223    fn prefixed_key(key: &str) -> String {
224        format!("{}{}", AGENT_MEMORY_PREFIX, key)
225    }
226
227    // Agent-scoped memory API (namespaced under "agent:")
228    /// Sets a value in the agent's memory.
229    pub fn set_memory(&mut self, key: impl Into<String>, value: impl Into<String>) {
230        let key = key.into();
231        self.chat_session
232            .set_metadata(Self::prefixed_key(&key), value);
233    }
234
235    /// Gets a value from the agent's memory.
236    pub fn get_memory(&self, key: &str) -> Option<&String> {
237        self.chat_session.get_metadata(&Self::prefixed_key(key))
238    }
239
240    /// Removes a value from the agent's memory.
241    pub fn remove_memory(&mut self, key: &str) -> Option<String> {
242        self.chat_session.remove_metadata(&Self::prefixed_key(key))
243    }
244
245    // Convenience helpers to reduce duplication in examples
246    /// Increments a counter in the agent's memory.
247    pub fn increment_counter(&mut self, key: &str) -> u32 {
248        let current = self
249            .get_memory(key)
250            .and_then(|v| v.parse::<u32>().ok())
251            .unwrap_or(0);
252        let next = current + 1;
253        self.set_memory(key, next.to_string());
254        next
255    }
256
257    /// Increments the "tasks_completed" counter in the agent's memory.
258    pub fn increment_tasks_completed(&mut self) -> u32 {
259        self.increment_counter("tasks_completed")
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::config::Config;
267    use crate::tools::{CalculatorTool, Tool, ToolParameter, ToolResult};
268    use serde_json::Value;
269    use std::collections::HashMap;
270
271    /// Tests that an agent can be created using the builder.
272    #[tokio::test]
273    async fn test_agent_creation_via_builder() {
274        let config = Config::new_default();
275        let agent = Agent::builder("test_agent").config(config).build().await;
276        assert!(agent.is_ok());
277    }
278
279    /// Tests the namespacing of agent memory.
280    #[tokio::test]
281    async fn test_agent_memory_namespacing_set_get_remove() {
282        let config = Config::new_default();
283        let mut agent = Agent::builder("test_agent")
284            .config(config)
285            .build()
286            .await
287            .unwrap();
288
289        // Set and get namespaced memory
290        agent.set_memory("working_directory", "/tmp");
291        assert_eq!(
292            agent.get_memory("working_directory"),
293            Some(&"/tmp".to_string())
294        );
295
296        // Ensure underlying chat session stored the prefixed key
297        assert_eq!(
298            agent.chat_session().get_metadata("agent:working_directory"),
299            Some(&"/tmp".to_string())
300        );
301        // Non-prefixed key should not exist
302        assert!(agent
303            .chat_session()
304            .get_metadata("working_directory")
305            .is_none());
306
307        // Remove should also be namespaced
308        let removed = agent.remove_memory("working_directory");
309        assert_eq!(removed.as_deref(), Some("/tmp"));
310        assert!(agent.get_memory("working_directory").is_none());
311    }
312
313    /// Tests that clearing agent memory only affects agent-scoped data.
314    #[tokio::test]
315    async fn test_agent_clear_memory_scoped() {
316        let config = Config::new_default();
317        let mut agent = Agent::builder("test_agent")
318            .config(config)
319            .build()
320            .await
321            .unwrap();
322
323        // Set an agent memory and a general (non-agent) session metadata key
324        agent.set_memory("tasks_completed", "3");
325        agent
326            .chat_session_mut()
327            .set_metadata("session_start", "now");
328
329        // Clear only agent-scoped memory
330        agent.clear_memory();
331
332        // Agent memory removed
333        assert!(agent.get_memory("tasks_completed").is_none());
334        // General session metadata preserved
335        assert_eq!(
336            agent.chat_session().get_metadata("session_start"),
337            Some(&"now".to_string())
338        );
339    }
340
341    /// Tests the increment helper methods for agent memory.
342    #[tokio::test]
343    async fn test_agent_increment_helpers() {
344        let config = Config::new_default();
345        let mut agent = Agent::builder("test_agent")
346            .config(config)
347            .build()
348            .await
349            .unwrap();
350
351        // tasks_completed increments from 0
352        let n1 = agent.increment_tasks_completed();
353        assert_eq!(n1, 1);
354        assert_eq!(agent.get_memory("tasks_completed"), Some(&"1".to_string()));
355
356        let n2 = agent.increment_tasks_completed();
357        assert_eq!(n2, 2);
358        assert_eq!(agent.get_memory("tasks_completed"), Some(&"2".to_string()));
359
360        // generic counter
361        let f1 = agent.increment_counter("files_accessed");
362        assert_eq!(f1, 1);
363        let f2 = agent.increment_counter("files_accessed");
364        assert_eq!(f2, 2);
365        assert_eq!(agent.get_memory("files_accessed"), Some(&"2".to_string()));
366    }
367
368    /// Tests the full functionality of the agent builder.
369    #[tokio::test]
370    async fn test_agent_builder() {
371        let config = Config::new_default();
372        let agent = Agent::builder("test_agent")
373            .config(config)
374            .system_prompt("You are a helpful assistant")
375            .max_iterations(5)
376            .tool(Box::new(CalculatorTool))
377            .build()
378            .await
379            .unwrap();
380
381        assert_eq!(agent.name(), "test_agent");
382        assert_eq!(agent.max_iterations, 5);
383        assert_eq!(
384            agent.tool_registry().list_tools(),
385            vec!["calculator".to_string()]
386        );
387    }
388
389    /// Tests setting the system prompt for an agent.
390    #[tokio::test]
391    async fn test_agent_system_prompt() {
392        let config = Config::new_default();
393        let mut agent = Agent::builder("test_agent")
394            .config(config)
395            .build()
396            .await
397            .unwrap();
398        agent.set_system_prompt("You are a test agent");
399
400        // Check that the system prompt is set in chat session
401        let session = agent.chat_session();
402        assert_eq!(
403            session.system_prompt,
404            Some("You are a test agent".to_string())
405        );
406    }
407
408    /// Tests the tool registry functionality of an agent.
409    #[tokio::test]
410    async fn test_agent_tool_registry() {
411        let config = Config::new_default();
412        let mut agent = Agent::builder("test_agent")
413            .config(config)
414            .build()
415            .await
416            .unwrap();
417
418        // Initially no tools
419        assert!(agent.tool_registry().list_tools().is_empty());
420
421        // Register a tool
422        agent.register_tool(Box::new(CalculatorTool));
423        assert_eq!(
424            agent.tool_registry().list_tools(),
425            vec!["calculator".to_string()]
426        );
427    }
428
429    /// Tests clearing the chat history of an agent.
430    #[tokio::test]
431    async fn test_agent_clear_history() {
432        let config = Config::new_default();
433        let mut agent = Agent::builder("test_agent")
434            .config(config)
435            .build()
436            .await
437            .unwrap();
438
439        // Add a message to the chat session
440        agent.chat_session_mut().add_user_message("Hello");
441        assert!(!agent.chat_session().messages.is_empty());
442
443        // Clear history
444        agent.clear_history();
445        assert!(agent.chat_session().messages.is_empty());
446    }
447
448    // Mock tool for testing
449    struct MockTool;
450
451    #[async_trait::async_trait]
452    impl Tool for MockTool {
453        fn name(&self) -> &str {
454            "mock_tool"
455        }
456
457        fn description(&self) -> &str {
458            "A mock tool for testing"
459        }
460
461        fn parameters(&self) -> HashMap<String, ToolParameter> {
462            let mut params = HashMap::new();
463            params.insert(
464                "input".to_string(),
465                ToolParameter {
466                    param_type: "string".to_string(),
467                    description: "Input parameter".to_string(),
468                    required: Some(true),
469                },
470            );
471            params
472        }
473
474        async fn execute(&self, args: Value) -> crate::Result<ToolResult> {
475            let input = args
476                .get("input")
477                .and_then(|v| v.as_str())
478                .unwrap_or("default");
479            Ok(ToolResult::success(format!("Mock tool output: {}", input)))
480        }
481    }
482}
483
484pub struct AgentBuilder {
485    name: String,
486    config: Option<Config>,
487    system_prompt: Option<String>,
488    tools: Vec<Box<dyn crate::tools::Tool>>,
489    max_iterations: usize,
490}
491
492impl AgentBuilder {
493    pub fn new(name: impl Into<String>) -> Self {
494        Self {
495            name: name.into(),
496            config: None,
497            system_prompt: None,
498            tools: Vec::new(),
499            max_iterations: 10,
500        }
501    }
502
503    pub fn config(mut self, config: Config) -> Self {
504        self.config = Some(config);
505        self
506    }
507
508    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
509        self.system_prompt = Some(prompt.into());
510        self
511    }
512
513    pub fn tool(mut self, tool: Box<dyn crate::tools::Tool>) -> Self {
514        self.tools.push(tool);
515        self
516    }
517
518    pub fn max_iterations(mut self, max: usize) -> Self {
519        self.max_iterations = max;
520        self
521    }
522
523    pub async fn build(self) -> Result<Agent> {
524        let config = self
525            .config
526            .ok_or_else(|| HeliosError::AgentError("Config is required".to_string()))?;
527
528        let mut agent = Agent::new(self.name, config).await?;
529
530        if let Some(prompt) = self.system_prompt {
531            agent.set_system_prompt(prompt);
532        }
533
534        for tool in self.tools {
535            agent.register_tool(tool);
536        }
537
538        agent.set_max_iterations(self.max_iterations);
539
540        Ok(agent)
541    }
542}