rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError, GetTokenUsage},
6    http_client::HttpClientExt,
7    message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
8    one_or_many::string_or_one_or_many,
9    telemetry::{ProviderResponseExt, SpanCombinator},
10    wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21// ================================================================
22// Anthropic Completion API
23// ================================================================
24
25/// `claude-opus-4-0` completion model
26pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27/// `claude-sonnet-4-0` completion model
28pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29/// `claude-3-7-sonnet-latest` completion model
30pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31/// `claude-3-5-sonnet-latest` completion model
32pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33/// `claude-3-5-haiku-latest` completion model
34pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42    pub content: Vec<Content>,
43    pub id: String,
44    pub model: String,
45    pub role: String,
46    pub stop_reason: Option<String>,
47    pub stop_sequence: Option<String>,
48    pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52    type OutputMessage = Content;
53    type Usage = Usage;
54
55    fn get_response_id(&self) -> Option<String> {
56        Some(self.id.to_owned())
57    }
58
59    fn get_response_model_name(&self) -> Option<String> {
60        Some(self.model.to_owned())
61    }
62
63    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64        self.content.clone()
65    }
66
67    fn get_text_response(&self) -> Option<String> {
68        let res = self
69            .content
70            .iter()
71            .filter_map(|x| {
72                if let Content::Text { text, .. } = x {
73                    Some(text.to_owned())
74                } else {
75                    None
76                }
77            })
78            .collect::<Vec<String>>()
79            .join("\n");
80
81        if res.is_empty() { None } else { Some(res) }
82    }
83
84    fn get_usage(&self) -> Option<Self::Usage> {
85        Some(self.usage.clone())
86    }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91    pub input_tokens: u64,
92    pub cache_read_input_tokens: Option<u64>,
93    pub cache_creation_input_tokens: Option<u64>,
94    pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        write!(
100            f,
101            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102            self.input_tokens,
103            match self.cache_read_input_tokens {
104                Some(token) => token.to_string(),
105                None => "n/a".to_string(),
106            },
107            match self.cache_creation_input_tokens {
108                Some(token) => token.to_string(),
109                None => "n/a".to_string(),
110            },
111            self.output_tokens
112        )
113    }
114}
115
116impl GetTokenUsage for Usage {
117    fn token_usage(&self) -> Option<crate::completion::Usage> {
118        let mut usage = crate::completion::Usage::new();
119
120        usage.input_tokens = self.input_tokens
121            + self.cache_creation_input_tokens.unwrap_or_default()
122            + self.cache_read_input_tokens.unwrap_or_default();
123        usage.output_tokens = self.output_tokens;
124        usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126        Some(usage)
127    }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132    pub name: String,
133    pub description: Option<String>,
134    pub input_schema: serde_json::Value,
135}
136
137/// Cache control directive for Anthropic prompt caching
138#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum CacheControl {
141    Ephemeral,
142}
143
144/// System message content block with optional cache control
145#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
146#[serde(tag = "type", rename_all = "snake_case")]
147pub enum SystemContent {
148    Text {
149        text: String,
150        #[serde(skip_serializing_if = "Option::is_none")]
151        cache_control: Option<CacheControl>,
152    },
153}
154
155impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
156    type Error = CompletionError;
157
158    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
159        let content = response
160            .content
161            .iter()
162            .map(|content| content.clone().try_into())
163            .collect::<Result<Vec<_>, _>>()?;
164
165        let choice = OneOrMany::many(content).map_err(|_| {
166            CompletionError::ResponseError(
167                "Response contained no message or tool call (empty)".to_owned(),
168            )
169        })?;
170
171        let usage = completion::Usage {
172            input_tokens: response.usage.input_tokens,
173            output_tokens: response.usage.output_tokens,
174            total_tokens: response.usage.input_tokens + response.usage.output_tokens,
175        };
176
177        Ok(completion::CompletionResponse {
178            choice,
179            usage,
180            raw_response: response,
181        })
182    }
183}
184
185#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
186pub struct Message {
187    pub role: Role,
188    #[serde(deserialize_with = "string_or_one_or_many")]
189    pub content: OneOrMany<Content>,
190}
191
192#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
193#[serde(rename_all = "lowercase")]
194pub enum Role {
195    User,
196    Assistant,
197}
198
199#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
200#[serde(tag = "type", rename_all = "snake_case")]
201pub enum Content {
202    Text {
203        text: String,
204        #[serde(skip_serializing_if = "Option::is_none")]
205        cache_control: Option<CacheControl>,
206    },
207    Image {
208        source: ImageSource,
209        #[serde(skip_serializing_if = "Option::is_none")]
210        cache_control: Option<CacheControl>,
211    },
212    ToolUse {
213        id: String,
214        name: String,
215        input: serde_json::Value,
216    },
217    ToolResult {
218        tool_use_id: String,
219        #[serde(deserialize_with = "string_or_one_or_many")]
220        content: OneOrMany<ToolResultContent>,
221        #[serde(skip_serializing_if = "Option::is_none")]
222        is_error: Option<bool>,
223        #[serde(skip_serializing_if = "Option::is_none")]
224        cache_control: Option<CacheControl>,
225    },
226    Document {
227        source: DocumentSource,
228        #[serde(skip_serializing_if = "Option::is_none")]
229        cache_control: Option<CacheControl>,
230    },
231    Thinking {
232        thinking: String,
233        #[serde(skip_serializing_if = "Option::is_none")]
234        signature: Option<String>,
235    },
236}
237
238impl FromStr for Content {
239    type Err = Infallible;
240
241    fn from_str(s: &str) -> Result<Self, Self::Err> {
242        Ok(Content::Text {
243            text: s.to_owned(),
244            cache_control: None,
245        })
246    }
247}
248
249#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
250#[serde(tag = "type", rename_all = "snake_case")]
251pub enum ToolResultContent {
252    Text { text: String },
253    Image(ImageSource),
254}
255
256impl FromStr for ToolResultContent {
257    type Err = Infallible;
258
259    fn from_str(s: &str) -> Result<Self, Self::Err> {
260        Ok(ToolResultContent::Text { text: s.to_owned() })
261    }
262}
263
264#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
265#[serde(untagged)]
266pub enum ImageSourceData {
267    Base64(String),
268    Url(String),
269}
270
271impl From<ImageSourceData> for DocumentSourceKind {
272    fn from(value: ImageSourceData) -> Self {
273        match value {
274            ImageSourceData::Base64(data) => DocumentSourceKind::Base64(data),
275            ImageSourceData::Url(url) => DocumentSourceKind::Url(url),
276        }
277    }
278}
279
280impl TryFrom<DocumentSourceKind> for ImageSourceData {
281    type Error = MessageError;
282
283    fn try_from(value: DocumentSourceKind) -> Result<Self, Self::Error> {
284        match value {
285            DocumentSourceKind::Base64(data) => Ok(ImageSourceData::Base64(data)),
286            DocumentSourceKind::Url(url) => Ok(ImageSourceData::Url(url)),
287            _ => Err(MessageError::ConversionError("Content has no body".into())),
288        }
289    }
290}
291
292impl From<ImageSourceData> for String {
293    fn from(value: ImageSourceData) -> Self {
294        match value {
295            ImageSourceData::Base64(s) | ImageSourceData::Url(s) => s,
296        }
297    }
298}
299
300#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
301pub struct ImageSource {
302    pub data: ImageSourceData,
303    pub media_type: ImageFormat,
304    pub r#type: SourceType,
305}
306
307#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
308pub struct DocumentSource {
309    pub data: String,
310    pub media_type: DocumentFormat,
311    pub r#type: SourceType,
312}
313
314#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
315#[serde(rename_all = "lowercase")]
316pub enum ImageFormat {
317    #[serde(rename = "image/jpeg")]
318    JPEG,
319    #[serde(rename = "image/png")]
320    PNG,
321    #[serde(rename = "image/gif")]
322    GIF,
323    #[serde(rename = "image/webp")]
324    WEBP,
325}
326
327/// The document format to be used.
328///
329/// Currently, Anthropic only supports PDF for text documents over the API (within a message). You can find more information about this here: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
330#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
331#[serde(rename_all = "lowercase")]
332pub enum DocumentFormat {
333    #[serde(rename = "application/pdf")]
334    PDF,
335}
336
337#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
338#[serde(rename_all = "lowercase")]
339pub enum SourceType {
340    BASE64,
341    URL,
342}
343
344impl From<String> for Content {
345    fn from(text: String) -> Self {
346        Content::Text {
347            text,
348            cache_control: None,
349        }
350    }
351}
352
353impl From<String> for ToolResultContent {
354    fn from(text: String) -> Self {
355        ToolResultContent::Text { text }
356    }
357}
358
359impl TryFrom<message::ContentFormat> for SourceType {
360    type Error = MessageError;
361
362    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
363        match format {
364            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
365            message::ContentFormat::Url => Ok(SourceType::URL),
366            message::ContentFormat::String => Err(MessageError::ConversionError(
367                "ContentFormat::String is deprecated, use ContentFormat::Url for URLs".into(),
368            )),
369        }
370    }
371}
372
373impl From<SourceType> for message::ContentFormat {
374    fn from(source_type: SourceType) -> Self {
375        match source_type {
376            SourceType::BASE64 => message::ContentFormat::Base64,
377            SourceType::URL => message::ContentFormat::Url,
378        }
379    }
380}
381
382impl TryFrom<message::ImageMediaType> for ImageFormat {
383    type Error = MessageError;
384
385    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
386        Ok(match media_type {
387            message::ImageMediaType::JPEG => ImageFormat::JPEG,
388            message::ImageMediaType::PNG => ImageFormat::PNG,
389            message::ImageMediaType::GIF => ImageFormat::GIF,
390            message::ImageMediaType::WEBP => ImageFormat::WEBP,
391            _ => {
392                return Err(MessageError::ConversionError(
393                    format!("Unsupported image media type: {media_type:?}").to_owned(),
394                ));
395            }
396        })
397    }
398}
399
400impl From<ImageFormat> for message::ImageMediaType {
401    fn from(format: ImageFormat) -> Self {
402        match format {
403            ImageFormat::JPEG => message::ImageMediaType::JPEG,
404            ImageFormat::PNG => message::ImageMediaType::PNG,
405            ImageFormat::GIF => message::ImageMediaType::GIF,
406            ImageFormat::WEBP => message::ImageMediaType::WEBP,
407        }
408    }
409}
410
411impl TryFrom<DocumentMediaType> for DocumentFormat {
412    type Error = MessageError;
413    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
414        if !matches!(value, DocumentMediaType::PDF) {
415            return Err(MessageError::ConversionError(
416                "Anthropic only supports PDF documents".to_string(),
417            ));
418        };
419
420        Ok(DocumentFormat::PDF)
421    }
422}
423
424impl TryFrom<message::AssistantContent> for Content {
425    type Error = MessageError;
426    fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
427        match text {
428            message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
429                text,
430                cache_control: None,
431            }),
432            message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
433                "Anthropic currently doesn't support images.".to_string(),
434            )),
435            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
436                Ok(Content::ToolUse {
437                    id,
438                    name: function.name,
439                    input: function.arguments,
440                })
441            }
442            message::AssistantContent::Reasoning(Reasoning {
443                reasoning,
444                signature,
445                ..
446            }) => Ok(Content::Thinking {
447                thinking: reasoning.first().cloned().unwrap_or(String::new()),
448                signature,
449            }),
450        }
451    }
452}
453
454impl TryFrom<message::Message> for Message {
455    type Error = MessageError;
456
457    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
458        Ok(match message {
459            message::Message::User { content } => Message {
460                role: Role::User,
461                content: content.try_map(|content| match content {
462                    message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
463                        text,
464                        cache_control: None,
465                    }),
466                    message::UserContent::ToolResult(message::ToolResult {
467                        id, content, ..
468                    }) => Ok(Content::ToolResult {
469                        tool_use_id: id,
470                        content: content.try_map(|content| match content {
471                            message::ToolResultContent::Text(message::Text { text }) => {
472                                Ok(ToolResultContent::Text { text })
473                            }
474                            message::ToolResultContent::Image(image) => {
475                                let DocumentSourceKind::Base64(data) = image.data else {
476                                    return Err(MessageError::ConversionError(
477                                        "Only base64 strings can be used with the Anthropic API"
478                                            .to_string(),
479                                    ));
480                                };
481                                let media_type =
482                                    image.media_type.ok_or(MessageError::ConversionError(
483                                        "Image media type is required".to_owned(),
484                                    ))?;
485                                Ok(ToolResultContent::Image(ImageSource {
486                                    data: ImageSourceData::Base64(data),
487                                    media_type: media_type.try_into()?,
488                                    r#type: SourceType::BASE64,
489                                }))
490                            }
491                        })?,
492                        is_error: None,
493                        cache_control: None,
494                    }),
495                    message::UserContent::Image(message::Image {
496                        data, media_type, ..
497                    }) => {
498                        let media_type = media_type.ok_or(MessageError::ConversionError(
499                            "Image media type is required for Claude API".to_string(),
500                        ))?;
501
502                        let source = match data {
503                            DocumentSourceKind::Base64(data) => ImageSource {
504                                data: ImageSourceData::Base64(data),
505                                r#type: SourceType::BASE64,
506                                media_type: ImageFormat::try_from(media_type)?,
507                            },
508                            DocumentSourceKind::Url(url) => ImageSource {
509                                data: ImageSourceData::Url(url),
510                                r#type: SourceType::URL,
511                                media_type: ImageFormat::try_from(media_type)?,
512                            },
513                            DocumentSourceKind::Unknown => {
514                                return Err(MessageError::ConversionError(
515                                    "Image content has no body".into(),
516                                ));
517                            }
518                            doc => {
519                                return Err(MessageError::ConversionError(format!(
520                                    "Unsupported document type: {doc:?}"
521                                )));
522                            }
523                        };
524
525                        Ok(Content::Image {
526                            source,
527                            cache_control: None,
528                        })
529                    }
530                    message::UserContent::Document(message::Document {
531                        data, media_type, ..
532                    }) => {
533                        let media_type = media_type.ok_or(MessageError::ConversionError(
534                            "Document media type is required".to_string(),
535                        ))?;
536
537                        let data = match data {
538                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
539                                data
540                            }
541                            _ => {
542                                return Err(MessageError::ConversionError(
543                                    "Only base64 encoded documents currently supported".into(),
544                                ));
545                            }
546                        };
547
548                        let source = DocumentSource {
549                            data,
550                            media_type: media_type.try_into()?,
551                            r#type: SourceType::BASE64,
552                        };
553                        Ok(Content::Document {
554                            source,
555                            cache_control: None,
556                        })
557                    }
558                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
559                        "Audio is not supported in Anthropic".to_owned(),
560                    )),
561                    message::UserContent::Video { .. } => Err(MessageError::ConversionError(
562                        "Video is not supported in Anthropic".to_owned(),
563                    )),
564                })?,
565            },
566
567            message::Message::Assistant { content, .. } => Message {
568                content: content.try_map(|content| content.try_into())?,
569                role: Role::Assistant,
570            },
571        })
572    }
573}
574
575impl TryFrom<Content> for message::AssistantContent {
576    type Error = MessageError;
577
578    fn try_from(content: Content) -> Result<Self, Self::Error> {
579        Ok(match content {
580            Content::Text { text, .. } => message::AssistantContent::text(text),
581            Content::ToolUse { id, name, input } => {
582                message::AssistantContent::tool_call(id, name, input)
583            }
584            Content::Thinking {
585                thinking,
586                signature,
587            } => message::AssistantContent::Reasoning(
588                Reasoning::new(&thinking).with_signature(signature),
589            ),
590            _ => {
591                return Err(MessageError::ConversionError(
592                    "Content did not contain a message, tool call, or reasoning".to_owned(),
593                ));
594            }
595        })
596    }
597}
598
599impl From<ToolResultContent> for message::ToolResultContent {
600    fn from(content: ToolResultContent) -> Self {
601        match content {
602            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
603            ToolResultContent::Image(ImageSource {
604                data,
605                media_type: format,
606                ..
607            }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
608        }
609    }
610}
611
612impl TryFrom<Message> for message::Message {
613    type Error = MessageError;
614
615    fn try_from(message: Message) -> Result<Self, Self::Error> {
616        Ok(match message.role {
617            Role::User => message::Message::User {
618                content: message.content.try_map(|content| {
619                    Ok(match content {
620                        Content::Text { text, .. } => message::UserContent::text(text),
621                        Content::ToolResult {
622                            tool_use_id,
623                            content,
624                            ..
625                        } => message::UserContent::tool_result(
626                            tool_use_id,
627                            content.map(|content| content.into()),
628                        ),
629                        Content::Image { source, .. } => {
630                            message::UserContent::Image(message::Image {
631                                data: source.data.into(),
632                                media_type: Some(source.media_type.into()),
633                                detail: None,
634                                additional_params: None,
635                            })
636                        }
637                        Content::Document { source, .. } => message::UserContent::document(
638                            source.data,
639                            Some(message::DocumentMediaType::PDF),
640                        ),
641                        _ => {
642                            return Err(MessageError::ConversionError(
643                                "Unsupported content type for User role".to_owned(),
644                            ));
645                        }
646                    })
647                })?,
648            },
649            Role::Assistant => match message.content.first() {
650                Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
651                    message::Message::Assistant {
652                        id: None,
653                        content: message.content.try_map(|content| content.try_into())?,
654                    }
655                }
656
657                _ => {
658                    return Err(MessageError::ConversionError(
659                        format!("Unsupported message for Assistant role: {message:?}").to_owned(),
660                    ));
661                }
662            },
663        })
664    }
665}
666
667#[derive(Clone)]
668pub struct CompletionModel<T = reqwest::Client> {
669    pub(crate) client: Client<T>,
670    pub model: String,
671    pub default_max_tokens: Option<u64>,
672    /// Enable automatic prompt caching (adds cache_control breakpoints to system prompt and messages)
673    pub prompt_caching: bool,
674}
675
676impl<T> CompletionModel<T>
677where
678    T: HttpClientExt,
679{
680    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
681        let model = model.into();
682        let default_max_tokens = calculate_max_tokens(&model);
683
684        Self {
685            client,
686            model,
687            default_max_tokens,
688            prompt_caching: false, // Default to off
689        }
690    }
691
692    pub fn with_model(client: Client<T>, model: &str) -> Self {
693        Self {
694            client,
695            model: model.to_string(),
696            default_max_tokens: Some(calculate_max_tokens_custom(model)),
697            prompt_caching: false, // Default to off
698        }
699    }
700
701    /// Enable automatic prompt caching.
702    ///
703    /// When enabled, cache_control breakpoints are automatically added to:
704    /// - The system prompt (marked with ephemeral cache)
705    /// - The last content block of the last message (marked with ephemeral cache)
706    ///
707    /// This allows Anthropic to cache the conversation history for cost savings.
708    pub fn with_prompt_caching(mut self) -> Self {
709        self.prompt_caching = true;
710        self
711    }
712}
713
714/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
715/// set or if set too high, the request will fail. The following values are based on the models
716/// available at the time of writing.
717fn calculate_max_tokens(model: &str) -> Option<u64> {
718    if model.starts_with("claude-opus-4") {
719        Some(32000)
720    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
721        Some(64000)
722    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
723        Some(8192)
724    } else if model.starts_with("claude-3-opus")
725        || model.starts_with("claude-3-sonnet")
726        || model.starts_with("claude-3-haiku")
727    {
728        Some(4096)
729    } else {
730        None
731    }
732}
733
734fn calculate_max_tokens_custom(model: &str) -> u64 {
735    if model.starts_with("claude-opus-4") {
736        32000
737    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
738        64000
739    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
740        8192
741    } else if model.starts_with("claude-3-opus")
742        || model.starts_with("claude-3-sonnet")
743        || model.starts_with("claude-3-haiku")
744    {
745        4096
746    } else {
747        2048
748    }
749}
750
751#[derive(Debug, Deserialize, Serialize)]
752pub struct Metadata {
753    user_id: Option<String>,
754}
755
756#[derive(Default, Debug, Serialize, Deserialize)]
757#[serde(tag = "type", rename_all = "snake_case")]
758pub enum ToolChoice {
759    #[default]
760    Auto,
761    Any,
762    None,
763    Tool {
764        name: String,
765    },
766}
767impl TryFrom<message::ToolChoice> for ToolChoice {
768    type Error = CompletionError;
769
770    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
771        let res = match value {
772            message::ToolChoice::Auto => Self::Auto,
773            message::ToolChoice::None => Self::None,
774            message::ToolChoice::Required => Self::Any,
775            message::ToolChoice::Specific { function_names } => {
776                if function_names.len() != 1 {
777                    return Err(CompletionError::ProviderError(
778                        "Only one tool may be specified to be used by Claude".into(),
779                    ));
780                }
781
782                Self::Tool {
783                    name: function_names.first().unwrap().to_string(),
784                }
785            }
786        };
787
788        Ok(res)
789    }
790}
791
792#[derive(Debug, Deserialize, Serialize)]
793struct AnthropicCompletionRequest {
794    model: String,
795    messages: Vec<Message>,
796    max_tokens: u64,
797    /// System prompt as array of content blocks to support cache_control
798    #[serde(skip_serializing_if = "Vec::is_empty")]
799    system: Vec<SystemContent>,
800    #[serde(skip_serializing_if = "Option::is_none")]
801    temperature: Option<f64>,
802    #[serde(skip_serializing_if = "Option::is_none")]
803    tool_choice: Option<ToolChoice>,
804    #[serde(skip_serializing_if = "Vec::is_empty")]
805    tools: Vec<ToolDefinition>,
806    #[serde(flatten, skip_serializing_if = "Option::is_none")]
807    additional_params: Option<serde_json::Value>,
808}
809
810/// Helper to set cache_control on a Content block
811fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
812    match content {
813        Content::Text { cache_control, .. } => *cache_control = value,
814        Content::Image { cache_control, .. } => *cache_control = value,
815        Content::ToolResult { cache_control, .. } => *cache_control = value,
816        Content::Document { cache_control, .. } => *cache_control = value,
817        _ => {}
818    }
819}
820
821/// Apply cache control breakpoints to system prompt and messages.
822/// Strategy: cache the system prompt, and mark the last content block of the last message
823/// for caching. This allows the conversation history to be cached while new messages
824/// are added.
825pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
826    // Add cache_control to the system prompt (if non-empty)
827    if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
828        *cache_control = Some(CacheControl::Ephemeral);
829    }
830
831    // Clear any existing cache_control from all message content blocks
832    for msg in messages.iter_mut() {
833        for content in msg.content.iter_mut() {
834            set_content_cache_control(content, None);
835        }
836    }
837
838    // Add cache_control to the last content block of the last message
839    if let Some(last_msg) = messages.last_mut() {
840        set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::Ephemeral));
841    }
842}
843
844/// Parameters for building an AnthropicCompletionRequest
845pub struct AnthropicRequestParams<'a> {
846    pub model: &'a str,
847    pub request: CompletionRequest,
848    pub prompt_caching: bool,
849}
850
851impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
852    type Error = CompletionError;
853
854    fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
855        let AnthropicRequestParams {
856            model,
857            request: req,
858            prompt_caching,
859        } = params;
860
861        // Check if max_tokens is set, required for Anthropic
862        let Some(max_tokens) = req.max_tokens else {
863            return Err(CompletionError::RequestError(
864                "`max_tokens` must be set for Anthropic".into(),
865            ));
866        };
867
868        let mut full_history = vec![];
869        if let Some(docs) = req.normalized_documents() {
870            full_history.push(docs);
871        }
872        full_history.extend(req.chat_history);
873
874        let mut messages = full_history
875            .into_iter()
876            .map(Message::try_from)
877            .collect::<Result<Vec<Message>, _>>()?;
878
879        let tools = req
880            .tools
881            .into_iter()
882            .map(|tool| ToolDefinition {
883                name: tool.name,
884                description: Some(tool.description),
885                input_schema: tool.parameters,
886            })
887            .collect::<Vec<_>>();
888
889        // Convert system prompt to array format for cache_control support
890        let mut system = if let Some(preamble) = req.preamble {
891            if preamble.is_empty() {
892                vec![]
893            } else {
894                vec![SystemContent::Text {
895                    text: preamble,
896                    cache_control: None,
897                }]
898            }
899        } else {
900            vec![]
901        };
902
903        // Apply cache control breakpoints only if prompt_caching is enabled
904        if prompt_caching {
905            apply_cache_control(&mut system, &mut messages);
906        }
907
908        Ok(Self {
909            model: model.to_string(),
910            messages,
911            max_tokens,
912            system,
913            temperature: req.temperature,
914            tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
915            tools,
916            additional_params: req.additional_params,
917        })
918    }
919}
920
921impl<T> completion::CompletionModel for CompletionModel<T>
922where
923    T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
924{
925    type Response = CompletionResponse;
926    type StreamingResponse = StreamingCompletionResponse;
927    type Client = Client<T>;
928
929    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
930        Self::new(client.clone(), model.into())
931    }
932
933    async fn completion(
934        &self,
935        mut completion_request: completion::CompletionRequest,
936    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
937        let span = if tracing::Span::current().is_disabled() {
938            info_span!(
939                target: "rig::completions",
940                "chat",
941                gen_ai.operation.name = "chat",
942                gen_ai.provider.name = "anthropic",
943                gen_ai.request.model = &self.model,
944                gen_ai.system_instructions = &completion_request.preamble,
945                gen_ai.response.id = tracing::field::Empty,
946                gen_ai.response.model = tracing::field::Empty,
947                gen_ai.usage.output_tokens = tracing::field::Empty,
948                gen_ai.usage.input_tokens = tracing::field::Empty,
949            )
950        } else {
951            tracing::Span::current()
952        };
953
954        // Check if max_tokens is set, required for Anthropic
955        if completion_request.max_tokens.is_none() {
956            if let Some(tokens) = self.default_max_tokens {
957                completion_request.max_tokens = Some(tokens);
958            } else {
959                return Err(CompletionError::RequestError(
960                    "`max_tokens` must be set for Anthropic".into(),
961                ));
962            }
963        }
964
965        let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
966            model: &self.model,
967            request: completion_request,
968            prompt_caching: self.prompt_caching,
969        })?;
970
971        if enabled!(Level::TRACE) {
972            tracing::trace!(
973                target: "rig::completions",
974                "Anthropic completion request: {}",
975                serde_json::to_string_pretty(&request)?
976            );
977        }
978
979        async move {
980            let request: Vec<u8> = serde_json::to_vec(&request)?;
981
982            let req = self
983                .client
984                .post("/v1/messages")?
985                .body(request)
986                .map_err(|e| CompletionError::HttpError(e.into()))?;
987
988            let response = self
989                .client
990                .send::<_, Bytes>(req)
991                .await
992                .map_err(CompletionError::HttpError)?;
993
994            if response.status().is_success() {
995                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
996                    response
997                        .into_body()
998                        .await
999                        .map_err(CompletionError::HttpError)?
1000                        .to_vec()
1001                        .as_slice(),
1002                )? {
1003                    ApiResponse::Message(completion) => {
1004                        let span = tracing::Span::current();
1005                        span.record_response_metadata(&completion);
1006                        span.record_token_usage(&completion.usage);
1007                        if enabled!(Level::TRACE) {
1008                            tracing::trace!(
1009                                target: "rig::completions",
1010                                "Anthropic completion response: {}",
1011                                serde_json::to_string_pretty(&completion)?
1012                            );
1013                        }
1014                        completion.try_into()
1015                    }
1016                    ApiResponse::Error(ApiErrorResponse { message }) => {
1017                        Err(CompletionError::ResponseError(message))
1018                    }
1019                }
1020            } else {
1021                let text: String = String::from_utf8_lossy(
1022                    &response
1023                        .into_body()
1024                        .await
1025                        .map_err(CompletionError::HttpError)?,
1026                )
1027                .into();
1028                Err(CompletionError::ProviderError(text))
1029            }
1030        }
1031        .instrument(span)
1032        .await
1033    }
1034
1035    async fn stream(
1036        &self,
1037        request: CompletionRequest,
1038    ) -> Result<
1039        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1040        CompletionError,
1041    > {
1042        CompletionModel::stream(self, request).await
1043    }
1044}
1045
1046#[derive(Debug, Deserialize)]
1047struct ApiErrorResponse {
1048    message: String,
1049}
1050
1051#[derive(Debug, Deserialize)]
1052#[serde(tag = "type", rename_all = "snake_case")]
1053enum ApiResponse<T> {
1054    Message(T),
1055    Error(ApiErrorResponse),
1056}
1057
1058#[cfg(test)]
1059mod tests {
1060    use super::*;
1061    use serde_json::json;
1062    use serde_path_to_error::deserialize;
1063
1064    #[test]
1065    fn test_deserialize_message() {
1066        let assistant_message_json = r#"
1067        {
1068            "role": "assistant",
1069            "content": "\n\nHello there, how may I assist you today?"
1070        }
1071        "#;
1072
1073        let assistant_message_json2 = r#"
1074        {
1075            "role": "assistant",
1076            "content": [
1077                {
1078                    "type": "text",
1079                    "text": "\n\nHello there, how may I assist you today?"
1080                },
1081                {
1082                    "type": "tool_use",
1083                    "id": "toolu_01A09q90qw90lq917835lq9",
1084                    "name": "get_weather",
1085                    "input": {"location": "San Francisco, CA"}
1086                }
1087            ]
1088        }
1089        "#;
1090
1091        let user_message_json = r#"
1092        {
1093            "role": "user",
1094            "content": [
1095                {
1096                    "type": "image",
1097                    "source": {
1098                        "type": "base64",
1099                        "media_type": "image/jpeg",
1100                        "data": "/9j/4AAQSkZJRg..."
1101                    }
1102                },
1103                {
1104                    "type": "text",
1105                    "text": "What is in this image?"
1106                },
1107                {
1108                    "type": "tool_result",
1109                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1110                    "content": "15 degrees"
1111                }
1112            ]
1113        }
1114        "#;
1115
1116        let assistant_message: Message = {
1117            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1118            deserialize(jd).unwrap_or_else(|err| {
1119                panic!("Deserialization error at {}: {}", err.path(), err);
1120            })
1121        };
1122
1123        let assistant_message2: Message = {
1124            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1125            deserialize(jd).unwrap_or_else(|err| {
1126                panic!("Deserialization error at {}: {}", err.path(), err);
1127            })
1128        };
1129
1130        let user_message: Message = {
1131            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1132            deserialize(jd).unwrap_or_else(|err| {
1133                panic!("Deserialization error at {}: {}", err.path(), err);
1134            })
1135        };
1136
1137        let Message { role, content } = assistant_message;
1138        assert_eq!(role, Role::Assistant);
1139        assert_eq!(
1140            content.first(),
1141            Content::Text {
1142                text: "\n\nHello there, how may I assist you today?".to_owned(),
1143                cache_control: None,
1144            }
1145        );
1146
1147        let Message { role, content } = assistant_message2;
1148        {
1149            assert_eq!(role, Role::Assistant);
1150            assert_eq!(content.len(), 2);
1151
1152            let mut iter = content.into_iter();
1153
1154            match iter.next().unwrap() {
1155                Content::Text { text, .. } => {
1156                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
1157                }
1158                _ => panic!("Expected text content"),
1159            }
1160
1161            match iter.next().unwrap() {
1162                Content::ToolUse { id, name, input } => {
1163                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1164                    assert_eq!(name, "get_weather");
1165                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
1166                }
1167                _ => panic!("Expected tool use content"),
1168            }
1169
1170            assert_eq!(iter.next(), None);
1171        }
1172
1173        let Message { role, content } = user_message;
1174        {
1175            assert_eq!(role, Role::User);
1176            assert_eq!(content.len(), 3);
1177
1178            let mut iter = content.into_iter();
1179
1180            match iter.next().unwrap() {
1181                Content::Image { source, .. } => {
1182                    assert_eq!(
1183                        source,
1184                        ImageSource {
1185                            data: ImageSourceData::Base64("/9j/4AAQSkZJRg...".to_owned()),
1186                            media_type: ImageFormat::JPEG,
1187                            r#type: SourceType::BASE64,
1188                        }
1189                    );
1190                }
1191                _ => panic!("Expected image content"),
1192            }
1193
1194            match iter.next().unwrap() {
1195                Content::Text { text, .. } => {
1196                    assert_eq!(text, "What is in this image?");
1197                }
1198                _ => panic!("Expected text content"),
1199            }
1200
1201            match iter.next().unwrap() {
1202                Content::ToolResult {
1203                    tool_use_id,
1204                    content,
1205                    is_error,
1206                    ..
1207                } => {
1208                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1209                    assert_eq!(
1210                        content.first(),
1211                        ToolResultContent::Text {
1212                            text: "15 degrees".to_owned()
1213                        }
1214                    );
1215                    assert_eq!(is_error, None);
1216                }
1217                _ => panic!("Expected tool result content"),
1218            }
1219
1220            assert_eq!(iter.next(), None);
1221        }
1222    }
1223
1224    #[test]
1225    fn test_message_to_message_conversion() {
1226        let user_message: Message = serde_json::from_str(
1227            r#"
1228        {
1229            "role": "user",
1230            "content": [
1231                {
1232                    "type": "image",
1233                    "source": {
1234                        "type": "base64",
1235                        "media_type": "image/jpeg",
1236                        "data": "/9j/4AAQSkZJRg..."
1237                    }
1238                },
1239                {
1240                    "type": "text",
1241                    "text": "What is in this image?"
1242                },
1243                {
1244                    "type": "document",
1245                    "source": {
1246                        "type": "base64",
1247                        "data": "base64_encoded_pdf_data",
1248                        "media_type": "application/pdf"
1249                    }
1250                }
1251            ]
1252        }
1253        "#,
1254        )
1255        .unwrap();
1256
1257        let assistant_message = Message {
1258            role: Role::Assistant,
1259            content: OneOrMany::one(Content::ToolUse {
1260                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1261                name: "get_weather".to_string(),
1262                input: json!({"location": "San Francisco, CA"}),
1263            }),
1264        };
1265
1266        let tool_message = Message {
1267            role: Role::User,
1268            content: OneOrMany::one(Content::ToolResult {
1269                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1270                content: OneOrMany::one(ToolResultContent::Text {
1271                    text: "15 degrees".to_string(),
1272                }),
1273                is_error: None,
1274                cache_control: None,
1275            }),
1276        };
1277
1278        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1279        let converted_assistant_message: message::Message =
1280            assistant_message.clone().try_into().unwrap();
1281        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1282
1283        match converted_user_message.clone() {
1284            message::Message::User { content } => {
1285                assert_eq!(content.len(), 3);
1286
1287                let mut iter = content.into_iter();
1288
1289                match iter.next().unwrap() {
1290                    message::UserContent::Image(message::Image {
1291                        data, media_type, ..
1292                    }) => {
1293                        assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1294                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1295                    }
1296                    _ => panic!("Expected image content"),
1297                }
1298
1299                match iter.next().unwrap() {
1300                    message::UserContent::Text(message::Text { text }) => {
1301                        assert_eq!(text, "What is in this image?");
1302                    }
1303                    _ => panic!("Expected text content"),
1304                }
1305
1306                match iter.next().unwrap() {
1307                    message::UserContent::Document(message::Document {
1308                        data, media_type, ..
1309                    }) => {
1310                        assert_eq!(
1311                            data,
1312                            DocumentSourceKind::String("base64_encoded_pdf_data".into())
1313                        );
1314                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1315                    }
1316                    _ => panic!("Expected document content"),
1317                }
1318
1319                assert_eq!(iter.next(), None);
1320            }
1321            _ => panic!("Expected user message"),
1322        }
1323
1324        match converted_tool_message.clone() {
1325            message::Message::User { content } => {
1326                let message::ToolResult { id, content, .. } = match content.first() {
1327                    message::UserContent::ToolResult(tool_result) => tool_result,
1328                    _ => panic!("Expected tool result content"),
1329                };
1330                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1331                match content.first() {
1332                    message::ToolResultContent::Text(message::Text { text }) => {
1333                        assert_eq!(text, "15 degrees");
1334                    }
1335                    _ => panic!("Expected text content"),
1336                }
1337            }
1338            _ => panic!("Expected tool result content"),
1339        }
1340
1341        match converted_assistant_message.clone() {
1342            message::Message::Assistant { content, .. } => {
1343                assert_eq!(content.len(), 1);
1344
1345                match content.first() {
1346                    message::AssistantContent::ToolCall(message::ToolCall {
1347                        id, function, ..
1348                    }) => {
1349                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1350                        assert_eq!(function.name, "get_weather");
1351                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1352                    }
1353                    _ => panic!("Expected tool call content"),
1354                }
1355            }
1356            _ => panic!("Expected assistant message"),
1357        }
1358
1359        let original_user_message: Message = converted_user_message.try_into().unwrap();
1360        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1361        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1362
1363        assert_eq!(user_message, original_user_message);
1364        assert_eq!(assistant_message, original_assistant_message);
1365        assert_eq!(tool_message, original_tool_message);
1366    }
1367
1368    #[test]
1369    fn test_content_format_conversion() {
1370        use crate::completion::message::ContentFormat;
1371
1372        let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1373        assert_eq!(source_type, SourceType::URL);
1374
1375        let content_format: ContentFormat = SourceType::URL.into();
1376        assert_eq!(content_format, ContentFormat::Url);
1377
1378        let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1379        assert_eq!(source_type, SourceType::BASE64);
1380
1381        let content_format: ContentFormat = SourceType::BASE64.into();
1382        assert_eq!(content_format, ContentFormat::Base64);
1383
1384        let result: Result<SourceType, _> = ContentFormat::String.try_into();
1385        assert!(result.is_err());
1386        assert!(
1387            result
1388                .unwrap_err()
1389                .to_string()
1390                .contains("ContentFormat::String is deprecated")
1391        );
1392    }
1393
1394    #[test]
1395    fn test_cache_control_serialization() {
1396        // Test SystemContent with cache_control
1397        let system = SystemContent::Text {
1398            text: "You are a helpful assistant.".to_string(),
1399            cache_control: Some(CacheControl::Ephemeral),
1400        };
1401        let json = serde_json::to_string(&system).unwrap();
1402        assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1403        assert!(json.contains(r#""type":"text""#));
1404
1405        // Test SystemContent without cache_control (should not have cache_control field)
1406        let system_no_cache = SystemContent::Text {
1407            text: "Hello".to_string(),
1408            cache_control: None,
1409        };
1410        let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1411        assert!(!json_no_cache.contains("cache_control"));
1412
1413        // Test Content::Text with cache_control
1414        let content = Content::Text {
1415            text: "Test message".to_string(),
1416            cache_control: Some(CacheControl::Ephemeral),
1417        };
1418        let json_content = serde_json::to_string(&content).unwrap();
1419        assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1420
1421        // Test apply_cache_control function
1422        let mut system_vec = vec![SystemContent::Text {
1423            text: "System prompt".to_string(),
1424            cache_control: None,
1425        }];
1426        let mut messages = vec![
1427            Message {
1428                role: Role::User,
1429                content: OneOrMany::one(Content::Text {
1430                    text: "First message".to_string(),
1431                    cache_control: None,
1432                }),
1433            },
1434            Message {
1435                role: Role::Assistant,
1436                content: OneOrMany::one(Content::Text {
1437                    text: "Response".to_string(),
1438                    cache_control: None,
1439                }),
1440            },
1441        ];
1442
1443        apply_cache_control(&mut system_vec, &mut messages);
1444
1445        // System should have cache_control
1446        match &system_vec[0] {
1447            SystemContent::Text { cache_control, .. } => {
1448                assert!(cache_control.is_some());
1449            }
1450        }
1451
1452        // Only the last content block of last message should have cache_control
1453        // First message should NOT have cache_control
1454        for content in messages[0].content.iter() {
1455            if let Content::Text { cache_control, .. } = content {
1456                assert!(cache_control.is_none());
1457            }
1458        }
1459
1460        // Last message SHOULD have cache_control
1461        for content in messages[1].content.iter() {
1462            if let Content::Text { cache_control, .. } = content {
1463                assert!(cache_control.is_some());
1464            }
1465        }
1466    }
1467}