1use std::collections::HashMap;
7use std::fmt::Debug;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use serde_json::Value;
14
15use crate::error::{Error, Result};
16use crate::runnables::RunnableConfig;
17
18use super::base::{
19 ArgsSchema, BaseTool, HandleToolError, HandleValidationError, ResponseFormat, ToolException,
20 ToolInput, ToolOutput,
21};
22
23pub type ToolFunc = Arc<dyn Fn(String) -> Result<String> + Send + Sync>;
25
26pub type AsyncToolFunc =
28 Arc<dyn Fn(String) -> Pin<Box<dyn Future<Output = Result<String>> + Send>> + Send + Sync>;
29
30pub struct Tool {
35 name: String,
37 description: String,
39 func: Option<ToolFunc>,
41 coroutine: Option<AsyncToolFunc>,
43 args_schema: Option<ArgsSchema>,
45 return_direct: bool,
47 verbose: bool,
49 handle_tool_error: HandleToolError,
51 handle_validation_error: HandleValidationError,
53 response_format: ResponseFormat,
55 tags: Option<Vec<String>>,
57 metadata: Option<HashMap<String, Value>>,
59 extras: Option<HashMap<String, Value>>,
61}
62
63impl Debug for Tool {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 f.debug_struct("Tool")
66 .field("name", &self.name)
67 .field("description", &self.description)
68 .field("return_direct", &self.return_direct)
69 .field("response_format", &self.response_format)
70 .finish()
71 }
72}
73
74impl Tool {
75 pub fn new(
77 name: impl Into<String>,
78 func: Option<ToolFunc>,
79 description: impl Into<String>,
80 ) -> Self {
81 Self {
82 name: name.into(),
83 description: description.into(),
84 func,
85 coroutine: None,
86 args_schema: None,
87 return_direct: false,
88 verbose: false,
89 handle_tool_error: HandleToolError::Bool(false),
90 handle_validation_error: HandleValidationError::Bool(false),
91 response_format: ResponseFormat::Content,
92 tags: None,
93 metadata: None,
94 extras: None,
95 }
96 }
97
98 pub fn with_coroutine(mut self, coroutine: AsyncToolFunc) -> Self {
100 self.coroutine = Some(coroutine);
101 self
102 }
103
104 pub fn with_args_schema(mut self, schema: ArgsSchema) -> Self {
106 self.args_schema = Some(schema);
107 self
108 }
109
110 pub fn with_return_direct(mut self, return_direct: bool) -> Self {
112 self.return_direct = return_direct;
113 self
114 }
115
116 pub fn with_response_format(mut self, format: ResponseFormat) -> Self {
118 self.response_format = format;
119 self
120 }
121
122 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
124 self.tags = Some(tags);
125 self
126 }
127
128 pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
130 self.metadata = Some(metadata);
131 self
132 }
133
134 pub fn with_extras(mut self, extras: HashMap<String, Value>) -> Self {
136 self.extras = Some(extras);
137 self
138 }
139
140 pub fn from_function<F>(
142 func: F,
143 name: impl Into<String>,
144 description: impl Into<String>,
145 ) -> Self
146 where
147 F: Fn(String) -> Result<String> + Send + Sync + 'static,
148 {
149 Self::new(name, Some(Arc::new(func)), description)
150 }
151
152 pub fn from_function_with_async<F, AF, Fut>(
154 func: F,
155 coroutine: AF,
156 name: impl Into<String>,
157 description: impl Into<String>,
158 ) -> Self
159 where
160 F: Fn(String) -> Result<String> + Send + Sync + 'static,
161 AF: Fn(String) -> Fut + Send + Sync + 'static,
162 Fut: Future<Output = Result<String>> + Send + 'static,
163 {
164 Self::new(name, Some(Arc::new(func)), description)
165 .with_coroutine(Arc::new(move |input| Box::pin(coroutine(input))))
166 }
167
168 fn extract_single_input(&self, input: ToolInput) -> Result<String> {
170 match input {
171 ToolInput::String(s) => Ok(s),
172 ToolInput::Dict(d) => {
173 let all_args: Vec<_> = d.values().collect();
176 if all_args.len() != 1 {
177 return Err(Error::ToolInvocation(format!(
178 "Too many arguments to single-input tool {}. Consider using StructuredTool instead. Args: {:?}",
179 self.name, all_args
180 )));
181 }
182 match all_args[0] {
183 Value::String(s) => Ok(s.clone()),
184 other => Ok(other.to_string()),
185 }
186 }
187 ToolInput::ToolCall(tc) => {
188 let args = tc.args();
189 if let Some(obj) = args.as_object() {
190 let values: Vec<_> = obj.values().collect();
191 if values.len() != 1 {
192 return Err(Error::ToolInvocation(format!(
193 "Too many arguments to single-input tool {}. Consider using StructuredTool instead.",
194 self.name,
195 )));
196 }
197 match &values[0] {
198 Value::String(s) => Ok(s.clone()),
199 other => Ok(other.to_string()),
200 }
201 } else if let Some(s) = args.as_str() {
202 Ok(s.to_string())
203 } else {
204 Ok(args.to_string())
205 }
206 }
207 }
208 }
209}
210
211#[async_trait]
212impl BaseTool for Tool {
213 fn name(&self) -> &str {
214 &self.name
215 }
216
217 fn description(&self) -> &str {
218 &self.description
219 }
220
221 fn args_schema(&self) -> Option<&ArgsSchema> {
222 self.args_schema.as_ref()
223 }
224
225 fn return_direct(&self) -> bool {
226 self.return_direct
227 }
228
229 fn verbose(&self) -> bool {
230 self.verbose
231 }
232
233 fn tags(&self) -> Option<&[String]> {
234 self.tags.as_deref()
235 }
236
237 fn metadata(&self) -> Option<&HashMap<String, Value>> {
238 self.metadata.as_ref()
239 }
240
241 fn handle_tool_error(&self) -> &HandleToolError {
242 &self.handle_tool_error
243 }
244
245 fn handle_validation_error(&self) -> &HandleValidationError {
246 &self.handle_validation_error
247 }
248
249 fn response_format(&self) -> ResponseFormat {
250 self.response_format
251 }
252
253 fn extras(&self) -> Option<&HashMap<String, Value>> {
254 self.extras.as_ref()
255 }
256
257 fn args(&self) -> HashMap<String, Value> {
258 if self.args_schema.is_some() {
261 return self.args_schema.as_ref().unwrap().properties();
262 }
263 let mut props = HashMap::new();
264 props.insert(
265 "tool_input".to_string(),
266 serde_json::json!({"type": "string"}),
267 );
268 props
269 }
270
271 fn run(&self, input: ToolInput, _config: Option<RunnableConfig>) -> Result<ToolOutput> {
272 let string_input = self.extract_single_input(input)?;
273
274 if let Some(ref func) = self.func {
275 match func(string_input) {
276 Ok(result) => Ok(ToolOutput::String(result)),
277 Err(e) => {
278 if let Error::ToolInvocation(msg) = &e {
280 let exc = ToolException::new(msg.clone());
281 if let Some(handled) =
282 super::base::handle_tool_error_impl(&exc, &self.handle_tool_error)
283 {
284 return Ok(ToolOutput::String(handled));
285 }
286 }
287 Err(e)
288 }
289 }
290 } else {
291 Err(Error::ToolInvocation(
292 "Tool does not support sync invocation.".to_string(),
293 ))
294 }
295 }
296
297 async fn arun(&self, input: ToolInput, config: Option<RunnableConfig>) -> Result<ToolOutput> {
298 let string_input = self.extract_single_input(input.clone())?;
299
300 if let Some(ref coroutine) = self.coroutine {
301 match coroutine(string_input).await {
302 Ok(result) => Ok(ToolOutput::String(result)),
303 Err(e) => {
304 if let Error::ToolInvocation(msg) = &e {
305 let exc = ToolException::new(msg.clone());
306 if let Some(handled) =
307 super::base::handle_tool_error_impl(&exc, &self.handle_tool_error)
308 {
309 return Ok(ToolOutput::String(handled));
310 }
311 }
312 Err(e)
313 }
314 }
315 } else {
316 self.run(input, config)
318 }
319 }
320}
321
322pub struct ToolBuilder {
324 name: Option<String>,
325 description: Option<String>,
326 func: Option<ToolFunc>,
327 coroutine: Option<AsyncToolFunc>,
328 args_schema: Option<ArgsSchema>,
329 return_direct: bool,
330 response_format: ResponseFormat,
331 tags: Option<Vec<String>>,
332 metadata: Option<HashMap<String, Value>>,
333 extras: Option<HashMap<String, Value>>,
334}
335
336impl ToolBuilder {
337 pub fn new() -> Self {
339 Self {
340 name: None,
341 description: None,
342 func: None,
343 coroutine: None,
344 args_schema: None,
345 return_direct: false,
346 response_format: ResponseFormat::Content,
347 tags: None,
348 metadata: None,
349 extras: None,
350 }
351 }
352
353 pub fn name(mut self, name: impl Into<String>) -> Self {
355 self.name = Some(name.into());
356 self
357 }
358
359 pub fn description(mut self, description: impl Into<String>) -> Self {
361 self.description = Some(description.into());
362 self
363 }
364
365 pub fn func<F>(mut self, func: F) -> Self
367 where
368 F: Fn(String) -> Result<String> + Send + Sync + 'static,
369 {
370 self.func = Some(Arc::new(func));
371 self
372 }
373
374 pub fn coroutine<AF, Fut>(mut self, coroutine: AF) -> Self
376 where
377 AF: Fn(String) -> Fut + Send + Sync + 'static,
378 Fut: Future<Output = Result<String>> + Send + 'static,
379 {
380 self.coroutine = Some(Arc::new(move |input| Box::pin(coroutine(input))));
381 self
382 }
383
384 pub fn args_schema(mut self, schema: ArgsSchema) -> Self {
386 self.args_schema = Some(schema);
387 self
388 }
389
390 pub fn return_direct(mut self, return_direct: bool) -> Self {
392 self.return_direct = return_direct;
393 self
394 }
395
396 pub fn response_format(mut self, format: ResponseFormat) -> Self {
398 self.response_format = format;
399 self
400 }
401
402 pub fn tags(mut self, tags: Vec<String>) -> Self {
404 self.tags = Some(tags);
405 self
406 }
407
408 pub fn metadata(mut self, metadata: HashMap<String, Value>) -> Self {
410 self.metadata = Some(metadata);
411 self
412 }
413
414 pub fn extras(mut self, extras: HashMap<String, Value>) -> Self {
416 self.extras = Some(extras);
417 self
418 }
419
420 pub fn build(self) -> Result<Tool> {
422 let name = self
423 .name
424 .ok_or_else(|| Error::InvalidConfig("Tool name is required".to_string()))?;
425 let description = self.description.unwrap_or_default();
426
427 if self.func.is_none() && self.coroutine.is_none() {
428 return Err(Error::InvalidConfig(
429 "Function and/or coroutine must be provided".to_string(),
430 ));
431 }
432
433 Ok(Tool {
434 name,
435 description,
436 func: self.func,
437 coroutine: self.coroutine,
438 args_schema: self.args_schema,
439 return_direct: self.return_direct,
440 verbose: false,
441 handle_tool_error: HandleToolError::Bool(false),
442 handle_validation_error: HandleValidationError::Bool(false),
443 response_format: self.response_format,
444 tags: self.tags,
445 metadata: self.metadata,
446 extras: self.extras,
447 })
448 }
449}
450
451impl Default for ToolBuilder {
452 fn default() -> Self {
453 Self::new()
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[test]
462 fn test_tool_creation() {
463 let tool = Tool::from_function(
464 |input| Ok(format!("Echo: {}", input)),
465 "echo",
466 "Echoes the input",
467 );
468
469 assert_eq!(tool.name(), "echo");
470 assert_eq!(tool.description(), "Echoes the input");
471 }
472
473 #[test]
474 fn test_tool_run() {
475 let tool = Tool::from_function(
476 |input| Ok(format!("Hello, {}!", input)),
477 "greet",
478 "Greets the user",
479 );
480
481 let result = tool
482 .run(ToolInput::String("World".to_string()), None)
483 .unwrap();
484 match result {
485 ToolOutput::String(s) => assert_eq!(s, "Hello, World!"),
486 _ => panic!("Expected String output"),
487 }
488 }
489
490 #[test]
491 fn test_tool_run_with_dict() {
492 let tool = Tool::from_function(
493 |input| Ok(format!("Got: {}", input)),
494 "process",
495 "Processes input",
496 );
497
498 let mut dict = HashMap::new();
499 dict.insert("query".to_string(), Value::String("test".to_string()));
500
501 let result = tool.run(ToolInput::Dict(dict), None).unwrap();
502 match result {
503 ToolOutput::String(s) => assert_eq!(s, "Got: test"),
504 _ => panic!("Expected String output"),
505 }
506 }
507
508 #[test]
509 fn test_tool_args() {
510 let tool = Tool::from_function(Ok, "identity", "Returns input unchanged");
511
512 let args = tool.args();
513 assert!(args.contains_key("tool_input"));
514 }
515
516 #[test]
517 fn test_tool_builder() {
518 let tool = ToolBuilder::new()
519 .name("test_tool")
520 .description("A test tool")
521 .func(Ok)
522 .return_direct(true)
523 .build()
524 .unwrap();
525
526 assert_eq!(tool.name(), "test_tool");
527 assert!(tool.return_direct());
528 }
529
530 #[tokio::test]
531 async fn test_tool_arun() {
532 let tool = Tool::from_function(
533 |input| Ok(format!("Sync: {}", input)),
534 "sync_tool",
535 "A sync tool",
536 );
537
538 let result = tool
540 .arun(ToolInput::String("test".to_string()), None)
541 .await
542 .unwrap();
543 match result {
544 ToolOutput::String(s) => assert_eq!(s, "Sync: test"),
545 _ => panic!("Expected String output"),
546 }
547 }
548}