Skip to main content

ai_agent/
streaming_tool_executor.rs

1// Source: /data/home/swei/claudecode/openclaudecode/src/services/tools/StreamingToolExecutor.ts
2//! Streaming tool executor that starts executing tools as they stream in from the API.
3//!
4//! Translated from TypeScript StreamingToolExecutor.ts.
5//! - Concurrent-safe tools can execute in parallel with other concurrent-safe tools
6//! - Non-concurrent tools must execute alone (exclusive access)
7//! - Results are buffered and emitted in the order tools were received
8
9use std::sync::Arc;
10
11use tokio::sync::{Mutex, Notify, mpsc};
12
13pub use crate::tools::orchestration::ToolMessageUpdate;
14use crate::types::{
15    Message, MessageRole, ToolAnnotations, ToolCall, ToolDefinition, ToolInputSchema, ToolResult,
16};
17
18/// Status of a tracked tool in the execution queue.
19#[derive(Debug, Clone, PartialEq)]
20enum ToolStatus {
21    Queued,
22    Executing,
23    Completed,
24    Yielded,
25}
26
27/// A tool being tracked for execution.
28#[derive(Clone)]
29struct TrackedTool {
30    id: String,
31    name: String,
32    status: ToolStatus,
33    is_concurrency_safe: bool,
34    args: serde_json::Value,
35    /// Results accumulated from this tool (for get_completed_results)
36    results: Vec<ToolMessageUpdate>,
37}
38
39/// A boxed executor function that takes tool name, args, and call ID.
40type ToolExecutorFn = Arc<
41    dyn Fn(
42            String,
43            serde_json::Value,
44            String,
45        ) -> std::pin::Pin<
46            Box<
47                dyn std::future::Future<Output = Result<ToolResult, crate::AgentError>>
48                    + Send
49                    + Sync,
50            >,
51        > + Send
52        + Sync,
53>;
54
55/// Shared state for the streaming executor.
56struct SharedState {
57    tools: Vec<TrackedTool>,
58    has_errored: bool,
59    discarded: bool,
60}
61
62/// Executes tools as they stream in with concurrency control.
63pub struct StreamingToolExecutor {
64    state: Arc<Mutex<SharedState>>,
65    executor: ToolExecutorFn,
66    tools_def: Vec<ToolDefinition>,
67    sibling_abort: Arc<Notify>,
68    /// Channel for delivering results to the consumer.
69    result_tx: mpsc::UnboundedSender<ToolMessageUpdate>,
70    notify: Arc<Notify>,
71}
72
73impl StreamingToolExecutor {
74    /// Create a new streaming tool executor.
75    pub fn new(
76        executor: ToolExecutorFn,
77        tools_def: Vec<ToolDefinition>,
78    ) -> (Self, mpsc::UnboundedReceiver<ToolMessageUpdate>) {
79        let (tx, rx) = mpsc::unbounded_channel();
80        (
81            Self {
82                state: Arc::new(Mutex::new(SharedState {
83                    tools: Vec::new(),
84                    has_errored: false,
85                    discarded: false,
86                })),
87                executor,
88                tools_def,
89                sibling_abort: Arc::new(Notify::new()),
90                result_tx: tx,
91                notify: Arc::new(Notify::new()),
92            },
93            rx,
94        )
95    }
96
97    /// Add a tool to the execution queue. Will start executing immediately if conditions allow.
98    pub fn add_tool(&self, name: String, id: String, args: serde_json::Value) {
99        let is_concurrency_safe = self
100            .tools_def
101            .iter()
102            .find(|t| t.name == name)
103            .map(|t| t.is_concurrency_safe(&args))
104            .unwrap_or(false);
105
106        let known = self.tools_def.iter().any(|t| t.name == name);
107        let tool = TrackedTool {
108            id: id.clone(),
109            name: name.clone(),
110            status: ToolStatus::Queued,
111            is_concurrency_safe,
112            args,
113            results: Vec::new(),
114        };
115
116        // Push to state and process queue in background
117        let state = self.state.clone();
118        let sibling_abort = self.sibling_abort.clone();
119        let executor = self.executor.clone();
120        let tools_def = self.tools_def.clone();
121        let result_tx = self.result_tx.clone();
122        let notify = self.notify.clone();
123
124        tokio::spawn(async move {
125            // Check for unknown tool
126            if !known {
127                let update = create_synthetic_error(&id, "streaming_fallback", &name);
128                let mut guard = state.lock().await;
129                guard.tools.push(TrackedTool {
130                    status: ToolStatus::Completed,
131                    results: Vec::new(),
132                    ..tool
133                });
134                drop(guard);
135                result_tx.send(update).ok();
136                notify.notify_one();
137                return;
138            }
139
140            // Update state
141            {
142                let mut guard = state.lock().await;
143                guard.tools.push(tool);
144            }
145
146            // Process queue
147            process_queue(state, executor, tools_def, result_tx, notify, sibling_abort).await;
148        });
149    }
150
151    /// Mark a tool use as complete.
152    pub async fn mark_complete(&self, tool_use_id: &str) {
153        let mut guard = self.state.lock().await;
154        if let Some(tool) = guard.tools.iter_mut().find(|t| t.id == tool_use_id) {
155            tool.status = ToolStatus::Completed;
156        }
157        drop(guard);
158        self.notify.notify_one();
159    }
160
161    /// Get a tool's concurrency safety flag.
162    pub async fn get_is_concurrency_safe(&self, tool_use_id: &str) -> bool {
163        let guard = self.state.lock().await;
164        guard
165            .tools
166            .iter()
167            .find(|t| t.id == tool_use_id)
168            .map(|t| t.is_concurrency_safe)
169            .unwrap_or(false)
170    }
171
172    /// Check if there are unfinished tools.
173    pub async fn has_unfinished_tools(&self) -> bool {
174        let guard = self.state.lock().await;
175        guard
176            .tools
177            .iter()
178            .any(|t| t.status != ToolStatus::Completed && t.status != ToolStatus::Yielded)
179    }
180
181    /// Check if any tools are currently executing.
182    pub async fn has_executing_tools(&self) -> bool {
183        let guard = self.state.lock().await;
184        guard
185            .tools
186            .iter()
187            .any(|t| t.status == ToolStatus::Executing)
188    }
189
190    /// Discard all pending and in-progress tools.
191    pub async fn discard(&self) {
192        let to_cancel: Vec<(String, String)> = {
193            let mut guard = self.state.lock().await;
194            guard.discarded = true;
195            guard
196                .tools
197                .iter()
198                .filter(|t| t.status == ToolStatus::Queued || t.status == ToolStatus::Executing)
199                .map(|t| (t.id.clone(), t.name.clone()))
200                .collect()
201        };
202        for (id, name) in to_cancel {
203            let mut guard = self.state.lock().await;
204            if let Some(tool) = guard.tools.iter_mut().find(|t| t.id == id) {
205                tool.status = ToolStatus::Completed;
206            }
207            drop(guard);
208            self.result_tx
209                .send(create_synthetic_error(&id, "streaming_fallback", &name))
210                .ok();
211        }
212        self.notify.notify_one();
213    }
214
215    /// Trigger sibling abort (called when Bash tool errors).
216    pub async fn trigger_sibling_abort(&self) {
217        let mut guard = self.state.lock().await;
218        guard.has_errored = true;
219        let ids: Vec<(String, String)> = guard
220            .tools
221            .iter()
222            .filter(|t| t.status == ToolStatus::Executing)
223            .map(|t| (t.id.clone(), t.name.clone()))
224            .collect();
225        drop(guard);
226
227        self.sibling_abort.notify_waiters();
228        for (id, name) in ids {
229            let update = create_synthetic_error(&id, "sibling_error", &name);
230            self.result_tx.send(update).ok();
231        }
232        self.notify.notify_one();
233    }
234
235    /// Set tool result from external execution.
236    pub async fn set_tool_result(
237        &self,
238        tool_call_id: String,
239        result: Result<ToolResult, crate::AgentError>,
240    ) {
241        let message = match result {
242            Ok(tool_result) => {
243                let msg = Message {
244                    role: MessageRole::Tool,
245                    content: tool_result.content,
246                    tool_call_id: Some(tool_call_id.clone()),
247                    is_error: tool_result.is_error,
248                    ..Default::default()
249                };
250                ToolMessageUpdate {
251                    message: Some(msg),
252                    new_context: None,
253                    context_modifier: None,
254                }
255            }
256            Err(e) => {
257                let error_content = format!("<tool_use_error>Error: {}</tool_use_error>", e);
258                let msg = Message {
259                    role: MessageRole::Tool,
260                    content: error_content,
261                    tool_call_id: Some(tool_call_id.clone()),
262                    is_error: Some(true),
263                    ..Default::default()
264                };
265                ToolMessageUpdate {
266                    message: Some(msg),
267                    new_context: None,
268                    context_modifier: None,
269                }
270            }
271        };
272
273        // Mark complete (adds to state if missing)
274        self.mark_complete(&tool_call_id).await;
275        // Store result for get_completed_results
276        self.store_result(&tool_call_id, message.clone()).await;
277        // Always send the result to the channel
278        self.result_tx.send(message).ok();
279        self.notify.notify_one();
280    }
281
282    /// Store a result in the tracked tool for get_completed_results iteration.
283    async fn store_result(&self, tool_call_id: &str, update: ToolMessageUpdate) {
284        let mut guard = self.state.lock().await;
285        if let Some(tool) = guard.tools.iter_mut().find(|t| t.id == tool_call_id) {
286            tool.results.push(update);
287        }
288    }
289
290    /// Get completed results that haven't been yielded yet.
291    /// Yields progress messages immediately, then results in order.
292    /// Stops yielding when encountering a non-concurrency-safe executing tool.
293    pub async fn get_completed_results(&self) -> Vec<ToolMessageUpdate> {
294        let mut guard = self.state.lock().await;
295        // Phase 1: collect indices of tools to yield (read-only)
296        let to_yield: Vec<(usize, String)> = guard
297            .tools
298            .iter()
299            .enumerate()
300            .filter_map(|(i, tool)| {
301                if tool.status == ToolStatus::Yielded {
302                    return None;
303                }
304                if tool.status == ToolStatus::Executing && !tool.is_concurrency_safe {
305                    return None; // Break here
306                }
307                if tool.status == ToolStatus::Completed && !tool.results.is_empty() {
308                    return Some((i, tool.id.clone()));
309                }
310                None
311            })
312            .collect();
313
314        // Phase 2: mark as yielded and collect results
315        let mut results = Vec::new();
316        for (i, _id) in to_yield {
317            if let Some(tool) = guard.tools.get_mut(i) {
318                tool.status = ToolStatus::Yielded;
319                results.append(&mut tool.results);
320            }
321        }
322
323        results
324    }
325
326    /// Wait for remaining tools and collect their results.
327    pub async fn get_remaining_results(
328        &self,
329        result_rx: &mut mpsc::UnboundedReceiver<ToolMessageUpdate>,
330    ) -> Vec<ToolMessageUpdate> {
331        let mut all_results = Vec::new();
332
333        // Collect any results already available
334        while let Ok(update) = result_rx.try_recv() {
335            all_results.push(update);
336        }
337
338        // Wait for all tools to complete
339        while self.has_unfinished_tools().await {
340            self.notify.notified().await;
341
342            // Collect results from channel
343            while let Ok(update) = result_rx.try_recv() {
344                all_results.push(update);
345            }
346
347            // Small delay to avoid busy loop
348            if self.has_executing_tools().await {
349                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
350            }
351        }
352
353        // Final collection
354        while let Ok(update) = result_rx.try_recv() {
355            all_results.push(update);
356        }
357
358        // Mark all remaining tools as yielded
359        {
360            let mut guard = self.state.lock().await;
361            for tool in guard.tools.iter_mut() {
362                if tool.status != ToolStatus::Yielded {
363                    tool.status = ToolStatus::Yielded;
364                }
365            }
366        }
367
368        all_results
369    }
370
371    /// Discard all pending and in-progress tools.
372    pub async fn discard_sync(&self) {
373        let mut guard = self.state.lock().await;
374        guard.discarded = true;
375        let to_cancel: Vec<(String, String)> = guard
376            .tools
377            .iter()
378            .filter(|t| t.status == ToolStatus::Queued || t.status == ToolStatus::Executing)
379            .map(|t| (t.id.clone(), t.name.clone()))
380            .collect();
381        drop(guard);
382
383        for (id, name) in to_cancel {
384            let mut guard = self.state.lock().await;
385            if let Some(tool) = guard.tools.iter_mut().find(|t| t.id == id) {
386                tool.status = ToolStatus::Completed;
387            }
388            drop(guard);
389            self.result_tx
390                .send(create_synthetic_error(&id, "streaming_fallback", &name))
391                .ok();
392        }
393        self.notify.notify_one();
394    }
395}
396
397/// Process the queue, starting execution for queued tools if allowed.
398async fn process_queue(
399    state: Arc<Mutex<SharedState>>,
400    executor: ToolExecutorFn,
401    _tools_def: Vec<ToolDefinition>,
402    result_tx: mpsc::UnboundedSender<ToolMessageUpdate>,
403    notify: Arc<Notify>,
404    sibling_abort: Arc<Notify>,
405) {
406    // Snapshot state
407    let snapshot: Vec<(String, String, serde_json::Value, bool, bool, bool)> = {
408        let guard = state.lock().await;
409        guard
410            .tools
411            .iter()
412            .map(|t| {
413                let is_queued = t.status == ToolStatus::Queued;
414                let is_executing = t.status == ToolStatus::Executing;
415                (
416                    t.id.clone(),
417                    t.name.clone(),
418                    t.args.clone(),
419                    t.is_concurrency_safe,
420                    is_queued,
421                    is_executing,
422                )
423            })
424            .collect()
425    };
426
427    // Find tools that can run
428    let mut can_run: Vec<(String, String, serde_json::Value, bool)> = Vec::new();
429    for (id, name, args, is_safe, is_queued, is_executing) in &snapshot {
430        if !is_queued {
431            continue;
432        }
433        let blocked = snapshot
434            .iter()
435            .any(|(_, _, _, other_safe, _, other_exec)| *other_exec && !*other_safe);
436        if blocked && !*is_safe {
437            // Non-safe blocked by another executing — skip (will be picked by the executing one)
438            continue;
439        }
440        can_run.push((id.clone(), name.clone(), args.clone(), *is_safe));
441    }
442
443    for (id, name, args, is_safe) in can_run {
444        execute_tool(
445            state.clone(),
446            id.clone(),
447            name.clone(),
448            args,
449            is_safe,
450            executor.clone(),
451            sibling_abort.clone(),
452            result_tx.clone(),
453            notify.clone(),
454        )
455        .await;
456        if !is_safe {
457            break;
458        }
459    }
460
461    notify.notify_one();
462}
463
464/// Execute a single tool in the background.
465async fn execute_tool(
466    state: Arc<Mutex<SharedState>>,
467    id: String,
468    name: String,
469    args: serde_json::Value,
470    _is_concurrency_safe: bool,
471    executor: ToolExecutorFn,
472    sibling_abort: Arc<Notify>,
473    result_tx: mpsc::UnboundedSender<ToolMessageUpdate>,
474    notify: Arc<Notify>,
475) {
476    // Pre-flight checks
477    let guard = state.lock().await;
478    if guard.discarded {
479        drop(guard);
480        result_tx
481            .send(create_synthetic_error(&id, "streaming_fallback", &name))
482            .ok();
483        return;
484    }
485    if guard.has_errored {
486        drop(guard);
487        result_tx
488            .send(create_synthetic_error(&id, "sibling_error", &name))
489            .ok();
490        return;
491    }
492    drop(guard);
493
494    // Mark as executing
495    {
496        let mut guard = state.lock().await;
497        if let Some(tool) = guard.tools.iter_mut().find(|t| t.id == id) {
498            tool.status = ToolStatus::Executing;
499        }
500    }
501
502    // Wait for sibling abort before starting
503    {
504        let sab = sibling_abort.clone();
505        sab.notified().await;
506    }
507
508    // Execute the tool
509    let result = executor(name.clone(), args.clone(), id.clone()).await;
510
511    // Mark complete and send result
512    {
513        let mut guard = state.lock().await;
514        if let Some(tool) = guard.tools.iter_mut().find(|t| t.id == id) {
515            tool.status = ToolStatus::Completed;
516        }
517        // Check for Bash error cascade
518        if let Ok(tool_result) = &result {
519            if tool_result.is_error == Some(true) && name == "Bash" {
520                guard.has_errored = true;
521                let siblings: Vec<(String, String)> = guard
522                    .tools
523                    .iter()
524                    .filter(|t| t.status == ToolStatus::Executing)
525                    .map(|t| (t.id.clone(), t.name.clone()))
526                    .collect();
527                drop(guard);
528                sibling_abort.notify_waiters();
529                for (sid, sname) in siblings {
530                    result_tx
531                        .send(create_synthetic_error(&sid, "sibling_error", &sname))
532                        .ok();
533                }
534                notify.notify_one();
535                return;
536            }
537        }
538        drop(guard);
539    }
540
541    // Send result
542    let message = match result {
543        Ok(tool_result) => ToolMessageUpdate {
544            message: Some(Message {
545                role: MessageRole::Tool,
546                content: tool_result.content,
547                tool_call_id: Some(id.clone()),
548                is_error: tool_result.is_error,
549                ..Default::default()
550            }),
551            new_context: None,
552            context_modifier: None,
553        },
554        Err(e) => ToolMessageUpdate {
555            message: Some(Message {
556                role: MessageRole::Tool,
557                content: format!("<tool_use_error>Error: {}</tool_use_error>", e),
558                tool_call_id: Some(id.clone()),
559                is_error: Some(true),
560                ..Default::default()
561            }),
562            new_context: None,
563            context_modifier: None,
564        },
565    };
566    result_tx.send(message.clone()).ok();
567    // Also store in state for get_completed_results
568    {
569        let mut guard = state.lock().await;
570        if let Some(tool) = guard.tools.iter_mut().find(|t| t.id == id) {
571            tool.results.push(message);
572        }
573    }
574    notify.notify_one();
575}
576
577/// Create a synthetic error message for cancelled/aborted tools.
578fn create_synthetic_error(reason: &str, tool_call_id: &str, tool_name: &str) -> ToolMessageUpdate {
579    let message = match reason {
580        "streaming_fallback" => Message {
581            role: MessageRole::User,
582            content: format!(
583                "Streaming fallback - tool '{}' execution discarded",
584                tool_name
585            ),
586            ..Default::default()
587        },
588        "sibling_error" => Message {
589            role: MessageRole::User,
590            content: format!("Cancelled: parallel tool call '{}' errored", tool_name),
591            ..Default::default()
592        },
593        "user_interrupted" => Message {
594            role: MessageRole::User,
595            content: "User rejected tool use".to_string(),
596            ..Default::default()
597        },
598        _ => Message {
599            role: MessageRole::User,
600            content: format!("Tool '{}' error", tool_name),
601            ..Default::default()
602        },
603    };
604
605    ToolMessageUpdate {
606        message: Some(message),
607        new_context: None,
608        context_modifier: None,
609    }
610}
611
612/// Get tool concurrency info for the streaming executor.
613pub fn get_tool_concurrency_info(
614    tool_calls: &[ToolCall],
615    tools: &[ToolDefinition],
616) -> Vec<(String, String, bool, serde_json::Value)> {
617    tool_calls
618        .iter()
619        .map(|tc| {
620            let is_safe = tools
621                .iter()
622                .find(|t| t.name == tc.name)
623                .map(|t| t.is_concurrency_safe(&tc.arguments))
624                .unwrap_or(false);
625            (
626                tc.id.clone(),
627                tc.name.clone(),
628                is_safe,
629                tc.arguments.clone(),
630            )
631        })
632        .collect()
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638    use tokio::time::{Duration, sleep};
639
640    #[tokio::test]
641    async fn test_create_executor() {
642        let executor: ToolExecutorFn = Arc::new(|_name, _args, _id| {
643            Box::pin(async {
644                Ok(ToolResult {
645                    result_type: "tool_result".to_string(),
646                    tool_use_id: "1".to_string(),
647                    content: "ok".to_string(),
648                    is_error: Some(false),
649                    was_persisted: None,
650                })
651            })
652        });
653        let exe = StreamingToolExecutor::new(executor, vec![]);
654        exe.0.add_tool(
655            "Bash".to_string(),
656            "tool1".to_string(),
657            serde_json::json!({}),
658        );
659        // Give spawned task a moment to complete
660        sleep(Duration::from_millis(50)).await;
661        assert_eq!(exe.0.state.lock().await.tools.len(), 1);
662    }
663
664    #[tokio::test]
665    async fn test_mark_complete() {
666        let executor: ToolExecutorFn = Arc::new(|_name, _args, _id| {
667            Box::pin(async {
668                Ok(ToolResult {
669                    result_type: "t".into(),
670                    tool_use_id: "1".into(),
671                    content: "ok".into(),
672                    is_error: Some(false),
673                    was_persisted: None,
674                })
675            })
676        });
677        let exe = StreamingToolExecutor::new(executor, vec![]);
678        exe.0.add_tool(
679            "Bash".to_string(),
680            "tool1".to_string(),
681            serde_json::json!({}),
682        );
683        exe.0.mark_complete("tool1").await;
684        sleep(Duration::from_millis(50)).await;
685        let guard = exe.0.state.lock().await;
686        assert_eq!(guard.tools[0].status, ToolStatus::Completed);
687    }
688
689    #[tokio::test]
690    async fn test_discard() {
691        let executor: ToolExecutorFn = Arc::new(|_name, _args, _id| {
692            Box::pin(async {
693                Ok(ToolResult {
694                    result_type: "t".into(),
695                    tool_use_id: "1".into(),
696                    content: "ok".into(),
697                    is_error: Some(false),
698                    was_persisted: None,
699                })
700            })
701        });
702        let (exe, mut rx) = StreamingToolExecutor::new(executor, vec![]);
703        // Add 2 tools
704        exe.add_tool(
705            "Bash".to_string(),
706            "tool1".to_string(),
707            serde_json::json!({}),
708        );
709        exe.add_tool(
710            "Glob".to_string(),
711            "tool2".to_string(),
712            serde_json::json!({}),
713        );
714        // Small delay for spawned tasks
715        sleep(Duration::from_millis(50)).await;
716
717        exe.discard().await;
718
719        let mut count = 0;
720        while rx.try_recv().is_ok() {
721            count += 1;
722        }
723        assert!(count >= 1);
724    }
725
726    #[tokio::test]
727    async fn test_trigger_sibling_abort() {
728        let executor: ToolExecutorFn = Arc::new(|_name, _args, _id| {
729            Box::pin(async {
730                Ok(ToolResult {
731                    result_type: "t".into(),
732                    tool_use_id: "1".into(),
733                    content: "ok".into(),
734                    is_error: Some(false),
735                    was_persisted: None,
736                })
737            })
738        });
739        let (exe, mut rx) = StreamingToolExecutor::new(executor, vec![]);
740        exe.add_tool(
741            "Bash".to_string(),
742            "tool1".to_string(),
743            serde_json::json!({}),
744        );
745        exe.add_tool(
746            "Glob".to_string(),
747            "tool2".to_string(),
748            serde_json::json!({}),
749        );
750        sleep(Duration::from_millis(50)).await;
751
752        // Manually set executing status
753        {
754            let mut guard = exe.state.lock().await;
755            if let Some(t) = guard.tools.iter_mut().find(|t| t.id == "tool1") {
756                t.status = ToolStatus::Executing;
757            }
758            if let Some(t) = guard.tools.iter_mut().find(|t| t.id == "tool2") {
759                t.status = ToolStatus::Executing;
760            }
761        }
762
763        exe.trigger_sibling_abort().await;
764
765        let guard = exe.state.lock().await;
766        assert!(guard.has_errored);
767
768        let mut count = 0;
769        while rx.try_recv().is_ok() {
770            count += 1;
771        }
772        assert!(count >= 1);
773    }
774
775    #[tokio::test]
776    async fn test_set_tool_result() {
777        let executor: ToolExecutorFn = Arc::new(|_name, _args, _id| {
778            Box::pin(async {
779                Ok(ToolResult {
780                    result_type: "tool_result".to_string(),
781                    tool_use_id: "1".to_string(),
782                    content: "command output".to_string(),
783                    is_error: Some(false),
784                    was_persisted: None,
785                })
786            })
787        });
788        let (exe, mut rx) = StreamingToolExecutor::new(executor, vec![]);
789        exe.add_tool(
790            "Bash".to_string(),
791            "tool1".to_string(),
792            serde_json::json!({}),
793        );
794
795        exe.set_tool_result(
796            "tool1".to_string(),
797            Ok(ToolResult {
798                result_type: "tool_result".to_string(),
799                tool_use_id: "tool1".to_string(),
800                content: "command output".to_string(),
801                is_error: Some(false),
802                was_persisted: None,
803            }),
804        )
805        .await;
806
807        let update = rx.recv().await;
808        assert!(update.is_some());
809        let msg = update.unwrap().message.unwrap();
810        assert_eq!(msg.content, "command output");
811    }
812
813    #[test]
814    fn test_get_tool_concurrency_info() {
815        let tools = vec![ToolDefinition {
816            name: "Bash".to_string(),
817            description: "Execute commands".to_string(),
818            input_schema: ToolInputSchema {
819                schema_type: "object".to_string(),
820                properties: serde_json::json!({}),
821                required: None,
822            },
823            annotations: Some(ToolAnnotations {
824                concurrency_safe: Some(true),
825                ..Default::default()
826            }),
827            should_defer: None,
828            always_load: None,
829            is_mcp: None,
830            search_hint: None,
831            aliases: None,
832            user_facing_name: None,
833            interrupt_behavior: None,
834        }];
835        let calls = vec![ToolCall {
836            id: "1".to_string(),
837            r#type: "function".to_string(),
838            name: "Bash".to_string(),
839            arguments: serde_json::json!({}),
840        }];
841        let info = get_tool_concurrency_info(&calls, &tools);
842        assert_eq!(info.len(), 1);
843        assert!(info[0].2);
844    }
845}