rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError, GetTokenUsage},
6    http_client::HttpClientExt,
7    message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
8    one_or_many::string_or_one_or_many,
9    telemetry::{ProviderResponseExt, SpanCombinator},
10    wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21// ================================================================
22// Anthropic Completion API
23// ================================================================
24
25/// `claude-opus-4-0` completion model
26pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27/// `claude-sonnet-4-0` completion model
28pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29/// `claude-3-7-sonnet-latest` completion model
30pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31/// `claude-3-5-sonnet-latest` completion model
32pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33/// `claude-3-5-haiku-latest` completion model
34pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42    pub content: Vec<Content>,
43    pub id: String,
44    pub model: String,
45    pub role: String,
46    pub stop_reason: Option<String>,
47    pub stop_sequence: Option<String>,
48    pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52    type OutputMessage = Content;
53    type Usage = Usage;
54
55    fn get_response_id(&self) -> Option<String> {
56        Some(self.id.to_owned())
57    }
58
59    fn get_response_model_name(&self) -> Option<String> {
60        Some(self.model.to_owned())
61    }
62
63    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64        self.content.clone()
65    }
66
67    fn get_text_response(&self) -> Option<String> {
68        let res = self
69            .content
70            .iter()
71            .filter_map(|x| {
72                if let Content::Text { text, .. } = x {
73                    Some(text.to_owned())
74                } else {
75                    None
76                }
77            })
78            .collect::<Vec<String>>()
79            .join("\n");
80
81        if res.is_empty() { None } else { Some(res) }
82    }
83
84    fn get_usage(&self) -> Option<Self::Usage> {
85        Some(self.usage.clone())
86    }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91    pub input_tokens: u64,
92    pub cache_read_input_tokens: Option<u64>,
93    pub cache_creation_input_tokens: Option<u64>,
94    pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        write!(
100            f,
101            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102            self.input_tokens,
103            match self.cache_read_input_tokens {
104                Some(token) => token.to_string(),
105                None => "n/a".to_string(),
106            },
107            match self.cache_creation_input_tokens {
108                Some(token) => token.to_string(),
109                None => "n/a".to_string(),
110            },
111            self.output_tokens
112        )
113    }
114}
115
116impl GetTokenUsage for Usage {
117    fn token_usage(&self) -> Option<crate::completion::Usage> {
118        let mut usage = crate::completion::Usage::new();
119
120        usage.input_tokens = self.input_tokens
121            + self.cache_creation_input_tokens.unwrap_or_default()
122            + self.cache_read_input_tokens.unwrap_or_default();
123        usage.output_tokens = self.output_tokens;
124        usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126        Some(usage)
127    }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132    pub name: String,
133    pub description: Option<String>,
134    pub input_schema: serde_json::Value,
135}
136
137/// Cache control directive for Anthropic prompt caching
138#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum CacheControl {
141    Ephemeral,
142}
143
144/// System message content block with optional cache control
145#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
146#[serde(tag = "type", rename_all = "snake_case")]
147pub enum SystemContent {
148    Text {
149        text: String,
150        #[serde(skip_serializing_if = "Option::is_none")]
151        cache_control: Option<CacheControl>,
152    },
153}
154
155impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
156    type Error = CompletionError;
157
158    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
159        let content = response
160            .content
161            .iter()
162            .map(|content| content.clone().try_into())
163            .collect::<Result<Vec<_>, _>>()?;
164
165        let choice = OneOrMany::many(content).map_err(|_| {
166            CompletionError::ResponseError(
167                "Response contained no message or tool call (empty)".to_owned(),
168            )
169        })?;
170
171        let usage = completion::Usage {
172            input_tokens: response.usage.input_tokens,
173            output_tokens: response.usage.output_tokens,
174            total_tokens: response.usage.input_tokens + response.usage.output_tokens,
175        };
176
177        Ok(completion::CompletionResponse {
178            choice,
179            usage,
180            raw_response: response,
181        })
182    }
183}
184
185#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
186pub struct Message {
187    pub role: Role,
188    #[serde(deserialize_with = "string_or_one_or_many")]
189    pub content: OneOrMany<Content>,
190}
191
192#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
193#[serde(rename_all = "lowercase")]
194pub enum Role {
195    User,
196    Assistant,
197}
198
199#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
200#[serde(tag = "type", rename_all = "snake_case")]
201pub enum Content {
202    Text {
203        text: String,
204        #[serde(skip_serializing_if = "Option::is_none")]
205        cache_control: Option<CacheControl>,
206    },
207    Image {
208        source: ImageSource,
209        #[serde(skip_serializing_if = "Option::is_none")]
210        cache_control: Option<CacheControl>,
211    },
212    ToolUse {
213        id: String,
214        name: String,
215        input: serde_json::Value,
216    },
217    ToolResult {
218        tool_use_id: String,
219        #[serde(deserialize_with = "string_or_one_or_many")]
220        content: OneOrMany<ToolResultContent>,
221        #[serde(skip_serializing_if = "Option::is_none")]
222        is_error: Option<bool>,
223        #[serde(skip_serializing_if = "Option::is_none")]
224        cache_control: Option<CacheControl>,
225    },
226    Document {
227        source: DocumentSource,
228        #[serde(skip_serializing_if = "Option::is_none")]
229        cache_control: Option<CacheControl>,
230    },
231    Thinking {
232        thinking: String,
233        #[serde(skip_serializing_if = "Option::is_none")]
234        signature: Option<String>,
235    },
236}
237
238impl FromStr for Content {
239    type Err = Infallible;
240
241    fn from_str(s: &str) -> Result<Self, Self::Err> {
242        Ok(Content::Text {
243            text: s.to_owned(),
244            cache_control: None,
245        })
246    }
247}
248
249#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
250#[serde(tag = "type", rename_all = "snake_case")]
251pub enum ToolResultContent {
252    Text { text: String },
253    Image(ImageSource),
254}
255
256impl FromStr for ToolResultContent {
257    type Err = Infallible;
258
259    fn from_str(s: &str) -> Result<Self, Self::Err> {
260        Ok(ToolResultContent::Text { text: s.to_owned() })
261    }
262}
263
264#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
265#[serde(untagged)]
266pub enum ImageSourceData {
267    Base64(String),
268    Url(String),
269}
270
271impl From<ImageSourceData> for DocumentSourceKind {
272    fn from(value: ImageSourceData) -> Self {
273        match value {
274            ImageSourceData::Base64(data) => DocumentSourceKind::Base64(data),
275            ImageSourceData::Url(url) => DocumentSourceKind::Url(url),
276        }
277    }
278}
279
280impl TryFrom<DocumentSourceKind> for ImageSourceData {
281    type Error = MessageError;
282
283    fn try_from(value: DocumentSourceKind) -> Result<Self, Self::Error> {
284        match value {
285            DocumentSourceKind::Base64(data) => Ok(ImageSourceData::Base64(data)),
286            DocumentSourceKind::Url(url) => Ok(ImageSourceData::Url(url)),
287            _ => Err(MessageError::ConversionError("Content has no body".into())),
288        }
289    }
290}
291
292impl From<ImageSourceData> for String {
293    fn from(value: ImageSourceData) -> Self {
294        match value {
295            ImageSourceData::Base64(s) | ImageSourceData::Url(s) => s,
296        }
297    }
298}
299
300#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
301pub struct ImageSource {
302    pub data: ImageSourceData,
303    pub media_type: ImageFormat,
304    pub r#type: SourceType,
305}
306
307#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
308pub struct DocumentSource {
309    pub data: String,
310    pub media_type: DocumentFormat,
311    pub r#type: SourceType,
312}
313
314#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
315#[serde(rename_all = "lowercase")]
316pub enum ImageFormat {
317    #[serde(rename = "image/jpeg")]
318    JPEG,
319    #[serde(rename = "image/png")]
320    PNG,
321    #[serde(rename = "image/gif")]
322    GIF,
323    #[serde(rename = "image/webp")]
324    WEBP,
325}
326
327/// The document format to be used.
328///
329/// Currently, Anthropic only supports PDF for text documents over the API (within a message). You can find more information about this here: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
330#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
331#[serde(rename_all = "lowercase")]
332pub enum DocumentFormat {
333    #[serde(rename = "application/pdf")]
334    PDF,
335}
336
337#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
338#[serde(rename_all = "lowercase")]
339pub enum SourceType {
340    BASE64,
341    URL,
342}
343
344impl From<String> for Content {
345    fn from(text: String) -> Self {
346        Content::Text {
347            text,
348            cache_control: None,
349        }
350    }
351}
352
353impl From<String> for ToolResultContent {
354    fn from(text: String) -> Self {
355        ToolResultContent::Text { text }
356    }
357}
358
359impl TryFrom<message::ContentFormat> for SourceType {
360    type Error = MessageError;
361
362    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
363        match format {
364            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
365            message::ContentFormat::Url => Ok(SourceType::URL),
366            message::ContentFormat::String => Err(MessageError::ConversionError(
367                "ContentFormat::String is deprecated, use ContentFormat::Url for URLs".into(),
368            )),
369        }
370    }
371}
372
373impl From<SourceType> for message::ContentFormat {
374    fn from(source_type: SourceType) -> Self {
375        match source_type {
376            SourceType::BASE64 => message::ContentFormat::Base64,
377            SourceType::URL => message::ContentFormat::Url,
378        }
379    }
380}
381
382impl TryFrom<message::ImageMediaType> for ImageFormat {
383    type Error = MessageError;
384
385    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
386        Ok(match media_type {
387            message::ImageMediaType::JPEG => ImageFormat::JPEG,
388            message::ImageMediaType::PNG => ImageFormat::PNG,
389            message::ImageMediaType::GIF => ImageFormat::GIF,
390            message::ImageMediaType::WEBP => ImageFormat::WEBP,
391            _ => {
392                return Err(MessageError::ConversionError(
393                    format!("Unsupported image media type: {media_type:?}").to_owned(),
394                ));
395            }
396        })
397    }
398}
399
400impl From<ImageFormat> for message::ImageMediaType {
401    fn from(format: ImageFormat) -> Self {
402        match format {
403            ImageFormat::JPEG => message::ImageMediaType::JPEG,
404            ImageFormat::PNG => message::ImageMediaType::PNG,
405            ImageFormat::GIF => message::ImageMediaType::GIF,
406            ImageFormat::WEBP => message::ImageMediaType::WEBP,
407        }
408    }
409}
410
411impl TryFrom<DocumentMediaType> for DocumentFormat {
412    type Error = MessageError;
413    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
414        if !matches!(value, DocumentMediaType::PDF) {
415            return Err(MessageError::ConversionError(
416                "Anthropic only supports PDF documents".to_string(),
417            ));
418        };
419
420        Ok(DocumentFormat::PDF)
421    }
422}
423
424impl TryFrom<message::AssistantContent> for Content {
425    type Error = MessageError;
426    fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
427        match text {
428            message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
429                text,
430                cache_control: None,
431            }),
432            message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
433                "Anthropic currently doesn't support images.".to_string(),
434            )),
435            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
436                Ok(Content::ToolUse {
437                    id,
438                    name: function.name,
439                    input: function.arguments,
440                })
441            }
442            message::AssistantContent::Reasoning(Reasoning {
443                reasoning,
444                signature,
445                ..
446            }) => Ok(Content::Thinking {
447                thinking: reasoning.first().cloned().unwrap_or(String::new()),
448                signature,
449            }),
450        }
451    }
452}
453
454impl TryFrom<message::Message> for Message {
455    type Error = MessageError;
456
457    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
458        Ok(match message {
459            message::Message::User { content } => Message {
460                role: Role::User,
461                content: content.try_map(|content| match content {
462                    message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
463                        text,
464                        cache_control: None,
465                    }),
466                    message::UserContent::ToolResult(message::ToolResult {
467                        id, content, ..
468                    }) => Ok(Content::ToolResult {
469                        tool_use_id: id,
470                        content: content.try_map(|content| match content {
471                            message::ToolResultContent::Text(message::Text { text }) => {
472                                Ok(ToolResultContent::Text { text })
473                            }
474                            message::ToolResultContent::Image(image) => {
475                                let DocumentSourceKind::Base64(data) = image.data else {
476                                    return Err(MessageError::ConversionError(
477                                        "Only base64 strings can be used with the Anthropic API"
478                                            .to_string(),
479                                    ));
480                                };
481                                let media_type =
482                                    image.media_type.ok_or(MessageError::ConversionError(
483                                        "Image media type is required".to_owned(),
484                                    ))?;
485                                Ok(ToolResultContent::Image(ImageSource {
486                                    data: ImageSourceData::Base64(data),
487                                    media_type: media_type.try_into()?,
488                                    r#type: SourceType::BASE64,
489                                }))
490                            }
491                        })?,
492                        is_error: None,
493                        cache_control: None,
494                    }),
495                    message::UserContent::Image(message::Image {
496                        data, media_type, ..
497                    }) => {
498                        let media_type = media_type.ok_or(MessageError::ConversionError(
499                            "Image media type is required for Claude API".to_string(),
500                        ))?;
501
502                        let source = match data {
503                            DocumentSourceKind::Base64(data) => ImageSource {
504                                data: ImageSourceData::Base64(data),
505                                r#type: SourceType::BASE64,
506                                media_type: ImageFormat::try_from(media_type)?,
507                            },
508                            DocumentSourceKind::Url(url) => ImageSource {
509                                data: ImageSourceData::Url(url),
510                                r#type: SourceType::URL,
511                                media_type: ImageFormat::try_from(media_type)?,
512                            },
513                            DocumentSourceKind::Unknown => {
514                                return Err(MessageError::ConversionError(
515                                    "Image content has no body".into(),
516                                ));
517                            }
518                            doc => {
519                                return Err(MessageError::ConversionError(format!(
520                                    "Unsupported document type: {doc:?}"
521                                )));
522                            }
523                        };
524
525                        Ok(Content::Image {
526                            source,
527                            cache_control: None,
528                        })
529                    }
530                    message::UserContent::Document(message::Document {
531                        data, media_type, ..
532                    }) => {
533                        let media_type = media_type.ok_or(MessageError::ConversionError(
534                            "Document media type is required".to_string(),
535                        ))?;
536
537                        let data = match data {
538                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
539                                data
540                            }
541                            _ => {
542                                return Err(MessageError::ConversionError(
543                                    "Only base64 encoded documents currently supported".into(),
544                                ));
545                            }
546                        };
547
548                        let source = DocumentSource {
549                            data,
550                            media_type: media_type.try_into()?,
551                            r#type: SourceType::BASE64,
552                        };
553                        Ok(Content::Document {
554                            source,
555                            cache_control: None,
556                        })
557                    }
558                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
559                        "Audio is not supported in Anthropic".to_owned(),
560                    )),
561                    message::UserContent::Video { .. } => Err(MessageError::ConversionError(
562                        "Video is not supported in Anthropic".to_owned(),
563                    )),
564                })?,
565            },
566
567            message::Message::Assistant { content, .. } => Message {
568                content: content.try_map(|content| content.try_into())?,
569                role: Role::Assistant,
570            },
571        })
572    }
573}
574
575impl TryFrom<Content> for message::AssistantContent {
576    type Error = MessageError;
577
578    fn try_from(content: Content) -> Result<Self, Self::Error> {
579        Ok(match content {
580            Content::Text { text, .. } => message::AssistantContent::text(text),
581            Content::ToolUse { id, name, input } => {
582                message::AssistantContent::tool_call(id, name, input)
583            }
584            Content::Thinking {
585                thinking,
586                signature,
587            } => message::AssistantContent::Reasoning(
588                Reasoning::new(&thinking).with_signature(signature),
589            ),
590            _ => {
591                return Err(MessageError::ConversionError(
592                    "Content did not contain a message, tool call, or reasoning".to_owned(),
593                ));
594            }
595        })
596    }
597}
598
599impl From<ToolResultContent> for message::ToolResultContent {
600    fn from(content: ToolResultContent) -> Self {
601        match content {
602            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
603            ToolResultContent::Image(ImageSource {
604                data,
605                media_type: format,
606                ..
607            }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
608        }
609    }
610}
611
612impl TryFrom<Message> for message::Message {
613    type Error = MessageError;
614
615    fn try_from(message: Message) -> Result<Self, Self::Error> {
616        Ok(match message.role {
617            Role::User => message::Message::User {
618                content: message.content.try_map(|content| {
619                    Ok(match content {
620                        Content::Text { text, .. } => message::UserContent::text(text),
621                        Content::ToolResult {
622                            tool_use_id,
623                            content,
624                            ..
625                        } => message::UserContent::tool_result(
626                            tool_use_id,
627                            content.map(|content| content.into()),
628                        ),
629                        Content::Image { source, .. } => {
630                            message::UserContent::Image(message::Image {
631                                data: source.data.into(),
632                                media_type: Some(source.media_type.into()),
633                                detail: None,
634                                additional_params: None,
635                            })
636                        }
637                        Content::Document { source, .. } => message::UserContent::document(
638                            source.data,
639                            Some(message::DocumentMediaType::PDF),
640                        ),
641                        _ => {
642                            return Err(MessageError::ConversionError(
643                                "Unsupported content type for User role".to_owned(),
644                            ));
645                        }
646                    })
647                })?,
648            },
649            Role::Assistant => match message.content.first() {
650                Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
651                    message::Message::Assistant {
652                        id: None,
653                        content: message.content.try_map(|content| content.try_into())?,
654                    }
655                }
656
657                _ => {
658                    return Err(MessageError::ConversionError(
659                        format!("Unsupported message for Assistant role: {message:?}").to_owned(),
660                    ));
661                }
662            },
663        })
664    }
665}
666
667#[derive(Clone)]
668pub struct CompletionModel<T = reqwest::Client> {
669    pub(crate) client: Client<T>,
670    pub model: String,
671    pub default_max_tokens: Option<u64>,
672    /// Enable automatic prompt caching (adds cache_control breakpoints to system prompt and messages)
673    pub prompt_caching: bool,
674}
675
676impl<T> CompletionModel<T>
677where
678    T: HttpClientExt,
679{
680    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
681        let model = model.into();
682        let default_max_tokens = calculate_max_tokens(&model);
683
684        Self {
685            client,
686            model,
687            default_max_tokens,
688            prompt_caching: false, // Default to off
689        }
690    }
691
692    pub fn with_model(client: Client<T>, model: &str) -> Self {
693        Self {
694            client,
695            model: model.to_string(),
696            default_max_tokens: Some(calculate_max_tokens_custom(model)),
697            prompt_caching: false, // Default to off
698        }
699    }
700
701    /// Enable automatic prompt caching.
702    ///
703    /// When enabled, cache_control breakpoints are automatically added to:
704    /// - The system prompt (marked with ephemeral cache)
705    /// - The last content block of the last message (marked with ephemeral cache)
706    ///
707    /// This allows Anthropic to cache the conversation history for cost savings.
708    pub fn with_prompt_caching(mut self) -> Self {
709        self.prompt_caching = true;
710        self
711    }
712}
713
714/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
715/// set or if set too high, the request will fail. The following values are based on the models
716/// available at the time of writing.
717fn calculate_max_tokens(model: &str) -> Option<u64> {
718    match model {
719        CLAUDE_4_OPUS => Some(32_000),
720        CLAUDE_4_SONNET | CLAUDE_3_7_SONNET => Some(64_000),
721        CLAUDE_3_5_SONNET | CLAUDE_3_5_HAIKU => Some(8_192),
722        _ => None,
723    }
724}
725
726fn calculate_max_tokens_custom(model: &str) -> u64 {
727    match model {
728        "claude-4-opus" => 32_000,
729        "claude-4-sonnet" | "claude-3.7-sonnet" => 64_000,
730        "claude-3.5-sonnet" | "claude-3.5-haiku" => 8_192,
731        _ => 4_096,
732    }
733}
734
735#[derive(Debug, Deserialize, Serialize)]
736pub struct Metadata {
737    user_id: Option<String>,
738}
739
740#[derive(Default, Debug, Serialize, Deserialize)]
741#[serde(tag = "type", rename_all = "snake_case")]
742pub enum ToolChoice {
743    #[default]
744    Auto,
745    Any,
746    None,
747    Tool {
748        name: String,
749    },
750}
751impl TryFrom<message::ToolChoice> for ToolChoice {
752    type Error = CompletionError;
753
754    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
755        let res = match value {
756            message::ToolChoice::Auto => Self::Auto,
757            message::ToolChoice::None => Self::None,
758            message::ToolChoice::Required => Self::Any,
759            message::ToolChoice::Specific { function_names } => {
760                if function_names.len() != 1 {
761                    return Err(CompletionError::ProviderError(
762                        "Only one tool may be specified to be used by Claude".into(),
763                    ));
764                }
765
766                Self::Tool {
767                    name: function_names.first().unwrap().to_string(),
768                }
769            }
770        };
771
772        Ok(res)
773    }
774}
775
776#[derive(Debug, Deserialize, Serialize)]
777struct AnthropicCompletionRequest {
778    model: String,
779    messages: Vec<Message>,
780    max_tokens: u64,
781    /// System prompt as array of content blocks to support cache_control
782    #[serde(skip_serializing_if = "Vec::is_empty")]
783    system: Vec<SystemContent>,
784    #[serde(skip_serializing_if = "Option::is_none")]
785    temperature: Option<f64>,
786    #[serde(skip_serializing_if = "Option::is_none")]
787    tool_choice: Option<ToolChoice>,
788    #[serde(skip_serializing_if = "Vec::is_empty")]
789    tools: Vec<ToolDefinition>,
790    #[serde(flatten, skip_serializing_if = "Option::is_none")]
791    additional_params: Option<serde_json::Value>,
792}
793
794/// Helper to set cache_control on a Content block
795fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
796    match content {
797        Content::Text { cache_control, .. } => *cache_control = value,
798        Content::Image { cache_control, .. } => *cache_control = value,
799        Content::ToolResult { cache_control, .. } => *cache_control = value,
800        Content::Document { cache_control, .. } => *cache_control = value,
801        _ => {}
802    }
803}
804
805/// Apply cache control breakpoints to system prompt and messages.
806/// Strategy: cache the system prompt, and mark the last content block of the last message
807/// for caching. This allows the conversation history to be cached while new messages
808/// are added.
809pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
810    // Add cache_control to the system prompt (if non-empty)
811    if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
812        *cache_control = Some(CacheControl::Ephemeral);
813    }
814
815    // Clear any existing cache_control from all message content blocks
816    for msg in messages.iter_mut() {
817        for content in msg.content.iter_mut() {
818            set_content_cache_control(content, None);
819        }
820    }
821
822    // Add cache_control to the last content block of the last message
823    if let Some(last_msg) = messages.last_mut() {
824        set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::Ephemeral));
825    }
826}
827
828/// Parameters for building an AnthropicCompletionRequest
829pub struct AnthropicRequestParams<'a> {
830    pub model: &'a str,
831    pub request: CompletionRequest,
832    pub prompt_caching: bool,
833}
834
835impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
836    type Error = CompletionError;
837
838    fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
839        let AnthropicRequestParams {
840            model,
841            request: req,
842            prompt_caching,
843        } = params;
844
845        // Check if max_tokens is set, required for Anthropic
846        let Some(max_tokens) = req.max_tokens else {
847            return Err(CompletionError::RequestError(
848                "`max_tokens` must be set for Anthropic".into(),
849            ));
850        };
851
852        let mut full_history = vec![];
853        if let Some(docs) = req.normalized_documents() {
854            full_history.push(docs);
855        }
856        full_history.extend(req.chat_history);
857
858        let mut messages = full_history
859            .into_iter()
860            .map(Message::try_from)
861            .collect::<Result<Vec<Message>, _>>()?;
862
863        let tools = req
864            .tools
865            .into_iter()
866            .map(|tool| ToolDefinition {
867                name: tool.name,
868                description: Some(tool.description),
869                input_schema: tool.parameters,
870            })
871            .collect::<Vec<_>>();
872
873        // Convert system prompt to array format for cache_control support
874        let mut system = if let Some(preamble) = req.preamble {
875            if preamble.is_empty() {
876                vec![]
877            } else {
878                vec![SystemContent::Text {
879                    text: preamble,
880                    cache_control: None,
881                }]
882            }
883        } else {
884            vec![]
885        };
886
887        // Apply cache control breakpoints only if prompt_caching is enabled
888        if prompt_caching {
889            apply_cache_control(&mut system, &mut messages);
890        }
891
892        Ok(Self {
893            model: model.to_string(),
894            messages,
895            max_tokens,
896            system,
897            temperature: req.temperature,
898            tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
899            tools,
900            additional_params: req.additional_params,
901        })
902    }
903}
904
905impl<T> completion::CompletionModel for CompletionModel<T>
906where
907    T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
908{
909    type Response = CompletionResponse;
910    type StreamingResponse = StreamingCompletionResponse;
911    type Client = Client<T>;
912
913    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
914        Self::new(client.clone(), model.into())
915    }
916
917    async fn completion(
918        &self,
919        mut completion_request: completion::CompletionRequest,
920    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
921        let span = if tracing::Span::current().is_disabled() {
922            info_span!(
923                target: "rig::completions",
924                "chat",
925                gen_ai.operation.name = "chat",
926                gen_ai.provider.name = "anthropic",
927                gen_ai.request.model = &self.model,
928                gen_ai.system_instructions = &completion_request.preamble,
929                gen_ai.response.id = tracing::field::Empty,
930                gen_ai.response.model = tracing::field::Empty,
931                gen_ai.usage.output_tokens = tracing::field::Empty,
932                gen_ai.usage.input_tokens = tracing::field::Empty,
933            )
934        } else {
935            tracing::Span::current()
936        };
937
938        // Check if max_tokens is set, required for Anthropic
939        if completion_request.max_tokens.is_none() {
940            if let Some(tokens) = self.default_max_tokens {
941                completion_request.max_tokens = Some(tokens);
942            } else {
943                return Err(CompletionError::RequestError(
944                    "`max_tokens` must be set for Anthropic".into(),
945                ));
946            }
947        }
948
949        let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
950            model: &self.model,
951            request: completion_request,
952            prompt_caching: self.prompt_caching,
953        })?;
954
955        if enabled!(Level::TRACE) {
956            tracing::trace!(
957                target: "rig::completions",
958                "Anthropic completion request: {}",
959                serde_json::to_string_pretty(&request)?
960            );
961        }
962
963        async move {
964            let request: Vec<u8> = serde_json::to_vec(&request)?;
965
966            let req = self
967                .client
968                .post("/v1/messages")?
969                .body(request)
970                .map_err(|e| CompletionError::HttpError(e.into()))?;
971
972            let response = self
973                .client
974                .send::<_, Bytes>(req)
975                .await
976                .map_err(CompletionError::HttpError)?;
977
978            if response.status().is_success() {
979                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
980                    response
981                        .into_body()
982                        .await
983                        .map_err(CompletionError::HttpError)?
984                        .to_vec()
985                        .as_slice(),
986                )? {
987                    ApiResponse::Message(completion) => {
988                        let span = tracing::Span::current();
989                        span.record_response_metadata(&completion);
990                        span.record_token_usage(&completion.usage);
991                        if enabled!(Level::TRACE) {
992                            tracing::trace!(
993                                target: "rig::completions",
994                                "Anthropic completion response: {}",
995                                serde_json::to_string_pretty(&completion)?
996                            );
997                        }
998                        completion.try_into()
999                    }
1000                    ApiResponse::Error(ApiErrorResponse { message }) => {
1001                        Err(CompletionError::ResponseError(message))
1002                    }
1003                }
1004            } else {
1005                let text: String = String::from_utf8_lossy(
1006                    &response
1007                        .into_body()
1008                        .await
1009                        .map_err(CompletionError::HttpError)?,
1010                )
1011                .into();
1012                Err(CompletionError::ProviderError(text))
1013            }
1014        }
1015        .instrument(span)
1016        .await
1017    }
1018
1019    async fn stream(
1020        &self,
1021        request: CompletionRequest,
1022    ) -> Result<
1023        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1024        CompletionError,
1025    > {
1026        CompletionModel::stream(self, request).await
1027    }
1028}
1029
1030#[derive(Debug, Deserialize)]
1031struct ApiErrorResponse {
1032    message: String,
1033}
1034
1035#[derive(Debug, Deserialize)]
1036#[serde(tag = "type", rename_all = "snake_case")]
1037enum ApiResponse<T> {
1038    Message(T),
1039    Error(ApiErrorResponse),
1040}
1041
1042#[cfg(test)]
1043mod tests {
1044    use super::*;
1045    use serde_json::json;
1046    use serde_path_to_error::deserialize;
1047
1048    #[test]
1049    fn test_deserialize_message() {
1050        let assistant_message_json = r#"
1051        {
1052            "role": "assistant",
1053            "content": "\n\nHello there, how may I assist you today?"
1054        }
1055        "#;
1056
1057        let assistant_message_json2 = r#"
1058        {
1059            "role": "assistant",
1060            "content": [
1061                {
1062                    "type": "text",
1063                    "text": "\n\nHello there, how may I assist you today?"
1064                },
1065                {
1066                    "type": "tool_use",
1067                    "id": "toolu_01A09q90qw90lq917835lq9",
1068                    "name": "get_weather",
1069                    "input": {"location": "San Francisco, CA"}
1070                }
1071            ]
1072        }
1073        "#;
1074
1075        let user_message_json = r#"
1076        {
1077            "role": "user",
1078            "content": [
1079                {
1080                    "type": "image",
1081                    "source": {
1082                        "type": "base64",
1083                        "media_type": "image/jpeg",
1084                        "data": "/9j/4AAQSkZJRg..."
1085                    }
1086                },
1087                {
1088                    "type": "text",
1089                    "text": "What is in this image?"
1090                },
1091                {
1092                    "type": "tool_result",
1093                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1094                    "content": "15 degrees"
1095                }
1096            ]
1097        }
1098        "#;
1099
1100        let assistant_message: Message = {
1101            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1102            deserialize(jd).unwrap_or_else(|err| {
1103                panic!("Deserialization error at {}: {}", err.path(), err);
1104            })
1105        };
1106
1107        let assistant_message2: Message = {
1108            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1109            deserialize(jd).unwrap_or_else(|err| {
1110                panic!("Deserialization error at {}: {}", err.path(), err);
1111            })
1112        };
1113
1114        let user_message: Message = {
1115            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1116            deserialize(jd).unwrap_or_else(|err| {
1117                panic!("Deserialization error at {}: {}", err.path(), err);
1118            })
1119        };
1120
1121        let Message { role, content } = assistant_message;
1122        assert_eq!(role, Role::Assistant);
1123        assert_eq!(
1124            content.first(),
1125            Content::Text {
1126                text: "\n\nHello there, how may I assist you today?".to_owned(),
1127                cache_control: None,
1128            }
1129        );
1130
1131        let Message { role, content } = assistant_message2;
1132        {
1133            assert_eq!(role, Role::Assistant);
1134            assert_eq!(content.len(), 2);
1135
1136            let mut iter = content.into_iter();
1137
1138            match iter.next().unwrap() {
1139                Content::Text { text, .. } => {
1140                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
1141                }
1142                _ => panic!("Expected text content"),
1143            }
1144
1145            match iter.next().unwrap() {
1146                Content::ToolUse { id, name, input } => {
1147                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1148                    assert_eq!(name, "get_weather");
1149                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
1150                }
1151                _ => panic!("Expected tool use content"),
1152            }
1153
1154            assert_eq!(iter.next(), None);
1155        }
1156
1157        let Message { role, content } = user_message;
1158        {
1159            assert_eq!(role, Role::User);
1160            assert_eq!(content.len(), 3);
1161
1162            let mut iter = content.into_iter();
1163
1164            match iter.next().unwrap() {
1165                Content::Image { source, .. } => {
1166                    assert_eq!(
1167                        source,
1168                        ImageSource {
1169                            data: ImageSourceData::Base64("/9j/4AAQSkZJRg...".to_owned()),
1170                            media_type: ImageFormat::JPEG,
1171                            r#type: SourceType::BASE64,
1172                        }
1173                    );
1174                }
1175                _ => panic!("Expected image content"),
1176            }
1177
1178            match iter.next().unwrap() {
1179                Content::Text { text, .. } => {
1180                    assert_eq!(text, "What is in this image?");
1181                }
1182                _ => panic!("Expected text content"),
1183            }
1184
1185            match iter.next().unwrap() {
1186                Content::ToolResult {
1187                    tool_use_id,
1188                    content,
1189                    is_error,
1190                    ..
1191                } => {
1192                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1193                    assert_eq!(
1194                        content.first(),
1195                        ToolResultContent::Text {
1196                            text: "15 degrees".to_owned()
1197                        }
1198                    );
1199                    assert_eq!(is_error, None);
1200                }
1201                _ => panic!("Expected tool result content"),
1202            }
1203
1204            assert_eq!(iter.next(), None);
1205        }
1206    }
1207
1208    #[test]
1209    fn test_message_to_message_conversion() {
1210        let user_message: Message = serde_json::from_str(
1211            r#"
1212        {
1213            "role": "user",
1214            "content": [
1215                {
1216                    "type": "image",
1217                    "source": {
1218                        "type": "base64",
1219                        "media_type": "image/jpeg",
1220                        "data": "/9j/4AAQSkZJRg..."
1221                    }
1222                },
1223                {
1224                    "type": "text",
1225                    "text": "What is in this image?"
1226                },
1227                {
1228                    "type": "document",
1229                    "source": {
1230                        "type": "base64",
1231                        "data": "base64_encoded_pdf_data",
1232                        "media_type": "application/pdf"
1233                    }
1234                }
1235            ]
1236        }
1237        "#,
1238        )
1239        .unwrap();
1240
1241        let assistant_message = Message {
1242            role: Role::Assistant,
1243            content: OneOrMany::one(Content::ToolUse {
1244                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1245                name: "get_weather".to_string(),
1246                input: json!({"location": "San Francisco, CA"}),
1247            }),
1248        };
1249
1250        let tool_message = Message {
1251            role: Role::User,
1252            content: OneOrMany::one(Content::ToolResult {
1253                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1254                content: OneOrMany::one(ToolResultContent::Text {
1255                    text: "15 degrees".to_string(),
1256                }),
1257                is_error: None,
1258                cache_control: None,
1259            }),
1260        };
1261
1262        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1263        let converted_assistant_message: message::Message =
1264            assistant_message.clone().try_into().unwrap();
1265        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1266
1267        match converted_user_message.clone() {
1268            message::Message::User { content } => {
1269                assert_eq!(content.len(), 3);
1270
1271                let mut iter = content.into_iter();
1272
1273                match iter.next().unwrap() {
1274                    message::UserContent::Image(message::Image {
1275                        data, media_type, ..
1276                    }) => {
1277                        assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1278                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1279                    }
1280                    _ => panic!("Expected image content"),
1281                }
1282
1283                match iter.next().unwrap() {
1284                    message::UserContent::Text(message::Text { text }) => {
1285                        assert_eq!(text, "What is in this image?");
1286                    }
1287                    _ => panic!("Expected text content"),
1288                }
1289
1290                match iter.next().unwrap() {
1291                    message::UserContent::Document(message::Document {
1292                        data, media_type, ..
1293                    }) => {
1294                        assert_eq!(
1295                            data,
1296                            DocumentSourceKind::String("base64_encoded_pdf_data".into())
1297                        );
1298                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1299                    }
1300                    _ => panic!("Expected document content"),
1301                }
1302
1303                assert_eq!(iter.next(), None);
1304            }
1305            _ => panic!("Expected user message"),
1306        }
1307
1308        match converted_tool_message.clone() {
1309            message::Message::User { content } => {
1310                let message::ToolResult { id, content, .. } = match content.first() {
1311                    message::UserContent::ToolResult(tool_result) => tool_result,
1312                    _ => panic!("Expected tool result content"),
1313                };
1314                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1315                match content.first() {
1316                    message::ToolResultContent::Text(message::Text { text }) => {
1317                        assert_eq!(text, "15 degrees");
1318                    }
1319                    _ => panic!("Expected text content"),
1320                }
1321            }
1322            _ => panic!("Expected tool result content"),
1323        }
1324
1325        match converted_assistant_message.clone() {
1326            message::Message::Assistant { content, .. } => {
1327                assert_eq!(content.len(), 1);
1328
1329                match content.first() {
1330                    message::AssistantContent::ToolCall(message::ToolCall {
1331                        id, function, ..
1332                    }) => {
1333                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1334                        assert_eq!(function.name, "get_weather");
1335                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1336                    }
1337                    _ => panic!("Expected tool call content"),
1338                }
1339            }
1340            _ => panic!("Expected assistant message"),
1341        }
1342
1343        let original_user_message: Message = converted_user_message.try_into().unwrap();
1344        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1345        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1346
1347        assert_eq!(user_message, original_user_message);
1348        assert_eq!(assistant_message, original_assistant_message);
1349        assert_eq!(tool_message, original_tool_message);
1350    }
1351
1352    #[test]
1353    fn test_content_format_conversion() {
1354        use crate::completion::message::ContentFormat;
1355
1356        let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1357        assert_eq!(source_type, SourceType::URL);
1358
1359        let content_format: ContentFormat = SourceType::URL.into();
1360        assert_eq!(content_format, ContentFormat::Url);
1361
1362        let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1363        assert_eq!(source_type, SourceType::BASE64);
1364
1365        let content_format: ContentFormat = SourceType::BASE64.into();
1366        assert_eq!(content_format, ContentFormat::Base64);
1367
1368        let result: Result<SourceType, _> = ContentFormat::String.try_into();
1369        assert!(result.is_err());
1370        assert!(
1371            result
1372                .unwrap_err()
1373                .to_string()
1374                .contains("ContentFormat::String is deprecated")
1375        );
1376    }
1377
1378    #[test]
1379    fn test_cache_control_serialization() {
1380        // Test SystemContent with cache_control
1381        let system = SystemContent::Text {
1382            text: "You are a helpful assistant.".to_string(),
1383            cache_control: Some(CacheControl::Ephemeral),
1384        };
1385        let json = serde_json::to_string(&system).unwrap();
1386        assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1387        assert!(json.contains(r#""type":"text""#));
1388
1389        // Test SystemContent without cache_control (should not have cache_control field)
1390        let system_no_cache = SystemContent::Text {
1391            text: "Hello".to_string(),
1392            cache_control: None,
1393        };
1394        let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1395        assert!(!json_no_cache.contains("cache_control"));
1396
1397        // Test Content::Text with cache_control
1398        let content = Content::Text {
1399            text: "Test message".to_string(),
1400            cache_control: Some(CacheControl::Ephemeral),
1401        };
1402        let json_content = serde_json::to_string(&content).unwrap();
1403        assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1404
1405        // Test apply_cache_control function
1406        let mut system_vec = vec![SystemContent::Text {
1407            text: "System prompt".to_string(),
1408            cache_control: None,
1409        }];
1410        let mut messages = vec![
1411            Message {
1412                role: Role::User,
1413                content: OneOrMany::one(Content::Text {
1414                    text: "First message".to_string(),
1415                    cache_control: None,
1416                }),
1417            },
1418            Message {
1419                role: Role::Assistant,
1420                content: OneOrMany::one(Content::Text {
1421                    text: "Response".to_string(),
1422                    cache_control: None,
1423                }),
1424            },
1425        ];
1426
1427        apply_cache_control(&mut system_vec, &mut messages);
1428
1429        // System should have cache_control
1430        match &system_vec[0] {
1431            SystemContent::Text { cache_control, .. } => {
1432                assert!(cache_control.is_some());
1433            }
1434        }
1435
1436        // Only the last content block of last message should have cache_control
1437        // First message should NOT have cache_control
1438        for content in messages[0].content.iter() {
1439            if let Content::Text { cache_control, .. } = content {
1440                assert!(cache_control.is_none());
1441            }
1442        }
1443
1444        // Last message SHOULD have cache_control
1445        for content in messages[1].content.iter() {
1446            if let Content::Text { cache_control, .. } = content {
1447                assert!(cache_control.is_some());
1448            }
1449        }
1450    }
1451}