agent_core/controller/tools/
executor.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicI64, Ordering};
3use std::sync::Arc;
4
5use tokio::sync::{mpsc, Mutex};
6use tokio_util::sync::CancellationToken;
7
8use super::registry::ToolRegistry;
9use super::types::{ToolBatchResult, ToolContext, ToolRequest, ToolResult};
10use crate::controller::types::TurnId;
11
12/// Manages tool execution with support for parallel batch execution.
13pub struct ToolExecutor {
14    registry: Arc<ToolRegistry>,
15    tool_result_tx: mpsc::Sender<ToolResult>,
16    batch_result_tx: mpsc::Sender<ToolBatchResult>,
17    batch_counter: AtomicI64,
18}
19
20impl ToolExecutor {
21    /// Create a new tool executor.
22    ///
23    /// # Arguments
24    /// * `registry` - Tool registry for looking up tools
25    /// * `tool_result_tx` - Channel for individual tool results (UI feedback)
26    /// * `batch_result_tx` - Channel for batch results (sending to LLM)
27    pub fn new(
28        registry: Arc<ToolRegistry>,
29        tool_result_tx: mpsc::Sender<ToolResult>,
30        batch_result_tx: mpsc::Sender<ToolBatchResult>,
31    ) -> Self {
32        Self {
33            registry,
34            tool_result_tx,
35            batch_result_tx,
36            batch_counter: AtomicI64::new(0),
37        }
38    }
39
40    /// Execute a batch of tools in parallel.
41    ///
42    /// Results are emitted individually to tool_result_tx for UI feedback,
43    /// and the complete batch is sent to batch_result_tx when all tools finish.
44    ///
45    /// Returns the batch ID.
46    pub async fn execute_batch(
47        &self,
48        session_id: i64,
49        turn_id: Option<TurnId>,
50        requests: Vec<ToolRequest>,
51        cancel_token: CancellationToken,
52    ) -> i64 {
53        let batch_id = self.batch_counter.fetch_add(1, Ordering::SeqCst) + 1;
54        let expected_count = requests.len();
55
56        if expected_count == 0 {
57            // Empty batch - send empty result immediately
58            let batch_result = ToolBatchResult {
59                batch_id,
60                session_id,
61                turn_id,
62                results: Vec::new(),
63            };
64            if let Err(e) = self.batch_result_tx.send(batch_result).await {
65                tracing::debug!("Failed to send empty batch result: {}", e);
66            }
67            return batch_id;
68        }
69
70        tracing::debug!(
71            batch_id,
72            session_id,
73            tool_count = expected_count,
74            "Starting tool batch execution"
75        );
76
77        // Create batch state
78        let batch = Arc::new(ToolExecutorBatch {
79            batch_id,
80            session_id,
81            turn_id: turn_id.clone(),
82            tool_result_tx: self.tool_result_tx.clone(),
83            batch_result_tx: self.batch_result_tx.clone(),
84            requests: requests.clone(),
85            results: Mutex::new(HashMap::new()),
86            expected_count,
87        });
88
89        // Start all tools concurrently
90        for request in requests {
91            let batch = batch.clone();
92            let registry = self.registry.clone();
93            let cancel = cancel_token.clone();
94            let turn_id = turn_id.clone();
95
96            tokio::spawn(async move {
97                batch
98                    .run_tool(registry, request, turn_id, cancel)
99                    .await;
100            });
101        }
102
103        batch_id
104    }
105
106    /// Execute a single tool (convenience method that creates a batch of 1).
107    pub async fn execute(
108        &self,
109        session_id: i64,
110        turn_id: Option<TurnId>,
111        request: ToolRequest,
112        cancel_token: CancellationToken,
113    ) -> i64 {
114        self.execute_batch(session_id, turn_id, vec![request], cancel_token)
115            .await
116    }
117}
118
119/// Internal batch state for tracking parallel tool executions.
120struct ToolExecutorBatch {
121    batch_id: i64,
122    session_id: i64,
123    turn_id: Option<TurnId>,
124    tool_result_tx: mpsc::Sender<ToolResult>,
125    batch_result_tx: mpsc::Sender<ToolBatchResult>,
126    requests: Vec<ToolRequest>,
127    results: Mutex<HashMap<String, ToolResult>>,
128    expected_count: usize,
129}
130
131impl ToolExecutorBatch {
132    /// Run a single tool and add result to the batch.
133    async fn run_tool(
134        &self,
135        registry: Arc<ToolRegistry>,
136        request: ToolRequest,
137        turn_id: Option<TurnId>,
138        cancel_token: CancellationToken,
139    ) {
140        let tool_use_id = request.tool_use_id.clone();
141        let tool_name = request.tool_name.clone();
142        let input = request.input.clone();
143
144        tracing::debug!(
145            batch_id = self.batch_id,
146            session_id = self.session_id,
147            tool_name = %tool_name,
148            tool_use_id = %tool_use_id,
149            "Starting tool execution"
150        );
151
152        // Look up tool in registry
153        let tool = registry.get(&tool_name).await;
154
155        let result = match tool {
156            None => {
157                // Tool not found
158                tracing::warn!(
159                    batch_id = self.batch_id,
160                    tool_name = %tool_name,
161                    "Tool not found in registry"
162                );
163                ToolResult::error(
164                    self.session_id,
165                    tool_name,
166                    tool_use_id,
167                    input,
168                    format!("Tool not found: {}", request.tool_name),
169                    turn_id,
170                )
171            }
172            Some(tool) => {
173                // Get display name from tool's display config
174                let display_name = Some(tool.display_config().display_name);
175
176                // Build tool context
177                let context = ToolContext {
178                    session_id: self.session_id,
179                    tool_use_id: tool_use_id.clone(),
180                    turn_id: turn_id.clone(),
181                };
182
183                // Execute tool with cancellation support
184                tokio::select! {
185                    exec_result = tool.execute(context, input.clone()) => {
186                        match exec_result {
187                            Ok(content) => {
188                                tracing::info!(
189                                    batch_id = self.batch_id,
190                                    tool_name = %tool_name,
191                                    result_bytes = content.len(),
192                                    "Tool execution succeeded"
193                                );
194                                // Compute compact summary for compaction
195                                let compact_summary = Some(tool.compact_summary(&input, &content));
196                                ToolResult::success(
197                                    self.session_id,
198                                    tool_name,
199                                    display_name,
200                                    tool_use_id,
201                                    input,
202                                    content,
203                                    turn_id,
204                                    compact_summary,
205                                )
206                            }
207                            Err(error) => {
208                                tracing::warn!(
209                                    batch_id = self.batch_id,
210                                    tool_name = %tool_name,
211                                    error = %error,
212                                    "Tool execution failed"
213                                );
214                                ToolResult::error(
215                                    self.session_id,
216                                    tool_name,
217                                    tool_use_id,
218                                    input,
219                                    error,
220                                    turn_id,
221                                )
222                            }
223                        }
224                    }
225                    _ = cancel_token.cancelled() => {
226                        tracing::warn!(
227                            batch_id = self.batch_id,
228                            tool_name = %tool_name,
229                            "Tool execution cancelled"
230                        );
231                        ToolResult::timeout(
232                            self.session_id,
233                            tool_name,
234                            tool_use_id,
235                            input,
236                            turn_id,
237                        )
238                    }
239                }
240            }
241        };
242
243        self.add_result(result).await;
244    }
245
246    /// Add a result to the batch and check for completion.
247    async fn add_result(&self, result: ToolResult) {
248        // Send individual result for UI feedback
249        if let Err(e) = self.tool_result_tx.send(result.clone()).await {
250            tracing::debug!("Failed to send tool result: {}", e);
251        }
252
253        let mut results = self.results.lock().await;
254        results.insert(result.tool_use_id.clone(), result);
255
256        tracing::debug!(
257            batch_id = self.batch_id,
258            completed = results.len(),
259            expected = self.expected_count,
260            "Tool completed in batch"
261        );
262
263        // Check if all tools have completed
264        if results.len() == self.expected_count {
265            self.send_batch_result(&results).await;
266        }
267    }
268
269    /// Send the complete batch result.
270    async fn send_batch_result(&self, results: &HashMap<String, ToolResult>) {
271        // Build results in original request order
272        let ordered_results: Vec<ToolResult> = self
273            .requests
274            .iter()
275            .filter_map(|req| results.get(&req.tool_use_id).cloned())
276            .collect();
277
278        let batch_result = ToolBatchResult {
279            batch_id: self.batch_id,
280            session_id: self.session_id,
281            turn_id: self.turn_id.clone(),
282            results: ordered_results,
283        };
284
285        tracing::debug!(
286            batch_id = self.batch_id,
287            session_id = self.session_id,
288            result_count = batch_result.results.len(),
289            "Sending batch result"
290        );
291
292        if let Err(e) = self.batch_result_tx.send(batch_result).await {
293            tracing::debug!("Failed to send batch result: {}", e);
294        }
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use crate::controller::tools::types::{Executable, ToolResultStatus, ToolType};
302    use std::future::Future;
303    use std::pin::Pin;
304    use std::time::Duration;
305
306    struct EchoTool;
307
308    impl Executable for EchoTool {
309        fn name(&self) -> &str {
310            "echo"
311        }
312
313        fn description(&self) -> &str {
314            "Echoes input back"
315        }
316
317        fn input_schema(&self) -> &str {
318            r#"{"type":"object","properties":{"message":{"type":"string"}}}"#
319        }
320
321        fn tool_type(&self) -> ToolType {
322            ToolType::Custom
323        }
324
325        fn execute(
326            &self,
327            _context: ToolContext,
328            input: HashMap<String, serde_json::Value>,
329        ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
330            let message = input
331                .get("message")
332                .and_then(|v| v.as_str())
333                .unwrap_or("no message")
334                .to_string();
335            Box::pin(async move { Ok(format!("Echo: {}", message)) })
336        }
337    }
338
339    struct SlowTool;
340
341    impl Executable for SlowTool {
342        fn name(&self) -> &str {
343            "slow"
344        }
345
346        fn description(&self) -> &str {
347            "A slow tool for testing timeouts"
348        }
349
350        fn input_schema(&self) -> &str {
351            r#"{"type":"object"}"#
352        }
353
354        fn tool_type(&self) -> ToolType {
355            ToolType::Custom
356        }
357
358        fn execute(
359            &self,
360            _context: ToolContext,
361            _input: HashMap<String, serde_json::Value>,
362        ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
363            Box::pin(async {
364                tokio::time::sleep(Duration::from_secs(10)).await;
365                Ok("done".to_string())
366            })
367        }
368    }
369
370    #[tokio::test]
371    async fn test_execute_single_tool() {
372        let registry = Arc::new(ToolRegistry::new());
373        registry.register(Arc::new(EchoTool)).await.unwrap();
374
375        let (tool_tx, mut tool_rx) = mpsc::channel(10);
376        let (batch_tx, mut batch_rx) = mpsc::channel(10);
377
378        let executor = ToolExecutor::new(registry, tool_tx, batch_tx);
379
380        let mut input = HashMap::new();
381        input.insert(
382            "message".to_string(),
383            serde_json::Value::String("hello".to_string()),
384        );
385
386        let request = ToolRequest {
387            tool_use_id: "test_1".to_string(),
388            tool_name: "echo".to_string(),
389            input,
390        };
391
392        let cancel = CancellationToken::new();
393        executor.execute(1, None, request, cancel).await;
394
395        // Wait for individual result
396        let result = tool_rx.recv().await.unwrap();
397        assert_eq!(result.status, ToolResultStatus::Success);
398        assert!(result.content.contains("Echo: hello"));
399
400        // Wait for batch result
401        let batch = batch_rx.recv().await.unwrap();
402        assert_eq!(batch.results.len(), 1);
403    }
404
405    #[tokio::test]
406    async fn test_execute_batch() {
407        let registry = Arc::new(ToolRegistry::new());
408        registry.register(Arc::new(EchoTool)).await.unwrap();
409
410        let (tool_tx, mut tool_rx) = mpsc::channel(10);
411        let (batch_tx, mut batch_rx) = mpsc::channel(10);
412
413        let executor = ToolExecutor::new(registry, tool_tx, batch_tx);
414
415        let requests: Vec<ToolRequest> = (0..3)
416            .map(|i| {
417                let mut input = HashMap::new();
418                input.insert(
419                    "message".to_string(),
420                    serde_json::Value::String(format!("msg_{}", i)),
421                );
422                ToolRequest {
423                    tool_use_id: format!("tool_{}", i),
424                    tool_name: "echo".to_string(),
425                    input,
426                }
427            })
428            .collect();
429
430        let cancel = CancellationToken::new();
431        executor.execute_batch(1, None, requests, cancel).await;
432
433        // Collect individual results
434        for _ in 0..3 {
435            let result = tool_rx.recv().await.unwrap();
436            assert_eq!(result.status, ToolResultStatus::Success);
437        }
438
439        // Wait for batch result
440        let batch = batch_rx.recv().await.unwrap();
441        assert_eq!(batch.results.len(), 3);
442    }
443
444    #[tokio::test]
445    async fn test_tool_not_found() {
446        let registry = Arc::new(ToolRegistry::new());
447
448        let (tool_tx, mut tool_rx) = mpsc::channel(10);
449        let (batch_tx, _batch_rx) = mpsc::channel(10);
450
451        let executor = ToolExecutor::new(registry, tool_tx, batch_tx);
452
453        let request = ToolRequest {
454            tool_use_id: "test_1".to_string(),
455            tool_name: "nonexistent".to_string(),
456            input: HashMap::new(),
457        };
458
459        let cancel = CancellationToken::new();
460        executor.execute(1, None, request, cancel).await;
461
462        let result = tool_rx.recv().await.unwrap();
463        assert_eq!(result.status, ToolResultStatus::Error);
464        assert!(result.error.unwrap().contains("not found"));
465    }
466
467    #[tokio::test]
468    async fn test_tool_cancellation() {
469        let registry = Arc::new(ToolRegistry::new());
470        registry.register(Arc::new(SlowTool)).await.unwrap();
471
472        let (tool_tx, mut tool_rx) = mpsc::channel(10);
473        let (batch_tx, _batch_rx) = mpsc::channel(10);
474
475        let executor = ToolExecutor::new(registry, tool_tx, batch_tx);
476
477        let request = ToolRequest {
478            tool_use_id: "test_1".to_string(),
479            tool_name: "slow".to_string(),
480            input: HashMap::new(),
481        };
482
483        let cancel = CancellationToken::new();
484        let cancel_clone = cancel.clone();
485
486        // Start execution
487        executor.execute(1, None, request, cancel).await;
488
489        // Cancel after a short delay
490        tokio::spawn(async move {
491            tokio::time::sleep(Duration::from_millis(50)).await;
492            cancel_clone.cancel();
493        });
494
495        // Wait for result - should be timeout/cancelled
496        let result = tool_rx.recv().await.unwrap();
497        assert_eq!(result.status, ToolResultStatus::Timeout);
498    }
499}