1use crate::AgentError;
7use crate::constants::env::ai;
8use crate::types::{
9 Message, MessageRole, ToolAnnotations, ToolCall, ToolDefinition, ToolInputSchema, ToolResult,
10};
11use futures_util::stream::{self, StreamExt};
12use serde::Serialize;
13
14use crate::tool_errors::format_tool_error;
15use crate::tool_result_storage::process_tool_result;
16use crate::tool_validation::validate_tool_input;
17
18pub const MAX_TOOL_USE_CONCURRENCY: usize = 10;
20
21pub fn get_max_tool_use_concurrency() -> usize {
23 std::env::var(ai::MAX_TOOL_USE_CONCURRENCY)
24 .ok()
25 .and_then(|v| v.parse::<usize>().ok())
26 .unwrap_or(MAX_TOOL_USE_CONCURRENCY)
27}
28
29#[derive(Debug, Clone)]
31pub struct ToolBatch {
32 pub is_concurrency_safe: bool,
34 pub blocks: Vec<ToolCall>,
36}
37
38#[derive(Debug, Clone)]
40pub struct ContextModifier {
41 pub tool_use_id: String,
42 pub modify_context: fn(crate::types::ToolContext) -> crate::types::ToolContext,
43}
44
45#[derive(Debug, Clone)]
47pub struct ToolMessageUpdate {
48 pub message: Option<Message>,
50 pub new_context: Option<crate::types::ToolContext>,
52 pub context_modifier: Option<ContextModifier>,
54}
55
56pub fn partition_tool_calls(tool_calls: &[ToolCall], tools: &[ToolDefinition]) -> Vec<ToolBatch> {
60 let mut batches: Vec<ToolBatch> = Vec::new();
61
62 for tool_use in tool_calls {
63 let tool = tools.iter().find(|t| t.name == tool_use.name);
65
66 let is_concurrency_safe = tool
70 .map(|t| t.is_concurrency_safe(&tool_use.arguments))
71 .unwrap_or(false);
72
73 if is_concurrency_safe {
75 if let Some(last) = batches.last_mut() {
76 if last.is_concurrency_safe {
77 last.blocks.push(tool_use.clone());
79 continue;
80 }
81 }
82 }
83
84 batches.push(ToolBatch {
86 is_concurrency_safe,
87 blocks: vec![tool_use.clone()],
88 });
89 }
90
91 batches
92}
93
94pub fn mark_tool_use_as_complete(
96 in_progress_ids: &mut std::collections::HashSet<String>,
97 tool_use_id: &str,
98) {
99 in_progress_ids.remove(tool_use_id);
100}
101
102pub async fn run_tools_serially<F, Fut>(
105 tool_calls: Vec<ToolCall>,
106 tool_context: crate::types::ToolContext,
107 tools: Vec<ToolDefinition>,
108 mut executor: F,
109 project_dir: Option<String>,
110 session_id: Option<String>,
111) -> Vec<ToolMessageUpdate>
112where
113 F: FnMut(String, serde_json::Value, String) -> Fut + Send,
114 Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
115{
116 let mut updates = Vec::new();
117 let mut current_context = tool_context;
118 let mut in_progress_ids = std::collections::HashSet::new();
119
120 for tool_call in tool_calls {
121 let tool_name = tool_call.name.clone();
122 let tool_args = tool_call.arguments.clone();
123 let tool_call_id = tool_call.id.clone();
124
125 in_progress_ids.insert(tool_call_id.clone());
127
128 let tool_def = tools.iter().find(|t| t.name == tool_name);
131 let interrupt_behavior = tool_def.map(|t| t.interrupt_behavior()).unwrap_or_default();
132 if !matches!(interrupt_behavior, crate::tools::types::InterruptBehavior::Block)
133 && current_context.abort_signal.is_aborted()
134 {
135 let error_content =
136 "<tool_use_error>Tool execution aborted by user interrupt</tool_use_error>"
137 .to_string();
138 updates.push(ToolMessageUpdate {
139 message: Some(Message {
140 role: MessageRole::Tool,
141 content: error_content,
142 tool_call_id: Some(tool_call_id.clone()),
143 is_error: Some(true),
144 ..Default::default()
145 }),
146 new_context: Some(current_context.clone()),
147 context_modifier: None,
148 });
149 mark_tool_use_as_complete(&mut in_progress_ids, &tool_call_id);
150 continue;
151 }
152
153 if let Err(validation_err) = validate_tool_input(&tool_name, &tool_args, &tools) {
155 let error_content = format!(
156 "<tool_use_error>InputValidationError: {}</tool_use_error>",
157 validation_err
158 );
159 updates.push(ToolMessageUpdate {
160 message: Some(Message {
161 role: MessageRole::Tool,
162 content: error_content,
163 tool_call_id: Some(tool_call_id.clone()),
164 is_error: Some(true),
165 ..Default::default()
166 }),
167 new_context: Some(current_context.clone()),
168 context_modifier: None,
169 });
170 mark_tool_use_as_complete(&mut in_progress_ids, &tool_call_id);
171 continue;
172 }
173
174 match executor(tool_name.clone(), tool_args.clone(), tool_call_id.clone()).await {
176 Ok(mut result) => {
177 let persisted = process_tool_result(
179 &result.content,
180 &tool_name,
181 &tool_call_id,
182 project_dir.as_deref(),
183 session_id.as_deref(),
184 None, );
186 result.content = persisted.0;
187 result.was_persisted = Some(persisted.1);
188
189 let message = Message {
190 role: MessageRole::Tool,
191 content: result.content,
192 tool_call_id: Some(tool_call_id.clone()),
193 is_error: result.is_error,
194 ..Default::default()
195 };
196
197 updates.push(ToolMessageUpdate {
198 message: Some(message),
199 new_context: Some(current_context.clone()),
200 context_modifier: None,
201 });
202 }
203 Err(e) => {
204 let error_content = format!(
206 "<tool_use_error>Error: {}</tool_use_error>",
207 format_tool_error(&e)
208 );
209 let message = Message {
210 role: MessageRole::Tool,
211 content: error_content,
212 tool_call_id: Some(tool_call_id.clone()),
213 is_error: Some(true),
214 ..Default::default()
215 };
216
217 updates.push(ToolMessageUpdate {
218 message: Some(message),
219 new_context: Some(current_context.clone()),
220 context_modifier: None,
221 });
222 }
223 }
224
225 mark_tool_use_as_complete(&mut in_progress_ids, &tool_call_id);
227 }
228
229 updates
230}
231
232pub async fn run_tools_concurrently<F, Fut>(
235 tool_calls: Vec<ToolCall>,
236 tool_context: crate::types::ToolContext,
237 tools: Vec<ToolDefinition>,
238 mut executor: F,
239 project_dir: Option<String>,
240 session_id: Option<String>,
241) -> Vec<ToolMessageUpdate>
242where
243 F: FnMut(String, serde_json::Value, String) -> Fut + Send + Clone + 'static,
244 Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
245{
246 let max_concurrency = get_max_tool_use_concurrency();
247 let mut updates = Vec::new();
248
249 let executions: Vec<_> = tool_calls
251 .into_iter()
252 .map(|tool_call| {
253 let mut exec = executor.clone();
254 let tool_name = tool_call.name.clone();
255 let tool_args = tool_call.arguments.clone();
256 let tool_call_id = tool_call.id.clone();
257 let tools = tools.clone();
258 let project_dir = project_dir.clone();
259 let session_id = session_id.clone();
260 let abort_signal = tool_context.abort_signal.clone();
261
262 async move {
263 let tool_def = tools.iter().find(|t| t.name == tool_name);
266 let interrupt_behavior =
267 tool_def.map(|t| t.interrupt_behavior()).unwrap_or_default();
268 if !matches!(interrupt_behavior, crate::tools::types::InterruptBehavior::Block)
269 && abort_signal.is_aborted()
270 {
271 return (
272 tool_call_id,
273 Err(AgentError::Tool("Tool execution aborted by user interrupt".to_string())),
274 );
275 }
276
277 if let Err(validation_err) = validate_tool_input(&tool_name, &tool_args, &tools) {
279 let error_content = format!(
280 "<tool_use_error>InputValidationError: {}</tool_use_error>",
281 validation_err
282 );
283 return (
284 tool_call_id,
285 Err(AgentError::Tool(format!(
286 "InputValidationError: {}",
287 validation_err
288 ))),
289 );
290 }
291 let result = exec(tool_name.clone(), tool_args, tool_call_id.clone()).await;
292 (tool_call_id, result)
293 }
294 })
295 .collect();
296
297 let mut stream = stream::iter(executions).buffer_unordered(max_concurrency);
299
300 while let Some((tool_call_id, result)) = stream.next().await {
301 match result {
302 Ok(tool_result) => {
303 let (content, _) = process_tool_result(
305 &tool_result.content,
306 "", &tool_call_id,
308 project_dir.as_deref(),
309 session_id.as_deref(),
310 None,
311 );
312 let message = Message {
313 role: MessageRole::Tool,
314 content,
315 tool_call_id: Some(tool_call_id),
316 ..Default::default()
317 };
318
319 updates.push(ToolMessageUpdate {
320 message: Some(message),
321 new_context: None,
322 context_modifier: None,
323 });
324 }
325 Err(e) => {
326 let error_content = format!(
327 "<tool_use_error>Error: {}</tool_use_error>",
328 format_tool_error(&e)
329 );
330 let message = Message {
331 role: MessageRole::Tool,
332 content: error_content,
333 tool_call_id: Some(tool_call_id),
334 is_error: Some(true),
335 ..Default::default()
336 };
337
338 updates.push(ToolMessageUpdate {
339 message: Some(message),
340 new_context: None,
341 context_modifier: None,
342 });
343 }
344 }
345 }
346
347 updates
348}
349
350pub async fn run_tools<F, Fut>(
353 tool_calls: Vec<ToolCall>,
354 tools: Vec<ToolDefinition>,
355 tool_context: crate::types::ToolContext,
356 executor: F,
357 project_dir: Option<String>,
358 session_id: Option<String>,
359) -> Vec<ToolMessageUpdate>
360where
361 F: FnMut(String, serde_json::Value, String) -> Fut + Send + Clone + 'static,
362 Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
363{
364 let batches = partition_tool_calls(&tool_calls, &tools);
365 let mut all_updates = Vec::new();
366 let mut current_context = tool_context;
367
368 for batch in batches {
369 let tools_clone = tools.clone();
370 let project_dir_clone = project_dir.clone();
371 let session_id_clone = session_id.clone();
372
373 if batch.is_concurrency_safe {
374 let updates = run_tools_concurrently(
376 batch.blocks,
377 current_context.clone(),
378 tools_clone,
379 executor.clone(),
380 project_dir_clone,
381 session_id_clone,
382 )
383 .await;
384 all_updates.extend(updates);
385 } else {
386 let updates = run_tools_serially(
388 batch.blocks,
389 current_context.clone(),
390 tools_clone,
391 executor.clone(),
392 project_dir_clone,
393 session_id_clone,
394 )
395 .await;
396
397 if let Some(last_update) = updates.last() {
399 if let Some(ctx) = &last_update.new_context {
400 current_context = ctx.clone();
401 }
402 }
403
404 all_updates.extend(updates);
405 }
406 }
407
408 all_updates
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use crate::types::ToolInputSchema;
415
416 fn create_test_tool(name: &str, concurrency_safe: bool) -> ToolDefinition {
417 ToolDefinition {
418 name: name.to_string(),
419 description: format!("Test tool {}", name),
420 input_schema: ToolInputSchema {
421 schema_type: "object".to_string(),
422 properties: serde_json::json!({}),
423 required: None,
424 },
425 annotations: if concurrency_safe {
426 Some(ToolAnnotations {
427 concurrency_safe: Some(true),
428 ..Default::default()
429 })
430 } else {
431 None
432 },
433 should_defer: None,
434 always_load: None,
435 is_mcp: None,
436 search_hint: None,
437 aliases: None,
438 user_facing_name: None,
439 interrupt_behavior: None,
440 }
441 }
442
443 #[test]
444 fn test_get_max_tool_use_concurrency_default() {
445 assert_eq!(get_max_tool_use_concurrency(), MAX_TOOL_USE_CONCURRENCY);
448 }
449
450 #[test]
451 fn test_get_max_tool_use_concurrency_value() {
452 let result = get_max_tool_use_concurrency();
454 assert!(result > 0);
455 }
456
457 #[test]
458 fn test_partition_tool_calls_all_non_safe() {
459 let tool_calls = vec![
460 ToolCall {
461 id: "1".to_string(),
462 r#type: "function".to_string(),
463 name: "Bash".to_string(),
464 arguments: serde_json::json!({}),
465 },
466 ToolCall {
467 id: "2".to_string(),
468 r#type: "function".to_string(),
469 name: "Edit".to_string(),
470 arguments: serde_json::json!({}),
471 },
472 ];
473 let tools = vec![
474 create_test_tool("Bash", false),
475 create_test_tool("Edit", false),
476 ];
477
478 let batches = partition_tool_calls(&tool_calls, &tools);
479 assert_eq!(batches.len(), 2);
480 assert!(!batches[0].is_concurrency_safe);
481 assert!(!batches[1].is_concurrency_safe);
482 }
483
484 #[test]
485 fn test_partition_tool_calls_mixed() {
486 let tool_calls = vec![
487 ToolCall {
488 id: "1".to_string(),
489 r#type: "function".to_string(),
490 name: "Read".to_string(),
491 arguments: serde_json::json!({}),
492 },
493 ToolCall {
494 id: "2".to_string(),
495 r#type: "function".to_string(),
496 name: "Glob".to_string(),
497 arguments: serde_json::json!({}),
498 },
499 ToolCall {
500 id: "3".to_string(),
501 r#type: "function".to_string(),
502 name: "Bash".to_string(),
503 arguments: serde_json::json!({}),
504 },
505 ToolCall {
506 id: "4".to_string(),
507 r#type: "function".to_string(),
508 name: "Grep".to_string(),
509 arguments: serde_json::json!({}),
510 },
511 ];
512 let tools = vec![
513 create_test_tool("Read", true),
514 create_test_tool("Glob", true),
515 create_test_tool("Bash", false),
516 create_test_tool("Grep", true),
517 ];
518
519 let batches = partition_tool_calls(&tool_calls, &tools);
520 assert_eq!(batches.len(), 3);
522 assert!(batches[0].is_concurrency_safe);
523 assert_eq!(batches[0].blocks.len(), 2);
524 assert!(!batches[1].is_concurrency_safe);
525 assert!(batches[2].is_concurrency_safe);
526 }
527
528 #[test]
529 fn test_partition_tool_calls_with_unknown_tool() {
530 let tool_calls = vec![ToolCall {
531 id: "1".to_string(),
532 r#type: "function".to_string(),
533 name: "UnknownTool".to_string(),
534 arguments: serde_json::json!({}),
535 }];
536 let tools = vec![];
537
538 let batches = partition_tool_calls(&tool_calls, &tools);
539 assert_eq!(batches.len(), 1);
540 assert!(!batches[0].is_concurrency_safe);
542 }
543
544 #[tokio::test]
545 async fn test_run_tools_serially() {
546 let tool_calls = vec![ToolCall {
547 id: "1".to_string(),
548 r#type: "function".to_string(),
549 name: "test".to_string(),
550 arguments: serde_json::json!({}),
551 }];
552
553 let tool_context = crate::types::ToolContext::default();
554 let tools = vec![create_test_tool("test", false)];
555
556 let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
557 Ok(crate::types::ToolResult {
558 result_type: "tool_result".to_string(),
559 tool_use_id: "1".to_string(),
560 content: "success".to_string(),
561 is_error: Some(false),
562 was_persisted: Some(false),
563 })
564 };
565
566 let updates =
567 run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
568 assert_eq!(updates.len(), 1);
569 assert!(updates[0].message.is_some());
570 }
571
572 #[tokio::test]
573 async fn test_run_tools_concurrently() {
574 let tool_calls = vec![
575 ToolCall {
576 id: "1".to_string(),
577 r#type: "function".to_string(),
578 name: "test1".to_string(),
579 arguments: serde_json::json!({}),
580 },
581 ToolCall {
582 id: "2".to_string(),
583 r#type: "function".to_string(),
584 name: "test2".to_string(),
585 arguments: serde_json::json!({}),
586 },
587 ];
588
589 let tool_context = crate::types::ToolContext::default();
590 let tools = vec![
591 create_test_tool("test1", true),
592 create_test_tool("test2", true),
593 ];
594
595 let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
596 Ok(crate::types::ToolResult {
597 result_type: "tool_result".to_string(),
598 tool_use_id: "1".to_string(),
599 content: "success".to_string(),
600 is_error: Some(false),
601 was_persisted: Some(false),
602 })
603 };
604
605 let updates =
606 run_tools_concurrently(tool_calls, tool_context, tools, executor, None, None).await;
607 assert_eq!(updates.len(), 2);
608 }
609
610 #[tokio::test]
611 async fn test_run_tools_with_partitioning() {
612 let tool_calls = vec![
613 ToolCall {
614 id: "1".to_string(),
615 r#type: "function".to_string(),
616 name: "Read".to_string(),
617 arguments: serde_json::json!({}),
618 },
619 ToolCall {
620 id: "2".to_string(),
621 r#type: "function".to_string(),
622 name: "Glob".to_string(),
623 arguments: serde_json::json!({}),
624 },
625 ToolCall {
626 id: "3".to_string(),
627 r#type: "function".to_string(),
628 name: "Bash".to_string(),
629 arguments: serde_json::json!({}),
630 },
631 ];
632 let tools = vec![
633 create_test_tool("Read", true),
634 create_test_tool("Glob", true),
635 create_test_tool("Bash", false),
636 ];
637
638 let tool_context = crate::types::ToolContext::default();
639
640 let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
641 Ok(crate::types::ToolResult {
642 result_type: "tool_result".to_string(),
643 tool_use_id: "1".to_string(),
644 content: "success".to_string(),
645 is_error: Some(false),
646 was_persisted: Some(false),
647 })
648 };
649
650 let updates = run_tools(tool_calls, tools, tool_context, executor, None, None).await;
651 assert_eq!(updates.len(), 3);
652 }
653
654 #[test]
655 fn test_mark_tool_use_as_complete() {
656 let mut in_progress = std::collections::HashSet::new();
657 in_progress.insert("tool1".to_string());
658 in_progress.insert("tool2".to_string());
659
660 mark_tool_use_as_complete(&mut in_progress, "tool1");
661
662 assert!(!in_progress.contains("tool1"));
663 assert!(in_progress.contains("tool2"));
664 }
665
666 #[tokio::test]
667 async fn test_run_tools_serially_aborted() {
668 use crate::utils::abort_controller::create_abort_controller_default;
669
670 let tool_calls = vec![ToolCall {
671 id: "1".to_string(),
672 r#type: "function".to_string(),
673 name: "test".to_string(),
674 arguments: serde_json::json!({}),
675 }];
676
677 let controller = create_abort_controller_default();
678 controller.abort(None); let abort_signal = controller.signal().clone();
680
681 let tool_context = crate::types::ToolContext {
682 cwd: "/tmp".to_string(),
683 abort_signal,
684 };
685 let tools = vec![create_tool_with_interrupt("test", Some("cancel".into()))];
686
687 let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
688 Ok(crate::types::ToolResult {
689 result_type: "tool_result".to_string(),
690 tool_use_id: "1".to_string(),
691 content: "should not reach".to_string(),
692 is_error: Some(false),
693 was_persisted: Some(false),
694 })
695 };
696
697 let updates =
698 run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
699 assert_eq!(updates.len(), 1);
700 let msg = updates[0].message.as_ref().unwrap();
701 assert!(msg.is_error == Some(true));
702 assert!(msg.content.contains("aborted"));
703 }
704
705 #[tokio::test]
706 async fn test_run_tools_concurrently_aborted() {
707 use crate::utils::abort_controller::create_abort_controller_default;
708
709 let tool_calls = vec![ToolCall {
710 id: "1".to_string(),
711 r#type: "function".to_string(),
712 name: "Read".to_string(),
713 arguments: serde_json::json!({}),
714 }];
715
716 let controller = create_abort_controller_default();
717 controller.abort(None); let abort_signal = controller.signal().clone();
719
720 let tool_context = crate::types::ToolContext {
721 cwd: "/tmp".to_string(),
722 abort_signal,
723 };
724 let tools = vec![create_tool_with_interrupt("Read", Some("cancel".into()))];
725
726 let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
727 Ok(crate::types::ToolResult {
728 result_type: "tool_result".to_string(),
729 tool_use_id: "1".to_string(),
730 content: "should not reach".to_string(),
731 is_error: Some(false),
732 was_persisted: Some(false),
733 })
734 };
735
736 let updates = run_tools_concurrently(
737 tool_calls, tool_context, tools, executor, None, None,
738 )
739 .await;
740 assert_eq!(updates.len(), 1);
741 let msg = updates[0].message.as_ref().unwrap();
742 assert!(msg.is_error == Some(true));
743 }
744
745 fn create_tool_with_interrupt(
746 name: &str,
747 interrupt: Option<String>,
748 ) -> ToolDefinition {
749 ToolDefinition {
750 name: name.to_string(),
751 description: format!("Test tool {}", name),
752 input_schema: ToolInputSchema {
753 schema_type: "object".to_string(),
754 properties: serde_json::json!({}),
755 required: None,
756 },
757 annotations: None,
758 should_defer: None,
759 always_load: None,
760 is_mcp: None,
761 search_hint: None,
762 aliases: None,
763 user_facing_name: None,
764 interrupt_behavior: interrupt,
765 }
766 }
767
768 #[tokio::test]
769 async fn test_interrupt_cancel_tool_aborted() {
770 use crate::utils::abort_controller::create_abort_controller_default;
771
772 let tool_calls = vec![ToolCall {
773 id: "1".to_string(),
774 r#type: "function".to_string(),
775 name: "CancelTool".to_string(),
776 arguments: serde_json::json!({}),
777 }];
778
779 let controller = create_abort_controller_default();
780 controller.abort(None);
781 let abort_signal = controller.signal().clone();
782
783 let tool_context = crate::types::ToolContext {
784 cwd: "/tmp".to_string(),
785 abort_signal,
786 };
787 let tools = vec![create_tool_with_interrupt("CancelTool", Some("cancel".into()))];
788
789 let executor = |_name: String, _args: serde_json::Value, _id: String| async {
790 Ok(crate::types::ToolResult {
791 result_type: "tool_result".to_string(),
792 tool_use_id: "1".to_string(),
793 content: "should not reach".to_string(),
794 is_error: Some(false),
795 was_persisted: Some(false),
796 })
797 };
798
799 let updates =
800 run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
801 assert_eq!(updates.len(), 1);
802 let msg = updates[0].message.as_ref().unwrap();
803 assert!(msg.is_error == Some(true));
804 assert!(msg.content.contains("aborted"));
805 }
806
807 #[tokio::test]
808 async fn test_interrupt_block_tool_ignores_abort() {
809 use crate::utils::abort_controller::create_abort_controller_default;
810
811 let tool_calls = vec![ToolCall {
812 id: "1".to_string(),
813 r#type: "function".to_string(),
814 name: "BlockTool".to_string(),
815 arguments: serde_json::json!({}),
816 }];
817
818 let controller = create_abort_controller_default();
819 controller.abort(None); let abort_signal = controller.signal().clone();
821
822 let tool_context = crate::types::ToolContext {
823 cwd: "/tmp".to_string(),
824 abort_signal,
825 };
826 let tools = vec![create_tool_with_interrupt("BlockTool", Some("block".into()))];
827
828 let executor = |_name: String, _args: serde_json::Value, _id: String| async {
829 Ok(crate::types::ToolResult {
830 result_type: "tool_result".to_string(),
831 tool_use_id: "1".to_string(),
832 content: "block tool completed".to_string(),
833 is_error: Some(false),
834 was_persisted: Some(false),
835 })
836 };
837
838 let updates =
839 run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
840 assert_eq!(updates.len(), 1);
841 let msg = updates[0].message.as_ref().unwrap();
842 assert!(msg.is_error != Some(true));
844 assert!(msg.content.contains("block tool completed"));
845 }
846
847 #[tokio::test]
848 async fn test_interrupt_default_treated_as_block() {
849 use crate::utils::abort_controller::create_abort_controller_default;
850
851 let tool_calls = vec![ToolCall {
852 id: "1".to_string(),
853 r#type: "function".to_string(),
854 name: "DefaultTool".to_string(),
855 arguments: serde_json::json!({}),
856 }];
857
858 let controller = create_abort_controller_default();
859 controller.abort(None);
860 let abort_signal = controller.signal().clone();
861
862 let tool_context = crate::types::ToolContext {
863 cwd: "/tmp".to_string(),
864 abort_signal,
865 };
866 let tools = vec![create_tool_with_interrupt("DefaultTool", None)];
868
869 let executor = |_name: String, _args: serde_json::Value, _id: String| async {
870 Ok(crate::types::ToolResult {
871 result_type: "tool_result".to_string(),
872 tool_use_id: "1".to_string(),
873 content: "default completed".to_string(),
874 is_error: Some(false),
875 was_persisted: Some(false),
876 })
877 };
878
879 let updates =
880 run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
881 assert_eq!(updates.len(), 1);
882 let msg = updates[0].message.as_ref().unwrap();
883 assert!(msg.is_error != Some(true));
885 }
886
887 #[tokio::test]
888 async fn test_interrupt_concurrently_block_ignores_abort() {
889 use crate::utils::abort_controller::create_abort_controller_default;
890
891 let tool_calls = vec![ToolCall {
892 id: "1".to_string(),
893 r#type: "function".to_string(),
894 name: "BlockTool".to_string(),
895 arguments: serde_json::json!({}),
896 }];
897
898 let controller = create_abort_controller_default();
899 controller.abort(None);
900 let abort_signal = controller.signal().clone();
901
902 let tool_context = crate::types::ToolContext {
903 cwd: "/tmp".to_string(),
904 abort_signal,
905 };
906 let tools =
907 vec![create_tool_with_interrupt("BlockTool", Some("block".into()))];
908
909 let executor = |_name: String, _args: serde_json::Value, _id: String| async {
910 Ok(crate::types::ToolResult {
911 result_type: "tool_result".to_string(),
912 tool_use_id: "1".to_string(),
913 content: "concurrent block done".to_string(),
914 is_error: Some(false),
915 was_persisted: Some(false),
916 })
917 };
918
919 let updates =
920 run_tools_concurrently(tool_calls, tool_context, tools, executor, None, None).await;
921 assert_eq!(updates.len(), 1);
922 let msg = updates[0].message.as_ref().unwrap();
923 assert!(msg.content.contains("concurrent block done"));
924 }
925}