1use crate::error::{Result, ToolError};
2use crate::registry::ToolRegistry;
3use crate::traits::ToolOutput;
4use hehe_core::{Context, ToolCall, ToolCallStatus};
5use serde_json::Value;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::time::timeout;
9use tracing::{info, warn};
10
11pub struct ToolExecutor {
12 registry: Arc<ToolRegistry>,
13 default_timeout: Duration,
14 require_confirmation_for_dangerous: bool,
15}
16
17impl ToolExecutor {
18 pub fn new(registry: Arc<ToolRegistry>) -> Self {
19 Self {
20 registry,
21 default_timeout: Duration::from_secs(60),
22 require_confirmation_for_dangerous: true,
23 }
24 }
25
26 pub fn with_timeout(mut self, timeout: Duration) -> Self {
27 self.default_timeout = timeout;
28 self
29 }
30
31 pub fn allow_dangerous_without_confirmation(mut self) -> Self {
32 self.require_confirmation_for_dangerous = false;
33 self
34 }
35
36 pub async fn execute(
37 &self,
38 ctx: &Context,
39 name: &str,
40 input: Value,
41 ) -> Result<ToolOutput> {
42 let tool = self
43 .registry
44 .get(name)
45 .ok_or_else(|| ToolError::not_found(name))?;
46
47 if ctx.is_cancelled() {
48 return Err(ToolError::Cancelled);
49 }
50
51 tool.validate_input(&input)?;
52
53 info!(tool = name, "Executing tool");
54
55 let execute_timeout = ctx
56 .remaining()
57 .unwrap_or(self.default_timeout)
58 .min(self.default_timeout);
59
60 let result = timeout(execute_timeout, tool.execute(ctx, input)).await;
61
62 match result {
63 Ok(Ok(output)) => {
64 info!(tool = name, is_error = output.is_error, "Tool execution completed");
65 Ok(output)
66 }
67 Ok(Err(e)) => {
68 warn!(tool = name, error = %e, "Tool execution failed");
69 Err(e)
70 }
71 Err(_) => {
72 warn!(tool = name, timeout_ms = ?execute_timeout.as_millis(), "Tool execution timed out");
73 Err(ToolError::Timeout(execute_timeout.as_millis() as u64))
74 }
75 }
76 }
77
78 pub async fn execute_call(&self, ctx: &Context, call: &mut ToolCall) -> Result<ToolOutput> {
79 call.start();
80
81 match self.execute(ctx, &call.name, call.input.clone()).await {
82 Ok(output) => {
83 if output.is_error {
84 call.fail(&output.content);
85 } else {
86 call.complete(serde_json::to_value(&output.content).unwrap_or(Value::Null));
87 }
88 Ok(output)
89 }
90 Err(e) => {
91 call.fail(e.to_string());
92 Err(e)
93 }
94 }
95 }
96
97 pub fn registry(&self) -> &ToolRegistry {
98 &self.registry
99 }
100
101 pub fn is_dangerous(&self, name: &str) -> bool {
102 self.registry
103 .get(name)
104 .map(|t| t.is_dangerous())
105 .unwrap_or(false)
106 }
107
108 pub fn needs_confirmation(&self, name: &str) -> bool {
109 self.require_confirmation_for_dangerous && self.is_dangerous(name)
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116 use crate::traits::Tool;
117 use async_trait::async_trait;
118 use hehe_core::ToolDefinition;
119
120 struct EchoTool {
121 def: ToolDefinition,
122 }
123
124 impl EchoTool {
125 fn new() -> Self {
126 Self {
127 def: ToolDefinition::new("echo", "Echoes input"),
128 }
129 }
130 }
131
132 #[async_trait]
133 impl Tool for EchoTool {
134 fn definition(&self) -> &ToolDefinition {
135 &self.def
136 }
137
138 async fn execute(&self, _ctx: &Context, input: Value) -> Result<ToolOutput> {
139 Ok(ToolOutput::text(input.to_string()))
140 }
141 }
142
143 struct SlowTool {
144 def: ToolDefinition,
145 }
146
147 impl SlowTool {
148 fn new() -> Self {
149 Self {
150 def: ToolDefinition::new("slow", "A slow tool"),
151 }
152 }
153 }
154
155 #[async_trait]
156 impl Tool for SlowTool {
157 fn definition(&self) -> &ToolDefinition {
158 &self.def
159 }
160
161 async fn execute(&self, _ctx: &Context, _input: Value) -> Result<ToolOutput> {
162 tokio::time::sleep(Duration::from_secs(10)).await;
163 Ok(ToolOutput::text("done"))
164 }
165 }
166
167 #[tokio::test]
168 async fn test_executor_execute() {
169 let mut registry = ToolRegistry::new();
170 registry.register(Arc::new(EchoTool::new())).unwrap();
171
172 let executor = ToolExecutor::new(Arc::new(registry));
173 let ctx = Context::new();
174
175 let output = executor
176 .execute(&ctx, "echo", serde_json::json!({"message": "hello"}))
177 .await
178 .unwrap();
179
180 assert!(output.content.contains("hello"));
181 }
182
183 #[tokio::test]
184 async fn test_executor_not_found() {
185 let registry = ToolRegistry::new();
186 let executor = ToolExecutor::new(Arc::new(registry));
187 let ctx = Context::new();
188
189 let result = executor.execute(&ctx, "nonexistent", Value::Null).await;
190 assert!(matches!(result, Err(ToolError::NotFound(_))));
191 }
192
193 #[tokio::test]
194 async fn test_executor_timeout() {
195 let mut registry = ToolRegistry::new();
196 registry.register(Arc::new(SlowTool::new())).unwrap();
197
198 let executor = ToolExecutor::new(Arc::new(registry))
199 .with_timeout(Duration::from_millis(100));
200 let ctx = Context::new();
201
202 let result = executor.execute(&ctx, "slow", Value::Null).await;
203 assert!(matches!(result, Err(ToolError::Timeout(_))));
204 }
205
206 #[tokio::test]
207 async fn test_executor_execute_call() {
208 let mut registry = ToolRegistry::new();
209 registry.register(Arc::new(EchoTool::new())).unwrap();
210
211 let executor = ToolExecutor::new(Arc::new(registry));
212 let ctx = Context::new();
213
214 let mut call = ToolCall::new("echo", serde_json::json!({"x": 1}));
215 assert!(call.is_pending());
216
217 let output = executor.execute_call(&ctx, &mut call).await.unwrap();
218
219 assert!(call.is_completed());
220 assert!(!output.is_error);
221 }
222}