1use crate::config::AgentConfig;
2use crate::error::{AgentError, Result};
3use crate::event::AgentEvent;
4use crate::response::{AgentResponse, ToolCallRecord};
5use crate::session::Session;
6use hehe_core::message::{ContentBlock, ToolResult, ToolUse};
7use hehe_core::{Context, Message};
8use hehe_llm::{CompletionRequest, LlmProvider};
9use hehe_tools::ToolExecutor;
10use std::sync::Arc;
11use std::time::Instant;
12use tokio::sync::mpsc;
13use tracing::{debug, info, warn};
14
15pub struct Executor {
16 config: AgentConfig,
17 llm: Arc<dyn LlmProvider>,
18 tools: Option<Arc<ToolExecutor>>,
19}
20
21impl Executor {
22 pub fn new(
23 config: AgentConfig,
24 llm: Arc<dyn LlmProvider>,
25 tools: Option<Arc<ToolExecutor>>,
26 ) -> Self {
27 Self { config, llm, tools }
28 }
29
30 pub async fn execute(&self, session: &Session, user_input: &str) -> Result<AgentResponse> {
31 let user_message = Message::user(user_input);
32 session.add_message(user_message);
33
34 let mut all_tool_calls = Vec::new();
35 let mut iterations = 0;
36
37 loop {
38 iterations += 1;
39 session.increment_iterations();
40
41 if iterations > self.config.max_iterations {
42 return Err(AgentError::MaxIterationsReached(self.config.max_iterations));
43 }
44
45 info!(iteration = iterations, "Starting agent loop iteration");
46
47 let request = self.build_request(session);
48 let response = self.llm.complete(request).await?;
49
50 let tool_uses = response.message.tool_uses();
51
52 if tool_uses.is_empty() {
53 let text = response.text_content();
54 session.add_message(Message::assistant(&text));
55
56 return Ok(AgentResponse::new(session.id().clone(), text)
57 .with_tool_calls(all_tool_calls)
58 .with_iterations(iterations));
59 }
60
61 let mut assistant_content = Vec::new();
62 if !response.text_content().is_empty() {
63 assistant_content.push(ContentBlock::text(response.text_content()));
64 }
65 for tu in &tool_uses {
66 assistant_content.push(ContentBlock::tool_use(ToolUse::new(
67 &tu.id,
68 &tu.name,
69 tu.input.clone(),
70 )));
71 }
72 session.add_message(Message::new(hehe_core::Role::Assistant, assistant_content));
73
74 let tool_results = self.execute_tools(&tool_uses).await;
75
76 for (tu, (output, duration_ms, is_error)) in tool_uses.iter().zip(&tool_results) {
77 all_tool_calls.push(ToolCallRecord {
78 id: tu.id.clone(),
79 name: tu.name.clone(),
80 input: tu.input.clone(),
81 output: output.clone(),
82 is_error: *is_error,
83 duration_ms: *duration_ms,
84 });
85 }
86
87 session.increment_tool_calls(tool_results.len());
88
89 let tool_result_content: Vec<ContentBlock> = tool_uses
90 .iter()
91 .zip(&tool_results)
92 .map(|(tu, (output, _, is_error))| {
93 if *is_error {
94 ContentBlock::tool_result(ToolResult::error(&tu.id, output))
95 } else {
96 ContentBlock::tool_result(ToolResult::success(&tu.id, output))
97 }
98 })
99 .collect();
100
101 session.add_message(Message::tool(tool_result_content));
102 }
103 }
104
105 pub async fn execute_stream(
106 &self,
107 session: &Session,
108 user_input: &str,
109 tx: mpsc::Sender<AgentEvent>,
110 ) -> Result<AgentResponse> {
111 let _ = tx.send(AgentEvent::message_start(session.id().clone())).await;
112
113 let result = self.execute(session, user_input).await;
114
115 match &result {
116 Ok(response) => {
117 let _ = tx.send(AgentEvent::text_complete(response.text.clone())).await;
118 let _ = tx.send(AgentEvent::message_end(session.id().clone())).await;
119 }
120 Err(e) => {
121 let _ = tx.send(AgentEvent::error(e.to_string())).await;
122 }
123 }
124
125 result
126 }
127
128 fn build_request(&self, session: &Session) -> CompletionRequest {
129 let messages = session.last_messages(self.config.max_context_messages);
130
131 let mut request = CompletionRequest::new(&self.config.model, messages)
132 .with_system(&self.config.system_prompt)
133 .with_temperature(self.config.temperature);
134
135 if let Some(max_tokens) = self.config.max_tokens {
136 request = request.with_max_tokens(max_tokens as u32);
137 }
138
139 if self.config.tools_enabled {
140 if let Some(tools) = &self.tools {
141 let definitions = tools.registry().definitions();
142 if !definitions.is_empty() {
143 request = request.with_tools(definitions);
144 }
145 }
146 }
147
148 request
149 }
150
151 async fn execute_tools(
152 &self,
153 tool_uses: &[&ToolUse],
154 ) -> Vec<(String, u64, bool)> {
155 let Some(tools) = &self.tools else {
156 return tool_uses
157 .iter()
158 .map(|tu| (format!("Tool execution not available: {}", tu.name), 0, true))
159 .collect();
160 };
161
162 let ctx = Context::new().with_timeout(self.config.tool_timeout());
163 let mut results = Vec::with_capacity(tool_uses.len());
164
165 for tu in tool_uses {
166 let start = Instant::now();
167 debug!(tool = %tu.name, id = %tu.id, "Executing tool");
168
169 let result = tools.execute(&ctx, &tu.name, tu.input.clone()).await;
170 let duration_ms = start.elapsed().as_millis() as u64;
171
172 match result {
173 Ok(output) => {
174 info!(tool = %tu.name, duration_ms, is_error = output.is_error, "Tool completed");
175 results.push((output.content, duration_ms, output.is_error));
176 }
177 Err(e) => {
178 warn!(tool = %tu.name, error = %e, "Tool execution failed");
179 results.push((e.to_string(), duration_ms, true));
180 }
181 }
182 }
183
184 results
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use async_trait::async_trait;
192 use hehe_core::capability::Capabilities;
193 use hehe_core::stream::StreamChunk;
194 use hehe_llm::{BoxStream, CompletionResponse, LlmError, ModelInfo};
195
196 struct MockLlm {
197 responses: std::sync::Mutex<Vec<CompletionResponse>>,
198 }
199
200 impl MockLlm {
201 fn new(responses: Vec<CompletionResponse>) -> Self {
202 Self {
203 responses: std::sync::Mutex::new(responses),
204 }
205 }
206 }
207
208 #[async_trait]
209 impl LlmProvider for MockLlm {
210 fn name(&self) -> &str {
211 "mock"
212 }
213
214 fn capabilities(&self) -> &Capabilities {
215 static CAPS: std::sync::OnceLock<Capabilities> = std::sync::OnceLock::new();
216 CAPS.get_or_init(Capabilities::text_basic)
217 }
218
219 async fn complete(&self, _request: CompletionRequest) -> std::result::Result<CompletionResponse, LlmError> {
220 let mut responses = self.responses.lock().unwrap();
221 if responses.is_empty() {
222 Ok(CompletionResponse::new("id", "model", Message::assistant("Default response")))
223 } else {
224 Ok(responses.remove(0))
225 }
226 }
227
228 async fn complete_stream(
229 &self,
230 _request: CompletionRequest,
231 ) -> std::result::Result<BoxStream<StreamChunk>, LlmError> {
232 use futures::stream;
233 Ok(Box::pin(stream::empty()))
234 }
235
236 async fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, LlmError> {
237 Ok(vec![])
238 }
239
240 fn default_model(&self) -> &str {
241 "mock"
242 }
243 }
244
245 #[tokio::test]
246 async fn test_executor_simple_response() {
247 let config = AgentConfig::new("mock", "You are helpful.");
248 let llm = Arc::new(MockLlm::new(vec![CompletionResponse::new(
249 "resp-1",
250 "mock",
251 Message::assistant("Hello!"),
252 )]));
253
254 let executor = Executor::new(config, llm, None);
255 let session = Session::new();
256
257 let response = executor.execute(&session, "Hi").await.unwrap();
258
259 assert_eq!(response.text(), "Hello!");
260 assert_eq!(response.iterations, 1);
261 assert!(!response.has_tool_calls());
262 }
263
264 #[tokio::test]
265 async fn test_executor_max_iterations() {
266 let config = AgentConfig::new("mock", "You are helpful.").with_max_iterations(2);
267
268 let tool_response = Message::new(
269 hehe_core::Role::Assistant,
270 vec![ContentBlock::tool_use(ToolUse::new(
271 "call_1",
272 "test_tool",
273 serde_json::json!({}),
274 ))],
275 );
276
277 let llm = Arc::new(MockLlm::new(vec![
278 CompletionResponse::new("resp-1", "mock", tool_response.clone()),
279 CompletionResponse::new("resp-2", "mock", tool_response.clone()),
280 CompletionResponse::new("resp-3", "mock", tool_response),
281 ]));
282
283 let executor = Executor::new(config, llm, None);
284 let session = Session::new();
285
286 let result = executor.execute(&session, "Hi").await;
287
288 assert!(matches!(result, Err(AgentError::MaxIterationsReached(2))));
289 }
290}