agent_chain_core/messages/
ai.rs

1//! AI message type.
2//!
3//! This module contains the `AIMessage` and `AIMessageChunk` types which represent
4//! messages from an AI model. Mirrors `langchain_core.messages.ai`.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[cfg(feature = "specta")]
10use specta::Type;
11
12use super::tool::{
13    InvalidToolCall, ToolCall, ToolCallChunk, default_tool_chunk_parser, default_tool_parser,
14    invalid_tool_call, tool_call,
15};
16use crate::utils::json::parse_partial_json;
17use crate::utils::merge::{merge_dicts, merge_lists};
18use crate::utils::usage::{dict_int_add_json, dict_int_sub_floor_json};
19use crate::utils::uuid::{LC_AUTO_PREFIX, LC_ID_PREFIX, uuid7};
20
21/// Breakdown of input token counts.
22///
23/// Does *not* need to sum to full input token count. Does *not* need to have all keys.
24#[cfg_attr(feature = "specta", derive(Type))]
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
26pub struct InputTokenDetails {
27    /// Audio input tokens.
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub audio: Option<i64>,
30    /// Input tokens that were cached and there was a cache miss.
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub cache_creation: Option<i64>,
33    /// Input tokens that were cached and there was a cache hit.
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub cache_read: Option<i64>,
36}
37
38/// Breakdown of output token counts.
39///
40/// Does *not* need to sum to full output token count. Does *not* need to have all keys.
41#[cfg_attr(feature = "specta", derive(Type))]
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
43pub struct OutputTokenDetails {
44    /// Audio output tokens.
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub audio: Option<i64>,
47    /// Reasoning output tokens.
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub reasoning: Option<i64>,
50}
51
52/// Usage metadata for a message, such as token counts.
53///
54/// This is a standard representation of token usage that is consistent across models.
55#[cfg_attr(feature = "specta", derive(Type))]
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
57pub struct UsageMetadata {
58    /// Count of input (or prompt) tokens. Sum of all input token types.
59    pub input_tokens: i64,
60    /// Count of output (or completion) tokens. Sum of all output token types.
61    pub output_tokens: i64,
62    /// Total token count. Sum of `input_tokens` + `output_tokens`.
63    pub total_tokens: i64,
64    /// Breakdown of input token counts.
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub input_token_details: Option<InputTokenDetails>,
67    /// Breakdown of output token counts.
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub output_token_details: Option<OutputTokenDetails>,
70}
71
72impl UsageMetadata {
73    /// Create a new usage metadata with the given token counts.
74    pub fn new(input_tokens: i64, output_tokens: i64) -> Self {
75        Self {
76            input_tokens,
77            output_tokens,
78            total_tokens: input_tokens + output_tokens,
79            input_token_details: None,
80            output_token_details: None,
81        }
82    }
83
84    /// Add another UsageMetadata to this one.
85    pub fn add(&self, other: &UsageMetadata) -> Self {
86        Self {
87            input_tokens: self.input_tokens + other.input_tokens,
88            output_tokens: self.output_tokens + other.output_tokens,
89            total_tokens: self.total_tokens + other.total_tokens,
90            input_token_details: match (&self.input_token_details, &other.input_token_details) {
91                (Some(a), Some(b)) => Some(InputTokenDetails {
92                    audio: match (a.audio, b.audio) {
93                        (Some(x), Some(y)) => Some(x + y),
94                        (Some(x), None) | (None, Some(x)) => Some(x),
95                        (None, None) => None,
96                    },
97                    cache_creation: match (a.cache_creation, b.cache_creation) {
98                        (Some(x), Some(y)) => Some(x + y),
99                        (Some(x), None) | (None, Some(x)) => Some(x),
100                        (None, None) => None,
101                    },
102                    cache_read: match (a.cache_read, b.cache_read) {
103                        (Some(x), Some(y)) => Some(x + y),
104                        (Some(x), None) | (None, Some(x)) => Some(x),
105                        (None, None) => None,
106                    },
107                }),
108                (Some(a), None) => Some(a.clone()),
109                (None, Some(b)) => Some(b.clone()),
110                (None, None) => None,
111            },
112            output_token_details: match (&self.output_token_details, &other.output_token_details) {
113                (Some(a), Some(b)) => Some(OutputTokenDetails {
114                    audio: match (a.audio, b.audio) {
115                        (Some(x), Some(y)) => Some(x + y),
116                        (Some(x), None) | (None, Some(x)) => Some(x),
117                        (None, None) => None,
118                    },
119                    reasoning: match (a.reasoning, b.reasoning) {
120                        (Some(x), Some(y)) => Some(x + y),
121                        (Some(x), None) | (None, Some(x)) => Some(x),
122                        (None, None) => None,
123                    },
124                }),
125                (Some(a), None) => Some(a.clone()),
126                (None, Some(b)) => Some(b.clone()),
127                (None, None) => None,
128            },
129        }
130    }
131}
132
133/// An AI message in the conversation.
134///
135/// An `AIMessage` is returned from a chat model as a response to a prompt.
136/// This message represents the output of the model and consists of both
137/// the raw output as returned by the model and standardized fields
138/// (e.g., tool calls, usage metadata).
139///
140/// This corresponds to `AIMessage` in LangChain Python.
141#[cfg_attr(feature = "specta", derive(Type))]
142#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
143pub struct AIMessage {
144    /// The message content
145    content: String,
146    /// Optional unique identifier
147    id: Option<String>,
148    /// Optional name for the message
149    #[serde(skip_serializing_if = "Option::is_none")]
150    name: Option<String>,
151    /// Tool calls made by the AI
152    #[serde(default)]
153    tool_calls: Vec<ToolCall>,
154    /// Tool calls with parsing errors associated with the message
155    #[serde(default)]
156    invalid_tool_calls: Vec<InvalidToolCall>,
157    /// If present, usage metadata for a message, such as token counts.
158    #[serde(skip_serializing_if = "Option::is_none")]
159    usage_metadata: Option<UsageMetadata>,
160    /// Additional metadata
161    #[serde(default)]
162    additional_kwargs: HashMap<String, serde_json::Value>,
163    /// Response metadata (e.g., response headers, logprobs, token counts, model name)
164    #[serde(default)]
165    response_metadata: HashMap<String, serde_json::Value>,
166}
167
168impl AIMessage {
169    /// Create a new AI message.
170    pub fn new(content: impl Into<String>) -> Self {
171        Self {
172            content: content.into(),
173            id: Some(uuid7(None).to_string()),
174            name: None,
175            tool_calls: Vec::new(),
176            invalid_tool_calls: Vec::new(),
177            usage_metadata: None,
178            additional_kwargs: HashMap::new(),
179            response_metadata: HashMap::new(),
180        }
181    }
182
183    /// Create a new AI message with an explicit ID.
184    ///
185    /// Use this when deserializing or reconstructing messages where the ID must be preserved.
186    pub fn with_id(id: impl Into<String>, content: impl Into<String>) -> Self {
187        Self {
188            content: content.into(),
189            id: Some(id.into()),
190            name: None,
191            tool_calls: Vec::new(),
192            invalid_tool_calls: Vec::new(),
193            usage_metadata: None,
194            additional_kwargs: HashMap::new(),
195            response_metadata: HashMap::new(),
196        }
197    }
198
199    /// Create a new AI message with tool calls.
200    pub fn with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
201        Self {
202            content: content.into(),
203            id: Some(uuid7(None).to_string()),
204            name: None,
205            tool_calls,
206            invalid_tool_calls: Vec::new(),
207            usage_metadata: None,
208            additional_kwargs: HashMap::new(),
209            response_metadata: HashMap::new(),
210        }
211    }
212
213    /// Create a new AI message with tool calls and an explicit ID.
214    ///
215    /// Use this when deserializing or reconstructing messages where the ID must be preserved.
216    pub fn with_id_and_tool_calls(
217        id: impl Into<String>,
218        content: impl Into<String>,
219        tool_calls: Vec<ToolCall>,
220    ) -> Self {
221        Self {
222            content: content.into(),
223            id: Some(id.into()),
224            name: None,
225            tool_calls,
226            invalid_tool_calls: Vec::new(),
227            usage_metadata: None,
228            additional_kwargs: HashMap::new(),
229            response_metadata: HashMap::new(),
230        }
231    }
232
233    /// Create a new AI message with both valid and invalid tool calls.
234    pub fn with_all_tool_calls(
235        content: impl Into<String>,
236        tool_calls: Vec<ToolCall>,
237        invalid_tool_calls: Vec<InvalidToolCall>,
238    ) -> Self {
239        Self {
240            content: content.into(),
241            id: Some(uuid7(None).to_string()),
242            name: None,
243            tool_calls,
244            invalid_tool_calls,
245            usage_metadata: None,
246            additional_kwargs: HashMap::new(),
247            response_metadata: HashMap::new(),
248        }
249    }
250
251    /// Set the name for this message.
252    pub fn with_name(mut self, name: impl Into<String>) -> Self {
253        self.name = Some(name.into());
254        self
255    }
256
257    /// Set invalid tool calls for this message.
258    pub fn with_invalid_tool_calls(mut self, invalid_tool_calls: Vec<InvalidToolCall>) -> Self {
259        self.invalid_tool_calls = invalid_tool_calls;
260        self
261    }
262
263    /// Set usage metadata for this message.
264    pub fn with_usage_metadata(mut self, usage_metadata: UsageMetadata) -> Self {
265        self.usage_metadata = Some(usage_metadata);
266        self
267    }
268
269    /// Get the message content.
270    pub fn content(&self) -> &str {
271        &self.content
272    }
273
274    /// Get the message ID.
275    pub fn id(&self) -> Option<&str> {
276        self.id.as_deref()
277    }
278
279    /// Get the message name.
280    pub fn name(&self) -> Option<&str> {
281        self.name.as_deref()
282    }
283
284    /// Get the tool calls.
285    pub fn tool_calls(&self) -> &[ToolCall] {
286        &self.tool_calls
287    }
288
289    /// Get the invalid tool calls.
290    pub fn invalid_tool_calls(&self) -> &[InvalidToolCall] {
291        &self.invalid_tool_calls
292    }
293
294    /// Get usage metadata if present.
295    pub fn usage_metadata(&self) -> Option<&UsageMetadata> {
296        self.usage_metadata.as_ref()
297    }
298
299    /// Add annotations to the message (e.g., citations from web search).
300    /// Annotations are stored in additional_kwargs under the "annotations" key.
301    pub fn with_annotations<T: Serialize>(mut self, annotations: Vec<T>) -> Self {
302        if let Ok(value) = serde_json::to_value(&annotations) {
303            self.additional_kwargs
304                .insert("annotations".to_string(), value);
305        }
306        self
307    }
308
309    /// Get annotations from the message if present.
310    pub fn annotations(&self) -> Option<&serde_json::Value> {
311        self.additional_kwargs.get("annotations")
312    }
313
314    /// Get additional kwargs.
315    pub fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
316        &self.additional_kwargs
317    }
318
319    /// Get response metadata.
320    pub fn response_metadata(&self) -> &HashMap<String, serde_json::Value> {
321        &self.response_metadata
322    }
323
324    /// Set response metadata.
325    pub fn with_response_metadata(
326        mut self,
327        response_metadata: HashMap<String, serde_json::Value>,
328    ) -> Self {
329        self.response_metadata = response_metadata;
330        self
331    }
332
333    /// Set additional kwargs.
334    pub fn with_additional_kwargs(
335        mut self,
336        additional_kwargs: HashMap<String, serde_json::Value>,
337    ) -> Self {
338        self.additional_kwargs = additional_kwargs;
339        self
340    }
341
342    /// Get a pretty representation of the message.
343    ///
344    /// This corresponds to `pretty_repr` in LangChain Python.
345    pub fn pretty_repr(&self, _html: bool) -> String {
346        let title = "AI Message";
347        let sep_len = (80 - title.len() - 2) / 2;
348        let sep: String = "=".repeat(sep_len);
349        let header = format!("{} {} {}", sep, title, sep);
350
351        let mut lines = vec![header];
352
353        if let Some(name) = &self.name {
354            lines.push(format!("Name: {}", name));
355        }
356
357        lines.push(String::new());
358        lines.push(self.content.clone());
359
360        format_tool_calls_repr(&self.tool_calls, &self.invalid_tool_calls, &mut lines);
361
362        lines.join("\n").trim().to_string()
363    }
364}
365
366/// Helper function to format tool calls for pretty_repr.
367fn format_tool_calls_repr(
368    tool_calls: &[ToolCall],
369    invalid_tool_calls: &[InvalidToolCall],
370    lines: &mut Vec<String>,
371) {
372    if !tool_calls.is_empty() {
373        lines.push("Tool Calls:".to_string());
374        for tc in tool_calls {
375            lines.push(format!("  {} ({})", tc.name(), tc.id()));
376            lines.push(format!(" Call ID: {}", tc.id()));
377            lines.push("  Args:".to_string());
378            if let serde_json::Value::Object(args) = tc.args() {
379                for (arg, value) in args {
380                    lines.push(format!("    {}: {}", arg, value));
381                }
382            } else {
383                lines.push(format!("    {}", tc.args()));
384            }
385        }
386    }
387    if !invalid_tool_calls.is_empty() {
388        lines.push("Invalid Tool Calls:".to_string());
389        for itc in invalid_tool_calls {
390            let name = itc.name.as_deref().unwrap_or("Tool");
391            let id = itc.id.as_deref().unwrap_or("unknown");
392            lines.push(format!("  {} ({})", name, id));
393            lines.push(format!(" Call ID: {}", id));
394            if let Some(error) = &itc.error {
395                lines.push(format!("  Error: {}", error));
396            }
397            lines.push("  Args:".to_string());
398            if let Some(args) = &itc.args {
399                lines.push(format!("    {}", args));
400            }
401        }
402    }
403}
404
405/// Position indicator for an aggregated AIMessageChunk.
406///
407/// If a chunk with `chunk_position="last"` is aggregated into a stream,
408/// `tool_call_chunks` in message content will be parsed into `tool_calls`.
409#[cfg_attr(feature = "specta", derive(Type))]
410#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
411#[serde(rename_all = "lowercase")]
412pub enum ChunkPosition {
413    /// This is the last chunk in the stream
414    Last,
415}
416
417/// AI message chunk (yielded when streaming).
418///
419/// This is returned from a chat model during streaming to incrementally
420/// build up a complete AIMessage.
421///
422/// This corresponds to `AIMessageChunk` in LangChain Python.
423#[cfg_attr(feature = "specta", derive(Type))]
424#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
425pub struct AIMessageChunk {
426    /// The message content (may be partial during streaming)
427    content: String,
428    /// Optional unique identifier
429    id: Option<String>,
430    /// Optional name for the message
431    #[serde(skip_serializing_if = "Option::is_none")]
432    name: Option<String>,
433    /// Tool calls made by the AI
434    #[serde(default)]
435    tool_calls: Vec<ToolCall>,
436    /// Tool calls with parsing errors
437    #[serde(default)]
438    invalid_tool_calls: Vec<InvalidToolCall>,
439    /// Tool call chunks (for streaming tool calls)
440    #[serde(default)]
441    tool_call_chunks: Vec<ToolCallChunk>,
442    /// If present, usage metadata for a message
443    #[serde(skip_serializing_if = "Option::is_none")]
444    usage_metadata: Option<UsageMetadata>,
445    /// Additional metadata
446    #[serde(default)]
447    additional_kwargs: HashMap<String, serde_json::Value>,
448    /// Response metadata
449    #[serde(default)]
450    response_metadata: HashMap<String, serde_json::Value>,
451    /// Optional span represented by an aggregated AIMessageChunk.
452    ///
453    /// If a chunk with `chunk_position=Some(ChunkPosition::Last)` is aggregated into a stream,
454    /// `tool_call_chunks` in message content will be parsed into `tool_calls`.
455    #[serde(skip_serializing_if = "Option::is_none")]
456    chunk_position: Option<ChunkPosition>,
457}
458
459impl AIMessageChunk {
460    /// Create a new AI message chunk.
461    pub fn new(content: impl Into<String>) -> Self {
462        Self {
463            content: content.into(),
464            id: None,
465            name: None,
466            tool_calls: Vec::new(),
467            invalid_tool_calls: Vec::new(),
468            tool_call_chunks: Vec::new(),
469            usage_metadata: None,
470            additional_kwargs: HashMap::new(),
471            response_metadata: HashMap::new(),
472            chunk_position: None,
473        }
474    }
475
476    /// Create a new AI message chunk with an ID.
477    pub fn with_id(id: impl Into<String>, content: impl Into<String>) -> Self {
478        Self {
479            content: content.into(),
480            id: Some(id.into()),
481            name: None,
482            tool_calls: Vec::new(),
483            invalid_tool_calls: Vec::new(),
484            tool_call_chunks: Vec::new(),
485            usage_metadata: None,
486            additional_kwargs: HashMap::new(),
487            response_metadata: HashMap::new(),
488            chunk_position: None,
489        }
490    }
491
492    /// Create a new AI message chunk with tool call chunks.
493    pub fn with_tool_call_chunks(
494        content: impl Into<String>,
495        tool_call_chunks: Vec<ToolCallChunk>,
496    ) -> Self {
497        Self {
498            content: content.into(),
499            id: None,
500            name: None,
501            tool_calls: Vec::new(),
502            invalid_tool_calls: Vec::new(),
503            tool_call_chunks,
504            usage_metadata: None,
505            additional_kwargs: HashMap::new(),
506            response_metadata: HashMap::new(),
507            chunk_position: None,
508        }
509    }
510
511    /// Get the message content.
512    pub fn content(&self) -> &str {
513        &self.content
514    }
515
516    /// Get the message ID.
517    pub fn id(&self) -> Option<&str> {
518        self.id.as_deref()
519    }
520
521    /// Get the message name.
522    pub fn name(&self) -> Option<&str> {
523        self.name.as_deref()
524    }
525
526    /// Get the tool calls.
527    pub fn tool_calls(&self) -> &[ToolCall] {
528        &self.tool_calls
529    }
530
531    /// Get the invalid tool calls.
532    pub fn invalid_tool_calls(&self) -> &[InvalidToolCall] {
533        &self.invalid_tool_calls
534    }
535
536    /// Get the tool call chunks.
537    pub fn tool_call_chunks(&self) -> &[ToolCallChunk] {
538        &self.tool_call_chunks
539    }
540
541    /// Get usage metadata if present.
542    pub fn usage_metadata(&self) -> Option<&UsageMetadata> {
543        self.usage_metadata.as_ref()
544    }
545
546    /// Get additional kwargs.
547    pub fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
548        &self.additional_kwargs
549    }
550
551    /// Get response metadata.
552    pub fn response_metadata(&self) -> &HashMap<String, serde_json::Value> {
553        &self.response_metadata
554    }
555
556    /// Get chunk position.
557    pub fn chunk_position(&self) -> Option<&ChunkPosition> {
558        self.chunk_position.as_ref()
559    }
560
561    /// Set chunk position.
562    pub fn set_chunk_position(&mut self, position: Option<ChunkPosition>) {
563        self.chunk_position = position;
564    }
565
566    /// Set tool calls.
567    pub fn set_tool_calls(&mut self, tool_calls: Vec<ToolCall>) {
568        self.tool_calls = tool_calls;
569    }
570
571    /// Set invalid tool calls.
572    pub fn set_invalid_tool_calls(&mut self, invalid_tool_calls: Vec<InvalidToolCall>) {
573        self.invalid_tool_calls = invalid_tool_calls;
574    }
575
576    /// Set tool call chunks.
577    pub fn set_tool_call_chunks(&mut self, tool_call_chunks: Vec<ToolCallChunk>) {
578        self.tool_call_chunks = tool_call_chunks;
579    }
580
581    /// Initialize tool calls from tool call chunks.
582    ///
583    /// This parses the tool_call_chunks and populates tool_calls and invalid_tool_calls.
584    /// This corresponds to `init_tool_calls` model validator in Python.
585    pub fn init_tool_calls(&mut self) {
586        if self.tool_call_chunks.is_empty() {
587            if !self.tool_calls.is_empty() {
588                self.tool_call_chunks = self
589                    .tool_calls
590                    .iter()
591                    .map(|tc| ToolCallChunk {
592                        name: Some(tc.name().to_string()),
593                        args: Some(tc.args().to_string()),
594                        id: Some(tc.id().to_string()),
595                        index: None,
596                    })
597                    .collect();
598            }
599            if !self.invalid_tool_calls.is_empty() {
600                self.tool_call_chunks
601                    .extend(self.invalid_tool_calls.iter().map(|tc| ToolCallChunk {
602                        name: tc.name.clone(),
603                        args: tc.args.clone(),
604                        id: tc.id.clone(),
605                        index: None,
606                    }));
607            }
608            return;
609        }
610
611        let mut new_tool_calls = Vec::new();
612        let mut new_invalid_tool_calls = Vec::new();
613
614        for chunk in &self.tool_call_chunks {
615            let args_result = if let Some(args_str) = &chunk.args {
616                if args_str.is_empty() {
617                    Ok(serde_json::Value::Object(serde_json::Map::new()))
618                } else {
619                    parse_partial_json(args_str, false)
620                }
621            } else {
622                Ok(serde_json::Value::Object(serde_json::Map::new()))
623            };
624
625            match args_result {
626                Ok(args) if args.is_object() => {
627                    new_tool_calls.push(tool_call(
628                        chunk.name.clone().unwrap_or_default(),
629                        args,
630                        chunk.id.clone(),
631                    ));
632                }
633                _ => {
634                    new_invalid_tool_calls.push(invalid_tool_call(
635                        chunk.name.clone(),
636                        chunk.args.clone(),
637                        chunk.id.clone(),
638                        None,
639                    ));
640                }
641            }
642        }
643
644        self.tool_calls = new_tool_calls;
645        self.invalid_tool_calls = new_invalid_tool_calls;
646    }
647
648    /// Concatenate this chunk with another chunk.
649    ///
650    /// This merges content, tool_call_chunks, and metadata.
651    /// For more sophisticated merging of multiple chunks, use `add_ai_message_chunks`.
652    pub fn concat(&self, other: &AIMessageChunk) -> AIMessageChunk {
653        add_ai_message_chunks(self.clone(), vec![other.clone()])
654    }
655
656    /// Convert this chunk to a complete AIMessage.
657    pub fn to_message(&self) -> AIMessage {
658        AIMessage {
659            content: self.content.clone(),
660            id: self.id.clone(),
661            name: self.name.clone(),
662            tool_calls: self.tool_calls.clone(),
663            invalid_tool_calls: self.invalid_tool_calls.clone(),
664            usage_metadata: self.usage_metadata.clone(),
665            additional_kwargs: self.additional_kwargs.clone(),
666            response_metadata: self.response_metadata.clone(),
667        }
668    }
669
670    /// Get a pretty representation of the message.
671    ///
672    /// This corresponds to `pretty_repr` in LangChain Python.
673    pub fn pretty_repr(&self, _html: bool) -> String {
674        let title = "AIMessageChunk";
675        let sep_len = (80 - title.len() - 2) / 2;
676        let sep: String = "=".repeat(sep_len);
677        let header = format!("{} {} {}", sep, title, sep);
678
679        let mut lines = vec![header];
680
681        if let Some(name) = &self.name {
682            lines.push(format!("Name: {}", name));
683        }
684
685        lines.push(String::new());
686        lines.push(self.content.clone());
687
688        format_tool_calls_repr(&self.tool_calls, &self.invalid_tool_calls, &mut lines);
689
690        lines.join("\n").trim().to_string()
691    }
692}
693
694/// Add multiple AIMessageChunks together.
695///
696/// This corresponds to `add_ai_message_chunks` in LangChain Python.
697///
698/// # Arguments
699///
700/// * `left` - The first AIMessageChunk.
701/// * `others` - Other AIMessageChunks to add.
702///
703/// # Returns
704///
705/// The resulting AIMessageChunk.
706pub fn add_ai_message_chunks(left: AIMessageChunk, others: Vec<AIMessageChunk>) -> AIMessageChunk {
707    // Merge content (simple string concatenation for now)
708    let mut content = left.content.clone();
709    for other in &others {
710        content.push_str(&other.content);
711    }
712
713    // Merge additional_kwargs using merge_dicts
714    let additional_kwargs = {
715        let left_val = serde_json::to_value(&left.additional_kwargs).unwrap_or_default();
716        let other_vals: Vec<serde_json::Value> = others
717            .iter()
718            .map(|o| serde_json::to_value(&o.additional_kwargs).unwrap_or_default())
719            .collect();
720        match merge_dicts(left_val, other_vals) {
721            Ok(merged) => serde_json::from_value(merged).unwrap_or_default(),
722            Err(_) => left.additional_kwargs.clone(),
723        }
724    };
725
726    // Merge response_metadata using merge_dicts
727    let response_metadata = {
728        let left_val = serde_json::to_value(&left.response_metadata).unwrap_or_default();
729        let other_vals: Vec<serde_json::Value> = others
730            .iter()
731            .map(|o| serde_json::to_value(&o.response_metadata).unwrap_or_default())
732            .collect();
733        match merge_dicts(left_val, other_vals) {
734            Ok(merged) => serde_json::from_value(merged).unwrap_or_default(),
735            Err(_) => left.response_metadata.clone(),
736        }
737    };
738
739    // Merge tool_call_chunks using merge_lists
740    let tool_call_chunks = {
741        let left_chunks: Vec<serde_json::Value> = left
742            .tool_call_chunks
743            .iter()
744            .filter_map(|tc| serde_json::to_value(tc).ok())
745            .collect();
746        let other_chunks: Vec<Option<Vec<serde_json::Value>>> = others
747            .iter()
748            .map(|o| {
749                Some(
750                    o.tool_call_chunks
751                        .iter()
752                        .filter_map(|tc| serde_json::to_value(tc).ok())
753                        .collect(),
754                )
755            })
756            .collect();
757
758        match merge_lists(Some(left_chunks), other_chunks) {
759            Ok(Some(merged)) => merged
760                .into_iter()
761                .map(|v| {
762                    let name = v.get("name").and_then(|n| n.as_str()).map(String::from);
763                    let args = v.get("args").and_then(|a| a.as_str()).map(String::from);
764                    let id = v.get("id").and_then(|i| i.as_str()).map(String::from);
765                    let index = v.get("index").and_then(|i| i.as_i64()).map(|i| i as i32);
766                    ToolCallChunk {
767                        name,
768                        args,
769                        id,
770                        index,
771                    }
772                })
773                .collect(),
774            _ => {
775                let mut chunks = left.tool_call_chunks.clone();
776                for other in &others {
777                    chunks.extend(other.tool_call_chunks.clone());
778                }
779                chunks
780            }
781        }
782    };
783
784    // Merge usage metadata
785    let usage_metadata =
786        if left.usage_metadata.is_some() || others.iter().any(|o| o.usage_metadata.is_some()) {
787            let mut result = left.usage_metadata.clone();
788            for other in &others {
789                result = Some(add_usage(result.as_ref(), other.usage_metadata.as_ref()));
790            }
791            result
792        } else {
793            None
794        };
795
796    // Select ID with priority: provider-assigned > lc_run-* > lc_*
797    let chunk_id = {
798        let mut candidates = vec![left.id.as_deref()];
799        candidates.extend(others.iter().map(|o| o.id.as_deref()));
800
801        // First pass: pick the first provider-assigned id (non-run-* and non-lc_*)
802        let mut selected_id: Option<&str> = None;
803        for id_str in candidates.iter().flatten() {
804            if !id_str.starts_with(LC_ID_PREFIX) && !id_str.starts_with(LC_AUTO_PREFIX) {
805                selected_id = Some(id_str);
806                break;
807            }
808        }
809
810        // Second pass: prefer lc_run-* IDs over lc_* IDs
811        if selected_id.is_none() {
812            for id_str in candidates.iter().flatten() {
813                if id_str.starts_with(LC_ID_PREFIX) {
814                    selected_id = Some(id_str);
815                    break;
816                }
817            }
818        }
819
820        // Third pass: take any remaining ID (auto-generated lc_* IDs)
821        if selected_id.is_none()
822            && let Some(id_str) = candidates.iter().flatten().next()
823        {
824            selected_id = Some(id_str);
825        }
826
827        selected_id.map(String::from)
828    };
829
830    // Determine chunk_position: if any chunk has "last", result is "last"
831    let chunk_position = if left.chunk_position == Some(ChunkPosition::Last)
832        || others
833            .iter()
834            .any(|o| o.chunk_position == Some(ChunkPosition::Last))
835    {
836        Some(ChunkPosition::Last)
837    } else {
838        None
839    };
840
841    let mut result = AIMessageChunk {
842        content,
843        id: chunk_id,
844        name: left
845            .name
846            .clone()
847            .or_else(|| others.iter().find_map(|o| o.name.clone())),
848        tool_calls: left.tool_calls.clone(),
849        invalid_tool_calls: left.invalid_tool_calls.clone(),
850        tool_call_chunks,
851        usage_metadata,
852        additional_kwargs,
853        response_metadata,
854        chunk_position,
855    };
856
857    // Initialize tool calls from chunks if this is the last chunk
858    if result.chunk_position == Some(ChunkPosition::Last) {
859        result.init_tool_calls();
860    }
861
862    result
863}
864
865impl std::ops::Add for AIMessageChunk {
866    type Output = AIMessageChunk;
867
868    fn add(self, other: AIMessageChunk) -> AIMessageChunk {
869        add_ai_message_chunks(self, vec![other])
870    }
871}
872
873impl std::iter::Sum for AIMessageChunk {
874    fn sum<I: Iterator<Item = AIMessageChunk>>(iter: I) -> AIMessageChunk {
875        let chunks: Vec<AIMessageChunk> = iter.collect();
876        if chunks.is_empty() {
877            AIMessageChunk::new("")
878        } else {
879            let first = chunks[0].clone();
880            let rest = chunks[1..].to_vec();
881            add_ai_message_chunks(first, rest)
882        }
883    }
884}
885
886/// Add two UsageMetadata objects together.
887///
888/// This function recursively adds the token counts from both UsageMetadata objects.
889/// Uses the generic `_dict_int_op` pattern from Python.
890///
891/// # Example
892///
893/// ```
894/// use agent_chain_core::messages::{add_usage, UsageMetadata, InputTokenDetails};
895///
896/// let left = UsageMetadata {
897///     input_tokens: 5,
898///     output_tokens: 0,
899///     total_tokens: 5,
900///     input_token_details: Some(InputTokenDetails {
901///         audio: None,
902///         cache_creation: None,
903///         cache_read: Some(3),
904///     }),
905///     output_token_details: None,
906/// };
907/// let right = UsageMetadata {
908///     input_tokens: 0,
909///     output_tokens: 10,
910///     total_tokens: 10,
911///     input_token_details: None,
912///     output_token_details: None,
913/// };
914///
915/// let result = add_usage(Some(&left), Some(&right));
916/// assert_eq!(result.input_tokens, 5);
917/// assert_eq!(result.output_tokens, 10);
918/// assert_eq!(result.total_tokens, 15);
919/// ```
920pub fn add_usage(left: Option<&UsageMetadata>, right: Option<&UsageMetadata>) -> UsageMetadata {
921    match (left, right) {
922        (None, None) => UsageMetadata::default(),
923        (Some(l), None) => l.clone(),
924        (None, Some(r)) => r.clone(),
925        (Some(l), Some(r)) => {
926            let left_json = serde_json::to_value(l).unwrap_or_default();
927            let right_json = serde_json::to_value(r).unwrap_or_default();
928
929            match dict_int_add_json(&left_json, &right_json) {
930                Ok(merged) => serde_json::from_value(merged).unwrap_or_else(|_| l.add(r)),
931                Err(_) => l.add(r),
932            }
933        }
934    }
935}
936
937/// Subtract two UsageMetadata objects.
938///
939/// Token counts cannot be negative so the actual operation is `max(left - right, 0)`.
940/// Uses the generic `_dict_int_op` pattern from Python.
941///
942/// # Example
943///
944/// ```
945/// use agent_chain_core::messages::{subtract_usage, UsageMetadata, InputTokenDetails};
946///
947/// let left = UsageMetadata {
948///     input_tokens: 5,
949///     output_tokens: 10,
950///     total_tokens: 15,
951///     input_token_details: Some(InputTokenDetails {
952///         audio: None,
953///         cache_creation: None,
954///         cache_read: Some(4),
955///     }),
956///     output_token_details: None,
957/// };
958/// let right = UsageMetadata {
959///     input_tokens: 3,
960///     output_tokens: 8,
961///     total_tokens: 11,
962///     input_token_details: None,
963///     output_token_details: None,
964/// };
965///
966/// let result = subtract_usage(Some(&left), Some(&right));
967/// assert_eq!(result.input_tokens, 2);
968/// assert_eq!(result.output_tokens, 2);
969/// assert_eq!(result.total_tokens, 4);
970/// ```
971pub fn subtract_usage(
972    left: Option<&UsageMetadata>,
973    right: Option<&UsageMetadata>,
974) -> UsageMetadata {
975    match (left, right) {
976        (None, None) => UsageMetadata::default(),
977        (Some(l), None) => l.clone(),
978        (None, Some(_)) => UsageMetadata::default(),
979        (Some(l), Some(r)) => {
980            let left_json = serde_json::to_value(l).unwrap_or_default();
981            let right_json = serde_json::to_value(r).unwrap_or_default();
982
983            match dict_int_sub_floor_json(&left_json, &right_json) {
984                Ok(subtracted) => {
985                    serde_json::from_value(subtracted).unwrap_or_else(|_| subtract_manual(l, r))
986                }
987                Err(_) => subtract_manual(l, r),
988            }
989        }
990    }
991}
992
993/// Manual subtraction fallback for UsageMetadata.
994fn subtract_manual(l: &UsageMetadata, r: &UsageMetadata) -> UsageMetadata {
995    UsageMetadata {
996        input_tokens: (l.input_tokens - r.input_tokens).max(0),
997        output_tokens: (l.output_tokens - r.output_tokens).max(0),
998        total_tokens: (l.total_tokens - r.total_tokens).max(0),
999        input_token_details: match (&l.input_token_details, &r.input_token_details) {
1000            (Some(a), Some(b)) => Some(InputTokenDetails {
1001                audio: a.audio.map(|x| (x - b.audio.unwrap_or(0)).max(0)),
1002                cache_creation: a
1003                    .cache_creation
1004                    .map(|x| (x - b.cache_creation.unwrap_or(0)).max(0)),
1005                cache_read: a.cache_read.map(|x| (x - b.cache_read.unwrap_or(0)).max(0)),
1006            }),
1007            (Some(a), None) => Some(a.clone()),
1008            (None, Some(b)) => Some(InputTokenDetails {
1009                audio: b.audio.map(|_| 0),
1010                cache_creation: b.cache_creation.map(|_| 0),
1011                cache_read: b.cache_read.map(|_| 0),
1012            }),
1013            (None, None) => None,
1014        },
1015        output_token_details: match (&l.output_token_details, &r.output_token_details) {
1016            (Some(a), Some(b)) => Some(OutputTokenDetails {
1017                audio: a.audio.map(|x| (x - b.audio.unwrap_or(0)).max(0)),
1018                reasoning: a.reasoning.map(|x| (x - b.reasoning.unwrap_or(0)).max(0)),
1019            }),
1020            (Some(a), None) => Some(a.clone()),
1021            (None, Some(b)) => Some(OutputTokenDetails {
1022                audio: b.audio.map(|_| 0),
1023                reasoning: b.reasoning.map(|_| 0),
1024            }),
1025            (None, None) => None,
1026        },
1027    }
1028}
1029
1030/// Parse tool calls from additional_kwargs for backwards compatibility.
1031///
1032/// This corresponds to `_backwards_compat_tool_calls` in LangChain Python.
1033/// It checks `additional_kwargs["tool_calls"]` and parses them into
1034/// either `tool_calls`/`invalid_tool_calls` (for AIMessage) or
1035/// `tool_call_chunks` (for AIMessageChunk).
1036///
1037/// # Arguments
1038///
1039/// * `additional_kwargs` - The additional_kwargs HashMap to check
1040/// * `is_chunk` - Whether this is for an AIMessageChunk (uses chunk parser) or AIMessage
1041///
1042/// # Returns
1043///
1044/// A tuple of (tool_calls, invalid_tool_calls, tool_call_chunks) where only
1045/// the appropriate fields are populated based on `is_chunk`.
1046pub fn backwards_compat_tool_calls(
1047    additional_kwargs: &HashMap<String, serde_json::Value>,
1048    is_chunk: bool,
1049) -> (Vec<ToolCall>, Vec<InvalidToolCall>, Vec<ToolCallChunk>) {
1050    let mut tool_calls = Vec::new();
1051    let mut invalid_tool_calls = Vec::new();
1052    let mut tool_call_chunks = Vec::new();
1053
1054    if let Some(raw_tool_calls) = additional_kwargs.get("tool_calls")
1055        && let Some(raw_array) = raw_tool_calls.as_array()
1056    {
1057        if is_chunk {
1058            tool_call_chunks = default_tool_chunk_parser(raw_array);
1059        } else {
1060            let (parsed_calls, parsed_invalid) = default_tool_parser(raw_array);
1061            tool_calls = parsed_calls;
1062            invalid_tool_calls = parsed_invalid;
1063        }
1064    }
1065
1066    (tool_calls, invalid_tool_calls, tool_call_chunks)
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071    use super::*;
1072    use serde_json::json;
1073
1074    #[test]
1075    fn test_add_usage_basic() {
1076        let left = UsageMetadata {
1077            input_tokens: 5,
1078            output_tokens: 0,
1079            total_tokens: 5,
1080            input_token_details: Some(InputTokenDetails {
1081                audio: None,
1082                cache_creation: None,
1083                cache_read: Some(3),
1084            }),
1085            output_token_details: None,
1086        };
1087        let right = UsageMetadata {
1088            input_tokens: 0,
1089            output_tokens: 10,
1090            total_tokens: 10,
1091            input_token_details: None,
1092            output_token_details: Some(OutputTokenDetails {
1093                audio: None,
1094                reasoning: Some(4),
1095            }),
1096        };
1097
1098        let result = add_usage(Some(&left), Some(&right));
1099
1100        assert_eq!(result.input_tokens, 5);
1101        assert_eq!(result.output_tokens, 10);
1102        assert_eq!(result.total_tokens, 15);
1103        assert!(result.input_token_details.is_some());
1104        assert_eq!(
1105            result.input_token_details.as_ref().unwrap().cache_read,
1106            Some(3)
1107        );
1108        assert!(result.output_token_details.is_some());
1109        assert_eq!(
1110            result.output_token_details.as_ref().unwrap().reasoning,
1111            Some(4)
1112        );
1113    }
1114
1115    #[test]
1116    fn test_add_usage_none_cases() {
1117        let usage = UsageMetadata::new(10, 20);
1118
1119        // Both None
1120        let result = add_usage(None, None);
1121        assert_eq!(result.input_tokens, 0);
1122        assert_eq!(result.output_tokens, 0);
1123        assert_eq!(result.total_tokens, 0);
1124
1125        // Left Some, Right None
1126        let result = add_usage(Some(&usage), None);
1127        assert_eq!(result.input_tokens, 10);
1128        assert_eq!(result.output_tokens, 20);
1129
1130        // Left None, Right Some
1131        let result = add_usage(None, Some(&usage));
1132        assert_eq!(result.input_tokens, 10);
1133        assert_eq!(result.output_tokens, 20);
1134    }
1135
1136    #[test]
1137    fn test_subtract_usage_basic() {
1138        let left = UsageMetadata {
1139            input_tokens: 5,
1140            output_tokens: 10,
1141            total_tokens: 15,
1142            input_token_details: Some(InputTokenDetails {
1143                audio: None,
1144                cache_creation: None,
1145                cache_read: Some(4),
1146            }),
1147            output_token_details: None,
1148        };
1149        let right = UsageMetadata {
1150            input_tokens: 3,
1151            output_tokens: 8,
1152            total_tokens: 11,
1153            input_token_details: None,
1154            output_token_details: Some(OutputTokenDetails {
1155                audio: None,
1156                reasoning: Some(4),
1157            }),
1158        };
1159
1160        let result = subtract_usage(Some(&left), Some(&right));
1161
1162        assert_eq!(result.input_tokens, 2);
1163        assert_eq!(result.output_tokens, 2);
1164        assert_eq!(result.total_tokens, 4);
1165        // cache_read should remain 4 (4 - 0 = 4)
1166        assert!(result.input_token_details.is_some());
1167        assert_eq!(
1168            result.input_token_details.as_ref().unwrap().cache_read,
1169            Some(4)
1170        );
1171        // reasoning should be 0 (0 - 4 = -4, floored to 0)
1172        assert!(result.output_token_details.is_some());
1173        assert_eq!(
1174            result.output_token_details.as_ref().unwrap().reasoning,
1175            Some(0)
1176        );
1177    }
1178
1179    #[test]
1180    fn test_subtract_usage_floor_at_zero() {
1181        let left = UsageMetadata::new(5, 5);
1182        let right = UsageMetadata::new(10, 10);
1183
1184        let result = subtract_usage(Some(&left), Some(&right));
1185
1186        // Should floor at 0, not go negative
1187        assert_eq!(result.input_tokens, 0);
1188        assert_eq!(result.output_tokens, 0);
1189        assert_eq!(result.total_tokens, 0);
1190    }
1191
1192    #[test]
1193    fn test_subtract_usage_none_cases() {
1194        let usage = UsageMetadata::new(10, 20);
1195
1196        // Both None
1197        let result = subtract_usage(None, None);
1198        assert_eq!(result.input_tokens, 0);
1199
1200        // Left Some, Right None - should return left unchanged
1201        let result = subtract_usage(Some(&usage), None);
1202        assert_eq!(result.input_tokens, 10);
1203        assert_eq!(result.output_tokens, 20);
1204
1205        // Left None, Right Some - should return default (zeroes)
1206        let result = subtract_usage(None, Some(&usage));
1207        assert_eq!(result.input_tokens, 0);
1208        assert_eq!(result.output_tokens, 0);
1209    }
1210
1211    #[test]
1212    fn test_backwards_compat_tool_calls_for_message() {
1213        let mut additional_kwargs = HashMap::new();
1214        additional_kwargs.insert(
1215            "tool_calls".to_string(),
1216            json!([
1217                {
1218                    "id": "call_123",
1219                    "function": {
1220                        "name": "get_weather",
1221                        "arguments": "{\"city\": \"London\"}"
1222                    }
1223                }
1224            ]),
1225        );
1226
1227        let (tool_calls, invalid_tool_calls, tool_call_chunks) =
1228            backwards_compat_tool_calls(&additional_kwargs, false);
1229
1230        assert_eq!(tool_calls.len(), 1);
1231        assert_eq!(tool_calls[0].name(), "get_weather");
1232        assert!(invalid_tool_calls.is_empty());
1233        assert!(tool_call_chunks.is_empty());
1234    }
1235
1236    #[test]
1237    fn test_backwards_compat_tool_calls_for_chunk() {
1238        let mut additional_kwargs = HashMap::new();
1239        additional_kwargs.insert(
1240            "tool_calls".to_string(),
1241            json!([
1242                {
1243                    "id": "call_123",
1244                    "index": 0,
1245                    "function": {
1246                        "name": "get_weather",
1247                        "arguments": "{\"city\":"
1248                    }
1249                }
1250            ]),
1251        );
1252
1253        let (tool_calls, invalid_tool_calls, tool_call_chunks) =
1254            backwards_compat_tool_calls(&additional_kwargs, true);
1255
1256        assert!(tool_calls.is_empty());
1257        assert!(invalid_tool_calls.is_empty());
1258        assert_eq!(tool_call_chunks.len(), 1);
1259        assert_eq!(tool_call_chunks[0].name, Some("get_weather".to_string()));
1260        assert_eq!(tool_call_chunks[0].index, Some(0));
1261    }
1262
1263    #[test]
1264    fn test_backwards_compat_tool_calls_empty() {
1265        let additional_kwargs = HashMap::new();
1266
1267        let (tool_calls, invalid_tool_calls, tool_call_chunks) =
1268            backwards_compat_tool_calls(&additional_kwargs, false);
1269
1270        assert!(tool_calls.is_empty());
1271        assert!(invalid_tool_calls.is_empty());
1272        assert!(tool_call_chunks.is_empty());
1273    }
1274
1275    #[test]
1276    fn test_backwards_compat_tool_calls_invalid_json() {
1277        let mut additional_kwargs = HashMap::new();
1278        additional_kwargs.insert(
1279            "tool_calls".to_string(),
1280            json!([
1281                {
1282                    "id": "call_123",
1283                    "function": {
1284                        "name": "get_weather",
1285                        "arguments": "invalid json {"
1286                    }
1287                }
1288            ]),
1289        );
1290
1291        let (tool_calls, invalid_tool_calls, _tool_call_chunks) =
1292            backwards_compat_tool_calls(&additional_kwargs, false);
1293
1294        // Should be invalid because the JSON is malformed
1295        assert!(tool_calls.is_empty());
1296        assert_eq!(invalid_tool_calls.len(), 1);
1297        assert_eq!(invalid_tool_calls[0].name, Some("get_weather".to_string()));
1298    }
1299
1300    #[test]
1301    fn test_ai_message_chunk_add() {
1302        let chunk1 = AIMessageChunk::new("Hello ");
1303        let chunk2 = AIMessageChunk::new("world!");
1304
1305        let result = chunk1 + chunk2;
1306
1307        assert_eq!(result.content(), "Hello world!");
1308    }
1309
1310    #[test]
1311    fn test_ai_message_chunk_sum() {
1312        let chunks = vec![
1313            AIMessageChunk::new("Hello "),
1314            AIMessageChunk::new("beautiful "),
1315            AIMessageChunk::new("world!"),
1316        ];
1317
1318        let result: AIMessageChunk = chunks.into_iter().sum();
1319
1320        assert_eq!(result.content(), "Hello beautiful world!");
1321    }
1322
1323    #[test]
1324    fn test_add_ai_message_chunks_with_usage() {
1325        let mut chunk1 = AIMessageChunk::new("Hello ");
1326        chunk1.usage_metadata = Some(UsageMetadata::new(5, 0));
1327
1328        let mut chunk2 = AIMessageChunk::new("world!");
1329        chunk2.usage_metadata = Some(UsageMetadata::new(0, 10));
1330
1331        let result = add_ai_message_chunks(chunk1, vec![chunk2]);
1332
1333        assert_eq!(result.content(), "Hello world!");
1334        assert!(result.usage_metadata.is_some());
1335        let usage = result.usage_metadata.as_ref().unwrap();
1336        assert_eq!(usage.input_tokens, 5);
1337        assert_eq!(usage.output_tokens, 10);
1338        assert_eq!(usage.total_tokens, 15);
1339    }
1340
1341    #[test]
1342    fn test_add_ai_message_chunks_id_priority() {
1343        // Provider-assigned ID should take priority
1344        let chunk1 = AIMessageChunk::with_id("lc_auto123", "");
1345        let chunk2 = AIMessageChunk::with_id("provider_id_456", "");
1346        let chunk3 = AIMessageChunk::with_id("lc_run-789", "");
1347
1348        let result = add_ai_message_chunks(chunk1, vec![chunk2, chunk3]);
1349
1350        // Provider ID should be selected (not lc_* or lc_run-*)
1351        assert_eq!(result.id(), Some("provider_id_456"));
1352    }
1353
1354    #[test]
1355    fn test_add_ai_message_chunks_lc_run_priority() {
1356        // lc_run-* should take priority over lc_*
1357        let chunk1 = AIMessageChunk::with_id("lc_auto123", "");
1358        let chunk2 = AIMessageChunk::with_id("lc_run-789", "");
1359
1360        let result = add_ai_message_chunks(chunk1, vec![chunk2]);
1361
1362        assert_eq!(result.id(), Some("lc_run-789"));
1363    }
1364}