Skip to main content

rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError, GetTokenUsage},
6    http_client::HttpClientExt,
7    message::{self, DocumentMediaType, DocumentSourceKind, MessageError, MimeType, Reasoning},
8    one_or_many::string_or_one_or_many,
9    telemetry::{ProviderResponseExt, SpanCombinator},
10    wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21// ================================================================
22// Anthropic Completion API
23// ================================================================
24
25/// `claude-opus-4-0` completion model
26pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27/// `claude-sonnet-4-0` completion model
28pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29/// `claude-3-7-sonnet-latest` completion model
30pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31/// `claude-3-5-sonnet-latest` completion model
32pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33/// `claude-3-5-haiku-latest` completion model
34pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42    pub content: Vec<Content>,
43    pub id: String,
44    pub model: String,
45    pub role: String,
46    pub stop_reason: Option<String>,
47    pub stop_sequence: Option<String>,
48    pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52    type OutputMessage = Content;
53    type Usage = Usage;
54
55    fn get_response_id(&self) -> Option<String> {
56        Some(self.id.to_owned())
57    }
58
59    fn get_response_model_name(&self) -> Option<String> {
60        Some(self.model.to_owned())
61    }
62
63    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64        self.content.clone()
65    }
66
67    fn get_text_response(&self) -> Option<String> {
68        let res = self
69            .content
70            .iter()
71            .filter_map(|x| {
72                if let Content::Text { text, .. } = x {
73                    Some(text.to_owned())
74                } else {
75                    None
76                }
77            })
78            .collect::<Vec<String>>()
79            .join("\n");
80
81        if res.is_empty() { None } else { Some(res) }
82    }
83
84    fn get_usage(&self) -> Option<Self::Usage> {
85        Some(self.usage.clone())
86    }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91    pub input_tokens: u64,
92    pub cache_read_input_tokens: Option<u64>,
93    pub cache_creation_input_tokens: Option<u64>,
94    pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        write!(
100            f,
101            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102            self.input_tokens,
103            match self.cache_read_input_tokens {
104                Some(token) => token.to_string(),
105                None => "n/a".to_string(),
106            },
107            match self.cache_creation_input_tokens {
108                Some(token) => token.to_string(),
109                None => "n/a".to_string(),
110            },
111            self.output_tokens
112        )
113    }
114}
115
116impl GetTokenUsage for Usage {
117    fn token_usage(&self) -> Option<crate::completion::Usage> {
118        let mut usage = crate::completion::Usage::new();
119
120        usage.input_tokens = self.input_tokens
121            + self.cache_creation_input_tokens.unwrap_or_default()
122            + self.cache_read_input_tokens.unwrap_or_default();
123        usage.output_tokens = self.output_tokens;
124        usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126        Some(usage)
127    }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132    pub name: String,
133    pub description: Option<String>,
134    pub input_schema: serde_json::Value,
135}
136
137/// Cache control directive for Anthropic prompt caching
138#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum CacheControl {
141    Ephemeral,
142}
143
144/// System message content block with optional cache control
145#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
146#[serde(tag = "type", rename_all = "snake_case")]
147pub enum SystemContent {
148    Text {
149        text: String,
150        #[serde(skip_serializing_if = "Option::is_none")]
151        cache_control: Option<CacheControl>,
152    },
153}
154
155impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
156    type Error = CompletionError;
157
158    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
159        let content = response
160            .content
161            .iter()
162            .map(|content| content.clone().try_into())
163            .collect::<Result<Vec<_>, _>>()?;
164
165        let choice = OneOrMany::many(content).map_err(|_| {
166            CompletionError::ResponseError(
167                "Response contained no message or tool call (empty)".to_owned(),
168            )
169        })?;
170
171        let usage = completion::Usage {
172            input_tokens: response.usage.input_tokens,
173            output_tokens: response.usage.output_tokens,
174            total_tokens: response.usage.input_tokens + response.usage.output_tokens,
175            cached_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
176        };
177
178        Ok(completion::CompletionResponse {
179            choice,
180            usage,
181            raw_response: response,
182            message_id: None,
183        })
184    }
185}
186
187#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
188pub struct Message {
189    pub role: Role,
190    #[serde(deserialize_with = "string_or_one_or_many")]
191    pub content: OneOrMany<Content>,
192}
193
194#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
195#[serde(rename_all = "lowercase")]
196pub enum Role {
197    User,
198    Assistant,
199}
200
201#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
202#[serde(tag = "type", rename_all = "snake_case")]
203pub enum Content {
204    Text {
205        text: String,
206        #[serde(skip_serializing_if = "Option::is_none")]
207        cache_control: Option<CacheControl>,
208    },
209    Image {
210        source: ImageSource,
211        #[serde(skip_serializing_if = "Option::is_none")]
212        cache_control: Option<CacheControl>,
213    },
214    ToolUse {
215        id: String,
216        name: String,
217        input: serde_json::Value,
218    },
219    ToolResult {
220        tool_use_id: String,
221        #[serde(deserialize_with = "string_or_one_or_many")]
222        content: OneOrMany<ToolResultContent>,
223        #[serde(skip_serializing_if = "Option::is_none")]
224        is_error: Option<bool>,
225        #[serde(skip_serializing_if = "Option::is_none")]
226        cache_control: Option<CacheControl>,
227    },
228    Document {
229        source: DocumentSource,
230        #[serde(skip_serializing_if = "Option::is_none")]
231        cache_control: Option<CacheControl>,
232    },
233    Thinking {
234        thinking: String,
235        #[serde(skip_serializing_if = "Option::is_none")]
236        signature: Option<String>,
237    },
238    RedactedThinking {
239        data: String,
240    },
241}
242
243impl FromStr for Content {
244    type Err = Infallible;
245
246    fn from_str(s: &str) -> Result<Self, Self::Err> {
247        Ok(Content::Text {
248            text: s.to_owned(),
249            cache_control: None,
250        })
251    }
252}
253
254#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
255#[serde(tag = "type", rename_all = "snake_case")]
256pub enum ToolResultContent {
257    Text { text: String },
258    Image(ImageSource),
259}
260
261impl FromStr for ToolResultContent {
262    type Err = Infallible;
263
264    fn from_str(s: &str) -> Result<Self, Self::Err> {
265        Ok(ToolResultContent::Text { text: s.to_owned() })
266    }
267}
268
269/// The source of an image content block.
270///
271/// Anthropic supports two source types for images:
272/// - `Base64`: Base64-encoded image data with media type
273/// - `Url`: URL reference to an image
274///
275/// See: <https://docs.anthropic.com/en/api/messages>
276#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
277#[serde(tag = "type", rename_all = "snake_case")]
278pub enum ImageSource {
279    #[serde(rename = "base64")]
280    Base64 {
281        data: String,
282        media_type: ImageFormat,
283    },
284    #[serde(rename = "url")]
285    Url { url: String },
286}
287
288/// The source of a document content block.
289///
290/// Anthropic supports multiple source types for documents. Currently implemented:
291/// - `Base64`: Base64-encoded document data (used for PDFs)
292/// - `Text`: Plain text document data
293///
294/// Future variants (not yet implemented):
295/// - URL-based PDF sources
296/// - Content block sources
297/// - File API sources
298#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
299#[serde(tag = "type", rename_all = "snake_case")]
300pub enum DocumentSource {
301    Base64 {
302        data: String,
303        media_type: DocumentFormat,
304    },
305    Text {
306        data: String,
307        media_type: PlainTextMediaType,
308    },
309}
310
311#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
312#[serde(rename_all = "lowercase")]
313pub enum ImageFormat {
314    #[serde(rename = "image/jpeg")]
315    JPEG,
316    #[serde(rename = "image/png")]
317    PNG,
318    #[serde(rename = "image/gif")]
319    GIF,
320    #[serde(rename = "image/webp")]
321    WEBP,
322}
323
324/// The media type for base64-encoded documents.
325///
326/// Used with the `DocumentSource::Base64` variant. Currently only PDF is supported
327/// for base64-encoded document sources.
328///
329/// See: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
330#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
331#[serde(rename_all = "lowercase")]
332pub enum DocumentFormat {
333    #[serde(rename = "application/pdf")]
334    PDF,
335}
336
337/// The media type for plain text document sources.
338///
339/// Used with the `DocumentSource::Text` variant.
340///
341/// See: <https://docs.anthropic.com/en/api/messages>
342#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
343pub enum PlainTextMediaType {
344    #[serde(rename = "text/plain")]
345    Plain,
346}
347
348#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
349#[serde(rename_all = "lowercase")]
350pub enum SourceType {
351    BASE64,
352    URL,
353    TEXT,
354}
355
356impl From<String> for Content {
357    fn from(text: String) -> Self {
358        Content::Text {
359            text,
360            cache_control: None,
361        }
362    }
363}
364
365impl From<String> for ToolResultContent {
366    fn from(text: String) -> Self {
367        ToolResultContent::Text { text }
368    }
369}
370
371impl TryFrom<message::ContentFormat> for SourceType {
372    type Error = MessageError;
373
374    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
375        match format {
376            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
377            message::ContentFormat::Url => Ok(SourceType::URL),
378            message::ContentFormat::String => Ok(SourceType::TEXT),
379        }
380    }
381}
382
383impl From<SourceType> for message::ContentFormat {
384    fn from(source_type: SourceType) -> Self {
385        match source_type {
386            SourceType::BASE64 => message::ContentFormat::Base64,
387            SourceType::URL => message::ContentFormat::Url,
388            SourceType::TEXT => message::ContentFormat::String,
389        }
390    }
391}
392
393impl TryFrom<message::ImageMediaType> for ImageFormat {
394    type Error = MessageError;
395
396    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
397        Ok(match media_type {
398            message::ImageMediaType::JPEG => ImageFormat::JPEG,
399            message::ImageMediaType::PNG => ImageFormat::PNG,
400            message::ImageMediaType::GIF => ImageFormat::GIF,
401            message::ImageMediaType::WEBP => ImageFormat::WEBP,
402            _ => {
403                return Err(MessageError::ConversionError(
404                    format!("Unsupported image media type: {media_type:?}").to_owned(),
405                ));
406            }
407        })
408    }
409}
410
411impl From<ImageFormat> for message::ImageMediaType {
412    fn from(format: ImageFormat) -> Self {
413        match format {
414            ImageFormat::JPEG => message::ImageMediaType::JPEG,
415            ImageFormat::PNG => message::ImageMediaType::PNG,
416            ImageFormat::GIF => message::ImageMediaType::GIF,
417            ImageFormat::WEBP => message::ImageMediaType::WEBP,
418        }
419    }
420}
421
422impl TryFrom<DocumentMediaType> for DocumentFormat {
423    type Error = MessageError;
424    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
425        match value {
426            DocumentMediaType::PDF => Ok(DocumentFormat::PDF),
427            other => Err(MessageError::ConversionError(format!(
428                "DocumentFormat only supports PDF for base64 sources, got: {}",
429                other.to_mime_type()
430            ))),
431        }
432    }
433}
434
435impl TryFrom<message::AssistantContent> for Content {
436    type Error = MessageError;
437    fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
438        match text {
439            message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
440                text,
441                cache_control: None,
442            }),
443            message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
444                "Anthropic currently doesn't support images.".to_string(),
445            )),
446            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
447                Ok(Content::ToolUse {
448                    id,
449                    name: function.name,
450                    input: function.arguments,
451                })
452            }
453            message::AssistantContent::Reasoning(reasoning) => Ok(Content::Thinking {
454                thinking: reasoning.display_text(),
455                signature: reasoning.first_signature().map(str::to_owned),
456            }),
457        }
458    }
459}
460
461fn anthropic_content_from_assistant_content(
462    content: message::AssistantContent,
463) -> Result<Vec<Content>, MessageError> {
464    match content {
465        message::AssistantContent::Text(message::Text { text }) => Ok(vec![Content::Text {
466            text,
467            cache_control: None,
468        }]),
469        message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
470            "Anthropic currently doesn't support images.".to_string(),
471        )),
472        message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
473            Ok(vec![Content::ToolUse {
474                id,
475                name: function.name,
476                input: function.arguments,
477            }])
478        }
479        message::AssistantContent::Reasoning(reasoning) => {
480            let mut converted = Vec::new();
481            for block in reasoning.content {
482                match block {
483                    message::ReasoningContent::Text { text, signature } => {
484                        converted.push(Content::Thinking {
485                            thinking: text,
486                            signature,
487                        });
488                    }
489                    message::ReasoningContent::Summary(summary) => {
490                        converted.push(Content::Thinking {
491                            thinking: summary,
492                            signature: None,
493                        });
494                    }
495                    message::ReasoningContent::Redacted { data }
496                    | message::ReasoningContent::Encrypted(data) => {
497                        converted.push(Content::RedactedThinking { data });
498                    }
499                }
500            }
501
502            if converted.is_empty() {
503                return Err(MessageError::ConversionError(
504                    "Cannot convert empty reasoning content to Anthropic format".to_string(),
505                ));
506            }
507
508            Ok(converted)
509        }
510    }
511}
512
513impl TryFrom<message::Message> for Message {
514    type Error = MessageError;
515
516    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
517        Ok(match message {
518            message::Message::User { content } => Message {
519                role: Role::User,
520                content: content.try_map(|content| match content {
521                    message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
522                        text,
523                        cache_control: None,
524                    }),
525                    message::UserContent::ToolResult(message::ToolResult {
526                        id, content, ..
527                    }) => Ok(Content::ToolResult {
528                        tool_use_id: id,
529                        content: content.try_map(|content| match content {
530                            message::ToolResultContent::Text(message::Text { text }) => {
531                                Ok(ToolResultContent::Text { text })
532                            }
533                            message::ToolResultContent::Image(image) => {
534                                let DocumentSourceKind::Base64(data) = image.data else {
535                                    return Err(MessageError::ConversionError(
536                                        "Only base64 strings can be used with the Anthropic API"
537                                            .to_string(),
538                                    ));
539                                };
540                                let media_type =
541                                    image.media_type.ok_or(MessageError::ConversionError(
542                                        "Image media type is required".to_owned(),
543                                    ))?;
544                                Ok(ToolResultContent::Image(ImageSource::Base64 {
545                                    data,
546                                    media_type: media_type.try_into()?,
547                                }))
548                            }
549                        })?,
550                        is_error: None,
551                        cache_control: None,
552                    }),
553                    message::UserContent::Image(message::Image {
554                        data, media_type, ..
555                    }) => {
556                        let source = match data {
557                            DocumentSourceKind::Base64(data) => {
558                                let media_type =
559                                    media_type.ok_or(MessageError::ConversionError(
560                                        "Image media type is required for Claude API".to_string(),
561                                    ))?;
562                                ImageSource::Base64 {
563                                    data,
564                                    media_type: ImageFormat::try_from(media_type)?,
565                                }
566                            }
567                            DocumentSourceKind::Url(url) => ImageSource::Url { url },
568                            DocumentSourceKind::Unknown => {
569                                return Err(MessageError::ConversionError(
570                                    "Image content has no body".into(),
571                                ));
572                            }
573                            doc => {
574                                return Err(MessageError::ConversionError(format!(
575                                    "Unsupported document type: {doc:?}"
576                                )));
577                            }
578                        };
579
580                        Ok(Content::Image {
581                            source,
582                            cache_control: None,
583                        })
584                    }
585                    message::UserContent::Document(message::Document {
586                        data, media_type, ..
587                    }) => {
588                        let media_type = media_type.ok_or(MessageError::ConversionError(
589                            "Document media type is required".to_string(),
590                        ))?;
591
592                        let source = match media_type {
593                            DocumentMediaType::PDF => {
594                                let data = match data {
595                                    DocumentSourceKind::Base64(data)
596                                    | DocumentSourceKind::String(data) => data,
597                                    _ => {
598                                        return Err(MessageError::ConversionError(
599                                            "Only base64 encoded data is supported for PDF documents".into(),
600                                        ));
601                                    }
602                                };
603                                DocumentSource::Base64 {
604                                    data,
605                                    media_type: DocumentFormat::PDF,
606                                }
607                            }
608                            DocumentMediaType::TXT => {
609                                let data = match data {
610                                    DocumentSourceKind::String(data)
611                                    | DocumentSourceKind::Base64(data) => data,
612                                    _ => {
613                                        return Err(MessageError::ConversionError(
614                                            "Only string or base64 data is supported for plain text documents".into(),
615                                        ));
616                                    }
617                                };
618                                DocumentSource::Text {
619                                    data,
620                                    media_type: PlainTextMediaType::Plain,
621                                }
622                            }
623                            other => {
624                                return Err(MessageError::ConversionError(format!(
625                                    "Anthropic only supports PDF and plain text documents, got: {}",
626                                    other.to_mime_type()
627                                )));
628                            }
629                        };
630
631                        Ok(Content::Document {
632                            source,
633                            cache_control: None,
634                        })
635                    }
636                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
637                        "Audio is not supported in Anthropic".to_owned(),
638                    )),
639                    message::UserContent::Video { .. } => Err(MessageError::ConversionError(
640                        "Video is not supported in Anthropic".to_owned(),
641                    )),
642                })?,
643            },
644
645            message::Message::Assistant { content, .. } => {
646                let converted_content = content.into_iter().try_fold(
647                    Vec::new(),
648                    |mut accumulated, assistant_content| {
649                        accumulated
650                            .extend(anthropic_content_from_assistant_content(assistant_content)?);
651                        Ok::<Vec<Content>, MessageError>(accumulated)
652                    },
653                )?;
654
655                Message {
656                    content: OneOrMany::many(converted_content).map_err(|_| {
657                        MessageError::ConversionError(
658                            "Assistant message did not contain Anthropic-compatible content"
659                                .to_owned(),
660                        )
661                    })?,
662                    role: Role::Assistant,
663                }
664            }
665        })
666    }
667}
668
669impl TryFrom<Content> for message::AssistantContent {
670    type Error = MessageError;
671
672    fn try_from(content: Content) -> Result<Self, Self::Error> {
673        Ok(match content {
674            Content::Text { text, .. } => message::AssistantContent::text(text),
675            Content::ToolUse { id, name, input } => {
676                message::AssistantContent::tool_call(id, name, input)
677            }
678            Content::Thinking {
679                thinking,
680                signature,
681            } => message::AssistantContent::Reasoning(Reasoning::new_with_signature(
682                &thinking, signature,
683            )),
684            Content::RedactedThinking { data } => {
685                message::AssistantContent::Reasoning(Reasoning::redacted(data))
686            }
687            _ => {
688                return Err(MessageError::ConversionError(
689                    "Content did not contain a message, tool call, or reasoning".to_owned(),
690                ));
691            }
692        })
693    }
694}
695
696impl From<ToolResultContent> for message::ToolResultContent {
697    fn from(content: ToolResultContent) -> Self {
698        match content {
699            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
700            ToolResultContent::Image(source) => match source {
701                ImageSource::Base64 { data, media_type } => {
702                    message::ToolResultContent::image_base64(data, Some(media_type.into()), None)
703                }
704                ImageSource::Url { url } => message::ToolResultContent::image_url(url, None, None),
705            },
706        }
707    }
708}
709
710impl TryFrom<Message> for message::Message {
711    type Error = MessageError;
712
713    fn try_from(message: Message) -> Result<Self, Self::Error> {
714        Ok(match message.role {
715            Role::User => message::Message::User {
716                content: message.content.try_map(|content| {
717                    Ok(match content {
718                        Content::Text { text, .. } => message::UserContent::text(text),
719                        Content::ToolResult {
720                            tool_use_id,
721                            content,
722                            ..
723                        } => message::UserContent::tool_result(
724                            tool_use_id,
725                            content.map(|content| content.into()),
726                        ),
727                        Content::Image { source, .. } => match source {
728                            ImageSource::Base64 { data, media_type } => {
729                                message::UserContent::Image(message::Image {
730                                    data: DocumentSourceKind::Base64(data),
731                                    media_type: Some(media_type.into()),
732                                    detail: None,
733                                    additional_params: None,
734                                })
735                            }
736                            ImageSource::Url { url } => {
737                                message::UserContent::Image(message::Image {
738                                    data: DocumentSourceKind::Url(url),
739                                    media_type: None,
740                                    detail: None,
741                                    additional_params: None,
742                                })
743                            }
744                        },
745                        Content::Document { source, .. } => match source {
746                            DocumentSource::Base64 { data, media_type } => {
747                                let rig_media_type = match media_type {
748                                    DocumentFormat::PDF => message::DocumentMediaType::PDF,
749                                };
750                                message::UserContent::document(data, Some(rig_media_type))
751                            }
752                            DocumentSource::Text { data, .. } => message::UserContent::document(
753                                data,
754                                Some(message::DocumentMediaType::TXT),
755                            ),
756                        },
757                        _ => {
758                            return Err(MessageError::ConversionError(
759                                "Unsupported content type for User role".to_owned(),
760                            ));
761                        }
762                    })
763                })?,
764            },
765            Role::Assistant => message::Message::Assistant {
766                id: None,
767                content: message.content.try_map(|content| content.try_into())?,
768            },
769        })
770    }
771}
772
773#[derive(Clone)]
774pub struct CompletionModel<T = reqwest::Client> {
775    pub(crate) client: Client<T>,
776    pub model: String,
777    pub default_max_tokens: Option<u64>,
778    /// Enable automatic prompt caching (adds cache_control breakpoints to system prompt and messages)
779    pub prompt_caching: bool,
780}
781
782impl<T> CompletionModel<T>
783where
784    T: HttpClientExt,
785{
786    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
787        let model = model.into();
788        let default_max_tokens = calculate_max_tokens(&model);
789
790        Self {
791            client,
792            model,
793            default_max_tokens,
794            prompt_caching: false, // Default to off
795        }
796    }
797
798    pub fn with_model(client: Client<T>, model: &str) -> Self {
799        Self {
800            client,
801            model: model.to_string(),
802            default_max_tokens: Some(calculate_max_tokens_custom(model)),
803            prompt_caching: false, // Default to off
804        }
805    }
806
807    /// Enable automatic prompt caching.
808    ///
809    /// When enabled, cache_control breakpoints are automatically added to:
810    /// - The system prompt (marked with ephemeral cache)
811    /// - The last content block of the last message (marked with ephemeral cache)
812    ///
813    /// This allows Anthropic to cache the conversation history for cost savings.
814    pub fn with_prompt_caching(mut self) -> Self {
815        self.prompt_caching = true;
816        self
817    }
818}
819
820/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
821/// set or if set too high, the request will fail. The following values are based on the models
822/// available at the time of writing.
823fn calculate_max_tokens(model: &str) -> Option<u64> {
824    if model.starts_with("claude-opus-4") {
825        Some(32000)
826    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
827        Some(64000)
828    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
829        Some(8192)
830    } else if model.starts_with("claude-3-opus")
831        || model.starts_with("claude-3-sonnet")
832        || model.starts_with("claude-3-haiku")
833    {
834        Some(4096)
835    } else {
836        None
837    }
838}
839
840fn calculate_max_tokens_custom(model: &str) -> u64 {
841    if model.starts_with("claude-opus-4") {
842        32000
843    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
844        64000
845    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
846        8192
847    } else if model.starts_with("claude-3-opus")
848        || model.starts_with("claude-3-sonnet")
849        || model.starts_with("claude-3-haiku")
850    {
851        4096
852    } else {
853        2048
854    }
855}
856
857#[derive(Debug, Deserialize, Serialize)]
858pub struct Metadata {
859    user_id: Option<String>,
860}
861
862#[derive(Default, Debug, Serialize, Deserialize)]
863#[serde(tag = "type", rename_all = "snake_case")]
864pub enum ToolChoice {
865    #[default]
866    Auto,
867    Any,
868    None,
869    Tool {
870        name: String,
871    },
872}
873impl TryFrom<message::ToolChoice> for ToolChoice {
874    type Error = CompletionError;
875
876    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
877        let res = match value {
878            message::ToolChoice::Auto => Self::Auto,
879            message::ToolChoice::None => Self::None,
880            message::ToolChoice::Required => Self::Any,
881            message::ToolChoice::Specific { function_names } => {
882                if function_names.len() != 1 {
883                    return Err(CompletionError::ProviderError(
884                        "Only one tool may be specified to be used by Claude".into(),
885                    ));
886                }
887
888                Self::Tool {
889                    name: function_names.first().unwrap().to_string(),
890                }
891            }
892        };
893
894        Ok(res)
895    }
896}
897
898/// Recursively ensures all object schemas respect Anthropic structured output restrictions:
899/// - `additionalProperties` must be explicitly set to `false` on every object
900/// - All properties must be listed in `required`
901///
902/// Source: <https://docs.anthropic.com/en/docs/build-with-claude/structured-outputs#json-schema-limitations>
903fn sanitize_schema(schema: &mut serde_json::Value) {
904    use serde_json::Value;
905
906    if let Value::Object(obj) = schema {
907        let is_object_schema = obj.get("type") == Some(&Value::String("object".to_string()))
908            || obj.contains_key("properties");
909
910        if is_object_schema && !obj.contains_key("additionalProperties") {
911            obj.insert("additionalProperties".to_string(), Value::Bool(false));
912        }
913
914        if let Some(Value::Object(properties)) = obj.get("properties") {
915            let prop_keys = properties.keys().cloned().map(Value::String).collect();
916            obj.insert("required".to_string(), Value::Array(prop_keys));
917        }
918
919        // Anthropic does not support numerical constraints on integer/number types.
920        let is_numeric_schema = obj.get("type") == Some(&Value::String("integer".to_string()))
921            || obj.get("type") == Some(&Value::String("number".to_string()));
922
923        if is_numeric_schema {
924            for key in [
925                "minimum",
926                "maximum",
927                "exclusiveMinimum",
928                "exclusiveMaximum",
929                "multipleOf",
930            ] {
931                obj.remove(key);
932            }
933        }
934
935        if let Some(defs) = obj.get_mut("$defs")
936            && let Value::Object(defs_obj) = defs
937        {
938            for (_, def_schema) in defs_obj.iter_mut() {
939                sanitize_schema(def_schema);
940            }
941        }
942
943        if let Some(properties) = obj.get_mut("properties")
944            && let Value::Object(props) = properties
945        {
946            for (_, prop_value) in props.iter_mut() {
947                sanitize_schema(prop_value);
948            }
949        }
950
951        if let Some(items) = obj.get_mut("items") {
952            sanitize_schema(items);
953        }
954
955        for key in ["anyOf", "oneOf", "allOf"] {
956            if let Some(variants) = obj.get_mut(key)
957                && let Value::Array(variants_array) = variants
958            {
959                for variant in variants_array.iter_mut() {
960                    sanitize_schema(variant);
961                }
962            }
963        }
964    }
965}
966
967/// Output format specifier for Anthropic's structured output.
968/// Source: <https://docs.anthropic.com/en/api/messages>
969#[derive(Debug, Deserialize, Serialize)]
970#[serde(tag = "type", rename_all = "snake_case")]
971enum OutputFormat {
972    /// Constrains the model's response to conform to the provided JSON schema.
973    JsonSchema { schema: serde_json::Value },
974}
975
976/// Configuration for the model's output format.
977#[derive(Debug, Deserialize, Serialize)]
978struct OutputConfig {
979    format: OutputFormat,
980}
981
982#[derive(Debug, Deserialize, Serialize)]
983struct AnthropicCompletionRequest {
984    model: String,
985    messages: Vec<Message>,
986    max_tokens: u64,
987    /// System prompt as array of content blocks to support cache_control
988    #[serde(skip_serializing_if = "Vec::is_empty")]
989    system: Vec<SystemContent>,
990    #[serde(skip_serializing_if = "Option::is_none")]
991    temperature: Option<f64>,
992    #[serde(skip_serializing_if = "Option::is_none")]
993    tool_choice: Option<ToolChoice>,
994    #[serde(skip_serializing_if = "Vec::is_empty")]
995    tools: Vec<ToolDefinition>,
996    #[serde(skip_serializing_if = "Option::is_none")]
997    output_config: Option<OutputConfig>,
998    #[serde(flatten, skip_serializing_if = "Option::is_none")]
999    additional_params: Option<serde_json::Value>,
1000}
1001
1002/// Helper to set cache_control on a Content block
1003fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
1004    match content {
1005        Content::Text { cache_control, .. } => *cache_control = value,
1006        Content::Image { cache_control, .. } => *cache_control = value,
1007        Content::ToolResult { cache_control, .. } => *cache_control = value,
1008        Content::Document { cache_control, .. } => *cache_control = value,
1009        _ => {}
1010    }
1011}
1012
1013/// Apply cache control breakpoints to system prompt and messages.
1014/// Strategy: cache the system prompt, and mark the last content block of the last message
1015/// for caching. This allows the conversation history to be cached while new messages
1016/// are added.
1017pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
1018    // Add cache_control to the system prompt (if non-empty)
1019    if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
1020        *cache_control = Some(CacheControl::Ephemeral);
1021    }
1022
1023    // Clear any existing cache_control from all message content blocks
1024    for msg in messages.iter_mut() {
1025        for content in msg.content.iter_mut() {
1026            set_content_cache_control(content, None);
1027        }
1028    }
1029
1030    // Add cache_control to the last content block of the last message
1031    if let Some(last_msg) = messages.last_mut() {
1032        set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::Ephemeral));
1033    }
1034}
1035
1036/// Parameters for building an AnthropicCompletionRequest
1037pub struct AnthropicRequestParams<'a> {
1038    pub model: &'a str,
1039    pub request: CompletionRequest,
1040    pub prompt_caching: bool,
1041}
1042
1043impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
1044    type Error = CompletionError;
1045
1046    fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
1047        let AnthropicRequestParams {
1048            model,
1049            request: req,
1050            prompt_caching,
1051        } = params;
1052
1053        // Check if max_tokens is set, required for Anthropic
1054        let Some(max_tokens) = req.max_tokens else {
1055            return Err(CompletionError::RequestError(
1056                "`max_tokens` must be set for Anthropic".into(),
1057            ));
1058        };
1059
1060        let mut full_history = vec![];
1061        if let Some(docs) = req.normalized_documents() {
1062            full_history.push(docs);
1063        }
1064        full_history.extend(req.chat_history);
1065
1066        let mut messages = full_history
1067            .into_iter()
1068            .map(Message::try_from)
1069            .collect::<Result<Vec<Message>, _>>()?;
1070
1071        let tools = req
1072            .tools
1073            .into_iter()
1074            .map(|tool| ToolDefinition {
1075                name: tool.name,
1076                description: Some(tool.description),
1077                input_schema: tool.parameters,
1078            })
1079            .collect::<Vec<_>>();
1080
1081        // Convert system prompt to array format for cache_control support
1082        let mut system = if let Some(preamble) = req.preamble {
1083            if preamble.is_empty() {
1084                vec![]
1085            } else {
1086                vec![SystemContent::Text {
1087                    text: preamble,
1088                    cache_control: None,
1089                }]
1090            }
1091        } else {
1092            vec![]
1093        };
1094
1095        // Apply cache control breakpoints only if prompt_caching is enabled
1096        if prompt_caching {
1097            apply_cache_control(&mut system, &mut messages);
1098        }
1099
1100        // Map output_schema to Anthropic's output_config field
1101        let output_config = req.output_schema.map(|schema| {
1102            let mut schema_value = schema.to_value();
1103            sanitize_schema(&mut schema_value);
1104            OutputConfig {
1105                format: OutputFormat::JsonSchema {
1106                    schema: schema_value,
1107                },
1108            }
1109        });
1110
1111        Ok(Self {
1112            model: model.to_string(),
1113            messages,
1114            max_tokens,
1115            system,
1116            temperature: req.temperature,
1117            tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
1118            tools,
1119            output_config,
1120            additional_params: req.additional_params,
1121        })
1122    }
1123}
1124
1125impl<T> completion::CompletionModel for CompletionModel<T>
1126where
1127    T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
1128{
1129    type Response = CompletionResponse;
1130    type StreamingResponse = StreamingCompletionResponse;
1131    type Client = Client<T>;
1132
1133    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1134        Self::new(client.clone(), model.into())
1135    }
1136
1137    async fn completion(
1138        &self,
1139        mut completion_request: completion::CompletionRequest,
1140    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1141        let request_model = completion_request
1142            .model
1143            .clone()
1144            .unwrap_or_else(|| self.model.clone());
1145        let span = if tracing::Span::current().is_disabled() {
1146            info_span!(
1147                target: "rig::completions",
1148                "chat",
1149                gen_ai.operation.name = "chat",
1150                gen_ai.provider.name = "anthropic",
1151                gen_ai.request.model = &request_model,
1152                gen_ai.system_instructions = &completion_request.preamble,
1153                gen_ai.response.id = tracing::field::Empty,
1154                gen_ai.response.model = tracing::field::Empty,
1155                gen_ai.usage.output_tokens = tracing::field::Empty,
1156                gen_ai.usage.input_tokens = tracing::field::Empty,
1157            )
1158        } else {
1159            tracing::Span::current()
1160        };
1161
1162        // Check if max_tokens is set, required for Anthropic
1163        if completion_request.max_tokens.is_none() {
1164            if let Some(tokens) = self.default_max_tokens {
1165                completion_request.max_tokens = Some(tokens);
1166            } else {
1167                return Err(CompletionError::RequestError(
1168                    "`max_tokens` must be set for Anthropic".into(),
1169                ));
1170            }
1171        }
1172
1173        let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
1174            model: &request_model,
1175            request: completion_request,
1176            prompt_caching: self.prompt_caching,
1177        })?;
1178
1179        if enabled!(Level::TRACE) {
1180            tracing::trace!(
1181                target: "rig::completions",
1182                "Anthropic completion request: {}",
1183                serde_json::to_string_pretty(&request)?
1184            );
1185        }
1186
1187        async move {
1188            let request: Vec<u8> = serde_json::to_vec(&request)?;
1189
1190            let req = self
1191                .client
1192                .post("/v1/messages")?
1193                .body(request)
1194                .map_err(|e| CompletionError::HttpError(e.into()))?;
1195
1196            let response = self
1197                .client
1198                .send::<_, Bytes>(req)
1199                .await
1200                .map_err(CompletionError::HttpError)?;
1201
1202            if response.status().is_success() {
1203                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
1204                    response
1205                        .into_body()
1206                        .await
1207                        .map_err(CompletionError::HttpError)?
1208                        .to_vec()
1209                        .as_slice(),
1210                )? {
1211                    ApiResponse::Message(completion) => {
1212                        let span = tracing::Span::current();
1213                        span.record_response_metadata(&completion);
1214                        span.record_token_usage(&completion.usage);
1215                        if enabled!(Level::TRACE) {
1216                            tracing::trace!(
1217                                target: "rig::completions",
1218                                "Anthropic completion response: {}",
1219                                serde_json::to_string_pretty(&completion)?
1220                            );
1221                        }
1222                        completion.try_into()
1223                    }
1224                    ApiResponse::Error(ApiErrorResponse { message }) => {
1225                        Err(CompletionError::ResponseError(message))
1226                    }
1227                }
1228            } else {
1229                let text: String = String::from_utf8_lossy(
1230                    &response
1231                        .into_body()
1232                        .await
1233                        .map_err(CompletionError::HttpError)?,
1234                )
1235                .into();
1236                Err(CompletionError::ProviderError(text))
1237            }
1238        }
1239        .instrument(span)
1240        .await
1241    }
1242
1243    async fn stream(
1244        &self,
1245        request: CompletionRequest,
1246    ) -> Result<
1247        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1248        CompletionError,
1249    > {
1250        CompletionModel::stream(self, request).await
1251    }
1252}
1253
1254#[derive(Debug, Deserialize)]
1255struct ApiErrorResponse {
1256    message: String,
1257}
1258
1259#[derive(Debug, Deserialize)]
1260#[serde(tag = "type", rename_all = "snake_case")]
1261enum ApiResponse<T> {
1262    Message(T),
1263    Error(ApiErrorResponse),
1264}
1265
1266#[cfg(test)]
1267mod tests {
1268    use super::*;
1269    use serde_json::json;
1270    use serde_path_to_error::deserialize;
1271
1272    #[test]
1273    fn test_deserialize_message() {
1274        let assistant_message_json = r#"
1275        {
1276            "role": "assistant",
1277            "content": "\n\nHello there, how may I assist you today?"
1278        }
1279        "#;
1280
1281        let assistant_message_json2 = r#"
1282        {
1283            "role": "assistant",
1284            "content": [
1285                {
1286                    "type": "text",
1287                    "text": "\n\nHello there, how may I assist you today?"
1288                },
1289                {
1290                    "type": "tool_use",
1291                    "id": "toolu_01A09q90qw90lq917835lq9",
1292                    "name": "get_weather",
1293                    "input": {"location": "San Francisco, CA"}
1294                }
1295            ]
1296        }
1297        "#;
1298
1299        let user_message_json = r#"
1300        {
1301            "role": "user",
1302            "content": [
1303                {
1304                    "type": "image",
1305                    "source": {
1306                        "type": "base64",
1307                        "media_type": "image/jpeg",
1308                        "data": "/9j/4AAQSkZJRg..."
1309                    }
1310                },
1311                {
1312                    "type": "text",
1313                    "text": "What is in this image?"
1314                },
1315                {
1316                    "type": "tool_result",
1317                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1318                    "content": "15 degrees"
1319                }
1320            ]
1321        }
1322        "#;
1323
1324        let assistant_message: Message = {
1325            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1326            deserialize(jd).unwrap_or_else(|err| {
1327                panic!("Deserialization error at {}: {}", err.path(), err);
1328            })
1329        };
1330
1331        let assistant_message2: Message = {
1332            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1333            deserialize(jd).unwrap_or_else(|err| {
1334                panic!("Deserialization error at {}: {}", err.path(), err);
1335            })
1336        };
1337
1338        let user_message: Message = {
1339            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1340            deserialize(jd).unwrap_or_else(|err| {
1341                panic!("Deserialization error at {}: {}", err.path(), err);
1342            })
1343        };
1344
1345        let Message { role, content } = assistant_message;
1346        assert_eq!(role, Role::Assistant);
1347        assert_eq!(
1348            content.first(),
1349            Content::Text {
1350                text: "\n\nHello there, how may I assist you today?".to_owned(),
1351                cache_control: None,
1352            }
1353        );
1354
1355        let Message { role, content } = assistant_message2;
1356        {
1357            assert_eq!(role, Role::Assistant);
1358            assert_eq!(content.len(), 2);
1359
1360            let mut iter = content.into_iter();
1361
1362            match iter.next().unwrap() {
1363                Content::Text { text, .. } => {
1364                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
1365                }
1366                _ => panic!("Expected text content"),
1367            }
1368
1369            match iter.next().unwrap() {
1370                Content::ToolUse { id, name, input } => {
1371                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1372                    assert_eq!(name, "get_weather");
1373                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
1374                }
1375                _ => panic!("Expected tool use content"),
1376            }
1377
1378            assert_eq!(iter.next(), None);
1379        }
1380
1381        let Message { role, content } = user_message;
1382        {
1383            assert_eq!(role, Role::User);
1384            assert_eq!(content.len(), 3);
1385
1386            let mut iter = content.into_iter();
1387
1388            match iter.next().unwrap() {
1389                Content::Image { source, .. } => {
1390                    assert_eq!(
1391                        source,
1392                        ImageSource::Base64 {
1393                            data: "/9j/4AAQSkZJRg...".to_owned(),
1394                            media_type: ImageFormat::JPEG,
1395                        }
1396                    );
1397                }
1398                _ => panic!("Expected image content"),
1399            }
1400
1401            match iter.next().unwrap() {
1402                Content::Text { text, .. } => {
1403                    assert_eq!(text, "What is in this image?");
1404                }
1405                _ => panic!("Expected text content"),
1406            }
1407
1408            match iter.next().unwrap() {
1409                Content::ToolResult {
1410                    tool_use_id,
1411                    content,
1412                    is_error,
1413                    ..
1414                } => {
1415                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1416                    assert_eq!(
1417                        content.first(),
1418                        ToolResultContent::Text {
1419                            text: "15 degrees".to_owned()
1420                        }
1421                    );
1422                    assert_eq!(is_error, None);
1423                }
1424                _ => panic!("Expected tool result content"),
1425            }
1426
1427            assert_eq!(iter.next(), None);
1428        }
1429    }
1430
1431    #[test]
1432    fn test_message_to_message_conversion() {
1433        let user_message: Message = serde_json::from_str(
1434            r#"
1435        {
1436            "role": "user",
1437            "content": [
1438                {
1439                    "type": "image",
1440                    "source": {
1441                        "type": "base64",
1442                        "media_type": "image/jpeg",
1443                        "data": "/9j/4AAQSkZJRg..."
1444                    }
1445                },
1446                {
1447                    "type": "text",
1448                    "text": "What is in this image?"
1449                },
1450                {
1451                    "type": "document",
1452                    "source": {
1453                        "type": "base64",
1454                        "data": "base64_encoded_pdf_data",
1455                        "media_type": "application/pdf"
1456                    }
1457                }
1458            ]
1459        }
1460        "#,
1461        )
1462        .unwrap();
1463
1464        let assistant_message = Message {
1465            role: Role::Assistant,
1466            content: OneOrMany::one(Content::ToolUse {
1467                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1468                name: "get_weather".to_string(),
1469                input: json!({"location": "San Francisco, CA"}),
1470            }),
1471        };
1472
1473        let tool_message = Message {
1474            role: Role::User,
1475            content: OneOrMany::one(Content::ToolResult {
1476                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1477                content: OneOrMany::one(ToolResultContent::Text {
1478                    text: "15 degrees".to_string(),
1479                }),
1480                is_error: None,
1481                cache_control: None,
1482            }),
1483        };
1484
1485        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1486        let converted_assistant_message: message::Message =
1487            assistant_message.clone().try_into().unwrap();
1488        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1489
1490        match converted_user_message.clone() {
1491            message::Message::User { content } => {
1492                assert_eq!(content.len(), 3);
1493
1494                let mut iter = content.into_iter();
1495
1496                match iter.next().unwrap() {
1497                    message::UserContent::Image(message::Image {
1498                        data, media_type, ..
1499                    }) => {
1500                        assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1501                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1502                    }
1503                    _ => panic!("Expected image content"),
1504                }
1505
1506                match iter.next().unwrap() {
1507                    message::UserContent::Text(message::Text { text }) => {
1508                        assert_eq!(text, "What is in this image?");
1509                    }
1510                    _ => panic!("Expected text content"),
1511                }
1512
1513                match iter.next().unwrap() {
1514                    message::UserContent::Document(message::Document {
1515                        data, media_type, ..
1516                    }) => {
1517                        assert_eq!(
1518                            data,
1519                            DocumentSourceKind::String("base64_encoded_pdf_data".into())
1520                        );
1521                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1522                    }
1523                    _ => panic!("Expected document content"),
1524                }
1525
1526                assert_eq!(iter.next(), None);
1527            }
1528            _ => panic!("Expected user message"),
1529        }
1530
1531        match converted_tool_message.clone() {
1532            message::Message::User { content } => {
1533                let message::ToolResult { id, content, .. } = match content.first() {
1534                    message::UserContent::ToolResult(tool_result) => tool_result,
1535                    _ => panic!("Expected tool result content"),
1536                };
1537                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1538                match content.first() {
1539                    message::ToolResultContent::Text(message::Text { text }) => {
1540                        assert_eq!(text, "15 degrees");
1541                    }
1542                    _ => panic!("Expected text content"),
1543                }
1544            }
1545            _ => panic!("Expected tool result content"),
1546        }
1547
1548        match converted_assistant_message.clone() {
1549            message::Message::Assistant { content, .. } => {
1550                assert_eq!(content.len(), 1);
1551
1552                match content.first() {
1553                    message::AssistantContent::ToolCall(message::ToolCall {
1554                        id, function, ..
1555                    }) => {
1556                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1557                        assert_eq!(function.name, "get_weather");
1558                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1559                    }
1560                    _ => panic!("Expected tool call content"),
1561                }
1562            }
1563            _ => panic!("Expected assistant message"),
1564        }
1565
1566        let original_user_message: Message = converted_user_message.try_into().unwrap();
1567        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1568        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1569
1570        assert_eq!(user_message, original_user_message);
1571        assert_eq!(assistant_message, original_assistant_message);
1572        assert_eq!(tool_message, original_tool_message);
1573    }
1574
1575    #[test]
1576    fn test_content_format_conversion() {
1577        use crate::completion::message::ContentFormat;
1578
1579        let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1580        assert_eq!(source_type, SourceType::URL);
1581
1582        let content_format: ContentFormat = SourceType::URL.into();
1583        assert_eq!(content_format, ContentFormat::Url);
1584
1585        let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1586        assert_eq!(source_type, SourceType::BASE64);
1587
1588        let content_format: ContentFormat = SourceType::BASE64.into();
1589        assert_eq!(content_format, ContentFormat::Base64);
1590
1591        let source_type: SourceType = ContentFormat::String.try_into().unwrap();
1592        assert_eq!(source_type, SourceType::TEXT);
1593
1594        let content_format: ContentFormat = SourceType::TEXT.into();
1595        assert_eq!(content_format, ContentFormat::String);
1596    }
1597
1598    #[test]
1599    fn test_cache_control_serialization() {
1600        // Test SystemContent with cache_control
1601        let system = SystemContent::Text {
1602            text: "You are a helpful assistant.".to_string(),
1603            cache_control: Some(CacheControl::Ephemeral),
1604        };
1605        let json = serde_json::to_string(&system).unwrap();
1606        assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1607        assert!(json.contains(r#""type":"text""#));
1608
1609        // Test SystemContent without cache_control (should not have cache_control field)
1610        let system_no_cache = SystemContent::Text {
1611            text: "Hello".to_string(),
1612            cache_control: None,
1613        };
1614        let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1615        assert!(!json_no_cache.contains("cache_control"));
1616
1617        // Test Content::Text with cache_control
1618        let content = Content::Text {
1619            text: "Test message".to_string(),
1620            cache_control: Some(CacheControl::Ephemeral),
1621        };
1622        let json_content = serde_json::to_string(&content).unwrap();
1623        assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1624
1625        // Test apply_cache_control function
1626        let mut system_vec = vec![SystemContent::Text {
1627            text: "System prompt".to_string(),
1628            cache_control: None,
1629        }];
1630        let mut messages = vec![
1631            Message {
1632                role: Role::User,
1633                content: OneOrMany::one(Content::Text {
1634                    text: "First message".to_string(),
1635                    cache_control: None,
1636                }),
1637            },
1638            Message {
1639                role: Role::Assistant,
1640                content: OneOrMany::one(Content::Text {
1641                    text: "Response".to_string(),
1642                    cache_control: None,
1643                }),
1644            },
1645        ];
1646
1647        apply_cache_control(&mut system_vec, &mut messages);
1648
1649        // System should have cache_control
1650        match &system_vec[0] {
1651            SystemContent::Text { cache_control, .. } => {
1652                assert!(cache_control.is_some());
1653            }
1654        }
1655
1656        // Only the last content block of last message should have cache_control
1657        // First message should NOT have cache_control
1658        for content in messages[0].content.iter() {
1659            if let Content::Text { cache_control, .. } = content {
1660                assert!(cache_control.is_none());
1661            }
1662        }
1663
1664        // Last message SHOULD have cache_control
1665        for content in messages[1].content.iter() {
1666            if let Content::Text { cache_control, .. } = content {
1667                assert!(cache_control.is_some());
1668            }
1669        }
1670    }
1671
1672    #[test]
1673    fn test_plaintext_document_serialization() {
1674        let content = Content::Document {
1675            source: DocumentSource::Text {
1676                data: "Hello, world!".to_string(),
1677                media_type: PlainTextMediaType::Plain,
1678            },
1679            cache_control: None,
1680        };
1681
1682        let json = serde_json::to_value(&content).unwrap();
1683        assert_eq!(json["type"], "document");
1684        assert_eq!(json["source"]["type"], "text");
1685        assert_eq!(json["source"]["media_type"], "text/plain");
1686        assert_eq!(json["source"]["data"], "Hello, world!");
1687    }
1688
1689    #[test]
1690    fn test_plaintext_document_deserialization() {
1691        let json = r#"
1692        {
1693            "type": "document",
1694            "source": {
1695                "type": "text",
1696                "media_type": "text/plain",
1697                "data": "Hello, world!"
1698            }
1699        }
1700        "#;
1701
1702        let content: Content = serde_json::from_str(json).unwrap();
1703        match content {
1704            Content::Document {
1705                source,
1706                cache_control,
1707            } => {
1708                assert_eq!(
1709                    source,
1710                    DocumentSource::Text {
1711                        data: "Hello, world!".to_string(),
1712                        media_type: PlainTextMediaType::Plain,
1713                    }
1714                );
1715                assert_eq!(cache_control, None);
1716            }
1717            _ => panic!("Expected Document content"),
1718        }
1719    }
1720
1721    #[test]
1722    fn test_base64_pdf_document_serialization() {
1723        let content = Content::Document {
1724            source: DocumentSource::Base64 {
1725                data: "base64data".to_string(),
1726                media_type: DocumentFormat::PDF,
1727            },
1728            cache_control: None,
1729        };
1730
1731        let json = serde_json::to_value(&content).unwrap();
1732        assert_eq!(json["type"], "document");
1733        assert_eq!(json["source"]["type"], "base64");
1734        assert_eq!(json["source"]["media_type"], "application/pdf");
1735        assert_eq!(json["source"]["data"], "base64data");
1736    }
1737
1738    #[test]
1739    fn test_base64_pdf_document_deserialization() {
1740        let json = r#"
1741        {
1742            "type": "document",
1743            "source": {
1744                "type": "base64",
1745                "media_type": "application/pdf",
1746                "data": "base64data"
1747            }
1748        }
1749        "#;
1750
1751        let content: Content = serde_json::from_str(json).unwrap();
1752        match content {
1753            Content::Document { source, .. } => {
1754                assert_eq!(
1755                    source,
1756                    DocumentSource::Base64 {
1757                        data: "base64data".to_string(),
1758                        media_type: DocumentFormat::PDF,
1759                    }
1760                );
1761            }
1762            _ => panic!("Expected Document content"),
1763        }
1764    }
1765
1766    #[test]
1767    fn test_plaintext_rig_to_anthropic_conversion() {
1768        use crate::completion::message as msg;
1769
1770        let rig_message = msg::Message::User {
1771            content: OneOrMany::one(msg::UserContent::document(
1772                "Some plain text content".to_string(),
1773                Some(msg::DocumentMediaType::TXT),
1774            )),
1775        };
1776
1777        let anthropic_message: Message = rig_message.try_into().unwrap();
1778        assert_eq!(anthropic_message.role, Role::User);
1779
1780        let mut iter = anthropic_message.content.into_iter();
1781        match iter.next().unwrap() {
1782            Content::Document { source, .. } => {
1783                assert_eq!(
1784                    source,
1785                    DocumentSource::Text {
1786                        data: "Some plain text content".to_string(),
1787                        media_type: PlainTextMediaType::Plain,
1788                    }
1789                );
1790            }
1791            other => panic!("Expected Document content, got: {other:?}"),
1792        }
1793    }
1794
1795    #[test]
1796    fn test_plaintext_anthropic_to_rig_conversion() {
1797        use crate::completion::message as msg;
1798
1799        let anthropic_message = Message {
1800            role: Role::User,
1801            content: OneOrMany::one(Content::Document {
1802                source: DocumentSource::Text {
1803                    data: "Some plain text content".to_string(),
1804                    media_type: PlainTextMediaType::Plain,
1805                },
1806                cache_control: None,
1807            }),
1808        };
1809
1810        let rig_message: msg::Message = anthropic_message.try_into().unwrap();
1811        match rig_message {
1812            msg::Message::User { content } => {
1813                let mut iter = content.into_iter();
1814                match iter.next().unwrap() {
1815                    msg::UserContent::Document(msg::Document {
1816                        data, media_type, ..
1817                    }) => {
1818                        assert_eq!(
1819                            data,
1820                            DocumentSourceKind::String("Some plain text content".into())
1821                        );
1822                        assert_eq!(media_type, Some(msg::DocumentMediaType::TXT));
1823                    }
1824                    other => panic!("Expected Document content, got: {other:?}"),
1825                }
1826            }
1827            _ => panic!("Expected User message"),
1828        }
1829    }
1830
1831    #[test]
1832    fn test_plaintext_roundtrip_rig_to_anthropic_and_back() {
1833        use crate::completion::message as msg;
1834
1835        let original = msg::Message::User {
1836            content: OneOrMany::one(msg::UserContent::document(
1837                "Round trip text".to_string(),
1838                Some(msg::DocumentMediaType::TXT),
1839            )),
1840        };
1841
1842        let anthropic: Message = original.clone().try_into().unwrap();
1843        let back: msg::Message = anthropic.try_into().unwrap();
1844
1845        match (&original, &back) {
1846            (
1847                msg::Message::User {
1848                    content: orig_content,
1849                },
1850                msg::Message::User {
1851                    content: back_content,
1852                },
1853            ) => match (orig_content.first(), back_content.first()) {
1854                (
1855                    msg::UserContent::Document(msg::Document {
1856                        media_type: orig_mt,
1857                        ..
1858                    }),
1859                    msg::UserContent::Document(msg::Document {
1860                        media_type: back_mt,
1861                        ..
1862                    }),
1863                ) => {
1864                    assert_eq!(orig_mt, back_mt);
1865                }
1866                _ => panic!("Expected Document content in both"),
1867            },
1868            _ => panic!("Expected User messages"),
1869        }
1870    }
1871
1872    #[test]
1873    fn test_unsupported_document_type_returns_error() {
1874        use crate::completion::message as msg;
1875
1876        let rig_message = msg::Message::User {
1877            content: OneOrMany::one(msg::UserContent::Document(msg::Document {
1878                data: DocumentSourceKind::String("data".into()),
1879                media_type: Some(msg::DocumentMediaType::HTML),
1880                additional_params: None,
1881            })),
1882        };
1883
1884        let result: Result<Message, _> = rig_message.try_into();
1885        assert!(result.is_err());
1886        let err = result.unwrap_err().to_string();
1887        assert!(
1888            err.contains("Anthropic only supports PDF and plain text documents"),
1889            "Unexpected error: {err}"
1890        );
1891    }
1892
1893    #[test]
1894    fn test_plaintext_document_url_source_returns_error() {
1895        use crate::completion::message as msg;
1896
1897        let rig_message = msg::Message::User {
1898            content: OneOrMany::one(msg::UserContent::Document(msg::Document {
1899                data: DocumentSourceKind::Url("https://example.com/doc.txt".into()),
1900                media_type: Some(msg::DocumentMediaType::TXT),
1901                additional_params: None,
1902            })),
1903        };
1904
1905        let result: Result<Message, _> = rig_message.try_into();
1906        assert!(result.is_err());
1907        let err = result.unwrap_err().to_string();
1908        assert!(
1909            err.contains("Only string or base64 data is supported for plain text documents"),
1910            "Unexpected error: {err}"
1911        );
1912    }
1913
1914    #[test]
1915    fn test_plaintext_document_with_cache_control() {
1916        let content = Content::Document {
1917            source: DocumentSource::Text {
1918                data: "cached text".to_string(),
1919                media_type: PlainTextMediaType::Plain,
1920            },
1921            cache_control: Some(CacheControl::Ephemeral),
1922        };
1923
1924        let json = serde_json::to_value(&content).unwrap();
1925        assert_eq!(json["source"]["type"], "text");
1926        assert_eq!(json["source"]["media_type"], "text/plain");
1927        assert_eq!(json["cache_control"]["type"], "ephemeral");
1928    }
1929
1930    #[test]
1931    fn test_message_with_plaintext_document_deserialization() {
1932        let json = r#"
1933        {
1934            "role": "user",
1935            "content": [
1936                {
1937                    "type": "document",
1938                    "source": {
1939                        "type": "text",
1940                        "media_type": "text/plain",
1941                        "data": "Hello from a text file"
1942                    }
1943                },
1944                {
1945                    "type": "text",
1946                    "text": "Summarize this document."
1947                }
1948            ]
1949        }
1950        "#;
1951
1952        let message: Message = serde_json::from_str(json).unwrap();
1953        assert_eq!(message.role, Role::User);
1954        assert_eq!(message.content.len(), 2);
1955
1956        let mut iter = message.content.into_iter();
1957
1958        match iter.next().unwrap() {
1959            Content::Document { source, .. } => {
1960                assert_eq!(
1961                    source,
1962                    DocumentSource::Text {
1963                        data: "Hello from a text file".to_string(),
1964                        media_type: PlainTextMediaType::Plain,
1965                    }
1966                );
1967            }
1968            _ => panic!("Expected Document content"),
1969        }
1970
1971        match iter.next().unwrap() {
1972            Content::Text { text, .. } => {
1973                assert_eq!(text, "Summarize this document.");
1974            }
1975            _ => panic!("Expected Text content"),
1976        }
1977    }
1978
1979    #[test]
1980    fn test_assistant_reasoning_multiblock_to_anthropic_content() {
1981        let reasoning = message::Reasoning {
1982            id: None,
1983            content: vec![
1984                message::ReasoningContent::Text {
1985                    text: "step one".to_string(),
1986                    signature: Some("sig-1".to_string()),
1987                },
1988                message::ReasoningContent::Summary("summary".to_string()),
1989                message::ReasoningContent::Text {
1990                    text: "step two".to_string(),
1991                    signature: Some("sig-2".to_string()),
1992                },
1993                message::ReasoningContent::Redacted {
1994                    data: "redacted block".to_string(),
1995                },
1996            ],
1997        };
1998
1999        let msg = message::Message::Assistant {
2000            id: None,
2001            content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2002        };
2003        let converted: Message = msg.try_into().expect("convert assistant message");
2004        let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2005
2006        assert_eq!(converted.role, Role::Assistant);
2007        assert_eq!(converted_content.len(), 4);
2008        assert!(matches!(
2009            converted_content.first(),
2010            Some(Content::Thinking { thinking, signature: Some(signature) })
2011                if thinking == "step one" && signature == "sig-1"
2012        ));
2013        assert!(matches!(
2014            converted_content.get(1),
2015            Some(Content::Thinking { thinking, signature: None }) if thinking == "summary"
2016        ));
2017        assert!(matches!(
2018            converted_content.get(2),
2019            Some(Content::Thinking { thinking, signature: Some(signature) })
2020                if thinking == "step two" && signature == "sig-2"
2021        ));
2022        assert!(matches!(
2023            converted_content.get(3),
2024            Some(Content::RedactedThinking { data }) if data == "redacted block"
2025        ));
2026    }
2027
2028    #[test]
2029    fn test_redacted_thinking_content_to_assistant_reasoning() {
2030        let content = Content::RedactedThinking {
2031            data: "opaque-redacted".to_string(),
2032        };
2033        let converted: message::AssistantContent =
2034            content.try_into().expect("convert redacted thinking");
2035
2036        assert!(matches!(
2037            converted,
2038            message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2039                if matches!(
2040                    content.first(),
2041                    Some(message::ReasoningContent::Redacted { data }) if data == "opaque-redacted"
2042                )
2043        ));
2044    }
2045
2046    #[test]
2047    fn test_assistant_encrypted_reasoning_maps_to_redacted_thinking() {
2048        let reasoning = message::Reasoning {
2049            id: None,
2050            content: vec![message::ReasoningContent::Encrypted(
2051                "ciphertext".to_string(),
2052            )],
2053        };
2054        let msg = message::Message::Assistant {
2055            id: None,
2056            content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2057        };
2058
2059        let converted: Message = msg.try_into().expect("convert assistant message");
2060        let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2061
2062        assert_eq!(converted_content.len(), 1);
2063        assert!(matches!(
2064            converted_content.first(),
2065            Some(Content::RedactedThinking { data }) if data == "ciphertext"
2066        ));
2067    }
2068}