1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError, GetTokenUsage},
6 http_client::HttpClientExt,
7 message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
8 one_or_many::string_or_one_or_many,
9 telemetry::{ProviderResponseExt, SpanCombinator},
10 wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42 pub content: Vec<Content>,
43 pub id: String,
44 pub model: String,
45 pub role: String,
46 pub stop_reason: Option<String>,
47 pub stop_sequence: Option<String>,
48 pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52 type OutputMessage = Content;
53 type Usage = Usage;
54
55 fn get_response_id(&self) -> Option<String> {
56 Some(self.id.to_owned())
57 }
58
59 fn get_response_model_name(&self) -> Option<String> {
60 Some(self.model.to_owned())
61 }
62
63 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64 self.content.clone()
65 }
66
67 fn get_text_response(&self) -> Option<String> {
68 let res = self
69 .content
70 .iter()
71 .filter_map(|x| {
72 if let Content::Text { text, .. } = x {
73 Some(text.to_owned())
74 } else {
75 None
76 }
77 })
78 .collect::<Vec<String>>()
79 .join("\n");
80
81 if res.is_empty() { None } else { Some(res) }
82 }
83
84 fn get_usage(&self) -> Option<Self::Usage> {
85 Some(self.usage.clone())
86 }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91 pub input_tokens: u64,
92 pub cache_read_input_tokens: Option<u64>,
93 pub cache_creation_input_tokens: Option<u64>,
94 pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 write!(
100 f,
101 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102 self.input_tokens,
103 match self.cache_read_input_tokens {
104 Some(token) => token.to_string(),
105 None => "n/a".to_string(),
106 },
107 match self.cache_creation_input_tokens {
108 Some(token) => token.to_string(),
109 None => "n/a".to_string(),
110 },
111 self.output_tokens
112 )
113 }
114}
115
116impl GetTokenUsage for Usage {
117 fn token_usage(&self) -> Option<crate::completion::Usage> {
118 let mut usage = crate::completion::Usage::new();
119
120 usage.input_tokens = self.input_tokens
121 + self.cache_creation_input_tokens.unwrap_or_default()
122 + self.cache_read_input_tokens.unwrap_or_default();
123 usage.output_tokens = self.output_tokens;
124 usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126 Some(usage)
127 }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132 pub name: String,
133 pub description: Option<String>,
134 pub input_schema: serde_json::Value,
135}
136
137#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum CacheControl {
141 Ephemeral,
142}
143
144#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
146#[serde(tag = "type", rename_all = "snake_case")]
147pub enum SystemContent {
148 Text {
149 text: String,
150 #[serde(skip_serializing_if = "Option::is_none")]
151 cache_control: Option<CacheControl>,
152 },
153}
154
155impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
156 type Error = CompletionError;
157
158 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
159 let content = response
160 .content
161 .iter()
162 .map(|content| content.clone().try_into())
163 .collect::<Result<Vec<_>, _>>()?;
164
165 let choice = OneOrMany::many(content).map_err(|_| {
166 CompletionError::ResponseError(
167 "Response contained no message or tool call (empty)".to_owned(),
168 )
169 })?;
170
171 let usage = completion::Usage {
172 input_tokens: response.usage.input_tokens,
173 output_tokens: response.usage.output_tokens,
174 total_tokens: response.usage.input_tokens + response.usage.output_tokens,
175 };
176
177 Ok(completion::CompletionResponse {
178 choice,
179 usage,
180 raw_response: response,
181 })
182 }
183}
184
185#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
186pub struct Message {
187 pub role: Role,
188 #[serde(deserialize_with = "string_or_one_or_many")]
189 pub content: OneOrMany<Content>,
190}
191
192#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
193#[serde(rename_all = "lowercase")]
194pub enum Role {
195 User,
196 Assistant,
197}
198
199#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
200#[serde(tag = "type", rename_all = "snake_case")]
201pub enum Content {
202 Text {
203 text: String,
204 #[serde(skip_serializing_if = "Option::is_none")]
205 cache_control: Option<CacheControl>,
206 },
207 Image {
208 source: ImageSource,
209 #[serde(skip_serializing_if = "Option::is_none")]
210 cache_control: Option<CacheControl>,
211 },
212 ToolUse {
213 id: String,
214 name: String,
215 input: serde_json::Value,
216 },
217 ToolResult {
218 tool_use_id: String,
219 #[serde(deserialize_with = "string_or_one_or_many")]
220 content: OneOrMany<ToolResultContent>,
221 #[serde(skip_serializing_if = "Option::is_none")]
222 is_error: Option<bool>,
223 #[serde(skip_serializing_if = "Option::is_none")]
224 cache_control: Option<CacheControl>,
225 },
226 Document {
227 source: DocumentSource,
228 #[serde(skip_serializing_if = "Option::is_none")]
229 cache_control: Option<CacheControl>,
230 },
231 Thinking {
232 thinking: String,
233 #[serde(skip_serializing_if = "Option::is_none")]
234 signature: Option<String>,
235 },
236}
237
238impl FromStr for Content {
239 type Err = Infallible;
240
241 fn from_str(s: &str) -> Result<Self, Self::Err> {
242 Ok(Content::Text {
243 text: s.to_owned(),
244 cache_control: None,
245 })
246 }
247}
248
249#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
250#[serde(tag = "type", rename_all = "snake_case")]
251pub enum ToolResultContent {
252 Text { text: String },
253 Image(ImageSource),
254}
255
256impl FromStr for ToolResultContent {
257 type Err = Infallible;
258
259 fn from_str(s: &str) -> Result<Self, Self::Err> {
260 Ok(ToolResultContent::Text { text: s.to_owned() })
261 }
262}
263
264#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
265#[serde(untagged)]
266pub enum ImageSourceData {
267 Base64(String),
268 Url(String),
269}
270
271impl From<ImageSourceData> for DocumentSourceKind {
272 fn from(value: ImageSourceData) -> Self {
273 match value {
274 ImageSourceData::Base64(data) => DocumentSourceKind::Base64(data),
275 ImageSourceData::Url(url) => DocumentSourceKind::Url(url),
276 }
277 }
278}
279
280impl TryFrom<DocumentSourceKind> for ImageSourceData {
281 type Error = MessageError;
282
283 fn try_from(value: DocumentSourceKind) -> Result<Self, Self::Error> {
284 match value {
285 DocumentSourceKind::Base64(data) => Ok(ImageSourceData::Base64(data)),
286 DocumentSourceKind::Url(url) => Ok(ImageSourceData::Url(url)),
287 _ => Err(MessageError::ConversionError("Content has no body".into())),
288 }
289 }
290}
291
292impl From<ImageSourceData> for String {
293 fn from(value: ImageSourceData) -> Self {
294 match value {
295 ImageSourceData::Base64(s) | ImageSourceData::Url(s) => s,
296 }
297 }
298}
299
300#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
301pub struct ImageSource {
302 pub data: ImageSourceData,
303 pub media_type: ImageFormat,
304 pub r#type: SourceType,
305}
306
307#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
308pub struct DocumentSource {
309 pub data: String,
310 pub media_type: DocumentFormat,
311 pub r#type: SourceType,
312}
313
314#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
315#[serde(rename_all = "lowercase")]
316pub enum ImageFormat {
317 #[serde(rename = "image/jpeg")]
318 JPEG,
319 #[serde(rename = "image/png")]
320 PNG,
321 #[serde(rename = "image/gif")]
322 GIF,
323 #[serde(rename = "image/webp")]
324 WEBP,
325}
326
327#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
331#[serde(rename_all = "lowercase")]
332pub enum DocumentFormat {
333 #[serde(rename = "application/pdf")]
334 PDF,
335}
336
337#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
338#[serde(rename_all = "lowercase")]
339pub enum SourceType {
340 BASE64,
341 URL,
342}
343
344impl From<String> for Content {
345 fn from(text: String) -> Self {
346 Content::Text {
347 text,
348 cache_control: None,
349 }
350 }
351}
352
353impl From<String> for ToolResultContent {
354 fn from(text: String) -> Self {
355 ToolResultContent::Text { text }
356 }
357}
358
359impl TryFrom<message::ContentFormat> for SourceType {
360 type Error = MessageError;
361
362 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
363 match format {
364 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
365 message::ContentFormat::Url => Ok(SourceType::URL),
366 message::ContentFormat::String => Err(MessageError::ConversionError(
367 "ContentFormat::String is deprecated, use ContentFormat::Url for URLs".into(),
368 )),
369 }
370 }
371}
372
373impl From<SourceType> for message::ContentFormat {
374 fn from(source_type: SourceType) -> Self {
375 match source_type {
376 SourceType::BASE64 => message::ContentFormat::Base64,
377 SourceType::URL => message::ContentFormat::Url,
378 }
379 }
380}
381
382impl TryFrom<message::ImageMediaType> for ImageFormat {
383 type Error = MessageError;
384
385 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
386 Ok(match media_type {
387 message::ImageMediaType::JPEG => ImageFormat::JPEG,
388 message::ImageMediaType::PNG => ImageFormat::PNG,
389 message::ImageMediaType::GIF => ImageFormat::GIF,
390 message::ImageMediaType::WEBP => ImageFormat::WEBP,
391 _ => {
392 return Err(MessageError::ConversionError(
393 format!("Unsupported image media type: {media_type:?}").to_owned(),
394 ));
395 }
396 })
397 }
398}
399
400impl From<ImageFormat> for message::ImageMediaType {
401 fn from(format: ImageFormat) -> Self {
402 match format {
403 ImageFormat::JPEG => message::ImageMediaType::JPEG,
404 ImageFormat::PNG => message::ImageMediaType::PNG,
405 ImageFormat::GIF => message::ImageMediaType::GIF,
406 ImageFormat::WEBP => message::ImageMediaType::WEBP,
407 }
408 }
409}
410
411impl TryFrom<DocumentMediaType> for DocumentFormat {
412 type Error = MessageError;
413 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
414 if !matches!(value, DocumentMediaType::PDF) {
415 return Err(MessageError::ConversionError(
416 "Anthropic only supports PDF documents".to_string(),
417 ));
418 };
419
420 Ok(DocumentFormat::PDF)
421 }
422}
423
424impl TryFrom<message::AssistantContent> for Content {
425 type Error = MessageError;
426 fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
427 match text {
428 message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
429 text,
430 cache_control: None,
431 }),
432 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
433 "Anthropic currently doesn't support images.".to_string(),
434 )),
435 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
436 Ok(Content::ToolUse {
437 id,
438 name: function.name,
439 input: function.arguments,
440 })
441 }
442 message::AssistantContent::Reasoning(Reasoning {
443 reasoning,
444 signature,
445 ..
446 }) => Ok(Content::Thinking {
447 thinking: reasoning.first().cloned().unwrap_or(String::new()),
448 signature,
449 }),
450 }
451 }
452}
453
454impl TryFrom<message::Message> for Message {
455 type Error = MessageError;
456
457 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
458 Ok(match message {
459 message::Message::User { content } => Message {
460 role: Role::User,
461 content: content.try_map(|content| match content {
462 message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
463 text,
464 cache_control: None,
465 }),
466 message::UserContent::ToolResult(message::ToolResult {
467 id, content, ..
468 }) => Ok(Content::ToolResult {
469 tool_use_id: id,
470 content: content.try_map(|content| match content {
471 message::ToolResultContent::Text(message::Text { text }) => {
472 Ok(ToolResultContent::Text { text })
473 }
474 message::ToolResultContent::Image(image) => {
475 let DocumentSourceKind::Base64(data) = image.data else {
476 return Err(MessageError::ConversionError(
477 "Only base64 strings can be used with the Anthropic API"
478 .to_string(),
479 ));
480 };
481 let media_type =
482 image.media_type.ok_or(MessageError::ConversionError(
483 "Image media type is required".to_owned(),
484 ))?;
485 Ok(ToolResultContent::Image(ImageSource {
486 data: ImageSourceData::Base64(data),
487 media_type: media_type.try_into()?,
488 r#type: SourceType::BASE64,
489 }))
490 }
491 })?,
492 is_error: None,
493 cache_control: None,
494 }),
495 message::UserContent::Image(message::Image {
496 data, media_type, ..
497 }) => {
498 let media_type = media_type.ok_or(MessageError::ConversionError(
499 "Image media type is required for Claude API".to_string(),
500 ))?;
501
502 let source = match data {
503 DocumentSourceKind::Base64(data) => ImageSource {
504 data: ImageSourceData::Base64(data),
505 r#type: SourceType::BASE64,
506 media_type: ImageFormat::try_from(media_type)?,
507 },
508 DocumentSourceKind::Url(url) => ImageSource {
509 data: ImageSourceData::Url(url),
510 r#type: SourceType::URL,
511 media_type: ImageFormat::try_from(media_type)?,
512 },
513 DocumentSourceKind::Unknown => {
514 return Err(MessageError::ConversionError(
515 "Image content has no body".into(),
516 ));
517 }
518 doc => {
519 return Err(MessageError::ConversionError(format!(
520 "Unsupported document type: {doc:?}"
521 )));
522 }
523 };
524
525 Ok(Content::Image {
526 source,
527 cache_control: None,
528 })
529 }
530 message::UserContent::Document(message::Document {
531 data, media_type, ..
532 }) => {
533 let media_type = media_type.ok_or(MessageError::ConversionError(
534 "Document media type is required".to_string(),
535 ))?;
536
537 let data = match data {
538 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
539 data
540 }
541 _ => {
542 return Err(MessageError::ConversionError(
543 "Only base64 encoded documents currently supported".into(),
544 ));
545 }
546 };
547
548 let source = DocumentSource {
549 data,
550 media_type: media_type.try_into()?,
551 r#type: SourceType::BASE64,
552 };
553 Ok(Content::Document {
554 source,
555 cache_control: None,
556 })
557 }
558 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
559 "Audio is not supported in Anthropic".to_owned(),
560 )),
561 message::UserContent::Video { .. } => Err(MessageError::ConversionError(
562 "Video is not supported in Anthropic".to_owned(),
563 )),
564 })?,
565 },
566
567 message::Message::Assistant { content, .. } => Message {
568 content: content.try_map(|content| content.try_into())?,
569 role: Role::Assistant,
570 },
571 })
572 }
573}
574
575impl TryFrom<Content> for message::AssistantContent {
576 type Error = MessageError;
577
578 fn try_from(content: Content) -> Result<Self, Self::Error> {
579 Ok(match content {
580 Content::Text { text, .. } => message::AssistantContent::text(text),
581 Content::ToolUse { id, name, input } => {
582 message::AssistantContent::tool_call(id, name, input)
583 }
584 Content::Thinking {
585 thinking,
586 signature,
587 } => message::AssistantContent::Reasoning(
588 Reasoning::new(&thinking).with_signature(signature),
589 ),
590 _ => {
591 return Err(MessageError::ConversionError(
592 "Content did not contain a message, tool call, or reasoning".to_owned(),
593 ));
594 }
595 })
596 }
597}
598
599impl From<ToolResultContent> for message::ToolResultContent {
600 fn from(content: ToolResultContent) -> Self {
601 match content {
602 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
603 ToolResultContent::Image(ImageSource {
604 data,
605 media_type: format,
606 ..
607 }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
608 }
609 }
610}
611
612impl TryFrom<Message> for message::Message {
613 type Error = MessageError;
614
615 fn try_from(message: Message) -> Result<Self, Self::Error> {
616 Ok(match message.role {
617 Role::User => message::Message::User {
618 content: message.content.try_map(|content| {
619 Ok(match content {
620 Content::Text { text, .. } => message::UserContent::text(text),
621 Content::ToolResult {
622 tool_use_id,
623 content,
624 ..
625 } => message::UserContent::tool_result(
626 tool_use_id,
627 content.map(|content| content.into()),
628 ),
629 Content::Image { source, .. } => {
630 message::UserContent::Image(message::Image {
631 data: source.data.into(),
632 media_type: Some(source.media_type.into()),
633 detail: None,
634 additional_params: None,
635 })
636 }
637 Content::Document { source, .. } => message::UserContent::document(
638 source.data,
639 Some(message::DocumentMediaType::PDF),
640 ),
641 _ => {
642 return Err(MessageError::ConversionError(
643 "Unsupported content type for User role".to_owned(),
644 ));
645 }
646 })
647 })?,
648 },
649 Role::Assistant => match message.content.first() {
650 Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
651 message::Message::Assistant {
652 id: None,
653 content: message.content.try_map(|content| content.try_into())?,
654 }
655 }
656
657 _ => {
658 return Err(MessageError::ConversionError(
659 format!("Unsupported message for Assistant role: {message:?}").to_owned(),
660 ));
661 }
662 },
663 })
664 }
665}
666
667#[derive(Clone)]
668pub struct CompletionModel<T = reqwest::Client> {
669 pub(crate) client: Client<T>,
670 pub model: String,
671 pub default_max_tokens: Option<u64>,
672 pub prompt_caching: bool,
674}
675
676impl<T> CompletionModel<T>
677where
678 T: HttpClientExt,
679{
680 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
681 let model = model.into();
682 let default_max_tokens = calculate_max_tokens(&model);
683
684 Self {
685 client,
686 model,
687 default_max_tokens,
688 prompt_caching: false, }
690 }
691
692 pub fn with_model(client: Client<T>, model: &str) -> Self {
693 Self {
694 client,
695 model: model.to_string(),
696 default_max_tokens: Some(calculate_max_tokens_custom(model)),
697 prompt_caching: false, }
699 }
700
701 pub fn with_prompt_caching(mut self) -> Self {
709 self.prompt_caching = true;
710 self
711 }
712}
713
714fn calculate_max_tokens(model: &str) -> Option<u64> {
718 if model.starts_with("claude-opus-4") {
719 Some(32000)
720 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
721 Some(64000)
722 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
723 Some(8192)
724 } else if model.starts_with("claude-3-opus")
725 || model.starts_with("claude-3-sonnet")
726 || model.starts_with("claude-3-haiku")
727 {
728 Some(4096)
729 } else {
730 None
731 }
732}
733
734fn calculate_max_tokens_custom(model: &str) -> u64 {
735 if model.starts_with("claude-opus-4") {
736 32000
737 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
738 64000
739 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
740 8192
741 } else if model.starts_with("claude-3-opus")
742 || model.starts_with("claude-3-sonnet")
743 || model.starts_with("claude-3-haiku")
744 {
745 4096
746 } else {
747 2048
748 }
749}
750
751#[derive(Debug, Deserialize, Serialize)]
752pub struct Metadata {
753 user_id: Option<String>,
754}
755
756#[derive(Default, Debug, Serialize, Deserialize)]
757#[serde(tag = "type", rename_all = "snake_case")]
758pub enum ToolChoice {
759 #[default]
760 Auto,
761 Any,
762 None,
763 Tool {
764 name: String,
765 },
766}
767impl TryFrom<message::ToolChoice> for ToolChoice {
768 type Error = CompletionError;
769
770 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
771 let res = match value {
772 message::ToolChoice::Auto => Self::Auto,
773 message::ToolChoice::None => Self::None,
774 message::ToolChoice::Required => Self::Any,
775 message::ToolChoice::Specific { function_names } => {
776 if function_names.len() != 1 {
777 return Err(CompletionError::ProviderError(
778 "Only one tool may be specified to be used by Claude".into(),
779 ));
780 }
781
782 Self::Tool {
783 name: function_names.first().unwrap().to_string(),
784 }
785 }
786 };
787
788 Ok(res)
789 }
790}
791
792#[derive(Debug, Deserialize, Serialize)]
793struct AnthropicCompletionRequest {
794 model: String,
795 messages: Vec<Message>,
796 max_tokens: u64,
797 #[serde(skip_serializing_if = "Vec::is_empty")]
799 system: Vec<SystemContent>,
800 #[serde(skip_serializing_if = "Option::is_none")]
801 temperature: Option<f64>,
802 #[serde(skip_serializing_if = "Option::is_none")]
803 tool_choice: Option<ToolChoice>,
804 #[serde(skip_serializing_if = "Vec::is_empty")]
805 tools: Vec<ToolDefinition>,
806 #[serde(flatten, skip_serializing_if = "Option::is_none")]
807 additional_params: Option<serde_json::Value>,
808}
809
810fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
812 match content {
813 Content::Text { cache_control, .. } => *cache_control = value,
814 Content::Image { cache_control, .. } => *cache_control = value,
815 Content::ToolResult { cache_control, .. } => *cache_control = value,
816 Content::Document { cache_control, .. } => *cache_control = value,
817 _ => {}
818 }
819}
820
821pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
826 if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
828 *cache_control = Some(CacheControl::Ephemeral);
829 }
830
831 for msg in messages.iter_mut() {
833 for content in msg.content.iter_mut() {
834 set_content_cache_control(content, None);
835 }
836 }
837
838 if let Some(last_msg) = messages.last_mut() {
840 set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::Ephemeral));
841 }
842}
843
844pub struct AnthropicRequestParams<'a> {
846 pub model: &'a str,
847 pub request: CompletionRequest,
848 pub prompt_caching: bool,
849}
850
851impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
852 type Error = CompletionError;
853
854 fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
855 let AnthropicRequestParams {
856 model,
857 request: req,
858 prompt_caching,
859 } = params;
860
861 let Some(max_tokens) = req.max_tokens else {
863 return Err(CompletionError::RequestError(
864 "`max_tokens` must be set for Anthropic".into(),
865 ));
866 };
867
868 let mut full_history = vec![];
869 if let Some(docs) = req.normalized_documents() {
870 full_history.push(docs);
871 }
872 full_history.extend(req.chat_history);
873
874 let mut messages = full_history
875 .into_iter()
876 .map(Message::try_from)
877 .collect::<Result<Vec<Message>, _>>()?;
878
879 let tools = req
880 .tools
881 .into_iter()
882 .map(|tool| ToolDefinition {
883 name: tool.name,
884 description: Some(tool.description),
885 input_schema: tool.parameters,
886 })
887 .collect::<Vec<_>>();
888
889 let mut system = if let Some(preamble) = req.preamble {
891 if preamble.is_empty() {
892 vec![]
893 } else {
894 vec![SystemContent::Text {
895 text: preamble,
896 cache_control: None,
897 }]
898 }
899 } else {
900 vec![]
901 };
902
903 if prompt_caching {
905 apply_cache_control(&mut system, &mut messages);
906 }
907
908 Ok(Self {
909 model: model.to_string(),
910 messages,
911 max_tokens,
912 system,
913 temperature: req.temperature,
914 tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
915 tools,
916 additional_params: req.additional_params,
917 })
918 }
919}
920
921impl<T> completion::CompletionModel for CompletionModel<T>
922where
923 T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
924{
925 type Response = CompletionResponse;
926 type StreamingResponse = StreamingCompletionResponse;
927 type Client = Client<T>;
928
929 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
930 Self::new(client.clone(), model.into())
931 }
932
933 async fn completion(
934 &self,
935 mut completion_request: completion::CompletionRequest,
936 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
937 let span = if tracing::Span::current().is_disabled() {
938 info_span!(
939 target: "rig::completions",
940 "chat",
941 gen_ai.operation.name = "chat",
942 gen_ai.provider.name = "anthropic",
943 gen_ai.request.model = &self.model,
944 gen_ai.system_instructions = &completion_request.preamble,
945 gen_ai.response.id = tracing::field::Empty,
946 gen_ai.response.model = tracing::field::Empty,
947 gen_ai.usage.output_tokens = tracing::field::Empty,
948 gen_ai.usage.input_tokens = tracing::field::Empty,
949 )
950 } else {
951 tracing::Span::current()
952 };
953
954 if completion_request.max_tokens.is_none() {
956 if let Some(tokens) = self.default_max_tokens {
957 completion_request.max_tokens = Some(tokens);
958 } else {
959 return Err(CompletionError::RequestError(
960 "`max_tokens` must be set for Anthropic".into(),
961 ));
962 }
963 }
964
965 let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
966 model: &self.model,
967 request: completion_request,
968 prompt_caching: self.prompt_caching,
969 })?;
970
971 if enabled!(Level::TRACE) {
972 tracing::trace!(
973 target: "rig::completions",
974 "Anthropic completion request: {}",
975 serde_json::to_string_pretty(&request)?
976 );
977 }
978
979 async move {
980 let request: Vec<u8> = serde_json::to_vec(&request)?;
981
982 let req = self
983 .client
984 .post("/v1/messages")?
985 .body(request)
986 .map_err(|e| CompletionError::HttpError(e.into()))?;
987
988 let response = self
989 .client
990 .send::<_, Bytes>(req)
991 .await
992 .map_err(CompletionError::HttpError)?;
993
994 if response.status().is_success() {
995 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
996 response
997 .into_body()
998 .await
999 .map_err(CompletionError::HttpError)?
1000 .to_vec()
1001 .as_slice(),
1002 )? {
1003 ApiResponse::Message(completion) => {
1004 let span = tracing::Span::current();
1005 span.record_response_metadata(&completion);
1006 span.record_token_usage(&completion.usage);
1007 if enabled!(Level::TRACE) {
1008 tracing::trace!(
1009 target: "rig::completions",
1010 "Anthropic completion response: {}",
1011 serde_json::to_string_pretty(&completion)?
1012 );
1013 }
1014 completion.try_into()
1015 }
1016 ApiResponse::Error(ApiErrorResponse { message }) => {
1017 Err(CompletionError::ResponseError(message))
1018 }
1019 }
1020 } else {
1021 let text: String = String::from_utf8_lossy(
1022 &response
1023 .into_body()
1024 .await
1025 .map_err(CompletionError::HttpError)?,
1026 )
1027 .into();
1028 Err(CompletionError::ProviderError(text))
1029 }
1030 }
1031 .instrument(span)
1032 .await
1033 }
1034
1035 async fn stream(
1036 &self,
1037 request: CompletionRequest,
1038 ) -> Result<
1039 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1040 CompletionError,
1041 > {
1042 CompletionModel::stream(self, request).await
1043 }
1044}
1045
1046#[derive(Debug, Deserialize)]
1047struct ApiErrorResponse {
1048 message: String,
1049}
1050
1051#[derive(Debug, Deserialize)]
1052#[serde(tag = "type", rename_all = "snake_case")]
1053enum ApiResponse<T> {
1054 Message(T),
1055 Error(ApiErrorResponse),
1056}
1057
1058#[cfg(test)]
1059mod tests {
1060 use super::*;
1061 use serde_json::json;
1062 use serde_path_to_error::deserialize;
1063
1064 #[test]
1065 fn test_deserialize_message() {
1066 let assistant_message_json = r#"
1067 {
1068 "role": "assistant",
1069 "content": "\n\nHello there, how may I assist you today?"
1070 }
1071 "#;
1072
1073 let assistant_message_json2 = r#"
1074 {
1075 "role": "assistant",
1076 "content": [
1077 {
1078 "type": "text",
1079 "text": "\n\nHello there, how may I assist you today?"
1080 },
1081 {
1082 "type": "tool_use",
1083 "id": "toolu_01A09q90qw90lq917835lq9",
1084 "name": "get_weather",
1085 "input": {"location": "San Francisco, CA"}
1086 }
1087 ]
1088 }
1089 "#;
1090
1091 let user_message_json = r#"
1092 {
1093 "role": "user",
1094 "content": [
1095 {
1096 "type": "image",
1097 "source": {
1098 "type": "base64",
1099 "media_type": "image/jpeg",
1100 "data": "/9j/4AAQSkZJRg..."
1101 }
1102 },
1103 {
1104 "type": "text",
1105 "text": "What is in this image?"
1106 },
1107 {
1108 "type": "tool_result",
1109 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1110 "content": "15 degrees"
1111 }
1112 ]
1113 }
1114 "#;
1115
1116 let assistant_message: Message = {
1117 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1118 deserialize(jd).unwrap_or_else(|err| {
1119 panic!("Deserialization error at {}: {}", err.path(), err);
1120 })
1121 };
1122
1123 let assistant_message2: Message = {
1124 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1125 deserialize(jd).unwrap_or_else(|err| {
1126 panic!("Deserialization error at {}: {}", err.path(), err);
1127 })
1128 };
1129
1130 let user_message: Message = {
1131 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1132 deserialize(jd).unwrap_or_else(|err| {
1133 panic!("Deserialization error at {}: {}", err.path(), err);
1134 })
1135 };
1136
1137 let Message { role, content } = assistant_message;
1138 assert_eq!(role, Role::Assistant);
1139 assert_eq!(
1140 content.first(),
1141 Content::Text {
1142 text: "\n\nHello there, how may I assist you today?".to_owned(),
1143 cache_control: None,
1144 }
1145 );
1146
1147 let Message { role, content } = assistant_message2;
1148 {
1149 assert_eq!(role, Role::Assistant);
1150 assert_eq!(content.len(), 2);
1151
1152 let mut iter = content.into_iter();
1153
1154 match iter.next().unwrap() {
1155 Content::Text { text, .. } => {
1156 assert_eq!(text, "\n\nHello there, how may I assist you today?");
1157 }
1158 _ => panic!("Expected text content"),
1159 }
1160
1161 match iter.next().unwrap() {
1162 Content::ToolUse { id, name, input } => {
1163 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1164 assert_eq!(name, "get_weather");
1165 assert_eq!(input, json!({"location": "San Francisco, CA"}));
1166 }
1167 _ => panic!("Expected tool use content"),
1168 }
1169
1170 assert_eq!(iter.next(), None);
1171 }
1172
1173 let Message { role, content } = user_message;
1174 {
1175 assert_eq!(role, Role::User);
1176 assert_eq!(content.len(), 3);
1177
1178 let mut iter = content.into_iter();
1179
1180 match iter.next().unwrap() {
1181 Content::Image { source, .. } => {
1182 assert_eq!(
1183 source,
1184 ImageSource {
1185 data: ImageSourceData::Base64("/9j/4AAQSkZJRg...".to_owned()),
1186 media_type: ImageFormat::JPEG,
1187 r#type: SourceType::BASE64,
1188 }
1189 );
1190 }
1191 _ => panic!("Expected image content"),
1192 }
1193
1194 match iter.next().unwrap() {
1195 Content::Text { text, .. } => {
1196 assert_eq!(text, "What is in this image?");
1197 }
1198 _ => panic!("Expected text content"),
1199 }
1200
1201 match iter.next().unwrap() {
1202 Content::ToolResult {
1203 tool_use_id,
1204 content,
1205 is_error,
1206 ..
1207 } => {
1208 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1209 assert_eq!(
1210 content.first(),
1211 ToolResultContent::Text {
1212 text: "15 degrees".to_owned()
1213 }
1214 );
1215 assert_eq!(is_error, None);
1216 }
1217 _ => panic!("Expected tool result content"),
1218 }
1219
1220 assert_eq!(iter.next(), None);
1221 }
1222 }
1223
1224 #[test]
1225 fn test_message_to_message_conversion() {
1226 let user_message: Message = serde_json::from_str(
1227 r#"
1228 {
1229 "role": "user",
1230 "content": [
1231 {
1232 "type": "image",
1233 "source": {
1234 "type": "base64",
1235 "media_type": "image/jpeg",
1236 "data": "/9j/4AAQSkZJRg..."
1237 }
1238 },
1239 {
1240 "type": "text",
1241 "text": "What is in this image?"
1242 },
1243 {
1244 "type": "document",
1245 "source": {
1246 "type": "base64",
1247 "data": "base64_encoded_pdf_data",
1248 "media_type": "application/pdf"
1249 }
1250 }
1251 ]
1252 }
1253 "#,
1254 )
1255 .unwrap();
1256
1257 let assistant_message = Message {
1258 role: Role::Assistant,
1259 content: OneOrMany::one(Content::ToolUse {
1260 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1261 name: "get_weather".to_string(),
1262 input: json!({"location": "San Francisco, CA"}),
1263 }),
1264 };
1265
1266 let tool_message = Message {
1267 role: Role::User,
1268 content: OneOrMany::one(Content::ToolResult {
1269 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1270 content: OneOrMany::one(ToolResultContent::Text {
1271 text: "15 degrees".to_string(),
1272 }),
1273 is_error: None,
1274 cache_control: None,
1275 }),
1276 };
1277
1278 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1279 let converted_assistant_message: message::Message =
1280 assistant_message.clone().try_into().unwrap();
1281 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1282
1283 match converted_user_message.clone() {
1284 message::Message::User { content } => {
1285 assert_eq!(content.len(), 3);
1286
1287 let mut iter = content.into_iter();
1288
1289 match iter.next().unwrap() {
1290 message::UserContent::Image(message::Image {
1291 data, media_type, ..
1292 }) => {
1293 assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1294 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1295 }
1296 _ => panic!("Expected image content"),
1297 }
1298
1299 match iter.next().unwrap() {
1300 message::UserContent::Text(message::Text { text }) => {
1301 assert_eq!(text, "What is in this image?");
1302 }
1303 _ => panic!("Expected text content"),
1304 }
1305
1306 match iter.next().unwrap() {
1307 message::UserContent::Document(message::Document {
1308 data, media_type, ..
1309 }) => {
1310 assert_eq!(
1311 data,
1312 DocumentSourceKind::String("base64_encoded_pdf_data".into())
1313 );
1314 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1315 }
1316 _ => panic!("Expected document content"),
1317 }
1318
1319 assert_eq!(iter.next(), None);
1320 }
1321 _ => panic!("Expected user message"),
1322 }
1323
1324 match converted_tool_message.clone() {
1325 message::Message::User { content } => {
1326 let message::ToolResult { id, content, .. } = match content.first() {
1327 message::UserContent::ToolResult(tool_result) => tool_result,
1328 _ => panic!("Expected tool result content"),
1329 };
1330 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1331 match content.first() {
1332 message::ToolResultContent::Text(message::Text { text }) => {
1333 assert_eq!(text, "15 degrees");
1334 }
1335 _ => panic!("Expected text content"),
1336 }
1337 }
1338 _ => panic!("Expected tool result content"),
1339 }
1340
1341 match converted_assistant_message.clone() {
1342 message::Message::Assistant { content, .. } => {
1343 assert_eq!(content.len(), 1);
1344
1345 match content.first() {
1346 message::AssistantContent::ToolCall(message::ToolCall {
1347 id, function, ..
1348 }) => {
1349 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1350 assert_eq!(function.name, "get_weather");
1351 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1352 }
1353 _ => panic!("Expected tool call content"),
1354 }
1355 }
1356 _ => panic!("Expected assistant message"),
1357 }
1358
1359 let original_user_message: Message = converted_user_message.try_into().unwrap();
1360 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1361 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1362
1363 assert_eq!(user_message, original_user_message);
1364 assert_eq!(assistant_message, original_assistant_message);
1365 assert_eq!(tool_message, original_tool_message);
1366 }
1367
1368 #[test]
1369 fn test_content_format_conversion() {
1370 use crate::completion::message::ContentFormat;
1371
1372 let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1373 assert_eq!(source_type, SourceType::URL);
1374
1375 let content_format: ContentFormat = SourceType::URL.into();
1376 assert_eq!(content_format, ContentFormat::Url);
1377
1378 let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1379 assert_eq!(source_type, SourceType::BASE64);
1380
1381 let content_format: ContentFormat = SourceType::BASE64.into();
1382 assert_eq!(content_format, ContentFormat::Base64);
1383
1384 let result: Result<SourceType, _> = ContentFormat::String.try_into();
1385 assert!(result.is_err());
1386 assert!(
1387 result
1388 .unwrap_err()
1389 .to_string()
1390 .contains("ContentFormat::String is deprecated")
1391 );
1392 }
1393
1394 #[test]
1395 fn test_cache_control_serialization() {
1396 let system = SystemContent::Text {
1398 text: "You are a helpful assistant.".to_string(),
1399 cache_control: Some(CacheControl::Ephemeral),
1400 };
1401 let json = serde_json::to_string(&system).unwrap();
1402 assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1403 assert!(json.contains(r#""type":"text""#));
1404
1405 let system_no_cache = SystemContent::Text {
1407 text: "Hello".to_string(),
1408 cache_control: None,
1409 };
1410 let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1411 assert!(!json_no_cache.contains("cache_control"));
1412
1413 let content = Content::Text {
1415 text: "Test message".to_string(),
1416 cache_control: Some(CacheControl::Ephemeral),
1417 };
1418 let json_content = serde_json::to_string(&content).unwrap();
1419 assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1420
1421 let mut system_vec = vec![SystemContent::Text {
1423 text: "System prompt".to_string(),
1424 cache_control: None,
1425 }];
1426 let mut messages = vec![
1427 Message {
1428 role: Role::User,
1429 content: OneOrMany::one(Content::Text {
1430 text: "First message".to_string(),
1431 cache_control: None,
1432 }),
1433 },
1434 Message {
1435 role: Role::Assistant,
1436 content: OneOrMany::one(Content::Text {
1437 text: "Response".to_string(),
1438 cache_control: None,
1439 }),
1440 },
1441 ];
1442
1443 apply_cache_control(&mut system_vec, &mut messages);
1444
1445 match &system_vec[0] {
1447 SystemContent::Text { cache_control, .. } => {
1448 assert!(cache_control.is_some());
1449 }
1450 }
1451
1452 for content in messages[0].content.iter() {
1455 if let Content::Text { cache_control, .. } = content {
1456 assert!(cache_control.is_none());
1457 }
1458 }
1459
1460 for content in messages[1].content.iter() {
1462 if let Content::Text { cache_control, .. } = content {
1463 assert!(cache_control.is_some());
1464 }
1465 }
1466 }
1467}