rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use std::{convert::Infallible, str::FromStr};
4
5use crate::{
6    completion::{self, CompletionError},
7    json_utils,
8    message::{self, MessageError},
9    one_or_many::string_or_one_or_many,
10    OneOrMany,
11};
12
13use serde::{Deserialize, Serialize};
14use serde_json::json;
15
16use super::client::Client;
17
18// ================================================================
19// Anthropic Completion API
20// ================================================================
21/// `claude-3-7-sonnet-latest` completion model
22pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
23
24/// `claude-3-5-sonnet-latest` completion model
25pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
26
27/// `claude-3-5-haiku-latest` completion model
28pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
29
30/// `claude-3-5-haiku-latest` completion model
31pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
32
33/// `claude-3-sonnet-20240229` completion model
34pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
35
36/// `claude-3-haiku-20240307` completion model
37pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
38
39pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
40pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
41pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
42
43#[derive(Debug, Deserialize)]
44pub struct CompletionResponse {
45    pub content: Vec<Content>,
46    pub id: String,
47    pub model: String,
48    pub role: String,
49    pub stop_reason: Option<String>,
50    pub stop_sequence: Option<String>,
51    pub usage: Usage,
52}
53
54#[derive(Debug, Deserialize, Serialize)]
55pub struct Usage {
56    pub input_tokens: u64,
57    pub cache_read_input_tokens: Option<u64>,
58    pub cache_creation_input_tokens: Option<u64>,
59    pub output_tokens: u64,
60}
61
62impl std::fmt::Display for Usage {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        write!(
65            f,
66            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
67            self.input_tokens,
68            match self.cache_read_input_tokens {
69                Some(token) => token.to_string(),
70                None => "n/a".to_string(),
71            },
72            match self.cache_creation_input_tokens {
73                Some(token) => token.to_string(),
74                None => "n/a".to_string(),
75            },
76            self.output_tokens
77        )
78    }
79}
80
81#[derive(Debug, Deserialize, Serialize)]
82pub struct ToolDefinition {
83    pub name: String,
84    pub description: Option<String>,
85    pub input_schema: serde_json::Value,
86}
87
88#[derive(Debug, Deserialize, Serialize)]
89#[serde(tag = "type", rename_all = "snake_case")]
90pub enum CacheControl {
91    Ephemeral,
92}
93
94impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
95    type Error = CompletionError;
96
97    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
98        let content = response
99            .content
100            .iter()
101            .map(|content| {
102                Ok(match content {
103                    Content::Text { text } => completion::AssistantContent::text(text),
104                    Content::ToolUse { id, name, input } => {
105                        completion::AssistantContent::tool_call(id, name, input.clone())
106                    }
107                    _ => {
108                        return Err(CompletionError::ResponseError(
109                            "Response did not contain a message or tool call".into(),
110                        ))
111                    }
112                })
113            })
114            .collect::<Result<Vec<_>, _>>()?;
115
116        let choice = OneOrMany::many(content).map_err(|_| {
117            CompletionError::ResponseError(
118                "Response contained no message or tool call (empty)".to_owned(),
119            )
120        })?;
121
122        Ok(completion::CompletionResponse {
123            choice,
124            raw_response: response,
125        })
126    }
127}
128
129#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
130pub struct Message {
131    pub role: Role,
132    #[serde(deserialize_with = "string_or_one_or_many")]
133    pub content: OneOrMany<Content>,
134}
135
136#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
137#[serde(rename_all = "lowercase")]
138pub enum Role {
139    User,
140    Assistant,
141}
142
143#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
144#[serde(tag = "type", rename_all = "snake_case")]
145pub enum Content {
146    Text {
147        text: String,
148    },
149    Image {
150        source: ImageSource,
151    },
152    ToolUse {
153        id: String,
154        name: String,
155        input: serde_json::Value,
156    },
157    ToolResult {
158        tool_use_id: String,
159        #[serde(deserialize_with = "string_or_one_or_many")]
160        content: OneOrMany<ToolResultContent>,
161        #[serde(skip_serializing_if = "Option::is_none")]
162        is_error: Option<bool>,
163    },
164    Document {
165        source: DocumentSource,
166    },
167}
168
169impl FromStr for Content {
170    type Err = Infallible;
171
172    fn from_str(s: &str) -> Result<Self, Self::Err> {
173        Ok(Content::Text { text: s.to_owned() })
174    }
175}
176
177#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
178#[serde(tag = "type", rename_all = "snake_case")]
179pub enum ToolResultContent {
180    Text { text: String },
181    Image(ImageSource),
182}
183
184impl FromStr for ToolResultContent {
185    type Err = Infallible;
186
187    fn from_str(s: &str) -> Result<Self, Self::Err> {
188        Ok(ToolResultContent::Text { text: s.to_owned() })
189    }
190}
191
192#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
193pub struct ImageSource {
194    pub data: String,
195    pub media_type: ImageFormat,
196    pub r#type: SourceType,
197}
198
199#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
200pub struct DocumentSource {
201    pub data: String,
202    pub media_type: DocumentFormat,
203    pub r#type: SourceType,
204}
205
206#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
207#[serde(rename_all = "lowercase")]
208pub enum ImageFormat {
209    #[serde(rename = "image/jpeg")]
210    JPEG,
211    #[serde(rename = "image/png")]
212    PNG,
213    #[serde(rename = "image/gif")]
214    GIF,
215    #[serde(rename = "image/webp")]
216    WEBP,
217}
218
219#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
220#[serde(rename_all = "lowercase")]
221pub enum DocumentFormat {
222    #[serde(rename = "application/pdf")]
223    PDF,
224}
225
226#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
227#[serde(rename_all = "lowercase")]
228pub enum SourceType {
229    BASE64,
230}
231
232impl From<String> for Content {
233    fn from(text: String) -> Self {
234        Content::Text { text }
235    }
236}
237
238impl From<String> for ToolResultContent {
239    fn from(text: String) -> Self {
240        ToolResultContent::Text { text }
241    }
242}
243
244impl TryFrom<message::ContentFormat> for SourceType {
245    type Error = MessageError;
246
247    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
248        match format {
249            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
250            message::ContentFormat::String => Err(MessageError::ConversionError(
251                "Image urls are not supported in Anthropic".to_owned(),
252            )),
253        }
254    }
255}
256
257impl From<SourceType> for message::ContentFormat {
258    fn from(source_type: SourceType) -> Self {
259        match source_type {
260            SourceType::BASE64 => message::ContentFormat::Base64,
261        }
262    }
263}
264
265impl TryFrom<message::ImageMediaType> for ImageFormat {
266    type Error = MessageError;
267
268    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
269        Ok(match media_type {
270            message::ImageMediaType::JPEG => ImageFormat::JPEG,
271            message::ImageMediaType::PNG => ImageFormat::PNG,
272            message::ImageMediaType::GIF => ImageFormat::GIF,
273            message::ImageMediaType::WEBP => ImageFormat::WEBP,
274            _ => {
275                return Err(MessageError::ConversionError(
276                    format!("Unsupported image media type: {:?}", media_type).to_owned(),
277                ))
278            }
279        })
280    }
281}
282
283impl From<ImageFormat> for message::ImageMediaType {
284    fn from(format: ImageFormat) -> Self {
285        match format {
286            ImageFormat::JPEG => message::ImageMediaType::JPEG,
287            ImageFormat::PNG => message::ImageMediaType::PNG,
288            ImageFormat::GIF => message::ImageMediaType::GIF,
289            ImageFormat::WEBP => message::ImageMediaType::WEBP,
290        }
291    }
292}
293
294impl From<message::AssistantContent> for Content {
295    fn from(text: message::AssistantContent) -> Self {
296        match text {
297            message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
298            message::AssistantContent::ToolCall(message::ToolCall { id, function }) => {
299                Content::ToolUse {
300                    id,
301                    name: function.name,
302                    input: function.arguments,
303                }
304            }
305        }
306    }
307}
308
309impl TryFrom<message::Message> for Message {
310    type Error = MessageError;
311
312    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
313        Ok(match message {
314            message::Message::User { content } => Message {
315                role: Role::User,
316                content: content.try_map(|content| match content {
317                    message::UserContent::Text(message::Text { text }) => {
318                        Ok(Content::Text { text })
319                    }
320                    message::UserContent::ToolResult(message::ToolResult { id, content }) => {
321                        Ok(Content::ToolResult {
322                            tool_use_id: id,
323                            content: content.try_map(|content| match content {
324                                message::ToolResultContent::Text(message::Text { text }) => {
325                                    Ok(ToolResultContent::Text { text })
326                                }
327                                message::ToolResultContent::Image(image) => {
328                                    let media_type =
329                                        image.media_type.ok_or(MessageError::ConversionError(
330                                            "Image media type is required".to_owned(),
331                                        ))?;
332                                    let format =
333                                        image.format.ok_or(MessageError::ConversionError(
334                                            "Image format is required".to_owned(),
335                                        ))?;
336                                    Ok(ToolResultContent::Image(ImageSource {
337                                        data: image.data,
338                                        media_type: media_type.try_into()?,
339                                        r#type: format.try_into()?,
340                                    }))
341                                }
342                            })?,
343                            is_error: None,
344                        })
345                    }
346                    message::UserContent::Image(message::Image {
347                        data,
348                        format,
349                        media_type,
350                        ..
351                    }) => {
352                        let source = ImageSource {
353                            data,
354                            media_type: match media_type {
355                                Some(media_type) => media_type.try_into()?,
356                                None => {
357                                    return Err(MessageError::ConversionError(
358                                        "Image media type is required".to_owned(),
359                                    ))
360                                }
361                            },
362                            r#type: match format {
363                                Some(format) => format.try_into()?,
364                                None => SourceType::BASE64,
365                            },
366                        };
367                        Ok(Content::Image { source })
368                    }
369                    message::UserContent::Document(message::Document { data, format, .. }) => {
370                        let source = DocumentSource {
371                            data,
372                            media_type: DocumentFormat::PDF,
373                            r#type: match format {
374                                Some(format) => format.try_into()?,
375                                None => SourceType::BASE64,
376                            },
377                        };
378                        Ok(Content::Document { source })
379                    }
380                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
381                        "Audio is not supported in Anthropic".to_owned(),
382                    )),
383                })?,
384            },
385
386            message::Message::Assistant { content } => Message {
387                content: content.map(|content| content.into()),
388                role: Role::Assistant,
389            },
390        })
391    }
392}
393
394impl TryFrom<Content> for message::AssistantContent {
395    type Error = MessageError;
396
397    fn try_from(content: Content) -> Result<Self, Self::Error> {
398        Ok(match content {
399            Content::Text { text } => message::AssistantContent::text(text),
400            Content::ToolUse { id, name, input } => {
401                message::AssistantContent::tool_call(id, name, input)
402            }
403            _ => {
404                return Err(MessageError::ConversionError(
405                    format!("Unsupported content type for Assistant role: {:?}", content)
406                        .to_owned(),
407                ))
408            }
409        })
410    }
411}
412
413impl From<ToolResultContent> for message::ToolResultContent {
414    fn from(content: ToolResultContent) -> Self {
415        match content {
416            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
417            ToolResultContent::Image(ImageSource {
418                data,
419                media_type: format,
420                r#type,
421            }) => message::ToolResultContent::image(
422                data,
423                Some(r#type.into()),
424                Some(format.into()),
425                None,
426            ),
427        }
428    }
429}
430
431impl TryFrom<Message> for message::Message {
432    type Error = MessageError;
433
434    fn try_from(message: Message) -> Result<Self, Self::Error> {
435        Ok(match message.role {
436            Role::User => message::Message::User {
437                content: message.content.try_map(|content| {
438                    Ok(match content {
439                        Content::Text { text } => message::UserContent::text(text),
440                        Content::ToolResult {
441                            tool_use_id,
442                            content,
443                            ..
444                        } => message::UserContent::tool_result(
445                            tool_use_id,
446                            content.map(|content| content.into()),
447                        ),
448                        Content::Image { source } => message::UserContent::Image(message::Image {
449                            data: source.data,
450                            format: Some(message::ContentFormat::Base64),
451                            media_type: Some(source.media_type.into()),
452                            detail: None,
453                        }),
454                        Content::Document { source } => message::UserContent::document(
455                            source.data,
456                            Some(message::ContentFormat::Base64),
457                            Some(message::DocumentMediaType::PDF),
458                        ),
459                        _ => {
460                            return Err(MessageError::ConversionError(
461                                "Unsupported content type for User role".to_owned(),
462                            ))
463                        }
464                    })
465                })?,
466            },
467            Role::Assistant => match message.content.first() {
468                Content::Text { .. } | Content::ToolUse { .. } => message::Message::Assistant {
469                    content: message.content.try_map(|content| content.try_into())?,
470                },
471
472                _ => {
473                    return Err(MessageError::ConversionError(
474                        format!("Unsupported message for Assistant role: {:?}", message).to_owned(),
475                    ))
476                }
477            },
478        })
479    }
480}
481
482#[derive(Clone)]
483pub struct CompletionModel {
484    pub(crate) client: Client,
485    pub model: String,
486    pub default_max_tokens: Option<u64>,
487}
488
489impl CompletionModel {
490    pub fn new(client: Client, model: &str) -> Self {
491        Self {
492            client,
493            model: model.to_string(),
494            default_max_tokens: calculate_max_tokens(model),
495        }
496    }
497}
498
499/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
500/// set or if set too high, the request will fail. The following values are based on the models
501/// available at the time of writing.
502///
503/// Dev Note: This is really bad design, I'm not sure why they did it like this..
504fn calculate_max_tokens(model: &str) -> Option<u64> {
505    if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
506        Some(8192)
507    } else if model.starts_with("claude-3-opus")
508        || model.starts_with("claude-3-sonnet")
509        || model.starts_with("claude-3-haiku")
510    {
511        Some(4096)
512    } else {
513        None
514    }
515}
516
517#[derive(Debug, Deserialize, Serialize)]
518struct Metadata {
519    user_id: Option<String>,
520}
521
522#[derive(Default, Debug, Serialize, Deserialize)]
523#[serde(tag = "type", rename_all = "snake_case")]
524pub enum ToolChoice {
525    #[default]
526    Auto,
527    Any,
528    Tool {
529        name: String,
530    },
531}
532
533impl completion::CompletionModel for CompletionModel {
534    type Response = CompletionResponse;
535
536    #[cfg_attr(feature = "worker", worker::send)]
537    async fn completion(
538        &self,
539        completion_request: completion::CompletionRequest,
540    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
541        // Note: Ideally we'd introduce provider-specific Request models to handle the
542        // specific requirements of each provider. For now, we just manually check while
543        // building the request as a raw JSON document.
544
545        // Check if max_tokens is set, required for Anthropic
546        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
547            tokens
548        } else if let Some(tokens) = self.default_max_tokens {
549            tokens
550        } else {
551            return Err(CompletionError::RequestError(
552                "`max_tokens` must be set for Anthropic".into(),
553            ));
554        };
555
556        let mut full_history = vec![];
557        if let Some(docs) = completion_request.normalized_documents() {
558            full_history.push(docs);
559        }
560        full_history.extend(completion_request.chat_history);
561
562        let full_history = full_history
563            .into_iter()
564            .map(Message::try_from)
565            .collect::<Result<Vec<Message>, _>>()?;
566
567        let mut request = json!({
568            "model": self.model,
569            "messages": full_history,
570            "max_tokens": max_tokens,
571            "system": completion_request.preamble.unwrap_or("".to_string()),
572        });
573
574        if let Some(temperature) = completion_request.temperature {
575            json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
576        }
577
578        if !completion_request.tools.is_empty() {
579            json_utils::merge_inplace(
580                &mut request,
581                json!({
582                    "tools": completion_request
583                        .tools
584                        .into_iter()
585                        .map(|tool| ToolDefinition {
586                            name: tool.name,
587                            description: Some(tool.description),
588                            input_schema: tool.parameters,
589                        })
590                        .collect::<Vec<_>>(),
591                    "tool_choice": ToolChoice::Auto,
592                }),
593            );
594        }
595
596        if let Some(ref params) = completion_request.additional_params {
597            json_utils::merge_inplace(&mut request, params.clone())
598        }
599
600        tracing::debug!("Anthropic completion request: {request}");
601
602        let response = self
603            .client
604            .post("/v1/messages")
605            .json(&request)
606            .send()
607            .await?;
608
609        if response.status().is_success() {
610            match response.json::<ApiResponse<CompletionResponse>>().await? {
611                ApiResponse::Message(completion) => {
612                    tracing::info!(target: "rig",
613                        "Anthropic completion token usage: {}",
614                        completion.usage
615                    );
616                    completion.try_into()
617                }
618                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
619            }
620        } else {
621            Err(CompletionError::ProviderError(response.text().await?))
622        }
623    }
624}
625
626#[derive(Debug, Deserialize)]
627struct ApiErrorResponse {
628    message: String,
629}
630
631#[derive(Debug, Deserialize)]
632#[serde(tag = "type", rename_all = "snake_case")]
633enum ApiResponse<T> {
634    Message(T),
635    Error(ApiErrorResponse),
636}
637
638#[cfg(test)]
639mod tests {
640    use super::*;
641    use serde_path_to_error::deserialize;
642
643    #[test]
644    fn test_deserialize_message() {
645        let assistant_message_json = r#"
646        {
647            "role": "assistant",
648            "content": "\n\nHello there, how may I assist you today?"
649        }
650        "#;
651
652        let assistant_message_json2 = r#"
653        {
654            "role": "assistant",
655            "content": [
656                {
657                    "type": "text",
658                    "text": "\n\nHello there, how may I assist you today?"
659                },
660                {
661                    "type": "tool_use",
662                    "id": "toolu_01A09q90qw90lq917835lq9",
663                    "name": "get_weather",
664                    "input": {"location": "San Francisco, CA"}
665                }
666            ]
667        }
668        "#;
669
670        let user_message_json = r#"
671        {
672            "role": "user",
673            "content": [
674                {
675                    "type": "image",
676                    "source": {
677                        "type": "base64",
678                        "media_type": "image/jpeg",
679                        "data": "/9j/4AAQSkZJRg..."
680                    }
681                },
682                {
683                    "type": "text",
684                    "text": "What is in this image?"
685                },
686                {
687                    "type": "tool_result",
688                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
689                    "content": "15 degrees"
690                }
691            ]
692        }
693        "#;
694
695        let assistant_message: Message = {
696            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
697            deserialize(jd).unwrap_or_else(|err| {
698                panic!("Deserialization error at {}: {}", err.path(), err);
699            })
700        };
701
702        let assistant_message2: Message = {
703            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
704            deserialize(jd).unwrap_or_else(|err| {
705                panic!("Deserialization error at {}: {}", err.path(), err);
706            })
707        };
708
709        let user_message: Message = {
710            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
711            deserialize(jd).unwrap_or_else(|err| {
712                panic!("Deserialization error at {}: {}", err.path(), err);
713            })
714        };
715
716        match assistant_message {
717            Message { role, content } => {
718                assert_eq!(role, Role::Assistant);
719                assert_eq!(
720                    content.first(),
721                    Content::Text {
722                        text: "\n\nHello there, how may I assist you today?".to_owned()
723                    }
724                );
725            }
726        }
727
728        match assistant_message2 {
729            Message { role, content } => {
730                assert_eq!(role, Role::Assistant);
731                assert_eq!(content.len(), 2);
732
733                let mut iter = content.into_iter();
734
735                match iter.next().unwrap() {
736                    Content::Text { text } => {
737                        assert_eq!(text, "\n\nHello there, how may I assist you today?");
738                    }
739                    _ => panic!("Expected text content"),
740                }
741
742                match iter.next().unwrap() {
743                    Content::ToolUse { id, name, input } => {
744                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
745                        assert_eq!(name, "get_weather");
746                        assert_eq!(input, json!({"location": "San Francisco, CA"}));
747                    }
748                    _ => panic!("Expected tool use content"),
749                }
750
751                assert_eq!(iter.next(), None);
752            }
753        }
754
755        match user_message {
756            Message { role, content } => {
757                assert_eq!(role, Role::User);
758                assert_eq!(content.len(), 3);
759
760                let mut iter = content.into_iter();
761
762                match iter.next().unwrap() {
763                    Content::Image { source } => {
764                        assert_eq!(
765                            source,
766                            ImageSource {
767                                data: "/9j/4AAQSkZJRg...".to_owned(),
768                                media_type: ImageFormat::JPEG,
769                                r#type: SourceType::BASE64,
770                            }
771                        );
772                    }
773                    _ => panic!("Expected image content"),
774                }
775
776                match iter.next().unwrap() {
777                    Content::Text { text } => {
778                        assert_eq!(text, "What is in this image?");
779                    }
780                    _ => panic!("Expected text content"),
781                }
782
783                match iter.next().unwrap() {
784                    Content::ToolResult {
785                        tool_use_id,
786                        content,
787                        is_error,
788                    } => {
789                        assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
790                        assert_eq!(
791                            content.first(),
792                            ToolResultContent::Text {
793                                text: "15 degrees".to_owned()
794                            }
795                        );
796                        assert_eq!(is_error, None);
797                    }
798                    _ => panic!("Expected tool result content"),
799                }
800
801                assert_eq!(iter.next(), None);
802            }
803        }
804    }
805
806    #[test]
807    fn test_message_to_message_conversion() {
808        let user_message: Message = serde_json::from_str(
809            r#"
810        {
811            "role": "user",
812            "content": [
813                {
814                    "type": "image",
815                    "source": {
816                        "type": "base64",
817                        "media_type": "image/jpeg",
818                        "data": "/9j/4AAQSkZJRg..."
819                    }
820                },
821                {
822                    "type": "text",
823                    "text": "What is in this image?"
824                },
825                {
826                    "type": "document",
827                    "source": {
828                        "type": "base64",
829                        "data": "base64_encoded_pdf_data",
830                        "media_type": "application/pdf"
831                    }
832                }
833            ]
834        }
835        "#,
836        )
837        .unwrap();
838
839        let assistant_message = Message {
840            role: Role::Assistant,
841            content: OneOrMany::one(Content::ToolUse {
842                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
843                name: "get_weather".to_string(),
844                input: json!({"location": "San Francisco, CA"}),
845            }),
846        };
847
848        let tool_message = Message {
849            role: Role::User,
850            content: OneOrMany::one(Content::ToolResult {
851                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
852                content: OneOrMany::one(ToolResultContent::Text {
853                    text: "15 degrees".to_string(),
854                }),
855                is_error: None,
856            }),
857        };
858
859        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
860        let converted_assistant_message: message::Message =
861            assistant_message.clone().try_into().unwrap();
862        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
863
864        match converted_user_message.clone() {
865            message::Message::User { content } => {
866                assert_eq!(content.len(), 3);
867
868                let mut iter = content.into_iter();
869
870                match iter.next().unwrap() {
871                    message::UserContent::Image(message::Image {
872                        data,
873                        format,
874                        media_type,
875                        ..
876                    }) => {
877                        assert_eq!(data, "/9j/4AAQSkZJRg...");
878                        assert_eq!(format.unwrap(), message::ContentFormat::Base64);
879                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
880                    }
881                    _ => panic!("Expected image content"),
882                }
883
884                match iter.next().unwrap() {
885                    message::UserContent::Text(message::Text { text }) => {
886                        assert_eq!(text, "What is in this image?");
887                    }
888                    _ => panic!("Expected text content"),
889                }
890
891                match iter.next().unwrap() {
892                    message::UserContent::Document(message::Document {
893                        data,
894                        format,
895                        media_type,
896                    }) => {
897                        assert_eq!(data, "base64_encoded_pdf_data");
898                        assert_eq!(format.unwrap(), message::ContentFormat::Base64);
899                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
900                    }
901                    _ => panic!("Expected document content"),
902                }
903
904                assert_eq!(iter.next(), None);
905            }
906            _ => panic!("Expected user message"),
907        }
908
909        match converted_tool_message.clone() {
910            message::Message::User { content } => {
911                let message::ToolResult { id, content, .. } = match content.first() {
912                    message::UserContent::ToolResult(tool_result) => tool_result,
913                    _ => panic!("Expected tool result content"),
914                };
915                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
916                match content.first() {
917                    message::ToolResultContent::Text(message::Text { text }) => {
918                        assert_eq!(text, "15 degrees");
919                    }
920                    _ => panic!("Expected text content"),
921                }
922            }
923            _ => panic!("Expected tool result content"),
924        }
925
926        match converted_assistant_message.clone() {
927            message::Message::Assistant { content } => {
928                assert_eq!(content.len(), 1);
929
930                match content.first() {
931                    message::AssistantContent::ToolCall(message::ToolCall { id, function }) => {
932                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
933                        assert_eq!(function.name, "get_weather");
934                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
935                    }
936                    _ => panic!("Expected tool call content"),
937                }
938            }
939            _ => panic!("Expected assistant message"),
940        }
941
942        let original_user_message: Message = converted_user_message.try_into().unwrap();
943        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
944        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
945
946        assert_eq!(user_message, original_user_message);
947        assert_eq!(assistant_message, original_assistant_message);
948        assert_eq!(tool_message, original_tool_message);
949    }
950}