1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[cfg(feature = "specta")]
10use specta::Type;
11
12use super::tool::{
13 InvalidToolCall, ToolCall, ToolCallChunk, default_tool_chunk_parser, default_tool_parser,
14 invalid_tool_call, tool_call,
15};
16use crate::utils::json::parse_partial_json;
17use crate::utils::merge::{merge_dicts, merge_lists};
18use crate::utils::usage::{dict_int_add_json, dict_int_sub_floor_json};
19use crate::utils::uuid::{LC_AUTO_PREFIX, LC_ID_PREFIX, uuid7};
20
21#[cfg_attr(feature = "specta", derive(Type))]
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
26pub struct InputTokenDetails {
27 #[serde(skip_serializing_if = "Option::is_none")]
29 pub audio: Option<i64>,
30 #[serde(skip_serializing_if = "Option::is_none")]
32 pub cache_creation: Option<i64>,
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub cache_read: Option<i64>,
36}
37
38#[cfg_attr(feature = "specta", derive(Type))]
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
43pub struct OutputTokenDetails {
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub audio: Option<i64>,
47 #[serde(skip_serializing_if = "Option::is_none")]
49 pub reasoning: Option<i64>,
50}
51
52#[cfg_attr(feature = "specta", derive(Type))]
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
57pub struct UsageMetadata {
58 pub input_tokens: i64,
60 pub output_tokens: i64,
62 pub total_tokens: i64,
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub input_token_details: Option<InputTokenDetails>,
67 #[serde(skip_serializing_if = "Option::is_none")]
69 pub output_token_details: Option<OutputTokenDetails>,
70}
71
72impl UsageMetadata {
73 pub fn new(input_tokens: i64, output_tokens: i64) -> Self {
75 Self {
76 input_tokens,
77 output_tokens,
78 total_tokens: input_tokens + output_tokens,
79 input_token_details: None,
80 output_token_details: None,
81 }
82 }
83
84 pub fn add(&self, other: &UsageMetadata) -> Self {
86 Self {
87 input_tokens: self.input_tokens + other.input_tokens,
88 output_tokens: self.output_tokens + other.output_tokens,
89 total_tokens: self.total_tokens + other.total_tokens,
90 input_token_details: match (&self.input_token_details, &other.input_token_details) {
91 (Some(a), Some(b)) => Some(InputTokenDetails {
92 audio: match (a.audio, b.audio) {
93 (Some(x), Some(y)) => Some(x + y),
94 (Some(x), None) | (None, Some(x)) => Some(x),
95 (None, None) => None,
96 },
97 cache_creation: match (a.cache_creation, b.cache_creation) {
98 (Some(x), Some(y)) => Some(x + y),
99 (Some(x), None) | (None, Some(x)) => Some(x),
100 (None, None) => None,
101 },
102 cache_read: match (a.cache_read, b.cache_read) {
103 (Some(x), Some(y)) => Some(x + y),
104 (Some(x), None) | (None, Some(x)) => Some(x),
105 (None, None) => None,
106 },
107 }),
108 (Some(a), None) => Some(a.clone()),
109 (None, Some(b)) => Some(b.clone()),
110 (None, None) => None,
111 },
112 output_token_details: match (&self.output_token_details, &other.output_token_details) {
113 (Some(a), Some(b)) => Some(OutputTokenDetails {
114 audio: match (a.audio, b.audio) {
115 (Some(x), Some(y)) => Some(x + y),
116 (Some(x), None) | (None, Some(x)) => Some(x),
117 (None, None) => None,
118 },
119 reasoning: match (a.reasoning, b.reasoning) {
120 (Some(x), Some(y)) => Some(x + y),
121 (Some(x), None) | (None, Some(x)) => Some(x),
122 (None, None) => None,
123 },
124 }),
125 (Some(a), None) => Some(a.clone()),
126 (None, Some(b)) => Some(b.clone()),
127 (None, None) => None,
128 },
129 }
130 }
131}
132
133#[cfg_attr(feature = "specta", derive(Type))]
142#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
143pub struct AIMessage {
144 content: String,
146 id: Option<String>,
148 #[serde(skip_serializing_if = "Option::is_none")]
150 name: Option<String>,
151 #[serde(default)]
153 tool_calls: Vec<ToolCall>,
154 #[serde(default)]
156 invalid_tool_calls: Vec<InvalidToolCall>,
157 #[serde(skip_serializing_if = "Option::is_none")]
159 usage_metadata: Option<UsageMetadata>,
160 #[serde(default)]
162 additional_kwargs: HashMap<String, serde_json::Value>,
163 #[serde(default)]
165 response_metadata: HashMap<String, serde_json::Value>,
166}
167
168impl AIMessage {
169 pub fn new(content: impl Into<String>) -> Self {
171 Self {
172 content: content.into(),
173 id: Some(uuid7(None).to_string()),
174 name: None,
175 tool_calls: Vec::new(),
176 invalid_tool_calls: Vec::new(),
177 usage_metadata: None,
178 additional_kwargs: HashMap::new(),
179 response_metadata: HashMap::new(),
180 }
181 }
182
183 pub fn with_id(id: impl Into<String>, content: impl Into<String>) -> Self {
187 Self {
188 content: content.into(),
189 id: Some(id.into()),
190 name: None,
191 tool_calls: Vec::new(),
192 invalid_tool_calls: Vec::new(),
193 usage_metadata: None,
194 additional_kwargs: HashMap::new(),
195 response_metadata: HashMap::new(),
196 }
197 }
198
199 pub fn with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
201 Self {
202 content: content.into(),
203 id: Some(uuid7(None).to_string()),
204 name: None,
205 tool_calls,
206 invalid_tool_calls: Vec::new(),
207 usage_metadata: None,
208 additional_kwargs: HashMap::new(),
209 response_metadata: HashMap::new(),
210 }
211 }
212
213 pub fn with_id_and_tool_calls(
217 id: impl Into<String>,
218 content: impl Into<String>,
219 tool_calls: Vec<ToolCall>,
220 ) -> Self {
221 Self {
222 content: content.into(),
223 id: Some(id.into()),
224 name: None,
225 tool_calls,
226 invalid_tool_calls: Vec::new(),
227 usage_metadata: None,
228 additional_kwargs: HashMap::new(),
229 response_metadata: HashMap::new(),
230 }
231 }
232
233 pub fn with_all_tool_calls(
235 content: impl Into<String>,
236 tool_calls: Vec<ToolCall>,
237 invalid_tool_calls: Vec<InvalidToolCall>,
238 ) -> Self {
239 Self {
240 content: content.into(),
241 id: Some(uuid7(None).to_string()),
242 name: None,
243 tool_calls,
244 invalid_tool_calls,
245 usage_metadata: None,
246 additional_kwargs: HashMap::new(),
247 response_metadata: HashMap::new(),
248 }
249 }
250
251 pub fn with_name(mut self, name: impl Into<String>) -> Self {
253 self.name = Some(name.into());
254 self
255 }
256
257 pub fn with_invalid_tool_calls(mut self, invalid_tool_calls: Vec<InvalidToolCall>) -> Self {
259 self.invalid_tool_calls = invalid_tool_calls;
260 self
261 }
262
263 pub fn with_usage_metadata(mut self, usage_metadata: UsageMetadata) -> Self {
265 self.usage_metadata = Some(usage_metadata);
266 self
267 }
268
269 pub fn content(&self) -> &str {
271 &self.content
272 }
273
274 pub fn id(&self) -> Option<&str> {
276 self.id.as_deref()
277 }
278
279 pub fn name(&self) -> Option<&str> {
281 self.name.as_deref()
282 }
283
284 pub fn tool_calls(&self) -> &[ToolCall] {
286 &self.tool_calls
287 }
288
289 pub fn invalid_tool_calls(&self) -> &[InvalidToolCall] {
291 &self.invalid_tool_calls
292 }
293
294 pub fn usage_metadata(&self) -> Option<&UsageMetadata> {
296 self.usage_metadata.as_ref()
297 }
298
299 pub fn with_annotations<T: Serialize>(mut self, annotations: Vec<T>) -> Self {
302 if let Ok(value) = serde_json::to_value(&annotations) {
303 self.additional_kwargs
304 .insert("annotations".to_string(), value);
305 }
306 self
307 }
308
309 pub fn annotations(&self) -> Option<&serde_json::Value> {
311 self.additional_kwargs.get("annotations")
312 }
313
314 pub fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
316 &self.additional_kwargs
317 }
318
319 pub fn response_metadata(&self) -> &HashMap<String, serde_json::Value> {
321 &self.response_metadata
322 }
323
324 pub fn with_response_metadata(
326 mut self,
327 response_metadata: HashMap<String, serde_json::Value>,
328 ) -> Self {
329 self.response_metadata = response_metadata;
330 self
331 }
332
333 pub fn with_additional_kwargs(
335 mut self,
336 additional_kwargs: HashMap<String, serde_json::Value>,
337 ) -> Self {
338 self.additional_kwargs = additional_kwargs;
339 self
340 }
341
342 pub fn pretty_repr(&self, _html: bool) -> String {
346 let title = "AI Message";
347 let sep_len = (80 - title.len() - 2) / 2;
348 let sep: String = "=".repeat(sep_len);
349 let header = format!("{} {} {}", sep, title, sep);
350
351 let mut lines = vec![header];
352
353 if let Some(name) = &self.name {
354 lines.push(format!("Name: {}", name));
355 }
356
357 lines.push(String::new());
358 lines.push(self.content.clone());
359
360 format_tool_calls_repr(&self.tool_calls, &self.invalid_tool_calls, &mut lines);
361
362 lines.join("\n").trim().to_string()
363 }
364}
365
366fn format_tool_calls_repr(
368 tool_calls: &[ToolCall],
369 invalid_tool_calls: &[InvalidToolCall],
370 lines: &mut Vec<String>,
371) {
372 if !tool_calls.is_empty() {
373 lines.push("Tool Calls:".to_string());
374 for tc in tool_calls {
375 lines.push(format!(" {} ({})", tc.name(), tc.id()));
376 lines.push(format!(" Call ID: {}", tc.id()));
377 lines.push(" Args:".to_string());
378 if let serde_json::Value::Object(args) = tc.args() {
379 for (arg, value) in args {
380 lines.push(format!(" {}: {}", arg, value));
381 }
382 } else {
383 lines.push(format!(" {}", tc.args()));
384 }
385 }
386 }
387 if !invalid_tool_calls.is_empty() {
388 lines.push("Invalid Tool Calls:".to_string());
389 for itc in invalid_tool_calls {
390 let name = itc.name.as_deref().unwrap_or("Tool");
391 let id = itc.id.as_deref().unwrap_or("unknown");
392 lines.push(format!(" {} ({})", name, id));
393 lines.push(format!(" Call ID: {}", id));
394 if let Some(error) = &itc.error {
395 lines.push(format!(" Error: {}", error));
396 }
397 lines.push(" Args:".to_string());
398 if let Some(args) = &itc.args {
399 lines.push(format!(" {}", args));
400 }
401 }
402 }
403}
404
405#[cfg_attr(feature = "specta", derive(Type))]
410#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
411#[serde(rename_all = "lowercase")]
412pub enum ChunkPosition {
413 Last,
415}
416
417#[cfg_attr(feature = "specta", derive(Type))]
424#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
425pub struct AIMessageChunk {
426 content: String,
428 id: Option<String>,
430 #[serde(skip_serializing_if = "Option::is_none")]
432 name: Option<String>,
433 #[serde(default)]
435 tool_calls: Vec<ToolCall>,
436 #[serde(default)]
438 invalid_tool_calls: Vec<InvalidToolCall>,
439 #[serde(default)]
441 tool_call_chunks: Vec<ToolCallChunk>,
442 #[serde(skip_serializing_if = "Option::is_none")]
444 usage_metadata: Option<UsageMetadata>,
445 #[serde(default)]
447 additional_kwargs: HashMap<String, serde_json::Value>,
448 #[serde(default)]
450 response_metadata: HashMap<String, serde_json::Value>,
451 #[serde(skip_serializing_if = "Option::is_none")]
456 chunk_position: Option<ChunkPosition>,
457}
458
459impl AIMessageChunk {
460 pub fn new(content: impl Into<String>) -> Self {
462 Self {
463 content: content.into(),
464 id: None,
465 name: None,
466 tool_calls: Vec::new(),
467 invalid_tool_calls: Vec::new(),
468 tool_call_chunks: Vec::new(),
469 usage_metadata: None,
470 additional_kwargs: HashMap::new(),
471 response_metadata: HashMap::new(),
472 chunk_position: None,
473 }
474 }
475
476 pub fn with_id(id: impl Into<String>, content: impl Into<String>) -> Self {
478 Self {
479 content: content.into(),
480 id: Some(id.into()),
481 name: None,
482 tool_calls: Vec::new(),
483 invalid_tool_calls: Vec::new(),
484 tool_call_chunks: Vec::new(),
485 usage_metadata: None,
486 additional_kwargs: HashMap::new(),
487 response_metadata: HashMap::new(),
488 chunk_position: None,
489 }
490 }
491
492 pub fn with_tool_call_chunks(
494 content: impl Into<String>,
495 tool_call_chunks: Vec<ToolCallChunk>,
496 ) -> Self {
497 Self {
498 content: content.into(),
499 id: None,
500 name: None,
501 tool_calls: Vec::new(),
502 invalid_tool_calls: Vec::new(),
503 tool_call_chunks,
504 usage_metadata: None,
505 additional_kwargs: HashMap::new(),
506 response_metadata: HashMap::new(),
507 chunk_position: None,
508 }
509 }
510
511 pub fn content(&self) -> &str {
513 &self.content
514 }
515
516 pub fn id(&self) -> Option<&str> {
518 self.id.as_deref()
519 }
520
521 pub fn name(&self) -> Option<&str> {
523 self.name.as_deref()
524 }
525
526 pub fn tool_calls(&self) -> &[ToolCall] {
528 &self.tool_calls
529 }
530
531 pub fn invalid_tool_calls(&self) -> &[InvalidToolCall] {
533 &self.invalid_tool_calls
534 }
535
536 pub fn tool_call_chunks(&self) -> &[ToolCallChunk] {
538 &self.tool_call_chunks
539 }
540
541 pub fn usage_metadata(&self) -> Option<&UsageMetadata> {
543 self.usage_metadata.as_ref()
544 }
545
546 pub fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
548 &self.additional_kwargs
549 }
550
551 pub fn response_metadata(&self) -> &HashMap<String, serde_json::Value> {
553 &self.response_metadata
554 }
555
556 pub fn chunk_position(&self) -> Option<&ChunkPosition> {
558 self.chunk_position.as_ref()
559 }
560
561 pub fn set_chunk_position(&mut self, position: Option<ChunkPosition>) {
563 self.chunk_position = position;
564 }
565
566 pub fn set_tool_calls(&mut self, tool_calls: Vec<ToolCall>) {
568 self.tool_calls = tool_calls;
569 }
570
571 pub fn set_invalid_tool_calls(&mut self, invalid_tool_calls: Vec<InvalidToolCall>) {
573 self.invalid_tool_calls = invalid_tool_calls;
574 }
575
576 pub fn set_tool_call_chunks(&mut self, tool_call_chunks: Vec<ToolCallChunk>) {
578 self.tool_call_chunks = tool_call_chunks;
579 }
580
581 pub fn init_tool_calls(&mut self) {
586 if self.tool_call_chunks.is_empty() {
587 if !self.tool_calls.is_empty() {
588 self.tool_call_chunks = self
589 .tool_calls
590 .iter()
591 .map(|tc| ToolCallChunk {
592 name: Some(tc.name().to_string()),
593 args: Some(tc.args().to_string()),
594 id: Some(tc.id().to_string()),
595 index: None,
596 })
597 .collect();
598 }
599 if !self.invalid_tool_calls.is_empty() {
600 self.tool_call_chunks
601 .extend(self.invalid_tool_calls.iter().map(|tc| ToolCallChunk {
602 name: tc.name.clone(),
603 args: tc.args.clone(),
604 id: tc.id.clone(),
605 index: None,
606 }));
607 }
608 return;
609 }
610
611 let mut new_tool_calls = Vec::new();
612 let mut new_invalid_tool_calls = Vec::new();
613
614 for chunk in &self.tool_call_chunks {
615 let args_result = if let Some(args_str) = &chunk.args {
616 if args_str.is_empty() {
617 Ok(serde_json::Value::Object(serde_json::Map::new()))
618 } else {
619 parse_partial_json(args_str, false)
620 }
621 } else {
622 Ok(serde_json::Value::Object(serde_json::Map::new()))
623 };
624
625 match args_result {
626 Ok(args) if args.is_object() => {
627 new_tool_calls.push(tool_call(
628 chunk.name.clone().unwrap_or_default(),
629 args,
630 chunk.id.clone(),
631 ));
632 }
633 _ => {
634 new_invalid_tool_calls.push(invalid_tool_call(
635 chunk.name.clone(),
636 chunk.args.clone(),
637 chunk.id.clone(),
638 None,
639 ));
640 }
641 }
642 }
643
644 self.tool_calls = new_tool_calls;
645 self.invalid_tool_calls = new_invalid_tool_calls;
646 }
647
648 pub fn concat(&self, other: &AIMessageChunk) -> AIMessageChunk {
653 add_ai_message_chunks(self.clone(), vec![other.clone()])
654 }
655
656 pub fn to_message(&self) -> AIMessage {
658 AIMessage {
659 content: self.content.clone(),
660 id: self.id.clone(),
661 name: self.name.clone(),
662 tool_calls: self.tool_calls.clone(),
663 invalid_tool_calls: self.invalid_tool_calls.clone(),
664 usage_metadata: self.usage_metadata.clone(),
665 additional_kwargs: self.additional_kwargs.clone(),
666 response_metadata: self.response_metadata.clone(),
667 }
668 }
669
670 pub fn pretty_repr(&self, _html: bool) -> String {
674 let title = "AIMessageChunk";
675 let sep_len = (80 - title.len() - 2) / 2;
676 let sep: String = "=".repeat(sep_len);
677 let header = format!("{} {} {}", sep, title, sep);
678
679 let mut lines = vec![header];
680
681 if let Some(name) = &self.name {
682 lines.push(format!("Name: {}", name));
683 }
684
685 lines.push(String::new());
686 lines.push(self.content.clone());
687
688 format_tool_calls_repr(&self.tool_calls, &self.invalid_tool_calls, &mut lines);
689
690 lines.join("\n").trim().to_string()
691 }
692}
693
694pub fn add_ai_message_chunks(left: AIMessageChunk, others: Vec<AIMessageChunk>) -> AIMessageChunk {
707 let mut content = left.content.clone();
709 for other in &others {
710 content.push_str(&other.content);
711 }
712
713 let additional_kwargs = {
715 let left_val = serde_json::to_value(&left.additional_kwargs).unwrap_or_default();
716 let other_vals: Vec<serde_json::Value> = others
717 .iter()
718 .map(|o| serde_json::to_value(&o.additional_kwargs).unwrap_or_default())
719 .collect();
720 match merge_dicts(left_val, other_vals) {
721 Ok(merged) => serde_json::from_value(merged).unwrap_or_default(),
722 Err(_) => left.additional_kwargs.clone(),
723 }
724 };
725
726 let response_metadata = {
728 let left_val = serde_json::to_value(&left.response_metadata).unwrap_or_default();
729 let other_vals: Vec<serde_json::Value> = others
730 .iter()
731 .map(|o| serde_json::to_value(&o.response_metadata).unwrap_or_default())
732 .collect();
733 match merge_dicts(left_val, other_vals) {
734 Ok(merged) => serde_json::from_value(merged).unwrap_or_default(),
735 Err(_) => left.response_metadata.clone(),
736 }
737 };
738
739 let tool_call_chunks = {
741 let left_chunks: Vec<serde_json::Value> = left
742 .tool_call_chunks
743 .iter()
744 .filter_map(|tc| serde_json::to_value(tc).ok())
745 .collect();
746 let other_chunks: Vec<Option<Vec<serde_json::Value>>> = others
747 .iter()
748 .map(|o| {
749 Some(
750 o.tool_call_chunks
751 .iter()
752 .filter_map(|tc| serde_json::to_value(tc).ok())
753 .collect(),
754 )
755 })
756 .collect();
757
758 match merge_lists(Some(left_chunks), other_chunks) {
759 Ok(Some(merged)) => merged
760 .into_iter()
761 .map(|v| {
762 let name = v.get("name").and_then(|n| n.as_str()).map(String::from);
763 let args = v.get("args").and_then(|a| a.as_str()).map(String::from);
764 let id = v.get("id").and_then(|i| i.as_str()).map(String::from);
765 let index = v.get("index").and_then(|i| i.as_i64()).map(|i| i as i32);
766 ToolCallChunk {
767 name,
768 args,
769 id,
770 index,
771 }
772 })
773 .collect(),
774 _ => {
775 let mut chunks = left.tool_call_chunks.clone();
776 for other in &others {
777 chunks.extend(other.tool_call_chunks.clone());
778 }
779 chunks
780 }
781 }
782 };
783
784 let usage_metadata =
786 if left.usage_metadata.is_some() || others.iter().any(|o| o.usage_metadata.is_some()) {
787 let mut result = left.usage_metadata.clone();
788 for other in &others {
789 result = Some(add_usage(result.as_ref(), other.usage_metadata.as_ref()));
790 }
791 result
792 } else {
793 None
794 };
795
796 let chunk_id = {
798 let mut candidates = vec![left.id.as_deref()];
799 candidates.extend(others.iter().map(|o| o.id.as_deref()));
800
801 let mut selected_id: Option<&str> = None;
803 for id_str in candidates.iter().flatten() {
804 if !id_str.starts_with(LC_ID_PREFIX) && !id_str.starts_with(LC_AUTO_PREFIX) {
805 selected_id = Some(id_str);
806 break;
807 }
808 }
809
810 if selected_id.is_none() {
812 for id_str in candidates.iter().flatten() {
813 if id_str.starts_with(LC_ID_PREFIX) {
814 selected_id = Some(id_str);
815 break;
816 }
817 }
818 }
819
820 if selected_id.is_none()
822 && let Some(id_str) = candidates.iter().flatten().next()
823 {
824 selected_id = Some(id_str);
825 }
826
827 selected_id.map(String::from)
828 };
829
830 let chunk_position = if left.chunk_position == Some(ChunkPosition::Last)
832 || others
833 .iter()
834 .any(|o| o.chunk_position == Some(ChunkPosition::Last))
835 {
836 Some(ChunkPosition::Last)
837 } else {
838 None
839 };
840
841 let mut result = AIMessageChunk {
842 content,
843 id: chunk_id,
844 name: left
845 .name
846 .clone()
847 .or_else(|| others.iter().find_map(|o| o.name.clone())),
848 tool_calls: left.tool_calls.clone(),
849 invalid_tool_calls: left.invalid_tool_calls.clone(),
850 tool_call_chunks,
851 usage_metadata,
852 additional_kwargs,
853 response_metadata,
854 chunk_position,
855 };
856
857 if result.chunk_position == Some(ChunkPosition::Last) {
859 result.init_tool_calls();
860 }
861
862 result
863}
864
865impl std::ops::Add for AIMessageChunk {
866 type Output = AIMessageChunk;
867
868 fn add(self, other: AIMessageChunk) -> AIMessageChunk {
869 add_ai_message_chunks(self, vec![other])
870 }
871}
872
873impl std::iter::Sum for AIMessageChunk {
874 fn sum<I: Iterator<Item = AIMessageChunk>>(iter: I) -> AIMessageChunk {
875 let chunks: Vec<AIMessageChunk> = iter.collect();
876 if chunks.is_empty() {
877 AIMessageChunk::new("")
878 } else {
879 let first = chunks[0].clone();
880 let rest = chunks[1..].to_vec();
881 add_ai_message_chunks(first, rest)
882 }
883 }
884}
885
886pub fn add_usage(left: Option<&UsageMetadata>, right: Option<&UsageMetadata>) -> UsageMetadata {
921 match (left, right) {
922 (None, None) => UsageMetadata::default(),
923 (Some(l), None) => l.clone(),
924 (None, Some(r)) => r.clone(),
925 (Some(l), Some(r)) => {
926 let left_json = serde_json::to_value(l).unwrap_or_default();
927 let right_json = serde_json::to_value(r).unwrap_or_default();
928
929 match dict_int_add_json(&left_json, &right_json) {
930 Ok(merged) => serde_json::from_value(merged).unwrap_or_else(|_| l.add(r)),
931 Err(_) => l.add(r),
932 }
933 }
934 }
935}
936
937pub fn subtract_usage(
972 left: Option<&UsageMetadata>,
973 right: Option<&UsageMetadata>,
974) -> UsageMetadata {
975 match (left, right) {
976 (None, None) => UsageMetadata::default(),
977 (Some(l), None) => l.clone(),
978 (None, Some(_)) => UsageMetadata::default(),
979 (Some(l), Some(r)) => {
980 let left_json = serde_json::to_value(l).unwrap_or_default();
981 let right_json = serde_json::to_value(r).unwrap_or_default();
982
983 match dict_int_sub_floor_json(&left_json, &right_json) {
984 Ok(subtracted) => {
985 serde_json::from_value(subtracted).unwrap_or_else(|_| subtract_manual(l, r))
986 }
987 Err(_) => subtract_manual(l, r),
988 }
989 }
990 }
991}
992
993fn subtract_manual(l: &UsageMetadata, r: &UsageMetadata) -> UsageMetadata {
995 UsageMetadata {
996 input_tokens: (l.input_tokens - r.input_tokens).max(0),
997 output_tokens: (l.output_tokens - r.output_tokens).max(0),
998 total_tokens: (l.total_tokens - r.total_tokens).max(0),
999 input_token_details: match (&l.input_token_details, &r.input_token_details) {
1000 (Some(a), Some(b)) => Some(InputTokenDetails {
1001 audio: a.audio.map(|x| (x - b.audio.unwrap_or(0)).max(0)),
1002 cache_creation: a
1003 .cache_creation
1004 .map(|x| (x - b.cache_creation.unwrap_or(0)).max(0)),
1005 cache_read: a.cache_read.map(|x| (x - b.cache_read.unwrap_or(0)).max(0)),
1006 }),
1007 (Some(a), None) => Some(a.clone()),
1008 (None, Some(b)) => Some(InputTokenDetails {
1009 audio: b.audio.map(|_| 0),
1010 cache_creation: b.cache_creation.map(|_| 0),
1011 cache_read: b.cache_read.map(|_| 0),
1012 }),
1013 (None, None) => None,
1014 },
1015 output_token_details: match (&l.output_token_details, &r.output_token_details) {
1016 (Some(a), Some(b)) => Some(OutputTokenDetails {
1017 audio: a.audio.map(|x| (x - b.audio.unwrap_or(0)).max(0)),
1018 reasoning: a.reasoning.map(|x| (x - b.reasoning.unwrap_or(0)).max(0)),
1019 }),
1020 (Some(a), None) => Some(a.clone()),
1021 (None, Some(b)) => Some(OutputTokenDetails {
1022 audio: b.audio.map(|_| 0),
1023 reasoning: b.reasoning.map(|_| 0),
1024 }),
1025 (None, None) => None,
1026 },
1027 }
1028}
1029
1030pub fn backwards_compat_tool_calls(
1047 additional_kwargs: &HashMap<String, serde_json::Value>,
1048 is_chunk: bool,
1049) -> (Vec<ToolCall>, Vec<InvalidToolCall>, Vec<ToolCallChunk>) {
1050 let mut tool_calls = Vec::new();
1051 let mut invalid_tool_calls = Vec::new();
1052 let mut tool_call_chunks = Vec::new();
1053
1054 if let Some(raw_tool_calls) = additional_kwargs.get("tool_calls")
1055 && let Some(raw_array) = raw_tool_calls.as_array()
1056 {
1057 if is_chunk {
1058 tool_call_chunks = default_tool_chunk_parser(raw_array);
1059 } else {
1060 let (parsed_calls, parsed_invalid) = default_tool_parser(raw_array);
1061 tool_calls = parsed_calls;
1062 invalid_tool_calls = parsed_invalid;
1063 }
1064 }
1065
1066 (tool_calls, invalid_tool_calls, tool_call_chunks)
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071 use super::*;
1072 use serde_json::json;
1073
1074 #[test]
1075 fn test_add_usage_basic() {
1076 let left = UsageMetadata {
1077 input_tokens: 5,
1078 output_tokens: 0,
1079 total_tokens: 5,
1080 input_token_details: Some(InputTokenDetails {
1081 audio: None,
1082 cache_creation: None,
1083 cache_read: Some(3),
1084 }),
1085 output_token_details: None,
1086 };
1087 let right = UsageMetadata {
1088 input_tokens: 0,
1089 output_tokens: 10,
1090 total_tokens: 10,
1091 input_token_details: None,
1092 output_token_details: Some(OutputTokenDetails {
1093 audio: None,
1094 reasoning: Some(4),
1095 }),
1096 };
1097
1098 let result = add_usage(Some(&left), Some(&right));
1099
1100 assert_eq!(result.input_tokens, 5);
1101 assert_eq!(result.output_tokens, 10);
1102 assert_eq!(result.total_tokens, 15);
1103 assert!(result.input_token_details.is_some());
1104 assert_eq!(
1105 result.input_token_details.as_ref().unwrap().cache_read,
1106 Some(3)
1107 );
1108 assert!(result.output_token_details.is_some());
1109 assert_eq!(
1110 result.output_token_details.as_ref().unwrap().reasoning,
1111 Some(4)
1112 );
1113 }
1114
1115 #[test]
1116 fn test_add_usage_none_cases() {
1117 let usage = UsageMetadata::new(10, 20);
1118
1119 let result = add_usage(None, None);
1121 assert_eq!(result.input_tokens, 0);
1122 assert_eq!(result.output_tokens, 0);
1123 assert_eq!(result.total_tokens, 0);
1124
1125 let result = add_usage(Some(&usage), None);
1127 assert_eq!(result.input_tokens, 10);
1128 assert_eq!(result.output_tokens, 20);
1129
1130 let result = add_usage(None, Some(&usage));
1132 assert_eq!(result.input_tokens, 10);
1133 assert_eq!(result.output_tokens, 20);
1134 }
1135
1136 #[test]
1137 fn test_subtract_usage_basic() {
1138 let left = UsageMetadata {
1139 input_tokens: 5,
1140 output_tokens: 10,
1141 total_tokens: 15,
1142 input_token_details: Some(InputTokenDetails {
1143 audio: None,
1144 cache_creation: None,
1145 cache_read: Some(4),
1146 }),
1147 output_token_details: None,
1148 };
1149 let right = UsageMetadata {
1150 input_tokens: 3,
1151 output_tokens: 8,
1152 total_tokens: 11,
1153 input_token_details: None,
1154 output_token_details: Some(OutputTokenDetails {
1155 audio: None,
1156 reasoning: Some(4),
1157 }),
1158 };
1159
1160 let result = subtract_usage(Some(&left), Some(&right));
1161
1162 assert_eq!(result.input_tokens, 2);
1163 assert_eq!(result.output_tokens, 2);
1164 assert_eq!(result.total_tokens, 4);
1165 assert!(result.input_token_details.is_some());
1167 assert_eq!(
1168 result.input_token_details.as_ref().unwrap().cache_read,
1169 Some(4)
1170 );
1171 assert!(result.output_token_details.is_some());
1173 assert_eq!(
1174 result.output_token_details.as_ref().unwrap().reasoning,
1175 Some(0)
1176 );
1177 }
1178
1179 #[test]
1180 fn test_subtract_usage_floor_at_zero() {
1181 let left = UsageMetadata::new(5, 5);
1182 let right = UsageMetadata::new(10, 10);
1183
1184 let result = subtract_usage(Some(&left), Some(&right));
1185
1186 assert_eq!(result.input_tokens, 0);
1188 assert_eq!(result.output_tokens, 0);
1189 assert_eq!(result.total_tokens, 0);
1190 }
1191
1192 #[test]
1193 fn test_subtract_usage_none_cases() {
1194 let usage = UsageMetadata::new(10, 20);
1195
1196 let result = subtract_usage(None, None);
1198 assert_eq!(result.input_tokens, 0);
1199
1200 let result = subtract_usage(Some(&usage), None);
1202 assert_eq!(result.input_tokens, 10);
1203 assert_eq!(result.output_tokens, 20);
1204
1205 let result = subtract_usage(None, Some(&usage));
1207 assert_eq!(result.input_tokens, 0);
1208 assert_eq!(result.output_tokens, 0);
1209 }
1210
1211 #[test]
1212 fn test_backwards_compat_tool_calls_for_message() {
1213 let mut additional_kwargs = HashMap::new();
1214 additional_kwargs.insert(
1215 "tool_calls".to_string(),
1216 json!([
1217 {
1218 "id": "call_123",
1219 "function": {
1220 "name": "get_weather",
1221 "arguments": "{\"city\": \"London\"}"
1222 }
1223 }
1224 ]),
1225 );
1226
1227 let (tool_calls, invalid_tool_calls, tool_call_chunks) =
1228 backwards_compat_tool_calls(&additional_kwargs, false);
1229
1230 assert_eq!(tool_calls.len(), 1);
1231 assert_eq!(tool_calls[0].name(), "get_weather");
1232 assert!(invalid_tool_calls.is_empty());
1233 assert!(tool_call_chunks.is_empty());
1234 }
1235
1236 #[test]
1237 fn test_backwards_compat_tool_calls_for_chunk() {
1238 let mut additional_kwargs = HashMap::new();
1239 additional_kwargs.insert(
1240 "tool_calls".to_string(),
1241 json!([
1242 {
1243 "id": "call_123",
1244 "index": 0,
1245 "function": {
1246 "name": "get_weather",
1247 "arguments": "{\"city\":"
1248 }
1249 }
1250 ]),
1251 );
1252
1253 let (tool_calls, invalid_tool_calls, tool_call_chunks) =
1254 backwards_compat_tool_calls(&additional_kwargs, true);
1255
1256 assert!(tool_calls.is_empty());
1257 assert!(invalid_tool_calls.is_empty());
1258 assert_eq!(tool_call_chunks.len(), 1);
1259 assert_eq!(tool_call_chunks[0].name, Some("get_weather".to_string()));
1260 assert_eq!(tool_call_chunks[0].index, Some(0));
1261 }
1262
1263 #[test]
1264 fn test_backwards_compat_tool_calls_empty() {
1265 let additional_kwargs = HashMap::new();
1266
1267 let (tool_calls, invalid_tool_calls, tool_call_chunks) =
1268 backwards_compat_tool_calls(&additional_kwargs, false);
1269
1270 assert!(tool_calls.is_empty());
1271 assert!(invalid_tool_calls.is_empty());
1272 assert!(tool_call_chunks.is_empty());
1273 }
1274
1275 #[test]
1276 fn test_backwards_compat_tool_calls_invalid_json() {
1277 let mut additional_kwargs = HashMap::new();
1278 additional_kwargs.insert(
1279 "tool_calls".to_string(),
1280 json!([
1281 {
1282 "id": "call_123",
1283 "function": {
1284 "name": "get_weather",
1285 "arguments": "invalid json {"
1286 }
1287 }
1288 ]),
1289 );
1290
1291 let (tool_calls, invalid_tool_calls, _tool_call_chunks) =
1292 backwards_compat_tool_calls(&additional_kwargs, false);
1293
1294 assert!(tool_calls.is_empty());
1296 assert_eq!(invalid_tool_calls.len(), 1);
1297 assert_eq!(invalid_tool_calls[0].name, Some("get_weather".to_string()));
1298 }
1299
1300 #[test]
1301 fn test_ai_message_chunk_add() {
1302 let chunk1 = AIMessageChunk::new("Hello ");
1303 let chunk2 = AIMessageChunk::new("world!");
1304
1305 let result = chunk1 + chunk2;
1306
1307 assert_eq!(result.content(), "Hello world!");
1308 }
1309
1310 #[test]
1311 fn test_ai_message_chunk_sum() {
1312 let chunks = vec![
1313 AIMessageChunk::new("Hello "),
1314 AIMessageChunk::new("beautiful "),
1315 AIMessageChunk::new("world!"),
1316 ];
1317
1318 let result: AIMessageChunk = chunks.into_iter().sum();
1319
1320 assert_eq!(result.content(), "Hello beautiful world!");
1321 }
1322
1323 #[test]
1324 fn test_add_ai_message_chunks_with_usage() {
1325 let mut chunk1 = AIMessageChunk::new("Hello ");
1326 chunk1.usage_metadata = Some(UsageMetadata::new(5, 0));
1327
1328 let mut chunk2 = AIMessageChunk::new("world!");
1329 chunk2.usage_metadata = Some(UsageMetadata::new(0, 10));
1330
1331 let result = add_ai_message_chunks(chunk1, vec![chunk2]);
1332
1333 assert_eq!(result.content(), "Hello world!");
1334 assert!(result.usage_metadata.is_some());
1335 let usage = result.usage_metadata.as_ref().unwrap();
1336 assert_eq!(usage.input_tokens, 5);
1337 assert_eq!(usage.output_tokens, 10);
1338 assert_eq!(usage.total_tokens, 15);
1339 }
1340
1341 #[test]
1342 fn test_add_ai_message_chunks_id_priority() {
1343 let chunk1 = AIMessageChunk::with_id("lc_auto123", "");
1345 let chunk2 = AIMessageChunk::with_id("provider_id_456", "");
1346 let chunk3 = AIMessageChunk::with_id("lc_run-789", "");
1347
1348 let result = add_ai_message_chunks(chunk1, vec![chunk2, chunk3]);
1349
1350 assert_eq!(result.id(), Some("provider_id_456"));
1352 }
1353
1354 #[test]
1355 fn test_add_ai_message_chunks_lc_run_priority() {
1356 let chunk1 = AIMessageChunk::with_id("lc_auto123", "");
1358 let chunk2 = AIMessageChunk::with_id("lc_run-789", "");
1359
1360 let result = add_ai_message_chunks(chunk1, vec![chunk2]);
1361
1362 assert_eq!(result.id(), Some("lc_run-789"));
1363 }
1364}