1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError, GetTokenUsage},
6 http_client::HttpClientExt,
7 message::{self, DocumentMediaType, DocumentSourceKind, MessageError, MimeType, Reasoning},
8 one_or_many::string_or_one_or_many,
9 telemetry::{ProviderResponseExt, SpanCombinator},
10 wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42 pub content: Vec<Content>,
43 pub id: String,
44 pub model: String,
45 pub role: String,
46 pub stop_reason: Option<String>,
47 pub stop_sequence: Option<String>,
48 pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52 type OutputMessage = Content;
53 type Usage = Usage;
54
55 fn get_response_id(&self) -> Option<String> {
56 Some(self.id.to_owned())
57 }
58
59 fn get_response_model_name(&self) -> Option<String> {
60 Some(self.model.to_owned())
61 }
62
63 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64 self.content.clone()
65 }
66
67 fn get_text_response(&self) -> Option<String> {
68 let res = self
69 .content
70 .iter()
71 .filter_map(|x| {
72 if let Content::Text { text, .. } = x {
73 Some(text.to_owned())
74 } else {
75 None
76 }
77 })
78 .collect::<Vec<String>>()
79 .join("\n");
80
81 if res.is_empty() { None } else { Some(res) }
82 }
83
84 fn get_usage(&self) -> Option<Self::Usage> {
85 Some(self.usage.clone())
86 }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91 pub input_tokens: u64,
92 pub cache_read_input_tokens: Option<u64>,
93 pub cache_creation_input_tokens: Option<u64>,
94 pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 write!(
100 f,
101 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102 self.input_tokens,
103 match self.cache_read_input_tokens {
104 Some(token) => token.to_string(),
105 None => "n/a".to_string(),
106 },
107 match self.cache_creation_input_tokens {
108 Some(token) => token.to_string(),
109 None => "n/a".to_string(),
110 },
111 self.output_tokens
112 )
113 }
114}
115
116impl GetTokenUsage for Usage {
117 fn token_usage(&self) -> Option<crate::completion::Usage> {
118 let mut usage = crate::completion::Usage::new();
119
120 usage.input_tokens = self.input_tokens
121 + self.cache_creation_input_tokens.unwrap_or_default()
122 + self.cache_read_input_tokens.unwrap_or_default();
123 usage.output_tokens = self.output_tokens;
124 usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126 Some(usage)
127 }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132 pub name: String,
133 pub description: Option<String>,
134 pub input_schema: serde_json::Value,
135}
136
137#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum CacheControl {
141 Ephemeral,
142}
143
144#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
146#[serde(tag = "type", rename_all = "snake_case")]
147pub enum SystemContent {
148 Text {
149 text: String,
150 #[serde(skip_serializing_if = "Option::is_none")]
151 cache_control: Option<CacheControl>,
152 },
153}
154
155impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
156 type Error = CompletionError;
157
158 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
159 let content = response
160 .content
161 .iter()
162 .map(|content| content.clone().try_into())
163 .collect::<Result<Vec<_>, _>>()?;
164
165 let choice = OneOrMany::many(content).map_err(|_| {
166 CompletionError::ResponseError(
167 "Response contained no message or tool call (empty)".to_owned(),
168 )
169 })?;
170
171 let usage = completion::Usage {
172 input_tokens: response.usage.input_tokens,
173 output_tokens: response.usage.output_tokens,
174 total_tokens: response.usage.input_tokens + response.usage.output_tokens,
175 cached_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
176 };
177
178 Ok(completion::CompletionResponse {
179 choice,
180 usage,
181 raw_response: response,
182 message_id: None,
183 })
184 }
185}
186
187#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
188pub struct Message {
189 pub role: Role,
190 #[serde(deserialize_with = "string_or_one_or_many")]
191 pub content: OneOrMany<Content>,
192}
193
194#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
195#[serde(rename_all = "lowercase")]
196pub enum Role {
197 User,
198 Assistant,
199}
200
201#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
202#[serde(tag = "type", rename_all = "snake_case")]
203pub enum Content {
204 Text {
205 text: String,
206 #[serde(skip_serializing_if = "Option::is_none")]
207 cache_control: Option<CacheControl>,
208 },
209 Image {
210 source: ImageSource,
211 #[serde(skip_serializing_if = "Option::is_none")]
212 cache_control: Option<CacheControl>,
213 },
214 ToolUse {
215 id: String,
216 name: String,
217 input: serde_json::Value,
218 },
219 ToolResult {
220 tool_use_id: String,
221 #[serde(deserialize_with = "string_or_one_or_many")]
222 content: OneOrMany<ToolResultContent>,
223 #[serde(skip_serializing_if = "Option::is_none")]
224 is_error: Option<bool>,
225 #[serde(skip_serializing_if = "Option::is_none")]
226 cache_control: Option<CacheControl>,
227 },
228 Document {
229 source: DocumentSource,
230 #[serde(skip_serializing_if = "Option::is_none")]
231 cache_control: Option<CacheControl>,
232 },
233 Thinking {
234 thinking: String,
235 #[serde(skip_serializing_if = "Option::is_none")]
236 signature: Option<String>,
237 },
238 RedactedThinking {
239 data: String,
240 },
241}
242
243impl FromStr for Content {
244 type Err = Infallible;
245
246 fn from_str(s: &str) -> Result<Self, Self::Err> {
247 Ok(Content::Text {
248 text: s.to_owned(),
249 cache_control: None,
250 })
251 }
252}
253
254#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
255#[serde(tag = "type", rename_all = "snake_case")]
256pub enum ToolResultContent {
257 Text { text: String },
258 Image(ImageSource),
259}
260
261impl FromStr for ToolResultContent {
262 type Err = Infallible;
263
264 fn from_str(s: &str) -> Result<Self, Self::Err> {
265 Ok(ToolResultContent::Text { text: s.to_owned() })
266 }
267}
268
269#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
277#[serde(tag = "type", rename_all = "snake_case")]
278pub enum ImageSource {
279 #[serde(rename = "base64")]
280 Base64 {
281 data: String,
282 media_type: ImageFormat,
283 },
284 #[serde(rename = "url")]
285 Url { url: String },
286}
287
288#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
299#[serde(tag = "type", rename_all = "snake_case")]
300pub enum DocumentSource {
301 Base64 {
302 data: String,
303 media_type: DocumentFormat,
304 },
305 Text {
306 data: String,
307 media_type: PlainTextMediaType,
308 },
309 Url {
310 url: String,
311 },
312}
313
314#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
315#[serde(rename_all = "lowercase")]
316pub enum ImageFormat {
317 #[serde(rename = "image/jpeg")]
318 JPEG,
319 #[serde(rename = "image/png")]
320 PNG,
321 #[serde(rename = "image/gif")]
322 GIF,
323 #[serde(rename = "image/webp")]
324 WEBP,
325}
326
327#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
334#[serde(rename_all = "lowercase")]
335pub enum DocumentFormat {
336 #[serde(rename = "application/pdf")]
337 PDF,
338}
339
340#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
346pub enum PlainTextMediaType {
347 #[serde(rename = "text/plain")]
348 Plain,
349}
350
351#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
352#[serde(rename_all = "lowercase")]
353pub enum SourceType {
354 BASE64,
355 URL,
356 TEXT,
357}
358
359impl From<String> for Content {
360 fn from(text: String) -> Self {
361 Content::Text {
362 text,
363 cache_control: None,
364 }
365 }
366}
367
368impl From<String> for ToolResultContent {
369 fn from(text: String) -> Self {
370 ToolResultContent::Text { text }
371 }
372}
373
374impl TryFrom<message::ContentFormat> for SourceType {
375 type Error = MessageError;
376
377 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
378 match format {
379 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
380 message::ContentFormat::Url => Ok(SourceType::URL),
381 message::ContentFormat::String => Ok(SourceType::TEXT),
382 }
383 }
384}
385
386impl From<SourceType> for message::ContentFormat {
387 fn from(source_type: SourceType) -> Self {
388 match source_type {
389 SourceType::BASE64 => message::ContentFormat::Base64,
390 SourceType::URL => message::ContentFormat::Url,
391 SourceType::TEXT => message::ContentFormat::String,
392 }
393 }
394}
395
396impl TryFrom<message::ImageMediaType> for ImageFormat {
397 type Error = MessageError;
398
399 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
400 Ok(match media_type {
401 message::ImageMediaType::JPEG => ImageFormat::JPEG,
402 message::ImageMediaType::PNG => ImageFormat::PNG,
403 message::ImageMediaType::GIF => ImageFormat::GIF,
404 message::ImageMediaType::WEBP => ImageFormat::WEBP,
405 _ => {
406 return Err(MessageError::ConversionError(
407 format!("Unsupported image media type: {media_type:?}").to_owned(),
408 ));
409 }
410 })
411 }
412}
413
414impl From<ImageFormat> for message::ImageMediaType {
415 fn from(format: ImageFormat) -> Self {
416 match format {
417 ImageFormat::JPEG => message::ImageMediaType::JPEG,
418 ImageFormat::PNG => message::ImageMediaType::PNG,
419 ImageFormat::GIF => message::ImageMediaType::GIF,
420 ImageFormat::WEBP => message::ImageMediaType::WEBP,
421 }
422 }
423}
424
425impl TryFrom<DocumentMediaType> for DocumentFormat {
426 type Error = MessageError;
427 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
428 match value {
429 DocumentMediaType::PDF => Ok(DocumentFormat::PDF),
430 other => Err(MessageError::ConversionError(format!(
431 "DocumentFormat only supports PDF for base64 sources, got: {}",
432 other.to_mime_type()
433 ))),
434 }
435 }
436}
437
438impl TryFrom<message::AssistantContent> for Content {
439 type Error = MessageError;
440 fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
441 match text {
442 message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
443 text,
444 cache_control: None,
445 }),
446 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
447 "Anthropic currently doesn't support images.".to_string(),
448 )),
449 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
450 Ok(Content::ToolUse {
451 id,
452 name: function.name,
453 input: function.arguments,
454 })
455 }
456 message::AssistantContent::Reasoning(reasoning) => Ok(Content::Thinking {
457 thinking: reasoning.display_text(),
458 signature: reasoning.first_signature().map(str::to_owned),
459 }),
460 }
461 }
462}
463
464fn anthropic_content_from_assistant_content(
465 content: message::AssistantContent,
466) -> Result<Vec<Content>, MessageError> {
467 match content {
468 message::AssistantContent::Text(message::Text { text }) => Ok(vec![Content::Text {
469 text,
470 cache_control: None,
471 }]),
472 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
473 "Anthropic currently doesn't support images.".to_string(),
474 )),
475 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
476 Ok(vec![Content::ToolUse {
477 id,
478 name: function.name,
479 input: function.arguments,
480 }])
481 }
482 message::AssistantContent::Reasoning(reasoning) => {
483 let mut converted = Vec::new();
484 for block in reasoning.content {
485 match block {
486 message::ReasoningContent::Text { text, signature } => {
487 converted.push(Content::Thinking {
488 thinking: text,
489 signature,
490 });
491 }
492 message::ReasoningContent::Summary(summary) => {
493 converted.push(Content::Thinking {
494 thinking: summary,
495 signature: None,
496 });
497 }
498 message::ReasoningContent::Redacted { data }
499 | message::ReasoningContent::Encrypted(data) => {
500 converted.push(Content::RedactedThinking { data });
501 }
502 }
503 }
504
505 if converted.is_empty() {
506 return Err(MessageError::ConversionError(
507 "Cannot convert empty reasoning content to Anthropic format".to_string(),
508 ));
509 }
510
511 Ok(converted)
512 }
513 }
514}
515
516impl TryFrom<message::Message> for Message {
517 type Error = MessageError;
518
519 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
520 Ok(match message {
521 message::Message::User { content } => Message {
522 role: Role::User,
523 content: content.try_map(|content| match content {
524 message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
525 text,
526 cache_control: None,
527 }),
528 message::UserContent::ToolResult(message::ToolResult {
529 id, content, ..
530 }) => Ok(Content::ToolResult {
531 tool_use_id: id,
532 content: content.try_map(|content| match content {
533 message::ToolResultContent::Text(message::Text { text }) => {
534 Ok(ToolResultContent::Text { text })
535 }
536 message::ToolResultContent::Image(image) => {
537 let DocumentSourceKind::Base64(data) = image.data else {
538 return Err(MessageError::ConversionError(
539 "Only base64 strings can be used with the Anthropic API"
540 .to_string(),
541 ));
542 };
543 let media_type =
544 image.media_type.ok_or(MessageError::ConversionError(
545 "Image media type is required".to_owned(),
546 ))?;
547 Ok(ToolResultContent::Image(ImageSource::Base64 {
548 data,
549 media_type: media_type.try_into()?,
550 }))
551 }
552 })?,
553 is_error: None,
554 cache_control: None,
555 }),
556 message::UserContent::Image(message::Image {
557 data, media_type, ..
558 }) => {
559 let source = match data {
560 DocumentSourceKind::Base64(data) => {
561 let media_type =
562 media_type.ok_or(MessageError::ConversionError(
563 "Image media type is required for Claude API".to_string(),
564 ))?;
565 ImageSource::Base64 {
566 data,
567 media_type: ImageFormat::try_from(media_type)?,
568 }
569 }
570 DocumentSourceKind::Url(url) => ImageSource::Url { url },
571 DocumentSourceKind::Unknown => {
572 return Err(MessageError::ConversionError(
573 "Image content has no body".into(),
574 ));
575 }
576 doc => {
577 return Err(MessageError::ConversionError(format!(
578 "Unsupported document type: {doc:?}"
579 )));
580 }
581 };
582
583 Ok(Content::Image {
584 source,
585 cache_control: None,
586 })
587 }
588 message::UserContent::Document(message::Document {
589 data, media_type, ..
590 }) => {
591 let media_type = media_type.ok_or(MessageError::ConversionError(
592 "Document media type is required".to_string(),
593 ))?;
594
595 let source = match media_type {
596 DocumentMediaType::PDF => {
597 let data = match data {
598 DocumentSourceKind::Base64(data)
599 | DocumentSourceKind::String(data) => data,
600 _ => {
601 return Err(MessageError::ConversionError(
602 "Only base64 encoded data is supported for PDF documents".into(),
603 ));
604 }
605 };
606 DocumentSource::Base64 {
607 data,
608 media_type: DocumentFormat::PDF,
609 }
610 }
611 DocumentMediaType::TXT => {
612 let data = match data {
613 DocumentSourceKind::String(data)
614 | DocumentSourceKind::Base64(data) => data,
615 _ => {
616 return Err(MessageError::ConversionError(
617 "Only string or base64 data is supported for plain text documents".into(),
618 ));
619 }
620 };
621 DocumentSource::Text {
622 data,
623 media_type: PlainTextMediaType::Plain,
624 }
625 }
626 other => {
627 return Err(MessageError::ConversionError(format!(
628 "Anthropic only supports PDF and plain text documents, got: {}",
629 other.to_mime_type()
630 )));
631 }
632 };
633
634 Ok(Content::Document {
635 source,
636 cache_control: None,
637 })
638 }
639 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
640 "Audio is not supported in Anthropic".to_owned(),
641 )),
642 message::UserContent::Video { .. } => Err(MessageError::ConversionError(
643 "Video is not supported in Anthropic".to_owned(),
644 )),
645 })?,
646 },
647
648 message::Message::Assistant { content, .. } => {
649 let converted_content = content.into_iter().try_fold(
650 Vec::new(),
651 |mut accumulated, assistant_content| {
652 accumulated
653 .extend(anthropic_content_from_assistant_content(assistant_content)?);
654 Ok::<Vec<Content>, MessageError>(accumulated)
655 },
656 )?;
657
658 Message {
659 content: OneOrMany::many(converted_content).map_err(|_| {
660 MessageError::ConversionError(
661 "Assistant message did not contain Anthropic-compatible content"
662 .to_owned(),
663 )
664 })?,
665 role: Role::Assistant,
666 }
667 }
668 })
669 }
670}
671
672impl TryFrom<Content> for message::AssistantContent {
673 type Error = MessageError;
674
675 fn try_from(content: Content) -> Result<Self, Self::Error> {
676 Ok(match content {
677 Content::Text { text, .. } => message::AssistantContent::text(text),
678 Content::ToolUse { id, name, input } => {
679 message::AssistantContent::tool_call(id, name, input)
680 }
681 Content::Thinking {
682 thinking,
683 signature,
684 } => message::AssistantContent::Reasoning(Reasoning::new_with_signature(
685 &thinking, signature,
686 )),
687 Content::RedactedThinking { data } => {
688 message::AssistantContent::Reasoning(Reasoning::redacted(data))
689 }
690 _ => {
691 return Err(MessageError::ConversionError(
692 "Content did not contain a message, tool call, or reasoning".to_owned(),
693 ));
694 }
695 })
696 }
697}
698
699impl From<ToolResultContent> for message::ToolResultContent {
700 fn from(content: ToolResultContent) -> Self {
701 match content {
702 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
703 ToolResultContent::Image(source) => match source {
704 ImageSource::Base64 { data, media_type } => {
705 message::ToolResultContent::image_base64(data, Some(media_type.into()), None)
706 }
707 ImageSource::Url { url } => message::ToolResultContent::image_url(url, None, None),
708 },
709 }
710 }
711}
712
713impl TryFrom<Message> for message::Message {
714 type Error = MessageError;
715
716 fn try_from(message: Message) -> Result<Self, Self::Error> {
717 Ok(match message.role {
718 Role::User => message::Message::User {
719 content: message.content.try_map(|content| {
720 Ok(match content {
721 Content::Text { text, .. } => message::UserContent::text(text),
722 Content::ToolResult {
723 tool_use_id,
724 content,
725 ..
726 } => message::UserContent::tool_result(
727 tool_use_id,
728 content.map(|content| content.into()),
729 ),
730 Content::Image { source, .. } => match source {
731 ImageSource::Base64 { data, media_type } => {
732 message::UserContent::Image(message::Image {
733 data: DocumentSourceKind::Base64(data),
734 media_type: Some(media_type.into()),
735 detail: None,
736 additional_params: None,
737 })
738 }
739 ImageSource::Url { url } => {
740 message::UserContent::Image(message::Image {
741 data: DocumentSourceKind::Url(url),
742 media_type: None,
743 detail: None,
744 additional_params: None,
745 })
746 }
747 },
748 Content::Document { source, .. } => match source {
749 DocumentSource::Base64 { data, media_type } => {
750 let rig_media_type = match media_type {
751 DocumentFormat::PDF => message::DocumentMediaType::PDF,
752 };
753 message::UserContent::document(data, Some(rig_media_type))
754 }
755 DocumentSource::Text { data, .. } => message::UserContent::document(
756 data,
757 Some(message::DocumentMediaType::TXT),
758 ),
759 DocumentSource::Url { url } => {
760 message::UserContent::document_url(url, None)
761 }
762 },
763 _ => {
764 return Err(MessageError::ConversionError(
765 "Unsupported content type for User role".to_owned(),
766 ));
767 }
768 })
769 })?,
770 },
771 Role::Assistant => message::Message::Assistant {
772 id: None,
773 content: message.content.try_map(|content| content.try_into())?,
774 },
775 })
776 }
777}
778
779#[derive(Clone)]
780pub struct CompletionModel<T = reqwest::Client> {
781 pub(crate) client: Client<T>,
782 pub model: String,
783 pub default_max_tokens: Option<u64>,
784 pub prompt_caching: bool,
786}
787
788impl<T> CompletionModel<T>
789where
790 T: HttpClientExt,
791{
792 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
793 let model = model.into();
794 let default_max_tokens = calculate_max_tokens(&model);
795
796 Self {
797 client,
798 model,
799 default_max_tokens,
800 prompt_caching: false, }
802 }
803
804 pub fn with_model(client: Client<T>, model: &str) -> Self {
805 Self {
806 client,
807 model: model.to_string(),
808 default_max_tokens: Some(calculate_max_tokens_custom(model)),
809 prompt_caching: false, }
811 }
812
813 pub fn with_prompt_caching(mut self) -> Self {
821 self.prompt_caching = true;
822 self
823 }
824}
825
826fn calculate_max_tokens(model: &str) -> Option<u64> {
830 if model.starts_with("claude-opus-4") {
831 Some(32000)
832 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
833 Some(64000)
834 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
835 Some(8192)
836 } else if model.starts_with("claude-3-opus")
837 || model.starts_with("claude-3-sonnet")
838 || model.starts_with("claude-3-haiku")
839 {
840 Some(4096)
841 } else {
842 None
843 }
844}
845
846fn calculate_max_tokens_custom(model: &str) -> u64 {
847 if model.starts_with("claude-opus-4") {
848 32000
849 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
850 64000
851 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
852 8192
853 } else if model.starts_with("claude-3-opus")
854 || model.starts_with("claude-3-sonnet")
855 || model.starts_with("claude-3-haiku")
856 {
857 4096
858 } else {
859 2048
860 }
861}
862
863#[derive(Debug, Deserialize, Serialize)]
864pub struct Metadata {
865 user_id: Option<String>,
866}
867
868#[derive(Default, Debug, Serialize, Deserialize)]
869#[serde(tag = "type", rename_all = "snake_case")]
870pub enum ToolChoice {
871 #[default]
872 Auto,
873 Any,
874 None,
875 Tool {
876 name: String,
877 },
878}
879impl TryFrom<message::ToolChoice> for ToolChoice {
880 type Error = CompletionError;
881
882 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
883 let res = match value {
884 message::ToolChoice::Auto => Self::Auto,
885 message::ToolChoice::None => Self::None,
886 message::ToolChoice::Required => Self::Any,
887 message::ToolChoice::Specific { function_names } => {
888 if function_names.len() != 1 {
889 return Err(CompletionError::ProviderError(
890 "Only one tool may be specified to be used by Claude".into(),
891 ));
892 }
893
894 Self::Tool {
895 name: function_names.first().unwrap().to_string(),
896 }
897 }
898 };
899
900 Ok(res)
901 }
902}
903
904fn sanitize_schema(schema: &mut serde_json::Value) {
910 use serde_json::Value;
911
912 if let Value::Object(obj) = schema {
913 let is_object_schema = obj.get("type") == Some(&Value::String("object".to_string()))
914 || obj.contains_key("properties");
915
916 if is_object_schema && !obj.contains_key("additionalProperties") {
917 obj.insert("additionalProperties".to_string(), Value::Bool(false));
918 }
919
920 if let Some(Value::Object(properties)) = obj.get("properties") {
921 let prop_keys = properties.keys().cloned().map(Value::String).collect();
922 obj.insert("required".to_string(), Value::Array(prop_keys));
923 }
924
925 let is_numeric_schema = obj.get("type") == Some(&Value::String("integer".to_string()))
927 || obj.get("type") == Some(&Value::String("number".to_string()));
928
929 if is_numeric_schema {
930 for key in [
931 "minimum",
932 "maximum",
933 "exclusiveMinimum",
934 "exclusiveMaximum",
935 "multipleOf",
936 ] {
937 obj.remove(key);
938 }
939 }
940
941 if let Some(defs) = obj.get_mut("$defs")
942 && let Value::Object(defs_obj) = defs
943 {
944 for (_, def_schema) in defs_obj.iter_mut() {
945 sanitize_schema(def_schema);
946 }
947 }
948
949 if let Some(properties) = obj.get_mut("properties")
950 && let Value::Object(props) = properties
951 {
952 for (_, prop_value) in props.iter_mut() {
953 sanitize_schema(prop_value);
954 }
955 }
956
957 if let Some(items) = obj.get_mut("items") {
958 sanitize_schema(items);
959 }
960
961 for key in ["anyOf", "oneOf", "allOf"] {
962 if let Some(variants) = obj.get_mut(key)
963 && let Value::Array(variants_array) = variants
964 {
965 for variant in variants_array.iter_mut() {
966 sanitize_schema(variant);
967 }
968 }
969 }
970 }
971}
972
973#[derive(Debug, Deserialize, Serialize)]
976#[serde(tag = "type", rename_all = "snake_case")]
977enum OutputFormat {
978 JsonSchema { schema: serde_json::Value },
980}
981
982#[derive(Debug, Deserialize, Serialize)]
984struct OutputConfig {
985 format: OutputFormat,
986}
987
988#[derive(Debug, Deserialize, Serialize)]
989struct AnthropicCompletionRequest {
990 model: String,
991 messages: Vec<Message>,
992 max_tokens: u64,
993 #[serde(skip_serializing_if = "Vec::is_empty")]
995 system: Vec<SystemContent>,
996 #[serde(skip_serializing_if = "Option::is_none")]
997 temperature: Option<f64>,
998 #[serde(skip_serializing_if = "Option::is_none")]
999 tool_choice: Option<ToolChoice>,
1000 #[serde(skip_serializing_if = "Vec::is_empty")]
1001 tools: Vec<ToolDefinition>,
1002 #[serde(skip_serializing_if = "Option::is_none")]
1003 output_config: Option<OutputConfig>,
1004 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1005 additional_params: Option<serde_json::Value>,
1006}
1007
1008fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
1010 match content {
1011 Content::Text { cache_control, .. } => *cache_control = value,
1012 Content::Image { cache_control, .. } => *cache_control = value,
1013 Content::ToolResult { cache_control, .. } => *cache_control = value,
1014 Content::Document { cache_control, .. } => *cache_control = value,
1015 _ => {}
1016 }
1017}
1018
1019pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
1024 if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
1026 *cache_control = Some(CacheControl::Ephemeral);
1027 }
1028
1029 for msg in messages.iter_mut() {
1031 for content in msg.content.iter_mut() {
1032 set_content_cache_control(content, None);
1033 }
1034 }
1035
1036 if let Some(last_msg) = messages.last_mut() {
1038 set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::Ephemeral));
1039 }
1040}
1041
1042pub struct AnthropicRequestParams<'a> {
1044 pub model: &'a str,
1045 pub request: CompletionRequest,
1046 pub prompt_caching: bool,
1047}
1048
1049impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
1050 type Error = CompletionError;
1051
1052 fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
1053 let AnthropicRequestParams {
1054 model,
1055 request: req,
1056 prompt_caching,
1057 } = params;
1058
1059 let Some(max_tokens) = req.max_tokens else {
1061 return Err(CompletionError::RequestError(
1062 "`max_tokens` must be set for Anthropic".into(),
1063 ));
1064 };
1065
1066 let mut full_history = vec![];
1067 if let Some(docs) = req.normalized_documents() {
1068 full_history.push(docs);
1069 }
1070 full_history.extend(req.chat_history);
1071
1072 let mut messages = full_history
1073 .into_iter()
1074 .map(Message::try_from)
1075 .collect::<Result<Vec<Message>, _>>()?;
1076
1077 let tools = req
1078 .tools
1079 .into_iter()
1080 .map(|tool| ToolDefinition {
1081 name: tool.name,
1082 description: Some(tool.description),
1083 input_schema: tool.parameters,
1084 })
1085 .collect::<Vec<_>>();
1086
1087 let mut system = if let Some(preamble) = req.preamble {
1089 if preamble.is_empty() {
1090 vec![]
1091 } else {
1092 vec![SystemContent::Text {
1093 text: preamble,
1094 cache_control: None,
1095 }]
1096 }
1097 } else {
1098 vec![]
1099 };
1100
1101 if prompt_caching {
1103 apply_cache_control(&mut system, &mut messages);
1104 }
1105
1106 let output_config = req.output_schema.map(|schema| {
1108 let mut schema_value = schema.to_value();
1109 sanitize_schema(&mut schema_value);
1110 OutputConfig {
1111 format: OutputFormat::JsonSchema {
1112 schema: schema_value,
1113 },
1114 }
1115 });
1116
1117 Ok(Self {
1118 model: model.to_string(),
1119 messages,
1120 max_tokens,
1121 system,
1122 temperature: req.temperature,
1123 tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
1124 tools,
1125 output_config,
1126 additional_params: req.additional_params,
1127 })
1128 }
1129}
1130
1131impl<T> completion::CompletionModel for CompletionModel<T>
1132where
1133 T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
1134{
1135 type Response = CompletionResponse;
1136 type StreamingResponse = StreamingCompletionResponse;
1137 type Client = Client<T>;
1138
1139 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1140 Self::new(client.clone(), model.into())
1141 }
1142
1143 async fn completion(
1144 &self,
1145 mut completion_request: completion::CompletionRequest,
1146 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1147 let request_model = completion_request
1148 .model
1149 .clone()
1150 .unwrap_or_else(|| self.model.clone());
1151 let span = if tracing::Span::current().is_disabled() {
1152 info_span!(
1153 target: "rig::completions",
1154 "chat",
1155 gen_ai.operation.name = "chat",
1156 gen_ai.provider.name = "anthropic",
1157 gen_ai.request.model = &request_model,
1158 gen_ai.system_instructions = &completion_request.preamble,
1159 gen_ai.response.id = tracing::field::Empty,
1160 gen_ai.response.model = tracing::field::Empty,
1161 gen_ai.usage.output_tokens = tracing::field::Empty,
1162 gen_ai.usage.input_tokens = tracing::field::Empty,
1163 )
1164 } else {
1165 tracing::Span::current()
1166 };
1167
1168 if completion_request.max_tokens.is_none() {
1170 if let Some(tokens) = self.default_max_tokens {
1171 completion_request.max_tokens = Some(tokens);
1172 } else {
1173 return Err(CompletionError::RequestError(
1174 "`max_tokens` must be set for Anthropic".into(),
1175 ));
1176 }
1177 }
1178
1179 let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
1180 model: &request_model,
1181 request: completion_request,
1182 prompt_caching: self.prompt_caching,
1183 })?;
1184
1185 if enabled!(Level::TRACE) {
1186 tracing::trace!(
1187 target: "rig::completions",
1188 "Anthropic completion request: {}",
1189 serde_json::to_string_pretty(&request)?
1190 );
1191 }
1192
1193 async move {
1194 let request: Vec<u8> = serde_json::to_vec(&request)?;
1195
1196 let req = self
1197 .client
1198 .post("/v1/messages")?
1199 .body(request)
1200 .map_err(|e| CompletionError::HttpError(e.into()))?;
1201
1202 let response = self
1203 .client
1204 .send::<_, Bytes>(req)
1205 .await
1206 .map_err(CompletionError::HttpError)?;
1207
1208 if response.status().is_success() {
1209 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
1210 response
1211 .into_body()
1212 .await
1213 .map_err(CompletionError::HttpError)?
1214 .to_vec()
1215 .as_slice(),
1216 )? {
1217 ApiResponse::Message(completion) => {
1218 let span = tracing::Span::current();
1219 span.record_response_metadata(&completion);
1220 span.record_token_usage(&completion.usage);
1221 if enabled!(Level::TRACE) {
1222 tracing::trace!(
1223 target: "rig::completions",
1224 "Anthropic completion response: {}",
1225 serde_json::to_string_pretty(&completion)?
1226 );
1227 }
1228 completion.try_into()
1229 }
1230 ApiResponse::Error(ApiErrorResponse { message }) => {
1231 Err(CompletionError::ResponseError(message))
1232 }
1233 }
1234 } else {
1235 let text: String = String::from_utf8_lossy(
1236 &response
1237 .into_body()
1238 .await
1239 .map_err(CompletionError::HttpError)?,
1240 )
1241 .into();
1242 Err(CompletionError::ProviderError(text))
1243 }
1244 }
1245 .instrument(span)
1246 .await
1247 }
1248
1249 async fn stream(
1250 &self,
1251 request: CompletionRequest,
1252 ) -> Result<
1253 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1254 CompletionError,
1255 > {
1256 CompletionModel::stream(self, request).await
1257 }
1258}
1259
1260#[derive(Debug, Deserialize)]
1261struct ApiErrorResponse {
1262 message: String,
1263}
1264
1265#[derive(Debug, Deserialize)]
1266#[serde(tag = "type", rename_all = "snake_case")]
1267enum ApiResponse<T> {
1268 Message(T),
1269 Error(ApiErrorResponse),
1270}
1271
1272#[cfg(test)]
1273mod tests {
1274 use super::*;
1275 use serde_json::json;
1276 use serde_path_to_error::deserialize;
1277
1278 #[test]
1279 fn test_deserialize_message() {
1280 let assistant_message_json = r#"
1281 {
1282 "role": "assistant",
1283 "content": "\n\nHello there, how may I assist you today?"
1284 }
1285 "#;
1286
1287 let assistant_message_json2 = r#"
1288 {
1289 "role": "assistant",
1290 "content": [
1291 {
1292 "type": "text",
1293 "text": "\n\nHello there, how may I assist you today?"
1294 },
1295 {
1296 "type": "tool_use",
1297 "id": "toolu_01A09q90qw90lq917835lq9",
1298 "name": "get_weather",
1299 "input": {"location": "San Francisco, CA"}
1300 }
1301 ]
1302 }
1303 "#;
1304
1305 let user_message_json = r#"
1306 {
1307 "role": "user",
1308 "content": [
1309 {
1310 "type": "image",
1311 "source": {
1312 "type": "base64",
1313 "media_type": "image/jpeg",
1314 "data": "/9j/4AAQSkZJRg..."
1315 }
1316 },
1317 {
1318 "type": "text",
1319 "text": "What is in this image?"
1320 },
1321 {
1322 "type": "tool_result",
1323 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1324 "content": "15 degrees"
1325 }
1326 ]
1327 }
1328 "#;
1329
1330 let assistant_message: Message = {
1331 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1332 deserialize(jd).unwrap_or_else(|err| {
1333 panic!("Deserialization error at {}: {}", err.path(), err);
1334 })
1335 };
1336
1337 let assistant_message2: Message = {
1338 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1339 deserialize(jd).unwrap_or_else(|err| {
1340 panic!("Deserialization error at {}: {}", err.path(), err);
1341 })
1342 };
1343
1344 let user_message: Message = {
1345 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1346 deserialize(jd).unwrap_or_else(|err| {
1347 panic!("Deserialization error at {}: {}", err.path(), err);
1348 })
1349 };
1350
1351 let Message { role, content } = assistant_message;
1352 assert_eq!(role, Role::Assistant);
1353 assert_eq!(
1354 content.first(),
1355 Content::Text {
1356 text: "\n\nHello there, how may I assist you today?".to_owned(),
1357 cache_control: None,
1358 }
1359 );
1360
1361 let Message { role, content } = assistant_message2;
1362 {
1363 assert_eq!(role, Role::Assistant);
1364 assert_eq!(content.len(), 2);
1365
1366 let mut iter = content.into_iter();
1367
1368 match iter.next().unwrap() {
1369 Content::Text { text, .. } => {
1370 assert_eq!(text, "\n\nHello there, how may I assist you today?");
1371 }
1372 _ => panic!("Expected text content"),
1373 }
1374
1375 match iter.next().unwrap() {
1376 Content::ToolUse { id, name, input } => {
1377 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1378 assert_eq!(name, "get_weather");
1379 assert_eq!(input, json!({"location": "San Francisco, CA"}));
1380 }
1381 _ => panic!("Expected tool use content"),
1382 }
1383
1384 assert_eq!(iter.next(), None);
1385 }
1386
1387 let Message { role, content } = user_message;
1388 {
1389 assert_eq!(role, Role::User);
1390 assert_eq!(content.len(), 3);
1391
1392 let mut iter = content.into_iter();
1393
1394 match iter.next().unwrap() {
1395 Content::Image { source, .. } => {
1396 assert_eq!(
1397 source,
1398 ImageSource::Base64 {
1399 data: "/9j/4AAQSkZJRg...".to_owned(),
1400 media_type: ImageFormat::JPEG,
1401 }
1402 );
1403 }
1404 _ => panic!("Expected image content"),
1405 }
1406
1407 match iter.next().unwrap() {
1408 Content::Text { text, .. } => {
1409 assert_eq!(text, "What is in this image?");
1410 }
1411 _ => panic!("Expected text content"),
1412 }
1413
1414 match iter.next().unwrap() {
1415 Content::ToolResult {
1416 tool_use_id,
1417 content,
1418 is_error,
1419 ..
1420 } => {
1421 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1422 assert_eq!(
1423 content.first(),
1424 ToolResultContent::Text {
1425 text: "15 degrees".to_owned()
1426 }
1427 );
1428 assert_eq!(is_error, None);
1429 }
1430 _ => panic!("Expected tool result content"),
1431 }
1432
1433 assert_eq!(iter.next(), None);
1434 }
1435 }
1436
1437 #[test]
1438 fn test_message_to_message_conversion() {
1439 let user_message: Message = serde_json::from_str(
1440 r#"
1441 {
1442 "role": "user",
1443 "content": [
1444 {
1445 "type": "image",
1446 "source": {
1447 "type": "base64",
1448 "media_type": "image/jpeg",
1449 "data": "/9j/4AAQSkZJRg..."
1450 }
1451 },
1452 {
1453 "type": "text",
1454 "text": "What is in this image?"
1455 },
1456 {
1457 "type": "document",
1458 "source": {
1459 "type": "base64",
1460 "data": "base64_encoded_pdf_data",
1461 "media_type": "application/pdf"
1462 }
1463 }
1464 ]
1465 }
1466 "#,
1467 )
1468 .unwrap();
1469
1470 let assistant_message = Message {
1471 role: Role::Assistant,
1472 content: OneOrMany::one(Content::ToolUse {
1473 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1474 name: "get_weather".to_string(),
1475 input: json!({"location": "San Francisco, CA"}),
1476 }),
1477 };
1478
1479 let tool_message = Message {
1480 role: Role::User,
1481 content: OneOrMany::one(Content::ToolResult {
1482 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1483 content: OneOrMany::one(ToolResultContent::Text {
1484 text: "15 degrees".to_string(),
1485 }),
1486 is_error: None,
1487 cache_control: None,
1488 }),
1489 };
1490
1491 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1492 let converted_assistant_message: message::Message =
1493 assistant_message.clone().try_into().unwrap();
1494 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1495
1496 match converted_user_message.clone() {
1497 message::Message::User { content } => {
1498 assert_eq!(content.len(), 3);
1499
1500 let mut iter = content.into_iter();
1501
1502 match iter.next().unwrap() {
1503 message::UserContent::Image(message::Image {
1504 data, media_type, ..
1505 }) => {
1506 assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1507 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1508 }
1509 _ => panic!("Expected image content"),
1510 }
1511
1512 match iter.next().unwrap() {
1513 message::UserContent::Text(message::Text { text }) => {
1514 assert_eq!(text, "What is in this image?");
1515 }
1516 _ => panic!("Expected text content"),
1517 }
1518
1519 match iter.next().unwrap() {
1520 message::UserContent::Document(message::Document {
1521 data, media_type, ..
1522 }) => {
1523 assert_eq!(
1524 data,
1525 DocumentSourceKind::String("base64_encoded_pdf_data".into())
1526 );
1527 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1528 }
1529 _ => panic!("Expected document content"),
1530 }
1531
1532 assert_eq!(iter.next(), None);
1533 }
1534 _ => panic!("Expected user message"),
1535 }
1536
1537 match converted_tool_message.clone() {
1538 message::Message::User { content } => {
1539 let message::ToolResult { id, content, .. } = match content.first() {
1540 message::UserContent::ToolResult(tool_result) => tool_result,
1541 _ => panic!("Expected tool result content"),
1542 };
1543 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1544 match content.first() {
1545 message::ToolResultContent::Text(message::Text { text }) => {
1546 assert_eq!(text, "15 degrees");
1547 }
1548 _ => panic!("Expected text content"),
1549 }
1550 }
1551 _ => panic!("Expected tool result content"),
1552 }
1553
1554 match converted_assistant_message.clone() {
1555 message::Message::Assistant { content, .. } => {
1556 assert_eq!(content.len(), 1);
1557
1558 match content.first() {
1559 message::AssistantContent::ToolCall(message::ToolCall {
1560 id, function, ..
1561 }) => {
1562 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1563 assert_eq!(function.name, "get_weather");
1564 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1565 }
1566 _ => panic!("Expected tool call content"),
1567 }
1568 }
1569 _ => panic!("Expected assistant message"),
1570 }
1571
1572 let original_user_message: Message = converted_user_message.try_into().unwrap();
1573 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1574 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1575
1576 assert_eq!(user_message, original_user_message);
1577 assert_eq!(assistant_message, original_assistant_message);
1578 assert_eq!(tool_message, original_tool_message);
1579 }
1580
1581 #[test]
1582 fn test_content_format_conversion() {
1583 use crate::completion::message::ContentFormat;
1584
1585 let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1586 assert_eq!(source_type, SourceType::URL);
1587
1588 let content_format: ContentFormat = SourceType::URL.into();
1589 assert_eq!(content_format, ContentFormat::Url);
1590
1591 let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1592 assert_eq!(source_type, SourceType::BASE64);
1593
1594 let content_format: ContentFormat = SourceType::BASE64.into();
1595 assert_eq!(content_format, ContentFormat::Base64);
1596
1597 let source_type: SourceType = ContentFormat::String.try_into().unwrap();
1598 assert_eq!(source_type, SourceType::TEXT);
1599
1600 let content_format: ContentFormat = SourceType::TEXT.into();
1601 assert_eq!(content_format, ContentFormat::String);
1602 }
1603
1604 #[test]
1605 fn test_cache_control_serialization() {
1606 let system = SystemContent::Text {
1608 text: "You are a helpful assistant.".to_string(),
1609 cache_control: Some(CacheControl::Ephemeral),
1610 };
1611 let json = serde_json::to_string(&system).unwrap();
1612 assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1613 assert!(json.contains(r#""type":"text""#));
1614
1615 let system_no_cache = SystemContent::Text {
1617 text: "Hello".to_string(),
1618 cache_control: None,
1619 };
1620 let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1621 assert!(!json_no_cache.contains("cache_control"));
1622
1623 let content = Content::Text {
1625 text: "Test message".to_string(),
1626 cache_control: Some(CacheControl::Ephemeral),
1627 };
1628 let json_content = serde_json::to_string(&content).unwrap();
1629 assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1630
1631 let mut system_vec = vec![SystemContent::Text {
1633 text: "System prompt".to_string(),
1634 cache_control: None,
1635 }];
1636 let mut messages = vec![
1637 Message {
1638 role: Role::User,
1639 content: OneOrMany::one(Content::Text {
1640 text: "First message".to_string(),
1641 cache_control: None,
1642 }),
1643 },
1644 Message {
1645 role: Role::Assistant,
1646 content: OneOrMany::one(Content::Text {
1647 text: "Response".to_string(),
1648 cache_control: None,
1649 }),
1650 },
1651 ];
1652
1653 apply_cache_control(&mut system_vec, &mut messages);
1654
1655 match &system_vec[0] {
1657 SystemContent::Text { cache_control, .. } => {
1658 assert!(cache_control.is_some());
1659 }
1660 }
1661
1662 for content in messages[0].content.iter() {
1665 if let Content::Text { cache_control, .. } = content {
1666 assert!(cache_control.is_none());
1667 }
1668 }
1669
1670 for content in messages[1].content.iter() {
1672 if let Content::Text { cache_control, .. } = content {
1673 assert!(cache_control.is_some());
1674 }
1675 }
1676 }
1677
1678 #[test]
1679 fn test_plaintext_document_serialization() {
1680 let content = Content::Document {
1681 source: DocumentSource::Text {
1682 data: "Hello, world!".to_string(),
1683 media_type: PlainTextMediaType::Plain,
1684 },
1685 cache_control: None,
1686 };
1687
1688 let json = serde_json::to_value(&content).unwrap();
1689 assert_eq!(json["type"], "document");
1690 assert_eq!(json["source"]["type"], "text");
1691 assert_eq!(json["source"]["media_type"], "text/plain");
1692 assert_eq!(json["source"]["data"], "Hello, world!");
1693 }
1694
1695 #[test]
1696 fn test_plaintext_document_deserialization() {
1697 let json = r#"
1698 {
1699 "type": "document",
1700 "source": {
1701 "type": "text",
1702 "media_type": "text/plain",
1703 "data": "Hello, world!"
1704 }
1705 }
1706 "#;
1707
1708 let content: Content = serde_json::from_str(json).unwrap();
1709 match content {
1710 Content::Document {
1711 source,
1712 cache_control,
1713 } => {
1714 assert_eq!(
1715 source,
1716 DocumentSource::Text {
1717 data: "Hello, world!".to_string(),
1718 media_type: PlainTextMediaType::Plain,
1719 }
1720 );
1721 assert_eq!(cache_control, None);
1722 }
1723 _ => panic!("Expected Document content"),
1724 }
1725 }
1726
1727 #[test]
1728 fn test_base64_pdf_document_serialization() {
1729 let content = Content::Document {
1730 source: DocumentSource::Base64 {
1731 data: "base64data".to_string(),
1732 media_type: DocumentFormat::PDF,
1733 },
1734 cache_control: None,
1735 };
1736
1737 let json = serde_json::to_value(&content).unwrap();
1738 assert_eq!(json["type"], "document");
1739 assert_eq!(json["source"]["type"], "base64");
1740 assert_eq!(json["source"]["media_type"], "application/pdf");
1741 assert_eq!(json["source"]["data"], "base64data");
1742 }
1743
1744 #[test]
1745 fn test_base64_pdf_document_deserialization() {
1746 let json = r#"
1747 {
1748 "type": "document",
1749 "source": {
1750 "type": "base64",
1751 "media_type": "application/pdf",
1752 "data": "base64data"
1753 }
1754 }
1755 "#;
1756
1757 let content: Content = serde_json::from_str(json).unwrap();
1758 match content {
1759 Content::Document { source, .. } => {
1760 assert_eq!(
1761 source,
1762 DocumentSource::Base64 {
1763 data: "base64data".to_string(),
1764 media_type: DocumentFormat::PDF,
1765 }
1766 );
1767 }
1768 _ => panic!("Expected Document content"),
1769 }
1770 }
1771
1772 #[test]
1773 fn test_plaintext_rig_to_anthropic_conversion() {
1774 use crate::completion::message as msg;
1775
1776 let rig_message = msg::Message::User {
1777 content: OneOrMany::one(msg::UserContent::document(
1778 "Some plain text content".to_string(),
1779 Some(msg::DocumentMediaType::TXT),
1780 )),
1781 };
1782
1783 let anthropic_message: Message = rig_message.try_into().unwrap();
1784 assert_eq!(anthropic_message.role, Role::User);
1785
1786 let mut iter = anthropic_message.content.into_iter();
1787 match iter.next().unwrap() {
1788 Content::Document { source, .. } => {
1789 assert_eq!(
1790 source,
1791 DocumentSource::Text {
1792 data: "Some plain text content".to_string(),
1793 media_type: PlainTextMediaType::Plain,
1794 }
1795 );
1796 }
1797 other => panic!("Expected Document content, got: {other:?}"),
1798 }
1799 }
1800
1801 #[test]
1802 fn test_plaintext_anthropic_to_rig_conversion() {
1803 use crate::completion::message as msg;
1804
1805 let anthropic_message = Message {
1806 role: Role::User,
1807 content: OneOrMany::one(Content::Document {
1808 source: DocumentSource::Text {
1809 data: "Some plain text content".to_string(),
1810 media_type: PlainTextMediaType::Plain,
1811 },
1812 cache_control: None,
1813 }),
1814 };
1815
1816 let rig_message: msg::Message = anthropic_message.try_into().unwrap();
1817 match rig_message {
1818 msg::Message::User { content } => {
1819 let mut iter = content.into_iter();
1820 match iter.next().unwrap() {
1821 msg::UserContent::Document(msg::Document {
1822 data, media_type, ..
1823 }) => {
1824 assert_eq!(
1825 data,
1826 DocumentSourceKind::String("Some plain text content".into())
1827 );
1828 assert_eq!(media_type, Some(msg::DocumentMediaType::TXT));
1829 }
1830 other => panic!("Expected Document content, got: {other:?}"),
1831 }
1832 }
1833 _ => panic!("Expected User message"),
1834 }
1835 }
1836
1837 #[test]
1838 fn test_plaintext_roundtrip_rig_to_anthropic_and_back() {
1839 use crate::completion::message as msg;
1840
1841 let original = msg::Message::User {
1842 content: OneOrMany::one(msg::UserContent::document(
1843 "Round trip text".to_string(),
1844 Some(msg::DocumentMediaType::TXT),
1845 )),
1846 };
1847
1848 let anthropic: Message = original.clone().try_into().unwrap();
1849 let back: msg::Message = anthropic.try_into().unwrap();
1850
1851 match (&original, &back) {
1852 (
1853 msg::Message::User {
1854 content: orig_content,
1855 },
1856 msg::Message::User {
1857 content: back_content,
1858 },
1859 ) => match (orig_content.first(), back_content.first()) {
1860 (
1861 msg::UserContent::Document(msg::Document {
1862 media_type: orig_mt,
1863 ..
1864 }),
1865 msg::UserContent::Document(msg::Document {
1866 media_type: back_mt,
1867 ..
1868 }),
1869 ) => {
1870 assert_eq!(orig_mt, back_mt);
1871 }
1872 _ => panic!("Expected Document content in both"),
1873 },
1874 _ => panic!("Expected User messages"),
1875 }
1876 }
1877
1878 #[test]
1879 fn test_unsupported_document_type_returns_error() {
1880 use crate::completion::message as msg;
1881
1882 let rig_message = msg::Message::User {
1883 content: OneOrMany::one(msg::UserContent::Document(msg::Document {
1884 data: DocumentSourceKind::String("data".into()),
1885 media_type: Some(msg::DocumentMediaType::HTML),
1886 additional_params: None,
1887 })),
1888 };
1889
1890 let result: Result<Message, _> = rig_message.try_into();
1891 assert!(result.is_err());
1892 let err = result.unwrap_err().to_string();
1893 assert!(
1894 err.contains("Anthropic only supports PDF and plain text documents"),
1895 "Unexpected error: {err}"
1896 );
1897 }
1898
1899 #[test]
1900 fn test_plaintext_document_url_source_returns_error() {
1901 use crate::completion::message as msg;
1902
1903 let rig_message = msg::Message::User {
1904 content: OneOrMany::one(msg::UserContent::Document(msg::Document {
1905 data: DocumentSourceKind::Url("https://example.com/doc.txt".into()),
1906 media_type: Some(msg::DocumentMediaType::TXT),
1907 additional_params: None,
1908 })),
1909 };
1910
1911 let result: Result<Message, _> = rig_message.try_into();
1912 assert!(result.is_err());
1913 let err = result.unwrap_err().to_string();
1914 assert!(
1915 err.contains("Only string or base64 data is supported for plain text documents"),
1916 "Unexpected error: {err}"
1917 );
1918 }
1919
1920 #[test]
1921 fn test_plaintext_document_with_cache_control() {
1922 let content = Content::Document {
1923 source: DocumentSource::Text {
1924 data: "cached text".to_string(),
1925 media_type: PlainTextMediaType::Plain,
1926 },
1927 cache_control: Some(CacheControl::Ephemeral),
1928 };
1929
1930 let json = serde_json::to_value(&content).unwrap();
1931 assert_eq!(json["source"]["type"], "text");
1932 assert_eq!(json["source"]["media_type"], "text/plain");
1933 assert_eq!(json["cache_control"]["type"], "ephemeral");
1934 }
1935
1936 #[test]
1937 fn test_message_with_plaintext_document_deserialization() {
1938 let json = r#"
1939 {
1940 "role": "user",
1941 "content": [
1942 {
1943 "type": "document",
1944 "source": {
1945 "type": "text",
1946 "media_type": "text/plain",
1947 "data": "Hello from a text file"
1948 }
1949 },
1950 {
1951 "type": "text",
1952 "text": "Summarize this document."
1953 }
1954 ]
1955 }
1956 "#;
1957
1958 let message: Message = serde_json::from_str(json).unwrap();
1959 assert_eq!(message.role, Role::User);
1960 assert_eq!(message.content.len(), 2);
1961
1962 let mut iter = message.content.into_iter();
1963
1964 match iter.next().unwrap() {
1965 Content::Document { source, .. } => {
1966 assert_eq!(
1967 source,
1968 DocumentSource::Text {
1969 data: "Hello from a text file".to_string(),
1970 media_type: PlainTextMediaType::Plain,
1971 }
1972 );
1973 }
1974 _ => panic!("Expected Document content"),
1975 }
1976
1977 match iter.next().unwrap() {
1978 Content::Text { text, .. } => {
1979 assert_eq!(text, "Summarize this document.");
1980 }
1981 _ => panic!("Expected Text content"),
1982 }
1983 }
1984
1985 #[test]
1986 fn test_assistant_reasoning_multiblock_to_anthropic_content() {
1987 let reasoning = message::Reasoning {
1988 id: None,
1989 content: vec![
1990 message::ReasoningContent::Text {
1991 text: "step one".to_string(),
1992 signature: Some("sig-1".to_string()),
1993 },
1994 message::ReasoningContent::Summary("summary".to_string()),
1995 message::ReasoningContent::Text {
1996 text: "step two".to_string(),
1997 signature: Some("sig-2".to_string()),
1998 },
1999 message::ReasoningContent::Redacted {
2000 data: "redacted block".to_string(),
2001 },
2002 ],
2003 };
2004
2005 let msg = message::Message::Assistant {
2006 id: None,
2007 content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2008 };
2009 let converted: Message = msg.try_into().expect("convert assistant message");
2010 let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2011
2012 assert_eq!(converted.role, Role::Assistant);
2013 assert_eq!(converted_content.len(), 4);
2014 assert!(matches!(
2015 converted_content.first(),
2016 Some(Content::Thinking { thinking, signature: Some(signature) })
2017 if thinking == "step one" && signature == "sig-1"
2018 ));
2019 assert!(matches!(
2020 converted_content.get(1),
2021 Some(Content::Thinking { thinking, signature: None }) if thinking == "summary"
2022 ));
2023 assert!(matches!(
2024 converted_content.get(2),
2025 Some(Content::Thinking { thinking, signature: Some(signature) })
2026 if thinking == "step two" && signature == "sig-2"
2027 ));
2028 assert!(matches!(
2029 converted_content.get(3),
2030 Some(Content::RedactedThinking { data }) if data == "redacted block"
2031 ));
2032 }
2033
2034 #[test]
2035 fn test_redacted_thinking_content_to_assistant_reasoning() {
2036 let content = Content::RedactedThinking {
2037 data: "opaque-redacted".to_string(),
2038 };
2039 let converted: message::AssistantContent =
2040 content.try_into().expect("convert redacted thinking");
2041
2042 assert!(matches!(
2043 converted,
2044 message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2045 if matches!(
2046 content.first(),
2047 Some(message::ReasoningContent::Redacted { data }) if data == "opaque-redacted"
2048 )
2049 ));
2050 }
2051
2052 #[test]
2053 fn test_assistant_encrypted_reasoning_maps_to_redacted_thinking() {
2054 let reasoning = message::Reasoning {
2055 id: None,
2056 content: vec![message::ReasoningContent::Encrypted(
2057 "ciphertext".to_string(),
2058 )],
2059 };
2060 let msg = message::Message::Assistant {
2061 id: None,
2062 content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2063 };
2064
2065 let converted: Message = msg.try_into().expect("convert assistant message");
2066 let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2067
2068 assert_eq!(converted_content.len(), 1);
2069 assert!(matches!(
2070 converted_content.first(),
2071 Some(Content::RedactedThinking { data }) if data == "ciphertext"
2072 ));
2073 }
2074}