agent_chain_core/messages/
tool.rs

1//! Tool-related message types.
2//!
3//! This module contains types for tool calls and tool messages,
4//! mirroring `langchain_core.messages.tool`.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9use crate::utils::uuid7;
10
11#[cfg(feature = "specta")]
12use specta::Type;
13
14/// Mixin trait for objects that tools can return directly.
15///
16/// If a custom Tool is invoked with a `ToolCall` and the output of custom code is
17/// not an instance of `ToolOutputMixin`, the output will automatically be coerced to
18/// a string and wrapped in a `ToolMessage`.
19pub trait ToolOutputMixin {}
20
21/// A tool call made by the AI model.
22///
23/// Represents an AI's request to call a tool. This corresponds to
24/// `ToolCall` in LangChain Python.
25#[cfg_attr(feature = "specta", derive(Type))]
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub struct ToolCall {
28    /// Unique identifier for this tool call
29    id: String,
30    /// Name of the tool to call
31    name: String,
32    /// Arguments for the tool call as a JSON object
33    args: serde_json::Value,
34}
35
36impl ToolCall {
37    /// Create a new tool call.
38    pub fn new(name: impl Into<String>, args: serde_json::Value) -> Self {
39        Self {
40            id: uuid7(None).to_string(),
41            name: name.into(),
42            args,
43        }
44    }
45
46    /// Create a new tool call with a specific ID.
47    pub fn with_id(
48        id: impl Into<String>,
49        name: impl Into<String>,
50        args: serde_json::Value,
51    ) -> Self {
52        Self {
53            id: id.into(),
54            name: name.into(),
55            args,
56        }
57    }
58
59    /// Get the tool call ID.
60    pub fn id(&self) -> &str {
61        &self.id
62    }
63
64    /// Get the tool name.
65    pub fn name(&self) -> &str {
66        &self.name
67    }
68
69    /// Get the tool arguments.
70    pub fn args(&self) -> &serde_json::Value {
71        &self.args
72    }
73}
74
75/// A tool call chunk (yielded when streaming).
76///
77/// When merging tool call chunks, all string attributes are concatenated.
78/// Chunks are only merged if their values of `index` are equal and not None.
79#[cfg_attr(feature = "specta", derive(Type))]
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
81pub struct ToolCallChunk {
82    /// The name of the tool to be called (may be partial during streaming)
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub name: Option<String>,
85    /// The arguments to the tool call (may be partial JSON string during streaming)
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub args: Option<String>,
88    /// An identifier associated with the tool call
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub id: Option<String>,
91    /// The index of the tool call in a sequence
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub index: Option<i32>,
94}
95
96impl ToolCallChunk {
97    /// Create a new tool call chunk.
98    pub fn new(
99        name: Option<String>,
100        args: Option<String>,
101        id: Option<String>,
102        index: Option<i32>,
103    ) -> Self {
104        Self {
105            name,
106            args,
107            id,
108            index,
109        }
110    }
111}
112
113/// Represents an invalid tool call that failed parsing.
114///
115/// Here we add an `error` key to surface errors made during generation
116/// (e.g., invalid JSON arguments.)
117#[cfg_attr(feature = "specta", derive(Type))]
118#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
119pub struct InvalidToolCall {
120    /// The name of the tool to be called
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub name: Option<String>,
123    /// The arguments to the tool call (unparsed string)
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub args: Option<String>,
126    /// An identifier associated with the tool call
127    #[serde(skip_serializing_if = "Option::is_none")]
128    pub id: Option<String>,
129    /// An error message associated with the tool call
130    #[serde(skip_serializing_if = "Option::is_none")]
131    pub error: Option<String>,
132}
133
134impl InvalidToolCall {
135    /// Create a new invalid tool call.
136    pub fn new(
137        name: Option<String>,
138        args: Option<String>,
139        id: Option<String>,
140        error: Option<String>,
141    ) -> Self {
142        Self {
143            name,
144            args,
145            id,
146            error,
147        }
148    }
149}
150
151/// A tool message containing the result of a tool call.
152///
153/// `ToolMessage` objects contain the result of a tool invocation. Typically, the result
154/// is encoded inside the `content` field.
155///
156/// This corresponds to `ToolMessage` in LangChain Python.
157#[cfg_attr(feature = "specta", derive(Type))]
158#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
159pub struct ToolMessage {
160    /// The tool result content
161    content: String,
162    /// The ID of the tool call this message is responding to
163    tool_call_id: String,
164    /// Optional unique identifier
165    id: Option<String>,
166    /// Optional name for the tool
167    #[serde(skip_serializing_if = "Option::is_none")]
168    name: Option<String>,
169    /// Status of the tool invocation
170    #[serde(default = "default_status")]
171    status: ToolStatus,
172    /// Artifact of the tool execution which is not meant to be sent to the model.
173    ///
174    /// Should only be specified if it is different from the message content, e.g. if only
175    /// a subset of the full tool output is being passed as message content but the full
176    /// output is needed in other parts of the code.
177    #[serde(skip_serializing_if = "Option::is_none")]
178    artifact: Option<serde_json::Value>,
179    /// Additional metadata
180    #[serde(default)]
181    additional_kwargs: HashMap<String, serde_json::Value>,
182    /// Response metadata
183    #[serde(default)]
184    response_metadata: HashMap<String, serde_json::Value>,
185}
186
187fn default_status() -> ToolStatus {
188    ToolStatus::Success
189}
190
191/// Status of a tool invocation.
192#[cfg_attr(feature = "specta", derive(Type))]
193#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
194#[serde(rename_all = "lowercase")]
195pub enum ToolStatus {
196    #[default]
197    Success,
198    Error,
199}
200
201impl ToolMessage {
202    /// Create a new tool message.
203    pub fn new(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
204        Self {
205            content: content.into(),
206            tool_call_id: tool_call_id.into(),
207            id: Some(uuid7(None).to_string()),
208            name: None,
209            status: ToolStatus::Success,
210            artifact: None,
211            additional_kwargs: HashMap::new(),
212            response_metadata: HashMap::new(),
213        }
214    }
215
216    /// Create a new tool message with an explicit ID.
217    ///
218    /// Use this when deserializing or reconstructing messages where the ID must be preserved.
219    pub fn with_id(
220        id: impl Into<String>,
221        content: impl Into<String>,
222        tool_call_id: impl Into<String>,
223    ) -> Self {
224        Self {
225            content: content.into(),
226            tool_call_id: tool_call_id.into(),
227            id: Some(id.into()),
228            name: None,
229            status: ToolStatus::Success,
230            artifact: None,
231            additional_kwargs: HashMap::new(),
232            response_metadata: HashMap::new(),
233        }
234    }
235
236    /// Create a new tool message with error status.
237    pub fn error(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
238        Self {
239            content: content.into(),
240            tool_call_id: tool_call_id.into(),
241            id: Some(uuid7(None).to_string()),
242            name: None,
243            status: ToolStatus::Error,
244            artifact: None,
245            additional_kwargs: HashMap::new(),
246            response_metadata: HashMap::new(),
247        }
248    }
249
250    /// Create a new tool message with an artifact.
251    pub fn with_artifact(
252        content: impl Into<String>,
253        tool_call_id: impl Into<String>,
254        artifact: serde_json::Value,
255    ) -> Self {
256        Self {
257            content: content.into(),
258            tool_call_id: tool_call_id.into(),
259            id: Some(uuid7(None).to_string()),
260            name: None,
261            status: ToolStatus::Success,
262            artifact: Some(artifact),
263            additional_kwargs: HashMap::new(),
264            response_metadata: HashMap::new(),
265        }
266    }
267
268    /// Set the name for this tool message.
269    pub fn with_name(mut self, name: impl Into<String>) -> Self {
270        self.name = Some(name.into());
271        self
272    }
273
274    /// Get the message content.
275    pub fn content(&self) -> &str {
276        &self.content
277    }
278
279    /// Get the tool call ID this message responds to.
280    pub fn tool_call_id(&self) -> &str {
281        &self.tool_call_id
282    }
283
284    /// Get the message ID.
285    pub fn id(&self) -> Option<&str> {
286        self.id.as_deref()
287    }
288
289    /// Get the tool name.
290    pub fn name(&self) -> Option<&str> {
291        self.name.as_deref()
292    }
293
294    /// Get the status of the tool invocation.
295    pub fn status(&self) -> &ToolStatus {
296        &self.status
297    }
298
299    /// Get the artifact if present.
300    pub fn artifact(&self) -> Option<&serde_json::Value> {
301        self.artifact.as_ref()
302    }
303
304    /// Get additional kwargs.
305    pub fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
306        &self.additional_kwargs
307    }
308
309    /// Get response metadata.
310    pub fn response_metadata(&self) -> &HashMap<String, serde_json::Value> {
311        &self.response_metadata
312    }
313}
314
315impl ToolOutputMixin for ToolMessage {}
316
317/// Tool message chunk (yielded when streaming).
318///
319/// This corresponds to `ToolMessageChunk` in LangChain Python.
320#[cfg_attr(feature = "specta", derive(Type))]
321#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
322pub struct ToolMessageChunk {
323    /// The tool result content (may be partial during streaming)
324    content: String,
325    /// The ID of the tool call this message is responding to
326    tool_call_id: String,
327    /// Optional unique identifier
328    id: Option<String>,
329    /// Optional name for the tool
330    #[serde(skip_serializing_if = "Option::is_none")]
331    name: Option<String>,
332    /// Status of the tool invocation
333    #[serde(default = "default_status")]
334    status: ToolStatus,
335    /// Artifact of the tool execution
336    #[serde(skip_serializing_if = "Option::is_none")]
337    artifact: Option<serde_json::Value>,
338    /// Additional metadata
339    #[serde(default)]
340    additional_kwargs: HashMap<String, serde_json::Value>,
341    /// Response metadata
342    #[serde(default)]
343    response_metadata: HashMap<String, serde_json::Value>,
344}
345
346impl ToolMessageChunk {
347    /// Create a new tool message chunk.
348    pub fn new(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
349        Self {
350            content: content.into(),
351            tool_call_id: tool_call_id.into(),
352            id: None,
353            name: None,
354            status: ToolStatus::Success,
355            artifact: None,
356            additional_kwargs: HashMap::new(),
357            response_metadata: HashMap::new(),
358        }
359    }
360
361    /// Get the message content.
362    pub fn content(&self) -> &str {
363        &self.content
364    }
365
366    /// Get the tool call ID.
367    pub fn tool_call_id(&self) -> &str {
368        &self.tool_call_id
369    }
370
371    /// Get the message ID.
372    pub fn id(&self) -> Option<&str> {
373        self.id.as_deref()
374    }
375
376    /// Get the tool name.
377    pub fn name(&self) -> Option<&str> {
378        self.name.as_deref()
379    }
380
381    /// Get the status.
382    pub fn status(&self) -> &ToolStatus {
383        &self.status
384    }
385
386    /// Get the artifact.
387    pub fn artifact(&self) -> Option<&serde_json::Value> {
388        self.artifact.as_ref()
389    }
390
391    /// Concatenate this chunk with another chunk.
392    pub fn concat(&self, other: &ToolMessageChunk) -> ToolMessageChunk {
393        let mut content = self.content.clone();
394        content.push_str(&other.content);
395
396        // Merge status (error takes precedence)
397        let status = if self.status == ToolStatus::Error || other.status == ToolStatus::Error {
398            ToolStatus::Error
399        } else {
400            ToolStatus::Success
401        };
402
403        ToolMessageChunk {
404            content,
405            tool_call_id: self.tool_call_id.clone(),
406            id: self.id.clone().or_else(|| other.id.clone()),
407            name: self.name.clone().or_else(|| other.name.clone()),
408            status,
409            artifact: self.artifact.clone().or_else(|| other.artifact.clone()),
410            additional_kwargs: self.additional_kwargs.clone(),
411            response_metadata: self.response_metadata.clone(),
412        }
413    }
414
415    /// Convert this chunk to a complete ToolMessage.
416    pub fn to_message(&self) -> ToolMessage {
417        ToolMessage {
418            content: self.content.clone(),
419            tool_call_id: self.tool_call_id.clone(),
420            id: self.id.clone(),
421            name: self.name.clone(),
422            status: self.status.clone(),
423            artifact: self.artifact.clone(),
424            additional_kwargs: self.additional_kwargs.clone(),
425            response_metadata: self.response_metadata.clone(),
426        }
427    }
428}
429
430impl std::ops::Add for ToolMessageChunk {
431    type Output = ToolMessageChunk;
432
433    fn add(self, other: ToolMessageChunk) -> ToolMessageChunk {
434        self.concat(&other)
435    }
436}
437
438/// Factory function to create a tool call.
439///
440/// This corresponds to the `tool_call` function in LangChain Python.
441pub fn tool_call(name: impl Into<String>, args: serde_json::Value, id: Option<String>) -> ToolCall {
442    match id {
443        Some(id) => ToolCall::with_id(id, name, args),
444        None => ToolCall::new(name, args),
445    }
446}
447
448/// Factory function to create a tool call chunk.
449///
450/// This corresponds to the `tool_call_chunk` function in LangChain Python.
451pub fn tool_call_chunk(
452    name: Option<String>,
453    args: Option<String>,
454    id: Option<String>,
455    index: Option<i32>,
456) -> ToolCallChunk {
457    ToolCallChunk::new(name, args, id, index)
458}
459
460/// Factory function to create an invalid tool call.
461///
462/// This corresponds to the `invalid_tool_call` function in LangChain Python.
463pub fn invalid_tool_call(
464    name: Option<String>,
465    args: Option<String>,
466    id: Option<String>,
467    error: Option<String>,
468) -> InvalidToolCall {
469    InvalidToolCall::new(name, args, id, error)
470}
471
472/// Best-effort parsing of tools from raw tool call dictionaries.
473///
474/// This corresponds to the `default_tool_parser` function in LangChain Python.
475pub fn default_tool_parser(
476    raw_tool_calls: &[serde_json::Value],
477) -> (Vec<ToolCall>, Vec<InvalidToolCall>) {
478    let mut tool_calls = Vec::new();
479    let mut invalid_tool_calls = Vec::new();
480
481    for raw_tool_call in raw_tool_calls {
482        let function = match raw_tool_call.get("function") {
483            Some(f) => f,
484            None => continue,
485        };
486
487        let function_name = function
488            .get("name")
489            .and_then(|n| n.as_str())
490            .unwrap_or("")
491            .to_string();
492
493        let arguments_str = function
494            .get("arguments")
495            .and_then(|a| a.as_str())
496            .unwrap_or("{}");
497
498        let id = raw_tool_call
499            .get("id")
500            .and_then(|i| i.as_str())
501            .map(|s| s.to_string());
502
503        match serde_json::from_str::<serde_json::Value>(arguments_str) {
504            Ok(args) if args.is_object() => {
505                tool_calls.push(tool_call(function_name, args, id));
506            }
507            _ => {
508                invalid_tool_calls.push(invalid_tool_call(
509                    Some(function_name),
510                    Some(arguments_str.to_string()),
511                    id,
512                    None,
513                ));
514            }
515        }
516    }
517
518    (tool_calls, invalid_tool_calls)
519}
520
521/// Best-effort parsing of tool call chunks from raw tool call dictionaries.
522///
523/// This corresponds to the `default_tool_chunk_parser` function in LangChain Python.
524pub fn default_tool_chunk_parser(raw_tool_calls: &[serde_json::Value]) -> Vec<ToolCallChunk> {
525    let mut chunks = Vec::new();
526
527    for raw_tool_call in raw_tool_calls {
528        let (function_name, function_args) = match raw_tool_call.get("function") {
529            Some(f) => (
530                f.get("name")
531                    .and_then(|n| n.as_str())
532                    .map(|s| s.to_string()),
533                f.get("arguments")
534                    .and_then(|a| a.as_str())
535                    .map(|s| s.to_string()),
536            ),
537            None => (None, None),
538        };
539
540        let id = raw_tool_call
541            .get("id")
542            .and_then(|i| i.as_str())
543            .map(|s| s.to_string());
544
545        let index = raw_tool_call
546            .get("index")
547            .and_then(|i| i.as_i64())
548            .map(|i| i as i32);
549
550        chunks.push(tool_call_chunk(function_name, function_args, id, index));
551    }
552
553    chunks
554}