1use std::{convert::Infallible, str::FromStr};
4
5use crate::{
6 completion::{self, CompletionError},
7 json_utils,
8 message::{self, MessageError},
9 one_or_many::string_or_one_or_many,
10 OneOrMany,
11};
12
13use serde::{Deserialize, Serialize};
14use serde_json::json;
15
16use super::client::Client;
17
18pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
23
24pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
26
27pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
29
30pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
32
33pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
35
36pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
38
39pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
40pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
41pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
42
43#[derive(Debug, Deserialize)]
44pub struct CompletionResponse {
45 pub content: Vec<Content>,
46 pub id: String,
47 pub model: String,
48 pub role: String,
49 pub stop_reason: Option<String>,
50 pub stop_sequence: Option<String>,
51 pub usage: Usage,
52}
53
54#[derive(Debug, Deserialize, Serialize)]
55pub struct Usage {
56 pub input_tokens: u64,
57 pub cache_read_input_tokens: Option<u64>,
58 pub cache_creation_input_tokens: Option<u64>,
59 pub output_tokens: u64,
60}
61
62impl std::fmt::Display for Usage {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(
65 f,
66 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
67 self.input_tokens,
68 match self.cache_read_input_tokens {
69 Some(token) => token.to_string(),
70 None => "n/a".to_string(),
71 },
72 match self.cache_creation_input_tokens {
73 Some(token) => token.to_string(),
74 None => "n/a".to_string(),
75 },
76 self.output_tokens
77 )
78 }
79}
80
81#[derive(Debug, Deserialize, Serialize)]
82pub struct ToolDefinition {
83 pub name: String,
84 pub description: Option<String>,
85 pub input_schema: serde_json::Value,
86}
87
88#[derive(Debug, Deserialize, Serialize)]
89#[serde(tag = "type", rename_all = "snake_case")]
90pub enum CacheControl {
91 Ephemeral,
92}
93
94impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
95 type Error = CompletionError;
96
97 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
98 let content = response
99 .content
100 .iter()
101 .map(|content| {
102 Ok(match content {
103 Content::Text { text } => completion::AssistantContent::text(text),
104 Content::ToolUse { id, name, input } => {
105 completion::AssistantContent::tool_call(id, name, input.clone())
106 }
107 _ => {
108 return Err(CompletionError::ResponseError(
109 "Response did not contain a message or tool call".into(),
110 ))
111 }
112 })
113 })
114 .collect::<Result<Vec<_>, _>>()?;
115
116 let choice = OneOrMany::many(content).map_err(|_| {
117 CompletionError::ResponseError(
118 "Response contained no message or tool call (empty)".to_owned(),
119 )
120 })?;
121
122 Ok(completion::CompletionResponse {
123 choice,
124 raw_response: response,
125 })
126 }
127}
128
129#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
130pub struct Message {
131 pub role: Role,
132 #[serde(deserialize_with = "string_or_one_or_many")]
133 pub content: OneOrMany<Content>,
134}
135
136#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
137#[serde(rename_all = "lowercase")]
138pub enum Role {
139 User,
140 Assistant,
141}
142
143#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
144#[serde(tag = "type", rename_all = "snake_case")]
145pub enum Content {
146 Text {
147 text: String,
148 },
149 Image {
150 source: ImageSource,
151 },
152 ToolUse {
153 id: String,
154 name: String,
155 input: serde_json::Value,
156 },
157 ToolResult {
158 tool_use_id: String,
159 #[serde(deserialize_with = "string_or_one_or_many")]
160 content: OneOrMany<ToolResultContent>,
161 #[serde(skip_serializing_if = "Option::is_none")]
162 is_error: Option<bool>,
163 },
164 Document {
165 source: DocumentSource,
166 },
167}
168
169impl FromStr for Content {
170 type Err = Infallible;
171
172 fn from_str(s: &str) -> Result<Self, Self::Err> {
173 Ok(Content::Text { text: s.to_owned() })
174 }
175}
176
177#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
178#[serde(tag = "type", rename_all = "snake_case")]
179pub enum ToolResultContent {
180 Text { text: String },
181 Image(ImageSource),
182}
183
184impl FromStr for ToolResultContent {
185 type Err = Infallible;
186
187 fn from_str(s: &str) -> Result<Self, Self::Err> {
188 Ok(ToolResultContent::Text { text: s.to_owned() })
189 }
190}
191
192#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
193pub struct ImageSource {
194 pub data: String,
195 pub media_type: ImageFormat,
196 pub r#type: SourceType,
197}
198
199#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
200pub struct DocumentSource {
201 pub data: String,
202 pub media_type: DocumentFormat,
203 pub r#type: SourceType,
204}
205
206#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
207#[serde(rename_all = "lowercase")]
208pub enum ImageFormat {
209 #[serde(rename = "image/jpeg")]
210 JPEG,
211 #[serde(rename = "image/png")]
212 PNG,
213 #[serde(rename = "image/gif")]
214 GIF,
215 #[serde(rename = "image/webp")]
216 WEBP,
217}
218
219#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
220#[serde(rename_all = "lowercase")]
221pub enum DocumentFormat {
222 #[serde(rename = "application/pdf")]
223 PDF,
224}
225
226#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
227#[serde(rename_all = "lowercase")]
228pub enum SourceType {
229 BASE64,
230}
231
232impl From<String> for Content {
233 fn from(text: String) -> Self {
234 Content::Text { text }
235 }
236}
237
238impl From<String> for ToolResultContent {
239 fn from(text: String) -> Self {
240 ToolResultContent::Text { text }
241 }
242}
243
244impl TryFrom<message::ContentFormat> for SourceType {
245 type Error = MessageError;
246
247 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
248 match format {
249 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
250 message::ContentFormat::String => Err(MessageError::ConversionError(
251 "Image urls are not supported in Anthropic".to_owned(),
252 )),
253 }
254 }
255}
256
257impl From<SourceType> for message::ContentFormat {
258 fn from(source_type: SourceType) -> Self {
259 match source_type {
260 SourceType::BASE64 => message::ContentFormat::Base64,
261 }
262 }
263}
264
265impl TryFrom<message::ImageMediaType> for ImageFormat {
266 type Error = MessageError;
267
268 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
269 Ok(match media_type {
270 message::ImageMediaType::JPEG => ImageFormat::JPEG,
271 message::ImageMediaType::PNG => ImageFormat::PNG,
272 message::ImageMediaType::GIF => ImageFormat::GIF,
273 message::ImageMediaType::WEBP => ImageFormat::WEBP,
274 _ => {
275 return Err(MessageError::ConversionError(
276 format!("Unsupported image media type: {:?}", media_type).to_owned(),
277 ))
278 }
279 })
280 }
281}
282
283impl From<ImageFormat> for message::ImageMediaType {
284 fn from(format: ImageFormat) -> Self {
285 match format {
286 ImageFormat::JPEG => message::ImageMediaType::JPEG,
287 ImageFormat::PNG => message::ImageMediaType::PNG,
288 ImageFormat::GIF => message::ImageMediaType::GIF,
289 ImageFormat::WEBP => message::ImageMediaType::WEBP,
290 }
291 }
292}
293
294impl From<message::AssistantContent> for Content {
295 fn from(text: message::AssistantContent) -> Self {
296 match text {
297 message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
298 message::AssistantContent::ToolCall(message::ToolCall { id, function }) => {
299 Content::ToolUse {
300 id,
301 name: function.name,
302 input: function.arguments,
303 }
304 }
305 }
306 }
307}
308
309impl TryFrom<message::Message> for Message {
310 type Error = MessageError;
311
312 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
313 Ok(match message {
314 message::Message::User { content } => Message {
315 role: Role::User,
316 content: content.try_map(|content| match content {
317 message::UserContent::Text(message::Text { text }) => {
318 Ok(Content::Text { text })
319 }
320 message::UserContent::ToolResult(message::ToolResult { id, content }) => {
321 Ok(Content::ToolResult {
322 tool_use_id: id,
323 content: content.try_map(|content| match content {
324 message::ToolResultContent::Text(message::Text { text }) => {
325 Ok(ToolResultContent::Text { text })
326 }
327 message::ToolResultContent::Image(image) => {
328 let media_type =
329 image.media_type.ok_or(MessageError::ConversionError(
330 "Image media type is required".to_owned(),
331 ))?;
332 let format =
333 image.format.ok_or(MessageError::ConversionError(
334 "Image format is required".to_owned(),
335 ))?;
336 Ok(ToolResultContent::Image(ImageSource {
337 data: image.data,
338 media_type: media_type.try_into()?,
339 r#type: format.try_into()?,
340 }))
341 }
342 })?,
343 is_error: None,
344 })
345 }
346 message::UserContent::Image(message::Image {
347 data,
348 format,
349 media_type,
350 ..
351 }) => {
352 let source = ImageSource {
353 data,
354 media_type: match media_type {
355 Some(media_type) => media_type.try_into()?,
356 None => {
357 return Err(MessageError::ConversionError(
358 "Image media type is required".to_owned(),
359 ))
360 }
361 },
362 r#type: match format {
363 Some(format) => format.try_into()?,
364 None => SourceType::BASE64,
365 },
366 };
367 Ok(Content::Image { source })
368 }
369 message::UserContent::Document(message::Document { data, format, .. }) => {
370 let source = DocumentSource {
371 data,
372 media_type: DocumentFormat::PDF,
373 r#type: match format {
374 Some(format) => format.try_into()?,
375 None => SourceType::BASE64,
376 },
377 };
378 Ok(Content::Document { source })
379 }
380 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
381 "Audio is not supported in Anthropic".to_owned(),
382 )),
383 })?,
384 },
385
386 message::Message::Assistant { content } => Message {
387 content: content.map(|content| content.into()),
388 role: Role::Assistant,
389 },
390 })
391 }
392}
393
394impl TryFrom<Content> for message::AssistantContent {
395 type Error = MessageError;
396
397 fn try_from(content: Content) -> Result<Self, Self::Error> {
398 Ok(match content {
399 Content::Text { text } => message::AssistantContent::text(text),
400 Content::ToolUse { id, name, input } => {
401 message::AssistantContent::tool_call(id, name, input)
402 }
403 _ => {
404 return Err(MessageError::ConversionError(
405 format!("Unsupported content type for Assistant role: {:?}", content)
406 .to_owned(),
407 ))
408 }
409 })
410 }
411}
412
413impl From<ToolResultContent> for message::ToolResultContent {
414 fn from(content: ToolResultContent) -> Self {
415 match content {
416 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
417 ToolResultContent::Image(ImageSource {
418 data,
419 media_type: format,
420 r#type,
421 }) => message::ToolResultContent::image(
422 data,
423 Some(r#type.into()),
424 Some(format.into()),
425 None,
426 ),
427 }
428 }
429}
430
431impl TryFrom<Message> for message::Message {
432 type Error = MessageError;
433
434 fn try_from(message: Message) -> Result<Self, Self::Error> {
435 Ok(match message.role {
436 Role::User => message::Message::User {
437 content: message.content.try_map(|content| {
438 Ok(match content {
439 Content::Text { text } => message::UserContent::text(text),
440 Content::ToolResult {
441 tool_use_id,
442 content,
443 ..
444 } => message::UserContent::tool_result(
445 tool_use_id,
446 content.map(|content| content.into()),
447 ),
448 Content::Image { source } => message::UserContent::Image(message::Image {
449 data: source.data,
450 format: Some(message::ContentFormat::Base64),
451 media_type: Some(source.media_type.into()),
452 detail: None,
453 }),
454 Content::Document { source } => message::UserContent::document(
455 source.data,
456 Some(message::ContentFormat::Base64),
457 Some(message::DocumentMediaType::PDF),
458 ),
459 _ => {
460 return Err(MessageError::ConversionError(
461 "Unsupported content type for User role".to_owned(),
462 ))
463 }
464 })
465 })?,
466 },
467 Role::Assistant => match message.content.first() {
468 Content::Text { .. } | Content::ToolUse { .. } => message::Message::Assistant {
469 content: message.content.try_map(|content| content.try_into())?,
470 },
471
472 _ => {
473 return Err(MessageError::ConversionError(
474 format!("Unsupported message for Assistant role: {:?}", message).to_owned(),
475 ))
476 }
477 },
478 })
479 }
480}
481
482#[derive(Clone)]
483pub struct CompletionModel {
484 pub(crate) client: Client,
485 pub model: String,
486 pub default_max_tokens: Option<u64>,
487}
488
489impl CompletionModel {
490 pub fn new(client: Client, model: &str) -> Self {
491 Self {
492 client,
493 model: model.to_string(),
494 default_max_tokens: calculate_max_tokens(model),
495 }
496 }
497}
498
499fn calculate_max_tokens(model: &str) -> Option<u64> {
505 if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
506 Some(8192)
507 } else if model.starts_with("claude-3-opus")
508 || model.starts_with("claude-3-sonnet")
509 || model.starts_with("claude-3-haiku")
510 {
511 Some(4096)
512 } else {
513 None
514 }
515}
516
517#[derive(Debug, Deserialize, Serialize)]
518struct Metadata {
519 user_id: Option<String>,
520}
521
522#[derive(Default, Debug, Serialize, Deserialize)]
523#[serde(tag = "type", rename_all = "snake_case")]
524pub enum ToolChoice {
525 #[default]
526 Auto,
527 Any,
528 Tool {
529 name: String,
530 },
531}
532
533impl completion::CompletionModel for CompletionModel {
534 type Response = CompletionResponse;
535
536 #[cfg_attr(feature = "worker", worker::send)]
537 async fn completion(
538 &self,
539 completion_request: completion::CompletionRequest,
540 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
541 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
547 tokens
548 } else if let Some(tokens) = self.default_max_tokens {
549 tokens
550 } else {
551 return Err(CompletionError::RequestError(
552 "`max_tokens` must be set for Anthropic".into(),
553 ));
554 };
555
556 let mut full_history = vec![];
557 if let Some(docs) = completion_request.normalized_documents() {
558 full_history.push(docs);
559 }
560 full_history.extend(completion_request.chat_history);
561
562 let full_history = full_history
563 .into_iter()
564 .map(Message::try_from)
565 .collect::<Result<Vec<Message>, _>>()?;
566
567 let mut request = json!({
568 "model": self.model,
569 "messages": full_history,
570 "max_tokens": max_tokens,
571 "system": completion_request.preamble.unwrap_or("".to_string()),
572 });
573
574 if let Some(temperature) = completion_request.temperature {
575 json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
576 }
577
578 if !completion_request.tools.is_empty() {
579 json_utils::merge_inplace(
580 &mut request,
581 json!({
582 "tools": completion_request
583 .tools
584 .into_iter()
585 .map(|tool| ToolDefinition {
586 name: tool.name,
587 description: Some(tool.description),
588 input_schema: tool.parameters,
589 })
590 .collect::<Vec<_>>(),
591 "tool_choice": ToolChoice::Auto,
592 }),
593 );
594 }
595
596 if let Some(ref params) = completion_request.additional_params {
597 json_utils::merge_inplace(&mut request, params.clone())
598 }
599
600 tracing::debug!("Anthropic completion request: {request}");
601
602 let response = self
603 .client
604 .post("/v1/messages")
605 .json(&request)
606 .send()
607 .await?;
608
609 if response.status().is_success() {
610 match response.json::<ApiResponse<CompletionResponse>>().await? {
611 ApiResponse::Message(completion) => {
612 tracing::info!(target: "rig",
613 "Anthropic completion token usage: {}",
614 completion.usage
615 );
616 completion.try_into()
617 }
618 ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
619 }
620 } else {
621 Err(CompletionError::ProviderError(response.text().await?))
622 }
623 }
624}
625
626#[derive(Debug, Deserialize)]
627struct ApiErrorResponse {
628 message: String,
629}
630
631#[derive(Debug, Deserialize)]
632#[serde(tag = "type", rename_all = "snake_case")]
633enum ApiResponse<T> {
634 Message(T),
635 Error(ApiErrorResponse),
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use serde_path_to_error::deserialize;
642
643 #[test]
644 fn test_deserialize_message() {
645 let assistant_message_json = r#"
646 {
647 "role": "assistant",
648 "content": "\n\nHello there, how may I assist you today?"
649 }
650 "#;
651
652 let assistant_message_json2 = r#"
653 {
654 "role": "assistant",
655 "content": [
656 {
657 "type": "text",
658 "text": "\n\nHello there, how may I assist you today?"
659 },
660 {
661 "type": "tool_use",
662 "id": "toolu_01A09q90qw90lq917835lq9",
663 "name": "get_weather",
664 "input": {"location": "San Francisco, CA"}
665 }
666 ]
667 }
668 "#;
669
670 let user_message_json = r#"
671 {
672 "role": "user",
673 "content": [
674 {
675 "type": "image",
676 "source": {
677 "type": "base64",
678 "media_type": "image/jpeg",
679 "data": "/9j/4AAQSkZJRg..."
680 }
681 },
682 {
683 "type": "text",
684 "text": "What is in this image?"
685 },
686 {
687 "type": "tool_result",
688 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
689 "content": "15 degrees"
690 }
691 ]
692 }
693 "#;
694
695 let assistant_message: Message = {
696 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
697 deserialize(jd).unwrap_or_else(|err| {
698 panic!("Deserialization error at {}: {}", err.path(), err);
699 })
700 };
701
702 let assistant_message2: Message = {
703 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
704 deserialize(jd).unwrap_or_else(|err| {
705 panic!("Deserialization error at {}: {}", err.path(), err);
706 })
707 };
708
709 let user_message: Message = {
710 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
711 deserialize(jd).unwrap_or_else(|err| {
712 panic!("Deserialization error at {}: {}", err.path(), err);
713 })
714 };
715
716 match assistant_message {
717 Message { role, content } => {
718 assert_eq!(role, Role::Assistant);
719 assert_eq!(
720 content.first(),
721 Content::Text {
722 text: "\n\nHello there, how may I assist you today?".to_owned()
723 }
724 );
725 }
726 }
727
728 match assistant_message2 {
729 Message { role, content } => {
730 assert_eq!(role, Role::Assistant);
731 assert_eq!(content.len(), 2);
732
733 let mut iter = content.into_iter();
734
735 match iter.next().unwrap() {
736 Content::Text { text } => {
737 assert_eq!(text, "\n\nHello there, how may I assist you today?");
738 }
739 _ => panic!("Expected text content"),
740 }
741
742 match iter.next().unwrap() {
743 Content::ToolUse { id, name, input } => {
744 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
745 assert_eq!(name, "get_weather");
746 assert_eq!(input, json!({"location": "San Francisco, CA"}));
747 }
748 _ => panic!("Expected tool use content"),
749 }
750
751 assert_eq!(iter.next(), None);
752 }
753 }
754
755 match user_message {
756 Message { role, content } => {
757 assert_eq!(role, Role::User);
758 assert_eq!(content.len(), 3);
759
760 let mut iter = content.into_iter();
761
762 match iter.next().unwrap() {
763 Content::Image { source } => {
764 assert_eq!(
765 source,
766 ImageSource {
767 data: "/9j/4AAQSkZJRg...".to_owned(),
768 media_type: ImageFormat::JPEG,
769 r#type: SourceType::BASE64,
770 }
771 );
772 }
773 _ => panic!("Expected image content"),
774 }
775
776 match iter.next().unwrap() {
777 Content::Text { text } => {
778 assert_eq!(text, "What is in this image?");
779 }
780 _ => panic!("Expected text content"),
781 }
782
783 match iter.next().unwrap() {
784 Content::ToolResult {
785 tool_use_id,
786 content,
787 is_error,
788 } => {
789 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
790 assert_eq!(
791 content.first(),
792 ToolResultContent::Text {
793 text: "15 degrees".to_owned()
794 }
795 );
796 assert_eq!(is_error, None);
797 }
798 _ => panic!("Expected tool result content"),
799 }
800
801 assert_eq!(iter.next(), None);
802 }
803 }
804 }
805
806 #[test]
807 fn test_message_to_message_conversion() {
808 let user_message: Message = serde_json::from_str(
809 r#"
810 {
811 "role": "user",
812 "content": [
813 {
814 "type": "image",
815 "source": {
816 "type": "base64",
817 "media_type": "image/jpeg",
818 "data": "/9j/4AAQSkZJRg..."
819 }
820 },
821 {
822 "type": "text",
823 "text": "What is in this image?"
824 },
825 {
826 "type": "document",
827 "source": {
828 "type": "base64",
829 "data": "base64_encoded_pdf_data",
830 "media_type": "application/pdf"
831 }
832 }
833 ]
834 }
835 "#,
836 )
837 .unwrap();
838
839 let assistant_message = Message {
840 role: Role::Assistant,
841 content: OneOrMany::one(Content::ToolUse {
842 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
843 name: "get_weather".to_string(),
844 input: json!({"location": "San Francisco, CA"}),
845 }),
846 };
847
848 let tool_message = Message {
849 role: Role::User,
850 content: OneOrMany::one(Content::ToolResult {
851 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
852 content: OneOrMany::one(ToolResultContent::Text {
853 text: "15 degrees".to_string(),
854 }),
855 is_error: None,
856 }),
857 };
858
859 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
860 let converted_assistant_message: message::Message =
861 assistant_message.clone().try_into().unwrap();
862 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
863
864 match converted_user_message.clone() {
865 message::Message::User { content } => {
866 assert_eq!(content.len(), 3);
867
868 let mut iter = content.into_iter();
869
870 match iter.next().unwrap() {
871 message::UserContent::Image(message::Image {
872 data,
873 format,
874 media_type,
875 ..
876 }) => {
877 assert_eq!(data, "/9j/4AAQSkZJRg...");
878 assert_eq!(format.unwrap(), message::ContentFormat::Base64);
879 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
880 }
881 _ => panic!("Expected image content"),
882 }
883
884 match iter.next().unwrap() {
885 message::UserContent::Text(message::Text { text }) => {
886 assert_eq!(text, "What is in this image?");
887 }
888 _ => panic!("Expected text content"),
889 }
890
891 match iter.next().unwrap() {
892 message::UserContent::Document(message::Document {
893 data,
894 format,
895 media_type,
896 }) => {
897 assert_eq!(data, "base64_encoded_pdf_data");
898 assert_eq!(format.unwrap(), message::ContentFormat::Base64);
899 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
900 }
901 _ => panic!("Expected document content"),
902 }
903
904 assert_eq!(iter.next(), None);
905 }
906 _ => panic!("Expected user message"),
907 }
908
909 match converted_tool_message.clone() {
910 message::Message::User { content } => {
911 let message::ToolResult { id, content, .. } = match content.first() {
912 message::UserContent::ToolResult(tool_result) => tool_result,
913 _ => panic!("Expected tool result content"),
914 };
915 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
916 match content.first() {
917 message::ToolResultContent::Text(message::Text { text }) => {
918 assert_eq!(text, "15 degrees");
919 }
920 _ => panic!("Expected text content"),
921 }
922 }
923 _ => panic!("Expected tool result content"),
924 }
925
926 match converted_assistant_message.clone() {
927 message::Message::Assistant { content } => {
928 assert_eq!(content.len(), 1);
929
930 match content.first() {
931 message::AssistantContent::ToolCall(message::ToolCall { id, function }) => {
932 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
933 assert_eq!(function.name, "get_weather");
934 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
935 }
936 _ => panic!("Expected tool call content"),
937 }
938 }
939 _ => panic!("Expected assistant message"),
940 }
941
942 let original_user_message: Message = converted_user_message.try_into().unwrap();
943 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
944 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
945
946 assert_eq!(user_message, original_user_message);
947 assert_eq!(assistant_message, original_assistant_message);
948 assert_eq!(tool_message, original_tool_message);
949 }
950}