1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError, GetTokenUsage},
6 http_client::HttpClientExt,
7 message::{self, DocumentMediaType, DocumentSourceKind, MessageError, MimeType, Reasoning},
8 one_or_many::string_or_one_or_many,
9 telemetry::{ProviderResponseExt, SpanCombinator},
10 wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42 pub content: Vec<Content>,
43 pub id: String,
44 pub model: String,
45 pub role: String,
46 pub stop_reason: Option<String>,
47 pub stop_sequence: Option<String>,
48 pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52 type OutputMessage = Content;
53 type Usage = Usage;
54
55 fn get_response_id(&self) -> Option<String> {
56 Some(self.id.to_owned())
57 }
58
59 fn get_response_model_name(&self) -> Option<String> {
60 Some(self.model.to_owned())
61 }
62
63 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64 self.content.clone()
65 }
66
67 fn get_text_response(&self) -> Option<String> {
68 let res = self
69 .content
70 .iter()
71 .filter_map(|x| {
72 if let Content::Text { text, .. } = x {
73 Some(text.to_owned())
74 } else {
75 None
76 }
77 })
78 .collect::<Vec<String>>()
79 .join("\n");
80
81 if res.is_empty() { None } else { Some(res) }
82 }
83
84 fn get_usage(&self) -> Option<Self::Usage> {
85 Some(self.usage.clone())
86 }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91 pub input_tokens: u64,
92 pub cache_read_input_tokens: Option<u64>,
93 pub cache_creation_input_tokens: Option<u64>,
94 pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 write!(
100 f,
101 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102 self.input_tokens,
103 match self.cache_read_input_tokens {
104 Some(token) => token.to_string(),
105 None => "n/a".to_string(),
106 },
107 match self.cache_creation_input_tokens {
108 Some(token) => token.to_string(),
109 None => "n/a".to_string(),
110 },
111 self.output_tokens
112 )
113 }
114}
115
116impl GetTokenUsage for Usage {
117 fn token_usage(&self) -> Option<crate::completion::Usage> {
118 let mut usage = crate::completion::Usage::new();
119
120 usage.input_tokens = self.input_tokens
121 + self.cache_creation_input_tokens.unwrap_or_default()
122 + self.cache_read_input_tokens.unwrap_or_default();
123 usage.output_tokens = self.output_tokens;
124 usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126 Some(usage)
127 }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132 pub name: String,
133 pub description: Option<String>,
134 pub input_schema: serde_json::Value,
135}
136
137#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum CacheControl {
141 Ephemeral,
142}
143
144#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
146#[serde(tag = "type", rename_all = "snake_case")]
147pub enum SystemContent {
148 Text {
149 text: String,
150 #[serde(skip_serializing_if = "Option::is_none")]
151 cache_control: Option<CacheControl>,
152 },
153}
154
155impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
156 type Error = CompletionError;
157
158 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
159 let content = response
160 .content
161 .iter()
162 .map(|content| content.clone().try_into())
163 .collect::<Result<Vec<_>, _>>()?;
164
165 let choice = OneOrMany::many(content).map_err(|_| {
166 CompletionError::ResponseError(
167 "Response contained no message or tool call (empty)".to_owned(),
168 )
169 })?;
170
171 let usage = completion::Usage {
172 input_tokens: response.usage.input_tokens,
173 output_tokens: response.usage.output_tokens,
174 total_tokens: response.usage.input_tokens + response.usage.output_tokens,
175 cached_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
176 };
177
178 Ok(completion::CompletionResponse {
179 choice,
180 usage,
181 raw_response: response,
182 message_id: None,
183 })
184 }
185}
186
187#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
188pub struct Message {
189 pub role: Role,
190 #[serde(deserialize_with = "string_or_one_or_many")]
191 pub content: OneOrMany<Content>,
192}
193
194#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
195#[serde(rename_all = "lowercase")]
196pub enum Role {
197 User,
198 Assistant,
199}
200
201#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
202#[serde(tag = "type", rename_all = "snake_case")]
203pub enum Content {
204 Text {
205 text: String,
206 #[serde(skip_serializing_if = "Option::is_none")]
207 cache_control: Option<CacheControl>,
208 },
209 Image {
210 source: ImageSource,
211 #[serde(skip_serializing_if = "Option::is_none")]
212 cache_control: Option<CacheControl>,
213 },
214 ToolUse {
215 id: String,
216 name: String,
217 input: serde_json::Value,
218 },
219 ToolResult {
220 tool_use_id: String,
221 #[serde(deserialize_with = "string_or_one_or_many")]
222 content: OneOrMany<ToolResultContent>,
223 #[serde(skip_serializing_if = "Option::is_none")]
224 is_error: Option<bool>,
225 #[serde(skip_serializing_if = "Option::is_none")]
226 cache_control: Option<CacheControl>,
227 },
228 Document {
229 source: DocumentSource,
230 #[serde(skip_serializing_if = "Option::is_none")]
231 cache_control: Option<CacheControl>,
232 },
233 Thinking {
234 thinking: String,
235 #[serde(skip_serializing_if = "Option::is_none")]
236 signature: Option<String>,
237 },
238 RedactedThinking {
239 data: String,
240 },
241}
242
243impl FromStr for Content {
244 type Err = Infallible;
245
246 fn from_str(s: &str) -> Result<Self, Self::Err> {
247 Ok(Content::Text {
248 text: s.to_owned(),
249 cache_control: None,
250 })
251 }
252}
253
254#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
255#[serde(tag = "type", rename_all = "snake_case")]
256pub enum ToolResultContent {
257 Text { text: String },
258 Image(ImageSource),
259}
260
261impl FromStr for ToolResultContent {
262 type Err = Infallible;
263
264 fn from_str(s: &str) -> Result<Self, Self::Err> {
265 Ok(ToolResultContent::Text { text: s.to_owned() })
266 }
267}
268
269#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
277#[serde(tag = "type", rename_all = "snake_case")]
278pub enum ImageSource {
279 #[serde(rename = "base64")]
280 Base64 {
281 data: String,
282 media_type: ImageFormat,
283 },
284 #[serde(rename = "url")]
285 Url { url: String },
286}
287
288#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
299#[serde(tag = "type", rename_all = "snake_case")]
300pub enum DocumentSource {
301 Base64 {
302 data: String,
303 media_type: DocumentFormat,
304 },
305 Text {
306 data: String,
307 media_type: PlainTextMediaType,
308 },
309}
310
311#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
312#[serde(rename_all = "lowercase")]
313pub enum ImageFormat {
314 #[serde(rename = "image/jpeg")]
315 JPEG,
316 #[serde(rename = "image/png")]
317 PNG,
318 #[serde(rename = "image/gif")]
319 GIF,
320 #[serde(rename = "image/webp")]
321 WEBP,
322}
323
324#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
331#[serde(rename_all = "lowercase")]
332pub enum DocumentFormat {
333 #[serde(rename = "application/pdf")]
334 PDF,
335}
336
337#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
343pub enum PlainTextMediaType {
344 #[serde(rename = "text/plain")]
345 Plain,
346}
347
348#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
349#[serde(rename_all = "lowercase")]
350pub enum SourceType {
351 BASE64,
352 URL,
353 TEXT,
354}
355
356impl From<String> for Content {
357 fn from(text: String) -> Self {
358 Content::Text {
359 text,
360 cache_control: None,
361 }
362 }
363}
364
365impl From<String> for ToolResultContent {
366 fn from(text: String) -> Self {
367 ToolResultContent::Text { text }
368 }
369}
370
371impl TryFrom<message::ContentFormat> for SourceType {
372 type Error = MessageError;
373
374 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
375 match format {
376 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
377 message::ContentFormat::Url => Ok(SourceType::URL),
378 message::ContentFormat::String => Ok(SourceType::TEXT),
379 }
380 }
381}
382
383impl From<SourceType> for message::ContentFormat {
384 fn from(source_type: SourceType) -> Self {
385 match source_type {
386 SourceType::BASE64 => message::ContentFormat::Base64,
387 SourceType::URL => message::ContentFormat::Url,
388 SourceType::TEXT => message::ContentFormat::String,
389 }
390 }
391}
392
393impl TryFrom<message::ImageMediaType> for ImageFormat {
394 type Error = MessageError;
395
396 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
397 Ok(match media_type {
398 message::ImageMediaType::JPEG => ImageFormat::JPEG,
399 message::ImageMediaType::PNG => ImageFormat::PNG,
400 message::ImageMediaType::GIF => ImageFormat::GIF,
401 message::ImageMediaType::WEBP => ImageFormat::WEBP,
402 _ => {
403 return Err(MessageError::ConversionError(
404 format!("Unsupported image media type: {media_type:?}").to_owned(),
405 ));
406 }
407 })
408 }
409}
410
411impl From<ImageFormat> for message::ImageMediaType {
412 fn from(format: ImageFormat) -> Self {
413 match format {
414 ImageFormat::JPEG => message::ImageMediaType::JPEG,
415 ImageFormat::PNG => message::ImageMediaType::PNG,
416 ImageFormat::GIF => message::ImageMediaType::GIF,
417 ImageFormat::WEBP => message::ImageMediaType::WEBP,
418 }
419 }
420}
421
422impl TryFrom<DocumentMediaType> for DocumentFormat {
423 type Error = MessageError;
424 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
425 match value {
426 DocumentMediaType::PDF => Ok(DocumentFormat::PDF),
427 other => Err(MessageError::ConversionError(format!(
428 "DocumentFormat only supports PDF for base64 sources, got: {}",
429 other.to_mime_type()
430 ))),
431 }
432 }
433}
434
435impl TryFrom<message::AssistantContent> for Content {
436 type Error = MessageError;
437 fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
438 match text {
439 message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
440 text,
441 cache_control: None,
442 }),
443 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
444 "Anthropic currently doesn't support images.".to_string(),
445 )),
446 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
447 Ok(Content::ToolUse {
448 id,
449 name: function.name,
450 input: function.arguments,
451 })
452 }
453 message::AssistantContent::Reasoning(reasoning) => Ok(Content::Thinking {
454 thinking: reasoning.display_text(),
455 signature: reasoning.first_signature().map(str::to_owned),
456 }),
457 }
458 }
459}
460
461fn anthropic_content_from_assistant_content(
462 content: message::AssistantContent,
463) -> Result<Vec<Content>, MessageError> {
464 match content {
465 message::AssistantContent::Text(message::Text { text }) => Ok(vec![Content::Text {
466 text,
467 cache_control: None,
468 }]),
469 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
470 "Anthropic currently doesn't support images.".to_string(),
471 )),
472 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
473 Ok(vec![Content::ToolUse {
474 id,
475 name: function.name,
476 input: function.arguments,
477 }])
478 }
479 message::AssistantContent::Reasoning(reasoning) => {
480 let mut converted = Vec::new();
481 for block in reasoning.content {
482 match block {
483 message::ReasoningContent::Text { text, signature } => {
484 converted.push(Content::Thinking {
485 thinking: text,
486 signature,
487 });
488 }
489 message::ReasoningContent::Summary(summary) => {
490 converted.push(Content::Thinking {
491 thinking: summary,
492 signature: None,
493 });
494 }
495 message::ReasoningContent::Redacted { data }
496 | message::ReasoningContent::Encrypted(data) => {
497 converted.push(Content::RedactedThinking { data });
498 }
499 }
500 }
501
502 if converted.is_empty() {
503 return Err(MessageError::ConversionError(
504 "Cannot convert empty reasoning content to Anthropic format".to_string(),
505 ));
506 }
507
508 Ok(converted)
509 }
510 }
511}
512
513impl TryFrom<message::Message> for Message {
514 type Error = MessageError;
515
516 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
517 Ok(match message {
518 message::Message::User { content } => Message {
519 role: Role::User,
520 content: content.try_map(|content| match content {
521 message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
522 text,
523 cache_control: None,
524 }),
525 message::UserContent::ToolResult(message::ToolResult {
526 id, content, ..
527 }) => Ok(Content::ToolResult {
528 tool_use_id: id,
529 content: content.try_map(|content| match content {
530 message::ToolResultContent::Text(message::Text { text }) => {
531 Ok(ToolResultContent::Text { text })
532 }
533 message::ToolResultContent::Image(image) => {
534 let DocumentSourceKind::Base64(data) = image.data else {
535 return Err(MessageError::ConversionError(
536 "Only base64 strings can be used with the Anthropic API"
537 .to_string(),
538 ));
539 };
540 let media_type =
541 image.media_type.ok_or(MessageError::ConversionError(
542 "Image media type is required".to_owned(),
543 ))?;
544 Ok(ToolResultContent::Image(ImageSource::Base64 {
545 data,
546 media_type: media_type.try_into()?,
547 }))
548 }
549 })?,
550 is_error: None,
551 cache_control: None,
552 }),
553 message::UserContent::Image(message::Image {
554 data, media_type, ..
555 }) => {
556 let source = match data {
557 DocumentSourceKind::Base64(data) => {
558 let media_type =
559 media_type.ok_or(MessageError::ConversionError(
560 "Image media type is required for Claude API".to_string(),
561 ))?;
562 ImageSource::Base64 {
563 data,
564 media_type: ImageFormat::try_from(media_type)?,
565 }
566 }
567 DocumentSourceKind::Url(url) => ImageSource::Url { url },
568 DocumentSourceKind::Unknown => {
569 return Err(MessageError::ConversionError(
570 "Image content has no body".into(),
571 ));
572 }
573 doc => {
574 return Err(MessageError::ConversionError(format!(
575 "Unsupported document type: {doc:?}"
576 )));
577 }
578 };
579
580 Ok(Content::Image {
581 source,
582 cache_control: None,
583 })
584 }
585 message::UserContent::Document(message::Document {
586 data, media_type, ..
587 }) => {
588 let media_type = media_type.ok_or(MessageError::ConversionError(
589 "Document media type is required".to_string(),
590 ))?;
591
592 let source = match media_type {
593 DocumentMediaType::PDF => {
594 let data = match data {
595 DocumentSourceKind::Base64(data)
596 | DocumentSourceKind::String(data) => data,
597 _ => {
598 return Err(MessageError::ConversionError(
599 "Only base64 encoded data is supported for PDF documents".into(),
600 ));
601 }
602 };
603 DocumentSource::Base64 {
604 data,
605 media_type: DocumentFormat::PDF,
606 }
607 }
608 DocumentMediaType::TXT => {
609 let data = match data {
610 DocumentSourceKind::String(data)
611 | DocumentSourceKind::Base64(data) => data,
612 _ => {
613 return Err(MessageError::ConversionError(
614 "Only string or base64 data is supported for plain text documents".into(),
615 ));
616 }
617 };
618 DocumentSource::Text {
619 data,
620 media_type: PlainTextMediaType::Plain,
621 }
622 }
623 other => {
624 return Err(MessageError::ConversionError(format!(
625 "Anthropic only supports PDF and plain text documents, got: {}",
626 other.to_mime_type()
627 )));
628 }
629 };
630
631 Ok(Content::Document {
632 source,
633 cache_control: None,
634 })
635 }
636 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
637 "Audio is not supported in Anthropic".to_owned(),
638 )),
639 message::UserContent::Video { .. } => Err(MessageError::ConversionError(
640 "Video is not supported in Anthropic".to_owned(),
641 )),
642 })?,
643 },
644
645 message::Message::Assistant { content, .. } => {
646 let converted_content = content.into_iter().try_fold(
647 Vec::new(),
648 |mut accumulated, assistant_content| {
649 accumulated
650 .extend(anthropic_content_from_assistant_content(assistant_content)?);
651 Ok::<Vec<Content>, MessageError>(accumulated)
652 },
653 )?;
654
655 Message {
656 content: OneOrMany::many(converted_content).map_err(|_| {
657 MessageError::ConversionError(
658 "Assistant message did not contain Anthropic-compatible content"
659 .to_owned(),
660 )
661 })?,
662 role: Role::Assistant,
663 }
664 }
665 })
666 }
667}
668
669impl TryFrom<Content> for message::AssistantContent {
670 type Error = MessageError;
671
672 fn try_from(content: Content) -> Result<Self, Self::Error> {
673 Ok(match content {
674 Content::Text { text, .. } => message::AssistantContent::text(text),
675 Content::ToolUse { id, name, input } => {
676 message::AssistantContent::tool_call(id, name, input)
677 }
678 Content::Thinking {
679 thinking,
680 signature,
681 } => message::AssistantContent::Reasoning(Reasoning::new_with_signature(
682 &thinking, signature,
683 )),
684 Content::RedactedThinking { data } => {
685 message::AssistantContent::Reasoning(Reasoning::redacted(data))
686 }
687 _ => {
688 return Err(MessageError::ConversionError(
689 "Content did not contain a message, tool call, or reasoning".to_owned(),
690 ));
691 }
692 })
693 }
694}
695
696impl From<ToolResultContent> for message::ToolResultContent {
697 fn from(content: ToolResultContent) -> Self {
698 match content {
699 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
700 ToolResultContent::Image(source) => match source {
701 ImageSource::Base64 { data, media_type } => {
702 message::ToolResultContent::image_base64(data, Some(media_type.into()), None)
703 }
704 ImageSource::Url { url } => message::ToolResultContent::image_url(url, None, None),
705 },
706 }
707 }
708}
709
710impl TryFrom<Message> for message::Message {
711 type Error = MessageError;
712
713 fn try_from(message: Message) -> Result<Self, Self::Error> {
714 Ok(match message.role {
715 Role::User => message::Message::User {
716 content: message.content.try_map(|content| {
717 Ok(match content {
718 Content::Text { text, .. } => message::UserContent::text(text),
719 Content::ToolResult {
720 tool_use_id,
721 content,
722 ..
723 } => message::UserContent::tool_result(
724 tool_use_id,
725 content.map(|content| content.into()),
726 ),
727 Content::Image { source, .. } => match source {
728 ImageSource::Base64 { data, media_type } => {
729 message::UserContent::Image(message::Image {
730 data: DocumentSourceKind::Base64(data),
731 media_type: Some(media_type.into()),
732 detail: None,
733 additional_params: None,
734 })
735 }
736 ImageSource::Url { url } => {
737 message::UserContent::Image(message::Image {
738 data: DocumentSourceKind::Url(url),
739 media_type: None,
740 detail: None,
741 additional_params: None,
742 })
743 }
744 },
745 Content::Document { source, .. } => match source {
746 DocumentSource::Base64 { data, media_type } => {
747 let rig_media_type = match media_type {
748 DocumentFormat::PDF => message::DocumentMediaType::PDF,
749 };
750 message::UserContent::document(data, Some(rig_media_type))
751 }
752 DocumentSource::Text { data, .. } => message::UserContent::document(
753 data,
754 Some(message::DocumentMediaType::TXT),
755 ),
756 },
757 _ => {
758 return Err(MessageError::ConversionError(
759 "Unsupported content type for User role".to_owned(),
760 ));
761 }
762 })
763 })?,
764 },
765 Role::Assistant => message::Message::Assistant {
766 id: None,
767 content: message.content.try_map(|content| content.try_into())?,
768 },
769 })
770 }
771}
772
773#[derive(Clone)]
774pub struct CompletionModel<T = reqwest::Client> {
775 pub(crate) client: Client<T>,
776 pub model: String,
777 pub default_max_tokens: Option<u64>,
778 pub prompt_caching: bool,
780}
781
782impl<T> CompletionModel<T>
783where
784 T: HttpClientExt,
785{
786 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
787 let model = model.into();
788 let default_max_tokens = calculate_max_tokens(&model);
789
790 Self {
791 client,
792 model,
793 default_max_tokens,
794 prompt_caching: false, }
796 }
797
798 pub fn with_model(client: Client<T>, model: &str) -> Self {
799 Self {
800 client,
801 model: model.to_string(),
802 default_max_tokens: Some(calculate_max_tokens_custom(model)),
803 prompt_caching: false, }
805 }
806
807 pub fn with_prompt_caching(mut self) -> Self {
815 self.prompt_caching = true;
816 self
817 }
818}
819
820fn calculate_max_tokens(model: &str) -> Option<u64> {
824 if model.starts_with("claude-opus-4") {
825 Some(32000)
826 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
827 Some(64000)
828 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
829 Some(8192)
830 } else if model.starts_with("claude-3-opus")
831 || model.starts_with("claude-3-sonnet")
832 || model.starts_with("claude-3-haiku")
833 {
834 Some(4096)
835 } else {
836 None
837 }
838}
839
840fn calculate_max_tokens_custom(model: &str) -> u64 {
841 if model.starts_with("claude-opus-4") {
842 32000
843 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
844 64000
845 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
846 8192
847 } else if model.starts_with("claude-3-opus")
848 || model.starts_with("claude-3-sonnet")
849 || model.starts_with("claude-3-haiku")
850 {
851 4096
852 } else {
853 2048
854 }
855}
856
857#[derive(Debug, Deserialize, Serialize)]
858pub struct Metadata {
859 user_id: Option<String>,
860}
861
862#[derive(Default, Debug, Serialize, Deserialize)]
863#[serde(tag = "type", rename_all = "snake_case")]
864pub enum ToolChoice {
865 #[default]
866 Auto,
867 Any,
868 None,
869 Tool {
870 name: String,
871 },
872}
873impl TryFrom<message::ToolChoice> for ToolChoice {
874 type Error = CompletionError;
875
876 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
877 let res = match value {
878 message::ToolChoice::Auto => Self::Auto,
879 message::ToolChoice::None => Self::None,
880 message::ToolChoice::Required => Self::Any,
881 message::ToolChoice::Specific { function_names } => {
882 if function_names.len() != 1 {
883 return Err(CompletionError::ProviderError(
884 "Only one tool may be specified to be used by Claude".into(),
885 ));
886 }
887
888 Self::Tool {
889 name: function_names.first().unwrap().to_string(),
890 }
891 }
892 };
893
894 Ok(res)
895 }
896}
897
898fn sanitize_schema(schema: &mut serde_json::Value) {
904 use serde_json::Value;
905
906 if let Value::Object(obj) = schema {
907 let is_object_schema = obj.get("type") == Some(&Value::String("object".to_string()))
908 || obj.contains_key("properties");
909
910 if is_object_schema && !obj.contains_key("additionalProperties") {
911 obj.insert("additionalProperties".to_string(), Value::Bool(false));
912 }
913
914 if let Some(Value::Object(properties)) = obj.get("properties") {
915 let prop_keys = properties.keys().cloned().map(Value::String).collect();
916 obj.insert("required".to_string(), Value::Array(prop_keys));
917 }
918
919 let is_numeric_schema = obj.get("type") == Some(&Value::String("integer".to_string()))
921 || obj.get("type") == Some(&Value::String("number".to_string()));
922
923 if is_numeric_schema {
924 for key in [
925 "minimum",
926 "maximum",
927 "exclusiveMinimum",
928 "exclusiveMaximum",
929 "multipleOf",
930 ] {
931 obj.remove(key);
932 }
933 }
934
935 if let Some(defs) = obj.get_mut("$defs")
936 && let Value::Object(defs_obj) = defs
937 {
938 for (_, def_schema) in defs_obj.iter_mut() {
939 sanitize_schema(def_schema);
940 }
941 }
942
943 if let Some(properties) = obj.get_mut("properties")
944 && let Value::Object(props) = properties
945 {
946 for (_, prop_value) in props.iter_mut() {
947 sanitize_schema(prop_value);
948 }
949 }
950
951 if let Some(items) = obj.get_mut("items") {
952 sanitize_schema(items);
953 }
954
955 for key in ["anyOf", "oneOf", "allOf"] {
956 if let Some(variants) = obj.get_mut(key)
957 && let Value::Array(variants_array) = variants
958 {
959 for variant in variants_array.iter_mut() {
960 sanitize_schema(variant);
961 }
962 }
963 }
964 }
965}
966
967#[derive(Debug, Deserialize, Serialize)]
970#[serde(tag = "type", rename_all = "snake_case")]
971enum OutputFormat {
972 JsonSchema { schema: serde_json::Value },
974}
975
976#[derive(Debug, Deserialize, Serialize)]
978struct OutputConfig {
979 format: OutputFormat,
980}
981
982#[derive(Debug, Deserialize, Serialize)]
983struct AnthropicCompletionRequest {
984 model: String,
985 messages: Vec<Message>,
986 max_tokens: u64,
987 #[serde(skip_serializing_if = "Vec::is_empty")]
989 system: Vec<SystemContent>,
990 #[serde(skip_serializing_if = "Option::is_none")]
991 temperature: Option<f64>,
992 #[serde(skip_serializing_if = "Option::is_none")]
993 tool_choice: Option<ToolChoice>,
994 #[serde(skip_serializing_if = "Vec::is_empty")]
995 tools: Vec<ToolDefinition>,
996 #[serde(skip_serializing_if = "Option::is_none")]
997 output_config: Option<OutputConfig>,
998 #[serde(flatten, skip_serializing_if = "Option::is_none")]
999 additional_params: Option<serde_json::Value>,
1000}
1001
1002fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
1004 match content {
1005 Content::Text { cache_control, .. } => *cache_control = value,
1006 Content::Image { cache_control, .. } => *cache_control = value,
1007 Content::ToolResult { cache_control, .. } => *cache_control = value,
1008 Content::Document { cache_control, .. } => *cache_control = value,
1009 _ => {}
1010 }
1011}
1012
1013pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
1018 if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
1020 *cache_control = Some(CacheControl::Ephemeral);
1021 }
1022
1023 for msg in messages.iter_mut() {
1025 for content in msg.content.iter_mut() {
1026 set_content_cache_control(content, None);
1027 }
1028 }
1029
1030 if let Some(last_msg) = messages.last_mut() {
1032 set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::Ephemeral));
1033 }
1034}
1035
1036pub struct AnthropicRequestParams<'a> {
1038 pub model: &'a str,
1039 pub request: CompletionRequest,
1040 pub prompt_caching: bool,
1041}
1042
1043impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
1044 type Error = CompletionError;
1045
1046 fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
1047 let AnthropicRequestParams {
1048 model,
1049 request: req,
1050 prompt_caching,
1051 } = params;
1052
1053 let Some(max_tokens) = req.max_tokens else {
1055 return Err(CompletionError::RequestError(
1056 "`max_tokens` must be set for Anthropic".into(),
1057 ));
1058 };
1059
1060 let mut full_history = vec![];
1061 if let Some(docs) = req.normalized_documents() {
1062 full_history.push(docs);
1063 }
1064 full_history.extend(req.chat_history);
1065
1066 let mut messages = full_history
1067 .into_iter()
1068 .map(Message::try_from)
1069 .collect::<Result<Vec<Message>, _>>()?;
1070
1071 let tools = req
1072 .tools
1073 .into_iter()
1074 .map(|tool| ToolDefinition {
1075 name: tool.name,
1076 description: Some(tool.description),
1077 input_schema: tool.parameters,
1078 })
1079 .collect::<Vec<_>>();
1080
1081 let mut system = if let Some(preamble) = req.preamble {
1083 if preamble.is_empty() {
1084 vec![]
1085 } else {
1086 vec![SystemContent::Text {
1087 text: preamble,
1088 cache_control: None,
1089 }]
1090 }
1091 } else {
1092 vec![]
1093 };
1094
1095 if prompt_caching {
1097 apply_cache_control(&mut system, &mut messages);
1098 }
1099
1100 let output_config = req.output_schema.map(|schema| {
1102 let mut schema_value = schema.to_value();
1103 sanitize_schema(&mut schema_value);
1104 OutputConfig {
1105 format: OutputFormat::JsonSchema {
1106 schema: schema_value,
1107 },
1108 }
1109 });
1110
1111 Ok(Self {
1112 model: model.to_string(),
1113 messages,
1114 max_tokens,
1115 system,
1116 temperature: req.temperature,
1117 tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
1118 tools,
1119 output_config,
1120 additional_params: req.additional_params,
1121 })
1122 }
1123}
1124
1125impl<T> completion::CompletionModel for CompletionModel<T>
1126where
1127 T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
1128{
1129 type Response = CompletionResponse;
1130 type StreamingResponse = StreamingCompletionResponse;
1131 type Client = Client<T>;
1132
1133 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1134 Self::new(client.clone(), model.into())
1135 }
1136
1137 async fn completion(
1138 &self,
1139 mut completion_request: completion::CompletionRequest,
1140 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1141 let request_model = completion_request
1142 .model
1143 .clone()
1144 .unwrap_or_else(|| self.model.clone());
1145 let span = if tracing::Span::current().is_disabled() {
1146 info_span!(
1147 target: "rig::completions",
1148 "chat",
1149 gen_ai.operation.name = "chat",
1150 gen_ai.provider.name = "anthropic",
1151 gen_ai.request.model = &request_model,
1152 gen_ai.system_instructions = &completion_request.preamble,
1153 gen_ai.response.id = tracing::field::Empty,
1154 gen_ai.response.model = tracing::field::Empty,
1155 gen_ai.usage.output_tokens = tracing::field::Empty,
1156 gen_ai.usage.input_tokens = tracing::field::Empty,
1157 )
1158 } else {
1159 tracing::Span::current()
1160 };
1161
1162 if completion_request.max_tokens.is_none() {
1164 if let Some(tokens) = self.default_max_tokens {
1165 completion_request.max_tokens = Some(tokens);
1166 } else {
1167 return Err(CompletionError::RequestError(
1168 "`max_tokens` must be set for Anthropic".into(),
1169 ));
1170 }
1171 }
1172
1173 let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
1174 model: &request_model,
1175 request: completion_request,
1176 prompt_caching: self.prompt_caching,
1177 })?;
1178
1179 if enabled!(Level::TRACE) {
1180 tracing::trace!(
1181 target: "rig::completions",
1182 "Anthropic completion request: {}",
1183 serde_json::to_string_pretty(&request)?
1184 );
1185 }
1186
1187 async move {
1188 let request: Vec<u8> = serde_json::to_vec(&request)?;
1189
1190 let req = self
1191 .client
1192 .post("/v1/messages")?
1193 .body(request)
1194 .map_err(|e| CompletionError::HttpError(e.into()))?;
1195
1196 let response = self
1197 .client
1198 .send::<_, Bytes>(req)
1199 .await
1200 .map_err(CompletionError::HttpError)?;
1201
1202 if response.status().is_success() {
1203 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
1204 response
1205 .into_body()
1206 .await
1207 .map_err(CompletionError::HttpError)?
1208 .to_vec()
1209 .as_slice(),
1210 )? {
1211 ApiResponse::Message(completion) => {
1212 let span = tracing::Span::current();
1213 span.record_response_metadata(&completion);
1214 span.record_token_usage(&completion.usage);
1215 if enabled!(Level::TRACE) {
1216 tracing::trace!(
1217 target: "rig::completions",
1218 "Anthropic completion response: {}",
1219 serde_json::to_string_pretty(&completion)?
1220 );
1221 }
1222 completion.try_into()
1223 }
1224 ApiResponse::Error(ApiErrorResponse { message }) => {
1225 Err(CompletionError::ResponseError(message))
1226 }
1227 }
1228 } else {
1229 let text: String = String::from_utf8_lossy(
1230 &response
1231 .into_body()
1232 .await
1233 .map_err(CompletionError::HttpError)?,
1234 )
1235 .into();
1236 Err(CompletionError::ProviderError(text))
1237 }
1238 }
1239 .instrument(span)
1240 .await
1241 }
1242
1243 async fn stream(
1244 &self,
1245 request: CompletionRequest,
1246 ) -> Result<
1247 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1248 CompletionError,
1249 > {
1250 CompletionModel::stream(self, request).await
1251 }
1252}
1253
1254#[derive(Debug, Deserialize)]
1255struct ApiErrorResponse {
1256 message: String,
1257}
1258
1259#[derive(Debug, Deserialize)]
1260#[serde(tag = "type", rename_all = "snake_case")]
1261enum ApiResponse<T> {
1262 Message(T),
1263 Error(ApiErrorResponse),
1264}
1265
1266#[cfg(test)]
1267mod tests {
1268 use super::*;
1269 use serde_json::json;
1270 use serde_path_to_error::deserialize;
1271
1272 #[test]
1273 fn test_deserialize_message() {
1274 let assistant_message_json = r#"
1275 {
1276 "role": "assistant",
1277 "content": "\n\nHello there, how may I assist you today?"
1278 }
1279 "#;
1280
1281 let assistant_message_json2 = r#"
1282 {
1283 "role": "assistant",
1284 "content": [
1285 {
1286 "type": "text",
1287 "text": "\n\nHello there, how may I assist you today?"
1288 },
1289 {
1290 "type": "tool_use",
1291 "id": "toolu_01A09q90qw90lq917835lq9",
1292 "name": "get_weather",
1293 "input": {"location": "San Francisco, CA"}
1294 }
1295 ]
1296 }
1297 "#;
1298
1299 let user_message_json = r#"
1300 {
1301 "role": "user",
1302 "content": [
1303 {
1304 "type": "image",
1305 "source": {
1306 "type": "base64",
1307 "media_type": "image/jpeg",
1308 "data": "/9j/4AAQSkZJRg..."
1309 }
1310 },
1311 {
1312 "type": "text",
1313 "text": "What is in this image?"
1314 },
1315 {
1316 "type": "tool_result",
1317 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1318 "content": "15 degrees"
1319 }
1320 ]
1321 }
1322 "#;
1323
1324 let assistant_message: Message = {
1325 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1326 deserialize(jd).unwrap_or_else(|err| {
1327 panic!("Deserialization error at {}: {}", err.path(), err);
1328 })
1329 };
1330
1331 let assistant_message2: Message = {
1332 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1333 deserialize(jd).unwrap_or_else(|err| {
1334 panic!("Deserialization error at {}: {}", err.path(), err);
1335 })
1336 };
1337
1338 let user_message: Message = {
1339 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1340 deserialize(jd).unwrap_or_else(|err| {
1341 panic!("Deserialization error at {}: {}", err.path(), err);
1342 })
1343 };
1344
1345 let Message { role, content } = assistant_message;
1346 assert_eq!(role, Role::Assistant);
1347 assert_eq!(
1348 content.first(),
1349 Content::Text {
1350 text: "\n\nHello there, how may I assist you today?".to_owned(),
1351 cache_control: None,
1352 }
1353 );
1354
1355 let Message { role, content } = assistant_message2;
1356 {
1357 assert_eq!(role, Role::Assistant);
1358 assert_eq!(content.len(), 2);
1359
1360 let mut iter = content.into_iter();
1361
1362 match iter.next().unwrap() {
1363 Content::Text { text, .. } => {
1364 assert_eq!(text, "\n\nHello there, how may I assist you today?");
1365 }
1366 _ => panic!("Expected text content"),
1367 }
1368
1369 match iter.next().unwrap() {
1370 Content::ToolUse { id, name, input } => {
1371 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1372 assert_eq!(name, "get_weather");
1373 assert_eq!(input, json!({"location": "San Francisco, CA"}));
1374 }
1375 _ => panic!("Expected tool use content"),
1376 }
1377
1378 assert_eq!(iter.next(), None);
1379 }
1380
1381 let Message { role, content } = user_message;
1382 {
1383 assert_eq!(role, Role::User);
1384 assert_eq!(content.len(), 3);
1385
1386 let mut iter = content.into_iter();
1387
1388 match iter.next().unwrap() {
1389 Content::Image { source, .. } => {
1390 assert_eq!(
1391 source,
1392 ImageSource::Base64 {
1393 data: "/9j/4AAQSkZJRg...".to_owned(),
1394 media_type: ImageFormat::JPEG,
1395 }
1396 );
1397 }
1398 _ => panic!("Expected image content"),
1399 }
1400
1401 match iter.next().unwrap() {
1402 Content::Text { text, .. } => {
1403 assert_eq!(text, "What is in this image?");
1404 }
1405 _ => panic!("Expected text content"),
1406 }
1407
1408 match iter.next().unwrap() {
1409 Content::ToolResult {
1410 tool_use_id,
1411 content,
1412 is_error,
1413 ..
1414 } => {
1415 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1416 assert_eq!(
1417 content.first(),
1418 ToolResultContent::Text {
1419 text: "15 degrees".to_owned()
1420 }
1421 );
1422 assert_eq!(is_error, None);
1423 }
1424 _ => panic!("Expected tool result content"),
1425 }
1426
1427 assert_eq!(iter.next(), None);
1428 }
1429 }
1430
1431 #[test]
1432 fn test_message_to_message_conversion() {
1433 let user_message: Message = serde_json::from_str(
1434 r#"
1435 {
1436 "role": "user",
1437 "content": [
1438 {
1439 "type": "image",
1440 "source": {
1441 "type": "base64",
1442 "media_type": "image/jpeg",
1443 "data": "/9j/4AAQSkZJRg..."
1444 }
1445 },
1446 {
1447 "type": "text",
1448 "text": "What is in this image?"
1449 },
1450 {
1451 "type": "document",
1452 "source": {
1453 "type": "base64",
1454 "data": "base64_encoded_pdf_data",
1455 "media_type": "application/pdf"
1456 }
1457 }
1458 ]
1459 }
1460 "#,
1461 )
1462 .unwrap();
1463
1464 let assistant_message = Message {
1465 role: Role::Assistant,
1466 content: OneOrMany::one(Content::ToolUse {
1467 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1468 name: "get_weather".to_string(),
1469 input: json!({"location": "San Francisco, CA"}),
1470 }),
1471 };
1472
1473 let tool_message = Message {
1474 role: Role::User,
1475 content: OneOrMany::one(Content::ToolResult {
1476 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1477 content: OneOrMany::one(ToolResultContent::Text {
1478 text: "15 degrees".to_string(),
1479 }),
1480 is_error: None,
1481 cache_control: None,
1482 }),
1483 };
1484
1485 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1486 let converted_assistant_message: message::Message =
1487 assistant_message.clone().try_into().unwrap();
1488 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1489
1490 match converted_user_message.clone() {
1491 message::Message::User { content } => {
1492 assert_eq!(content.len(), 3);
1493
1494 let mut iter = content.into_iter();
1495
1496 match iter.next().unwrap() {
1497 message::UserContent::Image(message::Image {
1498 data, media_type, ..
1499 }) => {
1500 assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1501 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1502 }
1503 _ => panic!("Expected image content"),
1504 }
1505
1506 match iter.next().unwrap() {
1507 message::UserContent::Text(message::Text { text }) => {
1508 assert_eq!(text, "What is in this image?");
1509 }
1510 _ => panic!("Expected text content"),
1511 }
1512
1513 match iter.next().unwrap() {
1514 message::UserContent::Document(message::Document {
1515 data, media_type, ..
1516 }) => {
1517 assert_eq!(
1518 data,
1519 DocumentSourceKind::String("base64_encoded_pdf_data".into())
1520 );
1521 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1522 }
1523 _ => panic!("Expected document content"),
1524 }
1525
1526 assert_eq!(iter.next(), None);
1527 }
1528 _ => panic!("Expected user message"),
1529 }
1530
1531 match converted_tool_message.clone() {
1532 message::Message::User { content } => {
1533 let message::ToolResult { id, content, .. } = match content.first() {
1534 message::UserContent::ToolResult(tool_result) => tool_result,
1535 _ => panic!("Expected tool result content"),
1536 };
1537 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1538 match content.first() {
1539 message::ToolResultContent::Text(message::Text { text }) => {
1540 assert_eq!(text, "15 degrees");
1541 }
1542 _ => panic!("Expected text content"),
1543 }
1544 }
1545 _ => panic!("Expected tool result content"),
1546 }
1547
1548 match converted_assistant_message.clone() {
1549 message::Message::Assistant { content, .. } => {
1550 assert_eq!(content.len(), 1);
1551
1552 match content.first() {
1553 message::AssistantContent::ToolCall(message::ToolCall {
1554 id, function, ..
1555 }) => {
1556 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1557 assert_eq!(function.name, "get_weather");
1558 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1559 }
1560 _ => panic!("Expected tool call content"),
1561 }
1562 }
1563 _ => panic!("Expected assistant message"),
1564 }
1565
1566 let original_user_message: Message = converted_user_message.try_into().unwrap();
1567 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1568 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1569
1570 assert_eq!(user_message, original_user_message);
1571 assert_eq!(assistant_message, original_assistant_message);
1572 assert_eq!(tool_message, original_tool_message);
1573 }
1574
1575 #[test]
1576 fn test_content_format_conversion() {
1577 use crate::completion::message::ContentFormat;
1578
1579 let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1580 assert_eq!(source_type, SourceType::URL);
1581
1582 let content_format: ContentFormat = SourceType::URL.into();
1583 assert_eq!(content_format, ContentFormat::Url);
1584
1585 let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1586 assert_eq!(source_type, SourceType::BASE64);
1587
1588 let content_format: ContentFormat = SourceType::BASE64.into();
1589 assert_eq!(content_format, ContentFormat::Base64);
1590
1591 let source_type: SourceType = ContentFormat::String.try_into().unwrap();
1592 assert_eq!(source_type, SourceType::TEXT);
1593
1594 let content_format: ContentFormat = SourceType::TEXT.into();
1595 assert_eq!(content_format, ContentFormat::String);
1596 }
1597
1598 #[test]
1599 fn test_cache_control_serialization() {
1600 let system = SystemContent::Text {
1602 text: "You are a helpful assistant.".to_string(),
1603 cache_control: Some(CacheControl::Ephemeral),
1604 };
1605 let json = serde_json::to_string(&system).unwrap();
1606 assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1607 assert!(json.contains(r#""type":"text""#));
1608
1609 let system_no_cache = SystemContent::Text {
1611 text: "Hello".to_string(),
1612 cache_control: None,
1613 };
1614 let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1615 assert!(!json_no_cache.contains("cache_control"));
1616
1617 let content = Content::Text {
1619 text: "Test message".to_string(),
1620 cache_control: Some(CacheControl::Ephemeral),
1621 };
1622 let json_content = serde_json::to_string(&content).unwrap();
1623 assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1624
1625 let mut system_vec = vec![SystemContent::Text {
1627 text: "System prompt".to_string(),
1628 cache_control: None,
1629 }];
1630 let mut messages = vec![
1631 Message {
1632 role: Role::User,
1633 content: OneOrMany::one(Content::Text {
1634 text: "First message".to_string(),
1635 cache_control: None,
1636 }),
1637 },
1638 Message {
1639 role: Role::Assistant,
1640 content: OneOrMany::one(Content::Text {
1641 text: "Response".to_string(),
1642 cache_control: None,
1643 }),
1644 },
1645 ];
1646
1647 apply_cache_control(&mut system_vec, &mut messages);
1648
1649 match &system_vec[0] {
1651 SystemContent::Text { cache_control, .. } => {
1652 assert!(cache_control.is_some());
1653 }
1654 }
1655
1656 for content in messages[0].content.iter() {
1659 if let Content::Text { cache_control, .. } = content {
1660 assert!(cache_control.is_none());
1661 }
1662 }
1663
1664 for content in messages[1].content.iter() {
1666 if let Content::Text { cache_control, .. } = content {
1667 assert!(cache_control.is_some());
1668 }
1669 }
1670 }
1671
1672 #[test]
1673 fn test_plaintext_document_serialization() {
1674 let content = Content::Document {
1675 source: DocumentSource::Text {
1676 data: "Hello, world!".to_string(),
1677 media_type: PlainTextMediaType::Plain,
1678 },
1679 cache_control: None,
1680 };
1681
1682 let json = serde_json::to_value(&content).unwrap();
1683 assert_eq!(json["type"], "document");
1684 assert_eq!(json["source"]["type"], "text");
1685 assert_eq!(json["source"]["media_type"], "text/plain");
1686 assert_eq!(json["source"]["data"], "Hello, world!");
1687 }
1688
1689 #[test]
1690 fn test_plaintext_document_deserialization() {
1691 let json = r#"
1692 {
1693 "type": "document",
1694 "source": {
1695 "type": "text",
1696 "media_type": "text/plain",
1697 "data": "Hello, world!"
1698 }
1699 }
1700 "#;
1701
1702 let content: Content = serde_json::from_str(json).unwrap();
1703 match content {
1704 Content::Document {
1705 source,
1706 cache_control,
1707 } => {
1708 assert_eq!(
1709 source,
1710 DocumentSource::Text {
1711 data: "Hello, world!".to_string(),
1712 media_type: PlainTextMediaType::Plain,
1713 }
1714 );
1715 assert_eq!(cache_control, None);
1716 }
1717 _ => panic!("Expected Document content"),
1718 }
1719 }
1720
1721 #[test]
1722 fn test_base64_pdf_document_serialization() {
1723 let content = Content::Document {
1724 source: DocumentSource::Base64 {
1725 data: "base64data".to_string(),
1726 media_type: DocumentFormat::PDF,
1727 },
1728 cache_control: None,
1729 };
1730
1731 let json = serde_json::to_value(&content).unwrap();
1732 assert_eq!(json["type"], "document");
1733 assert_eq!(json["source"]["type"], "base64");
1734 assert_eq!(json["source"]["media_type"], "application/pdf");
1735 assert_eq!(json["source"]["data"], "base64data");
1736 }
1737
1738 #[test]
1739 fn test_base64_pdf_document_deserialization() {
1740 let json = r#"
1741 {
1742 "type": "document",
1743 "source": {
1744 "type": "base64",
1745 "media_type": "application/pdf",
1746 "data": "base64data"
1747 }
1748 }
1749 "#;
1750
1751 let content: Content = serde_json::from_str(json).unwrap();
1752 match content {
1753 Content::Document { source, .. } => {
1754 assert_eq!(
1755 source,
1756 DocumentSource::Base64 {
1757 data: "base64data".to_string(),
1758 media_type: DocumentFormat::PDF,
1759 }
1760 );
1761 }
1762 _ => panic!("Expected Document content"),
1763 }
1764 }
1765
1766 #[test]
1767 fn test_plaintext_rig_to_anthropic_conversion() {
1768 use crate::completion::message as msg;
1769
1770 let rig_message = msg::Message::User {
1771 content: OneOrMany::one(msg::UserContent::document(
1772 "Some plain text content".to_string(),
1773 Some(msg::DocumentMediaType::TXT),
1774 )),
1775 };
1776
1777 let anthropic_message: Message = rig_message.try_into().unwrap();
1778 assert_eq!(anthropic_message.role, Role::User);
1779
1780 let mut iter = anthropic_message.content.into_iter();
1781 match iter.next().unwrap() {
1782 Content::Document { source, .. } => {
1783 assert_eq!(
1784 source,
1785 DocumentSource::Text {
1786 data: "Some plain text content".to_string(),
1787 media_type: PlainTextMediaType::Plain,
1788 }
1789 );
1790 }
1791 other => panic!("Expected Document content, got: {other:?}"),
1792 }
1793 }
1794
1795 #[test]
1796 fn test_plaintext_anthropic_to_rig_conversion() {
1797 use crate::completion::message as msg;
1798
1799 let anthropic_message = Message {
1800 role: Role::User,
1801 content: OneOrMany::one(Content::Document {
1802 source: DocumentSource::Text {
1803 data: "Some plain text content".to_string(),
1804 media_type: PlainTextMediaType::Plain,
1805 },
1806 cache_control: None,
1807 }),
1808 };
1809
1810 let rig_message: msg::Message = anthropic_message.try_into().unwrap();
1811 match rig_message {
1812 msg::Message::User { content } => {
1813 let mut iter = content.into_iter();
1814 match iter.next().unwrap() {
1815 msg::UserContent::Document(msg::Document {
1816 data, media_type, ..
1817 }) => {
1818 assert_eq!(
1819 data,
1820 DocumentSourceKind::String("Some plain text content".into())
1821 );
1822 assert_eq!(media_type, Some(msg::DocumentMediaType::TXT));
1823 }
1824 other => panic!("Expected Document content, got: {other:?}"),
1825 }
1826 }
1827 _ => panic!("Expected User message"),
1828 }
1829 }
1830
1831 #[test]
1832 fn test_plaintext_roundtrip_rig_to_anthropic_and_back() {
1833 use crate::completion::message as msg;
1834
1835 let original = msg::Message::User {
1836 content: OneOrMany::one(msg::UserContent::document(
1837 "Round trip text".to_string(),
1838 Some(msg::DocumentMediaType::TXT),
1839 )),
1840 };
1841
1842 let anthropic: Message = original.clone().try_into().unwrap();
1843 let back: msg::Message = anthropic.try_into().unwrap();
1844
1845 match (&original, &back) {
1846 (
1847 msg::Message::User {
1848 content: orig_content,
1849 },
1850 msg::Message::User {
1851 content: back_content,
1852 },
1853 ) => match (orig_content.first(), back_content.first()) {
1854 (
1855 msg::UserContent::Document(msg::Document {
1856 media_type: orig_mt,
1857 ..
1858 }),
1859 msg::UserContent::Document(msg::Document {
1860 media_type: back_mt,
1861 ..
1862 }),
1863 ) => {
1864 assert_eq!(orig_mt, back_mt);
1865 }
1866 _ => panic!("Expected Document content in both"),
1867 },
1868 _ => panic!("Expected User messages"),
1869 }
1870 }
1871
1872 #[test]
1873 fn test_unsupported_document_type_returns_error() {
1874 use crate::completion::message as msg;
1875
1876 let rig_message = msg::Message::User {
1877 content: OneOrMany::one(msg::UserContent::Document(msg::Document {
1878 data: DocumentSourceKind::String("data".into()),
1879 media_type: Some(msg::DocumentMediaType::HTML),
1880 additional_params: None,
1881 })),
1882 };
1883
1884 let result: Result<Message, _> = rig_message.try_into();
1885 assert!(result.is_err());
1886 let err = result.unwrap_err().to_string();
1887 assert!(
1888 err.contains("Anthropic only supports PDF and plain text documents"),
1889 "Unexpected error: {err}"
1890 );
1891 }
1892
1893 #[test]
1894 fn test_plaintext_document_url_source_returns_error() {
1895 use crate::completion::message as msg;
1896
1897 let rig_message = msg::Message::User {
1898 content: OneOrMany::one(msg::UserContent::Document(msg::Document {
1899 data: DocumentSourceKind::Url("https://example.com/doc.txt".into()),
1900 media_type: Some(msg::DocumentMediaType::TXT),
1901 additional_params: None,
1902 })),
1903 };
1904
1905 let result: Result<Message, _> = rig_message.try_into();
1906 assert!(result.is_err());
1907 let err = result.unwrap_err().to_string();
1908 assert!(
1909 err.contains("Only string or base64 data is supported for plain text documents"),
1910 "Unexpected error: {err}"
1911 );
1912 }
1913
1914 #[test]
1915 fn test_plaintext_document_with_cache_control() {
1916 let content = Content::Document {
1917 source: DocumentSource::Text {
1918 data: "cached text".to_string(),
1919 media_type: PlainTextMediaType::Plain,
1920 },
1921 cache_control: Some(CacheControl::Ephemeral),
1922 };
1923
1924 let json = serde_json::to_value(&content).unwrap();
1925 assert_eq!(json["source"]["type"], "text");
1926 assert_eq!(json["source"]["media_type"], "text/plain");
1927 assert_eq!(json["cache_control"]["type"], "ephemeral");
1928 }
1929
1930 #[test]
1931 fn test_message_with_plaintext_document_deserialization() {
1932 let json = r#"
1933 {
1934 "role": "user",
1935 "content": [
1936 {
1937 "type": "document",
1938 "source": {
1939 "type": "text",
1940 "media_type": "text/plain",
1941 "data": "Hello from a text file"
1942 }
1943 },
1944 {
1945 "type": "text",
1946 "text": "Summarize this document."
1947 }
1948 ]
1949 }
1950 "#;
1951
1952 let message: Message = serde_json::from_str(json).unwrap();
1953 assert_eq!(message.role, Role::User);
1954 assert_eq!(message.content.len(), 2);
1955
1956 let mut iter = message.content.into_iter();
1957
1958 match iter.next().unwrap() {
1959 Content::Document { source, .. } => {
1960 assert_eq!(
1961 source,
1962 DocumentSource::Text {
1963 data: "Hello from a text file".to_string(),
1964 media_type: PlainTextMediaType::Plain,
1965 }
1966 );
1967 }
1968 _ => panic!("Expected Document content"),
1969 }
1970
1971 match iter.next().unwrap() {
1972 Content::Text { text, .. } => {
1973 assert_eq!(text, "Summarize this document.");
1974 }
1975 _ => panic!("Expected Text content"),
1976 }
1977 }
1978
1979 #[test]
1980 fn test_assistant_reasoning_multiblock_to_anthropic_content() {
1981 let reasoning = message::Reasoning {
1982 id: None,
1983 content: vec![
1984 message::ReasoningContent::Text {
1985 text: "step one".to_string(),
1986 signature: Some("sig-1".to_string()),
1987 },
1988 message::ReasoningContent::Summary("summary".to_string()),
1989 message::ReasoningContent::Text {
1990 text: "step two".to_string(),
1991 signature: Some("sig-2".to_string()),
1992 },
1993 message::ReasoningContent::Redacted {
1994 data: "redacted block".to_string(),
1995 },
1996 ],
1997 };
1998
1999 let msg = message::Message::Assistant {
2000 id: None,
2001 content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2002 };
2003 let converted: Message = msg.try_into().expect("convert assistant message");
2004 let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2005
2006 assert_eq!(converted.role, Role::Assistant);
2007 assert_eq!(converted_content.len(), 4);
2008 assert!(matches!(
2009 converted_content.first(),
2010 Some(Content::Thinking { thinking, signature: Some(signature) })
2011 if thinking == "step one" && signature == "sig-1"
2012 ));
2013 assert!(matches!(
2014 converted_content.get(1),
2015 Some(Content::Thinking { thinking, signature: None }) if thinking == "summary"
2016 ));
2017 assert!(matches!(
2018 converted_content.get(2),
2019 Some(Content::Thinking { thinking, signature: Some(signature) })
2020 if thinking == "step two" && signature == "sig-2"
2021 ));
2022 assert!(matches!(
2023 converted_content.get(3),
2024 Some(Content::RedactedThinking { data }) if data == "redacted block"
2025 ));
2026 }
2027
2028 #[test]
2029 fn test_redacted_thinking_content_to_assistant_reasoning() {
2030 let content = Content::RedactedThinking {
2031 data: "opaque-redacted".to_string(),
2032 };
2033 let converted: message::AssistantContent =
2034 content.try_into().expect("convert redacted thinking");
2035
2036 assert!(matches!(
2037 converted,
2038 message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2039 if matches!(
2040 content.first(),
2041 Some(message::ReasoningContent::Redacted { data }) if data == "opaque-redacted"
2042 )
2043 ));
2044 }
2045
2046 #[test]
2047 fn test_assistant_encrypted_reasoning_maps_to_redacted_thinking() {
2048 let reasoning = message::Reasoning {
2049 id: None,
2050 content: vec![message::ReasoningContent::Encrypted(
2051 "ciphertext".to_string(),
2052 )],
2053 };
2054 let msg = message::Message::Assistant {
2055 id: None,
2056 content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2057 };
2058
2059 let converted: Message = msg.try_into().expect("convert assistant message");
2060 let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2061
2062 assert_eq!(converted_content.len(), 1);
2063 assert!(matches!(
2064 converted_content.first(),
2065 Some(Content::RedactedThinking { data }) if data == "ciphertext"
2066 ));
2067 }
2068}