rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError},
6    json_utils,
7    message::{self, DocumentMediaType, MessageError, Reasoning},
8    one_or_many::string_or_one_or_many,
9};
10use std::{convert::Infallible, str::FromStr};
11
12use super::client::Client;
13use crate::completion::CompletionRequest;
14use crate::providers::anthropic::streaming::StreamingCompletionResponse;
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17
18// ================================================================
19// Anthropic Completion API
20// ================================================================
21
22/// `claude-opus-4-0` completion model
23pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
24
25/// `claude-sonnet-4-0` completion model
26pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
27
28/// `claude-3-7-sonnet-latest` completion model
29pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
30
31/// `claude-3-5-sonnet-latest` completion model
32pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33
34/// `claude-3-5-haiku-latest` completion model
35pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
36
37/// `claude-3-5-haiku-latest` completion model
38pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
39
40/// `claude-3-sonnet-20240229` completion model
41pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
42
43/// `claude-3-haiku-20240307` completion model
44pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
45
46pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
47pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
48pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
49
50#[derive(Debug, Deserialize, Serialize)]
51pub struct CompletionResponse {
52    pub content: Vec<Content>,
53    pub id: String,
54    pub model: String,
55    pub role: String,
56    pub stop_reason: Option<String>,
57    pub stop_sequence: Option<String>,
58    pub usage: Usage,
59}
60
61#[derive(Debug, Deserialize, Serialize)]
62pub struct Usage {
63    pub input_tokens: u64,
64    pub cache_read_input_tokens: Option<u64>,
65    pub cache_creation_input_tokens: Option<u64>,
66    pub output_tokens: u64,
67}
68
69impl std::fmt::Display for Usage {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        write!(
72            f,
73            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
74            self.input_tokens,
75            match self.cache_read_input_tokens {
76                Some(token) => token.to_string(),
77                None => "n/a".to_string(),
78            },
79            match self.cache_creation_input_tokens {
80                Some(token) => token.to_string(),
81                None => "n/a".to_string(),
82            },
83            self.output_tokens
84        )
85    }
86}
87
88#[derive(Debug, Deserialize, Serialize)]
89pub struct ToolDefinition {
90    pub name: String,
91    pub description: Option<String>,
92    pub input_schema: serde_json::Value,
93}
94
95#[derive(Debug, Deserialize, Serialize)]
96#[serde(tag = "type", rename_all = "snake_case")]
97pub enum CacheControl {
98    Ephemeral,
99}
100
101impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
102    type Error = CompletionError;
103
104    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
105        let content = response
106            .content
107            .iter()
108            .map(|content| {
109                Ok(match content {
110                    Content::Text { text } => completion::AssistantContent::text(text),
111                    Content::ToolUse { id, name, input } => {
112                        completion::AssistantContent::tool_call(id, name, input.clone())
113                    }
114                    _ => {
115                        return Err(CompletionError::ResponseError(
116                            "Response did not contain a message or tool call".into(),
117                        ));
118                    }
119                })
120            })
121            .collect::<Result<Vec<_>, _>>()?;
122
123        let choice = OneOrMany::many(content).map_err(|_| {
124            CompletionError::ResponseError(
125                "Response contained no message or tool call (empty)".to_owned(),
126            )
127        })?;
128
129        let usage = completion::Usage {
130            input_tokens: response.usage.input_tokens,
131            output_tokens: response.usage.output_tokens,
132            total_tokens: response.usage.input_tokens + response.usage.output_tokens,
133        };
134
135        Ok(completion::CompletionResponse {
136            choice,
137            usage,
138            raw_response: response,
139        })
140    }
141}
142
143#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
144pub struct Message {
145    pub role: Role,
146    #[serde(deserialize_with = "string_or_one_or_many")]
147    pub content: OneOrMany<Content>,
148}
149
150#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
151#[serde(rename_all = "lowercase")]
152pub enum Role {
153    User,
154    Assistant,
155}
156
157#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
158#[serde(tag = "type", rename_all = "snake_case")]
159pub enum Content {
160    Text {
161        text: String,
162    },
163    Image {
164        source: ImageSource,
165    },
166    ToolUse {
167        id: String,
168        name: String,
169        input: serde_json::Value,
170    },
171    ToolResult {
172        tool_use_id: String,
173        #[serde(deserialize_with = "string_or_one_or_many")]
174        content: OneOrMany<ToolResultContent>,
175        #[serde(skip_serializing_if = "Option::is_none")]
176        is_error: Option<bool>,
177    },
178    Document {
179        source: DocumentSource,
180    },
181    Thinking {
182        thinking: String,
183        signature: Option<String>,
184    },
185}
186
187impl FromStr for Content {
188    type Err = Infallible;
189
190    fn from_str(s: &str) -> Result<Self, Self::Err> {
191        Ok(Content::Text { text: s.to_owned() })
192    }
193}
194
195#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
196#[serde(tag = "type", rename_all = "snake_case")]
197pub enum ToolResultContent {
198    Text { text: String },
199    Image(ImageSource),
200}
201
202impl FromStr for ToolResultContent {
203    type Err = Infallible;
204
205    fn from_str(s: &str) -> Result<Self, Self::Err> {
206        Ok(ToolResultContent::Text { text: s.to_owned() })
207    }
208}
209
210#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
211pub struct ImageSource {
212    pub data: String,
213    pub media_type: ImageFormat,
214    pub r#type: SourceType,
215}
216
217#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
218pub struct DocumentSource {
219    pub data: String,
220    pub media_type: DocumentFormat,
221    pub r#type: SourceType,
222}
223
224#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
225#[serde(rename_all = "lowercase")]
226pub enum ImageFormat {
227    #[serde(rename = "image/jpeg")]
228    JPEG,
229    #[serde(rename = "image/png")]
230    PNG,
231    #[serde(rename = "image/gif")]
232    GIF,
233    #[serde(rename = "image/webp")]
234    WEBP,
235}
236
237/// The document format to be used.
238///
239/// Currently, Anthropic only supports PDF for text documents over the API (within a message). You can find more information about this here: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
240#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
241#[serde(rename_all = "lowercase")]
242pub enum DocumentFormat {
243    #[serde(rename = "application/pdf")]
244    PDF,
245}
246
247#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
248#[serde(rename_all = "lowercase")]
249pub enum SourceType {
250    BASE64,
251}
252
253impl From<String> for Content {
254    fn from(text: String) -> Self {
255        Content::Text { text }
256    }
257}
258
259impl From<String> for ToolResultContent {
260    fn from(text: String) -> Self {
261        ToolResultContent::Text { text }
262    }
263}
264
265impl TryFrom<message::ContentFormat> for SourceType {
266    type Error = MessageError;
267
268    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
269        match format {
270            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
271            message::ContentFormat::String => Err(MessageError::ConversionError(
272                "Image urls are not supported in Anthropic".to_owned(),
273            )),
274        }
275    }
276}
277
278impl From<SourceType> for message::ContentFormat {
279    fn from(source_type: SourceType) -> Self {
280        match source_type {
281            SourceType::BASE64 => message::ContentFormat::Base64,
282        }
283    }
284}
285
286impl TryFrom<message::ImageMediaType> for ImageFormat {
287    type Error = MessageError;
288
289    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
290        Ok(match media_type {
291            message::ImageMediaType::JPEG => ImageFormat::JPEG,
292            message::ImageMediaType::PNG => ImageFormat::PNG,
293            message::ImageMediaType::GIF => ImageFormat::GIF,
294            message::ImageMediaType::WEBP => ImageFormat::WEBP,
295            _ => {
296                return Err(MessageError::ConversionError(
297                    format!("Unsupported image media type: {media_type:?}").to_owned(),
298                ));
299            }
300        })
301    }
302}
303
304impl From<ImageFormat> for message::ImageMediaType {
305    fn from(format: ImageFormat) -> Self {
306        match format {
307            ImageFormat::JPEG => message::ImageMediaType::JPEG,
308            ImageFormat::PNG => message::ImageMediaType::PNG,
309            ImageFormat::GIF => message::ImageMediaType::GIF,
310            ImageFormat::WEBP => message::ImageMediaType::WEBP,
311        }
312    }
313}
314
315impl TryFrom<DocumentMediaType> for DocumentFormat {
316    type Error = MessageError;
317    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
318        if !matches!(value, DocumentMediaType::PDF) {
319            return Err(MessageError::ConversionError(
320                "Anthropic only supports PDF documents".to_string(),
321            ));
322        };
323
324        Ok(DocumentFormat::PDF)
325    }
326}
327
328impl From<message::AssistantContent> for Content {
329    fn from(text: message::AssistantContent) -> Self {
330        match text {
331            message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
332            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
333                Content::ToolUse {
334                    id,
335                    name: function.name,
336                    input: function.arguments,
337                }
338            }
339            message::AssistantContent::Reasoning(Reasoning { reasoning }) => Content::Thinking {
340                thinking: reasoning,
341                signature: None,
342            },
343        }
344    }
345}
346
347impl TryFrom<message::Message> for Message {
348    type Error = MessageError;
349
350    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
351        Ok(match message {
352            message::Message::User { content } => Message {
353                role: Role::User,
354                content: content.try_map(|content| match content {
355                    message::UserContent::Text(message::Text { text }) => {
356                        Ok(Content::Text { text })
357                    }
358                    message::UserContent::ToolResult(message::ToolResult {
359                        id, content, ..
360                    }) => Ok(Content::ToolResult {
361                        tool_use_id: id,
362                        content: content.try_map(|content| match content {
363                            message::ToolResultContent::Text(message::Text { text }) => {
364                                Ok(ToolResultContent::Text { text })
365                            }
366                            message::ToolResultContent::Image(image) => {
367                                let media_type =
368                                    image.media_type.ok_or(MessageError::ConversionError(
369                                        "Image media type is required".to_owned(),
370                                    ))?;
371                                let format = image.format.ok_or(MessageError::ConversionError(
372                                    "Image format is required".to_owned(),
373                                ))?;
374                                Ok(ToolResultContent::Image(ImageSource {
375                                    data: image.data,
376                                    media_type: media_type.try_into()?,
377                                    r#type: format.try_into()?,
378                                }))
379                            }
380                        })?,
381                        is_error: None,
382                    }),
383                    message::UserContent::Image(message::Image {
384                        data,
385                        format,
386                        media_type,
387                        ..
388                    }) => {
389                        let source = ImageSource {
390                            data,
391                            media_type: match media_type {
392                                Some(media_type) => media_type.try_into()?,
393                                None => {
394                                    return Err(MessageError::ConversionError(
395                                        "Image media type is required".to_owned(),
396                                    ));
397                                }
398                            },
399                            r#type: match format {
400                                Some(format) => format.try_into()?,
401                                None => SourceType::BASE64,
402                            },
403                        };
404                        Ok(Content::Image { source })
405                    }
406                    message::UserContent::Document(message::Document {
407                        data,
408                        format,
409                        media_type,
410                    }) => {
411                        let Some(media_type) = media_type else {
412                            return Err(MessageError::ConversionError(
413                                "Document media type is required".to_string(),
414                            ));
415                        };
416
417                        let source = DocumentSource {
418                            data,
419                            media_type: media_type.try_into()?,
420                            r#type: match format {
421                                Some(format) => format.try_into()?,
422                                None => SourceType::BASE64,
423                            },
424                        };
425                        Ok(Content::Document { source })
426                    }
427                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
428                        "Audio is not supported in Anthropic".to_owned(),
429                    )),
430                })?,
431            },
432
433            message::Message::Assistant { content, .. } => Message {
434                content: content.map(|content| content.into()),
435                role: Role::Assistant,
436            },
437        })
438    }
439}
440
441impl TryFrom<Content> for message::AssistantContent {
442    type Error = MessageError;
443
444    fn try_from(content: Content) -> Result<Self, Self::Error> {
445        Ok(match content {
446            Content::Text { text } => message::AssistantContent::text(text),
447            Content::ToolUse { id, name, input } => {
448                message::AssistantContent::tool_call(id, name, input)
449            }
450            _ => {
451                return Err(MessageError::ConversionError(
452                    format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
453                ));
454            }
455        })
456    }
457}
458
459impl From<ToolResultContent> for message::ToolResultContent {
460    fn from(content: ToolResultContent) -> Self {
461        match content {
462            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
463            ToolResultContent::Image(ImageSource {
464                data,
465                media_type: format,
466                r#type,
467            }) => message::ToolResultContent::image(
468                data,
469                Some(r#type.into()),
470                Some(format.into()),
471                None,
472            ),
473        }
474    }
475}
476
477impl TryFrom<Message> for message::Message {
478    type Error = MessageError;
479
480    fn try_from(message: Message) -> Result<Self, Self::Error> {
481        Ok(match message.role {
482            Role::User => message::Message::User {
483                content: message.content.try_map(|content| {
484                    Ok(match content {
485                        Content::Text { text } => message::UserContent::text(text),
486                        Content::ToolResult {
487                            tool_use_id,
488                            content,
489                            ..
490                        } => message::UserContent::tool_result(
491                            tool_use_id,
492                            content.map(|content| content.into()),
493                        ),
494                        Content::Image { source } => message::UserContent::Image(message::Image {
495                            data: source.data,
496                            format: Some(message::ContentFormat::Base64),
497                            media_type: Some(source.media_type.into()),
498                            detail: None,
499                        }),
500                        Content::Document { source } => message::UserContent::document(
501                            source.data,
502                            Some(message::ContentFormat::Base64),
503                            Some(message::DocumentMediaType::PDF),
504                        ),
505                        _ => {
506                            return Err(MessageError::ConversionError(
507                                "Unsupported content type for User role".to_owned(),
508                            ));
509                        }
510                    })
511                })?,
512            },
513            Role::Assistant => match message.content.first() {
514                Content::Text { .. } | Content::ToolUse { .. } => message::Message::Assistant {
515                    id: None,
516                    content: message.content.try_map(|content| content.try_into())?,
517                },
518
519                _ => {
520                    return Err(MessageError::ConversionError(
521                        format!("Unsupported message for Assistant role: {message:?}").to_owned(),
522                    ));
523                }
524            },
525        })
526    }
527}
528
529#[derive(Clone)]
530pub struct CompletionModel {
531    pub(crate) client: Client,
532    pub model: String,
533    pub default_max_tokens: Option<u64>,
534}
535
536impl CompletionModel {
537    pub fn new(client: Client, model: &str) -> Self {
538        Self {
539            client,
540            model: model.to_string(),
541            default_max_tokens: calculate_max_tokens(model),
542        }
543    }
544}
545
546/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
547/// set or if set too high, the request will fail. The following values are based on the models
548/// available at the time of writing.
549///
550/// Dev Note: This is really bad design, I'm not sure why they did it like this..
551fn calculate_max_tokens(model: &str) -> Option<u64> {
552    if model.starts_with("claude-opus-4") {
553        Some(32000)
554    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
555        Some(64000)
556    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
557        Some(8192)
558    } else if model.starts_with("claude-3-opus")
559        || model.starts_with("claude-3-sonnet")
560        || model.starts_with("claude-3-haiku")
561    {
562        Some(4096)
563    } else {
564        None
565    }
566}
567
568#[derive(Debug, Deserialize, Serialize)]
569struct Metadata {
570    user_id: Option<String>,
571}
572
573#[derive(Default, Debug, Serialize, Deserialize)]
574#[serde(tag = "type", rename_all = "snake_case")]
575pub enum ToolChoice {
576    #[default]
577    Auto,
578    Any,
579    Tool {
580        name: String,
581    },
582}
583
584impl completion::CompletionModel for CompletionModel {
585    type Response = CompletionResponse;
586    type StreamingResponse = StreamingCompletionResponse;
587
588    #[cfg_attr(feature = "worker", worker::send)]
589    async fn completion(
590        &self,
591        completion_request: completion::CompletionRequest,
592    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
593        // Note: Ideally we'd introduce provider-specific Request models to handle the
594        // specific requirements of each provider. For now, we just manually check while
595        // building the request as a raw JSON document.
596
597        // Check if max_tokens is set, required for Anthropic
598        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
599            tokens
600        } else if let Some(tokens) = self.default_max_tokens {
601            tokens
602        } else {
603            return Err(CompletionError::RequestError(
604                "`max_tokens` must be set for Anthropic".into(),
605            ));
606        };
607
608        let mut full_history = vec![];
609        if let Some(docs) = completion_request.normalized_documents() {
610            full_history.push(docs);
611        }
612        full_history.extend(completion_request.chat_history);
613
614        let full_history = full_history
615            .into_iter()
616            .map(Message::try_from)
617            .collect::<Result<Vec<Message>, _>>()?;
618
619        let mut request = json!({
620            "model": self.model,
621            "messages": full_history,
622            "max_tokens": max_tokens,
623            "system": completion_request.preamble.unwrap_or("".to_string()),
624        });
625
626        if let Some(temperature) = completion_request.temperature {
627            json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
628        }
629
630        if !completion_request.tools.is_empty() {
631            json_utils::merge_inplace(
632                &mut request,
633                json!({
634                    "tools": completion_request
635                        .tools
636                        .into_iter()
637                        .map(|tool| ToolDefinition {
638                            name: tool.name,
639                            description: Some(tool.description),
640                            input_schema: tool.parameters,
641                        })
642                        .collect::<Vec<_>>(),
643                    "tool_choice": ToolChoice::Auto,
644                }),
645            );
646        }
647
648        if let Some(ref params) = completion_request.additional_params {
649            json_utils::merge_inplace(&mut request, params.clone())
650        }
651
652        tracing::debug!("Anthropic completion request: {request}");
653
654        let response = self
655            .client
656            .post("/v1/messages")
657            .json(&request)
658            .send()
659            .await?;
660
661        if response.status().is_success() {
662            match response.json::<ApiResponse<CompletionResponse>>().await? {
663                ApiResponse::Message(completion) => {
664                    tracing::info!(target: "rig",
665                        "Anthropic completion token usage: {}",
666                        completion.usage
667                    );
668                    completion.try_into()
669                }
670                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
671            }
672        } else {
673            Err(CompletionError::ProviderError(response.text().await?))
674        }
675    }
676
677    #[cfg_attr(feature = "worker", worker::send)]
678    async fn stream(
679        &self,
680        request: CompletionRequest,
681    ) -> Result<
682        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
683        CompletionError,
684    > {
685        CompletionModel::stream(self, request).await
686    }
687}
688
689#[derive(Debug, Deserialize)]
690struct ApiErrorResponse {
691    message: String,
692}
693
694#[derive(Debug, Deserialize)]
695#[serde(tag = "type", rename_all = "snake_case")]
696enum ApiResponse<T> {
697    Message(T),
698    Error(ApiErrorResponse),
699}
700
701#[cfg(test)]
702mod tests {
703    use super::*;
704    use serde_path_to_error::deserialize;
705
706    #[test]
707    fn test_deserialize_message() {
708        let assistant_message_json = r#"
709        {
710            "role": "assistant",
711            "content": "\n\nHello there, how may I assist you today?"
712        }
713        "#;
714
715        let assistant_message_json2 = r#"
716        {
717            "role": "assistant",
718            "content": [
719                {
720                    "type": "text",
721                    "text": "\n\nHello there, how may I assist you today?"
722                },
723                {
724                    "type": "tool_use",
725                    "id": "toolu_01A09q90qw90lq917835lq9",
726                    "name": "get_weather",
727                    "input": {"location": "San Francisco, CA"}
728                }
729            ]
730        }
731        "#;
732
733        let user_message_json = r#"
734        {
735            "role": "user",
736            "content": [
737                {
738                    "type": "image",
739                    "source": {
740                        "type": "base64",
741                        "media_type": "image/jpeg",
742                        "data": "/9j/4AAQSkZJRg..."
743                    }
744                },
745                {
746                    "type": "text",
747                    "text": "What is in this image?"
748                },
749                {
750                    "type": "tool_result",
751                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
752                    "content": "15 degrees"
753                }
754            ]
755        }
756        "#;
757
758        let assistant_message: Message = {
759            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
760            deserialize(jd).unwrap_or_else(|err| {
761                panic!("Deserialization error at {}: {}", err.path(), err);
762            })
763        };
764
765        let assistant_message2: Message = {
766            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
767            deserialize(jd).unwrap_or_else(|err| {
768                panic!("Deserialization error at {}: {}", err.path(), err);
769            })
770        };
771
772        let user_message: Message = {
773            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
774            deserialize(jd).unwrap_or_else(|err| {
775                panic!("Deserialization error at {}: {}", err.path(), err);
776            })
777        };
778
779        let Message { role, content } = assistant_message;
780        assert_eq!(role, Role::Assistant);
781        assert_eq!(
782            content.first(),
783            Content::Text {
784                text: "\n\nHello there, how may I assist you today?".to_owned()
785            }
786        );
787
788        let Message { role, content } = assistant_message2;
789        {
790            assert_eq!(role, Role::Assistant);
791            assert_eq!(content.len(), 2);
792
793            let mut iter = content.into_iter();
794
795            match iter.next().unwrap() {
796                Content::Text { text } => {
797                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
798                }
799                _ => panic!("Expected text content"),
800            }
801
802            match iter.next().unwrap() {
803                Content::ToolUse { id, name, input } => {
804                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
805                    assert_eq!(name, "get_weather");
806                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
807                }
808                _ => panic!("Expected tool use content"),
809            }
810
811            assert_eq!(iter.next(), None);
812        }
813
814        let Message { role, content } = user_message;
815        {
816            assert_eq!(role, Role::User);
817            assert_eq!(content.len(), 3);
818
819            let mut iter = content.into_iter();
820
821            match iter.next().unwrap() {
822                Content::Image { source } => {
823                    assert_eq!(
824                        source,
825                        ImageSource {
826                            data: "/9j/4AAQSkZJRg...".to_owned(),
827                            media_type: ImageFormat::JPEG,
828                            r#type: SourceType::BASE64,
829                        }
830                    );
831                }
832                _ => panic!("Expected image content"),
833            }
834
835            match iter.next().unwrap() {
836                Content::Text { text } => {
837                    assert_eq!(text, "What is in this image?");
838                }
839                _ => panic!("Expected text content"),
840            }
841
842            match iter.next().unwrap() {
843                Content::ToolResult {
844                    tool_use_id,
845                    content,
846                    is_error,
847                } => {
848                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
849                    assert_eq!(
850                        content.first(),
851                        ToolResultContent::Text {
852                            text: "15 degrees".to_owned()
853                        }
854                    );
855                    assert_eq!(is_error, None);
856                }
857                _ => panic!("Expected tool result content"),
858            }
859
860            assert_eq!(iter.next(), None);
861        }
862    }
863
864    #[test]
865    fn test_message_to_message_conversion() {
866        let user_message: Message = serde_json::from_str(
867            r#"
868        {
869            "role": "user",
870            "content": [
871                {
872                    "type": "image",
873                    "source": {
874                        "type": "base64",
875                        "media_type": "image/jpeg",
876                        "data": "/9j/4AAQSkZJRg..."
877                    }
878                },
879                {
880                    "type": "text",
881                    "text": "What is in this image?"
882                },
883                {
884                    "type": "document",
885                    "source": {
886                        "type": "base64",
887                        "data": "base64_encoded_pdf_data",
888                        "media_type": "application/pdf"
889                    }
890                }
891            ]
892        }
893        "#,
894        )
895        .unwrap();
896
897        let assistant_message = Message {
898            role: Role::Assistant,
899            content: OneOrMany::one(Content::ToolUse {
900                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
901                name: "get_weather".to_string(),
902                input: json!({"location": "San Francisco, CA"}),
903            }),
904        };
905
906        let tool_message = Message {
907            role: Role::User,
908            content: OneOrMany::one(Content::ToolResult {
909                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
910                content: OneOrMany::one(ToolResultContent::Text {
911                    text: "15 degrees".to_string(),
912                }),
913                is_error: None,
914            }),
915        };
916
917        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
918        let converted_assistant_message: message::Message =
919            assistant_message.clone().try_into().unwrap();
920        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
921
922        match converted_user_message.clone() {
923            message::Message::User { content } => {
924                assert_eq!(content.len(), 3);
925
926                let mut iter = content.into_iter();
927
928                match iter.next().unwrap() {
929                    message::UserContent::Image(message::Image {
930                        data,
931                        format,
932                        media_type,
933                        ..
934                    }) => {
935                        assert_eq!(data, "/9j/4AAQSkZJRg...");
936                        assert_eq!(format.unwrap(), message::ContentFormat::Base64);
937                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
938                    }
939                    _ => panic!("Expected image content"),
940                }
941
942                match iter.next().unwrap() {
943                    message::UserContent::Text(message::Text { text }) => {
944                        assert_eq!(text, "What is in this image?");
945                    }
946                    _ => panic!("Expected text content"),
947                }
948
949                match iter.next().unwrap() {
950                    message::UserContent::Document(message::Document {
951                        data,
952                        format,
953                        media_type,
954                    }) => {
955                        assert_eq!(data, "base64_encoded_pdf_data");
956                        assert_eq!(format.unwrap(), message::ContentFormat::Base64);
957                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
958                    }
959                    _ => panic!("Expected document content"),
960                }
961
962                assert_eq!(iter.next(), None);
963            }
964            _ => panic!("Expected user message"),
965        }
966
967        match converted_tool_message.clone() {
968            message::Message::User { content } => {
969                let message::ToolResult { id, content, .. } = match content.first() {
970                    message::UserContent::ToolResult(tool_result) => tool_result,
971                    _ => panic!("Expected tool result content"),
972                };
973                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
974                match content.first() {
975                    message::ToolResultContent::Text(message::Text { text }) => {
976                        assert_eq!(text, "15 degrees");
977                    }
978                    _ => panic!("Expected text content"),
979                }
980            }
981            _ => panic!("Expected tool result content"),
982        }
983
984        match converted_assistant_message.clone() {
985            message::Message::Assistant { content, .. } => {
986                assert_eq!(content.len(), 1);
987
988                match content.first() {
989                    message::AssistantContent::ToolCall(message::ToolCall {
990                        id, function, ..
991                    }) => {
992                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
993                        assert_eq!(function.name, "get_weather");
994                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
995                    }
996                    _ => panic!("Expected tool call content"),
997                }
998            }
999            _ => panic!("Expected assistant message"),
1000        }
1001
1002        let original_user_message: Message = converted_user_message.try_into().unwrap();
1003        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1004        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1005
1006        assert_eq!(user_message, original_user_message);
1007        assert_eq!(assistant_message, original_assistant_message);
1008        assert_eq!(tool_message, original_tool_message);
1009    }
1010}