Skip to main content

rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError, GetTokenUsage},
6    http_client::HttpClientExt,
7    message::{self, DocumentMediaType, DocumentSourceKind, MessageError, MimeType, Reasoning},
8    one_or_many::string_or_one_or_many,
9    telemetry::{ProviderResponseExt, SpanCombinator},
10    wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21// ================================================================
22// Anthropic Completion API
23// ================================================================
24
25/// `claude-opus-4-0` completion model
26pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27/// `claude-sonnet-4-0` completion model
28pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29/// `claude-3-7-sonnet-latest` completion model
30pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31/// `claude-3-5-sonnet-latest` completion model
32pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33/// `claude-3-5-haiku-latest` completion model
34pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42    pub content: Vec<Content>,
43    pub id: String,
44    pub model: String,
45    pub role: String,
46    pub stop_reason: Option<String>,
47    pub stop_sequence: Option<String>,
48    pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52    type OutputMessage = Content;
53    type Usage = Usage;
54
55    fn get_response_id(&self) -> Option<String> {
56        Some(self.id.to_owned())
57    }
58
59    fn get_response_model_name(&self) -> Option<String> {
60        Some(self.model.to_owned())
61    }
62
63    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64        self.content.clone()
65    }
66
67    fn get_text_response(&self) -> Option<String> {
68        let res = self
69            .content
70            .iter()
71            .filter_map(|x| {
72                if let Content::Text { text, .. } = x {
73                    Some(text.to_owned())
74                } else {
75                    None
76                }
77            })
78            .collect::<Vec<String>>()
79            .join("\n");
80
81        if res.is_empty() { None } else { Some(res) }
82    }
83
84    fn get_usage(&self) -> Option<Self::Usage> {
85        Some(self.usage.clone())
86    }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91    pub input_tokens: u64,
92    pub cache_read_input_tokens: Option<u64>,
93    pub cache_creation_input_tokens: Option<u64>,
94    pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        write!(
100            f,
101            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102            self.input_tokens,
103            match self.cache_read_input_tokens {
104                Some(token) => token.to_string(),
105                None => "n/a".to_string(),
106            },
107            match self.cache_creation_input_tokens {
108                Some(token) => token.to_string(),
109                None => "n/a".to_string(),
110            },
111            self.output_tokens
112        )
113    }
114}
115
116impl GetTokenUsage for Usage {
117    fn token_usage(&self) -> Option<crate::completion::Usage> {
118        let mut usage = crate::completion::Usage::new();
119
120        usage.input_tokens = self.input_tokens
121            + self.cache_creation_input_tokens.unwrap_or_default()
122            + self.cache_read_input_tokens.unwrap_or_default();
123        usage.output_tokens = self.output_tokens;
124        usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126        Some(usage)
127    }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132    pub name: String,
133    pub description: Option<String>,
134    pub input_schema: serde_json::Value,
135}
136
137/// Cache control directive for Anthropic prompt caching
138#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum CacheControl {
141    Ephemeral,
142}
143
144/// System message content block with optional cache control
145#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
146#[serde(tag = "type", rename_all = "snake_case")]
147pub enum SystemContent {
148    Text {
149        text: String,
150        #[serde(skip_serializing_if = "Option::is_none")]
151        cache_control: Option<CacheControl>,
152    },
153}
154
155impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
156    type Error = CompletionError;
157
158    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
159        let content = response
160            .content
161            .iter()
162            .map(|content| content.clone().try_into())
163            .collect::<Result<Vec<_>, _>>()?;
164
165        let choice = OneOrMany::many(content).map_err(|_| {
166            CompletionError::ResponseError(
167                "Response contained no message or tool call (empty)".to_owned(),
168            )
169        })?;
170
171        let usage = completion::Usage {
172            input_tokens: response.usage.input_tokens,
173            output_tokens: response.usage.output_tokens,
174            total_tokens: response.usage.input_tokens + response.usage.output_tokens,
175            cached_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
176        };
177
178        Ok(completion::CompletionResponse {
179            choice,
180            usage,
181            raw_response: response,
182            message_id: None,
183        })
184    }
185}
186
187#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
188pub struct Message {
189    pub role: Role,
190    #[serde(deserialize_with = "string_or_one_or_many")]
191    pub content: OneOrMany<Content>,
192}
193
194#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
195#[serde(rename_all = "lowercase")]
196pub enum Role {
197    User,
198    Assistant,
199}
200
201#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
202#[serde(tag = "type", rename_all = "snake_case")]
203pub enum Content {
204    Text {
205        text: String,
206        #[serde(skip_serializing_if = "Option::is_none")]
207        cache_control: Option<CacheControl>,
208    },
209    Image {
210        source: ImageSource,
211        #[serde(skip_serializing_if = "Option::is_none")]
212        cache_control: Option<CacheControl>,
213    },
214    ToolUse {
215        id: String,
216        name: String,
217        input: serde_json::Value,
218    },
219    ToolResult {
220        tool_use_id: String,
221        #[serde(deserialize_with = "string_or_one_or_many")]
222        content: OneOrMany<ToolResultContent>,
223        #[serde(skip_serializing_if = "Option::is_none")]
224        is_error: Option<bool>,
225        #[serde(skip_serializing_if = "Option::is_none")]
226        cache_control: Option<CacheControl>,
227    },
228    Document {
229        source: DocumentSource,
230        #[serde(skip_serializing_if = "Option::is_none")]
231        cache_control: Option<CacheControl>,
232    },
233    Thinking {
234        thinking: String,
235        #[serde(skip_serializing_if = "Option::is_none")]
236        signature: Option<String>,
237    },
238    RedactedThinking {
239        data: String,
240    },
241}
242
243impl FromStr for Content {
244    type Err = Infallible;
245
246    fn from_str(s: &str) -> Result<Self, Self::Err> {
247        Ok(Content::Text {
248            text: s.to_owned(),
249            cache_control: None,
250        })
251    }
252}
253
254#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
255#[serde(tag = "type", rename_all = "snake_case")]
256pub enum ToolResultContent {
257    Text { text: String },
258    Image(ImageSource),
259}
260
261impl FromStr for ToolResultContent {
262    type Err = Infallible;
263
264    fn from_str(s: &str) -> Result<Self, Self::Err> {
265        Ok(ToolResultContent::Text { text: s.to_owned() })
266    }
267}
268
269/// The source of an image content block.
270///
271/// Anthropic supports two source types for images:
272/// - `Base64`: Base64-encoded image data with media type
273/// - `Url`: URL reference to an image
274///
275/// See: <https://docs.anthropic.com/en/api/messages>
276#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
277#[serde(tag = "type", rename_all = "snake_case")]
278pub enum ImageSource {
279    #[serde(rename = "base64")]
280    Base64 {
281        data: String,
282        media_type: ImageFormat,
283    },
284    #[serde(rename = "url")]
285    Url { url: String },
286}
287
288/// The source of a document content block.
289///
290/// Anthropic supports multiple source types for documents. Currently implemented:
291/// - `Base64`: Base64-encoded document data (used for PDFs)
292/// - `Text`: Plain text document data
293///
294/// Future variants (not yet implemented):
295/// - URL-based PDF sources
296/// - Content block sources
297/// - File API sources
298#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
299#[serde(tag = "type", rename_all = "snake_case")]
300pub enum DocumentSource {
301    Base64 {
302        data: String,
303        media_type: DocumentFormat,
304    },
305    Text {
306        data: String,
307        media_type: PlainTextMediaType,
308    },
309    Url {
310        url: String,
311    },
312}
313
314#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
315#[serde(rename_all = "lowercase")]
316pub enum ImageFormat {
317    #[serde(rename = "image/jpeg")]
318    JPEG,
319    #[serde(rename = "image/png")]
320    PNG,
321    #[serde(rename = "image/gif")]
322    GIF,
323    #[serde(rename = "image/webp")]
324    WEBP,
325}
326
327/// The media type for base64-encoded documents.
328///
329/// Used with the `DocumentSource::Base64` variant. Currently only PDF is supported
330/// for base64-encoded document sources.
331///
332/// See: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
333#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
334#[serde(rename_all = "lowercase")]
335pub enum DocumentFormat {
336    #[serde(rename = "application/pdf")]
337    PDF,
338}
339
340/// The media type for plain text document sources.
341///
342/// Used with the `DocumentSource::Text` variant.
343///
344/// See: <https://docs.anthropic.com/en/api/messages>
345#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
346pub enum PlainTextMediaType {
347    #[serde(rename = "text/plain")]
348    Plain,
349}
350
351#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
352#[serde(rename_all = "lowercase")]
353pub enum SourceType {
354    BASE64,
355    URL,
356    TEXT,
357}
358
359impl From<String> for Content {
360    fn from(text: String) -> Self {
361        Content::Text {
362            text,
363            cache_control: None,
364        }
365    }
366}
367
368impl From<String> for ToolResultContent {
369    fn from(text: String) -> Self {
370        ToolResultContent::Text { text }
371    }
372}
373
374impl TryFrom<message::ContentFormat> for SourceType {
375    type Error = MessageError;
376
377    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
378        match format {
379            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
380            message::ContentFormat::Url => Ok(SourceType::URL),
381            message::ContentFormat::String => Ok(SourceType::TEXT),
382        }
383    }
384}
385
386impl From<SourceType> for message::ContentFormat {
387    fn from(source_type: SourceType) -> Self {
388        match source_type {
389            SourceType::BASE64 => message::ContentFormat::Base64,
390            SourceType::URL => message::ContentFormat::Url,
391            SourceType::TEXT => message::ContentFormat::String,
392        }
393    }
394}
395
396impl TryFrom<message::ImageMediaType> for ImageFormat {
397    type Error = MessageError;
398
399    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
400        Ok(match media_type {
401            message::ImageMediaType::JPEG => ImageFormat::JPEG,
402            message::ImageMediaType::PNG => ImageFormat::PNG,
403            message::ImageMediaType::GIF => ImageFormat::GIF,
404            message::ImageMediaType::WEBP => ImageFormat::WEBP,
405            _ => {
406                return Err(MessageError::ConversionError(
407                    format!("Unsupported image media type: {media_type:?}").to_owned(),
408                ));
409            }
410        })
411    }
412}
413
414impl From<ImageFormat> for message::ImageMediaType {
415    fn from(format: ImageFormat) -> Self {
416        match format {
417            ImageFormat::JPEG => message::ImageMediaType::JPEG,
418            ImageFormat::PNG => message::ImageMediaType::PNG,
419            ImageFormat::GIF => message::ImageMediaType::GIF,
420            ImageFormat::WEBP => message::ImageMediaType::WEBP,
421        }
422    }
423}
424
425impl TryFrom<DocumentMediaType> for DocumentFormat {
426    type Error = MessageError;
427    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
428        match value {
429            DocumentMediaType::PDF => Ok(DocumentFormat::PDF),
430            other => Err(MessageError::ConversionError(format!(
431                "DocumentFormat only supports PDF for base64 sources, got: {}",
432                other.to_mime_type()
433            ))),
434        }
435    }
436}
437
438impl TryFrom<message::AssistantContent> for Content {
439    type Error = MessageError;
440    fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
441        match text {
442            message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
443                text,
444                cache_control: None,
445            }),
446            message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
447                "Anthropic currently doesn't support images.".to_string(),
448            )),
449            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
450                Ok(Content::ToolUse {
451                    id,
452                    name: function.name,
453                    input: function.arguments,
454                })
455            }
456            message::AssistantContent::Reasoning(reasoning) => Ok(Content::Thinking {
457                thinking: reasoning.display_text(),
458                signature: reasoning.first_signature().map(str::to_owned),
459            }),
460        }
461    }
462}
463
464fn anthropic_content_from_assistant_content(
465    content: message::AssistantContent,
466) -> Result<Vec<Content>, MessageError> {
467    match content {
468        message::AssistantContent::Text(message::Text { text }) => Ok(vec![Content::Text {
469            text,
470            cache_control: None,
471        }]),
472        message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
473            "Anthropic currently doesn't support images.".to_string(),
474        )),
475        message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
476            Ok(vec![Content::ToolUse {
477                id,
478                name: function.name,
479                input: function.arguments,
480            }])
481        }
482        message::AssistantContent::Reasoning(reasoning) => {
483            let mut converted = Vec::new();
484            for block in reasoning.content {
485                match block {
486                    message::ReasoningContent::Text { text, signature } => {
487                        converted.push(Content::Thinking {
488                            thinking: text,
489                            signature,
490                        });
491                    }
492                    message::ReasoningContent::Summary(summary) => {
493                        converted.push(Content::Thinking {
494                            thinking: summary,
495                            signature: None,
496                        });
497                    }
498                    message::ReasoningContent::Redacted { data }
499                    | message::ReasoningContent::Encrypted(data) => {
500                        converted.push(Content::RedactedThinking { data });
501                    }
502                }
503            }
504
505            if converted.is_empty() {
506                return Err(MessageError::ConversionError(
507                    "Cannot convert empty reasoning content to Anthropic format".to_string(),
508                ));
509            }
510
511            Ok(converted)
512        }
513    }
514}
515
516impl TryFrom<message::Message> for Message {
517    type Error = MessageError;
518
519    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
520        Ok(match message {
521            message::Message::User { content } => Message {
522                role: Role::User,
523                content: content.try_map(|content| match content {
524                    message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
525                        text,
526                        cache_control: None,
527                    }),
528                    message::UserContent::ToolResult(message::ToolResult {
529                        id, content, ..
530                    }) => Ok(Content::ToolResult {
531                        tool_use_id: id,
532                        content: content.try_map(|content| match content {
533                            message::ToolResultContent::Text(message::Text { text }) => {
534                                Ok(ToolResultContent::Text { text })
535                            }
536                            message::ToolResultContent::Image(image) => {
537                                let DocumentSourceKind::Base64(data) = image.data else {
538                                    return Err(MessageError::ConversionError(
539                                        "Only base64 strings can be used with the Anthropic API"
540                                            .to_string(),
541                                    ));
542                                };
543                                let media_type =
544                                    image.media_type.ok_or(MessageError::ConversionError(
545                                        "Image media type is required".to_owned(),
546                                    ))?;
547                                Ok(ToolResultContent::Image(ImageSource::Base64 {
548                                    data,
549                                    media_type: media_type.try_into()?,
550                                }))
551                            }
552                        })?,
553                        is_error: None,
554                        cache_control: None,
555                    }),
556                    message::UserContent::Image(message::Image {
557                        data, media_type, ..
558                    }) => {
559                        let source = match data {
560                            DocumentSourceKind::Base64(data) => {
561                                let media_type =
562                                    media_type.ok_or(MessageError::ConversionError(
563                                        "Image media type is required for Claude API".to_string(),
564                                    ))?;
565                                ImageSource::Base64 {
566                                    data,
567                                    media_type: ImageFormat::try_from(media_type)?,
568                                }
569                            }
570                            DocumentSourceKind::Url(url) => ImageSource::Url { url },
571                            DocumentSourceKind::Unknown => {
572                                return Err(MessageError::ConversionError(
573                                    "Image content has no body".into(),
574                                ));
575                            }
576                            doc => {
577                                return Err(MessageError::ConversionError(format!(
578                                    "Unsupported document type: {doc:?}"
579                                )));
580                            }
581                        };
582
583                        Ok(Content::Image {
584                            source,
585                            cache_control: None,
586                        })
587                    }
588                    message::UserContent::Document(message::Document {
589                        data, media_type, ..
590                    }) => {
591                        let media_type = media_type.ok_or(MessageError::ConversionError(
592                            "Document media type is required".to_string(),
593                        ))?;
594
595                        let source = match media_type {
596                            DocumentMediaType::PDF => {
597                                let data = match data {
598                                    DocumentSourceKind::Base64(data)
599                                    | DocumentSourceKind::String(data) => data,
600                                    _ => {
601                                        return Err(MessageError::ConversionError(
602                                            "Only base64 encoded data is supported for PDF documents".into(),
603                                        ));
604                                    }
605                                };
606                                DocumentSource::Base64 {
607                                    data,
608                                    media_type: DocumentFormat::PDF,
609                                }
610                            }
611                            DocumentMediaType::TXT => {
612                                let data = match data {
613                                    DocumentSourceKind::String(data)
614                                    | DocumentSourceKind::Base64(data) => data,
615                                    _ => {
616                                        return Err(MessageError::ConversionError(
617                                            "Only string or base64 data is supported for plain text documents".into(),
618                                        ));
619                                    }
620                                };
621                                DocumentSource::Text {
622                                    data,
623                                    media_type: PlainTextMediaType::Plain,
624                                }
625                            }
626                            other => {
627                                return Err(MessageError::ConversionError(format!(
628                                    "Anthropic only supports PDF and plain text documents, got: {}",
629                                    other.to_mime_type()
630                                )));
631                            }
632                        };
633
634                        Ok(Content::Document {
635                            source,
636                            cache_control: None,
637                        })
638                    }
639                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
640                        "Audio is not supported in Anthropic".to_owned(),
641                    )),
642                    message::UserContent::Video { .. } => Err(MessageError::ConversionError(
643                        "Video is not supported in Anthropic".to_owned(),
644                    )),
645                })?,
646            },
647
648            message::Message::System { content } => Message {
649                role: Role::User,
650                content: OneOrMany::one(Content::Text {
651                    text: content,
652                    cache_control: None,
653                }),
654            },
655
656            message::Message::Assistant { content, .. } => {
657                let converted_content = content.into_iter().try_fold(
658                    Vec::new(),
659                    |mut accumulated, assistant_content| {
660                        accumulated
661                            .extend(anthropic_content_from_assistant_content(assistant_content)?);
662                        Ok::<Vec<Content>, MessageError>(accumulated)
663                    },
664                )?;
665
666                Message {
667                    content: OneOrMany::many(converted_content).map_err(|_| {
668                        MessageError::ConversionError(
669                            "Assistant message did not contain Anthropic-compatible content"
670                                .to_owned(),
671                        )
672                    })?,
673                    role: Role::Assistant,
674                }
675            }
676        })
677    }
678}
679
680impl TryFrom<Content> for message::AssistantContent {
681    type Error = MessageError;
682
683    fn try_from(content: Content) -> Result<Self, Self::Error> {
684        Ok(match content {
685            Content::Text { text, .. } => message::AssistantContent::text(text),
686            Content::ToolUse { id, name, input } => {
687                message::AssistantContent::tool_call(id, name, input)
688            }
689            Content::Thinking {
690                thinking,
691                signature,
692            } => message::AssistantContent::Reasoning(Reasoning::new_with_signature(
693                &thinking, signature,
694            )),
695            Content::RedactedThinking { data } => {
696                message::AssistantContent::Reasoning(Reasoning::redacted(data))
697            }
698            _ => {
699                return Err(MessageError::ConversionError(
700                    "Content did not contain a message, tool call, or reasoning".to_owned(),
701                ));
702            }
703        })
704    }
705}
706
707impl From<ToolResultContent> for message::ToolResultContent {
708    fn from(content: ToolResultContent) -> Self {
709        match content {
710            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
711            ToolResultContent::Image(source) => match source {
712                ImageSource::Base64 { data, media_type } => {
713                    message::ToolResultContent::image_base64(data, Some(media_type.into()), None)
714                }
715                ImageSource::Url { url } => message::ToolResultContent::image_url(url, None, None),
716            },
717        }
718    }
719}
720
721impl TryFrom<Message> for message::Message {
722    type Error = MessageError;
723
724    fn try_from(message: Message) -> Result<Self, Self::Error> {
725        Ok(match message.role {
726            Role::User => message::Message::User {
727                content: message.content.try_map(|content| {
728                    Ok(match content {
729                        Content::Text { text, .. } => message::UserContent::text(text),
730                        Content::ToolResult {
731                            tool_use_id,
732                            content,
733                            ..
734                        } => message::UserContent::tool_result(
735                            tool_use_id,
736                            content.map(|content| content.into()),
737                        ),
738                        Content::Image { source, .. } => match source {
739                            ImageSource::Base64 { data, media_type } => {
740                                message::UserContent::Image(message::Image {
741                                    data: DocumentSourceKind::Base64(data),
742                                    media_type: Some(media_type.into()),
743                                    detail: None,
744                                    additional_params: None,
745                                })
746                            }
747                            ImageSource::Url { url } => {
748                                message::UserContent::Image(message::Image {
749                                    data: DocumentSourceKind::Url(url),
750                                    media_type: None,
751                                    detail: None,
752                                    additional_params: None,
753                                })
754                            }
755                        },
756                        Content::Document { source, .. } => match source {
757                            DocumentSource::Base64 { data, media_type } => {
758                                let rig_media_type = match media_type {
759                                    DocumentFormat::PDF => message::DocumentMediaType::PDF,
760                                };
761                                message::UserContent::document(data, Some(rig_media_type))
762                            }
763                            DocumentSource::Text { data, .. } => message::UserContent::document(
764                                data,
765                                Some(message::DocumentMediaType::TXT),
766                            ),
767                            DocumentSource::Url { url } => {
768                                message::UserContent::document_url(url, None)
769                            }
770                        },
771                        _ => {
772                            return Err(MessageError::ConversionError(
773                                "Unsupported content type for User role".to_owned(),
774                            ));
775                        }
776                    })
777                })?,
778            },
779            Role::Assistant => message::Message::Assistant {
780                id: None,
781                content: message.content.try_map(|content| content.try_into())?,
782            },
783        })
784    }
785}
786
787#[derive(Clone)]
788pub struct CompletionModel<T = reqwest::Client> {
789    pub(crate) client: Client<T>,
790    pub model: String,
791    pub default_max_tokens: Option<u64>,
792    /// Enable automatic prompt caching (adds cache_control breakpoints to system prompt and messages)
793    pub prompt_caching: bool,
794}
795
796impl<T> CompletionModel<T>
797where
798    T: HttpClientExt,
799{
800    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
801        let model = model.into();
802        let default_max_tokens = calculate_max_tokens(&model);
803
804        Self {
805            client,
806            model,
807            default_max_tokens,
808            prompt_caching: false, // Default to off
809        }
810    }
811
812    pub fn with_model(client: Client<T>, model: &str) -> Self {
813        Self {
814            client,
815            model: model.to_string(),
816            default_max_tokens: Some(calculate_max_tokens_custom(model)),
817            prompt_caching: false, // Default to off
818        }
819    }
820
821    /// Enable automatic prompt caching.
822    ///
823    /// When enabled, cache_control breakpoints are automatically added to:
824    /// - The system prompt (marked with ephemeral cache)
825    /// - The last content block of the last message (marked with ephemeral cache)
826    ///
827    /// This allows Anthropic to cache the conversation history for cost savings.
828    pub fn with_prompt_caching(mut self) -> Self {
829        self.prompt_caching = true;
830        self
831    }
832}
833
834/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
835/// set or if set too high, the request will fail. The following values are based on the models
836/// available at the time of writing.
837fn calculate_max_tokens(model: &str) -> Option<u64> {
838    if model.starts_with("claude-opus-4") {
839        Some(32000)
840    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
841        Some(64000)
842    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
843        Some(8192)
844    } else if model.starts_with("claude-3-opus")
845        || model.starts_with("claude-3-sonnet")
846        || model.starts_with("claude-3-haiku")
847    {
848        Some(4096)
849    } else {
850        None
851    }
852}
853
854fn calculate_max_tokens_custom(model: &str) -> u64 {
855    if model.starts_with("claude-opus-4") {
856        32000
857    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
858        64000
859    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
860        8192
861    } else if model.starts_with("claude-3-opus")
862        || model.starts_with("claude-3-sonnet")
863        || model.starts_with("claude-3-haiku")
864    {
865        4096
866    } else {
867        2048
868    }
869}
870
871#[derive(Debug, Deserialize, Serialize)]
872pub struct Metadata {
873    user_id: Option<String>,
874}
875
876#[derive(Default, Debug, Serialize, Deserialize)]
877#[serde(tag = "type", rename_all = "snake_case")]
878pub enum ToolChoice {
879    #[default]
880    Auto,
881    Any,
882    None,
883    Tool {
884        name: String,
885    },
886}
887impl TryFrom<message::ToolChoice> for ToolChoice {
888    type Error = CompletionError;
889
890    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
891        let res = match value {
892            message::ToolChoice::Auto => Self::Auto,
893            message::ToolChoice::None => Self::None,
894            message::ToolChoice::Required => Self::Any,
895            message::ToolChoice::Specific { function_names } => {
896                if function_names.len() != 1 {
897                    return Err(CompletionError::ProviderError(
898                        "Only one tool may be specified to be used by Claude".into(),
899                    ));
900                }
901
902                Self::Tool {
903                    name: function_names.first().unwrap().to_string(),
904                }
905            }
906        };
907
908        Ok(res)
909    }
910}
911
912/// Recursively ensures all object schemas respect Anthropic structured output restrictions:
913/// - `additionalProperties` must be explicitly set to `false` on every object
914/// - All properties must be listed in `required`
915///
916/// Source: <https://docs.anthropic.com/en/docs/build-with-claude/structured-outputs#json-schema-limitations>
917fn sanitize_schema(schema: &mut serde_json::Value) {
918    use serde_json::Value;
919
920    if let Value::Object(obj) = schema {
921        let is_object_schema = obj.get("type") == Some(&Value::String("object".to_string()))
922            || obj.contains_key("properties");
923
924        if is_object_schema && !obj.contains_key("additionalProperties") {
925            obj.insert("additionalProperties".to_string(), Value::Bool(false));
926        }
927
928        if let Some(Value::Object(properties)) = obj.get("properties") {
929            let prop_keys = properties.keys().cloned().map(Value::String).collect();
930            obj.insert("required".to_string(), Value::Array(prop_keys));
931        }
932
933        // Anthropic does not support numerical constraints on integer/number types.
934        let is_numeric_schema = obj.get("type") == Some(&Value::String("integer".to_string()))
935            || obj.get("type") == Some(&Value::String("number".to_string()));
936
937        if is_numeric_schema {
938            for key in [
939                "minimum",
940                "maximum",
941                "exclusiveMinimum",
942                "exclusiveMaximum",
943                "multipleOf",
944            ] {
945                obj.remove(key);
946            }
947        }
948
949        if let Some(defs) = obj.get_mut("$defs")
950            && let Value::Object(defs_obj) = defs
951        {
952            for (_, def_schema) in defs_obj.iter_mut() {
953                sanitize_schema(def_schema);
954            }
955        }
956
957        if let Some(properties) = obj.get_mut("properties")
958            && let Value::Object(props) = properties
959        {
960            for (_, prop_value) in props.iter_mut() {
961                sanitize_schema(prop_value);
962            }
963        }
964
965        if let Some(items) = obj.get_mut("items") {
966            sanitize_schema(items);
967        }
968
969        for key in ["anyOf", "oneOf", "allOf"] {
970            if let Some(variants) = obj.get_mut(key)
971                && let Value::Array(variants_array) = variants
972            {
973                for variant in variants_array.iter_mut() {
974                    sanitize_schema(variant);
975                }
976            }
977        }
978    }
979}
980
981/// Output format specifier for Anthropic's structured output.
982/// Source: <https://docs.anthropic.com/en/api/messages>
983#[derive(Debug, Deserialize, Serialize)]
984#[serde(tag = "type", rename_all = "snake_case")]
985enum OutputFormat {
986    /// Constrains the model's response to conform to the provided JSON schema.
987    JsonSchema { schema: serde_json::Value },
988}
989
990/// Configuration for the model's output format.
991#[derive(Debug, Deserialize, Serialize)]
992struct OutputConfig {
993    format: OutputFormat,
994}
995
996#[derive(Debug, Deserialize, Serialize)]
997struct AnthropicCompletionRequest {
998    model: String,
999    messages: Vec<Message>,
1000    max_tokens: u64,
1001    /// System prompt as array of content blocks to support cache_control
1002    #[serde(skip_serializing_if = "Vec::is_empty")]
1003    system: Vec<SystemContent>,
1004    #[serde(skip_serializing_if = "Option::is_none")]
1005    temperature: Option<f64>,
1006    #[serde(skip_serializing_if = "Option::is_none")]
1007    tool_choice: Option<ToolChoice>,
1008    #[serde(skip_serializing_if = "Vec::is_empty")]
1009    tools: Vec<serde_json::Value>,
1010    #[serde(skip_serializing_if = "Option::is_none")]
1011    output_config: Option<OutputConfig>,
1012    #[serde(flatten, skip_serializing_if = "Option::is_none")]
1013    additional_params: Option<serde_json::Value>,
1014}
1015
1016/// Helper to set cache_control on a Content block
1017fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
1018    match content {
1019        Content::Text { cache_control, .. } => *cache_control = value,
1020        Content::Image { cache_control, .. } => *cache_control = value,
1021        Content::ToolResult { cache_control, .. } => *cache_control = value,
1022        Content::Document { cache_control, .. } => *cache_control = value,
1023        _ => {}
1024    }
1025}
1026
1027/// Apply cache control breakpoints to system prompt and messages.
1028/// Strategy: cache the system prompt, and mark the last content block of the last message
1029/// for caching. This allows the conversation history to be cached while new messages
1030/// are added.
1031pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
1032    // Add cache_control to the system prompt (if non-empty)
1033    if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
1034        *cache_control = Some(CacheControl::Ephemeral);
1035    }
1036
1037    // Clear any existing cache_control from all message content blocks
1038    for msg in messages.iter_mut() {
1039        for content in msg.content.iter_mut() {
1040            set_content_cache_control(content, None);
1041        }
1042    }
1043
1044    // Add cache_control to the last content block of the last message
1045    if let Some(last_msg) = messages.last_mut() {
1046        set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::Ephemeral));
1047    }
1048}
1049
1050pub(super) fn split_system_messages_from_history(
1051    history: Vec<message::Message>,
1052) -> (Vec<SystemContent>, Vec<message::Message>) {
1053    let mut system = Vec::new();
1054    let mut remaining = Vec::new();
1055
1056    for message in history {
1057        match message {
1058            message::Message::System { content } => {
1059                if !content.is_empty() {
1060                    system.push(SystemContent::Text {
1061                        text: content,
1062                        cache_control: None,
1063                    });
1064                }
1065            }
1066            other => remaining.push(other),
1067        }
1068    }
1069
1070    (system, remaining)
1071}
1072
1073/// Parameters for building an AnthropicCompletionRequest
1074pub struct AnthropicRequestParams<'a> {
1075    pub model: &'a str,
1076    pub request: CompletionRequest,
1077    pub prompt_caching: bool,
1078}
1079
1080impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
1081    type Error = CompletionError;
1082
1083    fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
1084        let AnthropicRequestParams {
1085            model,
1086            request: mut req,
1087            prompt_caching,
1088        } = params;
1089
1090        // Check if max_tokens is set, required for Anthropic
1091        let Some(max_tokens) = req.max_tokens else {
1092            return Err(CompletionError::RequestError(
1093                "`max_tokens` must be set for Anthropic".into(),
1094            ));
1095        };
1096
1097        let mut full_history = vec![];
1098        if let Some(docs) = req.normalized_documents() {
1099            full_history.push(docs);
1100        }
1101        full_history.extend(req.chat_history);
1102        let (history_system, full_history) = split_system_messages_from_history(full_history);
1103
1104        let mut messages = full_history
1105            .into_iter()
1106            .map(Message::try_from)
1107            .collect::<Result<Vec<Message>, _>>()?;
1108
1109        let mut additional_params_payload = req
1110            .additional_params
1111            .take()
1112            .unwrap_or(serde_json::Value::Null);
1113        let mut additional_tools =
1114            extract_tools_from_additional_params(&mut additional_params_payload)?;
1115
1116        let mut tools = req
1117            .tools
1118            .into_iter()
1119            .map(|tool| ToolDefinition {
1120                name: tool.name,
1121                description: Some(tool.description),
1122                input_schema: tool.parameters,
1123            })
1124            .map(serde_json::to_value)
1125            .collect::<Result<Vec<_>, _>>()?;
1126        tools.append(&mut additional_tools);
1127
1128        // Convert system prompt to array format for cache_control support
1129        let mut system = if let Some(preamble) = req.preamble {
1130            if preamble.is_empty() {
1131                vec![]
1132            } else {
1133                vec![SystemContent::Text {
1134                    text: preamble,
1135                    cache_control: None,
1136                }]
1137            }
1138        } else {
1139            vec![]
1140        };
1141        system.extend(history_system);
1142
1143        // Apply cache control breakpoints only if prompt_caching is enabled
1144        if prompt_caching {
1145            apply_cache_control(&mut system, &mut messages);
1146        }
1147
1148        // Map output_schema to Anthropic's output_config field
1149        let output_config = req.output_schema.map(|schema| {
1150            let mut schema_value = schema.to_value();
1151            sanitize_schema(&mut schema_value);
1152            OutputConfig {
1153                format: OutputFormat::JsonSchema {
1154                    schema: schema_value,
1155                },
1156            }
1157        });
1158
1159        Ok(Self {
1160            model: model.to_string(),
1161            messages,
1162            max_tokens,
1163            system,
1164            temperature: req.temperature,
1165            tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
1166            tools,
1167            output_config,
1168            additional_params: if additional_params_payload.is_null() {
1169                None
1170            } else {
1171                Some(additional_params_payload)
1172            },
1173        })
1174    }
1175}
1176
1177fn extract_tools_from_additional_params(
1178    additional_params: &mut serde_json::Value,
1179) -> Result<Vec<serde_json::Value>, CompletionError> {
1180    if let Some(map) = additional_params.as_object_mut()
1181        && let Some(raw_tools) = map.remove("tools")
1182    {
1183        return serde_json::from_value::<Vec<serde_json::Value>>(raw_tools).map_err(|err| {
1184            CompletionError::RequestError(
1185                format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(),
1186            )
1187        });
1188    }
1189
1190    Ok(Vec::new())
1191}
1192
1193impl<T> completion::CompletionModel for CompletionModel<T>
1194where
1195    T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
1196{
1197    type Response = CompletionResponse;
1198    type StreamingResponse = StreamingCompletionResponse;
1199    type Client = Client<T>;
1200
1201    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1202        Self::new(client.clone(), model.into())
1203    }
1204
1205    async fn completion(
1206        &self,
1207        mut completion_request: completion::CompletionRequest,
1208    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1209        let request_model = completion_request
1210            .model
1211            .clone()
1212            .unwrap_or_else(|| self.model.clone());
1213        let span = if tracing::Span::current().is_disabled() {
1214            info_span!(
1215                target: "rig::completions",
1216                "chat",
1217                gen_ai.operation.name = "chat",
1218                gen_ai.provider.name = "anthropic",
1219                gen_ai.request.model = &request_model,
1220                gen_ai.system_instructions = &completion_request.preamble,
1221                gen_ai.response.id = tracing::field::Empty,
1222                gen_ai.response.model = tracing::field::Empty,
1223                gen_ai.usage.output_tokens = tracing::field::Empty,
1224                gen_ai.usage.input_tokens = tracing::field::Empty,
1225                gen_ai.usage.cached_tokens = tracing::field::Empty,
1226            )
1227        } else {
1228            tracing::Span::current()
1229        };
1230
1231        // Check if max_tokens is set, required for Anthropic
1232        if completion_request.max_tokens.is_none() {
1233            if let Some(tokens) = self.default_max_tokens {
1234                completion_request.max_tokens = Some(tokens);
1235            } else {
1236                return Err(CompletionError::RequestError(
1237                    "`max_tokens` must be set for Anthropic".into(),
1238                ));
1239            }
1240        }
1241
1242        let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
1243            model: &request_model,
1244            request: completion_request,
1245            prompt_caching: self.prompt_caching,
1246        })?;
1247
1248        if enabled!(Level::TRACE) {
1249            tracing::trace!(
1250                target: "rig::completions",
1251                "Anthropic completion request: {}",
1252                serde_json::to_string_pretty(&request)?
1253            );
1254        }
1255
1256        async move {
1257            let request: Vec<u8> = serde_json::to_vec(&request)?;
1258
1259            let req = self
1260                .client
1261                .post("/v1/messages")?
1262                .body(request)
1263                .map_err(|e| CompletionError::HttpError(e.into()))?;
1264
1265            let response = self
1266                .client
1267                .send::<_, Bytes>(req)
1268                .await
1269                .map_err(CompletionError::HttpError)?;
1270
1271            if response.status().is_success() {
1272                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
1273                    response
1274                        .into_body()
1275                        .await
1276                        .map_err(CompletionError::HttpError)?
1277                        .to_vec()
1278                        .as_slice(),
1279                )? {
1280                    ApiResponse::Message(completion) => {
1281                        let span = tracing::Span::current();
1282                        span.record_response_metadata(&completion);
1283                        span.record_token_usage(&completion.usage);
1284                        if enabled!(Level::TRACE) {
1285                            tracing::trace!(
1286                                target: "rig::completions",
1287                                "Anthropic completion response: {}",
1288                                serde_json::to_string_pretty(&completion)?
1289                            );
1290                        }
1291                        completion.try_into()
1292                    }
1293                    ApiResponse::Error(ApiErrorResponse { message }) => {
1294                        Err(CompletionError::ResponseError(message))
1295                    }
1296                }
1297            } else {
1298                let text: String = String::from_utf8_lossy(
1299                    &response
1300                        .into_body()
1301                        .await
1302                        .map_err(CompletionError::HttpError)?,
1303                )
1304                .into();
1305                Err(CompletionError::ProviderError(text))
1306            }
1307        }
1308        .instrument(span)
1309        .await
1310    }
1311
1312    async fn stream(
1313        &self,
1314        request: CompletionRequest,
1315    ) -> Result<
1316        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1317        CompletionError,
1318    > {
1319        CompletionModel::stream(self, request).await
1320    }
1321}
1322
1323#[derive(Debug, Deserialize)]
1324struct ApiErrorResponse {
1325    message: String,
1326}
1327
1328#[derive(Debug, Deserialize)]
1329#[serde(tag = "type", rename_all = "snake_case")]
1330enum ApiResponse<T> {
1331    Message(T),
1332    Error(ApiErrorResponse),
1333}
1334
1335#[cfg(test)]
1336mod tests {
1337    use super::*;
1338    use serde_json::json;
1339    use serde_path_to_error::deserialize;
1340
1341    #[test]
1342    fn test_deserialize_message() {
1343        let assistant_message_json = r#"
1344        {
1345            "role": "assistant",
1346            "content": "\n\nHello there, how may I assist you today?"
1347        }
1348        "#;
1349
1350        let assistant_message_json2 = r#"
1351        {
1352            "role": "assistant",
1353            "content": [
1354                {
1355                    "type": "text",
1356                    "text": "\n\nHello there, how may I assist you today?"
1357                },
1358                {
1359                    "type": "tool_use",
1360                    "id": "toolu_01A09q90qw90lq917835lq9",
1361                    "name": "get_weather",
1362                    "input": {"location": "San Francisco, CA"}
1363                }
1364            ]
1365        }
1366        "#;
1367
1368        let user_message_json = r#"
1369        {
1370            "role": "user",
1371            "content": [
1372                {
1373                    "type": "image",
1374                    "source": {
1375                        "type": "base64",
1376                        "media_type": "image/jpeg",
1377                        "data": "/9j/4AAQSkZJRg..."
1378                    }
1379                },
1380                {
1381                    "type": "text",
1382                    "text": "What is in this image?"
1383                },
1384                {
1385                    "type": "tool_result",
1386                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1387                    "content": "15 degrees"
1388                }
1389            ]
1390        }
1391        "#;
1392
1393        let assistant_message: Message = {
1394            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1395            deserialize(jd).unwrap_or_else(|err| {
1396                panic!("Deserialization error at {}: {}", err.path(), err);
1397            })
1398        };
1399
1400        let assistant_message2: Message = {
1401            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1402            deserialize(jd).unwrap_or_else(|err| {
1403                panic!("Deserialization error at {}: {}", err.path(), err);
1404            })
1405        };
1406
1407        let user_message: Message = {
1408            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1409            deserialize(jd).unwrap_or_else(|err| {
1410                panic!("Deserialization error at {}: {}", err.path(), err);
1411            })
1412        };
1413
1414        let Message { role, content } = assistant_message;
1415        assert_eq!(role, Role::Assistant);
1416        assert_eq!(
1417            content.first(),
1418            Content::Text {
1419                text: "\n\nHello there, how may I assist you today?".to_owned(),
1420                cache_control: None,
1421            }
1422        );
1423
1424        let Message { role, content } = assistant_message2;
1425        {
1426            assert_eq!(role, Role::Assistant);
1427            assert_eq!(content.len(), 2);
1428
1429            let mut iter = content.into_iter();
1430
1431            match iter.next().unwrap() {
1432                Content::Text { text, .. } => {
1433                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
1434                }
1435                _ => panic!("Expected text content"),
1436            }
1437
1438            match iter.next().unwrap() {
1439                Content::ToolUse { id, name, input } => {
1440                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1441                    assert_eq!(name, "get_weather");
1442                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
1443                }
1444                _ => panic!("Expected tool use content"),
1445            }
1446
1447            assert_eq!(iter.next(), None);
1448        }
1449
1450        let Message { role, content } = user_message;
1451        {
1452            assert_eq!(role, Role::User);
1453            assert_eq!(content.len(), 3);
1454
1455            let mut iter = content.into_iter();
1456
1457            match iter.next().unwrap() {
1458                Content::Image { source, .. } => {
1459                    assert_eq!(
1460                        source,
1461                        ImageSource::Base64 {
1462                            data: "/9j/4AAQSkZJRg...".to_owned(),
1463                            media_type: ImageFormat::JPEG,
1464                        }
1465                    );
1466                }
1467                _ => panic!("Expected image content"),
1468            }
1469
1470            match iter.next().unwrap() {
1471                Content::Text { text, .. } => {
1472                    assert_eq!(text, "What is in this image?");
1473                }
1474                _ => panic!("Expected text content"),
1475            }
1476
1477            match iter.next().unwrap() {
1478                Content::ToolResult {
1479                    tool_use_id,
1480                    content,
1481                    is_error,
1482                    ..
1483                } => {
1484                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1485                    assert_eq!(
1486                        content.first(),
1487                        ToolResultContent::Text {
1488                            text: "15 degrees".to_owned()
1489                        }
1490                    );
1491                    assert_eq!(is_error, None);
1492                }
1493                _ => panic!("Expected tool result content"),
1494            }
1495
1496            assert_eq!(iter.next(), None);
1497        }
1498    }
1499
1500    #[test]
1501    fn test_message_to_message_conversion() {
1502        let user_message: Message = serde_json::from_str(
1503            r#"
1504        {
1505            "role": "user",
1506            "content": [
1507                {
1508                    "type": "image",
1509                    "source": {
1510                        "type": "base64",
1511                        "media_type": "image/jpeg",
1512                        "data": "/9j/4AAQSkZJRg..."
1513                    }
1514                },
1515                {
1516                    "type": "text",
1517                    "text": "What is in this image?"
1518                },
1519                {
1520                    "type": "document",
1521                    "source": {
1522                        "type": "base64",
1523                        "data": "base64_encoded_pdf_data",
1524                        "media_type": "application/pdf"
1525                    }
1526                }
1527            ]
1528        }
1529        "#,
1530        )
1531        .unwrap();
1532
1533        let assistant_message = Message {
1534            role: Role::Assistant,
1535            content: OneOrMany::one(Content::ToolUse {
1536                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1537                name: "get_weather".to_string(),
1538                input: json!({"location": "San Francisco, CA"}),
1539            }),
1540        };
1541
1542        let tool_message = Message {
1543            role: Role::User,
1544            content: OneOrMany::one(Content::ToolResult {
1545                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1546                content: OneOrMany::one(ToolResultContent::Text {
1547                    text: "15 degrees".to_string(),
1548                }),
1549                is_error: None,
1550                cache_control: None,
1551            }),
1552        };
1553
1554        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1555        let converted_assistant_message: message::Message =
1556            assistant_message.clone().try_into().unwrap();
1557        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1558
1559        match converted_user_message.clone() {
1560            message::Message::User { content } => {
1561                assert_eq!(content.len(), 3);
1562
1563                let mut iter = content.into_iter();
1564
1565                match iter.next().unwrap() {
1566                    message::UserContent::Image(message::Image {
1567                        data, media_type, ..
1568                    }) => {
1569                        assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1570                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1571                    }
1572                    _ => panic!("Expected image content"),
1573                }
1574
1575                match iter.next().unwrap() {
1576                    message::UserContent::Text(message::Text { text }) => {
1577                        assert_eq!(text, "What is in this image?");
1578                    }
1579                    _ => panic!("Expected text content"),
1580                }
1581
1582                match iter.next().unwrap() {
1583                    message::UserContent::Document(message::Document {
1584                        data, media_type, ..
1585                    }) => {
1586                        assert_eq!(
1587                            data,
1588                            DocumentSourceKind::String("base64_encoded_pdf_data".into())
1589                        );
1590                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1591                    }
1592                    _ => panic!("Expected document content"),
1593                }
1594
1595                assert_eq!(iter.next(), None);
1596            }
1597            _ => panic!("Expected user message"),
1598        }
1599
1600        match converted_tool_message.clone() {
1601            message::Message::User { content } => {
1602                let message::ToolResult { id, content, .. } = match content.first() {
1603                    message::UserContent::ToolResult(tool_result) => tool_result,
1604                    _ => panic!("Expected tool result content"),
1605                };
1606                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1607                match content.first() {
1608                    message::ToolResultContent::Text(message::Text { text }) => {
1609                        assert_eq!(text, "15 degrees");
1610                    }
1611                    _ => panic!("Expected text content"),
1612                }
1613            }
1614            _ => panic!("Expected tool result content"),
1615        }
1616
1617        match converted_assistant_message.clone() {
1618            message::Message::Assistant { content, .. } => {
1619                assert_eq!(content.len(), 1);
1620
1621                match content.first() {
1622                    message::AssistantContent::ToolCall(message::ToolCall {
1623                        id, function, ..
1624                    }) => {
1625                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1626                        assert_eq!(function.name, "get_weather");
1627                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1628                    }
1629                    _ => panic!("Expected tool call content"),
1630                }
1631            }
1632            _ => panic!("Expected assistant message"),
1633        }
1634
1635        let original_user_message: Message = converted_user_message.try_into().unwrap();
1636        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1637        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1638
1639        assert_eq!(user_message, original_user_message);
1640        assert_eq!(assistant_message, original_assistant_message);
1641        assert_eq!(tool_message, original_tool_message);
1642    }
1643
1644    #[test]
1645    fn test_content_format_conversion() {
1646        use crate::completion::message::ContentFormat;
1647
1648        let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1649        assert_eq!(source_type, SourceType::URL);
1650
1651        let content_format: ContentFormat = SourceType::URL.into();
1652        assert_eq!(content_format, ContentFormat::Url);
1653
1654        let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1655        assert_eq!(source_type, SourceType::BASE64);
1656
1657        let content_format: ContentFormat = SourceType::BASE64.into();
1658        assert_eq!(content_format, ContentFormat::Base64);
1659
1660        let source_type: SourceType = ContentFormat::String.try_into().unwrap();
1661        assert_eq!(source_type, SourceType::TEXT);
1662
1663        let content_format: ContentFormat = SourceType::TEXT.into();
1664        assert_eq!(content_format, ContentFormat::String);
1665    }
1666
1667    #[test]
1668    fn test_cache_control_serialization() {
1669        // Test SystemContent with cache_control
1670        let system = SystemContent::Text {
1671            text: "You are a helpful assistant.".to_string(),
1672            cache_control: Some(CacheControl::Ephemeral),
1673        };
1674        let json = serde_json::to_string(&system).unwrap();
1675        assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1676        assert!(json.contains(r#""type":"text""#));
1677
1678        // Test SystemContent without cache_control (should not have cache_control field)
1679        let system_no_cache = SystemContent::Text {
1680            text: "Hello".to_string(),
1681            cache_control: None,
1682        };
1683        let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1684        assert!(!json_no_cache.contains("cache_control"));
1685
1686        // Test Content::Text with cache_control
1687        let content = Content::Text {
1688            text: "Test message".to_string(),
1689            cache_control: Some(CacheControl::Ephemeral),
1690        };
1691        let json_content = serde_json::to_string(&content).unwrap();
1692        assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1693
1694        // Test apply_cache_control function
1695        let mut system_vec = vec![SystemContent::Text {
1696            text: "System prompt".to_string(),
1697            cache_control: None,
1698        }];
1699        let mut messages = vec![
1700            Message {
1701                role: Role::User,
1702                content: OneOrMany::one(Content::Text {
1703                    text: "First message".to_string(),
1704                    cache_control: None,
1705                }),
1706            },
1707            Message {
1708                role: Role::Assistant,
1709                content: OneOrMany::one(Content::Text {
1710                    text: "Response".to_string(),
1711                    cache_control: None,
1712                }),
1713            },
1714        ];
1715
1716        apply_cache_control(&mut system_vec, &mut messages);
1717
1718        // System should have cache_control
1719        match &system_vec[0] {
1720            SystemContent::Text { cache_control, .. } => {
1721                assert!(cache_control.is_some());
1722            }
1723        }
1724
1725        // Only the last content block of last message should have cache_control
1726        // First message should NOT have cache_control
1727        for content in messages[0].content.iter() {
1728            if let Content::Text { cache_control, .. } = content {
1729                assert!(cache_control.is_none());
1730            }
1731        }
1732
1733        // Last message SHOULD have cache_control
1734        for content in messages[1].content.iter() {
1735            if let Content::Text { cache_control, .. } = content {
1736                assert!(cache_control.is_some());
1737            }
1738        }
1739    }
1740
1741    #[test]
1742    fn test_plaintext_document_serialization() {
1743        let content = Content::Document {
1744            source: DocumentSource::Text {
1745                data: "Hello, world!".to_string(),
1746                media_type: PlainTextMediaType::Plain,
1747            },
1748            cache_control: None,
1749        };
1750
1751        let json = serde_json::to_value(&content).unwrap();
1752        assert_eq!(json["type"], "document");
1753        assert_eq!(json["source"]["type"], "text");
1754        assert_eq!(json["source"]["media_type"], "text/plain");
1755        assert_eq!(json["source"]["data"], "Hello, world!");
1756    }
1757
1758    #[test]
1759    fn test_plaintext_document_deserialization() {
1760        let json = r#"
1761        {
1762            "type": "document",
1763            "source": {
1764                "type": "text",
1765                "media_type": "text/plain",
1766                "data": "Hello, world!"
1767            }
1768        }
1769        "#;
1770
1771        let content: Content = serde_json::from_str(json).unwrap();
1772        match content {
1773            Content::Document {
1774                source,
1775                cache_control,
1776            } => {
1777                assert_eq!(
1778                    source,
1779                    DocumentSource::Text {
1780                        data: "Hello, world!".to_string(),
1781                        media_type: PlainTextMediaType::Plain,
1782                    }
1783                );
1784                assert_eq!(cache_control, None);
1785            }
1786            _ => panic!("Expected Document content"),
1787        }
1788    }
1789
1790    #[test]
1791    fn test_base64_pdf_document_serialization() {
1792        let content = Content::Document {
1793            source: DocumentSource::Base64 {
1794                data: "base64data".to_string(),
1795                media_type: DocumentFormat::PDF,
1796            },
1797            cache_control: None,
1798        };
1799
1800        let json = serde_json::to_value(&content).unwrap();
1801        assert_eq!(json["type"], "document");
1802        assert_eq!(json["source"]["type"], "base64");
1803        assert_eq!(json["source"]["media_type"], "application/pdf");
1804        assert_eq!(json["source"]["data"], "base64data");
1805    }
1806
1807    #[test]
1808    fn test_base64_pdf_document_deserialization() {
1809        let json = r#"
1810        {
1811            "type": "document",
1812            "source": {
1813                "type": "base64",
1814                "media_type": "application/pdf",
1815                "data": "base64data"
1816            }
1817        }
1818        "#;
1819
1820        let content: Content = serde_json::from_str(json).unwrap();
1821        match content {
1822            Content::Document { source, .. } => {
1823                assert_eq!(
1824                    source,
1825                    DocumentSource::Base64 {
1826                        data: "base64data".to_string(),
1827                        media_type: DocumentFormat::PDF,
1828                    }
1829                );
1830            }
1831            _ => panic!("Expected Document content"),
1832        }
1833    }
1834
1835    #[test]
1836    fn test_plaintext_rig_to_anthropic_conversion() {
1837        use crate::completion::message as msg;
1838
1839        let rig_message = msg::Message::User {
1840            content: OneOrMany::one(msg::UserContent::document(
1841                "Some plain text content".to_string(),
1842                Some(msg::DocumentMediaType::TXT),
1843            )),
1844        };
1845
1846        let anthropic_message: Message = rig_message.try_into().unwrap();
1847        assert_eq!(anthropic_message.role, Role::User);
1848
1849        let mut iter = anthropic_message.content.into_iter();
1850        match iter.next().unwrap() {
1851            Content::Document { source, .. } => {
1852                assert_eq!(
1853                    source,
1854                    DocumentSource::Text {
1855                        data: "Some plain text content".to_string(),
1856                        media_type: PlainTextMediaType::Plain,
1857                    }
1858                );
1859            }
1860            other => panic!("Expected Document content, got: {other:?}"),
1861        }
1862    }
1863
1864    #[test]
1865    fn test_plaintext_anthropic_to_rig_conversion() {
1866        use crate::completion::message as msg;
1867
1868        let anthropic_message = Message {
1869            role: Role::User,
1870            content: OneOrMany::one(Content::Document {
1871                source: DocumentSource::Text {
1872                    data: "Some plain text content".to_string(),
1873                    media_type: PlainTextMediaType::Plain,
1874                },
1875                cache_control: None,
1876            }),
1877        };
1878
1879        let rig_message: msg::Message = anthropic_message.try_into().unwrap();
1880        match rig_message {
1881            msg::Message::User { content } => {
1882                let mut iter = content.into_iter();
1883                match iter.next().unwrap() {
1884                    msg::UserContent::Document(msg::Document {
1885                        data, media_type, ..
1886                    }) => {
1887                        assert_eq!(
1888                            data,
1889                            DocumentSourceKind::String("Some plain text content".into())
1890                        );
1891                        assert_eq!(media_type, Some(msg::DocumentMediaType::TXT));
1892                    }
1893                    other => panic!("Expected Document content, got: {other:?}"),
1894                }
1895            }
1896            _ => panic!("Expected User message"),
1897        }
1898    }
1899
1900    #[test]
1901    fn test_plaintext_roundtrip_rig_to_anthropic_and_back() {
1902        use crate::completion::message as msg;
1903
1904        let original = msg::Message::User {
1905            content: OneOrMany::one(msg::UserContent::document(
1906                "Round trip text".to_string(),
1907                Some(msg::DocumentMediaType::TXT),
1908            )),
1909        };
1910
1911        let anthropic: Message = original.clone().try_into().unwrap();
1912        let back: msg::Message = anthropic.try_into().unwrap();
1913
1914        match (&original, &back) {
1915            (
1916                msg::Message::User {
1917                    content: orig_content,
1918                },
1919                msg::Message::User {
1920                    content: back_content,
1921                },
1922            ) => match (orig_content.first(), back_content.first()) {
1923                (
1924                    msg::UserContent::Document(msg::Document {
1925                        media_type: orig_mt,
1926                        ..
1927                    }),
1928                    msg::UserContent::Document(msg::Document {
1929                        media_type: back_mt,
1930                        ..
1931                    }),
1932                ) => {
1933                    assert_eq!(orig_mt, back_mt);
1934                }
1935                _ => panic!("Expected Document content in both"),
1936            },
1937            _ => panic!("Expected User messages"),
1938        }
1939    }
1940
1941    #[test]
1942    fn test_unsupported_document_type_returns_error() {
1943        use crate::completion::message as msg;
1944
1945        let rig_message = msg::Message::User {
1946            content: OneOrMany::one(msg::UserContent::Document(msg::Document {
1947                data: DocumentSourceKind::String("data".into()),
1948                media_type: Some(msg::DocumentMediaType::HTML),
1949                additional_params: None,
1950            })),
1951        };
1952
1953        let result: Result<Message, _> = rig_message.try_into();
1954        assert!(result.is_err());
1955        let err = result.unwrap_err().to_string();
1956        assert!(
1957            err.contains("Anthropic only supports PDF and plain text documents"),
1958            "Unexpected error: {err}"
1959        );
1960    }
1961
1962    #[test]
1963    fn test_plaintext_document_url_source_returns_error() {
1964        use crate::completion::message as msg;
1965
1966        let rig_message = msg::Message::User {
1967            content: OneOrMany::one(msg::UserContent::Document(msg::Document {
1968                data: DocumentSourceKind::Url("https://example.com/doc.txt".into()),
1969                media_type: Some(msg::DocumentMediaType::TXT),
1970                additional_params: None,
1971            })),
1972        };
1973
1974        let result: Result<Message, _> = rig_message.try_into();
1975        assert!(result.is_err());
1976        let err = result.unwrap_err().to_string();
1977        assert!(
1978            err.contains("Only string or base64 data is supported for plain text documents"),
1979            "Unexpected error: {err}"
1980        );
1981    }
1982
1983    #[test]
1984    fn test_plaintext_document_with_cache_control() {
1985        let content = Content::Document {
1986            source: DocumentSource::Text {
1987                data: "cached text".to_string(),
1988                media_type: PlainTextMediaType::Plain,
1989            },
1990            cache_control: Some(CacheControl::Ephemeral),
1991        };
1992
1993        let json = serde_json::to_value(&content).unwrap();
1994        assert_eq!(json["source"]["type"], "text");
1995        assert_eq!(json["source"]["media_type"], "text/plain");
1996        assert_eq!(json["cache_control"]["type"], "ephemeral");
1997    }
1998
1999    #[test]
2000    fn test_message_with_plaintext_document_deserialization() {
2001        let json = r#"
2002        {
2003            "role": "user",
2004            "content": [
2005                {
2006                    "type": "document",
2007                    "source": {
2008                        "type": "text",
2009                        "media_type": "text/plain",
2010                        "data": "Hello from a text file"
2011                    }
2012                },
2013                {
2014                    "type": "text",
2015                    "text": "Summarize this document."
2016                }
2017            ]
2018        }
2019        "#;
2020
2021        let message: Message = serde_json::from_str(json).unwrap();
2022        assert_eq!(message.role, Role::User);
2023        assert_eq!(message.content.len(), 2);
2024
2025        let mut iter = message.content.into_iter();
2026
2027        match iter.next().unwrap() {
2028            Content::Document { source, .. } => {
2029                assert_eq!(
2030                    source,
2031                    DocumentSource::Text {
2032                        data: "Hello from a text file".to_string(),
2033                        media_type: PlainTextMediaType::Plain,
2034                    }
2035                );
2036            }
2037            _ => panic!("Expected Document content"),
2038        }
2039
2040        match iter.next().unwrap() {
2041            Content::Text { text, .. } => {
2042                assert_eq!(text, "Summarize this document.");
2043            }
2044            _ => panic!("Expected Text content"),
2045        }
2046    }
2047
2048    #[test]
2049    fn test_assistant_reasoning_multiblock_to_anthropic_content() {
2050        let reasoning = message::Reasoning {
2051            id: None,
2052            content: vec![
2053                message::ReasoningContent::Text {
2054                    text: "step one".to_string(),
2055                    signature: Some("sig-1".to_string()),
2056                },
2057                message::ReasoningContent::Summary("summary".to_string()),
2058                message::ReasoningContent::Text {
2059                    text: "step two".to_string(),
2060                    signature: Some("sig-2".to_string()),
2061                },
2062                message::ReasoningContent::Redacted {
2063                    data: "redacted block".to_string(),
2064                },
2065            ],
2066        };
2067
2068        let msg = message::Message::Assistant {
2069            id: None,
2070            content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2071        };
2072        let converted: Message = msg.try_into().expect("convert assistant message");
2073        let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2074
2075        assert_eq!(converted.role, Role::Assistant);
2076        assert_eq!(converted_content.len(), 4);
2077        assert!(matches!(
2078            converted_content.first(),
2079            Some(Content::Thinking { thinking, signature: Some(signature) })
2080                if thinking == "step one" && signature == "sig-1"
2081        ));
2082        assert!(matches!(
2083            converted_content.get(1),
2084            Some(Content::Thinking { thinking, signature: None }) if thinking == "summary"
2085        ));
2086        assert!(matches!(
2087            converted_content.get(2),
2088            Some(Content::Thinking { thinking, signature: Some(signature) })
2089                if thinking == "step two" && signature == "sig-2"
2090        ));
2091        assert!(matches!(
2092            converted_content.get(3),
2093            Some(Content::RedactedThinking { data }) if data == "redacted block"
2094        ));
2095    }
2096
2097    #[test]
2098    fn test_redacted_thinking_content_to_assistant_reasoning() {
2099        let content = Content::RedactedThinking {
2100            data: "opaque-redacted".to_string(),
2101        };
2102        let converted: message::AssistantContent =
2103            content.try_into().expect("convert redacted thinking");
2104
2105        assert!(matches!(
2106            converted,
2107            message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2108                if matches!(
2109                    content.first(),
2110                    Some(message::ReasoningContent::Redacted { data }) if data == "opaque-redacted"
2111                )
2112        ));
2113    }
2114
2115    #[test]
2116    fn test_assistant_encrypted_reasoning_maps_to_redacted_thinking() {
2117        let reasoning = message::Reasoning {
2118            id: None,
2119            content: vec![message::ReasoningContent::Encrypted(
2120                "ciphertext".to_string(),
2121            )],
2122        };
2123        let msg = message::Message::Assistant {
2124            id: None,
2125            content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2126        };
2127
2128        let converted: Message = msg.try_into().expect("convert assistant message");
2129        let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2130
2131        assert_eq!(converted_content.len(), 1);
2132        assert!(matches!(
2133            converted_content.first(),
2134            Some(Content::RedactedThinking { data }) if data == "ciphertext"
2135        ));
2136    }
2137}