rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    completion::{self, CompletionError},
5    json_utils,
6    message::{self, DocumentMediaType, MessageError},
7    one_or_many::string_or_one_or_many,
8    OneOrMany,
9};
10use std::{convert::Infallible, str::FromStr};
11
12use super::client::Client;
13use crate::completion::CompletionRequest;
14use crate::providers::anthropic::streaming::StreamingCompletionResponse;
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17
18// ================================================================
19// Anthropic Completion API
20// ================================================================
21/// `claude-3-7-sonnet-latest` completion model
22pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
23
24/// `claude-3-5-sonnet-latest` completion model
25pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
26
27/// `claude-3-5-haiku-latest` completion model
28pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
29
30/// `claude-3-5-haiku-latest` completion model
31pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
32
33/// `claude-3-sonnet-20240229` completion model
34pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
35
36/// `claude-3-haiku-20240307` completion model
37pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
38
39pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
40pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
41pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
42
43#[derive(Debug, Deserialize)]
44pub struct CompletionResponse {
45    pub content: Vec<Content>,
46    pub id: String,
47    pub model: String,
48    pub role: String,
49    pub stop_reason: Option<String>,
50    pub stop_sequence: Option<String>,
51    pub usage: Usage,
52}
53
54#[derive(Debug, Deserialize, Serialize)]
55pub struct Usage {
56    pub input_tokens: u64,
57    pub cache_read_input_tokens: Option<u64>,
58    pub cache_creation_input_tokens: Option<u64>,
59    pub output_tokens: u64,
60}
61
62impl std::fmt::Display for Usage {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        write!(
65            f,
66            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
67            self.input_tokens,
68            match self.cache_read_input_tokens {
69                Some(token) => token.to_string(),
70                None => "n/a".to_string(),
71            },
72            match self.cache_creation_input_tokens {
73                Some(token) => token.to_string(),
74                None => "n/a".to_string(),
75            },
76            self.output_tokens
77        )
78    }
79}
80
81#[derive(Debug, Deserialize, Serialize)]
82pub struct ToolDefinition {
83    pub name: String,
84    pub description: Option<String>,
85    pub input_schema: serde_json::Value,
86}
87
88#[derive(Debug, Deserialize, Serialize)]
89#[serde(tag = "type", rename_all = "snake_case")]
90pub enum CacheControl {
91    Ephemeral,
92}
93
94impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
95    type Error = CompletionError;
96
97    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
98        let content = response
99            .content
100            .iter()
101            .map(|content| {
102                Ok(match content {
103                    Content::Text { text } => completion::AssistantContent::text(text),
104                    Content::ToolUse { id, name, input } => {
105                        completion::AssistantContent::tool_call(id, name, input.clone())
106                    }
107                    _ => {
108                        return Err(CompletionError::ResponseError(
109                            "Response did not contain a message or tool call".into(),
110                        ))
111                    }
112                })
113            })
114            .collect::<Result<Vec<_>, _>>()?;
115
116        let choice = OneOrMany::many(content).map_err(|_| {
117            CompletionError::ResponseError(
118                "Response contained no message or tool call (empty)".to_owned(),
119            )
120        })?;
121
122        Ok(completion::CompletionResponse {
123            choice,
124            raw_response: response,
125        })
126    }
127}
128
129#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
130pub struct Message {
131    pub role: Role,
132    #[serde(deserialize_with = "string_or_one_or_many")]
133    pub content: OneOrMany<Content>,
134}
135
136#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
137#[serde(rename_all = "lowercase")]
138pub enum Role {
139    User,
140    Assistant,
141}
142
143#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
144#[serde(tag = "type", rename_all = "snake_case")]
145pub enum Content {
146    Text {
147        text: String,
148    },
149    Image {
150        source: ImageSource,
151    },
152    ToolUse {
153        id: String,
154        name: String,
155        input: serde_json::Value,
156    },
157    ToolResult {
158        tool_use_id: String,
159        #[serde(deserialize_with = "string_or_one_or_many")]
160        content: OneOrMany<ToolResultContent>,
161        #[serde(skip_serializing_if = "Option::is_none")]
162        is_error: Option<bool>,
163    },
164    Document {
165        source: DocumentSource,
166    },
167}
168
169impl FromStr for Content {
170    type Err = Infallible;
171
172    fn from_str(s: &str) -> Result<Self, Self::Err> {
173        Ok(Content::Text { text: s.to_owned() })
174    }
175}
176
177#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
178#[serde(tag = "type", rename_all = "snake_case")]
179pub enum ToolResultContent {
180    Text { text: String },
181    Image(ImageSource),
182}
183
184impl FromStr for ToolResultContent {
185    type Err = Infallible;
186
187    fn from_str(s: &str) -> Result<Self, Self::Err> {
188        Ok(ToolResultContent::Text { text: s.to_owned() })
189    }
190}
191
192#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
193pub struct ImageSource {
194    pub data: String,
195    pub media_type: ImageFormat,
196    pub r#type: SourceType,
197}
198
199#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
200pub struct DocumentSource {
201    pub data: String,
202    pub media_type: DocumentFormat,
203    pub r#type: SourceType,
204}
205
206#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
207#[serde(rename_all = "lowercase")]
208pub enum ImageFormat {
209    #[serde(rename = "image/jpeg")]
210    JPEG,
211    #[serde(rename = "image/png")]
212    PNG,
213    #[serde(rename = "image/gif")]
214    GIF,
215    #[serde(rename = "image/webp")]
216    WEBP,
217}
218
219/// The document format to be used.
220///
221/// Currently, Anthropic only supports PDF for text documents over the API (within a message). You can find more information about this here: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
222#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
223#[serde(rename_all = "lowercase")]
224pub enum DocumentFormat {
225    #[serde(rename = "application/pdf")]
226    PDF,
227}
228
229#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
230#[serde(rename_all = "lowercase")]
231pub enum SourceType {
232    BASE64,
233}
234
235impl From<String> for Content {
236    fn from(text: String) -> Self {
237        Content::Text { text }
238    }
239}
240
241impl From<String> for ToolResultContent {
242    fn from(text: String) -> Self {
243        ToolResultContent::Text { text }
244    }
245}
246
247impl TryFrom<message::ContentFormat> for SourceType {
248    type Error = MessageError;
249
250    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
251        match format {
252            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
253            message::ContentFormat::String => Err(MessageError::ConversionError(
254                "Image urls are not supported in Anthropic".to_owned(),
255            )),
256        }
257    }
258}
259
260impl From<SourceType> for message::ContentFormat {
261    fn from(source_type: SourceType) -> Self {
262        match source_type {
263            SourceType::BASE64 => message::ContentFormat::Base64,
264        }
265    }
266}
267
268impl TryFrom<message::ImageMediaType> for ImageFormat {
269    type Error = MessageError;
270
271    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
272        Ok(match media_type {
273            message::ImageMediaType::JPEG => ImageFormat::JPEG,
274            message::ImageMediaType::PNG => ImageFormat::PNG,
275            message::ImageMediaType::GIF => ImageFormat::GIF,
276            message::ImageMediaType::WEBP => ImageFormat::WEBP,
277            _ => {
278                return Err(MessageError::ConversionError(
279                    format!("Unsupported image media type: {media_type:?}").to_owned(),
280                ))
281            }
282        })
283    }
284}
285
286impl From<ImageFormat> for message::ImageMediaType {
287    fn from(format: ImageFormat) -> Self {
288        match format {
289            ImageFormat::JPEG => message::ImageMediaType::JPEG,
290            ImageFormat::PNG => message::ImageMediaType::PNG,
291            ImageFormat::GIF => message::ImageMediaType::GIF,
292            ImageFormat::WEBP => message::ImageMediaType::WEBP,
293        }
294    }
295}
296
297impl TryFrom<DocumentMediaType> for DocumentFormat {
298    type Error = MessageError;
299    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
300        if !matches!(value, DocumentMediaType::PDF) {
301            return Err(MessageError::ConversionError(
302                "Anthropic only supports PDF documents".to_string(),
303            ));
304        };
305
306        Ok(DocumentFormat::PDF)
307    }
308}
309
310impl From<message::AssistantContent> for Content {
311    fn from(text: message::AssistantContent) -> Self {
312        match text {
313            message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
314            message::AssistantContent::ToolCall(message::ToolCall { id, function }) => {
315                Content::ToolUse {
316                    id,
317                    name: function.name,
318                    input: function.arguments,
319                }
320            }
321        }
322    }
323}
324
325impl TryFrom<message::Message> for Message {
326    type Error = MessageError;
327
328    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
329        Ok(match message {
330            message::Message::User { content } => Message {
331                role: Role::User,
332                content: content.try_map(|content| match content {
333                    message::UserContent::Text(message::Text { text }) => {
334                        Ok(Content::Text { text })
335                    }
336                    message::UserContent::ToolResult(message::ToolResult { id, content }) => {
337                        Ok(Content::ToolResult {
338                            tool_use_id: id,
339                            content: content.try_map(|content| match content {
340                                message::ToolResultContent::Text(message::Text { text }) => {
341                                    Ok(ToolResultContent::Text { text })
342                                }
343                                message::ToolResultContent::Image(image) => {
344                                    let media_type =
345                                        image.media_type.ok_or(MessageError::ConversionError(
346                                            "Image media type is required".to_owned(),
347                                        ))?;
348                                    let format =
349                                        image.format.ok_or(MessageError::ConversionError(
350                                            "Image format is required".to_owned(),
351                                        ))?;
352                                    Ok(ToolResultContent::Image(ImageSource {
353                                        data: image.data,
354                                        media_type: media_type.try_into()?,
355                                        r#type: format.try_into()?,
356                                    }))
357                                }
358                            })?,
359                            is_error: None,
360                        })
361                    }
362                    message::UserContent::Image(message::Image {
363                        data,
364                        format,
365                        media_type,
366                        ..
367                    }) => {
368                        let source = ImageSource {
369                            data,
370                            media_type: match media_type {
371                                Some(media_type) => media_type.try_into()?,
372                                None => {
373                                    return Err(MessageError::ConversionError(
374                                        "Image media type is required".to_owned(),
375                                    ))
376                                }
377                            },
378                            r#type: match format {
379                                Some(format) => format.try_into()?,
380                                None => SourceType::BASE64,
381                            },
382                        };
383                        Ok(Content::Image { source })
384                    }
385                    message::UserContent::Document(message::Document {
386                        data,
387                        format,
388                        media_type,
389                    }) => {
390                        let Some(media_type) = media_type else {
391                            return Err(MessageError::ConversionError(
392                                "Document media type is required".to_string(),
393                            ));
394                        };
395
396                        let source = DocumentSource {
397                            data,
398                            media_type: media_type.try_into()?,
399                            r#type: match format {
400                                Some(format) => format.try_into()?,
401                                None => SourceType::BASE64,
402                            },
403                        };
404                        Ok(Content::Document { source })
405                    }
406                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
407                        "Audio is not supported in Anthropic".to_owned(),
408                    )),
409                })?,
410            },
411
412            message::Message::Assistant { content } => Message {
413                content: content.map(|content| content.into()),
414                role: Role::Assistant,
415            },
416        })
417    }
418}
419
420impl TryFrom<Content> for message::AssistantContent {
421    type Error = MessageError;
422
423    fn try_from(content: Content) -> Result<Self, Self::Error> {
424        Ok(match content {
425            Content::Text { text } => message::AssistantContent::text(text),
426            Content::ToolUse { id, name, input } => {
427                message::AssistantContent::tool_call(id, name, input)
428            }
429            _ => {
430                return Err(MessageError::ConversionError(
431                    format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
432                ))
433            }
434        })
435    }
436}
437
438impl From<ToolResultContent> for message::ToolResultContent {
439    fn from(content: ToolResultContent) -> Self {
440        match content {
441            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
442            ToolResultContent::Image(ImageSource {
443                data,
444                media_type: format,
445                r#type,
446            }) => message::ToolResultContent::image(
447                data,
448                Some(r#type.into()),
449                Some(format.into()),
450                None,
451            ),
452        }
453    }
454}
455
456impl TryFrom<Message> for message::Message {
457    type Error = MessageError;
458
459    fn try_from(message: Message) -> Result<Self, Self::Error> {
460        Ok(match message.role {
461            Role::User => message::Message::User {
462                content: message.content.try_map(|content| {
463                    Ok(match content {
464                        Content::Text { text } => message::UserContent::text(text),
465                        Content::ToolResult {
466                            tool_use_id,
467                            content,
468                            ..
469                        } => message::UserContent::tool_result(
470                            tool_use_id,
471                            content.map(|content| content.into()),
472                        ),
473                        Content::Image { source } => message::UserContent::Image(message::Image {
474                            data: source.data,
475                            format: Some(message::ContentFormat::Base64),
476                            media_type: Some(source.media_type.into()),
477                            detail: None,
478                        }),
479                        Content::Document { source } => message::UserContent::document(
480                            source.data,
481                            Some(message::ContentFormat::Base64),
482                            Some(message::DocumentMediaType::PDF),
483                        ),
484                        _ => {
485                            return Err(MessageError::ConversionError(
486                                "Unsupported content type for User role".to_owned(),
487                            ))
488                        }
489                    })
490                })?,
491            },
492            Role::Assistant => match message.content.first() {
493                Content::Text { .. } | Content::ToolUse { .. } => message::Message::Assistant {
494                    content: message.content.try_map(|content| content.try_into())?,
495                },
496
497                _ => {
498                    return Err(MessageError::ConversionError(
499                        format!("Unsupported message for Assistant role: {message:?}").to_owned(),
500                    ))
501                }
502            },
503        })
504    }
505}
506
507#[derive(Clone)]
508pub struct CompletionModel {
509    pub(crate) client: Client,
510    pub model: String,
511    pub default_max_tokens: Option<u64>,
512}
513
514impl CompletionModel {
515    pub fn new(client: Client, model: &str) -> Self {
516        Self {
517            client,
518            model: model.to_string(),
519            default_max_tokens: calculate_max_tokens(model),
520        }
521    }
522}
523
524/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
525/// set or if set too high, the request will fail. The following values are based on the models
526/// available at the time of writing.
527///
528/// Dev Note: This is really bad design, I'm not sure why they did it like this..
529fn calculate_max_tokens(model: &str) -> Option<u64> {
530    if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
531        Some(8192)
532    } else if model.starts_with("claude-3-opus")
533        || model.starts_with("claude-3-sonnet")
534        || model.starts_with("claude-3-haiku")
535    {
536        Some(4096)
537    } else {
538        None
539    }
540}
541
542#[derive(Debug, Deserialize, Serialize)]
543struct Metadata {
544    user_id: Option<String>,
545}
546
547#[derive(Default, Debug, Serialize, Deserialize)]
548#[serde(tag = "type", rename_all = "snake_case")]
549pub enum ToolChoice {
550    #[default]
551    Auto,
552    Any,
553    Tool {
554        name: String,
555    },
556}
557
558impl completion::CompletionModel for CompletionModel {
559    type Response = CompletionResponse;
560    type StreamingResponse = StreamingCompletionResponse;
561
562    #[cfg_attr(feature = "worker", worker::send)]
563    async fn completion(
564        &self,
565        completion_request: completion::CompletionRequest,
566    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
567        // Note: Ideally we'd introduce provider-specific Request models to handle the
568        // specific requirements of each provider. For now, we just manually check while
569        // building the request as a raw JSON document.
570
571        // Check if max_tokens is set, required for Anthropic
572        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
573            tokens
574        } else if let Some(tokens) = self.default_max_tokens {
575            tokens
576        } else {
577            return Err(CompletionError::RequestError(
578                "`max_tokens` must be set for Anthropic".into(),
579            ));
580        };
581
582        let mut full_history = vec![];
583        if let Some(docs) = completion_request.normalized_documents() {
584            full_history.push(docs);
585        }
586        full_history.extend(completion_request.chat_history);
587
588        let full_history = full_history
589            .into_iter()
590            .map(Message::try_from)
591            .collect::<Result<Vec<Message>, _>>()?;
592
593        let mut request = json!({
594            "model": self.model,
595            "messages": full_history,
596            "max_tokens": max_tokens,
597            "system": completion_request.preamble.unwrap_or("".to_string()),
598        });
599
600        if let Some(temperature) = completion_request.temperature {
601            json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
602        }
603
604        if !completion_request.tools.is_empty() {
605            json_utils::merge_inplace(
606                &mut request,
607                json!({
608                    "tools": completion_request
609                        .tools
610                        .into_iter()
611                        .map(|tool| ToolDefinition {
612                            name: tool.name,
613                            description: Some(tool.description),
614                            input_schema: tool.parameters,
615                        })
616                        .collect::<Vec<_>>(),
617                    "tool_choice": ToolChoice::Auto,
618                }),
619            );
620        }
621
622        if let Some(ref params) = completion_request.additional_params {
623            json_utils::merge_inplace(&mut request, params.clone())
624        }
625
626        tracing::debug!("Anthropic completion request: {request}");
627
628        let response = self
629            .client
630            .post("/v1/messages")
631            .json(&request)
632            .send()
633            .await?;
634
635        if response.status().is_success() {
636            match response.json::<ApiResponse<CompletionResponse>>().await? {
637                ApiResponse::Message(completion) => {
638                    tracing::info!(target: "rig",
639                        "Anthropic completion token usage: {}",
640                        completion.usage
641                    );
642                    completion.try_into()
643                }
644                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
645            }
646        } else {
647            Err(CompletionError::ProviderError(response.text().await?))
648        }
649    }
650
651    #[cfg_attr(feature = "worker", worker::send)]
652    async fn stream(
653        &self,
654        request: CompletionRequest,
655    ) -> Result<
656        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
657        CompletionError,
658    > {
659        CompletionModel::stream(self, request).await
660    }
661}
662
663#[derive(Debug, Deserialize)]
664struct ApiErrorResponse {
665    message: String,
666}
667
668#[derive(Debug, Deserialize)]
669#[serde(tag = "type", rename_all = "snake_case")]
670enum ApiResponse<T> {
671    Message(T),
672    Error(ApiErrorResponse),
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use serde_path_to_error::deserialize;
679
680    #[test]
681    fn test_deserialize_message() {
682        let assistant_message_json = r#"
683        {
684            "role": "assistant",
685            "content": "\n\nHello there, how may I assist you today?"
686        }
687        "#;
688
689        let assistant_message_json2 = r#"
690        {
691            "role": "assistant",
692            "content": [
693                {
694                    "type": "text",
695                    "text": "\n\nHello there, how may I assist you today?"
696                },
697                {
698                    "type": "tool_use",
699                    "id": "toolu_01A09q90qw90lq917835lq9",
700                    "name": "get_weather",
701                    "input": {"location": "San Francisco, CA"}
702                }
703            ]
704        }
705        "#;
706
707        let user_message_json = r#"
708        {
709            "role": "user",
710            "content": [
711                {
712                    "type": "image",
713                    "source": {
714                        "type": "base64",
715                        "media_type": "image/jpeg",
716                        "data": "/9j/4AAQSkZJRg..."
717                    }
718                },
719                {
720                    "type": "text",
721                    "text": "What is in this image?"
722                },
723                {
724                    "type": "tool_result",
725                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
726                    "content": "15 degrees"
727                }
728            ]
729        }
730        "#;
731
732        let assistant_message: Message = {
733            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
734            deserialize(jd).unwrap_or_else(|err| {
735                panic!("Deserialization error at {}: {}", err.path(), err);
736            })
737        };
738
739        let assistant_message2: Message = {
740            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
741            deserialize(jd).unwrap_or_else(|err| {
742                panic!("Deserialization error at {}: {}", err.path(), err);
743            })
744        };
745
746        let user_message: Message = {
747            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
748            deserialize(jd).unwrap_or_else(|err| {
749                panic!("Deserialization error at {}: {}", err.path(), err);
750            })
751        };
752
753        match assistant_message {
754            Message { role, content } => {
755                assert_eq!(role, Role::Assistant);
756                assert_eq!(
757                    content.first(),
758                    Content::Text {
759                        text: "\n\nHello there, how may I assist you today?".to_owned()
760                    }
761                );
762            }
763        }
764
765        match assistant_message2 {
766            Message { role, content } => {
767                assert_eq!(role, Role::Assistant);
768                assert_eq!(content.len(), 2);
769
770                let mut iter = content.into_iter();
771
772                match iter.next().unwrap() {
773                    Content::Text { text } => {
774                        assert_eq!(text, "\n\nHello there, how may I assist you today?");
775                    }
776                    _ => panic!("Expected text content"),
777                }
778
779                match iter.next().unwrap() {
780                    Content::ToolUse { id, name, input } => {
781                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
782                        assert_eq!(name, "get_weather");
783                        assert_eq!(input, json!({"location": "San Francisco, CA"}));
784                    }
785                    _ => panic!("Expected tool use content"),
786                }
787
788                assert_eq!(iter.next(), None);
789            }
790        }
791
792        match user_message {
793            Message { role, content } => {
794                assert_eq!(role, Role::User);
795                assert_eq!(content.len(), 3);
796
797                let mut iter = content.into_iter();
798
799                match iter.next().unwrap() {
800                    Content::Image { source } => {
801                        assert_eq!(
802                            source,
803                            ImageSource {
804                                data: "/9j/4AAQSkZJRg...".to_owned(),
805                                media_type: ImageFormat::JPEG,
806                                r#type: SourceType::BASE64,
807                            }
808                        );
809                    }
810                    _ => panic!("Expected image content"),
811                }
812
813                match iter.next().unwrap() {
814                    Content::Text { text } => {
815                        assert_eq!(text, "What is in this image?");
816                    }
817                    _ => panic!("Expected text content"),
818                }
819
820                match iter.next().unwrap() {
821                    Content::ToolResult {
822                        tool_use_id,
823                        content,
824                        is_error,
825                    } => {
826                        assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
827                        assert_eq!(
828                            content.first(),
829                            ToolResultContent::Text {
830                                text: "15 degrees".to_owned()
831                            }
832                        );
833                        assert_eq!(is_error, None);
834                    }
835                    _ => panic!("Expected tool result content"),
836                }
837
838                assert_eq!(iter.next(), None);
839            }
840        }
841    }
842
843    #[test]
844    fn test_message_to_message_conversion() {
845        let user_message: Message = serde_json::from_str(
846            r#"
847        {
848            "role": "user",
849            "content": [
850                {
851                    "type": "image",
852                    "source": {
853                        "type": "base64",
854                        "media_type": "image/jpeg",
855                        "data": "/9j/4AAQSkZJRg..."
856                    }
857                },
858                {
859                    "type": "text",
860                    "text": "What is in this image?"
861                },
862                {
863                    "type": "document",
864                    "source": {
865                        "type": "base64",
866                        "data": "base64_encoded_pdf_data",
867                        "media_type": "application/pdf"
868                    }
869                }
870            ]
871        }
872        "#,
873        )
874        .unwrap();
875
876        let assistant_message = Message {
877            role: Role::Assistant,
878            content: OneOrMany::one(Content::ToolUse {
879                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
880                name: "get_weather".to_string(),
881                input: json!({"location": "San Francisco, CA"}),
882            }),
883        };
884
885        let tool_message = Message {
886            role: Role::User,
887            content: OneOrMany::one(Content::ToolResult {
888                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
889                content: OneOrMany::one(ToolResultContent::Text {
890                    text: "15 degrees".to_string(),
891                }),
892                is_error: None,
893            }),
894        };
895
896        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
897        let converted_assistant_message: message::Message =
898            assistant_message.clone().try_into().unwrap();
899        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
900
901        match converted_user_message.clone() {
902            message::Message::User { content } => {
903                assert_eq!(content.len(), 3);
904
905                let mut iter = content.into_iter();
906
907                match iter.next().unwrap() {
908                    message::UserContent::Image(message::Image {
909                        data,
910                        format,
911                        media_type,
912                        ..
913                    }) => {
914                        assert_eq!(data, "/9j/4AAQSkZJRg...");
915                        assert_eq!(format.unwrap(), message::ContentFormat::Base64);
916                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
917                    }
918                    _ => panic!("Expected image content"),
919                }
920
921                match iter.next().unwrap() {
922                    message::UserContent::Text(message::Text { text }) => {
923                        assert_eq!(text, "What is in this image?");
924                    }
925                    _ => panic!("Expected text content"),
926                }
927
928                match iter.next().unwrap() {
929                    message::UserContent::Document(message::Document {
930                        data,
931                        format,
932                        media_type,
933                    }) => {
934                        assert_eq!(data, "base64_encoded_pdf_data");
935                        assert_eq!(format.unwrap(), message::ContentFormat::Base64);
936                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
937                    }
938                    _ => panic!("Expected document content"),
939                }
940
941                assert_eq!(iter.next(), None);
942            }
943            _ => panic!("Expected user message"),
944        }
945
946        match converted_tool_message.clone() {
947            message::Message::User { content } => {
948                let message::ToolResult { id, content, .. } = match content.first() {
949                    message::UserContent::ToolResult(tool_result) => tool_result,
950                    _ => panic!("Expected tool result content"),
951                };
952                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
953                match content.first() {
954                    message::ToolResultContent::Text(message::Text { text }) => {
955                        assert_eq!(text, "15 degrees");
956                    }
957                    _ => panic!("Expected text content"),
958                }
959            }
960            _ => panic!("Expected tool result content"),
961        }
962
963        match converted_assistant_message.clone() {
964            message::Message::Assistant { content } => {
965                assert_eq!(content.len(), 1);
966
967                match content.first() {
968                    message::AssistantContent::ToolCall(message::ToolCall { id, function }) => {
969                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
970                        assert_eq!(function.name, "get_weather");
971                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
972                    }
973                    _ => panic!("Expected tool call content"),
974                }
975            }
976            _ => panic!("Expected assistant message"),
977        }
978
979        let original_user_message: Message = converted_user_message.try_into().unwrap();
980        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
981        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
982
983        assert_eq!(user_message, original_user_message);
984        assert_eq!(assistant_message, original_assistant_message);
985        assert_eq!(tool_message, original_tool_message);
986    }
987}