1use std::borrow::Cow;
13use std::future::Future;
14use std::ops::Deref;
15use std::pin::Pin;
16use std::sync::Arc;
17
18use crate::ToolResult;
19
20#[repr(transparent)]
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct ToolSchema(pub serde_json::Value);
29
30impl ToolSchema {
31 pub fn new(value: serde_json::Value) -> Self {
32 Self(value)
33 }
34}
35
36impl Deref for ToolSchema {
37 type Target = serde_json::Value;
38
39 fn deref(&self) -> &Self::Target {
40 &self.0
41 }
42}
43
44impl PartialEq<serde_json::Value> for ToolSchema {
45 fn eq(&self, other: &serde_json::Value) -> bool {
46 &self.0 == other
47 }
48}
49
50impl From<serde_json::Value> for ToolSchema {
51 fn from(value: serde_json::Value) -> Self {
52 Self(value)
53 }
54}
55
56impl From<ToolSchema> for serde_json::Value {
57 fn from(schema: ToolSchema) -> Self {
58 schema.0
59 }
60}
61
62impl serde::Serialize for ToolSchema {
63 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
64 where
65 S: serde::Serializer,
66 {
67 self.0.serialize(serializer)
68 }
69}
70
71impl<'de> serde::Deserialize<'de> for ToolSchema {
72 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
73 where
74 D: serde::Deserializer<'de>,
75 {
76 serde_json::Value::deserialize(deserializer).map(ToolSchema)
77 }
78}
79
80pub trait ToolArgs:
103 serde::de::DeserializeOwned + schemars::JsonSchema + Send + Sync + 'static
104{
105 const NAME: &'static str;
107 const DESCRIPTION: &'static str;
109 fn schema() -> ToolSchema;
111
112 fn parse(value: serde_json::Value) -> Result<Self, serde_json::Error>
116 where
117 Self: Sized,
118 {
119 serde_json::from_value(value)
120 }
121
122 fn tool_definition() -> ToolDefinition {
126 ToolDefinition {
127 name: Self::NAME.to_string(),
128 description: Self::DESCRIPTION.to_string(),
129 parameters: Self::schema().0,
130 cache_control: None,
131 }
132 }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq)]
139pub enum ParallelSafety {
140 Safe,
142 CategoryExclusive,
144 Exclusive,
146}
147
148#[derive(Debug, Clone, PartialEq, Eq, Hash)]
152pub struct ToolCategory(pub Cow<'static, str>);
153
154impl ToolCategory {
155 pub const FILE_IO: Self = Self(Cow::Borrowed("file_io"));
156 pub const NETWORK: Self = Self(Cow::Borrowed("network"));
157 pub const DATABASE: Self = Self(Cow::Borrowed("database"));
158
159 pub fn custom(name: impl Into<Cow<'static, str>>) -> Self {
160 Self(name.into())
161 }
162}
163
164#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
175pub struct ToolDefinition {
176 pub name: String,
178 pub description: String,
180 pub parameters: serde_json::Value,
182 #[serde(skip_serializing_if = "Option::is_none")]
184 pub cache_control: Option<crate::message::CacheControl>,
185}
186
187impl ToolDefinition {
188 pub fn with_cache(self, cache: crate::message::CacheControl) -> Self {
190 Self {
191 cache_control: Some(cache),
192 ..self
193 }
194 }
195
196 pub fn compute_and_clean_schema<S: schemars::JsonSchema>() -> ToolSchema {
203 let root = schemars::schema_for!(S);
204 let val = serde_json::to_value(&root)
205 .expect("Failed to serialize JsonSchema; this is a bug in schemars");
206 ToolSchema(Self::clean_schema(val))
207 }
208
209 fn clean_schema(mut value: serde_json::Value) -> serde_json::Value {
214 if let Some(obj) = value.as_object_mut() {
215 obj.remove("$schema");
217 obj.remove("$id");
218 obj.remove("title");
219 obj.remove("description");
220 }
221 value
222 }
223}
224
225pub type ToolFn = Arc<
229 dyn Fn(&serde_json::Value) -> Pin<Box<dyn Future<Output = ToolResult> + Send>> + Send + Sync,
230>;
231
232#[doc(hidden)]
237pub fn __tool_box<F>(f: F) -> Pin<Box<dyn Future<Output = ToolResult> + Send>>
238where
239 F: Future<Output = ToolResult> + Send + 'static,
240{
241 Box::pin(f)
242}
243
244#[derive(Clone)]
259pub struct ExecutableTool {
260 pub definition: ToolDefinition,
262 pub safety: ParallelSafety,
264 pub category: Option<ToolCategory>,
266 executor: ToolFn,
268}
269
270impl ExecutableTool {
271 pub fn definition(&self) -> &ToolDefinition {
275 &self.definition
276 }
277
278 pub fn safety(&self) -> &ParallelSafety {
280 &self.safety
281 }
282
283 pub fn category(&self) -> Option<&ToolCategory> {
285 self.category.as_ref()
286 }
287
288 pub fn execute(
290 &self,
291 args: &serde_json::Value,
292 ) -> Pin<Box<dyn Future<Output = ToolResult> + Send>> {
293 (self.executor)(args)
294 }
295
296 pub fn from_fn(
302 def: ToolDefinition,
303 safety: ParallelSafety,
304 category: Option<ToolCategory>,
305 f: ToolFn,
306 ) -> Self {
307 Self {
308 definition: def,
309 safety,
310 category,
311 executor: f,
312 }
313 }
314
315 pub fn safe<F, Fut>(def: ToolDefinition, f: F) -> Self
319 where
320 F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
321 Fut: Future<Output = ToolResult> + Send + 'static,
322 {
323 Self {
324 definition: def,
325 safety: ParallelSafety::Safe,
326 category: None,
327 executor: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
328 }
329 }
330
331 pub fn category_exclusive<F, Fut>(def: ToolDefinition, category: ToolCategory, f: F) -> Self
333 where
334 F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
335 Fut: Future<Output = ToolResult> + Send + 'static,
336 {
337 Self {
338 definition: def,
339 safety: ParallelSafety::CategoryExclusive,
340 category: Some(category),
341 executor: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
342 }
343 }
344
345 pub fn exclusive<F, Fut>(def: ToolDefinition, f: F) -> Self
347 where
348 F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
349 Fut: Future<Output = ToolResult> + Send + 'static,
350 {
351 Self {
352 definition: def,
353 safety: ParallelSafety::Exclusive,
354 category: None,
355 executor: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
356 }
357 }
358
359 pub fn safe_fn<T, F, Fut>(def: ToolDefinition, f: F) -> Self
366 where
367 T: ToolArgs + Send + 'static,
368 F: Fn(T) -> Fut + Send + Sync + 'static,
369 Fut: Future<Output = ToolResult> + Send + 'static,
370 {
371 let f = Arc::new(f);
372 Self::safe(def, move |value| {
373 let f = Arc::clone(&f);
374 let result = T::parse(value.clone());
375 async move {
376 match result {
377 Ok(parsed) => f(parsed).await,
378 Err(e) => Err(crate::ToolError::invalid_input(format!(
379 "invalid tool arguments: {e}"
380 ))),
381 }
382 }
383 })
384 }
385
386 pub fn category_exclusive_fn<T, F, Fut>(
388 def: ToolDefinition,
389 category: ToolCategory,
390 f: F,
391 ) -> Self
392 where
393 T: ToolArgs + Send + 'static,
394 F: Fn(T) -> Fut + Send + Sync + 'static,
395 Fut: Future<Output = ToolResult> + Send + 'static,
396 {
397 let f = Arc::new(f);
398 Self::category_exclusive(def, category, move |value| {
399 let f = Arc::clone(&f);
400 let result = T::parse(value.clone());
401 async move {
402 match result {
403 Ok(parsed) => f(parsed).await,
404 Err(e) => Err(crate::ToolError::invalid_input(format!(
405 "invalid tool arguments: {e}"
406 ))),
407 }
408 }
409 })
410 }
411
412 pub fn exclusive_fn<T, F, Fut>(def: ToolDefinition, f: F) -> Self
414 where
415 T: ToolArgs + Send + 'static,
416 F: Fn(T) -> Fut + Send + Sync + 'static,
417 Fut: Future<Output = ToolResult> + Send + 'static,
418 {
419 let f = Arc::new(f);
420 Self::exclusive(def, move |value| {
421 let f = Arc::clone(&f);
422 let result = T::parse(value.clone());
423 async move {
424 match result {
425 Ok(parsed) => f(parsed).await,
426 Err(e) => Err(crate::ToolError::invalid_input(format!(
427 "invalid tool arguments: {e}"
428 ))),
429 }
430 }
431 })
432 }
433}