helios_engine/
agent.rs

1use crate::chat::{ChatMessage, ChatSession};
2use crate::config::Config;
3use crate::error::{HeliosError, Result};
4use crate::llm::LLMClient;
5use crate::tools::{ToolRegistry, ToolResult};
6use serde_json::Value;
7
8pub struct Agent {
9    name: String,
10    llm_client: LLMClient,
11    tool_registry: ToolRegistry,
12    chat_session: ChatSession,
13    max_iterations: usize,
14}
15
16impl Agent {
17    pub fn new(name: impl Into<String>, config: Config) -> Self {
18        Self {
19            name: name.into(),
20            llm_client: LLMClient::new(config.llm),
21            tool_registry: ToolRegistry::new(),
22            chat_session: ChatSession::new(),
23            max_iterations: 10,
24        }
25    }
26
27    pub fn builder(name: impl Into<String>) -> AgentBuilder {
28        AgentBuilder::new(name)
29    }
30
31    pub fn name(&self) -> &str {
32        &self.name
33    }
34
35    pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
36        self.chat_session = self.chat_session.clone().with_system_prompt(prompt);
37    }
38
39    pub fn register_tool(&mut self, tool: Box<dyn crate::tools::Tool>) {
40        self.tool_registry.register(tool);
41    }
42
43    pub fn tool_registry(&self) -> &ToolRegistry {
44        &self.tool_registry
45    }
46
47    pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry {
48        &mut self.tool_registry
49    }
50
51    pub fn chat_session(&self) -> &ChatSession {
52        &self.chat_session
53    }
54
55    pub fn chat_session_mut(&mut self) -> &mut ChatSession {
56        &mut self.chat_session
57    }
58
59    pub fn clear_history(&mut self) {
60        self.chat_session.clear();
61    }
62
63    pub async fn send_message(&mut self, message: impl Into<String>) -> Result<String> {
64        let user_message = message.into();
65        self.chat_session.add_user_message(user_message.clone());
66
67        // Execute agent loop with tool calling
68        let response = self.execute_with_tools().await?;
69        
70        Ok(response)
71    }
72
73    async fn execute_with_tools(&mut self) -> Result<String> {
74        let mut iterations = 0;
75        let tool_definitions = self.tool_registry.get_definitions();
76
77        loop {
78            if iterations >= self.max_iterations {
79                return Err(HeliosError::AgentError(
80                    "Maximum iterations reached".to_string(),
81                ));
82            }
83
84            let messages = self.chat_session.get_messages();
85            let tools_option = if tool_definitions.is_empty() {
86                None
87            } else {
88                Some(tool_definitions.clone())
89            };
90
91            let response = self.llm_client.chat(messages, tools_option).await?;
92
93            // Check if the response includes tool calls
94            if let Some(ref tool_calls) = response.tool_calls {
95                // Add assistant message with tool calls
96                self.chat_session.add_message(response.clone());
97
98                // Execute each tool call
99                for tool_call in tool_calls {
100                    let tool_name = &tool_call.function.name;
101                    let tool_args: Value = serde_json::from_str(&tool_call.function.arguments)
102                        .unwrap_or(Value::Object(serde_json::Map::new()));
103
104                    let tool_result = self
105                        .tool_registry
106                        .execute(tool_name, tool_args)
107                        .await
108                        .unwrap_or_else(|e| {
109                            ToolResult::error(format!("Tool execution failed: {}", e))
110                        });
111
112                    // Add tool result message
113                    let tool_message = ChatMessage::tool(
114                        tool_result.output,
115                        tool_call.id.clone(),
116                    );
117                    self.chat_session.add_message(tool_message);
118                }
119
120                iterations += 1;
121                continue;
122            }
123
124            // No tool calls, we have the final response
125            self.chat_session.add_message(response.clone());
126            return Ok(response.content);
127        }
128    }
129
130    pub async fn chat(&mut self, message: impl Into<String>) -> Result<String> {
131        self.send_message(message).await
132    }
133
134    pub fn set_max_iterations(&mut self, max: usize) {
135        self.max_iterations = max;
136    }
137}
138
139pub struct AgentBuilder {
140    name: String,
141    config: Option<Config>,
142    system_prompt: Option<String>,
143    tools: Vec<Box<dyn crate::tools::Tool>>,
144    max_iterations: usize,
145}
146
147impl AgentBuilder {
148    pub fn new(name: impl Into<String>) -> Self {
149        Self {
150            name: name.into(),
151            config: None,
152            system_prompt: None,
153            tools: Vec::new(),
154            max_iterations: 10,
155        }
156    }
157
158    pub fn config(mut self, config: Config) -> Self {
159        self.config = Some(config);
160        self
161    }
162
163    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
164        self.system_prompt = Some(prompt.into());
165        self
166    }
167
168    pub fn tool(mut self, tool: Box<dyn crate::tools::Tool>) -> Self {
169        self.tools.push(tool);
170        self
171    }
172
173    pub fn max_iterations(mut self, max: usize) -> Self {
174        self.max_iterations = max;
175        self
176    }
177
178    pub fn build(self) -> Result<Agent> {
179        let config = self
180            .config
181            .ok_or_else(|| HeliosError::AgentError("Config is required".to_string()))?;
182
183        let mut agent = Agent::new(self.name, config);
184
185        if let Some(prompt) = self.system_prompt {
186            agent.set_system_prompt(prompt);
187        }
188
189        for tool in self.tools {
190            agent.register_tool(tool);
191        }
192
193        agent.set_max_iterations(self.max_iterations);
194
195        Ok(agent)
196    }
197}