lellm_agent/runtime/tools/
executor.rs1use std::borrow::Cow;
6use std::sync::Arc;
7
8use lellm_core::{Message, ToolCall, ToolError, ToolErrorKind, ToolResult};
9
10use super::super::event::AgentEvent;
11use super::super::retry::RetryPolicy;
12use super::{ToolCatalog, ToolFn, ToolSnapshot};
13use tokio::sync::mpsc::Sender;
14
15#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum ParallelSafety {
18 Safe,
19 CategoryExclusive,
20 Exclusive,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25pub struct ToolCategory(pub Cow<'static, str>);
26
27impl ToolCategory {
28 pub const FILE_IO: Self = Self(Cow::Borrowed("file_io"));
29 pub const NETWORK: Self = Self(Cow::Borrowed("network"));
30 pub const DATABASE: Self = Self(Cow::Borrowed("database"));
31
32 pub fn custom(name: impl Into<Cow<'static, str>>) -> Self {
33 Self(name.into())
34 }
35}
36
37#[derive(Clone)]
42pub struct ToolRegistration {
43 pub(crate) definition: lellm_core::ToolDefinition,
44 pub(crate) safety: ParallelSafety,
45 pub(crate) category: Option<ToolCategory>,
46 pub(crate) func: ToolFn,
47}
48
49impl ToolRegistration {
50 pub fn definition(&self) -> &lellm_core::ToolDefinition {
52 &self.definition
53 }
54
55 pub fn safety(&self) -> &ParallelSafety {
57 &self.safety
58 }
59
60 pub fn category(&self) -> Option<&ToolCategory> {
62 self.category.as_ref()
63 }
64
65 pub fn safe<F, Fut>(def: lellm_core::ToolDefinition, f: F) -> Self
66 where
67 F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
68 Fut: std::future::Future<Output = ToolResult> + Send + 'static,
69 {
70 Self {
71 definition: def,
72 safety: ParallelSafety::Safe,
73 category: None,
74 func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
75 }
76 }
77
78 pub fn safe_fn<T, F, Fut>(def: lellm_core::ToolDefinition, f: F) -> Self
83 where
84 T: for<'de> serde::Deserialize<'de> + Send + 'static,
85 F: Fn(T) -> Fut + Send + Sync + 'static,
86 Fut: std::future::Future<Output = ToolResult> + Send + 'static,
87 {
88 let f = Arc::new(f);
89 Self::safe(def, move |value| {
90 let f = Arc::clone(&f);
91 let result = serde_json::from_value::<T>(value.clone());
92 Box::pin(async move {
93 match result {
94 Ok(parsed) => f(parsed).await,
95 Err(e) => Err(ToolError::invalid_input(format!(
96 "invalid tool arguments: {e}"
97 ))),
98 }
99 })
100 })
101 }
102
103 pub fn category_exclusive<F, Fut>(
104 def: lellm_core::ToolDefinition,
105 category: ToolCategory,
106 f: F,
107 ) -> Self
108 where
109 F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
110 Fut: std::future::Future<Output = ToolResult> + Send + 'static,
111 {
112 Self {
113 definition: def,
114 safety: ParallelSafety::CategoryExclusive,
115 category: Some(category),
116 func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
117 }
118 }
119
120 pub fn exclusive<F, Fut>(def: lellm_core::ToolDefinition, f: F) -> Self
121 where
122 F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
123 Fut: std::future::Future<Output = ToolResult> + Send + 'static,
124 {
125 Self {
126 definition: def,
127 safety: ParallelSafety::Exclusive,
128 category: None,
129 func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
130 }
131 }
132}
133
134#[derive(Debug)]
142pub struct BatchExecutionResult {
143 pub results: Vec<Message>,
145 pub panicked: bool,
147}
148
149#[derive(Clone)]
154pub struct ToolExecutor {
155 catalog: Arc<dyn ToolCatalog>,
156 retry_policy: RetryPolicy,
157}
158
159impl ToolExecutor {
160 pub fn new(catalog: Arc<dyn ToolCatalog>) -> Self {
162 Self {
163 catalog,
164 retry_policy: RetryPolicy::default(),
165 }
166 }
167
168 pub fn with_catalog(catalog: Arc<dyn ToolCatalog>) -> Self {
170 Self::new(catalog)
171 }
172
173 pub fn with_retry_policy(catalog: Arc<dyn ToolCatalog>, policy: RetryPolicy) -> Self {
175 Self {
176 catalog,
177 retry_policy: policy,
178 }
179 }
180
181 pub fn set_retry_policy(&mut self, policy: RetryPolicy) {
183 self.retry_policy = policy;
184 }
185
186 pub fn retry_policy(&self) -> RetryPolicy {
188 self.retry_policy.clone()
189 }
190
191 pub async fn snapshot(&self) -> Arc<ToolSnapshot> {
195 self.catalog.snapshot().await
196 }
197
198 pub async fn execute_with_snapshot(
202 &self,
203 call: &ToolCall,
204 snapshot: &ToolSnapshot,
205 ) -> ToolResult {
206 match snapshot.get(&call.name) {
207 Some(entry) => {
208 self.retry_policy
209 .execute_with_retry(&entry.func, &call.arguments)
210 .await
211 }
212 None => Err(ToolError::not_found(format!("unknown tool: {}", call.name))),
213 }
214 }
215
216 pub async fn execute_with_emission(
218 &self,
219 call: &ToolCall,
220 snapshot: &ToolSnapshot,
221 tx: &Sender<AgentEvent>,
222 ) -> ToolResult {
223 match snapshot.get(&call.name) {
224 Some(entry) => {
225 self.retry_policy
226 .execute_with_retry_and_emission(&entry.func, &call.arguments, tx, &call.id)
227 .await
228 }
229 None => Err(ToolError::not_found(format!("unknown tool: {}", call.name))),
230 }
231 }
232}
233
234pub async fn execute_batch_with(
255 calls: &[ToolCall],
256 snapshot: &ToolSnapshot,
257 retry_policy: &RetryPolicy,
258) -> BatchExecutionResult {
259 if calls.is_empty() {
260 return BatchExecutionResult {
261 results: Vec::new(),
262 panicked: false,
263 };
264 }
265
266 let mut safe_calls: Vec<(usize, ToolCall)> = Vec::new();
268 let mut category_calls: std::collections::HashMap<ToolCategory, Vec<(usize, ToolCall)>> =
269 std::collections::HashMap::new();
270 let mut exclusive_calls: Vec<(usize, ToolCall)> = Vec::new();
271
272 for (idx, call) in calls.iter().enumerate() {
273 let safety = snapshot
274 .get(&call.name)
275 .map(|t| t.safety.clone())
276 .unwrap_or(ParallelSafety::Exclusive);
277
278 match safety {
279 ParallelSafety::Safe => safe_calls.push((idx, call.clone())),
280 ParallelSafety::CategoryExclusive => {
281 if let Some(cat) = snapshot.get(&call.name).and_then(|t| t.category.clone()) {
282 category_calls
283 .entry(cat)
284 .or_default()
285 .push((idx, call.clone()));
286 } else {
287 exclusive_calls.push((idx, call.clone()));
288 }
289 }
290 ParallelSafety::Exclusive => exclusive_calls.push((idx, call.clone())),
291 }
292 }
293
294 let mut group_handles: Vec<tokio::task::JoinHandle<Vec<(usize, Message)>>> = Vec::new();
296 let mut group_indices: Vec<Vec<usize>> = Vec::new();
297
298 let snapshot = Arc::new(snapshot.clone_for_spawn());
299 let retry_policy = retry_policy.clone();
300
301 if !safe_calls.is_empty() {
303 let s = Arc::clone(&snapshot);
304 let rp = retry_policy.clone();
305 let indices: Vec<usize> = safe_calls.iter().map(|(i, _)| *i).collect();
306 group_handles.push(tokio::spawn(async move {
307 run_parallel_indexed_with(&s, &rp, safe_calls).await
308 }));
309 group_indices.push(indices);
310 }
311
312 for group_calls in category_calls.into_values() {
314 let s = Arc::clone(&snapshot);
315 let rp = retry_policy.clone();
316 let indices: Vec<usize> = group_calls.iter().map(|(i, _)| *i).collect();
317 group_handles.push(tokio::spawn(async move {
318 run_serial_indexed_with(&s, &rp, group_calls).await
319 }));
320 group_indices.push(indices);
321 }
322
323 if !exclusive_calls.is_empty() {
325 let s = Arc::clone(&snapshot);
326 let rp = retry_policy.clone();
327 let indices: Vec<usize> = exclusive_calls.iter().map(|(i, _)| *i).collect();
328 group_handles.push(tokio::spawn(async move {
329 run_serial_indexed_with(&s, &rp, exclusive_calls).await
330 }));
331 group_indices.push(indices);
332 }
333
334 let mut results: Vec<Option<Message>> = vec![None; calls.len()];
336 let mut panicked = false;
337 let all_handles = futures_util::future::join_all(group_handles).await;
338
339 for (handle_result, indices) in all_handles.into_iter().zip(group_indices.into_iter()) {
340 match handle_result {
341 Ok(indexed_messages) => {
342 for (idx, msg) in indexed_messages {
343 results[idx] = Some(msg);
344 }
345 }
346 Err(join_err) => {
347 panicked = true;
348 for idx in indices {
349 let call = &calls[idx];
350 results[idx] = Some(Message::tool_result(
351 call,
352 &Err(ToolError {
353 kind: ToolErrorKind::Internal,
354 message: format!("tool group task panicked: {join_err}"),
355 }),
356 ));
357 }
358 }
359 }
360 }
361
362 BatchExecutionResult {
363 results: results.into_iter().flatten().collect(),
364 panicked,
365 }
366}
367
368impl ToolSnapshot {
371 pub fn clone_for_spawn(&self) -> Arc<indexmap::IndexMap<String, ToolRegistration>> {
373 self.tools.clone()
374 }
375}
376
377async fn run_parallel_indexed_with(
380 tools: &Arc<indexmap::IndexMap<String, ToolRegistration>>,
381 retry_policy: &RetryPolicy,
382 calls: Vec<(usize, ToolCall)>,
383) -> Vec<(usize, Message)> {
384 let handles: Vec<_> = calls
385 .iter()
386 .map(|(idx, call)| {
387 let tools = Arc::clone(tools);
388 let rp = retry_policy.clone();
389 let call = call.clone();
390 let idx = *idx;
391 tokio::spawn(async move {
392 let result = match tools.get(&call.name) {
393 Some(entry) => rp.execute_with_retry(&entry.func, &call.arguments).await,
394 None => Err(ToolError::not_found(format!("unknown tool: {}", call.name))),
395 };
396 (idx, Message::tool_result(&call, &result))
397 })
398 })
399 .collect();
400
401 let raw = futures_util::future::join_all(handles).await;
402 raw.into_iter()
403 .zip(calls.into_iter())
404 .map(|(h, (idx, call))| match h {
405 Ok((_, msg)) => (idx, msg),
406 Err(join_err) => (
407 idx,
408 Message::tool_result(
409 &call,
410 &Err(ToolError {
411 kind: ToolErrorKind::Internal,
412 message: format!("tool '{}' task panicked: {join_err}", call.name),
413 }),
414 ),
415 ),
416 })
417 .collect()
418}
419
420async fn run_serial_indexed_with(
423 tools: &Arc<indexmap::IndexMap<String, ToolRegistration>>,
424 retry_policy: &RetryPolicy,
425 calls: Vec<(usize, ToolCall)>,
426) -> Vec<(usize, Message)> {
427 let mut results = Vec::with_capacity(calls.len());
428 for (idx, call) in calls {
429 let exec_result = match tools.get(&call.name) {
430 Some(entry) => {
431 retry_policy
432 .execute_with_retry(&entry.func, &call.arguments)
433 .await
434 }
435 None => Err(ToolError::not_found(format!("unknown tool: {}", call.name))),
436 };
437 results.push((idx, Message::tool_result(&call, &exec_result)));
438 }
439 results
440}