1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError, GetTokenUsage},
6 http_client::HttpClientExt,
7 message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
8 one_or_many::string_or_one_or_many,
9 telemetry::{ProviderResponseExt, SpanCombinator},
10 wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42 pub content: Vec<Content>,
43 pub id: String,
44 pub model: String,
45 pub role: String,
46 pub stop_reason: Option<String>,
47 pub stop_sequence: Option<String>,
48 pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52 type OutputMessage = Content;
53 type Usage = Usage;
54
55 fn get_response_id(&self) -> Option<String> {
56 Some(self.id.to_owned())
57 }
58
59 fn get_response_model_name(&self) -> Option<String> {
60 Some(self.model.to_owned())
61 }
62
63 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64 self.content.clone()
65 }
66
67 fn get_text_response(&self) -> Option<String> {
68 let res = self
69 .content
70 .iter()
71 .filter_map(|x| {
72 if let Content::Text { text, .. } = x {
73 Some(text.to_owned())
74 } else {
75 None
76 }
77 })
78 .collect::<Vec<String>>()
79 .join("\n");
80
81 if res.is_empty() { None } else { Some(res) }
82 }
83
84 fn get_usage(&self) -> Option<Self::Usage> {
85 Some(self.usage.clone())
86 }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91 pub input_tokens: u64,
92 pub cache_read_input_tokens: Option<u64>,
93 pub cache_creation_input_tokens: Option<u64>,
94 pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 write!(
100 f,
101 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102 self.input_tokens,
103 match self.cache_read_input_tokens {
104 Some(token) => token.to_string(),
105 None => "n/a".to_string(),
106 },
107 match self.cache_creation_input_tokens {
108 Some(token) => token.to_string(),
109 None => "n/a".to_string(),
110 },
111 self.output_tokens
112 )
113 }
114}
115
116impl GetTokenUsage for Usage {
117 fn token_usage(&self) -> Option<crate::completion::Usage> {
118 let mut usage = crate::completion::Usage::new();
119
120 usage.input_tokens = self.input_tokens
121 + self.cache_creation_input_tokens.unwrap_or_default()
122 + self.cache_read_input_tokens.unwrap_or_default();
123 usage.output_tokens = self.output_tokens;
124 usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126 Some(usage)
127 }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132 pub name: String,
133 pub description: Option<String>,
134 pub input_schema: serde_json::Value,
135}
136
137#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum CacheControl {
141 Ephemeral,
142}
143
144#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
146#[serde(tag = "type", rename_all = "snake_case")]
147pub enum SystemContent {
148 Text {
149 text: String,
150 #[serde(skip_serializing_if = "Option::is_none")]
151 cache_control: Option<CacheControl>,
152 },
153}
154
155impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
156 type Error = CompletionError;
157
158 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
159 let content = response
160 .content
161 .iter()
162 .map(|content| content.clone().try_into())
163 .collect::<Result<Vec<_>, _>>()?;
164
165 let choice = OneOrMany::many(content).map_err(|_| {
166 CompletionError::ResponseError(
167 "Response contained no message or tool call (empty)".to_owned(),
168 )
169 })?;
170
171 let usage = completion::Usage {
172 input_tokens: response.usage.input_tokens,
173 output_tokens: response.usage.output_tokens,
174 total_tokens: response.usage.input_tokens + response.usage.output_tokens,
175 cached_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
176 };
177
178 Ok(completion::CompletionResponse {
179 choice,
180 usage,
181 raw_response: response,
182 })
183 }
184}
185
186#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
187pub struct Message {
188 pub role: Role,
189 #[serde(deserialize_with = "string_or_one_or_many")]
190 pub content: OneOrMany<Content>,
191}
192
193#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
194#[serde(rename_all = "lowercase")]
195pub enum Role {
196 User,
197 Assistant,
198}
199
200#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
201#[serde(tag = "type", rename_all = "snake_case")]
202pub enum Content {
203 Text {
204 text: String,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 cache_control: Option<CacheControl>,
207 },
208 Image {
209 source: ImageSource,
210 #[serde(skip_serializing_if = "Option::is_none")]
211 cache_control: Option<CacheControl>,
212 },
213 ToolUse {
214 id: String,
215 name: String,
216 input: serde_json::Value,
217 },
218 ToolResult {
219 tool_use_id: String,
220 #[serde(deserialize_with = "string_or_one_or_many")]
221 content: OneOrMany<ToolResultContent>,
222 #[serde(skip_serializing_if = "Option::is_none")]
223 is_error: Option<bool>,
224 #[serde(skip_serializing_if = "Option::is_none")]
225 cache_control: Option<CacheControl>,
226 },
227 Document {
228 source: DocumentSource,
229 #[serde(skip_serializing_if = "Option::is_none")]
230 cache_control: Option<CacheControl>,
231 },
232 Thinking {
233 thinking: String,
234 #[serde(skip_serializing_if = "Option::is_none")]
235 signature: Option<String>,
236 },
237}
238
239impl FromStr for Content {
240 type Err = Infallible;
241
242 fn from_str(s: &str) -> Result<Self, Self::Err> {
243 Ok(Content::Text {
244 text: s.to_owned(),
245 cache_control: None,
246 })
247 }
248}
249
250#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
251#[serde(tag = "type", rename_all = "snake_case")]
252pub enum ToolResultContent {
253 Text { text: String },
254 Image(ImageSource),
255}
256
257impl FromStr for ToolResultContent {
258 type Err = Infallible;
259
260 fn from_str(s: &str) -> Result<Self, Self::Err> {
261 Ok(ToolResultContent::Text { text: s.to_owned() })
262 }
263}
264
265#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
266#[serde(untagged)]
267pub enum ImageSourceData {
268 Base64(String),
269 Url(String),
270}
271
272impl From<ImageSourceData> for DocumentSourceKind {
273 fn from(value: ImageSourceData) -> Self {
274 match value {
275 ImageSourceData::Base64(data) => DocumentSourceKind::Base64(data),
276 ImageSourceData::Url(url) => DocumentSourceKind::Url(url),
277 }
278 }
279}
280
281impl TryFrom<DocumentSourceKind> for ImageSourceData {
282 type Error = MessageError;
283
284 fn try_from(value: DocumentSourceKind) -> Result<Self, Self::Error> {
285 match value {
286 DocumentSourceKind::Base64(data) => Ok(ImageSourceData::Base64(data)),
287 DocumentSourceKind::Url(url) => Ok(ImageSourceData::Url(url)),
288 _ => Err(MessageError::ConversionError("Content has no body".into())),
289 }
290 }
291}
292
293impl From<ImageSourceData> for String {
294 fn from(value: ImageSourceData) -> Self {
295 match value {
296 ImageSourceData::Base64(s) | ImageSourceData::Url(s) => s,
297 }
298 }
299}
300
301#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
302pub struct ImageSource {
303 pub data: ImageSourceData,
304 pub media_type: ImageFormat,
305 pub r#type: SourceType,
306}
307
308#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
309pub struct DocumentSource {
310 pub data: String,
311 pub media_type: DocumentFormat,
312 pub r#type: SourceType,
313}
314
315#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
316#[serde(rename_all = "lowercase")]
317pub enum ImageFormat {
318 #[serde(rename = "image/jpeg")]
319 JPEG,
320 #[serde(rename = "image/png")]
321 PNG,
322 #[serde(rename = "image/gif")]
323 GIF,
324 #[serde(rename = "image/webp")]
325 WEBP,
326}
327
328#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
332#[serde(rename_all = "lowercase")]
333pub enum DocumentFormat {
334 #[serde(rename = "application/pdf")]
335 PDF,
336}
337
338#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
339#[serde(rename_all = "lowercase")]
340pub enum SourceType {
341 BASE64,
342 URL,
343}
344
345impl From<String> for Content {
346 fn from(text: String) -> Self {
347 Content::Text {
348 text,
349 cache_control: None,
350 }
351 }
352}
353
354impl From<String> for ToolResultContent {
355 fn from(text: String) -> Self {
356 ToolResultContent::Text { text }
357 }
358}
359
360impl TryFrom<message::ContentFormat> for SourceType {
361 type Error = MessageError;
362
363 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
364 match format {
365 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
366 message::ContentFormat::Url => Ok(SourceType::URL),
367 message::ContentFormat::String => Err(MessageError::ConversionError(
368 "ContentFormat::String is deprecated, use ContentFormat::Url for URLs".into(),
369 )),
370 }
371 }
372}
373
374impl From<SourceType> for message::ContentFormat {
375 fn from(source_type: SourceType) -> Self {
376 match source_type {
377 SourceType::BASE64 => message::ContentFormat::Base64,
378 SourceType::URL => message::ContentFormat::Url,
379 }
380 }
381}
382
383impl TryFrom<message::ImageMediaType> for ImageFormat {
384 type Error = MessageError;
385
386 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
387 Ok(match media_type {
388 message::ImageMediaType::JPEG => ImageFormat::JPEG,
389 message::ImageMediaType::PNG => ImageFormat::PNG,
390 message::ImageMediaType::GIF => ImageFormat::GIF,
391 message::ImageMediaType::WEBP => ImageFormat::WEBP,
392 _ => {
393 return Err(MessageError::ConversionError(
394 format!("Unsupported image media type: {media_type:?}").to_owned(),
395 ));
396 }
397 })
398 }
399}
400
401impl From<ImageFormat> for message::ImageMediaType {
402 fn from(format: ImageFormat) -> Self {
403 match format {
404 ImageFormat::JPEG => message::ImageMediaType::JPEG,
405 ImageFormat::PNG => message::ImageMediaType::PNG,
406 ImageFormat::GIF => message::ImageMediaType::GIF,
407 ImageFormat::WEBP => message::ImageMediaType::WEBP,
408 }
409 }
410}
411
412impl TryFrom<DocumentMediaType> for DocumentFormat {
413 type Error = MessageError;
414 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
415 if !matches!(value, DocumentMediaType::PDF) {
416 return Err(MessageError::ConversionError(
417 "Anthropic only supports PDF documents".to_string(),
418 ));
419 };
420
421 Ok(DocumentFormat::PDF)
422 }
423}
424
425impl TryFrom<message::AssistantContent> for Content {
426 type Error = MessageError;
427 fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
428 match text {
429 message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
430 text,
431 cache_control: None,
432 }),
433 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
434 "Anthropic currently doesn't support images.".to_string(),
435 )),
436 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
437 Ok(Content::ToolUse {
438 id,
439 name: function.name,
440 input: function.arguments,
441 })
442 }
443 message::AssistantContent::Reasoning(Reasoning {
444 reasoning,
445 signature,
446 ..
447 }) => Ok(Content::Thinking {
448 thinking: reasoning.first().cloned().unwrap_or(String::new()),
449 signature,
450 }),
451 }
452 }
453}
454
455impl TryFrom<message::Message> for Message {
456 type Error = MessageError;
457
458 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
459 Ok(match message {
460 message::Message::User { content } => Message {
461 role: Role::User,
462 content: content.try_map(|content| match content {
463 message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
464 text,
465 cache_control: None,
466 }),
467 message::UserContent::ToolResult(message::ToolResult {
468 id, content, ..
469 }) => Ok(Content::ToolResult {
470 tool_use_id: id,
471 content: content.try_map(|content| match content {
472 message::ToolResultContent::Text(message::Text { text }) => {
473 Ok(ToolResultContent::Text { text })
474 }
475 message::ToolResultContent::Image(image) => {
476 let DocumentSourceKind::Base64(data) = image.data else {
477 return Err(MessageError::ConversionError(
478 "Only base64 strings can be used with the Anthropic API"
479 .to_string(),
480 ));
481 };
482 let media_type =
483 image.media_type.ok_or(MessageError::ConversionError(
484 "Image media type is required".to_owned(),
485 ))?;
486 Ok(ToolResultContent::Image(ImageSource {
487 data: ImageSourceData::Base64(data),
488 media_type: media_type.try_into()?,
489 r#type: SourceType::BASE64,
490 }))
491 }
492 })?,
493 is_error: None,
494 cache_control: None,
495 }),
496 message::UserContent::Image(message::Image {
497 data, media_type, ..
498 }) => {
499 let media_type = media_type.ok_or(MessageError::ConversionError(
500 "Image media type is required for Claude API".to_string(),
501 ))?;
502
503 let source = match data {
504 DocumentSourceKind::Base64(data) => ImageSource {
505 data: ImageSourceData::Base64(data),
506 r#type: SourceType::BASE64,
507 media_type: ImageFormat::try_from(media_type)?,
508 },
509 DocumentSourceKind::Url(url) => ImageSource {
510 data: ImageSourceData::Url(url),
511 r#type: SourceType::URL,
512 media_type: ImageFormat::try_from(media_type)?,
513 },
514 DocumentSourceKind::Unknown => {
515 return Err(MessageError::ConversionError(
516 "Image content has no body".into(),
517 ));
518 }
519 doc => {
520 return Err(MessageError::ConversionError(format!(
521 "Unsupported document type: {doc:?}"
522 )));
523 }
524 };
525
526 Ok(Content::Image {
527 source,
528 cache_control: None,
529 })
530 }
531 message::UserContent::Document(message::Document {
532 data, media_type, ..
533 }) => {
534 let media_type = media_type.ok_or(MessageError::ConversionError(
535 "Document media type is required".to_string(),
536 ))?;
537
538 let data = match data {
539 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
540 data
541 }
542 _ => {
543 return Err(MessageError::ConversionError(
544 "Only base64 encoded documents currently supported".into(),
545 ));
546 }
547 };
548
549 let source = DocumentSource {
550 data,
551 media_type: media_type.try_into()?,
552 r#type: SourceType::BASE64,
553 };
554 Ok(Content::Document {
555 source,
556 cache_control: None,
557 })
558 }
559 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
560 "Audio is not supported in Anthropic".to_owned(),
561 )),
562 message::UserContent::Video { .. } => Err(MessageError::ConversionError(
563 "Video is not supported in Anthropic".to_owned(),
564 )),
565 })?,
566 },
567
568 message::Message::Assistant { content, .. } => Message {
569 content: content.try_map(|content| content.try_into())?,
570 role: Role::Assistant,
571 },
572 })
573 }
574}
575
576impl TryFrom<Content> for message::AssistantContent {
577 type Error = MessageError;
578
579 fn try_from(content: Content) -> Result<Self, Self::Error> {
580 Ok(match content {
581 Content::Text { text, .. } => message::AssistantContent::text(text),
582 Content::ToolUse { id, name, input } => {
583 message::AssistantContent::tool_call(id, name, input)
584 }
585 Content::Thinking {
586 thinking,
587 signature,
588 } => message::AssistantContent::Reasoning(
589 Reasoning::new(&thinking).with_signature(signature),
590 ),
591 _ => {
592 return Err(MessageError::ConversionError(
593 "Content did not contain a message, tool call, or reasoning".to_owned(),
594 ));
595 }
596 })
597 }
598}
599
600impl From<ToolResultContent> for message::ToolResultContent {
601 fn from(content: ToolResultContent) -> Self {
602 match content {
603 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
604 ToolResultContent::Image(ImageSource {
605 data,
606 media_type: format,
607 ..
608 }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
609 }
610 }
611}
612
613impl TryFrom<Message> for message::Message {
614 type Error = MessageError;
615
616 fn try_from(message: Message) -> Result<Self, Self::Error> {
617 Ok(match message.role {
618 Role::User => message::Message::User {
619 content: message.content.try_map(|content| {
620 Ok(match content {
621 Content::Text { text, .. } => message::UserContent::text(text),
622 Content::ToolResult {
623 tool_use_id,
624 content,
625 ..
626 } => message::UserContent::tool_result(
627 tool_use_id,
628 content.map(|content| content.into()),
629 ),
630 Content::Image { source, .. } => {
631 message::UserContent::Image(message::Image {
632 data: source.data.into(),
633 media_type: Some(source.media_type.into()),
634 detail: None,
635 additional_params: None,
636 })
637 }
638 Content::Document { source, .. } => message::UserContent::document(
639 source.data,
640 Some(message::DocumentMediaType::PDF),
641 ),
642 _ => {
643 return Err(MessageError::ConversionError(
644 "Unsupported content type for User role".to_owned(),
645 ));
646 }
647 })
648 })?,
649 },
650 Role::Assistant => match message.content.first() {
651 Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
652 message::Message::Assistant {
653 id: None,
654 content: message.content.try_map(|content| content.try_into())?,
655 }
656 }
657
658 _ => {
659 return Err(MessageError::ConversionError(
660 format!("Unsupported message for Assistant role: {message:?}").to_owned(),
661 ));
662 }
663 },
664 })
665 }
666}
667
668#[derive(Clone)]
669pub struct CompletionModel<T = reqwest::Client> {
670 pub(crate) client: Client<T>,
671 pub model: String,
672 pub default_max_tokens: Option<u64>,
673 pub prompt_caching: bool,
675}
676
677impl<T> CompletionModel<T>
678where
679 T: HttpClientExt,
680{
681 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
682 let model = model.into();
683 let default_max_tokens = calculate_max_tokens(&model);
684
685 Self {
686 client,
687 model,
688 default_max_tokens,
689 prompt_caching: false, }
691 }
692
693 pub fn with_model(client: Client<T>, model: &str) -> Self {
694 Self {
695 client,
696 model: model.to_string(),
697 default_max_tokens: Some(calculate_max_tokens_custom(model)),
698 prompt_caching: false, }
700 }
701
702 pub fn with_prompt_caching(mut self) -> Self {
710 self.prompt_caching = true;
711 self
712 }
713}
714
715fn calculate_max_tokens(model: &str) -> Option<u64> {
719 if model.starts_with("claude-opus-4") {
720 Some(32000)
721 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
722 Some(64000)
723 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
724 Some(8192)
725 } else if model.starts_with("claude-3-opus")
726 || model.starts_with("claude-3-sonnet")
727 || model.starts_with("claude-3-haiku")
728 {
729 Some(4096)
730 } else {
731 None
732 }
733}
734
735fn calculate_max_tokens_custom(model: &str) -> u64 {
736 if model.starts_with("claude-opus-4") {
737 32000
738 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
739 64000
740 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
741 8192
742 } else if model.starts_with("claude-3-opus")
743 || model.starts_with("claude-3-sonnet")
744 || model.starts_with("claude-3-haiku")
745 {
746 4096
747 } else {
748 2048
749 }
750}
751
752#[derive(Debug, Deserialize, Serialize)]
753pub struct Metadata {
754 user_id: Option<String>,
755}
756
757#[derive(Default, Debug, Serialize, Deserialize)]
758#[serde(tag = "type", rename_all = "snake_case")]
759pub enum ToolChoice {
760 #[default]
761 Auto,
762 Any,
763 None,
764 Tool {
765 name: String,
766 },
767}
768impl TryFrom<message::ToolChoice> for ToolChoice {
769 type Error = CompletionError;
770
771 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
772 let res = match value {
773 message::ToolChoice::Auto => Self::Auto,
774 message::ToolChoice::None => Self::None,
775 message::ToolChoice::Required => Self::Any,
776 message::ToolChoice::Specific { function_names } => {
777 if function_names.len() != 1 {
778 return Err(CompletionError::ProviderError(
779 "Only one tool may be specified to be used by Claude".into(),
780 ));
781 }
782
783 Self::Tool {
784 name: function_names.first().unwrap().to_string(),
785 }
786 }
787 };
788
789 Ok(res)
790 }
791}
792
793#[derive(Debug, Deserialize, Serialize)]
794struct AnthropicCompletionRequest {
795 model: String,
796 messages: Vec<Message>,
797 max_tokens: u64,
798 #[serde(skip_serializing_if = "Vec::is_empty")]
800 system: Vec<SystemContent>,
801 #[serde(skip_serializing_if = "Option::is_none")]
802 temperature: Option<f64>,
803 #[serde(skip_serializing_if = "Option::is_none")]
804 tool_choice: Option<ToolChoice>,
805 #[serde(skip_serializing_if = "Vec::is_empty")]
806 tools: Vec<ToolDefinition>,
807 #[serde(flatten, skip_serializing_if = "Option::is_none")]
808 additional_params: Option<serde_json::Value>,
809}
810
811fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
813 match content {
814 Content::Text { cache_control, .. } => *cache_control = value,
815 Content::Image { cache_control, .. } => *cache_control = value,
816 Content::ToolResult { cache_control, .. } => *cache_control = value,
817 Content::Document { cache_control, .. } => *cache_control = value,
818 _ => {}
819 }
820}
821
822pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
827 if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
829 *cache_control = Some(CacheControl::Ephemeral);
830 }
831
832 for msg in messages.iter_mut() {
834 for content in msg.content.iter_mut() {
835 set_content_cache_control(content, None);
836 }
837 }
838
839 if let Some(last_msg) = messages.last_mut() {
841 set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::Ephemeral));
842 }
843}
844
845pub struct AnthropicRequestParams<'a> {
847 pub model: &'a str,
848 pub request: CompletionRequest,
849 pub prompt_caching: bool,
850}
851
852impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
853 type Error = CompletionError;
854
855 fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
856 let AnthropicRequestParams {
857 model,
858 request: req,
859 prompt_caching,
860 } = params;
861
862 let Some(max_tokens) = req.max_tokens else {
864 return Err(CompletionError::RequestError(
865 "`max_tokens` must be set for Anthropic".into(),
866 ));
867 };
868
869 let mut full_history = vec![];
870 if let Some(docs) = req.normalized_documents() {
871 full_history.push(docs);
872 }
873 full_history.extend(req.chat_history);
874
875 let mut messages = full_history
876 .into_iter()
877 .map(Message::try_from)
878 .collect::<Result<Vec<Message>, _>>()?;
879
880 let tools = req
881 .tools
882 .into_iter()
883 .map(|tool| ToolDefinition {
884 name: tool.name,
885 description: Some(tool.description),
886 input_schema: tool.parameters,
887 })
888 .collect::<Vec<_>>();
889
890 let mut system = if let Some(preamble) = req.preamble {
892 if preamble.is_empty() {
893 vec![]
894 } else {
895 vec![SystemContent::Text {
896 text: preamble,
897 cache_control: None,
898 }]
899 }
900 } else {
901 vec![]
902 };
903
904 if prompt_caching {
906 apply_cache_control(&mut system, &mut messages);
907 }
908
909 Ok(Self {
910 model: model.to_string(),
911 messages,
912 max_tokens,
913 system,
914 temperature: req.temperature,
915 tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
916 tools,
917 additional_params: req.additional_params,
918 })
919 }
920}
921
922impl<T> completion::CompletionModel for CompletionModel<T>
923where
924 T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
925{
926 type Response = CompletionResponse;
927 type StreamingResponse = StreamingCompletionResponse;
928 type Client = Client<T>;
929
930 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
931 Self::new(client.clone(), model.into())
932 }
933
934 async fn completion(
935 &self,
936 mut completion_request: completion::CompletionRequest,
937 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
938 let span = if tracing::Span::current().is_disabled() {
939 info_span!(
940 target: "rig::completions",
941 "chat",
942 gen_ai.operation.name = "chat",
943 gen_ai.provider.name = "anthropic",
944 gen_ai.request.model = &self.model,
945 gen_ai.system_instructions = &completion_request.preamble,
946 gen_ai.response.id = tracing::field::Empty,
947 gen_ai.response.model = tracing::field::Empty,
948 gen_ai.usage.output_tokens = tracing::field::Empty,
949 gen_ai.usage.input_tokens = tracing::field::Empty,
950 )
951 } else {
952 tracing::Span::current()
953 };
954
955 if completion_request.max_tokens.is_none() {
957 if let Some(tokens) = self.default_max_tokens {
958 completion_request.max_tokens = Some(tokens);
959 } else {
960 return Err(CompletionError::RequestError(
961 "`max_tokens` must be set for Anthropic".into(),
962 ));
963 }
964 }
965
966 let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
967 model: &self.model,
968 request: completion_request,
969 prompt_caching: self.prompt_caching,
970 })?;
971
972 if enabled!(Level::TRACE) {
973 tracing::trace!(
974 target: "rig::completions",
975 "Anthropic completion request: {}",
976 serde_json::to_string_pretty(&request)?
977 );
978 }
979
980 async move {
981 let request: Vec<u8> = serde_json::to_vec(&request)?;
982
983 let req = self
984 .client
985 .post("/v1/messages")?
986 .body(request)
987 .map_err(|e| CompletionError::HttpError(e.into()))?;
988
989 let response = self
990 .client
991 .send::<_, Bytes>(req)
992 .await
993 .map_err(CompletionError::HttpError)?;
994
995 if response.status().is_success() {
996 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
997 response
998 .into_body()
999 .await
1000 .map_err(CompletionError::HttpError)?
1001 .to_vec()
1002 .as_slice(),
1003 )? {
1004 ApiResponse::Message(completion) => {
1005 let span = tracing::Span::current();
1006 span.record_response_metadata(&completion);
1007 span.record_token_usage(&completion.usage);
1008 if enabled!(Level::TRACE) {
1009 tracing::trace!(
1010 target: "rig::completions",
1011 "Anthropic completion response: {}",
1012 serde_json::to_string_pretty(&completion)?
1013 );
1014 }
1015 completion.try_into()
1016 }
1017 ApiResponse::Error(ApiErrorResponse { message }) => {
1018 Err(CompletionError::ResponseError(message))
1019 }
1020 }
1021 } else {
1022 let text: String = String::from_utf8_lossy(
1023 &response
1024 .into_body()
1025 .await
1026 .map_err(CompletionError::HttpError)?,
1027 )
1028 .into();
1029 Err(CompletionError::ProviderError(text))
1030 }
1031 }
1032 .instrument(span)
1033 .await
1034 }
1035
1036 async fn stream(
1037 &self,
1038 request: CompletionRequest,
1039 ) -> Result<
1040 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1041 CompletionError,
1042 > {
1043 CompletionModel::stream(self, request).await
1044 }
1045}
1046
1047#[derive(Debug, Deserialize)]
1048struct ApiErrorResponse {
1049 message: String,
1050}
1051
1052#[derive(Debug, Deserialize)]
1053#[serde(tag = "type", rename_all = "snake_case")]
1054enum ApiResponse<T> {
1055 Message(T),
1056 Error(ApiErrorResponse),
1057}
1058
1059#[cfg(test)]
1060mod tests {
1061 use super::*;
1062 use serde_json::json;
1063 use serde_path_to_error::deserialize;
1064
1065 #[test]
1066 fn test_deserialize_message() {
1067 let assistant_message_json = r#"
1068 {
1069 "role": "assistant",
1070 "content": "\n\nHello there, how may I assist you today?"
1071 }
1072 "#;
1073
1074 let assistant_message_json2 = r#"
1075 {
1076 "role": "assistant",
1077 "content": [
1078 {
1079 "type": "text",
1080 "text": "\n\nHello there, how may I assist you today?"
1081 },
1082 {
1083 "type": "tool_use",
1084 "id": "toolu_01A09q90qw90lq917835lq9",
1085 "name": "get_weather",
1086 "input": {"location": "San Francisco, CA"}
1087 }
1088 ]
1089 }
1090 "#;
1091
1092 let user_message_json = r#"
1093 {
1094 "role": "user",
1095 "content": [
1096 {
1097 "type": "image",
1098 "source": {
1099 "type": "base64",
1100 "media_type": "image/jpeg",
1101 "data": "/9j/4AAQSkZJRg..."
1102 }
1103 },
1104 {
1105 "type": "text",
1106 "text": "What is in this image?"
1107 },
1108 {
1109 "type": "tool_result",
1110 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1111 "content": "15 degrees"
1112 }
1113 ]
1114 }
1115 "#;
1116
1117 let assistant_message: Message = {
1118 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1119 deserialize(jd).unwrap_or_else(|err| {
1120 panic!("Deserialization error at {}: {}", err.path(), err);
1121 })
1122 };
1123
1124 let assistant_message2: Message = {
1125 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1126 deserialize(jd).unwrap_or_else(|err| {
1127 panic!("Deserialization error at {}: {}", err.path(), err);
1128 })
1129 };
1130
1131 let user_message: Message = {
1132 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1133 deserialize(jd).unwrap_or_else(|err| {
1134 panic!("Deserialization error at {}: {}", err.path(), err);
1135 })
1136 };
1137
1138 let Message { role, content } = assistant_message;
1139 assert_eq!(role, Role::Assistant);
1140 assert_eq!(
1141 content.first(),
1142 Content::Text {
1143 text: "\n\nHello there, how may I assist you today?".to_owned(),
1144 cache_control: None,
1145 }
1146 );
1147
1148 let Message { role, content } = assistant_message2;
1149 {
1150 assert_eq!(role, Role::Assistant);
1151 assert_eq!(content.len(), 2);
1152
1153 let mut iter = content.into_iter();
1154
1155 match iter.next().unwrap() {
1156 Content::Text { text, .. } => {
1157 assert_eq!(text, "\n\nHello there, how may I assist you today?");
1158 }
1159 _ => panic!("Expected text content"),
1160 }
1161
1162 match iter.next().unwrap() {
1163 Content::ToolUse { id, name, input } => {
1164 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1165 assert_eq!(name, "get_weather");
1166 assert_eq!(input, json!({"location": "San Francisco, CA"}));
1167 }
1168 _ => panic!("Expected tool use content"),
1169 }
1170
1171 assert_eq!(iter.next(), None);
1172 }
1173
1174 let Message { role, content } = user_message;
1175 {
1176 assert_eq!(role, Role::User);
1177 assert_eq!(content.len(), 3);
1178
1179 let mut iter = content.into_iter();
1180
1181 match iter.next().unwrap() {
1182 Content::Image { source, .. } => {
1183 assert_eq!(
1184 source,
1185 ImageSource {
1186 data: ImageSourceData::Base64("/9j/4AAQSkZJRg...".to_owned()),
1187 media_type: ImageFormat::JPEG,
1188 r#type: SourceType::BASE64,
1189 }
1190 );
1191 }
1192 _ => panic!("Expected image content"),
1193 }
1194
1195 match iter.next().unwrap() {
1196 Content::Text { text, .. } => {
1197 assert_eq!(text, "What is in this image?");
1198 }
1199 _ => panic!("Expected text content"),
1200 }
1201
1202 match iter.next().unwrap() {
1203 Content::ToolResult {
1204 tool_use_id,
1205 content,
1206 is_error,
1207 ..
1208 } => {
1209 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1210 assert_eq!(
1211 content.first(),
1212 ToolResultContent::Text {
1213 text: "15 degrees".to_owned()
1214 }
1215 );
1216 assert_eq!(is_error, None);
1217 }
1218 _ => panic!("Expected tool result content"),
1219 }
1220
1221 assert_eq!(iter.next(), None);
1222 }
1223 }
1224
1225 #[test]
1226 fn test_message_to_message_conversion() {
1227 let user_message: Message = serde_json::from_str(
1228 r#"
1229 {
1230 "role": "user",
1231 "content": [
1232 {
1233 "type": "image",
1234 "source": {
1235 "type": "base64",
1236 "media_type": "image/jpeg",
1237 "data": "/9j/4AAQSkZJRg..."
1238 }
1239 },
1240 {
1241 "type": "text",
1242 "text": "What is in this image?"
1243 },
1244 {
1245 "type": "document",
1246 "source": {
1247 "type": "base64",
1248 "data": "base64_encoded_pdf_data",
1249 "media_type": "application/pdf"
1250 }
1251 }
1252 ]
1253 }
1254 "#,
1255 )
1256 .unwrap();
1257
1258 let assistant_message = Message {
1259 role: Role::Assistant,
1260 content: OneOrMany::one(Content::ToolUse {
1261 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1262 name: "get_weather".to_string(),
1263 input: json!({"location": "San Francisco, CA"}),
1264 }),
1265 };
1266
1267 let tool_message = Message {
1268 role: Role::User,
1269 content: OneOrMany::one(Content::ToolResult {
1270 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1271 content: OneOrMany::one(ToolResultContent::Text {
1272 text: "15 degrees".to_string(),
1273 }),
1274 is_error: None,
1275 cache_control: None,
1276 }),
1277 };
1278
1279 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1280 let converted_assistant_message: message::Message =
1281 assistant_message.clone().try_into().unwrap();
1282 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1283
1284 match converted_user_message.clone() {
1285 message::Message::User { content } => {
1286 assert_eq!(content.len(), 3);
1287
1288 let mut iter = content.into_iter();
1289
1290 match iter.next().unwrap() {
1291 message::UserContent::Image(message::Image {
1292 data, media_type, ..
1293 }) => {
1294 assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1295 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1296 }
1297 _ => panic!("Expected image content"),
1298 }
1299
1300 match iter.next().unwrap() {
1301 message::UserContent::Text(message::Text { text }) => {
1302 assert_eq!(text, "What is in this image?");
1303 }
1304 _ => panic!("Expected text content"),
1305 }
1306
1307 match iter.next().unwrap() {
1308 message::UserContent::Document(message::Document {
1309 data, media_type, ..
1310 }) => {
1311 assert_eq!(
1312 data,
1313 DocumentSourceKind::String("base64_encoded_pdf_data".into())
1314 );
1315 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1316 }
1317 _ => panic!("Expected document content"),
1318 }
1319
1320 assert_eq!(iter.next(), None);
1321 }
1322 _ => panic!("Expected user message"),
1323 }
1324
1325 match converted_tool_message.clone() {
1326 message::Message::User { content } => {
1327 let message::ToolResult { id, content, .. } = match content.first() {
1328 message::UserContent::ToolResult(tool_result) => tool_result,
1329 _ => panic!("Expected tool result content"),
1330 };
1331 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1332 match content.first() {
1333 message::ToolResultContent::Text(message::Text { text }) => {
1334 assert_eq!(text, "15 degrees");
1335 }
1336 _ => panic!("Expected text content"),
1337 }
1338 }
1339 _ => panic!("Expected tool result content"),
1340 }
1341
1342 match converted_assistant_message.clone() {
1343 message::Message::Assistant { content, .. } => {
1344 assert_eq!(content.len(), 1);
1345
1346 match content.first() {
1347 message::AssistantContent::ToolCall(message::ToolCall {
1348 id, function, ..
1349 }) => {
1350 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1351 assert_eq!(function.name, "get_weather");
1352 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1353 }
1354 _ => panic!("Expected tool call content"),
1355 }
1356 }
1357 _ => panic!("Expected assistant message"),
1358 }
1359
1360 let original_user_message: Message = converted_user_message.try_into().unwrap();
1361 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1362 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1363
1364 assert_eq!(user_message, original_user_message);
1365 assert_eq!(assistant_message, original_assistant_message);
1366 assert_eq!(tool_message, original_tool_message);
1367 }
1368
1369 #[test]
1370 fn test_content_format_conversion() {
1371 use crate::completion::message::ContentFormat;
1372
1373 let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1374 assert_eq!(source_type, SourceType::URL);
1375
1376 let content_format: ContentFormat = SourceType::URL.into();
1377 assert_eq!(content_format, ContentFormat::Url);
1378
1379 let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1380 assert_eq!(source_type, SourceType::BASE64);
1381
1382 let content_format: ContentFormat = SourceType::BASE64.into();
1383 assert_eq!(content_format, ContentFormat::Base64);
1384
1385 let result: Result<SourceType, _> = ContentFormat::String.try_into();
1386 assert!(result.is_err());
1387 assert!(
1388 result
1389 .unwrap_err()
1390 .to_string()
1391 .contains("ContentFormat::String is deprecated")
1392 );
1393 }
1394
1395 #[test]
1396 fn test_cache_control_serialization() {
1397 let system = SystemContent::Text {
1399 text: "You are a helpful assistant.".to_string(),
1400 cache_control: Some(CacheControl::Ephemeral),
1401 };
1402 let json = serde_json::to_string(&system).unwrap();
1403 assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1404 assert!(json.contains(r#""type":"text""#));
1405
1406 let system_no_cache = SystemContent::Text {
1408 text: "Hello".to_string(),
1409 cache_control: None,
1410 };
1411 let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1412 assert!(!json_no_cache.contains("cache_control"));
1413
1414 let content = Content::Text {
1416 text: "Test message".to_string(),
1417 cache_control: Some(CacheControl::Ephemeral),
1418 };
1419 let json_content = serde_json::to_string(&content).unwrap();
1420 assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1421
1422 let mut system_vec = vec![SystemContent::Text {
1424 text: "System prompt".to_string(),
1425 cache_control: None,
1426 }];
1427 let mut messages = vec![
1428 Message {
1429 role: Role::User,
1430 content: OneOrMany::one(Content::Text {
1431 text: "First message".to_string(),
1432 cache_control: None,
1433 }),
1434 },
1435 Message {
1436 role: Role::Assistant,
1437 content: OneOrMany::one(Content::Text {
1438 text: "Response".to_string(),
1439 cache_control: None,
1440 }),
1441 },
1442 ];
1443
1444 apply_cache_control(&mut system_vec, &mut messages);
1445
1446 match &system_vec[0] {
1448 SystemContent::Text { cache_control, .. } => {
1449 assert!(cache_control.is_some());
1450 }
1451 }
1452
1453 for content in messages[0].content.iter() {
1456 if let Content::Text { cache_control, .. } = content {
1457 assert!(cache_control.is_none());
1458 }
1459 }
1460
1461 for content in messages[1].content.iter() {
1463 if let Content::Text { cache_control, .. } = content {
1464 assert!(cache_control.is_some());
1465 }
1466 }
1467 }
1468}