1use std::sync::Arc;
10
11use tokio::sync::{Mutex, Notify, mpsc};
12
13pub use crate::tools::orchestration::ToolMessageUpdate;
14use crate::types::{
15 Message, MessageRole, ToolAnnotations, ToolCall, ToolDefinition, ToolInputSchema, ToolResult,
16};
17
18#[derive(Debug, Clone, PartialEq)]
20enum ToolStatus {
21 Queued,
22 Executing,
23 Completed,
24 Yielded,
25}
26
27#[derive(Clone)]
29struct TrackedTool {
30 id: String,
31 name: String,
32 status: ToolStatus,
33 is_concurrency_safe: bool,
34 args: serde_json::Value,
35 results: Vec<ToolMessageUpdate>,
37}
38
39type ToolExecutorFn = Arc<
41 dyn Fn(
42 String,
43 serde_json::Value,
44 String,
45 ) -> std::pin::Pin<
46 Box<
47 dyn std::future::Future<Output = Result<ToolResult, crate::AgentError>>
48 + Send
49 + Sync,
50 >,
51 > + Send
52 + Sync,
53>;
54
55struct SharedState {
57 tools: Vec<TrackedTool>,
58 has_errored: bool,
59 discarded: bool,
60}
61
62pub struct StreamingToolExecutor {
64 state: Arc<Mutex<SharedState>>,
65 executor: ToolExecutorFn,
66 tools_def: Vec<ToolDefinition>,
67 sibling_abort: Arc<Notify>,
68 result_tx: mpsc::UnboundedSender<ToolMessageUpdate>,
70 notify: Arc<Notify>,
71}
72
73impl StreamingToolExecutor {
74 pub fn new(
76 executor: ToolExecutorFn,
77 tools_def: Vec<ToolDefinition>,
78 ) -> (Self, mpsc::UnboundedReceiver<ToolMessageUpdate>) {
79 let (tx, rx) = mpsc::unbounded_channel();
80 (
81 Self {
82 state: Arc::new(Mutex::new(SharedState {
83 tools: Vec::new(),
84 has_errored: false,
85 discarded: false,
86 })),
87 executor,
88 tools_def,
89 sibling_abort: Arc::new(Notify::new()),
90 result_tx: tx,
91 notify: Arc::new(Notify::new()),
92 },
93 rx,
94 )
95 }
96
97 pub fn add_tool(&self, name: String, id: String, args: serde_json::Value) {
99 let is_concurrency_safe = self
100 .tools_def
101 .iter()
102 .find(|t| t.name == name)
103 .map(|t| t.is_concurrency_safe(&args))
104 .unwrap_or(false);
105
106 let known = self.tools_def.iter().any(|t| t.name == name);
107 let tool = TrackedTool {
108 id: id.clone(),
109 name: name.clone(),
110 status: ToolStatus::Queued,
111 is_concurrency_safe,
112 args,
113 results: Vec::new(),
114 };
115
116 let state = self.state.clone();
118 let sibling_abort = self.sibling_abort.clone();
119 let executor = self.executor.clone();
120 let tools_def = self.tools_def.clone();
121 let result_tx = self.result_tx.clone();
122 let notify = self.notify.clone();
123
124 tokio::spawn(async move {
125 if !known {
127 let update = create_synthetic_error(&id, "streaming_fallback", &name);
128 let mut guard = state.lock().await;
129 guard.tools.push(TrackedTool {
130 status: ToolStatus::Completed,
131 results: Vec::new(),
132 ..tool
133 });
134 drop(guard);
135 result_tx.send(update).ok();
136 notify.notify_one();
137 return;
138 }
139
140 {
142 let mut guard = state.lock().await;
143 guard.tools.push(tool);
144 }
145
146 process_queue(state, executor, tools_def, result_tx, notify, sibling_abort).await;
148 });
149 }
150
151 pub async fn mark_complete(&self, tool_use_id: &str) {
153 let mut guard = self.state.lock().await;
154 if let Some(tool) = guard.tools.iter_mut().find(|t| t.id == tool_use_id) {
155 tool.status = ToolStatus::Completed;
156 }
157 drop(guard);
158 self.notify.notify_one();
159 }
160
161 pub async fn get_is_concurrency_safe(&self, tool_use_id: &str) -> bool {
163 let guard = self.state.lock().await;
164 guard
165 .tools
166 .iter()
167 .find(|t| t.id == tool_use_id)
168 .map(|t| t.is_concurrency_safe)
169 .unwrap_or(false)
170 }
171
172 pub async fn has_unfinished_tools(&self) -> bool {
174 let guard = self.state.lock().await;
175 guard
176 .tools
177 .iter()
178 .any(|t| t.status != ToolStatus::Completed && t.status != ToolStatus::Yielded)
179 }
180
181 pub async fn has_executing_tools(&self) -> bool {
183 let guard = self.state.lock().await;
184 guard
185 .tools
186 .iter()
187 .any(|t| t.status == ToolStatus::Executing)
188 }
189
190 pub async fn discard(&self) {
192 let to_cancel: Vec<(String, String)> = {
193 let mut guard = self.state.lock().await;
194 guard.discarded = true;
195 guard
196 .tools
197 .iter()
198 .filter(|t| t.status == ToolStatus::Queued || t.status == ToolStatus::Executing)
199 .map(|t| (t.id.clone(), t.name.clone()))
200 .collect()
201 };
202 for (id, name) in to_cancel {
203 let mut guard = self.state.lock().await;
204 if let Some(tool) = guard.tools.iter_mut().find(|t| t.id == id) {
205 tool.status = ToolStatus::Completed;
206 }
207 drop(guard);
208 self.result_tx
209 .send(create_synthetic_error(&id, "streaming_fallback", &name))
210 .ok();
211 }
212 self.notify.notify_one();
213 }
214
215 pub async fn trigger_sibling_abort(&self) {
217 let mut guard = self.state.lock().await;
218 guard.has_errored = true;
219 let ids: Vec<(String, String)> = guard
220 .tools
221 .iter()
222 .filter(|t| t.status == ToolStatus::Executing)
223 .map(|t| (t.id.clone(), t.name.clone()))
224 .collect();
225 drop(guard);
226
227 self.sibling_abort.notify_waiters();
228 for (id, name) in ids {
229 let update = create_synthetic_error(&id, "sibling_error", &name);
230 self.result_tx.send(update).ok();
231 }
232 self.notify.notify_one();
233 }
234
235 pub async fn set_tool_result(
237 &self,
238 tool_call_id: String,
239 result: Result<ToolResult, crate::AgentError>,
240 ) {
241 let message = match result {
242 Ok(tool_result) => {
243 let msg = Message {
244 role: MessageRole::Tool,
245 content: tool_result.content,
246 tool_call_id: Some(tool_call_id.clone()),
247 is_error: tool_result.is_error,
248 ..Default::default()
249 };
250 ToolMessageUpdate {
251 message: Some(msg),
252 new_context: None,
253 context_modifier: None,
254 }
255 }
256 Err(e) => {
257 let error_content = format!("<tool_use_error>Error: {}</tool_use_error>", e);
258 let msg = Message {
259 role: MessageRole::Tool,
260 content: error_content,
261 tool_call_id: Some(tool_call_id.clone()),
262 is_error: Some(true),
263 ..Default::default()
264 };
265 ToolMessageUpdate {
266 message: Some(msg),
267 new_context: None,
268 context_modifier: None,
269 }
270 }
271 };
272
273 self.mark_complete(&tool_call_id).await;
275 self.store_result(&tool_call_id, message.clone()).await;
277 self.result_tx.send(message).ok();
279 self.notify.notify_one();
280 }
281
282 async fn store_result(&self, tool_call_id: &str, update: ToolMessageUpdate) {
284 let mut guard = self.state.lock().await;
285 if let Some(tool) = guard.tools.iter_mut().find(|t| t.id == tool_call_id) {
286 tool.results.push(update);
287 }
288 }
289
290 pub async fn get_completed_results(&self) -> Vec<ToolMessageUpdate> {
294 let mut guard = self.state.lock().await;
295 let to_yield: Vec<(usize, String)> = guard
297 .tools
298 .iter()
299 .enumerate()
300 .filter_map(|(i, tool)| {
301 if tool.status == ToolStatus::Yielded {
302 return None;
303 }
304 if tool.status == ToolStatus::Executing && !tool.is_concurrency_safe {
305 return None; }
307 if tool.status == ToolStatus::Completed && !tool.results.is_empty() {
308 return Some((i, tool.id.clone()));
309 }
310 None
311 })
312 .collect();
313
314 let mut results = Vec::new();
316 for (i, _id) in to_yield {
317 if let Some(tool) = guard.tools.get_mut(i) {
318 tool.status = ToolStatus::Yielded;
319 results.append(&mut tool.results);
320 }
321 }
322
323 results
324 }
325
326 pub async fn get_remaining_results(
328 &self,
329 result_rx: &mut mpsc::UnboundedReceiver<ToolMessageUpdate>,
330 ) -> Vec<ToolMessageUpdate> {
331 let mut all_results = Vec::new();
332
333 while let Ok(update) = result_rx.try_recv() {
335 all_results.push(update);
336 }
337
338 while self.has_unfinished_tools().await {
340 self.notify.notified().await;
341
342 while let Ok(update) = result_rx.try_recv() {
344 all_results.push(update);
345 }
346
347 if self.has_executing_tools().await {
349 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
350 }
351 }
352
353 while let Ok(update) = result_rx.try_recv() {
355 all_results.push(update);
356 }
357
358 {
360 let mut guard = self.state.lock().await;
361 for tool in guard.tools.iter_mut() {
362 if tool.status != ToolStatus::Yielded {
363 tool.status = ToolStatus::Yielded;
364 }
365 }
366 }
367
368 all_results
369 }
370
371 pub async fn discard_sync(&self) {
373 let mut guard = self.state.lock().await;
374 guard.discarded = true;
375 let to_cancel: Vec<(String, String)> = guard
376 .tools
377 .iter()
378 .filter(|t| t.status == ToolStatus::Queued || t.status == ToolStatus::Executing)
379 .map(|t| (t.id.clone(), t.name.clone()))
380 .collect();
381 drop(guard);
382
383 for (id, name) in to_cancel {
384 let mut guard = self.state.lock().await;
385 if let Some(tool) = guard.tools.iter_mut().find(|t| t.id == id) {
386 tool.status = ToolStatus::Completed;
387 }
388 drop(guard);
389 self.result_tx
390 .send(create_synthetic_error(&id, "streaming_fallback", &name))
391 .ok();
392 }
393 self.notify.notify_one();
394 }
395}
396
397async fn process_queue(
399 state: Arc<Mutex<SharedState>>,
400 executor: ToolExecutorFn,
401 _tools_def: Vec<ToolDefinition>,
402 result_tx: mpsc::UnboundedSender<ToolMessageUpdate>,
403 notify: Arc<Notify>,
404 sibling_abort: Arc<Notify>,
405) {
406 let snapshot: Vec<(String, String, serde_json::Value, bool, bool, bool)> = {
408 let guard = state.lock().await;
409 guard
410 .tools
411 .iter()
412 .map(|t| {
413 let is_queued = t.status == ToolStatus::Queued;
414 let is_executing = t.status == ToolStatus::Executing;
415 (
416 t.id.clone(),
417 t.name.clone(),
418 t.args.clone(),
419 t.is_concurrency_safe,
420 is_queued,
421 is_executing,
422 )
423 })
424 .collect()
425 };
426
427 let mut can_run: Vec<(String, String, serde_json::Value, bool)> = Vec::new();
429 for (id, name, args, is_safe, is_queued, is_executing) in &snapshot {
430 if !is_queued {
431 continue;
432 }
433 let blocked = snapshot
434 .iter()
435 .any(|(_, _, _, other_safe, _, other_exec)| *other_exec && !*other_safe);
436 if blocked && !*is_safe {
437 continue;
439 }
440 can_run.push((id.clone(), name.clone(), args.clone(), *is_safe));
441 }
442
443 for (id, name, args, is_safe) in can_run {
444 execute_tool(
445 state.clone(),
446 id.clone(),
447 name.clone(),
448 args,
449 is_safe,
450 executor.clone(),
451 sibling_abort.clone(),
452 result_tx.clone(),
453 notify.clone(),
454 )
455 .await;
456 if !is_safe {
457 break;
458 }
459 }
460
461 notify.notify_one();
462}
463
464async fn execute_tool(
466 state: Arc<Mutex<SharedState>>,
467 id: String,
468 name: String,
469 args: serde_json::Value,
470 _is_concurrency_safe: bool,
471 executor: ToolExecutorFn,
472 sibling_abort: Arc<Notify>,
473 result_tx: mpsc::UnboundedSender<ToolMessageUpdate>,
474 notify: Arc<Notify>,
475) {
476 let guard = state.lock().await;
478 if guard.discarded {
479 drop(guard);
480 result_tx
481 .send(create_synthetic_error(&id, "streaming_fallback", &name))
482 .ok();
483 return;
484 }
485 if guard.has_errored {
486 drop(guard);
487 result_tx
488 .send(create_synthetic_error(&id, "sibling_error", &name))
489 .ok();
490 return;
491 }
492 drop(guard);
493
494 {
496 let mut guard = state.lock().await;
497 if let Some(tool) = guard.tools.iter_mut().find(|t| t.id == id) {
498 tool.status = ToolStatus::Executing;
499 }
500 }
501
502 {
504 let sab = sibling_abort.clone();
505 sab.notified().await;
506 }
507
508 let result = executor(name.clone(), args.clone(), id.clone()).await;
510
511 {
513 let mut guard = state.lock().await;
514 if let Some(tool) = guard.tools.iter_mut().find(|t| t.id == id) {
515 tool.status = ToolStatus::Completed;
516 }
517 if let Ok(tool_result) = &result {
519 if tool_result.is_error == Some(true) && name == "Bash" {
520 guard.has_errored = true;
521 let siblings: Vec<(String, String)> = guard
522 .tools
523 .iter()
524 .filter(|t| t.status == ToolStatus::Executing)
525 .map(|t| (t.id.clone(), t.name.clone()))
526 .collect();
527 drop(guard);
528 sibling_abort.notify_waiters();
529 for (sid, sname) in siblings {
530 result_tx
531 .send(create_synthetic_error(&sid, "sibling_error", &sname))
532 .ok();
533 }
534 notify.notify_one();
535 return;
536 }
537 }
538 drop(guard);
539 }
540
541 let message = match result {
543 Ok(tool_result) => ToolMessageUpdate {
544 message: Some(Message {
545 role: MessageRole::Tool,
546 content: tool_result.content,
547 tool_call_id: Some(id.clone()),
548 is_error: tool_result.is_error,
549 ..Default::default()
550 }),
551 new_context: None,
552 context_modifier: None,
553 },
554 Err(e) => ToolMessageUpdate {
555 message: Some(Message {
556 role: MessageRole::Tool,
557 content: format!("<tool_use_error>Error: {}</tool_use_error>", e),
558 tool_call_id: Some(id.clone()),
559 is_error: Some(true),
560 ..Default::default()
561 }),
562 new_context: None,
563 context_modifier: None,
564 },
565 };
566 result_tx.send(message.clone()).ok();
567 {
569 let mut guard = state.lock().await;
570 if let Some(tool) = guard.tools.iter_mut().find(|t| t.id == id) {
571 tool.results.push(message);
572 }
573 }
574 notify.notify_one();
575}
576
577fn create_synthetic_error(reason: &str, tool_call_id: &str, tool_name: &str) -> ToolMessageUpdate {
579 let message = match reason {
580 "streaming_fallback" => Message {
581 role: MessageRole::User,
582 content: format!(
583 "Streaming fallback - tool '{}' execution discarded",
584 tool_name
585 ),
586 ..Default::default()
587 },
588 "sibling_error" => Message {
589 role: MessageRole::User,
590 content: format!("Cancelled: parallel tool call '{}' errored", tool_name),
591 ..Default::default()
592 },
593 "user_interrupted" => Message {
594 role: MessageRole::User,
595 content: "User rejected tool use".to_string(),
596 ..Default::default()
597 },
598 _ => Message {
599 role: MessageRole::User,
600 content: format!("Tool '{}' error", tool_name),
601 ..Default::default()
602 },
603 };
604
605 ToolMessageUpdate {
606 message: Some(message),
607 new_context: None,
608 context_modifier: None,
609 }
610}
611
612pub fn get_tool_concurrency_info(
614 tool_calls: &[ToolCall],
615 tools: &[ToolDefinition],
616) -> Vec<(String, String, bool, serde_json::Value)> {
617 tool_calls
618 .iter()
619 .map(|tc| {
620 let is_safe = tools
621 .iter()
622 .find(|t| t.name == tc.name)
623 .map(|t| t.is_concurrency_safe(&tc.arguments))
624 .unwrap_or(false);
625 (
626 tc.id.clone(),
627 tc.name.clone(),
628 is_safe,
629 tc.arguments.clone(),
630 )
631 })
632 .collect()
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638 use tokio::time::{Duration, sleep};
639
640 #[tokio::test]
641 async fn test_create_executor() {
642 let executor: ToolExecutorFn = Arc::new(|_name, _args, _id| {
643 Box::pin(async {
644 Ok(ToolResult {
645 result_type: "tool_result".to_string(),
646 tool_use_id: "1".to_string(),
647 content: "ok".to_string(),
648 is_error: Some(false),
649 was_persisted: None,
650 })
651 })
652 });
653 let exe = StreamingToolExecutor::new(executor, vec![]);
654 exe.0.add_tool(
655 "Bash".to_string(),
656 "tool1".to_string(),
657 serde_json::json!({}),
658 );
659 sleep(Duration::from_millis(50)).await;
661 assert_eq!(exe.0.state.lock().await.tools.len(), 1);
662 }
663
664 #[tokio::test]
665 async fn test_mark_complete() {
666 let executor: ToolExecutorFn = Arc::new(|_name, _args, _id| {
667 Box::pin(async {
668 Ok(ToolResult {
669 result_type: "t".into(),
670 tool_use_id: "1".into(),
671 content: "ok".into(),
672 is_error: Some(false),
673 was_persisted: None,
674 })
675 })
676 });
677 let exe = StreamingToolExecutor::new(executor, vec![]);
678 exe.0.add_tool(
679 "Bash".to_string(),
680 "tool1".to_string(),
681 serde_json::json!({}),
682 );
683 exe.0.mark_complete("tool1").await;
684 sleep(Duration::from_millis(50)).await;
685 let guard = exe.0.state.lock().await;
686 assert_eq!(guard.tools[0].status, ToolStatus::Completed);
687 }
688
689 #[tokio::test]
690 async fn test_discard() {
691 let executor: ToolExecutorFn = Arc::new(|_name, _args, _id| {
692 Box::pin(async {
693 Ok(ToolResult {
694 result_type: "t".into(),
695 tool_use_id: "1".into(),
696 content: "ok".into(),
697 is_error: Some(false),
698 was_persisted: None,
699 })
700 })
701 });
702 let (exe, mut rx) = StreamingToolExecutor::new(executor, vec![]);
703 exe.add_tool(
705 "Bash".to_string(),
706 "tool1".to_string(),
707 serde_json::json!({}),
708 );
709 exe.add_tool(
710 "Glob".to_string(),
711 "tool2".to_string(),
712 serde_json::json!({}),
713 );
714 sleep(Duration::from_millis(50)).await;
716
717 exe.discard().await;
718
719 let mut count = 0;
720 while rx.try_recv().is_ok() {
721 count += 1;
722 }
723 assert!(count >= 1);
724 }
725
726 #[tokio::test]
727 async fn test_trigger_sibling_abort() {
728 let executor: ToolExecutorFn = Arc::new(|_name, _args, _id| {
729 Box::pin(async {
730 Ok(ToolResult {
731 result_type: "t".into(),
732 tool_use_id: "1".into(),
733 content: "ok".into(),
734 is_error: Some(false),
735 was_persisted: None,
736 })
737 })
738 });
739 let (exe, mut rx) = StreamingToolExecutor::new(executor, vec![]);
740 exe.add_tool(
741 "Bash".to_string(),
742 "tool1".to_string(),
743 serde_json::json!({}),
744 );
745 exe.add_tool(
746 "Glob".to_string(),
747 "tool2".to_string(),
748 serde_json::json!({}),
749 );
750 sleep(Duration::from_millis(50)).await;
751
752 {
754 let mut guard = exe.state.lock().await;
755 if let Some(t) = guard.tools.iter_mut().find(|t| t.id == "tool1") {
756 t.status = ToolStatus::Executing;
757 }
758 if let Some(t) = guard.tools.iter_mut().find(|t| t.id == "tool2") {
759 t.status = ToolStatus::Executing;
760 }
761 }
762
763 exe.trigger_sibling_abort().await;
764
765 let guard = exe.state.lock().await;
766 assert!(guard.has_errored);
767
768 let mut count = 0;
769 while rx.try_recv().is_ok() {
770 count += 1;
771 }
772 assert!(count >= 1);
773 }
774
775 #[tokio::test]
776 async fn test_set_tool_result() {
777 let executor: ToolExecutorFn = Arc::new(|_name, _args, _id| {
778 Box::pin(async {
779 Ok(ToolResult {
780 result_type: "tool_result".to_string(),
781 tool_use_id: "1".to_string(),
782 content: "command output".to_string(),
783 is_error: Some(false),
784 was_persisted: None,
785 })
786 })
787 });
788 let (exe, mut rx) = StreamingToolExecutor::new(executor, vec![]);
789 exe.add_tool(
790 "Bash".to_string(),
791 "tool1".to_string(),
792 serde_json::json!({}),
793 );
794
795 exe.set_tool_result(
796 "tool1".to_string(),
797 Ok(ToolResult {
798 result_type: "tool_result".to_string(),
799 tool_use_id: "tool1".to_string(),
800 content: "command output".to_string(),
801 is_error: Some(false),
802 was_persisted: None,
803 }),
804 )
805 .await;
806
807 let update = rx.recv().await;
808 assert!(update.is_some());
809 let msg = update.unwrap().message.unwrap();
810 assert_eq!(msg.content, "command output");
811 }
812
813 #[test]
814 fn test_get_tool_concurrency_info() {
815 let tools = vec![ToolDefinition {
816 name: "Bash".to_string(),
817 description: "Execute commands".to_string(),
818 input_schema: ToolInputSchema {
819 schema_type: "object".to_string(),
820 properties: serde_json::json!({}),
821 required: None,
822 },
823 annotations: Some(ToolAnnotations {
824 concurrency_safe: Some(true),
825 ..Default::default()
826 }),
827 should_defer: None,
828 always_load: None,
829 is_mcp: None,
830 search_hint: None,
831 aliases: None,
832 user_facing_name: None,
833 interrupt_behavior: None,
834 }];
835 let calls = vec![ToolCall {
836 id: "1".to_string(),
837 r#type: "function".to_string(),
838 name: "Bash".to_string(),
839 arguments: serde_json::json!({}),
840 }];
841 let info = get_tool_concurrency_info(&calls, &tools);
842 assert_eq!(info.len(), 1);
843 assert!(info[0].2);
844 }
845}