1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError, GetTokenUsage},
6 http_client::HttpClientExt,
7 message::{self, DocumentMediaType, DocumentSourceKind, MessageError, MimeType, Reasoning},
8 one_or_many::string_or_one_or_many,
9 telemetry::{ProviderResponseExt, SpanCombinator},
10 wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42 pub content: Vec<Content>,
43 pub id: String,
44 pub model: String,
45 pub role: String,
46 pub stop_reason: Option<String>,
47 pub stop_sequence: Option<String>,
48 pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52 type OutputMessage = Content;
53 type Usage = Usage;
54
55 fn get_response_id(&self) -> Option<String> {
56 Some(self.id.to_owned())
57 }
58
59 fn get_response_model_name(&self) -> Option<String> {
60 Some(self.model.to_owned())
61 }
62
63 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64 self.content.clone()
65 }
66
67 fn get_text_response(&self) -> Option<String> {
68 let res = self
69 .content
70 .iter()
71 .filter_map(|x| {
72 if let Content::Text { text, .. } = x {
73 Some(text.to_owned())
74 } else {
75 None
76 }
77 })
78 .collect::<Vec<String>>()
79 .join("\n");
80
81 if res.is_empty() { None } else { Some(res) }
82 }
83
84 fn get_usage(&self) -> Option<Self::Usage> {
85 Some(self.usage.clone())
86 }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91 pub input_tokens: u64,
92 pub cache_read_input_tokens: Option<u64>,
93 pub cache_creation_input_tokens: Option<u64>,
94 pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 write!(
100 f,
101 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102 self.input_tokens,
103 match self.cache_read_input_tokens {
104 Some(token) => token.to_string(),
105 None => "n/a".to_string(),
106 },
107 match self.cache_creation_input_tokens {
108 Some(token) => token.to_string(),
109 None => "n/a".to_string(),
110 },
111 self.output_tokens
112 )
113 }
114}
115
116impl GetTokenUsage for Usage {
117 fn token_usage(&self) -> Option<crate::completion::Usage> {
118 let mut usage = crate::completion::Usage::new();
119
120 usage.input_tokens = self.input_tokens;
121 usage.output_tokens = self.output_tokens;
122 usage.cached_input_tokens = self.cache_read_input_tokens.unwrap_or_default();
123 usage.cache_creation_input_tokens = self.cache_creation_input_tokens.unwrap_or_default();
124 usage.total_tokens = self.input_tokens
125 + self.cache_read_input_tokens.unwrap_or_default()
126 + self.cache_creation_input_tokens.unwrap_or_default()
127 + self.output_tokens;
128
129 Some(usage)
130 }
131}
132
133#[derive(Debug, Deserialize, Serialize)]
134pub struct ToolDefinition {
135 pub name: String,
136 pub description: Option<String>,
137 pub input_schema: serde_json::Value,
138}
139
140#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Default)]
146pub enum CacheTtl {
147 #[default]
149 #[serde(rename = "5m")]
150 FiveMinutes,
151 #[serde(rename = "1h")]
153 OneHour,
154}
155
156#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
161#[serde(tag = "type", rename_all = "snake_case")]
162pub enum CacheControl {
163 Ephemeral {
164 #[serde(skip_serializing_if = "Option::is_none")]
166 ttl: Option<CacheTtl>,
167 },
168}
169
170impl CacheControl {
171 pub fn ephemeral() -> Self {
173 Self::Ephemeral { ttl: None }
174 }
175
176 pub fn ephemeral_1h() -> Self {
178 Self::Ephemeral {
179 ttl: Some(CacheTtl::OneHour),
180 }
181 }
182}
183
184#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
186#[serde(tag = "type", rename_all = "snake_case")]
187pub enum SystemContent {
188 Text {
189 text: String,
190 #[serde(skip_serializing_if = "Option::is_none")]
191 cache_control: Option<CacheControl>,
192 },
193}
194
195impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
196 type Error = CompletionError;
197
198 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
199 let content = response
200 .content
201 .iter()
202 .map(|content| content.clone().try_into())
203 .collect::<Result<Vec<_>, _>>()?;
204
205 let choice = OneOrMany::many(content).map_err(|_| {
206 CompletionError::ResponseError(
207 "Response contained no message or tool call (empty)".to_owned(),
208 )
209 })?;
210
211 let usage = completion::Usage {
212 input_tokens: response.usage.input_tokens,
213 output_tokens: response.usage.output_tokens,
214 total_tokens: response.usage.input_tokens
215 + response.usage.cache_read_input_tokens.unwrap_or(0)
216 + response.usage.cache_creation_input_tokens.unwrap_or(0)
217 + response.usage.output_tokens,
218 cached_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
219 cache_creation_input_tokens: response.usage.cache_creation_input_tokens.unwrap_or(0),
220 };
221
222 Ok(completion::CompletionResponse {
223 choice,
224 usage,
225 raw_response: response,
226 message_id: None,
227 })
228 }
229}
230
231#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
232pub struct Message {
233 pub role: Role,
234 #[serde(deserialize_with = "string_or_one_or_many")]
235 pub content: OneOrMany<Content>,
236}
237
238#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
239#[serde(rename_all = "lowercase")]
240pub enum Role {
241 User,
242 Assistant,
243}
244
245#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
246#[serde(tag = "type", rename_all = "snake_case")]
247pub enum Content {
248 Text {
249 text: String,
250 #[serde(skip_serializing_if = "Option::is_none")]
251 cache_control: Option<CacheControl>,
252 },
253 Image {
254 source: ImageSource,
255 #[serde(skip_serializing_if = "Option::is_none")]
256 cache_control: Option<CacheControl>,
257 },
258 ToolUse {
259 id: String,
260 name: String,
261 input: serde_json::Value,
262 },
263 ToolResult {
264 tool_use_id: String,
265 #[serde(deserialize_with = "string_or_one_or_many")]
266 content: OneOrMany<ToolResultContent>,
267 #[serde(skip_serializing_if = "Option::is_none")]
268 is_error: Option<bool>,
269 #[serde(skip_serializing_if = "Option::is_none")]
270 cache_control: Option<CacheControl>,
271 },
272 Document {
273 source: DocumentSource,
274 #[serde(skip_serializing_if = "Option::is_none")]
275 cache_control: Option<CacheControl>,
276 },
277 Thinking {
278 thinking: String,
279 #[serde(skip_serializing_if = "Option::is_none")]
280 signature: Option<String>,
281 },
282 RedactedThinking {
283 data: String,
284 },
285}
286
287impl FromStr for Content {
288 type Err = Infallible;
289
290 fn from_str(s: &str) -> Result<Self, Self::Err> {
291 Ok(Content::Text {
292 text: s.to_owned(),
293 cache_control: None,
294 })
295 }
296}
297
298#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
299#[serde(tag = "type", rename_all = "snake_case")]
300pub enum ToolResultContent {
301 Text { text: String },
302 Image(ImageSource),
303}
304
305impl FromStr for ToolResultContent {
306 type Err = Infallible;
307
308 fn from_str(s: &str) -> Result<Self, Self::Err> {
309 Ok(ToolResultContent::Text { text: s.to_owned() })
310 }
311}
312
313#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
321#[serde(tag = "type", rename_all = "snake_case")]
322pub enum ImageSource {
323 #[serde(rename = "base64")]
324 Base64 {
325 data: String,
326 media_type: ImageFormat,
327 },
328 #[serde(rename = "url")]
329 Url { url: String },
330}
331
332#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
343#[serde(tag = "type", rename_all = "snake_case")]
344pub enum DocumentSource {
345 Base64 {
346 data: String,
347 media_type: DocumentFormat,
348 },
349 Text {
350 data: String,
351 media_type: PlainTextMediaType,
352 },
353 Url {
354 url: String,
355 },
356}
357
358#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
359#[serde(rename_all = "lowercase")]
360pub enum ImageFormat {
361 #[serde(rename = "image/jpeg")]
362 JPEG,
363 #[serde(rename = "image/png")]
364 PNG,
365 #[serde(rename = "image/gif")]
366 GIF,
367 #[serde(rename = "image/webp")]
368 WEBP,
369}
370
371#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
378#[serde(rename_all = "lowercase")]
379pub enum DocumentFormat {
380 #[serde(rename = "application/pdf")]
381 PDF,
382}
383
384#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
390pub enum PlainTextMediaType {
391 #[serde(rename = "text/plain")]
392 Plain,
393}
394
395#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
396#[serde(rename_all = "lowercase")]
397pub enum SourceType {
398 BASE64,
399 URL,
400 TEXT,
401}
402
403impl From<String> for Content {
404 fn from(text: String) -> Self {
405 Content::Text {
406 text,
407 cache_control: None,
408 }
409 }
410}
411
412impl From<String> for ToolResultContent {
413 fn from(text: String) -> Self {
414 ToolResultContent::Text { text }
415 }
416}
417
418impl TryFrom<message::ContentFormat> for SourceType {
419 type Error = MessageError;
420
421 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
422 match format {
423 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
424 message::ContentFormat::Url => Ok(SourceType::URL),
425 message::ContentFormat::String => Ok(SourceType::TEXT),
426 }
427 }
428}
429
430impl From<SourceType> for message::ContentFormat {
431 fn from(source_type: SourceType) -> Self {
432 match source_type {
433 SourceType::BASE64 => message::ContentFormat::Base64,
434 SourceType::URL => message::ContentFormat::Url,
435 SourceType::TEXT => message::ContentFormat::String,
436 }
437 }
438}
439
440impl TryFrom<message::ImageMediaType> for ImageFormat {
441 type Error = MessageError;
442
443 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
444 Ok(match media_type {
445 message::ImageMediaType::JPEG => ImageFormat::JPEG,
446 message::ImageMediaType::PNG => ImageFormat::PNG,
447 message::ImageMediaType::GIF => ImageFormat::GIF,
448 message::ImageMediaType::WEBP => ImageFormat::WEBP,
449 _ => {
450 return Err(MessageError::ConversionError(
451 format!("Unsupported image media type: {media_type:?}").to_owned(),
452 ));
453 }
454 })
455 }
456}
457
458impl From<ImageFormat> for message::ImageMediaType {
459 fn from(format: ImageFormat) -> Self {
460 match format {
461 ImageFormat::JPEG => message::ImageMediaType::JPEG,
462 ImageFormat::PNG => message::ImageMediaType::PNG,
463 ImageFormat::GIF => message::ImageMediaType::GIF,
464 ImageFormat::WEBP => message::ImageMediaType::WEBP,
465 }
466 }
467}
468
469impl TryFrom<DocumentMediaType> for DocumentFormat {
470 type Error = MessageError;
471 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
472 match value {
473 DocumentMediaType::PDF => Ok(DocumentFormat::PDF),
474 other => Err(MessageError::ConversionError(format!(
475 "DocumentFormat only supports PDF for base64 sources, got: {}",
476 other.to_mime_type()
477 ))),
478 }
479 }
480}
481
482impl TryFrom<message::AssistantContent> for Content {
483 type Error = MessageError;
484 fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
485 match text {
486 message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
487 text,
488 cache_control: None,
489 }),
490 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
491 "Anthropic currently doesn't support images.".to_string(),
492 )),
493 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
494 Ok(Content::ToolUse {
495 id,
496 name: function.name,
497 input: function.arguments,
498 })
499 }
500 message::AssistantContent::Reasoning(reasoning) => Ok(Content::Thinking {
501 thinking: reasoning.display_text(),
502 signature: reasoning.first_signature().map(str::to_owned),
503 }),
504 }
505 }
506}
507
508fn anthropic_content_from_assistant_content(
509 content: message::AssistantContent,
510) -> Result<Vec<Content>, MessageError> {
511 match content {
512 message::AssistantContent::Text(message::Text { text }) => Ok(vec![Content::Text {
513 text,
514 cache_control: None,
515 }]),
516 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
517 "Anthropic currently doesn't support images.".to_string(),
518 )),
519 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
520 Ok(vec![Content::ToolUse {
521 id,
522 name: function.name,
523 input: function.arguments,
524 }])
525 }
526 message::AssistantContent::Reasoning(reasoning) => {
527 let mut converted = Vec::new();
528 for block in reasoning.content {
529 match block {
530 message::ReasoningContent::Text { text, signature } => {
531 converted.push(Content::Thinking {
532 thinking: text,
533 signature,
534 });
535 }
536 message::ReasoningContent::Summary(summary) => {
537 converted.push(Content::Thinking {
538 thinking: summary,
539 signature: None,
540 });
541 }
542 message::ReasoningContent::Redacted { data }
543 | message::ReasoningContent::Encrypted(data) => {
544 converted.push(Content::RedactedThinking { data });
545 }
546 }
547 }
548
549 if converted.is_empty() {
550 return Err(MessageError::ConversionError(
551 "Cannot convert empty reasoning content to Anthropic format".to_string(),
552 ));
553 }
554
555 Ok(converted)
556 }
557 }
558}
559
560impl TryFrom<message::Message> for Message {
561 type Error = MessageError;
562
563 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
564 Ok(match message {
565 message::Message::User { content } => Message {
566 role: Role::User,
567 content: content.try_map(|content| match content {
568 message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
569 text,
570 cache_control: None,
571 }),
572 message::UserContent::ToolResult(message::ToolResult {
573 id, content, ..
574 }) => Ok(Content::ToolResult {
575 tool_use_id: id,
576 content: content.try_map(|content| match content {
577 message::ToolResultContent::Text(message::Text { text }) => {
578 Ok(ToolResultContent::Text { text })
579 }
580 message::ToolResultContent::Image(image) => {
581 let DocumentSourceKind::Base64(data) = image.data else {
582 return Err(MessageError::ConversionError(
583 "Only base64 strings can be used with the Anthropic API"
584 .to_string(),
585 ));
586 };
587 let media_type =
588 image.media_type.ok_or(MessageError::ConversionError(
589 "Image media type is required".to_owned(),
590 ))?;
591 Ok(ToolResultContent::Image(ImageSource::Base64 {
592 data,
593 media_type: media_type.try_into()?,
594 }))
595 }
596 })?,
597 is_error: None,
598 cache_control: None,
599 }),
600 message::UserContent::Image(message::Image {
601 data, media_type, ..
602 }) => {
603 let source = match data {
604 DocumentSourceKind::Base64(data) => {
605 let media_type =
606 media_type.ok_or(MessageError::ConversionError(
607 "Image media type is required for Claude API".to_string(),
608 ))?;
609 ImageSource::Base64 {
610 data,
611 media_type: ImageFormat::try_from(media_type)?,
612 }
613 }
614 DocumentSourceKind::Url(url) => ImageSource::Url { url },
615 DocumentSourceKind::Unknown => {
616 return Err(MessageError::ConversionError(
617 "Image content has no body".into(),
618 ));
619 }
620 doc => {
621 return Err(MessageError::ConversionError(format!(
622 "Unsupported document type: {doc:?}"
623 )));
624 }
625 };
626
627 Ok(Content::Image {
628 source,
629 cache_control: None,
630 })
631 }
632 message::UserContent::Document(message::Document {
633 data, media_type, ..
634 }) => {
635 let media_type = media_type.ok_or(MessageError::ConversionError(
636 "Document media type is required".to_string(),
637 ))?;
638
639 let source = match media_type {
640 DocumentMediaType::PDF => {
641 let data = match data {
642 DocumentSourceKind::Base64(data)
643 | DocumentSourceKind::String(data) => data,
644 _ => {
645 return Err(MessageError::ConversionError(
646 "Only base64 encoded data is supported for PDF documents".into(),
647 ));
648 }
649 };
650 DocumentSource::Base64 {
651 data,
652 media_type: DocumentFormat::PDF,
653 }
654 }
655 DocumentMediaType::TXT => {
656 let data = match data {
657 DocumentSourceKind::String(data)
658 | DocumentSourceKind::Base64(data) => data,
659 _ => {
660 return Err(MessageError::ConversionError(
661 "Only string or base64 data is supported for plain text documents".into(),
662 ));
663 }
664 };
665 DocumentSource::Text {
666 data,
667 media_type: PlainTextMediaType::Plain,
668 }
669 }
670 other => {
671 return Err(MessageError::ConversionError(format!(
672 "Anthropic only supports PDF and plain text documents, got: {}",
673 other.to_mime_type()
674 )));
675 }
676 };
677
678 Ok(Content::Document {
679 source,
680 cache_control: None,
681 })
682 }
683 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
684 "Audio is not supported in Anthropic".to_owned(),
685 )),
686 message::UserContent::Video { .. } => Err(MessageError::ConversionError(
687 "Video is not supported in Anthropic".to_owned(),
688 )),
689 })?,
690 },
691
692 message::Message::System { content } => Message {
693 role: Role::User,
694 content: OneOrMany::one(Content::Text {
695 text: content,
696 cache_control: None,
697 }),
698 },
699
700 message::Message::Assistant { content, .. } => {
701 let converted_content = content.into_iter().try_fold(
702 Vec::new(),
703 |mut accumulated, assistant_content| {
704 accumulated
705 .extend(anthropic_content_from_assistant_content(assistant_content)?);
706 Ok::<Vec<Content>, MessageError>(accumulated)
707 },
708 )?;
709
710 Message {
711 content: OneOrMany::many(converted_content).map_err(|_| {
712 MessageError::ConversionError(
713 "Assistant message did not contain Anthropic-compatible content"
714 .to_owned(),
715 )
716 })?,
717 role: Role::Assistant,
718 }
719 }
720 })
721 }
722}
723
724impl TryFrom<Content> for message::AssistantContent {
725 type Error = MessageError;
726
727 fn try_from(content: Content) -> Result<Self, Self::Error> {
728 Ok(match content {
729 Content::Text { text, .. } => message::AssistantContent::text(text),
730 Content::ToolUse { id, name, input } => {
731 message::AssistantContent::tool_call(id, name, input)
732 }
733 Content::Thinking {
734 thinking,
735 signature,
736 } => message::AssistantContent::Reasoning(Reasoning::new_with_signature(
737 &thinking, signature,
738 )),
739 Content::RedactedThinking { data } => {
740 message::AssistantContent::Reasoning(Reasoning::redacted(data))
741 }
742 _ => {
743 return Err(MessageError::ConversionError(
744 "Content did not contain a message, tool call, or reasoning".to_owned(),
745 ));
746 }
747 })
748 }
749}
750
751impl From<ToolResultContent> for message::ToolResultContent {
752 fn from(content: ToolResultContent) -> Self {
753 match content {
754 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
755 ToolResultContent::Image(source) => match source {
756 ImageSource::Base64 { data, media_type } => {
757 message::ToolResultContent::image_base64(data, Some(media_type.into()), None)
758 }
759 ImageSource::Url { url } => message::ToolResultContent::image_url(url, None, None),
760 },
761 }
762 }
763}
764
765impl TryFrom<Message> for message::Message {
766 type Error = MessageError;
767
768 fn try_from(message: Message) -> Result<Self, Self::Error> {
769 Ok(match message.role {
770 Role::User => message::Message::User {
771 content: message.content.try_map(|content| {
772 Ok(match content {
773 Content::Text { text, .. } => message::UserContent::text(text),
774 Content::ToolResult {
775 tool_use_id,
776 content,
777 ..
778 } => message::UserContent::tool_result(
779 tool_use_id,
780 content.map(|content| content.into()),
781 ),
782 Content::Image { source, .. } => match source {
783 ImageSource::Base64 { data, media_type } => {
784 message::UserContent::Image(message::Image {
785 data: DocumentSourceKind::Base64(data),
786 media_type: Some(media_type.into()),
787 detail: None,
788 additional_params: None,
789 })
790 }
791 ImageSource::Url { url } => {
792 message::UserContent::Image(message::Image {
793 data: DocumentSourceKind::Url(url),
794 media_type: None,
795 detail: None,
796 additional_params: None,
797 })
798 }
799 },
800 Content::Document { source, .. } => match source {
801 DocumentSource::Base64 { data, media_type } => {
802 let rig_media_type = match media_type {
803 DocumentFormat::PDF => message::DocumentMediaType::PDF,
804 };
805 message::UserContent::document(data, Some(rig_media_type))
806 }
807 DocumentSource::Text { data, .. } => message::UserContent::document(
808 data,
809 Some(message::DocumentMediaType::TXT),
810 ),
811 DocumentSource::Url { url } => {
812 message::UserContent::document_url(url, None)
813 }
814 },
815 _ => {
816 return Err(MessageError::ConversionError(
817 "Unsupported content type for User role".to_owned(),
818 ));
819 }
820 })
821 })?,
822 },
823 Role::Assistant => message::Message::Assistant {
824 id: None,
825 content: message.content.try_map(|content| content.try_into())?,
826 },
827 })
828 }
829}
830
831#[derive(Clone)]
832pub struct CompletionModel<T = reqwest::Client> {
833 pub(crate) client: Client<T>,
834 pub model: String,
835 pub default_max_tokens: Option<u64>,
836 pub prompt_caching: bool,
838 pub automatic_caching: bool,
842 pub automatic_caching_ttl: Option<CacheTtl>,
846}
847
848impl<T> CompletionModel<T>
849where
850 T: HttpClientExt,
851{
852 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
853 let model = model.into();
854 let default_max_tokens = calculate_max_tokens(&model);
855
856 Self {
857 client,
858 model,
859 default_max_tokens,
860 prompt_caching: false,
861 automatic_caching: false,
862 automatic_caching_ttl: None,
863 }
864 }
865
866 pub fn with_model(client: Client<T>, model: &str) -> Self {
867 Self {
868 client,
869 model: model.to_string(),
870 default_max_tokens: Some(calculate_max_tokens_custom(model)),
871 prompt_caching: false,
872 automatic_caching: false,
873 automatic_caching_ttl: None,
874 }
875 }
876
877 pub fn with_prompt_caching(mut self) -> Self {
885 self.prompt_caching = true;
886 self
887 }
888
889 pub fn with_automatic_caching(mut self) -> Self {
923 self.automatic_caching = true;
924 self
925 }
926
927 pub fn with_automatic_caching_1h(mut self) -> Self {
944 self.automatic_caching = true;
945 self.automatic_caching_ttl = Some(CacheTtl::OneHour);
946 self
947 }
948}
949
950fn calculate_max_tokens(model: &str) -> Option<u64> {
954 if model.starts_with("claude-opus-4") {
955 Some(32000)
956 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
957 Some(64000)
958 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
959 Some(8192)
960 } else if model.starts_with("claude-3-opus")
961 || model.starts_with("claude-3-sonnet")
962 || model.starts_with("claude-3-haiku")
963 {
964 Some(4096)
965 } else {
966 None
967 }
968}
969
970fn calculate_max_tokens_custom(model: &str) -> u64 {
971 if model.starts_with("claude-opus-4") {
972 32000
973 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
974 64000
975 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
976 8192
977 } else if model.starts_with("claude-3-opus")
978 || model.starts_with("claude-3-sonnet")
979 || model.starts_with("claude-3-haiku")
980 {
981 4096
982 } else {
983 2048
984 }
985}
986
987#[derive(Debug, Deserialize, Serialize)]
988pub struct Metadata {
989 user_id: Option<String>,
990}
991
992#[derive(Default, Debug, Serialize, Deserialize)]
993#[serde(tag = "type", rename_all = "snake_case")]
994pub enum ToolChoice {
995 #[default]
996 Auto,
997 Any,
998 None,
999 Tool {
1000 name: String,
1001 },
1002}
1003impl TryFrom<message::ToolChoice> for ToolChoice {
1004 type Error = CompletionError;
1005
1006 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
1007 let res = match value {
1008 message::ToolChoice::Auto => Self::Auto,
1009 message::ToolChoice::None => Self::None,
1010 message::ToolChoice::Required => Self::Any,
1011 message::ToolChoice::Specific { function_names } => {
1012 if function_names.len() != 1 {
1013 return Err(CompletionError::ProviderError(
1014 "Only one tool may be specified to be used by Claude".into(),
1015 ));
1016 }
1017
1018 Self::Tool {
1019 name: function_names.first().unwrap().to_string(),
1020 }
1021 }
1022 };
1023
1024 Ok(res)
1025 }
1026}
1027
1028fn sanitize_schema(schema: &mut serde_json::Value) {
1034 use serde_json::Value;
1035
1036 if let Value::Object(obj) = schema {
1037 let is_object_schema = obj.get("type") == Some(&Value::String("object".to_string()))
1038 || obj.contains_key("properties");
1039
1040 if is_object_schema && !obj.contains_key("additionalProperties") {
1041 obj.insert("additionalProperties".to_string(), Value::Bool(false));
1042 }
1043
1044 if let Some(Value::Object(properties)) = obj.get("properties") {
1045 let prop_keys = properties.keys().cloned().map(Value::String).collect();
1046 obj.insert("required".to_string(), Value::Array(prop_keys));
1047 }
1048
1049 let is_numeric_schema = obj.get("type") == Some(&Value::String("integer".to_string()))
1051 || obj.get("type") == Some(&Value::String("number".to_string()));
1052
1053 if is_numeric_schema {
1054 for key in [
1055 "minimum",
1056 "maximum",
1057 "exclusiveMinimum",
1058 "exclusiveMaximum",
1059 "multipleOf",
1060 ] {
1061 obj.remove(key);
1062 }
1063 }
1064
1065 if let Some(defs) = obj.get_mut("$defs")
1066 && let Value::Object(defs_obj) = defs
1067 {
1068 for (_, def_schema) in defs_obj.iter_mut() {
1069 sanitize_schema(def_schema);
1070 }
1071 }
1072
1073 if let Some(properties) = obj.get_mut("properties")
1074 && let Value::Object(props) = properties
1075 {
1076 for (_, prop_value) in props.iter_mut() {
1077 sanitize_schema(prop_value);
1078 }
1079 }
1080
1081 if let Some(items) = obj.get_mut("items") {
1082 sanitize_schema(items);
1083 }
1084
1085 if let Some(one_of) = obj.remove("oneOf") {
1087 match obj.get_mut("anyOf") {
1088 Some(Value::Array(existing)) => {
1089 if let Value::Array(mut incoming) = one_of {
1090 existing.append(&mut incoming);
1091 }
1092 }
1093 _ => {
1094 obj.insert("anyOf".to_string(), one_of);
1095 }
1096 }
1097 }
1098
1099 for key in ["anyOf", "allOf"] {
1100 if let Some(variants) = obj.get_mut(key)
1101 && let Value::Array(variants_array) = variants
1102 {
1103 for variant in variants_array.iter_mut() {
1104 sanitize_schema(variant);
1105 }
1106 }
1107 }
1108 }
1109}
1110
1111#[derive(Debug, Deserialize, Serialize)]
1114#[serde(tag = "type", rename_all = "snake_case")]
1115enum OutputFormat {
1116 JsonSchema { schema: serde_json::Value },
1118}
1119
1120#[derive(Debug, Deserialize, Serialize)]
1122struct OutputConfig {
1123 format: OutputFormat,
1124}
1125
1126#[derive(Debug, Deserialize, Serialize)]
1127struct AnthropicCompletionRequest {
1128 model: String,
1129 messages: Vec<Message>,
1130 max_tokens: u64,
1131 #[serde(skip_serializing_if = "Vec::is_empty")]
1133 system: Vec<SystemContent>,
1134 #[serde(skip_serializing_if = "Option::is_none")]
1135 temperature: Option<f64>,
1136 #[serde(skip_serializing_if = "Option::is_none")]
1137 tool_choice: Option<ToolChoice>,
1138 #[serde(skip_serializing_if = "Vec::is_empty")]
1139 tools: Vec<serde_json::Value>,
1140 #[serde(skip_serializing_if = "Option::is_none")]
1141 output_config: Option<OutputConfig>,
1142 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1143 additional_params: Option<serde_json::Value>,
1144 #[serde(skip_serializing_if = "Option::is_none")]
1148 cache_control: Option<CacheControl>,
1149}
1150
1151fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
1153 match content {
1154 Content::Text { cache_control, .. } => *cache_control = value,
1155 Content::Image { cache_control, .. } => *cache_control = value,
1156 Content::ToolResult { cache_control, .. } => *cache_control = value,
1157 Content::Document { cache_control, .. } => *cache_control = value,
1158 _ => {}
1159 }
1160}
1161
1162pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
1167 if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
1169 *cache_control = Some(CacheControl::ephemeral());
1170 }
1171
1172 for msg in messages.iter_mut() {
1174 for content in msg.content.iter_mut() {
1175 set_content_cache_control(content, None);
1176 }
1177 }
1178
1179 if let Some(last_msg) = messages.last_mut() {
1181 set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::ephemeral()));
1182 }
1183}
1184
1185pub(super) fn split_system_messages_from_history(
1186 history: Vec<message::Message>,
1187) -> (Vec<SystemContent>, Vec<message::Message>) {
1188 let mut system = Vec::new();
1189 let mut remaining = Vec::new();
1190
1191 for message in history {
1192 match message {
1193 message::Message::System { content } => {
1194 if !content.is_empty() {
1195 system.push(SystemContent::Text {
1196 text: content,
1197 cache_control: None,
1198 });
1199 }
1200 }
1201 other => remaining.push(other),
1202 }
1203 }
1204
1205 (system, remaining)
1206}
1207
1208pub struct AnthropicRequestParams<'a> {
1210 pub model: &'a str,
1211 pub request: CompletionRequest,
1212 pub prompt_caching: bool,
1213 pub automatic_caching: bool,
1215 pub automatic_caching_ttl: Option<CacheTtl>,
1217}
1218
1219impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
1220 type Error = CompletionError;
1221
1222 fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
1223 let AnthropicRequestParams {
1224 model,
1225 request: mut req,
1226 prompt_caching,
1227 automatic_caching,
1228 automatic_caching_ttl,
1229 } = params;
1230
1231 let Some(max_tokens) = req.max_tokens else {
1233 return Err(CompletionError::RequestError(
1234 "`max_tokens` must be set for Anthropic".into(),
1235 ));
1236 };
1237
1238 let mut full_history = vec![];
1239 if let Some(docs) = req.normalized_documents() {
1240 full_history.push(docs);
1241 }
1242 full_history.extend(req.chat_history);
1243 let (history_system, full_history) = split_system_messages_from_history(full_history);
1244
1245 let mut messages = full_history
1246 .into_iter()
1247 .map(Message::try_from)
1248 .collect::<Result<Vec<Message>, _>>()?;
1249
1250 let mut additional_params_payload = req
1251 .additional_params
1252 .take()
1253 .unwrap_or(serde_json::Value::Null);
1254 let mut additional_tools =
1255 extract_tools_from_additional_params(&mut additional_params_payload)?;
1256
1257 let mut tools = req
1258 .tools
1259 .into_iter()
1260 .map(|tool| ToolDefinition {
1261 name: tool.name,
1262 description: Some(tool.description),
1263 input_schema: tool.parameters,
1264 })
1265 .map(serde_json::to_value)
1266 .collect::<Result<Vec<_>, _>>()?;
1267 tools.append(&mut additional_tools);
1268
1269 let mut system = if let Some(preamble) = req.preamble {
1271 if preamble.is_empty() {
1272 vec![]
1273 } else {
1274 vec![SystemContent::Text {
1275 text: preamble,
1276 cache_control: None,
1277 }]
1278 }
1279 } else {
1280 vec![]
1281 };
1282 system.extend(history_system);
1283
1284 if prompt_caching {
1286 apply_cache_control(&mut system, &mut messages);
1287 }
1288
1289 let output_config = req.output_schema.map(|schema| {
1291 let mut schema_value = schema.to_value();
1292 sanitize_schema(&mut schema_value);
1293 OutputConfig {
1294 format: OutputFormat::JsonSchema {
1295 schema: schema_value,
1296 },
1297 }
1298 });
1299
1300 Ok(Self {
1301 model: model.to_string(),
1302 messages,
1303 max_tokens,
1304 system,
1305 temperature: req.temperature,
1306 tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
1307 tools,
1308 output_config,
1309 cache_control: if automatic_caching {
1311 Some(CacheControl::Ephemeral {
1312 ttl: automatic_caching_ttl,
1313 })
1314 } else {
1315 None
1316 },
1317 additional_params: if additional_params_payload.is_null() {
1318 None
1319 } else {
1320 Some(additional_params_payload)
1321 },
1322 })
1323 }
1324}
1325
1326fn extract_tools_from_additional_params(
1327 additional_params: &mut serde_json::Value,
1328) -> Result<Vec<serde_json::Value>, CompletionError> {
1329 if let Some(map) = additional_params.as_object_mut()
1330 && let Some(raw_tools) = map.remove("tools")
1331 {
1332 return serde_json::from_value::<Vec<serde_json::Value>>(raw_tools).map_err(|err| {
1333 CompletionError::RequestError(
1334 format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(),
1335 )
1336 });
1337 }
1338
1339 Ok(Vec::new())
1340}
1341
1342impl<T> completion::CompletionModel for CompletionModel<T>
1343where
1344 T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
1345{
1346 type Response = CompletionResponse;
1347 type StreamingResponse = StreamingCompletionResponse;
1348 type Client = Client<T>;
1349
1350 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1351 Self::new(client.clone(), model.into())
1352 }
1353
1354 async fn completion(
1355 &self,
1356 mut completion_request: completion::CompletionRequest,
1357 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1358 let request_model = completion_request
1359 .model
1360 .clone()
1361 .unwrap_or_else(|| self.model.clone());
1362 let span = if tracing::Span::current().is_disabled() {
1363 info_span!(
1364 target: "rig::completions",
1365 "chat",
1366 gen_ai.operation.name = "chat",
1367 gen_ai.provider.name = "anthropic",
1368 gen_ai.request.model = &request_model,
1369 gen_ai.system_instructions = &completion_request.preamble,
1370 gen_ai.response.id = tracing::field::Empty,
1371 gen_ai.response.model = tracing::field::Empty,
1372 gen_ai.usage.output_tokens = tracing::field::Empty,
1373 gen_ai.usage.input_tokens = tracing::field::Empty,
1374 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
1375 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
1376 )
1377 } else {
1378 tracing::Span::current()
1379 };
1380
1381 if completion_request.max_tokens.is_none() {
1383 if let Some(tokens) = self.default_max_tokens {
1384 completion_request.max_tokens = Some(tokens);
1385 } else {
1386 return Err(CompletionError::RequestError(
1387 "`max_tokens` must be set for Anthropic".into(),
1388 ));
1389 }
1390 }
1391
1392 let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
1393 model: &request_model,
1394 request: completion_request,
1395 prompt_caching: self.prompt_caching,
1396 automatic_caching: self.automatic_caching,
1397 automatic_caching_ttl: self.automatic_caching_ttl.clone(),
1398 })?;
1399
1400 if enabled!(Level::TRACE) {
1401 tracing::trace!(
1402 target: "rig::completions",
1403 "Anthropic completion request: {}",
1404 serde_json::to_string_pretty(&request)?
1405 );
1406 }
1407
1408 async move {
1409 let request: Vec<u8> = serde_json::to_vec(&request)?;
1410
1411 let req = self
1412 .client
1413 .post("/v1/messages")?
1414 .body(request)
1415 .map_err(|e| CompletionError::HttpError(e.into()))?;
1416
1417 let response = self
1418 .client
1419 .send::<_, Bytes>(req)
1420 .await
1421 .map_err(CompletionError::HttpError)?;
1422
1423 if response.status().is_success() {
1424 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
1425 response
1426 .into_body()
1427 .await
1428 .map_err(CompletionError::HttpError)?
1429 .to_vec()
1430 .as_slice(),
1431 )? {
1432 ApiResponse::Message(completion) => {
1433 let span = tracing::Span::current();
1434 span.record_response_metadata(&completion);
1435 span.record_token_usage(&completion.usage);
1436 if enabled!(Level::TRACE) {
1437 tracing::trace!(
1438 target: "rig::completions",
1439 "Anthropic completion response: {}",
1440 serde_json::to_string_pretty(&completion)?
1441 );
1442 }
1443 completion.try_into()
1444 }
1445 ApiResponse::Error(ApiErrorResponse { message }) => {
1446 Err(CompletionError::ResponseError(message))
1447 }
1448 }
1449 } else {
1450 let text: String = String::from_utf8_lossy(
1451 &response
1452 .into_body()
1453 .await
1454 .map_err(CompletionError::HttpError)?,
1455 )
1456 .into();
1457 Err(CompletionError::ProviderError(text))
1458 }
1459 }
1460 .instrument(span)
1461 .await
1462 }
1463
1464 async fn stream(
1465 &self,
1466 request: CompletionRequest,
1467 ) -> Result<
1468 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1469 CompletionError,
1470 > {
1471 CompletionModel::stream(self, request).await
1472 }
1473}
1474
1475#[derive(Debug, Deserialize)]
1476struct ApiErrorResponse {
1477 message: String,
1478}
1479
1480#[derive(Debug, Deserialize)]
1481#[serde(tag = "type", rename_all = "snake_case")]
1482enum ApiResponse<T> {
1483 Message(T),
1484 Error(ApiErrorResponse),
1485}
1486
1487#[cfg(test)]
1488mod tests {
1489 use super::*;
1490 use serde_json::json;
1491 use serde_path_to_error::deserialize;
1492
1493 #[test]
1494 fn test_deserialize_message() {
1495 let assistant_message_json = r#"
1496 {
1497 "role": "assistant",
1498 "content": "\n\nHello there, how may I assist you today?"
1499 }
1500 "#;
1501
1502 let assistant_message_json2 = r#"
1503 {
1504 "role": "assistant",
1505 "content": [
1506 {
1507 "type": "text",
1508 "text": "\n\nHello there, how may I assist you today?"
1509 },
1510 {
1511 "type": "tool_use",
1512 "id": "toolu_01A09q90qw90lq917835lq9",
1513 "name": "get_weather",
1514 "input": {"location": "San Francisco, CA"}
1515 }
1516 ]
1517 }
1518 "#;
1519
1520 let user_message_json = r#"
1521 {
1522 "role": "user",
1523 "content": [
1524 {
1525 "type": "image",
1526 "source": {
1527 "type": "base64",
1528 "media_type": "image/jpeg",
1529 "data": "/9j/4AAQSkZJRg..."
1530 }
1531 },
1532 {
1533 "type": "text",
1534 "text": "What is in this image?"
1535 },
1536 {
1537 "type": "tool_result",
1538 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1539 "content": "15 degrees"
1540 }
1541 ]
1542 }
1543 "#;
1544
1545 let assistant_message: Message = {
1546 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1547 deserialize(jd).unwrap_or_else(|err| {
1548 panic!("Deserialization error at {}: {}", err.path(), err);
1549 })
1550 };
1551
1552 let assistant_message2: Message = {
1553 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1554 deserialize(jd).unwrap_or_else(|err| {
1555 panic!("Deserialization error at {}: {}", err.path(), err);
1556 })
1557 };
1558
1559 let user_message: Message = {
1560 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1561 deserialize(jd).unwrap_or_else(|err| {
1562 panic!("Deserialization error at {}: {}", err.path(), err);
1563 })
1564 };
1565
1566 let Message { role, content } = assistant_message;
1567 assert_eq!(role, Role::Assistant);
1568 assert_eq!(
1569 content.first(),
1570 Content::Text {
1571 text: "\n\nHello there, how may I assist you today?".to_owned(),
1572 cache_control: None,
1573 }
1574 );
1575
1576 let Message { role, content } = assistant_message2;
1577 {
1578 assert_eq!(role, Role::Assistant);
1579 assert_eq!(content.len(), 2);
1580
1581 let mut iter = content.into_iter();
1582
1583 match iter.next().unwrap() {
1584 Content::Text { text, .. } => {
1585 assert_eq!(text, "\n\nHello there, how may I assist you today?");
1586 }
1587 _ => panic!("Expected text content"),
1588 }
1589
1590 match iter.next().unwrap() {
1591 Content::ToolUse { id, name, input } => {
1592 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1593 assert_eq!(name, "get_weather");
1594 assert_eq!(input, json!({"location": "San Francisco, CA"}));
1595 }
1596 _ => panic!("Expected tool use content"),
1597 }
1598
1599 assert_eq!(iter.next(), None);
1600 }
1601
1602 let Message { role, content } = user_message;
1603 {
1604 assert_eq!(role, Role::User);
1605 assert_eq!(content.len(), 3);
1606
1607 let mut iter = content.into_iter();
1608
1609 match iter.next().unwrap() {
1610 Content::Image { source, .. } => {
1611 assert_eq!(
1612 source,
1613 ImageSource::Base64 {
1614 data: "/9j/4AAQSkZJRg...".to_owned(),
1615 media_type: ImageFormat::JPEG,
1616 }
1617 );
1618 }
1619 _ => panic!("Expected image content"),
1620 }
1621
1622 match iter.next().unwrap() {
1623 Content::Text { text, .. } => {
1624 assert_eq!(text, "What is in this image?");
1625 }
1626 _ => panic!("Expected text content"),
1627 }
1628
1629 match iter.next().unwrap() {
1630 Content::ToolResult {
1631 tool_use_id,
1632 content,
1633 is_error,
1634 ..
1635 } => {
1636 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1637 assert_eq!(
1638 content.first(),
1639 ToolResultContent::Text {
1640 text: "15 degrees".to_owned()
1641 }
1642 );
1643 assert_eq!(is_error, None);
1644 }
1645 _ => panic!("Expected tool result content"),
1646 }
1647
1648 assert_eq!(iter.next(), None);
1649 }
1650 }
1651
1652 #[test]
1653 fn test_message_to_message_conversion() {
1654 let user_message: Message = serde_json::from_str(
1655 r#"
1656 {
1657 "role": "user",
1658 "content": [
1659 {
1660 "type": "image",
1661 "source": {
1662 "type": "base64",
1663 "media_type": "image/jpeg",
1664 "data": "/9j/4AAQSkZJRg..."
1665 }
1666 },
1667 {
1668 "type": "text",
1669 "text": "What is in this image?"
1670 },
1671 {
1672 "type": "document",
1673 "source": {
1674 "type": "base64",
1675 "data": "base64_encoded_pdf_data",
1676 "media_type": "application/pdf"
1677 }
1678 }
1679 ]
1680 }
1681 "#,
1682 )
1683 .unwrap();
1684
1685 let assistant_message = Message {
1686 role: Role::Assistant,
1687 content: OneOrMany::one(Content::ToolUse {
1688 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1689 name: "get_weather".to_string(),
1690 input: json!({"location": "San Francisco, CA"}),
1691 }),
1692 };
1693
1694 let tool_message = Message {
1695 role: Role::User,
1696 content: OneOrMany::one(Content::ToolResult {
1697 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1698 content: OneOrMany::one(ToolResultContent::Text {
1699 text: "15 degrees".to_string(),
1700 }),
1701 is_error: None,
1702 cache_control: None,
1703 }),
1704 };
1705
1706 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1707 let converted_assistant_message: message::Message =
1708 assistant_message.clone().try_into().unwrap();
1709 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1710
1711 match converted_user_message.clone() {
1712 message::Message::User { content } => {
1713 assert_eq!(content.len(), 3);
1714
1715 let mut iter = content.into_iter();
1716
1717 match iter.next().unwrap() {
1718 message::UserContent::Image(message::Image {
1719 data, media_type, ..
1720 }) => {
1721 assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1722 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1723 }
1724 _ => panic!("Expected image content"),
1725 }
1726
1727 match iter.next().unwrap() {
1728 message::UserContent::Text(message::Text { text }) => {
1729 assert_eq!(text, "What is in this image?");
1730 }
1731 _ => panic!("Expected text content"),
1732 }
1733
1734 match iter.next().unwrap() {
1735 message::UserContent::Document(message::Document {
1736 data, media_type, ..
1737 }) => {
1738 assert_eq!(
1739 data,
1740 DocumentSourceKind::String("base64_encoded_pdf_data".into())
1741 );
1742 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1743 }
1744 _ => panic!("Expected document content"),
1745 }
1746
1747 assert_eq!(iter.next(), None);
1748 }
1749 _ => panic!("Expected user message"),
1750 }
1751
1752 match converted_tool_message.clone() {
1753 message::Message::User { content } => {
1754 let message::ToolResult { id, content, .. } = match content.first() {
1755 message::UserContent::ToolResult(tool_result) => tool_result,
1756 _ => panic!("Expected tool result content"),
1757 };
1758 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1759 match content.first() {
1760 message::ToolResultContent::Text(message::Text { text }) => {
1761 assert_eq!(text, "15 degrees");
1762 }
1763 _ => panic!("Expected text content"),
1764 }
1765 }
1766 _ => panic!("Expected tool result content"),
1767 }
1768
1769 match converted_assistant_message.clone() {
1770 message::Message::Assistant { content, .. } => {
1771 assert_eq!(content.len(), 1);
1772
1773 match content.first() {
1774 message::AssistantContent::ToolCall(message::ToolCall {
1775 id, function, ..
1776 }) => {
1777 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1778 assert_eq!(function.name, "get_weather");
1779 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1780 }
1781 _ => panic!("Expected tool call content"),
1782 }
1783 }
1784 _ => panic!("Expected assistant message"),
1785 }
1786
1787 let original_user_message: Message = converted_user_message.try_into().unwrap();
1788 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1789 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1790
1791 assert_eq!(user_message, original_user_message);
1792 assert_eq!(assistant_message, original_assistant_message);
1793 assert_eq!(tool_message, original_tool_message);
1794 }
1795
1796 #[test]
1797 fn test_content_format_conversion() {
1798 use crate::completion::message::ContentFormat;
1799
1800 let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1801 assert_eq!(source_type, SourceType::URL);
1802
1803 let content_format: ContentFormat = SourceType::URL.into();
1804 assert_eq!(content_format, ContentFormat::Url);
1805
1806 let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1807 assert_eq!(source_type, SourceType::BASE64);
1808
1809 let content_format: ContentFormat = SourceType::BASE64.into();
1810 assert_eq!(content_format, ContentFormat::Base64);
1811
1812 let source_type: SourceType = ContentFormat::String.try_into().unwrap();
1813 assert_eq!(source_type, SourceType::TEXT);
1814
1815 let content_format: ContentFormat = SourceType::TEXT.into();
1816 assert_eq!(content_format, ContentFormat::String);
1817 }
1818
1819 #[test]
1820 fn test_cache_control_serialization() {
1821 let system = SystemContent::Text {
1823 text: "You are a helpful assistant.".to_string(),
1824 cache_control: Some(CacheControl::ephemeral()),
1825 };
1826 let json = serde_json::to_string(&system).unwrap();
1827 assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1828 assert!(json.contains(r#""type":"text""#));
1829
1830 let system_no_cache = SystemContent::Text {
1832 text: "Hello".to_string(),
1833 cache_control: None,
1834 };
1835 let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1836 assert!(!json_no_cache.contains("cache_control"));
1837
1838 let content = Content::Text {
1840 text: "Test message".to_string(),
1841 cache_control: Some(CacheControl::ephemeral()),
1842 };
1843 let json_content = serde_json::to_string(&content).unwrap();
1844 assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1845
1846 let mut system_vec = vec![SystemContent::Text {
1848 text: "System prompt".to_string(),
1849 cache_control: None,
1850 }];
1851 let mut messages = vec![
1852 Message {
1853 role: Role::User,
1854 content: OneOrMany::one(Content::Text {
1855 text: "First message".to_string(),
1856 cache_control: None,
1857 }),
1858 },
1859 Message {
1860 role: Role::Assistant,
1861 content: OneOrMany::one(Content::Text {
1862 text: "Response".to_string(),
1863 cache_control: None,
1864 }),
1865 },
1866 ];
1867
1868 apply_cache_control(&mut system_vec, &mut messages);
1869
1870 match &system_vec[0] {
1872 SystemContent::Text { cache_control, .. } => {
1873 assert!(cache_control.is_some());
1874 }
1875 }
1876
1877 for content in messages[0].content.iter() {
1880 if let Content::Text { cache_control, .. } = content {
1881 assert!(cache_control.is_none());
1882 }
1883 }
1884
1885 for content in messages[1].content.iter() {
1887 if let Content::Text { cache_control, .. } = content {
1888 assert!(cache_control.is_some());
1889 }
1890 }
1891 }
1892
1893 #[test]
1894 fn test_plaintext_document_serialization() {
1895 let content = Content::Document {
1896 source: DocumentSource::Text {
1897 data: "Hello, world!".to_string(),
1898 media_type: PlainTextMediaType::Plain,
1899 },
1900 cache_control: None,
1901 };
1902
1903 let json = serde_json::to_value(&content).unwrap();
1904 assert_eq!(json["type"], "document");
1905 assert_eq!(json["source"]["type"], "text");
1906 assert_eq!(json["source"]["media_type"], "text/plain");
1907 assert_eq!(json["source"]["data"], "Hello, world!");
1908 }
1909
1910 #[test]
1911 fn test_plaintext_document_deserialization() {
1912 let json = r#"
1913 {
1914 "type": "document",
1915 "source": {
1916 "type": "text",
1917 "media_type": "text/plain",
1918 "data": "Hello, world!"
1919 }
1920 }
1921 "#;
1922
1923 let content: Content = serde_json::from_str(json).unwrap();
1924 match content {
1925 Content::Document {
1926 source,
1927 cache_control,
1928 } => {
1929 assert_eq!(
1930 source,
1931 DocumentSource::Text {
1932 data: "Hello, world!".to_string(),
1933 media_type: PlainTextMediaType::Plain,
1934 }
1935 );
1936 assert_eq!(cache_control, None);
1937 }
1938 _ => panic!("Expected Document content"),
1939 }
1940 }
1941
1942 #[test]
1943 fn test_base64_pdf_document_serialization() {
1944 let content = Content::Document {
1945 source: DocumentSource::Base64 {
1946 data: "base64data".to_string(),
1947 media_type: DocumentFormat::PDF,
1948 },
1949 cache_control: None,
1950 };
1951
1952 let json = serde_json::to_value(&content).unwrap();
1953 assert_eq!(json["type"], "document");
1954 assert_eq!(json["source"]["type"], "base64");
1955 assert_eq!(json["source"]["media_type"], "application/pdf");
1956 assert_eq!(json["source"]["data"], "base64data");
1957 }
1958
1959 #[test]
1960 fn test_base64_pdf_document_deserialization() {
1961 let json = r#"
1962 {
1963 "type": "document",
1964 "source": {
1965 "type": "base64",
1966 "media_type": "application/pdf",
1967 "data": "base64data"
1968 }
1969 }
1970 "#;
1971
1972 let content: Content = serde_json::from_str(json).unwrap();
1973 match content {
1974 Content::Document { source, .. } => {
1975 assert_eq!(
1976 source,
1977 DocumentSource::Base64 {
1978 data: "base64data".to_string(),
1979 media_type: DocumentFormat::PDF,
1980 }
1981 );
1982 }
1983 _ => panic!("Expected Document content"),
1984 }
1985 }
1986
1987 #[test]
1988 fn test_plaintext_rig_to_anthropic_conversion() {
1989 use crate::completion::message as msg;
1990
1991 let rig_message = msg::Message::User {
1992 content: OneOrMany::one(msg::UserContent::document(
1993 "Some plain text content".to_string(),
1994 Some(msg::DocumentMediaType::TXT),
1995 )),
1996 };
1997
1998 let anthropic_message: Message = rig_message.try_into().unwrap();
1999 assert_eq!(anthropic_message.role, Role::User);
2000
2001 let mut iter = anthropic_message.content.into_iter();
2002 match iter.next().unwrap() {
2003 Content::Document { source, .. } => {
2004 assert_eq!(
2005 source,
2006 DocumentSource::Text {
2007 data: "Some plain text content".to_string(),
2008 media_type: PlainTextMediaType::Plain,
2009 }
2010 );
2011 }
2012 other => panic!("Expected Document content, got: {other:?}"),
2013 }
2014 }
2015
2016 #[test]
2017 fn test_plaintext_anthropic_to_rig_conversion() {
2018 use crate::completion::message as msg;
2019
2020 let anthropic_message = Message {
2021 role: Role::User,
2022 content: OneOrMany::one(Content::Document {
2023 source: DocumentSource::Text {
2024 data: "Some plain text content".to_string(),
2025 media_type: PlainTextMediaType::Plain,
2026 },
2027 cache_control: None,
2028 }),
2029 };
2030
2031 let rig_message: msg::Message = anthropic_message.try_into().unwrap();
2032 match rig_message {
2033 msg::Message::User { content } => {
2034 let mut iter = content.into_iter();
2035 match iter.next().unwrap() {
2036 msg::UserContent::Document(msg::Document {
2037 data, media_type, ..
2038 }) => {
2039 assert_eq!(
2040 data,
2041 DocumentSourceKind::String("Some plain text content".into())
2042 );
2043 assert_eq!(media_type, Some(msg::DocumentMediaType::TXT));
2044 }
2045 other => panic!("Expected Document content, got: {other:?}"),
2046 }
2047 }
2048 _ => panic!("Expected User message"),
2049 }
2050 }
2051
2052 #[test]
2053 fn test_plaintext_roundtrip_rig_to_anthropic_and_back() {
2054 use crate::completion::message as msg;
2055
2056 let original = msg::Message::User {
2057 content: OneOrMany::one(msg::UserContent::document(
2058 "Round trip text".to_string(),
2059 Some(msg::DocumentMediaType::TXT),
2060 )),
2061 };
2062
2063 let anthropic: Message = original.clone().try_into().unwrap();
2064 let back: msg::Message = anthropic.try_into().unwrap();
2065
2066 match (&original, &back) {
2067 (
2068 msg::Message::User {
2069 content: orig_content,
2070 },
2071 msg::Message::User {
2072 content: back_content,
2073 },
2074 ) => match (orig_content.first(), back_content.first()) {
2075 (
2076 msg::UserContent::Document(msg::Document {
2077 media_type: orig_mt,
2078 ..
2079 }),
2080 msg::UserContent::Document(msg::Document {
2081 media_type: back_mt,
2082 ..
2083 }),
2084 ) => {
2085 assert_eq!(orig_mt, back_mt);
2086 }
2087 _ => panic!("Expected Document content in both"),
2088 },
2089 _ => panic!("Expected User messages"),
2090 }
2091 }
2092
2093 #[test]
2094 fn test_unsupported_document_type_returns_error() {
2095 use crate::completion::message as msg;
2096
2097 let rig_message = msg::Message::User {
2098 content: OneOrMany::one(msg::UserContent::Document(msg::Document {
2099 data: DocumentSourceKind::String("data".into()),
2100 media_type: Some(msg::DocumentMediaType::HTML),
2101 additional_params: None,
2102 })),
2103 };
2104
2105 let result: Result<Message, _> = rig_message.try_into();
2106 assert!(result.is_err());
2107 let err = result.unwrap_err().to_string();
2108 assert!(
2109 err.contains("Anthropic only supports PDF and plain text documents"),
2110 "Unexpected error: {err}"
2111 );
2112 }
2113
2114 #[test]
2115 fn test_plaintext_document_url_source_returns_error() {
2116 use crate::completion::message as msg;
2117
2118 let rig_message = msg::Message::User {
2119 content: OneOrMany::one(msg::UserContent::Document(msg::Document {
2120 data: DocumentSourceKind::Url("https://example.com/doc.txt".into()),
2121 media_type: Some(msg::DocumentMediaType::TXT),
2122 additional_params: None,
2123 })),
2124 };
2125
2126 let result: Result<Message, _> = rig_message.try_into();
2127 assert!(result.is_err());
2128 let err = result.unwrap_err().to_string();
2129 assert!(
2130 err.contains("Only string or base64 data is supported for plain text documents"),
2131 "Unexpected error: {err}"
2132 );
2133 }
2134
2135 #[test]
2136 fn test_plaintext_document_with_cache_control() {
2137 let content = Content::Document {
2138 source: DocumentSource::Text {
2139 data: "cached text".to_string(),
2140 media_type: PlainTextMediaType::Plain,
2141 },
2142 cache_control: Some(CacheControl::ephemeral()),
2143 };
2144
2145 let json = serde_json::to_value(&content).unwrap();
2146 assert_eq!(json["source"]["type"], "text");
2147 assert_eq!(json["source"]["media_type"], "text/plain");
2148 assert_eq!(json["cache_control"]["type"], "ephemeral");
2149 }
2150
2151 #[test]
2152 fn test_message_with_plaintext_document_deserialization() {
2153 let json = r#"
2154 {
2155 "role": "user",
2156 "content": [
2157 {
2158 "type": "document",
2159 "source": {
2160 "type": "text",
2161 "media_type": "text/plain",
2162 "data": "Hello from a text file"
2163 }
2164 },
2165 {
2166 "type": "text",
2167 "text": "Summarize this document."
2168 }
2169 ]
2170 }
2171 "#;
2172
2173 let message: Message = serde_json::from_str(json).unwrap();
2174 assert_eq!(message.role, Role::User);
2175 assert_eq!(message.content.len(), 2);
2176
2177 let mut iter = message.content.into_iter();
2178
2179 match iter.next().unwrap() {
2180 Content::Document { source, .. } => {
2181 assert_eq!(
2182 source,
2183 DocumentSource::Text {
2184 data: "Hello from a text file".to_string(),
2185 media_type: PlainTextMediaType::Plain,
2186 }
2187 );
2188 }
2189 _ => panic!("Expected Document content"),
2190 }
2191
2192 match iter.next().unwrap() {
2193 Content::Text { text, .. } => {
2194 assert_eq!(text, "Summarize this document.");
2195 }
2196 _ => panic!("Expected Text content"),
2197 }
2198 }
2199
2200 #[test]
2201 fn test_assistant_reasoning_multiblock_to_anthropic_content() {
2202 let reasoning = message::Reasoning {
2203 id: None,
2204 content: vec![
2205 message::ReasoningContent::Text {
2206 text: "step one".to_string(),
2207 signature: Some("sig-1".to_string()),
2208 },
2209 message::ReasoningContent::Summary("summary".to_string()),
2210 message::ReasoningContent::Text {
2211 text: "step two".to_string(),
2212 signature: Some("sig-2".to_string()),
2213 },
2214 message::ReasoningContent::Redacted {
2215 data: "redacted block".to_string(),
2216 },
2217 ],
2218 };
2219
2220 let msg = message::Message::Assistant {
2221 id: None,
2222 content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2223 };
2224 let converted: Message = msg.try_into().expect("convert assistant message");
2225 let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2226
2227 assert_eq!(converted.role, Role::Assistant);
2228 assert_eq!(converted_content.len(), 4);
2229 assert!(matches!(
2230 converted_content.first(),
2231 Some(Content::Thinking { thinking, signature: Some(signature) })
2232 if thinking == "step one" && signature == "sig-1"
2233 ));
2234 assert!(matches!(
2235 converted_content.get(1),
2236 Some(Content::Thinking { thinking, signature: None }) if thinking == "summary"
2237 ));
2238 assert!(matches!(
2239 converted_content.get(2),
2240 Some(Content::Thinking { thinking, signature: Some(signature) })
2241 if thinking == "step two" && signature == "sig-2"
2242 ));
2243 assert!(matches!(
2244 converted_content.get(3),
2245 Some(Content::RedactedThinking { data }) if data == "redacted block"
2246 ));
2247 }
2248
2249 #[test]
2250 fn test_redacted_thinking_content_to_assistant_reasoning() {
2251 let content = Content::RedactedThinking {
2252 data: "opaque-redacted".to_string(),
2253 };
2254 let converted: message::AssistantContent =
2255 content.try_into().expect("convert redacted thinking");
2256
2257 assert!(matches!(
2258 converted,
2259 message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2260 if matches!(
2261 content.first(),
2262 Some(message::ReasoningContent::Redacted { data }) if data == "opaque-redacted"
2263 )
2264 ));
2265 }
2266
2267 #[test]
2268 fn test_assistant_encrypted_reasoning_maps_to_redacted_thinking() {
2269 let reasoning = message::Reasoning {
2270 id: None,
2271 content: vec![message::ReasoningContent::Encrypted(
2272 "ciphertext".to_string(),
2273 )],
2274 };
2275 let msg = message::Message::Assistant {
2276 id: None,
2277 content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2278 };
2279
2280 let converted: Message = msg.try_into().expect("convert assistant message");
2281 let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2282
2283 assert_eq!(converted_content.len(), 1);
2284 assert!(matches!(
2285 converted_content.first(),
2286 Some(Content::RedactedThinking { data }) if data == "ciphertext"
2287 ));
2288 }
2289}