1use std::collections::HashMap;
2use std::sync::atomic::{AtomicI64, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::sync::{mpsc, Mutex};
7use tokio::task::JoinHandle;
8use tokio_util::sync::CancellationToken;
9
10use super::registry::ToolRegistry;
11use super::types::{ToolBatchResult, ToolContext, ToolRequest, ToolResult};
12use crate::controller::types::TurnId;
13use crate::permissions::{PermissionRegistry, PermissionRequest};
14
15pub struct ToolExecutor {
17 registry: Arc<ToolRegistry>,
18 permission_registry: Arc<PermissionRegistry>,
19 tool_result_tx: mpsc::Sender<ToolResult>,
20 batch_result_tx: mpsc::Sender<ToolBatchResult>,
21 batch_counter: AtomicI64,
22}
23
24impl ToolExecutor {
25 pub fn new(
33 registry: Arc<ToolRegistry>,
34 permission_registry: Arc<PermissionRegistry>,
35 tool_result_tx: mpsc::Sender<ToolResult>,
36 batch_result_tx: mpsc::Sender<ToolBatchResult>,
37 ) -> Self {
38 Self {
39 registry,
40 permission_registry,
41 tool_result_tx,
42 batch_result_tx,
43 batch_counter: AtomicI64::new(0),
44 }
45 }
46
47 pub async fn execute_batch(
63 &self,
64 session_id: i64,
65 turn_id: Option<TurnId>,
66 requests: Vec<ToolRequest>,
67 cancel_token: CancellationToken,
68 ) -> i64 {
69 let batch_id = self.batch_counter.fetch_add(1, Ordering::SeqCst) + 1;
70 let expected_count = requests.len();
71
72 if expected_count == 0 {
73 let batch_result = ToolBatchResult {
75 batch_id,
76 session_id,
77 turn_id,
78 results: Vec::new(),
79 };
80 if let Err(e) = self.batch_result_tx.send(batch_result).await {
81 tracing::debug!("Failed to send empty batch result: {}", e);
82 }
83 return batch_id;
84 }
85
86 tracing::debug!(
87 batch_id,
88 session_id,
89 tool_count = expected_count,
90 "Starting tool batch execution"
91 );
92
93 let mut all_permissions: Vec<PermissionRequest> = Vec::new();
95 let mut tools_needing_permissions: Vec<String> = Vec::new();
96
97 for request in &requests {
98 if let Some(tool) = self.registry.get(&request.tool_name).await {
99 if tool.handles_own_permissions() {
101 continue;
102 }
103
104 let context = ToolContext::new(session_id, &request.tool_use_id, turn_id.clone());
106
107 if let Some(perms) = tool.required_permissions(&context, &request.input) {
109 if !perms.is_empty() {
110 tools_needing_permissions.push(request.tool_use_id.clone());
111 all_permissions.extend(perms);
112 }
113 }
114 }
115 }
116
117 let permissions_pre_approved = if !all_permissions.is_empty() {
119 tracing::debug!(
120 batch_id,
121 permission_count = all_permissions.len(),
122 tool_count = tools_needing_permissions.len(),
123 "Requesting permissions"
124 );
125
126 if all_permissions.len() == 1 {
128 let permission = all_permissions.into_iter().next().unwrap();
130 match self
131 .permission_registry
132 .request_permission(session_id, permission, turn_id.clone())
133 .await
134 {
135 Ok(rx) => {
136 match rx.await {
137 Ok(response) => {
138 if response.granted {
139 tracing::info!(batch_id, "Single permission approved");
140 true
141 } else {
142 tracing::info!(batch_id, "Single permission denied");
143
144 let error_results: Vec<ToolResult> = requests
146 .iter()
147 .map(|req| {
148 ToolResult::error(
149 session_id,
150 req.tool_name.clone(),
151 req.tool_use_id.clone(),
152 req.input.clone(),
153 "Permission denied by user".to_string(),
154 turn_id.clone(),
155 )
156 })
157 .collect();
158
159 for result in &error_results {
161 if let Err(e) =
162 self.tool_result_tx.send(result.clone()).await
163 {
164 tracing::debug!("Failed to send tool result: {}", e);
165 }
166 }
167
168 let batch_result = ToolBatchResult {
170 batch_id,
171 session_id,
172 turn_id,
173 results: error_results,
174 };
175 if let Err(e) = self.batch_result_tx.send(batch_result).await {
176 tracing::debug!("Failed to send batch result: {}", e);
177 }
178
179 return batch_id;
180 }
181 }
182 Err(_) => {
183 tracing::info!(batch_id, "Single permission request cancelled");
185
186 let error_results: Vec<ToolResult> = requests
187 .iter()
188 .map(|req| {
189 ToolResult::error(
190 session_id,
191 req.tool_name.clone(),
192 req.tool_use_id.clone(),
193 req.input.clone(),
194 "Permission request cancelled".to_string(),
195 turn_id.clone(),
196 )
197 })
198 .collect();
199
200 for result in &error_results {
201 if let Err(e) = self.tool_result_tx.send(result.clone()).await {
202 tracing::debug!("Failed to send tool result: {}", e);
203 }
204 }
205
206 let batch_result = ToolBatchResult {
207 batch_id,
208 session_id,
209 turn_id,
210 results: error_results,
211 };
212 if let Err(e) = self.batch_result_tx.send(batch_result).await {
213 tracing::debug!("Failed to send batch result: {}", e);
214 }
215
216 return batch_id;
217 }
218 }
219 }
220 Err(e) => {
221 tracing::warn!(batch_id, error = %e, "Failed to request single permission");
222
223 let error_results: Vec<ToolResult> = requests
224 .iter()
225 .map(|req| {
226 ToolResult::error(
227 session_id,
228 req.tool_name.clone(),
229 req.tool_use_id.clone(),
230 req.input.clone(),
231 format!("Permission request failed: {}", e),
232 turn_id.clone(),
233 )
234 })
235 .collect();
236
237 for result in &error_results {
238 if let Err(e) = self.tool_result_tx.send(result.clone()).await {
239 tracing::debug!("Failed to send tool result: {}", e);
240 }
241 }
242
243 let batch_result = ToolBatchResult {
244 batch_id,
245 session_id,
246 turn_id,
247 results: error_results,
248 };
249 if let Err(e) = self.batch_result_tx.send(batch_result).await {
250 tracing::debug!("Failed to send batch result: {}", e);
251 }
252
253 return batch_id;
254 }
255 }
256 } else {
257 match self
259 .permission_registry
260 .register_batch(session_id, all_permissions, turn_id.clone())
261 .await
262 {
263 Ok(rx) => {
264 match rx.await {
266 Ok(response) => {
267 if !response.denied_requests.is_empty() {
270 tracing::info!(
271 batch_id,
272 denied_count = response.denied_requests.len(),
273 "Batch permissions denied"
274 );
275
276 let error_results: Vec<ToolResult> = requests
278 .iter()
279 .map(|req| {
280 ToolResult::error(
281 session_id,
282 req.tool_name.clone(),
283 req.tool_use_id.clone(),
284 req.input.clone(),
285 "Permission denied by user".to_string(),
286 turn_id.clone(),
287 )
288 })
289 .collect();
290
291 for result in &error_results {
293 if let Err(e) =
294 self.tool_result_tx.send(result.clone()).await
295 {
296 tracing::debug!("Failed to send tool result: {}", e);
297 }
298 }
299
300 let batch_result = ToolBatchResult {
302 batch_id,
303 session_id,
304 turn_id,
305 results: error_results,
306 };
307 if let Err(e) = self.batch_result_tx.send(batch_result).await {
308 tracing::debug!("Failed to send batch result: {}", e);
309 }
310
311 return batch_id;
312 }
313
314 tracing::info!(
315 batch_id,
316 grant_count = response.approved_grants.len(),
317 "Batch permissions approved"
318 );
319 true
320 }
321 Err(_) => {
322 tracing::info!(batch_id, "Batch permission request cancelled");
324
325 let error_results: Vec<ToolResult> = requests
327 .iter()
328 .map(|req| {
329 ToolResult::error(
330 session_id,
331 req.tool_name.clone(),
332 req.tool_use_id.clone(),
333 req.input.clone(),
334 "Permission request cancelled".to_string(),
335 turn_id.clone(),
336 )
337 })
338 .collect();
339
340 for result in &error_results {
342 if let Err(e) = self.tool_result_tx.send(result.clone()).await {
343 tracing::debug!("Failed to send tool result: {}", e);
344 }
345 }
346
347 let batch_result = ToolBatchResult {
349 batch_id,
350 session_id,
351 turn_id,
352 results: error_results,
353 };
354 if let Err(e) = self.batch_result_tx.send(batch_result).await {
355 tracing::debug!("Failed to send batch result: {}", e);
356 }
357
358 return batch_id;
359 }
360 }
361 }
362 Err(e) => {
363 tracing::warn!(
365 batch_id,
366 error = %e,
367 "Failed to register batch permission request"
368 );
369 false
370 }
371 }
372 }
373 } else {
374 true
376 };
377
378 let batch = Arc::new(ToolExecutorBatch {
380 batch_id,
381 session_id,
382 turn_id: turn_id.clone(),
383 tool_result_tx: self.tool_result_tx.clone(),
384 batch_result_tx: self.batch_result_tx.clone(),
385 requests: requests.clone(),
386 results: Mutex::new(HashMap::new()),
387 expected_count,
388 permissions_pre_approved,
389 task_handles: Mutex::new(Vec::with_capacity(expected_count)),
390 });
391
392 for request in requests {
394 let batch_clone = batch.clone();
395 let registry = self.registry.clone();
396 let cancel = cancel_token.clone();
397 let turn_id = turn_id.clone();
398
399 let handle = tokio::spawn(async move {
400 batch_clone
401 .run_tool(registry, request, turn_id, cancel)
402 .await;
403 });
404
405 batch.task_handles.lock().await.push(handle);
407 }
408
409 batch_id
410 }
411
412 pub async fn execute(
414 &self,
415 session_id: i64,
416 turn_id: Option<TurnId>,
417 request: ToolRequest,
418 cancel_token: CancellationToken,
419 ) -> i64 {
420 self.execute_batch(session_id, turn_id, vec![request], cancel_token)
421 .await
422 }
423}
424
425struct ToolExecutorBatch {
427 batch_id: i64,
428 session_id: i64,
429 turn_id: Option<TurnId>,
430 tool_result_tx: mpsc::Sender<ToolResult>,
431 batch_result_tx: mpsc::Sender<ToolBatchResult>,
432 requests: Vec<ToolRequest>,
433 results: Mutex<HashMap<String, ToolResult>>,
434 expected_count: usize,
435 permissions_pre_approved: bool,
437 task_handles: Mutex<Vec<JoinHandle<()>>>,
439}
440
441impl ToolExecutorBatch {
442 async fn run_tool(
444 &self,
445 registry: Arc<ToolRegistry>,
446 request: ToolRequest,
447 turn_id: Option<TurnId>,
448 cancel_token: CancellationToken,
449 ) {
450 let tool_use_id = request.tool_use_id.clone();
451 let tool_name = request.tool_name.clone();
452 let input = request.input.clone();
453
454 tracing::debug!(
455 batch_id = self.batch_id,
456 session_id = self.session_id,
457 tool_name = %tool_name,
458 tool_use_id = %tool_use_id,
459 "Starting tool execution"
460 );
461
462 let tool = registry.get(&tool_name).await;
464
465 let result = match tool {
466 None => {
467 tracing::warn!(
469 batch_id = self.batch_id,
470 tool_name = %tool_name,
471 "Tool not found in registry"
472 );
473 ToolResult::error(
474 self.session_id,
475 tool_name,
476 tool_use_id,
477 input,
478 format!("Tool not found: {}", request.tool_name),
479 turn_id,
480 )
481 }
482 Some(tool) => {
483 let display_name = Some(tool.display_config().display_name);
485
486 let context = ToolContext {
488 session_id: self.session_id,
489 tool_use_id: tool_use_id.clone(),
490 turn_id: turn_id.clone(),
491 permissions_pre_approved: self.permissions_pre_approved,
492 };
493
494 tokio::select! {
496 exec_result = tool.execute(context, input.clone()) => {
497 match exec_result {
498 Ok(content) => {
499 tracing::info!(
500 batch_id = self.batch_id,
501 tool_name = %tool_name,
502 result_bytes = content.len(),
503 "Tool execution succeeded"
504 );
505 let compact_summary = Some(tool.compact_summary(&input, &content));
507 ToolResult::success(
508 self.session_id,
509 tool_name,
510 display_name,
511 tool_use_id,
512 input,
513 content,
514 turn_id,
515 compact_summary,
516 )
517 }
518 Err(error) => {
519 tracing::warn!(
520 batch_id = self.batch_id,
521 tool_name = %tool_name,
522 error = %error,
523 "Tool execution failed"
524 );
525 ToolResult::error(
526 self.session_id,
527 tool_name,
528 tool_use_id,
529 input,
530 error,
531 turn_id,
532 )
533 }
534 }
535 }
536 _ = cancel_token.cancelled() => {
537 tracing::warn!(
538 batch_id = self.batch_id,
539 tool_name = %tool_name,
540 "Tool execution cancelled"
541 );
542 ToolResult::timeout(
543 self.session_id,
544 tool_name,
545 tool_use_id,
546 input,
547 turn_id,
548 )
549 }
550 }
551 }
552 };
553
554 self.add_result(result).await;
555 }
556
557 async fn add_result(&self, result: ToolResult) {
559 if let Err(e) = self.tool_result_tx.send(result.clone()).await {
561 tracing::debug!("Failed to send tool result: {}", e);
562 }
563
564 let mut results = self.results.lock().await;
565 results.insert(result.tool_use_id.clone(), result);
566
567 tracing::debug!(
568 batch_id = self.batch_id,
569 completed = results.len(),
570 expected = self.expected_count,
571 "Tool completed in batch"
572 );
573
574 if results.len() == self.expected_count {
576 self.send_batch_result(&results).await;
577 }
578 }
579
580 async fn send_batch_result(&self, results: &HashMap<String, ToolResult>) {
582 let ordered_results: Vec<ToolResult> = self
584 .requests
585 .iter()
586 .filter_map(|req| results.get(&req.tool_use_id).cloned())
587 .collect();
588
589 let batch_result = ToolBatchResult {
590 batch_id: self.batch_id,
591 session_id: self.session_id,
592 turn_id: self.turn_id.clone(),
593 results: ordered_results,
594 };
595
596 tracing::debug!(
597 batch_id = self.batch_id,
598 session_id = self.session_id,
599 result_count = batch_result.results.len(),
600 "Sending batch result"
601 );
602
603 if let Err(e) = self.batch_result_tx.send(batch_result).await {
604 tracing::debug!("Failed to send batch result: {}", e);
605 }
606 }
607
608 #[allow(dead_code)] async fn await_completion(&self, timeout: Option<Duration>) -> (usize, usize, usize) {
621 let handles: Vec<JoinHandle<()>> = {
622 let mut guard = self.task_handles.lock().await;
623 std::mem::take(&mut *guard)
624 };
625
626 let total = handles.len();
627 let mut completed = 0;
628 let mut panicked = 0;
629 let mut timed_out = 0;
630
631 for handle in handles {
632 let result = if let Some(timeout_duration) = timeout {
633 match tokio::time::timeout(timeout_duration, handle).await {
634 Ok(join_result) => Some(join_result),
635 Err(_) => {
636 timed_out += 1;
637 tracing::warn!(
638 batch_id = self.batch_id,
639 "Task did not complete within timeout"
640 );
641 None
642 }
643 }
644 } else {
645 Some(handle.await)
646 };
647
648 if let Some(join_result) = result {
649 match join_result {
650 Ok(()) => completed += 1,
651 Err(e) => {
652 panicked += 1;
653 if e.is_panic() {
654 tracing::error!(
655 batch_id = self.batch_id,
656 error = %e,
657 "Task panicked"
658 );
659 } else {
660 tracing::warn!(
661 batch_id = self.batch_id,
662 error = %e,
663 "Task was cancelled"
664 );
665 }
666 }
667 }
668 }
669 }
670
671 tracing::debug!(
672 batch_id = self.batch_id,
673 total,
674 completed,
675 panicked,
676 timed_out,
677 "Batch task completion summary"
678 );
679
680 (completed, panicked, timed_out)
681 }
682
683 #[allow(dead_code)] async fn active_task_count(&self) -> usize {
686 let handles = self.task_handles.lock().await;
687 handles.iter().filter(|h| !h.is_finished()).count()
688 }
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694 use crate::controller::tools::types::{Executable, ToolResultStatus, ToolType};
695 use crate::controller::types::ControllerEvent;
696 use std::future::Future;
697 use std::pin::Pin;
698 use std::time::Duration;
699
700 fn create_test_permission_registry() -> Arc<PermissionRegistry> {
702 let (event_tx, _event_rx) = mpsc::channel::<ControllerEvent>(10);
703 Arc::new(PermissionRegistry::new(event_tx))
704 }
705
706 struct EchoTool;
707
708 impl Executable for EchoTool {
709 fn name(&self) -> &str {
710 "echo"
711 }
712
713 fn description(&self) -> &str {
714 "Echoes input back"
715 }
716
717 fn input_schema(&self) -> &str {
718 r#"{"type":"object","properties":{"message":{"type":"string"}}}"#
719 }
720
721 fn tool_type(&self) -> ToolType {
722 ToolType::Custom
723 }
724
725 fn execute(
726 &self,
727 _context: ToolContext,
728 input: HashMap<String, serde_json::Value>,
729 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
730 let message = input
731 .get("message")
732 .and_then(|v| v.as_str())
733 .unwrap_or("no message")
734 .to_string();
735 Box::pin(async move { Ok(format!("Echo: {}", message)) })
736 }
737 }
738
739 struct SlowTool;
740
741 impl Executable for SlowTool {
742 fn name(&self) -> &str {
743 "slow"
744 }
745
746 fn description(&self) -> &str {
747 "A slow tool for testing timeouts"
748 }
749
750 fn input_schema(&self) -> &str {
751 r#"{"type":"object"}"#
752 }
753
754 fn tool_type(&self) -> ToolType {
755 ToolType::Custom
756 }
757
758 fn execute(
759 &self,
760 _context: ToolContext,
761 _input: HashMap<String, serde_json::Value>,
762 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
763 Box::pin(async {
764 tokio::time::sleep(Duration::from_secs(10)).await;
765 Ok("done".to_string())
766 })
767 }
768 }
769
770 #[tokio::test]
771 async fn test_execute_single_tool() {
772 let registry = Arc::new(ToolRegistry::new());
773 registry.register(Arc::new(EchoTool)).await.unwrap();
774
775 let permission_registry = create_test_permission_registry();
776 let (tool_tx, mut tool_rx) = mpsc::channel(10);
777 let (batch_tx, mut batch_rx) = mpsc::channel(10);
778
779 let executor = ToolExecutor::new(registry, permission_registry, tool_tx, batch_tx);
780
781 let mut input = HashMap::new();
782 input.insert(
783 "message".to_string(),
784 serde_json::Value::String("hello".to_string()),
785 );
786
787 let request = ToolRequest {
788 tool_use_id: "test_1".to_string(),
789 tool_name: "echo".to_string(),
790 input,
791 };
792
793 let cancel = CancellationToken::new();
794 executor.execute(1, None, request, cancel).await;
795
796 let result = tool_rx.recv().await.unwrap();
798 assert_eq!(result.status, ToolResultStatus::Success);
799 assert!(result.content.contains("Echo: hello"));
800
801 let batch = batch_rx.recv().await.unwrap();
803 assert_eq!(batch.results.len(), 1);
804 }
805
806 #[tokio::test]
807 async fn test_execute_batch() {
808 let registry = Arc::new(ToolRegistry::new());
809 registry.register(Arc::new(EchoTool)).await.unwrap();
810
811 let permission_registry = create_test_permission_registry();
812 let (tool_tx, mut tool_rx) = mpsc::channel(10);
813 let (batch_tx, mut batch_rx) = mpsc::channel(10);
814
815 let executor = ToolExecutor::new(registry, permission_registry, tool_tx, batch_tx);
816
817 let requests: Vec<ToolRequest> = (0..3)
818 .map(|i| {
819 let mut input = HashMap::new();
820 input.insert(
821 "message".to_string(),
822 serde_json::Value::String(format!("msg_{}", i)),
823 );
824 ToolRequest {
825 tool_use_id: format!("tool_{}", i),
826 tool_name: "echo".to_string(),
827 input,
828 }
829 })
830 .collect();
831
832 let cancel = CancellationToken::new();
833 executor.execute_batch(1, None, requests, cancel).await;
834
835 for _ in 0..3 {
837 let result = tool_rx.recv().await.unwrap();
838 assert_eq!(result.status, ToolResultStatus::Success);
839 }
840
841 let batch = batch_rx.recv().await.unwrap();
843 assert_eq!(batch.results.len(), 3);
844 }
845
846 #[tokio::test]
847 async fn test_tool_not_found() {
848 let registry = Arc::new(ToolRegistry::new());
849
850 let permission_registry = create_test_permission_registry();
851 let (tool_tx, mut tool_rx) = mpsc::channel(10);
852 let (batch_tx, _batch_rx) = mpsc::channel(10);
853
854 let executor = ToolExecutor::new(registry, permission_registry, tool_tx, batch_tx);
855
856 let request = ToolRequest {
857 tool_use_id: "test_1".to_string(),
858 tool_name: "nonexistent".to_string(),
859 input: HashMap::new(),
860 };
861
862 let cancel = CancellationToken::new();
863 executor.execute(1, None, request, cancel).await;
864
865 let result = tool_rx.recv().await.unwrap();
866 assert_eq!(result.status, ToolResultStatus::Error);
867 assert!(result.error.unwrap().contains("not found"));
868 }
869
870 #[tokio::test]
871 async fn test_tool_cancellation() {
872 let registry = Arc::new(ToolRegistry::new());
873 registry.register(Arc::new(SlowTool)).await.unwrap();
874
875 let permission_registry = create_test_permission_registry();
876 let (tool_tx, mut tool_rx) = mpsc::channel(10);
877 let (batch_tx, _batch_rx) = mpsc::channel(10);
878
879 let executor = ToolExecutor::new(registry, permission_registry, tool_tx, batch_tx);
880
881 let request = ToolRequest {
882 tool_use_id: "test_1".to_string(),
883 tool_name: "slow".to_string(),
884 input: HashMap::new(),
885 };
886
887 let cancel = CancellationToken::new();
888 let cancel_clone = cancel.clone();
889
890 executor.execute(1, None, request, cancel).await;
892
893 tokio::spawn(async move {
895 tokio::time::sleep(Duration::from_millis(50)).await;
896 cancel_clone.cancel();
897 });
898
899 let result = tool_rx.recv().await.unwrap();
901 assert_eq!(result.status, ToolResultStatus::Timeout);
902 }
903}