Skip to main content

lash_core/tool_dispatch/
scheduling.rs

1use std::future::Future;
2use std::sync::Arc;
3
4use futures_util::stream::{FuturesUnordered, StreamExt};
5
6use crate::{ProgressSender, ToolCallRecord, ToolScheduling};
7
8use super::context::ToolDispatchContext;
9use super::preparation::dispatch_tool_call;
10
11#[derive(Clone)]
12pub struct ParallelToolCallSpec {
13    pub index: usize,
14    pub tool_name: String,
15    pub args: serde_json::Value,
16}
17
18#[derive(Clone)]
19pub struct ParallelToolCallOutcome {
20    pub index: usize,
21    pub record: ToolCallRecord,
22}
23
24pub(crate) async fn dispatch_parallel_tool_call(
25    context: Arc<ToolDispatchContext<'_>>,
26    spec: ParallelToolCallSpec,
27    progress: Option<ProgressSender>,
28) -> ParallelToolCallOutcome {
29    let outcome = dispatch_tool_call(&context, spec.tool_name, spec.args, progress.as_ref()).await;
30    ParallelToolCallOutcome {
31        index: spec.index,
32        record: outcome.record,
33    }
34}
35
36/// Resolve the [`ToolScheduling`] declared on a tool's definition. Unknown
37/// tool names default to [`ToolScheduling::Parallel`] — the dispatcher
38/// will still surface an "unknown tool" error via the normal path.
39pub(crate) fn resolve_tool_scheduling(
40    context: &ToolDispatchContext<'_>,
41    tool_name: &str,
42) -> ToolScheduling {
43    context
44        .surface
45        .tools
46        .iter()
47        .find(|def| def.manifest.name == tool_name)
48        .map(|def| def.manifest.scheduling)
49        .unwrap_or_default()
50}
51
52/// Schedule a batch using Lash's tool execution policy.
53///
54/// Parallel-safe tools run concurrently first, then serial tools run
55/// one-at-a-time in original index order. Returned outputs are sorted by the
56/// same original index so callers keep their source/model ordering.
57pub(crate) async fn schedule_tool_batch<T, O, IndexOf, SchedulingOf, Run, Fut>(
58    items: Vec<T>,
59    index_of: IndexOf,
60    scheduling_of: SchedulingOf,
61    run: Run,
62) -> Vec<O>
63where
64    T: Send + 'static,
65    O: Send + 'static,
66    IndexOf: Fn(&T) -> usize,
67    SchedulingOf: Fn(&T) -> ToolScheduling,
68    Run: Fn(T) -> Fut,
69    Fut: Future<Output = O> + Send,
70{
71    let mut parallel_items = Vec::new();
72    let mut serial_items = Vec::new();
73    for item in items {
74        let index = index_of(&item);
75        match scheduling_of(&item) {
76            ToolScheduling::Parallel => parallel_items.push((index, item)),
77            ToolScheduling::Serial => serial_items.push((index, item)),
78        }
79    }
80
81    let mut outcomes = Vec::new();
82
83    let mut pending = FuturesUnordered::new();
84    for (index, item) in parallel_items {
85        let future = run(item);
86        pending.push(async move { (index, future.await) });
87    }
88    while let Some(outcome) = pending.next().await {
89        outcomes.push(outcome);
90    }
91
92    serial_items.sort_by_key(|(index, _)| *index);
93    for (index, item) in serial_items {
94        outcomes.push((index, run(item).await));
95    }
96
97    outcomes.sort_by_key(|(index, _)| *index);
98    outcomes.into_iter().map(|(_, outcome)| outcome).collect()
99}
100
101/// Dispatch a batch of tool calls produced by one model response.
102pub async fn dispatch_parallel_tool_calls(
103    context: Arc<ToolDispatchContext<'_>>,
104    specs: Vec<ParallelToolCallSpec>,
105    progress: Option<&ProgressSender>,
106) -> Vec<ParallelToolCallOutcome> {
107    let progress = progress.cloned();
108    schedule_tool_batch(
109        specs,
110        |spec| spec.index,
111        {
112            let context = Arc::clone(&context);
113            move |spec| resolve_tool_scheduling(&context, &spec.tool_name)
114        },
115        move |spec| dispatch_parallel_tool_call(Arc::clone(&context), spec, progress.clone()),
116    )
117    .await
118}