rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError},
6    json_utils,
7    message::{self, DocumentMediaType, MessageError, Reasoning},
8    one_or_many::string_or_one_or_many,
9};
10use std::{convert::Infallible, str::FromStr};
11
12use super::client::Client;
13use crate::completion::CompletionRequest;
14use crate::providers::anthropic::streaming::StreamingCompletionResponse;
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17
18// ================================================================
19// Anthropic Completion API
20// ================================================================
21
22/// `claude-opus-4-0` completion model
23pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
24
25/// `claude-sonnet-4-0` completion model
26pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
27
28/// `claude-3-7-sonnet-latest` completion model
29pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
30
31/// `claude-3-5-sonnet-latest` completion model
32pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33
34/// `claude-3-5-haiku-latest` completion model
35pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
36
37/// `claude-3-5-haiku-latest` completion model
38pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
39
40/// `claude-3-sonnet-20240229` completion model
41pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
42
43/// `claude-3-haiku-20240307` completion model
44pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
45
46pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
47pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
48pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
49
50#[derive(Debug, Deserialize, Serialize)]
51pub struct CompletionResponse {
52    pub content: Vec<Content>,
53    pub id: String,
54    pub model: String,
55    pub role: String,
56    pub stop_reason: Option<String>,
57    pub stop_sequence: Option<String>,
58    pub usage: Usage,
59}
60
61#[derive(Debug, Deserialize, Serialize)]
62pub struct Usage {
63    pub input_tokens: u64,
64    pub cache_read_input_tokens: Option<u64>,
65    pub cache_creation_input_tokens: Option<u64>,
66    pub output_tokens: u64,
67}
68
69impl std::fmt::Display for Usage {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        write!(
72            f,
73            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
74            self.input_tokens,
75            match self.cache_read_input_tokens {
76                Some(token) => token.to_string(),
77                None => "n/a".to_string(),
78            },
79            match self.cache_creation_input_tokens {
80                Some(token) => token.to_string(),
81                None => "n/a".to_string(),
82            },
83            self.output_tokens
84        )
85    }
86}
87
88#[derive(Debug, Deserialize, Serialize)]
89pub struct ToolDefinition {
90    pub name: String,
91    pub description: Option<String>,
92    pub input_schema: serde_json::Value,
93}
94
95#[derive(Debug, Deserialize, Serialize)]
96#[serde(tag = "type", rename_all = "snake_case")]
97pub enum CacheControl {
98    Ephemeral,
99}
100
101impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
102    type Error = CompletionError;
103
104    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
105        let content = response
106            .content
107            .iter()
108            .map(|content| {
109                Ok(match content {
110                    Content::Text { text } => completion::AssistantContent::text(text),
111                    Content::ToolUse { id, name, input } => {
112                        completion::AssistantContent::tool_call(id, name, input.clone())
113                    }
114                    _ => {
115                        return Err(CompletionError::ResponseError(
116                            "Response did not contain a message or tool call".into(),
117                        ));
118                    }
119                })
120            })
121            .collect::<Result<Vec<_>, _>>()?;
122
123        let choice = OneOrMany::many(content).map_err(|_| {
124            CompletionError::ResponseError(
125                "Response contained no message or tool call (empty)".to_owned(),
126            )
127        })?;
128
129        let usage = completion::Usage {
130            input_tokens: response.usage.input_tokens,
131            output_tokens: response.usage.output_tokens,
132            total_tokens: response.usage.input_tokens + response.usage.output_tokens,
133        };
134
135        Ok(completion::CompletionResponse {
136            choice,
137            usage,
138            raw_response: response,
139        })
140    }
141}
142
143#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
144pub struct Message {
145    pub role: Role,
146    #[serde(deserialize_with = "string_or_one_or_many")]
147    pub content: OneOrMany<Content>,
148}
149
150#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
151#[serde(rename_all = "lowercase")]
152pub enum Role {
153    User,
154    Assistant,
155}
156
157#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
158#[serde(tag = "type", rename_all = "snake_case")]
159pub enum Content {
160    Text {
161        text: String,
162    },
163    Image {
164        source: ImageSource,
165    },
166    ToolUse {
167        id: String,
168        name: String,
169        input: serde_json::Value,
170    },
171    ToolResult {
172        tool_use_id: String,
173        #[serde(deserialize_with = "string_or_one_or_many")]
174        content: OneOrMany<ToolResultContent>,
175        #[serde(skip_serializing_if = "Option::is_none")]
176        is_error: Option<bool>,
177    },
178    Document {
179        source: DocumentSource,
180    },
181    Thinking {
182        thinking: String,
183        signature: Option<String>,
184    },
185}
186
187impl FromStr for Content {
188    type Err = Infallible;
189
190    fn from_str(s: &str) -> Result<Self, Self::Err> {
191        Ok(Content::Text { text: s.to_owned() })
192    }
193}
194
195#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
196#[serde(tag = "type", rename_all = "snake_case")]
197pub enum ToolResultContent {
198    Text { text: String },
199    Image(ImageSource),
200}
201
202impl FromStr for ToolResultContent {
203    type Err = Infallible;
204
205    fn from_str(s: &str) -> Result<Self, Self::Err> {
206        Ok(ToolResultContent::Text { text: s.to_owned() })
207    }
208}
209
210#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
211pub struct ImageSource {
212    pub data: String,
213    pub media_type: ImageFormat,
214    pub r#type: SourceType,
215}
216
217#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
218pub struct DocumentSource {
219    pub data: String,
220    pub media_type: DocumentFormat,
221    pub r#type: SourceType,
222}
223
224#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
225#[serde(rename_all = "lowercase")]
226pub enum ImageFormat {
227    #[serde(rename = "image/jpeg")]
228    JPEG,
229    #[serde(rename = "image/png")]
230    PNG,
231    #[serde(rename = "image/gif")]
232    GIF,
233    #[serde(rename = "image/webp")]
234    WEBP,
235}
236
237/// The document format to be used.
238///
239/// Currently, Anthropic only supports PDF for text documents over the API (within a message). You can find more information about this here: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
240#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
241#[serde(rename_all = "lowercase")]
242pub enum DocumentFormat {
243    #[serde(rename = "application/pdf")]
244    PDF,
245}
246
247#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
248#[serde(rename_all = "lowercase")]
249pub enum SourceType {
250    BASE64,
251}
252
253impl From<String> for Content {
254    fn from(text: String) -> Self {
255        Content::Text { text }
256    }
257}
258
259impl From<String> for ToolResultContent {
260    fn from(text: String) -> Self {
261        ToolResultContent::Text { text }
262    }
263}
264
265impl TryFrom<message::ContentFormat> for SourceType {
266    type Error = MessageError;
267
268    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
269        match format {
270            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
271            message::ContentFormat::String => Err(MessageError::ConversionError(
272                "Image urls are not supported in Anthropic".to_owned(),
273            )),
274        }
275    }
276}
277
278impl From<SourceType> for message::ContentFormat {
279    fn from(source_type: SourceType) -> Self {
280        match source_type {
281            SourceType::BASE64 => message::ContentFormat::Base64,
282        }
283    }
284}
285
286impl TryFrom<message::ImageMediaType> for ImageFormat {
287    type Error = MessageError;
288
289    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
290        Ok(match media_type {
291            message::ImageMediaType::JPEG => ImageFormat::JPEG,
292            message::ImageMediaType::PNG => ImageFormat::PNG,
293            message::ImageMediaType::GIF => ImageFormat::GIF,
294            message::ImageMediaType::WEBP => ImageFormat::WEBP,
295            _ => {
296                return Err(MessageError::ConversionError(
297                    format!("Unsupported image media type: {media_type:?}").to_owned(),
298                ));
299            }
300        })
301    }
302}
303
304impl From<ImageFormat> for message::ImageMediaType {
305    fn from(format: ImageFormat) -> Self {
306        match format {
307            ImageFormat::JPEG => message::ImageMediaType::JPEG,
308            ImageFormat::PNG => message::ImageMediaType::PNG,
309            ImageFormat::GIF => message::ImageMediaType::GIF,
310            ImageFormat::WEBP => message::ImageMediaType::WEBP,
311        }
312    }
313}
314
315impl TryFrom<DocumentMediaType> for DocumentFormat {
316    type Error = MessageError;
317    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
318        if !matches!(value, DocumentMediaType::PDF) {
319            return Err(MessageError::ConversionError(
320                "Anthropic only supports PDF documents".to_string(),
321            ));
322        };
323
324        Ok(DocumentFormat::PDF)
325    }
326}
327
328impl From<message::AssistantContent> for Content {
329    fn from(text: message::AssistantContent) -> Self {
330        match text {
331            message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
332            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
333                Content::ToolUse {
334                    id,
335                    name: function.name,
336                    input: function.arguments,
337                }
338            }
339            message::AssistantContent::Reasoning(Reasoning { reasoning, id }) => {
340                Content::Thinking {
341                    thinking: reasoning.first().cloned().unwrap_or(String::new()),
342                    signature: id,
343                }
344            }
345        }
346    }
347}
348
349impl TryFrom<message::Message> for Message {
350    type Error = MessageError;
351
352    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
353        Ok(match message {
354            message::Message::User { content } => Message {
355                role: Role::User,
356                content: content.try_map(|content| match content {
357                    message::UserContent::Text(message::Text { text }) => {
358                        Ok(Content::Text { text })
359                    }
360                    message::UserContent::ToolResult(message::ToolResult {
361                        id, content, ..
362                    }) => Ok(Content::ToolResult {
363                        tool_use_id: id,
364                        content: content.try_map(|content| match content {
365                            message::ToolResultContent::Text(message::Text { text }) => {
366                                Ok(ToolResultContent::Text { text })
367                            }
368                            message::ToolResultContent::Image(image) => {
369                                let media_type =
370                                    image.media_type.ok_or(MessageError::ConversionError(
371                                        "Image media type is required".to_owned(),
372                                    ))?;
373                                let format = image.format.ok_or(MessageError::ConversionError(
374                                    "Image format is required".to_owned(),
375                                ))?;
376                                Ok(ToolResultContent::Image(ImageSource {
377                                    data: image.data,
378                                    media_type: media_type.try_into()?,
379                                    r#type: format.try_into()?,
380                                }))
381                            }
382                        })?,
383                        is_error: None,
384                    }),
385                    message::UserContent::Image(message::Image {
386                        data,
387                        format,
388                        media_type,
389                        ..
390                    }) => {
391                        let source = ImageSource {
392                            data,
393                            media_type: match media_type {
394                                Some(media_type) => media_type.try_into()?,
395                                None => {
396                                    return Err(MessageError::ConversionError(
397                                        "Image media type is required".to_owned(),
398                                    ));
399                                }
400                            },
401                            r#type: match format {
402                                Some(format) => format.try_into()?,
403                                None => SourceType::BASE64,
404                            },
405                        };
406                        Ok(Content::Image { source })
407                    }
408                    message::UserContent::Document(message::Document {
409                        data,
410                        format,
411                        media_type,
412                        ..
413                    }) => {
414                        let Some(media_type) = media_type else {
415                            return Err(MessageError::ConversionError(
416                                "Document media type is required".to_string(),
417                            ));
418                        };
419
420                        let source = DocumentSource {
421                            data,
422                            media_type: media_type.try_into()?,
423                            r#type: match format {
424                                Some(format) => format.try_into()?,
425                                None => SourceType::BASE64,
426                            },
427                        };
428                        Ok(Content::Document { source })
429                    }
430                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
431                        "Audio is not supported in Anthropic".to_owned(),
432                    )),
433                    message::UserContent::Video { .. } => Err(MessageError::ConversionError(
434                        "Audio is not supported in Anthropic".to_owned(),
435                    )),
436                })?,
437            },
438
439            message::Message::Assistant { content, .. } => Message {
440                content: content.map(|content| content.into()),
441                role: Role::Assistant,
442            },
443        })
444    }
445}
446
447impl TryFrom<Content> for message::AssistantContent {
448    type Error = MessageError;
449
450    fn try_from(content: Content) -> Result<Self, Self::Error> {
451        Ok(match content {
452            Content::Text { text } => message::AssistantContent::text(text),
453            Content::ToolUse { id, name, input } => {
454                message::AssistantContent::tool_call(id, name, input)
455            }
456            Content::Thinking {
457                thinking,
458                signature,
459            } => message::AssistantContent::Reasoning(
460                Reasoning::new(&thinking).optional_id(signature),
461            ),
462            _ => {
463                return Err(MessageError::ConversionError(
464                    format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
465                ));
466            }
467        })
468    }
469}
470
471impl From<ToolResultContent> for message::ToolResultContent {
472    fn from(content: ToolResultContent) -> Self {
473        match content {
474            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
475            ToolResultContent::Image(ImageSource {
476                data,
477                media_type: format,
478                r#type,
479            }) => message::ToolResultContent::image(
480                data,
481                Some(r#type.into()),
482                Some(format.into()),
483                None,
484            ),
485        }
486    }
487}
488
489impl TryFrom<Message> for message::Message {
490    type Error = MessageError;
491
492    fn try_from(message: Message) -> Result<Self, Self::Error> {
493        Ok(match message.role {
494            Role::User => message::Message::User {
495                content: message.content.try_map(|content| {
496                    Ok(match content {
497                        Content::Text { text } => message::UserContent::text(text),
498                        Content::ToolResult {
499                            tool_use_id,
500                            content,
501                            ..
502                        } => message::UserContent::tool_result(
503                            tool_use_id,
504                            content.map(|content| content.into()),
505                        ),
506                        Content::Image { source } => message::UserContent::Image(message::Image {
507                            data: source.data,
508                            format: Some(message::ContentFormat::Base64),
509                            media_type: Some(source.media_type.into()),
510                            detail: None,
511                            additional_params: None,
512                        }),
513                        Content::Document { source } => message::UserContent::document(
514                            source.data,
515                            Some(message::ContentFormat::Base64),
516                            Some(message::DocumentMediaType::PDF),
517                        ),
518                        _ => {
519                            return Err(MessageError::ConversionError(
520                                "Unsupported content type for User role".to_owned(),
521                            ));
522                        }
523                    })
524                })?,
525            },
526            Role::Assistant => match message.content.first() {
527                Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
528                    message::Message::Assistant {
529                        id: None,
530                        content: message.content.try_map(|content| content.try_into())?,
531                    }
532                }
533
534                _ => {
535                    return Err(MessageError::ConversionError(
536                        format!("Unsupported message for Assistant role: {message:?}").to_owned(),
537                    ));
538                }
539            },
540        })
541    }
542}
543
544#[derive(Clone)]
545pub struct CompletionModel {
546    pub(crate) client: Client,
547    pub model: String,
548    pub default_max_tokens: Option<u64>,
549}
550
551impl CompletionModel {
552    pub fn new(client: Client, model: &str) -> Self {
553        Self {
554            client,
555            model: model.to_string(),
556            default_max_tokens: calculate_max_tokens(model),
557        }
558    }
559}
560
561/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
562/// set or if set too high, the request will fail. The following values are based on the models
563/// available at the time of writing.
564///
565/// Dev Note: This is really bad design, I'm not sure why they did it like this..
566fn calculate_max_tokens(model: &str) -> Option<u64> {
567    if model.starts_with("claude-opus-4") {
568        Some(32000)
569    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
570        Some(64000)
571    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
572        Some(8192)
573    } else if model.starts_with("claude-3-opus")
574        || model.starts_with("claude-3-sonnet")
575        || model.starts_with("claude-3-haiku")
576    {
577        Some(4096)
578    } else {
579        None
580    }
581}
582
583#[derive(Debug, Deserialize, Serialize)]
584struct Metadata {
585    user_id: Option<String>,
586}
587
588#[derive(Default, Debug, Serialize, Deserialize)]
589#[serde(tag = "type", rename_all = "snake_case")]
590pub enum ToolChoice {
591    #[default]
592    Auto,
593    Any,
594    Tool {
595        name: String,
596    },
597}
598
599impl completion::CompletionModel for CompletionModel {
600    type Response = CompletionResponse;
601    type StreamingResponse = StreamingCompletionResponse;
602
603    #[cfg_attr(feature = "worker", worker::send)]
604    async fn completion(
605        &self,
606        completion_request: completion::CompletionRequest,
607    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
608        // Note: Ideally we'd introduce provider-specific Request models to handle the
609        // specific requirements of each provider. For now, we just manually check while
610        // building the request as a raw JSON document.
611
612        // Check if max_tokens is set, required for Anthropic
613        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
614            tokens
615        } else if let Some(tokens) = self.default_max_tokens {
616            tokens
617        } else {
618            return Err(CompletionError::RequestError(
619                "`max_tokens` must be set for Anthropic".into(),
620            ));
621        };
622
623        let mut full_history = vec![];
624        if let Some(docs) = completion_request.normalized_documents() {
625            full_history.push(docs);
626        }
627        full_history.extend(completion_request.chat_history);
628
629        let full_history = full_history
630            .into_iter()
631            .map(Message::try_from)
632            .collect::<Result<Vec<Message>, _>>()?;
633
634        let mut request = json!({
635            "model": self.model,
636            "messages": full_history,
637            "max_tokens": max_tokens,
638            "system": completion_request.preamble.unwrap_or("".to_string()),
639        });
640
641        if let Some(temperature) = completion_request.temperature {
642            json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
643        }
644
645        if !completion_request.tools.is_empty() {
646            json_utils::merge_inplace(
647                &mut request,
648                json!({
649                    "tools": completion_request
650                        .tools
651                        .into_iter()
652                        .map(|tool| ToolDefinition {
653                            name: tool.name,
654                            description: Some(tool.description),
655                            input_schema: tool.parameters,
656                        })
657                        .collect::<Vec<_>>(),
658                    "tool_choice": ToolChoice::Auto,
659                }),
660            );
661        }
662
663        if let Some(ref params) = completion_request.additional_params {
664            json_utils::merge_inplace(&mut request, params.clone())
665        }
666
667        tracing::debug!("Anthropic completion request: {request}");
668
669        let response = self
670            .client
671            .post("/v1/messages")
672            .json(&request)
673            .send()
674            .await?;
675
676        if response.status().is_success() {
677            match response.json::<ApiResponse<CompletionResponse>>().await? {
678                ApiResponse::Message(completion) => {
679                    tracing::info!(target: "rig",
680                        "Anthropic completion token usage: {}",
681                        completion.usage
682                    );
683                    completion.try_into()
684                }
685                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
686            }
687        } else {
688            Err(CompletionError::ProviderError(response.text().await?))
689        }
690    }
691
692    #[cfg_attr(feature = "worker", worker::send)]
693    async fn stream(
694        &self,
695        request: CompletionRequest,
696    ) -> Result<
697        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
698        CompletionError,
699    > {
700        CompletionModel::stream(self, request).await
701    }
702}
703
704#[derive(Debug, Deserialize)]
705struct ApiErrorResponse {
706    message: String,
707}
708
709#[derive(Debug, Deserialize)]
710#[serde(tag = "type", rename_all = "snake_case")]
711enum ApiResponse<T> {
712    Message(T),
713    Error(ApiErrorResponse),
714}
715
716#[cfg(test)]
717mod tests {
718    use super::*;
719    use serde_path_to_error::deserialize;
720
721    #[test]
722    fn test_deserialize_message() {
723        let assistant_message_json = r#"
724        {
725            "role": "assistant",
726            "content": "\n\nHello there, how may I assist you today?"
727        }
728        "#;
729
730        let assistant_message_json2 = r#"
731        {
732            "role": "assistant",
733            "content": [
734                {
735                    "type": "text",
736                    "text": "\n\nHello there, how may I assist you today?"
737                },
738                {
739                    "type": "tool_use",
740                    "id": "toolu_01A09q90qw90lq917835lq9",
741                    "name": "get_weather",
742                    "input": {"location": "San Francisco, CA"}
743                }
744            ]
745        }
746        "#;
747
748        let user_message_json = r#"
749        {
750            "role": "user",
751            "content": [
752                {
753                    "type": "image",
754                    "source": {
755                        "type": "base64",
756                        "media_type": "image/jpeg",
757                        "data": "/9j/4AAQSkZJRg..."
758                    }
759                },
760                {
761                    "type": "text",
762                    "text": "What is in this image?"
763                },
764                {
765                    "type": "tool_result",
766                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
767                    "content": "15 degrees"
768                }
769            ]
770        }
771        "#;
772
773        let assistant_message: Message = {
774            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
775            deserialize(jd).unwrap_or_else(|err| {
776                panic!("Deserialization error at {}: {}", err.path(), err);
777            })
778        };
779
780        let assistant_message2: Message = {
781            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
782            deserialize(jd).unwrap_or_else(|err| {
783                panic!("Deserialization error at {}: {}", err.path(), err);
784            })
785        };
786
787        let user_message: Message = {
788            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
789            deserialize(jd).unwrap_or_else(|err| {
790                panic!("Deserialization error at {}: {}", err.path(), err);
791            })
792        };
793
794        let Message { role, content } = assistant_message;
795        assert_eq!(role, Role::Assistant);
796        assert_eq!(
797            content.first(),
798            Content::Text {
799                text: "\n\nHello there, how may I assist you today?".to_owned()
800            }
801        );
802
803        let Message { role, content } = assistant_message2;
804        {
805            assert_eq!(role, Role::Assistant);
806            assert_eq!(content.len(), 2);
807
808            let mut iter = content.into_iter();
809
810            match iter.next().unwrap() {
811                Content::Text { text } => {
812                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
813                }
814                _ => panic!("Expected text content"),
815            }
816
817            match iter.next().unwrap() {
818                Content::ToolUse { id, name, input } => {
819                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
820                    assert_eq!(name, "get_weather");
821                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
822                }
823                _ => panic!("Expected tool use content"),
824            }
825
826            assert_eq!(iter.next(), None);
827        }
828
829        let Message { role, content } = user_message;
830        {
831            assert_eq!(role, Role::User);
832            assert_eq!(content.len(), 3);
833
834            let mut iter = content.into_iter();
835
836            match iter.next().unwrap() {
837                Content::Image { source } => {
838                    assert_eq!(
839                        source,
840                        ImageSource {
841                            data: "/9j/4AAQSkZJRg...".to_owned(),
842                            media_type: ImageFormat::JPEG,
843                            r#type: SourceType::BASE64,
844                        }
845                    );
846                }
847                _ => panic!("Expected image content"),
848            }
849
850            match iter.next().unwrap() {
851                Content::Text { text } => {
852                    assert_eq!(text, "What is in this image?");
853                }
854                _ => panic!("Expected text content"),
855            }
856
857            match iter.next().unwrap() {
858                Content::ToolResult {
859                    tool_use_id,
860                    content,
861                    is_error,
862                } => {
863                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
864                    assert_eq!(
865                        content.first(),
866                        ToolResultContent::Text {
867                            text: "15 degrees".to_owned()
868                        }
869                    );
870                    assert_eq!(is_error, None);
871                }
872                _ => panic!("Expected tool result content"),
873            }
874
875            assert_eq!(iter.next(), None);
876        }
877    }
878
879    #[test]
880    fn test_message_to_message_conversion() {
881        let user_message: Message = serde_json::from_str(
882            r#"
883        {
884            "role": "user",
885            "content": [
886                {
887                    "type": "image",
888                    "source": {
889                        "type": "base64",
890                        "media_type": "image/jpeg",
891                        "data": "/9j/4AAQSkZJRg..."
892                    }
893                },
894                {
895                    "type": "text",
896                    "text": "What is in this image?"
897                },
898                {
899                    "type": "document",
900                    "source": {
901                        "type": "base64",
902                        "data": "base64_encoded_pdf_data",
903                        "media_type": "application/pdf"
904                    }
905                }
906            ]
907        }
908        "#,
909        )
910        .unwrap();
911
912        let assistant_message = Message {
913            role: Role::Assistant,
914            content: OneOrMany::one(Content::ToolUse {
915                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
916                name: "get_weather".to_string(),
917                input: json!({"location": "San Francisco, CA"}),
918            }),
919        };
920
921        let tool_message = Message {
922            role: Role::User,
923            content: OneOrMany::one(Content::ToolResult {
924                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
925                content: OneOrMany::one(ToolResultContent::Text {
926                    text: "15 degrees".to_string(),
927                }),
928                is_error: None,
929            }),
930        };
931
932        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
933        let converted_assistant_message: message::Message =
934            assistant_message.clone().try_into().unwrap();
935        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
936
937        match converted_user_message.clone() {
938            message::Message::User { content } => {
939                assert_eq!(content.len(), 3);
940
941                let mut iter = content.into_iter();
942
943                match iter.next().unwrap() {
944                    message::UserContent::Image(message::Image {
945                        data,
946                        format,
947                        media_type,
948                        ..
949                    }) => {
950                        assert_eq!(data, "/9j/4AAQSkZJRg...");
951                        assert_eq!(format.unwrap(), message::ContentFormat::Base64);
952                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
953                    }
954                    _ => panic!("Expected image content"),
955                }
956
957                match iter.next().unwrap() {
958                    message::UserContent::Text(message::Text { text }) => {
959                        assert_eq!(text, "What is in this image?");
960                    }
961                    _ => panic!("Expected text content"),
962                }
963
964                match iter.next().unwrap() {
965                    message::UserContent::Document(message::Document {
966                        data,
967                        format,
968                        media_type,
969                        ..
970                    }) => {
971                        assert_eq!(data, "base64_encoded_pdf_data");
972                        assert_eq!(format.unwrap(), message::ContentFormat::Base64);
973                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
974                    }
975                    _ => panic!("Expected document content"),
976                }
977
978                assert_eq!(iter.next(), None);
979            }
980            _ => panic!("Expected user message"),
981        }
982
983        match converted_tool_message.clone() {
984            message::Message::User { content } => {
985                let message::ToolResult { id, content, .. } = match content.first() {
986                    message::UserContent::ToolResult(tool_result) => tool_result,
987                    _ => panic!("Expected tool result content"),
988                };
989                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
990                match content.first() {
991                    message::ToolResultContent::Text(message::Text { text }) => {
992                        assert_eq!(text, "15 degrees");
993                    }
994                    _ => panic!("Expected text content"),
995                }
996            }
997            _ => panic!("Expected tool result content"),
998        }
999
1000        match converted_assistant_message.clone() {
1001            message::Message::Assistant { content, .. } => {
1002                assert_eq!(content.len(), 1);
1003
1004                match content.first() {
1005                    message::AssistantContent::ToolCall(message::ToolCall {
1006                        id, function, ..
1007                    }) => {
1008                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1009                        assert_eq!(function.name, "get_weather");
1010                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1011                    }
1012                    _ => panic!("Expected tool call content"),
1013                }
1014            }
1015            _ => panic!("Expected assistant message"),
1016        }
1017
1018        let original_user_message: Message = converted_user_message.try_into().unwrap();
1019        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1020        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1021
1022        assert_eq!(user_message, original_user_message);
1023        assert_eq!(assistant_message, original_assistant_message);
1024        assert_eq!(tool_message, original_tool_message);
1025    }
1026}