1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError, GetTokenUsage},
6 http_client::HttpClientExt,
7 message::{self, DocumentMediaType, DocumentSourceKind, MessageError, MimeType, Reasoning},
8 one_or_many::string_or_one_or_many,
9 telemetry::{ProviderResponseExt, SpanCombinator},
10 wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42 pub content: Vec<Content>,
43 pub id: String,
44 pub model: String,
45 pub role: String,
46 pub stop_reason: Option<String>,
47 pub stop_sequence: Option<String>,
48 pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52 type OutputMessage = Content;
53 type Usage = Usage;
54
55 fn get_response_id(&self) -> Option<String> {
56 Some(self.id.to_owned())
57 }
58
59 fn get_response_model_name(&self) -> Option<String> {
60 Some(self.model.to_owned())
61 }
62
63 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64 self.content.clone()
65 }
66
67 fn get_text_response(&self) -> Option<String> {
68 let res = self
69 .content
70 .iter()
71 .filter_map(|x| {
72 if let Content::Text { text, .. } = x {
73 Some(text.to_owned())
74 } else {
75 None
76 }
77 })
78 .collect::<Vec<String>>()
79 .join("\n");
80
81 if res.is_empty() { None } else { Some(res) }
82 }
83
84 fn get_usage(&self) -> Option<Self::Usage> {
85 Some(self.usage.clone())
86 }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91 pub input_tokens: u64,
92 pub cache_read_input_tokens: Option<u64>,
93 pub cache_creation_input_tokens: Option<u64>,
94 pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 write!(
100 f,
101 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102 self.input_tokens,
103 match self.cache_read_input_tokens {
104 Some(token) => token.to_string(),
105 None => "n/a".to_string(),
106 },
107 match self.cache_creation_input_tokens {
108 Some(token) => token.to_string(),
109 None => "n/a".to_string(),
110 },
111 self.output_tokens
112 )
113 }
114}
115
116impl GetTokenUsage for Usage {
117 fn token_usage(&self) -> Option<crate::completion::Usage> {
118 let mut usage = crate::completion::Usage::new();
119
120 usage.input_tokens = self.input_tokens
121 + self.cache_creation_input_tokens.unwrap_or_default()
122 + self.cache_read_input_tokens.unwrap_or_default();
123 usage.output_tokens = self.output_tokens;
124 usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126 Some(usage)
127 }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132 pub name: String,
133 pub description: Option<String>,
134 pub input_schema: serde_json::Value,
135}
136
137#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum CacheControl {
141 Ephemeral,
142}
143
144#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
146#[serde(tag = "type", rename_all = "snake_case")]
147pub enum SystemContent {
148 Text {
149 text: String,
150 #[serde(skip_serializing_if = "Option::is_none")]
151 cache_control: Option<CacheControl>,
152 },
153}
154
155impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
156 type Error = CompletionError;
157
158 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
159 let content = response
160 .content
161 .iter()
162 .map(|content| content.clone().try_into())
163 .collect::<Result<Vec<_>, _>>()?;
164
165 let choice = OneOrMany::many(content).map_err(|_| {
166 CompletionError::ResponseError(
167 "Response contained no message or tool call (empty)".to_owned(),
168 )
169 })?;
170
171 let usage = completion::Usage {
172 input_tokens: response.usage.input_tokens,
173 output_tokens: response.usage.output_tokens,
174 total_tokens: response.usage.input_tokens + response.usage.output_tokens,
175 cached_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
176 };
177
178 Ok(completion::CompletionResponse {
179 choice,
180 usage,
181 raw_response: response,
182 message_id: None,
183 })
184 }
185}
186
187#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
188pub struct Message {
189 pub role: Role,
190 #[serde(deserialize_with = "string_or_one_or_many")]
191 pub content: OneOrMany<Content>,
192}
193
194#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
195#[serde(rename_all = "lowercase")]
196pub enum Role {
197 User,
198 Assistant,
199}
200
201#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
202#[serde(tag = "type", rename_all = "snake_case")]
203pub enum Content {
204 Text {
205 text: String,
206 #[serde(skip_serializing_if = "Option::is_none")]
207 cache_control: Option<CacheControl>,
208 },
209 Image {
210 source: ImageSource,
211 #[serde(skip_serializing_if = "Option::is_none")]
212 cache_control: Option<CacheControl>,
213 },
214 ToolUse {
215 id: String,
216 name: String,
217 input: serde_json::Value,
218 },
219 ToolResult {
220 tool_use_id: String,
221 #[serde(deserialize_with = "string_or_one_or_many")]
222 content: OneOrMany<ToolResultContent>,
223 #[serde(skip_serializing_if = "Option::is_none")]
224 is_error: Option<bool>,
225 #[serde(skip_serializing_if = "Option::is_none")]
226 cache_control: Option<CacheControl>,
227 },
228 Document {
229 source: DocumentSource,
230 #[serde(skip_serializing_if = "Option::is_none")]
231 cache_control: Option<CacheControl>,
232 },
233 Thinking {
234 thinking: String,
235 #[serde(skip_serializing_if = "Option::is_none")]
236 signature: Option<String>,
237 },
238 RedactedThinking {
239 data: String,
240 },
241}
242
243impl FromStr for Content {
244 type Err = Infallible;
245
246 fn from_str(s: &str) -> Result<Self, Self::Err> {
247 Ok(Content::Text {
248 text: s.to_owned(),
249 cache_control: None,
250 })
251 }
252}
253
254#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
255#[serde(tag = "type", rename_all = "snake_case")]
256pub enum ToolResultContent {
257 Text { text: String },
258 Image(ImageSource),
259}
260
261impl FromStr for ToolResultContent {
262 type Err = Infallible;
263
264 fn from_str(s: &str) -> Result<Self, Self::Err> {
265 Ok(ToolResultContent::Text { text: s.to_owned() })
266 }
267}
268
269#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
277#[serde(tag = "type", rename_all = "snake_case")]
278pub enum ImageSource {
279 #[serde(rename = "base64")]
280 Base64 {
281 data: String,
282 media_type: ImageFormat,
283 },
284 #[serde(rename = "url")]
285 Url { url: String },
286}
287
288#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
299#[serde(tag = "type", rename_all = "snake_case")]
300pub enum DocumentSource {
301 Base64 {
302 data: String,
303 media_type: DocumentFormat,
304 },
305 Text {
306 data: String,
307 media_type: PlainTextMediaType,
308 },
309 Url {
310 url: String,
311 },
312}
313
314#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
315#[serde(rename_all = "lowercase")]
316pub enum ImageFormat {
317 #[serde(rename = "image/jpeg")]
318 JPEG,
319 #[serde(rename = "image/png")]
320 PNG,
321 #[serde(rename = "image/gif")]
322 GIF,
323 #[serde(rename = "image/webp")]
324 WEBP,
325}
326
327#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
334#[serde(rename_all = "lowercase")]
335pub enum DocumentFormat {
336 #[serde(rename = "application/pdf")]
337 PDF,
338}
339
340#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
346pub enum PlainTextMediaType {
347 #[serde(rename = "text/plain")]
348 Plain,
349}
350
351#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
352#[serde(rename_all = "lowercase")]
353pub enum SourceType {
354 BASE64,
355 URL,
356 TEXT,
357}
358
359impl From<String> for Content {
360 fn from(text: String) -> Self {
361 Content::Text {
362 text,
363 cache_control: None,
364 }
365 }
366}
367
368impl From<String> for ToolResultContent {
369 fn from(text: String) -> Self {
370 ToolResultContent::Text { text }
371 }
372}
373
374impl TryFrom<message::ContentFormat> for SourceType {
375 type Error = MessageError;
376
377 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
378 match format {
379 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
380 message::ContentFormat::Url => Ok(SourceType::URL),
381 message::ContentFormat::String => Ok(SourceType::TEXT),
382 }
383 }
384}
385
386impl From<SourceType> for message::ContentFormat {
387 fn from(source_type: SourceType) -> Self {
388 match source_type {
389 SourceType::BASE64 => message::ContentFormat::Base64,
390 SourceType::URL => message::ContentFormat::Url,
391 SourceType::TEXT => message::ContentFormat::String,
392 }
393 }
394}
395
396impl TryFrom<message::ImageMediaType> for ImageFormat {
397 type Error = MessageError;
398
399 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
400 Ok(match media_type {
401 message::ImageMediaType::JPEG => ImageFormat::JPEG,
402 message::ImageMediaType::PNG => ImageFormat::PNG,
403 message::ImageMediaType::GIF => ImageFormat::GIF,
404 message::ImageMediaType::WEBP => ImageFormat::WEBP,
405 _ => {
406 return Err(MessageError::ConversionError(
407 format!("Unsupported image media type: {media_type:?}").to_owned(),
408 ));
409 }
410 })
411 }
412}
413
414impl From<ImageFormat> for message::ImageMediaType {
415 fn from(format: ImageFormat) -> Self {
416 match format {
417 ImageFormat::JPEG => message::ImageMediaType::JPEG,
418 ImageFormat::PNG => message::ImageMediaType::PNG,
419 ImageFormat::GIF => message::ImageMediaType::GIF,
420 ImageFormat::WEBP => message::ImageMediaType::WEBP,
421 }
422 }
423}
424
425impl TryFrom<DocumentMediaType> for DocumentFormat {
426 type Error = MessageError;
427 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
428 match value {
429 DocumentMediaType::PDF => Ok(DocumentFormat::PDF),
430 other => Err(MessageError::ConversionError(format!(
431 "DocumentFormat only supports PDF for base64 sources, got: {}",
432 other.to_mime_type()
433 ))),
434 }
435 }
436}
437
438impl TryFrom<message::AssistantContent> for Content {
439 type Error = MessageError;
440 fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
441 match text {
442 message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
443 text,
444 cache_control: None,
445 }),
446 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
447 "Anthropic currently doesn't support images.".to_string(),
448 )),
449 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
450 Ok(Content::ToolUse {
451 id,
452 name: function.name,
453 input: function.arguments,
454 })
455 }
456 message::AssistantContent::Reasoning(reasoning) => Ok(Content::Thinking {
457 thinking: reasoning.display_text(),
458 signature: reasoning.first_signature().map(str::to_owned),
459 }),
460 }
461 }
462}
463
464fn anthropic_content_from_assistant_content(
465 content: message::AssistantContent,
466) -> Result<Vec<Content>, MessageError> {
467 match content {
468 message::AssistantContent::Text(message::Text { text }) => Ok(vec![Content::Text {
469 text,
470 cache_control: None,
471 }]),
472 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
473 "Anthropic currently doesn't support images.".to_string(),
474 )),
475 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
476 Ok(vec![Content::ToolUse {
477 id,
478 name: function.name,
479 input: function.arguments,
480 }])
481 }
482 message::AssistantContent::Reasoning(reasoning) => {
483 let mut converted = Vec::new();
484 for block in reasoning.content {
485 match block {
486 message::ReasoningContent::Text { text, signature } => {
487 converted.push(Content::Thinking {
488 thinking: text,
489 signature,
490 });
491 }
492 message::ReasoningContent::Summary(summary) => {
493 converted.push(Content::Thinking {
494 thinking: summary,
495 signature: None,
496 });
497 }
498 message::ReasoningContent::Redacted { data }
499 | message::ReasoningContent::Encrypted(data) => {
500 converted.push(Content::RedactedThinking { data });
501 }
502 }
503 }
504
505 if converted.is_empty() {
506 return Err(MessageError::ConversionError(
507 "Cannot convert empty reasoning content to Anthropic format".to_string(),
508 ));
509 }
510
511 Ok(converted)
512 }
513 }
514}
515
516impl TryFrom<message::Message> for Message {
517 type Error = MessageError;
518
519 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
520 Ok(match message {
521 message::Message::User { content } => Message {
522 role: Role::User,
523 content: content.try_map(|content| match content {
524 message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
525 text,
526 cache_control: None,
527 }),
528 message::UserContent::ToolResult(message::ToolResult {
529 id, content, ..
530 }) => Ok(Content::ToolResult {
531 tool_use_id: id,
532 content: content.try_map(|content| match content {
533 message::ToolResultContent::Text(message::Text { text }) => {
534 Ok(ToolResultContent::Text { text })
535 }
536 message::ToolResultContent::Image(image) => {
537 let DocumentSourceKind::Base64(data) = image.data else {
538 return Err(MessageError::ConversionError(
539 "Only base64 strings can be used with the Anthropic API"
540 .to_string(),
541 ));
542 };
543 let media_type =
544 image.media_type.ok_or(MessageError::ConversionError(
545 "Image media type is required".to_owned(),
546 ))?;
547 Ok(ToolResultContent::Image(ImageSource::Base64 {
548 data,
549 media_type: media_type.try_into()?,
550 }))
551 }
552 })?,
553 is_error: None,
554 cache_control: None,
555 }),
556 message::UserContent::Image(message::Image {
557 data, media_type, ..
558 }) => {
559 let source = match data {
560 DocumentSourceKind::Base64(data) => {
561 let media_type =
562 media_type.ok_or(MessageError::ConversionError(
563 "Image media type is required for Claude API".to_string(),
564 ))?;
565 ImageSource::Base64 {
566 data,
567 media_type: ImageFormat::try_from(media_type)?,
568 }
569 }
570 DocumentSourceKind::Url(url) => ImageSource::Url { url },
571 DocumentSourceKind::Unknown => {
572 return Err(MessageError::ConversionError(
573 "Image content has no body".into(),
574 ));
575 }
576 doc => {
577 return Err(MessageError::ConversionError(format!(
578 "Unsupported document type: {doc:?}"
579 )));
580 }
581 };
582
583 Ok(Content::Image {
584 source,
585 cache_control: None,
586 })
587 }
588 message::UserContent::Document(message::Document {
589 data, media_type, ..
590 }) => {
591 let media_type = media_type.ok_or(MessageError::ConversionError(
592 "Document media type is required".to_string(),
593 ))?;
594
595 let source = match media_type {
596 DocumentMediaType::PDF => {
597 let data = match data {
598 DocumentSourceKind::Base64(data)
599 | DocumentSourceKind::String(data) => data,
600 _ => {
601 return Err(MessageError::ConversionError(
602 "Only base64 encoded data is supported for PDF documents".into(),
603 ));
604 }
605 };
606 DocumentSource::Base64 {
607 data,
608 media_type: DocumentFormat::PDF,
609 }
610 }
611 DocumentMediaType::TXT => {
612 let data = match data {
613 DocumentSourceKind::String(data)
614 | DocumentSourceKind::Base64(data) => data,
615 _ => {
616 return Err(MessageError::ConversionError(
617 "Only string or base64 data is supported for plain text documents".into(),
618 ));
619 }
620 };
621 DocumentSource::Text {
622 data,
623 media_type: PlainTextMediaType::Plain,
624 }
625 }
626 other => {
627 return Err(MessageError::ConversionError(format!(
628 "Anthropic only supports PDF and plain text documents, got: {}",
629 other.to_mime_type()
630 )));
631 }
632 };
633
634 Ok(Content::Document {
635 source,
636 cache_control: None,
637 })
638 }
639 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
640 "Audio is not supported in Anthropic".to_owned(),
641 )),
642 message::UserContent::Video { .. } => Err(MessageError::ConversionError(
643 "Video is not supported in Anthropic".to_owned(),
644 )),
645 })?,
646 },
647
648 message::Message::System { content } => Message {
649 role: Role::User,
650 content: OneOrMany::one(Content::Text {
651 text: content,
652 cache_control: None,
653 }),
654 },
655
656 message::Message::Assistant { content, .. } => {
657 let converted_content = content.into_iter().try_fold(
658 Vec::new(),
659 |mut accumulated, assistant_content| {
660 accumulated
661 .extend(anthropic_content_from_assistant_content(assistant_content)?);
662 Ok::<Vec<Content>, MessageError>(accumulated)
663 },
664 )?;
665
666 Message {
667 content: OneOrMany::many(converted_content).map_err(|_| {
668 MessageError::ConversionError(
669 "Assistant message did not contain Anthropic-compatible content"
670 .to_owned(),
671 )
672 })?,
673 role: Role::Assistant,
674 }
675 }
676 })
677 }
678}
679
680impl TryFrom<Content> for message::AssistantContent {
681 type Error = MessageError;
682
683 fn try_from(content: Content) -> Result<Self, Self::Error> {
684 Ok(match content {
685 Content::Text { text, .. } => message::AssistantContent::text(text),
686 Content::ToolUse { id, name, input } => {
687 message::AssistantContent::tool_call(id, name, input)
688 }
689 Content::Thinking {
690 thinking,
691 signature,
692 } => message::AssistantContent::Reasoning(Reasoning::new_with_signature(
693 &thinking, signature,
694 )),
695 Content::RedactedThinking { data } => {
696 message::AssistantContent::Reasoning(Reasoning::redacted(data))
697 }
698 _ => {
699 return Err(MessageError::ConversionError(
700 "Content did not contain a message, tool call, or reasoning".to_owned(),
701 ));
702 }
703 })
704 }
705}
706
707impl From<ToolResultContent> for message::ToolResultContent {
708 fn from(content: ToolResultContent) -> Self {
709 match content {
710 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
711 ToolResultContent::Image(source) => match source {
712 ImageSource::Base64 { data, media_type } => {
713 message::ToolResultContent::image_base64(data, Some(media_type.into()), None)
714 }
715 ImageSource::Url { url } => message::ToolResultContent::image_url(url, None, None),
716 },
717 }
718 }
719}
720
721impl TryFrom<Message> for message::Message {
722 type Error = MessageError;
723
724 fn try_from(message: Message) -> Result<Self, Self::Error> {
725 Ok(match message.role {
726 Role::User => message::Message::User {
727 content: message.content.try_map(|content| {
728 Ok(match content {
729 Content::Text { text, .. } => message::UserContent::text(text),
730 Content::ToolResult {
731 tool_use_id,
732 content,
733 ..
734 } => message::UserContent::tool_result(
735 tool_use_id,
736 content.map(|content| content.into()),
737 ),
738 Content::Image { source, .. } => match source {
739 ImageSource::Base64 { data, media_type } => {
740 message::UserContent::Image(message::Image {
741 data: DocumentSourceKind::Base64(data),
742 media_type: Some(media_type.into()),
743 detail: None,
744 additional_params: None,
745 })
746 }
747 ImageSource::Url { url } => {
748 message::UserContent::Image(message::Image {
749 data: DocumentSourceKind::Url(url),
750 media_type: None,
751 detail: None,
752 additional_params: None,
753 })
754 }
755 },
756 Content::Document { source, .. } => match source {
757 DocumentSource::Base64 { data, media_type } => {
758 let rig_media_type = match media_type {
759 DocumentFormat::PDF => message::DocumentMediaType::PDF,
760 };
761 message::UserContent::document(data, Some(rig_media_type))
762 }
763 DocumentSource::Text { data, .. } => message::UserContent::document(
764 data,
765 Some(message::DocumentMediaType::TXT),
766 ),
767 DocumentSource::Url { url } => {
768 message::UserContent::document_url(url, None)
769 }
770 },
771 _ => {
772 return Err(MessageError::ConversionError(
773 "Unsupported content type for User role".to_owned(),
774 ));
775 }
776 })
777 })?,
778 },
779 Role::Assistant => message::Message::Assistant {
780 id: None,
781 content: message.content.try_map(|content| content.try_into())?,
782 },
783 })
784 }
785}
786
787#[derive(Clone)]
788pub struct CompletionModel<T = reqwest::Client> {
789 pub(crate) client: Client<T>,
790 pub model: String,
791 pub default_max_tokens: Option<u64>,
792 pub prompt_caching: bool,
794}
795
796impl<T> CompletionModel<T>
797where
798 T: HttpClientExt,
799{
800 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
801 let model = model.into();
802 let default_max_tokens = calculate_max_tokens(&model);
803
804 Self {
805 client,
806 model,
807 default_max_tokens,
808 prompt_caching: false, }
810 }
811
812 pub fn with_model(client: Client<T>, model: &str) -> Self {
813 Self {
814 client,
815 model: model.to_string(),
816 default_max_tokens: Some(calculate_max_tokens_custom(model)),
817 prompt_caching: false, }
819 }
820
821 pub fn with_prompt_caching(mut self) -> Self {
829 self.prompt_caching = true;
830 self
831 }
832}
833
834fn calculate_max_tokens(model: &str) -> Option<u64> {
838 if model.starts_with("claude-opus-4") {
839 Some(32000)
840 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
841 Some(64000)
842 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
843 Some(8192)
844 } else if model.starts_with("claude-3-opus")
845 || model.starts_with("claude-3-sonnet")
846 || model.starts_with("claude-3-haiku")
847 {
848 Some(4096)
849 } else {
850 None
851 }
852}
853
854fn calculate_max_tokens_custom(model: &str) -> u64 {
855 if model.starts_with("claude-opus-4") {
856 32000
857 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
858 64000
859 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
860 8192
861 } else if model.starts_with("claude-3-opus")
862 || model.starts_with("claude-3-sonnet")
863 || model.starts_with("claude-3-haiku")
864 {
865 4096
866 } else {
867 2048
868 }
869}
870
871#[derive(Debug, Deserialize, Serialize)]
872pub struct Metadata {
873 user_id: Option<String>,
874}
875
876#[derive(Default, Debug, Serialize, Deserialize)]
877#[serde(tag = "type", rename_all = "snake_case")]
878pub enum ToolChoice {
879 #[default]
880 Auto,
881 Any,
882 None,
883 Tool {
884 name: String,
885 },
886}
887impl TryFrom<message::ToolChoice> for ToolChoice {
888 type Error = CompletionError;
889
890 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
891 let res = match value {
892 message::ToolChoice::Auto => Self::Auto,
893 message::ToolChoice::None => Self::None,
894 message::ToolChoice::Required => Self::Any,
895 message::ToolChoice::Specific { function_names } => {
896 if function_names.len() != 1 {
897 return Err(CompletionError::ProviderError(
898 "Only one tool may be specified to be used by Claude".into(),
899 ));
900 }
901
902 Self::Tool {
903 name: function_names.first().unwrap().to_string(),
904 }
905 }
906 };
907
908 Ok(res)
909 }
910}
911
912fn sanitize_schema(schema: &mut serde_json::Value) {
918 use serde_json::Value;
919
920 if let Value::Object(obj) = schema {
921 let is_object_schema = obj.get("type") == Some(&Value::String("object".to_string()))
922 || obj.contains_key("properties");
923
924 if is_object_schema && !obj.contains_key("additionalProperties") {
925 obj.insert("additionalProperties".to_string(), Value::Bool(false));
926 }
927
928 if let Some(Value::Object(properties)) = obj.get("properties") {
929 let prop_keys = properties.keys().cloned().map(Value::String).collect();
930 obj.insert("required".to_string(), Value::Array(prop_keys));
931 }
932
933 let is_numeric_schema = obj.get("type") == Some(&Value::String("integer".to_string()))
935 || obj.get("type") == Some(&Value::String("number".to_string()));
936
937 if is_numeric_schema {
938 for key in [
939 "minimum",
940 "maximum",
941 "exclusiveMinimum",
942 "exclusiveMaximum",
943 "multipleOf",
944 ] {
945 obj.remove(key);
946 }
947 }
948
949 if let Some(defs) = obj.get_mut("$defs")
950 && let Value::Object(defs_obj) = defs
951 {
952 for (_, def_schema) in defs_obj.iter_mut() {
953 sanitize_schema(def_schema);
954 }
955 }
956
957 if let Some(properties) = obj.get_mut("properties")
958 && let Value::Object(props) = properties
959 {
960 for (_, prop_value) in props.iter_mut() {
961 sanitize_schema(prop_value);
962 }
963 }
964
965 if let Some(items) = obj.get_mut("items") {
966 sanitize_schema(items);
967 }
968
969 for key in ["anyOf", "oneOf", "allOf"] {
970 if let Some(variants) = obj.get_mut(key)
971 && let Value::Array(variants_array) = variants
972 {
973 for variant in variants_array.iter_mut() {
974 sanitize_schema(variant);
975 }
976 }
977 }
978 }
979}
980
981#[derive(Debug, Deserialize, Serialize)]
984#[serde(tag = "type", rename_all = "snake_case")]
985enum OutputFormat {
986 JsonSchema { schema: serde_json::Value },
988}
989
990#[derive(Debug, Deserialize, Serialize)]
992struct OutputConfig {
993 format: OutputFormat,
994}
995
996#[derive(Debug, Deserialize, Serialize)]
997struct AnthropicCompletionRequest {
998 model: String,
999 messages: Vec<Message>,
1000 max_tokens: u64,
1001 #[serde(skip_serializing_if = "Vec::is_empty")]
1003 system: Vec<SystemContent>,
1004 #[serde(skip_serializing_if = "Option::is_none")]
1005 temperature: Option<f64>,
1006 #[serde(skip_serializing_if = "Option::is_none")]
1007 tool_choice: Option<ToolChoice>,
1008 #[serde(skip_serializing_if = "Vec::is_empty")]
1009 tools: Vec<serde_json::Value>,
1010 #[serde(skip_serializing_if = "Option::is_none")]
1011 output_config: Option<OutputConfig>,
1012 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1013 additional_params: Option<serde_json::Value>,
1014}
1015
1016fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
1018 match content {
1019 Content::Text { cache_control, .. } => *cache_control = value,
1020 Content::Image { cache_control, .. } => *cache_control = value,
1021 Content::ToolResult { cache_control, .. } => *cache_control = value,
1022 Content::Document { cache_control, .. } => *cache_control = value,
1023 _ => {}
1024 }
1025}
1026
1027pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
1032 if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
1034 *cache_control = Some(CacheControl::Ephemeral);
1035 }
1036
1037 for msg in messages.iter_mut() {
1039 for content in msg.content.iter_mut() {
1040 set_content_cache_control(content, None);
1041 }
1042 }
1043
1044 if let Some(last_msg) = messages.last_mut() {
1046 set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::Ephemeral));
1047 }
1048}
1049
1050pub(super) fn split_system_messages_from_history(
1051 history: Vec<message::Message>,
1052) -> (Vec<SystemContent>, Vec<message::Message>) {
1053 let mut system = Vec::new();
1054 let mut remaining = Vec::new();
1055
1056 for message in history {
1057 match message {
1058 message::Message::System { content } => {
1059 if !content.is_empty() {
1060 system.push(SystemContent::Text {
1061 text: content,
1062 cache_control: None,
1063 });
1064 }
1065 }
1066 other => remaining.push(other),
1067 }
1068 }
1069
1070 (system, remaining)
1071}
1072
1073pub struct AnthropicRequestParams<'a> {
1075 pub model: &'a str,
1076 pub request: CompletionRequest,
1077 pub prompt_caching: bool,
1078}
1079
1080impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
1081 type Error = CompletionError;
1082
1083 fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
1084 let AnthropicRequestParams {
1085 model,
1086 request: mut req,
1087 prompt_caching,
1088 } = params;
1089
1090 let Some(max_tokens) = req.max_tokens else {
1092 return Err(CompletionError::RequestError(
1093 "`max_tokens` must be set for Anthropic".into(),
1094 ));
1095 };
1096
1097 let mut full_history = vec![];
1098 if let Some(docs) = req.normalized_documents() {
1099 full_history.push(docs);
1100 }
1101 full_history.extend(req.chat_history);
1102 let (history_system, full_history) = split_system_messages_from_history(full_history);
1103
1104 let mut messages = full_history
1105 .into_iter()
1106 .map(Message::try_from)
1107 .collect::<Result<Vec<Message>, _>>()?;
1108
1109 let mut additional_params_payload = req
1110 .additional_params
1111 .take()
1112 .unwrap_or(serde_json::Value::Null);
1113 let mut additional_tools =
1114 extract_tools_from_additional_params(&mut additional_params_payload)?;
1115
1116 let mut tools = req
1117 .tools
1118 .into_iter()
1119 .map(|tool| ToolDefinition {
1120 name: tool.name,
1121 description: Some(tool.description),
1122 input_schema: tool.parameters,
1123 })
1124 .map(serde_json::to_value)
1125 .collect::<Result<Vec<_>, _>>()?;
1126 tools.append(&mut additional_tools);
1127
1128 let mut system = if let Some(preamble) = req.preamble {
1130 if preamble.is_empty() {
1131 vec![]
1132 } else {
1133 vec![SystemContent::Text {
1134 text: preamble,
1135 cache_control: None,
1136 }]
1137 }
1138 } else {
1139 vec![]
1140 };
1141 system.extend(history_system);
1142
1143 if prompt_caching {
1145 apply_cache_control(&mut system, &mut messages);
1146 }
1147
1148 let output_config = req.output_schema.map(|schema| {
1150 let mut schema_value = schema.to_value();
1151 sanitize_schema(&mut schema_value);
1152 OutputConfig {
1153 format: OutputFormat::JsonSchema {
1154 schema: schema_value,
1155 },
1156 }
1157 });
1158
1159 Ok(Self {
1160 model: model.to_string(),
1161 messages,
1162 max_tokens,
1163 system,
1164 temperature: req.temperature,
1165 tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
1166 tools,
1167 output_config,
1168 additional_params: if additional_params_payload.is_null() {
1169 None
1170 } else {
1171 Some(additional_params_payload)
1172 },
1173 })
1174 }
1175}
1176
1177fn extract_tools_from_additional_params(
1178 additional_params: &mut serde_json::Value,
1179) -> Result<Vec<serde_json::Value>, CompletionError> {
1180 if let Some(map) = additional_params.as_object_mut()
1181 && let Some(raw_tools) = map.remove("tools")
1182 {
1183 return serde_json::from_value::<Vec<serde_json::Value>>(raw_tools).map_err(|err| {
1184 CompletionError::RequestError(
1185 format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(),
1186 )
1187 });
1188 }
1189
1190 Ok(Vec::new())
1191}
1192
1193impl<T> completion::CompletionModel for CompletionModel<T>
1194where
1195 T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
1196{
1197 type Response = CompletionResponse;
1198 type StreamingResponse = StreamingCompletionResponse;
1199 type Client = Client<T>;
1200
1201 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1202 Self::new(client.clone(), model.into())
1203 }
1204
1205 async fn completion(
1206 &self,
1207 mut completion_request: completion::CompletionRequest,
1208 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1209 let request_model = completion_request
1210 .model
1211 .clone()
1212 .unwrap_or_else(|| self.model.clone());
1213 let span = if tracing::Span::current().is_disabled() {
1214 info_span!(
1215 target: "rig::completions",
1216 "chat",
1217 gen_ai.operation.name = "chat",
1218 gen_ai.provider.name = "anthropic",
1219 gen_ai.request.model = &request_model,
1220 gen_ai.system_instructions = &completion_request.preamble,
1221 gen_ai.response.id = tracing::field::Empty,
1222 gen_ai.response.model = tracing::field::Empty,
1223 gen_ai.usage.output_tokens = tracing::field::Empty,
1224 gen_ai.usage.input_tokens = tracing::field::Empty,
1225 gen_ai.usage.cached_tokens = tracing::field::Empty,
1226 )
1227 } else {
1228 tracing::Span::current()
1229 };
1230
1231 if completion_request.max_tokens.is_none() {
1233 if let Some(tokens) = self.default_max_tokens {
1234 completion_request.max_tokens = Some(tokens);
1235 } else {
1236 return Err(CompletionError::RequestError(
1237 "`max_tokens` must be set for Anthropic".into(),
1238 ));
1239 }
1240 }
1241
1242 let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
1243 model: &request_model,
1244 request: completion_request,
1245 prompt_caching: self.prompt_caching,
1246 })?;
1247
1248 if enabled!(Level::TRACE) {
1249 tracing::trace!(
1250 target: "rig::completions",
1251 "Anthropic completion request: {}",
1252 serde_json::to_string_pretty(&request)?
1253 );
1254 }
1255
1256 async move {
1257 let request: Vec<u8> = serde_json::to_vec(&request)?;
1258
1259 let req = self
1260 .client
1261 .post("/v1/messages")?
1262 .body(request)
1263 .map_err(|e| CompletionError::HttpError(e.into()))?;
1264
1265 let response = self
1266 .client
1267 .send::<_, Bytes>(req)
1268 .await
1269 .map_err(CompletionError::HttpError)?;
1270
1271 if response.status().is_success() {
1272 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
1273 response
1274 .into_body()
1275 .await
1276 .map_err(CompletionError::HttpError)?
1277 .to_vec()
1278 .as_slice(),
1279 )? {
1280 ApiResponse::Message(completion) => {
1281 let span = tracing::Span::current();
1282 span.record_response_metadata(&completion);
1283 span.record_token_usage(&completion.usage);
1284 if enabled!(Level::TRACE) {
1285 tracing::trace!(
1286 target: "rig::completions",
1287 "Anthropic completion response: {}",
1288 serde_json::to_string_pretty(&completion)?
1289 );
1290 }
1291 completion.try_into()
1292 }
1293 ApiResponse::Error(ApiErrorResponse { message }) => {
1294 Err(CompletionError::ResponseError(message))
1295 }
1296 }
1297 } else {
1298 let text: String = String::from_utf8_lossy(
1299 &response
1300 .into_body()
1301 .await
1302 .map_err(CompletionError::HttpError)?,
1303 )
1304 .into();
1305 Err(CompletionError::ProviderError(text))
1306 }
1307 }
1308 .instrument(span)
1309 .await
1310 }
1311
1312 async fn stream(
1313 &self,
1314 request: CompletionRequest,
1315 ) -> Result<
1316 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1317 CompletionError,
1318 > {
1319 CompletionModel::stream(self, request).await
1320 }
1321}
1322
1323#[derive(Debug, Deserialize)]
1324struct ApiErrorResponse {
1325 message: String,
1326}
1327
1328#[derive(Debug, Deserialize)]
1329#[serde(tag = "type", rename_all = "snake_case")]
1330enum ApiResponse<T> {
1331 Message(T),
1332 Error(ApiErrorResponse),
1333}
1334
1335#[cfg(test)]
1336mod tests {
1337 use super::*;
1338 use serde_json::json;
1339 use serde_path_to_error::deserialize;
1340
1341 #[test]
1342 fn test_deserialize_message() {
1343 let assistant_message_json = r#"
1344 {
1345 "role": "assistant",
1346 "content": "\n\nHello there, how may I assist you today?"
1347 }
1348 "#;
1349
1350 let assistant_message_json2 = r#"
1351 {
1352 "role": "assistant",
1353 "content": [
1354 {
1355 "type": "text",
1356 "text": "\n\nHello there, how may I assist you today?"
1357 },
1358 {
1359 "type": "tool_use",
1360 "id": "toolu_01A09q90qw90lq917835lq9",
1361 "name": "get_weather",
1362 "input": {"location": "San Francisco, CA"}
1363 }
1364 ]
1365 }
1366 "#;
1367
1368 let user_message_json = r#"
1369 {
1370 "role": "user",
1371 "content": [
1372 {
1373 "type": "image",
1374 "source": {
1375 "type": "base64",
1376 "media_type": "image/jpeg",
1377 "data": "/9j/4AAQSkZJRg..."
1378 }
1379 },
1380 {
1381 "type": "text",
1382 "text": "What is in this image?"
1383 },
1384 {
1385 "type": "tool_result",
1386 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1387 "content": "15 degrees"
1388 }
1389 ]
1390 }
1391 "#;
1392
1393 let assistant_message: Message = {
1394 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1395 deserialize(jd).unwrap_or_else(|err| {
1396 panic!("Deserialization error at {}: {}", err.path(), err);
1397 })
1398 };
1399
1400 let assistant_message2: Message = {
1401 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1402 deserialize(jd).unwrap_or_else(|err| {
1403 panic!("Deserialization error at {}: {}", err.path(), err);
1404 })
1405 };
1406
1407 let user_message: Message = {
1408 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1409 deserialize(jd).unwrap_or_else(|err| {
1410 panic!("Deserialization error at {}: {}", err.path(), err);
1411 })
1412 };
1413
1414 let Message { role, content } = assistant_message;
1415 assert_eq!(role, Role::Assistant);
1416 assert_eq!(
1417 content.first(),
1418 Content::Text {
1419 text: "\n\nHello there, how may I assist you today?".to_owned(),
1420 cache_control: None,
1421 }
1422 );
1423
1424 let Message { role, content } = assistant_message2;
1425 {
1426 assert_eq!(role, Role::Assistant);
1427 assert_eq!(content.len(), 2);
1428
1429 let mut iter = content.into_iter();
1430
1431 match iter.next().unwrap() {
1432 Content::Text { text, .. } => {
1433 assert_eq!(text, "\n\nHello there, how may I assist you today?");
1434 }
1435 _ => panic!("Expected text content"),
1436 }
1437
1438 match iter.next().unwrap() {
1439 Content::ToolUse { id, name, input } => {
1440 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1441 assert_eq!(name, "get_weather");
1442 assert_eq!(input, json!({"location": "San Francisco, CA"}));
1443 }
1444 _ => panic!("Expected tool use content"),
1445 }
1446
1447 assert_eq!(iter.next(), None);
1448 }
1449
1450 let Message { role, content } = user_message;
1451 {
1452 assert_eq!(role, Role::User);
1453 assert_eq!(content.len(), 3);
1454
1455 let mut iter = content.into_iter();
1456
1457 match iter.next().unwrap() {
1458 Content::Image { source, .. } => {
1459 assert_eq!(
1460 source,
1461 ImageSource::Base64 {
1462 data: "/9j/4AAQSkZJRg...".to_owned(),
1463 media_type: ImageFormat::JPEG,
1464 }
1465 );
1466 }
1467 _ => panic!("Expected image content"),
1468 }
1469
1470 match iter.next().unwrap() {
1471 Content::Text { text, .. } => {
1472 assert_eq!(text, "What is in this image?");
1473 }
1474 _ => panic!("Expected text content"),
1475 }
1476
1477 match iter.next().unwrap() {
1478 Content::ToolResult {
1479 tool_use_id,
1480 content,
1481 is_error,
1482 ..
1483 } => {
1484 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1485 assert_eq!(
1486 content.first(),
1487 ToolResultContent::Text {
1488 text: "15 degrees".to_owned()
1489 }
1490 );
1491 assert_eq!(is_error, None);
1492 }
1493 _ => panic!("Expected tool result content"),
1494 }
1495
1496 assert_eq!(iter.next(), None);
1497 }
1498 }
1499
1500 #[test]
1501 fn test_message_to_message_conversion() {
1502 let user_message: Message = serde_json::from_str(
1503 r#"
1504 {
1505 "role": "user",
1506 "content": [
1507 {
1508 "type": "image",
1509 "source": {
1510 "type": "base64",
1511 "media_type": "image/jpeg",
1512 "data": "/9j/4AAQSkZJRg..."
1513 }
1514 },
1515 {
1516 "type": "text",
1517 "text": "What is in this image?"
1518 },
1519 {
1520 "type": "document",
1521 "source": {
1522 "type": "base64",
1523 "data": "base64_encoded_pdf_data",
1524 "media_type": "application/pdf"
1525 }
1526 }
1527 ]
1528 }
1529 "#,
1530 )
1531 .unwrap();
1532
1533 let assistant_message = Message {
1534 role: Role::Assistant,
1535 content: OneOrMany::one(Content::ToolUse {
1536 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1537 name: "get_weather".to_string(),
1538 input: json!({"location": "San Francisco, CA"}),
1539 }),
1540 };
1541
1542 let tool_message = Message {
1543 role: Role::User,
1544 content: OneOrMany::one(Content::ToolResult {
1545 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1546 content: OneOrMany::one(ToolResultContent::Text {
1547 text: "15 degrees".to_string(),
1548 }),
1549 is_error: None,
1550 cache_control: None,
1551 }),
1552 };
1553
1554 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1555 let converted_assistant_message: message::Message =
1556 assistant_message.clone().try_into().unwrap();
1557 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1558
1559 match converted_user_message.clone() {
1560 message::Message::User { content } => {
1561 assert_eq!(content.len(), 3);
1562
1563 let mut iter = content.into_iter();
1564
1565 match iter.next().unwrap() {
1566 message::UserContent::Image(message::Image {
1567 data, media_type, ..
1568 }) => {
1569 assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1570 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1571 }
1572 _ => panic!("Expected image content"),
1573 }
1574
1575 match iter.next().unwrap() {
1576 message::UserContent::Text(message::Text { text }) => {
1577 assert_eq!(text, "What is in this image?");
1578 }
1579 _ => panic!("Expected text content"),
1580 }
1581
1582 match iter.next().unwrap() {
1583 message::UserContent::Document(message::Document {
1584 data, media_type, ..
1585 }) => {
1586 assert_eq!(
1587 data,
1588 DocumentSourceKind::String("base64_encoded_pdf_data".into())
1589 );
1590 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1591 }
1592 _ => panic!("Expected document content"),
1593 }
1594
1595 assert_eq!(iter.next(), None);
1596 }
1597 _ => panic!("Expected user message"),
1598 }
1599
1600 match converted_tool_message.clone() {
1601 message::Message::User { content } => {
1602 let message::ToolResult { id, content, .. } = match content.first() {
1603 message::UserContent::ToolResult(tool_result) => tool_result,
1604 _ => panic!("Expected tool result content"),
1605 };
1606 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1607 match content.first() {
1608 message::ToolResultContent::Text(message::Text { text }) => {
1609 assert_eq!(text, "15 degrees");
1610 }
1611 _ => panic!("Expected text content"),
1612 }
1613 }
1614 _ => panic!("Expected tool result content"),
1615 }
1616
1617 match converted_assistant_message.clone() {
1618 message::Message::Assistant { content, .. } => {
1619 assert_eq!(content.len(), 1);
1620
1621 match content.first() {
1622 message::AssistantContent::ToolCall(message::ToolCall {
1623 id, function, ..
1624 }) => {
1625 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1626 assert_eq!(function.name, "get_weather");
1627 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1628 }
1629 _ => panic!("Expected tool call content"),
1630 }
1631 }
1632 _ => panic!("Expected assistant message"),
1633 }
1634
1635 let original_user_message: Message = converted_user_message.try_into().unwrap();
1636 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1637 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1638
1639 assert_eq!(user_message, original_user_message);
1640 assert_eq!(assistant_message, original_assistant_message);
1641 assert_eq!(tool_message, original_tool_message);
1642 }
1643
1644 #[test]
1645 fn test_content_format_conversion() {
1646 use crate::completion::message::ContentFormat;
1647
1648 let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1649 assert_eq!(source_type, SourceType::URL);
1650
1651 let content_format: ContentFormat = SourceType::URL.into();
1652 assert_eq!(content_format, ContentFormat::Url);
1653
1654 let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1655 assert_eq!(source_type, SourceType::BASE64);
1656
1657 let content_format: ContentFormat = SourceType::BASE64.into();
1658 assert_eq!(content_format, ContentFormat::Base64);
1659
1660 let source_type: SourceType = ContentFormat::String.try_into().unwrap();
1661 assert_eq!(source_type, SourceType::TEXT);
1662
1663 let content_format: ContentFormat = SourceType::TEXT.into();
1664 assert_eq!(content_format, ContentFormat::String);
1665 }
1666
1667 #[test]
1668 fn test_cache_control_serialization() {
1669 let system = SystemContent::Text {
1671 text: "You are a helpful assistant.".to_string(),
1672 cache_control: Some(CacheControl::Ephemeral),
1673 };
1674 let json = serde_json::to_string(&system).unwrap();
1675 assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1676 assert!(json.contains(r#""type":"text""#));
1677
1678 let system_no_cache = SystemContent::Text {
1680 text: "Hello".to_string(),
1681 cache_control: None,
1682 };
1683 let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1684 assert!(!json_no_cache.contains("cache_control"));
1685
1686 let content = Content::Text {
1688 text: "Test message".to_string(),
1689 cache_control: Some(CacheControl::Ephemeral),
1690 };
1691 let json_content = serde_json::to_string(&content).unwrap();
1692 assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1693
1694 let mut system_vec = vec![SystemContent::Text {
1696 text: "System prompt".to_string(),
1697 cache_control: None,
1698 }];
1699 let mut messages = vec![
1700 Message {
1701 role: Role::User,
1702 content: OneOrMany::one(Content::Text {
1703 text: "First message".to_string(),
1704 cache_control: None,
1705 }),
1706 },
1707 Message {
1708 role: Role::Assistant,
1709 content: OneOrMany::one(Content::Text {
1710 text: "Response".to_string(),
1711 cache_control: None,
1712 }),
1713 },
1714 ];
1715
1716 apply_cache_control(&mut system_vec, &mut messages);
1717
1718 match &system_vec[0] {
1720 SystemContent::Text { cache_control, .. } => {
1721 assert!(cache_control.is_some());
1722 }
1723 }
1724
1725 for content in messages[0].content.iter() {
1728 if let Content::Text { cache_control, .. } = content {
1729 assert!(cache_control.is_none());
1730 }
1731 }
1732
1733 for content in messages[1].content.iter() {
1735 if let Content::Text { cache_control, .. } = content {
1736 assert!(cache_control.is_some());
1737 }
1738 }
1739 }
1740
1741 #[test]
1742 fn test_plaintext_document_serialization() {
1743 let content = Content::Document {
1744 source: DocumentSource::Text {
1745 data: "Hello, world!".to_string(),
1746 media_type: PlainTextMediaType::Plain,
1747 },
1748 cache_control: None,
1749 };
1750
1751 let json = serde_json::to_value(&content).unwrap();
1752 assert_eq!(json["type"], "document");
1753 assert_eq!(json["source"]["type"], "text");
1754 assert_eq!(json["source"]["media_type"], "text/plain");
1755 assert_eq!(json["source"]["data"], "Hello, world!");
1756 }
1757
1758 #[test]
1759 fn test_plaintext_document_deserialization() {
1760 let json = r#"
1761 {
1762 "type": "document",
1763 "source": {
1764 "type": "text",
1765 "media_type": "text/plain",
1766 "data": "Hello, world!"
1767 }
1768 }
1769 "#;
1770
1771 let content: Content = serde_json::from_str(json).unwrap();
1772 match content {
1773 Content::Document {
1774 source,
1775 cache_control,
1776 } => {
1777 assert_eq!(
1778 source,
1779 DocumentSource::Text {
1780 data: "Hello, world!".to_string(),
1781 media_type: PlainTextMediaType::Plain,
1782 }
1783 );
1784 assert_eq!(cache_control, None);
1785 }
1786 _ => panic!("Expected Document content"),
1787 }
1788 }
1789
1790 #[test]
1791 fn test_base64_pdf_document_serialization() {
1792 let content = Content::Document {
1793 source: DocumentSource::Base64 {
1794 data: "base64data".to_string(),
1795 media_type: DocumentFormat::PDF,
1796 },
1797 cache_control: None,
1798 };
1799
1800 let json = serde_json::to_value(&content).unwrap();
1801 assert_eq!(json["type"], "document");
1802 assert_eq!(json["source"]["type"], "base64");
1803 assert_eq!(json["source"]["media_type"], "application/pdf");
1804 assert_eq!(json["source"]["data"], "base64data");
1805 }
1806
1807 #[test]
1808 fn test_base64_pdf_document_deserialization() {
1809 let json = r#"
1810 {
1811 "type": "document",
1812 "source": {
1813 "type": "base64",
1814 "media_type": "application/pdf",
1815 "data": "base64data"
1816 }
1817 }
1818 "#;
1819
1820 let content: Content = serde_json::from_str(json).unwrap();
1821 match content {
1822 Content::Document { source, .. } => {
1823 assert_eq!(
1824 source,
1825 DocumentSource::Base64 {
1826 data: "base64data".to_string(),
1827 media_type: DocumentFormat::PDF,
1828 }
1829 );
1830 }
1831 _ => panic!("Expected Document content"),
1832 }
1833 }
1834
1835 #[test]
1836 fn test_plaintext_rig_to_anthropic_conversion() {
1837 use crate::completion::message as msg;
1838
1839 let rig_message = msg::Message::User {
1840 content: OneOrMany::one(msg::UserContent::document(
1841 "Some plain text content".to_string(),
1842 Some(msg::DocumentMediaType::TXT),
1843 )),
1844 };
1845
1846 let anthropic_message: Message = rig_message.try_into().unwrap();
1847 assert_eq!(anthropic_message.role, Role::User);
1848
1849 let mut iter = anthropic_message.content.into_iter();
1850 match iter.next().unwrap() {
1851 Content::Document { source, .. } => {
1852 assert_eq!(
1853 source,
1854 DocumentSource::Text {
1855 data: "Some plain text content".to_string(),
1856 media_type: PlainTextMediaType::Plain,
1857 }
1858 );
1859 }
1860 other => panic!("Expected Document content, got: {other:?}"),
1861 }
1862 }
1863
1864 #[test]
1865 fn test_plaintext_anthropic_to_rig_conversion() {
1866 use crate::completion::message as msg;
1867
1868 let anthropic_message = Message {
1869 role: Role::User,
1870 content: OneOrMany::one(Content::Document {
1871 source: DocumentSource::Text {
1872 data: "Some plain text content".to_string(),
1873 media_type: PlainTextMediaType::Plain,
1874 },
1875 cache_control: None,
1876 }),
1877 };
1878
1879 let rig_message: msg::Message = anthropic_message.try_into().unwrap();
1880 match rig_message {
1881 msg::Message::User { content } => {
1882 let mut iter = content.into_iter();
1883 match iter.next().unwrap() {
1884 msg::UserContent::Document(msg::Document {
1885 data, media_type, ..
1886 }) => {
1887 assert_eq!(
1888 data,
1889 DocumentSourceKind::String("Some plain text content".into())
1890 );
1891 assert_eq!(media_type, Some(msg::DocumentMediaType::TXT));
1892 }
1893 other => panic!("Expected Document content, got: {other:?}"),
1894 }
1895 }
1896 _ => panic!("Expected User message"),
1897 }
1898 }
1899
1900 #[test]
1901 fn test_plaintext_roundtrip_rig_to_anthropic_and_back() {
1902 use crate::completion::message as msg;
1903
1904 let original = msg::Message::User {
1905 content: OneOrMany::one(msg::UserContent::document(
1906 "Round trip text".to_string(),
1907 Some(msg::DocumentMediaType::TXT),
1908 )),
1909 };
1910
1911 let anthropic: Message = original.clone().try_into().unwrap();
1912 let back: msg::Message = anthropic.try_into().unwrap();
1913
1914 match (&original, &back) {
1915 (
1916 msg::Message::User {
1917 content: orig_content,
1918 },
1919 msg::Message::User {
1920 content: back_content,
1921 },
1922 ) => match (orig_content.first(), back_content.first()) {
1923 (
1924 msg::UserContent::Document(msg::Document {
1925 media_type: orig_mt,
1926 ..
1927 }),
1928 msg::UserContent::Document(msg::Document {
1929 media_type: back_mt,
1930 ..
1931 }),
1932 ) => {
1933 assert_eq!(orig_mt, back_mt);
1934 }
1935 _ => panic!("Expected Document content in both"),
1936 },
1937 _ => panic!("Expected User messages"),
1938 }
1939 }
1940
1941 #[test]
1942 fn test_unsupported_document_type_returns_error() {
1943 use crate::completion::message as msg;
1944
1945 let rig_message = msg::Message::User {
1946 content: OneOrMany::one(msg::UserContent::Document(msg::Document {
1947 data: DocumentSourceKind::String("data".into()),
1948 media_type: Some(msg::DocumentMediaType::HTML),
1949 additional_params: None,
1950 })),
1951 };
1952
1953 let result: Result<Message, _> = rig_message.try_into();
1954 assert!(result.is_err());
1955 let err = result.unwrap_err().to_string();
1956 assert!(
1957 err.contains("Anthropic only supports PDF and plain text documents"),
1958 "Unexpected error: {err}"
1959 );
1960 }
1961
1962 #[test]
1963 fn test_plaintext_document_url_source_returns_error() {
1964 use crate::completion::message as msg;
1965
1966 let rig_message = msg::Message::User {
1967 content: OneOrMany::one(msg::UserContent::Document(msg::Document {
1968 data: DocumentSourceKind::Url("https://example.com/doc.txt".into()),
1969 media_type: Some(msg::DocumentMediaType::TXT),
1970 additional_params: None,
1971 })),
1972 };
1973
1974 let result: Result<Message, _> = rig_message.try_into();
1975 assert!(result.is_err());
1976 let err = result.unwrap_err().to_string();
1977 assert!(
1978 err.contains("Only string or base64 data is supported for plain text documents"),
1979 "Unexpected error: {err}"
1980 );
1981 }
1982
1983 #[test]
1984 fn test_plaintext_document_with_cache_control() {
1985 let content = Content::Document {
1986 source: DocumentSource::Text {
1987 data: "cached text".to_string(),
1988 media_type: PlainTextMediaType::Plain,
1989 },
1990 cache_control: Some(CacheControl::Ephemeral),
1991 };
1992
1993 let json = serde_json::to_value(&content).unwrap();
1994 assert_eq!(json["source"]["type"], "text");
1995 assert_eq!(json["source"]["media_type"], "text/plain");
1996 assert_eq!(json["cache_control"]["type"], "ephemeral");
1997 }
1998
1999 #[test]
2000 fn test_message_with_plaintext_document_deserialization() {
2001 let json = r#"
2002 {
2003 "role": "user",
2004 "content": [
2005 {
2006 "type": "document",
2007 "source": {
2008 "type": "text",
2009 "media_type": "text/plain",
2010 "data": "Hello from a text file"
2011 }
2012 },
2013 {
2014 "type": "text",
2015 "text": "Summarize this document."
2016 }
2017 ]
2018 }
2019 "#;
2020
2021 let message: Message = serde_json::from_str(json).unwrap();
2022 assert_eq!(message.role, Role::User);
2023 assert_eq!(message.content.len(), 2);
2024
2025 let mut iter = message.content.into_iter();
2026
2027 match iter.next().unwrap() {
2028 Content::Document { source, .. } => {
2029 assert_eq!(
2030 source,
2031 DocumentSource::Text {
2032 data: "Hello from a text file".to_string(),
2033 media_type: PlainTextMediaType::Plain,
2034 }
2035 );
2036 }
2037 _ => panic!("Expected Document content"),
2038 }
2039
2040 match iter.next().unwrap() {
2041 Content::Text { text, .. } => {
2042 assert_eq!(text, "Summarize this document.");
2043 }
2044 _ => panic!("Expected Text content"),
2045 }
2046 }
2047
2048 #[test]
2049 fn test_assistant_reasoning_multiblock_to_anthropic_content() {
2050 let reasoning = message::Reasoning {
2051 id: None,
2052 content: vec![
2053 message::ReasoningContent::Text {
2054 text: "step one".to_string(),
2055 signature: Some("sig-1".to_string()),
2056 },
2057 message::ReasoningContent::Summary("summary".to_string()),
2058 message::ReasoningContent::Text {
2059 text: "step two".to_string(),
2060 signature: Some("sig-2".to_string()),
2061 },
2062 message::ReasoningContent::Redacted {
2063 data: "redacted block".to_string(),
2064 },
2065 ],
2066 };
2067
2068 let msg = message::Message::Assistant {
2069 id: None,
2070 content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2071 };
2072 let converted: Message = msg.try_into().expect("convert assistant message");
2073 let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2074
2075 assert_eq!(converted.role, Role::Assistant);
2076 assert_eq!(converted_content.len(), 4);
2077 assert!(matches!(
2078 converted_content.first(),
2079 Some(Content::Thinking { thinking, signature: Some(signature) })
2080 if thinking == "step one" && signature == "sig-1"
2081 ));
2082 assert!(matches!(
2083 converted_content.get(1),
2084 Some(Content::Thinking { thinking, signature: None }) if thinking == "summary"
2085 ));
2086 assert!(matches!(
2087 converted_content.get(2),
2088 Some(Content::Thinking { thinking, signature: Some(signature) })
2089 if thinking == "step two" && signature == "sig-2"
2090 ));
2091 assert!(matches!(
2092 converted_content.get(3),
2093 Some(Content::RedactedThinking { data }) if data == "redacted block"
2094 ));
2095 }
2096
2097 #[test]
2098 fn test_redacted_thinking_content_to_assistant_reasoning() {
2099 let content = Content::RedactedThinking {
2100 data: "opaque-redacted".to_string(),
2101 };
2102 let converted: message::AssistantContent =
2103 content.try_into().expect("convert redacted thinking");
2104
2105 assert!(matches!(
2106 converted,
2107 message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2108 if matches!(
2109 content.first(),
2110 Some(message::ReasoningContent::Redacted { data }) if data == "opaque-redacted"
2111 )
2112 ));
2113 }
2114
2115 #[test]
2116 fn test_assistant_encrypted_reasoning_maps_to_redacted_thinking() {
2117 let reasoning = message::Reasoning {
2118 id: None,
2119 content: vec![message::ReasoningContent::Encrypted(
2120 "ciphertext".to_string(),
2121 )],
2122 };
2123 let msg = message::Message::Assistant {
2124 id: None,
2125 content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2126 };
2127
2128 let converted: Message = msg.try_into().expect("convert assistant message");
2129 let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2130
2131 assert_eq!(converted_content.len(), 1);
2132 assert!(matches!(
2133 converted_content.first(),
2134 Some(Content::RedactedThinking { data }) if data == "ciphertext"
2135 ));
2136 }
2137}