1use crate::error::AgentError;
2use crate::registry::ToolRegistry;
3use serde_json::Value;
4use std::sync::Arc;
5use std::time::Duration;
6use tracing::instrument;
7
8#[derive(Debug, Clone)]
10pub struct ToolCall {
11 pub id: String,
12 pub name: String,
13 pub args: Value,
14}
15
16impl ToolCall {
17 pub fn new(id: impl Into<String>, name: impl Into<String>, args: Value) -> Self {
18 Self {
19 id: id.into(),
20 name: name.into(),
21 args,
22 }
23 }
24}
25
26#[derive(Debug, Clone)]
28pub struct ToolResult {
29 pub call_id: String,
30 pub output: Result<Value, AgentError>,
31}
32
33pub struct ToolExecutor {
35 registry: Arc<ToolRegistry>,
36 max_concurrent: usize,
37 timeout: Duration,
38}
39
40impl ToolExecutor {
41 pub fn new(registry: ToolRegistry) -> Self {
43 Self {
44 registry: Arc::new(registry),
45 max_concurrent: 5,
46 timeout: Duration::from_secs(60),
47 }
48 }
49
50 pub fn with_max_concurrent(mut self, max: usize) -> Self {
52 self.max_concurrent = max;
53 self
54 }
55
56 pub fn with_timeout(mut self, timeout: Duration) -> Self {
58 self.timeout = timeout;
59 self
60 }
61
62 #[instrument(skip(self, calls))]
64 pub async fn execute_tools(&self, calls: Vec<ToolCall>) -> Vec<ToolResult> {
65 if calls.is_empty() {
66 return Vec::new();
67 }
68
69 let (independent, dependent) = self.categorize_calls(&calls);
71
72 let mut results = self.execute_parallel(independent).await;
74
75 results.extend(self.execute_sequential(dependent).await);
77
78 results
79 }
80
81 fn categorize_calls(&self, calls: &[ToolCall]) -> (Vec<ToolCall>, Vec<ToolCall>) {
83 let mut independent = Vec::new();
84 let mut dependent = Vec::new();
85
86 for call in calls {
87 if self.has_dependencies(call) {
88 dependent.push(call.clone());
89 } else {
90 independent.push(call.clone());
91 }
92 }
93
94 (independent, dependent)
95 }
96
97 fn has_dependencies(&self, call: &ToolCall) -> bool {
99 let args_str = serde_json::to_string(&call.args).unwrap_or_default();
101
102 args_str.contains("$output_") || args_str.contains("$var_") || args_str.contains("$result_")
104 }
105
106 #[instrument(skip(self, calls))]
108 async fn execute_parallel(&self, calls: Vec<ToolCall>) -> Vec<ToolResult> {
109 if calls.is_empty() {
110 return Vec::new();
111 }
112
113 use futures::stream::{self, StreamExt};
114
115 stream::iter(calls)
116 .map(|call| {
117 let registry = Arc::clone(&self.registry);
118 let timeout = self.timeout;
119 async move {
120 let result = tokio::time::timeout(
121 timeout,
122 registry.execute(&call.name, call.args.clone()),
123 )
124 .await;
125
126 ToolResult {
127 call_id: call.id,
128 output: result.unwrap_or(Err(AgentError::ToolError(
129 "Tool execution timed out".to_string(),
130 ))),
131 }
132 }
133 })
134 .buffer_unordered(self.max_concurrent)
135 .collect()
136 .await
137 }
138
139 #[instrument(skip(self, calls))]
141 async fn execute_sequential(&self, calls: Vec<ToolCall>) -> Vec<ToolResult> {
142 let mut results = Vec::new();
143
144 for call in calls {
145 let result = tokio::time::timeout(
146 self.timeout,
147 self.registry.execute(&call.name, call.args.clone()),
148 )
149 .await;
150
151 results.push(ToolResult {
152 call_id: call.id,
153 output: result.unwrap_or(Err(AgentError::ToolError(
154 "Tool execution timed out".to_string(),
155 ))),
156 });
157 }
158
159 results
160 }
161}
162
163impl Clone for ToolExecutor {
164 fn clone(&self) -> Self {
165 Self {
166 registry: Arc::clone(&self.registry),
167 max_concurrent: self.max_concurrent,
168 timeout: self.timeout,
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use crate::tool::EchoTool;
177 use async_trait::async_trait;
178 use serde_json::json;
179 #[tokio::test]
180 async fn test_tool_call_new() {
181 let call = ToolCall::new("1", "echo", json!({"test": "value"}));
182 assert_eq!(call.id, "1");
183 assert_eq!(call.name, "echo");
184 }
185
186 #[tokio::test]
187 async fn test_executor_new() {
188 let registry = ToolRegistry::new();
189 let executor = ToolExecutor::new(registry);
190
191 assert_eq!(executor.max_concurrent, 5);
192 assert_eq!(executor.timeout, Duration::from_secs(60));
193 }
194
195 #[tokio::test]
196 async fn test_executor_with_config() {
197 let registry = ToolRegistry::new();
198 let executor = ToolExecutor::new(registry)
199 .with_max_concurrent(10)
200 .with_timeout(Duration::from_secs(30));
201
202 assert_eq!(executor.max_concurrent, 10);
203 assert_eq!(executor.timeout, Duration::from_secs(30));
204 }
205
206 #[tokio::test]
207 async fn test_execute_tools_empty() {
208 let registry = ToolRegistry::new();
209 let executor = ToolExecutor::new(registry);
210 let results = executor.execute_tools(vec![]).await;
211
212 assert!(results.is_empty());
213 }
214
215 #[tokio::test]
216 async fn test_execute_tools_single() {
217 let mut registry = ToolRegistry::new();
218 registry.register(EchoTool::new()).unwrap();
219
220 let executor = ToolExecutor::new(registry);
221 let call = ToolCall::new("1", "echo", json!({"test": "value"}));
222 let results = executor.execute_tools(vec![call]).await;
223
224 assert_eq!(results.len(), 1);
225 assert_eq!(results[0].call_id, "1");
226 assert!(results[0].output.is_ok());
227 }
228
229 #[tokio::test]
230 async fn test_execute_tools_parallel() {
231 let mut registry = ToolRegistry::new();
232 registry.register(EchoTool::new()).unwrap();
233
234 let executor = ToolExecutor::new(registry);
235
236 let calls = vec![
238 ToolCall::new("1", "echo", json!({"id": 1})),
239 ToolCall::new("2", "echo", json!({"id": 2})),
240 ToolCall::new("3", "echo", json!({"id": 3})),
241 ];
242
243 let results = executor.execute_tools(calls).await;
244
245 assert_eq!(results.len(), 3);
246 assert!(results.iter().all(|r| r.output.is_ok()));
247 }
248
249 #[tokio::test]
250 async fn test_execute_tools_sequential_with_dependencies() {
251 let mut registry = ToolRegistry::new();
252 registry.register(EchoTool::new()).unwrap();
253
254 let executor = ToolExecutor::new(registry);
255
256 let calls = vec![
258 ToolCall::new("1", "echo", json!({"id": 1})),
259 ToolCall::new("2", "echo", json!({"input": "$output_1", "id": 2})),
260 ];
261
262 let results = executor.execute_tools(calls).await;
263
264 assert_eq!(results.len(), 2);
265 assert!(results[0].output.is_ok());
267 assert!(results[1].output.is_ok());
268 }
269
270 #[tokio::test]
271 async fn test_execute_tools_timeout() {
272 use async_trait::async_trait;
273
274 struct SlowTool;
276
277 #[async_trait]
278 impl crate::tool::Tool for SlowTool {
279 fn name(&self) -> &str {
280 "slow"
281 }
282
283 async fn execute(&self, _args: Value) -> Result<Value, AgentError> {
284 tokio::time::sleep(Duration::from_secs(2)).await;
285 Ok(json!({"status": "done"}))
286 }
287 }
288
289 let mut registry = ToolRegistry::new();
290 registry.register(SlowTool).unwrap();
291
292 let executor = ToolExecutor::new(registry).with_timeout(Duration::from_millis(100));
293 let call = ToolCall::new("1", "slow", json!({}));
294 let results = executor.execute_tools(vec![call]).await;
295
296 assert_eq!(results.len(), 1);
297 assert!(results[0].output.is_err());
298 assert!(matches!(
299 results[0].output.as_ref().unwrap_err(),
300 AgentError::ToolError(_)
301 ));
302 }
303
304 #[tokio::test]
305 async fn test_execute_tools_tool_not_found() {
306 let registry = ToolRegistry::new();
307 let executor = ToolExecutor::new(registry);
308
309 let call = ToolCall::new("1", "nonexistent", json!({}));
310 let results = executor.execute_tools(vec![call]).await;
311
312 assert_eq!(results.len(), 1);
313 assert!(results[0].output.is_err());
314 }
315
316 #[tokio::test]
317 async fn test_categorize_calls_no_dependencies() {
318 let registry = ToolRegistry::new();
319 let executor = ToolExecutor::new(registry);
320
321 let calls = vec![
322 ToolCall::new("1", "echo", json!({"id": 1})),
323 ToolCall::new("2", "echo", json!({"id": 2})),
324 ];
325
326 let (independent, dependent) = executor.categorize_calls(&calls);
327
328 assert_eq!(independent.len(), 2);
329 assert_eq!(dependent.len(), 0);
330 }
331
332 #[tokio::test]
333 async fn test_categorize_calls_with_dependencies() {
334 let registry = ToolRegistry::new();
335 let executor = ToolExecutor::new(registry);
336
337 let calls = vec![
338 ToolCall::new("1", "echo", json!({"id": 1})),
339 ToolCall::new("2", "echo", json!({"input": "$output_1", "id": 2})),
340 ToolCall::new("3", "echo", json!({"id": 3})),
341 ];
342
343 let (independent, dependent) = executor.categorize_calls(&calls);
344
345 assert_eq!(independent.len(), 2); assert_eq!(dependent.len(), 1); }
348
349 #[tokio::test]
350 async fn test_parallel_with_concurrency_limit() {
351 use std::sync::atomic::{AtomicUsize, Ordering};
352 use std::sync::Arc;
353 use std::time::Instant;
354
355 struct ConcurrentTool {
356 counter: Arc<AtomicUsize>,
357 }
358
359 #[async_trait]
360 impl crate::tool::Tool for ConcurrentTool {
361 fn name(&self) -> &str {
362 "concurrent"
363 }
364
365 async fn execute(&self, _args: Value) -> Result<Value, AgentError> {
366 let count = self.counter.fetch_add(1, Ordering::SeqCst);
367 tokio::time::sleep(Duration::from_millis(100)).await;
368 self.counter.fetch_sub(1, Ordering::SeqCst);
369 Ok(json!({"count": count}))
370 }
371 }
372
373 let counter = Arc::new(AtomicUsize::new(0));
374
375 let mut registry = ToolRegistry::new();
376 registry
377 .register(ConcurrentTool {
378 counter: counter.clone(),
379 })
380 .unwrap();
381
382 let executor = ToolExecutor::new(registry).with_max_concurrent(2);
383
384 let calls: Vec<ToolCall> = (0..5)
385 .map(|i| ToolCall::new(i.to_string(), "concurrent", json!({})))
386 .collect();
387
388 let start = Instant::now();
389 let results = executor.execute_tools(calls).await;
390 let duration = start.elapsed();
391
392 assert_eq!(results.len(), 5);
393 assert!(results.iter().all(|r| r.output.is_ok()));
394
395 assert!(duration.as_millis() >= 200 && duration.as_millis() <= 400);
399 }
400}