agent_chain_core/tools/
structured.rs

1//! Structured tool that can operate on any number of inputs.
2//!
3//! This module provides the `StructuredTool` struct for creating tools
4//! that accept multiple typed arguments, mirroring `langchain_core.tools.structured`.
5
6use std::collections::HashMap;
7use std::fmt::Debug;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use serde_json::Value;
14
15use crate::error::{Error, Result};
16use crate::runnables::RunnableConfig;
17
18use super::base::{
19    ArgsSchema, BaseTool, FILTERED_ARGS, HandleToolError, HandleValidationError, ResponseFormat,
20    ToolException, ToolInput, ToolOutput,
21};
22
23/// Type alias for sync structured tool function.
24pub type StructuredToolFunc = Arc<dyn Fn(HashMap<String, Value>) -> Result<Value> + Send + Sync>;
25
26/// Type alias for async structured tool function.
27pub type AsyncStructuredToolFunc = Arc<
28    dyn Fn(HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>>
29        + Send
30        + Sync,
31>;
32
33/// Tool that can operate on any number of inputs.
34///
35/// Unlike `Tool`, which accepts a single string input, `StructuredTool`
36/// accepts a dictionary of typed arguments.
37pub struct StructuredTool {
38    /// The unique name of the tool.
39    name: String,
40    /// A description of what the tool does.
41    description: String,
42    /// The function to run when the tool is called.
43    func: Option<StructuredToolFunc>,
44    /// The asynchronous version of the function.
45    coroutine: Option<AsyncStructuredToolFunc>,
46    /// The input arguments' schema.
47    args_schema: ArgsSchema,
48    /// Whether to return the tool's output directly.
49    return_direct: bool,
50    /// Whether to log the tool's progress.
51    verbose: bool,
52    /// How to handle tool errors.
53    handle_tool_error: HandleToolError,
54    /// How to handle validation errors.
55    handle_validation_error: HandleValidationError,
56    /// The tool response format.
57    response_format: ResponseFormat,
58    /// Optional tags for the tool.
59    tags: Option<Vec<String>>,
60    /// Optional metadata for the tool.
61    metadata: Option<HashMap<String, Value>>,
62    /// Optional provider-specific extras.
63    extras: Option<HashMap<String, Value>>,
64}
65
66impl Debug for StructuredTool {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        f.debug_struct("StructuredTool")
69            .field("name", &self.name)
70            .field("description", &self.description)
71            .field("args_schema", &self.args_schema)
72            .field("return_direct", &self.return_direct)
73            .field("response_format", &self.response_format)
74            .finish()
75    }
76}
77
78impl StructuredTool {
79    /// Create a new StructuredTool.
80    pub fn new(
81        name: impl Into<String>,
82        description: impl Into<String>,
83        args_schema: ArgsSchema,
84    ) -> Self {
85        Self {
86            name: name.into(),
87            description: description.into(),
88            func: None,
89            coroutine: None,
90            args_schema,
91            return_direct: false,
92            verbose: false,
93            handle_tool_error: HandleToolError::Bool(false),
94            handle_validation_error: HandleValidationError::Bool(false),
95            response_format: ResponseFormat::Content,
96            tags: None,
97            metadata: None,
98            extras: None,
99        }
100    }
101
102    /// Set the sync function.
103    pub fn with_func(mut self, func: StructuredToolFunc) -> Self {
104        self.func = Some(func);
105        self
106    }
107
108    /// Set the async function.
109    pub fn with_coroutine(mut self, coroutine: AsyncStructuredToolFunc) -> Self {
110        self.coroutine = Some(coroutine);
111        self
112    }
113
114    /// Set whether to return directly.
115    pub fn with_return_direct(mut self, return_direct: bool) -> Self {
116        self.return_direct = return_direct;
117        self
118    }
119
120    /// Set the response format.
121    pub fn with_response_format(mut self, format: ResponseFormat) -> Self {
122        self.response_format = format;
123        self
124    }
125
126    /// Set tags.
127    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
128        self.tags = Some(tags);
129        self
130    }
131
132    /// Set metadata.
133    pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
134        self.metadata = Some(metadata);
135        self
136    }
137
138    /// Set extras.
139    pub fn with_extras(mut self, extras: HashMap<String, Value>) -> Self {
140        self.extras = Some(extras);
141        self
142    }
143
144    /// Set handle_tool_error.
145    pub fn with_handle_tool_error(mut self, handler: HandleToolError) -> Self {
146        self.handle_tool_error = handler;
147        self
148    }
149
150    /// Set handle_validation_error.
151    pub fn with_handle_validation_error(mut self, handler: HandleValidationError) -> Self {
152        self.handle_validation_error = handler;
153        self
154    }
155
156    /// Create a tool from a function.
157    ///
158    /// This is the main way to create a StructuredTool.
159    pub fn from_function<F>(
160        func: F,
161        name: impl Into<String>,
162        description: impl Into<String>,
163        args_schema: ArgsSchema,
164    ) -> Self
165    where
166        F: Fn(HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
167    {
168        Self::new(name, description, args_schema).with_func(Arc::new(func))
169    }
170
171    /// Create a tool from a sync and async function pair.
172    pub fn from_function_with_async<F, AF, Fut>(
173        func: F,
174        coroutine: AF,
175        name: impl Into<String>,
176        description: impl Into<String>,
177        args_schema: ArgsSchema,
178    ) -> Self
179    where
180        F: Fn(HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
181        AF: Fn(HashMap<String, Value>) -> Fut + Send + Sync + 'static,
182        Fut: Future<Output = Result<Value>> + Send + 'static,
183    {
184        Self::new(name, description, args_schema)
185            .with_func(Arc::new(func))
186            .with_coroutine(Arc::new(move |args| Box::pin(coroutine(args))))
187    }
188
189    /// Create a tool from an async function only.
190    pub fn from_async_function<AF, Fut>(
191        coroutine: AF,
192        name: impl Into<String>,
193        description: impl Into<String>,
194        args_schema: ArgsSchema,
195    ) -> Self
196    where
197        AF: Fn(HashMap<String, Value>) -> Fut + Send + Sync + 'static,
198        Fut: Future<Output = Result<Value>> + Send + 'static,
199    {
200        Self::new(name, description, args_schema)
201            .with_coroutine(Arc::new(move |args| Box::pin(coroutine(args))))
202    }
203
204    /// Extract the arguments from the tool input.
205    fn extract_args(&self, input: ToolInput) -> Result<HashMap<String, Value>> {
206        match input {
207            ToolInput::String(s) => {
208                // Try to parse as JSON
209                if let Ok(Value::Object(obj)) = serde_json::from_str(&s) {
210                    Ok(obj.into_iter().collect())
211                } else {
212                    // Use as single argument if schema has one field
213                    let props = self.args_schema.properties();
214                    if props.len() == 1 {
215                        let key = props.keys().next().unwrap().clone();
216                        let mut args = HashMap::new();
217                        args.insert(key, Value::String(s));
218                        Ok(args)
219                    } else {
220                        Err(Error::ToolInvocation(
221                            "String input not allowed for multi-argument tool".to_string(),
222                        ))
223                    }
224                }
225            }
226            ToolInput::Dict(d) => Ok(d),
227            ToolInput::ToolCall(tc) => {
228                let args = tc.args();
229                if let Some(obj) = args.as_object() {
230                    Ok(obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
231                } else {
232                    Err(Error::ToolInvocation(
233                        "ToolCall args must be an object".to_string(),
234                    ))
235                }
236            }
237        }
238    }
239
240    /// Filter out arguments that shouldn't be passed to the function.
241    fn filter_args(&self, args: HashMap<String, Value>) -> HashMap<String, Value> {
242        args.into_iter()
243            .filter(|(k, _)| !FILTERED_ARGS.contains(&k.as_str()))
244            .collect()
245    }
246}
247
248#[async_trait]
249impl BaseTool for StructuredTool {
250    fn name(&self) -> &str {
251        &self.name
252    }
253
254    fn description(&self) -> &str {
255        &self.description
256    }
257
258    fn args_schema(&self) -> Option<&ArgsSchema> {
259        Some(&self.args_schema)
260    }
261
262    fn return_direct(&self) -> bool {
263        self.return_direct
264    }
265
266    fn verbose(&self) -> bool {
267        self.verbose
268    }
269
270    fn tags(&self) -> Option<&[String]> {
271        self.tags.as_deref()
272    }
273
274    fn metadata(&self) -> Option<&HashMap<String, Value>> {
275        self.metadata.as_ref()
276    }
277
278    fn handle_tool_error(&self) -> &HandleToolError {
279        &self.handle_tool_error
280    }
281
282    fn handle_validation_error(&self) -> &HandleValidationError {
283        &self.handle_validation_error
284    }
285
286    fn response_format(&self) -> ResponseFormat {
287        self.response_format
288    }
289
290    fn extras(&self) -> Option<&HashMap<String, Value>> {
291        self.extras.as_ref()
292    }
293
294    fn run(&self, input: ToolInput, _config: Option<RunnableConfig>) -> Result<ToolOutput> {
295        let args = self.extract_args(input)?;
296        let filtered_args = self.filter_args(args);
297
298        if let Some(ref func) = self.func {
299            match func(filtered_args) {
300                Ok(result) => {
301                    match self.response_format {
302                        ResponseFormat::Content => match result {
303                            Value::String(s) => Ok(ToolOutput::String(s)),
304                            other => Ok(ToolOutput::Json(other)),
305                        },
306                        ResponseFormat::ContentAndArtifact => {
307                            // Expect a tuple [content, artifact]
308                            if let Value::Array(arr) = result {
309                                if arr.len() == 2 {
310                                    Ok(ToolOutput::ContentAndArtifact {
311                                        content: arr[0].clone(),
312                                        artifact: arr[1].clone(),
313                                    })
314                                } else {
315                                    Err(Error::ToolInvocation(
316                                        "content_and_artifact response must be a 2-tuple"
317                                            .to_string(),
318                                    ))
319                                }
320                            } else {
321                                Err(Error::ToolInvocation(
322                                    "content_and_artifact response must be a 2-tuple".to_string(),
323                                ))
324                            }
325                        }
326                    }
327                }
328                Err(e) => {
329                    if let Error::ToolInvocation(msg) = &e {
330                        let exc = ToolException::new(msg.clone());
331                        if let Some(handled) =
332                            super::base::handle_tool_error_impl(&exc, &self.handle_tool_error)
333                        {
334                            return Ok(ToolOutput::String(handled));
335                        }
336                    }
337                    Err(e)
338                }
339            }
340        } else {
341            Err(Error::ToolInvocation(
342                "StructuredTool does not support sync invocation.".to_string(),
343            ))
344        }
345    }
346
347    async fn arun(&self, input: ToolInput, config: Option<RunnableConfig>) -> Result<ToolOutput> {
348        let args = self.extract_args(input.clone())?;
349        let filtered_args = self.filter_args(args);
350
351        if let Some(ref coroutine) = self.coroutine {
352            match coroutine(filtered_args).await {
353                Ok(result) => match self.response_format {
354                    ResponseFormat::Content => match result {
355                        Value::String(s) => Ok(ToolOutput::String(s)),
356                        other => Ok(ToolOutput::Json(other)),
357                    },
358                    ResponseFormat::ContentAndArtifact => {
359                        if let Value::Array(arr) = result {
360                            if arr.len() == 2 {
361                                Ok(ToolOutput::ContentAndArtifact {
362                                    content: arr[0].clone(),
363                                    artifact: arr[1].clone(),
364                                })
365                            } else {
366                                Err(Error::ToolInvocation(
367                                    "content_and_artifact response must be a 2-tuple".to_string(),
368                                ))
369                            }
370                        } else {
371                            Err(Error::ToolInvocation(
372                                "content_and_artifact response must be a 2-tuple".to_string(),
373                            ))
374                        }
375                    }
376                },
377                Err(e) => {
378                    if let Error::ToolInvocation(msg) = &e {
379                        let exc = ToolException::new(msg.clone());
380                        if let Some(handled) =
381                            super::base::handle_tool_error_impl(&exc, &self.handle_tool_error)
382                        {
383                            return Ok(ToolOutput::String(handled));
384                        }
385                    }
386                    Err(e)
387                }
388            }
389        } else {
390            // Fall back to sync implementation
391            self.run(input, config)
392        }
393    }
394}
395
396/// Builder for creating StructuredTool instances.
397pub struct StructuredToolBuilder {
398    name: Option<String>,
399    description: Option<String>,
400    func: Option<StructuredToolFunc>,
401    coroutine: Option<AsyncStructuredToolFunc>,
402    args_schema: Option<ArgsSchema>,
403    return_direct: bool,
404    response_format: ResponseFormat,
405    parse_docstring: bool,
406    error_on_invalid_docstring: bool,
407    tags: Option<Vec<String>>,
408    metadata: Option<HashMap<String, Value>>,
409    extras: Option<HashMap<String, Value>>,
410}
411
412impl StructuredToolBuilder {
413    /// Create a new StructuredToolBuilder.
414    pub fn new() -> Self {
415        Self {
416            name: None,
417            description: None,
418            func: None,
419            coroutine: None,
420            args_schema: None,
421            return_direct: false,
422            response_format: ResponseFormat::Content,
423            parse_docstring: false,
424            error_on_invalid_docstring: false,
425            tags: None,
426            metadata: None,
427            extras: None,
428        }
429    }
430
431    /// Set the name.
432    pub fn name(mut self, name: impl Into<String>) -> Self {
433        self.name = Some(name.into());
434        self
435    }
436
437    /// Set the description.
438    pub fn description(mut self, description: impl Into<String>) -> Self {
439        self.description = Some(description.into());
440        self
441    }
442
443    /// Set the sync function.
444    pub fn func<F>(mut self, func: F) -> Self
445    where
446        F: Fn(HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
447    {
448        self.func = Some(Arc::new(func));
449        self
450    }
451
452    /// Set the async function.
453    pub fn coroutine<AF, Fut>(mut self, coroutine: AF) -> Self
454    where
455        AF: Fn(HashMap<String, Value>) -> Fut + Send + Sync + 'static,
456        Fut: Future<Output = Result<Value>> + Send + 'static,
457    {
458        self.coroutine = Some(Arc::new(move |args| Box::pin(coroutine(args))));
459        self
460    }
461
462    /// Set the args schema.
463    pub fn args_schema(mut self, schema: ArgsSchema) -> Self {
464        self.args_schema = Some(schema);
465        self
466    }
467
468    /// Set return_direct.
469    pub fn return_direct(mut self, return_direct: bool) -> Self {
470        self.return_direct = return_direct;
471        self
472    }
473
474    /// Set the response format.
475    pub fn response_format(mut self, format: ResponseFormat) -> Self {
476        self.response_format = format;
477        self
478    }
479
480    /// Set parse_docstring.
481    pub fn parse_docstring(mut self, parse: bool) -> Self {
482        self.parse_docstring = parse;
483        self
484    }
485
486    /// Set error_on_invalid_docstring.
487    pub fn error_on_invalid_docstring(mut self, error: bool) -> Self {
488        self.error_on_invalid_docstring = error;
489        self
490    }
491
492    /// Set tags.
493    pub fn tags(mut self, tags: Vec<String>) -> Self {
494        self.tags = Some(tags);
495        self
496    }
497
498    /// Set metadata.
499    pub fn metadata(mut self, metadata: HashMap<String, Value>) -> Self {
500        self.metadata = Some(metadata);
501        self
502    }
503
504    /// Set extras.
505    pub fn extras(mut self, extras: HashMap<String, Value>) -> Self {
506        self.extras = Some(extras);
507        self
508    }
509
510    /// Build the StructuredTool.
511    pub fn build(self) -> Result<StructuredTool> {
512        let name = self
513            .name
514            .ok_or_else(|| Error::InvalidConfig("Tool name is required".to_string()))?;
515        let description = self.description.unwrap_or_default();
516        let args_schema = self.args_schema.unwrap_or_default();
517
518        if self.func.is_none() && self.coroutine.is_none() {
519            return Err(Error::InvalidConfig(
520                "Function and/or coroutine must be provided".to_string(),
521            ));
522        }
523
524        Ok(StructuredTool {
525            name,
526            description,
527            func: self.func,
528            coroutine: self.coroutine,
529            args_schema,
530            return_direct: self.return_direct,
531            verbose: false,
532            handle_tool_error: HandleToolError::Bool(false),
533            handle_validation_error: HandleValidationError::Bool(false),
534            response_format: self.response_format,
535            tags: self.tags,
536            metadata: self.metadata,
537            extras: self.extras,
538        })
539    }
540}
541
542impl Default for StructuredToolBuilder {
543    fn default() -> Self {
544        Self::new()
545    }
546}
547
548/// Helper function to create an args schema from field definitions.
549pub fn create_args_schema(
550    name: &str,
551    properties: HashMap<String, Value>,
552    required: Vec<String>,
553    description: Option<&str>,
554) -> ArgsSchema {
555    let mut schema = serde_json::json!({
556        "type": "object",
557        "title": name,
558        "properties": properties,
559        "required": required,
560    });
561
562    if let Some(desc) = description {
563        schema["description"] = Value::String(desc.to_string());
564    }
565
566    ArgsSchema::JsonSchema(schema)
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572
573    #[test]
574    fn test_structured_tool_creation() {
575        let schema = create_args_schema(
576            "add_numbers",
577            {
578                let mut props = HashMap::new();
579                props.insert("a".to_string(), serde_json::json!({"type": "number"}));
580                props.insert("b".to_string(), serde_json::json!({"type": "number"}));
581                props
582            },
583            vec!["a".to_string(), "b".to_string()],
584            Some("Add two numbers"),
585        );
586
587        let tool = StructuredTool::from_function(
588            |args| {
589                let a = args.get("a").and_then(|v| v.as_f64()).unwrap_or(0.0);
590                let b = args.get("b").and_then(|v| v.as_f64()).unwrap_or(0.0);
591                Ok(Value::from(a + b))
592            },
593            "add",
594            "Adds two numbers together",
595            schema,
596        );
597
598        assert_eq!(tool.name(), "add");
599        assert_eq!(tool.description(), "Adds two numbers together");
600    }
601
602    #[test]
603    fn test_structured_tool_run() {
604        let schema = create_args_schema(
605            "multiply",
606            {
607                let mut props = HashMap::new();
608                props.insert("x".to_string(), serde_json::json!({"type": "number"}));
609                props.insert("y".to_string(), serde_json::json!({"type": "number"}));
610                props
611            },
612            vec!["x".to_string(), "y".to_string()],
613            None,
614        );
615
616        let tool = StructuredTool::from_function(
617            |args| {
618                let x = args.get("x").and_then(|v| v.as_f64()).unwrap_or(0.0);
619                let y = args.get("y").and_then(|v| v.as_f64()).unwrap_or(0.0);
620                Ok(Value::from(x * y))
621            },
622            "multiply",
623            "Multiplies two numbers",
624            schema,
625        );
626
627        let mut input = HashMap::new();
628        input.insert("x".to_string(), Value::from(3.0));
629        input.insert("y".to_string(), Value::from(4.0));
630
631        let result = tool.run(ToolInput::Dict(input), None).unwrap();
632        match result {
633            ToolOutput::Json(v) => assert_eq!(v.as_f64().unwrap(), 12.0),
634            _ => panic!("Expected Json output"),
635        }
636    }
637
638    #[test]
639    fn test_structured_tool_builder() {
640        let tool = StructuredToolBuilder::new()
641            .name("greet")
642            .description("Greets a person")
643            .args_schema(create_args_schema(
644                "greet",
645                {
646                    let mut props = HashMap::new();
647                    props.insert("name".to_string(), serde_json::json!({"type": "string"}));
648                    props
649                },
650                vec!["name".to_string()],
651                None,
652            ))
653            .func(|args| {
654                let name = args
655                    .get("name")
656                    .and_then(|v| v.as_str())
657                    .unwrap_or("stranger");
658                Ok(Value::String(format!("Hello, {}!", name)))
659            })
660            .return_direct(true)
661            .build()
662            .unwrap();
663
664        assert_eq!(tool.name(), "greet");
665        assert!(tool.return_direct());
666    }
667
668    #[test]
669    fn test_create_args_schema() {
670        let schema = create_args_schema(
671            "test_schema",
672            {
673                let mut props = HashMap::new();
674                props.insert("field1".to_string(), serde_json::json!({"type": "string"}));
675                props
676            },
677            vec!["field1".to_string()],
678            Some("Test description"),
679        );
680
681        let json = schema.to_json_schema();
682        assert_eq!(json["title"], "test_schema");
683        assert_eq!(json["description"], "Test description");
684        assert!(json["properties"]["field1"].is_object());
685    }
686
687    #[tokio::test]
688    async fn test_structured_tool_arun() {
689        let schema = create_args_schema(
690            "concat",
691            {
692                let mut props = HashMap::new();
693                props.insert("a".to_string(), serde_json::json!({"type": "string"}));
694                props.insert("b".to_string(), serde_json::json!({"type": "string"}));
695                props
696            },
697            vec!["a".to_string(), "b".to_string()],
698            None,
699        );
700
701        let tool = StructuredTool::from_function(
702            |args| {
703                let a = args.get("a").and_then(|v| v.as_str()).unwrap_or("");
704                let b = args.get("b").and_then(|v| v.as_str()).unwrap_or("");
705                Ok(Value::String(format!("{}{}", a, b)))
706            },
707            "concat",
708            "Concatenates two strings",
709            schema,
710        );
711
712        let mut input = HashMap::new();
713        input.insert("a".to_string(), Value::String("Hello".to_string()));
714        input.insert("b".to_string(), Value::String("World".to_string()));
715
716        let result = tool.arun(ToolInput::Dict(input), None).await.unwrap();
717        match result {
718            ToolOutput::String(s) => assert_eq!(s, "HelloWorld"),
719            _ => panic!("Expected String output"),
720        }
721    }
722}