rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError, GetTokenUsage},
6    http_client::HttpClientExt,
7    message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
8    one_or_many::string_or_one_or_many,
9    telemetry::{ProviderResponseExt, SpanCombinator},
10    wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21// ================================================================
22// Anthropic Completion API
23// ================================================================
24
25/// `claude-opus-4-0` completion model
26pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27/// `claude-sonnet-4-0` completion model
28pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29/// `claude-3-7-sonnet-latest` completion model
30pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31/// `claude-3-5-sonnet-latest` completion model
32pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33/// `claude-3-5-haiku-latest` completion model
34pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42    pub content: Vec<Content>,
43    pub id: String,
44    pub model: String,
45    pub role: String,
46    pub stop_reason: Option<String>,
47    pub stop_sequence: Option<String>,
48    pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52    type OutputMessage = Content;
53    type Usage = Usage;
54
55    fn get_response_id(&self) -> Option<String> {
56        Some(self.id.to_owned())
57    }
58
59    fn get_response_model_name(&self) -> Option<String> {
60        Some(self.model.to_owned())
61    }
62
63    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64        self.content.clone()
65    }
66
67    fn get_text_response(&self) -> Option<String> {
68        let res = self
69            .content
70            .iter()
71            .filter_map(|x| {
72                if let Content::Text { text, .. } = x {
73                    Some(text.to_owned())
74                } else {
75                    None
76                }
77            })
78            .collect::<Vec<String>>()
79            .join("\n");
80
81        if res.is_empty() { None } else { Some(res) }
82    }
83
84    fn get_usage(&self) -> Option<Self::Usage> {
85        Some(self.usage.clone())
86    }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91    pub input_tokens: u64,
92    pub cache_read_input_tokens: Option<u64>,
93    pub cache_creation_input_tokens: Option<u64>,
94    pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        write!(
100            f,
101            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102            self.input_tokens,
103            match self.cache_read_input_tokens {
104                Some(token) => token.to_string(),
105                None => "n/a".to_string(),
106            },
107            match self.cache_creation_input_tokens {
108                Some(token) => token.to_string(),
109                None => "n/a".to_string(),
110            },
111            self.output_tokens
112        )
113    }
114}
115
116impl GetTokenUsage for Usage {
117    fn token_usage(&self) -> Option<crate::completion::Usage> {
118        let mut usage = crate::completion::Usage::new();
119
120        usage.input_tokens = self.input_tokens
121            + self.cache_creation_input_tokens.unwrap_or_default()
122            + self.cache_read_input_tokens.unwrap_or_default();
123        usage.output_tokens = self.output_tokens;
124        usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126        Some(usage)
127    }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132    pub name: String,
133    pub description: Option<String>,
134    pub input_schema: serde_json::Value,
135}
136
137/// Cache control directive for Anthropic prompt caching
138#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum CacheControl {
141    Ephemeral,
142}
143
144/// System message content block with optional cache control
145#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
146#[serde(tag = "type", rename_all = "snake_case")]
147pub enum SystemContent {
148    Text {
149        text: String,
150        #[serde(skip_serializing_if = "Option::is_none")]
151        cache_control: Option<CacheControl>,
152    },
153}
154
155impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
156    type Error = CompletionError;
157
158    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
159        let content = response
160            .content
161            .iter()
162            .map(|content| content.clone().try_into())
163            .collect::<Result<Vec<_>, _>>()?;
164
165        let choice = OneOrMany::many(content).map_err(|_| {
166            CompletionError::ResponseError(
167                "Response contained no message or tool call (empty)".to_owned(),
168            )
169        })?;
170
171        let usage = completion::Usage {
172            input_tokens: response.usage.input_tokens,
173            output_tokens: response.usage.output_tokens,
174            total_tokens: response.usage.input_tokens + response.usage.output_tokens,
175        };
176
177        Ok(completion::CompletionResponse {
178            choice,
179            usage,
180            raw_response: response,
181        })
182    }
183}
184
185#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
186pub struct Message {
187    pub role: Role,
188    #[serde(deserialize_with = "string_or_one_or_many")]
189    pub content: OneOrMany<Content>,
190}
191
192#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
193#[serde(rename_all = "lowercase")]
194pub enum Role {
195    User,
196    Assistant,
197}
198
199#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
200#[serde(tag = "type", rename_all = "snake_case")]
201pub enum Content {
202    Text {
203        text: String,
204        #[serde(skip_serializing_if = "Option::is_none")]
205        cache_control: Option<CacheControl>,
206    },
207    Image {
208        source: ImageSource,
209        #[serde(skip_serializing_if = "Option::is_none")]
210        cache_control: Option<CacheControl>,
211    },
212    ToolUse {
213        id: String,
214        name: String,
215        input: serde_json::Value,
216    },
217    ToolResult {
218        tool_use_id: String,
219        #[serde(deserialize_with = "string_or_one_or_many")]
220        content: OneOrMany<ToolResultContent>,
221        #[serde(skip_serializing_if = "Option::is_none")]
222        is_error: Option<bool>,
223        #[serde(skip_serializing_if = "Option::is_none")]
224        cache_control: Option<CacheControl>,
225    },
226    Document {
227        source: DocumentSource,
228        #[serde(skip_serializing_if = "Option::is_none")]
229        cache_control: Option<CacheControl>,
230    },
231    Thinking {
232        thinking: String,
233        #[serde(skip_serializing_if = "Option::is_none")]
234        signature: Option<String>,
235    },
236}
237
238impl FromStr for Content {
239    type Err = Infallible;
240
241    fn from_str(s: &str) -> Result<Self, Self::Err> {
242        Ok(Content::Text {
243            text: s.to_owned(),
244            cache_control: None,
245        })
246    }
247}
248
249#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
250#[serde(tag = "type", rename_all = "snake_case")]
251pub enum ToolResultContent {
252    Text { text: String },
253    Image(ImageSource),
254}
255
256impl FromStr for ToolResultContent {
257    type Err = Infallible;
258
259    fn from_str(s: &str) -> Result<Self, Self::Err> {
260        Ok(ToolResultContent::Text { text: s.to_owned() })
261    }
262}
263
264#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
265#[serde(untagged)]
266pub enum ImageSourceData {
267    Base64(String),
268    Url(String),
269}
270
271impl From<ImageSourceData> for DocumentSourceKind {
272    fn from(value: ImageSourceData) -> Self {
273        match value {
274            ImageSourceData::Base64(data) => DocumentSourceKind::Base64(data),
275            ImageSourceData::Url(url) => DocumentSourceKind::Url(url),
276        }
277    }
278}
279
280impl TryFrom<DocumentSourceKind> for ImageSourceData {
281    type Error = MessageError;
282
283    fn try_from(value: DocumentSourceKind) -> Result<Self, Self::Error> {
284        match value {
285            DocumentSourceKind::Base64(data) => Ok(ImageSourceData::Base64(data)),
286            DocumentSourceKind::Url(url) => Ok(ImageSourceData::Url(url)),
287            _ => Err(MessageError::ConversionError("Content has no body".into())),
288        }
289    }
290}
291
292impl From<ImageSourceData> for String {
293    fn from(value: ImageSourceData) -> Self {
294        match value {
295            ImageSourceData::Base64(s) | ImageSourceData::Url(s) => s,
296        }
297    }
298}
299
300#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
301pub struct ImageSource {
302    pub data: ImageSourceData,
303    pub media_type: ImageFormat,
304    pub r#type: SourceType,
305}
306
307#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
308pub struct DocumentSource {
309    pub data: String,
310    pub media_type: DocumentFormat,
311    pub r#type: SourceType,
312}
313
314#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
315#[serde(rename_all = "lowercase")]
316pub enum ImageFormat {
317    #[serde(rename = "image/jpeg")]
318    JPEG,
319    #[serde(rename = "image/png")]
320    PNG,
321    #[serde(rename = "image/gif")]
322    GIF,
323    #[serde(rename = "image/webp")]
324    WEBP,
325}
326
327/// The document format to be used.
328///
329/// Currently, Anthropic only supports PDF for text documents over the API (within a message). You can find more information about this here: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
330#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
331#[serde(rename_all = "lowercase")]
332pub enum DocumentFormat {
333    #[serde(rename = "application/pdf")]
334    PDF,
335}
336
337#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
338#[serde(rename_all = "lowercase")]
339pub enum SourceType {
340    BASE64,
341    URL,
342}
343
344impl From<String> for Content {
345    fn from(text: String) -> Self {
346        Content::Text {
347            text,
348            cache_control: None,
349        }
350    }
351}
352
353impl From<String> for ToolResultContent {
354    fn from(text: String) -> Self {
355        ToolResultContent::Text { text }
356    }
357}
358
359impl TryFrom<message::ContentFormat> for SourceType {
360    type Error = MessageError;
361
362    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
363        match format {
364            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
365            message::ContentFormat::Url => Ok(SourceType::URL),
366            message::ContentFormat::String => Err(MessageError::ConversionError(
367                "ContentFormat::String is deprecated, use ContentFormat::Url for URLs".into(),
368            )),
369        }
370    }
371}
372
373impl From<SourceType> for message::ContentFormat {
374    fn from(source_type: SourceType) -> Self {
375        match source_type {
376            SourceType::BASE64 => message::ContentFormat::Base64,
377            SourceType::URL => message::ContentFormat::Url,
378        }
379    }
380}
381
382impl TryFrom<message::ImageMediaType> for ImageFormat {
383    type Error = MessageError;
384
385    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
386        Ok(match media_type {
387            message::ImageMediaType::JPEG => ImageFormat::JPEG,
388            message::ImageMediaType::PNG => ImageFormat::PNG,
389            message::ImageMediaType::GIF => ImageFormat::GIF,
390            message::ImageMediaType::WEBP => ImageFormat::WEBP,
391            _ => {
392                return Err(MessageError::ConversionError(
393                    format!("Unsupported image media type: {media_type:?}").to_owned(),
394                ));
395            }
396        })
397    }
398}
399
400impl From<ImageFormat> for message::ImageMediaType {
401    fn from(format: ImageFormat) -> Self {
402        match format {
403            ImageFormat::JPEG => message::ImageMediaType::JPEG,
404            ImageFormat::PNG => message::ImageMediaType::PNG,
405            ImageFormat::GIF => message::ImageMediaType::GIF,
406            ImageFormat::WEBP => message::ImageMediaType::WEBP,
407        }
408    }
409}
410
411impl TryFrom<DocumentMediaType> for DocumentFormat {
412    type Error = MessageError;
413    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
414        if !matches!(value, DocumentMediaType::PDF) {
415            return Err(MessageError::ConversionError(
416                "Anthropic only supports PDF documents".to_string(),
417            ));
418        };
419
420        Ok(DocumentFormat::PDF)
421    }
422}
423
424impl TryFrom<message::AssistantContent> for Content {
425    type Error = MessageError;
426    fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
427        match text {
428            message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
429                text,
430                cache_control: None,
431            }),
432            message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
433                "Anthropic currently doesn't support images.".to_string(),
434            )),
435            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
436                Ok(Content::ToolUse {
437                    id,
438                    name: function.name,
439                    input: function.arguments,
440                })
441            }
442            message::AssistantContent::Reasoning(Reasoning {
443                reasoning,
444                signature,
445                ..
446            }) => Ok(Content::Thinking {
447                thinking: reasoning.first().cloned().unwrap_or(String::new()),
448                signature,
449            }),
450        }
451    }
452}
453
454impl TryFrom<message::Message> for Message {
455    type Error = MessageError;
456
457    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
458        Ok(match message {
459            message::Message::User { content } => Message {
460                role: Role::User,
461                content: content.try_map(|content| match content {
462                    message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
463                        text,
464                        cache_control: None,
465                    }),
466                    message::UserContent::ToolResult(message::ToolResult {
467                        id, content, ..
468                    }) => Ok(Content::ToolResult {
469                        tool_use_id: id,
470                        content: content.try_map(|content| match content {
471                            message::ToolResultContent::Text(message::Text { text }) => {
472                                Ok(ToolResultContent::Text { text })
473                            }
474                            message::ToolResultContent::Image(image) => {
475                                let DocumentSourceKind::Base64(data) = image.data else {
476                                    return Err(MessageError::ConversionError(
477                                        "Only base64 strings can be used with the Anthropic API"
478                                            .to_string(),
479                                    ));
480                                };
481                                let media_type =
482                                    image.media_type.ok_or(MessageError::ConversionError(
483                                        "Image media type is required".to_owned(),
484                                    ))?;
485                                Ok(ToolResultContent::Image(ImageSource {
486                                    data: ImageSourceData::Base64(data),
487                                    media_type: media_type.try_into()?,
488                                    r#type: SourceType::BASE64,
489                                }))
490                            }
491                        })?,
492                        is_error: None,
493                        cache_control: None,
494                    }),
495                    message::UserContent::Image(message::Image {
496                        data, media_type, ..
497                    }) => {
498                        let media_type = media_type.ok_or(MessageError::ConversionError(
499                            "Image media type is required for Claude API".to_string(),
500                        ))?;
501
502                        let source = match data {
503                            DocumentSourceKind::Base64(data) => ImageSource {
504                                data: ImageSourceData::Base64(data),
505                                r#type: SourceType::BASE64,
506                                media_type: ImageFormat::try_from(media_type)?,
507                            },
508                            DocumentSourceKind::Url(url) => ImageSource {
509                                data: ImageSourceData::Url(url),
510                                r#type: SourceType::URL,
511                                media_type: ImageFormat::try_from(media_type)?,
512                            },
513                            DocumentSourceKind::Unknown => {
514                                return Err(MessageError::ConversionError(
515                                    "Image content has no body".into(),
516                                ));
517                            }
518                            doc => {
519                                return Err(MessageError::ConversionError(format!(
520                                    "Unsupported document type: {doc:?}"
521                                )));
522                            }
523                        };
524
525                        Ok(Content::Image {
526                            source,
527                            cache_control: None,
528                        })
529                    }
530                    message::UserContent::Document(message::Document {
531                        data, media_type, ..
532                    }) => {
533                        let media_type = media_type.ok_or(MessageError::ConversionError(
534                            "Document media type is required".to_string(),
535                        ))?;
536
537                        let data = match data {
538                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
539                                data
540                            }
541                            _ => {
542                                return Err(MessageError::ConversionError(
543                                    "Only base64 encoded documents currently supported".into(),
544                                ));
545                            }
546                        };
547
548                        let source = DocumentSource {
549                            data,
550                            media_type: media_type.try_into()?,
551                            r#type: SourceType::BASE64,
552                        };
553                        Ok(Content::Document {
554                            source,
555                            cache_control: None,
556                        })
557                    }
558                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
559                        "Audio is not supported in Anthropic".to_owned(),
560                    )),
561                    message::UserContent::Video { .. } => Err(MessageError::ConversionError(
562                        "Video is not supported in Anthropic".to_owned(),
563                    )),
564                })?,
565            },
566
567            message::Message::Assistant { content, .. } => Message {
568                content: content.try_map(|content| content.try_into())?,
569                role: Role::Assistant,
570            },
571        })
572    }
573}
574
575impl TryFrom<Content> for message::AssistantContent {
576    type Error = MessageError;
577
578    fn try_from(content: Content) -> Result<Self, Self::Error> {
579        Ok(match content {
580            Content::Text { text, .. } => message::AssistantContent::text(text),
581            Content::ToolUse { id, name, input } => {
582                message::AssistantContent::tool_call(id, name, input)
583            }
584            Content::Thinking {
585                thinking,
586                signature,
587            } => message::AssistantContent::Reasoning(
588                Reasoning::new(&thinking).with_signature(signature),
589            ),
590            _ => {
591                return Err(MessageError::ConversionError(
592                    "Content did not contain a message, tool call, or reasoning".to_owned(),
593                ));
594            }
595        })
596    }
597}
598
599impl From<ToolResultContent> for message::ToolResultContent {
600    fn from(content: ToolResultContent) -> Self {
601        match content {
602            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
603            ToolResultContent::Image(ImageSource {
604                data,
605                media_type: format,
606                ..
607            }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
608        }
609    }
610}
611
612impl TryFrom<Message> for message::Message {
613    type Error = MessageError;
614
615    fn try_from(message: Message) -> Result<Self, Self::Error> {
616        Ok(match message.role {
617            Role::User => message::Message::User {
618                content: message.content.try_map(|content| {
619                    Ok(match content {
620                        Content::Text { text, .. } => message::UserContent::text(text),
621                        Content::ToolResult {
622                            tool_use_id,
623                            content,
624                            ..
625                        } => message::UserContent::tool_result(
626                            tool_use_id,
627                            content.map(|content| content.into()),
628                        ),
629                        Content::Image { source, .. } => {
630                            message::UserContent::Image(message::Image {
631                                data: source.data.into(),
632                                media_type: Some(source.media_type.into()),
633                                detail: None,
634                                additional_params: None,
635                            })
636                        }
637                        Content::Document { source, .. } => message::UserContent::document(
638                            source.data,
639                            Some(message::DocumentMediaType::PDF),
640                        ),
641                        _ => {
642                            return Err(MessageError::ConversionError(
643                                "Unsupported content type for User role".to_owned(),
644                            ));
645                        }
646                    })
647                })?,
648            },
649            Role::Assistant => match message.content.first() {
650                Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
651                    message::Message::Assistant {
652                        id: None,
653                        content: message.content.try_map(|content| content.try_into())?,
654                    }
655                }
656
657                _ => {
658                    return Err(MessageError::ConversionError(
659                        format!("Unsupported message for Assistant role: {message:?}").to_owned(),
660                    ));
661                }
662            },
663        })
664    }
665}
666
667#[derive(Clone)]
668pub struct CompletionModel<T = reqwest::Client> {
669    pub(crate) client: Client<T>,
670    pub model: String,
671    pub default_max_tokens: Option<u64>,
672    /// Enable automatic prompt caching (adds cache_control breakpoints to system prompt and messages)
673    pub prompt_caching: bool,
674}
675
676impl<T> CompletionModel<T>
677where
678    T: HttpClientExt,
679{
680    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
681        let model = model.into();
682        let default_max_tokens = calculate_max_tokens(&model);
683
684        Self {
685            client,
686            model,
687            default_max_tokens,
688            prompt_caching: false, // Default to off
689        }
690    }
691
692    pub fn with_model(client: Client<T>, model: &str) -> Self {
693        Self {
694            client,
695            model: model.to_string(),
696            default_max_tokens: Some(calculate_max_tokens_custom(model)),
697            prompt_caching: false, // Default to off
698        }
699    }
700
701    /// Enable automatic prompt caching.
702    ///
703    /// When enabled, cache_control breakpoints are automatically added to:
704    /// - The system prompt (marked with ephemeral cache)
705    /// - The last content block of the last message (marked with ephemeral cache)
706    ///
707    /// This allows Anthropic to cache the conversation history for cost savings.
708    pub fn with_prompt_caching(mut self) -> Self {
709        self.prompt_caching = true;
710        self
711    }
712}
713
714/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
715/// set or if set too high, the request will fail. The following values are based on the models
716/// available at the time of writing.
717fn calculate_max_tokens(model: &str) -> Option<u64> {
718    match model {
719        CLAUDE_4_OPUS => Some(32_000),
720        CLAUDE_4_SONNET | CLAUDE_3_7_SONNET => Some(64_000),
721        CLAUDE_3_5_SONNET | CLAUDE_3_5_HAIKU => Some(8_192),
722        _ => None,
723    }
724}
725
726fn calculate_max_tokens_custom(model: &str) -> u64 {
727    match model {
728        "claude-4-opus" => 32_000,
729        "claude-4-sonnet" | "claude-3.7-sonnet" => 64_000,
730        "claude-3.5-sonnet" | "claude-3.5-haiku" => 8_192,
731        _ => 4_096,
732    }
733}
734
735#[derive(Debug, Deserialize, Serialize)]
736pub struct Metadata {
737    user_id: Option<String>,
738}
739
740#[derive(Default, Debug, Serialize, Deserialize)]
741#[serde(tag = "type", rename_all = "snake_case")]
742pub enum ToolChoice {
743    #[default]
744    Auto,
745    Any,
746    None,
747    Tool {
748        name: String,
749    },
750}
751impl TryFrom<message::ToolChoice> for ToolChoice {
752    type Error = CompletionError;
753
754    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
755        let res = match value {
756            message::ToolChoice::Auto => Self::Auto,
757            message::ToolChoice::None => Self::None,
758            message::ToolChoice::Required => Self::Any,
759            message::ToolChoice::Specific { function_names } => {
760                if function_names.len() != 1 {
761                    return Err(CompletionError::ProviderError(
762                        "Only one tool may be specified to be used by Claude".into(),
763                    ));
764                }
765
766                Self::Tool {
767                    name: function_names.first().unwrap().to_string(),
768                }
769            }
770        };
771
772        Ok(res)
773    }
774}
775
776#[derive(Debug, Deserialize, Serialize)]
777struct AnthropicCompletionRequest {
778    model: String,
779    messages: Vec<Message>,
780    max_tokens: u64,
781    /// System prompt as array of content blocks to support cache_control
782    #[serde(skip_serializing_if = "Vec::is_empty")]
783    system: Vec<SystemContent>,
784    #[serde(skip_serializing_if = "Option::is_none")]
785    temperature: Option<f64>,
786    #[serde(skip_serializing_if = "Option::is_none")]
787    tool_choice: Option<ToolChoice>,
788    #[serde(skip_serializing_if = "Vec::is_empty")]
789    tools: Vec<ToolDefinition>,
790    #[serde(flatten, skip_serializing_if = "Option::is_none")]
791    additional_params: Option<serde_json::Value>,
792}
793
794/// Helper to set cache_control on a Content block
795fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
796    match content {
797        Content::Text { cache_control, .. } => *cache_control = value,
798        Content::Image { cache_control, .. } => *cache_control = value,
799        Content::ToolResult { cache_control, .. } => *cache_control = value,
800        Content::Document { cache_control, .. } => *cache_control = value,
801        _ => {}
802    }
803}
804
805/// Apply cache control breakpoints to system prompt and messages.
806/// Strategy: cache the system prompt, and mark the last content block of the last message
807/// for caching. This allows the conversation history to be cached while new messages
808/// are added.
809pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
810    // Add cache_control to the system prompt (if non-empty)
811    if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
812        *cache_control = Some(CacheControl::Ephemeral);
813    }
814
815    // Clear any existing cache_control from all message content blocks
816    for msg in messages.iter_mut() {
817        for content in msg.content.iter_mut() {
818            set_content_cache_control(content, None);
819        }
820    }
821
822    // Add cache_control to the last content block of the last message
823    if let Some(last_msg) = messages.last_mut() {
824        set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::Ephemeral));
825    }
826}
827
828/// Parameters for building an AnthropicCompletionRequest
829pub struct AnthropicRequestParams<'a> {
830    pub model: &'a str,
831    pub request: CompletionRequest,
832    pub prompt_caching: bool,
833}
834
835impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
836    type Error = CompletionError;
837
838    fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
839        let AnthropicRequestParams {
840            model,
841            request: req,
842            prompt_caching,
843        } = params;
844
845        // Check if max_tokens is set, required for Anthropic
846        let Some(max_tokens) = req.max_tokens else {
847            return Err(CompletionError::RequestError(
848                "`max_tokens` must be set for Anthropic".into(),
849            ));
850        };
851
852        let mut full_history = vec![];
853        if let Some(docs) = req.normalized_documents() {
854            full_history.push(docs);
855        }
856        full_history.extend(req.chat_history);
857
858        let mut messages = full_history
859            .into_iter()
860            .map(Message::try_from)
861            .collect::<Result<Vec<Message>, _>>()?;
862
863        let tools = req
864            .tools
865            .into_iter()
866            .map(|tool| ToolDefinition {
867                name: tool.name,
868                description: Some(tool.description),
869                input_schema: tool.parameters,
870            })
871            .collect::<Vec<_>>();
872
873        // Convert system prompt to array format for cache_control support
874        let mut system = if let Some(preamble) = req.preamble {
875            if preamble.is_empty() {
876                vec![]
877            } else {
878                vec![SystemContent::Text {
879                    text: preamble,
880                    cache_control: None,
881                }]
882            }
883        } else {
884            vec![]
885        };
886
887        // Apply cache control breakpoints only if prompt_caching is enabled
888        if prompt_caching {
889            apply_cache_control(&mut system, &mut messages);
890        }
891
892        Ok(Self {
893            model: model.to_string(),
894            messages,
895            max_tokens,
896            system,
897            temperature: req.temperature,
898            tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
899            tools,
900            additional_params: req.additional_params,
901        })
902    }
903}
904
905impl<T> completion::CompletionModel for CompletionModel<T>
906where
907    T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
908{
909    type Response = CompletionResponse;
910    type StreamingResponse = StreamingCompletionResponse;
911    type Client = Client<T>;
912
913    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
914        Self::new(client.clone(), model.into())
915    }
916
917    #[cfg_attr(feature = "worker", worker::send)]
918    async fn completion(
919        &self,
920        mut completion_request: completion::CompletionRequest,
921    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
922        let span = if tracing::Span::current().is_disabled() {
923            info_span!(
924                target: "rig::completions",
925                "chat",
926                gen_ai.operation.name = "chat",
927                gen_ai.provider.name = "anthropic",
928                gen_ai.request.model = &self.model,
929                gen_ai.system_instructions = &completion_request.preamble,
930                gen_ai.response.id = tracing::field::Empty,
931                gen_ai.response.model = tracing::field::Empty,
932                gen_ai.usage.output_tokens = tracing::field::Empty,
933                gen_ai.usage.input_tokens = tracing::field::Empty,
934            )
935        } else {
936            tracing::Span::current()
937        };
938
939        // Check if max_tokens is set, required for Anthropic
940        if completion_request.max_tokens.is_none() {
941            if let Some(tokens) = self.default_max_tokens {
942                completion_request.max_tokens = Some(tokens);
943            } else {
944                return Err(CompletionError::RequestError(
945                    "`max_tokens` must be set for Anthropic".into(),
946                ));
947            }
948        }
949
950        let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
951            model: &self.model,
952            request: completion_request,
953            prompt_caching: self.prompt_caching,
954        })?;
955
956        if enabled!(Level::TRACE) {
957            tracing::trace!(
958                target: "rig::completions",
959                "Anthropic completion request: {}",
960                serde_json::to_string_pretty(&request)?
961            );
962        }
963
964        async move {
965            let request: Vec<u8> = serde_json::to_vec(&request)?;
966
967            let req = self
968                .client
969                .post("/v1/messages")?
970                .body(request)
971                .map_err(|e| CompletionError::HttpError(e.into()))?;
972
973            let response = self
974                .client
975                .send::<_, Bytes>(req)
976                .await
977                .map_err(CompletionError::HttpError)?;
978
979            if response.status().is_success() {
980                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
981                    response
982                        .into_body()
983                        .await
984                        .map_err(CompletionError::HttpError)?
985                        .to_vec()
986                        .as_slice(),
987                )? {
988                    ApiResponse::Message(completion) => {
989                        let span = tracing::Span::current();
990                        span.record_response_metadata(&completion);
991                        span.record_token_usage(&completion.usage);
992                        if enabled!(Level::TRACE) {
993                            tracing::trace!(
994                                target: "rig::completions",
995                                "Anthropic completion response: {}",
996                                serde_json::to_string_pretty(&completion)?
997                            );
998                        }
999                        completion.try_into()
1000                    }
1001                    ApiResponse::Error(ApiErrorResponse { message }) => {
1002                        Err(CompletionError::ResponseError(message))
1003                    }
1004                }
1005            } else {
1006                let text: String = String::from_utf8_lossy(
1007                    &response
1008                        .into_body()
1009                        .await
1010                        .map_err(CompletionError::HttpError)?,
1011                )
1012                .into();
1013                Err(CompletionError::ProviderError(text))
1014            }
1015        }
1016        .instrument(span)
1017        .await
1018    }
1019
1020    #[cfg_attr(feature = "worker", worker::send)]
1021    async fn stream(
1022        &self,
1023        request: CompletionRequest,
1024    ) -> Result<
1025        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1026        CompletionError,
1027    > {
1028        CompletionModel::stream(self, request).await
1029    }
1030}
1031
1032#[derive(Debug, Deserialize)]
1033struct ApiErrorResponse {
1034    message: String,
1035}
1036
1037#[derive(Debug, Deserialize)]
1038#[serde(tag = "type", rename_all = "snake_case")]
1039enum ApiResponse<T> {
1040    Message(T),
1041    Error(ApiErrorResponse),
1042}
1043
1044#[cfg(test)]
1045mod tests {
1046    use super::*;
1047    use serde_json::json;
1048    use serde_path_to_error::deserialize;
1049
1050    #[test]
1051    fn test_deserialize_message() {
1052        let assistant_message_json = r#"
1053        {
1054            "role": "assistant",
1055            "content": "\n\nHello there, how may I assist you today?"
1056        }
1057        "#;
1058
1059        let assistant_message_json2 = r#"
1060        {
1061            "role": "assistant",
1062            "content": [
1063                {
1064                    "type": "text",
1065                    "text": "\n\nHello there, how may I assist you today?"
1066                },
1067                {
1068                    "type": "tool_use",
1069                    "id": "toolu_01A09q90qw90lq917835lq9",
1070                    "name": "get_weather",
1071                    "input": {"location": "San Francisco, CA"}
1072                }
1073            ]
1074        }
1075        "#;
1076
1077        let user_message_json = r#"
1078        {
1079            "role": "user",
1080            "content": [
1081                {
1082                    "type": "image",
1083                    "source": {
1084                        "type": "base64",
1085                        "media_type": "image/jpeg",
1086                        "data": "/9j/4AAQSkZJRg..."
1087                    }
1088                },
1089                {
1090                    "type": "text",
1091                    "text": "What is in this image?"
1092                },
1093                {
1094                    "type": "tool_result",
1095                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1096                    "content": "15 degrees"
1097                }
1098            ]
1099        }
1100        "#;
1101
1102        let assistant_message: Message = {
1103            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1104            deserialize(jd).unwrap_or_else(|err| {
1105                panic!("Deserialization error at {}: {}", err.path(), err);
1106            })
1107        };
1108
1109        let assistant_message2: Message = {
1110            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1111            deserialize(jd).unwrap_or_else(|err| {
1112                panic!("Deserialization error at {}: {}", err.path(), err);
1113            })
1114        };
1115
1116        let user_message: Message = {
1117            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1118            deserialize(jd).unwrap_or_else(|err| {
1119                panic!("Deserialization error at {}: {}", err.path(), err);
1120            })
1121        };
1122
1123        let Message { role, content } = assistant_message;
1124        assert_eq!(role, Role::Assistant);
1125        assert_eq!(
1126            content.first(),
1127            Content::Text {
1128                text: "\n\nHello there, how may I assist you today?".to_owned(),
1129                cache_control: None,
1130            }
1131        );
1132
1133        let Message { role, content } = assistant_message2;
1134        {
1135            assert_eq!(role, Role::Assistant);
1136            assert_eq!(content.len(), 2);
1137
1138            let mut iter = content.into_iter();
1139
1140            match iter.next().unwrap() {
1141                Content::Text { text, .. } => {
1142                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
1143                }
1144                _ => panic!("Expected text content"),
1145            }
1146
1147            match iter.next().unwrap() {
1148                Content::ToolUse { id, name, input } => {
1149                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1150                    assert_eq!(name, "get_weather");
1151                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
1152                }
1153                _ => panic!("Expected tool use content"),
1154            }
1155
1156            assert_eq!(iter.next(), None);
1157        }
1158
1159        let Message { role, content } = user_message;
1160        {
1161            assert_eq!(role, Role::User);
1162            assert_eq!(content.len(), 3);
1163
1164            let mut iter = content.into_iter();
1165
1166            match iter.next().unwrap() {
1167                Content::Image { source, .. } => {
1168                    assert_eq!(
1169                        source,
1170                        ImageSource {
1171                            data: ImageSourceData::Base64("/9j/4AAQSkZJRg...".to_owned()),
1172                            media_type: ImageFormat::JPEG,
1173                            r#type: SourceType::BASE64,
1174                        }
1175                    );
1176                }
1177                _ => panic!("Expected image content"),
1178            }
1179
1180            match iter.next().unwrap() {
1181                Content::Text { text, .. } => {
1182                    assert_eq!(text, "What is in this image?");
1183                }
1184                _ => panic!("Expected text content"),
1185            }
1186
1187            match iter.next().unwrap() {
1188                Content::ToolResult {
1189                    tool_use_id,
1190                    content,
1191                    is_error,
1192                    ..
1193                } => {
1194                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1195                    assert_eq!(
1196                        content.first(),
1197                        ToolResultContent::Text {
1198                            text: "15 degrees".to_owned()
1199                        }
1200                    );
1201                    assert_eq!(is_error, None);
1202                }
1203                _ => panic!("Expected tool result content"),
1204            }
1205
1206            assert_eq!(iter.next(), None);
1207        }
1208    }
1209
1210    #[test]
1211    fn test_message_to_message_conversion() {
1212        let user_message: Message = serde_json::from_str(
1213            r#"
1214        {
1215            "role": "user",
1216            "content": [
1217                {
1218                    "type": "image",
1219                    "source": {
1220                        "type": "base64",
1221                        "media_type": "image/jpeg",
1222                        "data": "/9j/4AAQSkZJRg..."
1223                    }
1224                },
1225                {
1226                    "type": "text",
1227                    "text": "What is in this image?"
1228                },
1229                {
1230                    "type": "document",
1231                    "source": {
1232                        "type": "base64",
1233                        "data": "base64_encoded_pdf_data",
1234                        "media_type": "application/pdf"
1235                    }
1236                }
1237            ]
1238        }
1239        "#,
1240        )
1241        .unwrap();
1242
1243        let assistant_message = Message {
1244            role: Role::Assistant,
1245            content: OneOrMany::one(Content::ToolUse {
1246                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1247                name: "get_weather".to_string(),
1248                input: json!({"location": "San Francisco, CA"}),
1249            }),
1250        };
1251
1252        let tool_message = Message {
1253            role: Role::User,
1254            content: OneOrMany::one(Content::ToolResult {
1255                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1256                content: OneOrMany::one(ToolResultContent::Text {
1257                    text: "15 degrees".to_string(),
1258                }),
1259                is_error: None,
1260                cache_control: None,
1261            }),
1262        };
1263
1264        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1265        let converted_assistant_message: message::Message =
1266            assistant_message.clone().try_into().unwrap();
1267        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1268
1269        match converted_user_message.clone() {
1270            message::Message::User { content } => {
1271                assert_eq!(content.len(), 3);
1272
1273                let mut iter = content.into_iter();
1274
1275                match iter.next().unwrap() {
1276                    message::UserContent::Image(message::Image {
1277                        data, media_type, ..
1278                    }) => {
1279                        assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1280                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1281                    }
1282                    _ => panic!("Expected image content"),
1283                }
1284
1285                match iter.next().unwrap() {
1286                    message::UserContent::Text(message::Text { text }) => {
1287                        assert_eq!(text, "What is in this image?");
1288                    }
1289                    _ => panic!("Expected text content"),
1290                }
1291
1292                match iter.next().unwrap() {
1293                    message::UserContent::Document(message::Document {
1294                        data, media_type, ..
1295                    }) => {
1296                        assert_eq!(
1297                            data,
1298                            DocumentSourceKind::String("base64_encoded_pdf_data".into())
1299                        );
1300                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1301                    }
1302                    _ => panic!("Expected document content"),
1303                }
1304
1305                assert_eq!(iter.next(), None);
1306            }
1307            _ => panic!("Expected user message"),
1308        }
1309
1310        match converted_tool_message.clone() {
1311            message::Message::User { content } => {
1312                let message::ToolResult { id, content, .. } = match content.first() {
1313                    message::UserContent::ToolResult(tool_result) => tool_result,
1314                    _ => panic!("Expected tool result content"),
1315                };
1316                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1317                match content.first() {
1318                    message::ToolResultContent::Text(message::Text { text }) => {
1319                        assert_eq!(text, "15 degrees");
1320                    }
1321                    _ => panic!("Expected text content"),
1322                }
1323            }
1324            _ => panic!("Expected tool result content"),
1325        }
1326
1327        match converted_assistant_message.clone() {
1328            message::Message::Assistant { content, .. } => {
1329                assert_eq!(content.len(), 1);
1330
1331                match content.first() {
1332                    message::AssistantContent::ToolCall(message::ToolCall {
1333                        id, function, ..
1334                    }) => {
1335                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1336                        assert_eq!(function.name, "get_weather");
1337                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1338                    }
1339                    _ => panic!("Expected tool call content"),
1340                }
1341            }
1342            _ => panic!("Expected assistant message"),
1343        }
1344
1345        let original_user_message: Message = converted_user_message.try_into().unwrap();
1346        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1347        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1348
1349        assert_eq!(user_message, original_user_message);
1350        assert_eq!(assistant_message, original_assistant_message);
1351        assert_eq!(tool_message, original_tool_message);
1352    }
1353
1354    #[test]
1355    fn test_content_format_conversion() {
1356        use crate::completion::message::ContentFormat;
1357
1358        let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1359        assert_eq!(source_type, SourceType::URL);
1360
1361        let content_format: ContentFormat = SourceType::URL.into();
1362        assert_eq!(content_format, ContentFormat::Url);
1363
1364        let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1365        assert_eq!(source_type, SourceType::BASE64);
1366
1367        let content_format: ContentFormat = SourceType::BASE64.into();
1368        assert_eq!(content_format, ContentFormat::Base64);
1369
1370        let result: Result<SourceType, _> = ContentFormat::String.try_into();
1371        assert!(result.is_err());
1372        assert!(
1373            result
1374                .unwrap_err()
1375                .to_string()
1376                .contains("ContentFormat::String is deprecated")
1377        );
1378    }
1379
1380    #[test]
1381    fn test_cache_control_serialization() {
1382        // Test SystemContent with cache_control
1383        let system = SystemContent::Text {
1384            text: "You are a helpful assistant.".to_string(),
1385            cache_control: Some(CacheControl::Ephemeral),
1386        };
1387        let json = serde_json::to_string(&system).unwrap();
1388        assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1389        assert!(json.contains(r#""type":"text""#));
1390
1391        // Test SystemContent without cache_control (should not have cache_control field)
1392        let system_no_cache = SystemContent::Text {
1393            text: "Hello".to_string(),
1394            cache_control: None,
1395        };
1396        let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1397        assert!(!json_no_cache.contains("cache_control"));
1398
1399        // Test Content::Text with cache_control
1400        let content = Content::Text {
1401            text: "Test message".to_string(),
1402            cache_control: Some(CacheControl::Ephemeral),
1403        };
1404        let json_content = serde_json::to_string(&content).unwrap();
1405        assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1406
1407        // Test apply_cache_control function
1408        let mut system_vec = vec![SystemContent::Text {
1409            text: "System prompt".to_string(),
1410            cache_control: None,
1411        }];
1412        let mut messages = vec![
1413            Message {
1414                role: Role::User,
1415                content: OneOrMany::one(Content::Text {
1416                    text: "First message".to_string(),
1417                    cache_control: None,
1418                }),
1419            },
1420            Message {
1421                role: Role::Assistant,
1422                content: OneOrMany::one(Content::Text {
1423                    text: "Response".to_string(),
1424                    cache_control: None,
1425                }),
1426            },
1427        ];
1428
1429        apply_cache_control(&mut system_vec, &mut messages);
1430
1431        // System should have cache_control
1432        match &system_vec[0] {
1433            SystemContent::Text { cache_control, .. } => {
1434                assert!(cache_control.is_some());
1435            }
1436        }
1437
1438        // Only the last content block of last message should have cache_control
1439        // First message should NOT have cache_control
1440        for content in messages[0].content.iter() {
1441            if let Content::Text { cache_control, .. } = content {
1442                assert!(cache_control.is_none());
1443            }
1444        }
1445
1446        // Last message SHOULD have cache_control
1447        for content in messages[1].content.iter() {
1448            if let Content::Text { cache_control, .. } = content {
1449                assert!(cache_control.is_some());
1450            }
1451        }
1452    }
1453}