lash_core/tool_dispatch/
scheduling.rs1use std::future::Future;
2use std::sync::Arc;
3
4use futures_util::stream::{FuturesUnordered, StreamExt};
5
6use crate::{ProgressSender, ToolCallRecord, ToolScheduling};
7
8use super::context::ToolDispatchContext;
9use super::preparation::dispatch_tool_call;
10
11#[derive(Clone)]
12pub struct ParallelToolCallSpec {
13 pub index: usize,
14 pub tool_name: String,
15 pub args: serde_json::Value,
16}
17
18#[derive(Clone)]
19pub struct ParallelToolCallOutcome {
20 pub index: usize,
21 pub record: ToolCallRecord,
22}
23
24pub(crate) async fn dispatch_parallel_tool_call(
25 context: Arc<ToolDispatchContext<'_>>,
26 spec: ParallelToolCallSpec,
27 progress: Option<ProgressSender>,
28) -> ParallelToolCallOutcome {
29 let outcome = dispatch_tool_call(&context, spec.tool_name, spec.args, progress.as_ref()).await;
30 ParallelToolCallOutcome {
31 index: spec.index,
32 record: outcome.record,
33 }
34}
35
36pub(crate) fn resolve_tool_scheduling(
40 context: &ToolDispatchContext<'_>,
41 tool_name: &str,
42) -> ToolScheduling {
43 context
44 .surface
45 .tools
46 .iter()
47 .find(|def| def.manifest.name == tool_name)
48 .map(|def| def.manifest.scheduling)
49 .unwrap_or_default()
50}
51
52pub(crate) async fn schedule_tool_batch<T, O, IndexOf, SchedulingOf, Run, Fut>(
58 items: Vec<T>,
59 index_of: IndexOf,
60 scheduling_of: SchedulingOf,
61 run: Run,
62) -> Vec<O>
63where
64 T: Send + 'static,
65 O: Send + 'static,
66 IndexOf: Fn(&T) -> usize,
67 SchedulingOf: Fn(&T) -> ToolScheduling,
68 Run: Fn(T) -> Fut,
69 Fut: Future<Output = O> + Send,
70{
71 let mut parallel_items = Vec::new();
72 let mut serial_items = Vec::new();
73 for item in items {
74 let index = index_of(&item);
75 match scheduling_of(&item) {
76 ToolScheduling::Parallel => parallel_items.push((index, item)),
77 ToolScheduling::Serial => serial_items.push((index, item)),
78 }
79 }
80
81 let mut outcomes = Vec::new();
82
83 let mut pending = FuturesUnordered::new();
84 for (index, item) in parallel_items {
85 let future = run(item);
86 pending.push(async move { (index, future.await) });
87 }
88 while let Some(outcome) = pending.next().await {
89 outcomes.push(outcome);
90 }
91
92 serial_items.sort_by_key(|(index, _)| *index);
93 for (index, item) in serial_items {
94 outcomes.push((index, run(item).await));
95 }
96
97 outcomes.sort_by_key(|(index, _)| *index);
98 outcomes.into_iter().map(|(_, outcome)| outcome).collect()
99}
100
101pub async fn dispatch_parallel_tool_calls(
103 context: Arc<ToolDispatchContext<'_>>,
104 specs: Vec<ParallelToolCallSpec>,
105 progress: Option<&ProgressSender>,
106) -> Vec<ParallelToolCallOutcome> {
107 let progress = progress.cloned();
108 schedule_tool_batch(
109 specs,
110 |spec| spec.index,
111 {
112 let context = Arc::clone(&context);
113 move |spec| resolve_tool_scheduling(&context, &spec.tool_name)
114 },
115 move |spec| dispatch_parallel_tool_call(Arc::clone(&context), spec, progress.clone()),
116 )
117 .await
118}