Skip to main content

agent_core_runtime/controller/tools/
executor.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicI64, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::sync::{mpsc, Mutex};
7use tokio::task::JoinHandle;
8use tokio_util::sync::CancellationToken;
9
10use super::registry::ToolRegistry;
11use super::types::{ToolBatchResult, ToolContext, ToolRequest, ToolResult};
12use crate::controller::types::TurnId;
13use crate::permissions::{PermissionRegistry, PermissionRequest};
14
15/// Manages tool execution with support for parallel batch execution.
16pub struct ToolExecutor {
17    registry: Arc<ToolRegistry>,
18    permission_registry: Arc<PermissionRegistry>,
19    tool_result_tx: mpsc::Sender<ToolResult>,
20    batch_result_tx: mpsc::Sender<ToolBatchResult>,
21    batch_counter: AtomicI64,
22}
23
24impl ToolExecutor {
25    /// Create a new tool executor.
26    ///
27    /// # Arguments
28    /// * `registry` - Tool registry for looking up tools
29    /// * `permission_registry` - Permission registry for batch permission requests
30    /// * `tool_result_tx` - Channel for individual tool results (UI feedback)
31    /// * `batch_result_tx` - Channel for batch results (sending to LLM)
32    pub fn new(
33        registry: Arc<ToolRegistry>,
34        permission_registry: Arc<PermissionRegistry>,
35        tool_result_tx: mpsc::Sender<ToolResult>,
36        batch_result_tx: mpsc::Sender<ToolBatchResult>,
37    ) -> Self {
38        Self {
39            registry,
40            permission_registry,
41            tool_result_tx,
42            batch_result_tx,
43            batch_counter: AtomicI64::new(0),
44        }
45    }
46
47    /// Execute a batch of tools in parallel.
48    ///
49    /// This method implements batch permission handling:
50    /// 1. Collects required permissions from all tools via `required_permissions()`
51    /// 2. Requests batch approval from the permission registry (single UI prompt)
52    /// 3. If approved: executes all tools with `permissions_pre_approved: true`
53    /// 4. If denied: returns error results for all tools
54    ///
55    /// Tools that handle their own permissions (`handles_own_permissions() -> true`)
56    /// are always executed regardless of batch permission status.
57    ///
58    /// Results are emitted individually to tool_result_tx for UI feedback,
59    /// and the complete batch is sent to batch_result_tx when all tools finish.
60    ///
61    /// Returns the batch ID.
62    pub async fn execute_batch(
63        &self,
64        session_id: i64,
65        turn_id: Option<TurnId>,
66        requests: Vec<ToolRequest>,
67        cancel_token: CancellationToken,
68    ) -> i64 {
69        let batch_id = self.batch_counter.fetch_add(1, Ordering::SeqCst) + 1;
70        let expected_count = requests.len();
71
72        if expected_count == 0 {
73            // Empty batch - send empty result immediately
74            let batch_result = ToolBatchResult {
75                batch_id,
76                session_id,
77                turn_id,
78                results: Vec::new(),
79            };
80            if let Err(e) = self.batch_result_tx.send(batch_result).await {
81                tracing::debug!("Failed to send empty batch result: {}", e);
82            }
83            return batch_id;
84        }
85
86        tracing::debug!(
87            batch_id,
88            session_id,
89            tool_count = expected_count,
90            "Starting tool batch execution"
91        );
92
93        // Collect permission requirements from all tools
94        let mut all_permissions: Vec<PermissionRequest> = Vec::new();
95        let mut tools_needing_permissions: Vec<String> = Vec::new();
96
97        for request in &requests {
98            if let Some(tool) = self.registry.get(&request.tool_name).await {
99                // Skip tools that handle their own permissions
100                if tool.handles_own_permissions() {
101                    continue;
102                }
103
104                // Build context for permission check
105                let context = ToolContext::new(session_id, &request.tool_use_id, turn_id.clone());
106
107                // Collect required permissions
108                if let Some(perms) = tool.required_permissions(&context, &request.input) {
109                    if !perms.is_empty() {
110                        tools_needing_permissions.push(request.tool_use_id.clone());
111                        all_permissions.extend(perms);
112                    }
113                }
114            }
115        }
116
117        // Determine if batch permissions are pre-approved
118        let permissions_pre_approved = if !all_permissions.is_empty() {
119            tracing::debug!(
120                batch_id,
121                permission_count = all_permissions.len(),
122                tool_count = tools_needing_permissions.len(),
123                "Requesting permissions"
124            );
125
126            // Use single permission request for single permission, batch for multiple
127            if all_permissions.len() == 1 {
128                // Single permission - use simpler PermissionPanel UI
129                let permission = all_permissions.into_iter().next().unwrap();
130                match self
131                    .permission_registry
132                    .request_permission(session_id, permission, turn_id.clone())
133                    .await
134                {
135                    Ok(rx) => {
136                        match rx.await {
137                            Ok(response) => {
138                                if response.granted {
139                                    tracing::info!(batch_id, "Single permission approved");
140                                    true
141                                } else {
142                                    tracing::info!(batch_id, "Single permission denied");
143
144                                    // Create error results for all tools
145                                    let error_results: Vec<ToolResult> = requests
146                                        .iter()
147                                        .map(|req| {
148                                            ToolResult::error(
149                                                session_id,
150                                                req.tool_name.clone(),
151                                                req.tool_use_id.clone(),
152                                                req.input.clone(),
153                                                "Permission denied by user".to_string(),
154                                                turn_id.clone(),
155                                            )
156                                        })
157                                        .collect();
158
159                                    // Send individual error results
160                                    for result in &error_results {
161                                        if let Err(e) =
162                                            self.tool_result_tx.send(result.clone()).await
163                                        {
164                                            tracing::debug!("Failed to send tool result: {}", e);
165                                        }
166                                    }
167
168                                    // Send batch result
169                                    let batch_result = ToolBatchResult {
170                                        batch_id,
171                                        session_id,
172                                        turn_id,
173                                        results: error_results,
174                                    };
175                                    if let Err(e) = self.batch_result_tx.send(batch_result).await {
176                                        tracing::debug!("Failed to send batch result: {}", e);
177                                    }
178
179                                    return batch_id;
180                                }
181                            }
182                            Err(_) => {
183                                // Channel closed - permission request was cancelled
184                                tracing::info!(batch_id, "Single permission request cancelled");
185
186                                let error_results: Vec<ToolResult> = requests
187                                    .iter()
188                                    .map(|req| {
189                                        ToolResult::error(
190                                            session_id,
191                                            req.tool_name.clone(),
192                                            req.tool_use_id.clone(),
193                                            req.input.clone(),
194                                            "Permission request cancelled".to_string(),
195                                            turn_id.clone(),
196                                        )
197                                    })
198                                    .collect();
199
200                                for result in &error_results {
201                                    if let Err(e) = self.tool_result_tx.send(result.clone()).await {
202                                        tracing::debug!("Failed to send tool result: {}", e);
203                                    }
204                                }
205
206                                let batch_result = ToolBatchResult {
207                                    batch_id,
208                                    session_id,
209                                    turn_id,
210                                    results: error_results,
211                                };
212                                if let Err(e) = self.batch_result_tx.send(batch_result).await {
213                                    tracing::debug!("Failed to send batch result: {}", e);
214                                }
215
216                                return batch_id;
217                            }
218                        }
219                    }
220                    Err(e) => {
221                        tracing::warn!(batch_id, error = %e, "Failed to request single permission");
222
223                        let error_results: Vec<ToolResult> = requests
224                            .iter()
225                            .map(|req| {
226                                ToolResult::error(
227                                    session_id,
228                                    req.tool_name.clone(),
229                                    req.tool_use_id.clone(),
230                                    req.input.clone(),
231                                    format!("Permission request failed: {}", e),
232                                    turn_id.clone(),
233                                )
234                            })
235                            .collect();
236
237                        for result in &error_results {
238                            if let Err(e) = self.tool_result_tx.send(result.clone()).await {
239                                tracing::debug!("Failed to send tool result: {}", e);
240                            }
241                        }
242
243                        let batch_result = ToolBatchResult {
244                            batch_id,
245                            session_id,
246                            turn_id,
247                            results: error_results,
248                        };
249                        if let Err(e) = self.batch_result_tx.send(batch_result).await {
250                            tracing::debug!("Failed to send batch result: {}", e);
251                        }
252
253                        return batch_id;
254                    }
255                }
256            } else {
257                // Multiple permissions - use BatchPermissionPanel UI
258                match self
259                    .permission_registry
260                    .register_batch(session_id, all_permissions, turn_id.clone())
261                    .await
262                {
263                Ok(rx) => {
264                    // Wait for permission response
265                    match rx.await {
266                        Ok(response) => {
267                            // Batch permissions: all-or-none model
268                            // If any requests were denied, fail all tools
269                            if !response.denied_requests.is_empty() {
270                                tracing::info!(
271                                    batch_id,
272                                    denied_count = response.denied_requests.len(),
273                                    "Batch permissions denied"
274                                );
275
276                                // Create error results for all tools
277                                let error_results: Vec<ToolResult> = requests
278                                    .iter()
279                                    .map(|req| {
280                                        ToolResult::error(
281                                            session_id,
282                                            req.tool_name.clone(),
283                                            req.tool_use_id.clone(),
284                                            req.input.clone(),
285                                            "Permission denied by user".to_string(),
286                                            turn_id.clone(),
287                                        )
288                                    })
289                                    .collect();
290
291                                // Send individual error results
292                                for result in &error_results {
293                                    if let Err(e) =
294                                        self.tool_result_tx.send(result.clone()).await
295                                    {
296                                        tracing::debug!("Failed to send tool result: {}", e);
297                                    }
298                                }
299
300                                // Send batch result
301                                let batch_result = ToolBatchResult {
302                                    batch_id,
303                                    session_id,
304                                    turn_id,
305                                    results: error_results,
306                                };
307                                if let Err(e) = self.batch_result_tx.send(batch_result).await {
308                                    tracing::debug!("Failed to send batch result: {}", e);
309                                }
310
311                                return batch_id;
312                            }
313
314                            tracing::info!(
315                                batch_id,
316                                grant_count = response.approved_grants.len(),
317                                "Batch permissions approved"
318                            );
319                            true
320                        }
321                        Err(_) => {
322                            // Channel closed - permission request was cancelled
323                            tracing::info!(batch_id, "Batch permission request cancelled");
324
325                            // Create error results for all tools
326                            let error_results: Vec<ToolResult> = requests
327                                .iter()
328                                .map(|req| {
329                                    ToolResult::error(
330                                        session_id,
331                                        req.tool_name.clone(),
332                                        req.tool_use_id.clone(),
333                                        req.input.clone(),
334                                        "Permission request cancelled".to_string(),
335                                        turn_id.clone(),
336                                    )
337                                })
338                                .collect();
339
340                            // Send individual error results
341                            for result in &error_results {
342                                if let Err(e) = self.tool_result_tx.send(result.clone()).await {
343                                    tracing::debug!("Failed to send tool result: {}", e);
344                                }
345                            }
346
347                            // Send batch result
348                            let batch_result = ToolBatchResult {
349                                batch_id,
350                                session_id,
351                                turn_id,
352                                results: error_results,
353                            };
354                            if let Err(e) = self.batch_result_tx.send(batch_result).await {
355                                tracing::debug!("Failed to send batch result: {}", e);
356                            }
357
358                            return batch_id;
359                        }
360                    }
361                }
362                Err(e) => {
363                    // Failed to register batch - treat as permission denied
364                    tracing::warn!(
365                        batch_id,
366                        error = %e,
367                        "Failed to register batch permission request"
368                    );
369                    false
370                }
371                }
372            }
373        } else {
374            // No permissions needed
375            true
376        };
377
378        // Create batch state
379        let batch = Arc::new(ToolExecutorBatch {
380            batch_id,
381            session_id,
382            turn_id: turn_id.clone(),
383            tool_result_tx: self.tool_result_tx.clone(),
384            batch_result_tx: self.batch_result_tx.clone(),
385            requests: requests.clone(),
386            results: Mutex::new(HashMap::new()),
387            expected_count,
388            permissions_pre_approved,
389            task_handles: Mutex::new(Vec::with_capacity(expected_count)),
390        });
391
392        // Start all tools concurrently, storing JoinHandles for tracking
393        for request in requests {
394            let batch_clone = batch.clone();
395            let registry = self.registry.clone();
396            let cancel = cancel_token.clone();
397            let turn_id = turn_id.clone();
398
399            let handle = tokio::spawn(async move {
400                batch_clone
401                    .run_tool(registry, request, turn_id, cancel)
402                    .await;
403            });
404
405            // Store the handle for tracking
406            batch.task_handles.lock().await.push(handle);
407        }
408
409        batch_id
410    }
411
412    /// Execute a single tool (convenience method that creates a batch of 1).
413    pub async fn execute(
414        &self,
415        session_id: i64,
416        turn_id: Option<TurnId>,
417        request: ToolRequest,
418        cancel_token: CancellationToken,
419    ) -> i64 {
420        self.execute_batch(session_id, turn_id, vec![request], cancel_token)
421            .await
422    }
423}
424
425/// Internal batch state for tracking parallel tool executions.
426struct ToolExecutorBatch {
427    batch_id: i64,
428    session_id: i64,
429    turn_id: Option<TurnId>,
430    tool_result_tx: mpsc::Sender<ToolResult>,
431    batch_result_tx: mpsc::Sender<ToolBatchResult>,
432    requests: Vec<ToolRequest>,
433    results: Mutex<HashMap<String, ToolResult>>,
434    expected_count: usize,
435    /// Whether permissions were pre-approved by the batch executor.
436    permissions_pre_approved: bool,
437    /// JoinHandles for spawned tool tasks, enabling graceful shutdown and panic detection.
438    task_handles: Mutex<Vec<JoinHandle<()>>>,
439}
440
441impl ToolExecutorBatch {
442    /// Run a single tool and add result to the batch.
443    async fn run_tool(
444        &self,
445        registry: Arc<ToolRegistry>,
446        request: ToolRequest,
447        turn_id: Option<TurnId>,
448        cancel_token: CancellationToken,
449    ) {
450        let tool_use_id = request.tool_use_id.clone();
451        let tool_name = request.tool_name.clone();
452        let input = request.input.clone();
453
454        tracing::debug!(
455            batch_id = self.batch_id,
456            session_id = self.session_id,
457            tool_name = %tool_name,
458            tool_use_id = %tool_use_id,
459            "Starting tool execution"
460        );
461
462        // Look up tool in registry
463        let tool = registry.get(&tool_name).await;
464
465        let result = match tool {
466            None => {
467                // Tool not found
468                tracing::warn!(
469                    batch_id = self.batch_id,
470                    tool_name = %tool_name,
471                    "Tool not found in registry"
472                );
473                ToolResult::error(
474                    self.session_id,
475                    tool_name,
476                    tool_use_id,
477                    input,
478                    format!("Tool not found: {}", request.tool_name),
479                    turn_id,
480                )
481            }
482            Some(tool) => {
483                // Get display name from tool's display config
484                let display_name = Some(tool.display_config().display_name);
485
486                // Build tool context with pre-approved flag from batch
487                let context = ToolContext {
488                    session_id: self.session_id,
489                    tool_use_id: tool_use_id.clone(),
490                    turn_id: turn_id.clone(),
491                    permissions_pre_approved: self.permissions_pre_approved,
492                };
493
494                // Execute tool with cancellation support
495                tokio::select! {
496                    exec_result = tool.execute(context, input.clone()) => {
497                        match exec_result {
498                            Ok(content) => {
499                                tracing::info!(
500                                    batch_id = self.batch_id,
501                                    tool_name = %tool_name,
502                                    result_bytes = content.len(),
503                                    "Tool execution succeeded"
504                                );
505                                // Compute compact summary for compaction
506                                let compact_summary = Some(tool.compact_summary(&input, &content));
507                                ToolResult::success(
508                                    self.session_id,
509                                    tool_name,
510                                    display_name,
511                                    tool_use_id,
512                                    input,
513                                    content,
514                                    turn_id,
515                                    compact_summary,
516                                )
517                            }
518                            Err(error) => {
519                                tracing::warn!(
520                                    batch_id = self.batch_id,
521                                    tool_name = %tool_name,
522                                    error = %error,
523                                    "Tool execution failed"
524                                );
525                                ToolResult::error(
526                                    self.session_id,
527                                    tool_name,
528                                    tool_use_id,
529                                    input,
530                                    error,
531                                    turn_id,
532                                )
533                            }
534                        }
535                    }
536                    _ = cancel_token.cancelled() => {
537                        tracing::warn!(
538                            batch_id = self.batch_id,
539                            tool_name = %tool_name,
540                            "Tool execution cancelled"
541                        );
542                        ToolResult::timeout(
543                            self.session_id,
544                            tool_name,
545                            tool_use_id,
546                            input,
547                            turn_id,
548                        )
549                    }
550                }
551            }
552        };
553
554        self.add_result(result).await;
555    }
556
557    /// Add a result to the batch and check for completion.
558    async fn add_result(&self, result: ToolResult) {
559        // Send individual result for UI feedback
560        if let Err(e) = self.tool_result_tx.send(result.clone()).await {
561            tracing::debug!("Failed to send tool result: {}", e);
562        }
563
564        let mut results = self.results.lock().await;
565        results.insert(result.tool_use_id.clone(), result);
566
567        tracing::debug!(
568            batch_id = self.batch_id,
569            completed = results.len(),
570            expected = self.expected_count,
571            "Tool completed in batch"
572        );
573
574        // Check if all tools have completed
575        if results.len() == self.expected_count {
576            self.send_batch_result(&results).await;
577        }
578    }
579
580    /// Send the complete batch result.
581    async fn send_batch_result(&self, results: &HashMap<String, ToolResult>) {
582        // Build results in original request order
583        let ordered_results: Vec<ToolResult> = self
584            .requests
585            .iter()
586            .filter_map(|req| results.get(&req.tool_use_id).cloned())
587            .collect();
588
589        let batch_result = ToolBatchResult {
590            batch_id: self.batch_id,
591            session_id: self.session_id,
592            turn_id: self.turn_id.clone(),
593            results: ordered_results,
594        };
595
596        tracing::debug!(
597            batch_id = self.batch_id,
598            session_id = self.session_id,
599            result_count = batch_result.results.len(),
600            "Sending batch result"
601        );
602
603        if let Err(e) = self.batch_result_tx.send(batch_result).await {
604            tracing::debug!("Failed to send batch result: {}", e);
605        }
606    }
607
608    /// Await completion of all spawned tasks with an optional timeout.
609    ///
610    /// This method drains the task handles and awaits each one. If a timeout is provided,
611    /// tasks that don't complete within the timeout will be logged but not forcefully aborted
612    /// (Tokio tasks can only be aborted by dropping the handle, which we do after timeout).
613    ///
614    /// # Arguments
615    /// * `timeout` - Optional timeout duration. If None, waits indefinitely.
616    ///
617    /// # Returns
618    /// A tuple of (completed_count, panicked_count, timed_out_count)
619    #[allow(dead_code)] // Available for future use in graceful shutdown
620    async fn await_completion(&self, timeout: Option<Duration>) -> (usize, usize, usize) {
621        let handles: Vec<JoinHandle<()>> = {
622            let mut guard = self.task_handles.lock().await;
623            std::mem::take(&mut *guard)
624        };
625
626        let total = handles.len();
627        let mut completed = 0;
628        let mut panicked = 0;
629        let mut timed_out = 0;
630
631        for handle in handles {
632            let result = if let Some(timeout_duration) = timeout {
633                match tokio::time::timeout(timeout_duration, handle).await {
634                    Ok(join_result) => Some(join_result),
635                    Err(_) => {
636                        timed_out += 1;
637                        tracing::warn!(
638                            batch_id = self.batch_id,
639                            "Task did not complete within timeout"
640                        );
641                        None
642                    }
643                }
644            } else {
645                Some(handle.await)
646            };
647
648            if let Some(join_result) = result {
649                match join_result {
650                    Ok(()) => completed += 1,
651                    Err(e) => {
652                        panicked += 1;
653                        if e.is_panic() {
654                            tracing::error!(
655                                batch_id = self.batch_id,
656                                error = %e,
657                                "Task panicked"
658                            );
659                        } else {
660                            tracing::warn!(
661                                batch_id = self.batch_id,
662                                error = %e,
663                                "Task was cancelled"
664                            );
665                        }
666                    }
667                }
668            }
669        }
670
671        tracing::debug!(
672            batch_id = self.batch_id,
673            total,
674            completed,
675            panicked,
676            timed_out,
677            "Batch task completion summary"
678        );
679
680        (completed, panicked, timed_out)
681    }
682
683    /// Returns the number of tasks that are still running.
684    #[allow(dead_code)] // Available for future use in monitoring
685    async fn active_task_count(&self) -> usize {
686        let handles = self.task_handles.lock().await;
687        handles.iter().filter(|h| !h.is_finished()).count()
688    }
689}
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694    use crate::controller::tools::types::{Executable, ToolResultStatus, ToolType};
695    use crate::controller::types::ControllerEvent;
696    use std::future::Future;
697    use std::pin::Pin;
698    use std::time::Duration;
699
700    /// Create a test permission registry that auto-approves everything.
701    fn create_test_permission_registry() -> Arc<PermissionRegistry> {
702        let (event_tx, _event_rx) = mpsc::channel::<ControllerEvent>(10);
703        Arc::new(PermissionRegistry::new(event_tx))
704    }
705
706    struct EchoTool;
707
708    impl Executable for EchoTool {
709        fn name(&self) -> &str {
710            "echo"
711        }
712
713        fn description(&self) -> &str {
714            "Echoes input back"
715        }
716
717        fn input_schema(&self) -> &str {
718            r#"{"type":"object","properties":{"message":{"type":"string"}}}"#
719        }
720
721        fn tool_type(&self) -> ToolType {
722            ToolType::Custom
723        }
724
725        fn execute(
726            &self,
727            _context: ToolContext,
728            input: HashMap<String, serde_json::Value>,
729        ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
730            let message = input
731                .get("message")
732                .and_then(|v| v.as_str())
733                .unwrap_or("no message")
734                .to_string();
735            Box::pin(async move { Ok(format!("Echo: {}", message)) })
736        }
737    }
738
739    struct SlowTool;
740
741    impl Executable for SlowTool {
742        fn name(&self) -> &str {
743            "slow"
744        }
745
746        fn description(&self) -> &str {
747            "A slow tool for testing timeouts"
748        }
749
750        fn input_schema(&self) -> &str {
751            r#"{"type":"object"}"#
752        }
753
754        fn tool_type(&self) -> ToolType {
755            ToolType::Custom
756        }
757
758        fn execute(
759            &self,
760            _context: ToolContext,
761            _input: HashMap<String, serde_json::Value>,
762        ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
763            Box::pin(async {
764                tokio::time::sleep(Duration::from_secs(10)).await;
765                Ok("done".to_string())
766            })
767        }
768    }
769
770    #[tokio::test]
771    async fn test_execute_single_tool() {
772        let registry = Arc::new(ToolRegistry::new());
773        registry.register(Arc::new(EchoTool)).await.unwrap();
774
775        let permission_registry = create_test_permission_registry();
776        let (tool_tx, mut tool_rx) = mpsc::channel(10);
777        let (batch_tx, mut batch_rx) = mpsc::channel(10);
778
779        let executor = ToolExecutor::new(registry, permission_registry, tool_tx, batch_tx);
780
781        let mut input = HashMap::new();
782        input.insert(
783            "message".to_string(),
784            serde_json::Value::String("hello".to_string()),
785        );
786
787        let request = ToolRequest {
788            tool_use_id: "test_1".to_string(),
789            tool_name: "echo".to_string(),
790            input,
791        };
792
793        let cancel = CancellationToken::new();
794        executor.execute(1, None, request, cancel).await;
795
796        // Wait for individual result
797        let result = tool_rx.recv().await.unwrap();
798        assert_eq!(result.status, ToolResultStatus::Success);
799        assert!(result.content.contains("Echo: hello"));
800
801        // Wait for batch result
802        let batch = batch_rx.recv().await.unwrap();
803        assert_eq!(batch.results.len(), 1);
804    }
805
806    #[tokio::test]
807    async fn test_execute_batch() {
808        let registry = Arc::new(ToolRegistry::new());
809        registry.register(Arc::new(EchoTool)).await.unwrap();
810
811        let permission_registry = create_test_permission_registry();
812        let (tool_tx, mut tool_rx) = mpsc::channel(10);
813        let (batch_tx, mut batch_rx) = mpsc::channel(10);
814
815        let executor = ToolExecutor::new(registry, permission_registry, tool_tx, batch_tx);
816
817        let requests: Vec<ToolRequest> = (0..3)
818            .map(|i| {
819                let mut input = HashMap::new();
820                input.insert(
821                    "message".to_string(),
822                    serde_json::Value::String(format!("msg_{}", i)),
823                );
824                ToolRequest {
825                    tool_use_id: format!("tool_{}", i),
826                    tool_name: "echo".to_string(),
827                    input,
828                }
829            })
830            .collect();
831
832        let cancel = CancellationToken::new();
833        executor.execute_batch(1, None, requests, cancel).await;
834
835        // Collect individual results
836        for _ in 0..3 {
837            let result = tool_rx.recv().await.unwrap();
838            assert_eq!(result.status, ToolResultStatus::Success);
839        }
840
841        // Wait for batch result
842        let batch = batch_rx.recv().await.unwrap();
843        assert_eq!(batch.results.len(), 3);
844    }
845
846    #[tokio::test]
847    async fn test_tool_not_found() {
848        let registry = Arc::new(ToolRegistry::new());
849
850        let permission_registry = create_test_permission_registry();
851        let (tool_tx, mut tool_rx) = mpsc::channel(10);
852        let (batch_tx, _batch_rx) = mpsc::channel(10);
853
854        let executor = ToolExecutor::new(registry, permission_registry, tool_tx, batch_tx);
855
856        let request = ToolRequest {
857            tool_use_id: "test_1".to_string(),
858            tool_name: "nonexistent".to_string(),
859            input: HashMap::new(),
860        };
861
862        let cancel = CancellationToken::new();
863        executor.execute(1, None, request, cancel).await;
864
865        let result = tool_rx.recv().await.unwrap();
866        assert_eq!(result.status, ToolResultStatus::Error);
867        assert!(result.error.unwrap().contains("not found"));
868    }
869
870    #[tokio::test]
871    async fn test_tool_cancellation() {
872        let registry = Arc::new(ToolRegistry::new());
873        registry.register(Arc::new(SlowTool)).await.unwrap();
874
875        let permission_registry = create_test_permission_registry();
876        let (tool_tx, mut tool_rx) = mpsc::channel(10);
877        let (batch_tx, _batch_rx) = mpsc::channel(10);
878
879        let executor = ToolExecutor::new(registry, permission_registry, tool_tx, batch_tx);
880
881        let request = ToolRequest {
882            tool_use_id: "test_1".to_string(),
883            tool_name: "slow".to_string(),
884            input: HashMap::new(),
885        };
886
887        let cancel = CancellationToken::new();
888        let cancel_clone = cancel.clone();
889
890        // Start execution
891        executor.execute(1, None, request, cancel).await;
892
893        // Cancel after a short delay
894        tokio::spawn(async move {
895            tokio::time::sleep(Duration::from_millis(50)).await;
896            cancel_clone.cancel();
897        });
898
899        // Wait for result - should be timeout/cancelled
900        let result = tool_rx.recv().await.unwrap();
901        assert_eq!(result.status, ToolResultStatus::Timeout);
902    }
903}