rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError},
6    json_utils,
7    message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
8    one_or_many::string_or_one_or_many,
9};
10use std::{convert::Infallible, str::FromStr};
11
12use super::client::Client;
13use crate::completion::CompletionRequest;
14use crate::providers::anthropic::streaming::StreamingCompletionResponse;
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17
18// ================================================================
19// Anthropic Completion API
20// ================================================================
21
22/// `claude-opus-4-0` completion model
23pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
24
25/// `claude-sonnet-4-0` completion model
26pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
27
28/// `claude-3-7-sonnet-latest` completion model
29pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
30
31/// `claude-3-5-sonnet-latest` completion model
32pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33
34/// `claude-3-5-haiku-latest` completion model
35pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
36
37/// `claude-3-5-haiku-latest` completion model
38pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
39
40/// `claude-3-sonnet-20240229` completion model
41pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
42
43/// `claude-3-haiku-20240307` completion model
44pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
45
46pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
47pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
48pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
49
50#[derive(Debug, Deserialize, Serialize)]
51pub struct CompletionResponse {
52    pub content: Vec<Content>,
53    pub id: String,
54    pub model: String,
55    pub role: String,
56    pub stop_reason: Option<String>,
57    pub stop_sequence: Option<String>,
58    pub usage: Usage,
59}
60
61#[derive(Debug, Deserialize, Serialize)]
62pub struct Usage {
63    pub input_tokens: u64,
64    pub cache_read_input_tokens: Option<u64>,
65    pub cache_creation_input_tokens: Option<u64>,
66    pub output_tokens: u64,
67}
68
69impl std::fmt::Display for Usage {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        write!(
72            f,
73            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
74            self.input_tokens,
75            match self.cache_read_input_tokens {
76                Some(token) => token.to_string(),
77                None => "n/a".to_string(),
78            },
79            match self.cache_creation_input_tokens {
80                Some(token) => token.to_string(),
81                None => "n/a".to_string(),
82            },
83            self.output_tokens
84        )
85    }
86}
87
88#[derive(Debug, Deserialize, Serialize)]
89pub struct ToolDefinition {
90    pub name: String,
91    pub description: Option<String>,
92    pub input_schema: serde_json::Value,
93}
94
95#[derive(Debug, Deserialize, Serialize)]
96#[serde(tag = "type", rename_all = "snake_case")]
97pub enum CacheControl {
98    Ephemeral,
99}
100
101impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
102    type Error = CompletionError;
103
104    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
105        let content = response
106            .content
107            .iter()
108            .map(|content| {
109                Ok(match content {
110                    Content::Text { text } => completion::AssistantContent::text(text),
111                    Content::ToolUse { id, name, input } => {
112                        completion::AssistantContent::tool_call(id, name, input.clone())
113                    }
114                    _ => {
115                        return Err(CompletionError::ResponseError(
116                            "Response did not contain a message or tool call".into(),
117                        ));
118                    }
119                })
120            })
121            .collect::<Result<Vec<_>, _>>()?;
122
123        let choice = OneOrMany::many(content).map_err(|_| {
124            CompletionError::ResponseError(
125                "Response contained no message or tool call (empty)".to_owned(),
126            )
127        })?;
128
129        let usage = completion::Usage {
130            input_tokens: response.usage.input_tokens,
131            output_tokens: response.usage.output_tokens,
132            total_tokens: response.usage.input_tokens + response.usage.output_tokens,
133        };
134
135        Ok(completion::CompletionResponse {
136            choice,
137            usage,
138            raw_response: response,
139        })
140    }
141}
142
143#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
144pub struct Message {
145    pub role: Role,
146    #[serde(deserialize_with = "string_or_one_or_many")]
147    pub content: OneOrMany<Content>,
148}
149
150#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
151#[serde(rename_all = "lowercase")]
152pub enum Role {
153    User,
154    Assistant,
155}
156
157#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
158#[serde(tag = "type", rename_all = "snake_case")]
159pub enum Content {
160    Text {
161        text: String,
162    },
163    Image {
164        source: ImageSource,
165    },
166    ToolUse {
167        id: String,
168        name: String,
169        input: serde_json::Value,
170    },
171    ToolResult {
172        tool_use_id: String,
173        #[serde(deserialize_with = "string_or_one_or_many")]
174        content: OneOrMany<ToolResultContent>,
175        #[serde(skip_serializing_if = "Option::is_none")]
176        is_error: Option<bool>,
177    },
178    Document {
179        source: DocumentSource,
180    },
181    Thinking {
182        thinking: String,
183        signature: Option<String>,
184    },
185}
186
187impl FromStr for Content {
188    type Err = Infallible;
189
190    fn from_str(s: &str) -> Result<Self, Self::Err> {
191        Ok(Content::Text { text: s.to_owned() })
192    }
193}
194
195#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
196#[serde(tag = "type", rename_all = "snake_case")]
197pub enum ToolResultContent {
198    Text { text: String },
199    Image(ImageSource),
200}
201
202impl FromStr for ToolResultContent {
203    type Err = Infallible;
204
205    fn from_str(s: &str) -> Result<Self, Self::Err> {
206        Ok(ToolResultContent::Text { text: s.to_owned() })
207    }
208}
209
210#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
211pub struct ImageSource {
212    pub data: String,
213    pub media_type: ImageFormat,
214    pub r#type: SourceType,
215}
216
217#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
218pub struct DocumentSource {
219    pub data: String,
220    pub media_type: DocumentFormat,
221    pub r#type: SourceType,
222}
223
224#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
225#[serde(rename_all = "lowercase")]
226pub enum ImageFormat {
227    #[serde(rename = "image/jpeg")]
228    JPEG,
229    #[serde(rename = "image/png")]
230    PNG,
231    #[serde(rename = "image/gif")]
232    GIF,
233    #[serde(rename = "image/webp")]
234    WEBP,
235}
236
237/// The document format to be used.
238///
239/// Currently, Anthropic only supports PDF for text documents over the API (within a message). You can find more information about this here: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
240#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
241#[serde(rename_all = "lowercase")]
242pub enum DocumentFormat {
243    #[serde(rename = "application/pdf")]
244    PDF,
245}
246
247#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
248#[serde(rename_all = "lowercase")]
249pub enum SourceType {
250    BASE64,
251}
252
253impl From<String> for Content {
254    fn from(text: String) -> Self {
255        Content::Text { text }
256    }
257}
258
259impl From<String> for ToolResultContent {
260    fn from(text: String) -> Self {
261        ToolResultContent::Text { text }
262    }
263}
264
265impl TryFrom<message::ContentFormat> for SourceType {
266    type Error = MessageError;
267
268    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
269        match format {
270            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
271            message::ContentFormat::String => Err(MessageError::ConversionError(
272                "Image urls are not supported in Anthropic".to_owned(),
273            )),
274        }
275    }
276}
277
278impl From<SourceType> for message::ContentFormat {
279    fn from(source_type: SourceType) -> Self {
280        match source_type {
281            SourceType::BASE64 => message::ContentFormat::Base64,
282        }
283    }
284}
285
286impl TryFrom<message::ImageMediaType> for ImageFormat {
287    type Error = MessageError;
288
289    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
290        Ok(match media_type {
291            message::ImageMediaType::JPEG => ImageFormat::JPEG,
292            message::ImageMediaType::PNG => ImageFormat::PNG,
293            message::ImageMediaType::GIF => ImageFormat::GIF,
294            message::ImageMediaType::WEBP => ImageFormat::WEBP,
295            _ => {
296                return Err(MessageError::ConversionError(
297                    format!("Unsupported image media type: {media_type:?}").to_owned(),
298                ));
299            }
300        })
301    }
302}
303
304impl From<ImageFormat> for message::ImageMediaType {
305    fn from(format: ImageFormat) -> Self {
306        match format {
307            ImageFormat::JPEG => message::ImageMediaType::JPEG,
308            ImageFormat::PNG => message::ImageMediaType::PNG,
309            ImageFormat::GIF => message::ImageMediaType::GIF,
310            ImageFormat::WEBP => message::ImageMediaType::WEBP,
311        }
312    }
313}
314
315impl TryFrom<DocumentMediaType> for DocumentFormat {
316    type Error = MessageError;
317    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
318        if !matches!(value, DocumentMediaType::PDF) {
319            return Err(MessageError::ConversionError(
320                "Anthropic only supports PDF documents".to_string(),
321            ));
322        };
323
324        Ok(DocumentFormat::PDF)
325    }
326}
327
328impl From<message::AssistantContent> for Content {
329    fn from(text: message::AssistantContent) -> Self {
330        match text {
331            message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
332            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
333                Content::ToolUse {
334                    id,
335                    name: function.name,
336                    input: function.arguments,
337                }
338            }
339            message::AssistantContent::Reasoning(Reasoning { reasoning, id }) => {
340                Content::Thinking {
341                    thinking: reasoning.first().cloned().unwrap_or(String::new()),
342                    signature: id,
343                }
344            }
345        }
346    }
347}
348
349impl TryFrom<message::Message> for Message {
350    type Error = MessageError;
351
352    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
353        Ok(match message {
354            message::Message::User { content } => Message {
355                role: Role::User,
356                content: content.try_map(|content| match content {
357                    message::UserContent::Text(message::Text { text }) => {
358                        Ok(Content::Text { text })
359                    }
360                    message::UserContent::ToolResult(message::ToolResult {
361                        id, content, ..
362                    }) => Ok(Content::ToolResult {
363                        tool_use_id: id,
364                        content: content.try_map(|content| match content {
365                            message::ToolResultContent::Text(message::Text { text }) => {
366                                Ok(ToolResultContent::Text { text })
367                            }
368                            message::ToolResultContent::Image(image) => {
369                                let DocumentSourceKind::Base64(data) = image.data else {
370                                    return Err(MessageError::ConversionError(
371                                        "Only base64 strings can be used with the Anthropic API"
372                                            .to_string(),
373                                    ));
374                                };
375                                let media_type =
376                                    image.media_type.ok_or(MessageError::ConversionError(
377                                        "Image media type is required".to_owned(),
378                                    ))?;
379                                Ok(ToolResultContent::Image(ImageSource {
380                                    data,
381                                    media_type: media_type.try_into()?,
382                                    r#type: SourceType::BASE64,
383                                }))
384                            }
385                        })?,
386                        is_error: None,
387                    }),
388                    message::UserContent::Image(message::Image {
389                        data, media_type, ..
390                    }) => {
391                        let DocumentSourceKind::Base64(data) = data else {
392                            return Err(MessageError::ConversionError(
393                                "Only base64 strings are allowed in the Anthropic API".to_string(),
394                            ));
395                        };
396
397                        let source = ImageSource {
398                            data,
399                            media_type: match media_type {
400                                Some(media_type) => media_type.try_into()?,
401                                None => {
402                                    return Err(MessageError::ConversionError(
403                                        "Image media type is required".to_owned(),
404                                    ));
405                                }
406                            },
407                            r#type: SourceType::BASE64,
408                        };
409                        Ok(Content::Image { source })
410                    }
411                    message::UserContent::Document(message::Document {
412                        data,
413                        format,
414                        media_type,
415                        ..
416                    }) => {
417                        let Some(media_type) = media_type else {
418                            return Err(MessageError::ConversionError(
419                                "Document media type is required".to_string(),
420                            ));
421                        };
422
423                        let source = DocumentSource {
424                            data,
425                            media_type: media_type.try_into()?,
426                            r#type: match format {
427                                Some(format) => format.try_into()?,
428                                None => SourceType::BASE64,
429                            },
430                        };
431                        Ok(Content::Document { source })
432                    }
433                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
434                        "Audio is not supported in Anthropic".to_owned(),
435                    )),
436                    message::UserContent::Video { .. } => Err(MessageError::ConversionError(
437                        "Audio is not supported in Anthropic".to_owned(),
438                    )),
439                })?,
440            },
441
442            message::Message::Assistant { content, .. } => Message {
443                content: content.map(|content| content.into()),
444                role: Role::Assistant,
445            },
446        })
447    }
448}
449
450impl TryFrom<Content> for message::AssistantContent {
451    type Error = MessageError;
452
453    fn try_from(content: Content) -> Result<Self, Self::Error> {
454        Ok(match content {
455            Content::Text { text } => message::AssistantContent::text(text),
456            Content::ToolUse { id, name, input } => {
457                message::AssistantContent::tool_call(id, name, input)
458            }
459            Content::Thinking {
460                thinking,
461                signature,
462            } => message::AssistantContent::Reasoning(
463                Reasoning::new(&thinking).optional_id(signature),
464            ),
465            _ => {
466                return Err(MessageError::ConversionError(
467                    format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
468                ));
469            }
470        })
471    }
472}
473
474impl From<ToolResultContent> for message::ToolResultContent {
475    fn from(content: ToolResultContent) -> Self {
476        match content {
477            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
478            ToolResultContent::Image(ImageSource {
479                data,
480                media_type: format,
481                ..
482            }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
483        }
484    }
485}
486
487impl TryFrom<Message> for message::Message {
488    type Error = MessageError;
489
490    fn try_from(message: Message) -> Result<Self, Self::Error> {
491        Ok(match message.role {
492            Role::User => message::Message::User {
493                content: message.content.try_map(|content| {
494                    Ok(match content {
495                        Content::Text { text } => message::UserContent::text(text),
496                        Content::ToolResult {
497                            tool_use_id,
498                            content,
499                            ..
500                        } => message::UserContent::tool_result(
501                            tool_use_id,
502                            content.map(|content| content.into()),
503                        ),
504                        Content::Image { source } => message::UserContent::Image(message::Image {
505                            data: DocumentSourceKind::base64(&source.data),
506                            media_type: Some(source.media_type.into()),
507                            detail: None,
508                            additional_params: None,
509                        }),
510                        Content::Document { source } => message::UserContent::document(
511                            source.data,
512                            Some(message::ContentFormat::Base64),
513                            Some(message::DocumentMediaType::PDF),
514                        ),
515                        _ => {
516                            return Err(MessageError::ConversionError(
517                                "Unsupported content type for User role".to_owned(),
518                            ));
519                        }
520                    })
521                })?,
522            },
523            Role::Assistant => match message.content.first() {
524                Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
525                    message::Message::Assistant {
526                        id: None,
527                        content: message.content.try_map(|content| content.try_into())?,
528                    }
529                }
530
531                _ => {
532                    return Err(MessageError::ConversionError(
533                        format!("Unsupported message for Assistant role: {message:?}").to_owned(),
534                    ));
535                }
536            },
537        })
538    }
539}
540
541#[derive(Clone)]
542pub struct CompletionModel {
543    pub(crate) client: Client,
544    pub model: String,
545    pub default_max_tokens: Option<u64>,
546}
547
548impl CompletionModel {
549    pub fn new(client: Client, model: &str) -> Self {
550        Self {
551            client,
552            model: model.to_string(),
553            default_max_tokens: calculate_max_tokens(model),
554        }
555    }
556}
557
558/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
559/// set or if set too high, the request will fail. The following values are based on the models
560/// available at the time of writing.
561///
562/// Dev Note: This is really bad design, I'm not sure why they did it like this..
563fn calculate_max_tokens(model: &str) -> Option<u64> {
564    if model.starts_with("claude-opus-4") {
565        Some(32000)
566    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
567        Some(64000)
568    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
569        Some(8192)
570    } else if model.starts_with("claude-3-opus")
571        || model.starts_with("claude-3-sonnet")
572        || model.starts_with("claude-3-haiku")
573    {
574        Some(4096)
575    } else {
576        None
577    }
578}
579
580#[derive(Debug, Deserialize, Serialize)]
581struct Metadata {
582    user_id: Option<String>,
583}
584
585#[derive(Default, Debug, Serialize, Deserialize)]
586#[serde(tag = "type", rename_all = "snake_case")]
587pub enum ToolChoice {
588    #[default]
589    Auto,
590    Any,
591    Tool {
592        name: String,
593    },
594}
595
596impl completion::CompletionModel for CompletionModel {
597    type Response = CompletionResponse;
598    type StreamingResponse = StreamingCompletionResponse;
599
600    #[cfg_attr(feature = "worker", worker::send)]
601    async fn completion(
602        &self,
603        completion_request: completion::CompletionRequest,
604    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
605        // Note: Ideally we'd introduce provider-specific Request models to handle the
606        // specific requirements of each provider. For now, we just manually check while
607        // building the request as a raw JSON document.
608
609        // Check if max_tokens is set, required for Anthropic
610        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
611            tokens
612        } else if let Some(tokens) = self.default_max_tokens {
613            tokens
614        } else {
615            return Err(CompletionError::RequestError(
616                "`max_tokens` must be set for Anthropic".into(),
617            ));
618        };
619
620        let mut full_history = vec![];
621        if let Some(docs) = completion_request.normalized_documents() {
622            full_history.push(docs);
623        }
624        full_history.extend(completion_request.chat_history);
625
626        let full_history = full_history
627            .into_iter()
628            .map(Message::try_from)
629            .collect::<Result<Vec<Message>, _>>()?;
630
631        let mut request = json!({
632            "model": self.model,
633            "messages": full_history,
634            "max_tokens": max_tokens,
635            "system": completion_request.preamble.unwrap_or("".to_string()),
636        });
637
638        if let Some(temperature) = completion_request.temperature {
639            json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
640        }
641
642        if !completion_request.tools.is_empty() {
643            json_utils::merge_inplace(
644                &mut request,
645                json!({
646                    "tools": completion_request
647                        .tools
648                        .into_iter()
649                        .map(|tool| ToolDefinition {
650                            name: tool.name,
651                            description: Some(tool.description),
652                            input_schema: tool.parameters,
653                        })
654                        .collect::<Vec<_>>(),
655                    "tool_choice": ToolChoice::Auto,
656                }),
657            );
658        }
659
660        if let Some(ref params) = completion_request.additional_params {
661            json_utils::merge_inplace(&mut request, params.clone())
662        }
663
664        tracing::debug!("Anthropic completion request: {request}");
665
666        let response = self
667            .client
668            .post("/v1/messages")
669            .json(&request)
670            .send()
671            .await?;
672
673        if response.status().is_success() {
674            match response.json::<ApiResponse<CompletionResponse>>().await? {
675                ApiResponse::Message(completion) => {
676                    tracing::info!(target: "rig",
677                        "Anthropic completion token usage: {}",
678                        completion.usage
679                    );
680                    completion.try_into()
681                }
682                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
683            }
684        } else {
685            Err(CompletionError::ProviderError(response.text().await?))
686        }
687    }
688
689    #[cfg_attr(feature = "worker", worker::send)]
690    async fn stream(
691        &self,
692        request: CompletionRequest,
693    ) -> Result<
694        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
695        CompletionError,
696    > {
697        CompletionModel::stream(self, request).await
698    }
699}
700
701#[derive(Debug, Deserialize)]
702struct ApiErrorResponse {
703    message: String,
704}
705
706#[derive(Debug, Deserialize)]
707#[serde(tag = "type", rename_all = "snake_case")]
708enum ApiResponse<T> {
709    Message(T),
710    Error(ApiErrorResponse),
711}
712
713#[cfg(test)]
714mod tests {
715    use super::*;
716    use serde_path_to_error::deserialize;
717
718    #[test]
719    fn test_deserialize_message() {
720        let assistant_message_json = r#"
721        {
722            "role": "assistant",
723            "content": "\n\nHello there, how may I assist you today?"
724        }
725        "#;
726
727        let assistant_message_json2 = r#"
728        {
729            "role": "assistant",
730            "content": [
731                {
732                    "type": "text",
733                    "text": "\n\nHello there, how may I assist you today?"
734                },
735                {
736                    "type": "tool_use",
737                    "id": "toolu_01A09q90qw90lq917835lq9",
738                    "name": "get_weather",
739                    "input": {"location": "San Francisco, CA"}
740                }
741            ]
742        }
743        "#;
744
745        let user_message_json = r#"
746        {
747            "role": "user",
748            "content": [
749                {
750                    "type": "image",
751                    "source": {
752                        "type": "base64",
753                        "media_type": "image/jpeg",
754                        "data": "/9j/4AAQSkZJRg..."
755                    }
756                },
757                {
758                    "type": "text",
759                    "text": "What is in this image?"
760                },
761                {
762                    "type": "tool_result",
763                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
764                    "content": "15 degrees"
765                }
766            ]
767        }
768        "#;
769
770        let assistant_message: Message = {
771            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
772            deserialize(jd).unwrap_or_else(|err| {
773                panic!("Deserialization error at {}: {}", err.path(), err);
774            })
775        };
776
777        let assistant_message2: Message = {
778            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
779            deserialize(jd).unwrap_or_else(|err| {
780                panic!("Deserialization error at {}: {}", err.path(), err);
781            })
782        };
783
784        let user_message: Message = {
785            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
786            deserialize(jd).unwrap_or_else(|err| {
787                panic!("Deserialization error at {}: {}", err.path(), err);
788            })
789        };
790
791        let Message { role, content } = assistant_message;
792        assert_eq!(role, Role::Assistant);
793        assert_eq!(
794            content.first(),
795            Content::Text {
796                text: "\n\nHello there, how may I assist you today?".to_owned()
797            }
798        );
799
800        let Message { role, content } = assistant_message2;
801        {
802            assert_eq!(role, Role::Assistant);
803            assert_eq!(content.len(), 2);
804
805            let mut iter = content.into_iter();
806
807            match iter.next().unwrap() {
808                Content::Text { text } => {
809                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
810                }
811                _ => panic!("Expected text content"),
812            }
813
814            match iter.next().unwrap() {
815                Content::ToolUse { id, name, input } => {
816                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
817                    assert_eq!(name, "get_weather");
818                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
819                }
820                _ => panic!("Expected tool use content"),
821            }
822
823            assert_eq!(iter.next(), None);
824        }
825
826        let Message { role, content } = user_message;
827        {
828            assert_eq!(role, Role::User);
829            assert_eq!(content.len(), 3);
830
831            let mut iter = content.into_iter();
832
833            match iter.next().unwrap() {
834                Content::Image { source } => {
835                    assert_eq!(
836                        source,
837                        ImageSource {
838                            data: "/9j/4AAQSkZJRg...".to_owned(),
839                            media_type: ImageFormat::JPEG,
840                            r#type: SourceType::BASE64,
841                        }
842                    );
843                }
844                _ => panic!("Expected image content"),
845            }
846
847            match iter.next().unwrap() {
848                Content::Text { text } => {
849                    assert_eq!(text, "What is in this image?");
850                }
851                _ => panic!("Expected text content"),
852            }
853
854            match iter.next().unwrap() {
855                Content::ToolResult {
856                    tool_use_id,
857                    content,
858                    is_error,
859                } => {
860                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
861                    assert_eq!(
862                        content.first(),
863                        ToolResultContent::Text {
864                            text: "15 degrees".to_owned()
865                        }
866                    );
867                    assert_eq!(is_error, None);
868                }
869                _ => panic!("Expected tool result content"),
870            }
871
872            assert_eq!(iter.next(), None);
873        }
874    }
875
876    #[test]
877    fn test_message_to_message_conversion() {
878        let user_message: Message = serde_json::from_str(
879            r#"
880        {
881            "role": "user",
882            "content": [
883                {
884                    "type": "image",
885                    "source": {
886                        "type": "base64",
887                        "media_type": "image/jpeg",
888                        "data": "/9j/4AAQSkZJRg..."
889                    }
890                },
891                {
892                    "type": "text",
893                    "text": "What is in this image?"
894                },
895                {
896                    "type": "document",
897                    "source": {
898                        "type": "base64",
899                        "data": "base64_encoded_pdf_data",
900                        "media_type": "application/pdf"
901                    }
902                }
903            ]
904        }
905        "#,
906        )
907        .unwrap();
908
909        let assistant_message = Message {
910            role: Role::Assistant,
911            content: OneOrMany::one(Content::ToolUse {
912                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
913                name: "get_weather".to_string(),
914                input: json!({"location": "San Francisco, CA"}),
915            }),
916        };
917
918        let tool_message = Message {
919            role: Role::User,
920            content: OneOrMany::one(Content::ToolResult {
921                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
922                content: OneOrMany::one(ToolResultContent::Text {
923                    text: "15 degrees".to_string(),
924                }),
925                is_error: None,
926            }),
927        };
928
929        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
930        let converted_assistant_message: message::Message =
931            assistant_message.clone().try_into().unwrap();
932        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
933
934        match converted_user_message.clone() {
935            message::Message::User { content } => {
936                assert_eq!(content.len(), 3);
937
938                let mut iter = content.into_iter();
939
940                match iter.next().unwrap() {
941                    message::UserContent::Image(message::Image {
942                        data, media_type, ..
943                    }) => {
944                        assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
945                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
946                    }
947                    _ => panic!("Expected image content"),
948                }
949
950                match iter.next().unwrap() {
951                    message::UserContent::Text(message::Text { text }) => {
952                        assert_eq!(text, "What is in this image?");
953                    }
954                    _ => panic!("Expected text content"),
955                }
956
957                match iter.next().unwrap() {
958                    message::UserContent::Document(message::Document {
959                        data, media_type, ..
960                    }) => {
961                        assert_eq!(data, "base64_encoded_pdf_data");
962                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
963                    }
964                    _ => panic!("Expected document content"),
965                }
966
967                assert_eq!(iter.next(), None);
968            }
969            _ => panic!("Expected user message"),
970        }
971
972        match converted_tool_message.clone() {
973            message::Message::User { content } => {
974                let message::ToolResult { id, content, .. } = match content.first() {
975                    message::UserContent::ToolResult(tool_result) => tool_result,
976                    _ => panic!("Expected tool result content"),
977                };
978                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
979                match content.first() {
980                    message::ToolResultContent::Text(message::Text { text }) => {
981                        assert_eq!(text, "15 degrees");
982                    }
983                    _ => panic!("Expected text content"),
984                }
985            }
986            _ => panic!("Expected tool result content"),
987        }
988
989        match converted_assistant_message.clone() {
990            message::Message::Assistant { content, .. } => {
991                assert_eq!(content.len(), 1);
992
993                match content.first() {
994                    message::AssistantContent::ToolCall(message::ToolCall {
995                        id, function, ..
996                    }) => {
997                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
998                        assert_eq!(function.name, "get_weather");
999                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1000                    }
1001                    _ => panic!("Expected tool call content"),
1002                }
1003            }
1004            _ => panic!("Expected assistant message"),
1005        }
1006
1007        let original_user_message: Message = converted_user_message.try_into().unwrap();
1008        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1009        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1010
1011        assert_eq!(user_message, original_user_message);
1012        assert_eq!(assistant_message, original_assistant_message);
1013        assert_eq!(tool_message, original_tool_message);
1014    }
1015}