rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError},
6    json_utils,
7    message::{self, DocumentMediaType, MessageError},
8    one_or_many::string_or_one_or_many,
9};
10use std::{convert::Infallible, str::FromStr};
11
12use super::client::Client;
13use crate::completion::CompletionRequest;
14use crate::providers::anthropic::streaming::StreamingCompletionResponse;
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17
18// ================================================================
19// Anthropic Completion API
20// ================================================================
21/// `claude-3-7-sonnet-latest` completion model
22pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
23
24/// `claude-3-5-sonnet-latest` completion model
25pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
26
27/// `claude-3-5-haiku-latest` completion model
28pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
29
30/// `claude-3-5-haiku-latest` completion model
31pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
32
33/// `claude-3-sonnet-20240229` completion model
34pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
35
36/// `claude-3-haiku-20240307` completion model
37pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
38
39pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
40pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
41pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
42
43#[derive(Debug, Deserialize)]
44pub struct CompletionResponse {
45    pub content: Vec<Content>,
46    pub id: String,
47    pub model: String,
48    pub role: String,
49    pub stop_reason: Option<String>,
50    pub stop_sequence: Option<String>,
51    pub usage: Usage,
52}
53
54#[derive(Debug, Deserialize, Serialize)]
55pub struct Usage {
56    pub input_tokens: u64,
57    pub cache_read_input_tokens: Option<u64>,
58    pub cache_creation_input_tokens: Option<u64>,
59    pub output_tokens: u64,
60}
61
62impl std::fmt::Display for Usage {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        write!(
65            f,
66            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
67            self.input_tokens,
68            match self.cache_read_input_tokens {
69                Some(token) => token.to_string(),
70                None => "n/a".to_string(),
71            },
72            match self.cache_creation_input_tokens {
73                Some(token) => token.to_string(),
74                None => "n/a".to_string(),
75            },
76            self.output_tokens
77        )
78    }
79}
80
81#[derive(Debug, Deserialize, Serialize)]
82pub struct ToolDefinition {
83    pub name: String,
84    pub description: Option<String>,
85    pub input_schema: serde_json::Value,
86}
87
88#[derive(Debug, Deserialize, Serialize)]
89#[serde(tag = "type", rename_all = "snake_case")]
90pub enum CacheControl {
91    Ephemeral,
92}
93
94impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
95    type Error = CompletionError;
96
97    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
98        let content = response
99            .content
100            .iter()
101            .map(|content| {
102                Ok(match content {
103                    Content::Text { text } => completion::AssistantContent::text(text),
104                    Content::ToolUse { id, name, input } => {
105                        completion::AssistantContent::tool_call(id, name, input.clone())
106                    }
107                    _ => {
108                        return Err(CompletionError::ResponseError(
109                            "Response did not contain a message or tool call".into(),
110                        ));
111                    }
112                })
113            })
114            .collect::<Result<Vec<_>, _>>()?;
115
116        let choice = OneOrMany::many(content).map_err(|_| {
117            CompletionError::ResponseError(
118                "Response contained no message or tool call (empty)".to_owned(),
119            )
120        })?;
121
122        let usage = completion::Usage {
123            input_tokens: response.usage.input_tokens,
124            output_tokens: response.usage.output_tokens,
125            total_tokens: response.usage.input_tokens + response.usage.output_tokens,
126        };
127
128        Ok(completion::CompletionResponse {
129            choice,
130            usage,
131            raw_response: response,
132        })
133    }
134}
135
136#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
137pub struct Message {
138    pub role: Role,
139    #[serde(deserialize_with = "string_or_one_or_many")]
140    pub content: OneOrMany<Content>,
141}
142
143#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
144#[serde(rename_all = "lowercase")]
145pub enum Role {
146    User,
147    Assistant,
148}
149
150#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
151#[serde(tag = "type", rename_all = "snake_case")]
152pub enum Content {
153    Text {
154        text: String,
155    },
156    Image {
157        source: ImageSource,
158    },
159    ToolUse {
160        id: String,
161        name: String,
162        input: serde_json::Value,
163    },
164    ToolResult {
165        tool_use_id: String,
166        #[serde(deserialize_with = "string_or_one_or_many")]
167        content: OneOrMany<ToolResultContent>,
168        #[serde(skip_serializing_if = "Option::is_none")]
169        is_error: Option<bool>,
170    },
171    Document {
172        source: DocumentSource,
173    },
174}
175
176impl FromStr for Content {
177    type Err = Infallible;
178
179    fn from_str(s: &str) -> Result<Self, Self::Err> {
180        Ok(Content::Text { text: s.to_owned() })
181    }
182}
183
184#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
185#[serde(tag = "type", rename_all = "snake_case")]
186pub enum ToolResultContent {
187    Text { text: String },
188    Image(ImageSource),
189}
190
191impl FromStr for ToolResultContent {
192    type Err = Infallible;
193
194    fn from_str(s: &str) -> Result<Self, Self::Err> {
195        Ok(ToolResultContent::Text { text: s.to_owned() })
196    }
197}
198
199#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
200pub struct ImageSource {
201    pub data: String,
202    pub media_type: ImageFormat,
203    pub r#type: SourceType,
204}
205
206#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
207pub struct DocumentSource {
208    pub data: String,
209    pub media_type: DocumentFormat,
210    pub r#type: SourceType,
211}
212
213#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
214#[serde(rename_all = "lowercase")]
215pub enum ImageFormat {
216    #[serde(rename = "image/jpeg")]
217    JPEG,
218    #[serde(rename = "image/png")]
219    PNG,
220    #[serde(rename = "image/gif")]
221    GIF,
222    #[serde(rename = "image/webp")]
223    WEBP,
224}
225
226/// The document format to be used.
227///
228/// Currently, Anthropic only supports PDF for text documents over the API (within a message). You can find more information about this here: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
229#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
230#[serde(rename_all = "lowercase")]
231pub enum DocumentFormat {
232    #[serde(rename = "application/pdf")]
233    PDF,
234}
235
236#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
237#[serde(rename_all = "lowercase")]
238pub enum SourceType {
239    BASE64,
240}
241
242impl From<String> for Content {
243    fn from(text: String) -> Self {
244        Content::Text { text }
245    }
246}
247
248impl From<String> for ToolResultContent {
249    fn from(text: String) -> Self {
250        ToolResultContent::Text { text }
251    }
252}
253
254impl TryFrom<message::ContentFormat> for SourceType {
255    type Error = MessageError;
256
257    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
258        match format {
259            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
260            message::ContentFormat::String => Err(MessageError::ConversionError(
261                "Image urls are not supported in Anthropic".to_owned(),
262            )),
263        }
264    }
265}
266
267impl From<SourceType> for message::ContentFormat {
268    fn from(source_type: SourceType) -> Self {
269        match source_type {
270            SourceType::BASE64 => message::ContentFormat::Base64,
271        }
272    }
273}
274
275impl TryFrom<message::ImageMediaType> for ImageFormat {
276    type Error = MessageError;
277
278    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
279        Ok(match media_type {
280            message::ImageMediaType::JPEG => ImageFormat::JPEG,
281            message::ImageMediaType::PNG => ImageFormat::PNG,
282            message::ImageMediaType::GIF => ImageFormat::GIF,
283            message::ImageMediaType::WEBP => ImageFormat::WEBP,
284            _ => {
285                return Err(MessageError::ConversionError(
286                    format!("Unsupported image media type: {media_type:?}").to_owned(),
287                ));
288            }
289        })
290    }
291}
292
293impl From<ImageFormat> for message::ImageMediaType {
294    fn from(format: ImageFormat) -> Self {
295        match format {
296            ImageFormat::JPEG => message::ImageMediaType::JPEG,
297            ImageFormat::PNG => message::ImageMediaType::PNG,
298            ImageFormat::GIF => message::ImageMediaType::GIF,
299            ImageFormat::WEBP => message::ImageMediaType::WEBP,
300        }
301    }
302}
303
304impl TryFrom<DocumentMediaType> for DocumentFormat {
305    type Error = MessageError;
306    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
307        if !matches!(value, DocumentMediaType::PDF) {
308            return Err(MessageError::ConversionError(
309                "Anthropic only supports PDF documents".to_string(),
310            ));
311        };
312
313        Ok(DocumentFormat::PDF)
314    }
315}
316
317impl From<message::AssistantContent> for Content {
318    fn from(text: message::AssistantContent) -> Self {
319        match text {
320            message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
321            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
322                Content::ToolUse {
323                    id,
324                    name: function.name,
325                    input: function.arguments,
326                }
327            }
328        }
329    }
330}
331
332impl TryFrom<message::Message> for Message {
333    type Error = MessageError;
334
335    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
336        Ok(match message {
337            message::Message::User { content } => Message {
338                role: Role::User,
339                content: content.try_map(|content| match content {
340                    message::UserContent::Text(message::Text { text }) => {
341                        Ok(Content::Text { text })
342                    }
343                    message::UserContent::ToolResult(message::ToolResult {
344                        id, content, ..
345                    }) => Ok(Content::ToolResult {
346                        tool_use_id: id,
347                        content: content.try_map(|content| match content {
348                            message::ToolResultContent::Text(message::Text { text }) => {
349                                Ok(ToolResultContent::Text { text })
350                            }
351                            message::ToolResultContent::Image(image) => {
352                                let media_type =
353                                    image.media_type.ok_or(MessageError::ConversionError(
354                                        "Image media type is required".to_owned(),
355                                    ))?;
356                                let format = image.format.ok_or(MessageError::ConversionError(
357                                    "Image format is required".to_owned(),
358                                ))?;
359                                Ok(ToolResultContent::Image(ImageSource {
360                                    data: image.data,
361                                    media_type: media_type.try_into()?,
362                                    r#type: format.try_into()?,
363                                }))
364                            }
365                        })?,
366                        is_error: None,
367                    }),
368                    message::UserContent::Image(message::Image {
369                        data,
370                        format,
371                        media_type,
372                        ..
373                    }) => {
374                        let source = ImageSource {
375                            data,
376                            media_type: match media_type {
377                                Some(media_type) => media_type.try_into()?,
378                                None => {
379                                    return Err(MessageError::ConversionError(
380                                        "Image media type is required".to_owned(),
381                                    ));
382                                }
383                            },
384                            r#type: match format {
385                                Some(format) => format.try_into()?,
386                                None => SourceType::BASE64,
387                            },
388                        };
389                        Ok(Content::Image { source })
390                    }
391                    message::UserContent::Document(message::Document {
392                        data,
393                        format,
394                        media_type,
395                    }) => {
396                        let Some(media_type) = media_type else {
397                            return Err(MessageError::ConversionError(
398                                "Document media type is required".to_string(),
399                            ));
400                        };
401
402                        let source = DocumentSource {
403                            data,
404                            media_type: media_type.try_into()?,
405                            r#type: match format {
406                                Some(format) => format.try_into()?,
407                                None => SourceType::BASE64,
408                            },
409                        };
410                        Ok(Content::Document { source })
411                    }
412                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
413                        "Audio is not supported in Anthropic".to_owned(),
414                    )),
415                })?,
416            },
417
418            message::Message::Assistant { content, .. } => Message {
419                content: content.map(|content| content.into()),
420                role: Role::Assistant,
421            },
422        })
423    }
424}
425
426impl TryFrom<Content> for message::AssistantContent {
427    type Error = MessageError;
428
429    fn try_from(content: Content) -> Result<Self, Self::Error> {
430        Ok(match content {
431            Content::Text { text } => message::AssistantContent::text(text),
432            Content::ToolUse { id, name, input } => {
433                message::AssistantContent::tool_call(id, name, input)
434            }
435            _ => {
436                return Err(MessageError::ConversionError(
437                    format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
438                ));
439            }
440        })
441    }
442}
443
444impl From<ToolResultContent> for message::ToolResultContent {
445    fn from(content: ToolResultContent) -> Self {
446        match content {
447            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
448            ToolResultContent::Image(ImageSource {
449                data,
450                media_type: format,
451                r#type,
452            }) => message::ToolResultContent::image(
453                data,
454                Some(r#type.into()),
455                Some(format.into()),
456                None,
457            ),
458        }
459    }
460}
461
462impl TryFrom<Message> for message::Message {
463    type Error = MessageError;
464
465    fn try_from(message: Message) -> Result<Self, Self::Error> {
466        Ok(match message.role {
467            Role::User => message::Message::User {
468                content: message.content.try_map(|content| {
469                    Ok(match content {
470                        Content::Text { text } => message::UserContent::text(text),
471                        Content::ToolResult {
472                            tool_use_id,
473                            content,
474                            ..
475                        } => message::UserContent::tool_result(
476                            tool_use_id,
477                            content.map(|content| content.into()),
478                        ),
479                        Content::Image { source } => message::UserContent::Image(message::Image {
480                            data: source.data,
481                            format: Some(message::ContentFormat::Base64),
482                            media_type: Some(source.media_type.into()),
483                            detail: None,
484                        }),
485                        Content::Document { source } => message::UserContent::document(
486                            source.data,
487                            Some(message::ContentFormat::Base64),
488                            Some(message::DocumentMediaType::PDF),
489                        ),
490                        _ => {
491                            return Err(MessageError::ConversionError(
492                                "Unsupported content type for User role".to_owned(),
493                            ));
494                        }
495                    })
496                })?,
497            },
498            Role::Assistant => match message.content.first() {
499                Content::Text { .. } | Content::ToolUse { .. } => message::Message::Assistant {
500                    id: None,
501                    content: message.content.try_map(|content| content.try_into())?,
502                },
503
504                _ => {
505                    return Err(MessageError::ConversionError(
506                        format!("Unsupported message for Assistant role: {message:?}").to_owned(),
507                    ));
508                }
509            },
510        })
511    }
512}
513
514#[derive(Clone)]
515pub struct CompletionModel {
516    pub(crate) client: Client,
517    pub model: String,
518    pub default_max_tokens: Option<u64>,
519}
520
521impl CompletionModel {
522    pub fn new(client: Client, model: &str) -> Self {
523        Self {
524            client,
525            model: model.to_string(),
526            default_max_tokens: calculate_max_tokens(model),
527        }
528    }
529}
530
531/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
532/// set or if set too high, the request will fail. The following values are based on the models
533/// available at the time of writing.
534///
535/// Dev Note: This is really bad design, I'm not sure why they did it like this..
536fn calculate_max_tokens(model: &str) -> Option<u64> {
537    if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
538        Some(8192)
539    } else if model.starts_with("claude-3-opus")
540        || model.starts_with("claude-3-sonnet")
541        || model.starts_with("claude-3-haiku")
542    {
543        Some(4096)
544    } else {
545        None
546    }
547}
548
549#[derive(Debug, Deserialize, Serialize)]
550struct Metadata {
551    user_id: Option<String>,
552}
553
554#[derive(Default, Debug, Serialize, Deserialize)]
555#[serde(tag = "type", rename_all = "snake_case")]
556pub enum ToolChoice {
557    #[default]
558    Auto,
559    Any,
560    Tool {
561        name: String,
562    },
563}
564
565impl completion::CompletionModel for CompletionModel {
566    type Response = CompletionResponse;
567    type StreamingResponse = StreamingCompletionResponse;
568
569    #[cfg_attr(feature = "worker", worker::send)]
570    async fn completion(
571        &self,
572        completion_request: completion::CompletionRequest,
573    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
574        // Note: Ideally we'd introduce provider-specific Request models to handle the
575        // specific requirements of each provider. For now, we just manually check while
576        // building the request as a raw JSON document.
577
578        // Check if max_tokens is set, required for Anthropic
579        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
580            tokens
581        } else if let Some(tokens) = self.default_max_tokens {
582            tokens
583        } else {
584            return Err(CompletionError::RequestError(
585                "`max_tokens` must be set for Anthropic".into(),
586            ));
587        };
588
589        let mut full_history = vec![];
590        if let Some(docs) = completion_request.normalized_documents() {
591            full_history.push(docs);
592        }
593        full_history.extend(completion_request.chat_history);
594
595        let full_history = full_history
596            .into_iter()
597            .map(Message::try_from)
598            .collect::<Result<Vec<Message>, _>>()?;
599
600        let mut request = json!({
601            "model": self.model,
602            "messages": full_history,
603            "max_tokens": max_tokens,
604            "system": completion_request.preamble.unwrap_or("".to_string()),
605        });
606
607        if let Some(temperature) = completion_request.temperature {
608            json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
609        }
610
611        if !completion_request.tools.is_empty() {
612            json_utils::merge_inplace(
613                &mut request,
614                json!({
615                    "tools": completion_request
616                        .tools
617                        .into_iter()
618                        .map(|tool| ToolDefinition {
619                            name: tool.name,
620                            description: Some(tool.description),
621                            input_schema: tool.parameters,
622                        })
623                        .collect::<Vec<_>>(),
624                    "tool_choice": ToolChoice::Auto,
625                }),
626            );
627        }
628
629        if let Some(ref params) = completion_request.additional_params {
630            json_utils::merge_inplace(&mut request, params.clone())
631        }
632
633        tracing::debug!("Anthropic completion request: {request}");
634
635        let response = self
636            .client
637            .post("/v1/messages")
638            .json(&request)
639            .send()
640            .await?;
641
642        if response.status().is_success() {
643            match response.json::<ApiResponse<CompletionResponse>>().await? {
644                ApiResponse::Message(completion) => {
645                    tracing::info!(target: "rig",
646                        "Anthropic completion token usage: {}",
647                        completion.usage
648                    );
649                    completion.try_into()
650                }
651                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
652            }
653        } else {
654            Err(CompletionError::ProviderError(response.text().await?))
655        }
656    }
657
658    #[cfg_attr(feature = "worker", worker::send)]
659    async fn stream(
660        &self,
661        request: CompletionRequest,
662    ) -> Result<
663        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
664        CompletionError,
665    > {
666        CompletionModel::stream(self, request).await
667    }
668}
669
670#[derive(Debug, Deserialize)]
671struct ApiErrorResponse {
672    message: String,
673}
674
675#[derive(Debug, Deserialize)]
676#[serde(tag = "type", rename_all = "snake_case")]
677enum ApiResponse<T> {
678    Message(T),
679    Error(ApiErrorResponse),
680}
681
682#[cfg(test)]
683mod tests {
684    use super::*;
685    use serde_path_to_error::deserialize;
686
687    #[test]
688    fn test_deserialize_message() {
689        let assistant_message_json = r#"
690        {
691            "role": "assistant",
692            "content": "\n\nHello there, how may I assist you today?"
693        }
694        "#;
695
696        let assistant_message_json2 = r#"
697        {
698            "role": "assistant",
699            "content": [
700                {
701                    "type": "text",
702                    "text": "\n\nHello there, how may I assist you today?"
703                },
704                {
705                    "type": "tool_use",
706                    "id": "toolu_01A09q90qw90lq917835lq9",
707                    "name": "get_weather",
708                    "input": {"location": "San Francisco, CA"}
709                }
710            ]
711        }
712        "#;
713
714        let user_message_json = r#"
715        {
716            "role": "user",
717            "content": [
718                {
719                    "type": "image",
720                    "source": {
721                        "type": "base64",
722                        "media_type": "image/jpeg",
723                        "data": "/9j/4AAQSkZJRg..."
724                    }
725                },
726                {
727                    "type": "text",
728                    "text": "What is in this image?"
729                },
730                {
731                    "type": "tool_result",
732                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
733                    "content": "15 degrees"
734                }
735            ]
736        }
737        "#;
738
739        let assistant_message: Message = {
740            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
741            deserialize(jd).unwrap_or_else(|err| {
742                panic!("Deserialization error at {}: {}", err.path(), err);
743            })
744        };
745
746        let assistant_message2: Message = {
747            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
748            deserialize(jd).unwrap_or_else(|err| {
749                panic!("Deserialization error at {}: {}", err.path(), err);
750            })
751        };
752
753        let user_message: Message = {
754            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
755            deserialize(jd).unwrap_or_else(|err| {
756                panic!("Deserialization error at {}: {}", err.path(), err);
757            })
758        };
759
760        let Message { role, content } = assistant_message;
761        assert_eq!(role, Role::Assistant);
762        assert_eq!(
763            content.first(),
764            Content::Text {
765                text: "\n\nHello there, how may I assist you today?".to_owned()
766            }
767        );
768
769        let Message { role, content } = assistant_message2;
770        {
771            assert_eq!(role, Role::Assistant);
772            assert_eq!(content.len(), 2);
773
774            let mut iter = content.into_iter();
775
776            match iter.next().unwrap() {
777                Content::Text { text } => {
778                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
779                }
780                _ => panic!("Expected text content"),
781            }
782
783            match iter.next().unwrap() {
784                Content::ToolUse { id, name, input } => {
785                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
786                    assert_eq!(name, "get_weather");
787                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
788                }
789                _ => panic!("Expected tool use content"),
790            }
791
792            assert_eq!(iter.next(), None);
793        }
794
795        let Message { role, content } = user_message;
796        {
797            assert_eq!(role, Role::User);
798            assert_eq!(content.len(), 3);
799
800            let mut iter = content.into_iter();
801
802            match iter.next().unwrap() {
803                Content::Image { source } => {
804                    assert_eq!(
805                        source,
806                        ImageSource {
807                            data: "/9j/4AAQSkZJRg...".to_owned(),
808                            media_type: ImageFormat::JPEG,
809                            r#type: SourceType::BASE64,
810                        }
811                    );
812                }
813                _ => panic!("Expected image content"),
814            }
815
816            match iter.next().unwrap() {
817                Content::Text { text } => {
818                    assert_eq!(text, "What is in this image?");
819                }
820                _ => panic!("Expected text content"),
821            }
822
823            match iter.next().unwrap() {
824                Content::ToolResult {
825                    tool_use_id,
826                    content,
827                    is_error,
828                } => {
829                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
830                    assert_eq!(
831                        content.first(),
832                        ToolResultContent::Text {
833                            text: "15 degrees".to_owned()
834                        }
835                    );
836                    assert_eq!(is_error, None);
837                }
838                _ => panic!("Expected tool result content"),
839            }
840
841            assert_eq!(iter.next(), None);
842        }
843    }
844
845    #[test]
846    fn test_message_to_message_conversion() {
847        let user_message: Message = serde_json::from_str(
848            r#"
849        {
850            "role": "user",
851            "content": [
852                {
853                    "type": "image",
854                    "source": {
855                        "type": "base64",
856                        "media_type": "image/jpeg",
857                        "data": "/9j/4AAQSkZJRg..."
858                    }
859                },
860                {
861                    "type": "text",
862                    "text": "What is in this image?"
863                },
864                {
865                    "type": "document",
866                    "source": {
867                        "type": "base64",
868                        "data": "base64_encoded_pdf_data",
869                        "media_type": "application/pdf"
870                    }
871                }
872            ]
873        }
874        "#,
875        )
876        .unwrap();
877
878        let assistant_message = Message {
879            role: Role::Assistant,
880            content: OneOrMany::one(Content::ToolUse {
881                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
882                name: "get_weather".to_string(),
883                input: json!({"location": "San Francisco, CA"}),
884            }),
885        };
886
887        let tool_message = Message {
888            role: Role::User,
889            content: OneOrMany::one(Content::ToolResult {
890                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
891                content: OneOrMany::one(ToolResultContent::Text {
892                    text: "15 degrees".to_string(),
893                }),
894                is_error: None,
895            }),
896        };
897
898        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
899        let converted_assistant_message: message::Message =
900            assistant_message.clone().try_into().unwrap();
901        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
902
903        match converted_user_message.clone() {
904            message::Message::User { content } => {
905                assert_eq!(content.len(), 3);
906
907                let mut iter = content.into_iter();
908
909                match iter.next().unwrap() {
910                    message::UserContent::Image(message::Image {
911                        data,
912                        format,
913                        media_type,
914                        ..
915                    }) => {
916                        assert_eq!(data, "/9j/4AAQSkZJRg...");
917                        assert_eq!(format.unwrap(), message::ContentFormat::Base64);
918                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
919                    }
920                    _ => panic!("Expected image content"),
921                }
922
923                match iter.next().unwrap() {
924                    message::UserContent::Text(message::Text { text }) => {
925                        assert_eq!(text, "What is in this image?");
926                    }
927                    _ => panic!("Expected text content"),
928                }
929
930                match iter.next().unwrap() {
931                    message::UserContent::Document(message::Document {
932                        data,
933                        format,
934                        media_type,
935                    }) => {
936                        assert_eq!(data, "base64_encoded_pdf_data");
937                        assert_eq!(format.unwrap(), message::ContentFormat::Base64);
938                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
939                    }
940                    _ => panic!("Expected document content"),
941                }
942
943                assert_eq!(iter.next(), None);
944            }
945            _ => panic!("Expected user message"),
946        }
947
948        match converted_tool_message.clone() {
949            message::Message::User { content } => {
950                let message::ToolResult { id, content, .. } = match content.first() {
951                    message::UserContent::ToolResult(tool_result) => tool_result,
952                    _ => panic!("Expected tool result content"),
953                };
954                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
955                match content.first() {
956                    message::ToolResultContent::Text(message::Text { text }) => {
957                        assert_eq!(text, "15 degrees");
958                    }
959                    _ => panic!("Expected text content"),
960                }
961            }
962            _ => panic!("Expected tool result content"),
963        }
964
965        match converted_assistant_message.clone() {
966            message::Message::Assistant { content, .. } => {
967                assert_eq!(content.len(), 1);
968
969                match content.first() {
970                    message::AssistantContent::ToolCall(message::ToolCall {
971                        id, function, ..
972                    }) => {
973                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
974                        assert_eq!(function.name, "get_weather");
975                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
976                    }
977                    _ => panic!("Expected tool call content"),
978                }
979            }
980            _ => panic!("Expected assistant message"),
981        }
982
983        let original_user_message: Message = converted_user_message.try_into().unwrap();
984        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
985        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
986
987        assert_eq!(user_message, original_user_message);
988        assert_eq!(assistant_message, original_assistant_message);
989        assert_eq!(tool_message, original_tool_message);
990    }
991}