1use std::{convert::Infallible, str::FromStr};
4
5use crate::{
6 completion::{self, CompletionError},
7 json_utils,
8 message::{self, MessageError},
9 one_or_many::string_or_one_or_many,
10 OneOrMany,
11};
12
13use serde::{Deserialize, Serialize};
14use serde_json::json;
15
16use super::client::Client;
17
18pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
23
24pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
26
27pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
29
30pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
32
33pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize)]
41pub struct CompletionResponse {
42 pub content: Vec<Content>,
43 pub id: String,
44 pub model: String,
45 pub role: String,
46 pub stop_reason: Option<String>,
47 pub stop_sequence: Option<String>,
48 pub usage: Usage,
49}
50
51#[derive(Debug, Deserialize, Serialize)]
52pub struct Usage {
53 pub input_tokens: u64,
54 pub cache_read_input_tokens: Option<u64>,
55 pub cache_creation_input_tokens: Option<u64>,
56 pub output_tokens: u64,
57}
58
59impl std::fmt::Display for Usage {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(
62 f,
63 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
64 self.input_tokens,
65 match self.cache_read_input_tokens {
66 Some(token) => token.to_string(),
67 None => "n/a".to_string(),
68 },
69 match self.cache_creation_input_tokens {
70 Some(token) => token.to_string(),
71 None => "n/a".to_string(),
72 },
73 self.output_tokens
74 )
75 }
76}
77
78#[derive(Debug, Deserialize, Serialize)]
79pub struct ToolDefinition {
80 pub name: String,
81 pub description: Option<String>,
82 pub input_schema: serde_json::Value,
83}
84
85#[derive(Debug, Deserialize, Serialize)]
86#[serde(tag = "type", rename_all = "snake_case")]
87pub enum CacheControl {
88 Ephemeral,
89}
90
91impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
92 type Error = CompletionError;
93
94 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
95 let content = response
96 .content
97 .iter()
98 .map(|content| {
99 Ok(match content {
100 Content::Text { text } => completion::AssistantContent::text(text),
101 Content::ToolUse { id, name, input } => {
102 completion::AssistantContent::tool_call(id, name, input.clone())
103 }
104 _ => {
105 return Err(CompletionError::ResponseError(
106 "Response did not contain a message or tool call".into(),
107 ))
108 }
109 })
110 })
111 .collect::<Result<Vec<_>, _>>()?;
112
113 let choice = OneOrMany::many(content).map_err(|_| {
114 CompletionError::ResponseError(
115 "Response contained no message or tool call (empty)".to_owned(),
116 )
117 })?;
118
119 Ok(completion::CompletionResponse {
120 choice,
121 raw_response: response,
122 })
123 }
124}
125
126#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
127pub struct Message {
128 pub role: Role,
129 #[serde(deserialize_with = "string_or_one_or_many")]
130 pub content: OneOrMany<Content>,
131}
132
133#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
134#[serde(rename_all = "lowercase")]
135pub enum Role {
136 User,
137 Assistant,
138}
139
140#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
141#[serde(tag = "type", rename_all = "snake_case")]
142pub enum Content {
143 Text {
144 text: String,
145 },
146 Image {
147 source: ImageSource,
148 },
149 ToolUse {
150 id: String,
151 name: String,
152 input: serde_json::Value,
153 },
154 ToolResult {
155 tool_use_id: String,
156 #[serde(deserialize_with = "string_or_one_or_many")]
157 content: OneOrMany<ToolResultContent>,
158 #[serde(skip_serializing_if = "Option::is_none")]
159 is_error: Option<bool>,
160 },
161 Document {
162 source: DocumentSource,
163 },
164}
165
166impl FromStr for Content {
167 type Err = Infallible;
168
169 fn from_str(s: &str) -> Result<Self, Self::Err> {
170 Ok(Content::Text { text: s.to_owned() })
171 }
172}
173
174#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
175#[serde(tag = "type", rename_all = "snake_case")]
176pub enum ToolResultContent {
177 Text { text: String },
178 Image(ImageSource),
179}
180
181impl FromStr for ToolResultContent {
182 type Err = Infallible;
183
184 fn from_str(s: &str) -> Result<Self, Self::Err> {
185 Ok(ToolResultContent::Text { text: s.to_owned() })
186 }
187}
188
189#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
190pub struct ImageSource {
191 pub data: String,
192 pub media_type: ImageFormat,
193 pub r#type: SourceType,
194}
195
196#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
197pub struct DocumentSource {
198 pub data: String,
199 pub media_type: DocumentFormat,
200 pub r#type: SourceType,
201}
202
203#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
204#[serde(rename_all = "lowercase")]
205pub enum ImageFormat {
206 #[serde(rename = "image/jpeg")]
207 JPEG,
208 #[serde(rename = "image/png")]
209 PNG,
210 #[serde(rename = "image/gif")]
211 GIF,
212 #[serde(rename = "image/webp")]
213 WEBP,
214}
215
216#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
217#[serde(rename_all = "lowercase")]
218pub enum DocumentFormat {
219 #[serde(rename = "application/pdf")]
220 PDF,
221}
222
223#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
224#[serde(rename_all = "lowercase")]
225pub enum SourceType {
226 BASE64,
227}
228
229impl From<String> for Content {
230 fn from(text: String) -> Self {
231 Content::Text { text }
232 }
233}
234
235impl From<String> for ToolResultContent {
236 fn from(text: String) -> Self {
237 ToolResultContent::Text { text }
238 }
239}
240
241impl TryFrom<message::ContentFormat> for SourceType {
242 type Error = MessageError;
243
244 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
245 match format {
246 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
247 message::ContentFormat::String => Err(MessageError::ConversionError(
248 "Image urls are not supported in Anthropic".to_owned(),
249 )),
250 }
251 }
252}
253
254impl From<SourceType> for message::ContentFormat {
255 fn from(source_type: SourceType) -> Self {
256 match source_type {
257 SourceType::BASE64 => message::ContentFormat::Base64,
258 }
259 }
260}
261
262impl TryFrom<message::ImageMediaType> for ImageFormat {
263 type Error = MessageError;
264
265 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
266 Ok(match media_type {
267 message::ImageMediaType::JPEG => ImageFormat::JPEG,
268 message::ImageMediaType::PNG => ImageFormat::PNG,
269 message::ImageMediaType::GIF => ImageFormat::GIF,
270 message::ImageMediaType::WEBP => ImageFormat::WEBP,
271 _ => {
272 return Err(MessageError::ConversionError(
273 format!("Unsupported image media type: {:?}", media_type).to_owned(),
274 ))
275 }
276 })
277 }
278}
279
280impl From<ImageFormat> for message::ImageMediaType {
281 fn from(format: ImageFormat) -> Self {
282 match format {
283 ImageFormat::JPEG => message::ImageMediaType::JPEG,
284 ImageFormat::PNG => message::ImageMediaType::PNG,
285 ImageFormat::GIF => message::ImageMediaType::GIF,
286 ImageFormat::WEBP => message::ImageMediaType::WEBP,
287 }
288 }
289}
290
291impl From<message::AssistantContent> for Content {
292 fn from(text: message::AssistantContent) -> Self {
293 match text {
294 message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
295 message::AssistantContent::ToolCall(message::ToolCall { id, function }) => {
296 Content::ToolUse {
297 id,
298 name: function.name,
299 input: function.arguments,
300 }
301 }
302 }
303 }
304}
305
306impl TryFrom<message::Message> for Message {
307 type Error = MessageError;
308
309 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
310 Ok(match message {
311 message::Message::User { content } => Message {
312 role: Role::User,
313 content: content.try_map(|content| match content {
314 message::UserContent::Text(message::Text { text }) => {
315 Ok(Content::Text { text })
316 }
317 message::UserContent::ToolResult(message::ToolResult { id, content }) => {
318 Ok(Content::ToolResult {
319 tool_use_id: id,
320 content: content.try_map(|content| match content {
321 message::ToolResultContent::Text(message::Text { text }) => {
322 Ok(ToolResultContent::Text { text })
323 }
324 message::ToolResultContent::Image(image) => {
325 let media_type =
326 image.media_type.ok_or(MessageError::ConversionError(
327 "Image media type is required".to_owned(),
328 ))?;
329 let format =
330 image.format.ok_or(MessageError::ConversionError(
331 "Image format is required".to_owned(),
332 ))?;
333 Ok(ToolResultContent::Image(ImageSource {
334 data: image.data,
335 media_type: media_type.try_into()?,
336 r#type: format.try_into()?,
337 }))
338 }
339 })?,
340 is_error: None,
341 })
342 }
343 message::UserContent::Image(message::Image {
344 data,
345 format,
346 media_type,
347 ..
348 }) => {
349 let source = ImageSource {
350 data,
351 media_type: match media_type {
352 Some(media_type) => media_type.try_into()?,
353 None => {
354 return Err(MessageError::ConversionError(
355 "Image media type is required".to_owned(),
356 ))
357 }
358 },
359 r#type: match format {
360 Some(format) => format.try_into()?,
361 None => SourceType::BASE64,
362 },
363 };
364 Ok(Content::Image { source })
365 }
366 message::UserContent::Document(message::Document { data, format, .. }) => {
367 let source = DocumentSource {
368 data,
369 media_type: DocumentFormat::PDF,
370 r#type: match format {
371 Some(format) => format.try_into()?,
372 None => SourceType::BASE64,
373 },
374 };
375 Ok(Content::Document { source })
376 }
377 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
378 "Audio is not supported in Anthropic".to_owned(),
379 )),
380 })?,
381 },
382
383 message::Message::Assistant { content } => Message {
384 content: content.map(|content| content.into()),
385 role: Role::Assistant,
386 },
387 })
388 }
389}
390
391impl TryFrom<Content> for message::AssistantContent {
392 type Error = MessageError;
393
394 fn try_from(content: Content) -> Result<Self, Self::Error> {
395 Ok(match content {
396 Content::Text { text } => message::AssistantContent::text(text),
397 Content::ToolUse { id, name, input } => {
398 message::AssistantContent::tool_call(id, name, input)
399 }
400 _ => {
401 return Err(MessageError::ConversionError(
402 format!("Unsupported content type for Assistant role: {:?}", content)
403 .to_owned(),
404 ))
405 }
406 })
407 }
408}
409
410impl From<ToolResultContent> for message::ToolResultContent {
411 fn from(content: ToolResultContent) -> Self {
412 match content {
413 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
414 ToolResultContent::Image(ImageSource {
415 data,
416 media_type: format,
417 r#type,
418 }) => message::ToolResultContent::image(
419 data,
420 Some(r#type.into()),
421 Some(format.into()),
422 None,
423 ),
424 }
425 }
426}
427
428impl TryFrom<Message> for message::Message {
429 type Error = MessageError;
430
431 fn try_from(message: Message) -> Result<Self, Self::Error> {
432 Ok(match message.role {
433 Role::User => message::Message::User {
434 content: message.content.try_map(|content| {
435 Ok(match content {
436 Content::Text { text } => message::UserContent::text(text),
437 Content::ToolResult {
438 tool_use_id,
439 content,
440 ..
441 } => message::UserContent::tool_result(
442 tool_use_id,
443 content.map(|content| content.into()),
444 ),
445 Content::Image { source } => message::UserContent::Image(message::Image {
446 data: source.data,
447 format: Some(message::ContentFormat::Base64),
448 media_type: Some(source.media_type.into()),
449 detail: None,
450 }),
451 Content::Document { source } => message::UserContent::document(
452 source.data,
453 Some(message::ContentFormat::Base64),
454 Some(message::DocumentMediaType::PDF),
455 ),
456 _ => {
457 return Err(MessageError::ConversionError(
458 "Unsupported content type for User role".to_owned(),
459 ))
460 }
461 })
462 })?,
463 },
464 Role::Assistant => match message.content.first() {
465 Content::Text { .. } | Content::ToolUse { .. } => message::Message::Assistant {
466 content: message.content.try_map(|content| content.try_into())?,
467 },
468
469 _ => {
470 return Err(MessageError::ConversionError(
471 format!("Unsupported message for Assistant role: {:?}", message).to_owned(),
472 ))
473 }
474 },
475 })
476 }
477}
478
479#[derive(Clone)]
480pub struct CompletionModel {
481 pub(crate) client: Client,
482 pub model: String,
483 pub default_max_tokens: Option<u64>,
484}
485
486impl CompletionModel {
487 pub fn new(client: Client, model: &str) -> Self {
488 Self {
489 client,
490 model: model.to_string(),
491 default_max_tokens: calculate_max_tokens(model),
492 }
493 }
494}
495
496fn calculate_max_tokens(model: &str) -> Option<u64> {
502 if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
503 Some(8192)
504 } else if model.starts_with("claude-3-opus")
505 || model.starts_with("claude-3-sonnet")
506 || model.starts_with("claude-3-haiku")
507 {
508 Some(4096)
509 } else {
510 None
511 }
512}
513
514#[derive(Debug, Deserialize, Serialize)]
515struct Metadata {
516 user_id: Option<String>,
517}
518
519#[derive(Default, Debug, Serialize, Deserialize)]
520#[serde(tag = "type", rename_all = "snake_case")]
521pub enum ToolChoice {
522 #[default]
523 Auto,
524 Any,
525 Tool {
526 name: String,
527 },
528}
529
530impl completion::CompletionModel for CompletionModel {
531 type Response = CompletionResponse;
532
533 #[cfg_attr(feature = "worker", worker::send)]
534 async fn completion(
535 &self,
536 completion_request: completion::CompletionRequest,
537 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
538 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
544 tokens
545 } else if let Some(tokens) = self.default_max_tokens {
546 tokens
547 } else {
548 return Err(CompletionError::RequestError(
549 "`max_tokens` must be set for Anthropic".into(),
550 ));
551 };
552
553 let prompt_message: Message = completion_request
554 .prompt_with_context()
555 .try_into()
556 .map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
557
558 let mut messages = completion_request
559 .chat_history
560 .into_iter()
561 .map(|message| {
562 message
563 .try_into()
564 .map_err(|e: MessageError| CompletionError::RequestError(e.into()))
565 })
566 .collect::<Result<Vec<Message>, _>>()?;
567
568 messages.push(prompt_message);
569
570 let mut request = json!({
571 "model": self.model,
572 "messages": messages,
573 "max_tokens": max_tokens,
574 "system": completion_request.preamble.unwrap_or("".to_string()),
575 });
576
577 if let Some(temperature) = completion_request.temperature {
578 json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
579 }
580
581 if !completion_request.tools.is_empty() {
582 json_utils::merge_inplace(
583 &mut request,
584 json!({
585 "tools": completion_request
586 .tools
587 .into_iter()
588 .map(|tool| ToolDefinition {
589 name: tool.name,
590 description: Some(tool.description),
591 input_schema: tool.parameters,
592 })
593 .collect::<Vec<_>>(),
594 "tool_choice": ToolChoice::Auto,
595 }),
596 );
597 }
598
599 if let Some(ref params) = completion_request.additional_params {
600 json_utils::merge_inplace(&mut request, params.clone())
601 }
602
603 tracing::debug!("Anthropic completion request: {request}");
604
605 let response = self
606 .client
607 .post("/v1/messages")
608 .json(&request)
609 .send()
610 .await?;
611
612 if response.status().is_success() {
613 match response.json::<ApiResponse<CompletionResponse>>().await? {
614 ApiResponse::Message(completion) => {
615 tracing::info!(target: "rig",
616 "Anthropic completion token usage: {}",
617 completion.usage
618 );
619 completion.try_into()
620 }
621 ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
622 }
623 } else {
624 Err(CompletionError::ProviderError(response.text().await?))
625 }
626 }
627}
628
629#[derive(Debug, Deserialize)]
630struct ApiErrorResponse {
631 message: String,
632}
633
634#[derive(Debug, Deserialize)]
635#[serde(tag = "type", rename_all = "snake_case")]
636enum ApiResponse<T> {
637 Message(T),
638 Error(ApiErrorResponse),
639}
640
641#[cfg(test)]
642mod tests {
643 use super::*;
644 use serde_path_to_error::deserialize;
645
646 #[test]
647 fn test_deserialize_message() {
648 let assistant_message_json = r#"
649 {
650 "role": "assistant",
651 "content": "\n\nHello there, how may I assist you today?"
652 }
653 "#;
654
655 let assistant_message_json2 = r#"
656 {
657 "role": "assistant",
658 "content": [
659 {
660 "type": "text",
661 "text": "\n\nHello there, how may I assist you today?"
662 },
663 {
664 "type": "tool_use",
665 "id": "toolu_01A09q90qw90lq917835lq9",
666 "name": "get_weather",
667 "input": {"location": "San Francisco, CA"}
668 }
669 ]
670 }
671 "#;
672
673 let user_message_json = r#"
674 {
675 "role": "user",
676 "content": [
677 {
678 "type": "image",
679 "source": {
680 "type": "base64",
681 "media_type": "image/jpeg",
682 "data": "/9j/4AAQSkZJRg..."
683 }
684 },
685 {
686 "type": "text",
687 "text": "What is in this image?"
688 },
689 {
690 "type": "tool_result",
691 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
692 "content": "15 degrees"
693 }
694 ]
695 }
696 "#;
697
698 let assistant_message: Message = {
699 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
700 deserialize(jd).unwrap_or_else(|err| {
701 panic!("Deserialization error at {}: {}", err.path(), err);
702 })
703 };
704
705 let assistant_message2: Message = {
706 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
707 deserialize(jd).unwrap_or_else(|err| {
708 panic!("Deserialization error at {}: {}", err.path(), err);
709 })
710 };
711
712 let user_message: Message = {
713 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
714 deserialize(jd).unwrap_or_else(|err| {
715 panic!("Deserialization error at {}: {}", err.path(), err);
716 })
717 };
718
719 match assistant_message {
720 Message { role, content } => {
721 assert_eq!(role, Role::Assistant);
722 assert_eq!(
723 content.first(),
724 Content::Text {
725 text: "\n\nHello there, how may I assist you today?".to_owned()
726 }
727 );
728 }
729 }
730
731 match assistant_message2 {
732 Message { role, content } => {
733 assert_eq!(role, Role::Assistant);
734 assert_eq!(content.len(), 2);
735
736 let mut iter = content.into_iter();
737
738 match iter.next().unwrap() {
739 Content::Text { text } => {
740 assert_eq!(text, "\n\nHello there, how may I assist you today?");
741 }
742 _ => panic!("Expected text content"),
743 }
744
745 match iter.next().unwrap() {
746 Content::ToolUse { id, name, input } => {
747 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
748 assert_eq!(name, "get_weather");
749 assert_eq!(input, json!({"location": "San Francisco, CA"}));
750 }
751 _ => panic!("Expected tool use content"),
752 }
753
754 assert_eq!(iter.next(), None);
755 }
756 }
757
758 match user_message {
759 Message { role, content } => {
760 assert_eq!(role, Role::User);
761 assert_eq!(content.len(), 3);
762
763 let mut iter = content.into_iter();
764
765 match iter.next().unwrap() {
766 Content::Image { source } => {
767 assert_eq!(
768 source,
769 ImageSource {
770 data: "/9j/4AAQSkZJRg...".to_owned(),
771 media_type: ImageFormat::JPEG,
772 r#type: SourceType::BASE64,
773 }
774 );
775 }
776 _ => panic!("Expected image content"),
777 }
778
779 match iter.next().unwrap() {
780 Content::Text { text } => {
781 assert_eq!(text, "What is in this image?");
782 }
783 _ => panic!("Expected text content"),
784 }
785
786 match iter.next().unwrap() {
787 Content::ToolResult {
788 tool_use_id,
789 content,
790 is_error,
791 } => {
792 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
793 assert_eq!(
794 content.first(),
795 ToolResultContent::Text {
796 text: "15 degrees".to_owned()
797 }
798 );
799 assert_eq!(is_error, None);
800 }
801 _ => panic!("Expected tool result content"),
802 }
803
804 assert_eq!(iter.next(), None);
805 }
806 }
807 }
808
809 #[test]
810 fn test_message_to_message_conversion() {
811 let user_message: Message = serde_json::from_str(
812 r#"
813 {
814 "role": "user",
815 "content": [
816 {
817 "type": "image",
818 "source": {
819 "type": "base64",
820 "media_type": "image/jpeg",
821 "data": "/9j/4AAQSkZJRg..."
822 }
823 },
824 {
825 "type": "text",
826 "text": "What is in this image?"
827 },
828 {
829 "type": "document",
830 "source": {
831 "type": "base64",
832 "data": "base64_encoded_pdf_data",
833 "media_type": "application/pdf"
834 }
835 }
836 ]
837 }
838 "#,
839 )
840 .unwrap();
841
842 let assistant_message = Message {
843 role: Role::Assistant,
844 content: OneOrMany::one(Content::ToolUse {
845 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
846 name: "get_weather".to_string(),
847 input: json!({"location": "San Francisco, CA"}),
848 }),
849 };
850
851 let tool_message = Message {
852 role: Role::User,
853 content: OneOrMany::one(Content::ToolResult {
854 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
855 content: OneOrMany::one(ToolResultContent::Text {
856 text: "15 degrees".to_string(),
857 }),
858 is_error: None,
859 }),
860 };
861
862 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
863 let converted_assistant_message: message::Message =
864 assistant_message.clone().try_into().unwrap();
865 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
866
867 match converted_user_message.clone() {
868 message::Message::User { content } => {
869 assert_eq!(content.len(), 3);
870
871 let mut iter = content.into_iter();
872
873 match iter.next().unwrap() {
874 message::UserContent::Image(message::Image {
875 data,
876 format,
877 media_type,
878 ..
879 }) => {
880 assert_eq!(data, "/9j/4AAQSkZJRg...");
881 assert_eq!(format.unwrap(), message::ContentFormat::Base64);
882 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
883 }
884 _ => panic!("Expected image content"),
885 }
886
887 match iter.next().unwrap() {
888 message::UserContent::Text(message::Text { text }) => {
889 assert_eq!(text, "What is in this image?");
890 }
891 _ => panic!("Expected text content"),
892 }
893
894 match iter.next().unwrap() {
895 message::UserContent::Document(message::Document {
896 data,
897 format,
898 media_type,
899 }) => {
900 assert_eq!(data, "base64_encoded_pdf_data");
901 assert_eq!(format.unwrap(), message::ContentFormat::Base64);
902 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
903 }
904 _ => panic!("Expected document content"),
905 }
906
907 assert_eq!(iter.next(), None);
908 }
909 _ => panic!("Expected user message"),
910 }
911
912 match converted_tool_message.clone() {
913 message::Message::User { content } => {
914 let message::ToolResult { id, content, .. } = match content.first() {
915 message::UserContent::ToolResult(tool_result) => tool_result,
916 _ => panic!("Expected tool result content"),
917 };
918 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
919 match content.first() {
920 message::ToolResultContent::Text(message::Text { text }) => {
921 assert_eq!(text, "15 degrees");
922 }
923 _ => panic!("Expected text content"),
924 }
925 }
926 _ => panic!("Expected tool result content"),
927 }
928
929 match converted_assistant_message.clone() {
930 message::Message::Assistant { content } => {
931 assert_eq!(content.len(), 1);
932
933 match content.first() {
934 message::AssistantContent::ToolCall(message::ToolCall { id, function }) => {
935 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
936 assert_eq!(function.name, "get_weather");
937 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
938 }
939 _ => panic!("Expected tool call content"),
940 }
941 }
942 _ => panic!("Expected assistant message"),
943 }
944
945 let original_user_message: Message = converted_user_message.try_into().unwrap();
946 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
947 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
948
949 assert_eq!(user_message, original_user_message);
950 assert_eq!(assistant_message, original_assistant_message);
951 assert_eq!(tool_message, original_tool_message);
952 }
953}