rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError},
6    json_utils,
7    message::{self, DocumentMediaType, MessageError},
8    one_or_many::string_or_one_or_many,
9};
10use std::{convert::Infallible, str::FromStr};
11
12use super::client::Client;
13use crate::completion::CompletionRequest;
14use crate::providers::anthropic::streaming::StreamingCompletionResponse;
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17
18// ================================================================
19// Anthropic Completion API
20// ================================================================
21/// `claude-3-7-sonnet-latest` completion model
22pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
23
24/// `claude-3-5-sonnet-latest` completion model
25pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
26
27/// `claude-3-5-haiku-latest` completion model
28pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
29
30/// `claude-3-5-haiku-latest` completion model
31pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
32
33/// `claude-3-sonnet-20240229` completion model
34pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
35
36/// `claude-3-haiku-20240307` completion model
37pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
38
39pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
40pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
41pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
42
43#[derive(Debug, Deserialize)]
44pub struct CompletionResponse {
45    pub content: Vec<Content>,
46    pub id: String,
47    pub model: String,
48    pub role: String,
49    pub stop_reason: Option<String>,
50    pub stop_sequence: Option<String>,
51    pub usage: Usage,
52}
53
54#[derive(Debug, Deserialize, Serialize)]
55pub struct Usage {
56    pub input_tokens: u64,
57    pub cache_read_input_tokens: Option<u64>,
58    pub cache_creation_input_tokens: Option<u64>,
59    pub output_tokens: u64,
60}
61
62impl std::fmt::Display for Usage {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        write!(
65            f,
66            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
67            self.input_tokens,
68            match self.cache_read_input_tokens {
69                Some(token) => token.to_string(),
70                None => "n/a".to_string(),
71            },
72            match self.cache_creation_input_tokens {
73                Some(token) => token.to_string(),
74                None => "n/a".to_string(),
75            },
76            self.output_tokens
77        )
78    }
79}
80
81#[derive(Debug, Deserialize, Serialize)]
82pub struct ToolDefinition {
83    pub name: String,
84    pub description: Option<String>,
85    pub input_schema: serde_json::Value,
86}
87
88#[derive(Debug, Deserialize, Serialize)]
89#[serde(tag = "type", rename_all = "snake_case")]
90pub enum CacheControl {
91    Ephemeral,
92}
93
94impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
95    type Error = CompletionError;
96
97    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
98        let content = response
99            .content
100            .iter()
101            .map(|content| {
102                Ok(match content {
103                    Content::Text { text } => completion::AssistantContent::text(text),
104                    Content::ToolUse { id, name, input } => {
105                        completion::AssistantContent::tool_call(id, name, input.clone())
106                    }
107                    _ => {
108                        return Err(CompletionError::ResponseError(
109                            "Response did not contain a message or tool call".into(),
110                        ));
111                    }
112                })
113            })
114            .collect::<Result<Vec<_>, _>>()?;
115
116        let choice = OneOrMany::many(content).map_err(|_| {
117            CompletionError::ResponseError(
118                "Response contained no message or tool call (empty)".to_owned(),
119            )
120        })?;
121
122        Ok(completion::CompletionResponse {
123            choice,
124            raw_response: response,
125        })
126    }
127}
128
129#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
130pub struct Message {
131    pub role: Role,
132    #[serde(deserialize_with = "string_or_one_or_many")]
133    pub content: OneOrMany<Content>,
134}
135
136#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
137#[serde(rename_all = "lowercase")]
138pub enum Role {
139    User,
140    Assistant,
141}
142
143#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
144#[serde(tag = "type", rename_all = "snake_case")]
145pub enum Content {
146    Text {
147        text: String,
148    },
149    Image {
150        source: ImageSource,
151    },
152    ToolUse {
153        id: String,
154        name: String,
155        input: serde_json::Value,
156    },
157    ToolResult {
158        tool_use_id: String,
159        #[serde(deserialize_with = "string_or_one_or_many")]
160        content: OneOrMany<ToolResultContent>,
161        #[serde(skip_serializing_if = "Option::is_none")]
162        is_error: Option<bool>,
163    },
164    Document {
165        source: DocumentSource,
166    },
167}
168
169impl FromStr for Content {
170    type Err = Infallible;
171
172    fn from_str(s: &str) -> Result<Self, Self::Err> {
173        Ok(Content::Text { text: s.to_owned() })
174    }
175}
176
177#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
178#[serde(tag = "type", rename_all = "snake_case")]
179pub enum ToolResultContent {
180    Text { text: String },
181    Image(ImageSource),
182}
183
184impl FromStr for ToolResultContent {
185    type Err = Infallible;
186
187    fn from_str(s: &str) -> Result<Self, Self::Err> {
188        Ok(ToolResultContent::Text { text: s.to_owned() })
189    }
190}
191
192#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
193pub struct ImageSource {
194    pub data: String,
195    pub media_type: ImageFormat,
196    pub r#type: SourceType,
197}
198
199#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
200pub struct DocumentSource {
201    pub data: String,
202    pub media_type: DocumentFormat,
203    pub r#type: SourceType,
204}
205
206#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
207#[serde(rename_all = "lowercase")]
208pub enum ImageFormat {
209    #[serde(rename = "image/jpeg")]
210    JPEG,
211    #[serde(rename = "image/png")]
212    PNG,
213    #[serde(rename = "image/gif")]
214    GIF,
215    #[serde(rename = "image/webp")]
216    WEBP,
217}
218
219/// The document format to be used.
220///
221/// Currently, Anthropic only supports PDF for text documents over the API (within a message). You can find more information about this here: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
222#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
223#[serde(rename_all = "lowercase")]
224pub enum DocumentFormat {
225    #[serde(rename = "application/pdf")]
226    PDF,
227}
228
229#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
230#[serde(rename_all = "lowercase")]
231pub enum SourceType {
232    BASE64,
233}
234
235impl From<String> for Content {
236    fn from(text: String) -> Self {
237        Content::Text { text }
238    }
239}
240
241impl From<String> for ToolResultContent {
242    fn from(text: String) -> Self {
243        ToolResultContent::Text { text }
244    }
245}
246
247impl TryFrom<message::ContentFormat> for SourceType {
248    type Error = MessageError;
249
250    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
251        match format {
252            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
253            message::ContentFormat::String => Err(MessageError::ConversionError(
254                "Image urls are not supported in Anthropic".to_owned(),
255            )),
256        }
257    }
258}
259
260impl From<SourceType> for message::ContentFormat {
261    fn from(source_type: SourceType) -> Self {
262        match source_type {
263            SourceType::BASE64 => message::ContentFormat::Base64,
264        }
265    }
266}
267
268impl TryFrom<message::ImageMediaType> for ImageFormat {
269    type Error = MessageError;
270
271    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
272        Ok(match media_type {
273            message::ImageMediaType::JPEG => ImageFormat::JPEG,
274            message::ImageMediaType::PNG => ImageFormat::PNG,
275            message::ImageMediaType::GIF => ImageFormat::GIF,
276            message::ImageMediaType::WEBP => ImageFormat::WEBP,
277            _ => {
278                return Err(MessageError::ConversionError(
279                    format!("Unsupported image media type: {media_type:?}").to_owned(),
280                ));
281            }
282        })
283    }
284}
285
286impl From<ImageFormat> for message::ImageMediaType {
287    fn from(format: ImageFormat) -> Self {
288        match format {
289            ImageFormat::JPEG => message::ImageMediaType::JPEG,
290            ImageFormat::PNG => message::ImageMediaType::PNG,
291            ImageFormat::GIF => message::ImageMediaType::GIF,
292            ImageFormat::WEBP => message::ImageMediaType::WEBP,
293        }
294    }
295}
296
297impl TryFrom<DocumentMediaType> for DocumentFormat {
298    type Error = MessageError;
299    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
300        if !matches!(value, DocumentMediaType::PDF) {
301            return Err(MessageError::ConversionError(
302                "Anthropic only supports PDF documents".to_string(),
303            ));
304        };
305
306        Ok(DocumentFormat::PDF)
307    }
308}
309
310impl From<message::AssistantContent> for Content {
311    fn from(text: message::AssistantContent) -> Self {
312        match text {
313            message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
314            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
315                Content::ToolUse {
316                    id,
317                    name: function.name,
318                    input: function.arguments,
319                }
320            }
321        }
322    }
323}
324
325impl TryFrom<message::Message> for Message {
326    type Error = MessageError;
327
328    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
329        Ok(match message {
330            message::Message::User { content } => Message {
331                role: Role::User,
332                content: content.try_map(|content| match content {
333                    message::UserContent::Text(message::Text { text }) => {
334                        Ok(Content::Text { text })
335                    }
336                    message::UserContent::ToolResult(message::ToolResult {
337                        id, content, ..
338                    }) => Ok(Content::ToolResult {
339                        tool_use_id: id,
340                        content: content.try_map(|content| match content {
341                            message::ToolResultContent::Text(message::Text { text }) => {
342                                Ok(ToolResultContent::Text { text })
343                            }
344                            message::ToolResultContent::Image(image) => {
345                                let media_type =
346                                    image.media_type.ok_or(MessageError::ConversionError(
347                                        "Image media type is required".to_owned(),
348                                    ))?;
349                                let format = image.format.ok_or(MessageError::ConversionError(
350                                    "Image format is required".to_owned(),
351                                ))?;
352                                Ok(ToolResultContent::Image(ImageSource {
353                                    data: image.data,
354                                    media_type: media_type.try_into()?,
355                                    r#type: format.try_into()?,
356                                }))
357                            }
358                        })?,
359                        is_error: None,
360                    }),
361                    message::UserContent::Image(message::Image {
362                        data,
363                        format,
364                        media_type,
365                        ..
366                    }) => {
367                        let source = ImageSource {
368                            data,
369                            media_type: match media_type {
370                                Some(media_type) => media_type.try_into()?,
371                                None => {
372                                    return Err(MessageError::ConversionError(
373                                        "Image media type is required".to_owned(),
374                                    ));
375                                }
376                            },
377                            r#type: match format {
378                                Some(format) => format.try_into()?,
379                                None => SourceType::BASE64,
380                            },
381                        };
382                        Ok(Content::Image { source })
383                    }
384                    message::UserContent::Document(message::Document {
385                        data,
386                        format,
387                        media_type,
388                    }) => {
389                        let Some(media_type) = media_type else {
390                            return Err(MessageError::ConversionError(
391                                "Document media type is required".to_string(),
392                            ));
393                        };
394
395                        let source = DocumentSource {
396                            data,
397                            media_type: media_type.try_into()?,
398                            r#type: match format {
399                                Some(format) => format.try_into()?,
400                                None => SourceType::BASE64,
401                            },
402                        };
403                        Ok(Content::Document { source })
404                    }
405                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
406                        "Audio is not supported in Anthropic".to_owned(),
407                    )),
408                })?,
409            },
410
411            message::Message::Assistant { content, .. } => Message {
412                content: content.map(|content| content.into()),
413                role: Role::Assistant,
414            },
415        })
416    }
417}
418
419impl TryFrom<Content> for message::AssistantContent {
420    type Error = MessageError;
421
422    fn try_from(content: Content) -> Result<Self, Self::Error> {
423        Ok(match content {
424            Content::Text { text } => message::AssistantContent::text(text),
425            Content::ToolUse { id, name, input } => {
426                message::AssistantContent::tool_call(id, name, input)
427            }
428            _ => {
429                return Err(MessageError::ConversionError(
430                    format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
431                ));
432            }
433        })
434    }
435}
436
437impl From<ToolResultContent> for message::ToolResultContent {
438    fn from(content: ToolResultContent) -> Self {
439        match content {
440            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
441            ToolResultContent::Image(ImageSource {
442                data,
443                media_type: format,
444                r#type,
445            }) => message::ToolResultContent::image(
446                data,
447                Some(r#type.into()),
448                Some(format.into()),
449                None,
450            ),
451        }
452    }
453}
454
455impl TryFrom<Message> for message::Message {
456    type Error = MessageError;
457
458    fn try_from(message: Message) -> Result<Self, Self::Error> {
459        Ok(match message.role {
460            Role::User => message::Message::User {
461                content: message.content.try_map(|content| {
462                    Ok(match content {
463                        Content::Text { text } => message::UserContent::text(text),
464                        Content::ToolResult {
465                            tool_use_id,
466                            content,
467                            ..
468                        } => message::UserContent::tool_result(
469                            tool_use_id,
470                            content.map(|content| content.into()),
471                        ),
472                        Content::Image { source } => message::UserContent::Image(message::Image {
473                            data: source.data,
474                            format: Some(message::ContentFormat::Base64),
475                            media_type: Some(source.media_type.into()),
476                            detail: None,
477                        }),
478                        Content::Document { source } => message::UserContent::document(
479                            source.data,
480                            Some(message::ContentFormat::Base64),
481                            Some(message::DocumentMediaType::PDF),
482                        ),
483                        _ => {
484                            return Err(MessageError::ConversionError(
485                                "Unsupported content type for User role".to_owned(),
486                            ));
487                        }
488                    })
489                })?,
490            },
491            Role::Assistant => match message.content.first() {
492                Content::Text { .. } | Content::ToolUse { .. } => message::Message::Assistant {
493                    id: None,
494                    content: message.content.try_map(|content| content.try_into())?,
495                },
496
497                _ => {
498                    return Err(MessageError::ConversionError(
499                        format!("Unsupported message for Assistant role: {message:?}").to_owned(),
500                    ));
501                }
502            },
503        })
504    }
505}
506
507#[derive(Clone)]
508pub struct CompletionModel {
509    pub(crate) client: Client,
510    pub model: String,
511    pub default_max_tokens: Option<u64>,
512}
513
514impl CompletionModel {
515    pub fn new(client: Client, model: &str) -> Self {
516        Self {
517            client,
518            model: model.to_string(),
519            default_max_tokens: calculate_max_tokens(model),
520        }
521    }
522}
523
524/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
525/// set or if set too high, the request will fail. The following values are based on the models
526/// available at the time of writing.
527///
528/// Dev Note: This is really bad design, I'm not sure why they did it like this..
529fn calculate_max_tokens(model: &str) -> Option<u64> {
530    if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
531        Some(8192)
532    } else if model.starts_with("claude-3-opus")
533        || model.starts_with("claude-3-sonnet")
534        || model.starts_with("claude-3-haiku")
535    {
536        Some(4096)
537    } else {
538        None
539    }
540}
541
542#[derive(Debug, Deserialize, Serialize)]
543struct Metadata {
544    user_id: Option<String>,
545}
546
547#[derive(Default, Debug, Serialize, Deserialize)]
548#[serde(tag = "type", rename_all = "snake_case")]
549pub enum ToolChoice {
550    #[default]
551    Auto,
552    Any,
553    Tool {
554        name: String,
555    },
556}
557
558impl completion::CompletionModel for CompletionModel {
559    type Response = CompletionResponse;
560    type StreamingResponse = StreamingCompletionResponse;
561
562    #[cfg_attr(feature = "worker", worker::send)]
563    async fn completion(
564        &self,
565        completion_request: completion::CompletionRequest,
566    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
567        // Note: Ideally we'd introduce provider-specific Request models to handle the
568        // specific requirements of each provider. For now, we just manually check while
569        // building the request as a raw JSON document.
570
571        // Check if max_tokens is set, required for Anthropic
572        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
573            tokens
574        } else if let Some(tokens) = self.default_max_tokens {
575            tokens
576        } else {
577            return Err(CompletionError::RequestError(
578                "`max_tokens` must be set for Anthropic".into(),
579            ));
580        };
581
582        let mut full_history = vec![];
583        if let Some(docs) = completion_request.normalized_documents() {
584            full_history.push(docs);
585        }
586        full_history.extend(completion_request.chat_history);
587
588        let full_history = full_history
589            .into_iter()
590            .map(Message::try_from)
591            .collect::<Result<Vec<Message>, _>>()?;
592
593        let mut request = json!({
594            "model": self.model,
595            "messages": full_history,
596            "max_tokens": max_tokens,
597            "system": completion_request.preamble.unwrap_or("".to_string()),
598        });
599
600        if let Some(temperature) = completion_request.temperature {
601            json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
602        }
603
604        if !completion_request.tools.is_empty() {
605            json_utils::merge_inplace(
606                &mut request,
607                json!({
608                    "tools": completion_request
609                        .tools
610                        .into_iter()
611                        .map(|tool| ToolDefinition {
612                            name: tool.name,
613                            description: Some(tool.description),
614                            input_schema: tool.parameters,
615                        })
616                        .collect::<Vec<_>>(),
617                    "tool_choice": ToolChoice::Auto,
618                }),
619            );
620        }
621
622        if let Some(ref params) = completion_request.additional_params {
623            json_utils::merge_inplace(&mut request, params.clone())
624        }
625
626        tracing::debug!("Anthropic completion request: {request}");
627
628        let response = self
629            .client
630            .post("/v1/messages")
631            .json(&request)
632            .send()
633            .await?;
634
635        if response.status().is_success() {
636            match response.json::<ApiResponse<CompletionResponse>>().await? {
637                ApiResponse::Message(completion) => {
638                    tracing::info!(target: "rig",
639                        "Anthropic completion token usage: {}",
640                        completion.usage
641                    );
642                    completion.try_into()
643                }
644                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
645            }
646        } else {
647            Err(CompletionError::ProviderError(response.text().await?))
648        }
649    }
650
651    #[cfg_attr(feature = "worker", worker::send)]
652    async fn stream(
653        &self,
654        request: CompletionRequest,
655    ) -> Result<
656        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
657        CompletionError,
658    > {
659        CompletionModel::stream(self, request).await
660    }
661}
662
663#[derive(Debug, Deserialize)]
664struct ApiErrorResponse {
665    message: String,
666}
667
668#[derive(Debug, Deserialize)]
669#[serde(tag = "type", rename_all = "snake_case")]
670enum ApiResponse<T> {
671    Message(T),
672    Error(ApiErrorResponse),
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use serde_path_to_error::deserialize;
679
680    #[test]
681    fn test_deserialize_message() {
682        let assistant_message_json = r#"
683        {
684            "role": "assistant",
685            "content": "\n\nHello there, how may I assist you today?"
686        }
687        "#;
688
689        let assistant_message_json2 = r#"
690        {
691            "role": "assistant",
692            "content": [
693                {
694                    "type": "text",
695                    "text": "\n\nHello there, how may I assist you today?"
696                },
697                {
698                    "type": "tool_use",
699                    "id": "toolu_01A09q90qw90lq917835lq9",
700                    "name": "get_weather",
701                    "input": {"location": "San Francisco, CA"}
702                }
703            ]
704        }
705        "#;
706
707        let user_message_json = r#"
708        {
709            "role": "user",
710            "content": [
711                {
712                    "type": "image",
713                    "source": {
714                        "type": "base64",
715                        "media_type": "image/jpeg",
716                        "data": "/9j/4AAQSkZJRg..."
717                    }
718                },
719                {
720                    "type": "text",
721                    "text": "What is in this image?"
722                },
723                {
724                    "type": "tool_result",
725                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
726                    "content": "15 degrees"
727                }
728            ]
729        }
730        "#;
731
732        let assistant_message: Message = {
733            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
734            deserialize(jd).unwrap_or_else(|err| {
735                panic!("Deserialization error at {}: {}", err.path(), err);
736            })
737        };
738
739        let assistant_message2: Message = {
740            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
741            deserialize(jd).unwrap_or_else(|err| {
742                panic!("Deserialization error at {}: {}", err.path(), err);
743            })
744        };
745
746        let user_message: Message = {
747            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
748            deserialize(jd).unwrap_or_else(|err| {
749                panic!("Deserialization error at {}: {}", err.path(), err);
750            })
751        };
752
753        let Message { role, content } = assistant_message;
754        assert_eq!(role, Role::Assistant);
755        assert_eq!(
756            content.first(),
757            Content::Text {
758                text: "\n\nHello there, how may I assist you today?".to_owned()
759            }
760        );
761
762        let Message { role, content } = assistant_message2;
763        {
764            assert_eq!(role, Role::Assistant);
765            assert_eq!(content.len(), 2);
766
767            let mut iter = content.into_iter();
768
769            match iter.next().unwrap() {
770                Content::Text { text } => {
771                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
772                }
773                _ => panic!("Expected text content"),
774            }
775
776            match iter.next().unwrap() {
777                Content::ToolUse { id, name, input } => {
778                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
779                    assert_eq!(name, "get_weather");
780                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
781                }
782                _ => panic!("Expected tool use content"),
783            }
784
785            assert_eq!(iter.next(), None);
786        }
787
788        let Message { role, content } = user_message;
789        {
790            assert_eq!(role, Role::User);
791            assert_eq!(content.len(), 3);
792
793            let mut iter = content.into_iter();
794
795            match iter.next().unwrap() {
796                Content::Image { source } => {
797                    assert_eq!(
798                        source,
799                        ImageSource {
800                            data: "/9j/4AAQSkZJRg...".to_owned(),
801                            media_type: ImageFormat::JPEG,
802                            r#type: SourceType::BASE64,
803                        }
804                    );
805                }
806                _ => panic!("Expected image content"),
807            }
808
809            match iter.next().unwrap() {
810                Content::Text { text } => {
811                    assert_eq!(text, "What is in this image?");
812                }
813                _ => panic!("Expected text content"),
814            }
815
816            match iter.next().unwrap() {
817                Content::ToolResult {
818                    tool_use_id,
819                    content,
820                    is_error,
821                } => {
822                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
823                    assert_eq!(
824                        content.first(),
825                        ToolResultContent::Text {
826                            text: "15 degrees".to_owned()
827                        }
828                    );
829                    assert_eq!(is_error, None);
830                }
831                _ => panic!("Expected tool result content"),
832            }
833
834            assert_eq!(iter.next(), None);
835        }
836    }
837
838    #[test]
839    fn test_message_to_message_conversion() {
840        let user_message: Message = serde_json::from_str(
841            r#"
842        {
843            "role": "user",
844            "content": [
845                {
846                    "type": "image",
847                    "source": {
848                        "type": "base64",
849                        "media_type": "image/jpeg",
850                        "data": "/9j/4AAQSkZJRg..."
851                    }
852                },
853                {
854                    "type": "text",
855                    "text": "What is in this image?"
856                },
857                {
858                    "type": "document",
859                    "source": {
860                        "type": "base64",
861                        "data": "base64_encoded_pdf_data",
862                        "media_type": "application/pdf"
863                    }
864                }
865            ]
866        }
867        "#,
868        )
869        .unwrap();
870
871        let assistant_message = Message {
872            role: Role::Assistant,
873            content: OneOrMany::one(Content::ToolUse {
874                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
875                name: "get_weather".to_string(),
876                input: json!({"location": "San Francisco, CA"}),
877            }),
878        };
879
880        let tool_message = Message {
881            role: Role::User,
882            content: OneOrMany::one(Content::ToolResult {
883                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
884                content: OneOrMany::one(ToolResultContent::Text {
885                    text: "15 degrees".to_string(),
886                }),
887                is_error: None,
888            }),
889        };
890
891        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
892        let converted_assistant_message: message::Message =
893            assistant_message.clone().try_into().unwrap();
894        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
895
896        match converted_user_message.clone() {
897            message::Message::User { content } => {
898                assert_eq!(content.len(), 3);
899
900                let mut iter = content.into_iter();
901
902                match iter.next().unwrap() {
903                    message::UserContent::Image(message::Image {
904                        data,
905                        format,
906                        media_type,
907                        ..
908                    }) => {
909                        assert_eq!(data, "/9j/4AAQSkZJRg...");
910                        assert_eq!(format.unwrap(), message::ContentFormat::Base64);
911                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
912                    }
913                    _ => panic!("Expected image content"),
914                }
915
916                match iter.next().unwrap() {
917                    message::UserContent::Text(message::Text { text }) => {
918                        assert_eq!(text, "What is in this image?");
919                    }
920                    _ => panic!("Expected text content"),
921                }
922
923                match iter.next().unwrap() {
924                    message::UserContent::Document(message::Document {
925                        data,
926                        format,
927                        media_type,
928                    }) => {
929                        assert_eq!(data, "base64_encoded_pdf_data");
930                        assert_eq!(format.unwrap(), message::ContentFormat::Base64);
931                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
932                    }
933                    _ => panic!("Expected document content"),
934                }
935
936                assert_eq!(iter.next(), None);
937            }
938            _ => panic!("Expected user message"),
939        }
940
941        match converted_tool_message.clone() {
942            message::Message::User { content } => {
943                let message::ToolResult { id, content, .. } = match content.first() {
944                    message::UserContent::ToolResult(tool_result) => tool_result,
945                    _ => panic!("Expected tool result content"),
946                };
947                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
948                match content.first() {
949                    message::ToolResultContent::Text(message::Text { text }) => {
950                        assert_eq!(text, "15 degrees");
951                    }
952                    _ => panic!("Expected text content"),
953                }
954            }
955            _ => panic!("Expected tool result content"),
956        }
957
958        match converted_assistant_message.clone() {
959            message::Message::Assistant { content, .. } => {
960                assert_eq!(content.len(), 1);
961
962                match content.first() {
963                    message::AssistantContent::ToolCall(message::ToolCall {
964                        id, function, ..
965                    }) => {
966                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
967                        assert_eq!(function.name, "get_weather");
968                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
969                    }
970                    _ => panic!("Expected tool call content"),
971                }
972            }
973            _ => panic!("Expected assistant message"),
974        }
975
976        let original_user_message: Message = converted_user_message.try_into().unwrap();
977        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
978        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
979
980        assert_eq!(user_message, original_user_message);
981        assert_eq!(assistant_message, original_assistant_message);
982        assert_eq!(tool_message, original_tool_message);
983    }
984}