agent_core/controller/tools/
executor.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicI64, Ordering};
3use std::sync::Arc;
4
5use tokio::sync::{mpsc, Mutex};
6use tokio_util::sync::CancellationToken;
7
8use super::registry::ToolRegistry;
9use super::types::{ToolBatchResult, ToolContext, ToolRequest, ToolResult};
10use crate::controller::types::TurnId;
11
12/// Manages tool execution with support for parallel batch execution.
13pub struct ToolExecutor {
14    registry: Arc<ToolRegistry>,
15    tool_result_tx: mpsc::Sender<ToolResult>,
16    batch_result_tx: mpsc::Sender<ToolBatchResult>,
17    batch_counter: AtomicI64,
18}
19
20impl ToolExecutor {
21    /// Create a new tool executor.
22    ///
23    /// # Arguments
24    /// * `registry` - Tool registry for looking up tools
25    /// * `tool_result_tx` - Channel for individual tool results (UI feedback)
26    /// * `batch_result_tx` - Channel for batch results (sending to LLM)
27    pub fn new(
28        registry: Arc<ToolRegistry>,
29        tool_result_tx: mpsc::Sender<ToolResult>,
30        batch_result_tx: mpsc::Sender<ToolBatchResult>,
31    ) -> Self {
32        Self {
33            registry,
34            tool_result_tx,
35            batch_result_tx,
36            batch_counter: AtomicI64::new(0),
37        }
38    }
39
40    /// Execute a batch of tools in parallel.
41    ///
42    /// Results are emitted individually to tool_result_tx for UI feedback,
43    /// and the complete batch is sent to batch_result_tx when all tools finish.
44    ///
45    /// Returns the batch ID.
46    pub async fn execute_batch(
47        &self,
48        session_id: i64,
49        turn_id: Option<TurnId>,
50        requests: Vec<ToolRequest>,
51        cancel_token: CancellationToken,
52    ) -> i64 {
53        let batch_id = self.batch_counter.fetch_add(1, Ordering::SeqCst) + 1;
54        let expected_count = requests.len();
55
56        if expected_count == 0 {
57            // Empty batch - send empty result immediately
58            let batch_result = ToolBatchResult {
59                batch_id,
60                session_id,
61                turn_id,
62                results: Vec::new(),
63            };
64            let _ = self.batch_result_tx.send(batch_result).await;
65            return batch_id;
66        }
67
68        tracing::debug!(
69            batch_id,
70            session_id,
71            tool_count = expected_count,
72            "Starting tool batch execution"
73        );
74
75        // Create batch state
76        let batch = Arc::new(ToolExecutorBatch {
77            batch_id,
78            session_id,
79            turn_id: turn_id.clone(),
80            tool_result_tx: self.tool_result_tx.clone(),
81            batch_result_tx: self.batch_result_tx.clone(),
82            requests: requests.clone(),
83            results: Mutex::new(HashMap::new()),
84            expected_count,
85        });
86
87        // Start all tools concurrently
88        for request in requests {
89            let batch = batch.clone();
90            let registry = self.registry.clone();
91            let cancel = cancel_token.clone();
92            let turn_id = turn_id.clone();
93
94            tokio::spawn(async move {
95                batch
96                    .run_tool(registry, request, turn_id, cancel)
97                    .await;
98            });
99        }
100
101        batch_id
102    }
103
104    /// Execute a single tool (convenience method that creates a batch of 1).
105    pub async fn execute(
106        &self,
107        session_id: i64,
108        turn_id: Option<TurnId>,
109        request: ToolRequest,
110        cancel_token: CancellationToken,
111    ) -> i64 {
112        self.execute_batch(session_id, turn_id, vec![request], cancel_token)
113            .await
114    }
115}
116
117/// Internal batch state for tracking parallel tool executions.
118struct ToolExecutorBatch {
119    batch_id: i64,
120    session_id: i64,
121    turn_id: Option<TurnId>,
122    tool_result_tx: mpsc::Sender<ToolResult>,
123    batch_result_tx: mpsc::Sender<ToolBatchResult>,
124    requests: Vec<ToolRequest>,
125    results: Mutex<HashMap<String, ToolResult>>,
126    expected_count: usize,
127}
128
129impl ToolExecutorBatch {
130    /// Run a single tool and add result to the batch.
131    async fn run_tool(
132        &self,
133        registry: Arc<ToolRegistry>,
134        request: ToolRequest,
135        turn_id: Option<TurnId>,
136        cancel_token: CancellationToken,
137    ) {
138        let tool_use_id = request.tool_use_id.clone();
139        let tool_name = request.tool_name.clone();
140        let input = request.input.clone();
141
142        tracing::debug!(
143            batch_id = self.batch_id,
144            session_id = self.session_id,
145            tool_name = %tool_name,
146            tool_use_id = %tool_use_id,
147            "Starting tool execution"
148        );
149
150        // Look up tool in registry
151        let tool = registry.get(&tool_name).await;
152
153        let result = match tool {
154            None => {
155                // Tool not found
156                tracing::warn!(
157                    batch_id = self.batch_id,
158                    tool_name = %tool_name,
159                    "Tool not found in registry"
160                );
161                ToolResult::error(
162                    self.session_id,
163                    tool_name,
164                    tool_use_id,
165                    input,
166                    format!("Tool not found: {}", request.tool_name),
167                    turn_id,
168                )
169            }
170            Some(tool) => {
171                // Get display name from tool's display config
172                let display_name = Some(tool.display_config().display_name);
173
174                // Build tool context
175                let context = ToolContext {
176                    session_id: self.session_id,
177                    tool_use_id: tool_use_id.clone(),
178                    turn_id: turn_id.clone(),
179                };
180
181                // Execute tool with cancellation support
182                tokio::select! {
183                    exec_result = tool.execute(context, input.clone()) => {
184                        match exec_result {
185                            Ok(content) => {
186                                tracing::info!(
187                                    batch_id = self.batch_id,
188                                    tool_name = %tool_name,
189                                    result_bytes = content.len(),
190                                    "Tool execution succeeded"
191                                );
192                                // Compute compact summary for compaction
193                                let compact_summary = Some(tool.compact_summary(&input, &content));
194                                ToolResult::success(
195                                    self.session_id,
196                                    tool_name,
197                                    display_name,
198                                    tool_use_id,
199                                    input,
200                                    content,
201                                    turn_id,
202                                    compact_summary,
203                                )
204                            }
205                            Err(error) => {
206                                tracing::warn!(
207                                    batch_id = self.batch_id,
208                                    tool_name = %tool_name,
209                                    error = %error,
210                                    "Tool execution failed"
211                                );
212                                ToolResult::error(
213                                    self.session_id,
214                                    tool_name,
215                                    tool_use_id,
216                                    input,
217                                    error,
218                                    turn_id,
219                                )
220                            }
221                        }
222                    }
223                    _ = cancel_token.cancelled() => {
224                        tracing::warn!(
225                            batch_id = self.batch_id,
226                            tool_name = %tool_name,
227                            "Tool execution cancelled"
228                        );
229                        ToolResult::timeout(
230                            self.session_id,
231                            tool_name,
232                            tool_use_id,
233                            input,
234                            turn_id,
235                        )
236                    }
237                }
238            }
239        };
240
241        self.add_result(result).await;
242    }
243
244    /// Add a result to the batch and check for completion.
245    async fn add_result(&self, result: ToolResult) {
246        // Send individual result for UI feedback
247        let _ = self.tool_result_tx.send(result.clone()).await;
248
249        let mut results = self.results.lock().await;
250        results.insert(result.tool_use_id.clone(), result);
251
252        tracing::debug!(
253            batch_id = self.batch_id,
254            completed = results.len(),
255            expected = self.expected_count,
256            "Tool completed in batch"
257        );
258
259        // Check if all tools have completed
260        if results.len() == self.expected_count {
261            self.send_batch_result(&results).await;
262        }
263    }
264
265    /// Send the complete batch result.
266    async fn send_batch_result(&self, results: &HashMap<String, ToolResult>) {
267        // Build results in original request order
268        let ordered_results: Vec<ToolResult> = self
269            .requests
270            .iter()
271            .filter_map(|req| results.get(&req.tool_use_id).cloned())
272            .collect();
273
274        let batch_result = ToolBatchResult {
275            batch_id: self.batch_id,
276            session_id: self.session_id,
277            turn_id: self.turn_id.clone(),
278            results: ordered_results,
279        };
280
281        tracing::debug!(
282            batch_id = self.batch_id,
283            session_id = self.session_id,
284            result_count = batch_result.results.len(),
285            "Sending batch result"
286        );
287
288        let _ = self.batch_result_tx.send(batch_result).await;
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::controller::tools::types::{Executable, ToolResultStatus, ToolType};
296    use std::future::Future;
297    use std::pin::Pin;
298    use std::time::Duration;
299
300    struct EchoTool;
301
302    impl Executable for EchoTool {
303        fn name(&self) -> &str {
304            "echo"
305        }
306
307        fn description(&self) -> &str {
308            "Echoes input back"
309        }
310
311        fn input_schema(&self) -> &str {
312            r#"{"type":"object","properties":{"message":{"type":"string"}}}"#
313        }
314
315        fn tool_type(&self) -> ToolType {
316            ToolType::Custom
317        }
318
319        fn execute(
320            &self,
321            _context: ToolContext,
322            input: HashMap<String, serde_json::Value>,
323        ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
324            let message = input
325                .get("message")
326                .and_then(|v| v.as_str())
327                .unwrap_or("no message")
328                .to_string();
329            Box::pin(async move { Ok(format!("Echo: {}", message)) })
330        }
331    }
332
333    #[allow(dead_code)]
334    struct SlowTool;
335
336    impl Executable for SlowTool {
337        fn name(&self) -> &str {
338            "slow"
339        }
340
341        fn description(&self) -> &str {
342            "A slow tool for testing timeouts"
343        }
344
345        fn input_schema(&self) -> &str {
346            r#"{"type":"object"}"#
347        }
348
349        fn tool_type(&self) -> ToolType {
350            ToolType::Custom
351        }
352
353        fn execute(
354            &self,
355            _context: ToolContext,
356            _input: HashMap<String, serde_json::Value>,
357        ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
358            Box::pin(async {
359                tokio::time::sleep(Duration::from_secs(10)).await;
360                Ok("done".to_string())
361            })
362        }
363    }
364
365    #[tokio::test]
366    async fn test_execute_single_tool() {
367        let registry = Arc::new(ToolRegistry::new());
368        registry.register(Arc::new(EchoTool)).await.unwrap();
369
370        let (tool_tx, mut tool_rx) = mpsc::channel(10);
371        let (batch_tx, mut batch_rx) = mpsc::channel(10);
372
373        let executor = ToolExecutor::new(registry, tool_tx, batch_tx);
374
375        let mut input = HashMap::new();
376        input.insert(
377            "message".to_string(),
378            serde_json::Value::String("hello".to_string()),
379        );
380
381        let request = ToolRequest {
382            tool_use_id: "test_1".to_string(),
383            tool_name: "echo".to_string(),
384            input,
385        };
386
387        let cancel = CancellationToken::new();
388        executor.execute(1, None, request, cancel).await;
389
390        // Wait for individual result
391        let result = tool_rx.recv().await.unwrap();
392        assert_eq!(result.status, ToolResultStatus::Success);
393        assert!(result.content.contains("Echo: hello"));
394
395        // Wait for batch result
396        let batch = batch_rx.recv().await.unwrap();
397        assert_eq!(batch.results.len(), 1);
398    }
399
400    #[tokio::test]
401    async fn test_execute_batch() {
402        let registry = Arc::new(ToolRegistry::new());
403        registry.register(Arc::new(EchoTool)).await.unwrap();
404
405        let (tool_tx, mut tool_rx) = mpsc::channel(10);
406        let (batch_tx, mut batch_rx) = mpsc::channel(10);
407
408        let executor = ToolExecutor::new(registry, tool_tx, batch_tx);
409
410        let requests: Vec<ToolRequest> = (0..3)
411            .map(|i| {
412                let mut input = HashMap::new();
413                input.insert(
414                    "message".to_string(),
415                    serde_json::Value::String(format!("msg_{}", i)),
416                );
417                ToolRequest {
418                    tool_use_id: format!("tool_{}", i),
419                    tool_name: "echo".to_string(),
420                    input,
421                }
422            })
423            .collect();
424
425        let cancel = CancellationToken::new();
426        executor.execute_batch(1, None, requests, cancel).await;
427
428        // Collect individual results
429        for _ in 0..3 {
430            let result = tool_rx.recv().await.unwrap();
431            assert_eq!(result.status, ToolResultStatus::Success);
432        }
433
434        // Wait for batch result
435        let batch = batch_rx.recv().await.unwrap();
436        assert_eq!(batch.results.len(), 3);
437    }
438
439    #[tokio::test]
440    async fn test_tool_not_found() {
441        let registry = Arc::new(ToolRegistry::new());
442
443        let (tool_tx, mut tool_rx) = mpsc::channel(10);
444        let (batch_tx, _batch_rx) = mpsc::channel(10);
445
446        let executor = ToolExecutor::new(registry, tool_tx, batch_tx);
447
448        let request = ToolRequest {
449            tool_use_id: "test_1".to_string(),
450            tool_name: "nonexistent".to_string(),
451            input: HashMap::new(),
452        };
453
454        let cancel = CancellationToken::new();
455        executor.execute(1, None, request, cancel).await;
456
457        let result = tool_rx.recv().await.unwrap();
458        assert_eq!(result.status, ToolResultStatus::Error);
459        assert!(result.error.unwrap().contains("not found"));
460    }
461}