hehe_tools/
executor.rs

1use crate::error::{Result, ToolError};
2use crate::registry::ToolRegistry;
3use crate::traits::ToolOutput;
4use hehe_core::{Context, ToolCall, ToolCallStatus};
5use serde_json::Value;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::time::timeout;
9use tracing::{info, warn};
10
11pub struct ToolExecutor {
12    registry: Arc<ToolRegistry>,
13    default_timeout: Duration,
14    require_confirmation_for_dangerous: bool,
15}
16
17impl ToolExecutor {
18    pub fn new(registry: Arc<ToolRegistry>) -> Self {
19        Self {
20            registry,
21            default_timeout: Duration::from_secs(60),
22            require_confirmation_for_dangerous: true,
23        }
24    }
25
26    pub fn with_timeout(mut self, timeout: Duration) -> Self {
27        self.default_timeout = timeout;
28        self
29    }
30
31    pub fn allow_dangerous_without_confirmation(mut self) -> Self {
32        self.require_confirmation_for_dangerous = false;
33        self
34    }
35
36    pub async fn execute(
37        &self,
38        ctx: &Context,
39        name: &str,
40        input: Value,
41    ) -> Result<ToolOutput> {
42        let tool = self
43            .registry
44            .get(name)
45            .ok_or_else(|| ToolError::not_found(name))?;
46
47        if ctx.is_cancelled() {
48            return Err(ToolError::Cancelled);
49        }
50
51        tool.validate_input(&input)?;
52
53        info!(tool = name, "Executing tool");
54
55        let execute_timeout = ctx
56            .remaining()
57            .unwrap_or(self.default_timeout)
58            .min(self.default_timeout);
59
60        let result = timeout(execute_timeout, tool.execute(ctx, input)).await;
61
62        match result {
63            Ok(Ok(output)) => {
64                info!(tool = name, is_error = output.is_error, "Tool execution completed");
65                Ok(output)
66            }
67            Ok(Err(e)) => {
68                warn!(tool = name, error = %e, "Tool execution failed");
69                Err(e)
70            }
71            Err(_) => {
72                warn!(tool = name, timeout_ms = ?execute_timeout.as_millis(), "Tool execution timed out");
73                Err(ToolError::Timeout(execute_timeout.as_millis() as u64))
74            }
75        }
76    }
77
78    pub async fn execute_call(&self, ctx: &Context, call: &mut ToolCall) -> Result<ToolOutput> {
79        call.start();
80
81        match self.execute(ctx, &call.name, call.input.clone()).await {
82            Ok(output) => {
83                if output.is_error {
84                    call.fail(&output.content);
85                } else {
86                    call.complete(serde_json::to_value(&output.content).unwrap_or(Value::Null));
87                }
88                Ok(output)
89            }
90            Err(e) => {
91                call.fail(e.to_string());
92                Err(e)
93            }
94        }
95    }
96
97    pub fn registry(&self) -> &ToolRegistry {
98        &self.registry
99    }
100
101    pub fn is_dangerous(&self, name: &str) -> bool {
102        self.registry
103            .get(name)
104            .map(|t| t.is_dangerous())
105            .unwrap_or(false)
106    }
107
108    pub fn needs_confirmation(&self, name: &str) -> bool {
109        self.require_confirmation_for_dangerous && self.is_dangerous(name)
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use crate::traits::Tool;
117    use async_trait::async_trait;
118    use hehe_core::ToolDefinition;
119
120    struct EchoTool {
121        def: ToolDefinition,
122    }
123
124    impl EchoTool {
125        fn new() -> Self {
126            Self {
127                def: ToolDefinition::new("echo", "Echoes input"),
128            }
129        }
130    }
131
132    #[async_trait]
133    impl Tool for EchoTool {
134        fn definition(&self) -> &ToolDefinition {
135            &self.def
136        }
137
138        async fn execute(&self, _ctx: &Context, input: Value) -> Result<ToolOutput> {
139            Ok(ToolOutput::text(input.to_string()))
140        }
141    }
142
143    struct SlowTool {
144        def: ToolDefinition,
145    }
146
147    impl SlowTool {
148        fn new() -> Self {
149            Self {
150                def: ToolDefinition::new("slow", "A slow tool"),
151            }
152        }
153    }
154
155    #[async_trait]
156    impl Tool for SlowTool {
157        fn definition(&self) -> &ToolDefinition {
158            &self.def
159        }
160
161        async fn execute(&self, _ctx: &Context, _input: Value) -> Result<ToolOutput> {
162            tokio::time::sleep(Duration::from_secs(10)).await;
163            Ok(ToolOutput::text("done"))
164        }
165    }
166
167    #[tokio::test]
168    async fn test_executor_execute() {
169        let mut registry = ToolRegistry::new();
170        registry.register(Arc::new(EchoTool::new())).unwrap();
171
172        let executor = ToolExecutor::new(Arc::new(registry));
173        let ctx = Context::new();
174
175        let output = executor
176            .execute(&ctx, "echo", serde_json::json!({"message": "hello"}))
177            .await
178            .unwrap();
179
180        assert!(output.content.contains("hello"));
181    }
182
183    #[tokio::test]
184    async fn test_executor_not_found() {
185        let registry = ToolRegistry::new();
186        let executor = ToolExecutor::new(Arc::new(registry));
187        let ctx = Context::new();
188
189        let result = executor.execute(&ctx, "nonexistent", Value::Null).await;
190        assert!(matches!(result, Err(ToolError::NotFound(_))));
191    }
192
193    #[tokio::test]
194    async fn test_executor_timeout() {
195        let mut registry = ToolRegistry::new();
196        registry.register(Arc::new(SlowTool::new())).unwrap();
197
198        let executor = ToolExecutor::new(Arc::new(registry))
199            .with_timeout(Duration::from_millis(100));
200        let ctx = Context::new();
201
202        let result = executor.execute(&ctx, "slow", Value::Null).await;
203        assert!(matches!(result, Err(ToolError::Timeout(_))));
204    }
205
206    #[tokio::test]
207    async fn test_executor_execute_call() {
208        let mut registry = ToolRegistry::new();
209        registry.register(Arc::new(EchoTool::new())).unwrap();
210
211        let executor = ToolExecutor::new(Arc::new(registry));
212        let ctx = Context::new();
213
214        let mut call = ToolCall::new("echo", serde_json::json!({"x": 1}));
215        assert!(call.is_pending());
216
217        let output = executor.execute_call(&ctx, &mut call).await.unwrap();
218        
219        assert!(call.is_completed());
220        assert!(!output.is_error);
221    }
222}