1use std::collections::HashMap;
7use std::fmt::Debug;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use thiserror::Error;
14
15use crate::error::Result;
16use crate::messages::{BaseMessage, ToolCall, ToolMessage};
17use crate::runnables::{RunnableConfig, ensure_config};
18
19pub const FILTERED_ARGS: &[&str] = &["run_manager", "callbacks"];
21
22pub const TOOL_MESSAGE_BLOCK_TYPES: &[&str] = &[
24 "text",
25 "image_url",
26 "image",
27 "json",
28 "search_result",
29 "custom_tool_call_output",
30 "document",
31 "file",
32];
33
34#[derive(Debug, Error)]
36#[error("Schema annotation error: {message}")]
37pub struct SchemaAnnotationError {
38 pub message: String,
39}
40
41impl SchemaAnnotationError {
42 pub fn new(message: impl Into<String>) -> Self {
43 Self {
44 message: message.into(),
45 }
46 }
47}
48
49#[derive(Debug, Error)]
55#[error("{0}")]
56pub struct ToolException(pub String);
57
58impl ToolException {
59 pub fn new(message: impl Into<String>) -> Self {
60 Self(message.into())
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
66#[serde(rename_all = "snake_case")]
67pub enum ResponseFormat {
68 #[default]
70 Content,
71 ContentAndArtifact,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77#[serde(untagged)]
78pub enum ArgsSchema {
79 JsonSchema(Value),
81 TypeName(String),
83}
84
85impl Default for ArgsSchema {
86 fn default() -> Self {
87 ArgsSchema::JsonSchema(serde_json::json!({
88 "type": "object",
89 "properties": {}
90 }))
91 }
92}
93
94impl ArgsSchema {
95 pub fn to_json_schema(&self) -> Value {
97 match self {
98 ArgsSchema::JsonSchema(schema) => schema.clone(),
99 ArgsSchema::TypeName(name) => serde_json::json!({
100 "type": "object",
101 "title": name,
102 "properties": {}
103 }),
104 }
105 }
106
107 pub fn properties(&self) -> HashMap<String, Value> {
109 match self {
110 ArgsSchema::JsonSchema(schema) => schema
111 .get("properties")
112 .and_then(|p| p.as_object())
113 .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
114 .unwrap_or_default(),
115 ArgsSchema::TypeName(_) => HashMap::new(),
116 }
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ToolDefinition {
123 pub name: String,
125 pub description: String,
127 pub parameters: Value,
129}
130
131#[derive(Clone)]
133pub enum HandleToolError {
134 None,
136 Bool(bool),
138 Message(String),
140 Handler(Arc<dyn Fn(&ToolException) -> String + Send + Sync>),
142}
143
144impl Debug for HandleToolError {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 match self {
147 HandleToolError::None => write!(f, "HandleToolError::None"),
148 HandleToolError::Bool(b) => f.debug_tuple("HandleToolError::Bool").field(b).finish(),
149 HandleToolError::Message(m) => {
150 f.debug_tuple("HandleToolError::Message").field(m).finish()
151 }
152 HandleToolError::Handler(_) => write!(f, "HandleToolError::Handler(<function>)"),
153 }
154 }
155}
156
157impl Default for HandleToolError {
158 fn default() -> Self {
159 HandleToolError::Bool(false)
160 }
161}
162
163#[derive(Clone)]
165pub enum HandleValidationError {
166 None,
168 Bool(bool),
170 Message(String),
172 Handler(Arc<dyn Fn(&str) -> String + Send + Sync>),
174}
175
176impl Debug for HandleValidationError {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 match self {
179 HandleValidationError::None => write!(f, "HandleValidationError::None"),
180 HandleValidationError::Bool(b) => f
181 .debug_tuple("HandleValidationError::Bool")
182 .field(b)
183 .finish(),
184 HandleValidationError::Message(m) => f
185 .debug_tuple("HandleValidationError::Message")
186 .field(m)
187 .finish(),
188 HandleValidationError::Handler(_) => {
189 write!(f, "HandleValidationError::Handler(<function>)")
190 }
191 }
192 }
193}
194
195impl Default for HandleValidationError {
196 fn default() -> Self {
197 HandleValidationError::Bool(false)
198 }
199}
200
201#[derive(Debug, Clone)]
203pub enum ToolInput {
204 String(String),
206 Dict(HashMap<String, Value>),
208 ToolCall(ToolCall),
210}
211
212impl From<String> for ToolInput {
213 fn from(s: String) -> Self {
214 ToolInput::String(s)
215 }
216}
217
218impl From<&str> for ToolInput {
219 fn from(s: &str) -> Self {
220 ToolInput::String(s.to_string())
221 }
222}
223
224impl From<HashMap<String, Value>> for ToolInput {
225 fn from(d: HashMap<String, Value>) -> Self {
226 ToolInput::Dict(d)
227 }
228}
229
230impl From<ToolCall> for ToolInput {
231 fn from(tc: ToolCall) -> Self {
232 ToolInput::ToolCall(tc)
233 }
234}
235
236impl From<Value> for ToolInput {
237 fn from(v: Value) -> Self {
238 match v {
239 Value::String(s) => ToolInput::String(s),
240 Value::Object(obj) => {
241 if obj.get("type").and_then(|t| t.as_str()) == Some("tool_call")
243 && let (Some(id), Some(name), Some(args)) = (
244 obj.get("id").and_then(|i| i.as_str()),
245 obj.get("name").and_then(|n| n.as_str()),
246 obj.get("args"),
247 )
248 {
249 return ToolInput::ToolCall(ToolCall::with_id(id, name, args.clone()));
250 }
251 ToolInput::Dict(obj.into_iter().collect())
252 }
253 _ => ToolInput::String(v.to_string()),
254 }
255 }
256}
257
258#[derive(Debug, Clone)]
260pub enum ToolOutput {
261 String(String),
263 Message(ToolMessage),
265 ContentAndArtifact { content: Value, artifact: Value },
267 Json(Value),
269}
270
271impl From<String> for ToolOutput {
272 fn from(s: String) -> Self {
273 ToolOutput::String(s)
274 }
275}
276
277impl From<&str> for ToolOutput {
278 fn from(s: &str) -> Self {
279 ToolOutput::String(s.to_string())
280 }
281}
282
283impl From<ToolMessage> for ToolOutput {
284 fn from(m: ToolMessage) -> Self {
285 ToolOutput::Message(m)
286 }
287}
288
289impl From<Value> for ToolOutput {
290 fn from(v: Value) -> Self {
291 ToolOutput::Json(v)
292 }
293}
294
295#[async_trait]
300pub trait BaseTool: Send + Sync + Debug {
301 fn name(&self) -> &str;
303
304 fn description(&self) -> &str;
306
307 fn args_schema(&self) -> Option<&ArgsSchema> {
309 None
310 }
311
312 fn return_direct(&self) -> bool {
314 false
315 }
316
317 fn verbose(&self) -> bool {
319 false
320 }
321
322 fn tags(&self) -> Option<&[String]> {
324 None
325 }
326
327 fn metadata(&self) -> Option<&HashMap<String, Value>> {
329 None
330 }
331
332 fn handle_tool_error(&self) -> &HandleToolError {
334 &HandleToolError::Bool(false)
335 }
336
337 fn handle_validation_error(&self) -> &HandleValidationError {
339 &HandleValidationError::Bool(false)
340 }
341
342 fn response_format(&self) -> ResponseFormat {
344 ResponseFormat::Content
345 }
346
347 fn extras(&self) -> Option<&HashMap<String, Value>> {
349 None
350 }
351
352 fn is_single_input(&self) -> bool {
354 let args = self.args();
355 let keys: Vec<_> = args.keys().filter(|k| *k != "kwargs").collect();
356 keys.len() == 1
357 }
358
359 fn args(&self) -> HashMap<String, Value> {
361 self.args_schema()
362 .map(|s| s.properties())
363 .unwrap_or_default()
364 }
365
366 fn tool_call_schema(&self) -> ArgsSchema {
368 self.args_schema().cloned().unwrap_or_default()
369 }
370
371 fn definition(&self) -> ToolDefinition {
373 ToolDefinition {
374 name: self.name().to_string(),
375 description: self.description().to_string(),
376 parameters: self
377 .args_schema()
378 .map(|s| s.to_json_schema())
379 .unwrap_or_else(|| serde_json::json!({"type": "object", "properties": {}})),
380 }
381 }
382
383 fn parameters_schema(&self) -> Value {
385 self.definition().parameters
386 }
387
388 fn run(&self, input: ToolInput, config: Option<RunnableConfig>) -> Result<ToolOutput>;
390
391 async fn arun(&self, input: ToolInput, config: Option<RunnableConfig>) -> Result<ToolOutput> {
393 self.run(input, config)
395 }
396
397 async fn invoke(&self, tool_call: ToolCall) -> BaseMessage {
399 let input = ToolInput::ToolCall(tool_call.clone());
400 match self.arun(input, None).await {
401 Ok(output) => match output {
402 ToolOutput::String(s) => ToolMessage::new(s, tool_call.id()).into(),
403 ToolOutput::Message(m) => m.into(),
404 ToolOutput::ContentAndArtifact { content, artifact } => {
405 ToolMessage::with_artifact(content.to_string(), tool_call.id(), artifact).into()
406 }
407 ToolOutput::Json(v) => ToolMessage::new(v.to_string(), tool_call.id()).into(),
408 },
409 Err(e) => ToolMessage::error(e.to_string(), tool_call.id()).into(),
410 }
411 }
412
413 async fn invoke_args(&self, args: Value) -> Value {
415 let tool_call = ToolCall::new(self.name(), args);
416 let result = self.invoke(tool_call).await;
417 Value::String(result.content().to_string())
418 }
419}
420
421#[derive(Debug, Clone, Default)]
426pub struct InjectedToolArg;
427
428#[derive(Debug, Clone, Default)]
433pub struct InjectedToolCallId;
434
435pub fn is_tool_call(input: &Value) -> bool {
437 input.get("type").and_then(|t| t.as_str()) == Some("tool_call")
438}
439
440pub fn handle_tool_error_impl(e: &ToolException, flag: &HandleToolError) -> Option<String> {
442 match flag {
443 HandleToolError::None => None,
444 HandleToolError::Bool(false) => None,
445 HandleToolError::Bool(true) => Some(e.0.clone()),
446 HandleToolError::Message(msg) => Some(msg.clone()),
447 HandleToolError::Handler(f) => Some(f(e)),
448 }
449}
450
451pub fn handle_validation_error_impl(e: &str, flag: &HandleValidationError) -> Option<String> {
453 match flag {
454 HandleValidationError::None => None,
455 HandleValidationError::Bool(false) => None,
456 HandleValidationError::Bool(true) => Some("Tool input validation error".to_string()),
457 HandleValidationError::Message(msg) => Some(msg.clone()),
458 HandleValidationError::Handler(f) => Some(f(e)),
459 }
460}
461
462pub fn format_output(
464 content: Value,
465 artifact: Option<Value>,
466 tool_call_id: Option<&str>,
467 name: &str,
468 _status: &str,
469) -> ToolOutput {
470 if let Some(tool_call_id) = tool_call_id {
471 let msg = if let Some(artifact) = artifact {
472 ToolMessage::with_artifact(stringify_content(&content), tool_call_id, artifact)
473 } else {
474 ToolMessage::new(stringify_content(&content), tool_call_id)
475 };
476 ToolOutput::Message(msg.with_name(name))
477 } else {
478 match content {
479 Value::String(s) => ToolOutput::String(s),
480 other => ToolOutput::Json(other),
481 }
482 }
483}
484
485pub fn is_message_content_type(obj: &Value) -> bool {
487 match obj {
488 Value::String(_) => true,
489 Value::Array(arr) => arr.iter().all(is_message_content_block),
490 _ => false,
491 }
492}
493
494pub fn is_message_content_block(obj: &Value) -> bool {
496 match obj {
497 Value::String(_) => true,
498 Value::Object(map) => map
499 .get("type")
500 .and_then(|t| t.as_str())
501 .map(|t| TOOL_MESSAGE_BLOCK_TYPES.contains(&t))
502 .unwrap_or(false),
503 _ => false,
504 }
505}
506
507pub fn stringify_content(content: &Value) -> String {
509 match content {
510 Value::String(s) => s.clone(),
511 other => serde_json::to_string(other).unwrap_or_else(|_| other.to_string()),
512 }
513}
514
515pub fn prep_run_args(
517 value: ToolInput,
518 config: Option<RunnableConfig>,
519) -> (ToolInput, Option<String>, RunnableConfig) {
520 let config = ensure_config(config);
521
522 match &value {
523 ToolInput::ToolCall(tc) => {
524 let tool_call_id = Some(tc.id().to_string());
525 let input = ToolInput::Dict(
526 tc.args()
527 .as_object()
528 .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
529 .unwrap_or_default(),
530 );
531 (input, tool_call_id, config)
532 }
533 _ => (value, None, config),
534 }
535}
536
537pub trait BaseToolkit: Send + Sync {
542 fn get_tools(&self) -> Vec<Arc<dyn BaseTool>>;
544}
545
546pub type DynTool = Arc<dyn BaseTool>;
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 #[test]
554 fn test_tool_input_from_string() {
555 let input = ToolInput::from("test");
556 match input {
557 ToolInput::String(s) => assert_eq!(s, "test"),
558 _ => panic!("Expected String variant"),
559 }
560 }
561
562 #[test]
563 fn test_tool_input_from_value() {
564 let value = serde_json::json!({"key": "value"});
565 let input = ToolInput::from(value);
566 match input {
567 ToolInput::Dict(d) => {
568 assert_eq!(d.get("key"), Some(&Value::String("value".to_string())));
569 }
570 _ => panic!("Expected Dict variant"),
571 }
572 }
573
574 #[test]
575 fn test_is_tool_call() {
576 let tc = serde_json::json!({
577 "type": "tool_call",
578 "id": "123",
579 "name": "test",
580 "args": {}
581 });
582 assert!(is_tool_call(&tc));
583
584 let not_tc = serde_json::json!({"key": "value"});
585 assert!(!is_tool_call(¬_tc));
586 }
587
588 #[test]
589 fn test_args_schema_properties() {
590 let schema = ArgsSchema::JsonSchema(serde_json::json!({
591 "type": "object",
592 "properties": {
593 "query": {"type": "string"}
594 }
595 }));
596 let props = schema.properties();
597 assert!(props.contains_key("query"));
598 }
599
600 #[test]
601 fn test_response_format_default() {
602 assert_eq!(ResponseFormat::default(), ResponseFormat::Content);
603 }
604
605 #[test]
606 fn test_handle_tool_error() {
607 let exc = ToolException::new("test error");
608
609 let result = handle_tool_error_impl(&exc, &HandleToolError::Bool(false));
610 assert!(result.is_none());
611
612 let result = handle_tool_error_impl(&exc, &HandleToolError::Bool(true));
613 assert_eq!(result, Some("test error".to_string()));
614
615 let result = handle_tool_error_impl(&exc, &HandleToolError::Message("custom".to_string()));
616 assert_eq!(result, Some("custom".to_string()));
617 }
618}