langgraph_prebuilt/
tool_node.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value as JsonValue;
6use langgraph_checkpoint::config::RunnableConfig;
7use langgraph::runnable::{Runnable, RunnableError};
8
9use crate::traits::{BaseTool, ToolError};
10use crate::types::{Message, ToolCall};
11
12enum ToolCallResult {
14 Message(Message),
16 Command {
18 tool_call_id: String,
20 extra_messages: Vec<JsonValue>,
22 state_update: serde_json::Map<String, JsonValue>,
24 },
25}
26
27const INVALID_TOOL_NAME_ERROR: &str = "Error: {requested_tool} is not a valid tool, try one of [{available_tools}].";
29const TOOL_CALL_ERROR: &str = "Error: {error}\n Please fix your mistakes.";
30const TOOL_EXECUTION_ERROR: &str = "Error executing tool '{tool_name}' with kwargs {tool_kwargs} with error:\n {error}\n Please fix the error and try again.";
31
32pub struct ToolNode {
37 tools: HashMap<String, Arc<dyn BaseTool>>,
38 handle_tool_errors: bool,
39}
40
41impl ToolNode {
42 pub fn new(tools: Vec<Arc<dyn BaseTool>>) -> Self {
44 let tool_map: HashMap<String, Arc<dyn BaseTool>> = tools
45 .into_iter()
46 .map(|t| (t.name().to_string(), t))
47 .collect();
48
49 Self {
50 tools: tool_map,
51 handle_tool_errors: true,
52 }
53 }
54
55 pub fn with_error_handling(mut self, handle: bool) -> Self {
58 self.handle_tool_errors = handle;
59 self
60 }
61
62 pub fn tool_names(&self) -> Vec<&str> {
64 self.tools.keys().map(|s| s.as_str()).collect()
65 }
66
67 fn extract_tool_calls(input: &JsonValue) -> Vec<ToolCall> {
70 let messages = match input.get("messages") {
71 Some(JsonValue::Array(arr)) => arr,
72 _ => return vec![],
73 };
74
75 for msg in messages.iter().rev() {
77 if let Some(obj) = msg.as_object() {
78 if obj.get("type").and_then(|v| v.as_str()) == Some("ai") {
79 if let Some(JsonValue::Array(calls)) = obj.get("tool_calls") {
80 return calls
81 .iter()
82 .filter_map(|tc| serde_json::from_value(tc.clone()).ok())
83 .collect();
84 }
85 }
86 }
87 }
88
89 vec![]
90 }
91}
92
93#[async_trait]
94impl Runnable for ToolNode {
95 fn invoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
96 match tokio::runtime::Handle::try_current() {
98 Ok(handle) => handle.block_on(self.ainvoke(input, config)),
99 Err(_) => {
100 let rt = tokio::runtime::Runtime::new()
101 .map_err(|e| RunnableError::Node(e.to_string()))?;
102 rt.block_on(self.ainvoke(input, config))
103 }
104 }
105 }
106
107 async fn ainvoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
108 let tool_calls = Self::extract_tool_calls(input);
109
110 if tool_calls.is_empty() {
111 return Ok(serde_json::json!({}));
112 }
113
114 let mut join_set = tokio::task::JoinSet::new();
116 for tc in tool_calls {
117 let tool = self.tools.get(&tc.name).cloned();
118 let config = config.clone();
119 let handle_errors = self.handle_tool_errors;
120 let tool_name = tc.name.clone();
121 let available_tools: Vec<String> = self.tools.keys().cloned().collect();
122
123 join_set.spawn(async move {
124 let tool = match tool {
125 Some(t) => t,
126 None => {
127 return Err(ToolError::NotFound(
128 INVALID_TOOL_NAME_ERROR
129 .replace("{requested_tool}", &tc.name)
130 .replace("{available_tools}", &available_tools.join(", ")),
131 ));
132 }
133 };
134
135 let result = tool.ainvoke(&tc.args, &config).await;
136 let tool_call_id = tc.id.clone().unwrap_or_default();
137
138 match result {
139 Ok(output) => {
140 let output = match &output {
143 JsonValue::String(s) => serde_json::from_str(s).unwrap_or(output),
144 _ => output,
145 };
146
147 if let Some(obj) = output.as_object() {
149 if obj.contains_key("update") || obj.contains_key("resume") {
150 let mut state_update = serde_json::Map::new();
151 let mut extra_messages: Vec<JsonValue> = Vec::new();
152
153 if let Some(update) = obj.get("update") {
154 if let Some(update_obj) = update.as_object() {
155 if let Some(JsonValue::Array(msgs)) = update_obj.get("messages") {
157 for msg in msgs {
158 let mut msg = msg.clone();
159 if let Some(msg_obj) = msg.as_object_mut() {
161 if msg_obj.contains_key("tool_call_id") {
162 msg_obj.insert(
163 "tool_call_id".to_string(),
164 JsonValue::String(tool_call_id.clone()),
165 );
166 }
167 }
168 extra_messages.push(msg);
169 }
170 }
171 for (k, v) in update_obj {
173 if k != "messages" {
174 state_update.insert(k.clone(), v.clone());
175 }
176 }
177 }
178 }
179
180 return Ok(ToolCallResult::Command {
181 tool_call_id,
182 extra_messages,
183 state_update,
184 });
185 }
186 }
187
188 let content = match output {
189 JsonValue::String(s) => s,
190 other => serde_json::to_string_pretty(&other).unwrap_or_else(|_| format!("{:?}", other)),
191 };
192 Ok(ToolCallResult::Message(Message::tool_result(tool_call_id, content)))
193 }
194 Err(crate::traits::ToolError::Interrupt(interrupt)) => {
195 Err(crate::traits::ToolError::Interrupt(interrupt))
196 }
197 Err(e) => {
198 if handle_errors {
199 let error_msg = TOOL_EXECUTION_ERROR
200 .replace("{tool_name}", &tool_name)
201 .replace("{tool_kwargs}", &serde_json::to_string(&tc.args).unwrap_or_default())
202 .replace("{error}", &e.to_string());
203 Ok(ToolCallResult::Message(Message::tool_error(tool_call_id, error_msg)))
204 } else {
205 Err(e)
206 }
207 }
208 }
209 });
210 }
211
212 let mut messages: Vec<JsonValue> = Vec::new();
214 let mut state_updates: serde_json::Map<String, JsonValue> = serde_json::Map::new();
215
216 while let Some(result) = join_set.join_next().await {
217 let msg_result = result.map_err(|e| RunnableError::Node(e.to_string()))?;
218 match msg_result {
219 Ok(ToolCallResult::Message(msg)) => {
220 messages.push(serde_json::to_value(msg).map_err(|e| RunnableError::Node(e.to_string()))?);
221 }
222 Ok(ToolCallResult::Command { tool_call_id, extra_messages, state_update }) => {
223 if extra_messages.is_empty() {
224 let default_msg = Message::tool_result(tool_call_id, "Command processed");
226 messages.push(serde_json::to_value(default_msg).map_err(|e| RunnableError::Node(e.to_string()))?);
227 } else {
228 messages.extend(extra_messages);
229 }
230 for (k, v) in state_update {
232 state_updates.insert(k, v);
233 }
234 }
235 Err(ToolError::Interrupt(interrupt)) => {
236 return Err(RunnableError::Interrupt(interrupt));
237 }
238 Err(e) => {
239 return Err(RunnableError::Node(e.to_string()));
240 }
241 }
242 }
243
244 let mut result = serde_json::json!({ "messages": messages });
246 if let Some(obj) = result.as_object_mut() {
247 for (k, v) in state_updates {
248 obj.insert(k, v);
249 }
250 }
251
252 Ok(result)
253 }
254
255 fn name(&self) -> &str {
256 "ToolNode"
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264
265 #[test]
266 fn test_extract_tool_calls() {
267 let input = serde_json::json!({
268 "messages": [
269 {"type": "human", "content": "Search for cats"},
270 {"type": "ai", "content": "", "tool_calls": [
271 {"name": "search", "args": {"query": "cats"}, "id": "call_1"}
272 ]}
273 ]
274 });
275
276 let calls = ToolNode::extract_tool_calls(&input);
277 assert_eq!(calls.len(), 1);
278 assert_eq!(calls[0].name, "search");
279 }
280
281 #[test]
282 fn test_extract_no_tool_calls() {
283 let input = serde_json::json!({
284 "messages": [
285 {"type": "human", "content": "Hello"},
286 {"type": "ai", "content": "Hi there!"}
287 ]
288 });
289
290 let calls = ToolNode::extract_tool_calls(&input);
291 assert!(calls.is_empty());
292 }
293}