1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError},
6 json_utils,
7 message::{self, DocumentMediaType, MessageError, Reasoning},
8 one_or_many::string_or_one_or_many,
9};
10use std::{convert::Infallible, str::FromStr};
11
12use super::client::Client;
13use crate::completion::CompletionRequest;
14use crate::providers::anthropic::streaming::StreamingCompletionResponse;
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17
18pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
24
25pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
27
28pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
30
31pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33
34pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
36
37pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
39
40pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
42
43pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
45
46pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
47pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
48pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
49
50#[derive(Debug, Deserialize, Serialize)]
51pub struct CompletionResponse {
52 pub content: Vec<Content>,
53 pub id: String,
54 pub model: String,
55 pub role: String,
56 pub stop_reason: Option<String>,
57 pub stop_sequence: Option<String>,
58 pub usage: Usage,
59}
60
61#[derive(Debug, Deserialize, Serialize)]
62pub struct Usage {
63 pub input_tokens: u64,
64 pub cache_read_input_tokens: Option<u64>,
65 pub cache_creation_input_tokens: Option<u64>,
66 pub output_tokens: u64,
67}
68
69impl std::fmt::Display for Usage {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 write!(
72 f,
73 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
74 self.input_tokens,
75 match self.cache_read_input_tokens {
76 Some(token) => token.to_string(),
77 None => "n/a".to_string(),
78 },
79 match self.cache_creation_input_tokens {
80 Some(token) => token.to_string(),
81 None => "n/a".to_string(),
82 },
83 self.output_tokens
84 )
85 }
86}
87
88#[derive(Debug, Deserialize, Serialize)]
89pub struct ToolDefinition {
90 pub name: String,
91 pub description: Option<String>,
92 pub input_schema: serde_json::Value,
93}
94
95#[derive(Debug, Deserialize, Serialize)]
96#[serde(tag = "type", rename_all = "snake_case")]
97pub enum CacheControl {
98 Ephemeral,
99}
100
101impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
102 type Error = CompletionError;
103
104 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
105 let content = response
106 .content
107 .iter()
108 .map(|content| {
109 Ok(match content {
110 Content::Text { text } => completion::AssistantContent::text(text),
111 Content::ToolUse { id, name, input } => {
112 completion::AssistantContent::tool_call(id, name, input.clone())
113 }
114 _ => {
115 return Err(CompletionError::ResponseError(
116 "Response did not contain a message or tool call".into(),
117 ));
118 }
119 })
120 })
121 .collect::<Result<Vec<_>, _>>()?;
122
123 let choice = OneOrMany::many(content).map_err(|_| {
124 CompletionError::ResponseError(
125 "Response contained no message or tool call (empty)".to_owned(),
126 )
127 })?;
128
129 let usage = completion::Usage {
130 input_tokens: response.usage.input_tokens,
131 output_tokens: response.usage.output_tokens,
132 total_tokens: response.usage.input_tokens + response.usage.output_tokens,
133 };
134
135 Ok(completion::CompletionResponse {
136 choice,
137 usage,
138 raw_response: response,
139 })
140 }
141}
142
143#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
144pub struct Message {
145 pub role: Role,
146 #[serde(deserialize_with = "string_or_one_or_many")]
147 pub content: OneOrMany<Content>,
148}
149
150#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
151#[serde(rename_all = "lowercase")]
152pub enum Role {
153 User,
154 Assistant,
155}
156
157#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
158#[serde(tag = "type", rename_all = "snake_case")]
159pub enum Content {
160 Text {
161 text: String,
162 },
163 Image {
164 source: ImageSource,
165 },
166 ToolUse {
167 id: String,
168 name: String,
169 input: serde_json::Value,
170 },
171 ToolResult {
172 tool_use_id: String,
173 #[serde(deserialize_with = "string_or_one_or_many")]
174 content: OneOrMany<ToolResultContent>,
175 #[serde(skip_serializing_if = "Option::is_none")]
176 is_error: Option<bool>,
177 },
178 Document {
179 source: DocumentSource,
180 },
181 Thinking {
182 thinking: String,
183 signature: Option<String>,
184 },
185}
186
187impl FromStr for Content {
188 type Err = Infallible;
189
190 fn from_str(s: &str) -> Result<Self, Self::Err> {
191 Ok(Content::Text { text: s.to_owned() })
192 }
193}
194
195#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
196#[serde(tag = "type", rename_all = "snake_case")]
197pub enum ToolResultContent {
198 Text { text: String },
199 Image(ImageSource),
200}
201
202impl FromStr for ToolResultContent {
203 type Err = Infallible;
204
205 fn from_str(s: &str) -> Result<Self, Self::Err> {
206 Ok(ToolResultContent::Text { text: s.to_owned() })
207 }
208}
209
210#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
211pub struct ImageSource {
212 pub data: String,
213 pub media_type: ImageFormat,
214 pub r#type: SourceType,
215}
216
217#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
218pub struct DocumentSource {
219 pub data: String,
220 pub media_type: DocumentFormat,
221 pub r#type: SourceType,
222}
223
224#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
225#[serde(rename_all = "lowercase")]
226pub enum ImageFormat {
227 #[serde(rename = "image/jpeg")]
228 JPEG,
229 #[serde(rename = "image/png")]
230 PNG,
231 #[serde(rename = "image/gif")]
232 GIF,
233 #[serde(rename = "image/webp")]
234 WEBP,
235}
236
237#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
241#[serde(rename_all = "lowercase")]
242pub enum DocumentFormat {
243 #[serde(rename = "application/pdf")]
244 PDF,
245}
246
247#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
248#[serde(rename_all = "lowercase")]
249pub enum SourceType {
250 BASE64,
251}
252
253impl From<String> for Content {
254 fn from(text: String) -> Self {
255 Content::Text { text }
256 }
257}
258
259impl From<String> for ToolResultContent {
260 fn from(text: String) -> Self {
261 ToolResultContent::Text { text }
262 }
263}
264
265impl TryFrom<message::ContentFormat> for SourceType {
266 type Error = MessageError;
267
268 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
269 match format {
270 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
271 message::ContentFormat::String => Err(MessageError::ConversionError(
272 "Image urls are not supported in Anthropic".to_owned(),
273 )),
274 }
275 }
276}
277
278impl From<SourceType> for message::ContentFormat {
279 fn from(source_type: SourceType) -> Self {
280 match source_type {
281 SourceType::BASE64 => message::ContentFormat::Base64,
282 }
283 }
284}
285
286impl TryFrom<message::ImageMediaType> for ImageFormat {
287 type Error = MessageError;
288
289 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
290 Ok(match media_type {
291 message::ImageMediaType::JPEG => ImageFormat::JPEG,
292 message::ImageMediaType::PNG => ImageFormat::PNG,
293 message::ImageMediaType::GIF => ImageFormat::GIF,
294 message::ImageMediaType::WEBP => ImageFormat::WEBP,
295 _ => {
296 return Err(MessageError::ConversionError(
297 format!("Unsupported image media type: {media_type:?}").to_owned(),
298 ));
299 }
300 })
301 }
302}
303
304impl From<ImageFormat> for message::ImageMediaType {
305 fn from(format: ImageFormat) -> Self {
306 match format {
307 ImageFormat::JPEG => message::ImageMediaType::JPEG,
308 ImageFormat::PNG => message::ImageMediaType::PNG,
309 ImageFormat::GIF => message::ImageMediaType::GIF,
310 ImageFormat::WEBP => message::ImageMediaType::WEBP,
311 }
312 }
313}
314
315impl TryFrom<DocumentMediaType> for DocumentFormat {
316 type Error = MessageError;
317 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
318 if !matches!(value, DocumentMediaType::PDF) {
319 return Err(MessageError::ConversionError(
320 "Anthropic only supports PDF documents".to_string(),
321 ));
322 };
323
324 Ok(DocumentFormat::PDF)
325 }
326}
327
328impl From<message::AssistantContent> for Content {
329 fn from(text: message::AssistantContent) -> Self {
330 match text {
331 message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
332 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
333 Content::ToolUse {
334 id,
335 name: function.name,
336 input: function.arguments,
337 }
338 }
339 message::AssistantContent::Reasoning(Reasoning { reasoning }) => Content::Thinking {
340 thinking: reasoning,
341 signature: None,
342 },
343 }
344 }
345}
346
347impl TryFrom<message::Message> for Message {
348 type Error = MessageError;
349
350 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
351 Ok(match message {
352 message::Message::User { content } => Message {
353 role: Role::User,
354 content: content.try_map(|content| match content {
355 message::UserContent::Text(message::Text { text }) => {
356 Ok(Content::Text { text })
357 }
358 message::UserContent::ToolResult(message::ToolResult {
359 id, content, ..
360 }) => Ok(Content::ToolResult {
361 tool_use_id: id,
362 content: content.try_map(|content| match content {
363 message::ToolResultContent::Text(message::Text { text }) => {
364 Ok(ToolResultContent::Text { text })
365 }
366 message::ToolResultContent::Image(image) => {
367 let media_type =
368 image.media_type.ok_or(MessageError::ConversionError(
369 "Image media type is required".to_owned(),
370 ))?;
371 let format = image.format.ok_or(MessageError::ConversionError(
372 "Image format is required".to_owned(),
373 ))?;
374 Ok(ToolResultContent::Image(ImageSource {
375 data: image.data,
376 media_type: media_type.try_into()?,
377 r#type: format.try_into()?,
378 }))
379 }
380 })?,
381 is_error: None,
382 }),
383 message::UserContent::Image(message::Image {
384 data,
385 format,
386 media_type,
387 ..
388 }) => {
389 let source = ImageSource {
390 data,
391 media_type: match media_type {
392 Some(media_type) => media_type.try_into()?,
393 None => {
394 return Err(MessageError::ConversionError(
395 "Image media type is required".to_owned(),
396 ));
397 }
398 },
399 r#type: match format {
400 Some(format) => format.try_into()?,
401 None => SourceType::BASE64,
402 },
403 };
404 Ok(Content::Image { source })
405 }
406 message::UserContent::Document(message::Document {
407 data,
408 format,
409 media_type,
410 }) => {
411 let Some(media_type) = media_type else {
412 return Err(MessageError::ConversionError(
413 "Document media type is required".to_string(),
414 ));
415 };
416
417 let source = DocumentSource {
418 data,
419 media_type: media_type.try_into()?,
420 r#type: match format {
421 Some(format) => format.try_into()?,
422 None => SourceType::BASE64,
423 },
424 };
425 Ok(Content::Document { source })
426 }
427 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
428 "Audio is not supported in Anthropic".to_owned(),
429 )),
430 })?,
431 },
432
433 message::Message::Assistant { content, .. } => Message {
434 content: content.map(|content| content.into()),
435 role: Role::Assistant,
436 },
437 })
438 }
439}
440
441impl TryFrom<Content> for message::AssistantContent {
442 type Error = MessageError;
443
444 fn try_from(content: Content) -> Result<Self, Self::Error> {
445 Ok(match content {
446 Content::Text { text } => message::AssistantContent::text(text),
447 Content::ToolUse { id, name, input } => {
448 message::AssistantContent::tool_call(id, name, input)
449 }
450 _ => {
451 return Err(MessageError::ConversionError(
452 format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
453 ));
454 }
455 })
456 }
457}
458
459impl From<ToolResultContent> for message::ToolResultContent {
460 fn from(content: ToolResultContent) -> Self {
461 match content {
462 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
463 ToolResultContent::Image(ImageSource {
464 data,
465 media_type: format,
466 r#type,
467 }) => message::ToolResultContent::image(
468 data,
469 Some(r#type.into()),
470 Some(format.into()),
471 None,
472 ),
473 }
474 }
475}
476
477impl TryFrom<Message> for message::Message {
478 type Error = MessageError;
479
480 fn try_from(message: Message) -> Result<Self, Self::Error> {
481 Ok(match message.role {
482 Role::User => message::Message::User {
483 content: message.content.try_map(|content| {
484 Ok(match content {
485 Content::Text { text } => message::UserContent::text(text),
486 Content::ToolResult {
487 tool_use_id,
488 content,
489 ..
490 } => message::UserContent::tool_result(
491 tool_use_id,
492 content.map(|content| content.into()),
493 ),
494 Content::Image { source } => message::UserContent::Image(message::Image {
495 data: source.data,
496 format: Some(message::ContentFormat::Base64),
497 media_type: Some(source.media_type.into()),
498 detail: None,
499 }),
500 Content::Document { source } => message::UserContent::document(
501 source.data,
502 Some(message::ContentFormat::Base64),
503 Some(message::DocumentMediaType::PDF),
504 ),
505 _ => {
506 return Err(MessageError::ConversionError(
507 "Unsupported content type for User role".to_owned(),
508 ));
509 }
510 })
511 })?,
512 },
513 Role::Assistant => match message.content.first() {
514 Content::Text { .. } | Content::ToolUse { .. } => message::Message::Assistant {
515 id: None,
516 content: message.content.try_map(|content| content.try_into())?,
517 },
518
519 _ => {
520 return Err(MessageError::ConversionError(
521 format!("Unsupported message for Assistant role: {message:?}").to_owned(),
522 ));
523 }
524 },
525 })
526 }
527}
528
529#[derive(Clone)]
530pub struct CompletionModel {
531 pub(crate) client: Client,
532 pub model: String,
533 pub default_max_tokens: Option<u64>,
534}
535
536impl CompletionModel {
537 pub fn new(client: Client, model: &str) -> Self {
538 Self {
539 client,
540 model: model.to_string(),
541 default_max_tokens: calculate_max_tokens(model),
542 }
543 }
544}
545
546fn calculate_max_tokens(model: &str) -> Option<u64> {
552 if model.starts_with("claude-opus-4") {
553 Some(32000)
554 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
555 Some(64000)
556 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
557 Some(8192)
558 } else if model.starts_with("claude-3-opus")
559 || model.starts_with("claude-3-sonnet")
560 || model.starts_with("claude-3-haiku")
561 {
562 Some(4096)
563 } else {
564 None
565 }
566}
567
568#[derive(Debug, Deserialize, Serialize)]
569struct Metadata {
570 user_id: Option<String>,
571}
572
573#[derive(Default, Debug, Serialize, Deserialize)]
574#[serde(tag = "type", rename_all = "snake_case")]
575pub enum ToolChoice {
576 #[default]
577 Auto,
578 Any,
579 Tool {
580 name: String,
581 },
582}
583
584impl completion::CompletionModel for CompletionModel {
585 type Response = CompletionResponse;
586 type StreamingResponse = StreamingCompletionResponse;
587
588 #[cfg_attr(feature = "worker", worker::send)]
589 async fn completion(
590 &self,
591 completion_request: completion::CompletionRequest,
592 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
593 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
599 tokens
600 } else if let Some(tokens) = self.default_max_tokens {
601 tokens
602 } else {
603 return Err(CompletionError::RequestError(
604 "`max_tokens` must be set for Anthropic".into(),
605 ));
606 };
607
608 let mut full_history = vec![];
609 if let Some(docs) = completion_request.normalized_documents() {
610 full_history.push(docs);
611 }
612 full_history.extend(completion_request.chat_history);
613
614 let full_history = full_history
615 .into_iter()
616 .map(Message::try_from)
617 .collect::<Result<Vec<Message>, _>>()?;
618
619 let mut request = json!({
620 "model": self.model,
621 "messages": full_history,
622 "max_tokens": max_tokens,
623 "system": completion_request.preamble.unwrap_or("".to_string()),
624 });
625
626 if let Some(temperature) = completion_request.temperature {
627 json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
628 }
629
630 if !completion_request.tools.is_empty() {
631 json_utils::merge_inplace(
632 &mut request,
633 json!({
634 "tools": completion_request
635 .tools
636 .into_iter()
637 .map(|tool| ToolDefinition {
638 name: tool.name,
639 description: Some(tool.description),
640 input_schema: tool.parameters,
641 })
642 .collect::<Vec<_>>(),
643 "tool_choice": ToolChoice::Auto,
644 }),
645 );
646 }
647
648 if let Some(ref params) = completion_request.additional_params {
649 json_utils::merge_inplace(&mut request, params.clone())
650 }
651
652 tracing::debug!("Anthropic completion request: {request}");
653
654 let response = self
655 .client
656 .post("/v1/messages")
657 .json(&request)
658 .send()
659 .await?;
660
661 if response.status().is_success() {
662 match response.json::<ApiResponse<CompletionResponse>>().await? {
663 ApiResponse::Message(completion) => {
664 tracing::info!(target: "rig",
665 "Anthropic completion token usage: {}",
666 completion.usage
667 );
668 completion.try_into()
669 }
670 ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
671 }
672 } else {
673 Err(CompletionError::ProviderError(response.text().await?))
674 }
675 }
676
677 #[cfg_attr(feature = "worker", worker::send)]
678 async fn stream(
679 &self,
680 request: CompletionRequest,
681 ) -> Result<
682 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
683 CompletionError,
684 > {
685 CompletionModel::stream(self, request).await
686 }
687}
688
689#[derive(Debug, Deserialize)]
690struct ApiErrorResponse {
691 message: String,
692}
693
694#[derive(Debug, Deserialize)]
695#[serde(tag = "type", rename_all = "snake_case")]
696enum ApiResponse<T> {
697 Message(T),
698 Error(ApiErrorResponse),
699}
700
701#[cfg(test)]
702mod tests {
703 use super::*;
704 use serde_path_to_error::deserialize;
705
706 #[test]
707 fn test_deserialize_message() {
708 let assistant_message_json = r#"
709 {
710 "role": "assistant",
711 "content": "\n\nHello there, how may I assist you today?"
712 }
713 "#;
714
715 let assistant_message_json2 = r#"
716 {
717 "role": "assistant",
718 "content": [
719 {
720 "type": "text",
721 "text": "\n\nHello there, how may I assist you today?"
722 },
723 {
724 "type": "tool_use",
725 "id": "toolu_01A09q90qw90lq917835lq9",
726 "name": "get_weather",
727 "input": {"location": "San Francisco, CA"}
728 }
729 ]
730 }
731 "#;
732
733 let user_message_json = r#"
734 {
735 "role": "user",
736 "content": [
737 {
738 "type": "image",
739 "source": {
740 "type": "base64",
741 "media_type": "image/jpeg",
742 "data": "/9j/4AAQSkZJRg..."
743 }
744 },
745 {
746 "type": "text",
747 "text": "What is in this image?"
748 },
749 {
750 "type": "tool_result",
751 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
752 "content": "15 degrees"
753 }
754 ]
755 }
756 "#;
757
758 let assistant_message: Message = {
759 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
760 deserialize(jd).unwrap_or_else(|err| {
761 panic!("Deserialization error at {}: {}", err.path(), err);
762 })
763 };
764
765 let assistant_message2: Message = {
766 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
767 deserialize(jd).unwrap_or_else(|err| {
768 panic!("Deserialization error at {}: {}", err.path(), err);
769 })
770 };
771
772 let user_message: Message = {
773 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
774 deserialize(jd).unwrap_or_else(|err| {
775 panic!("Deserialization error at {}: {}", err.path(), err);
776 })
777 };
778
779 let Message { role, content } = assistant_message;
780 assert_eq!(role, Role::Assistant);
781 assert_eq!(
782 content.first(),
783 Content::Text {
784 text: "\n\nHello there, how may I assist you today?".to_owned()
785 }
786 );
787
788 let Message { role, content } = assistant_message2;
789 {
790 assert_eq!(role, Role::Assistant);
791 assert_eq!(content.len(), 2);
792
793 let mut iter = content.into_iter();
794
795 match iter.next().unwrap() {
796 Content::Text { text } => {
797 assert_eq!(text, "\n\nHello there, how may I assist you today?");
798 }
799 _ => panic!("Expected text content"),
800 }
801
802 match iter.next().unwrap() {
803 Content::ToolUse { id, name, input } => {
804 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
805 assert_eq!(name, "get_weather");
806 assert_eq!(input, json!({"location": "San Francisco, CA"}));
807 }
808 _ => panic!("Expected tool use content"),
809 }
810
811 assert_eq!(iter.next(), None);
812 }
813
814 let Message { role, content } = user_message;
815 {
816 assert_eq!(role, Role::User);
817 assert_eq!(content.len(), 3);
818
819 let mut iter = content.into_iter();
820
821 match iter.next().unwrap() {
822 Content::Image { source } => {
823 assert_eq!(
824 source,
825 ImageSource {
826 data: "/9j/4AAQSkZJRg...".to_owned(),
827 media_type: ImageFormat::JPEG,
828 r#type: SourceType::BASE64,
829 }
830 );
831 }
832 _ => panic!("Expected image content"),
833 }
834
835 match iter.next().unwrap() {
836 Content::Text { text } => {
837 assert_eq!(text, "What is in this image?");
838 }
839 _ => panic!("Expected text content"),
840 }
841
842 match iter.next().unwrap() {
843 Content::ToolResult {
844 tool_use_id,
845 content,
846 is_error,
847 } => {
848 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
849 assert_eq!(
850 content.first(),
851 ToolResultContent::Text {
852 text: "15 degrees".to_owned()
853 }
854 );
855 assert_eq!(is_error, None);
856 }
857 _ => panic!("Expected tool result content"),
858 }
859
860 assert_eq!(iter.next(), None);
861 }
862 }
863
864 #[test]
865 fn test_message_to_message_conversion() {
866 let user_message: Message = serde_json::from_str(
867 r#"
868 {
869 "role": "user",
870 "content": [
871 {
872 "type": "image",
873 "source": {
874 "type": "base64",
875 "media_type": "image/jpeg",
876 "data": "/9j/4AAQSkZJRg..."
877 }
878 },
879 {
880 "type": "text",
881 "text": "What is in this image?"
882 },
883 {
884 "type": "document",
885 "source": {
886 "type": "base64",
887 "data": "base64_encoded_pdf_data",
888 "media_type": "application/pdf"
889 }
890 }
891 ]
892 }
893 "#,
894 )
895 .unwrap();
896
897 let assistant_message = Message {
898 role: Role::Assistant,
899 content: OneOrMany::one(Content::ToolUse {
900 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
901 name: "get_weather".to_string(),
902 input: json!({"location": "San Francisco, CA"}),
903 }),
904 };
905
906 let tool_message = Message {
907 role: Role::User,
908 content: OneOrMany::one(Content::ToolResult {
909 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
910 content: OneOrMany::one(ToolResultContent::Text {
911 text: "15 degrees".to_string(),
912 }),
913 is_error: None,
914 }),
915 };
916
917 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
918 let converted_assistant_message: message::Message =
919 assistant_message.clone().try_into().unwrap();
920 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
921
922 match converted_user_message.clone() {
923 message::Message::User { content } => {
924 assert_eq!(content.len(), 3);
925
926 let mut iter = content.into_iter();
927
928 match iter.next().unwrap() {
929 message::UserContent::Image(message::Image {
930 data,
931 format,
932 media_type,
933 ..
934 }) => {
935 assert_eq!(data, "/9j/4AAQSkZJRg...");
936 assert_eq!(format.unwrap(), message::ContentFormat::Base64);
937 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
938 }
939 _ => panic!("Expected image content"),
940 }
941
942 match iter.next().unwrap() {
943 message::UserContent::Text(message::Text { text }) => {
944 assert_eq!(text, "What is in this image?");
945 }
946 _ => panic!("Expected text content"),
947 }
948
949 match iter.next().unwrap() {
950 message::UserContent::Document(message::Document {
951 data,
952 format,
953 media_type,
954 }) => {
955 assert_eq!(data, "base64_encoded_pdf_data");
956 assert_eq!(format.unwrap(), message::ContentFormat::Base64);
957 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
958 }
959 _ => panic!("Expected document content"),
960 }
961
962 assert_eq!(iter.next(), None);
963 }
964 _ => panic!("Expected user message"),
965 }
966
967 match converted_tool_message.clone() {
968 message::Message::User { content } => {
969 let message::ToolResult { id, content, .. } = match content.first() {
970 message::UserContent::ToolResult(tool_result) => tool_result,
971 _ => panic!("Expected tool result content"),
972 };
973 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
974 match content.first() {
975 message::ToolResultContent::Text(message::Text { text }) => {
976 assert_eq!(text, "15 degrees");
977 }
978 _ => panic!("Expected text content"),
979 }
980 }
981 _ => panic!("Expected tool result content"),
982 }
983
984 match converted_assistant_message.clone() {
985 message::Message::Assistant { content, .. } => {
986 assert_eq!(content.len(), 1);
987
988 match content.first() {
989 message::AssistantContent::ToolCall(message::ToolCall {
990 id, function, ..
991 }) => {
992 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
993 assert_eq!(function.name, "get_weather");
994 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
995 }
996 _ => panic!("Expected tool call content"),
997 }
998 }
999 _ => panic!("Expected assistant message"),
1000 }
1001
1002 let original_user_message: Message = converted_user_message.try_into().unwrap();
1003 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1004 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1005
1006 assert_eq!(user_message, original_user_message);
1007 assert_eq!(assistant_message, original_assistant_message);
1008 assert_eq!(tool_message, original_tool_message);
1009 }
1010}