Skip to main content

langgraph_prebuilt/
tool_node.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value as JsonValue;
6use langgraph_checkpoint::config::RunnableConfig;
7use langgraph::runnable::{Runnable, RunnableError};
8
9use crate::traits::{BaseTool, ToolError};
10use crate::types::{Message, ToolCall};
11
12/// Result of executing a tool call.
13enum ToolCallResult {
14    /// Normal tool message.
15    Message(Message),
16    /// A Command returned by the tool (for state updates, resume, goto).
17    Command {
18        /// The tool_call_id for this invocation.
19        tool_call_id: String,
20        /// Extra messages from the Command.update (e.g., ToolMessages with state updates).
21        extra_messages: Vec<JsonValue>,
22        /// State update fields from Command.update (excluding messages).
23        state_update: serde_json::Map<String, JsonValue>,
24    },
25}
26
27/// Error message templates for tool invocation failures.
28const INVALID_TOOL_NAME_ERROR: &str = "Error: {requested_tool} is not a valid tool, try one of [{available_tools}].";
29const TOOL_CALL_ERROR: &str = "Error: {error}\n Please fix your mistakes.";
30const TOOL_EXECUTION_ERROR: &str = "Error executing tool '{tool_name}' with kwargs {tool_kwargs} with error:\n {error}\n Please fix the error and try again.";
31
32/// A node that executes tool calls from the AI's response.
33///
34/// ToolNode reads tool calls from the last AI message and executes them
35/// in parallel, returning the results as tool messages.
36pub struct ToolNode {
37    tools: HashMap<String, Arc<dyn BaseTool>>,
38    handle_tool_errors: bool,
39}
40
41impl ToolNode {
42    /// Create a new ToolNode with the given tools.
43    pub fn new(tools: Vec<Arc<dyn BaseTool>>) -> Self {
44        let tool_map: HashMap<String, Arc<dyn BaseTool>> = tools
45            .into_iter()
46            .map(|t| (t.name().to_string(), t))
47            .collect();
48
49        Self {
50            tools: tool_map,
51            handle_tool_errors: true,
52        }
53    }
54
55    /// Set whether to handle tool errors gracefully (returning error messages
56    /// instead of propagating).
57    pub fn with_error_handling(mut self, handle: bool) -> Self {
58        self.handle_tool_errors = handle;
59        self
60    }
61
62    /// Get the list of available tool names.
63    pub fn tool_names(&self) -> Vec<&str> {
64        self.tools.keys().map(|s| s.as_str()).collect()
65    }
66
67    /// Extract tool calls from the input state.
68    /// Expects a JSON object with a "messages" array containing AI messages with tool calls.
69    fn extract_tool_calls(input: &JsonValue) -> Vec<ToolCall> {
70        let messages = match input.get("messages") {
71            Some(JsonValue::Array(arr)) => arr,
72            _ => return vec![],
73        };
74
75        // Get the last AI message with tool calls
76        for msg in messages.iter().rev() {
77            if let Some(obj) = msg.as_object() {
78                if obj.get("type").and_then(|v| v.as_str()) == Some("ai") {
79                    if let Some(JsonValue::Array(calls)) = obj.get("tool_calls") {
80                        return calls
81                            .iter()
82                            .filter_map(|tc| serde_json::from_value(tc.clone()).ok())
83                            .collect();
84                    }
85                }
86            }
87        }
88
89        vec![]
90    }
91}
92
93#[async_trait]
94impl Runnable for ToolNode {
95    fn invoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
96        // Use tokio runtime for sync invocation
97        match tokio::runtime::Handle::try_current() {
98            Ok(handle) => handle.block_on(self.ainvoke(input, config)),
99            Err(_) => {
100                let rt = tokio::runtime::Runtime::new()
101                    .map_err(|e| RunnableError::Node(e.to_string()))?;
102                rt.block_on(self.ainvoke(input, config))
103            }
104        }
105    }
106
107    async fn ainvoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
108        let tool_calls = Self::extract_tool_calls(input);
109
110        if tool_calls.is_empty() {
111            return Ok(serde_json::json!({}));
112        }
113
114        // Execute all tool calls concurrently using JoinSet
115        let mut join_set = tokio::task::JoinSet::new();
116        for tc in tool_calls {
117            let tool = self.tools.get(&tc.name).cloned();
118            let config = config.clone();
119            let handle_errors = self.handle_tool_errors;
120            let tool_name = tc.name.clone();
121            let available_tools: Vec<String> = self.tools.keys().cloned().collect();
122
123            join_set.spawn(async move {
124                let tool = match tool {
125                    Some(t) => t,
126                    None => {
127                        return Err(ToolError::NotFound(
128                            INVALID_TOOL_NAME_ERROR
129                                .replace("{requested_tool}", &tc.name)
130                                .replace("{available_tools}", &available_tools.join(", ")),
131                        ));
132                    }
133                };
134
135                let result = tool.ainvoke(&tc.args, &config).await;
136                let tool_call_id = tc.id.clone().unwrap_or_default();
137
138                match result {
139                    Ok(output) => {
140                        // If the output is a string, try to parse it as JSON
141                        // (handles tools that return serialized JSON strings)
142                        let output = match &output {
143                            JsonValue::String(s) => serde_json::from_str(s).unwrap_or(output),
144                            _ => output,
145                        };
146
147                        // Check if the tool returned a Command (has "update" or "resume" field)
148                        if let Some(obj) = output.as_object() {
149                            if obj.contains_key("update") || obj.contains_key("resume") {
150                                let mut state_update = serde_json::Map::new();
151                                let mut extra_messages: Vec<JsonValue> = Vec::new();
152
153                                if let Some(update) = obj.get("update") {
154                                    if let Some(update_obj) = update.as_object() {
155                                        // Extract messages from update, fix up tool_call_id
156                                        if let Some(JsonValue::Array(msgs)) = update_obj.get("messages") {
157                                            for msg in msgs {
158                                                let mut msg = msg.clone();
159                                                // Fix up tool_call_id in each message
160                                                if let Some(msg_obj) = msg.as_object_mut() {
161                                                    if msg_obj.contains_key("tool_call_id") {
162                                                        msg_obj.insert(
163                                                            "tool_call_id".to_string(),
164                                                            JsonValue::String(tool_call_id.clone()),
165                                                        );
166                                                    }
167                                                }
168                                                extra_messages.push(msg);
169                                            }
170                                        }
171                                        // Collect non-messages fields as state updates
172                                        for (k, v) in update_obj {
173                                            if k != "messages" {
174                                                state_update.insert(k.clone(), v.clone());
175                                            }
176                                        }
177                                    }
178                                }
179
180                                return Ok(ToolCallResult::Command {
181                                    tool_call_id,
182                                    extra_messages,
183                                    state_update,
184                                });
185                            }
186                        }
187
188                        let content = match output {
189                            JsonValue::String(s) => s,
190                            other => serde_json::to_string_pretty(&other).unwrap_or_else(|_| format!("{:?}", other)),
191                        };
192                        Ok(ToolCallResult::Message(Message::tool_result(tool_call_id, content)))
193                    }
194                    Err(crate::traits::ToolError::Interrupt(interrupt)) => {
195                        Err(crate::traits::ToolError::Interrupt(interrupt))
196                    }
197                    Err(e) => {
198                        if handle_errors {
199                            let error_msg = TOOL_EXECUTION_ERROR
200                                .replace("{tool_name}", &tool_name)
201                                .replace("{tool_kwargs}", &serde_json::to_string(&tc.args).unwrap_or_default())
202                                .replace("{error}", &e.to_string());
203                            Ok(ToolCallResult::Message(Message::tool_error(tool_call_id, error_msg)))
204                        } else {
205                            Err(e)
206                        }
207                    }
208                }
209            });
210        }
211
212        // Collect results from all spawned tasks
213        let mut messages: Vec<JsonValue> = Vec::new();
214        let mut state_updates: serde_json::Map<String, JsonValue> = serde_json::Map::new();
215
216        while let Some(result) = join_set.join_next().await {
217            let msg_result = result.map_err(|e| RunnableError::Node(e.to_string()))?;
218            match msg_result {
219                Ok(ToolCallResult::Message(msg)) => {
220                    messages.push(serde_json::to_value(msg).map_err(|e| RunnableError::Node(e.to_string()))?);
221                }
222                Ok(ToolCallResult::Command { tool_call_id, extra_messages, state_update }) => {
223                    if extra_messages.is_empty() {
224                        // No messages in Command — add a default tool response
225                        let default_msg = Message::tool_result(tool_call_id, "Command processed");
226                        messages.push(serde_json::to_value(default_msg).map_err(|e| RunnableError::Node(e.to_string()))?);
227                    } else {
228                        messages.extend(extra_messages);
229                    }
230                    // Merge state updates from Command
231                    for (k, v) in state_update {
232                        state_updates.insert(k, v);
233                    }
234                }
235                Err(ToolError::Interrupt(interrupt)) => {
236                    return Err(RunnableError::Interrupt(interrupt));
237                }
238                Err(e) => {
239                    return Err(RunnableError::Node(e.to_string()));
240                }
241            }
242        }
243
244        // Build result: messages + any state updates from Commands
245        let mut result = serde_json::json!({ "messages": messages });
246        if let Some(obj) = result.as_object_mut() {
247            for (k, v) in state_updates {
248                obj.insert(k, v);
249            }
250        }
251
252        Ok(result)
253    }
254
255    fn name(&self) -> &str {
256        "ToolNode"
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    
264
265    #[test]
266    fn test_extract_tool_calls() {
267        let input = serde_json::json!({
268            "messages": [
269                {"type": "human", "content": "Search for cats"},
270                {"type": "ai", "content": "", "tool_calls": [
271                    {"name": "search", "args": {"query": "cats"}, "id": "call_1"}
272                ]}
273            ]
274        });
275
276        let calls = ToolNode::extract_tool_calls(&input);
277        assert_eq!(calls.len(), 1);
278        assert_eq!(calls[0].name, "search");
279    }
280
281    #[test]
282    fn test_extract_no_tool_calls() {
283        let input = serde_json::json!({
284            "messages": [
285                {"type": "human", "content": "Hello"},
286                {"type": "ai", "content": "Hi there!"}
287            ]
288        });
289
290        let calls = ToolNode::extract_tool_calls(&input);
291        assert!(calls.is_empty());
292    }
293}