1use std::collections::HashMap;
7use std::fmt::Debug;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use serde_json::Value;
14
15use crate::error::{Error, Result};
16use crate::runnables::RunnableConfig;
17
18use super::base::{
19 ArgsSchema, BaseTool, FILTERED_ARGS, HandleToolError, HandleValidationError, ResponseFormat,
20 ToolException, ToolInput, ToolOutput,
21};
22
23pub type StructuredToolFunc = Arc<dyn Fn(HashMap<String, Value>) -> Result<Value> + Send + Sync>;
25
26pub type AsyncStructuredToolFunc = Arc<
28 dyn Fn(HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>>
29 + Send
30 + Sync,
31>;
32
33pub struct StructuredTool {
38 name: String,
40 description: String,
42 func: Option<StructuredToolFunc>,
44 coroutine: Option<AsyncStructuredToolFunc>,
46 args_schema: ArgsSchema,
48 return_direct: bool,
50 verbose: bool,
52 handle_tool_error: HandleToolError,
54 handle_validation_error: HandleValidationError,
56 response_format: ResponseFormat,
58 tags: Option<Vec<String>>,
60 metadata: Option<HashMap<String, Value>>,
62 extras: Option<HashMap<String, Value>>,
64}
65
66impl Debug for StructuredTool {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("StructuredTool")
69 .field("name", &self.name)
70 .field("description", &self.description)
71 .field("args_schema", &self.args_schema)
72 .field("return_direct", &self.return_direct)
73 .field("response_format", &self.response_format)
74 .finish()
75 }
76}
77
78impl StructuredTool {
79 pub fn new(
81 name: impl Into<String>,
82 description: impl Into<String>,
83 args_schema: ArgsSchema,
84 ) -> Self {
85 Self {
86 name: name.into(),
87 description: description.into(),
88 func: None,
89 coroutine: None,
90 args_schema,
91 return_direct: false,
92 verbose: false,
93 handle_tool_error: HandleToolError::Bool(false),
94 handle_validation_error: HandleValidationError::Bool(false),
95 response_format: ResponseFormat::Content,
96 tags: None,
97 metadata: None,
98 extras: None,
99 }
100 }
101
102 pub fn with_func(mut self, func: StructuredToolFunc) -> Self {
104 self.func = Some(func);
105 self
106 }
107
108 pub fn with_coroutine(mut self, coroutine: AsyncStructuredToolFunc) -> Self {
110 self.coroutine = Some(coroutine);
111 self
112 }
113
114 pub fn with_return_direct(mut self, return_direct: bool) -> Self {
116 self.return_direct = return_direct;
117 self
118 }
119
120 pub fn with_response_format(mut self, format: ResponseFormat) -> Self {
122 self.response_format = format;
123 self
124 }
125
126 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
128 self.tags = Some(tags);
129 self
130 }
131
132 pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
134 self.metadata = Some(metadata);
135 self
136 }
137
138 pub fn with_extras(mut self, extras: HashMap<String, Value>) -> Self {
140 self.extras = Some(extras);
141 self
142 }
143
144 pub fn with_handle_tool_error(mut self, handler: HandleToolError) -> Self {
146 self.handle_tool_error = handler;
147 self
148 }
149
150 pub fn with_handle_validation_error(mut self, handler: HandleValidationError) -> Self {
152 self.handle_validation_error = handler;
153 self
154 }
155
156 pub fn from_function<F>(
160 func: F,
161 name: impl Into<String>,
162 description: impl Into<String>,
163 args_schema: ArgsSchema,
164 ) -> Self
165 where
166 F: Fn(HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
167 {
168 Self::new(name, description, args_schema).with_func(Arc::new(func))
169 }
170
171 pub fn from_function_with_async<F, AF, Fut>(
173 func: F,
174 coroutine: AF,
175 name: impl Into<String>,
176 description: impl Into<String>,
177 args_schema: ArgsSchema,
178 ) -> Self
179 where
180 F: Fn(HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
181 AF: Fn(HashMap<String, Value>) -> Fut + Send + Sync + 'static,
182 Fut: Future<Output = Result<Value>> + Send + 'static,
183 {
184 Self::new(name, description, args_schema)
185 .with_func(Arc::new(func))
186 .with_coroutine(Arc::new(move |args| Box::pin(coroutine(args))))
187 }
188
189 pub fn from_async_function<AF, Fut>(
191 coroutine: AF,
192 name: impl Into<String>,
193 description: impl Into<String>,
194 args_schema: ArgsSchema,
195 ) -> Self
196 where
197 AF: Fn(HashMap<String, Value>) -> Fut + Send + Sync + 'static,
198 Fut: Future<Output = Result<Value>> + Send + 'static,
199 {
200 Self::new(name, description, args_schema)
201 .with_coroutine(Arc::new(move |args| Box::pin(coroutine(args))))
202 }
203
204 fn extract_args(&self, input: ToolInput) -> Result<HashMap<String, Value>> {
206 match input {
207 ToolInput::String(s) => {
208 if let Ok(Value::Object(obj)) = serde_json::from_str(&s) {
210 Ok(obj.into_iter().collect())
211 } else {
212 let props = self.args_schema.properties();
214 if props.len() == 1 {
215 let key = props.keys().next().unwrap().clone();
216 let mut args = HashMap::new();
217 args.insert(key, Value::String(s));
218 Ok(args)
219 } else {
220 Err(Error::ToolInvocation(
221 "String input not allowed for multi-argument tool".to_string(),
222 ))
223 }
224 }
225 }
226 ToolInput::Dict(d) => Ok(d),
227 ToolInput::ToolCall(tc) => {
228 let args = tc.args();
229 if let Some(obj) = args.as_object() {
230 Ok(obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
231 } else {
232 Err(Error::ToolInvocation(
233 "ToolCall args must be an object".to_string(),
234 ))
235 }
236 }
237 }
238 }
239
240 fn filter_args(&self, args: HashMap<String, Value>) -> HashMap<String, Value> {
242 args.into_iter()
243 .filter(|(k, _)| !FILTERED_ARGS.contains(&k.as_str()))
244 .collect()
245 }
246}
247
248#[async_trait]
249impl BaseTool for StructuredTool {
250 fn name(&self) -> &str {
251 &self.name
252 }
253
254 fn description(&self) -> &str {
255 &self.description
256 }
257
258 fn args_schema(&self) -> Option<&ArgsSchema> {
259 Some(&self.args_schema)
260 }
261
262 fn return_direct(&self) -> bool {
263 self.return_direct
264 }
265
266 fn verbose(&self) -> bool {
267 self.verbose
268 }
269
270 fn tags(&self) -> Option<&[String]> {
271 self.tags.as_deref()
272 }
273
274 fn metadata(&self) -> Option<&HashMap<String, Value>> {
275 self.metadata.as_ref()
276 }
277
278 fn handle_tool_error(&self) -> &HandleToolError {
279 &self.handle_tool_error
280 }
281
282 fn handle_validation_error(&self) -> &HandleValidationError {
283 &self.handle_validation_error
284 }
285
286 fn response_format(&self) -> ResponseFormat {
287 self.response_format
288 }
289
290 fn extras(&self) -> Option<&HashMap<String, Value>> {
291 self.extras.as_ref()
292 }
293
294 fn run(&self, input: ToolInput, _config: Option<RunnableConfig>) -> Result<ToolOutput> {
295 let args = self.extract_args(input)?;
296 let filtered_args = self.filter_args(args);
297
298 if let Some(ref func) = self.func {
299 match func(filtered_args) {
300 Ok(result) => {
301 match self.response_format {
302 ResponseFormat::Content => match result {
303 Value::String(s) => Ok(ToolOutput::String(s)),
304 other => Ok(ToolOutput::Json(other)),
305 },
306 ResponseFormat::ContentAndArtifact => {
307 if let Value::Array(arr) = result {
309 if arr.len() == 2 {
310 Ok(ToolOutput::ContentAndArtifact {
311 content: arr[0].clone(),
312 artifact: arr[1].clone(),
313 })
314 } else {
315 Err(Error::ToolInvocation(
316 "content_and_artifact response must be a 2-tuple"
317 .to_string(),
318 ))
319 }
320 } else {
321 Err(Error::ToolInvocation(
322 "content_and_artifact response must be a 2-tuple".to_string(),
323 ))
324 }
325 }
326 }
327 }
328 Err(e) => {
329 if let Error::ToolInvocation(msg) = &e {
330 let exc = ToolException::new(msg.clone());
331 if let Some(handled) =
332 super::base::handle_tool_error_impl(&exc, &self.handle_tool_error)
333 {
334 return Ok(ToolOutput::String(handled));
335 }
336 }
337 Err(e)
338 }
339 }
340 } else {
341 Err(Error::ToolInvocation(
342 "StructuredTool does not support sync invocation.".to_string(),
343 ))
344 }
345 }
346
347 async fn arun(&self, input: ToolInput, config: Option<RunnableConfig>) -> Result<ToolOutput> {
348 let args = self.extract_args(input.clone())?;
349 let filtered_args = self.filter_args(args);
350
351 if let Some(ref coroutine) = self.coroutine {
352 match coroutine(filtered_args).await {
353 Ok(result) => match self.response_format {
354 ResponseFormat::Content => match result {
355 Value::String(s) => Ok(ToolOutput::String(s)),
356 other => Ok(ToolOutput::Json(other)),
357 },
358 ResponseFormat::ContentAndArtifact => {
359 if let Value::Array(arr) = result {
360 if arr.len() == 2 {
361 Ok(ToolOutput::ContentAndArtifact {
362 content: arr[0].clone(),
363 artifact: arr[1].clone(),
364 })
365 } else {
366 Err(Error::ToolInvocation(
367 "content_and_artifact response must be a 2-tuple".to_string(),
368 ))
369 }
370 } else {
371 Err(Error::ToolInvocation(
372 "content_and_artifact response must be a 2-tuple".to_string(),
373 ))
374 }
375 }
376 },
377 Err(e) => {
378 if let Error::ToolInvocation(msg) = &e {
379 let exc = ToolException::new(msg.clone());
380 if let Some(handled) =
381 super::base::handle_tool_error_impl(&exc, &self.handle_tool_error)
382 {
383 return Ok(ToolOutput::String(handled));
384 }
385 }
386 Err(e)
387 }
388 }
389 } else {
390 self.run(input, config)
392 }
393 }
394}
395
396pub struct StructuredToolBuilder {
398 name: Option<String>,
399 description: Option<String>,
400 func: Option<StructuredToolFunc>,
401 coroutine: Option<AsyncStructuredToolFunc>,
402 args_schema: Option<ArgsSchema>,
403 return_direct: bool,
404 response_format: ResponseFormat,
405 parse_docstring: bool,
406 error_on_invalid_docstring: bool,
407 tags: Option<Vec<String>>,
408 metadata: Option<HashMap<String, Value>>,
409 extras: Option<HashMap<String, Value>>,
410}
411
412impl StructuredToolBuilder {
413 pub fn new() -> Self {
415 Self {
416 name: None,
417 description: None,
418 func: None,
419 coroutine: None,
420 args_schema: None,
421 return_direct: false,
422 response_format: ResponseFormat::Content,
423 parse_docstring: false,
424 error_on_invalid_docstring: false,
425 tags: None,
426 metadata: None,
427 extras: None,
428 }
429 }
430
431 pub fn name(mut self, name: impl Into<String>) -> Self {
433 self.name = Some(name.into());
434 self
435 }
436
437 pub fn description(mut self, description: impl Into<String>) -> Self {
439 self.description = Some(description.into());
440 self
441 }
442
443 pub fn func<F>(mut self, func: F) -> Self
445 where
446 F: Fn(HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
447 {
448 self.func = Some(Arc::new(func));
449 self
450 }
451
452 pub fn coroutine<AF, Fut>(mut self, coroutine: AF) -> Self
454 where
455 AF: Fn(HashMap<String, Value>) -> Fut + Send + Sync + 'static,
456 Fut: Future<Output = Result<Value>> + Send + 'static,
457 {
458 self.coroutine = Some(Arc::new(move |args| Box::pin(coroutine(args))));
459 self
460 }
461
462 pub fn args_schema(mut self, schema: ArgsSchema) -> Self {
464 self.args_schema = Some(schema);
465 self
466 }
467
468 pub fn return_direct(mut self, return_direct: bool) -> Self {
470 self.return_direct = return_direct;
471 self
472 }
473
474 pub fn response_format(mut self, format: ResponseFormat) -> Self {
476 self.response_format = format;
477 self
478 }
479
480 pub fn parse_docstring(mut self, parse: bool) -> Self {
482 self.parse_docstring = parse;
483 self
484 }
485
486 pub fn error_on_invalid_docstring(mut self, error: bool) -> Self {
488 self.error_on_invalid_docstring = error;
489 self
490 }
491
492 pub fn tags(mut self, tags: Vec<String>) -> Self {
494 self.tags = Some(tags);
495 self
496 }
497
498 pub fn metadata(mut self, metadata: HashMap<String, Value>) -> Self {
500 self.metadata = Some(metadata);
501 self
502 }
503
504 pub fn extras(mut self, extras: HashMap<String, Value>) -> Self {
506 self.extras = Some(extras);
507 self
508 }
509
510 pub fn build(self) -> Result<StructuredTool> {
512 let name = self
513 .name
514 .ok_or_else(|| Error::InvalidConfig("Tool name is required".to_string()))?;
515 let description = self.description.unwrap_or_default();
516 let args_schema = self.args_schema.unwrap_or_default();
517
518 if self.func.is_none() && self.coroutine.is_none() {
519 return Err(Error::InvalidConfig(
520 "Function and/or coroutine must be provided".to_string(),
521 ));
522 }
523
524 Ok(StructuredTool {
525 name,
526 description,
527 func: self.func,
528 coroutine: self.coroutine,
529 args_schema,
530 return_direct: self.return_direct,
531 verbose: false,
532 handle_tool_error: HandleToolError::Bool(false),
533 handle_validation_error: HandleValidationError::Bool(false),
534 response_format: self.response_format,
535 tags: self.tags,
536 metadata: self.metadata,
537 extras: self.extras,
538 })
539 }
540}
541
542impl Default for StructuredToolBuilder {
543 fn default() -> Self {
544 Self::new()
545 }
546}
547
548pub fn create_args_schema(
550 name: &str,
551 properties: HashMap<String, Value>,
552 required: Vec<String>,
553 description: Option<&str>,
554) -> ArgsSchema {
555 let mut schema = serde_json::json!({
556 "type": "object",
557 "title": name,
558 "properties": properties,
559 "required": required,
560 });
561
562 if let Some(desc) = description {
563 schema["description"] = Value::String(desc.to_string());
564 }
565
566 ArgsSchema::JsonSchema(schema)
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572
573 #[test]
574 fn test_structured_tool_creation() {
575 let schema = create_args_schema(
576 "add_numbers",
577 {
578 let mut props = HashMap::new();
579 props.insert("a".to_string(), serde_json::json!({"type": "number"}));
580 props.insert("b".to_string(), serde_json::json!({"type": "number"}));
581 props
582 },
583 vec!["a".to_string(), "b".to_string()],
584 Some("Add two numbers"),
585 );
586
587 let tool = StructuredTool::from_function(
588 |args| {
589 let a = args.get("a").and_then(|v| v.as_f64()).unwrap_or(0.0);
590 let b = args.get("b").and_then(|v| v.as_f64()).unwrap_or(0.0);
591 Ok(Value::from(a + b))
592 },
593 "add",
594 "Adds two numbers together",
595 schema,
596 );
597
598 assert_eq!(tool.name(), "add");
599 assert_eq!(tool.description(), "Adds two numbers together");
600 }
601
602 #[test]
603 fn test_structured_tool_run() {
604 let schema = create_args_schema(
605 "multiply",
606 {
607 let mut props = HashMap::new();
608 props.insert("x".to_string(), serde_json::json!({"type": "number"}));
609 props.insert("y".to_string(), serde_json::json!({"type": "number"}));
610 props
611 },
612 vec!["x".to_string(), "y".to_string()],
613 None,
614 );
615
616 let tool = StructuredTool::from_function(
617 |args| {
618 let x = args.get("x").and_then(|v| v.as_f64()).unwrap_or(0.0);
619 let y = args.get("y").and_then(|v| v.as_f64()).unwrap_or(0.0);
620 Ok(Value::from(x * y))
621 },
622 "multiply",
623 "Multiplies two numbers",
624 schema,
625 );
626
627 let mut input = HashMap::new();
628 input.insert("x".to_string(), Value::from(3.0));
629 input.insert("y".to_string(), Value::from(4.0));
630
631 let result = tool.run(ToolInput::Dict(input), None).unwrap();
632 match result {
633 ToolOutput::Json(v) => assert_eq!(v.as_f64().unwrap(), 12.0),
634 _ => panic!("Expected Json output"),
635 }
636 }
637
638 #[test]
639 fn test_structured_tool_builder() {
640 let tool = StructuredToolBuilder::new()
641 .name("greet")
642 .description("Greets a person")
643 .args_schema(create_args_schema(
644 "greet",
645 {
646 let mut props = HashMap::new();
647 props.insert("name".to_string(), serde_json::json!({"type": "string"}));
648 props
649 },
650 vec!["name".to_string()],
651 None,
652 ))
653 .func(|args| {
654 let name = args
655 .get("name")
656 .and_then(|v| v.as_str())
657 .unwrap_or("stranger");
658 Ok(Value::String(format!("Hello, {}!", name)))
659 })
660 .return_direct(true)
661 .build()
662 .unwrap();
663
664 assert_eq!(tool.name(), "greet");
665 assert!(tool.return_direct());
666 }
667
668 #[test]
669 fn test_create_args_schema() {
670 let schema = create_args_schema(
671 "test_schema",
672 {
673 let mut props = HashMap::new();
674 props.insert("field1".to_string(), serde_json::json!({"type": "string"}));
675 props
676 },
677 vec!["field1".to_string()],
678 Some("Test description"),
679 );
680
681 let json = schema.to_json_schema();
682 assert_eq!(json["title"], "test_schema");
683 assert_eq!(json["description"], "Test description");
684 assert!(json["properties"]["field1"].is_object());
685 }
686
687 #[tokio::test]
688 async fn test_structured_tool_arun() {
689 let schema = create_args_schema(
690 "concat",
691 {
692 let mut props = HashMap::new();
693 props.insert("a".to_string(), serde_json::json!({"type": "string"}));
694 props.insert("b".to_string(), serde_json::json!({"type": "string"}));
695 props
696 },
697 vec!["a".to_string(), "b".to_string()],
698 None,
699 );
700
701 let tool = StructuredTool::from_function(
702 |args| {
703 let a = args.get("a").and_then(|v| v.as_str()).unwrap_or("");
704 let b = args.get("b").and_then(|v| v.as_str()).unwrap_or("");
705 Ok(Value::String(format!("{}{}", a, b)))
706 },
707 "concat",
708 "Concatenates two strings",
709 schema,
710 );
711
712 let mut input = HashMap::new();
713 input.insert("a".to_string(), Value::String("Hello".to_string()));
714 input.insert("b".to_string(), Value::String("World".to_string()));
715
716 let result = tool.arun(ToolInput::Dict(input), None).await.unwrap();
717 match result {
718 ToolOutput::String(s) => assert_eq!(s, "HelloWorld"),
719 _ => panic!("Expected String output"),
720 }
721 }
722}