Skip to main content

rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError, GetTokenUsage},
6    http_client::HttpClientExt,
7    message::{self, DocumentMediaType, DocumentSourceKind, MessageError, MimeType, Reasoning},
8    one_or_many::string_or_one_or_many,
9    telemetry::{ProviderResponseExt, SpanCombinator},
10    wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21// ================================================================
22// Anthropic Completion API
23// ================================================================
24
25/// `claude-opus-4-0` completion model
26pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27/// `claude-sonnet-4-0` completion model
28pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29/// `claude-3-7-sonnet-latest` completion model
30pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31/// `claude-3-5-sonnet-latest` completion model
32pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33/// `claude-3-5-haiku-latest` completion model
34pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42    pub content: Vec<Content>,
43    pub id: String,
44    pub model: String,
45    pub role: String,
46    pub stop_reason: Option<String>,
47    pub stop_sequence: Option<String>,
48    pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52    type OutputMessage = Content;
53    type Usage = Usage;
54
55    fn get_response_id(&self) -> Option<String> {
56        Some(self.id.to_owned())
57    }
58
59    fn get_response_model_name(&self) -> Option<String> {
60        Some(self.model.to_owned())
61    }
62
63    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64        self.content.clone()
65    }
66
67    fn get_text_response(&self) -> Option<String> {
68        let res = self
69            .content
70            .iter()
71            .filter_map(|x| {
72                if let Content::Text { text, .. } = x {
73                    Some(text.to_owned())
74                } else {
75                    None
76                }
77            })
78            .collect::<Vec<String>>()
79            .join("\n");
80
81        if res.is_empty() { None } else { Some(res) }
82    }
83
84    fn get_usage(&self) -> Option<Self::Usage> {
85        Some(self.usage.clone())
86    }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91    pub input_tokens: u64,
92    pub cache_read_input_tokens: Option<u64>,
93    pub cache_creation_input_tokens: Option<u64>,
94    pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        write!(
100            f,
101            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102            self.input_tokens,
103            match self.cache_read_input_tokens {
104                Some(token) => token.to_string(),
105                None => "n/a".to_string(),
106            },
107            match self.cache_creation_input_tokens {
108                Some(token) => token.to_string(),
109                None => "n/a".to_string(),
110            },
111            self.output_tokens
112        )
113    }
114}
115
116impl GetTokenUsage for Usage {
117    fn token_usage(&self) -> Option<crate::completion::Usage> {
118        let mut usage = crate::completion::Usage::new();
119
120        usage.input_tokens = self.input_tokens
121            + self.cache_creation_input_tokens.unwrap_or_default()
122            + self.cache_read_input_tokens.unwrap_or_default();
123        usage.output_tokens = self.output_tokens;
124        usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126        Some(usage)
127    }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132    pub name: String,
133    pub description: Option<String>,
134    pub input_schema: serde_json::Value,
135}
136
137/// Cache control directive for Anthropic prompt caching
138#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum CacheControl {
141    Ephemeral,
142}
143
144/// System message content block with optional cache control
145#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
146#[serde(tag = "type", rename_all = "snake_case")]
147pub enum SystemContent {
148    Text {
149        text: String,
150        #[serde(skip_serializing_if = "Option::is_none")]
151        cache_control: Option<CacheControl>,
152    },
153}
154
155impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
156    type Error = CompletionError;
157
158    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
159        let content = response
160            .content
161            .iter()
162            .map(|content| content.clone().try_into())
163            .collect::<Result<Vec<_>, _>>()?;
164
165        let choice = OneOrMany::many(content).map_err(|_| {
166            CompletionError::ResponseError(
167                "Response contained no message or tool call (empty)".to_owned(),
168            )
169        })?;
170
171        let usage = completion::Usage {
172            input_tokens: response.usage.input_tokens,
173            output_tokens: response.usage.output_tokens,
174            total_tokens: response.usage.input_tokens + response.usage.output_tokens,
175            cached_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
176        };
177
178        Ok(completion::CompletionResponse {
179            choice,
180            usage,
181            raw_response: response,
182            message_id: None,
183        })
184    }
185}
186
187#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
188pub struct Message {
189    pub role: Role,
190    #[serde(deserialize_with = "string_or_one_or_many")]
191    pub content: OneOrMany<Content>,
192}
193
194#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
195#[serde(rename_all = "lowercase")]
196pub enum Role {
197    User,
198    Assistant,
199}
200
201#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
202#[serde(tag = "type", rename_all = "snake_case")]
203pub enum Content {
204    Text {
205        text: String,
206        #[serde(skip_serializing_if = "Option::is_none")]
207        cache_control: Option<CacheControl>,
208    },
209    Image {
210        source: ImageSource,
211        #[serde(skip_serializing_if = "Option::is_none")]
212        cache_control: Option<CacheControl>,
213    },
214    ToolUse {
215        id: String,
216        name: String,
217        input: serde_json::Value,
218    },
219    ToolResult {
220        tool_use_id: String,
221        #[serde(deserialize_with = "string_or_one_or_many")]
222        content: OneOrMany<ToolResultContent>,
223        #[serde(skip_serializing_if = "Option::is_none")]
224        is_error: Option<bool>,
225        #[serde(skip_serializing_if = "Option::is_none")]
226        cache_control: Option<CacheControl>,
227    },
228    Document {
229        source: DocumentSource,
230        #[serde(skip_serializing_if = "Option::is_none")]
231        cache_control: Option<CacheControl>,
232    },
233    Thinking {
234        thinking: String,
235        #[serde(skip_serializing_if = "Option::is_none")]
236        signature: Option<String>,
237    },
238    RedactedThinking {
239        data: String,
240    },
241}
242
243impl FromStr for Content {
244    type Err = Infallible;
245
246    fn from_str(s: &str) -> Result<Self, Self::Err> {
247        Ok(Content::Text {
248            text: s.to_owned(),
249            cache_control: None,
250        })
251    }
252}
253
254#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
255#[serde(tag = "type", rename_all = "snake_case")]
256pub enum ToolResultContent {
257    Text { text: String },
258    Image(ImageSource),
259}
260
261impl FromStr for ToolResultContent {
262    type Err = Infallible;
263
264    fn from_str(s: &str) -> Result<Self, Self::Err> {
265        Ok(ToolResultContent::Text { text: s.to_owned() })
266    }
267}
268
269/// The source of an image content block.
270///
271/// Anthropic supports two source types for images:
272/// - `Base64`: Base64-encoded image data with media type
273/// - `Url`: URL reference to an image
274///
275/// See: <https://docs.anthropic.com/en/api/messages>
276#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
277#[serde(tag = "type", rename_all = "snake_case")]
278pub enum ImageSource {
279    #[serde(rename = "base64")]
280    Base64 {
281        data: String,
282        media_type: ImageFormat,
283    },
284    #[serde(rename = "url")]
285    Url { url: String },
286}
287
288/// The source of a document content block.
289///
290/// Anthropic supports multiple source types for documents. Currently implemented:
291/// - `Base64`: Base64-encoded document data (used for PDFs)
292/// - `Text`: Plain text document data
293///
294/// Future variants (not yet implemented):
295/// - URL-based PDF sources
296/// - Content block sources
297/// - File API sources
298#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
299#[serde(tag = "type", rename_all = "snake_case")]
300pub enum DocumentSource {
301    Base64 {
302        data: String,
303        media_type: DocumentFormat,
304    },
305    Text {
306        data: String,
307        media_type: PlainTextMediaType,
308    },
309    Url {
310        url: String,
311    },
312}
313
314#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
315#[serde(rename_all = "lowercase")]
316pub enum ImageFormat {
317    #[serde(rename = "image/jpeg")]
318    JPEG,
319    #[serde(rename = "image/png")]
320    PNG,
321    #[serde(rename = "image/gif")]
322    GIF,
323    #[serde(rename = "image/webp")]
324    WEBP,
325}
326
327/// The media type for base64-encoded documents.
328///
329/// Used with the `DocumentSource::Base64` variant. Currently only PDF is supported
330/// for base64-encoded document sources.
331///
332/// See: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
333#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
334#[serde(rename_all = "lowercase")]
335pub enum DocumentFormat {
336    #[serde(rename = "application/pdf")]
337    PDF,
338}
339
340/// The media type for plain text document sources.
341///
342/// Used with the `DocumentSource::Text` variant.
343///
344/// See: <https://docs.anthropic.com/en/api/messages>
345#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
346pub enum PlainTextMediaType {
347    #[serde(rename = "text/plain")]
348    Plain,
349}
350
351#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
352#[serde(rename_all = "lowercase")]
353pub enum SourceType {
354    BASE64,
355    URL,
356    TEXT,
357}
358
359impl From<String> for Content {
360    fn from(text: String) -> Self {
361        Content::Text {
362            text,
363            cache_control: None,
364        }
365    }
366}
367
368impl From<String> for ToolResultContent {
369    fn from(text: String) -> Self {
370        ToolResultContent::Text { text }
371    }
372}
373
374impl TryFrom<message::ContentFormat> for SourceType {
375    type Error = MessageError;
376
377    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
378        match format {
379            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
380            message::ContentFormat::Url => Ok(SourceType::URL),
381            message::ContentFormat::String => Ok(SourceType::TEXT),
382        }
383    }
384}
385
386impl From<SourceType> for message::ContentFormat {
387    fn from(source_type: SourceType) -> Self {
388        match source_type {
389            SourceType::BASE64 => message::ContentFormat::Base64,
390            SourceType::URL => message::ContentFormat::Url,
391            SourceType::TEXT => message::ContentFormat::String,
392        }
393    }
394}
395
396impl TryFrom<message::ImageMediaType> for ImageFormat {
397    type Error = MessageError;
398
399    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
400        Ok(match media_type {
401            message::ImageMediaType::JPEG => ImageFormat::JPEG,
402            message::ImageMediaType::PNG => ImageFormat::PNG,
403            message::ImageMediaType::GIF => ImageFormat::GIF,
404            message::ImageMediaType::WEBP => ImageFormat::WEBP,
405            _ => {
406                return Err(MessageError::ConversionError(
407                    format!("Unsupported image media type: {media_type:?}").to_owned(),
408                ));
409            }
410        })
411    }
412}
413
414impl From<ImageFormat> for message::ImageMediaType {
415    fn from(format: ImageFormat) -> Self {
416        match format {
417            ImageFormat::JPEG => message::ImageMediaType::JPEG,
418            ImageFormat::PNG => message::ImageMediaType::PNG,
419            ImageFormat::GIF => message::ImageMediaType::GIF,
420            ImageFormat::WEBP => message::ImageMediaType::WEBP,
421        }
422    }
423}
424
425impl TryFrom<DocumentMediaType> for DocumentFormat {
426    type Error = MessageError;
427    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
428        match value {
429            DocumentMediaType::PDF => Ok(DocumentFormat::PDF),
430            other => Err(MessageError::ConversionError(format!(
431                "DocumentFormat only supports PDF for base64 sources, got: {}",
432                other.to_mime_type()
433            ))),
434        }
435    }
436}
437
438impl TryFrom<message::AssistantContent> for Content {
439    type Error = MessageError;
440    fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
441        match text {
442            message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
443                text,
444                cache_control: None,
445            }),
446            message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
447                "Anthropic currently doesn't support images.".to_string(),
448            )),
449            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
450                Ok(Content::ToolUse {
451                    id,
452                    name: function.name,
453                    input: function.arguments,
454                })
455            }
456            message::AssistantContent::Reasoning(reasoning) => Ok(Content::Thinking {
457                thinking: reasoning.display_text(),
458                signature: reasoning.first_signature().map(str::to_owned),
459            }),
460        }
461    }
462}
463
464fn anthropic_content_from_assistant_content(
465    content: message::AssistantContent,
466) -> Result<Vec<Content>, MessageError> {
467    match content {
468        message::AssistantContent::Text(message::Text { text }) => Ok(vec![Content::Text {
469            text,
470            cache_control: None,
471        }]),
472        message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
473            "Anthropic currently doesn't support images.".to_string(),
474        )),
475        message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
476            Ok(vec![Content::ToolUse {
477                id,
478                name: function.name,
479                input: function.arguments,
480            }])
481        }
482        message::AssistantContent::Reasoning(reasoning) => {
483            let mut converted = Vec::new();
484            for block in reasoning.content {
485                match block {
486                    message::ReasoningContent::Text { text, signature } => {
487                        converted.push(Content::Thinking {
488                            thinking: text,
489                            signature,
490                        });
491                    }
492                    message::ReasoningContent::Summary(summary) => {
493                        converted.push(Content::Thinking {
494                            thinking: summary,
495                            signature: None,
496                        });
497                    }
498                    message::ReasoningContent::Redacted { data }
499                    | message::ReasoningContent::Encrypted(data) => {
500                        converted.push(Content::RedactedThinking { data });
501                    }
502                }
503            }
504
505            if converted.is_empty() {
506                return Err(MessageError::ConversionError(
507                    "Cannot convert empty reasoning content to Anthropic format".to_string(),
508                ));
509            }
510
511            Ok(converted)
512        }
513    }
514}
515
516impl TryFrom<message::Message> for Message {
517    type Error = MessageError;
518
519    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
520        Ok(match message {
521            message::Message::User { content } => Message {
522                role: Role::User,
523                content: content.try_map(|content| match content {
524                    message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
525                        text,
526                        cache_control: None,
527                    }),
528                    message::UserContent::ToolResult(message::ToolResult {
529                        id, content, ..
530                    }) => Ok(Content::ToolResult {
531                        tool_use_id: id,
532                        content: content.try_map(|content| match content {
533                            message::ToolResultContent::Text(message::Text { text }) => {
534                                Ok(ToolResultContent::Text { text })
535                            }
536                            message::ToolResultContent::Image(image) => {
537                                let DocumentSourceKind::Base64(data) = image.data else {
538                                    return Err(MessageError::ConversionError(
539                                        "Only base64 strings can be used with the Anthropic API"
540                                            .to_string(),
541                                    ));
542                                };
543                                let media_type =
544                                    image.media_type.ok_or(MessageError::ConversionError(
545                                        "Image media type is required".to_owned(),
546                                    ))?;
547                                Ok(ToolResultContent::Image(ImageSource::Base64 {
548                                    data,
549                                    media_type: media_type.try_into()?,
550                                }))
551                            }
552                        })?,
553                        is_error: None,
554                        cache_control: None,
555                    }),
556                    message::UserContent::Image(message::Image {
557                        data, media_type, ..
558                    }) => {
559                        let source = match data {
560                            DocumentSourceKind::Base64(data) => {
561                                let media_type =
562                                    media_type.ok_or(MessageError::ConversionError(
563                                        "Image media type is required for Claude API".to_string(),
564                                    ))?;
565                                ImageSource::Base64 {
566                                    data,
567                                    media_type: ImageFormat::try_from(media_type)?,
568                                }
569                            }
570                            DocumentSourceKind::Url(url) => ImageSource::Url { url },
571                            DocumentSourceKind::Unknown => {
572                                return Err(MessageError::ConversionError(
573                                    "Image content has no body".into(),
574                                ));
575                            }
576                            doc => {
577                                return Err(MessageError::ConversionError(format!(
578                                    "Unsupported document type: {doc:?}"
579                                )));
580                            }
581                        };
582
583                        Ok(Content::Image {
584                            source,
585                            cache_control: None,
586                        })
587                    }
588                    message::UserContent::Document(message::Document {
589                        data, media_type, ..
590                    }) => {
591                        let media_type = media_type.ok_or(MessageError::ConversionError(
592                            "Document media type is required".to_string(),
593                        ))?;
594
595                        let source = match media_type {
596                            DocumentMediaType::PDF => {
597                                let data = match data {
598                                    DocumentSourceKind::Base64(data)
599                                    | DocumentSourceKind::String(data) => data,
600                                    _ => {
601                                        return Err(MessageError::ConversionError(
602                                            "Only base64 encoded data is supported for PDF documents".into(),
603                                        ));
604                                    }
605                                };
606                                DocumentSource::Base64 {
607                                    data,
608                                    media_type: DocumentFormat::PDF,
609                                }
610                            }
611                            DocumentMediaType::TXT => {
612                                let data = match data {
613                                    DocumentSourceKind::String(data)
614                                    | DocumentSourceKind::Base64(data) => data,
615                                    _ => {
616                                        return Err(MessageError::ConversionError(
617                                            "Only string or base64 data is supported for plain text documents".into(),
618                                        ));
619                                    }
620                                };
621                                DocumentSource::Text {
622                                    data,
623                                    media_type: PlainTextMediaType::Plain,
624                                }
625                            }
626                            other => {
627                                return Err(MessageError::ConversionError(format!(
628                                    "Anthropic only supports PDF and plain text documents, got: {}",
629                                    other.to_mime_type()
630                                )));
631                            }
632                        };
633
634                        Ok(Content::Document {
635                            source,
636                            cache_control: None,
637                        })
638                    }
639                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
640                        "Audio is not supported in Anthropic".to_owned(),
641                    )),
642                    message::UserContent::Video { .. } => Err(MessageError::ConversionError(
643                        "Video is not supported in Anthropic".to_owned(),
644                    )),
645                })?,
646            },
647
648            message::Message::Assistant { content, .. } => {
649                let converted_content = content.into_iter().try_fold(
650                    Vec::new(),
651                    |mut accumulated, assistant_content| {
652                        accumulated
653                            .extend(anthropic_content_from_assistant_content(assistant_content)?);
654                        Ok::<Vec<Content>, MessageError>(accumulated)
655                    },
656                )?;
657
658                Message {
659                    content: OneOrMany::many(converted_content).map_err(|_| {
660                        MessageError::ConversionError(
661                            "Assistant message did not contain Anthropic-compatible content"
662                                .to_owned(),
663                        )
664                    })?,
665                    role: Role::Assistant,
666                }
667            }
668        })
669    }
670}
671
672impl TryFrom<Content> for message::AssistantContent {
673    type Error = MessageError;
674
675    fn try_from(content: Content) -> Result<Self, Self::Error> {
676        Ok(match content {
677            Content::Text { text, .. } => message::AssistantContent::text(text),
678            Content::ToolUse { id, name, input } => {
679                message::AssistantContent::tool_call(id, name, input)
680            }
681            Content::Thinking {
682                thinking,
683                signature,
684            } => message::AssistantContent::Reasoning(Reasoning::new_with_signature(
685                &thinking, signature,
686            )),
687            Content::RedactedThinking { data } => {
688                message::AssistantContent::Reasoning(Reasoning::redacted(data))
689            }
690            _ => {
691                return Err(MessageError::ConversionError(
692                    "Content did not contain a message, tool call, or reasoning".to_owned(),
693                ));
694            }
695        })
696    }
697}
698
699impl From<ToolResultContent> for message::ToolResultContent {
700    fn from(content: ToolResultContent) -> Self {
701        match content {
702            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
703            ToolResultContent::Image(source) => match source {
704                ImageSource::Base64 { data, media_type } => {
705                    message::ToolResultContent::image_base64(data, Some(media_type.into()), None)
706                }
707                ImageSource::Url { url } => message::ToolResultContent::image_url(url, None, None),
708            },
709        }
710    }
711}
712
713impl TryFrom<Message> for message::Message {
714    type Error = MessageError;
715
716    fn try_from(message: Message) -> Result<Self, Self::Error> {
717        Ok(match message.role {
718            Role::User => message::Message::User {
719                content: message.content.try_map(|content| {
720                    Ok(match content {
721                        Content::Text { text, .. } => message::UserContent::text(text),
722                        Content::ToolResult {
723                            tool_use_id,
724                            content,
725                            ..
726                        } => message::UserContent::tool_result(
727                            tool_use_id,
728                            content.map(|content| content.into()),
729                        ),
730                        Content::Image { source, .. } => match source {
731                            ImageSource::Base64 { data, media_type } => {
732                                message::UserContent::Image(message::Image {
733                                    data: DocumentSourceKind::Base64(data),
734                                    media_type: Some(media_type.into()),
735                                    detail: None,
736                                    additional_params: None,
737                                })
738                            }
739                            ImageSource::Url { url } => {
740                                message::UserContent::Image(message::Image {
741                                    data: DocumentSourceKind::Url(url),
742                                    media_type: None,
743                                    detail: None,
744                                    additional_params: None,
745                                })
746                            }
747                        },
748                        Content::Document { source, .. } => match source {
749                            DocumentSource::Base64 { data, media_type } => {
750                                let rig_media_type = match media_type {
751                                    DocumentFormat::PDF => message::DocumentMediaType::PDF,
752                                };
753                                message::UserContent::document(data, Some(rig_media_type))
754                            }
755                            DocumentSource::Text { data, .. } => message::UserContent::document(
756                                data,
757                                Some(message::DocumentMediaType::TXT),
758                            ),
759                            DocumentSource::Url { url } => {
760                                message::UserContent::document_url(url, None)
761                            }
762                        },
763                        _ => {
764                            return Err(MessageError::ConversionError(
765                                "Unsupported content type for User role".to_owned(),
766                            ));
767                        }
768                    })
769                })?,
770            },
771            Role::Assistant => message::Message::Assistant {
772                id: None,
773                content: message.content.try_map(|content| content.try_into())?,
774            },
775        })
776    }
777}
778
779#[derive(Clone)]
780pub struct CompletionModel<T = reqwest::Client> {
781    pub(crate) client: Client<T>,
782    pub model: String,
783    pub default_max_tokens: Option<u64>,
784    /// Enable automatic prompt caching (adds cache_control breakpoints to system prompt and messages)
785    pub prompt_caching: bool,
786}
787
788impl<T> CompletionModel<T>
789where
790    T: HttpClientExt,
791{
792    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
793        let model = model.into();
794        let default_max_tokens = calculate_max_tokens(&model);
795
796        Self {
797            client,
798            model,
799            default_max_tokens,
800            prompt_caching: false, // Default to off
801        }
802    }
803
804    pub fn with_model(client: Client<T>, model: &str) -> Self {
805        Self {
806            client,
807            model: model.to_string(),
808            default_max_tokens: Some(calculate_max_tokens_custom(model)),
809            prompt_caching: false, // Default to off
810        }
811    }
812
813    /// Enable automatic prompt caching.
814    ///
815    /// When enabled, cache_control breakpoints are automatically added to:
816    /// - The system prompt (marked with ephemeral cache)
817    /// - The last content block of the last message (marked with ephemeral cache)
818    ///
819    /// This allows Anthropic to cache the conversation history for cost savings.
820    pub fn with_prompt_caching(mut self) -> Self {
821        self.prompt_caching = true;
822        self
823    }
824}
825
826/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
827/// set or if set too high, the request will fail. The following values are based on the models
828/// available at the time of writing.
829fn calculate_max_tokens(model: &str) -> Option<u64> {
830    if model.starts_with("claude-opus-4") {
831        Some(32000)
832    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
833        Some(64000)
834    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
835        Some(8192)
836    } else if model.starts_with("claude-3-opus")
837        || model.starts_with("claude-3-sonnet")
838        || model.starts_with("claude-3-haiku")
839    {
840        Some(4096)
841    } else {
842        None
843    }
844}
845
846fn calculate_max_tokens_custom(model: &str) -> u64 {
847    if model.starts_with("claude-opus-4") {
848        32000
849    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
850        64000
851    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
852        8192
853    } else if model.starts_with("claude-3-opus")
854        || model.starts_with("claude-3-sonnet")
855        || model.starts_with("claude-3-haiku")
856    {
857        4096
858    } else {
859        2048
860    }
861}
862
863#[derive(Debug, Deserialize, Serialize)]
864pub struct Metadata {
865    user_id: Option<String>,
866}
867
868#[derive(Default, Debug, Serialize, Deserialize)]
869#[serde(tag = "type", rename_all = "snake_case")]
870pub enum ToolChoice {
871    #[default]
872    Auto,
873    Any,
874    None,
875    Tool {
876        name: String,
877    },
878}
879impl TryFrom<message::ToolChoice> for ToolChoice {
880    type Error = CompletionError;
881
882    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
883        let res = match value {
884            message::ToolChoice::Auto => Self::Auto,
885            message::ToolChoice::None => Self::None,
886            message::ToolChoice::Required => Self::Any,
887            message::ToolChoice::Specific { function_names } => {
888                if function_names.len() != 1 {
889                    return Err(CompletionError::ProviderError(
890                        "Only one tool may be specified to be used by Claude".into(),
891                    ));
892                }
893
894                Self::Tool {
895                    name: function_names.first().unwrap().to_string(),
896                }
897            }
898        };
899
900        Ok(res)
901    }
902}
903
904/// Recursively ensures all object schemas respect Anthropic structured output restrictions:
905/// - `additionalProperties` must be explicitly set to `false` on every object
906/// - All properties must be listed in `required`
907///
908/// Source: <https://docs.anthropic.com/en/docs/build-with-claude/structured-outputs#json-schema-limitations>
909fn sanitize_schema(schema: &mut serde_json::Value) {
910    use serde_json::Value;
911
912    if let Value::Object(obj) = schema {
913        let is_object_schema = obj.get("type") == Some(&Value::String("object".to_string()))
914            || obj.contains_key("properties");
915
916        if is_object_schema && !obj.contains_key("additionalProperties") {
917            obj.insert("additionalProperties".to_string(), Value::Bool(false));
918        }
919
920        if let Some(Value::Object(properties)) = obj.get("properties") {
921            let prop_keys = properties.keys().cloned().map(Value::String).collect();
922            obj.insert("required".to_string(), Value::Array(prop_keys));
923        }
924
925        // Anthropic does not support numerical constraints on integer/number types.
926        let is_numeric_schema = obj.get("type") == Some(&Value::String("integer".to_string()))
927            || obj.get("type") == Some(&Value::String("number".to_string()));
928
929        if is_numeric_schema {
930            for key in [
931                "minimum",
932                "maximum",
933                "exclusiveMinimum",
934                "exclusiveMaximum",
935                "multipleOf",
936            ] {
937                obj.remove(key);
938            }
939        }
940
941        if let Some(defs) = obj.get_mut("$defs")
942            && let Value::Object(defs_obj) = defs
943        {
944            for (_, def_schema) in defs_obj.iter_mut() {
945                sanitize_schema(def_schema);
946            }
947        }
948
949        if let Some(properties) = obj.get_mut("properties")
950            && let Value::Object(props) = properties
951        {
952            for (_, prop_value) in props.iter_mut() {
953                sanitize_schema(prop_value);
954            }
955        }
956
957        if let Some(items) = obj.get_mut("items") {
958            sanitize_schema(items);
959        }
960
961        for key in ["anyOf", "oneOf", "allOf"] {
962            if let Some(variants) = obj.get_mut(key)
963                && let Value::Array(variants_array) = variants
964            {
965                for variant in variants_array.iter_mut() {
966                    sanitize_schema(variant);
967                }
968            }
969        }
970    }
971}
972
973/// Output format specifier for Anthropic's structured output.
974/// Source: <https://docs.anthropic.com/en/api/messages>
975#[derive(Debug, Deserialize, Serialize)]
976#[serde(tag = "type", rename_all = "snake_case")]
977enum OutputFormat {
978    /// Constrains the model's response to conform to the provided JSON schema.
979    JsonSchema { schema: serde_json::Value },
980}
981
982/// Configuration for the model's output format.
983#[derive(Debug, Deserialize, Serialize)]
984struct OutputConfig {
985    format: OutputFormat,
986}
987
988#[derive(Debug, Deserialize, Serialize)]
989struct AnthropicCompletionRequest {
990    model: String,
991    messages: Vec<Message>,
992    max_tokens: u64,
993    /// System prompt as array of content blocks to support cache_control
994    #[serde(skip_serializing_if = "Vec::is_empty")]
995    system: Vec<SystemContent>,
996    #[serde(skip_serializing_if = "Option::is_none")]
997    temperature: Option<f64>,
998    #[serde(skip_serializing_if = "Option::is_none")]
999    tool_choice: Option<ToolChoice>,
1000    #[serde(skip_serializing_if = "Vec::is_empty")]
1001    tools: Vec<ToolDefinition>,
1002    #[serde(skip_serializing_if = "Option::is_none")]
1003    output_config: Option<OutputConfig>,
1004    #[serde(flatten, skip_serializing_if = "Option::is_none")]
1005    additional_params: Option<serde_json::Value>,
1006}
1007
1008/// Helper to set cache_control on a Content block
1009fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
1010    match content {
1011        Content::Text { cache_control, .. } => *cache_control = value,
1012        Content::Image { cache_control, .. } => *cache_control = value,
1013        Content::ToolResult { cache_control, .. } => *cache_control = value,
1014        Content::Document { cache_control, .. } => *cache_control = value,
1015        _ => {}
1016    }
1017}
1018
1019/// Apply cache control breakpoints to system prompt and messages.
1020/// Strategy: cache the system prompt, and mark the last content block of the last message
1021/// for caching. This allows the conversation history to be cached while new messages
1022/// are added.
1023pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
1024    // Add cache_control to the system prompt (if non-empty)
1025    if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
1026        *cache_control = Some(CacheControl::Ephemeral);
1027    }
1028
1029    // Clear any existing cache_control from all message content blocks
1030    for msg in messages.iter_mut() {
1031        for content in msg.content.iter_mut() {
1032            set_content_cache_control(content, None);
1033        }
1034    }
1035
1036    // Add cache_control to the last content block of the last message
1037    if let Some(last_msg) = messages.last_mut() {
1038        set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::Ephemeral));
1039    }
1040}
1041
1042/// Parameters for building an AnthropicCompletionRequest
1043pub struct AnthropicRequestParams<'a> {
1044    pub model: &'a str,
1045    pub request: CompletionRequest,
1046    pub prompt_caching: bool,
1047}
1048
1049impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
1050    type Error = CompletionError;
1051
1052    fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
1053        let AnthropicRequestParams {
1054            model,
1055            request: req,
1056            prompt_caching,
1057        } = params;
1058
1059        // Check if max_tokens is set, required for Anthropic
1060        let Some(max_tokens) = req.max_tokens else {
1061            return Err(CompletionError::RequestError(
1062                "`max_tokens` must be set for Anthropic".into(),
1063            ));
1064        };
1065
1066        let mut full_history = vec![];
1067        if let Some(docs) = req.normalized_documents() {
1068            full_history.push(docs);
1069        }
1070        full_history.extend(req.chat_history);
1071
1072        let mut messages = full_history
1073            .into_iter()
1074            .map(Message::try_from)
1075            .collect::<Result<Vec<Message>, _>>()?;
1076
1077        let tools = req
1078            .tools
1079            .into_iter()
1080            .map(|tool| ToolDefinition {
1081                name: tool.name,
1082                description: Some(tool.description),
1083                input_schema: tool.parameters,
1084            })
1085            .collect::<Vec<_>>();
1086
1087        // Convert system prompt to array format for cache_control support
1088        let mut system = if let Some(preamble) = req.preamble {
1089            if preamble.is_empty() {
1090                vec![]
1091            } else {
1092                vec![SystemContent::Text {
1093                    text: preamble,
1094                    cache_control: None,
1095                }]
1096            }
1097        } else {
1098            vec![]
1099        };
1100
1101        // Apply cache control breakpoints only if prompt_caching is enabled
1102        if prompt_caching {
1103            apply_cache_control(&mut system, &mut messages);
1104        }
1105
1106        // Map output_schema to Anthropic's output_config field
1107        let output_config = req.output_schema.map(|schema| {
1108            let mut schema_value = schema.to_value();
1109            sanitize_schema(&mut schema_value);
1110            OutputConfig {
1111                format: OutputFormat::JsonSchema {
1112                    schema: schema_value,
1113                },
1114            }
1115        });
1116
1117        Ok(Self {
1118            model: model.to_string(),
1119            messages,
1120            max_tokens,
1121            system,
1122            temperature: req.temperature,
1123            tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
1124            tools,
1125            output_config,
1126            additional_params: req.additional_params,
1127        })
1128    }
1129}
1130
1131impl<T> completion::CompletionModel for CompletionModel<T>
1132where
1133    T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
1134{
1135    type Response = CompletionResponse;
1136    type StreamingResponse = StreamingCompletionResponse;
1137    type Client = Client<T>;
1138
1139    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1140        Self::new(client.clone(), model.into())
1141    }
1142
1143    async fn completion(
1144        &self,
1145        mut completion_request: completion::CompletionRequest,
1146    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1147        let request_model = completion_request
1148            .model
1149            .clone()
1150            .unwrap_or_else(|| self.model.clone());
1151        let span = if tracing::Span::current().is_disabled() {
1152            info_span!(
1153                target: "rig::completions",
1154                "chat",
1155                gen_ai.operation.name = "chat",
1156                gen_ai.provider.name = "anthropic",
1157                gen_ai.request.model = &request_model,
1158                gen_ai.system_instructions = &completion_request.preamble,
1159                gen_ai.response.id = tracing::field::Empty,
1160                gen_ai.response.model = tracing::field::Empty,
1161                gen_ai.usage.output_tokens = tracing::field::Empty,
1162                gen_ai.usage.input_tokens = tracing::field::Empty,
1163            )
1164        } else {
1165            tracing::Span::current()
1166        };
1167
1168        // Check if max_tokens is set, required for Anthropic
1169        if completion_request.max_tokens.is_none() {
1170            if let Some(tokens) = self.default_max_tokens {
1171                completion_request.max_tokens = Some(tokens);
1172            } else {
1173                return Err(CompletionError::RequestError(
1174                    "`max_tokens` must be set for Anthropic".into(),
1175                ));
1176            }
1177        }
1178
1179        let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
1180            model: &request_model,
1181            request: completion_request,
1182            prompt_caching: self.prompt_caching,
1183        })?;
1184
1185        if enabled!(Level::TRACE) {
1186            tracing::trace!(
1187                target: "rig::completions",
1188                "Anthropic completion request: {}",
1189                serde_json::to_string_pretty(&request)?
1190            );
1191        }
1192
1193        async move {
1194            let request: Vec<u8> = serde_json::to_vec(&request)?;
1195
1196            let req = self
1197                .client
1198                .post("/v1/messages")?
1199                .body(request)
1200                .map_err(|e| CompletionError::HttpError(e.into()))?;
1201
1202            let response = self
1203                .client
1204                .send::<_, Bytes>(req)
1205                .await
1206                .map_err(CompletionError::HttpError)?;
1207
1208            if response.status().is_success() {
1209                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
1210                    response
1211                        .into_body()
1212                        .await
1213                        .map_err(CompletionError::HttpError)?
1214                        .to_vec()
1215                        .as_slice(),
1216                )? {
1217                    ApiResponse::Message(completion) => {
1218                        let span = tracing::Span::current();
1219                        span.record_response_metadata(&completion);
1220                        span.record_token_usage(&completion.usage);
1221                        if enabled!(Level::TRACE) {
1222                            tracing::trace!(
1223                                target: "rig::completions",
1224                                "Anthropic completion response: {}",
1225                                serde_json::to_string_pretty(&completion)?
1226                            );
1227                        }
1228                        completion.try_into()
1229                    }
1230                    ApiResponse::Error(ApiErrorResponse { message }) => {
1231                        Err(CompletionError::ResponseError(message))
1232                    }
1233                }
1234            } else {
1235                let text: String = String::from_utf8_lossy(
1236                    &response
1237                        .into_body()
1238                        .await
1239                        .map_err(CompletionError::HttpError)?,
1240                )
1241                .into();
1242                Err(CompletionError::ProviderError(text))
1243            }
1244        }
1245        .instrument(span)
1246        .await
1247    }
1248
1249    async fn stream(
1250        &self,
1251        request: CompletionRequest,
1252    ) -> Result<
1253        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1254        CompletionError,
1255    > {
1256        CompletionModel::stream(self, request).await
1257    }
1258}
1259
1260#[derive(Debug, Deserialize)]
1261struct ApiErrorResponse {
1262    message: String,
1263}
1264
1265#[derive(Debug, Deserialize)]
1266#[serde(tag = "type", rename_all = "snake_case")]
1267enum ApiResponse<T> {
1268    Message(T),
1269    Error(ApiErrorResponse),
1270}
1271
1272#[cfg(test)]
1273mod tests {
1274    use super::*;
1275    use serde_json::json;
1276    use serde_path_to_error::deserialize;
1277
1278    #[test]
1279    fn test_deserialize_message() {
1280        let assistant_message_json = r#"
1281        {
1282            "role": "assistant",
1283            "content": "\n\nHello there, how may I assist you today?"
1284        }
1285        "#;
1286
1287        let assistant_message_json2 = r#"
1288        {
1289            "role": "assistant",
1290            "content": [
1291                {
1292                    "type": "text",
1293                    "text": "\n\nHello there, how may I assist you today?"
1294                },
1295                {
1296                    "type": "tool_use",
1297                    "id": "toolu_01A09q90qw90lq917835lq9",
1298                    "name": "get_weather",
1299                    "input": {"location": "San Francisco, CA"}
1300                }
1301            ]
1302        }
1303        "#;
1304
1305        let user_message_json = r#"
1306        {
1307            "role": "user",
1308            "content": [
1309                {
1310                    "type": "image",
1311                    "source": {
1312                        "type": "base64",
1313                        "media_type": "image/jpeg",
1314                        "data": "/9j/4AAQSkZJRg..."
1315                    }
1316                },
1317                {
1318                    "type": "text",
1319                    "text": "What is in this image?"
1320                },
1321                {
1322                    "type": "tool_result",
1323                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1324                    "content": "15 degrees"
1325                }
1326            ]
1327        }
1328        "#;
1329
1330        let assistant_message: Message = {
1331            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1332            deserialize(jd).unwrap_or_else(|err| {
1333                panic!("Deserialization error at {}: {}", err.path(), err);
1334            })
1335        };
1336
1337        let assistant_message2: Message = {
1338            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1339            deserialize(jd).unwrap_or_else(|err| {
1340                panic!("Deserialization error at {}: {}", err.path(), err);
1341            })
1342        };
1343
1344        let user_message: Message = {
1345            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1346            deserialize(jd).unwrap_or_else(|err| {
1347                panic!("Deserialization error at {}: {}", err.path(), err);
1348            })
1349        };
1350
1351        let Message { role, content } = assistant_message;
1352        assert_eq!(role, Role::Assistant);
1353        assert_eq!(
1354            content.first(),
1355            Content::Text {
1356                text: "\n\nHello there, how may I assist you today?".to_owned(),
1357                cache_control: None,
1358            }
1359        );
1360
1361        let Message { role, content } = assistant_message2;
1362        {
1363            assert_eq!(role, Role::Assistant);
1364            assert_eq!(content.len(), 2);
1365
1366            let mut iter = content.into_iter();
1367
1368            match iter.next().unwrap() {
1369                Content::Text { text, .. } => {
1370                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
1371                }
1372                _ => panic!("Expected text content"),
1373            }
1374
1375            match iter.next().unwrap() {
1376                Content::ToolUse { id, name, input } => {
1377                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1378                    assert_eq!(name, "get_weather");
1379                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
1380                }
1381                _ => panic!("Expected tool use content"),
1382            }
1383
1384            assert_eq!(iter.next(), None);
1385        }
1386
1387        let Message { role, content } = user_message;
1388        {
1389            assert_eq!(role, Role::User);
1390            assert_eq!(content.len(), 3);
1391
1392            let mut iter = content.into_iter();
1393
1394            match iter.next().unwrap() {
1395                Content::Image { source, .. } => {
1396                    assert_eq!(
1397                        source,
1398                        ImageSource::Base64 {
1399                            data: "/9j/4AAQSkZJRg...".to_owned(),
1400                            media_type: ImageFormat::JPEG,
1401                        }
1402                    );
1403                }
1404                _ => panic!("Expected image content"),
1405            }
1406
1407            match iter.next().unwrap() {
1408                Content::Text { text, .. } => {
1409                    assert_eq!(text, "What is in this image?");
1410                }
1411                _ => panic!("Expected text content"),
1412            }
1413
1414            match iter.next().unwrap() {
1415                Content::ToolResult {
1416                    tool_use_id,
1417                    content,
1418                    is_error,
1419                    ..
1420                } => {
1421                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1422                    assert_eq!(
1423                        content.first(),
1424                        ToolResultContent::Text {
1425                            text: "15 degrees".to_owned()
1426                        }
1427                    );
1428                    assert_eq!(is_error, None);
1429                }
1430                _ => panic!("Expected tool result content"),
1431            }
1432
1433            assert_eq!(iter.next(), None);
1434        }
1435    }
1436
1437    #[test]
1438    fn test_message_to_message_conversion() {
1439        let user_message: Message = serde_json::from_str(
1440            r#"
1441        {
1442            "role": "user",
1443            "content": [
1444                {
1445                    "type": "image",
1446                    "source": {
1447                        "type": "base64",
1448                        "media_type": "image/jpeg",
1449                        "data": "/9j/4AAQSkZJRg..."
1450                    }
1451                },
1452                {
1453                    "type": "text",
1454                    "text": "What is in this image?"
1455                },
1456                {
1457                    "type": "document",
1458                    "source": {
1459                        "type": "base64",
1460                        "data": "base64_encoded_pdf_data",
1461                        "media_type": "application/pdf"
1462                    }
1463                }
1464            ]
1465        }
1466        "#,
1467        )
1468        .unwrap();
1469
1470        let assistant_message = Message {
1471            role: Role::Assistant,
1472            content: OneOrMany::one(Content::ToolUse {
1473                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1474                name: "get_weather".to_string(),
1475                input: json!({"location": "San Francisco, CA"}),
1476            }),
1477        };
1478
1479        let tool_message = Message {
1480            role: Role::User,
1481            content: OneOrMany::one(Content::ToolResult {
1482                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1483                content: OneOrMany::one(ToolResultContent::Text {
1484                    text: "15 degrees".to_string(),
1485                }),
1486                is_error: None,
1487                cache_control: None,
1488            }),
1489        };
1490
1491        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1492        let converted_assistant_message: message::Message =
1493            assistant_message.clone().try_into().unwrap();
1494        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1495
1496        match converted_user_message.clone() {
1497            message::Message::User { content } => {
1498                assert_eq!(content.len(), 3);
1499
1500                let mut iter = content.into_iter();
1501
1502                match iter.next().unwrap() {
1503                    message::UserContent::Image(message::Image {
1504                        data, media_type, ..
1505                    }) => {
1506                        assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1507                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1508                    }
1509                    _ => panic!("Expected image content"),
1510                }
1511
1512                match iter.next().unwrap() {
1513                    message::UserContent::Text(message::Text { text }) => {
1514                        assert_eq!(text, "What is in this image?");
1515                    }
1516                    _ => panic!("Expected text content"),
1517                }
1518
1519                match iter.next().unwrap() {
1520                    message::UserContent::Document(message::Document {
1521                        data, media_type, ..
1522                    }) => {
1523                        assert_eq!(
1524                            data,
1525                            DocumentSourceKind::String("base64_encoded_pdf_data".into())
1526                        );
1527                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1528                    }
1529                    _ => panic!("Expected document content"),
1530                }
1531
1532                assert_eq!(iter.next(), None);
1533            }
1534            _ => panic!("Expected user message"),
1535        }
1536
1537        match converted_tool_message.clone() {
1538            message::Message::User { content } => {
1539                let message::ToolResult { id, content, .. } = match content.first() {
1540                    message::UserContent::ToolResult(tool_result) => tool_result,
1541                    _ => panic!("Expected tool result content"),
1542                };
1543                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1544                match content.first() {
1545                    message::ToolResultContent::Text(message::Text { text }) => {
1546                        assert_eq!(text, "15 degrees");
1547                    }
1548                    _ => panic!("Expected text content"),
1549                }
1550            }
1551            _ => panic!("Expected tool result content"),
1552        }
1553
1554        match converted_assistant_message.clone() {
1555            message::Message::Assistant { content, .. } => {
1556                assert_eq!(content.len(), 1);
1557
1558                match content.first() {
1559                    message::AssistantContent::ToolCall(message::ToolCall {
1560                        id, function, ..
1561                    }) => {
1562                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1563                        assert_eq!(function.name, "get_weather");
1564                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1565                    }
1566                    _ => panic!("Expected tool call content"),
1567                }
1568            }
1569            _ => panic!("Expected assistant message"),
1570        }
1571
1572        let original_user_message: Message = converted_user_message.try_into().unwrap();
1573        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1574        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1575
1576        assert_eq!(user_message, original_user_message);
1577        assert_eq!(assistant_message, original_assistant_message);
1578        assert_eq!(tool_message, original_tool_message);
1579    }
1580
1581    #[test]
1582    fn test_content_format_conversion() {
1583        use crate::completion::message::ContentFormat;
1584
1585        let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1586        assert_eq!(source_type, SourceType::URL);
1587
1588        let content_format: ContentFormat = SourceType::URL.into();
1589        assert_eq!(content_format, ContentFormat::Url);
1590
1591        let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1592        assert_eq!(source_type, SourceType::BASE64);
1593
1594        let content_format: ContentFormat = SourceType::BASE64.into();
1595        assert_eq!(content_format, ContentFormat::Base64);
1596
1597        let source_type: SourceType = ContentFormat::String.try_into().unwrap();
1598        assert_eq!(source_type, SourceType::TEXT);
1599
1600        let content_format: ContentFormat = SourceType::TEXT.into();
1601        assert_eq!(content_format, ContentFormat::String);
1602    }
1603
1604    #[test]
1605    fn test_cache_control_serialization() {
1606        // Test SystemContent with cache_control
1607        let system = SystemContent::Text {
1608            text: "You are a helpful assistant.".to_string(),
1609            cache_control: Some(CacheControl::Ephemeral),
1610        };
1611        let json = serde_json::to_string(&system).unwrap();
1612        assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1613        assert!(json.contains(r#""type":"text""#));
1614
1615        // Test SystemContent without cache_control (should not have cache_control field)
1616        let system_no_cache = SystemContent::Text {
1617            text: "Hello".to_string(),
1618            cache_control: None,
1619        };
1620        let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1621        assert!(!json_no_cache.contains("cache_control"));
1622
1623        // Test Content::Text with cache_control
1624        let content = Content::Text {
1625            text: "Test message".to_string(),
1626            cache_control: Some(CacheControl::Ephemeral),
1627        };
1628        let json_content = serde_json::to_string(&content).unwrap();
1629        assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1630
1631        // Test apply_cache_control function
1632        let mut system_vec = vec![SystemContent::Text {
1633            text: "System prompt".to_string(),
1634            cache_control: None,
1635        }];
1636        let mut messages = vec![
1637            Message {
1638                role: Role::User,
1639                content: OneOrMany::one(Content::Text {
1640                    text: "First message".to_string(),
1641                    cache_control: None,
1642                }),
1643            },
1644            Message {
1645                role: Role::Assistant,
1646                content: OneOrMany::one(Content::Text {
1647                    text: "Response".to_string(),
1648                    cache_control: None,
1649                }),
1650            },
1651        ];
1652
1653        apply_cache_control(&mut system_vec, &mut messages);
1654
1655        // System should have cache_control
1656        match &system_vec[0] {
1657            SystemContent::Text { cache_control, .. } => {
1658                assert!(cache_control.is_some());
1659            }
1660        }
1661
1662        // Only the last content block of last message should have cache_control
1663        // First message should NOT have cache_control
1664        for content in messages[0].content.iter() {
1665            if let Content::Text { cache_control, .. } = content {
1666                assert!(cache_control.is_none());
1667            }
1668        }
1669
1670        // Last message SHOULD have cache_control
1671        for content in messages[1].content.iter() {
1672            if let Content::Text { cache_control, .. } = content {
1673                assert!(cache_control.is_some());
1674            }
1675        }
1676    }
1677
1678    #[test]
1679    fn test_plaintext_document_serialization() {
1680        let content = Content::Document {
1681            source: DocumentSource::Text {
1682                data: "Hello, world!".to_string(),
1683                media_type: PlainTextMediaType::Plain,
1684            },
1685            cache_control: None,
1686        };
1687
1688        let json = serde_json::to_value(&content).unwrap();
1689        assert_eq!(json["type"], "document");
1690        assert_eq!(json["source"]["type"], "text");
1691        assert_eq!(json["source"]["media_type"], "text/plain");
1692        assert_eq!(json["source"]["data"], "Hello, world!");
1693    }
1694
1695    #[test]
1696    fn test_plaintext_document_deserialization() {
1697        let json = r#"
1698        {
1699            "type": "document",
1700            "source": {
1701                "type": "text",
1702                "media_type": "text/plain",
1703                "data": "Hello, world!"
1704            }
1705        }
1706        "#;
1707
1708        let content: Content = serde_json::from_str(json).unwrap();
1709        match content {
1710            Content::Document {
1711                source,
1712                cache_control,
1713            } => {
1714                assert_eq!(
1715                    source,
1716                    DocumentSource::Text {
1717                        data: "Hello, world!".to_string(),
1718                        media_type: PlainTextMediaType::Plain,
1719                    }
1720                );
1721                assert_eq!(cache_control, None);
1722            }
1723            _ => panic!("Expected Document content"),
1724        }
1725    }
1726
1727    #[test]
1728    fn test_base64_pdf_document_serialization() {
1729        let content = Content::Document {
1730            source: DocumentSource::Base64 {
1731                data: "base64data".to_string(),
1732                media_type: DocumentFormat::PDF,
1733            },
1734            cache_control: None,
1735        };
1736
1737        let json = serde_json::to_value(&content).unwrap();
1738        assert_eq!(json["type"], "document");
1739        assert_eq!(json["source"]["type"], "base64");
1740        assert_eq!(json["source"]["media_type"], "application/pdf");
1741        assert_eq!(json["source"]["data"], "base64data");
1742    }
1743
1744    #[test]
1745    fn test_base64_pdf_document_deserialization() {
1746        let json = r#"
1747        {
1748            "type": "document",
1749            "source": {
1750                "type": "base64",
1751                "media_type": "application/pdf",
1752                "data": "base64data"
1753            }
1754        }
1755        "#;
1756
1757        let content: Content = serde_json::from_str(json).unwrap();
1758        match content {
1759            Content::Document { source, .. } => {
1760                assert_eq!(
1761                    source,
1762                    DocumentSource::Base64 {
1763                        data: "base64data".to_string(),
1764                        media_type: DocumentFormat::PDF,
1765                    }
1766                );
1767            }
1768            _ => panic!("Expected Document content"),
1769        }
1770    }
1771
1772    #[test]
1773    fn test_plaintext_rig_to_anthropic_conversion() {
1774        use crate::completion::message as msg;
1775
1776        let rig_message = msg::Message::User {
1777            content: OneOrMany::one(msg::UserContent::document(
1778                "Some plain text content".to_string(),
1779                Some(msg::DocumentMediaType::TXT),
1780            )),
1781        };
1782
1783        let anthropic_message: Message = rig_message.try_into().unwrap();
1784        assert_eq!(anthropic_message.role, Role::User);
1785
1786        let mut iter = anthropic_message.content.into_iter();
1787        match iter.next().unwrap() {
1788            Content::Document { source, .. } => {
1789                assert_eq!(
1790                    source,
1791                    DocumentSource::Text {
1792                        data: "Some plain text content".to_string(),
1793                        media_type: PlainTextMediaType::Plain,
1794                    }
1795                );
1796            }
1797            other => panic!("Expected Document content, got: {other:?}"),
1798        }
1799    }
1800
1801    #[test]
1802    fn test_plaintext_anthropic_to_rig_conversion() {
1803        use crate::completion::message as msg;
1804
1805        let anthropic_message = Message {
1806            role: Role::User,
1807            content: OneOrMany::one(Content::Document {
1808                source: DocumentSource::Text {
1809                    data: "Some plain text content".to_string(),
1810                    media_type: PlainTextMediaType::Plain,
1811                },
1812                cache_control: None,
1813            }),
1814        };
1815
1816        let rig_message: msg::Message = anthropic_message.try_into().unwrap();
1817        match rig_message {
1818            msg::Message::User { content } => {
1819                let mut iter = content.into_iter();
1820                match iter.next().unwrap() {
1821                    msg::UserContent::Document(msg::Document {
1822                        data, media_type, ..
1823                    }) => {
1824                        assert_eq!(
1825                            data,
1826                            DocumentSourceKind::String("Some plain text content".into())
1827                        );
1828                        assert_eq!(media_type, Some(msg::DocumentMediaType::TXT));
1829                    }
1830                    other => panic!("Expected Document content, got: {other:?}"),
1831                }
1832            }
1833            _ => panic!("Expected User message"),
1834        }
1835    }
1836
1837    #[test]
1838    fn test_plaintext_roundtrip_rig_to_anthropic_and_back() {
1839        use crate::completion::message as msg;
1840
1841        let original = msg::Message::User {
1842            content: OneOrMany::one(msg::UserContent::document(
1843                "Round trip text".to_string(),
1844                Some(msg::DocumentMediaType::TXT),
1845            )),
1846        };
1847
1848        let anthropic: Message = original.clone().try_into().unwrap();
1849        let back: msg::Message = anthropic.try_into().unwrap();
1850
1851        match (&original, &back) {
1852            (
1853                msg::Message::User {
1854                    content: orig_content,
1855                },
1856                msg::Message::User {
1857                    content: back_content,
1858                },
1859            ) => match (orig_content.first(), back_content.first()) {
1860                (
1861                    msg::UserContent::Document(msg::Document {
1862                        media_type: orig_mt,
1863                        ..
1864                    }),
1865                    msg::UserContent::Document(msg::Document {
1866                        media_type: back_mt,
1867                        ..
1868                    }),
1869                ) => {
1870                    assert_eq!(orig_mt, back_mt);
1871                }
1872                _ => panic!("Expected Document content in both"),
1873            },
1874            _ => panic!("Expected User messages"),
1875        }
1876    }
1877
1878    #[test]
1879    fn test_unsupported_document_type_returns_error() {
1880        use crate::completion::message as msg;
1881
1882        let rig_message = msg::Message::User {
1883            content: OneOrMany::one(msg::UserContent::Document(msg::Document {
1884                data: DocumentSourceKind::String("data".into()),
1885                media_type: Some(msg::DocumentMediaType::HTML),
1886                additional_params: None,
1887            })),
1888        };
1889
1890        let result: Result<Message, _> = rig_message.try_into();
1891        assert!(result.is_err());
1892        let err = result.unwrap_err().to_string();
1893        assert!(
1894            err.contains("Anthropic only supports PDF and plain text documents"),
1895            "Unexpected error: {err}"
1896        );
1897    }
1898
1899    #[test]
1900    fn test_plaintext_document_url_source_returns_error() {
1901        use crate::completion::message as msg;
1902
1903        let rig_message = msg::Message::User {
1904            content: OneOrMany::one(msg::UserContent::Document(msg::Document {
1905                data: DocumentSourceKind::Url("https://example.com/doc.txt".into()),
1906                media_type: Some(msg::DocumentMediaType::TXT),
1907                additional_params: None,
1908            })),
1909        };
1910
1911        let result: Result<Message, _> = rig_message.try_into();
1912        assert!(result.is_err());
1913        let err = result.unwrap_err().to_string();
1914        assert!(
1915            err.contains("Only string or base64 data is supported for plain text documents"),
1916            "Unexpected error: {err}"
1917        );
1918    }
1919
1920    #[test]
1921    fn test_plaintext_document_with_cache_control() {
1922        let content = Content::Document {
1923            source: DocumentSource::Text {
1924                data: "cached text".to_string(),
1925                media_type: PlainTextMediaType::Plain,
1926            },
1927            cache_control: Some(CacheControl::Ephemeral),
1928        };
1929
1930        let json = serde_json::to_value(&content).unwrap();
1931        assert_eq!(json["source"]["type"], "text");
1932        assert_eq!(json["source"]["media_type"], "text/plain");
1933        assert_eq!(json["cache_control"]["type"], "ephemeral");
1934    }
1935
1936    #[test]
1937    fn test_message_with_plaintext_document_deserialization() {
1938        let json = r#"
1939        {
1940            "role": "user",
1941            "content": [
1942                {
1943                    "type": "document",
1944                    "source": {
1945                        "type": "text",
1946                        "media_type": "text/plain",
1947                        "data": "Hello from a text file"
1948                    }
1949                },
1950                {
1951                    "type": "text",
1952                    "text": "Summarize this document."
1953                }
1954            ]
1955        }
1956        "#;
1957
1958        let message: Message = serde_json::from_str(json).unwrap();
1959        assert_eq!(message.role, Role::User);
1960        assert_eq!(message.content.len(), 2);
1961
1962        let mut iter = message.content.into_iter();
1963
1964        match iter.next().unwrap() {
1965            Content::Document { source, .. } => {
1966                assert_eq!(
1967                    source,
1968                    DocumentSource::Text {
1969                        data: "Hello from a text file".to_string(),
1970                        media_type: PlainTextMediaType::Plain,
1971                    }
1972                );
1973            }
1974            _ => panic!("Expected Document content"),
1975        }
1976
1977        match iter.next().unwrap() {
1978            Content::Text { text, .. } => {
1979                assert_eq!(text, "Summarize this document.");
1980            }
1981            _ => panic!("Expected Text content"),
1982        }
1983    }
1984
1985    #[test]
1986    fn test_assistant_reasoning_multiblock_to_anthropic_content() {
1987        let reasoning = message::Reasoning {
1988            id: None,
1989            content: vec![
1990                message::ReasoningContent::Text {
1991                    text: "step one".to_string(),
1992                    signature: Some("sig-1".to_string()),
1993                },
1994                message::ReasoningContent::Summary("summary".to_string()),
1995                message::ReasoningContent::Text {
1996                    text: "step two".to_string(),
1997                    signature: Some("sig-2".to_string()),
1998                },
1999                message::ReasoningContent::Redacted {
2000                    data: "redacted block".to_string(),
2001                },
2002            ],
2003        };
2004
2005        let msg = message::Message::Assistant {
2006            id: None,
2007            content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2008        };
2009        let converted: Message = msg.try_into().expect("convert assistant message");
2010        let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2011
2012        assert_eq!(converted.role, Role::Assistant);
2013        assert_eq!(converted_content.len(), 4);
2014        assert!(matches!(
2015            converted_content.first(),
2016            Some(Content::Thinking { thinking, signature: Some(signature) })
2017                if thinking == "step one" && signature == "sig-1"
2018        ));
2019        assert!(matches!(
2020            converted_content.get(1),
2021            Some(Content::Thinking { thinking, signature: None }) if thinking == "summary"
2022        ));
2023        assert!(matches!(
2024            converted_content.get(2),
2025            Some(Content::Thinking { thinking, signature: Some(signature) })
2026                if thinking == "step two" && signature == "sig-2"
2027        ));
2028        assert!(matches!(
2029            converted_content.get(3),
2030            Some(Content::RedactedThinking { data }) if data == "redacted block"
2031        ));
2032    }
2033
2034    #[test]
2035    fn test_redacted_thinking_content_to_assistant_reasoning() {
2036        let content = Content::RedactedThinking {
2037            data: "opaque-redacted".to_string(),
2038        };
2039        let converted: message::AssistantContent =
2040            content.try_into().expect("convert redacted thinking");
2041
2042        assert!(matches!(
2043            converted,
2044            message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2045                if matches!(
2046                    content.first(),
2047                    Some(message::ReasoningContent::Redacted { data }) if data == "opaque-redacted"
2048                )
2049        ));
2050    }
2051
2052    #[test]
2053    fn test_assistant_encrypted_reasoning_maps_to_redacted_thinking() {
2054        let reasoning = message::Reasoning {
2055            id: None,
2056            content: vec![message::ReasoningContent::Encrypted(
2057                "ciphertext".to_string(),
2058            )],
2059        };
2060        let msg = message::Message::Assistant {
2061            id: None,
2062            content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2063        };
2064
2065        let converted: Message = msg.try_into().expect("convert assistant message");
2066        let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2067
2068        assert_eq!(converted_content.len(), 1);
2069        assert!(matches!(
2070            converted_content.first(),
2071            Some(Content::RedactedThinking { data }) if data == "ciphertext"
2072        ));
2073    }
2074}