agent_chain_core/tools/
simple.rs

1//! Tool that takes in function or coroutine directly.
2//!
3//! This module provides the `Tool` struct for creating simple single-input tools,
4//! mirroring `langchain_core.tools.simple`.
5
6use std::collections::HashMap;
7use std::fmt::Debug;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use serde_json::Value;
14
15use crate::error::{Error, Result};
16use crate::runnables::RunnableConfig;
17
18use super::base::{
19    ArgsSchema, BaseTool, HandleToolError, HandleValidationError, ResponseFormat, ToolException,
20    ToolInput, ToolOutput,
21};
22
23/// Type alias for sync tool function.
24pub type ToolFunc = Arc<dyn Fn(String) -> Result<String> + Send + Sync>;
25
26/// Type alias for async tool function.
27pub type AsyncToolFunc =
28    Arc<dyn Fn(String) -> Pin<Box<dyn Future<Output = Result<String>> + Send>> + Send + Sync>;
29
30/// Tool that takes in a function or coroutine directly.
31///
32/// This is the simplest form of tool that takes a single string input
33/// and returns a string output.
34pub struct Tool {
35    /// The unique name of the tool.
36    name: String,
37    /// A description of what the tool does.
38    description: String,
39    /// The function to run when the tool is called.
40    func: Option<ToolFunc>,
41    /// The asynchronous version of the function.
42    coroutine: Option<AsyncToolFunc>,
43    /// Optional schema for the tool's input arguments.
44    args_schema: Option<ArgsSchema>,
45    /// Whether to return the tool's output directly.
46    return_direct: bool,
47    /// Whether to log the tool's progress.
48    verbose: bool,
49    /// How to handle tool errors.
50    handle_tool_error: HandleToolError,
51    /// How to handle validation errors.
52    handle_validation_error: HandleValidationError,
53    /// The tool response format.
54    response_format: ResponseFormat,
55    /// Optional tags for the tool.
56    tags: Option<Vec<String>>,
57    /// Optional metadata for the tool.
58    metadata: Option<HashMap<String, Value>>,
59    /// Optional provider-specific extras.
60    extras: Option<HashMap<String, Value>>,
61}
62
63impl Debug for Tool {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        f.debug_struct("Tool")
66            .field("name", &self.name)
67            .field("description", &self.description)
68            .field("return_direct", &self.return_direct)
69            .field("response_format", &self.response_format)
70            .finish()
71    }
72}
73
74impl Tool {
75    /// Create a new Tool.
76    pub fn new(
77        name: impl Into<String>,
78        func: Option<ToolFunc>,
79        description: impl Into<String>,
80    ) -> Self {
81        Self {
82            name: name.into(),
83            description: description.into(),
84            func,
85            coroutine: None,
86            args_schema: None,
87            return_direct: false,
88            verbose: false,
89            handle_tool_error: HandleToolError::Bool(false),
90            handle_validation_error: HandleValidationError::Bool(false),
91            response_format: ResponseFormat::Content,
92            tags: None,
93            metadata: None,
94            extras: None,
95        }
96    }
97
98    /// Set the coroutine (async function).
99    pub fn with_coroutine(mut self, coroutine: AsyncToolFunc) -> Self {
100        self.coroutine = Some(coroutine);
101        self
102    }
103
104    /// Set the args schema.
105    pub fn with_args_schema(mut self, schema: ArgsSchema) -> Self {
106        self.args_schema = Some(schema);
107        self
108    }
109
110    /// Set whether to return directly.
111    pub fn with_return_direct(mut self, return_direct: bool) -> Self {
112        self.return_direct = return_direct;
113        self
114    }
115
116    /// Set the response format.
117    pub fn with_response_format(mut self, format: ResponseFormat) -> Self {
118        self.response_format = format;
119        self
120    }
121
122    /// Set tags.
123    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
124        self.tags = Some(tags);
125        self
126    }
127
128    /// Set metadata.
129    pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
130        self.metadata = Some(metadata);
131        self
132    }
133
134    /// Set extras.
135    pub fn with_extras(mut self, extras: HashMap<String, Value>) -> Self {
136        self.extras = Some(extras);
137        self
138    }
139
140    /// Create a Tool from a function.
141    pub fn from_function<F>(
142        func: F,
143        name: impl Into<String>,
144        description: impl Into<String>,
145    ) -> Self
146    where
147        F: Fn(String) -> Result<String> + Send + Sync + 'static,
148    {
149        Self::new(name, Some(Arc::new(func)), description)
150    }
151
152    /// Create a Tool from a sync and async function pair.
153    pub fn from_function_with_async<F, AF, Fut>(
154        func: F,
155        coroutine: AF,
156        name: impl Into<String>,
157        description: impl Into<String>,
158    ) -> Self
159    where
160        F: Fn(String) -> Result<String> + Send + Sync + 'static,
161        AF: Fn(String) -> Fut + Send + Sync + 'static,
162        Fut: Future<Output = Result<String>> + Send + 'static,
163    {
164        Self::new(name, Some(Arc::new(func)), description)
165            .with_coroutine(Arc::new(move |input| Box::pin(coroutine(input))))
166    }
167
168    /// Extract the single input from the tool input.
169    fn extract_single_input(&self, input: ToolInput) -> Result<String> {
170        match input {
171            ToolInput::String(s) => Ok(s),
172            ToolInput::Dict(d) => {
173                // For backwards compatibility, if run_input is a dict,
174                // extract the single value
175                let all_args: Vec<_> = d.values().collect();
176                if all_args.len() != 1 {
177                    return Err(Error::ToolInvocation(format!(
178                        "Too many arguments to single-input tool {}. Consider using StructuredTool instead. Args: {:?}",
179                        self.name, all_args
180                    )));
181                }
182                match all_args[0] {
183                    Value::String(s) => Ok(s.clone()),
184                    other => Ok(other.to_string()),
185                }
186            }
187            ToolInput::ToolCall(tc) => {
188                let args = tc.args();
189                if let Some(obj) = args.as_object() {
190                    let values: Vec<_> = obj.values().collect();
191                    if values.len() != 1 {
192                        return Err(Error::ToolInvocation(format!(
193                            "Too many arguments to single-input tool {}. Consider using StructuredTool instead.",
194                            self.name,
195                        )));
196                    }
197                    match &values[0] {
198                        Value::String(s) => Ok(s.clone()),
199                        other => Ok(other.to_string()),
200                    }
201                } else if let Some(s) = args.as_str() {
202                    Ok(s.to_string())
203                } else {
204                    Ok(args.to_string())
205                }
206            }
207        }
208    }
209}
210
211#[async_trait]
212impl BaseTool for Tool {
213    fn name(&self) -> &str {
214        &self.name
215    }
216
217    fn description(&self) -> &str {
218        &self.description
219    }
220
221    fn args_schema(&self) -> Option<&ArgsSchema> {
222        self.args_schema.as_ref()
223    }
224
225    fn return_direct(&self) -> bool {
226        self.return_direct
227    }
228
229    fn verbose(&self) -> bool {
230        self.verbose
231    }
232
233    fn tags(&self) -> Option<&[String]> {
234        self.tags.as_deref()
235    }
236
237    fn metadata(&self) -> Option<&HashMap<String, Value>> {
238        self.metadata.as_ref()
239    }
240
241    fn handle_tool_error(&self) -> &HandleToolError {
242        &self.handle_tool_error
243    }
244
245    fn handle_validation_error(&self) -> &HandleValidationError {
246        &self.handle_validation_error
247    }
248
249    fn response_format(&self) -> ResponseFormat {
250        self.response_format
251    }
252
253    fn extras(&self) -> Option<&HashMap<String, Value>> {
254        self.extras.as_ref()
255    }
256
257    fn args(&self) -> HashMap<String, Value> {
258        // For backwards compatibility, if the function signature is ambiguous,
259        // assume it takes a single string input.
260        if self.args_schema.is_some() {
261            return self.args_schema.as_ref().unwrap().properties();
262        }
263        let mut props = HashMap::new();
264        props.insert(
265            "tool_input".to_string(),
266            serde_json::json!({"type": "string"}),
267        );
268        props
269    }
270
271    fn run(&self, input: ToolInput, _config: Option<RunnableConfig>) -> Result<ToolOutput> {
272        let string_input = self.extract_single_input(input)?;
273
274        if let Some(ref func) = self.func {
275            match func(string_input) {
276                Ok(result) => Ok(ToolOutput::String(result)),
277                Err(e) => {
278                    // Check if we should handle the error
279                    if let Error::ToolInvocation(msg) = &e {
280                        let exc = ToolException::new(msg.clone());
281                        if let Some(handled) =
282                            super::base::handle_tool_error_impl(&exc, &self.handle_tool_error)
283                        {
284                            return Ok(ToolOutput::String(handled));
285                        }
286                    }
287                    Err(e)
288                }
289            }
290        } else {
291            Err(Error::ToolInvocation(
292                "Tool does not support sync invocation.".to_string(),
293            ))
294        }
295    }
296
297    async fn arun(&self, input: ToolInput, config: Option<RunnableConfig>) -> Result<ToolOutput> {
298        let string_input = self.extract_single_input(input.clone())?;
299
300        if let Some(ref coroutine) = self.coroutine {
301            match coroutine(string_input).await {
302                Ok(result) => Ok(ToolOutput::String(result)),
303                Err(e) => {
304                    if let Error::ToolInvocation(msg) = &e {
305                        let exc = ToolException::new(msg.clone());
306                        if let Some(handled) =
307                            super::base::handle_tool_error_impl(&exc, &self.handle_tool_error)
308                        {
309                            return Ok(ToolOutput::String(handled));
310                        }
311                    }
312                    Err(e)
313                }
314            }
315        } else {
316            // Fall back to sync implementation
317            self.run(input, config)
318        }
319    }
320}
321
322/// Builder for creating Tool instances.
323pub struct ToolBuilder {
324    name: Option<String>,
325    description: Option<String>,
326    func: Option<ToolFunc>,
327    coroutine: Option<AsyncToolFunc>,
328    args_schema: Option<ArgsSchema>,
329    return_direct: bool,
330    response_format: ResponseFormat,
331    tags: Option<Vec<String>>,
332    metadata: Option<HashMap<String, Value>>,
333    extras: Option<HashMap<String, Value>>,
334}
335
336impl ToolBuilder {
337    /// Create a new ToolBuilder.
338    pub fn new() -> Self {
339        Self {
340            name: None,
341            description: None,
342            func: None,
343            coroutine: None,
344            args_schema: None,
345            return_direct: false,
346            response_format: ResponseFormat::Content,
347            tags: None,
348            metadata: None,
349            extras: None,
350        }
351    }
352
353    /// Set the name.
354    pub fn name(mut self, name: impl Into<String>) -> Self {
355        self.name = Some(name.into());
356        self
357    }
358
359    /// Set the description.
360    pub fn description(mut self, description: impl Into<String>) -> Self {
361        self.description = Some(description.into());
362        self
363    }
364
365    /// Set the sync function.
366    pub fn func<F>(mut self, func: F) -> Self
367    where
368        F: Fn(String) -> Result<String> + Send + Sync + 'static,
369    {
370        self.func = Some(Arc::new(func));
371        self
372    }
373
374    /// Set the async function.
375    pub fn coroutine<AF, Fut>(mut self, coroutine: AF) -> Self
376    where
377        AF: Fn(String) -> Fut + Send + Sync + 'static,
378        Fut: Future<Output = Result<String>> + Send + 'static,
379    {
380        self.coroutine = Some(Arc::new(move |input| Box::pin(coroutine(input))));
381        self
382    }
383
384    /// Set the args schema.
385    pub fn args_schema(mut self, schema: ArgsSchema) -> Self {
386        self.args_schema = Some(schema);
387        self
388    }
389
390    /// Set return_direct.
391    pub fn return_direct(mut self, return_direct: bool) -> Self {
392        self.return_direct = return_direct;
393        self
394    }
395
396    /// Set the response format.
397    pub fn response_format(mut self, format: ResponseFormat) -> Self {
398        self.response_format = format;
399        self
400    }
401
402    /// Set tags.
403    pub fn tags(mut self, tags: Vec<String>) -> Self {
404        self.tags = Some(tags);
405        self
406    }
407
408    /// Set metadata.
409    pub fn metadata(mut self, metadata: HashMap<String, Value>) -> Self {
410        self.metadata = Some(metadata);
411        self
412    }
413
414    /// Set extras.
415    pub fn extras(mut self, extras: HashMap<String, Value>) -> Self {
416        self.extras = Some(extras);
417        self
418    }
419
420    /// Build the Tool.
421    pub fn build(self) -> Result<Tool> {
422        let name = self
423            .name
424            .ok_or_else(|| Error::InvalidConfig("Tool name is required".to_string()))?;
425        let description = self.description.unwrap_or_default();
426
427        if self.func.is_none() && self.coroutine.is_none() {
428            return Err(Error::InvalidConfig(
429                "Function and/or coroutine must be provided".to_string(),
430            ));
431        }
432
433        Ok(Tool {
434            name,
435            description,
436            func: self.func,
437            coroutine: self.coroutine,
438            args_schema: self.args_schema,
439            return_direct: self.return_direct,
440            verbose: false,
441            handle_tool_error: HandleToolError::Bool(false),
442            handle_validation_error: HandleValidationError::Bool(false),
443            response_format: self.response_format,
444            tags: self.tags,
445            metadata: self.metadata,
446            extras: self.extras,
447        })
448    }
449}
450
451impl Default for ToolBuilder {
452    fn default() -> Self {
453        Self::new()
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[test]
462    fn test_tool_creation() {
463        let tool = Tool::from_function(
464            |input| Ok(format!("Echo: {}", input)),
465            "echo",
466            "Echoes the input",
467        );
468
469        assert_eq!(tool.name(), "echo");
470        assert_eq!(tool.description(), "Echoes the input");
471    }
472
473    #[test]
474    fn test_tool_run() {
475        let tool = Tool::from_function(
476            |input| Ok(format!("Hello, {}!", input)),
477            "greet",
478            "Greets the user",
479        );
480
481        let result = tool
482            .run(ToolInput::String("World".to_string()), None)
483            .unwrap();
484        match result {
485            ToolOutput::String(s) => assert_eq!(s, "Hello, World!"),
486            _ => panic!("Expected String output"),
487        }
488    }
489
490    #[test]
491    fn test_tool_run_with_dict() {
492        let tool = Tool::from_function(
493            |input| Ok(format!("Got: {}", input)),
494            "process",
495            "Processes input",
496        );
497
498        let mut dict = HashMap::new();
499        dict.insert("query".to_string(), Value::String("test".to_string()));
500
501        let result = tool.run(ToolInput::Dict(dict), None).unwrap();
502        match result {
503            ToolOutput::String(s) => assert_eq!(s, "Got: test"),
504            _ => panic!("Expected String output"),
505        }
506    }
507
508    #[test]
509    fn test_tool_args() {
510        let tool = Tool::from_function(Ok, "identity", "Returns input unchanged");
511
512        let args = tool.args();
513        assert!(args.contains_key("tool_input"));
514    }
515
516    #[test]
517    fn test_tool_builder() {
518        let tool = ToolBuilder::new()
519            .name("test_tool")
520            .description("A test tool")
521            .func(Ok)
522            .return_direct(true)
523            .build()
524            .unwrap();
525
526        assert_eq!(tool.name(), "test_tool");
527        assert!(tool.return_direct());
528    }
529
530    #[tokio::test]
531    async fn test_tool_arun() {
532        let tool = Tool::from_function(
533            |input| Ok(format!("Sync: {}", input)),
534            "sync_tool",
535            "A sync tool",
536        );
537
538        // Should fall back to sync implementation
539        let result = tool
540            .arun(ToolInput::String("test".to_string()), None)
541            .await
542            .unwrap();
543        match result {
544            ToolOutput::String(s) => assert_eq!(s, "Sync: test"),
545            _ => panic!("Expected String output"),
546        }
547    }
548}