1use async_trait::async_trait;
4use serde::{Deserialize, Serialize, de::DeserializeOwned};
5use serde_json::{Value, json};
6
7use crate::Result;
8use crate::llm::ToolDefinition;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ToolResult {
13 pub tool_call_id: String,
15 pub content: String,
17 #[serde(default)]
19 pub ephemeral: bool,
20}
21
22impl ToolResult {
23 pub fn new(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
24 Self {
25 tool_call_id: tool_call_id.into(),
26 content: content.into(),
27 ephemeral: false,
28 }
29 }
30
31 pub fn with_ephemeral(mut self, ephemeral: bool) -> Self {
32 self.ephemeral = ephemeral;
33 self
34 }
35}
36
37#[async_trait]
39pub trait Tool: Send + Sync {
40 fn name(&self) -> &str;
42
43 fn description(&self) -> &str;
45
46 fn definition(&self) -> ToolDefinition;
48
49 async fn execute(
51 &self,
52 args: Value,
53 overrides: Option<DependencyOverrides>,
54 ) -> Result<ToolResult>;
55
56 fn ephemeral(&self) -> EphemeralConfig {
58 EphemeralConfig::None
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Default)]
64pub enum EphemeralConfig {
65 #[default]
67 None,
68 Single,
70 Count(usize),
72}
73
74pub struct FunctionTool<T, F>
76where
77 T: DeserializeOwned + Send + Sync + 'static,
78 F: Fn(T) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send>>
79 + Send
80 + Sync,
81{
82 name: String,
83 description: String,
84 parameters_schema: serde_json::Map<String, Value>,
85 func: F,
86 ephemeral_config: EphemeralConfig,
87 _marker: std::marker::PhantomData<T>,
88}
89
90impl<T, F> FunctionTool<T, F>
91where
92 T: DeserializeOwned + Send + Sync + 'static,
93 F: Fn(T) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send>>
94 + Send
95 + Sync,
96{
97 pub fn new(
99 name: impl Into<String>,
100 description: impl Into<String>,
101 parameters_schema: serde_json::Map<String, Value>,
102 func: F,
103 ) -> Self {
104 Self {
105 name: name.into(),
106 description: description.into(),
107 parameters_schema,
108 func,
109 ephemeral_config: EphemeralConfig::None,
110 _marker: std::marker::PhantomData,
111 }
112 }
113
114 pub fn with_ephemeral(mut self, config: EphemeralConfig) -> Self {
116 self.ephemeral_config = config;
117 self
118 }
119}
120
121#[async_trait]
122impl<T, F> Tool for FunctionTool<T, F>
123where
124 T: DeserializeOwned + Send + Sync + 'static,
125 F: Fn(T) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send>>
126 + Send
127 + Sync,
128{
129 fn name(&self) -> &str {
130 &self.name
131 }
132
133 fn description(&self) -> &str {
134 &self.description
135 }
136
137 fn definition(&self) -> ToolDefinition {
138 ToolDefinition::new(
139 &self.name,
140 &self.description,
141 self.parameters_schema.clone(),
142 )
143 }
144
145 async fn execute(
146 &self,
147 args: Value,
148 _overrides: Option<DependencyOverrides>,
149 ) -> Result<ToolResult> {
150 let parsed: T = serde_json::from_value(args)?;
151 let content = (self.func)(parsed).await?;
152 Ok(ToolResult::new("", content)
153 .with_ephemeral(self.ephemeral_config != EphemeralConfig::None))
154 }
155
156 fn ephemeral(&self) -> EphemeralConfig {
157 self.ephemeral_config
158 }
159}
160
161pub struct ToolBuilder {
163 name: String,
164 description: String,
165 parameters_schema: serde_json::Map<String, Value>,
166 ephemeral: EphemeralConfig,
167}
168
169impl ToolBuilder {
170 pub fn new(name: impl Into<String>) -> Self {
171 Self {
172 name: name.into(),
173 description: String::new(),
174 parameters_schema: serde_json::Map::new(),
175 ephemeral: EphemeralConfig::None,
176 }
177 }
178
179 pub fn description(mut self, desc: impl Into<String>) -> Self {
180 self.description = desc.into();
181 self
182 }
183
184 pub fn parameter(mut self, name: &str, schema: Value) -> Self {
185 self.parameters_schema.insert(name.to_string(), schema);
186 self
187 }
188
189 pub fn string_param(self, name: &str, description: &str) -> Self {
190 self.parameter(
191 name,
192 json!({
193 "type": "string",
194 "description": description
195 }),
196 )
197 }
198
199 pub fn number_param(self, name: &str, description: &str) -> Self {
200 self.parameter(
201 name,
202 json!({
203 "type": "number",
204 "description": description
205 }),
206 )
207 }
208
209 pub fn boolean_param(self, name: &str, description: &str) -> Self {
210 self.parameter(
211 name,
212 json!({
213 "type": "boolean",
214 "description": description
215 }),
216 )
217 }
218
219 pub fn ephemeral(mut self, config: EphemeralConfig) -> Self {
220 self.ephemeral = config;
221 self
222 }
223
224 pub fn build<F, T>(self, func: F) -> Box<dyn Tool>
225 where
226 T: DeserializeOwned + Send + Sync + 'static,
227 F: Fn(T) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send>>
228 + Send
229 + Sync
230 + 'static,
231 {
232 let mut tool = FunctionTool::new(self.name, self.description, self.parameters_schema, func);
233 tool.ephemeral_config = self.ephemeral;
234 Box::new(tool)
235 }
236}
237
238pub type DependencyOverrides =
240 std::collections::HashMap<String, Box<dyn std::any::Any + Send + Sync>>;
241
242pub struct SimpleTool<F>
244where
245 F: Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send>>
246 + Send
247 + Sync,
248{
249 name: String,
250 description: String,
251 func: F,
252}
253
254impl<F> SimpleTool<F>
255where
256 F: Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send>>
257 + Send
258 + Sync,
259{
260 pub fn new(name: impl Into<String>, description: impl Into<String>, func: F) -> Self {
261 Self {
262 name: name.into(),
263 description: description.into(),
264 func,
265 }
266 }
267}
268
269#[async_trait]
270impl<F> Tool for SimpleTool<F>
271where
272 F: Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send>>
273 + Send
274 + Sync,
275{
276 fn name(&self) -> &str {
277 &self.name
278 }
279
280 fn description(&self) -> &str {
281 &self.description
282 }
283
284 fn definition(&self) -> ToolDefinition {
285 ToolDefinition::new(&self.name, &self.description, serde_json::Map::new())
286 }
287
288 async fn execute(
289 &self,
290 _args: Value,
291 _overrides: Option<DependencyOverrides>,
292 ) -> Result<ToolResult> {
293 let content = (self.func)().await?;
294 Ok(ToolResult::new("", content))
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[tokio::test]
303 async fn test_simple_tool() {
304 let tool = SimpleTool::new("ping", "Returns pong", || {
305 Box::pin(async { Ok("pong".to_string()) })
306 });
307
308 assert_eq!(tool.name(), "ping");
309
310 let result = tool.execute(json!({}), None).await.unwrap();
311 assert_eq!(result.content, "pong");
312 }
313
314 #[tokio::test]
315 async fn test_function_tool() {
316 #[derive(Deserialize)]
317 struct EchoArgs {
318 message: String,
319 }
320
321 let tool = FunctionTool::new(
322 "echo",
323 "Echoes the message back",
324 json!({
325 "type": "object",
326 "properties": {
327 "message": { "type": "string" }
328 },
329 "required": ["message"]
330 })
331 .as_object()
332 .unwrap()
333 .clone(),
334 |args: EchoArgs| Box::pin(async move { Ok(args.message) }),
335 );
336
337 let result = tool
338 .execute(json!({"message": "hello"}), None)
339 .await
340 .unwrap();
341 assert_eq!(result.content, "hello");
342 }
343}