Skip to main content

limit_agent/
executor.rs

1use crate::error::AgentError;
2use crate::registry::ToolRegistry;
3use serde_json::Value;
4use std::sync::Arc;
5use std::time::Duration;
6use tracing::instrument;
7
8/// Represents a single tool call
9#[derive(Debug, Clone)]
10pub struct ToolCall {
11    pub id: String,
12    pub name: String,
13    pub args: Value,
14}
15
16impl ToolCall {
17    pub fn new(id: impl Into<String>, name: impl Into<String>, args: Value) -> Self {
18        Self {
19            id: id.into(),
20            name: name.into(),
21            args,
22        }
23    }
24}
25
26/// Represents the result of a tool execution
27#[derive(Debug, Clone)]
28pub struct ToolResult {
29    pub call_id: String,
30    pub output: Result<Value, AgentError>,
31}
32
33/// Executor for tool calls with parallel/conditional execution
34pub struct ToolExecutor {
35    registry: Arc<ToolRegistry>,
36    max_concurrent: usize,
37    timeout: Duration,
38}
39
40impl ToolExecutor {
41    /// Create a new ToolExecutor with default settings
42    pub fn new(registry: ToolRegistry) -> Self {
43        Self {
44            registry: Arc::new(registry),
45            max_concurrent: 5,
46            timeout: Duration::from_secs(60),
47        }
48    }
49
50    /// Set maximum concurrent tool executions
51    pub fn with_max_concurrent(mut self, max: usize) -> Self {
52        self.max_concurrent = max;
53        self
54    }
55
56    /// Set timeout for each tool execution
57    pub fn with_timeout(mut self, timeout: Duration) -> Self {
58        self.timeout = timeout;
59        self
60    }
61
62    /// Execute multiple tool calls with conditional parallel/sequential logic
63    #[instrument(skip(self, calls))]
64    pub async fn execute_tools(&self, calls: Vec<ToolCall>) -> Vec<ToolResult> {
65        if calls.is_empty() {
66            return Vec::new();
67        }
68
69        // Analyze dependencies and categorize calls
70        let (independent, dependent) = self.categorize_calls(&calls);
71
72        // Execute independent calls in parallel (with concurrency limit)
73        let mut results = self.execute_parallel(independent).await;
74
75        // Execute dependent calls sequentially
76        results.extend(self.execute_sequential(dependent).await);
77
78        results
79    }
80
81    /// Categorize tool calls into independent and dependent groups
82    fn categorize_calls(&self, calls: &[ToolCall]) -> (Vec<ToolCall>, Vec<ToolCall>) {
83        let mut independent = Vec::new();
84        let mut dependent = Vec::new();
85
86        for call in calls {
87            if self.has_dependencies(call) {
88                dependent.push(call.clone());
89            } else {
90                independent.push(call.clone());
91            }
92        }
93
94        (independent, dependent)
95    }
96
97    /// Check if a tool call has dependencies on previous calls
98    fn has_dependencies(&self, call: &ToolCall) -> bool {
99        // Simple heuristic: if args contain references to variable patterns like $var or $output_N
100        let args_str = serde_json::to_string(&call.args).unwrap_or_default();
101
102        // Look for dependency markers in arguments
103        args_str.contains("$output_") || args_str.contains("$var_") || args_str.contains("$result_")
104    }
105
106    /// Execute independent tool calls in parallel with concurrency limit
107    #[instrument(skip(self, calls))]
108    async fn execute_parallel(&self, calls: Vec<ToolCall>) -> Vec<ToolResult> {
109        if calls.is_empty() {
110            return Vec::new();
111        }
112
113        use futures::stream::{self, StreamExt};
114
115        stream::iter(calls)
116            .map(|call| {
117                let registry = Arc::clone(&self.registry);
118                let timeout = self.timeout;
119                async move {
120                    let result = tokio::time::timeout(
121                        timeout,
122                        registry.execute(&call.name, call.args.clone()),
123                    )
124                    .await;
125
126                    ToolResult {
127                        call_id: call.id,
128                        output: result.unwrap_or(Err(AgentError::ToolError(
129                            "Tool execution timed out".to_string(),
130                        ))),
131                    }
132                }
133            })
134            .buffer_unordered(self.max_concurrent)
135            .collect()
136            .await
137    }
138
139    /// Execute dependent tool calls sequentially
140    #[instrument(skip(self, calls))]
141    async fn execute_sequential(&self, calls: Vec<ToolCall>) -> Vec<ToolResult> {
142        let mut results = Vec::new();
143
144        for call in calls {
145            let result = tokio::time::timeout(
146                self.timeout,
147                self.registry.execute(&call.name, call.args.clone()),
148            )
149            .await;
150
151            results.push(ToolResult {
152                call_id: call.id,
153                output: result.unwrap_or(Err(AgentError::ToolError(
154                    "Tool execution timed out".to_string(),
155                ))),
156            });
157        }
158
159        results
160    }
161}
162
163impl Clone for ToolExecutor {
164    fn clone(&self) -> Self {
165        Self {
166            registry: Arc::clone(&self.registry),
167            max_concurrent: self.max_concurrent,
168            timeout: self.timeout,
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use crate::tool::EchoTool;
177    use async_trait::async_trait;
178    use serde_json::json;
179    #[tokio::test]
180    async fn test_tool_call_new() {
181        let call = ToolCall::new("1", "echo", json!({"test": "value"}));
182        assert_eq!(call.id, "1");
183        assert_eq!(call.name, "echo");
184    }
185
186    #[tokio::test]
187    async fn test_executor_new() {
188        let registry = ToolRegistry::new();
189        let executor = ToolExecutor::new(registry);
190
191        assert_eq!(executor.max_concurrent, 5);
192        assert_eq!(executor.timeout, Duration::from_secs(60));
193    }
194
195    #[tokio::test]
196    async fn test_executor_with_config() {
197        let registry = ToolRegistry::new();
198        let executor = ToolExecutor::new(registry)
199            .with_max_concurrent(10)
200            .with_timeout(Duration::from_secs(30));
201
202        assert_eq!(executor.max_concurrent, 10);
203        assert_eq!(executor.timeout, Duration::from_secs(30));
204    }
205
206    #[tokio::test]
207    async fn test_execute_tools_empty() {
208        let registry = ToolRegistry::new();
209        let executor = ToolExecutor::new(registry);
210        let results = executor.execute_tools(vec![]).await;
211
212        assert!(results.is_empty());
213    }
214
215    #[tokio::test]
216    async fn test_execute_tools_single() {
217        let mut registry = ToolRegistry::new();
218        registry.register(EchoTool::new()).unwrap();
219
220        let executor = ToolExecutor::new(registry);
221        let call = ToolCall::new("1", "echo", json!({"test": "value"}));
222        let results = executor.execute_tools(vec![call]).await;
223
224        assert_eq!(results.len(), 1);
225        assert_eq!(results[0].call_id, "1");
226        assert!(results[0].output.is_ok());
227    }
228
229    #[tokio::test]
230    async fn test_execute_tools_parallel() {
231        let mut registry = ToolRegistry::new();
232        registry.register(EchoTool::new()).unwrap();
233
234        let executor = ToolExecutor::new(registry);
235
236        // Create multiple independent calls
237        let calls = vec![
238            ToolCall::new("1", "echo", json!({"id": 1})),
239            ToolCall::new("2", "echo", json!({"id": 2})),
240            ToolCall::new("3", "echo", json!({"id": 3})),
241        ];
242
243        let results = executor.execute_tools(calls).await;
244
245        assert_eq!(results.len(), 3);
246        assert!(results.iter().all(|r| r.output.is_ok()));
247    }
248
249    #[tokio::test]
250    async fn test_execute_tools_sequential_with_dependencies() {
251        let mut registry = ToolRegistry::new();
252        registry.register(EchoTool::new()).unwrap();
253
254        let executor = ToolExecutor::new(registry);
255
256        // Create calls with dependencies (detected via $output_ marker)
257        let calls = vec![
258            ToolCall::new("1", "echo", json!({"id": 1})),
259            ToolCall::new("2", "echo", json!({"input": "$output_1", "id": 2})),
260        ];
261
262        let results = executor.execute_tools(calls).await;
263
264        assert_eq!(results.len(), 2);
265        // Both should complete successfully
266        assert!(results[0].output.is_ok());
267        assert!(results[1].output.is_ok());
268    }
269
270    #[tokio::test]
271    async fn test_execute_tools_timeout() {
272        use async_trait::async_trait;
273
274        // Create a slow tool
275        struct SlowTool;
276
277        #[async_trait]
278        impl crate::tool::Tool for SlowTool {
279            fn name(&self) -> &str {
280                "slow"
281            }
282
283            async fn execute(&self, _args: Value) -> Result<Value, AgentError> {
284                tokio::time::sleep(Duration::from_secs(2)).await;
285                Ok(json!({"status": "done"}))
286            }
287        }
288
289        let mut registry = ToolRegistry::new();
290        registry.register(SlowTool).unwrap();
291
292        let executor = ToolExecutor::new(registry).with_timeout(Duration::from_millis(100));
293        let call = ToolCall::new("1", "slow", json!({}));
294        let results = executor.execute_tools(vec![call]).await;
295
296        assert_eq!(results.len(), 1);
297        assert!(results[0].output.is_err());
298        assert!(matches!(
299            results[0].output.as_ref().unwrap_err(),
300            AgentError::ToolError(_)
301        ));
302    }
303
304    #[tokio::test]
305    async fn test_execute_tools_tool_not_found() {
306        let registry = ToolRegistry::new();
307        let executor = ToolExecutor::new(registry);
308
309        let call = ToolCall::new("1", "nonexistent", json!({}));
310        let results = executor.execute_tools(vec![call]).await;
311
312        assert_eq!(results.len(), 1);
313        assert!(results[0].output.is_err());
314    }
315
316    #[tokio::test]
317    async fn test_categorize_calls_no_dependencies() {
318        let registry = ToolRegistry::new();
319        let executor = ToolExecutor::new(registry);
320
321        let calls = vec![
322            ToolCall::new("1", "echo", json!({"id": 1})),
323            ToolCall::new("2", "echo", json!({"id": 2})),
324        ];
325
326        let (independent, dependent) = executor.categorize_calls(&calls);
327
328        assert_eq!(independent.len(), 2);
329        assert_eq!(dependent.len(), 0);
330    }
331
332    #[tokio::test]
333    async fn test_categorize_calls_with_dependencies() {
334        let registry = ToolRegistry::new();
335        let executor = ToolExecutor::new(registry);
336
337        let calls = vec![
338            ToolCall::new("1", "echo", json!({"id": 1})),
339            ToolCall::new("2", "echo", json!({"input": "$output_1", "id": 2})),
340            ToolCall::new("3", "echo", json!({"id": 3})),
341        ];
342
343        let (independent, dependent) = executor.categorize_calls(&calls);
344
345        assert_eq!(independent.len(), 2); // calls 1 and 3
346        assert_eq!(dependent.len(), 1); // call 2
347    }
348
349    #[tokio::test]
350    async fn test_parallel_with_concurrency_limit() {
351        use std::sync::atomic::{AtomicUsize, Ordering};
352        use std::sync::Arc;
353        use std::time::Instant;
354
355        struct ConcurrentTool {
356            counter: Arc<AtomicUsize>,
357        }
358
359        #[async_trait]
360        impl crate::tool::Tool for ConcurrentTool {
361            fn name(&self) -> &str {
362                "concurrent"
363            }
364
365            async fn execute(&self, _args: Value) -> Result<Value, AgentError> {
366                let count = self.counter.fetch_add(1, Ordering::SeqCst);
367                tokio::time::sleep(Duration::from_millis(100)).await;
368                self.counter.fetch_sub(1, Ordering::SeqCst);
369                Ok(json!({"count": count}))
370            }
371        }
372
373        let counter = Arc::new(AtomicUsize::new(0));
374
375        let mut registry = ToolRegistry::new();
376        registry
377            .register(ConcurrentTool {
378                counter: counter.clone(),
379            })
380            .unwrap();
381
382        let executor = ToolExecutor::new(registry).with_max_concurrent(2);
383
384        let calls: Vec<ToolCall> = (0..5)
385            .map(|i| ToolCall::new(i.to_string(), "concurrent", json!({})))
386            .collect();
387
388        let start = Instant::now();
389        let results = executor.execute_tools(calls).await;
390        let duration = start.elapsed();
391
392        assert_eq!(results.len(), 5);
393        assert!(results.iter().all(|r| r.output.is_ok()));
394
395        // With concurrency limit of 2 and 100ms sleep per call:
396        // Should take ~250ms (5 calls / 2 concurrency * 100ms)
397        // Allow some margin for overhead
398        assert!(duration.as_millis() >= 200 && duration.as_millis() <= 400);
399    }
400}