hehe_agent/
executor.rs

1use crate::config::AgentConfig;
2use crate::error::{AgentError, Result};
3use crate::event::AgentEvent;
4use crate::response::{AgentResponse, ToolCallRecord};
5use crate::session::Session;
6use hehe_core::message::{ContentBlock, ToolResult, ToolUse};
7use hehe_core::{Context, Message};
8use hehe_llm::{CompletionRequest, LlmProvider};
9use hehe_tools::ToolExecutor;
10use std::sync::Arc;
11use std::time::Instant;
12use tokio::sync::mpsc;
13use tracing::{debug, info, warn};
14
15pub struct Executor {
16    config: AgentConfig,
17    llm: Arc<dyn LlmProvider>,
18    tools: Option<Arc<ToolExecutor>>,
19}
20
21impl Executor {
22    pub fn new(
23        config: AgentConfig,
24        llm: Arc<dyn LlmProvider>,
25        tools: Option<Arc<ToolExecutor>>,
26    ) -> Self {
27        Self { config, llm, tools }
28    }
29
30    pub async fn execute(&self, session: &Session, user_input: &str) -> Result<AgentResponse> {
31        let user_message = Message::user(user_input);
32        session.add_message(user_message);
33
34        let mut all_tool_calls = Vec::new();
35        let mut iterations = 0;
36
37        loop {
38            iterations += 1;
39            session.increment_iterations();
40
41            if iterations > self.config.max_iterations {
42                return Err(AgentError::MaxIterationsReached(self.config.max_iterations));
43            }
44
45            info!(iteration = iterations, "Starting agent loop iteration");
46
47            let request = self.build_request(session);
48            let response = self.llm.complete(request).await?;
49
50            let tool_uses = response.message.tool_uses();
51
52            if tool_uses.is_empty() {
53                let text = response.text_content();
54                session.add_message(Message::assistant(&text));
55
56                return Ok(AgentResponse::new(session.id().clone(), text)
57                    .with_tool_calls(all_tool_calls)
58                    .with_iterations(iterations));
59            }
60
61            let mut assistant_content = Vec::new();
62            if !response.text_content().is_empty() {
63                assistant_content.push(ContentBlock::text(response.text_content()));
64            }
65            for tu in &tool_uses {
66                assistant_content.push(ContentBlock::tool_use(ToolUse::new(
67                    &tu.id,
68                    &tu.name,
69                    tu.input.clone(),
70                )));
71            }
72            session.add_message(Message::new(hehe_core::Role::Assistant, assistant_content));
73
74            let tool_results = self.execute_tools(&tool_uses).await;
75
76            for (tu, (output, duration_ms, is_error)) in tool_uses.iter().zip(&tool_results) {
77                all_tool_calls.push(ToolCallRecord {
78                    id: tu.id.clone(),
79                    name: tu.name.clone(),
80                    input: tu.input.clone(),
81                    output: output.clone(),
82                    is_error: *is_error,
83                    duration_ms: *duration_ms,
84                });
85            }
86
87            session.increment_tool_calls(tool_results.len());
88
89            let tool_result_content: Vec<ContentBlock> = tool_uses
90                .iter()
91                .zip(&tool_results)
92                .map(|(tu, (output, _, is_error))| {
93                    if *is_error {
94                        ContentBlock::tool_result(ToolResult::error(&tu.id, output))
95                    } else {
96                        ContentBlock::tool_result(ToolResult::success(&tu.id, output))
97                    }
98                })
99                .collect();
100
101            session.add_message(Message::tool(tool_result_content));
102        }
103    }
104
105    pub async fn execute_stream(
106        &self,
107        session: &Session,
108        user_input: &str,
109        tx: mpsc::Sender<AgentEvent>,
110    ) -> Result<AgentResponse> {
111        let _ = tx.send(AgentEvent::message_start(session.id().clone())).await;
112
113        let result = self.execute(session, user_input).await;
114
115        match &result {
116            Ok(response) => {
117                let _ = tx.send(AgentEvent::text_complete(response.text.clone())).await;
118                let _ = tx.send(AgentEvent::message_end(session.id().clone())).await;
119            }
120            Err(e) => {
121                let _ = tx.send(AgentEvent::error(e.to_string())).await;
122            }
123        }
124
125        result
126    }
127
128    fn build_request(&self, session: &Session) -> CompletionRequest {
129        let messages = session.last_messages(self.config.max_context_messages);
130
131        let mut request = CompletionRequest::new(&self.config.model, messages)
132            .with_system(&self.config.system_prompt)
133            .with_temperature(self.config.temperature);
134
135        if let Some(max_tokens) = self.config.max_tokens {
136            request = request.with_max_tokens(max_tokens as u32);
137        }
138
139        if self.config.tools_enabled {
140            if let Some(tools) = &self.tools {
141                let definitions = tools.registry().definitions();
142                if !definitions.is_empty() {
143                    request = request.with_tools(definitions);
144                }
145            }
146        }
147
148        request
149    }
150
151    async fn execute_tools(
152        &self,
153        tool_uses: &[&ToolUse],
154    ) -> Vec<(String, u64, bool)> {
155        let Some(tools) = &self.tools else {
156            return tool_uses
157                .iter()
158                .map(|tu| (format!("Tool execution not available: {}", tu.name), 0, true))
159                .collect();
160        };
161
162        let ctx = Context::new().with_timeout(self.config.tool_timeout());
163        let mut results = Vec::with_capacity(tool_uses.len());
164
165        for tu in tool_uses {
166            let start = Instant::now();
167            debug!(tool = %tu.name, id = %tu.id, "Executing tool");
168
169            let result = tools.execute(&ctx, &tu.name, tu.input.clone()).await;
170            let duration_ms = start.elapsed().as_millis() as u64;
171
172            match result {
173                Ok(output) => {
174                    info!(tool = %tu.name, duration_ms, is_error = output.is_error, "Tool completed");
175                    results.push((output.content, duration_ms, output.is_error));
176                }
177                Err(e) => {
178                    warn!(tool = %tu.name, error = %e, "Tool execution failed");
179                    results.push((e.to_string(), duration_ms, true));
180                }
181            }
182        }
183
184        results
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use async_trait::async_trait;
192    use hehe_core::capability::Capabilities;
193    use hehe_core::stream::StreamChunk;
194    use hehe_llm::{BoxStream, CompletionResponse, LlmError, ModelInfo};
195
196    struct MockLlm {
197        responses: std::sync::Mutex<Vec<CompletionResponse>>,
198    }
199
200    impl MockLlm {
201        fn new(responses: Vec<CompletionResponse>) -> Self {
202            Self {
203                responses: std::sync::Mutex::new(responses),
204            }
205        }
206    }
207
208    #[async_trait]
209    impl LlmProvider for MockLlm {
210        fn name(&self) -> &str {
211            "mock"
212        }
213
214        fn capabilities(&self) -> &Capabilities {
215            static CAPS: std::sync::OnceLock<Capabilities> = std::sync::OnceLock::new();
216            CAPS.get_or_init(Capabilities::text_basic)
217        }
218
219        async fn complete(&self, _request: CompletionRequest) -> std::result::Result<CompletionResponse, LlmError> {
220            let mut responses = self.responses.lock().unwrap();
221            if responses.is_empty() {
222                Ok(CompletionResponse::new("id", "model", Message::assistant("Default response")))
223            } else {
224                Ok(responses.remove(0))
225            }
226        }
227
228        async fn complete_stream(
229            &self,
230            _request: CompletionRequest,
231        ) -> std::result::Result<BoxStream<StreamChunk>, LlmError> {
232            use futures::stream;
233            Ok(Box::pin(stream::empty()))
234        }
235
236        async fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, LlmError> {
237            Ok(vec![])
238        }
239
240        fn default_model(&self) -> &str {
241            "mock"
242        }
243    }
244
245    #[tokio::test]
246    async fn test_executor_simple_response() {
247        let config = AgentConfig::new("mock", "You are helpful.");
248        let llm = Arc::new(MockLlm::new(vec![CompletionResponse::new(
249            "resp-1",
250            "mock",
251            Message::assistant("Hello!"),
252        )]));
253
254        let executor = Executor::new(config, llm, None);
255        let session = Session::new();
256
257        let response = executor.execute(&session, "Hi").await.unwrap();
258
259        assert_eq!(response.text(), "Hello!");
260        assert_eq!(response.iterations, 1);
261        assert!(!response.has_tool_calls());
262    }
263
264    #[tokio::test]
265    async fn test_executor_max_iterations() {
266        let config = AgentConfig::new("mock", "You are helpful.").with_max_iterations(2);
267
268        let tool_response = Message::new(
269            hehe_core::Role::Assistant,
270            vec![ContentBlock::tool_use(ToolUse::new(
271                "call_1",
272                "test_tool",
273                serde_json::json!({}),
274            ))],
275        );
276
277        let llm = Arc::new(MockLlm::new(vec![
278            CompletionResponse::new("resp-1", "mock", tool_response.clone()),
279            CompletionResponse::new("resp-2", "mock", tool_response.clone()),
280            CompletionResponse::new("resp-3", "mock", tool_response),
281        ]));
282
283        let executor = Executor::new(config, llm, None);
284        let session = Session::new();
285
286        let result = executor.execute(&session, "Hi").await;
287
288        assert!(matches!(result, Err(AgentError::MaxIterationsReached(2))));
289    }
290}