1use crate::completion::CompletionRequest;
4use crate::providers::anthropic::streaming::StreamingCompletionResponse;
5use crate::{
6 OneOrMany,
7 client::Provider,
8 completion::{self, CompletionError, GetTokenUsage},
9 http_client::HttpClientExt,
10 message::{self, DocumentMediaType, DocumentSourceKind, MessageError, MimeType, Reasoning},
11 one_or_many::string_or_one_or_many,
12 telemetry::{ProviderResponseExt, SpanCombinator},
13 wasm_compat::*,
14};
15use bytes::Bytes;
16use serde::{Deserialize, Serialize};
17use std::{convert::Infallible, str::FromStr};
18use tracing::{Instrument, Level, enabled, info_span};
19
20pub const CLAUDE_OPUS_4_6: &str = "claude-opus-4-6";
26pub const CLAUDE_OPUS_4_7: &str = "claude-opus-4-7";
28pub const CLAUDE_SONNET_4_6: &str = "claude-sonnet-4-6";
30pub const CLAUDE_HAIKU_4_5: &str = "claude-haiku-4-5";
32
33pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
34pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
35pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
36const EMPTY_RESPONSE_ERROR: &str = "Response contained no message or tool call (empty)";
37
38pub trait AnthropicCompatibleProvider: Provider {
39 const PROVIDER_NAME: &'static str;
40
41 fn default_max_tokens(model: &str) -> Option<u64> {
42 let _ = model;
43 None
44 }
45}
46
47impl AnthropicCompatibleProvider for super::client::AnthropicExt {
48 const PROVIDER_NAME: &'static str = "anthropic";
49
50 fn default_max_tokens(model: &str) -> Option<u64> {
51 default_max_tokens_for_model(model)
52 }
53}
54
55#[derive(Debug, Deserialize, Serialize)]
56pub struct CompletionResponse {
57 pub content: Vec<Content>,
58 pub id: String,
59 pub model: String,
60 pub role: String,
61 pub stop_reason: Option<String>,
62 pub stop_sequence: Option<String>,
63 pub usage: Usage,
64}
65
66impl ProviderResponseExt for CompletionResponse {
67 type OutputMessage = Content;
68 type Usage = Usage;
69
70 fn get_response_id(&self) -> Option<String> {
71 Some(self.id.to_owned())
72 }
73
74 fn get_response_model_name(&self) -> Option<String> {
75 Some(self.model.to_owned())
76 }
77
78 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
79 self.content.clone()
80 }
81
82 fn get_text_response(&self) -> Option<String> {
83 let res = self
84 .content
85 .iter()
86 .filter_map(|x| {
87 if let Content::Text { text, .. } = x {
88 Some(text.to_owned())
89 } else {
90 None
91 }
92 })
93 .collect::<Vec<String>>()
94 .join("\n");
95
96 if res.is_empty() { None } else { Some(res) }
97 }
98
99 fn get_usage(&self) -> Option<Self::Usage> {
100 Some(self.usage.clone())
101 }
102}
103
104#[derive(Clone, Debug, Deserialize, Serialize)]
105pub struct Usage {
106 pub input_tokens: u64,
107 pub cache_read_input_tokens: Option<u64>,
108 pub cache_creation_input_tokens: Option<u64>,
109 pub output_tokens: u64,
110}
111
112impl std::fmt::Display for Usage {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 write!(
115 f,
116 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
117 self.input_tokens,
118 match self.cache_read_input_tokens {
119 Some(token) => token.to_string(),
120 None => "n/a".to_string(),
121 },
122 match self.cache_creation_input_tokens {
123 Some(token) => token.to_string(),
124 None => "n/a".to_string(),
125 },
126 self.output_tokens
127 )
128 }
129}
130
131impl GetTokenUsage for Usage {
132 fn token_usage(&self) -> Option<crate::completion::Usage> {
133 let mut usage = crate::completion::Usage::new();
134
135 usage.input_tokens = self.input_tokens;
136 usage.output_tokens = self.output_tokens;
137 usage.cached_input_tokens = self.cache_read_input_tokens.unwrap_or_default();
138 usage.cache_creation_input_tokens = self.cache_creation_input_tokens.unwrap_or_default();
139 usage.total_tokens = self.input_tokens
140 + self.cache_read_input_tokens.unwrap_or_default()
141 + self.cache_creation_input_tokens.unwrap_or_default()
142 + self.output_tokens;
143
144 Some(usage)
145 }
146}
147
148#[derive(Debug, Deserialize, Serialize)]
149pub struct ToolDefinition {
150 pub name: String,
151 pub description: Option<String>,
152 pub input_schema: serde_json::Value,
153}
154
155#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Default)]
161pub enum CacheTtl {
162 #[default]
164 #[serde(rename = "5m")]
165 FiveMinutes,
166 #[serde(rename = "1h")]
168 OneHour,
169}
170
171#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
176#[serde(tag = "type", rename_all = "snake_case")]
177pub enum CacheControl {
178 Ephemeral {
179 #[serde(skip_serializing_if = "Option::is_none")]
181 ttl: Option<CacheTtl>,
182 },
183}
184
185impl CacheControl {
186 pub fn ephemeral() -> Self {
188 Self::Ephemeral { ttl: None }
189 }
190
191 pub fn ephemeral_1h() -> Self {
193 Self::Ephemeral {
194 ttl: Some(CacheTtl::OneHour),
195 }
196 }
197}
198
199#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
201#[serde(tag = "type", rename_all = "snake_case")]
202pub enum SystemContent {
203 Text {
204 text: String,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 cache_control: Option<CacheControl>,
207 },
208}
209
210impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
211 type Error = CompletionError;
212
213 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
214 let content = response
215 .content
216 .iter()
217 .map(|content| content.clone().try_into())
218 .collect::<Result<Vec<_>, _>>()?;
219
220 let choice = if content.is_empty() {
221 if response.stop_reason.as_deref() == Some("end_turn") {
225 OneOrMany::one(completion::AssistantContent::text(""))
226 } else {
227 return Err(CompletionError::ResponseError(
228 EMPTY_RESPONSE_ERROR.to_owned(),
229 ));
230 }
231 } else {
232 OneOrMany::many(content)
233 .map_err(|_| CompletionError::ResponseError(EMPTY_RESPONSE_ERROR.to_owned()))?
234 };
235
236 let usage = completion::Usage {
237 input_tokens: response.usage.input_tokens,
238 output_tokens: response.usage.output_tokens,
239 total_tokens: response.usage.input_tokens
240 + response.usage.cache_read_input_tokens.unwrap_or(0)
241 + response.usage.cache_creation_input_tokens.unwrap_or(0)
242 + response.usage.output_tokens,
243 cached_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
244 cache_creation_input_tokens: response.usage.cache_creation_input_tokens.unwrap_or(0),
245 };
246
247 Ok(completion::CompletionResponse {
248 choice,
249 usage,
250 raw_response: response,
251 message_id: None,
252 })
253 }
254}
255
256#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
257pub struct Message {
258 pub role: Role,
259 #[serde(deserialize_with = "string_or_one_or_many")]
260 pub content: OneOrMany<Content>,
261}
262
263#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
264#[serde(rename_all = "lowercase")]
265pub enum Role {
266 User,
267 Assistant,
268}
269
270#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
271#[serde(tag = "type", rename_all = "snake_case")]
272pub enum Content {
273 Text {
274 text: String,
275 #[serde(skip_serializing_if = "Option::is_none")]
276 cache_control: Option<CacheControl>,
277 },
278 Image {
279 source: ImageSource,
280 #[serde(skip_serializing_if = "Option::is_none")]
281 cache_control: Option<CacheControl>,
282 },
283 ToolUse {
284 id: String,
285 name: String,
286 input: serde_json::Value,
287 },
288 ToolResult {
289 tool_use_id: String,
290 #[serde(deserialize_with = "string_or_one_or_many")]
291 content: OneOrMany<ToolResultContent>,
292 #[serde(skip_serializing_if = "Option::is_none")]
293 is_error: Option<bool>,
294 #[serde(skip_serializing_if = "Option::is_none")]
295 cache_control: Option<CacheControl>,
296 },
297 Document {
298 source: DocumentSource,
299 #[serde(skip_serializing_if = "Option::is_none")]
300 cache_control: Option<CacheControl>,
301 },
302 Thinking {
303 thinking: String,
304 #[serde(skip_serializing_if = "Option::is_none")]
305 signature: Option<String>,
306 },
307 RedactedThinking {
308 data: String,
309 },
310}
311
312impl FromStr for Content {
313 type Err = Infallible;
314
315 fn from_str(s: &str) -> Result<Self, Self::Err> {
316 Ok(Content::Text {
317 text: s.to_owned(),
318 cache_control: None,
319 })
320 }
321}
322
323#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
324#[serde(tag = "type", rename_all = "snake_case")]
325pub enum ToolResultContent {
326 Text { text: String },
327 Image(ImageSource),
328}
329
330impl FromStr for ToolResultContent {
331 type Err = Infallible;
332
333 fn from_str(s: &str) -> Result<Self, Self::Err> {
334 Ok(ToolResultContent::Text { text: s.to_owned() })
335 }
336}
337
338#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
346#[serde(tag = "type", rename_all = "snake_case")]
347pub enum ImageSource {
348 #[serde(rename = "base64")]
349 Base64 {
350 data: String,
351 media_type: ImageFormat,
352 },
353 #[serde(rename = "url")]
354 Url { url: String },
355}
356
357#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
368#[serde(tag = "type", rename_all = "snake_case")]
369pub enum DocumentSource {
370 Base64 {
371 data: String,
372 media_type: DocumentFormat,
373 },
374 Text {
375 data: String,
376 media_type: PlainTextMediaType,
377 },
378 Url {
379 url: String,
380 },
381}
382
383#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
384#[serde(rename_all = "lowercase")]
385pub enum ImageFormat {
386 #[serde(rename = "image/jpeg")]
387 JPEG,
388 #[serde(rename = "image/png")]
389 PNG,
390 #[serde(rename = "image/gif")]
391 GIF,
392 #[serde(rename = "image/webp")]
393 WEBP,
394}
395
396#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
403#[serde(rename_all = "lowercase")]
404pub enum DocumentFormat {
405 #[serde(rename = "application/pdf")]
406 PDF,
407}
408
409#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
415pub enum PlainTextMediaType {
416 #[serde(rename = "text/plain")]
417 Plain,
418}
419
420#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
421#[serde(rename_all = "lowercase")]
422pub enum SourceType {
423 BASE64,
424 URL,
425 TEXT,
426}
427
428impl From<String> for Content {
429 fn from(text: String) -> Self {
430 Content::Text {
431 text,
432 cache_control: None,
433 }
434 }
435}
436
437impl From<String> for ToolResultContent {
438 fn from(text: String) -> Self {
439 ToolResultContent::Text { text }
440 }
441}
442
443impl TryFrom<message::ContentFormat> for SourceType {
444 type Error = MessageError;
445
446 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
447 match format {
448 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
449 message::ContentFormat::Url => Ok(SourceType::URL),
450 message::ContentFormat::String => Ok(SourceType::TEXT),
451 }
452 }
453}
454
455impl From<SourceType> for message::ContentFormat {
456 fn from(source_type: SourceType) -> Self {
457 match source_type {
458 SourceType::BASE64 => message::ContentFormat::Base64,
459 SourceType::URL => message::ContentFormat::Url,
460 SourceType::TEXT => message::ContentFormat::String,
461 }
462 }
463}
464
465impl TryFrom<message::ImageMediaType> for ImageFormat {
466 type Error = MessageError;
467
468 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
469 Ok(match media_type {
470 message::ImageMediaType::JPEG => ImageFormat::JPEG,
471 message::ImageMediaType::PNG => ImageFormat::PNG,
472 message::ImageMediaType::GIF => ImageFormat::GIF,
473 message::ImageMediaType::WEBP => ImageFormat::WEBP,
474 _ => {
475 return Err(MessageError::ConversionError(
476 format!("Unsupported image media type: {media_type:?}").to_owned(),
477 ));
478 }
479 })
480 }
481}
482
483impl From<ImageFormat> for message::ImageMediaType {
484 fn from(format: ImageFormat) -> Self {
485 match format {
486 ImageFormat::JPEG => message::ImageMediaType::JPEG,
487 ImageFormat::PNG => message::ImageMediaType::PNG,
488 ImageFormat::GIF => message::ImageMediaType::GIF,
489 ImageFormat::WEBP => message::ImageMediaType::WEBP,
490 }
491 }
492}
493
494impl TryFrom<DocumentMediaType> for DocumentFormat {
495 type Error = MessageError;
496 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
497 match value {
498 DocumentMediaType::PDF => Ok(DocumentFormat::PDF),
499 other => Err(MessageError::ConversionError(format!(
500 "DocumentFormat only supports PDF for base64 sources, got: {}",
501 other.to_mime_type()
502 ))),
503 }
504 }
505}
506
507impl TryFrom<message::AssistantContent> for Content {
508 type Error = MessageError;
509 fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
510 match text {
511 message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
512 text,
513 cache_control: None,
514 }),
515 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
516 "Anthropic currently doesn't support images.".to_string(),
517 )),
518 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
519 Ok(Content::ToolUse {
520 id,
521 name: function.name,
522 input: function.arguments,
523 })
524 }
525 message::AssistantContent::Reasoning(reasoning) => Ok(Content::Thinking {
526 thinking: reasoning.display_text(),
527 signature: reasoning.first_signature().map(str::to_owned),
528 }),
529 }
530 }
531}
532
533fn anthropic_content_from_assistant_content(
534 content: message::AssistantContent,
535) -> Result<Vec<Content>, MessageError> {
536 match content {
537 message::AssistantContent::Text(message::Text { text }) => Ok(vec![Content::Text {
538 text,
539 cache_control: None,
540 }]),
541 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
542 "Anthropic currently doesn't support images.".to_string(),
543 )),
544 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
545 Ok(vec![Content::ToolUse {
546 id,
547 name: function.name,
548 input: function.arguments,
549 }])
550 }
551 message::AssistantContent::Reasoning(reasoning) => {
552 let mut converted = Vec::new();
553 for block in reasoning.content {
554 match block {
555 message::ReasoningContent::Text { text, signature } => {
556 converted.push(Content::Thinking {
557 thinking: text,
558 signature,
559 });
560 }
561 message::ReasoningContent::Summary(summary) => {
562 converted.push(Content::Thinking {
563 thinking: summary,
564 signature: None,
565 });
566 }
567 message::ReasoningContent::Redacted { data }
568 | message::ReasoningContent::Encrypted(data) => {
569 converted.push(Content::RedactedThinking { data });
570 }
571 }
572 }
573
574 if converted.is_empty() {
575 return Err(MessageError::ConversionError(
576 "Cannot convert empty reasoning content to Anthropic format".to_string(),
577 ));
578 }
579
580 Ok(converted)
581 }
582 }
583}
584
585impl TryFrom<message::Message> for Message {
586 type Error = MessageError;
587
588 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
589 Ok(match message {
590 message::Message::User { content } => Message {
591 role: Role::User,
592 content: content.try_map(|content| match content {
593 message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
594 text,
595 cache_control: None,
596 }),
597 message::UserContent::ToolResult(message::ToolResult {
598 id, content, ..
599 }) => Ok(Content::ToolResult {
600 tool_use_id: id,
601 content: content.try_map(|content| match content {
602 message::ToolResultContent::Text(message::Text { text }) => {
603 Ok(ToolResultContent::Text { text })
604 }
605 message::ToolResultContent::Image(image) => {
606 let DocumentSourceKind::Base64(data) = image.data else {
607 return Err(MessageError::ConversionError(
608 "Only base64 strings can be used with the Anthropic API"
609 .to_string(),
610 ));
611 };
612 let media_type =
613 image.media_type.ok_or(MessageError::ConversionError(
614 "Image media type is required".to_owned(),
615 ))?;
616 Ok(ToolResultContent::Image(ImageSource::Base64 {
617 data,
618 media_type: media_type.try_into()?,
619 }))
620 }
621 })?,
622 is_error: None,
623 cache_control: None,
624 }),
625 message::UserContent::Image(message::Image {
626 data, media_type, ..
627 }) => {
628 let source = match data {
629 DocumentSourceKind::Base64(data) => {
630 let media_type =
631 media_type.ok_or(MessageError::ConversionError(
632 "Image media type is required for Claude API".to_string(),
633 ))?;
634 ImageSource::Base64 {
635 data,
636 media_type: ImageFormat::try_from(media_type)?,
637 }
638 }
639 DocumentSourceKind::Url(url) => ImageSource::Url { url },
640 DocumentSourceKind::Unknown => {
641 return Err(MessageError::ConversionError(
642 "Image content has no body".into(),
643 ));
644 }
645 doc => {
646 return Err(MessageError::ConversionError(format!(
647 "Unsupported document type: {doc:?}"
648 )));
649 }
650 };
651
652 Ok(Content::Image {
653 source,
654 cache_control: None,
655 })
656 }
657 message::UserContent::Document(message::Document {
658 data, media_type, ..
659 }) => {
660 let media_type = media_type.ok_or(MessageError::ConversionError(
661 "Document media type is required".to_string(),
662 ))?;
663
664 let source = match media_type {
665 DocumentMediaType::PDF => {
666 let data = match data {
667 DocumentSourceKind::Base64(data)
668 | DocumentSourceKind::String(data) => data,
669 _ => {
670 return Err(MessageError::ConversionError(
671 "Only base64 encoded data is supported for PDF documents".into(),
672 ));
673 }
674 };
675 DocumentSource::Base64 {
676 data,
677 media_type: DocumentFormat::PDF,
678 }
679 }
680 DocumentMediaType::TXT => {
681 let data = match data {
682 DocumentSourceKind::String(data)
683 | DocumentSourceKind::Base64(data) => data,
684 _ => {
685 return Err(MessageError::ConversionError(
686 "Only string or base64 data is supported for plain text documents".into(),
687 ));
688 }
689 };
690 DocumentSource::Text {
691 data,
692 media_type: PlainTextMediaType::Plain,
693 }
694 }
695 other => {
696 return Err(MessageError::ConversionError(format!(
697 "Anthropic only supports PDF and plain text documents, got: {}",
698 other.to_mime_type()
699 )));
700 }
701 };
702
703 Ok(Content::Document {
704 source,
705 cache_control: None,
706 })
707 }
708 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
709 "Audio is not supported in Anthropic".to_owned(),
710 )),
711 message::UserContent::Video { .. } => Err(MessageError::ConversionError(
712 "Video is not supported in Anthropic".to_owned(),
713 )),
714 })?,
715 },
716
717 message::Message::System { content } => Message {
718 role: Role::User,
719 content: OneOrMany::one(Content::Text {
720 text: content,
721 cache_control: None,
722 }),
723 },
724
725 message::Message::Assistant { content, .. } => {
726 let converted_content = content.into_iter().try_fold(
727 Vec::new(),
728 |mut accumulated, assistant_content| {
729 accumulated
730 .extend(anthropic_content_from_assistant_content(assistant_content)?);
731 Ok::<Vec<Content>, MessageError>(accumulated)
732 },
733 )?;
734
735 Message {
736 content: OneOrMany::many(converted_content).map_err(|_| {
737 MessageError::ConversionError(
738 "Assistant message did not contain Anthropic-compatible content"
739 .to_owned(),
740 )
741 })?,
742 role: Role::Assistant,
743 }
744 }
745 })
746 }
747}
748
749impl TryFrom<Content> for message::AssistantContent {
750 type Error = MessageError;
751
752 fn try_from(content: Content) -> Result<Self, Self::Error> {
753 Ok(match content {
754 Content::Text { text, .. } => message::AssistantContent::text(text),
755 Content::ToolUse { id, name, input } => {
756 message::AssistantContent::tool_call(id, name, input)
757 }
758 Content::Thinking {
759 thinking,
760 signature,
761 } => message::AssistantContent::Reasoning(Reasoning::new_with_signature(
762 &thinking, signature,
763 )),
764 Content::RedactedThinking { data } => {
765 message::AssistantContent::Reasoning(Reasoning::redacted(data))
766 }
767 _ => {
768 return Err(MessageError::ConversionError(
769 "Content did not contain a message, tool call, or reasoning".to_owned(),
770 ));
771 }
772 })
773 }
774}
775
776impl From<ToolResultContent> for message::ToolResultContent {
777 fn from(content: ToolResultContent) -> Self {
778 match content {
779 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
780 ToolResultContent::Image(source) => match source {
781 ImageSource::Base64 { data, media_type } => {
782 message::ToolResultContent::image_base64(data, Some(media_type.into()), None)
783 }
784 ImageSource::Url { url } => message::ToolResultContent::image_url(url, None, None),
785 },
786 }
787 }
788}
789
790impl TryFrom<Message> for message::Message {
791 type Error = MessageError;
792
793 fn try_from(message: Message) -> Result<Self, Self::Error> {
794 Ok(match message.role {
795 Role::User => message::Message::User {
796 content: message.content.try_map(|content| {
797 Ok(match content {
798 Content::Text { text, .. } => message::UserContent::text(text),
799 Content::ToolResult {
800 tool_use_id,
801 content,
802 ..
803 } => message::UserContent::tool_result(
804 tool_use_id,
805 content.map(|content| content.into()),
806 ),
807 Content::Image { source, .. } => match source {
808 ImageSource::Base64 { data, media_type } => {
809 message::UserContent::Image(message::Image {
810 data: DocumentSourceKind::Base64(data),
811 media_type: Some(media_type.into()),
812 detail: None,
813 additional_params: None,
814 })
815 }
816 ImageSource::Url { url } => {
817 message::UserContent::Image(message::Image {
818 data: DocumentSourceKind::Url(url),
819 media_type: None,
820 detail: None,
821 additional_params: None,
822 })
823 }
824 },
825 Content::Document { source, .. } => match source {
826 DocumentSource::Base64 { data, media_type } => {
827 let rig_media_type = match media_type {
828 DocumentFormat::PDF => message::DocumentMediaType::PDF,
829 };
830 message::UserContent::document(data, Some(rig_media_type))
831 }
832 DocumentSource::Text { data, .. } => message::UserContent::document(
833 data,
834 Some(message::DocumentMediaType::TXT),
835 ),
836 DocumentSource::Url { url } => {
837 message::UserContent::document_url(url, None)
838 }
839 },
840 _ => {
841 return Err(MessageError::ConversionError(
842 "Unsupported content type for User role".to_owned(),
843 ));
844 }
845 })
846 })?,
847 },
848 Role::Assistant => message::Message::Assistant {
849 id: None,
850 content: message.content.try_map(|content| content.try_into())?,
851 },
852 })
853 }
854}
855
856#[doc(hidden)]
857#[derive(Clone)]
858pub struct GenericCompletionModel<Ext = super::client::AnthropicExt, T = reqwest::Client> {
859 pub(crate) client: crate::client::Client<Ext, T>,
860 pub model: String,
861 pub default_max_tokens: Option<u64>,
862 pub prompt_caching: bool,
864 pub automatic_caching: bool,
868 pub automatic_caching_ttl: Option<CacheTtl>,
872}
873
874pub type CompletionModel<T = reqwest::Client> =
879 GenericCompletionModel<super::client::AnthropicExt, T>;
880
881impl<Ext, T> GenericCompletionModel<Ext, T>
882where
883 T: HttpClientExt,
884 Ext: AnthropicCompatibleProvider + Clone + 'static,
885{
886 pub fn new(client: crate::client::Client<Ext, T>, model: impl Into<String>) -> Self {
887 let model = model.into();
888 let default_max_tokens = Ext::default_max_tokens(&model);
889
890 Self {
891 client,
892 model,
893 default_max_tokens,
894 prompt_caching: false,
895 automatic_caching: false,
896 automatic_caching_ttl: None,
897 }
898 }
899
900 pub fn with_model(client: crate::client::Client<Ext, T>, model: &str) -> Self {
901 Self {
902 client,
903 model: model.to_string(),
904 default_max_tokens: Ext::default_max_tokens(model)
905 .or_else(|| Some(default_max_tokens_with_fallback(model))),
906 prompt_caching: false,
907 automatic_caching: false,
908 automatic_caching_ttl: None,
909 }
910 }
911
912 pub fn with_prompt_caching(mut self) -> Self {
920 self.prompt_caching = true;
921 self
922 }
923
924 pub fn with_automatic_caching(mut self) -> Self {
957 self.automatic_caching = true;
958 self
959 }
960
961 pub fn with_automatic_caching_1h(mut self) -> Self {
978 self.automatic_caching = true;
979 self.automatic_caching_ttl = Some(CacheTtl::OneHour);
980 self
981 }
982}
983
984fn default_max_tokens_for_model(model: &str) -> Option<u64> {
988 if model.starts_with("claude-opus-4-7") || model.starts_with("claude-opus-4-6") {
989 Some(128_000)
990 } else if model.starts_with("claude-opus-4")
991 || model.starts_with("claude-sonnet-4")
992 || model.starts_with("claude-haiku-4-5")
993 {
994 Some(64_000)
995 } else {
996 None
997 }
998}
999
1000fn default_max_tokens_with_fallback(model: &str) -> u64 {
1001 default_max_tokens_for_model(model).unwrap_or(2_048)
1002}
1003
1004#[derive(Debug, Deserialize, Serialize)]
1005pub struct Metadata {
1006 user_id: Option<String>,
1007}
1008
1009#[derive(Default, Debug, Serialize, Deserialize)]
1010#[serde(tag = "type", rename_all = "snake_case")]
1011pub enum ToolChoice {
1012 #[default]
1013 Auto,
1014 Any,
1015 None,
1016 Tool {
1017 name: String,
1018 },
1019}
1020impl TryFrom<message::ToolChoice> for ToolChoice {
1021 type Error = CompletionError;
1022
1023 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1024 let res = match value {
1025 message::ToolChoice::Auto => Self::Auto,
1026 message::ToolChoice::None => Self::None,
1027 message::ToolChoice::Required => Self::Any,
1028 message::ToolChoice::Specific { function_names } => {
1029 if function_names.len() != 1 {
1030 return Err(CompletionError::ProviderError(
1031 "Only one tool may be specified to be used by Claude".into(),
1032 ));
1033 }
1034
1035 let Some(name) = function_names.into_iter().next() else {
1036 return Err(CompletionError::ProviderError(
1037 "Only one tool may be specified to be used by Claude".into(),
1038 ));
1039 };
1040
1041 Self::Tool { name }
1042 }
1043 };
1044
1045 Ok(res)
1046 }
1047}
1048
1049fn sanitize_schema(schema: &mut serde_json::Value) {
1055 use serde_json::Value;
1056
1057 if let Value::Object(obj) = schema {
1058 let is_object_schema = obj.get("type") == Some(&Value::String("object".to_string()))
1059 || obj.contains_key("properties");
1060
1061 if is_object_schema && !obj.contains_key("additionalProperties") {
1062 obj.insert("additionalProperties".to_string(), Value::Bool(false));
1063 }
1064
1065 if let Some(Value::Object(properties)) = obj.get("properties") {
1066 let prop_keys = properties.keys().cloned().map(Value::String).collect();
1067 obj.insert("required".to_string(), Value::Array(prop_keys));
1068 }
1069
1070 let is_numeric_schema = obj.get("type") == Some(&Value::String("integer".to_string()))
1072 || obj.get("type") == Some(&Value::String("number".to_string()));
1073
1074 if is_numeric_schema {
1075 for key in [
1076 "minimum",
1077 "maximum",
1078 "exclusiveMinimum",
1079 "exclusiveMaximum",
1080 "multipleOf",
1081 ] {
1082 obj.remove(key);
1083 }
1084 }
1085
1086 if let Some(defs) = obj.get_mut("$defs")
1087 && let Value::Object(defs_obj) = defs
1088 {
1089 for (_, def_schema) in defs_obj.iter_mut() {
1090 sanitize_schema(def_schema);
1091 }
1092 }
1093
1094 if let Some(properties) = obj.get_mut("properties")
1095 && let Value::Object(props) = properties
1096 {
1097 for (_, prop_value) in props.iter_mut() {
1098 sanitize_schema(prop_value);
1099 }
1100 }
1101
1102 if let Some(items) = obj.get_mut("items") {
1103 sanitize_schema(items);
1104 }
1105
1106 if let Some(one_of) = obj.remove("oneOf") {
1108 match obj.get_mut("anyOf") {
1109 Some(Value::Array(existing)) => {
1110 if let Value::Array(mut incoming) = one_of {
1111 existing.append(&mut incoming);
1112 }
1113 }
1114 _ => {
1115 obj.insert("anyOf".to_string(), one_of);
1116 }
1117 }
1118 }
1119
1120 for key in ["anyOf", "allOf"] {
1121 if let Some(variants) = obj.get_mut(key)
1122 && let Value::Array(variants_array) = variants
1123 {
1124 for variant in variants_array.iter_mut() {
1125 sanitize_schema(variant);
1126 }
1127 }
1128 }
1129 }
1130}
1131
1132#[derive(Debug, Deserialize, Serialize)]
1135#[serde(tag = "type", rename_all = "snake_case")]
1136enum OutputFormat {
1137 JsonSchema { schema: serde_json::Value },
1139}
1140
1141#[derive(Debug, Deserialize, Serialize)]
1143struct OutputConfig {
1144 format: OutputFormat,
1145}
1146
1147#[derive(Debug, Deserialize, Serialize)]
1148struct AnthropicCompletionRequest {
1149 model: String,
1150 messages: Vec<Message>,
1151 max_tokens: u64,
1152 #[serde(skip_serializing_if = "Vec::is_empty")]
1154 system: Vec<SystemContent>,
1155 #[serde(skip_serializing_if = "Option::is_none")]
1156 temperature: Option<f64>,
1157 #[serde(skip_serializing_if = "Option::is_none")]
1158 tool_choice: Option<ToolChoice>,
1159 #[serde(skip_serializing_if = "Vec::is_empty")]
1160 tools: Vec<serde_json::Value>,
1161 #[serde(skip_serializing_if = "Option::is_none")]
1162 output_config: Option<OutputConfig>,
1163 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1164 additional_params: Option<serde_json::Value>,
1165 #[serde(skip_serializing_if = "Option::is_none")]
1169 cache_control: Option<CacheControl>,
1170}
1171
1172fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
1174 match content {
1175 Content::Text { cache_control, .. } => *cache_control = value,
1176 Content::Image { cache_control, .. } => *cache_control = value,
1177 Content::ToolResult { cache_control, .. } => *cache_control = value,
1178 Content::Document { cache_control, .. } => *cache_control = value,
1179 _ => {}
1180 }
1181}
1182
1183pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
1188 if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
1190 *cache_control = Some(CacheControl::ephemeral());
1191 }
1192
1193 for msg in messages.iter_mut() {
1195 for content in msg.content.iter_mut() {
1196 set_content_cache_control(content, None);
1197 }
1198 }
1199
1200 if let Some(last_msg) = messages.last_mut() {
1202 set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::ephemeral()));
1203 }
1204}
1205
1206pub(super) fn split_system_messages_from_history(
1207 history: Vec<message::Message>,
1208) -> (Vec<SystemContent>, Vec<message::Message>) {
1209 let mut system = Vec::new();
1210 let mut remaining = Vec::new();
1211
1212 for message in history {
1213 match message {
1214 message::Message::System { content } => {
1215 if !content.is_empty() {
1216 system.push(SystemContent::Text {
1217 text: content,
1218 cache_control: None,
1219 });
1220 }
1221 }
1222 other => remaining.push(other),
1223 }
1224 }
1225
1226 (system, remaining)
1227}
1228
1229pub struct AnthropicRequestParams<'a> {
1231 pub model: &'a str,
1232 pub request: CompletionRequest,
1233 pub prompt_caching: bool,
1234 pub automatic_caching: bool,
1236 pub automatic_caching_ttl: Option<CacheTtl>,
1238}
1239
1240impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
1241 type Error = CompletionError;
1242
1243 fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
1244 let AnthropicRequestParams {
1245 model,
1246 request: mut req,
1247 prompt_caching,
1248 automatic_caching,
1249 automatic_caching_ttl,
1250 } = params;
1251
1252 let Some(max_tokens) = req.max_tokens else {
1254 return Err(CompletionError::RequestError(
1255 "`max_tokens` must be set for Anthropic".into(),
1256 ));
1257 };
1258
1259 let mut full_history = vec![];
1260 if let Some(docs) = req.normalized_documents() {
1261 full_history.push(docs);
1262 }
1263 full_history.extend(req.chat_history);
1264 let (history_system, full_history) = split_system_messages_from_history(full_history);
1265
1266 let mut messages = full_history
1267 .into_iter()
1268 .map(Message::try_from)
1269 .collect::<Result<Vec<Message>, _>>()?;
1270
1271 let mut additional_params_payload = req
1272 .additional_params
1273 .take()
1274 .unwrap_or(serde_json::Value::Null);
1275 let mut additional_tools =
1276 extract_tools_from_additional_params(&mut additional_params_payload)?;
1277
1278 let mut tools = req
1279 .tools
1280 .into_iter()
1281 .map(|tool| ToolDefinition {
1282 name: tool.name,
1283 description: Some(tool.description),
1284 input_schema: tool.parameters,
1285 })
1286 .map(serde_json::to_value)
1287 .collect::<Result<Vec<_>, _>>()?;
1288 tools.append(&mut additional_tools);
1289
1290 let mut system = if let Some(preamble) = req.preamble {
1292 if preamble.is_empty() {
1293 vec![]
1294 } else {
1295 vec![SystemContent::Text {
1296 text: preamble,
1297 cache_control: None,
1298 }]
1299 }
1300 } else {
1301 vec![]
1302 };
1303 system.extend(history_system);
1304
1305 if prompt_caching {
1307 apply_cache_control(&mut system, &mut messages);
1308 }
1309
1310 let output_config = if let Some(schema) = req.output_schema {
1311 let mut schema_value = schema.to_value();
1312 sanitize_schema(&mut schema_value);
1313 Some(OutputConfig {
1314 format: OutputFormat::JsonSchema {
1315 schema: schema_value,
1316 },
1317 })
1318 } else {
1319 None
1320 };
1321
1322 Ok(Self {
1323 model: model.to_string(),
1324 messages,
1325 max_tokens,
1326 system,
1327 temperature: req.temperature,
1328 tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
1329 tools,
1330 output_config,
1331 cache_control: if automatic_caching {
1333 Some(CacheControl::Ephemeral {
1334 ttl: automatic_caching_ttl,
1335 })
1336 } else {
1337 None
1338 },
1339 additional_params: if additional_params_payload.is_null() {
1340 None
1341 } else {
1342 Some(additional_params_payload)
1343 },
1344 })
1345 }
1346}
1347
1348fn extract_tools_from_additional_params(
1349 additional_params: &mut serde_json::Value,
1350) -> Result<Vec<serde_json::Value>, CompletionError> {
1351 if let Some(map) = additional_params.as_object_mut()
1352 && let Some(raw_tools) = map.remove("tools")
1353 {
1354 return serde_json::from_value::<Vec<serde_json::Value>>(raw_tools).map_err(|err| {
1355 CompletionError::RequestError(
1356 format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(),
1357 )
1358 });
1359 }
1360
1361 Ok(Vec::new())
1362}
1363
1364impl<Ext, T> completion::CompletionModel for GenericCompletionModel<Ext, T>
1365where
1366 T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
1367 Ext: AnthropicCompatibleProvider + Clone + WasmCompatSend + WasmCompatSync + 'static,
1368{
1369 type Response = CompletionResponse;
1370 type StreamingResponse = StreamingCompletionResponse;
1371 type Client = crate::client::Client<Ext, T>;
1372
1373 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1374 Self::new(client.clone(), model.into())
1375 }
1376
1377 async fn completion(
1378 &self,
1379 mut completion_request: completion::CompletionRequest,
1380 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1381 let request_model = completion_request
1382 .model
1383 .clone()
1384 .unwrap_or_else(|| self.model.clone());
1385 let span = if tracing::Span::current().is_disabled() {
1386 info_span!(
1387 target: "rig::completions",
1388 "chat",
1389 gen_ai.operation.name = "chat",
1390 gen_ai.provider.name = Ext::PROVIDER_NAME,
1391 gen_ai.request.model = &request_model,
1392 gen_ai.system_instructions = &completion_request.preamble,
1393 gen_ai.response.id = tracing::field::Empty,
1394 gen_ai.response.model = tracing::field::Empty,
1395 gen_ai.usage.output_tokens = tracing::field::Empty,
1396 gen_ai.usage.input_tokens = tracing::field::Empty,
1397 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
1398 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
1399 )
1400 } else {
1401 tracing::Span::current()
1402 };
1403
1404 if completion_request.max_tokens.is_none() {
1406 if let Some(tokens) = self.default_max_tokens {
1407 completion_request.max_tokens = Some(tokens);
1408 } else {
1409 return Err(CompletionError::RequestError(
1410 "`max_tokens` must be set for Anthropic".into(),
1411 ));
1412 }
1413 }
1414
1415 let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
1416 model: &request_model,
1417 request: completion_request,
1418 prompt_caching: self.prompt_caching,
1419 automatic_caching: self.automatic_caching,
1420 automatic_caching_ttl: self.automatic_caching_ttl.clone(),
1421 })?;
1422
1423 if enabled!(Level::TRACE) {
1424 tracing::trace!(
1425 target: "rig::completions",
1426 "Anthropic completion request: {}",
1427 serde_json::to_string_pretty(&request)?
1428 );
1429 }
1430
1431 async move {
1432 let request: Vec<u8> = serde_json::to_vec(&request)?;
1433
1434 let req = self
1435 .client
1436 .post("/v1/messages")?
1437 .body(request)
1438 .map_err(|e| CompletionError::HttpError(e.into()))?;
1439
1440 let response = self
1441 .client
1442 .send::<_, Bytes>(req)
1443 .await
1444 .map_err(CompletionError::HttpError)?;
1445
1446 if response.status().is_success() {
1447 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
1448 response
1449 .into_body()
1450 .await
1451 .map_err(CompletionError::HttpError)?
1452 .to_vec()
1453 .as_slice(),
1454 )? {
1455 ApiResponse::Message(completion) => {
1456 let span = tracing::Span::current();
1457 span.record_response_metadata(&completion);
1458 span.record_token_usage(&completion.usage);
1459 if enabled!(Level::TRACE) {
1460 tracing::trace!(
1461 target: "rig::completions",
1462 "Anthropic completion response: {}",
1463 serde_json::to_string_pretty(&completion)?
1464 );
1465 }
1466 completion.try_into()
1467 }
1468 ApiResponse::Error(ApiErrorResponse { message }) => {
1469 Err(CompletionError::ResponseError(message))
1470 }
1471 }
1472 } else {
1473 let text: String = String::from_utf8_lossy(
1474 &response
1475 .into_body()
1476 .await
1477 .map_err(CompletionError::HttpError)?,
1478 )
1479 .into();
1480 Err(CompletionError::ProviderError(text))
1481 }
1482 }
1483 .instrument(span)
1484 .await
1485 }
1486
1487 async fn stream(
1488 &self,
1489 request: CompletionRequest,
1490 ) -> Result<
1491 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1492 CompletionError,
1493 > {
1494 GenericCompletionModel::stream(self, request).await
1495 }
1496}
1497
1498#[derive(Debug, Deserialize)]
1499struct ApiErrorResponse {
1500 message: String,
1501}
1502
1503#[derive(Debug, Deserialize)]
1504#[serde(tag = "type", rename_all = "snake_case")]
1505enum ApiResponse<T> {
1506 Message(T),
1507 Error(ApiErrorResponse),
1508}
1509
1510#[cfg(test)]
1511mod tests {
1512 use super::*;
1513 use serde_json::json;
1514 use serde_path_to_error::deserialize;
1515
1516 #[test]
1517 fn current_model_default_max_tokens_match_anthropic_limits() {
1518 assert_eq!(default_max_tokens_for_model(CLAUDE_OPUS_4_7), Some(128_000));
1519 assert_eq!(default_max_tokens_for_model(CLAUDE_OPUS_4_6), Some(128_000));
1520 assert_eq!(
1521 default_max_tokens_for_model(CLAUDE_SONNET_4_6),
1522 Some(64_000)
1523 );
1524 assert_eq!(default_max_tokens_for_model(CLAUDE_HAIKU_4_5), Some(64_000));
1525 }
1526
1527 #[test]
1528 fn unknown_model_uses_conservative_default_max_tokens_fallback() {
1529 assert_eq!(default_max_tokens_for_model("claude-unknown"), None);
1530 assert_eq!(default_max_tokens_with_fallback("claude-unknown"), 2_048);
1531 }
1532
1533 #[test]
1534 fn test_deserialize_message() {
1535 let assistant_message_json = r#"
1536 {
1537 "role": "assistant",
1538 "content": "\n\nHello there, how may I assist you today?"
1539 }
1540 "#;
1541
1542 let assistant_message_json2 = r#"
1543 {
1544 "role": "assistant",
1545 "content": [
1546 {
1547 "type": "text",
1548 "text": "\n\nHello there, how may I assist you today?"
1549 },
1550 {
1551 "type": "tool_use",
1552 "id": "toolu_01A09q90qw90lq917835lq9",
1553 "name": "get_weather",
1554 "input": {"location": "San Francisco, CA"}
1555 }
1556 ]
1557 }
1558 "#;
1559
1560 let user_message_json = r#"
1561 {
1562 "role": "user",
1563 "content": [
1564 {
1565 "type": "image",
1566 "source": {
1567 "type": "base64",
1568 "media_type": "image/jpeg",
1569 "data": "/9j/4AAQSkZJRg..."
1570 }
1571 },
1572 {
1573 "type": "text",
1574 "text": "What is in this image?"
1575 },
1576 {
1577 "type": "tool_result",
1578 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1579 "content": "15 degrees"
1580 }
1581 ]
1582 }
1583 "#;
1584
1585 let assistant_message: Message = {
1586 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1587 deserialize(jd).unwrap_or_else(|err| {
1588 panic!("Deserialization error at {}: {}", err.path(), err);
1589 })
1590 };
1591
1592 let assistant_message2: Message = {
1593 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1594 deserialize(jd).unwrap_or_else(|err| {
1595 panic!("Deserialization error at {}: {}", err.path(), err);
1596 })
1597 };
1598
1599 let user_message: Message = {
1600 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1601 deserialize(jd).unwrap_or_else(|err| {
1602 panic!("Deserialization error at {}: {}", err.path(), err);
1603 })
1604 };
1605
1606 let Message { role, content } = assistant_message;
1607 assert_eq!(role, Role::Assistant);
1608 assert_eq!(
1609 content.first(),
1610 Content::Text {
1611 text: "\n\nHello there, how may I assist you today?".to_owned(),
1612 cache_control: None,
1613 }
1614 );
1615
1616 let Message { role, content } = assistant_message2;
1617 {
1618 assert_eq!(role, Role::Assistant);
1619 assert_eq!(content.len(), 2);
1620
1621 let mut iter = content.into_iter();
1622
1623 match iter.next().unwrap() {
1624 Content::Text { text, .. } => {
1625 assert_eq!(text, "\n\nHello there, how may I assist you today?");
1626 }
1627 _ => panic!("Expected text content"),
1628 }
1629
1630 match iter.next().unwrap() {
1631 Content::ToolUse { id, name, input } => {
1632 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1633 assert_eq!(name, "get_weather");
1634 assert_eq!(input, json!({"location": "San Francisco, CA"}));
1635 }
1636 _ => panic!("Expected tool use content"),
1637 }
1638
1639 assert_eq!(iter.next(), None);
1640 }
1641
1642 let Message { role, content } = user_message;
1643 {
1644 assert_eq!(role, Role::User);
1645 assert_eq!(content.len(), 3);
1646
1647 let mut iter = content.into_iter();
1648
1649 match iter.next().unwrap() {
1650 Content::Image { source, .. } => {
1651 assert_eq!(
1652 source,
1653 ImageSource::Base64 {
1654 data: "/9j/4AAQSkZJRg...".to_owned(),
1655 media_type: ImageFormat::JPEG,
1656 }
1657 );
1658 }
1659 _ => panic!("Expected image content"),
1660 }
1661
1662 match iter.next().unwrap() {
1663 Content::Text { text, .. } => {
1664 assert_eq!(text, "What is in this image?");
1665 }
1666 _ => panic!("Expected text content"),
1667 }
1668
1669 match iter.next().unwrap() {
1670 Content::ToolResult {
1671 tool_use_id,
1672 content,
1673 is_error,
1674 ..
1675 } => {
1676 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1677 assert_eq!(
1678 content.first(),
1679 ToolResultContent::Text {
1680 text: "15 degrees".to_owned()
1681 }
1682 );
1683 assert_eq!(is_error, None);
1684 }
1685 _ => panic!("Expected tool result content"),
1686 }
1687
1688 assert_eq!(iter.next(), None);
1689 }
1690 }
1691
1692 #[test]
1693 fn test_message_to_message_conversion() {
1694 let user_message: Message = serde_json::from_str(
1695 r#"
1696 {
1697 "role": "user",
1698 "content": [
1699 {
1700 "type": "image",
1701 "source": {
1702 "type": "base64",
1703 "media_type": "image/jpeg",
1704 "data": "/9j/4AAQSkZJRg..."
1705 }
1706 },
1707 {
1708 "type": "text",
1709 "text": "What is in this image?"
1710 },
1711 {
1712 "type": "document",
1713 "source": {
1714 "type": "base64",
1715 "data": "base64_encoded_pdf_data",
1716 "media_type": "application/pdf"
1717 }
1718 }
1719 ]
1720 }
1721 "#,
1722 )
1723 .unwrap();
1724
1725 let assistant_message = Message {
1726 role: Role::Assistant,
1727 content: OneOrMany::one(Content::ToolUse {
1728 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1729 name: "get_weather".to_string(),
1730 input: json!({"location": "San Francisco, CA"}),
1731 }),
1732 };
1733
1734 let tool_message = Message {
1735 role: Role::User,
1736 content: OneOrMany::one(Content::ToolResult {
1737 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1738 content: OneOrMany::one(ToolResultContent::Text {
1739 text: "15 degrees".to_string(),
1740 }),
1741 is_error: None,
1742 cache_control: None,
1743 }),
1744 };
1745
1746 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1747 let converted_assistant_message: message::Message =
1748 assistant_message.clone().try_into().unwrap();
1749 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1750
1751 match converted_user_message.clone() {
1752 message::Message::User { content } => {
1753 assert_eq!(content.len(), 3);
1754
1755 let mut iter = content.into_iter();
1756
1757 match iter.next().unwrap() {
1758 message::UserContent::Image(message::Image {
1759 data, media_type, ..
1760 }) => {
1761 assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1762 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1763 }
1764 _ => panic!("Expected image content"),
1765 }
1766
1767 match iter.next().unwrap() {
1768 message::UserContent::Text(message::Text { text }) => {
1769 assert_eq!(text, "What is in this image?");
1770 }
1771 _ => panic!("Expected text content"),
1772 }
1773
1774 match iter.next().unwrap() {
1775 message::UserContent::Document(message::Document {
1776 data, media_type, ..
1777 }) => {
1778 assert_eq!(
1779 data,
1780 DocumentSourceKind::String("base64_encoded_pdf_data".into())
1781 );
1782 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1783 }
1784 _ => panic!("Expected document content"),
1785 }
1786
1787 assert_eq!(iter.next(), None);
1788 }
1789 _ => panic!("Expected user message"),
1790 }
1791
1792 match converted_tool_message.clone() {
1793 message::Message::User { content } => {
1794 let message::ToolResult { id, content, .. } = match content.first() {
1795 message::UserContent::ToolResult(tool_result) => tool_result,
1796 _ => panic!("Expected tool result content"),
1797 };
1798 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1799 match content.first() {
1800 message::ToolResultContent::Text(message::Text { text }) => {
1801 assert_eq!(text, "15 degrees");
1802 }
1803 _ => panic!("Expected text content"),
1804 }
1805 }
1806 _ => panic!("Expected tool result content"),
1807 }
1808
1809 match converted_assistant_message.clone() {
1810 message::Message::Assistant { content, .. } => {
1811 assert_eq!(content.len(), 1);
1812
1813 match content.first() {
1814 message::AssistantContent::ToolCall(message::ToolCall {
1815 id, function, ..
1816 }) => {
1817 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1818 assert_eq!(function.name, "get_weather");
1819 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1820 }
1821 _ => panic!("Expected tool call content"),
1822 }
1823 }
1824 _ => panic!("Expected assistant message"),
1825 }
1826
1827 let original_user_message: Message = converted_user_message.try_into().unwrap();
1828 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1829 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1830
1831 assert_eq!(user_message, original_user_message);
1832 assert_eq!(assistant_message, original_assistant_message);
1833 assert_eq!(tool_message, original_tool_message);
1834 }
1835
1836 #[test]
1837 fn test_content_format_conversion() {
1838 use crate::completion::message::ContentFormat;
1839
1840 let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1841 assert_eq!(source_type, SourceType::URL);
1842
1843 let content_format: ContentFormat = SourceType::URL.into();
1844 assert_eq!(content_format, ContentFormat::Url);
1845
1846 let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1847 assert_eq!(source_type, SourceType::BASE64);
1848
1849 let content_format: ContentFormat = SourceType::BASE64.into();
1850 assert_eq!(content_format, ContentFormat::Base64);
1851
1852 let source_type: SourceType = ContentFormat::String.try_into().unwrap();
1853 assert_eq!(source_type, SourceType::TEXT);
1854
1855 let content_format: ContentFormat = SourceType::TEXT.into();
1856 assert_eq!(content_format, ContentFormat::String);
1857 }
1858
1859 #[test]
1860 fn test_cache_control_serialization() {
1861 let system = SystemContent::Text {
1863 text: "You are a helpful assistant.".to_string(),
1864 cache_control: Some(CacheControl::ephemeral()),
1865 };
1866 let json = serde_json::to_string(&system).unwrap();
1867 assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1868 assert!(json.contains(r#""type":"text""#));
1869
1870 let system_no_cache = SystemContent::Text {
1872 text: "Hello".to_string(),
1873 cache_control: None,
1874 };
1875 let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1876 assert!(!json_no_cache.contains("cache_control"));
1877
1878 let content = Content::Text {
1880 text: "Test message".to_string(),
1881 cache_control: Some(CacheControl::ephemeral()),
1882 };
1883 let json_content = serde_json::to_string(&content).unwrap();
1884 assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1885
1886 let mut system_vec = vec![SystemContent::Text {
1888 text: "System prompt".to_string(),
1889 cache_control: None,
1890 }];
1891 let mut messages = vec![
1892 Message {
1893 role: Role::User,
1894 content: OneOrMany::one(Content::Text {
1895 text: "First message".to_string(),
1896 cache_control: None,
1897 }),
1898 },
1899 Message {
1900 role: Role::Assistant,
1901 content: OneOrMany::one(Content::Text {
1902 text: "Response".to_string(),
1903 cache_control: None,
1904 }),
1905 },
1906 ];
1907
1908 apply_cache_control(&mut system_vec, &mut messages);
1909
1910 match &system_vec[0] {
1912 SystemContent::Text { cache_control, .. } => {
1913 assert!(cache_control.is_some());
1914 }
1915 }
1916
1917 for content in messages[0].content.iter() {
1920 if let Content::Text { cache_control, .. } = content {
1921 assert!(cache_control.is_none());
1922 }
1923 }
1924
1925 for content in messages[1].content.iter() {
1927 if let Content::Text { cache_control, .. } = content {
1928 assert!(cache_control.is_some());
1929 }
1930 }
1931 }
1932
1933 #[test]
1934 fn test_plaintext_document_serialization() {
1935 let content = Content::Document {
1936 source: DocumentSource::Text {
1937 data: "Hello, world!".to_string(),
1938 media_type: PlainTextMediaType::Plain,
1939 },
1940 cache_control: None,
1941 };
1942
1943 let json = serde_json::to_value(&content).unwrap();
1944 assert_eq!(json["type"], "document");
1945 assert_eq!(json["source"]["type"], "text");
1946 assert_eq!(json["source"]["media_type"], "text/plain");
1947 assert_eq!(json["source"]["data"], "Hello, world!");
1948 }
1949
1950 #[test]
1951 fn test_plaintext_document_deserialization() {
1952 let json = r#"
1953 {
1954 "type": "document",
1955 "source": {
1956 "type": "text",
1957 "media_type": "text/plain",
1958 "data": "Hello, world!"
1959 }
1960 }
1961 "#;
1962
1963 let content: Content = serde_json::from_str(json).unwrap();
1964 match content {
1965 Content::Document {
1966 source,
1967 cache_control,
1968 } => {
1969 assert_eq!(
1970 source,
1971 DocumentSource::Text {
1972 data: "Hello, world!".to_string(),
1973 media_type: PlainTextMediaType::Plain,
1974 }
1975 );
1976 assert_eq!(cache_control, None);
1977 }
1978 _ => panic!("Expected Document content"),
1979 }
1980 }
1981
1982 #[test]
1983 fn test_base64_pdf_document_serialization() {
1984 let content = Content::Document {
1985 source: DocumentSource::Base64 {
1986 data: "base64data".to_string(),
1987 media_type: DocumentFormat::PDF,
1988 },
1989 cache_control: None,
1990 };
1991
1992 let json = serde_json::to_value(&content).unwrap();
1993 assert_eq!(json["type"], "document");
1994 assert_eq!(json["source"]["type"], "base64");
1995 assert_eq!(json["source"]["media_type"], "application/pdf");
1996 assert_eq!(json["source"]["data"], "base64data");
1997 }
1998
1999 #[test]
2000 fn test_base64_pdf_document_deserialization() {
2001 let json = r#"
2002 {
2003 "type": "document",
2004 "source": {
2005 "type": "base64",
2006 "media_type": "application/pdf",
2007 "data": "base64data"
2008 }
2009 }
2010 "#;
2011
2012 let content: Content = serde_json::from_str(json).unwrap();
2013 match content {
2014 Content::Document { source, .. } => {
2015 assert_eq!(
2016 source,
2017 DocumentSource::Base64 {
2018 data: "base64data".to_string(),
2019 media_type: DocumentFormat::PDF,
2020 }
2021 );
2022 }
2023 _ => panic!("Expected Document content"),
2024 }
2025 }
2026
2027 #[test]
2028 fn test_plaintext_rig_to_anthropic_conversion() {
2029 use crate::completion::message as msg;
2030
2031 let rig_message = msg::Message::User {
2032 content: OneOrMany::one(msg::UserContent::document(
2033 "Some plain text content".to_string(),
2034 Some(msg::DocumentMediaType::TXT),
2035 )),
2036 };
2037
2038 let anthropic_message: Message = rig_message.try_into().unwrap();
2039 assert_eq!(anthropic_message.role, Role::User);
2040
2041 let mut iter = anthropic_message.content.into_iter();
2042 match iter.next().unwrap() {
2043 Content::Document { source, .. } => {
2044 assert_eq!(
2045 source,
2046 DocumentSource::Text {
2047 data: "Some plain text content".to_string(),
2048 media_type: PlainTextMediaType::Plain,
2049 }
2050 );
2051 }
2052 other => panic!("Expected Document content, got: {other:?}"),
2053 }
2054 }
2055
2056 #[test]
2057 fn test_plaintext_anthropic_to_rig_conversion() {
2058 use crate::completion::message as msg;
2059
2060 let anthropic_message = Message {
2061 role: Role::User,
2062 content: OneOrMany::one(Content::Document {
2063 source: DocumentSource::Text {
2064 data: "Some plain text content".to_string(),
2065 media_type: PlainTextMediaType::Plain,
2066 },
2067 cache_control: None,
2068 }),
2069 };
2070
2071 let rig_message: msg::Message = anthropic_message.try_into().unwrap();
2072 match rig_message {
2073 msg::Message::User { content } => {
2074 let mut iter = content.into_iter();
2075 match iter.next().unwrap() {
2076 msg::UserContent::Document(msg::Document {
2077 data, media_type, ..
2078 }) => {
2079 assert_eq!(
2080 data,
2081 DocumentSourceKind::String("Some plain text content".into())
2082 );
2083 assert_eq!(media_type, Some(msg::DocumentMediaType::TXT));
2084 }
2085 other => panic!("Expected Document content, got: {other:?}"),
2086 }
2087 }
2088 _ => panic!("Expected User message"),
2089 }
2090 }
2091
2092 #[test]
2093 fn test_plaintext_roundtrip_rig_to_anthropic_and_back() {
2094 use crate::completion::message as msg;
2095
2096 let original = msg::Message::User {
2097 content: OneOrMany::one(msg::UserContent::document(
2098 "Round trip text".to_string(),
2099 Some(msg::DocumentMediaType::TXT),
2100 )),
2101 };
2102
2103 let anthropic: Message = original.clone().try_into().unwrap();
2104 let back: msg::Message = anthropic.try_into().unwrap();
2105
2106 match (&original, &back) {
2107 (
2108 msg::Message::User {
2109 content: orig_content,
2110 },
2111 msg::Message::User {
2112 content: back_content,
2113 },
2114 ) => match (orig_content.first(), back_content.first()) {
2115 (
2116 msg::UserContent::Document(msg::Document {
2117 media_type: orig_mt,
2118 ..
2119 }),
2120 msg::UserContent::Document(msg::Document {
2121 media_type: back_mt,
2122 ..
2123 }),
2124 ) => {
2125 assert_eq!(orig_mt, back_mt);
2126 }
2127 _ => panic!("Expected Document content in both"),
2128 },
2129 _ => panic!("Expected User messages"),
2130 }
2131 }
2132
2133 #[test]
2134 fn test_unsupported_document_type_returns_error() {
2135 use crate::completion::message as msg;
2136
2137 let rig_message = msg::Message::User {
2138 content: OneOrMany::one(msg::UserContent::Document(msg::Document {
2139 data: DocumentSourceKind::String("data".into()),
2140 media_type: Some(msg::DocumentMediaType::HTML),
2141 additional_params: None,
2142 })),
2143 };
2144
2145 let result: Result<Message, _> = rig_message.try_into();
2146 assert!(result.is_err());
2147 let err = result.unwrap_err().to_string();
2148 assert!(
2149 err.contains("Anthropic only supports PDF and plain text documents"),
2150 "Unexpected error: {err}"
2151 );
2152 }
2153
2154 #[test]
2155 fn test_plaintext_document_url_source_returns_error() {
2156 use crate::completion::message as msg;
2157
2158 let rig_message = msg::Message::User {
2159 content: OneOrMany::one(msg::UserContent::Document(msg::Document {
2160 data: DocumentSourceKind::Url("https://example.com/doc.txt".into()),
2161 media_type: Some(msg::DocumentMediaType::TXT),
2162 additional_params: None,
2163 })),
2164 };
2165
2166 let result: Result<Message, _> = rig_message.try_into();
2167 assert!(result.is_err());
2168 let err = result.unwrap_err().to_string();
2169 assert!(
2170 err.contains("Only string or base64 data is supported for plain text documents"),
2171 "Unexpected error: {err}"
2172 );
2173 }
2174
2175 #[test]
2176 fn test_plaintext_document_with_cache_control() {
2177 let content = Content::Document {
2178 source: DocumentSource::Text {
2179 data: "cached text".to_string(),
2180 media_type: PlainTextMediaType::Plain,
2181 },
2182 cache_control: Some(CacheControl::ephemeral()),
2183 };
2184
2185 let json = serde_json::to_value(&content).unwrap();
2186 assert_eq!(json["source"]["type"], "text");
2187 assert_eq!(json["source"]["media_type"], "text/plain");
2188 assert_eq!(json["cache_control"]["type"], "ephemeral");
2189 }
2190
2191 #[test]
2192 fn test_message_with_plaintext_document_deserialization() {
2193 let json = r#"
2194 {
2195 "role": "user",
2196 "content": [
2197 {
2198 "type": "document",
2199 "source": {
2200 "type": "text",
2201 "media_type": "text/plain",
2202 "data": "Hello from a text file"
2203 }
2204 },
2205 {
2206 "type": "text",
2207 "text": "Summarize this document."
2208 }
2209 ]
2210 }
2211 "#;
2212
2213 let message: Message = serde_json::from_str(json).unwrap();
2214 assert_eq!(message.role, Role::User);
2215 assert_eq!(message.content.len(), 2);
2216
2217 let mut iter = message.content.into_iter();
2218
2219 match iter.next().unwrap() {
2220 Content::Document { source, .. } => {
2221 assert_eq!(
2222 source,
2223 DocumentSource::Text {
2224 data: "Hello from a text file".to_string(),
2225 media_type: PlainTextMediaType::Plain,
2226 }
2227 );
2228 }
2229 _ => panic!("Expected Document content"),
2230 }
2231
2232 match iter.next().unwrap() {
2233 Content::Text { text, .. } => {
2234 assert_eq!(text, "Summarize this document.");
2235 }
2236 _ => panic!("Expected Text content"),
2237 }
2238 }
2239
2240 #[test]
2241 fn test_assistant_reasoning_multiblock_to_anthropic_content() {
2242 let reasoning = message::Reasoning {
2243 id: None,
2244 content: vec![
2245 message::ReasoningContent::Text {
2246 text: "step one".to_string(),
2247 signature: Some("sig-1".to_string()),
2248 },
2249 message::ReasoningContent::Summary("summary".to_string()),
2250 message::ReasoningContent::Text {
2251 text: "step two".to_string(),
2252 signature: Some("sig-2".to_string()),
2253 },
2254 message::ReasoningContent::Redacted {
2255 data: "redacted block".to_string(),
2256 },
2257 ],
2258 };
2259
2260 let msg = message::Message::Assistant {
2261 id: None,
2262 content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2263 };
2264 let converted: Message = msg.try_into().expect("convert assistant message");
2265 let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2266
2267 assert_eq!(converted.role, Role::Assistant);
2268 assert_eq!(converted_content.len(), 4);
2269 assert!(matches!(
2270 converted_content.first(),
2271 Some(Content::Thinking { thinking, signature: Some(signature) })
2272 if thinking == "step one" && signature == "sig-1"
2273 ));
2274 assert!(matches!(
2275 converted_content.get(1),
2276 Some(Content::Thinking { thinking, signature: None }) if thinking == "summary"
2277 ));
2278 assert!(matches!(
2279 converted_content.get(2),
2280 Some(Content::Thinking { thinking, signature: Some(signature) })
2281 if thinking == "step two" && signature == "sig-2"
2282 ));
2283 assert!(matches!(
2284 converted_content.get(3),
2285 Some(Content::RedactedThinking { data }) if data == "redacted block"
2286 ));
2287 }
2288
2289 #[test]
2290 fn test_redacted_thinking_content_to_assistant_reasoning() {
2291 let content = Content::RedactedThinking {
2292 data: "opaque-redacted".to_string(),
2293 };
2294 let converted: message::AssistantContent =
2295 content.try_into().expect("convert redacted thinking");
2296
2297 assert!(matches!(
2298 converted,
2299 message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2300 if matches!(
2301 content.first(),
2302 Some(message::ReasoningContent::Redacted { data }) if data == "opaque-redacted"
2303 )
2304 ));
2305 }
2306
2307 #[test]
2308 fn test_assistant_encrypted_reasoning_maps_to_redacted_thinking() {
2309 let reasoning = message::Reasoning {
2310 id: None,
2311 content: vec![message::ReasoningContent::Encrypted(
2312 "ciphertext".to_string(),
2313 )],
2314 };
2315 let msg = message::Message::Assistant {
2316 id: None,
2317 content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2318 };
2319
2320 let converted: Message = msg.try_into().expect("convert assistant message");
2321 let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2322
2323 assert_eq!(converted_content.len(), 1);
2324 assert!(matches!(
2325 converted_content.first(),
2326 Some(Content::RedactedThinking { data }) if data == "ciphertext"
2327 ));
2328 }
2329
2330 #[test]
2331 fn empty_end_turn_response_normalizes_to_empty_text_choice() {
2332 let response = CompletionResponse {
2333 content: vec![],
2334 id: "msg_123".to_string(),
2335 model: CLAUDE_SONNET_4_6.to_string(),
2336 role: "assistant".to_string(),
2337 stop_reason: Some("end_turn".to_string()),
2338 stop_sequence: None,
2339 usage: Usage {
2340 input_tokens: 7,
2341 cache_read_input_tokens: None,
2342 cache_creation_input_tokens: None,
2343 output_tokens: 2,
2344 },
2345 };
2346
2347 let parsed: completion::CompletionResponse<CompletionResponse> = response
2348 .try_into()
2349 .expect("empty end_turn should not error");
2350
2351 assert_eq!(parsed.choice.len(), 1);
2352 assert!(matches!(
2353 parsed.choice.first(),
2354 completion::AssistantContent::Text(text) if text.text.is_empty()
2355 ));
2356 }
2357
2358 #[test]
2359 fn empty_non_end_turn_response_still_errors() {
2360 let response = CompletionResponse {
2361 content: vec![],
2362 id: "msg_123".to_string(),
2363 model: CLAUDE_SONNET_4_6.to_string(),
2364 role: "assistant".to_string(),
2365 stop_reason: Some("tool_use".to_string()),
2366 stop_sequence: None,
2367 usage: Usage {
2368 input_tokens: 7,
2369 cache_read_input_tokens: None,
2370 cache_creation_input_tokens: None,
2371 output_tokens: 2,
2372 },
2373 };
2374
2375 let err = completion::CompletionResponse::<CompletionResponse>::try_from(response)
2376 .expect_err("empty non-end_turn should remain an error");
2377
2378 assert!(matches!(
2379 err,
2380 CompletionError::ResponseError(message) if message == EMPTY_RESPONSE_ERROR
2381 ));
2382 }
2383}