Skip to main content

rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError, GetTokenUsage},
6    http_client::HttpClientExt,
7    message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
8    one_or_many::string_or_one_or_many,
9    telemetry::{ProviderResponseExt, SpanCombinator},
10    wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21// ================================================================
22// Anthropic Completion API
23// ================================================================
24
25/// `claude-opus-4-0` completion model
26pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27/// `claude-sonnet-4-0` completion model
28pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29/// `claude-3-7-sonnet-latest` completion model
30pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31/// `claude-3-5-sonnet-latest` completion model
32pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33/// `claude-3-5-haiku-latest` completion model
34pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42    pub content: Vec<Content>,
43    pub id: String,
44    pub model: String,
45    pub role: String,
46    pub stop_reason: Option<String>,
47    pub stop_sequence: Option<String>,
48    pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52    type OutputMessage = Content;
53    type Usage = Usage;
54
55    fn get_response_id(&self) -> Option<String> {
56        Some(self.id.to_owned())
57    }
58
59    fn get_response_model_name(&self) -> Option<String> {
60        Some(self.model.to_owned())
61    }
62
63    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64        self.content.clone()
65    }
66
67    fn get_text_response(&self) -> Option<String> {
68        let res = self
69            .content
70            .iter()
71            .filter_map(|x| {
72                if let Content::Text { text, .. } = x {
73                    Some(text.to_owned())
74                } else {
75                    None
76                }
77            })
78            .collect::<Vec<String>>()
79            .join("\n");
80
81        if res.is_empty() { None } else { Some(res) }
82    }
83
84    fn get_usage(&self) -> Option<Self::Usage> {
85        Some(self.usage.clone())
86    }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91    pub input_tokens: u64,
92    pub cache_read_input_tokens: Option<u64>,
93    pub cache_creation_input_tokens: Option<u64>,
94    pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        write!(
100            f,
101            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102            self.input_tokens,
103            match self.cache_read_input_tokens {
104                Some(token) => token.to_string(),
105                None => "n/a".to_string(),
106            },
107            match self.cache_creation_input_tokens {
108                Some(token) => token.to_string(),
109                None => "n/a".to_string(),
110            },
111            self.output_tokens
112        )
113    }
114}
115
116impl GetTokenUsage for Usage {
117    fn token_usage(&self) -> Option<crate::completion::Usage> {
118        let mut usage = crate::completion::Usage::new();
119
120        usage.input_tokens = self.input_tokens
121            + self.cache_creation_input_tokens.unwrap_or_default()
122            + self.cache_read_input_tokens.unwrap_or_default();
123        usage.output_tokens = self.output_tokens;
124        usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126        Some(usage)
127    }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132    pub name: String,
133    pub description: Option<String>,
134    pub input_schema: serde_json::Value,
135}
136
137/// Cache control directive for Anthropic prompt caching
138#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum CacheControl {
141    Ephemeral,
142}
143
144/// System message content block with optional cache control
145#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
146#[serde(tag = "type", rename_all = "snake_case")]
147pub enum SystemContent {
148    Text {
149        text: String,
150        #[serde(skip_serializing_if = "Option::is_none")]
151        cache_control: Option<CacheControl>,
152    },
153}
154
155impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
156    type Error = CompletionError;
157
158    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
159        let content = response
160            .content
161            .iter()
162            .map(|content| content.clone().try_into())
163            .collect::<Result<Vec<_>, _>>()?;
164
165        let choice = OneOrMany::many(content).map_err(|_| {
166            CompletionError::ResponseError(
167                "Response contained no message or tool call (empty)".to_owned(),
168            )
169        })?;
170
171        let usage = completion::Usage {
172            input_tokens: response.usage.input_tokens,
173            output_tokens: response.usage.output_tokens,
174            total_tokens: response.usage.input_tokens + response.usage.output_tokens,
175            cached_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
176        };
177
178        Ok(completion::CompletionResponse {
179            choice,
180            usage,
181            raw_response: response,
182        })
183    }
184}
185
186#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
187pub struct Message {
188    pub role: Role,
189    #[serde(deserialize_with = "string_or_one_or_many")]
190    pub content: OneOrMany<Content>,
191}
192
193#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
194#[serde(rename_all = "lowercase")]
195pub enum Role {
196    User,
197    Assistant,
198}
199
200#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
201#[serde(tag = "type", rename_all = "snake_case")]
202pub enum Content {
203    Text {
204        text: String,
205        #[serde(skip_serializing_if = "Option::is_none")]
206        cache_control: Option<CacheControl>,
207    },
208    Image {
209        source: ImageSource,
210        #[serde(skip_serializing_if = "Option::is_none")]
211        cache_control: Option<CacheControl>,
212    },
213    ToolUse {
214        id: String,
215        name: String,
216        input: serde_json::Value,
217    },
218    ToolResult {
219        tool_use_id: String,
220        #[serde(deserialize_with = "string_or_one_or_many")]
221        content: OneOrMany<ToolResultContent>,
222        #[serde(skip_serializing_if = "Option::is_none")]
223        is_error: Option<bool>,
224        #[serde(skip_serializing_if = "Option::is_none")]
225        cache_control: Option<CacheControl>,
226    },
227    Document {
228        source: DocumentSource,
229        #[serde(skip_serializing_if = "Option::is_none")]
230        cache_control: Option<CacheControl>,
231    },
232    Thinking {
233        thinking: String,
234        #[serde(skip_serializing_if = "Option::is_none")]
235        signature: Option<String>,
236    },
237}
238
239impl FromStr for Content {
240    type Err = Infallible;
241
242    fn from_str(s: &str) -> Result<Self, Self::Err> {
243        Ok(Content::Text {
244            text: s.to_owned(),
245            cache_control: None,
246        })
247    }
248}
249
250#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
251#[serde(tag = "type", rename_all = "snake_case")]
252pub enum ToolResultContent {
253    Text { text: String },
254    Image(ImageSource),
255}
256
257impl FromStr for ToolResultContent {
258    type Err = Infallible;
259
260    fn from_str(s: &str) -> Result<Self, Self::Err> {
261        Ok(ToolResultContent::Text { text: s.to_owned() })
262    }
263}
264
265#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
266#[serde(untagged)]
267pub enum ImageSourceData {
268    Base64(String),
269    Url(String),
270}
271
272impl From<ImageSourceData> for DocumentSourceKind {
273    fn from(value: ImageSourceData) -> Self {
274        match value {
275            ImageSourceData::Base64(data) => DocumentSourceKind::Base64(data),
276            ImageSourceData::Url(url) => DocumentSourceKind::Url(url),
277        }
278    }
279}
280
281impl TryFrom<DocumentSourceKind> for ImageSourceData {
282    type Error = MessageError;
283
284    fn try_from(value: DocumentSourceKind) -> Result<Self, Self::Error> {
285        match value {
286            DocumentSourceKind::Base64(data) => Ok(ImageSourceData::Base64(data)),
287            DocumentSourceKind::Url(url) => Ok(ImageSourceData::Url(url)),
288            _ => Err(MessageError::ConversionError("Content has no body".into())),
289        }
290    }
291}
292
293impl From<ImageSourceData> for String {
294    fn from(value: ImageSourceData) -> Self {
295        match value {
296            ImageSourceData::Base64(s) | ImageSourceData::Url(s) => s,
297        }
298    }
299}
300
301#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
302pub struct ImageSource {
303    pub data: ImageSourceData,
304    pub media_type: ImageFormat,
305    pub r#type: SourceType,
306}
307
308#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
309pub struct DocumentSource {
310    pub data: String,
311    pub media_type: DocumentFormat,
312    pub r#type: SourceType,
313}
314
315#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
316#[serde(rename_all = "lowercase")]
317pub enum ImageFormat {
318    #[serde(rename = "image/jpeg")]
319    JPEG,
320    #[serde(rename = "image/png")]
321    PNG,
322    #[serde(rename = "image/gif")]
323    GIF,
324    #[serde(rename = "image/webp")]
325    WEBP,
326}
327
328/// The document format to be used.
329///
330/// Currently, Anthropic only supports PDF for text documents over the API (within a message). You can find more information about this here: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
331#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
332#[serde(rename_all = "lowercase")]
333pub enum DocumentFormat {
334    #[serde(rename = "application/pdf")]
335    PDF,
336}
337
338#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
339#[serde(rename_all = "lowercase")]
340pub enum SourceType {
341    BASE64,
342    URL,
343}
344
345impl From<String> for Content {
346    fn from(text: String) -> Self {
347        Content::Text {
348            text,
349            cache_control: None,
350        }
351    }
352}
353
354impl From<String> for ToolResultContent {
355    fn from(text: String) -> Self {
356        ToolResultContent::Text { text }
357    }
358}
359
360impl TryFrom<message::ContentFormat> for SourceType {
361    type Error = MessageError;
362
363    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
364        match format {
365            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
366            message::ContentFormat::Url => Ok(SourceType::URL),
367            message::ContentFormat::String => Err(MessageError::ConversionError(
368                "ContentFormat::String is deprecated, use ContentFormat::Url for URLs".into(),
369            )),
370        }
371    }
372}
373
374impl From<SourceType> for message::ContentFormat {
375    fn from(source_type: SourceType) -> Self {
376        match source_type {
377            SourceType::BASE64 => message::ContentFormat::Base64,
378            SourceType::URL => message::ContentFormat::Url,
379        }
380    }
381}
382
383impl TryFrom<message::ImageMediaType> for ImageFormat {
384    type Error = MessageError;
385
386    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
387        Ok(match media_type {
388            message::ImageMediaType::JPEG => ImageFormat::JPEG,
389            message::ImageMediaType::PNG => ImageFormat::PNG,
390            message::ImageMediaType::GIF => ImageFormat::GIF,
391            message::ImageMediaType::WEBP => ImageFormat::WEBP,
392            _ => {
393                return Err(MessageError::ConversionError(
394                    format!("Unsupported image media type: {media_type:?}").to_owned(),
395                ));
396            }
397        })
398    }
399}
400
401impl From<ImageFormat> for message::ImageMediaType {
402    fn from(format: ImageFormat) -> Self {
403        match format {
404            ImageFormat::JPEG => message::ImageMediaType::JPEG,
405            ImageFormat::PNG => message::ImageMediaType::PNG,
406            ImageFormat::GIF => message::ImageMediaType::GIF,
407            ImageFormat::WEBP => message::ImageMediaType::WEBP,
408        }
409    }
410}
411
412impl TryFrom<DocumentMediaType> for DocumentFormat {
413    type Error = MessageError;
414    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
415        if !matches!(value, DocumentMediaType::PDF) {
416            return Err(MessageError::ConversionError(
417                "Anthropic only supports PDF documents".to_string(),
418            ));
419        };
420
421        Ok(DocumentFormat::PDF)
422    }
423}
424
425impl TryFrom<message::AssistantContent> for Content {
426    type Error = MessageError;
427    fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
428        match text {
429            message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
430                text,
431                cache_control: None,
432            }),
433            message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
434                "Anthropic currently doesn't support images.".to_string(),
435            )),
436            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
437                Ok(Content::ToolUse {
438                    id,
439                    name: function.name,
440                    input: function.arguments,
441                })
442            }
443            message::AssistantContent::Reasoning(Reasoning {
444                reasoning,
445                signature,
446                ..
447            }) => Ok(Content::Thinking {
448                thinking: reasoning.first().cloned().unwrap_or(String::new()),
449                signature,
450            }),
451        }
452    }
453}
454
455impl TryFrom<message::Message> for Message {
456    type Error = MessageError;
457
458    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
459        Ok(match message {
460            message::Message::User { content } => Message {
461                role: Role::User,
462                content: content.try_map(|content| match content {
463                    message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
464                        text,
465                        cache_control: None,
466                    }),
467                    message::UserContent::ToolResult(message::ToolResult {
468                        id, content, ..
469                    }) => Ok(Content::ToolResult {
470                        tool_use_id: id,
471                        content: content.try_map(|content| match content {
472                            message::ToolResultContent::Text(message::Text { text }) => {
473                                Ok(ToolResultContent::Text { text })
474                            }
475                            message::ToolResultContent::Image(image) => {
476                                let DocumentSourceKind::Base64(data) = image.data else {
477                                    return Err(MessageError::ConversionError(
478                                        "Only base64 strings can be used with the Anthropic API"
479                                            .to_string(),
480                                    ));
481                                };
482                                let media_type =
483                                    image.media_type.ok_or(MessageError::ConversionError(
484                                        "Image media type is required".to_owned(),
485                                    ))?;
486                                Ok(ToolResultContent::Image(ImageSource {
487                                    data: ImageSourceData::Base64(data),
488                                    media_type: media_type.try_into()?,
489                                    r#type: SourceType::BASE64,
490                                }))
491                            }
492                        })?,
493                        is_error: None,
494                        cache_control: None,
495                    }),
496                    message::UserContent::Image(message::Image {
497                        data, media_type, ..
498                    }) => {
499                        let media_type = media_type.ok_or(MessageError::ConversionError(
500                            "Image media type is required for Claude API".to_string(),
501                        ))?;
502
503                        let source = match data {
504                            DocumentSourceKind::Base64(data) => ImageSource {
505                                data: ImageSourceData::Base64(data),
506                                r#type: SourceType::BASE64,
507                                media_type: ImageFormat::try_from(media_type)?,
508                            },
509                            DocumentSourceKind::Url(url) => ImageSource {
510                                data: ImageSourceData::Url(url),
511                                r#type: SourceType::URL,
512                                media_type: ImageFormat::try_from(media_type)?,
513                            },
514                            DocumentSourceKind::Unknown => {
515                                return Err(MessageError::ConversionError(
516                                    "Image content has no body".into(),
517                                ));
518                            }
519                            doc => {
520                                return Err(MessageError::ConversionError(format!(
521                                    "Unsupported document type: {doc:?}"
522                                )));
523                            }
524                        };
525
526                        Ok(Content::Image {
527                            source,
528                            cache_control: None,
529                        })
530                    }
531                    message::UserContent::Document(message::Document {
532                        data, media_type, ..
533                    }) => {
534                        let media_type = media_type.ok_or(MessageError::ConversionError(
535                            "Document media type is required".to_string(),
536                        ))?;
537
538                        let data = match data {
539                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
540                                data
541                            }
542                            _ => {
543                                return Err(MessageError::ConversionError(
544                                    "Only base64 encoded documents currently supported".into(),
545                                ));
546                            }
547                        };
548
549                        let source = DocumentSource {
550                            data,
551                            media_type: media_type.try_into()?,
552                            r#type: SourceType::BASE64,
553                        };
554                        Ok(Content::Document {
555                            source,
556                            cache_control: None,
557                        })
558                    }
559                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
560                        "Audio is not supported in Anthropic".to_owned(),
561                    )),
562                    message::UserContent::Video { .. } => Err(MessageError::ConversionError(
563                        "Video is not supported in Anthropic".to_owned(),
564                    )),
565                })?,
566            },
567
568            message::Message::Assistant { content, .. } => Message {
569                content: content.try_map(|content| content.try_into())?,
570                role: Role::Assistant,
571            },
572        })
573    }
574}
575
576impl TryFrom<Content> for message::AssistantContent {
577    type Error = MessageError;
578
579    fn try_from(content: Content) -> Result<Self, Self::Error> {
580        Ok(match content {
581            Content::Text { text, .. } => message::AssistantContent::text(text),
582            Content::ToolUse { id, name, input } => {
583                message::AssistantContent::tool_call(id, name, input)
584            }
585            Content::Thinking {
586                thinking,
587                signature,
588            } => message::AssistantContent::Reasoning(
589                Reasoning::new(&thinking).with_signature(signature),
590            ),
591            _ => {
592                return Err(MessageError::ConversionError(
593                    "Content did not contain a message, tool call, or reasoning".to_owned(),
594                ));
595            }
596        })
597    }
598}
599
600impl From<ToolResultContent> for message::ToolResultContent {
601    fn from(content: ToolResultContent) -> Self {
602        match content {
603            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
604            ToolResultContent::Image(ImageSource {
605                data,
606                media_type: format,
607                ..
608            }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
609        }
610    }
611}
612
613impl TryFrom<Message> for message::Message {
614    type Error = MessageError;
615
616    fn try_from(message: Message) -> Result<Self, Self::Error> {
617        Ok(match message.role {
618            Role::User => message::Message::User {
619                content: message.content.try_map(|content| {
620                    Ok(match content {
621                        Content::Text { text, .. } => message::UserContent::text(text),
622                        Content::ToolResult {
623                            tool_use_id,
624                            content,
625                            ..
626                        } => message::UserContent::tool_result(
627                            tool_use_id,
628                            content.map(|content| content.into()),
629                        ),
630                        Content::Image { source, .. } => {
631                            message::UserContent::Image(message::Image {
632                                data: source.data.into(),
633                                media_type: Some(source.media_type.into()),
634                                detail: None,
635                                additional_params: None,
636                            })
637                        }
638                        Content::Document { source, .. } => message::UserContent::document(
639                            source.data,
640                            Some(message::DocumentMediaType::PDF),
641                        ),
642                        _ => {
643                            return Err(MessageError::ConversionError(
644                                "Unsupported content type for User role".to_owned(),
645                            ));
646                        }
647                    })
648                })?,
649            },
650            Role::Assistant => match message.content.first() {
651                Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
652                    message::Message::Assistant {
653                        id: None,
654                        content: message.content.try_map(|content| content.try_into())?,
655                    }
656                }
657
658                _ => {
659                    return Err(MessageError::ConversionError(
660                        format!("Unsupported message for Assistant role: {message:?}").to_owned(),
661                    ));
662                }
663            },
664        })
665    }
666}
667
668#[derive(Clone)]
669pub struct CompletionModel<T = reqwest::Client> {
670    pub(crate) client: Client<T>,
671    pub model: String,
672    pub default_max_tokens: Option<u64>,
673    /// Enable automatic prompt caching (adds cache_control breakpoints to system prompt and messages)
674    pub prompt_caching: bool,
675}
676
677impl<T> CompletionModel<T>
678where
679    T: HttpClientExt,
680{
681    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
682        let model = model.into();
683        let default_max_tokens = calculate_max_tokens(&model);
684
685        Self {
686            client,
687            model,
688            default_max_tokens,
689            prompt_caching: false, // Default to off
690        }
691    }
692
693    pub fn with_model(client: Client<T>, model: &str) -> Self {
694        Self {
695            client,
696            model: model.to_string(),
697            default_max_tokens: Some(calculate_max_tokens_custom(model)),
698            prompt_caching: false, // Default to off
699        }
700    }
701
702    /// Enable automatic prompt caching.
703    ///
704    /// When enabled, cache_control breakpoints are automatically added to:
705    /// - The system prompt (marked with ephemeral cache)
706    /// - The last content block of the last message (marked with ephemeral cache)
707    ///
708    /// This allows Anthropic to cache the conversation history for cost savings.
709    pub fn with_prompt_caching(mut self) -> Self {
710        self.prompt_caching = true;
711        self
712    }
713}
714
715/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
716/// set or if set too high, the request will fail. The following values are based on the models
717/// available at the time of writing.
718fn calculate_max_tokens(model: &str) -> Option<u64> {
719    if model.starts_with("claude-opus-4") {
720        Some(32000)
721    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
722        Some(64000)
723    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
724        Some(8192)
725    } else if model.starts_with("claude-3-opus")
726        || model.starts_with("claude-3-sonnet")
727        || model.starts_with("claude-3-haiku")
728    {
729        Some(4096)
730    } else {
731        None
732    }
733}
734
735fn calculate_max_tokens_custom(model: &str) -> u64 {
736    if model.starts_with("claude-opus-4") {
737        32000
738    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
739        64000
740    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
741        8192
742    } else if model.starts_with("claude-3-opus")
743        || model.starts_with("claude-3-sonnet")
744        || model.starts_with("claude-3-haiku")
745    {
746        4096
747    } else {
748        2048
749    }
750}
751
752#[derive(Debug, Deserialize, Serialize)]
753pub struct Metadata {
754    user_id: Option<String>,
755}
756
757#[derive(Default, Debug, Serialize, Deserialize)]
758#[serde(tag = "type", rename_all = "snake_case")]
759pub enum ToolChoice {
760    #[default]
761    Auto,
762    Any,
763    None,
764    Tool {
765        name: String,
766    },
767}
768impl TryFrom<message::ToolChoice> for ToolChoice {
769    type Error = CompletionError;
770
771    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
772        let res = match value {
773            message::ToolChoice::Auto => Self::Auto,
774            message::ToolChoice::None => Self::None,
775            message::ToolChoice::Required => Self::Any,
776            message::ToolChoice::Specific { function_names } => {
777                if function_names.len() != 1 {
778                    return Err(CompletionError::ProviderError(
779                        "Only one tool may be specified to be used by Claude".into(),
780                    ));
781                }
782
783                Self::Tool {
784                    name: function_names.first().unwrap().to_string(),
785                }
786            }
787        };
788
789        Ok(res)
790    }
791}
792
793#[derive(Debug, Deserialize, Serialize)]
794struct AnthropicCompletionRequest {
795    model: String,
796    messages: Vec<Message>,
797    max_tokens: u64,
798    /// System prompt as array of content blocks to support cache_control
799    #[serde(skip_serializing_if = "Vec::is_empty")]
800    system: Vec<SystemContent>,
801    #[serde(skip_serializing_if = "Option::is_none")]
802    temperature: Option<f64>,
803    #[serde(skip_serializing_if = "Option::is_none")]
804    tool_choice: Option<ToolChoice>,
805    #[serde(skip_serializing_if = "Vec::is_empty")]
806    tools: Vec<ToolDefinition>,
807    #[serde(flatten, skip_serializing_if = "Option::is_none")]
808    additional_params: Option<serde_json::Value>,
809}
810
811/// Helper to set cache_control on a Content block
812fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
813    match content {
814        Content::Text { cache_control, .. } => *cache_control = value,
815        Content::Image { cache_control, .. } => *cache_control = value,
816        Content::ToolResult { cache_control, .. } => *cache_control = value,
817        Content::Document { cache_control, .. } => *cache_control = value,
818        _ => {}
819    }
820}
821
822/// Apply cache control breakpoints to system prompt and messages.
823/// Strategy: cache the system prompt, and mark the last content block of the last message
824/// for caching. This allows the conversation history to be cached while new messages
825/// are added.
826pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
827    // Add cache_control to the system prompt (if non-empty)
828    if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
829        *cache_control = Some(CacheControl::Ephemeral);
830    }
831
832    // Clear any existing cache_control from all message content blocks
833    for msg in messages.iter_mut() {
834        for content in msg.content.iter_mut() {
835            set_content_cache_control(content, None);
836        }
837    }
838
839    // Add cache_control to the last content block of the last message
840    if let Some(last_msg) = messages.last_mut() {
841        set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::Ephemeral));
842    }
843}
844
845/// Parameters for building an AnthropicCompletionRequest
846pub struct AnthropicRequestParams<'a> {
847    pub model: &'a str,
848    pub request: CompletionRequest,
849    pub prompt_caching: bool,
850}
851
852impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
853    type Error = CompletionError;
854
855    fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
856        let AnthropicRequestParams {
857            model,
858            request: req,
859            prompt_caching,
860        } = params;
861
862        // Check if max_tokens is set, required for Anthropic
863        let Some(max_tokens) = req.max_tokens else {
864            return Err(CompletionError::RequestError(
865                "`max_tokens` must be set for Anthropic".into(),
866            ));
867        };
868
869        let mut full_history = vec![];
870        if let Some(docs) = req.normalized_documents() {
871            full_history.push(docs);
872        }
873        full_history.extend(req.chat_history);
874
875        let mut messages = full_history
876            .into_iter()
877            .map(Message::try_from)
878            .collect::<Result<Vec<Message>, _>>()?;
879
880        let tools = req
881            .tools
882            .into_iter()
883            .map(|tool| ToolDefinition {
884                name: tool.name,
885                description: Some(tool.description),
886                input_schema: tool.parameters,
887            })
888            .collect::<Vec<_>>();
889
890        // Convert system prompt to array format for cache_control support
891        let mut system = if let Some(preamble) = req.preamble {
892            if preamble.is_empty() {
893                vec![]
894            } else {
895                vec![SystemContent::Text {
896                    text: preamble,
897                    cache_control: None,
898                }]
899            }
900        } else {
901            vec![]
902        };
903
904        // Apply cache control breakpoints only if prompt_caching is enabled
905        if prompt_caching {
906            apply_cache_control(&mut system, &mut messages);
907        }
908
909        Ok(Self {
910            model: model.to_string(),
911            messages,
912            max_tokens,
913            system,
914            temperature: req.temperature,
915            tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
916            tools,
917            additional_params: req.additional_params,
918        })
919    }
920}
921
922impl<T> completion::CompletionModel for CompletionModel<T>
923where
924    T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
925{
926    type Response = CompletionResponse;
927    type StreamingResponse = StreamingCompletionResponse;
928    type Client = Client<T>;
929
930    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
931        Self::new(client.clone(), model.into())
932    }
933
934    async fn completion(
935        &self,
936        mut completion_request: completion::CompletionRequest,
937    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
938        let span = if tracing::Span::current().is_disabled() {
939            info_span!(
940                target: "rig::completions",
941                "chat",
942                gen_ai.operation.name = "chat",
943                gen_ai.provider.name = "anthropic",
944                gen_ai.request.model = &self.model,
945                gen_ai.system_instructions = &completion_request.preamble,
946                gen_ai.response.id = tracing::field::Empty,
947                gen_ai.response.model = tracing::field::Empty,
948                gen_ai.usage.output_tokens = tracing::field::Empty,
949                gen_ai.usage.input_tokens = tracing::field::Empty,
950            )
951        } else {
952            tracing::Span::current()
953        };
954
955        // Check if max_tokens is set, required for Anthropic
956        if completion_request.max_tokens.is_none() {
957            if let Some(tokens) = self.default_max_tokens {
958                completion_request.max_tokens = Some(tokens);
959            } else {
960                return Err(CompletionError::RequestError(
961                    "`max_tokens` must be set for Anthropic".into(),
962                ));
963            }
964        }
965
966        let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
967            model: &self.model,
968            request: completion_request,
969            prompt_caching: self.prompt_caching,
970        })?;
971
972        if enabled!(Level::TRACE) {
973            tracing::trace!(
974                target: "rig::completions",
975                "Anthropic completion request: {}",
976                serde_json::to_string_pretty(&request)?
977            );
978        }
979
980        async move {
981            let request: Vec<u8> = serde_json::to_vec(&request)?;
982
983            let req = self
984                .client
985                .post("/v1/messages")?
986                .body(request)
987                .map_err(|e| CompletionError::HttpError(e.into()))?;
988
989            let response = self
990                .client
991                .send::<_, Bytes>(req)
992                .await
993                .map_err(CompletionError::HttpError)?;
994
995            if response.status().is_success() {
996                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
997                    response
998                        .into_body()
999                        .await
1000                        .map_err(CompletionError::HttpError)?
1001                        .to_vec()
1002                        .as_slice(),
1003                )? {
1004                    ApiResponse::Message(completion) => {
1005                        let span = tracing::Span::current();
1006                        span.record_response_metadata(&completion);
1007                        span.record_token_usage(&completion.usage);
1008                        if enabled!(Level::TRACE) {
1009                            tracing::trace!(
1010                                target: "rig::completions",
1011                                "Anthropic completion response: {}",
1012                                serde_json::to_string_pretty(&completion)?
1013                            );
1014                        }
1015                        completion.try_into()
1016                    }
1017                    ApiResponse::Error(ApiErrorResponse { message }) => {
1018                        Err(CompletionError::ResponseError(message))
1019                    }
1020                }
1021            } else {
1022                let text: String = String::from_utf8_lossy(
1023                    &response
1024                        .into_body()
1025                        .await
1026                        .map_err(CompletionError::HttpError)?,
1027                )
1028                .into();
1029                Err(CompletionError::ProviderError(text))
1030            }
1031        }
1032        .instrument(span)
1033        .await
1034    }
1035
1036    async fn stream(
1037        &self,
1038        request: CompletionRequest,
1039    ) -> Result<
1040        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1041        CompletionError,
1042    > {
1043        CompletionModel::stream(self, request).await
1044    }
1045}
1046
1047#[derive(Debug, Deserialize)]
1048struct ApiErrorResponse {
1049    message: String,
1050}
1051
1052#[derive(Debug, Deserialize)]
1053#[serde(tag = "type", rename_all = "snake_case")]
1054enum ApiResponse<T> {
1055    Message(T),
1056    Error(ApiErrorResponse),
1057}
1058
1059#[cfg(test)]
1060mod tests {
1061    use super::*;
1062    use serde_json::json;
1063    use serde_path_to_error::deserialize;
1064
1065    #[test]
1066    fn test_deserialize_message() {
1067        let assistant_message_json = r#"
1068        {
1069            "role": "assistant",
1070            "content": "\n\nHello there, how may I assist you today?"
1071        }
1072        "#;
1073
1074        let assistant_message_json2 = r#"
1075        {
1076            "role": "assistant",
1077            "content": [
1078                {
1079                    "type": "text",
1080                    "text": "\n\nHello there, how may I assist you today?"
1081                },
1082                {
1083                    "type": "tool_use",
1084                    "id": "toolu_01A09q90qw90lq917835lq9",
1085                    "name": "get_weather",
1086                    "input": {"location": "San Francisco, CA"}
1087                }
1088            ]
1089        }
1090        "#;
1091
1092        let user_message_json = r#"
1093        {
1094            "role": "user",
1095            "content": [
1096                {
1097                    "type": "image",
1098                    "source": {
1099                        "type": "base64",
1100                        "media_type": "image/jpeg",
1101                        "data": "/9j/4AAQSkZJRg..."
1102                    }
1103                },
1104                {
1105                    "type": "text",
1106                    "text": "What is in this image?"
1107                },
1108                {
1109                    "type": "tool_result",
1110                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1111                    "content": "15 degrees"
1112                }
1113            ]
1114        }
1115        "#;
1116
1117        let assistant_message: Message = {
1118            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1119            deserialize(jd).unwrap_or_else(|err| {
1120                panic!("Deserialization error at {}: {}", err.path(), err);
1121            })
1122        };
1123
1124        let assistant_message2: Message = {
1125            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1126            deserialize(jd).unwrap_or_else(|err| {
1127                panic!("Deserialization error at {}: {}", err.path(), err);
1128            })
1129        };
1130
1131        let user_message: Message = {
1132            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1133            deserialize(jd).unwrap_or_else(|err| {
1134                panic!("Deserialization error at {}: {}", err.path(), err);
1135            })
1136        };
1137
1138        let Message { role, content } = assistant_message;
1139        assert_eq!(role, Role::Assistant);
1140        assert_eq!(
1141            content.first(),
1142            Content::Text {
1143                text: "\n\nHello there, how may I assist you today?".to_owned(),
1144                cache_control: None,
1145            }
1146        );
1147
1148        let Message { role, content } = assistant_message2;
1149        {
1150            assert_eq!(role, Role::Assistant);
1151            assert_eq!(content.len(), 2);
1152
1153            let mut iter = content.into_iter();
1154
1155            match iter.next().unwrap() {
1156                Content::Text { text, .. } => {
1157                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
1158                }
1159                _ => panic!("Expected text content"),
1160            }
1161
1162            match iter.next().unwrap() {
1163                Content::ToolUse { id, name, input } => {
1164                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1165                    assert_eq!(name, "get_weather");
1166                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
1167                }
1168                _ => panic!("Expected tool use content"),
1169            }
1170
1171            assert_eq!(iter.next(), None);
1172        }
1173
1174        let Message { role, content } = user_message;
1175        {
1176            assert_eq!(role, Role::User);
1177            assert_eq!(content.len(), 3);
1178
1179            let mut iter = content.into_iter();
1180
1181            match iter.next().unwrap() {
1182                Content::Image { source, .. } => {
1183                    assert_eq!(
1184                        source,
1185                        ImageSource {
1186                            data: ImageSourceData::Base64("/9j/4AAQSkZJRg...".to_owned()),
1187                            media_type: ImageFormat::JPEG,
1188                            r#type: SourceType::BASE64,
1189                        }
1190                    );
1191                }
1192                _ => panic!("Expected image content"),
1193            }
1194
1195            match iter.next().unwrap() {
1196                Content::Text { text, .. } => {
1197                    assert_eq!(text, "What is in this image?");
1198                }
1199                _ => panic!("Expected text content"),
1200            }
1201
1202            match iter.next().unwrap() {
1203                Content::ToolResult {
1204                    tool_use_id,
1205                    content,
1206                    is_error,
1207                    ..
1208                } => {
1209                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1210                    assert_eq!(
1211                        content.first(),
1212                        ToolResultContent::Text {
1213                            text: "15 degrees".to_owned()
1214                        }
1215                    );
1216                    assert_eq!(is_error, None);
1217                }
1218                _ => panic!("Expected tool result content"),
1219            }
1220
1221            assert_eq!(iter.next(), None);
1222        }
1223    }
1224
1225    #[test]
1226    fn test_message_to_message_conversion() {
1227        let user_message: Message = serde_json::from_str(
1228            r#"
1229        {
1230            "role": "user",
1231            "content": [
1232                {
1233                    "type": "image",
1234                    "source": {
1235                        "type": "base64",
1236                        "media_type": "image/jpeg",
1237                        "data": "/9j/4AAQSkZJRg..."
1238                    }
1239                },
1240                {
1241                    "type": "text",
1242                    "text": "What is in this image?"
1243                },
1244                {
1245                    "type": "document",
1246                    "source": {
1247                        "type": "base64",
1248                        "data": "base64_encoded_pdf_data",
1249                        "media_type": "application/pdf"
1250                    }
1251                }
1252            ]
1253        }
1254        "#,
1255        )
1256        .unwrap();
1257
1258        let assistant_message = Message {
1259            role: Role::Assistant,
1260            content: OneOrMany::one(Content::ToolUse {
1261                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1262                name: "get_weather".to_string(),
1263                input: json!({"location": "San Francisco, CA"}),
1264            }),
1265        };
1266
1267        let tool_message = Message {
1268            role: Role::User,
1269            content: OneOrMany::one(Content::ToolResult {
1270                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1271                content: OneOrMany::one(ToolResultContent::Text {
1272                    text: "15 degrees".to_string(),
1273                }),
1274                is_error: None,
1275                cache_control: None,
1276            }),
1277        };
1278
1279        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1280        let converted_assistant_message: message::Message =
1281            assistant_message.clone().try_into().unwrap();
1282        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1283
1284        match converted_user_message.clone() {
1285            message::Message::User { content } => {
1286                assert_eq!(content.len(), 3);
1287
1288                let mut iter = content.into_iter();
1289
1290                match iter.next().unwrap() {
1291                    message::UserContent::Image(message::Image {
1292                        data, media_type, ..
1293                    }) => {
1294                        assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1295                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1296                    }
1297                    _ => panic!("Expected image content"),
1298                }
1299
1300                match iter.next().unwrap() {
1301                    message::UserContent::Text(message::Text { text }) => {
1302                        assert_eq!(text, "What is in this image?");
1303                    }
1304                    _ => panic!("Expected text content"),
1305                }
1306
1307                match iter.next().unwrap() {
1308                    message::UserContent::Document(message::Document {
1309                        data, media_type, ..
1310                    }) => {
1311                        assert_eq!(
1312                            data,
1313                            DocumentSourceKind::String("base64_encoded_pdf_data".into())
1314                        );
1315                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1316                    }
1317                    _ => panic!("Expected document content"),
1318                }
1319
1320                assert_eq!(iter.next(), None);
1321            }
1322            _ => panic!("Expected user message"),
1323        }
1324
1325        match converted_tool_message.clone() {
1326            message::Message::User { content } => {
1327                let message::ToolResult { id, content, .. } = match content.first() {
1328                    message::UserContent::ToolResult(tool_result) => tool_result,
1329                    _ => panic!("Expected tool result content"),
1330                };
1331                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1332                match content.first() {
1333                    message::ToolResultContent::Text(message::Text { text }) => {
1334                        assert_eq!(text, "15 degrees");
1335                    }
1336                    _ => panic!("Expected text content"),
1337                }
1338            }
1339            _ => panic!("Expected tool result content"),
1340        }
1341
1342        match converted_assistant_message.clone() {
1343            message::Message::Assistant { content, .. } => {
1344                assert_eq!(content.len(), 1);
1345
1346                match content.first() {
1347                    message::AssistantContent::ToolCall(message::ToolCall {
1348                        id, function, ..
1349                    }) => {
1350                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1351                        assert_eq!(function.name, "get_weather");
1352                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1353                    }
1354                    _ => panic!("Expected tool call content"),
1355                }
1356            }
1357            _ => panic!("Expected assistant message"),
1358        }
1359
1360        let original_user_message: Message = converted_user_message.try_into().unwrap();
1361        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1362        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1363
1364        assert_eq!(user_message, original_user_message);
1365        assert_eq!(assistant_message, original_assistant_message);
1366        assert_eq!(tool_message, original_tool_message);
1367    }
1368
1369    #[test]
1370    fn test_content_format_conversion() {
1371        use crate::completion::message::ContentFormat;
1372
1373        let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1374        assert_eq!(source_type, SourceType::URL);
1375
1376        let content_format: ContentFormat = SourceType::URL.into();
1377        assert_eq!(content_format, ContentFormat::Url);
1378
1379        let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1380        assert_eq!(source_type, SourceType::BASE64);
1381
1382        let content_format: ContentFormat = SourceType::BASE64.into();
1383        assert_eq!(content_format, ContentFormat::Base64);
1384
1385        let result: Result<SourceType, _> = ContentFormat::String.try_into();
1386        assert!(result.is_err());
1387        assert!(
1388            result
1389                .unwrap_err()
1390                .to_string()
1391                .contains("ContentFormat::String is deprecated")
1392        );
1393    }
1394
1395    #[test]
1396    fn test_cache_control_serialization() {
1397        // Test SystemContent with cache_control
1398        let system = SystemContent::Text {
1399            text: "You are a helpful assistant.".to_string(),
1400            cache_control: Some(CacheControl::Ephemeral),
1401        };
1402        let json = serde_json::to_string(&system).unwrap();
1403        assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1404        assert!(json.contains(r#""type":"text""#));
1405
1406        // Test SystemContent without cache_control (should not have cache_control field)
1407        let system_no_cache = SystemContent::Text {
1408            text: "Hello".to_string(),
1409            cache_control: None,
1410        };
1411        let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1412        assert!(!json_no_cache.contains("cache_control"));
1413
1414        // Test Content::Text with cache_control
1415        let content = Content::Text {
1416            text: "Test message".to_string(),
1417            cache_control: Some(CacheControl::Ephemeral),
1418        };
1419        let json_content = serde_json::to_string(&content).unwrap();
1420        assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1421
1422        // Test apply_cache_control function
1423        let mut system_vec = vec![SystemContent::Text {
1424            text: "System prompt".to_string(),
1425            cache_control: None,
1426        }];
1427        let mut messages = vec![
1428            Message {
1429                role: Role::User,
1430                content: OneOrMany::one(Content::Text {
1431                    text: "First message".to_string(),
1432                    cache_control: None,
1433                }),
1434            },
1435            Message {
1436                role: Role::Assistant,
1437                content: OneOrMany::one(Content::Text {
1438                    text: "Response".to_string(),
1439                    cache_control: None,
1440                }),
1441            },
1442        ];
1443
1444        apply_cache_control(&mut system_vec, &mut messages);
1445
1446        // System should have cache_control
1447        match &system_vec[0] {
1448            SystemContent::Text { cache_control, .. } => {
1449                assert!(cache_control.is_some());
1450            }
1451        }
1452
1453        // Only the last content block of last message should have cache_control
1454        // First message should NOT have cache_control
1455        for content in messages[0].content.iter() {
1456            if let Content::Text { cache_control, .. } = content {
1457                assert!(cache_control.is_none());
1458            }
1459        }
1460
1461        // Last message SHOULD have cache_control
1462        for content in messages[1].content.iter() {
1463            if let Content::Text { cache_control, .. } = content {
1464                assert!(cache_control.is_some());
1465            }
1466        }
1467    }
1468}