lellm_agent/tools/
executor.rs1use std::borrow::Cow;
4use std::collections::HashMap;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use lellm_core::{Message, ToolCall};
9
10use super::ToolCallResult;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ParallelSafety {
15 Safe,
16 CategoryExclusive,
17 Exclusive,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct ToolCategory(pub Cow<'static, str>);
23
24impl ToolCategory {
25 pub const FILE_IO: Self = Self(Cow::Borrowed("file_io"));
26 pub const NETWORK: Self = Self(Cow::Borrowed("network"));
27 pub const DATABASE: Self = Self(Cow::Borrowed("database"));
28
29 pub fn custom(name: impl Into<Cow<'static, str>>) -> Self {
30 Self(name.into())
31 }
32}
33
34pub struct ToolRegistration {
36 safety: ParallelSafety,
37 #[allow(dead_code)]
38 category: Option<ToolCategory>,
39 func: ToolFn,
40}
41
42type ToolFn = Arc<
44 dyn Fn(&serde_json::Value) -> Pin<Box<dyn std::future::Future<Output = ToolCallResult> + Send>>
45 + Send
46 + Sync,
47>;
48
49impl ToolRegistration {
50 pub fn safe<F, Fut>(f: F) -> Self
51 where
52 F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
53 Fut: std::future::Future<Output = ToolCallResult> + Send + 'static,
54 {
55 Self {
56 safety: ParallelSafety::Safe,
57 category: None,
58 func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
59 }
60 }
61
62 pub fn category_exclusive<F, Fut>(category: ToolCategory, f: F) -> Self
63 where
64 F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
65 Fut: std::future::Future<Output = ToolCallResult> + Send + 'static,
66 {
67 Self {
68 safety: ParallelSafety::CategoryExclusive,
69 category: Some(category),
70 func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
71 }
72 }
73
74 pub fn exclusive<F, Fut>(f: F) -> Self
75 where
76 F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
77 Fut: std::future::Future<Output = ToolCallResult> + Send + 'static,
78 {
79 Self {
80 safety: ParallelSafety::Exclusive,
81 category: None,
82 func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
83 }
84 }
85}
86
87#[derive(Default)]
89pub struct ToolExecutor {
90 tools: HashMap<String, ToolFn>,
91 safety: HashMap<String, ParallelSafety>,
92}
93
94impl ToolExecutor {
95 pub fn new() -> Self {
96 Self {
97 tools: HashMap::new(),
98 safety: HashMap::new(),
99 }
100 }
101
102 pub fn register(&mut self, name: &str, reg: ToolRegistration) {
103 self.safety.insert(name.to_string(), reg.safety.clone());
104 self.tools.insert(name.to_string(), reg.func);
105 }
106
107 pub fn safety_for(&self, name: &str) -> ParallelSafety {
108 self.safety
109 .get(name)
110 .cloned()
111 .unwrap_or(ParallelSafety::Exclusive)
112 }
113
114 pub async fn execute(&self, call: &ToolCall) -> ToolCallResult {
115 match self.tools.get(&call.name) {
116 Some(tool_fn) => tool_fn(&call.arguments).await,
117 None => ToolCallResult::Err(format!("unknown tool: {}", call.name)),
118 }
119 }
120
121 pub async fn execute_batch(&self, calls: &[ToolCall]) -> Vec<Message> {
122 let mut results = Vec::new();
123 for call in calls {
124 let result = self.execute(call).await;
125 let content = match result {
126 ToolCallResult::Ok(s) => s,
127 ToolCallResult::Err(e) => format!("tool error: {e}"),
128 };
129 results.push(Message::ToolResult {
130 tool_call_id: call.id.clone(),
131 content: lellm_core::text_block(content),
132 });
133 }
134 results
135 }
136
137 pub fn partition_calls(&self, calls: &[ToolCall]) -> (Vec<ToolCall>, Vec<ToolCall>) {
139 let mut safe = Vec::new();
140 let mut exclusive = Vec::new();
141 for call in calls {
142 let safety = self.safety_for(&call.name);
143 match safety {
144 ParallelSafety::Safe => safe.push(call.clone()),
145 ParallelSafety::CategoryExclusive | ParallelSafety::Exclusive => {
146 exclusive.push(call.clone());
147 }
148 }
149 }
150 (safe, exclusive)
151 }
152}