1use std::sync::Arc;
7use std::time::Duration;
8use tokio::time::sleep;
9
10use super::{ToolError, ToolOperationResult, ToolRegistry};
11use crate::types::{ContentBlock, Message, ToolResult, ToolUse};
12
13#[derive(Debug, Clone)]
15pub struct ToolExecutionConfig {
16 pub max_retries: u32,
18
19 pub retry_delay: Duration,
21
22 pub exponential_backoff: bool,
24
25 pub max_retry_delay: Duration,
27
28 pub parallel_execution: bool,
30
31 pub max_concurrent_tools: usize,
33}
34
35impl Default for ToolExecutionConfig {
36 fn default() -> Self {
37 Self {
38 max_retries: 3,
39 retry_delay: Duration::from_millis(500),
40 exponential_backoff: true,
41 max_retry_delay: Duration::from_secs(10),
42 parallel_execution: true,
43 max_concurrent_tools: 4,
44 }
45 }
46}
47
48pub struct ToolExecutor {
53 registry: Arc<ToolRegistry>,
55
56 config: ToolExecutionConfig,
58}
59
60impl ToolExecutor {
61 pub fn new(registry: Arc<ToolRegistry>) -> Self {
63 Self {
64 registry,
65 config: ToolExecutionConfig::default(),
66 }
67 }
68
69 pub fn with_config(registry: Arc<ToolRegistry>, config: ToolExecutionConfig) -> Self {
71 Self { registry, config }
72 }
73
74 pub async fn execute_with_retry(&self, tool_use: &ToolUse) -> ToolOperationResult<ToolResult> {
82 let mut last_error = None;
83 let mut delay = self.config.retry_delay;
84
85 for attempt in 0..=self.config.max_retries {
86 match self.registry.execute(tool_use).await {
87 Ok(result) => {
88 if let Some(true) = result.is_error {
90 if attempt < self.config.max_retries && self.should_retry_error(&result) {
91 last_error = Some(ToolError::ExecutionFailed {
92 source: format!("Tool returned error: {:?}", result.content).into(),
93 });
94
95 if attempt < self.config.max_retries {
96 sleep(delay).await;
97 if self.config.exponential_backoff {
98 delay = std::cmp::min(delay * 2, self.config.max_retry_delay);
99 }
100 }
101 continue;
102 }
103 }
104 return Ok(result);
105 }
106 Err(err) => {
107 if attempt < self.config.max_retries && self.should_retry_error_type(&err) {
108 last_error = Some(err);
109 sleep(delay).await;
110 if self.config.exponential_backoff {
111 delay = std::cmp::min(delay * 2, self.config.max_retry_delay);
112 }
113 } else {
114 return Err(err);
115 }
116 }
117 }
118 }
119
120 Err(last_error.unwrap_or_else(|| ToolError::ExecutionFailed {
121 source: "Maximum retries exceeded".to_string().into(),
122 }))
123 }
124
125 pub async fn execute_multiple(
133 &self,
134 tool_uses: &[ToolUse],
135 ) -> Vec<ToolOperationResult<ToolResult>> {
136 if self.config.parallel_execution && tool_uses.len() > 1 {
137 self.execute_parallel_with_concurrency(tool_uses).await
138 } else {
139 let mut results = Vec::with_capacity(tool_uses.len());
140 for tool_use in tool_uses {
141 results.push(self.execute_with_retry(tool_use).await);
142 }
143 results
144 }
145 }
146
147 async fn execute_parallel_with_concurrency(
149 &self,
150 tool_uses: &[ToolUse],
151 ) -> Vec<ToolOperationResult<ToolResult>> {
152 use futures::stream::{self, StreamExt};
153
154 let semaphore = Arc::new(tokio::sync::Semaphore::new(
156 self.config.max_concurrent_tools,
157 ));
158
159 let futures = tool_uses.iter().enumerate().map(|(index, tool_use)| {
160 let registry = self.registry.clone();
161 let semaphore = semaphore.clone();
162 let tool_use = tool_use.clone();
163 let config = self.config.clone();
164
165 async move {
166 let _permit = semaphore.acquire().await.unwrap();
167 let executor = ToolExecutor::with_config(registry, config);
168 (index, executor.execute_with_retry(&tool_use).await)
169 }
170 });
171
172 let mut results: Vec<(usize, ToolOperationResult<ToolResult>)> = stream::iter(futures)
173 .buffer_unordered(self.config.max_concurrent_tools)
174 .collect()
175 .await;
176
177 results.sort_by_key(|(index, _)| *index);
179 results.into_iter().map(|(_, result)| result).collect()
180 }
181
182 pub fn extract_tool_uses(&self, message: &Message) -> Vec<ToolUse> {
190 message
191 .content
192 .iter()
193 .filter_map(|block| {
194 if let ContentBlock::ToolUse { id, name, input } = block {
195 Some(ToolUse {
196 id: id.clone(),
197 name: name.clone(),
198 input: input.clone(),
199 })
200 } else {
201 None
202 }
203 })
204 .collect()
205 }
206
207 fn should_retry_error(&self, _result: &ToolResult) -> bool {
209 false
212 }
213
214 fn should_retry_error_type(&self, error: &ToolError) -> bool {
216 match error {
217 ToolError::ExecutionFailed { .. } => true,
218 ToolError::Timeout { .. } => true,
219 ToolError::ValidationFailed { .. } => false, ToolError::NotFound { .. } => false, ToolError::RegistryError { .. } => false, }
223 }
224
225 pub fn registry(&self) -> &Arc<ToolRegistry> {
227 &self.registry
228 }
229
230 pub fn config(&self) -> &ToolExecutionConfig {
232 &self.config
233 }
234
235 pub fn set_config(&mut self, config: ToolExecutionConfig) {
237 self.config = config;
238 }
239}
240
241pub struct ToolExecutionConfigBuilder {
243 config: ToolExecutionConfig,
244}
245
246impl ToolExecutionConfigBuilder {
247 pub fn new() -> Self {
249 Self {
250 config: ToolExecutionConfig::default(),
251 }
252 }
253
254 pub fn max_retries(mut self, max_retries: u32) -> Self {
256 self.config.max_retries = max_retries;
257 self
258 }
259
260 pub fn retry_delay(mut self, delay: Duration) -> Self {
262 self.config.retry_delay = delay;
263 self
264 }
265
266 pub fn exponential_backoff(mut self, enabled: bool) -> Self {
268 self.config.exponential_backoff = enabled;
269 self
270 }
271
272 pub fn max_retry_delay(mut self, delay: Duration) -> Self {
274 self.config.max_retry_delay = delay;
275 self
276 }
277
278 pub fn parallel_execution(mut self, enabled: bool) -> Self {
280 self.config.parallel_execution = enabled;
281 self
282 }
283
284 pub fn max_concurrent_tools(mut self, max: usize) -> Self {
286 self.config.max_concurrent_tools = max;
287 self
288 }
289
290 pub fn build(self) -> ToolExecutionConfig {
292 self.config
293 }
294}
295
296impl Default for ToolExecutionConfigBuilder {
297 fn default() -> Self {
298 Self::new()
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use crate::ToolBuilder;
306 use crate::tools::{ToolFunction, ToolRegistry};
307 use crate::types::{Tool, ToolResult};
308 use async_trait::async_trait;
309 use serde_json::{Value, json};
310 use std::sync::atomic::{AtomicUsize, Ordering};
311
312 struct TestRetryTool {
313 attempts: Arc<AtomicUsize>,
314 fail_count: usize,
315 }
316
317 #[async_trait]
318 impl ToolFunction for TestRetryTool {
319 async fn execute(
320 &self,
321 _input: Value,
322 ) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
323 let attempt = self.attempts.fetch_add(1, Ordering::SeqCst);
324 if attempt < self.fail_count {
325 Err("Simulated failure".into())
326 } else {
327 Ok(ToolResult::success(
328 "test_id",
329 format!("Success on attempt {}", attempt + 1),
330 ))
331 }
332 }
333 }
334
335 struct TestSlowTool {
336 delay: Duration,
337 }
338
339 #[async_trait]
340 impl ToolFunction for TestSlowTool {
341 async fn execute(
342 &self,
343 _input: Value,
344 ) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
345 sleep(self.delay).await;
346 Ok(ToolResult::success("test_id", "Slow tool completed"))
347 }
348 }
349
350 #[tokio::test]
351 async fn test_successful_execution() {
352 let mut registry = ToolRegistry::new();
353 let tool_def = ToolBuilder::new("test_tool", "Test tool").build();
354
355 let attempts = Arc::new(AtomicUsize::new(0));
356 registry
357 .register(
358 "test_tool",
359 tool_def,
360 Box::new(TestRetryTool {
361 attempts,
362 fail_count: 0, }),
364 )
365 .unwrap();
366
367 let executor = ToolExecutor::new(Arc::new(registry));
368 let tool_use = ToolUse {
369 id: "test_id".to_string(),
370 name: "test_tool".to_string(),
371 input: json!({}),
372 };
373
374 let result = executor.execute_with_retry(&tool_use).await.unwrap();
375 if let crate::types::ToolResultContent::Text(content) = result.content {
376 assert_eq!(content, "Success on attempt 1");
377 } else {
378 panic!("Expected text content");
379 }
380 }
381
382 #[tokio::test]
383 async fn test_retry_logic() {
384 let mut registry = ToolRegistry::new();
385 let tool_def = ToolBuilder::new("retry_tool", "Tool that fails then succeeds").build();
386
387 let attempts = Arc::new(AtomicUsize::new(0));
388 registry
389 .register(
390 "retry_tool",
391 tool_def,
392 Box::new(TestRetryTool {
393 attempts,
394 fail_count: 2, }),
396 )
397 .unwrap();
398
399 let config = ToolExecutionConfigBuilder::new()
400 .max_retries(3)
401 .retry_delay(Duration::from_millis(10))
402 .exponential_backoff(false)
403 .build();
404
405 let executor = ToolExecutor::with_config(Arc::new(registry), config);
406 let tool_use = ToolUse {
407 id: "test_id".to_string(),
408 name: "retry_tool".to_string(),
409 input: json!({}),
410 };
411
412 let result = executor.execute_with_retry(&tool_use).await.unwrap();
413 if let crate::types::ToolResultContent::Text(content) = result.content {
414 assert_eq!(content, "Success on attempt 3");
415 } else {
416 panic!("Expected text content");
417 }
418 }
419
420 #[tokio::test]
421 async fn test_parallel_execution() {
422 let mut registry = ToolRegistry::new();
423 let tool_def = ToolBuilder::new("slow_tool", "Slow tool for testing parallelism").build();
424
425 registry
426 .register(
427 "slow_tool",
428 tool_def,
429 Box::new(TestSlowTool {
430 delay: Duration::from_millis(100),
431 }),
432 )
433 .unwrap();
434
435 let config = ToolExecutionConfigBuilder::new()
436 .parallel_execution(true)
437 .max_concurrent_tools(3)
438 .build();
439
440 let executor = ToolExecutor::with_config(Arc::new(registry), config);
441
442 let tool_uses = vec![
443 ToolUse {
444 id: "test_1".to_string(),
445 name: "slow_tool".to_string(),
446 input: json!({}),
447 },
448 ToolUse {
449 id: "test_2".to_string(),
450 name: "slow_tool".to_string(),
451 input: json!({}),
452 },
453 ToolUse {
454 id: "test_3".to_string(),
455 name: "slow_tool".to_string(),
456 input: json!({}),
457 },
458 ];
459
460 let start = std::time::Instant::now();
461 let results = executor.execute_multiple(&tool_uses).await;
462 let duration = start.elapsed();
463
464 assert!(duration < Duration::from_millis(200));
466 assert_eq!(results.len(), 3);
467
468 for result in results {
469 assert!(result.is_ok());
470 }
471 }
472
473 #[test]
474 fn test_config_builder() {
475 let config = ToolExecutionConfigBuilder::new()
476 .max_retries(5)
477 .retry_delay(Duration::from_millis(100))
478 .exponential_backoff(true)
479 .max_retry_delay(Duration::from_secs(5))
480 .parallel_execution(false)
481 .max_concurrent_tools(2)
482 .build();
483
484 assert_eq!(config.max_retries, 5);
485 assert_eq!(config.retry_delay, Duration::from_millis(100));
486 assert!(config.exponential_backoff);
487 assert_eq!(config.max_retry_delay, Duration::from_secs(5));
488 assert!(!config.parallel_execution);
489 assert_eq!(config.max_concurrent_tools, 2);
490 }
491}
492