Skip to main content

agentik_sdk/tools/
executor.rs

1//! High-level tool execution coordinator.
2//!
3//! This module provides the `ToolExecutor` which coordinates tool execution
4//! across multiple tools, handles retries, and manages conversation flow.
5
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::time::sleep;
9
10use super::{ToolError, ToolOperationResult, ToolRegistry};
11use crate::types::{ContentBlock, Message, ToolResult, ToolUse};
12
13/// Configuration for tool execution.
14#[derive(Debug, Clone)]
15pub struct ToolExecutionConfig {
16    /// Maximum number of retry attempts for failed tools.
17    pub max_retries: u32,
18
19    /// Base delay between retry attempts.
20    pub retry_delay: Duration,
21
22    /// Whether to use exponential backoff for retries.
23    pub exponential_backoff: bool,
24
25    /// Maximum delay for exponential backoff.
26    pub max_retry_delay: Duration,
27
28    /// Whether to execute tools in parallel when possible.
29    pub parallel_execution: bool,
30
31    /// Maximum number of concurrent tool executions.
32    pub max_concurrent_tools: usize,
33}
34
35impl Default for ToolExecutionConfig {
36    fn default() -> Self {
37        Self {
38            max_retries: 3,
39            retry_delay: Duration::from_millis(500),
40            exponential_backoff: true,
41            max_retry_delay: Duration::from_secs(10),
42            parallel_execution: true,
43            max_concurrent_tools: 4,
44        }
45    }
46}
47
48/// High-level tool executor that coordinates tool execution.
49///
50/// The executor handles multiple tool calls, retry logic, error recovery,
51/// and provides higher-level abstractions for tool management.
52pub struct ToolExecutor {
53    /// The tool registry for executing tools.
54    registry: Arc<ToolRegistry>,
55
56    /// Configuration for tool execution.
57    config: ToolExecutionConfig,
58}
59
60impl ToolExecutor {
61    /// Create a new tool executor with the given registry.
62    pub fn new(registry: Arc<ToolRegistry>) -> Self {
63        Self {
64            registry,
65            config: ToolExecutionConfig::default(),
66        }
67    }
68
69    /// Create a new tool executor with custom configuration.
70    pub fn with_config(registry: Arc<ToolRegistry>, config: ToolExecutionConfig) -> Self {
71        Self { registry, config }
72    }
73
74    /// Execute a single tool with retry logic.
75    ///
76    /// # Arguments
77    /// * `tool_use` - The tool use request from Claude
78    ///
79    /// # Returns
80    /// The tool result after execution (with retries if needed).
81    pub async fn execute_with_retry(&self, tool_use: &ToolUse) -> ToolOperationResult<ToolResult> {
82        let mut last_error = None;
83        let mut delay = self.config.retry_delay;
84
85        for attempt in 0..=self.config.max_retries {
86            match self.registry.execute(tool_use).await {
87                Ok(result) => {
88                    // Check if the result indicates an error that should be retried
89                    if let Some(true) = result.is_error {
90                        if attempt < self.config.max_retries && self.should_retry_error(&result) {
91                            last_error = Some(ToolError::ExecutionFailed {
92                                source: format!("Tool returned error: {:?}", result.content).into(),
93                            });
94
95                            if attempt < self.config.max_retries {
96                                sleep(delay).await;
97                                if self.config.exponential_backoff {
98                                    delay = std::cmp::min(delay * 2, self.config.max_retry_delay);
99                                }
100                            }
101                            continue;
102                        }
103                    }
104                    return Ok(result);
105                }
106                Err(err) => {
107                    if attempt < self.config.max_retries && self.should_retry_error_type(&err) {
108                        last_error = Some(err);
109                        sleep(delay).await;
110                        if self.config.exponential_backoff {
111                            delay = std::cmp::min(delay * 2, self.config.max_retry_delay);
112                        }
113                    } else {
114                        return Err(err);
115                    }
116                }
117            }
118        }
119
120        Err(last_error.unwrap_or_else(|| ToolError::ExecutionFailed {
121            source: "Maximum retries exceeded".to_string().into(),
122        }))
123    }
124
125    /// Execute multiple tools, potentially in parallel.
126    ///
127    /// # Arguments
128    /// * `tool_uses` - Vector of tool use requests
129    ///
130    /// # Returns
131    /// Vector of tool results in the same order as input.
132    pub async fn execute_multiple(
133        &self,
134        tool_uses: &[ToolUse],
135    ) -> Vec<ToolOperationResult<ToolResult>> {
136        if self.config.parallel_execution && tool_uses.len() > 1 {
137            self.execute_parallel_with_concurrency(tool_uses).await
138        } else {
139            let mut results = Vec::with_capacity(tool_uses.len());
140            for tool_use in tool_uses {
141                results.push(self.execute_with_retry(tool_use).await);
142            }
143            results
144        }
145    }
146
147    /// Execute tools in parallel with concurrency control.
148    async fn execute_parallel_with_concurrency(
149        &self,
150        tool_uses: &[ToolUse],
151    ) -> Vec<ToolOperationResult<ToolResult>> {
152        use futures::stream::{self, StreamExt};
153
154        // Use a semaphore to limit concurrent executions
155        let semaphore = Arc::new(tokio::sync::Semaphore::new(
156            self.config.max_concurrent_tools,
157        ));
158
159        let futures = tool_uses.iter().enumerate().map(|(index, tool_use)| {
160            let registry = self.registry.clone();
161            let semaphore = semaphore.clone();
162            let tool_use = tool_use.clone();
163            let config = self.config.clone();
164
165            async move {
166                let _permit = semaphore.acquire().await.unwrap();
167                let executor = ToolExecutor::with_config(registry, config);
168                (index, executor.execute_with_retry(&tool_use).await)
169            }
170        });
171
172        let mut results: Vec<(usize, ToolOperationResult<ToolResult>)> = stream::iter(futures)
173            .buffer_unordered(self.config.max_concurrent_tools)
174            .collect()
175            .await;
176
177        // Sort results by original index to maintain order
178        results.sort_by_key(|(index, _)| *index);
179        results.into_iter().map(|(_, result)| result).collect()
180    }
181
182    /// Extract tool use requests from a message.
183    ///
184    /// # Arguments
185    /// * `message` - Message from Claude that may contain tool use requests
186    ///
187    /// # Returns
188    /// Vector of tool use requests found in the message.
189    pub fn extract_tool_uses(&self, message: &Message) -> Vec<ToolUse> {
190        message
191            .content
192            .iter()
193            .filter_map(|block| {
194                if let ContentBlock::ToolUse { id, name, input } = block {
195                    Some(ToolUse {
196                        id: id.clone(),
197                        name: name.clone(),
198                        input: input.clone(),
199                    })
200                } else {
201                    None
202                }
203            })
204            .collect()
205    }
206
207    /// Check if a tool should be retried based on the error in the result.
208    fn should_retry_error(&self, _result: &ToolResult) -> bool {
209        // Add logic to determine if specific error types should be retried
210        // For now, we'll be conservative and not retry errors in results
211        false
212    }
213
214    /// Check if a tool execution error should be retried.
215    fn should_retry_error_type(&self, error: &ToolError) -> bool {
216        match error {
217            ToolError::ExecutionFailed { .. } => true,
218            ToolError::Timeout { .. } => true,
219            ToolError::ValidationFailed { .. } => false, // Don't retry validation errors
220            ToolError::NotFound { .. } => false,         // Don't retry missing tools
221            ToolError::RegistryError { .. } => false,    // Don't retry registry errors
222        }
223    }
224
225    /// Get the underlying tool registry.
226    pub fn registry(&self) -> &Arc<ToolRegistry> {
227        &self.registry
228    }
229
230    /// Get the current execution configuration.
231    pub fn config(&self) -> &ToolExecutionConfig {
232        &self.config
233    }
234
235    /// Update the execution configuration.
236    pub fn set_config(&mut self, config: ToolExecutionConfig) {
237        self.config = config;
238    }
239}
240
241/// Builder for creating tool execution configurations.
242pub struct ToolExecutionConfigBuilder {
243    config: ToolExecutionConfig,
244}
245
246impl ToolExecutionConfigBuilder {
247    /// Create a new configuration builder with defaults.
248    pub fn new() -> Self {
249        Self {
250            config: ToolExecutionConfig::default(),
251        }
252    }
253
254    /// Set the maximum number of retry attempts.
255    pub fn max_retries(mut self, max_retries: u32) -> Self {
256        self.config.max_retries = max_retries;
257        self
258    }
259
260    /// Set the base retry delay.
261    pub fn retry_delay(mut self, delay: Duration) -> Self {
262        self.config.retry_delay = delay;
263        self
264    }
265
266    /// Enable or disable exponential backoff.
267    pub fn exponential_backoff(mut self, enabled: bool) -> Self {
268        self.config.exponential_backoff = enabled;
269        self
270    }
271
272    /// Set the maximum retry delay for exponential backoff.
273    pub fn max_retry_delay(mut self, delay: Duration) -> Self {
274        self.config.max_retry_delay = delay;
275        self
276    }
277
278    /// Enable or disable parallel execution.
279    pub fn parallel_execution(mut self, enabled: bool) -> Self {
280        self.config.parallel_execution = enabled;
281        self
282    }
283
284    /// Set the maximum number of concurrent tool executions.
285    pub fn max_concurrent_tools(mut self, max: usize) -> Self {
286        self.config.max_concurrent_tools = max;
287        self
288    }
289
290    /// Build the configuration.
291    pub fn build(self) -> ToolExecutionConfig {
292        self.config
293    }
294}
295
296impl Default for ToolExecutionConfigBuilder {
297    fn default() -> Self {
298        Self::new()
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use crate::ToolBuilder;
306    use crate::tools::{ToolFunction, ToolRegistry};
307    use crate::types::{Tool, ToolResult};
308    use async_trait::async_trait;
309    use serde_json::{Value, json};
310    use std::sync::atomic::{AtomicUsize, Ordering};
311
312    struct TestRetryTool {
313        attempts: Arc<AtomicUsize>,
314        fail_count: usize,
315    }
316
317    #[async_trait]
318    impl ToolFunction for TestRetryTool {
319        async fn execute(
320            &self,
321            _input: Value,
322        ) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
323            let attempt = self.attempts.fetch_add(1, Ordering::SeqCst);
324            if attempt < self.fail_count {
325                Err("Simulated failure".into())
326            } else {
327                Ok(ToolResult::success(
328                    "test_id",
329                    format!("Success on attempt {}", attempt + 1),
330                ))
331            }
332        }
333    }
334
335    struct TestSlowTool {
336        delay: Duration,
337    }
338
339    #[async_trait]
340    impl ToolFunction for TestSlowTool {
341        async fn execute(
342            &self,
343            _input: Value,
344        ) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
345            sleep(self.delay).await;
346            Ok(ToolResult::success("test_id", "Slow tool completed"))
347        }
348    }
349
350    #[tokio::test]
351    async fn test_successful_execution() {
352        let mut registry = ToolRegistry::new();
353        let tool_def = ToolBuilder::new("test_tool", "Test tool").build();
354
355        let attempts = Arc::new(AtomicUsize::new(0));
356        registry
357            .register(
358                "test_tool",
359                tool_def,
360                Box::new(TestRetryTool {
361                    attempts,
362                    fail_count: 0, // Don't fail
363                }),
364            )
365            .unwrap();
366
367        let executor = ToolExecutor::new(Arc::new(registry));
368        let tool_use = ToolUse {
369            id: "test_id".to_string(),
370            name: "test_tool".to_string(),
371            input: json!({}),
372        };
373
374        let result = executor.execute_with_retry(&tool_use).await.unwrap();
375        if let crate::types::ToolResultContent::Text(content) = result.content {
376            assert_eq!(content, "Success on attempt 1");
377        } else {
378            panic!("Expected text content");
379        }
380    }
381
382    #[tokio::test]
383    async fn test_retry_logic() {
384        let mut registry = ToolRegistry::new();
385        let tool_def = ToolBuilder::new("retry_tool", "Tool that fails then succeeds").build();
386
387        let attempts = Arc::new(AtomicUsize::new(0));
388        registry
389            .register(
390                "retry_tool",
391                tool_def,
392                Box::new(TestRetryTool {
393                    attempts,
394                    fail_count: 2, // Fail first 2 attempts
395                }),
396            )
397            .unwrap();
398
399        let config = ToolExecutionConfigBuilder::new()
400            .max_retries(3)
401            .retry_delay(Duration::from_millis(10))
402            .exponential_backoff(false)
403            .build();
404
405        let executor = ToolExecutor::with_config(Arc::new(registry), config);
406        let tool_use = ToolUse {
407            id: "test_id".to_string(),
408            name: "retry_tool".to_string(),
409            input: json!({}),
410        };
411
412        let result = executor.execute_with_retry(&tool_use).await.unwrap();
413        if let crate::types::ToolResultContent::Text(content) = result.content {
414            assert_eq!(content, "Success on attempt 3");
415        } else {
416            panic!("Expected text content");
417        }
418    }
419
420    #[tokio::test]
421    async fn test_parallel_execution() {
422        let mut registry = ToolRegistry::new();
423        let tool_def = ToolBuilder::new("slow_tool", "Slow tool for testing parallelism").build();
424
425        registry
426            .register(
427                "slow_tool",
428                tool_def,
429                Box::new(TestSlowTool {
430                    delay: Duration::from_millis(100),
431                }),
432            )
433            .unwrap();
434
435        let config = ToolExecutionConfigBuilder::new()
436            .parallel_execution(true)
437            .max_concurrent_tools(3)
438            .build();
439
440        let executor = ToolExecutor::with_config(Arc::new(registry), config);
441
442        let tool_uses = vec![
443            ToolUse {
444                id: "test_1".to_string(),
445                name: "slow_tool".to_string(),
446                input: json!({}),
447            },
448            ToolUse {
449                id: "test_2".to_string(),
450                name: "slow_tool".to_string(),
451                input: json!({}),
452            },
453            ToolUse {
454                id: "test_3".to_string(),
455                name: "slow_tool".to_string(),
456                input: json!({}),
457            },
458        ];
459
460        let start = std::time::Instant::now();
461        let results = executor.execute_multiple(&tool_uses).await;
462        let duration = start.elapsed();
463
464        // Should complete in roughly 100ms (parallel) rather than 300ms (sequential)
465        assert!(duration < Duration::from_millis(200));
466        assert_eq!(results.len(), 3);
467
468        for result in results {
469            assert!(result.is_ok());
470        }
471    }
472
473    #[test]
474    fn test_config_builder() {
475        let config = ToolExecutionConfigBuilder::new()
476            .max_retries(5)
477            .retry_delay(Duration::from_millis(100))
478            .exponential_backoff(true)
479            .max_retry_delay(Duration::from_secs(5))
480            .parallel_execution(false)
481            .max_concurrent_tools(2)
482            .build();
483
484        assert_eq!(config.max_retries, 5);
485        assert_eq!(config.retry_delay, Duration::from_millis(100));
486        assert!(config.exponential_backoff);
487        assert_eq!(config.max_retry_delay, Duration::from_secs(5));
488        assert!(!config.parallel_execution);
489        assert_eq!(config.max_concurrent_tools, 2);
490    }
491}
492