rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use std::{convert::Infallible, str::FromStr};
4
5use crate::{
6    completion::{self, CompletionError},
7    json_utils,
8    message::{self, MessageError},
9    one_or_many::string_or_one_or_many,
10    OneOrMany,
11};
12
13use serde::{Deserialize, Serialize};
14use serde_json::json;
15
16use super::client::Client;
17
18// ================================================================
19// Anthropic Completion API
20// ================================================================
21/// `claude-3-5-sonnet-latest` completion model
22pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
23
24/// `claude-3-5-haiku-latest` completion model
25pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
26
27/// `claude-3-5-haiku-latest` completion model
28pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
29
30/// `claude-3-sonnet-20240229` completion model
31pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
32
33/// `claude-3-haiku-20240307` completion model
34pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize)]
41pub struct CompletionResponse {
42    pub content: Vec<Content>,
43    pub id: String,
44    pub model: String,
45    pub role: String,
46    pub stop_reason: Option<String>,
47    pub stop_sequence: Option<String>,
48    pub usage: Usage,
49}
50
51#[derive(Debug, Deserialize, Serialize)]
52pub struct Usage {
53    pub input_tokens: u64,
54    pub cache_read_input_tokens: Option<u64>,
55    pub cache_creation_input_tokens: Option<u64>,
56    pub output_tokens: u64,
57}
58
59impl std::fmt::Display for Usage {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        write!(
62            f,
63            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
64            self.input_tokens,
65            match self.cache_read_input_tokens {
66                Some(token) => token.to_string(),
67                None => "n/a".to_string(),
68            },
69            match self.cache_creation_input_tokens {
70                Some(token) => token.to_string(),
71                None => "n/a".to_string(),
72            },
73            self.output_tokens
74        )
75    }
76}
77
78#[derive(Debug, Deserialize, Serialize)]
79pub struct ToolDefinition {
80    pub name: String,
81    pub description: Option<String>,
82    pub input_schema: serde_json::Value,
83}
84
85#[derive(Debug, Deserialize, Serialize)]
86#[serde(tag = "type", rename_all = "snake_case")]
87pub enum CacheControl {
88    Ephemeral,
89}
90
91impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
92    type Error = CompletionError;
93
94    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
95        let content = response
96            .content
97            .iter()
98            .map(|content| {
99                Ok(match content {
100                    Content::Text { text } => completion::AssistantContent::text(text),
101                    Content::ToolUse { id, name, input } => {
102                        completion::AssistantContent::tool_call(id, name, input.clone())
103                    }
104                    _ => {
105                        return Err(CompletionError::ResponseError(
106                            "Response did not contain a message or tool call".into(),
107                        ))
108                    }
109                })
110            })
111            .collect::<Result<Vec<_>, _>>()?;
112
113        let choice = OneOrMany::many(content).map_err(|_| {
114            CompletionError::ResponseError(
115                "Response contained no message or tool call (empty)".to_owned(),
116            )
117        })?;
118
119        Ok(completion::CompletionResponse {
120            choice,
121            raw_response: response,
122        })
123    }
124}
125
126#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
127pub struct Message {
128    pub role: Role,
129    #[serde(deserialize_with = "string_or_one_or_many")]
130    pub content: OneOrMany<Content>,
131}
132
133#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
134#[serde(rename_all = "lowercase")]
135pub enum Role {
136    User,
137    Assistant,
138}
139
140#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
141#[serde(tag = "type", rename_all = "snake_case")]
142pub enum Content {
143    Text {
144        text: String,
145    },
146    Image {
147        source: ImageSource,
148    },
149    ToolUse {
150        id: String,
151        name: String,
152        input: serde_json::Value,
153    },
154    ToolResult {
155        tool_use_id: String,
156        #[serde(deserialize_with = "string_or_one_or_many")]
157        content: OneOrMany<ToolResultContent>,
158        #[serde(skip_serializing_if = "Option::is_none")]
159        is_error: Option<bool>,
160    },
161    Document {
162        source: DocumentSource,
163    },
164}
165
166impl FromStr for Content {
167    type Err = Infallible;
168
169    fn from_str(s: &str) -> Result<Self, Self::Err> {
170        Ok(Content::Text { text: s.to_owned() })
171    }
172}
173
174#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
175#[serde(tag = "type", rename_all = "snake_case")]
176pub enum ToolResultContent {
177    Text { text: String },
178    Image(ImageSource),
179}
180
181impl FromStr for ToolResultContent {
182    type Err = Infallible;
183
184    fn from_str(s: &str) -> Result<Self, Self::Err> {
185        Ok(ToolResultContent::Text { text: s.to_owned() })
186    }
187}
188
189#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
190pub struct ImageSource {
191    pub data: String,
192    pub media_type: ImageFormat,
193    pub r#type: SourceType,
194}
195
196#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
197pub struct DocumentSource {
198    pub data: String,
199    pub media_type: DocumentFormat,
200    pub r#type: SourceType,
201}
202
203#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
204#[serde(rename_all = "lowercase")]
205pub enum ImageFormat {
206    #[serde(rename = "image/jpeg")]
207    JPEG,
208    #[serde(rename = "image/png")]
209    PNG,
210    #[serde(rename = "image/gif")]
211    GIF,
212    #[serde(rename = "image/webp")]
213    WEBP,
214}
215
216#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
217#[serde(rename_all = "lowercase")]
218pub enum DocumentFormat {
219    #[serde(rename = "application/pdf")]
220    PDF,
221}
222
223#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
224#[serde(rename_all = "lowercase")]
225pub enum SourceType {
226    BASE64,
227}
228
229impl From<String> for Content {
230    fn from(text: String) -> Self {
231        Content::Text { text }
232    }
233}
234
235impl From<String> for ToolResultContent {
236    fn from(text: String) -> Self {
237        ToolResultContent::Text { text }
238    }
239}
240
241impl TryFrom<message::ContentFormat> for SourceType {
242    type Error = MessageError;
243
244    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
245        match format {
246            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
247            message::ContentFormat::String => Err(MessageError::ConversionError(
248                "Image urls are not supported in Anthropic".to_owned(),
249            )),
250        }
251    }
252}
253
254impl From<SourceType> for message::ContentFormat {
255    fn from(source_type: SourceType) -> Self {
256        match source_type {
257            SourceType::BASE64 => message::ContentFormat::Base64,
258        }
259    }
260}
261
262impl TryFrom<message::ImageMediaType> for ImageFormat {
263    type Error = MessageError;
264
265    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
266        Ok(match media_type {
267            message::ImageMediaType::JPEG => ImageFormat::JPEG,
268            message::ImageMediaType::PNG => ImageFormat::PNG,
269            message::ImageMediaType::GIF => ImageFormat::GIF,
270            message::ImageMediaType::WEBP => ImageFormat::WEBP,
271            _ => {
272                return Err(MessageError::ConversionError(
273                    format!("Unsupported image media type: {:?}", media_type).to_owned(),
274                ))
275            }
276        })
277    }
278}
279
280impl From<ImageFormat> for message::ImageMediaType {
281    fn from(format: ImageFormat) -> Self {
282        match format {
283            ImageFormat::JPEG => message::ImageMediaType::JPEG,
284            ImageFormat::PNG => message::ImageMediaType::PNG,
285            ImageFormat::GIF => message::ImageMediaType::GIF,
286            ImageFormat::WEBP => message::ImageMediaType::WEBP,
287        }
288    }
289}
290
291impl From<message::AssistantContent> for Content {
292    fn from(text: message::AssistantContent) -> Self {
293        match text {
294            message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
295            message::AssistantContent::ToolCall(message::ToolCall { id, function }) => {
296                Content::ToolUse {
297                    id,
298                    name: function.name,
299                    input: function.arguments,
300                }
301            }
302        }
303    }
304}
305
306impl TryFrom<message::Message> for Message {
307    type Error = MessageError;
308
309    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
310        Ok(match message {
311            message::Message::User { content } => Message {
312                role: Role::User,
313                content: content.try_map(|content| match content {
314                    message::UserContent::Text(message::Text { text }) => {
315                        Ok(Content::Text { text })
316                    }
317                    message::UserContent::ToolResult(message::ToolResult { id, content }) => {
318                        Ok(Content::ToolResult {
319                            tool_use_id: id,
320                            content: content.try_map(|content| match content {
321                                message::ToolResultContent::Text(message::Text { text }) => {
322                                    Ok(ToolResultContent::Text { text })
323                                }
324                                message::ToolResultContent::Image(image) => {
325                                    let media_type =
326                                        image.media_type.ok_or(MessageError::ConversionError(
327                                            "Image media type is required".to_owned(),
328                                        ))?;
329                                    let format =
330                                        image.format.ok_or(MessageError::ConversionError(
331                                            "Image format is required".to_owned(),
332                                        ))?;
333                                    Ok(ToolResultContent::Image(ImageSource {
334                                        data: image.data,
335                                        media_type: media_type.try_into()?,
336                                        r#type: format.try_into()?,
337                                    }))
338                                }
339                            })?,
340                            is_error: None,
341                        })
342                    }
343                    message::UserContent::Image(message::Image {
344                        data,
345                        format,
346                        media_type,
347                        ..
348                    }) => {
349                        let source = ImageSource {
350                            data,
351                            media_type: match media_type {
352                                Some(media_type) => media_type.try_into()?,
353                                None => {
354                                    return Err(MessageError::ConversionError(
355                                        "Image media type is required".to_owned(),
356                                    ))
357                                }
358                            },
359                            r#type: match format {
360                                Some(format) => format.try_into()?,
361                                None => SourceType::BASE64,
362                            },
363                        };
364                        Ok(Content::Image { source })
365                    }
366                    message::UserContent::Document(message::Document { data, format, .. }) => {
367                        let source = DocumentSource {
368                            data,
369                            media_type: DocumentFormat::PDF,
370                            r#type: match format {
371                                Some(format) => format.try_into()?,
372                                None => SourceType::BASE64,
373                            },
374                        };
375                        Ok(Content::Document { source })
376                    }
377                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
378                        "Audio is not supported in Anthropic".to_owned(),
379                    )),
380                })?,
381            },
382
383            message::Message::Assistant { content } => Message {
384                content: content.map(|content| content.into()),
385                role: Role::Assistant,
386            },
387        })
388    }
389}
390
391impl TryFrom<Content> for message::AssistantContent {
392    type Error = MessageError;
393
394    fn try_from(content: Content) -> Result<Self, Self::Error> {
395        Ok(match content {
396            Content::Text { text } => message::AssistantContent::text(text),
397            Content::ToolUse { id, name, input } => {
398                message::AssistantContent::tool_call(id, name, input)
399            }
400            _ => {
401                return Err(MessageError::ConversionError(
402                    format!("Unsupported content type for Assistant role: {:?}", content)
403                        .to_owned(),
404                ))
405            }
406        })
407    }
408}
409
410impl From<ToolResultContent> for message::ToolResultContent {
411    fn from(content: ToolResultContent) -> Self {
412        match content {
413            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
414            ToolResultContent::Image(ImageSource {
415                data,
416                media_type: format,
417                r#type,
418            }) => message::ToolResultContent::image(
419                data,
420                Some(r#type.into()),
421                Some(format.into()),
422                None,
423            ),
424        }
425    }
426}
427
428impl TryFrom<Message> for message::Message {
429    type Error = MessageError;
430
431    fn try_from(message: Message) -> Result<Self, Self::Error> {
432        Ok(match message.role {
433            Role::User => message::Message::User {
434                content: message.content.try_map(|content| {
435                    Ok(match content {
436                        Content::Text { text } => message::UserContent::text(text),
437                        Content::ToolResult {
438                            tool_use_id,
439                            content,
440                            ..
441                        } => message::UserContent::tool_result(
442                            tool_use_id,
443                            content.map(|content| content.into()),
444                        ),
445                        Content::Image { source } => message::UserContent::Image(message::Image {
446                            data: source.data,
447                            format: Some(message::ContentFormat::Base64),
448                            media_type: Some(source.media_type.into()),
449                            detail: None,
450                        }),
451                        Content::Document { source } => message::UserContent::document(
452                            source.data,
453                            Some(message::ContentFormat::Base64),
454                            Some(message::DocumentMediaType::PDF),
455                        ),
456                        _ => {
457                            return Err(MessageError::ConversionError(
458                                "Unsupported content type for User role".to_owned(),
459                            ))
460                        }
461                    })
462                })?,
463            },
464            Role::Assistant => match message.content.first() {
465                Content::Text { .. } | Content::ToolUse { .. } => message::Message::Assistant {
466                    content: message.content.try_map(|content| content.try_into())?,
467                },
468
469                _ => {
470                    return Err(MessageError::ConversionError(
471                        format!("Unsupported message for Assistant role: {:?}", message).to_owned(),
472                    ))
473                }
474            },
475        })
476    }
477}
478
479#[derive(Clone)]
480pub struct CompletionModel {
481    pub(crate) client: Client,
482    pub model: String,
483    pub default_max_tokens: Option<u64>,
484}
485
486impl CompletionModel {
487    pub fn new(client: Client, model: &str) -> Self {
488        Self {
489            client,
490            model: model.to_string(),
491            default_max_tokens: calculate_max_tokens(model),
492        }
493    }
494}
495
496/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
497/// set or if set too high, the request will fail. The following values are based on the models
498/// available at the time of writing.
499///
500/// Dev Note: This is really bad design, I'm not sure why they did it like this..
501fn calculate_max_tokens(model: &str) -> Option<u64> {
502    if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
503        Some(8192)
504    } else if model.starts_with("claude-3-opus")
505        || model.starts_with("claude-3-sonnet")
506        || model.starts_with("claude-3-haiku")
507    {
508        Some(4096)
509    } else {
510        None
511    }
512}
513
514#[derive(Debug, Deserialize, Serialize)]
515struct Metadata {
516    user_id: Option<String>,
517}
518
519#[derive(Default, Debug, Serialize, Deserialize)]
520#[serde(tag = "type", rename_all = "snake_case")]
521pub enum ToolChoice {
522    #[default]
523    Auto,
524    Any,
525    Tool {
526        name: String,
527    },
528}
529
530impl completion::CompletionModel for CompletionModel {
531    type Response = CompletionResponse;
532
533    #[cfg_attr(feature = "worker", worker::send)]
534    async fn completion(
535        &self,
536        completion_request: completion::CompletionRequest,
537    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
538        // Note: Ideally we'd introduce provider-specific Request models to handle the
539        // specific requirements of each provider. For now, we just manually check while
540        // building the request as a raw JSON document.
541
542        // Check if max_tokens is set, required for Anthropic
543        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
544            tokens
545        } else if let Some(tokens) = self.default_max_tokens {
546            tokens
547        } else {
548            return Err(CompletionError::RequestError(
549                "`max_tokens` must be set for Anthropic".into(),
550            ));
551        };
552
553        let prompt_message: Message = completion_request
554            .prompt_with_context()
555            .try_into()
556            .map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
557
558        let mut messages = completion_request
559            .chat_history
560            .into_iter()
561            .map(|message| {
562                message
563                    .try_into()
564                    .map_err(|e: MessageError| CompletionError::RequestError(e.into()))
565            })
566            .collect::<Result<Vec<Message>, _>>()?;
567
568        messages.push(prompt_message);
569
570        let mut request = json!({
571            "model": self.model,
572            "messages": messages,
573            "max_tokens": max_tokens,
574            "system": completion_request.preamble.unwrap_or("".to_string()),
575        });
576
577        if let Some(temperature) = completion_request.temperature {
578            json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
579        }
580
581        if !completion_request.tools.is_empty() {
582            json_utils::merge_inplace(
583                &mut request,
584                json!({
585                    "tools": completion_request
586                        .tools
587                        .into_iter()
588                        .map(|tool| ToolDefinition {
589                            name: tool.name,
590                            description: Some(tool.description),
591                            input_schema: tool.parameters,
592                        })
593                        .collect::<Vec<_>>(),
594                    "tool_choice": ToolChoice::Auto,
595                }),
596            );
597        }
598
599        if let Some(ref params) = completion_request.additional_params {
600            json_utils::merge_inplace(&mut request, params.clone())
601        }
602
603        tracing::debug!("Anthropic completion request: {request}");
604
605        let response = self
606            .client
607            .post("/v1/messages")
608            .json(&request)
609            .send()
610            .await?;
611
612        if response.status().is_success() {
613            match response.json::<ApiResponse<CompletionResponse>>().await? {
614                ApiResponse::Message(completion) => {
615                    tracing::info!(target: "rig",
616                        "Anthropic completion token usage: {}",
617                        completion.usage
618                    );
619                    completion.try_into()
620                }
621                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
622            }
623        } else {
624            Err(CompletionError::ProviderError(response.text().await?))
625        }
626    }
627}
628
629#[derive(Debug, Deserialize)]
630struct ApiErrorResponse {
631    message: String,
632}
633
634#[derive(Debug, Deserialize)]
635#[serde(tag = "type", rename_all = "snake_case")]
636enum ApiResponse<T> {
637    Message(T),
638    Error(ApiErrorResponse),
639}
640
641#[cfg(test)]
642mod tests {
643    use super::*;
644    use serde_path_to_error::deserialize;
645
646    #[test]
647    fn test_deserialize_message() {
648        let assistant_message_json = r#"
649        {
650            "role": "assistant",
651            "content": "\n\nHello there, how may I assist you today?"
652        }
653        "#;
654
655        let assistant_message_json2 = r#"
656        {
657            "role": "assistant",
658            "content": [
659                {
660                    "type": "text",
661                    "text": "\n\nHello there, how may I assist you today?"
662                },
663                {
664                    "type": "tool_use",
665                    "id": "toolu_01A09q90qw90lq917835lq9",
666                    "name": "get_weather",
667                    "input": {"location": "San Francisco, CA"}
668                }
669            ]
670        }
671        "#;
672
673        let user_message_json = r#"
674        {
675            "role": "user",
676            "content": [
677                {
678                    "type": "image",
679                    "source": {
680                        "type": "base64",
681                        "media_type": "image/jpeg",
682                        "data": "/9j/4AAQSkZJRg..."
683                    }
684                },
685                {
686                    "type": "text",
687                    "text": "What is in this image?"
688                },
689                {
690                    "type": "tool_result",
691                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
692                    "content": "15 degrees"
693                }
694            ]
695        }
696        "#;
697
698        let assistant_message: Message = {
699            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
700            deserialize(jd).unwrap_or_else(|err| {
701                panic!("Deserialization error at {}: {}", err.path(), err);
702            })
703        };
704
705        let assistant_message2: Message = {
706            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
707            deserialize(jd).unwrap_or_else(|err| {
708                panic!("Deserialization error at {}: {}", err.path(), err);
709            })
710        };
711
712        let user_message: Message = {
713            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
714            deserialize(jd).unwrap_or_else(|err| {
715                panic!("Deserialization error at {}: {}", err.path(), err);
716            })
717        };
718
719        match assistant_message {
720            Message { role, content } => {
721                assert_eq!(role, Role::Assistant);
722                assert_eq!(
723                    content.first(),
724                    Content::Text {
725                        text: "\n\nHello there, how may I assist you today?".to_owned()
726                    }
727                );
728            }
729        }
730
731        match assistant_message2 {
732            Message { role, content } => {
733                assert_eq!(role, Role::Assistant);
734                assert_eq!(content.len(), 2);
735
736                let mut iter = content.into_iter();
737
738                match iter.next().unwrap() {
739                    Content::Text { text } => {
740                        assert_eq!(text, "\n\nHello there, how may I assist you today?");
741                    }
742                    _ => panic!("Expected text content"),
743                }
744
745                match iter.next().unwrap() {
746                    Content::ToolUse { id, name, input } => {
747                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
748                        assert_eq!(name, "get_weather");
749                        assert_eq!(input, json!({"location": "San Francisco, CA"}));
750                    }
751                    _ => panic!("Expected tool use content"),
752                }
753
754                assert_eq!(iter.next(), None);
755            }
756        }
757
758        match user_message {
759            Message { role, content } => {
760                assert_eq!(role, Role::User);
761                assert_eq!(content.len(), 3);
762
763                let mut iter = content.into_iter();
764
765                match iter.next().unwrap() {
766                    Content::Image { source } => {
767                        assert_eq!(
768                            source,
769                            ImageSource {
770                                data: "/9j/4AAQSkZJRg...".to_owned(),
771                                media_type: ImageFormat::JPEG,
772                                r#type: SourceType::BASE64,
773                            }
774                        );
775                    }
776                    _ => panic!("Expected image content"),
777                }
778
779                match iter.next().unwrap() {
780                    Content::Text { text } => {
781                        assert_eq!(text, "What is in this image?");
782                    }
783                    _ => panic!("Expected text content"),
784                }
785
786                match iter.next().unwrap() {
787                    Content::ToolResult {
788                        tool_use_id,
789                        content,
790                        is_error,
791                    } => {
792                        assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
793                        assert_eq!(
794                            content.first(),
795                            ToolResultContent::Text {
796                                text: "15 degrees".to_owned()
797                            }
798                        );
799                        assert_eq!(is_error, None);
800                    }
801                    _ => panic!("Expected tool result content"),
802                }
803
804                assert_eq!(iter.next(), None);
805            }
806        }
807    }
808
809    #[test]
810    fn test_message_to_message_conversion() {
811        let user_message: Message = serde_json::from_str(
812            r#"
813        {
814            "role": "user",
815            "content": [
816                {
817                    "type": "image",
818                    "source": {
819                        "type": "base64",
820                        "media_type": "image/jpeg",
821                        "data": "/9j/4AAQSkZJRg..."
822                    }
823                },
824                {
825                    "type": "text",
826                    "text": "What is in this image?"
827                },
828                {
829                    "type": "document",
830                    "source": {
831                        "type": "base64",
832                        "data": "base64_encoded_pdf_data",
833                        "media_type": "application/pdf"
834                    }
835                }
836            ]
837        }
838        "#,
839        )
840        .unwrap();
841
842        let assistant_message = Message {
843            role: Role::Assistant,
844            content: OneOrMany::one(Content::ToolUse {
845                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
846                name: "get_weather".to_string(),
847                input: json!({"location": "San Francisco, CA"}),
848            }),
849        };
850
851        let tool_message = Message {
852            role: Role::User,
853            content: OneOrMany::one(Content::ToolResult {
854                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
855                content: OneOrMany::one(ToolResultContent::Text {
856                    text: "15 degrees".to_string(),
857                }),
858                is_error: None,
859            }),
860        };
861
862        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
863        let converted_assistant_message: message::Message =
864            assistant_message.clone().try_into().unwrap();
865        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
866
867        match converted_user_message.clone() {
868            message::Message::User { content } => {
869                assert_eq!(content.len(), 3);
870
871                let mut iter = content.into_iter();
872
873                match iter.next().unwrap() {
874                    message::UserContent::Image(message::Image {
875                        data,
876                        format,
877                        media_type,
878                        ..
879                    }) => {
880                        assert_eq!(data, "/9j/4AAQSkZJRg...");
881                        assert_eq!(format.unwrap(), message::ContentFormat::Base64);
882                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
883                    }
884                    _ => panic!("Expected image content"),
885                }
886
887                match iter.next().unwrap() {
888                    message::UserContent::Text(message::Text { text }) => {
889                        assert_eq!(text, "What is in this image?");
890                    }
891                    _ => panic!("Expected text content"),
892                }
893
894                match iter.next().unwrap() {
895                    message::UserContent::Document(message::Document {
896                        data,
897                        format,
898                        media_type,
899                    }) => {
900                        assert_eq!(data, "base64_encoded_pdf_data");
901                        assert_eq!(format.unwrap(), message::ContentFormat::Base64);
902                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
903                    }
904                    _ => panic!("Expected document content"),
905                }
906
907                assert_eq!(iter.next(), None);
908            }
909            _ => panic!("Expected user message"),
910        }
911
912        match converted_tool_message.clone() {
913            message::Message::User { content } => {
914                let message::ToolResult { id, content, .. } = match content.first() {
915                    message::UserContent::ToolResult(tool_result) => tool_result,
916                    _ => panic!("Expected tool result content"),
917                };
918                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
919                match content.first() {
920                    message::ToolResultContent::Text(message::Text { text }) => {
921                        assert_eq!(text, "15 degrees");
922                    }
923                    _ => panic!("Expected text content"),
924                }
925            }
926            _ => panic!("Expected tool result content"),
927        }
928
929        match converted_assistant_message.clone() {
930            message::Message::Assistant { content } => {
931                assert_eq!(content.len(), 1);
932
933                match content.first() {
934                    message::AssistantContent::ToolCall(message::ToolCall { id, function }) => {
935                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
936                        assert_eq!(function.name, "get_weather");
937                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
938                    }
939                    _ => panic!("Expected tool call content"),
940                }
941            }
942            _ => panic!("Expected assistant message"),
943        }
944
945        let original_user_message: Message = converted_user_message.try_into().unwrap();
946        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
947        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
948
949        assert_eq!(user_message, original_user_message);
950        assert_eq!(assistant_message, original_assistant_message);
951        assert_eq!(tool_message, original_tool_message);
952    }
953}