agent_chain_core/tools/
base.rs

1//! Base classes and utilities for LangChain tools.
2//!
3//! This module provides the core tool abstractions, mirroring
4//! `langchain_core.tools.base`.
5
6use std::collections::HashMap;
7use std::fmt::Debug;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use thiserror::Error;
14
15use crate::error::Result;
16use crate::messages::{BaseMessage, ToolCall, ToolMessage};
17use crate::runnables::{RunnableConfig, ensure_config};
18
19/// Arguments that are filtered out from tool schemas.
20pub const FILTERED_ARGS: &[&str] = &["run_manager", "callbacks"];
21
22/// Block types that are valid in tool messages.
23pub const TOOL_MESSAGE_BLOCK_TYPES: &[&str] = &[
24    "text",
25    "image_url",
26    "image",
27    "json",
28    "search_result",
29    "custom_tool_call_output",
30    "document",
31    "file",
32];
33
34/// Error raised when args_schema is missing or has incorrect type annotation.
35#[derive(Debug, Error)]
36#[error("Schema annotation error: {message}")]
37pub struct SchemaAnnotationError {
38    pub message: String,
39}
40
41impl SchemaAnnotationError {
42    pub fn new(message: impl Into<String>) -> Self {
43        Self {
44            message: message.into(),
45        }
46    }
47}
48
49/// Exception thrown when a tool execution error occurs.
50///
51/// This exception allows tools to signal errors without stopping the agent.
52/// The error is handled according to the tool's `handle_tool_error` setting,
53/// and the result is returned as an observation to the agent.
54#[derive(Debug, Error)]
55#[error("{0}")]
56pub struct ToolException(pub String);
57
58impl ToolException {
59    pub fn new(message: impl Into<String>) -> Self {
60        Self(message.into())
61    }
62}
63
64/// Represents the response format for a tool.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
66#[serde(rename_all = "snake_case")]
67pub enum ResponseFormat {
68    /// The output is interpreted as the contents of a ToolMessage.
69    #[default]
70    Content,
71    /// The output is expected to be a tuple of (content, artifact).
72    ContentAndArtifact,
73}
74
75/// Represents a tool's schema, which can be a JSON schema or a type reference.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77#[serde(untagged)]
78pub enum ArgsSchema {
79    /// A JSON schema definition.
80    JsonSchema(Value),
81    /// A type name reference.
82    TypeName(String),
83}
84
85impl Default for ArgsSchema {
86    fn default() -> Self {
87        ArgsSchema::JsonSchema(serde_json::json!({
88            "type": "object",
89            "properties": {}
90        }))
91    }
92}
93
94impl ArgsSchema {
95    /// Get the JSON schema for this args schema.
96    pub fn to_json_schema(&self) -> Value {
97        match self {
98            ArgsSchema::JsonSchema(schema) => schema.clone(),
99            ArgsSchema::TypeName(name) => serde_json::json!({
100                "type": "object",
101                "title": name,
102                "properties": {}
103            }),
104        }
105    }
106
107    /// Get properties from the schema.
108    pub fn properties(&self) -> HashMap<String, Value> {
109        match self {
110            ArgsSchema::JsonSchema(schema) => schema
111                .get("properties")
112                .and_then(|p| p.as_object())
113                .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
114                .unwrap_or_default(),
115            ArgsSchema::TypeName(_) => HashMap::new(),
116        }
117    }
118}
119
120/// Represents a tool's definition for LLM function calling.
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ToolDefinition {
123    /// The name of the tool.
124    pub name: String,
125    /// A description of what the tool does.
126    pub description: String,
127    /// JSON schema for the tool's parameters.
128    pub parameters: Value,
129}
130
131/// How to handle tool errors.
132#[derive(Clone)]
133pub enum HandleToolError {
134    /// Don't handle errors (re-raise them).
135    None,
136    /// Return a generic error message.
137    Bool(bool),
138    /// Return a specific error message.
139    Message(String),
140    /// Use a custom function to handle the error.
141    Handler(Arc<dyn Fn(&ToolException) -> String + Send + Sync>),
142}
143
144impl Debug for HandleToolError {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        match self {
147            HandleToolError::None => write!(f, "HandleToolError::None"),
148            HandleToolError::Bool(b) => f.debug_tuple("HandleToolError::Bool").field(b).finish(),
149            HandleToolError::Message(m) => {
150                f.debug_tuple("HandleToolError::Message").field(m).finish()
151            }
152            HandleToolError::Handler(_) => write!(f, "HandleToolError::Handler(<function>)"),
153        }
154    }
155}
156
157impl Default for HandleToolError {
158    fn default() -> Self {
159        HandleToolError::Bool(false)
160    }
161}
162
163/// How to handle validation errors.
164#[derive(Clone)]
165pub enum HandleValidationError {
166    /// Don't handle errors (re-raise them).
167    None,
168    /// Return a generic error message.
169    Bool(bool),
170    /// Return a specific error message.
171    Message(String),
172    /// Use a custom function to handle the error.
173    Handler(Arc<dyn Fn(&str) -> String + Send + Sync>),
174}
175
176impl Debug for HandleValidationError {
177    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178        match self {
179            HandleValidationError::None => write!(f, "HandleValidationError::None"),
180            HandleValidationError::Bool(b) => f
181                .debug_tuple("HandleValidationError::Bool")
182                .field(b)
183                .finish(),
184            HandleValidationError::Message(m) => f
185                .debug_tuple("HandleValidationError::Message")
186                .field(m)
187                .finish(),
188            HandleValidationError::Handler(_) => {
189                write!(f, "HandleValidationError::Handler(<function>)")
190            }
191        }
192    }
193}
194
195impl Default for HandleValidationError {
196    fn default() -> Self {
197        HandleValidationError::Bool(false)
198    }
199}
200
201/// Input type for tools - can be a string, dict, or ToolCall.
202#[derive(Debug, Clone)]
203pub enum ToolInput {
204    /// A simple string input.
205    String(String),
206    /// A dictionary of arguments.
207    Dict(HashMap<String, Value>),
208    /// A full tool call.
209    ToolCall(ToolCall),
210}
211
212impl From<String> for ToolInput {
213    fn from(s: String) -> Self {
214        ToolInput::String(s)
215    }
216}
217
218impl From<&str> for ToolInput {
219    fn from(s: &str) -> Self {
220        ToolInput::String(s.to_string())
221    }
222}
223
224impl From<HashMap<String, Value>> for ToolInput {
225    fn from(d: HashMap<String, Value>) -> Self {
226        ToolInput::Dict(d)
227    }
228}
229
230impl From<ToolCall> for ToolInput {
231    fn from(tc: ToolCall) -> Self {
232        ToolInput::ToolCall(tc)
233    }
234}
235
236impl From<Value> for ToolInput {
237    fn from(v: Value) -> Self {
238        match v {
239            Value::String(s) => ToolInput::String(s),
240            Value::Object(obj) => {
241                // Check if this is a tool call
242                if obj.get("type").and_then(|t| t.as_str()) == Some("tool_call")
243                    && let (Some(id), Some(name), Some(args)) = (
244                        obj.get("id").and_then(|i| i.as_str()),
245                        obj.get("name").and_then(|n| n.as_str()),
246                        obj.get("args"),
247                    )
248                {
249                    return ToolInput::ToolCall(ToolCall::with_id(id, name, args.clone()));
250                }
251                ToolInput::Dict(obj.into_iter().collect())
252            }
253            _ => ToolInput::String(v.to_string()),
254        }
255    }
256}
257
258/// Output type for tools.
259#[derive(Debug, Clone)]
260pub enum ToolOutput {
261    /// A simple string output.
262    String(String),
263    /// A ToolMessage output.
264    Message(ToolMessage),
265    /// A content and artifact tuple.
266    ContentAndArtifact { content: Value, artifact: Value },
267    /// Raw JSON value.
268    Json(Value),
269}
270
271impl From<String> for ToolOutput {
272    fn from(s: String) -> Self {
273        ToolOutput::String(s)
274    }
275}
276
277impl From<&str> for ToolOutput {
278    fn from(s: &str) -> Self {
279        ToolOutput::String(s.to_string())
280    }
281}
282
283impl From<ToolMessage> for ToolOutput {
284    fn from(m: ToolMessage) -> Self {
285        ToolOutput::Message(m)
286    }
287}
288
289impl From<Value> for ToolOutput {
290    fn from(v: Value) -> Self {
291        ToolOutput::Json(v)
292    }
293}
294
295/// Base trait for all LangChain tools.
296///
297/// This trait defines the interface that all LangChain tools must implement.
298/// Tools are components that can be called by agents to perform specific actions.
299#[async_trait]
300pub trait BaseTool: Send + Sync + Debug {
301    /// Get the unique name of the tool.
302    fn name(&self) -> &str;
303
304    /// Get the description of what the tool does.
305    fn description(&self) -> &str;
306
307    /// Get the args schema for the tool.
308    fn args_schema(&self) -> Option<&ArgsSchema> {
309        None
310    }
311
312    /// Whether to return the tool's output directly.
313    fn return_direct(&self) -> bool {
314        false
315    }
316
317    /// Whether to log the tool's progress.
318    fn verbose(&self) -> bool {
319        false
320    }
321
322    /// Get tags associated with the tool.
323    fn tags(&self) -> Option<&[String]> {
324        None
325    }
326
327    /// Get metadata associated with the tool.
328    fn metadata(&self) -> Option<&HashMap<String, Value>> {
329        None
330    }
331
332    /// Get how to handle tool errors.
333    fn handle_tool_error(&self) -> &HandleToolError {
334        &HandleToolError::Bool(false)
335    }
336
337    /// Get how to handle validation errors.
338    fn handle_validation_error(&self) -> &HandleValidationError {
339        &HandleValidationError::Bool(false)
340    }
341
342    /// Get the response format for the tool.
343    fn response_format(&self) -> ResponseFormat {
344        ResponseFormat::Content
345    }
346
347    /// Get optional provider-specific extra fields.
348    fn extras(&self) -> Option<&HashMap<String, Value>> {
349        None
350    }
351
352    /// Check if the tool accepts only a single input argument.
353    fn is_single_input(&self) -> bool {
354        let args = self.args();
355        let keys: Vec<_> = args.keys().filter(|k| *k != "kwargs").collect();
356        keys.len() == 1
357    }
358
359    /// Get the tool's input arguments schema.
360    fn args(&self) -> HashMap<String, Value> {
361        self.args_schema()
362            .map(|s| s.properties())
363            .unwrap_or_default()
364    }
365
366    /// Get the schema for tool calls, excluding injected arguments.
367    fn tool_call_schema(&self) -> ArgsSchema {
368        self.args_schema().cloned().unwrap_or_default()
369    }
370
371    /// Get the tool definition for LLM function calling.
372    fn definition(&self) -> ToolDefinition {
373        ToolDefinition {
374            name: self.name().to_string(),
375            description: self.description().to_string(),
376            parameters: self
377                .args_schema()
378                .map(|s| s.to_json_schema())
379                .unwrap_or_else(|| serde_json::json!({"type": "object", "properties": {}})),
380        }
381    }
382
383    /// Get the JSON schema for the tool's parameters.
384    fn parameters_schema(&self) -> Value {
385        self.definition().parameters
386    }
387
388    /// Run the tool synchronously.
389    fn run(&self, input: ToolInput, config: Option<RunnableConfig>) -> Result<ToolOutput>;
390
391    /// Run the tool asynchronously.
392    async fn arun(&self, input: ToolInput, config: Option<RunnableConfig>) -> Result<ToolOutput> {
393        // Default implementation uses sync run
394        self.run(input, config)
395    }
396
397    /// Invoke the tool with a ToolCall.
398    async fn invoke(&self, tool_call: ToolCall) -> BaseMessage {
399        let input = ToolInput::ToolCall(tool_call.clone());
400        match self.arun(input, None).await {
401            Ok(output) => match output {
402                ToolOutput::String(s) => ToolMessage::new(s, tool_call.id()).into(),
403                ToolOutput::Message(m) => m.into(),
404                ToolOutput::ContentAndArtifact { content, artifact } => {
405                    ToolMessage::with_artifact(content.to_string(), tool_call.id(), artifact).into()
406                }
407                ToolOutput::Json(v) => ToolMessage::new(v.to_string(), tool_call.id()).into(),
408            },
409            Err(e) => ToolMessage::error(e.to_string(), tool_call.id()).into(),
410        }
411    }
412
413    /// Invoke the tool directly with arguments.
414    async fn invoke_args(&self, args: Value) -> Value {
415        let tool_call = ToolCall::new(self.name(), args);
416        let result = self.invoke(tool_call).await;
417        Value::String(result.content().to_string())
418    }
419}
420
421/// Annotation for tool arguments that are injected at runtime.
422///
423/// Tool arguments annotated with this are not included in the tool
424/// schema sent to language models and are instead injected during execution.
425#[derive(Debug, Clone, Default)]
426pub struct InjectedToolArg;
427
428/// Annotation for injecting the tool call ID.
429///
430/// This annotation is used to mark a tool parameter that should receive
431/// the tool call ID at runtime.
432#[derive(Debug, Clone, Default)]
433pub struct InjectedToolCallId;
434
435/// Check if an input is a tool call dictionary.
436pub fn is_tool_call(input: &Value) -> bool {
437    input.get("type").and_then(|t| t.as_str()) == Some("tool_call")
438}
439
440/// Handle a tool exception based on the configured flag.
441pub fn handle_tool_error_impl(e: &ToolException, flag: &HandleToolError) -> Option<String> {
442    match flag {
443        HandleToolError::None => None,
444        HandleToolError::Bool(false) => None,
445        HandleToolError::Bool(true) => Some(e.0.clone()),
446        HandleToolError::Message(msg) => Some(msg.clone()),
447        HandleToolError::Handler(f) => Some(f(e)),
448    }
449}
450
451/// Handle a validation error based on the configured flag.
452pub fn handle_validation_error_impl(e: &str, flag: &HandleValidationError) -> Option<String> {
453    match flag {
454        HandleValidationError::None => None,
455        HandleValidationError::Bool(false) => None,
456        HandleValidationError::Bool(true) => Some("Tool input validation error".to_string()),
457        HandleValidationError::Message(msg) => Some(msg.clone()),
458        HandleValidationError::Handler(f) => Some(f(e)),
459    }
460}
461
462/// Format tool output as appropriate.
463pub fn format_output(
464    content: Value,
465    artifact: Option<Value>,
466    tool_call_id: Option<&str>,
467    name: &str,
468    _status: &str,
469) -> ToolOutput {
470    if let Some(tool_call_id) = tool_call_id {
471        let msg = if let Some(artifact) = artifact {
472            ToolMessage::with_artifact(stringify_content(&content), tool_call_id, artifact)
473        } else {
474            ToolMessage::new(stringify_content(&content), tool_call_id)
475        };
476        ToolOutput::Message(msg.with_name(name))
477    } else {
478        match content {
479            Value::String(s) => ToolOutput::String(s),
480            other => ToolOutput::Json(other),
481        }
482    }
483}
484
485/// Check if content is a valid message content type.
486pub fn is_message_content_type(obj: &Value) -> bool {
487    match obj {
488        Value::String(_) => true,
489        Value::Array(arr) => arr.iter().all(is_message_content_block),
490        _ => false,
491    }
492}
493
494/// Check if object is a valid message content block.
495pub fn is_message_content_block(obj: &Value) -> bool {
496    match obj {
497        Value::String(_) => true,
498        Value::Object(map) => map
499            .get("type")
500            .and_then(|t| t.as_str())
501            .map(|t| TOOL_MESSAGE_BLOCK_TYPES.contains(&t))
502            .unwrap_or(false),
503        _ => false,
504    }
505}
506
507/// Convert content to string, preferring JSON format.
508pub fn stringify_content(content: &Value) -> String {
509    match content {
510        Value::String(s) => s.clone(),
511        other => serde_json::to_string(other).unwrap_or_else(|_| other.to_string()),
512    }
513}
514
515/// Prepare arguments for tool execution.
516pub fn prep_run_args(
517    value: ToolInput,
518    config: Option<RunnableConfig>,
519) -> (ToolInput, Option<String>, RunnableConfig) {
520    let config = ensure_config(config);
521
522    match &value {
523        ToolInput::ToolCall(tc) => {
524            let tool_call_id = Some(tc.id().to_string());
525            let input = ToolInput::Dict(
526                tc.args()
527                    .as_object()
528                    .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
529                    .unwrap_or_default(),
530            );
531            (input, tool_call_id, config)
532        }
533        _ => (value, None, config),
534    }
535}
536
537/// Base class for toolkits containing related tools.
538///
539/// A toolkit is a collection of related tools that can be used together
540/// to accomplish a specific task or work with a particular system.
541pub trait BaseToolkit: Send + Sync {
542    /// Get all tools in the toolkit.
543    fn get_tools(&self) -> Vec<Arc<dyn BaseTool>>;
544}
545
546/// Type alias for dynamic tool reference.
547pub type DynTool = Arc<dyn BaseTool>;
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    #[test]
554    fn test_tool_input_from_string() {
555        let input = ToolInput::from("test");
556        match input {
557            ToolInput::String(s) => assert_eq!(s, "test"),
558            _ => panic!("Expected String variant"),
559        }
560    }
561
562    #[test]
563    fn test_tool_input_from_value() {
564        let value = serde_json::json!({"key": "value"});
565        let input = ToolInput::from(value);
566        match input {
567            ToolInput::Dict(d) => {
568                assert_eq!(d.get("key"), Some(&Value::String("value".to_string())));
569            }
570            _ => panic!("Expected Dict variant"),
571        }
572    }
573
574    #[test]
575    fn test_is_tool_call() {
576        let tc = serde_json::json!({
577            "type": "tool_call",
578            "id": "123",
579            "name": "test",
580            "args": {}
581        });
582        assert!(is_tool_call(&tc));
583
584        let not_tc = serde_json::json!({"key": "value"});
585        assert!(!is_tool_call(&not_tc));
586    }
587
588    #[test]
589    fn test_args_schema_properties() {
590        let schema = ArgsSchema::JsonSchema(serde_json::json!({
591            "type": "object",
592            "properties": {
593                "query": {"type": "string"}
594            }
595        }));
596        let props = schema.properties();
597        assert!(props.contains_key("query"));
598    }
599
600    #[test]
601    fn test_response_format_default() {
602        assert_eq!(ResponseFormat::default(), ResponseFormat::Content);
603    }
604
605    #[test]
606    fn test_handle_tool_error() {
607        let exc = ToolException::new("test error");
608
609        let result = handle_tool_error_impl(&exc, &HandleToolError::Bool(false));
610        assert!(result.is_none());
611
612        let result = handle_tool_error_impl(&exc, &HandleToolError::Bool(true));
613        assert_eq!(result, Some("test error".to_string()));
614
615        let result = handle_tool_error_impl(&exc, &HandleToolError::Message("custom".to_string()));
616        assert_eq!(result, Some("custom".to_string()));
617    }
618}