Skip to main content

ai_agent/tools/
orchestration.rs

1// Source: /data/home/swei/claudecode/openclaudecode/src/services/tools/toolOrchestration.ts
2//! Tool orchestration module for running tools with concurrency control.
3//!
4//! Translated from TypeScript toolOrchestration.ts
5
6use crate::AgentError;
7use crate::constants::env::ai;
8use crate::types::{
9    Message, MessageRole, ToolAnnotations, ToolCall, ToolDefinition, ToolInputSchema, ToolResult,
10};
11use futures_util::stream::{self, StreamExt};
12use serde::Serialize;
13
14use crate::tool_errors::format_tool_error;
15use crate::tool_result_storage::process_tool_result;
16use crate::tool_validation::validate_tool_input;
17
18/// Maximum number of concurrent tool executions (matches TypeScript default)
19pub const MAX_TOOL_USE_CONCURRENCY: usize = 10;
20
21/// Get max tool use concurrency from environment variable
22pub fn get_max_tool_use_concurrency() -> usize {
23    std::env::var(ai::MAX_TOOL_USE_CONCURRENCY)
24        .ok()
25        .and_then(|v| v.parse::<usize>().ok())
26        .unwrap_or(MAX_TOOL_USE_CONCURRENCY)
27}
28
29/// A batch of tool calls that can be executed together
30#[derive(Debug, Clone)]
31pub struct ToolBatch {
32    /// Whether this batch is concurrency-safe (can run in parallel)
33    pub is_concurrency_safe: bool,
34    /// The tool calls in this batch
35    pub blocks: Vec<ToolCall>,
36}
37
38/// Modifier for tool use context (for contextModifier support).
39#[derive(Debug, Clone)]
40pub struct ContextModifier {
41    pub tool_use_id: String,
42    pub modify_context: fn(crate::types::ToolContext) -> crate::types::ToolContext,
43}
44
45/// Message update for tool orchestration
46#[derive(Debug, Clone)]
47pub struct ToolMessageUpdate {
48    /// The message to add to the conversation
49    pub message: Option<Message>,
50    /// Updated context after this tool (for serial execution)
51    pub new_context: Option<crate::types::ToolContext>,
52    /// Context modifiers for this tool result (for contextModifier support)
53    pub context_modifier: Option<ContextModifier>,
54}
55
56/// Partition tool calls into batches where each batch is either:
57/// 1. A single non-concurrency-safe tool, or
58/// 2. Multiple consecutive concurrency-safe tools
59pub fn partition_tool_calls(tool_calls: &[ToolCall], tools: &[ToolDefinition]) -> Vec<ToolBatch> {
60    let mut batches: Vec<ToolBatch> = Vec::new();
61
62    for tool_use in tool_calls {
63        // Find the tool definition
64        let tool = tools.iter().find(|t| t.name == tool_use.name);
65
66        // Check concurrency safety
67        // Matches TypeScript: use the tool's isConcurrencySafe method
68        // If tool not found or isConcurrencySafe throws, treat as not concurrency-safe
69        let is_concurrency_safe = tool
70            .map(|t| t.is_concurrency_safe(&tool_use.arguments))
71            .unwrap_or(false);
72
73        // Check if we can add to the last batch
74        if is_concurrency_safe {
75            if let Some(last) = batches.last_mut() {
76                if last.is_concurrency_safe {
77                    // Add to existing concurrency-safe batch
78                    last.blocks.push(tool_use.clone());
79                    continue;
80                }
81            }
82        }
83
84        // Create new batch (either non-concurrency-safe or first in a concurrency-safe group)
85        batches.push(ToolBatch {
86            is_concurrency_safe,
87            blocks: vec![tool_use.clone()],
88        });
89    }
90
91    batches
92}
93
94/// Mark a tool use as complete (removes from in-progress set)
95pub fn mark_tool_use_as_complete(
96    in_progress_ids: &mut std::collections::HashSet<String>,
97    tool_use_id: &str,
98) {
99    in_progress_ids.remove(tool_use_id);
100}
101
102/// Run tools serially (for non-concurrency-safe tools)
103/// This matches TypeScript's runToolsSerially function
104pub async fn run_tools_serially<F, Fut>(
105    tool_calls: Vec<ToolCall>,
106    tool_context: crate::types::ToolContext,
107    tools: Vec<ToolDefinition>,
108    mut executor: F,
109    project_dir: Option<String>,
110    session_id: Option<String>,
111) -> Vec<ToolMessageUpdate>
112where
113    F: FnMut(String, serde_json::Value, String) -> Fut + Send,
114    Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
115{
116    let mut updates = Vec::new();
117    let mut current_context = tool_context;
118    let mut in_progress_ids = std::collections::HashSet::new();
119
120    for tool_call in tool_calls {
121        let tool_name = tool_call.name.clone();
122        let tool_args = tool_call.arguments.clone();
123        let tool_call_id = tool_call.id.clone();
124
125        // Mark tool as in progress
126        in_progress_ids.insert(tool_call_id.clone());
127
128        // Check abort signal based on tool's interruptBehavior
129        // 'block' tools ignore abort and complete; 'cancel'/default respect abort
130        let tool_def = tools.iter().find(|t| t.name == tool_name);
131        let interrupt_behavior = tool_def.map(|t| t.interrupt_behavior()).unwrap_or_default();
132        if !matches!(interrupt_behavior, crate::tools::types::InterruptBehavior::Block)
133            && current_context.abort_signal.is_aborted()
134        {
135            let error_content =
136                "<tool_use_error>Tool execution aborted by user interrupt</tool_use_error>"
137                    .to_string();
138            updates.push(ToolMessageUpdate {
139                message: Some(Message {
140                    role: MessageRole::Tool,
141                    content: error_content,
142                    tool_call_id: Some(tool_call_id.clone()),
143                    is_error: Some(true),
144                    ..Default::default()
145                }),
146                new_context: Some(current_context.clone()),
147                context_modifier: None,
148            });
149            mark_tool_use_as_complete(&mut in_progress_ids, &tool_call_id);
150            continue;
151        }
152
153        // Input validation (matches TS Zod schema validation)
154        if let Err(validation_err) = validate_tool_input(&tool_name, &tool_args, &tools) {
155            let error_content = format!(
156                "<tool_use_error>InputValidationError: {}</tool_use_error>",
157                validation_err
158            );
159            updates.push(ToolMessageUpdate {
160                message: Some(Message {
161                    role: MessageRole::Tool,
162                    content: error_content,
163                    tool_call_id: Some(tool_call_id.clone()),
164                    is_error: Some(true),
165                    ..Default::default()
166                }),
167                new_context: Some(current_context.clone()),
168                context_modifier: None,
169            });
170            mark_tool_use_as_complete(&mut in_progress_ids, &tool_call_id);
171            continue;
172        }
173
174        // Execute the tool (pass tool_call_id)
175        match executor(tool_name.clone(), tool_args.clone(), tool_call_id.clone()).await {
176            Ok(mut result) => {
177                // Large result persistence (matches TS processToolResultBlock)
178                let persisted = process_tool_result(
179                    &result.content,
180                    &tool_name,
181                    &tool_call_id,
182                    project_dir.as_deref(),
183                    session_id.as_deref(),
184                    None, // Use default threshold
185                );
186                result.content = persisted.0;
187                result.was_persisted = Some(persisted.1);
188
189                let message = Message {
190                    role: MessageRole::Tool,
191                    content: result.content,
192                    tool_call_id: Some(tool_call_id.clone()),
193                    is_error: result.is_error,
194                    ..Default::default()
195                };
196
197                updates.push(ToolMessageUpdate {
198                    message: Some(message),
199                    new_context: Some(current_context.clone()),
200                    context_modifier: None,
201                });
202            }
203            Err(e) => {
204                // Format error using tool_errors (matches TS formatError)
205                let error_content = format!(
206                    "<tool_use_error>Error: {}</tool_use_error>",
207                    format_tool_error(&e)
208                );
209                let message = Message {
210                    role: MessageRole::Tool,
211                    content: error_content,
212                    tool_call_id: Some(tool_call_id.clone()),
213                    is_error: Some(true),
214                    ..Default::default()
215                };
216
217                updates.push(ToolMessageUpdate {
218                    message: Some(message),
219                    new_context: Some(current_context.clone()),
220                    context_modifier: None,
221                });
222            }
223        }
224
225        // Mark tool as complete
226        mark_tool_use_as_complete(&mut in_progress_ids, &tool_call_id);
227    }
228
229    updates
230}
231
232/// Run tools concurrently (for concurrency-safe tools)
233/// Uses the all() generator pattern from TypeScript with concurrency limit
234pub async fn run_tools_concurrently<F, Fut>(
235    tool_calls: Vec<ToolCall>,
236    tool_context: crate::types::ToolContext,
237    tools: Vec<ToolDefinition>,
238    mut executor: F,
239    project_dir: Option<String>,
240    session_id: Option<String>,
241) -> Vec<ToolMessageUpdate>
242where
243    F: FnMut(String, serde_json::Value, String) -> Fut + Send + Clone + 'static,
244    Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
245{
246    let max_concurrency = get_max_tool_use_concurrency();
247    let mut updates = Vec::new();
248
249    // Create a stream of tool executions
250    let executions: Vec<_> = tool_calls
251        .into_iter()
252        .map(|tool_call| {
253            let mut exec = executor.clone();
254            let tool_name = tool_call.name.clone();
255            let tool_args = tool_call.arguments.clone();
256            let tool_call_id = tool_call.id.clone();
257            let tools = tools.clone();
258            let project_dir = project_dir.clone();
259            let session_id = session_id.clone();
260            let abort_signal = tool_context.abort_signal.clone();
261
262            async move {
263                // Check abort signal based on interruptBehavior
264                // 'block' tools ignore abort; 'cancel'/default respect abort
265                let tool_def = tools.iter().find(|t| t.name == tool_name);
266                let interrupt_behavior =
267                    tool_def.map(|t| t.interrupt_behavior()).unwrap_or_default();
268                if !matches!(interrupt_behavior, crate::tools::types::InterruptBehavior::Block)
269                    && abort_signal.is_aborted()
270                {
271                    return (
272                        tool_call_id,
273                        Err(AgentError::Tool("Tool execution aborted by user interrupt".to_string())),
274                    );
275                }
276
277                // Input validation
278                if let Err(validation_err) = validate_tool_input(&tool_name, &tool_args, &tools) {
279                    let error_content = format!(
280                        "<tool_use_error>InputValidationError: {}</tool_use_error>",
281                        validation_err
282                    );
283                    return (
284                        tool_call_id,
285                        Err(AgentError::Tool(format!(
286                            "InputValidationError: {}",
287                            validation_err
288                        ))),
289                    );
290                }
291                let result = exec(tool_name.clone(), tool_args, tool_call_id.clone()).await;
292                (tool_call_id, result)
293            }
294        })
295        .collect();
296
297    // Run with bounded concurrency using buffer_unordered
298    let mut stream = stream::iter(executions).buffer_unordered(max_concurrency);
299
300    while let Some((tool_call_id, result)) = stream.next().await {
301        match result {
302            Ok(tool_result) => {
303                // Large result persistence
304                let (content, _) = process_tool_result(
305                    &tool_result.content,
306                    "", // tool name not tracked in concurrent path
307                    &tool_call_id,
308                    project_dir.as_deref(),
309                    session_id.as_deref(),
310                    None,
311                );
312                let message = Message {
313                    role: MessageRole::Tool,
314                    content,
315                    tool_call_id: Some(tool_call_id),
316                    ..Default::default()
317                };
318
319                updates.push(ToolMessageUpdate {
320                    message: Some(message),
321                    new_context: None,
322                    context_modifier: None,
323                });
324            }
325            Err(e) => {
326                let error_content = format!(
327                    "<tool_use_error>Error: {}</tool_use_error>",
328                    format_tool_error(&e)
329                );
330                let message = Message {
331                    role: MessageRole::Tool,
332                    content: error_content,
333                    tool_call_id: Some(tool_call_id),
334                    is_error: Some(true),
335                    ..Default::default()
336                };
337
338                updates.push(ToolMessageUpdate {
339                    message: Some(message),
340                    new_context: None,
341                    context_modifier: None,
342                });
343            }
344        }
345    }
346
347    updates
348}
349
350/// Run all tools with proper partitioning and concurrency
351/// This is the main entry point that matches TypeScript's runTools()
352pub async fn run_tools<F, Fut>(
353    tool_calls: Vec<ToolCall>,
354    tools: Vec<ToolDefinition>,
355    tool_context: crate::types::ToolContext,
356    executor: F,
357    project_dir: Option<String>,
358    session_id: Option<String>,
359) -> Vec<ToolMessageUpdate>
360where
361    F: FnMut(String, serde_json::Value, String) -> Fut + Send + Clone + 'static,
362    Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
363{
364    let batches = partition_tool_calls(&tool_calls, &tools);
365    let mut all_updates = Vec::new();
366    let mut current_context = tool_context;
367
368    for batch in batches {
369        let tools_clone = tools.clone();
370        let project_dir_clone = project_dir.clone();
371        let session_id_clone = session_id.clone();
372
373        if batch.is_concurrency_safe {
374            // Run concurrency-safe batch concurrently
375            let updates = run_tools_concurrently(
376                batch.blocks,
377                current_context.clone(),
378                tools_clone,
379                executor.clone(),
380                project_dir_clone,
381                session_id_clone,
382            )
383            .await;
384            all_updates.extend(updates);
385        } else {
386            // Run non-concurrency-safe batch serially
387            let updates = run_tools_serially(
388                batch.blocks,
389                current_context.clone(),
390                tools_clone,
391                executor.clone(),
392                project_dir_clone,
393                session_id_clone,
394            )
395            .await;
396
397            // Update context after serial execution
398            if let Some(last_update) = updates.last() {
399                if let Some(ctx) = &last_update.new_context {
400                    current_context = ctx.clone();
401                }
402            }
403
404            all_updates.extend(updates);
405        }
406    }
407
408    all_updates
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use crate::types::ToolInputSchema;
415
416    fn create_test_tool(name: &str, concurrency_safe: bool) -> ToolDefinition {
417        ToolDefinition {
418            name: name.to_string(),
419            description: format!("Test tool {}", name),
420            input_schema: ToolInputSchema {
421                schema_type: "object".to_string(),
422                properties: serde_json::json!({}),
423                required: None,
424            },
425            annotations: if concurrency_safe {
426                Some(ToolAnnotations {
427                    concurrency_safe: Some(true),
428                    ..Default::default()
429                })
430            } else {
431                None
432            },
433            should_defer: None,
434            always_load: None,
435            is_mcp: None,
436            search_hint: None,
437            aliases: None,
438            user_facing_name: None,
439            interrupt_behavior: None,
440        }
441    }
442
443    #[test]
444    fn test_get_max_tool_use_concurrency_default() {
445        // Without env var, should return default
446        // Note: In real env with no var set, returns default
447        assert_eq!(get_max_tool_use_concurrency(), MAX_TOOL_USE_CONCURRENCY);
448    }
449
450    #[test]
451    fn test_get_max_tool_use_concurrency_value() {
452        // Just test that the function returns a value
453        let result = get_max_tool_use_concurrency();
454        assert!(result > 0);
455    }
456
457    #[test]
458    fn test_partition_tool_calls_all_non_safe() {
459        let tool_calls = vec![
460            ToolCall {
461                id: "1".to_string(),
462                r#type: "function".to_string(),
463                name: "Bash".to_string(),
464                arguments: serde_json::json!({}),
465            },
466            ToolCall {
467                id: "2".to_string(),
468                r#type: "function".to_string(),
469                name: "Edit".to_string(),
470                arguments: serde_json::json!({}),
471            },
472        ];
473        let tools = vec![
474            create_test_tool("Bash", false),
475            create_test_tool("Edit", false),
476        ];
477
478        let batches = partition_tool_calls(&tool_calls, &tools);
479        assert_eq!(batches.len(), 2);
480        assert!(!batches[0].is_concurrency_safe);
481        assert!(!batches[1].is_concurrency_safe);
482    }
483
484    #[test]
485    fn test_partition_tool_calls_mixed() {
486        let tool_calls = vec![
487            ToolCall {
488                id: "1".to_string(),
489                r#type: "function".to_string(),
490                name: "Read".to_string(),
491                arguments: serde_json::json!({}),
492            },
493            ToolCall {
494                id: "2".to_string(),
495                r#type: "function".to_string(),
496                name: "Glob".to_string(),
497                arguments: serde_json::json!({}),
498            },
499            ToolCall {
500                id: "3".to_string(),
501                r#type: "function".to_string(),
502                name: "Bash".to_string(),
503                arguments: serde_json::json!({}),
504            },
505            ToolCall {
506                id: "4".to_string(),
507                r#type: "function".to_string(),
508                name: "Grep".to_string(),
509                arguments: serde_json::json!({}),
510            },
511        ];
512        let tools = vec![
513            create_test_tool("Read", true),
514            create_test_tool("Glob", true),
515            create_test_tool("Bash", false),
516            create_test_tool("Grep", true),
517        ];
518
519        let batches = partition_tool_calls(&tool_calls, &tools);
520        // Should be: [Read,Glob] (concurrency safe), [Bash] (non-safe), [Grep] (concurrency safe)
521        assert_eq!(batches.len(), 3);
522        assert!(batches[0].is_concurrency_safe);
523        assert_eq!(batches[0].blocks.len(), 2);
524        assert!(!batches[1].is_concurrency_safe);
525        assert!(batches[2].is_concurrency_safe);
526    }
527
528    #[test]
529    fn test_partition_tool_calls_with_unknown_tool() {
530        let tool_calls = vec![ToolCall {
531            id: "1".to_string(),
532            r#type: "function".to_string(),
533            name: "UnknownTool".to_string(),
534            arguments: serde_json::json!({}),
535        }];
536        let tools = vec![];
537
538        let batches = partition_tool_calls(&tool_calls, &tools);
539        assert_eq!(batches.len(), 1);
540        // Unknown tools should be treated as not concurrency-safe
541        assert!(!batches[0].is_concurrency_safe);
542    }
543
544    #[tokio::test]
545    async fn test_run_tools_serially() {
546        let tool_calls = vec![ToolCall {
547            id: "1".to_string(),
548            r#type: "function".to_string(),
549            name: "test".to_string(),
550            arguments: serde_json::json!({}),
551        }];
552
553        let tool_context = crate::types::ToolContext::default();
554        let tools = vec![create_test_tool("test", false)];
555
556        let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
557            Ok(crate::types::ToolResult {
558                result_type: "tool_result".to_string(),
559                tool_use_id: "1".to_string(),
560                content: "success".to_string(),
561                is_error: Some(false),
562                was_persisted: Some(false),
563            })
564        };
565
566        let updates =
567            run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
568        assert_eq!(updates.len(), 1);
569        assert!(updates[0].message.is_some());
570    }
571
572    #[tokio::test]
573    async fn test_run_tools_concurrently() {
574        let tool_calls = vec![
575            ToolCall {
576                id: "1".to_string(),
577                r#type: "function".to_string(),
578                name: "test1".to_string(),
579                arguments: serde_json::json!({}),
580            },
581            ToolCall {
582                id: "2".to_string(),
583                r#type: "function".to_string(),
584                name: "test2".to_string(),
585                arguments: serde_json::json!({}),
586            },
587        ];
588
589        let tool_context = crate::types::ToolContext::default();
590        let tools = vec![
591            create_test_tool("test1", true),
592            create_test_tool("test2", true),
593        ];
594
595        let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
596            Ok(crate::types::ToolResult {
597                result_type: "tool_result".to_string(),
598                tool_use_id: "1".to_string(),
599                content: "success".to_string(),
600                is_error: Some(false),
601                was_persisted: Some(false),
602            })
603        };
604
605        let updates =
606            run_tools_concurrently(tool_calls, tool_context, tools, executor, None, None).await;
607        assert_eq!(updates.len(), 2);
608    }
609
610    #[tokio::test]
611    async fn test_run_tools_with_partitioning() {
612        let tool_calls = vec![
613            ToolCall {
614                id: "1".to_string(),
615                r#type: "function".to_string(),
616                name: "Read".to_string(),
617                arguments: serde_json::json!({}),
618            },
619            ToolCall {
620                id: "2".to_string(),
621                r#type: "function".to_string(),
622                name: "Glob".to_string(),
623                arguments: serde_json::json!({}),
624            },
625            ToolCall {
626                id: "3".to_string(),
627                r#type: "function".to_string(),
628                name: "Bash".to_string(),
629                arguments: serde_json::json!({}),
630            },
631        ];
632        let tools = vec![
633            create_test_tool("Read", true),
634            create_test_tool("Glob", true),
635            create_test_tool("Bash", false),
636        ];
637
638        let tool_context = crate::types::ToolContext::default();
639
640        let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
641            Ok(crate::types::ToolResult {
642                result_type: "tool_result".to_string(),
643                tool_use_id: "1".to_string(),
644                content: "success".to_string(),
645                is_error: Some(false),
646                was_persisted: Some(false),
647            })
648        };
649
650        let updates = run_tools(tool_calls, tools, tool_context, executor, None, None).await;
651        assert_eq!(updates.len(), 3);
652    }
653
654    #[test]
655    fn test_mark_tool_use_as_complete() {
656        let mut in_progress = std::collections::HashSet::new();
657        in_progress.insert("tool1".to_string());
658        in_progress.insert("tool2".to_string());
659
660        mark_tool_use_as_complete(&mut in_progress, "tool1");
661
662        assert!(!in_progress.contains("tool1"));
663        assert!(in_progress.contains("tool2"));
664    }
665
666    #[tokio::test]
667    async fn test_run_tools_serially_aborted() {
668        use crate::utils::abort_controller::create_abort_controller_default;
669
670        let tool_calls = vec![ToolCall {
671            id: "1".to_string(),
672            r#type: "function".to_string(),
673            name: "test".to_string(),
674            arguments: serde_json::json!({}),
675        }];
676
677        let controller = create_abort_controller_default();
678        controller.abort(None); // Pre-abort
679        let abort_signal = controller.signal().clone();
680
681        let tool_context = crate::types::ToolContext {
682            cwd: "/tmp".to_string(),
683            abort_signal,
684        };
685        let tools = vec![create_tool_with_interrupt("test", Some("cancel".into()))];
686
687        let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
688            Ok(crate::types::ToolResult {
689                result_type: "tool_result".to_string(),
690                tool_use_id: "1".to_string(),
691                content: "should not reach".to_string(),
692                is_error: Some(false),
693                was_persisted: Some(false),
694            })
695        };
696
697        let updates =
698            run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
699        assert_eq!(updates.len(), 1);
700        let msg = updates[0].message.as_ref().unwrap();
701        assert!(msg.is_error == Some(true));
702        assert!(msg.content.contains("aborted"));
703    }
704
705    #[tokio::test]
706    async fn test_run_tools_concurrently_aborted() {
707        use crate::utils::abort_controller::create_abort_controller_default;
708
709        let tool_calls = vec![ToolCall {
710            id: "1".to_string(),
711            r#type: "function".to_string(),
712            name: "Read".to_string(),
713            arguments: serde_json::json!({}),
714        }];
715
716        let controller = create_abort_controller_default();
717        controller.abort(None); // Pre-abort
718        let abort_signal = controller.signal().clone();
719
720        let tool_context = crate::types::ToolContext {
721            cwd: "/tmp".to_string(),
722            abort_signal,
723        };
724        let tools = vec![create_tool_with_interrupt("Read", Some("cancel".into()))];
725
726        let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
727            Ok(crate::types::ToolResult {
728                result_type: "tool_result".to_string(),
729                tool_use_id: "1".to_string(),
730                content: "should not reach".to_string(),
731                is_error: Some(false),
732                was_persisted: Some(false),
733            })
734        };
735
736        let updates = run_tools_concurrently(
737            tool_calls, tool_context, tools, executor, None, None,
738        )
739        .await;
740        assert_eq!(updates.len(), 1);
741        let msg = updates[0].message.as_ref().unwrap();
742        assert!(msg.is_error == Some(true));
743    }
744
745    fn create_tool_with_interrupt(
746        name: &str,
747        interrupt: Option<String>,
748    ) -> ToolDefinition {
749        ToolDefinition {
750            name: name.to_string(),
751            description: format!("Test tool {}", name),
752            input_schema: ToolInputSchema {
753                schema_type: "object".to_string(),
754                properties: serde_json::json!({}),
755                required: None,
756            },
757            annotations: None,
758            should_defer: None,
759            always_load: None,
760            is_mcp: None,
761            search_hint: None,
762            aliases: None,
763            user_facing_name: None,
764            interrupt_behavior: interrupt,
765        }
766    }
767
768    #[tokio::test]
769    async fn test_interrupt_cancel_tool_aborted() {
770        use crate::utils::abort_controller::create_abort_controller_default;
771
772        let tool_calls = vec![ToolCall {
773            id: "1".to_string(),
774            r#type: "function".to_string(),
775            name: "CancelTool".to_string(),
776            arguments: serde_json::json!({}),
777        }];
778
779        let controller = create_abort_controller_default();
780        controller.abort(None);
781        let abort_signal = controller.signal().clone();
782
783        let tool_context = crate::types::ToolContext {
784            cwd: "/tmp".to_string(),
785            abort_signal,
786        };
787        let tools = vec![create_tool_with_interrupt("CancelTool", Some("cancel".into()))];
788
789        let executor = |_name: String, _args: serde_json::Value, _id: String| async {
790            Ok(crate::types::ToolResult {
791                result_type: "tool_result".to_string(),
792                tool_use_id: "1".to_string(),
793                content: "should not reach".to_string(),
794                is_error: Some(false),
795                was_persisted: Some(false),
796            })
797        };
798
799        let updates =
800            run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
801        assert_eq!(updates.len(), 1);
802        let msg = updates[0].message.as_ref().unwrap();
803        assert!(msg.is_error == Some(true));
804        assert!(msg.content.contains("aborted"));
805    }
806
807    #[tokio::test]
808    async fn test_interrupt_block_tool_ignores_abort() {
809        use crate::utils::abort_controller::create_abort_controller_default;
810
811        let tool_calls = vec![ToolCall {
812            id: "1".to_string(),
813            r#type: "function".to_string(),
814            name: "BlockTool".to_string(),
815            arguments: serde_json::json!({}),
816        }];
817
818        let controller = create_abort_controller_default();
819        controller.abort(None); // Pre-abort, but block tool should ignore
820        let abort_signal = controller.signal().clone();
821
822        let tool_context = crate::types::ToolContext {
823            cwd: "/tmp".to_string(),
824            abort_signal,
825        };
826        let tools = vec![create_tool_with_interrupt("BlockTool", Some("block".into()))];
827
828        let executor = |_name: String, _args: serde_json::Value, _id: String| async {
829            Ok(crate::types::ToolResult {
830                result_type: "tool_result".to_string(),
831                tool_use_id: "1".to_string(),
832                content: "block tool completed".to_string(),
833                is_error: Some(false),
834                was_persisted: Some(false),
835            })
836        };
837
838        let updates =
839            run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
840        assert_eq!(updates.len(), 1);
841        let msg = updates[0].message.as_ref().unwrap();
842        // Block tool should complete normally, NOT aborted
843        assert!(msg.is_error != Some(true));
844        assert!(msg.content.contains("block tool completed"));
845    }
846
847    #[tokio::test]
848    async fn test_interrupt_default_treated_as_block() {
849        use crate::utils::abort_controller::create_abort_controller_default;
850
851        let tool_calls = vec![ToolCall {
852            id: "1".to_string(),
853            r#type: "function".to_string(),
854            name: "DefaultTool".to_string(),
855            arguments: serde_json::json!({}),
856        }];
857
858        let controller = create_abort_controller_default();
859        controller.abort(None);
860        let abort_signal = controller.signal().clone();
861
862        let tool_context = crate::types::ToolContext {
863            cwd: "/tmp".to_string(),
864            abort_signal,
865        };
866        // No interrupt_behavior set → defaults to Block
867        let tools = vec![create_tool_with_interrupt("DefaultTool", None)];
868
869        let executor = |_name: String, _args: serde_json::Value, _id: String| async {
870            Ok(crate::types::ToolResult {
871                result_type: "tool_result".to_string(),
872                tool_use_id: "1".to_string(),
873                content: "default completed".to_string(),
874                is_error: Some(false),
875                was_persisted: Some(false),
876            })
877        };
878
879        let updates =
880            run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
881        assert_eq!(updates.len(), 1);
882        let msg = updates[0].message.as_ref().unwrap();
883        // Default (block) should complete normally despite abort
884        assert!(msg.is_error != Some(true));
885    }
886
887    #[tokio::test]
888    async fn test_interrupt_concurrently_block_ignores_abort() {
889        use crate::utils::abort_controller::create_abort_controller_default;
890
891        let tool_calls = vec![ToolCall {
892            id: "1".to_string(),
893            r#type: "function".to_string(),
894            name: "BlockTool".to_string(),
895            arguments: serde_json::json!({}),
896        }];
897
898        let controller = create_abort_controller_default();
899        controller.abort(None);
900        let abort_signal = controller.signal().clone();
901
902        let tool_context = crate::types::ToolContext {
903            cwd: "/tmp".to_string(),
904            abort_signal,
905        };
906        let tools =
907            vec![create_tool_with_interrupt("BlockTool", Some("block".into()))];
908
909        let executor = |_name: String, _args: serde_json::Value, _id: String| async {
910            Ok(crate::types::ToolResult {
911                result_type: "tool_result".to_string(),
912                tool_use_id: "1".to_string(),
913                content: "concurrent block done".to_string(),
914                is_error: Some(false),
915                was_persisted: Some(false),
916            })
917        };
918
919        let updates =
920            run_tools_concurrently(tool_calls, tool_context, tools, executor, None, None).await;
921        assert_eq!(updates.len(), 1);
922        let msg = updates[0].message.as_ref().unwrap();
923        assert!(msg.content.contains("concurrent block done"));
924    }
925}