1use std::collections::HashMap;
7use std::future::Future;
8use std::sync::Arc;
9
10use serde_json::Value;
11
12use crate::error::{Error, Result};
13use crate::runnables::Runnable;
14
15use super::base::{ArgsSchema, ResponseFormat};
16use super::simple::Tool;
17use super::structured::{StructuredTool, create_args_schema};
18
19#[derive(Debug, Clone, Default)]
21pub struct ToolConfig {
22 pub name: Option<String>,
24 pub description: Option<String>,
26 pub return_direct: bool,
28 pub args_schema: Option<ArgsSchema>,
30 pub infer_schema: bool,
32 pub response_format: ResponseFormat,
34 pub parse_docstring: bool,
36 pub error_on_invalid_docstring: bool,
38 pub extras: Option<HashMap<String, Value>>,
40}
41
42impl ToolConfig {
43 pub fn new() -> Self {
45 Self {
46 infer_schema: true,
47 ..Default::default()
48 }
49 }
50
51 pub fn with_name(mut self, name: impl Into<String>) -> Self {
53 self.name = Some(name.into());
54 self
55 }
56
57 pub fn with_description(mut self, description: impl Into<String>) -> Self {
59 self.description = Some(description.into());
60 self
61 }
62
63 pub fn with_return_direct(mut self, return_direct: bool) -> Self {
65 self.return_direct = return_direct;
66 self
67 }
68
69 pub fn with_args_schema(mut self, schema: ArgsSchema) -> Self {
71 self.args_schema = Some(schema);
72 self
73 }
74
75 pub fn with_infer_schema(mut self, infer_schema: bool) -> Self {
77 self.infer_schema = infer_schema;
78 self
79 }
80
81 pub fn with_response_format(mut self, format: ResponseFormat) -> Self {
83 self.response_format = format;
84 self
85 }
86
87 pub fn with_parse_docstring(mut self, parse: bool) -> Self {
89 self.parse_docstring = parse;
90 self
91 }
92
93 pub fn with_extras(mut self, extras: HashMap<String, Value>) -> Self {
95 self.extras = Some(extras);
96 self
97 }
98}
99
100pub fn create_simple_tool<F>(
104 name: impl Into<String>,
105 description: impl Into<String>,
106 func: F,
107) -> Tool
108where
109 F: Fn(String) -> Result<String> + Send + Sync + 'static,
110{
111 Tool::from_function(func, name, description)
112}
113
114pub fn create_simple_tool_async<F, AF, Fut>(
116 name: impl Into<String>,
117 description: impl Into<String>,
118 func: F,
119 coroutine: AF,
120) -> Tool
121where
122 F: Fn(String) -> Result<String> + Send + Sync + 'static,
123 AF: Fn(String) -> Fut + Send + Sync + 'static,
124 Fut: Future<Output = Result<String>> + Send + 'static,
125{
126 Tool::from_function_with_async(func, coroutine, name, description)
127}
128
129pub fn create_structured_tool<F>(
133 name: impl Into<String>,
134 description: impl Into<String>,
135 args_schema: ArgsSchema,
136 func: F,
137) -> StructuredTool
138where
139 F: Fn(HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
140{
141 StructuredTool::from_function(func, name, description, args_schema)
142}
143
144pub fn create_structured_tool_async<F, AF, Fut>(
146 name: impl Into<String>,
147 description: impl Into<String>,
148 args_schema: ArgsSchema,
149 func: F,
150 coroutine: AF,
151) -> StructuredTool
152where
153 F: Fn(HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
154 AF: Fn(HashMap<String, Value>) -> Fut + Send + Sync + 'static,
155 Fut: Future<Output = Result<Value>> + Send + 'static,
156{
157 StructuredTool::from_function_with_async(func, coroutine, name, description, args_schema)
158}
159
160pub fn create_tool_with_config<F>(func: F, config: ToolConfig) -> Result<StructuredTool>
162where
163 F: Fn(HashMap<String, Value>) -> Result<Value> + Send + Sync + 'static,
164{
165 let name = config
166 .name
167 .ok_or_else(|| Error::InvalidConfig("Tool name is required".to_string()))?;
168 let description = config.description.unwrap_or_default();
169 let args_schema = config.args_schema.unwrap_or_default();
170
171 let mut tool = StructuredTool::from_function(func, name, description, args_schema);
172
173 if config.return_direct {
174 tool = tool.with_return_direct(true);
175 }
176
177 tool = tool.with_response_format(config.response_format);
178
179 if let Some(extras) = config.extras {
180 tool = tool.with_extras(extras);
181 }
182
183 Ok(tool)
184}
185
186pub fn convert_runnable_to_tool<R>(
190 runnable: Arc<R>,
191 name: impl Into<String>,
192 description: impl Into<String>,
193) -> StructuredTool
194where
195 R: Runnable<Input = HashMap<String, Value>, Output = Value> + Send + Sync + 'static,
196{
197 let name = name.into();
198 let description = description.into();
199
200 let runnable_clone = runnable.clone();
201 let func = move |args: HashMap<String, Value>| runnable_clone.invoke(args, None);
202
203 let schema = ArgsSchema::JsonSchema(serde_json::json!({
205 "type": "object",
206 "properties": {},
207 "additionalProperties": true
208 }));
209
210 StructuredTool::from_function(func, name, description, schema)
211}
212
213pub type ToolFromSchemaFn = Box<dyn Fn(HashMap<String, Value>) -> Result<Value> + Send + Sync>;
215
216pub fn tool_from_schema(
221 name: impl Into<String>,
222 description: impl Into<String>,
223 properties: Vec<(&str, &str, &str, bool)>, ) -> impl FnOnce(ToolFromSchemaFn) -> StructuredTool {
225 let name = name.into();
226 let description = description.into();
227
228 let mut props = HashMap::new();
229 let mut required = Vec::new();
230
231 for (prop_name, prop_type, prop_desc, is_required) in properties {
232 props.insert(
233 prop_name.to_string(),
234 serde_json::json!({
235 "type": prop_type,
236 "description": prop_desc
237 }),
238 );
239 if is_required {
240 required.push(prop_name.to_string());
241 }
242 }
243
244 let schema = create_args_schema(&name, props, required, Some(&description));
245
246 move |func| StructuredTool::from_function(func, name, description, schema)
247}
248
249pub fn get_description_from_runnable<R>(_runnable: &R) -> String
251where
252 R: Runnable,
253{
254 "Takes an input and produces an output.".to_string()
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::tools::base::BaseTool;
261
262 #[test]
263 fn test_create_simple_tool() {
264 let tool = create_simple_tool("echo", "Echoes the input", |input| {
265 Ok(format!("Echo: {}", input))
266 });
267
268 assert_eq!(tool.name(), "echo");
269 assert_eq!(tool.description(), "Echoes the input");
270 }
271
272 #[test]
273 fn test_create_structured_tool() {
274 let schema = create_args_schema(
275 "add",
276 {
277 let mut props = HashMap::new();
278 props.insert("a".to_string(), serde_json::json!({"type": "number"}));
279 props.insert("b".to_string(), serde_json::json!({"type": "number"}));
280 props
281 },
282 vec!["a".to_string(), "b".to_string()],
283 None,
284 );
285
286 let tool = create_structured_tool("add", "Adds two numbers", schema, |args| {
287 let a = args.get("a").and_then(|v| v.as_f64()).unwrap_or(0.0);
288 let b = args.get("b").and_then(|v| v.as_f64()).unwrap_or(0.0);
289 Ok(Value::from(a + b))
290 });
291
292 assert_eq!(tool.name(), "add");
293 }
294
295 #[test]
296 fn test_tool_config() {
297 let config = ToolConfig::new()
298 .with_name("test")
299 .with_description("A test tool")
300 .with_return_direct(true)
301 .with_response_format(ResponseFormat::ContentAndArtifact);
302
303 assert_eq!(config.name, Some("test".to_string()));
304 assert!(config.return_direct);
305 assert_eq!(config.response_format, ResponseFormat::ContentAndArtifact);
306 }
307
308 #[test]
309 fn test_create_tool_with_config() {
310 let config = ToolConfig::new()
311 .with_name("configured_tool")
312 .with_description("A configured tool")
313 .with_args_schema(ArgsSchema::JsonSchema(serde_json::json!({
314 "type": "object",
315 "properties": {
316 "input": {"type": "string"}
317 }
318 })));
319
320 let tool = create_tool_with_config(
321 |args| Ok(args.get("input").cloned().unwrap_or(Value::Null)),
322 config,
323 )
324 .unwrap();
325
326 assert_eq!(tool.name(), "configured_tool");
327 }
328
329 #[test]
330 fn test_tool_from_schema() {
331 let create_tool = tool_from_schema(
332 "greet",
333 "Greets a person",
334 vec![("name", "string", "The person's name", true)],
335 );
336
337 let tool = create_tool(Box::new(|args| {
338 let name = args
339 .get("name")
340 .and_then(|v| v.as_str())
341 .unwrap_or("stranger");
342 Ok(Value::String(format!("Hello, {}!", name)))
343 }));
344
345 assert_eq!(tool.name(), "greet");
346 }
347}