Skip to main content

langgraph_prebuilt/
types.rs

1use std::fmt;
2use serde::{Deserialize, Serialize};
3use serde_json::Value as JsonValue;
4
5fn default_tool_status() -> String {
6    "success".to_string()
7}
8
9/// A tool call requested by the AI model.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ToolCall {
12    /// The name of the tool to call.
13    pub name: String,
14    /// The arguments to pass to the tool, as a JSON object.
15    pub args: JsonValue,
16    /// A unique identifier for this tool call.
17    #[serde(default)]
18    pub id: Option<String>,
19}
20
21/// Message types for the agent system.
22///
23/// Mirrors the LangChain message types: HumanMessage, AIMessage,
24/// SystemMessage, ToolMessage, and RemoveMessage.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(tag = "type", rename_all = "snake_case")]
27pub enum Message {
28    /// A message from the human user.
29    Human {
30        content: MessageContent,
31        #[serde(default)]
32        id: Option<String>,
33    },
34    /// A message from the AI assistant.
35    Ai {
36        content: MessageContent,
37        #[serde(default)]
38        tool_calls: Vec<ToolCall>,
39        #[serde(default)]
40        id: Option<String>,
41        /// Token usage from the LLM API response, if available.
42        #[serde(default)]
43        usage: Option<crate::traits::LlmUsage>,
44        /// Thinking/reasoning content from models that support it (e.g., DeepSeek, o1/o3).
45        #[serde(default, skip_serializing_if = "Option::is_none")]
46        thinking: Option<String>,
47    },
48    /// A system message providing instructions.
49    System {
50        content: MessageContent,
51        #[serde(default)]
52        id: Option<String>,
53    },
54    /// A message containing the result of a tool call.
55    Tool {
56        content: MessageContent,
57        tool_call_id: String,
58        #[serde(default)]
59        name: Option<String>,
60        #[serde(default)]
61        id: Option<String>,
62        /// Status of the tool call: "success" or "error"
63        #[serde(default = "default_tool_status")]
64        status: String,
65    },
66    /// A message that removes a previous message by ID.
67    Remove {
68        id: String,
69    },
70}
71
72/// Content of a message - can be a simple string or structured content blocks.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74#[serde(untagged)]
75pub enum MessageContent {
76    /// Simple text content.
77    Text(String),
78    /// Structured content blocks (for multimodal messages).
79    Blocks(Vec<ContentBlock>),
80}
81
82/// A block of content within a message (text, image, etc.).
83#[derive(Debug, Clone, Serialize, Deserialize)]
84#[serde(tag = "type", rename_all = "snake_case")]
85pub enum ContentBlock {
86    Text {
87        text: String,
88    },
89    ImageUrl {
90        image_url: ImageUrl,
91    },
92}
93
94/// An image URL reference.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct ImageUrl {
97    pub url: String,
98    #[serde(default)]
99    pub detail: Option<String>,
100}
101
102impl Message {
103    /// Get the text content of the message, if any.
104    pub fn text(&self) -> Option<&str> {
105        match self {
106            Message::Human { content, .. }
107            | Message::Ai { content, .. }
108            | Message::System { content, .. }
109            | Message::Tool { content, .. } => match content {
110                MessageContent::Text(s) => Some(s.as_str()),
111                MessageContent::Blocks(blocks) => {
112                    // Return the first text block
113                    blocks.iter().find_map(|b| match b {
114                        ContentBlock::Text { text } => Some(text.as_str()),
115                        _ => None,
116                    })
117                }
118            },
119            Message::Remove { .. } => None,
120        }
121    }
122
123    /// Get the message ID, if any.
124    pub fn id(&self) -> Option<&str> {
125        match self {
126            Message::Human { id, .. }
127            | Message::Ai { id, .. }
128            | Message::System { id, .. }
129            | Message::Tool { id, .. } => id.as_deref(),
130            Message::Remove { id } => Some(id.as_str()),
131        }
132    }
133
134    /// Check if this message has tool calls.
135    pub fn has_tool_calls(&self) -> bool {
136        match self {
137            Message::Ai { tool_calls, .. } => !tool_calls.is_empty(),
138            _ => false,
139        }
140    }
141
142    /// Get tool calls from the message.
143    pub fn tool_calls(&self) -> &[ToolCall] {
144        match self {
145            Message::Ai { tool_calls, .. } => tool_calls,
146            _ => &[],
147        }
148    }
149
150    /// Create a human message.
151    pub fn human(content: impl Into<String>) -> Self {
152        Message::Human {
153            content: MessageContent::Text(content.into()),
154            id: None,
155        }
156    }
157
158    /// Create an AI message.
159    pub fn ai(content: impl Into<String>) -> Self {
160        Message::Ai {
161            content: MessageContent::Text(content.into()),
162            tool_calls: vec![],
163            id: None,
164            usage: None,
165            thinking: None,
166        }
167    }
168
169    /// Create an AI message with tool calls.
170    pub fn ai_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
171        Message::Ai {
172            content: MessageContent::Text(content.into()),
173            tool_calls,
174            id: None,
175            usage: None,
176            thinking: None,
177        }
178    }
179
180    /// Create an AI message with token usage information.
181    pub fn ai_with_usage(content: impl Into<String>, usage: crate::traits::LlmUsage) -> Self {
182        Message::Ai {
183            content: MessageContent::Text(content.into()),
184            tool_calls: vec![],
185            id: None,
186            usage: Some(usage),
187            thinking: None,
188        }
189    }
190
191    /// Create an AI message with tool calls and token usage.
192    pub fn ai_with_tool_calls_and_usage(
193        content: impl Into<String>,
194        tool_calls: Vec<ToolCall>,
195        usage: crate::traits::LlmUsage,
196    ) -> Self {
197        Message::Ai {
198            content: MessageContent::Text(content.into()),
199            tool_calls,
200            id: None,
201            usage: Some(usage),
202            thinking: None,
203        }
204    }
205
206    /// Get token usage from the message, if available.
207    pub fn usage(&self) -> Option<&crate::traits::LlmUsage> {
208        match self {
209            Message::Ai { usage, .. } => usage.as_ref(),
210            _ => None,
211        }
212    }
213
214    /// Get the thinking/reasoning content, if available.
215    pub fn thinking(&self) -> Option<&str> {
216        match self {
217            Message::Ai { thinking, .. } => thinking.as_deref(),
218            _ => None,
219        }
220    }
221
222    /// Create an AI message with thinking/reasoning content.
223    pub fn ai_with_thinking(
224        content: impl Into<String>,
225        thinking: impl Into<String>,
226    ) -> Self {
227        Message::Ai {
228            content: MessageContent::Text(content.into()),
229            tool_calls: vec![],
230            id: None,
231            usage: None,
232            thinking: Some(thinking.into()),
233        }
234    }
235
236    /// Create a system message.
237    pub fn system(content: impl Into<String>) -> Self {
238        Message::System {
239            content: MessageContent::Text(content.into()),
240            id: None,
241        }
242    }
243
244    /// Create a tool result message with success status.
245    pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
246        Message::Tool {
247            content: MessageContent::Text(content.into()),
248            tool_call_id: tool_call_id.into(),
249            name: None,
250            id: None,
251            status: "success".to_string(),
252        }
253    }
254
255    /// Create a tool error message.
256    pub fn tool_error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
257        Message::Tool {
258            content: MessageContent::Text(content.into()),
259            tool_call_id: tool_call_id.into(),
260            name: None,
261            id: None,
262            status: "error".to_string(),
263        }
264    }
265}
266
267impl fmt::Display for Message {
268    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269        match self {
270            Message::Human { content, .. } => write!(f, "[Human] {}", content_text(content)),
271            Message::Ai { content, tool_calls, thinking, .. } => {
272                let text = content_text(content);
273                if let Some(t) = thinking {
274                    write!(f, "[Thinking] {}\n[AI] {}", t, text)?;
275                    if !tool_calls.is_empty() {
276                        let calls: Vec<String> = tool_calls
277                            .iter()
278                            .map(|tc| format!("{}({})", tc.name, tc.args))
279                            .collect();
280                        write!(f, " → {}", calls.join(", "))?;
281                    }
282                    Ok(())
283                } else if tool_calls.is_empty() {
284                    write!(f, "[AI] {}", text)
285                } else {
286                    let calls: Vec<String> = tool_calls
287                        .iter()
288                        .map(|tc| format!("{}({})", tc.name, tc.args))
289                        .collect();
290                    if text.is_empty() {
291                        write!(f, "[AI] → {}", calls.join(", "))
292                    } else {
293                        write!(f, "[AI] {} → {}", text, calls.join(", "))
294                    }
295                }
296            }
297            Message::System { content, .. } => write!(f, "[System] {}", content_text(content)),
298            Message::Tool { content, name, status, .. } => {
299                let tool_name = name.as_deref().unwrap_or("tool");
300                let text = content_text(content);
301                if status == "error" {
302                    write!(f, "[Tool:{}] ERROR: {}", tool_name, text)
303                } else {
304                    write!(f, "[Tool:{}] {}", tool_name, text)
305                }
306            }
307            Message::Remove { id } => write!(f, "[Remove:{}]", id),
308        }
309    }
310}
311
312fn content_text(content: &MessageContent) -> &str {
313    match content {
314        MessageContent::Text(s) => s.as_str(),
315        MessageContent::Blocks(blocks) => blocks
316            .iter()
317            .find_map(|b| match b {
318                ContentBlock::Text { text } => Some(text.as_str()),
319                _ => None,
320            })
321            .unwrap_or(""),
322    }
323}
324
325impl From<String> for MessageContent {
326    fn from(s: String) -> Self {
327        MessageContent::Text(s)
328    }
329}
330
331impl From<&str> for MessageContent {
332    fn from(s: &str) -> Self {
333        MessageContent::Text(s.to_string())
334    }
335}
336
337/// Merge function for messages: appends new messages to existing ones.
338/// This is the default reducer for the `messages` field in agent states.
339pub fn add_messages(current: JsonValue, update: JsonValue) -> JsonValue {
340    // Check if the update is a "Reset" signal: {"reset": true, "messages": [...]}
341    if let Some(obj) = update.as_object() {
342        if obj.get("reset").and_then(|v| v.as_bool()) == Some(true) {
343            if let Some(msgs) = obj.get("messages").and_then(|v| v.as_array()) {
344                return JsonValue::Array(msgs.clone());
345            }
346        }
347    }
348
349    let messages: Vec<JsonValue> = match current {
350        JsonValue::Array(arr) => arr,
351        _ => vec![],
352    };
353
354    let new_messages: Vec<JsonValue> = match update {
355        JsonValue::Array(arr) => arr,
356        other => vec![other],
357    };
358
359    // Handle RemoveMessage by filtering out messages with matching IDs
360    let mut result: Vec<JsonValue> = Vec::new();
361    let mut remove_ids: Vec<String> = Vec::new();
362
363    // Collect IDs to remove
364    for msg in &new_messages {
365        if let Some(obj) = msg.as_object() {
366            if obj.get("type").and_then(|v| v.as_str()) == Some("remove") {
367                if let Some(id) = obj.get("id").and_then(|v| v.as_str()) {
368                    remove_ids.push(id.to_string());
369                }
370            }
371        }
372    }
373
374    // Add existing messages, skipping removed ones
375    for msg in messages {
376        if let Some(id) = msg.get("id").and_then(|v| v.as_str()) {
377            if remove_ids.contains(&id.to_string()) {
378                continue;
379            }
380        }
381        result.push(msg);
382    }
383
384    // Add new non-remove messages
385    for msg in new_messages {
386        if let Some(obj) = msg.as_object() {
387            if obj.get("type").and_then(|v| v.as_str()) == Some("remove") {
388                continue;
389            }
390        }
391        result.push(msg);
392    }
393
394    JsonValue::Array(result)
395}
396
397/// Merge function for messages with reference signature.
398///
399/// This is the version compatible with `#[channel(reducer = "...")]` in the
400/// derive macro, which expects `fn(&JsonValue, &JsonValue) -> JsonValue`.
401///
402/// ```ignore
403/// #[derive(StateGraph)]
404/// struct MyState {
405///     #[channel(reducer = "add_messages_ref")]
406///     messages: Vec<Message>,
407/// }
408/// ```
409pub fn add_messages_ref(current: &JsonValue, update: &JsonValue) -> JsonValue {
410    add_messages(current.clone(), update.clone())
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416
417    #[test]
418    fn test_message_human() {
419        let msg = Message::human("Hello");
420        assert_eq!(msg.text(), Some("Hello"));
421        assert!(msg.id().is_none());
422    }
423
424    #[test]
425    fn test_message_ai_with_tool_calls() {
426        let tc = ToolCall {
427            name: "search".into(),
428            args: serde_json::json!({"query": "test"}),
429            id: Some("call_1".into()),
430        };
431        let msg = Message::ai_with_tool_calls("", vec![tc]);
432        assert!(msg.has_tool_calls());
433        assert_eq!(msg.tool_calls().len(), 1);
434        assert_eq!(msg.tool_calls()[0].name, "search");
435    }
436
437    #[test]
438    fn test_add_messages() {
439        let existing = serde_json::json!([
440            {"type": "human", "content": "Hi"},
441        ]);
442        let update = serde_json::json!([
443            {"type": "ai", "content": "Hello"},
444        ]);
445        let result = add_messages(existing, update);
446        assert_eq!(result.as_array().unwrap().len(), 2);
447    }
448
449    #[test]
450    fn test_remove_message() {
451        let existing = serde_json::json!([
452            {"type": "human", "content": "Hi", "id": "msg1"},
453            {"type": "ai", "content": "Hello", "id": "msg2"},
454        ]);
455        let update = serde_json::json!([
456            {"type": "remove", "id": "msg1"},
457        ]);
458        let result = add_messages(existing, update);
459        let arr = result.as_array().unwrap();
460        assert_eq!(arr.len(), 1);
461        assert_eq!(arr[0]["id"], "msg2");
462    }
463
464    #[test]
465    fn test_message_serialization() {
466        let msg = Message::human("Hello world");
467        let json = serde_json::to_string(&msg).unwrap();
468        assert!(json.contains("human"));
469        assert!(json.contains("Hello world"));
470
471        let deserialized: Message = serde_json::from_str(&json).unwrap();
472        assert_eq!(deserialized.text(), Some("Hello world"));
473    }
474}