1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError, GetTokenUsage},
6 http_client::HttpClientExt,
7 message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
8 one_or_many::string_or_one_or_many,
9 telemetry::{ProviderResponseExt, SpanCombinator},
10 wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42 pub content: Vec<Content>,
43 pub id: String,
44 pub model: String,
45 pub role: String,
46 pub stop_reason: Option<String>,
47 pub stop_sequence: Option<String>,
48 pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52 type OutputMessage = Content;
53 type Usage = Usage;
54
55 fn get_response_id(&self) -> Option<String> {
56 Some(self.id.to_owned())
57 }
58
59 fn get_response_model_name(&self) -> Option<String> {
60 Some(self.model.to_owned())
61 }
62
63 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64 self.content.clone()
65 }
66
67 fn get_text_response(&self) -> Option<String> {
68 let res = self
69 .content
70 .iter()
71 .filter_map(|x| {
72 if let Content::Text { text, .. } = x {
73 Some(text.to_owned())
74 } else {
75 None
76 }
77 })
78 .collect::<Vec<String>>()
79 .join("\n");
80
81 if res.is_empty() { None } else { Some(res) }
82 }
83
84 fn get_usage(&self) -> Option<Self::Usage> {
85 Some(self.usage.clone())
86 }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91 pub input_tokens: u64,
92 pub cache_read_input_tokens: Option<u64>,
93 pub cache_creation_input_tokens: Option<u64>,
94 pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 write!(
100 f,
101 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102 self.input_tokens,
103 match self.cache_read_input_tokens {
104 Some(token) => token.to_string(),
105 None => "n/a".to_string(),
106 },
107 match self.cache_creation_input_tokens {
108 Some(token) => token.to_string(),
109 None => "n/a".to_string(),
110 },
111 self.output_tokens
112 )
113 }
114}
115
116impl GetTokenUsage for Usage {
117 fn token_usage(&self) -> Option<crate::completion::Usage> {
118 let mut usage = crate::completion::Usage::new();
119
120 usage.input_tokens = self.input_tokens
121 + self.cache_creation_input_tokens.unwrap_or_default()
122 + self.cache_read_input_tokens.unwrap_or_default();
123 usage.output_tokens = self.output_tokens;
124 usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126 Some(usage)
127 }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132 pub name: String,
133 pub description: Option<String>,
134 pub input_schema: serde_json::Value,
135}
136
137#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum CacheControl {
141 Ephemeral,
142}
143
144#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
146#[serde(tag = "type", rename_all = "snake_case")]
147pub enum SystemContent {
148 Text {
149 text: String,
150 #[serde(skip_serializing_if = "Option::is_none")]
151 cache_control: Option<CacheControl>,
152 },
153}
154
155impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
156 type Error = CompletionError;
157
158 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
159 let content = response
160 .content
161 .iter()
162 .map(|content| content.clone().try_into())
163 .collect::<Result<Vec<_>, _>>()?;
164
165 let choice = OneOrMany::many(content).map_err(|_| {
166 CompletionError::ResponseError(
167 "Response contained no message or tool call (empty)".to_owned(),
168 )
169 })?;
170
171 let usage = completion::Usage {
172 input_tokens: response.usage.input_tokens,
173 output_tokens: response.usage.output_tokens,
174 total_tokens: response.usage.input_tokens + response.usage.output_tokens,
175 };
176
177 Ok(completion::CompletionResponse {
178 choice,
179 usage,
180 raw_response: response,
181 })
182 }
183}
184
185#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
186pub struct Message {
187 pub role: Role,
188 #[serde(deserialize_with = "string_or_one_or_many")]
189 pub content: OneOrMany<Content>,
190}
191
192#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
193#[serde(rename_all = "lowercase")]
194pub enum Role {
195 User,
196 Assistant,
197}
198
199#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
200#[serde(tag = "type", rename_all = "snake_case")]
201pub enum Content {
202 Text {
203 text: String,
204 #[serde(skip_serializing_if = "Option::is_none")]
205 cache_control: Option<CacheControl>,
206 },
207 Image {
208 source: ImageSource,
209 #[serde(skip_serializing_if = "Option::is_none")]
210 cache_control: Option<CacheControl>,
211 },
212 ToolUse {
213 id: String,
214 name: String,
215 input: serde_json::Value,
216 },
217 ToolResult {
218 tool_use_id: String,
219 #[serde(deserialize_with = "string_or_one_or_many")]
220 content: OneOrMany<ToolResultContent>,
221 #[serde(skip_serializing_if = "Option::is_none")]
222 is_error: Option<bool>,
223 #[serde(skip_serializing_if = "Option::is_none")]
224 cache_control: Option<CacheControl>,
225 },
226 Document {
227 source: DocumentSource,
228 #[serde(skip_serializing_if = "Option::is_none")]
229 cache_control: Option<CacheControl>,
230 },
231 Thinking {
232 thinking: String,
233 #[serde(skip_serializing_if = "Option::is_none")]
234 signature: Option<String>,
235 },
236}
237
238impl FromStr for Content {
239 type Err = Infallible;
240
241 fn from_str(s: &str) -> Result<Self, Self::Err> {
242 Ok(Content::Text {
243 text: s.to_owned(),
244 cache_control: None,
245 })
246 }
247}
248
249#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
250#[serde(tag = "type", rename_all = "snake_case")]
251pub enum ToolResultContent {
252 Text { text: String },
253 Image(ImageSource),
254}
255
256impl FromStr for ToolResultContent {
257 type Err = Infallible;
258
259 fn from_str(s: &str) -> Result<Self, Self::Err> {
260 Ok(ToolResultContent::Text { text: s.to_owned() })
261 }
262}
263
264#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
265#[serde(untagged)]
266pub enum ImageSourceData {
267 Base64(String),
268 Url(String),
269}
270
271impl From<ImageSourceData> for DocumentSourceKind {
272 fn from(value: ImageSourceData) -> Self {
273 match value {
274 ImageSourceData::Base64(data) => DocumentSourceKind::Base64(data),
275 ImageSourceData::Url(url) => DocumentSourceKind::Url(url),
276 }
277 }
278}
279
280impl TryFrom<DocumentSourceKind> for ImageSourceData {
281 type Error = MessageError;
282
283 fn try_from(value: DocumentSourceKind) -> Result<Self, Self::Error> {
284 match value {
285 DocumentSourceKind::Base64(data) => Ok(ImageSourceData::Base64(data)),
286 DocumentSourceKind::Url(url) => Ok(ImageSourceData::Url(url)),
287 _ => Err(MessageError::ConversionError("Content has no body".into())),
288 }
289 }
290}
291
292impl From<ImageSourceData> for String {
293 fn from(value: ImageSourceData) -> Self {
294 match value {
295 ImageSourceData::Base64(s) | ImageSourceData::Url(s) => s,
296 }
297 }
298}
299
300#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
301pub struct ImageSource {
302 pub data: ImageSourceData,
303 pub media_type: ImageFormat,
304 pub r#type: SourceType,
305}
306
307#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
308pub struct DocumentSource {
309 pub data: String,
310 pub media_type: DocumentFormat,
311 pub r#type: SourceType,
312}
313
314#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
315#[serde(rename_all = "lowercase")]
316pub enum ImageFormat {
317 #[serde(rename = "image/jpeg")]
318 JPEG,
319 #[serde(rename = "image/png")]
320 PNG,
321 #[serde(rename = "image/gif")]
322 GIF,
323 #[serde(rename = "image/webp")]
324 WEBP,
325}
326
327#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
331#[serde(rename_all = "lowercase")]
332pub enum DocumentFormat {
333 #[serde(rename = "application/pdf")]
334 PDF,
335}
336
337#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
338#[serde(rename_all = "lowercase")]
339pub enum SourceType {
340 BASE64,
341 URL,
342}
343
344impl From<String> for Content {
345 fn from(text: String) -> Self {
346 Content::Text {
347 text,
348 cache_control: None,
349 }
350 }
351}
352
353impl From<String> for ToolResultContent {
354 fn from(text: String) -> Self {
355 ToolResultContent::Text { text }
356 }
357}
358
359impl TryFrom<message::ContentFormat> for SourceType {
360 type Error = MessageError;
361
362 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
363 match format {
364 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
365 message::ContentFormat::Url => Ok(SourceType::URL),
366 message::ContentFormat::String => Err(MessageError::ConversionError(
367 "ContentFormat::String is deprecated, use ContentFormat::Url for URLs".into(),
368 )),
369 }
370 }
371}
372
373impl From<SourceType> for message::ContentFormat {
374 fn from(source_type: SourceType) -> Self {
375 match source_type {
376 SourceType::BASE64 => message::ContentFormat::Base64,
377 SourceType::URL => message::ContentFormat::Url,
378 }
379 }
380}
381
382impl TryFrom<message::ImageMediaType> for ImageFormat {
383 type Error = MessageError;
384
385 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
386 Ok(match media_type {
387 message::ImageMediaType::JPEG => ImageFormat::JPEG,
388 message::ImageMediaType::PNG => ImageFormat::PNG,
389 message::ImageMediaType::GIF => ImageFormat::GIF,
390 message::ImageMediaType::WEBP => ImageFormat::WEBP,
391 _ => {
392 return Err(MessageError::ConversionError(
393 format!("Unsupported image media type: {media_type:?}").to_owned(),
394 ));
395 }
396 })
397 }
398}
399
400impl From<ImageFormat> for message::ImageMediaType {
401 fn from(format: ImageFormat) -> Self {
402 match format {
403 ImageFormat::JPEG => message::ImageMediaType::JPEG,
404 ImageFormat::PNG => message::ImageMediaType::PNG,
405 ImageFormat::GIF => message::ImageMediaType::GIF,
406 ImageFormat::WEBP => message::ImageMediaType::WEBP,
407 }
408 }
409}
410
411impl TryFrom<DocumentMediaType> for DocumentFormat {
412 type Error = MessageError;
413 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
414 if !matches!(value, DocumentMediaType::PDF) {
415 return Err(MessageError::ConversionError(
416 "Anthropic only supports PDF documents".to_string(),
417 ));
418 };
419
420 Ok(DocumentFormat::PDF)
421 }
422}
423
424impl TryFrom<message::AssistantContent> for Content {
425 type Error = MessageError;
426 fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
427 match text {
428 message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
429 text,
430 cache_control: None,
431 }),
432 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
433 "Anthropic currently doesn't support images.".to_string(),
434 )),
435 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
436 Ok(Content::ToolUse {
437 id,
438 name: function.name,
439 input: function.arguments,
440 })
441 }
442 message::AssistantContent::Reasoning(Reasoning {
443 reasoning,
444 signature,
445 ..
446 }) => Ok(Content::Thinking {
447 thinking: reasoning.first().cloned().unwrap_or(String::new()),
448 signature,
449 }),
450 }
451 }
452}
453
454impl TryFrom<message::Message> for Message {
455 type Error = MessageError;
456
457 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
458 Ok(match message {
459 message::Message::User { content } => Message {
460 role: Role::User,
461 content: content.try_map(|content| match content {
462 message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
463 text,
464 cache_control: None,
465 }),
466 message::UserContent::ToolResult(message::ToolResult {
467 id, content, ..
468 }) => Ok(Content::ToolResult {
469 tool_use_id: id,
470 content: content.try_map(|content| match content {
471 message::ToolResultContent::Text(message::Text { text }) => {
472 Ok(ToolResultContent::Text { text })
473 }
474 message::ToolResultContent::Image(image) => {
475 let DocumentSourceKind::Base64(data) = image.data else {
476 return Err(MessageError::ConversionError(
477 "Only base64 strings can be used with the Anthropic API"
478 .to_string(),
479 ));
480 };
481 let media_type =
482 image.media_type.ok_or(MessageError::ConversionError(
483 "Image media type is required".to_owned(),
484 ))?;
485 Ok(ToolResultContent::Image(ImageSource {
486 data: ImageSourceData::Base64(data),
487 media_type: media_type.try_into()?,
488 r#type: SourceType::BASE64,
489 }))
490 }
491 })?,
492 is_error: None,
493 cache_control: None,
494 }),
495 message::UserContent::Image(message::Image {
496 data, media_type, ..
497 }) => {
498 let media_type = media_type.ok_or(MessageError::ConversionError(
499 "Image media type is required for Claude API".to_string(),
500 ))?;
501
502 let source = match data {
503 DocumentSourceKind::Base64(data) => ImageSource {
504 data: ImageSourceData::Base64(data),
505 r#type: SourceType::BASE64,
506 media_type: ImageFormat::try_from(media_type)?,
507 },
508 DocumentSourceKind::Url(url) => ImageSource {
509 data: ImageSourceData::Url(url),
510 r#type: SourceType::URL,
511 media_type: ImageFormat::try_from(media_type)?,
512 },
513 DocumentSourceKind::Unknown => {
514 return Err(MessageError::ConversionError(
515 "Image content has no body".into(),
516 ));
517 }
518 doc => {
519 return Err(MessageError::ConversionError(format!(
520 "Unsupported document type: {doc:?}"
521 )));
522 }
523 };
524
525 Ok(Content::Image {
526 source,
527 cache_control: None,
528 })
529 }
530 message::UserContent::Document(message::Document {
531 data, media_type, ..
532 }) => {
533 let media_type = media_type.ok_or(MessageError::ConversionError(
534 "Document media type is required".to_string(),
535 ))?;
536
537 let data = match data {
538 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
539 data
540 }
541 _ => {
542 return Err(MessageError::ConversionError(
543 "Only base64 encoded documents currently supported".into(),
544 ));
545 }
546 };
547
548 let source = DocumentSource {
549 data,
550 media_type: media_type.try_into()?,
551 r#type: SourceType::BASE64,
552 };
553 Ok(Content::Document {
554 source,
555 cache_control: None,
556 })
557 }
558 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
559 "Audio is not supported in Anthropic".to_owned(),
560 )),
561 message::UserContent::Video { .. } => Err(MessageError::ConversionError(
562 "Video is not supported in Anthropic".to_owned(),
563 )),
564 })?,
565 },
566
567 message::Message::Assistant { content, .. } => Message {
568 content: content.try_map(|content| content.try_into())?,
569 role: Role::Assistant,
570 },
571 })
572 }
573}
574
575impl TryFrom<Content> for message::AssistantContent {
576 type Error = MessageError;
577
578 fn try_from(content: Content) -> Result<Self, Self::Error> {
579 Ok(match content {
580 Content::Text { text, .. } => message::AssistantContent::text(text),
581 Content::ToolUse { id, name, input } => {
582 message::AssistantContent::tool_call(id, name, input)
583 }
584 Content::Thinking {
585 thinking,
586 signature,
587 } => message::AssistantContent::Reasoning(
588 Reasoning::new(&thinking).with_signature(signature),
589 ),
590 _ => {
591 return Err(MessageError::ConversionError(
592 "Content did not contain a message, tool call, or reasoning".to_owned(),
593 ));
594 }
595 })
596 }
597}
598
599impl From<ToolResultContent> for message::ToolResultContent {
600 fn from(content: ToolResultContent) -> Self {
601 match content {
602 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
603 ToolResultContent::Image(ImageSource {
604 data,
605 media_type: format,
606 ..
607 }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
608 }
609 }
610}
611
612impl TryFrom<Message> for message::Message {
613 type Error = MessageError;
614
615 fn try_from(message: Message) -> Result<Self, Self::Error> {
616 Ok(match message.role {
617 Role::User => message::Message::User {
618 content: message.content.try_map(|content| {
619 Ok(match content {
620 Content::Text { text, .. } => message::UserContent::text(text),
621 Content::ToolResult {
622 tool_use_id,
623 content,
624 ..
625 } => message::UserContent::tool_result(
626 tool_use_id,
627 content.map(|content| content.into()),
628 ),
629 Content::Image { source, .. } => {
630 message::UserContent::Image(message::Image {
631 data: source.data.into(),
632 media_type: Some(source.media_type.into()),
633 detail: None,
634 additional_params: None,
635 })
636 }
637 Content::Document { source, .. } => message::UserContent::document(
638 source.data,
639 Some(message::DocumentMediaType::PDF),
640 ),
641 _ => {
642 return Err(MessageError::ConversionError(
643 "Unsupported content type for User role".to_owned(),
644 ));
645 }
646 })
647 })?,
648 },
649 Role::Assistant => match message.content.first() {
650 Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
651 message::Message::Assistant {
652 id: None,
653 content: message.content.try_map(|content| content.try_into())?,
654 }
655 }
656
657 _ => {
658 return Err(MessageError::ConversionError(
659 format!("Unsupported message for Assistant role: {message:?}").to_owned(),
660 ));
661 }
662 },
663 })
664 }
665}
666
667#[derive(Clone)]
668pub struct CompletionModel<T = reqwest::Client> {
669 pub(crate) client: Client<T>,
670 pub model: String,
671 pub default_max_tokens: Option<u64>,
672 pub prompt_caching: bool,
674}
675
676impl<T> CompletionModel<T>
677where
678 T: HttpClientExt,
679{
680 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
681 let model = model.into();
682 let default_max_tokens = calculate_max_tokens(&model);
683
684 Self {
685 client,
686 model,
687 default_max_tokens,
688 prompt_caching: false, }
690 }
691
692 pub fn with_model(client: Client<T>, model: &str) -> Self {
693 Self {
694 client,
695 model: model.to_string(),
696 default_max_tokens: Some(calculate_max_tokens_custom(model)),
697 prompt_caching: false, }
699 }
700
701 pub fn with_prompt_caching(mut self) -> Self {
709 self.prompt_caching = true;
710 self
711 }
712}
713
714fn calculate_max_tokens(model: &str) -> Option<u64> {
718 match model {
719 CLAUDE_4_OPUS => Some(32_000),
720 CLAUDE_4_SONNET | CLAUDE_3_7_SONNET => Some(64_000),
721 CLAUDE_3_5_SONNET | CLAUDE_3_5_HAIKU => Some(8_192),
722 _ => None,
723 }
724}
725
726fn calculate_max_tokens_custom(model: &str) -> u64 {
727 match model {
728 "claude-4-opus" => 32_000,
729 "claude-4-sonnet" | "claude-3.7-sonnet" => 64_000,
730 "claude-3.5-sonnet" | "claude-3.5-haiku" => 8_192,
731 _ => 4_096,
732 }
733}
734
735#[derive(Debug, Deserialize, Serialize)]
736pub struct Metadata {
737 user_id: Option<String>,
738}
739
740#[derive(Default, Debug, Serialize, Deserialize)]
741#[serde(tag = "type", rename_all = "snake_case")]
742pub enum ToolChoice {
743 #[default]
744 Auto,
745 Any,
746 None,
747 Tool {
748 name: String,
749 },
750}
751impl TryFrom<message::ToolChoice> for ToolChoice {
752 type Error = CompletionError;
753
754 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
755 let res = match value {
756 message::ToolChoice::Auto => Self::Auto,
757 message::ToolChoice::None => Self::None,
758 message::ToolChoice::Required => Self::Any,
759 message::ToolChoice::Specific { function_names } => {
760 if function_names.len() != 1 {
761 return Err(CompletionError::ProviderError(
762 "Only one tool may be specified to be used by Claude".into(),
763 ));
764 }
765
766 Self::Tool {
767 name: function_names.first().unwrap().to_string(),
768 }
769 }
770 };
771
772 Ok(res)
773 }
774}
775
776#[derive(Debug, Deserialize, Serialize)]
777struct AnthropicCompletionRequest {
778 model: String,
779 messages: Vec<Message>,
780 max_tokens: u64,
781 #[serde(skip_serializing_if = "Vec::is_empty")]
783 system: Vec<SystemContent>,
784 #[serde(skip_serializing_if = "Option::is_none")]
785 temperature: Option<f64>,
786 #[serde(skip_serializing_if = "Option::is_none")]
787 tool_choice: Option<ToolChoice>,
788 #[serde(skip_serializing_if = "Vec::is_empty")]
789 tools: Vec<ToolDefinition>,
790 #[serde(flatten, skip_serializing_if = "Option::is_none")]
791 additional_params: Option<serde_json::Value>,
792}
793
794fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
796 match content {
797 Content::Text { cache_control, .. } => *cache_control = value,
798 Content::Image { cache_control, .. } => *cache_control = value,
799 Content::ToolResult { cache_control, .. } => *cache_control = value,
800 Content::Document { cache_control, .. } => *cache_control = value,
801 _ => {}
802 }
803}
804
805pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
810 if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
812 *cache_control = Some(CacheControl::Ephemeral);
813 }
814
815 for msg in messages.iter_mut() {
817 for content in msg.content.iter_mut() {
818 set_content_cache_control(content, None);
819 }
820 }
821
822 if let Some(last_msg) = messages.last_mut() {
824 set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::Ephemeral));
825 }
826}
827
828pub struct AnthropicRequestParams<'a> {
830 pub model: &'a str,
831 pub request: CompletionRequest,
832 pub prompt_caching: bool,
833}
834
835impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
836 type Error = CompletionError;
837
838 fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
839 let AnthropicRequestParams {
840 model,
841 request: req,
842 prompt_caching,
843 } = params;
844
845 let Some(max_tokens) = req.max_tokens else {
847 return Err(CompletionError::RequestError(
848 "`max_tokens` must be set for Anthropic".into(),
849 ));
850 };
851
852 let mut full_history = vec![];
853 if let Some(docs) = req.normalized_documents() {
854 full_history.push(docs);
855 }
856 full_history.extend(req.chat_history);
857
858 let mut messages = full_history
859 .into_iter()
860 .map(Message::try_from)
861 .collect::<Result<Vec<Message>, _>>()?;
862
863 let tools = req
864 .tools
865 .into_iter()
866 .map(|tool| ToolDefinition {
867 name: tool.name,
868 description: Some(tool.description),
869 input_schema: tool.parameters,
870 })
871 .collect::<Vec<_>>();
872
873 let mut system = if let Some(preamble) = req.preamble {
875 if preamble.is_empty() {
876 vec![]
877 } else {
878 vec![SystemContent::Text {
879 text: preamble,
880 cache_control: None,
881 }]
882 }
883 } else {
884 vec![]
885 };
886
887 if prompt_caching {
889 apply_cache_control(&mut system, &mut messages);
890 }
891
892 Ok(Self {
893 model: model.to_string(),
894 messages,
895 max_tokens,
896 system,
897 temperature: req.temperature,
898 tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
899 tools,
900 additional_params: req.additional_params,
901 })
902 }
903}
904
905impl<T> completion::CompletionModel for CompletionModel<T>
906where
907 T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
908{
909 type Response = CompletionResponse;
910 type StreamingResponse = StreamingCompletionResponse;
911 type Client = Client<T>;
912
913 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
914 Self::new(client.clone(), model.into())
915 }
916
917 #[cfg_attr(feature = "worker", worker::send)]
918 async fn completion(
919 &self,
920 mut completion_request: completion::CompletionRequest,
921 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
922 let span = if tracing::Span::current().is_disabled() {
923 info_span!(
924 target: "rig::completions",
925 "chat",
926 gen_ai.operation.name = "chat",
927 gen_ai.provider.name = "anthropic",
928 gen_ai.request.model = &self.model,
929 gen_ai.system_instructions = &completion_request.preamble,
930 gen_ai.response.id = tracing::field::Empty,
931 gen_ai.response.model = tracing::field::Empty,
932 gen_ai.usage.output_tokens = tracing::field::Empty,
933 gen_ai.usage.input_tokens = tracing::field::Empty,
934 )
935 } else {
936 tracing::Span::current()
937 };
938
939 if completion_request.max_tokens.is_none() {
941 if let Some(tokens) = self.default_max_tokens {
942 completion_request.max_tokens = Some(tokens);
943 } else {
944 return Err(CompletionError::RequestError(
945 "`max_tokens` must be set for Anthropic".into(),
946 ));
947 }
948 }
949
950 let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
951 model: &self.model,
952 request: completion_request,
953 prompt_caching: self.prompt_caching,
954 })?;
955
956 if enabled!(Level::TRACE) {
957 tracing::trace!(
958 target: "rig::completions",
959 "Anthropic completion request: {}",
960 serde_json::to_string_pretty(&request)?
961 );
962 }
963
964 async move {
965 let request: Vec<u8> = serde_json::to_vec(&request)?;
966
967 let req = self
968 .client
969 .post("/v1/messages")?
970 .body(request)
971 .map_err(|e| CompletionError::HttpError(e.into()))?;
972
973 let response = self
974 .client
975 .send::<_, Bytes>(req)
976 .await
977 .map_err(CompletionError::HttpError)?;
978
979 if response.status().is_success() {
980 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
981 response
982 .into_body()
983 .await
984 .map_err(CompletionError::HttpError)?
985 .to_vec()
986 .as_slice(),
987 )? {
988 ApiResponse::Message(completion) => {
989 let span = tracing::Span::current();
990 span.record_response_metadata(&completion);
991 span.record_token_usage(&completion.usage);
992 if enabled!(Level::TRACE) {
993 tracing::trace!(
994 target: "rig::completions",
995 "Anthropic completion response: {}",
996 serde_json::to_string_pretty(&completion)?
997 );
998 }
999 completion.try_into()
1000 }
1001 ApiResponse::Error(ApiErrorResponse { message }) => {
1002 Err(CompletionError::ResponseError(message))
1003 }
1004 }
1005 } else {
1006 let text: String = String::from_utf8_lossy(
1007 &response
1008 .into_body()
1009 .await
1010 .map_err(CompletionError::HttpError)?,
1011 )
1012 .into();
1013 Err(CompletionError::ProviderError(text))
1014 }
1015 }
1016 .instrument(span)
1017 .await
1018 }
1019
1020 #[cfg_attr(feature = "worker", worker::send)]
1021 async fn stream(
1022 &self,
1023 request: CompletionRequest,
1024 ) -> Result<
1025 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1026 CompletionError,
1027 > {
1028 CompletionModel::stream(self, request).await
1029 }
1030}
1031
1032#[derive(Debug, Deserialize)]
1033struct ApiErrorResponse {
1034 message: String,
1035}
1036
1037#[derive(Debug, Deserialize)]
1038#[serde(tag = "type", rename_all = "snake_case")]
1039enum ApiResponse<T> {
1040 Message(T),
1041 Error(ApiErrorResponse),
1042}
1043
1044#[cfg(test)]
1045mod tests {
1046 use super::*;
1047 use serde_json::json;
1048 use serde_path_to_error::deserialize;
1049
1050 #[test]
1051 fn test_deserialize_message() {
1052 let assistant_message_json = r#"
1053 {
1054 "role": "assistant",
1055 "content": "\n\nHello there, how may I assist you today?"
1056 }
1057 "#;
1058
1059 let assistant_message_json2 = r#"
1060 {
1061 "role": "assistant",
1062 "content": [
1063 {
1064 "type": "text",
1065 "text": "\n\nHello there, how may I assist you today?"
1066 },
1067 {
1068 "type": "tool_use",
1069 "id": "toolu_01A09q90qw90lq917835lq9",
1070 "name": "get_weather",
1071 "input": {"location": "San Francisco, CA"}
1072 }
1073 ]
1074 }
1075 "#;
1076
1077 let user_message_json = r#"
1078 {
1079 "role": "user",
1080 "content": [
1081 {
1082 "type": "image",
1083 "source": {
1084 "type": "base64",
1085 "media_type": "image/jpeg",
1086 "data": "/9j/4AAQSkZJRg..."
1087 }
1088 },
1089 {
1090 "type": "text",
1091 "text": "What is in this image?"
1092 },
1093 {
1094 "type": "tool_result",
1095 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1096 "content": "15 degrees"
1097 }
1098 ]
1099 }
1100 "#;
1101
1102 let assistant_message: Message = {
1103 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1104 deserialize(jd).unwrap_or_else(|err| {
1105 panic!("Deserialization error at {}: {}", err.path(), err);
1106 })
1107 };
1108
1109 let assistant_message2: Message = {
1110 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1111 deserialize(jd).unwrap_or_else(|err| {
1112 panic!("Deserialization error at {}: {}", err.path(), err);
1113 })
1114 };
1115
1116 let user_message: Message = {
1117 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1118 deserialize(jd).unwrap_or_else(|err| {
1119 panic!("Deserialization error at {}: {}", err.path(), err);
1120 })
1121 };
1122
1123 let Message { role, content } = assistant_message;
1124 assert_eq!(role, Role::Assistant);
1125 assert_eq!(
1126 content.first(),
1127 Content::Text {
1128 text: "\n\nHello there, how may I assist you today?".to_owned(),
1129 cache_control: None,
1130 }
1131 );
1132
1133 let Message { role, content } = assistant_message2;
1134 {
1135 assert_eq!(role, Role::Assistant);
1136 assert_eq!(content.len(), 2);
1137
1138 let mut iter = content.into_iter();
1139
1140 match iter.next().unwrap() {
1141 Content::Text { text, .. } => {
1142 assert_eq!(text, "\n\nHello there, how may I assist you today?");
1143 }
1144 _ => panic!("Expected text content"),
1145 }
1146
1147 match iter.next().unwrap() {
1148 Content::ToolUse { id, name, input } => {
1149 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1150 assert_eq!(name, "get_weather");
1151 assert_eq!(input, json!({"location": "San Francisco, CA"}));
1152 }
1153 _ => panic!("Expected tool use content"),
1154 }
1155
1156 assert_eq!(iter.next(), None);
1157 }
1158
1159 let Message { role, content } = user_message;
1160 {
1161 assert_eq!(role, Role::User);
1162 assert_eq!(content.len(), 3);
1163
1164 let mut iter = content.into_iter();
1165
1166 match iter.next().unwrap() {
1167 Content::Image { source, .. } => {
1168 assert_eq!(
1169 source,
1170 ImageSource {
1171 data: ImageSourceData::Base64("/9j/4AAQSkZJRg...".to_owned()),
1172 media_type: ImageFormat::JPEG,
1173 r#type: SourceType::BASE64,
1174 }
1175 );
1176 }
1177 _ => panic!("Expected image content"),
1178 }
1179
1180 match iter.next().unwrap() {
1181 Content::Text { text, .. } => {
1182 assert_eq!(text, "What is in this image?");
1183 }
1184 _ => panic!("Expected text content"),
1185 }
1186
1187 match iter.next().unwrap() {
1188 Content::ToolResult {
1189 tool_use_id,
1190 content,
1191 is_error,
1192 ..
1193 } => {
1194 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1195 assert_eq!(
1196 content.first(),
1197 ToolResultContent::Text {
1198 text: "15 degrees".to_owned()
1199 }
1200 );
1201 assert_eq!(is_error, None);
1202 }
1203 _ => panic!("Expected tool result content"),
1204 }
1205
1206 assert_eq!(iter.next(), None);
1207 }
1208 }
1209
1210 #[test]
1211 fn test_message_to_message_conversion() {
1212 let user_message: Message = serde_json::from_str(
1213 r#"
1214 {
1215 "role": "user",
1216 "content": [
1217 {
1218 "type": "image",
1219 "source": {
1220 "type": "base64",
1221 "media_type": "image/jpeg",
1222 "data": "/9j/4AAQSkZJRg..."
1223 }
1224 },
1225 {
1226 "type": "text",
1227 "text": "What is in this image?"
1228 },
1229 {
1230 "type": "document",
1231 "source": {
1232 "type": "base64",
1233 "data": "base64_encoded_pdf_data",
1234 "media_type": "application/pdf"
1235 }
1236 }
1237 ]
1238 }
1239 "#,
1240 )
1241 .unwrap();
1242
1243 let assistant_message = Message {
1244 role: Role::Assistant,
1245 content: OneOrMany::one(Content::ToolUse {
1246 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1247 name: "get_weather".to_string(),
1248 input: json!({"location": "San Francisco, CA"}),
1249 }),
1250 };
1251
1252 let tool_message = Message {
1253 role: Role::User,
1254 content: OneOrMany::one(Content::ToolResult {
1255 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1256 content: OneOrMany::one(ToolResultContent::Text {
1257 text: "15 degrees".to_string(),
1258 }),
1259 is_error: None,
1260 cache_control: None,
1261 }),
1262 };
1263
1264 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1265 let converted_assistant_message: message::Message =
1266 assistant_message.clone().try_into().unwrap();
1267 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1268
1269 match converted_user_message.clone() {
1270 message::Message::User { content } => {
1271 assert_eq!(content.len(), 3);
1272
1273 let mut iter = content.into_iter();
1274
1275 match iter.next().unwrap() {
1276 message::UserContent::Image(message::Image {
1277 data, media_type, ..
1278 }) => {
1279 assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1280 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1281 }
1282 _ => panic!("Expected image content"),
1283 }
1284
1285 match iter.next().unwrap() {
1286 message::UserContent::Text(message::Text { text }) => {
1287 assert_eq!(text, "What is in this image?");
1288 }
1289 _ => panic!("Expected text content"),
1290 }
1291
1292 match iter.next().unwrap() {
1293 message::UserContent::Document(message::Document {
1294 data, media_type, ..
1295 }) => {
1296 assert_eq!(
1297 data,
1298 DocumentSourceKind::String("base64_encoded_pdf_data".into())
1299 );
1300 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1301 }
1302 _ => panic!("Expected document content"),
1303 }
1304
1305 assert_eq!(iter.next(), None);
1306 }
1307 _ => panic!("Expected user message"),
1308 }
1309
1310 match converted_tool_message.clone() {
1311 message::Message::User { content } => {
1312 let message::ToolResult { id, content, .. } = match content.first() {
1313 message::UserContent::ToolResult(tool_result) => tool_result,
1314 _ => panic!("Expected tool result content"),
1315 };
1316 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1317 match content.first() {
1318 message::ToolResultContent::Text(message::Text { text }) => {
1319 assert_eq!(text, "15 degrees");
1320 }
1321 _ => panic!("Expected text content"),
1322 }
1323 }
1324 _ => panic!("Expected tool result content"),
1325 }
1326
1327 match converted_assistant_message.clone() {
1328 message::Message::Assistant { content, .. } => {
1329 assert_eq!(content.len(), 1);
1330
1331 match content.first() {
1332 message::AssistantContent::ToolCall(message::ToolCall {
1333 id, function, ..
1334 }) => {
1335 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1336 assert_eq!(function.name, "get_weather");
1337 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1338 }
1339 _ => panic!("Expected tool call content"),
1340 }
1341 }
1342 _ => panic!("Expected assistant message"),
1343 }
1344
1345 let original_user_message: Message = converted_user_message.try_into().unwrap();
1346 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1347 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1348
1349 assert_eq!(user_message, original_user_message);
1350 assert_eq!(assistant_message, original_assistant_message);
1351 assert_eq!(tool_message, original_tool_message);
1352 }
1353
1354 #[test]
1355 fn test_content_format_conversion() {
1356 use crate::completion::message::ContentFormat;
1357
1358 let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1359 assert_eq!(source_type, SourceType::URL);
1360
1361 let content_format: ContentFormat = SourceType::URL.into();
1362 assert_eq!(content_format, ContentFormat::Url);
1363
1364 let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1365 assert_eq!(source_type, SourceType::BASE64);
1366
1367 let content_format: ContentFormat = SourceType::BASE64.into();
1368 assert_eq!(content_format, ContentFormat::Base64);
1369
1370 let result: Result<SourceType, _> = ContentFormat::String.try_into();
1371 assert!(result.is_err());
1372 assert!(
1373 result
1374 .unwrap_err()
1375 .to_string()
1376 .contains("ContentFormat::String is deprecated")
1377 );
1378 }
1379
1380 #[test]
1381 fn test_cache_control_serialization() {
1382 let system = SystemContent::Text {
1384 text: "You are a helpful assistant.".to_string(),
1385 cache_control: Some(CacheControl::Ephemeral),
1386 };
1387 let json = serde_json::to_string(&system).unwrap();
1388 assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1389 assert!(json.contains(r#""type":"text""#));
1390
1391 let system_no_cache = SystemContent::Text {
1393 text: "Hello".to_string(),
1394 cache_control: None,
1395 };
1396 let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1397 assert!(!json_no_cache.contains("cache_control"));
1398
1399 let content = Content::Text {
1401 text: "Test message".to_string(),
1402 cache_control: Some(CacheControl::Ephemeral),
1403 };
1404 let json_content = serde_json::to_string(&content).unwrap();
1405 assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1406
1407 let mut system_vec = vec![SystemContent::Text {
1409 text: "System prompt".to_string(),
1410 cache_control: None,
1411 }];
1412 let mut messages = vec![
1413 Message {
1414 role: Role::User,
1415 content: OneOrMany::one(Content::Text {
1416 text: "First message".to_string(),
1417 cache_control: None,
1418 }),
1419 },
1420 Message {
1421 role: Role::Assistant,
1422 content: OneOrMany::one(Content::Text {
1423 text: "Response".to_string(),
1424 cache_control: None,
1425 }),
1426 },
1427 ];
1428
1429 apply_cache_control(&mut system_vec, &mut messages);
1430
1431 match &system_vec[0] {
1433 SystemContent::Text { cache_control, .. } => {
1434 assert!(cache_control.is_some());
1435 }
1436 }
1437
1438 for content in messages[0].content.iter() {
1441 if let Content::Text { cache_control, .. } = content {
1442 assert!(cache_control.is_none());
1443 }
1444 }
1445
1446 for content in messages[1].content.iter() {
1448 if let Content::Text { cache_control, .. } = content {
1449 assert!(cache_control.is_some());
1450 }
1451 }
1452 }
1453}