1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError, GetTokenUsage},
6 http_client::HttpClientExt,
7 message::{self, DocumentMediaType, DocumentSourceKind, MessageError, MimeType, Reasoning},
8 one_or_many::string_or_one_or_many,
9 telemetry::{ProviderResponseExt, SpanCombinator},
10 wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, Level, enabled, info_span};
20
21pub const CLAUDE_OPUS_4_6: &str = "claude-opus-4-6";
27pub const CLAUDE_SONNET_4_6: &str = "claude-sonnet-4-6";
29pub const CLAUDE_HAIKU_4_5: &str = "claude-haiku-4-5";
31
32pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
33pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
34pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
35
36#[derive(Debug, Deserialize, Serialize)]
37pub struct CompletionResponse {
38 pub content: Vec<Content>,
39 pub id: String,
40 pub model: String,
41 pub role: String,
42 pub stop_reason: Option<String>,
43 pub stop_sequence: Option<String>,
44 pub usage: Usage,
45}
46
47impl ProviderResponseExt for CompletionResponse {
48 type OutputMessage = Content;
49 type Usage = Usage;
50
51 fn get_response_id(&self) -> Option<String> {
52 Some(self.id.to_owned())
53 }
54
55 fn get_response_model_name(&self) -> Option<String> {
56 Some(self.model.to_owned())
57 }
58
59 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
60 self.content.clone()
61 }
62
63 fn get_text_response(&self) -> Option<String> {
64 let res = self
65 .content
66 .iter()
67 .filter_map(|x| {
68 if let Content::Text { text, .. } = x {
69 Some(text.to_owned())
70 } else {
71 None
72 }
73 })
74 .collect::<Vec<String>>()
75 .join("\n");
76
77 if res.is_empty() { None } else { Some(res) }
78 }
79
80 fn get_usage(&self) -> Option<Self::Usage> {
81 Some(self.usage.clone())
82 }
83}
84
85#[derive(Clone, Debug, Deserialize, Serialize)]
86pub struct Usage {
87 pub input_tokens: u64,
88 pub cache_read_input_tokens: Option<u64>,
89 pub cache_creation_input_tokens: Option<u64>,
90 pub output_tokens: u64,
91}
92
93impl std::fmt::Display for Usage {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 write!(
96 f,
97 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
98 self.input_tokens,
99 match self.cache_read_input_tokens {
100 Some(token) => token.to_string(),
101 None => "n/a".to_string(),
102 },
103 match self.cache_creation_input_tokens {
104 Some(token) => token.to_string(),
105 None => "n/a".to_string(),
106 },
107 self.output_tokens
108 )
109 }
110}
111
112impl GetTokenUsage for Usage {
113 fn token_usage(&self) -> Option<crate::completion::Usage> {
114 let mut usage = crate::completion::Usage::new();
115
116 usage.input_tokens = self.input_tokens;
117 usage.output_tokens = self.output_tokens;
118 usage.cached_input_tokens = self.cache_read_input_tokens.unwrap_or_default();
119 usage.cache_creation_input_tokens = self.cache_creation_input_tokens.unwrap_or_default();
120 usage.total_tokens = self.input_tokens
121 + self.cache_read_input_tokens.unwrap_or_default()
122 + self.cache_creation_input_tokens.unwrap_or_default()
123 + self.output_tokens;
124
125 Some(usage)
126 }
127}
128
129#[derive(Debug, Deserialize, Serialize)]
130pub struct ToolDefinition {
131 pub name: String,
132 pub description: Option<String>,
133 pub input_schema: serde_json::Value,
134}
135
136#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Default)]
142pub enum CacheTtl {
143 #[default]
145 #[serde(rename = "5m")]
146 FiveMinutes,
147 #[serde(rename = "1h")]
149 OneHour,
150}
151
152#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
157#[serde(tag = "type", rename_all = "snake_case")]
158pub enum CacheControl {
159 Ephemeral {
160 #[serde(skip_serializing_if = "Option::is_none")]
162 ttl: Option<CacheTtl>,
163 },
164}
165
166impl CacheControl {
167 pub fn ephemeral() -> Self {
169 Self::Ephemeral { ttl: None }
170 }
171
172 pub fn ephemeral_1h() -> Self {
174 Self::Ephemeral {
175 ttl: Some(CacheTtl::OneHour),
176 }
177 }
178}
179
180#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
182#[serde(tag = "type", rename_all = "snake_case")]
183pub enum SystemContent {
184 Text {
185 text: String,
186 #[serde(skip_serializing_if = "Option::is_none")]
187 cache_control: Option<CacheControl>,
188 },
189}
190
191impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
192 type Error = CompletionError;
193
194 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
195 let content = response
196 .content
197 .iter()
198 .map(|content| content.clone().try_into())
199 .collect::<Result<Vec<_>, _>>()?;
200
201 let choice = OneOrMany::many(content).map_err(|_| {
202 CompletionError::ResponseError(
203 "Response contained no message or tool call (empty)".to_owned(),
204 )
205 })?;
206
207 let usage = completion::Usage {
208 input_tokens: response.usage.input_tokens,
209 output_tokens: response.usage.output_tokens,
210 total_tokens: response.usage.input_tokens
211 + response.usage.cache_read_input_tokens.unwrap_or(0)
212 + response.usage.cache_creation_input_tokens.unwrap_or(0)
213 + response.usage.output_tokens,
214 cached_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
215 cache_creation_input_tokens: response.usage.cache_creation_input_tokens.unwrap_or(0),
216 };
217
218 Ok(completion::CompletionResponse {
219 choice,
220 usage,
221 raw_response: response,
222 message_id: None,
223 })
224 }
225}
226
227#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
228pub struct Message {
229 pub role: Role,
230 #[serde(deserialize_with = "string_or_one_or_many")]
231 pub content: OneOrMany<Content>,
232}
233
234#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
235#[serde(rename_all = "lowercase")]
236pub enum Role {
237 User,
238 Assistant,
239}
240
241#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
242#[serde(tag = "type", rename_all = "snake_case")]
243pub enum Content {
244 Text {
245 text: String,
246 #[serde(skip_serializing_if = "Option::is_none")]
247 cache_control: Option<CacheControl>,
248 },
249 Image {
250 source: ImageSource,
251 #[serde(skip_serializing_if = "Option::is_none")]
252 cache_control: Option<CacheControl>,
253 },
254 ToolUse {
255 id: String,
256 name: String,
257 input: serde_json::Value,
258 },
259 ToolResult {
260 tool_use_id: String,
261 #[serde(deserialize_with = "string_or_one_or_many")]
262 content: OneOrMany<ToolResultContent>,
263 #[serde(skip_serializing_if = "Option::is_none")]
264 is_error: Option<bool>,
265 #[serde(skip_serializing_if = "Option::is_none")]
266 cache_control: Option<CacheControl>,
267 },
268 Document {
269 source: DocumentSource,
270 #[serde(skip_serializing_if = "Option::is_none")]
271 cache_control: Option<CacheControl>,
272 },
273 Thinking {
274 thinking: String,
275 #[serde(skip_serializing_if = "Option::is_none")]
276 signature: Option<String>,
277 },
278 RedactedThinking {
279 data: String,
280 },
281}
282
283impl FromStr for Content {
284 type Err = Infallible;
285
286 fn from_str(s: &str) -> Result<Self, Self::Err> {
287 Ok(Content::Text {
288 text: s.to_owned(),
289 cache_control: None,
290 })
291 }
292}
293
294#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
295#[serde(tag = "type", rename_all = "snake_case")]
296pub enum ToolResultContent {
297 Text { text: String },
298 Image(ImageSource),
299}
300
301impl FromStr for ToolResultContent {
302 type Err = Infallible;
303
304 fn from_str(s: &str) -> Result<Self, Self::Err> {
305 Ok(ToolResultContent::Text { text: s.to_owned() })
306 }
307}
308
309#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
317#[serde(tag = "type", rename_all = "snake_case")]
318pub enum ImageSource {
319 #[serde(rename = "base64")]
320 Base64 {
321 data: String,
322 media_type: ImageFormat,
323 },
324 #[serde(rename = "url")]
325 Url { url: String },
326}
327
328#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
339#[serde(tag = "type", rename_all = "snake_case")]
340pub enum DocumentSource {
341 Base64 {
342 data: String,
343 media_type: DocumentFormat,
344 },
345 Text {
346 data: String,
347 media_type: PlainTextMediaType,
348 },
349 Url {
350 url: String,
351 },
352}
353
354#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
355#[serde(rename_all = "lowercase")]
356pub enum ImageFormat {
357 #[serde(rename = "image/jpeg")]
358 JPEG,
359 #[serde(rename = "image/png")]
360 PNG,
361 #[serde(rename = "image/gif")]
362 GIF,
363 #[serde(rename = "image/webp")]
364 WEBP,
365}
366
367#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
374#[serde(rename_all = "lowercase")]
375pub enum DocumentFormat {
376 #[serde(rename = "application/pdf")]
377 PDF,
378}
379
380#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
386pub enum PlainTextMediaType {
387 #[serde(rename = "text/plain")]
388 Plain,
389}
390
391#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
392#[serde(rename_all = "lowercase")]
393pub enum SourceType {
394 BASE64,
395 URL,
396 TEXT,
397}
398
399impl From<String> for Content {
400 fn from(text: String) -> Self {
401 Content::Text {
402 text,
403 cache_control: None,
404 }
405 }
406}
407
408impl From<String> for ToolResultContent {
409 fn from(text: String) -> Self {
410 ToolResultContent::Text { text }
411 }
412}
413
414impl TryFrom<message::ContentFormat> for SourceType {
415 type Error = MessageError;
416
417 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
418 match format {
419 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
420 message::ContentFormat::Url => Ok(SourceType::URL),
421 message::ContentFormat::String => Ok(SourceType::TEXT),
422 }
423 }
424}
425
426impl From<SourceType> for message::ContentFormat {
427 fn from(source_type: SourceType) -> Self {
428 match source_type {
429 SourceType::BASE64 => message::ContentFormat::Base64,
430 SourceType::URL => message::ContentFormat::Url,
431 SourceType::TEXT => message::ContentFormat::String,
432 }
433 }
434}
435
436impl TryFrom<message::ImageMediaType> for ImageFormat {
437 type Error = MessageError;
438
439 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
440 Ok(match media_type {
441 message::ImageMediaType::JPEG => ImageFormat::JPEG,
442 message::ImageMediaType::PNG => ImageFormat::PNG,
443 message::ImageMediaType::GIF => ImageFormat::GIF,
444 message::ImageMediaType::WEBP => ImageFormat::WEBP,
445 _ => {
446 return Err(MessageError::ConversionError(
447 format!("Unsupported image media type: {media_type:?}").to_owned(),
448 ));
449 }
450 })
451 }
452}
453
454impl From<ImageFormat> for message::ImageMediaType {
455 fn from(format: ImageFormat) -> Self {
456 match format {
457 ImageFormat::JPEG => message::ImageMediaType::JPEG,
458 ImageFormat::PNG => message::ImageMediaType::PNG,
459 ImageFormat::GIF => message::ImageMediaType::GIF,
460 ImageFormat::WEBP => message::ImageMediaType::WEBP,
461 }
462 }
463}
464
465impl TryFrom<DocumentMediaType> for DocumentFormat {
466 type Error = MessageError;
467 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
468 match value {
469 DocumentMediaType::PDF => Ok(DocumentFormat::PDF),
470 other => Err(MessageError::ConversionError(format!(
471 "DocumentFormat only supports PDF for base64 sources, got: {}",
472 other.to_mime_type()
473 ))),
474 }
475 }
476}
477
478impl TryFrom<message::AssistantContent> for Content {
479 type Error = MessageError;
480 fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
481 match text {
482 message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
483 text,
484 cache_control: None,
485 }),
486 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
487 "Anthropic currently doesn't support images.".to_string(),
488 )),
489 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
490 Ok(Content::ToolUse {
491 id,
492 name: function.name,
493 input: function.arguments,
494 })
495 }
496 message::AssistantContent::Reasoning(reasoning) => Ok(Content::Thinking {
497 thinking: reasoning.display_text(),
498 signature: reasoning.first_signature().map(str::to_owned),
499 }),
500 }
501 }
502}
503
504fn anthropic_content_from_assistant_content(
505 content: message::AssistantContent,
506) -> Result<Vec<Content>, MessageError> {
507 match content {
508 message::AssistantContent::Text(message::Text { text }) => Ok(vec![Content::Text {
509 text,
510 cache_control: None,
511 }]),
512 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
513 "Anthropic currently doesn't support images.".to_string(),
514 )),
515 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
516 Ok(vec![Content::ToolUse {
517 id,
518 name: function.name,
519 input: function.arguments,
520 }])
521 }
522 message::AssistantContent::Reasoning(reasoning) => {
523 let mut converted = Vec::new();
524 for block in reasoning.content {
525 match block {
526 message::ReasoningContent::Text { text, signature } => {
527 converted.push(Content::Thinking {
528 thinking: text,
529 signature,
530 });
531 }
532 message::ReasoningContent::Summary(summary) => {
533 converted.push(Content::Thinking {
534 thinking: summary,
535 signature: None,
536 });
537 }
538 message::ReasoningContent::Redacted { data }
539 | message::ReasoningContent::Encrypted(data) => {
540 converted.push(Content::RedactedThinking { data });
541 }
542 }
543 }
544
545 if converted.is_empty() {
546 return Err(MessageError::ConversionError(
547 "Cannot convert empty reasoning content to Anthropic format".to_string(),
548 ));
549 }
550
551 Ok(converted)
552 }
553 }
554}
555
556impl TryFrom<message::Message> for Message {
557 type Error = MessageError;
558
559 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
560 Ok(match message {
561 message::Message::User { content } => Message {
562 role: Role::User,
563 content: content.try_map(|content| match content {
564 message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
565 text,
566 cache_control: None,
567 }),
568 message::UserContent::ToolResult(message::ToolResult {
569 id, content, ..
570 }) => Ok(Content::ToolResult {
571 tool_use_id: id,
572 content: content.try_map(|content| match content {
573 message::ToolResultContent::Text(message::Text { text }) => {
574 Ok(ToolResultContent::Text { text })
575 }
576 message::ToolResultContent::Image(image) => {
577 let DocumentSourceKind::Base64(data) = image.data else {
578 return Err(MessageError::ConversionError(
579 "Only base64 strings can be used with the Anthropic API"
580 .to_string(),
581 ));
582 };
583 let media_type =
584 image.media_type.ok_or(MessageError::ConversionError(
585 "Image media type is required".to_owned(),
586 ))?;
587 Ok(ToolResultContent::Image(ImageSource::Base64 {
588 data,
589 media_type: media_type.try_into()?,
590 }))
591 }
592 })?,
593 is_error: None,
594 cache_control: None,
595 }),
596 message::UserContent::Image(message::Image {
597 data, media_type, ..
598 }) => {
599 let source = match data {
600 DocumentSourceKind::Base64(data) => {
601 let media_type =
602 media_type.ok_or(MessageError::ConversionError(
603 "Image media type is required for Claude API".to_string(),
604 ))?;
605 ImageSource::Base64 {
606 data,
607 media_type: ImageFormat::try_from(media_type)?,
608 }
609 }
610 DocumentSourceKind::Url(url) => ImageSource::Url { url },
611 DocumentSourceKind::Unknown => {
612 return Err(MessageError::ConversionError(
613 "Image content has no body".into(),
614 ));
615 }
616 doc => {
617 return Err(MessageError::ConversionError(format!(
618 "Unsupported document type: {doc:?}"
619 )));
620 }
621 };
622
623 Ok(Content::Image {
624 source,
625 cache_control: None,
626 })
627 }
628 message::UserContent::Document(message::Document {
629 data, media_type, ..
630 }) => {
631 let media_type = media_type.ok_or(MessageError::ConversionError(
632 "Document media type is required".to_string(),
633 ))?;
634
635 let source = match media_type {
636 DocumentMediaType::PDF => {
637 let data = match data {
638 DocumentSourceKind::Base64(data)
639 | DocumentSourceKind::String(data) => data,
640 _ => {
641 return Err(MessageError::ConversionError(
642 "Only base64 encoded data is supported for PDF documents".into(),
643 ));
644 }
645 };
646 DocumentSource::Base64 {
647 data,
648 media_type: DocumentFormat::PDF,
649 }
650 }
651 DocumentMediaType::TXT => {
652 let data = match data {
653 DocumentSourceKind::String(data)
654 | DocumentSourceKind::Base64(data) => data,
655 _ => {
656 return Err(MessageError::ConversionError(
657 "Only string or base64 data is supported for plain text documents".into(),
658 ));
659 }
660 };
661 DocumentSource::Text {
662 data,
663 media_type: PlainTextMediaType::Plain,
664 }
665 }
666 other => {
667 return Err(MessageError::ConversionError(format!(
668 "Anthropic only supports PDF and plain text documents, got: {}",
669 other.to_mime_type()
670 )));
671 }
672 };
673
674 Ok(Content::Document {
675 source,
676 cache_control: None,
677 })
678 }
679 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
680 "Audio is not supported in Anthropic".to_owned(),
681 )),
682 message::UserContent::Video { .. } => Err(MessageError::ConversionError(
683 "Video is not supported in Anthropic".to_owned(),
684 )),
685 })?,
686 },
687
688 message::Message::System { content } => Message {
689 role: Role::User,
690 content: OneOrMany::one(Content::Text {
691 text: content,
692 cache_control: None,
693 }),
694 },
695
696 message::Message::Assistant { content, .. } => {
697 let converted_content = content.into_iter().try_fold(
698 Vec::new(),
699 |mut accumulated, assistant_content| {
700 accumulated
701 .extend(anthropic_content_from_assistant_content(assistant_content)?);
702 Ok::<Vec<Content>, MessageError>(accumulated)
703 },
704 )?;
705
706 Message {
707 content: OneOrMany::many(converted_content).map_err(|_| {
708 MessageError::ConversionError(
709 "Assistant message did not contain Anthropic-compatible content"
710 .to_owned(),
711 )
712 })?,
713 role: Role::Assistant,
714 }
715 }
716 })
717 }
718}
719
720impl TryFrom<Content> for message::AssistantContent {
721 type Error = MessageError;
722
723 fn try_from(content: Content) -> Result<Self, Self::Error> {
724 Ok(match content {
725 Content::Text { text, .. } => message::AssistantContent::text(text),
726 Content::ToolUse { id, name, input } => {
727 message::AssistantContent::tool_call(id, name, input)
728 }
729 Content::Thinking {
730 thinking,
731 signature,
732 } => message::AssistantContent::Reasoning(Reasoning::new_with_signature(
733 &thinking, signature,
734 )),
735 Content::RedactedThinking { data } => {
736 message::AssistantContent::Reasoning(Reasoning::redacted(data))
737 }
738 _ => {
739 return Err(MessageError::ConversionError(
740 "Content did not contain a message, tool call, or reasoning".to_owned(),
741 ));
742 }
743 })
744 }
745}
746
747impl From<ToolResultContent> for message::ToolResultContent {
748 fn from(content: ToolResultContent) -> Self {
749 match content {
750 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
751 ToolResultContent::Image(source) => match source {
752 ImageSource::Base64 { data, media_type } => {
753 message::ToolResultContent::image_base64(data, Some(media_type.into()), None)
754 }
755 ImageSource::Url { url } => message::ToolResultContent::image_url(url, None, None),
756 },
757 }
758 }
759}
760
761impl TryFrom<Message> for message::Message {
762 type Error = MessageError;
763
764 fn try_from(message: Message) -> Result<Self, Self::Error> {
765 Ok(match message.role {
766 Role::User => message::Message::User {
767 content: message.content.try_map(|content| {
768 Ok(match content {
769 Content::Text { text, .. } => message::UserContent::text(text),
770 Content::ToolResult {
771 tool_use_id,
772 content,
773 ..
774 } => message::UserContent::tool_result(
775 tool_use_id,
776 content.map(|content| content.into()),
777 ),
778 Content::Image { source, .. } => match source {
779 ImageSource::Base64 { data, media_type } => {
780 message::UserContent::Image(message::Image {
781 data: DocumentSourceKind::Base64(data),
782 media_type: Some(media_type.into()),
783 detail: None,
784 additional_params: None,
785 })
786 }
787 ImageSource::Url { url } => {
788 message::UserContent::Image(message::Image {
789 data: DocumentSourceKind::Url(url),
790 media_type: None,
791 detail: None,
792 additional_params: None,
793 })
794 }
795 },
796 Content::Document { source, .. } => match source {
797 DocumentSource::Base64 { data, media_type } => {
798 let rig_media_type = match media_type {
799 DocumentFormat::PDF => message::DocumentMediaType::PDF,
800 };
801 message::UserContent::document(data, Some(rig_media_type))
802 }
803 DocumentSource::Text { data, .. } => message::UserContent::document(
804 data,
805 Some(message::DocumentMediaType::TXT),
806 ),
807 DocumentSource::Url { url } => {
808 message::UserContent::document_url(url, None)
809 }
810 },
811 _ => {
812 return Err(MessageError::ConversionError(
813 "Unsupported content type for User role".to_owned(),
814 ));
815 }
816 })
817 })?,
818 },
819 Role::Assistant => message::Message::Assistant {
820 id: None,
821 content: message.content.try_map(|content| content.try_into())?,
822 },
823 })
824 }
825}
826
827#[derive(Clone)]
828pub struct CompletionModel<T = reqwest::Client> {
829 pub(crate) client: Client<T>,
830 pub model: String,
831 pub default_max_tokens: Option<u64>,
832 pub prompt_caching: bool,
834 pub automatic_caching: bool,
838 pub automatic_caching_ttl: Option<CacheTtl>,
842}
843
844impl<T> CompletionModel<T>
845where
846 T: HttpClientExt,
847{
848 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
849 let model = model.into();
850 let default_max_tokens = default_max_tokens_for_model(&model);
851
852 Self {
853 client,
854 model,
855 default_max_tokens,
856 prompt_caching: false,
857 automatic_caching: false,
858 automatic_caching_ttl: None,
859 }
860 }
861
862 pub fn with_model(client: Client<T>, model: &str) -> Self {
863 Self {
864 client,
865 model: model.to_string(),
866 default_max_tokens: Some(default_max_tokens_with_fallback(model)),
867 prompt_caching: false,
868 automatic_caching: false,
869 automatic_caching_ttl: None,
870 }
871 }
872
873 pub fn with_prompt_caching(mut self) -> Self {
881 self.prompt_caching = true;
882 self
883 }
884
885 pub fn with_automatic_caching(mut self) -> Self {
918 self.automatic_caching = true;
919 self
920 }
921
922 pub fn with_automatic_caching_1h(mut self) -> Self {
939 self.automatic_caching = true;
940 self.automatic_caching_ttl = Some(CacheTtl::OneHour);
941 self
942 }
943}
944
945fn default_max_tokens_for_model(model: &str) -> Option<u64> {
949 if model.starts_with("claude-opus-4-6") {
950 Some(128_000)
951 } else if model.starts_with("claude-opus-4")
952 || model.starts_with("claude-sonnet-4")
953 || model.starts_with("claude-haiku-4-5")
954 {
955 Some(64_000)
956 } else {
957 None
958 }
959}
960
961fn default_max_tokens_with_fallback(model: &str) -> u64 {
962 default_max_tokens_for_model(model).unwrap_or(2_048)
963}
964
965#[derive(Debug, Deserialize, Serialize)]
966pub struct Metadata {
967 user_id: Option<String>,
968}
969
970#[derive(Default, Debug, Serialize, Deserialize)]
971#[serde(tag = "type", rename_all = "snake_case")]
972pub enum ToolChoice {
973 #[default]
974 Auto,
975 Any,
976 None,
977 Tool {
978 name: String,
979 },
980}
981impl TryFrom<message::ToolChoice> for ToolChoice {
982 type Error = CompletionError;
983
984 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
985 let res = match value {
986 message::ToolChoice::Auto => Self::Auto,
987 message::ToolChoice::None => Self::None,
988 message::ToolChoice::Required => Self::Any,
989 message::ToolChoice::Specific { function_names } => {
990 if function_names.len() != 1 {
991 return Err(CompletionError::ProviderError(
992 "Only one tool may be specified to be used by Claude".into(),
993 ));
994 }
995
996 Self::Tool {
997 name: function_names.first().unwrap().to_string(),
998 }
999 }
1000 };
1001
1002 Ok(res)
1003 }
1004}
1005
1006fn sanitize_schema(schema: &mut serde_json::Value) {
1012 use serde_json::Value;
1013
1014 if let Value::Object(obj) = schema {
1015 let is_object_schema = obj.get("type") == Some(&Value::String("object".to_string()))
1016 || obj.contains_key("properties");
1017
1018 if is_object_schema && !obj.contains_key("additionalProperties") {
1019 obj.insert("additionalProperties".to_string(), Value::Bool(false));
1020 }
1021
1022 if let Some(Value::Object(properties)) = obj.get("properties") {
1023 let prop_keys = properties.keys().cloned().map(Value::String).collect();
1024 obj.insert("required".to_string(), Value::Array(prop_keys));
1025 }
1026
1027 let is_numeric_schema = obj.get("type") == Some(&Value::String("integer".to_string()))
1029 || obj.get("type") == Some(&Value::String("number".to_string()));
1030
1031 if is_numeric_schema {
1032 for key in [
1033 "minimum",
1034 "maximum",
1035 "exclusiveMinimum",
1036 "exclusiveMaximum",
1037 "multipleOf",
1038 ] {
1039 obj.remove(key);
1040 }
1041 }
1042
1043 if let Some(defs) = obj.get_mut("$defs")
1044 && let Value::Object(defs_obj) = defs
1045 {
1046 for (_, def_schema) in defs_obj.iter_mut() {
1047 sanitize_schema(def_schema);
1048 }
1049 }
1050
1051 if let Some(properties) = obj.get_mut("properties")
1052 && let Value::Object(props) = properties
1053 {
1054 for (_, prop_value) in props.iter_mut() {
1055 sanitize_schema(prop_value);
1056 }
1057 }
1058
1059 if let Some(items) = obj.get_mut("items") {
1060 sanitize_schema(items);
1061 }
1062
1063 if let Some(one_of) = obj.remove("oneOf") {
1065 match obj.get_mut("anyOf") {
1066 Some(Value::Array(existing)) => {
1067 if let Value::Array(mut incoming) = one_of {
1068 existing.append(&mut incoming);
1069 }
1070 }
1071 _ => {
1072 obj.insert("anyOf".to_string(), one_of);
1073 }
1074 }
1075 }
1076
1077 for key in ["anyOf", "allOf"] {
1078 if let Some(variants) = obj.get_mut(key)
1079 && let Value::Array(variants_array) = variants
1080 {
1081 for variant in variants_array.iter_mut() {
1082 sanitize_schema(variant);
1083 }
1084 }
1085 }
1086 }
1087}
1088
1089#[derive(Debug, Deserialize, Serialize)]
1092#[serde(tag = "type", rename_all = "snake_case")]
1093enum OutputFormat {
1094 JsonSchema { schema: serde_json::Value },
1096}
1097
1098#[derive(Debug, Deserialize, Serialize)]
1100struct OutputConfig {
1101 format: OutputFormat,
1102}
1103
1104#[derive(Debug, Deserialize, Serialize)]
1105struct AnthropicCompletionRequest {
1106 model: String,
1107 messages: Vec<Message>,
1108 max_tokens: u64,
1109 #[serde(skip_serializing_if = "Vec::is_empty")]
1111 system: Vec<SystemContent>,
1112 #[serde(skip_serializing_if = "Option::is_none")]
1113 temperature: Option<f64>,
1114 #[serde(skip_serializing_if = "Option::is_none")]
1115 tool_choice: Option<ToolChoice>,
1116 #[serde(skip_serializing_if = "Vec::is_empty")]
1117 tools: Vec<serde_json::Value>,
1118 #[serde(skip_serializing_if = "Option::is_none")]
1119 output_config: Option<OutputConfig>,
1120 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1121 additional_params: Option<serde_json::Value>,
1122 #[serde(skip_serializing_if = "Option::is_none")]
1126 cache_control: Option<CacheControl>,
1127}
1128
1129fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
1131 match content {
1132 Content::Text { cache_control, .. } => *cache_control = value,
1133 Content::Image { cache_control, .. } => *cache_control = value,
1134 Content::ToolResult { cache_control, .. } => *cache_control = value,
1135 Content::Document { cache_control, .. } => *cache_control = value,
1136 _ => {}
1137 }
1138}
1139
1140pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
1145 if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
1147 *cache_control = Some(CacheControl::ephemeral());
1148 }
1149
1150 for msg in messages.iter_mut() {
1152 for content in msg.content.iter_mut() {
1153 set_content_cache_control(content, None);
1154 }
1155 }
1156
1157 if let Some(last_msg) = messages.last_mut() {
1159 set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::ephemeral()));
1160 }
1161}
1162
1163pub(super) fn split_system_messages_from_history(
1164 history: Vec<message::Message>,
1165) -> (Vec<SystemContent>, Vec<message::Message>) {
1166 let mut system = Vec::new();
1167 let mut remaining = Vec::new();
1168
1169 for message in history {
1170 match message {
1171 message::Message::System { content } => {
1172 if !content.is_empty() {
1173 system.push(SystemContent::Text {
1174 text: content,
1175 cache_control: None,
1176 });
1177 }
1178 }
1179 other => remaining.push(other),
1180 }
1181 }
1182
1183 (system, remaining)
1184}
1185
1186pub struct AnthropicRequestParams<'a> {
1188 pub model: &'a str,
1189 pub request: CompletionRequest,
1190 pub prompt_caching: bool,
1191 pub automatic_caching: bool,
1193 pub automatic_caching_ttl: Option<CacheTtl>,
1195}
1196
1197impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
1198 type Error = CompletionError;
1199
1200 fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
1201 let AnthropicRequestParams {
1202 model,
1203 request: mut req,
1204 prompt_caching,
1205 automatic_caching,
1206 automatic_caching_ttl,
1207 } = params;
1208
1209 let Some(max_tokens) = req.max_tokens else {
1211 return Err(CompletionError::RequestError(
1212 "`max_tokens` must be set for Anthropic".into(),
1213 ));
1214 };
1215
1216 let mut full_history = vec![];
1217 if let Some(docs) = req.normalized_documents() {
1218 full_history.push(docs);
1219 }
1220 full_history.extend(req.chat_history);
1221 let (history_system, full_history) = split_system_messages_from_history(full_history);
1222
1223 let mut messages = full_history
1224 .into_iter()
1225 .map(Message::try_from)
1226 .collect::<Result<Vec<Message>, _>>()?;
1227
1228 let mut additional_params_payload = req
1229 .additional_params
1230 .take()
1231 .unwrap_or(serde_json::Value::Null);
1232 let mut additional_tools =
1233 extract_tools_from_additional_params(&mut additional_params_payload)?;
1234
1235 let mut tools = req
1236 .tools
1237 .into_iter()
1238 .map(|tool| ToolDefinition {
1239 name: tool.name,
1240 description: Some(tool.description),
1241 input_schema: tool.parameters,
1242 })
1243 .map(serde_json::to_value)
1244 .collect::<Result<Vec<_>, _>>()?;
1245 tools.append(&mut additional_tools);
1246
1247 let mut system = if let Some(preamble) = req.preamble {
1249 if preamble.is_empty() {
1250 vec![]
1251 } else {
1252 vec![SystemContent::Text {
1253 text: preamble,
1254 cache_control: None,
1255 }]
1256 }
1257 } else {
1258 vec![]
1259 };
1260 system.extend(history_system);
1261
1262 if prompt_caching {
1264 apply_cache_control(&mut system, &mut messages);
1265 }
1266
1267 let output_config = if let Some(schema) = req.output_schema {
1268 let mut schema_value = schema.to_value();
1269 sanitize_schema(&mut schema_value);
1270 Some(OutputConfig {
1271 format: OutputFormat::JsonSchema {
1272 schema: schema_value,
1273 },
1274 })
1275 } else {
1276 None
1277 };
1278
1279 Ok(Self {
1280 model: model.to_string(),
1281 messages,
1282 max_tokens,
1283 system,
1284 temperature: req.temperature,
1285 tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
1286 tools,
1287 output_config,
1288 cache_control: if automatic_caching {
1290 Some(CacheControl::Ephemeral {
1291 ttl: automatic_caching_ttl,
1292 })
1293 } else {
1294 None
1295 },
1296 additional_params: if additional_params_payload.is_null() {
1297 None
1298 } else {
1299 Some(additional_params_payload)
1300 },
1301 })
1302 }
1303}
1304
1305fn extract_tools_from_additional_params(
1306 additional_params: &mut serde_json::Value,
1307) -> Result<Vec<serde_json::Value>, CompletionError> {
1308 if let Some(map) = additional_params.as_object_mut()
1309 && let Some(raw_tools) = map.remove("tools")
1310 {
1311 return serde_json::from_value::<Vec<serde_json::Value>>(raw_tools).map_err(|err| {
1312 CompletionError::RequestError(
1313 format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(),
1314 )
1315 });
1316 }
1317
1318 Ok(Vec::new())
1319}
1320
1321impl<T> completion::CompletionModel for CompletionModel<T>
1322where
1323 T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
1324{
1325 type Response = CompletionResponse;
1326 type StreamingResponse = StreamingCompletionResponse;
1327 type Client = Client<T>;
1328
1329 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1330 Self::new(client.clone(), model.into())
1331 }
1332
1333 async fn completion(
1334 &self,
1335 mut completion_request: completion::CompletionRequest,
1336 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1337 let request_model = completion_request
1338 .model
1339 .clone()
1340 .unwrap_or_else(|| self.model.clone());
1341 let span = if tracing::Span::current().is_disabled() {
1342 info_span!(
1343 target: "rig::completions",
1344 "chat",
1345 gen_ai.operation.name = "chat",
1346 gen_ai.provider.name = "anthropic",
1347 gen_ai.request.model = &request_model,
1348 gen_ai.system_instructions = &completion_request.preamble,
1349 gen_ai.response.id = tracing::field::Empty,
1350 gen_ai.response.model = tracing::field::Empty,
1351 gen_ai.usage.output_tokens = tracing::field::Empty,
1352 gen_ai.usage.input_tokens = tracing::field::Empty,
1353 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
1354 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
1355 )
1356 } else {
1357 tracing::Span::current()
1358 };
1359
1360 if completion_request.max_tokens.is_none() {
1362 if let Some(tokens) = self.default_max_tokens {
1363 completion_request.max_tokens = Some(tokens);
1364 } else {
1365 return Err(CompletionError::RequestError(
1366 "`max_tokens` must be set for Anthropic".into(),
1367 ));
1368 }
1369 }
1370
1371 let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
1372 model: &request_model,
1373 request: completion_request,
1374 prompt_caching: self.prompt_caching,
1375 automatic_caching: self.automatic_caching,
1376 automatic_caching_ttl: self.automatic_caching_ttl.clone(),
1377 })?;
1378
1379 if enabled!(Level::TRACE) {
1380 tracing::trace!(
1381 target: "rig::completions",
1382 "Anthropic completion request: {}",
1383 serde_json::to_string_pretty(&request)?
1384 );
1385 }
1386
1387 async move {
1388 let request: Vec<u8> = serde_json::to_vec(&request)?;
1389
1390 let req = self
1391 .client
1392 .post("/v1/messages")?
1393 .body(request)
1394 .map_err(|e| CompletionError::HttpError(e.into()))?;
1395
1396 let response = self
1397 .client
1398 .send::<_, Bytes>(req)
1399 .await
1400 .map_err(CompletionError::HttpError)?;
1401
1402 if response.status().is_success() {
1403 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
1404 response
1405 .into_body()
1406 .await
1407 .map_err(CompletionError::HttpError)?
1408 .to_vec()
1409 .as_slice(),
1410 )? {
1411 ApiResponse::Message(completion) => {
1412 let span = tracing::Span::current();
1413 span.record_response_metadata(&completion);
1414 span.record_token_usage(&completion.usage);
1415 if enabled!(Level::TRACE) {
1416 tracing::trace!(
1417 target: "rig::completions",
1418 "Anthropic completion response: {}",
1419 serde_json::to_string_pretty(&completion)?
1420 );
1421 }
1422 completion.try_into()
1423 }
1424 ApiResponse::Error(ApiErrorResponse { message }) => {
1425 Err(CompletionError::ResponseError(message))
1426 }
1427 }
1428 } else {
1429 let text: String = String::from_utf8_lossy(
1430 &response
1431 .into_body()
1432 .await
1433 .map_err(CompletionError::HttpError)?,
1434 )
1435 .into();
1436 Err(CompletionError::ProviderError(text))
1437 }
1438 }
1439 .instrument(span)
1440 .await
1441 }
1442
1443 async fn stream(
1444 &self,
1445 request: CompletionRequest,
1446 ) -> Result<
1447 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1448 CompletionError,
1449 > {
1450 CompletionModel::stream(self, request).await
1451 }
1452}
1453
1454#[derive(Debug, Deserialize)]
1455struct ApiErrorResponse {
1456 message: String,
1457}
1458
1459#[derive(Debug, Deserialize)]
1460#[serde(tag = "type", rename_all = "snake_case")]
1461enum ApiResponse<T> {
1462 Message(T),
1463 Error(ApiErrorResponse),
1464}
1465
1466#[cfg(test)]
1467mod tests {
1468 use super::*;
1469 use serde_json::json;
1470 use serde_path_to_error::deserialize;
1471
1472 #[test]
1473 fn test_deserialize_message() {
1474 let assistant_message_json = r#"
1475 {
1476 "role": "assistant",
1477 "content": "\n\nHello there, how may I assist you today?"
1478 }
1479 "#;
1480
1481 let assistant_message_json2 = r#"
1482 {
1483 "role": "assistant",
1484 "content": [
1485 {
1486 "type": "text",
1487 "text": "\n\nHello there, how may I assist you today?"
1488 },
1489 {
1490 "type": "tool_use",
1491 "id": "toolu_01A09q90qw90lq917835lq9",
1492 "name": "get_weather",
1493 "input": {"location": "San Francisco, CA"}
1494 }
1495 ]
1496 }
1497 "#;
1498
1499 let user_message_json = r#"
1500 {
1501 "role": "user",
1502 "content": [
1503 {
1504 "type": "image",
1505 "source": {
1506 "type": "base64",
1507 "media_type": "image/jpeg",
1508 "data": "/9j/4AAQSkZJRg..."
1509 }
1510 },
1511 {
1512 "type": "text",
1513 "text": "What is in this image?"
1514 },
1515 {
1516 "type": "tool_result",
1517 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1518 "content": "15 degrees"
1519 }
1520 ]
1521 }
1522 "#;
1523
1524 let assistant_message: Message = {
1525 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1526 deserialize(jd).unwrap_or_else(|err| {
1527 panic!("Deserialization error at {}: {}", err.path(), err);
1528 })
1529 };
1530
1531 let assistant_message2: Message = {
1532 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1533 deserialize(jd).unwrap_or_else(|err| {
1534 panic!("Deserialization error at {}: {}", err.path(), err);
1535 })
1536 };
1537
1538 let user_message: Message = {
1539 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1540 deserialize(jd).unwrap_or_else(|err| {
1541 panic!("Deserialization error at {}: {}", err.path(), err);
1542 })
1543 };
1544
1545 let Message { role, content } = assistant_message;
1546 assert_eq!(role, Role::Assistant);
1547 assert_eq!(
1548 content.first(),
1549 Content::Text {
1550 text: "\n\nHello there, how may I assist you today?".to_owned(),
1551 cache_control: None,
1552 }
1553 );
1554
1555 let Message { role, content } = assistant_message2;
1556 {
1557 assert_eq!(role, Role::Assistant);
1558 assert_eq!(content.len(), 2);
1559
1560 let mut iter = content.into_iter();
1561
1562 match iter.next().unwrap() {
1563 Content::Text { text, .. } => {
1564 assert_eq!(text, "\n\nHello there, how may I assist you today?");
1565 }
1566 _ => panic!("Expected text content"),
1567 }
1568
1569 match iter.next().unwrap() {
1570 Content::ToolUse { id, name, input } => {
1571 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1572 assert_eq!(name, "get_weather");
1573 assert_eq!(input, json!({"location": "San Francisco, CA"}));
1574 }
1575 _ => panic!("Expected tool use content"),
1576 }
1577
1578 assert_eq!(iter.next(), None);
1579 }
1580
1581 let Message { role, content } = user_message;
1582 {
1583 assert_eq!(role, Role::User);
1584 assert_eq!(content.len(), 3);
1585
1586 let mut iter = content.into_iter();
1587
1588 match iter.next().unwrap() {
1589 Content::Image { source, .. } => {
1590 assert_eq!(
1591 source,
1592 ImageSource::Base64 {
1593 data: "/9j/4AAQSkZJRg...".to_owned(),
1594 media_type: ImageFormat::JPEG,
1595 }
1596 );
1597 }
1598 _ => panic!("Expected image content"),
1599 }
1600
1601 match iter.next().unwrap() {
1602 Content::Text { text, .. } => {
1603 assert_eq!(text, "What is in this image?");
1604 }
1605 _ => panic!("Expected text content"),
1606 }
1607
1608 match iter.next().unwrap() {
1609 Content::ToolResult {
1610 tool_use_id,
1611 content,
1612 is_error,
1613 ..
1614 } => {
1615 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1616 assert_eq!(
1617 content.first(),
1618 ToolResultContent::Text {
1619 text: "15 degrees".to_owned()
1620 }
1621 );
1622 assert_eq!(is_error, None);
1623 }
1624 _ => panic!("Expected tool result content"),
1625 }
1626
1627 assert_eq!(iter.next(), None);
1628 }
1629 }
1630
1631 #[test]
1632 fn test_message_to_message_conversion() {
1633 let user_message: Message = serde_json::from_str(
1634 r#"
1635 {
1636 "role": "user",
1637 "content": [
1638 {
1639 "type": "image",
1640 "source": {
1641 "type": "base64",
1642 "media_type": "image/jpeg",
1643 "data": "/9j/4AAQSkZJRg..."
1644 }
1645 },
1646 {
1647 "type": "text",
1648 "text": "What is in this image?"
1649 },
1650 {
1651 "type": "document",
1652 "source": {
1653 "type": "base64",
1654 "data": "base64_encoded_pdf_data",
1655 "media_type": "application/pdf"
1656 }
1657 }
1658 ]
1659 }
1660 "#,
1661 )
1662 .unwrap();
1663
1664 let assistant_message = Message {
1665 role: Role::Assistant,
1666 content: OneOrMany::one(Content::ToolUse {
1667 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1668 name: "get_weather".to_string(),
1669 input: json!({"location": "San Francisco, CA"}),
1670 }),
1671 };
1672
1673 let tool_message = Message {
1674 role: Role::User,
1675 content: OneOrMany::one(Content::ToolResult {
1676 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1677 content: OneOrMany::one(ToolResultContent::Text {
1678 text: "15 degrees".to_string(),
1679 }),
1680 is_error: None,
1681 cache_control: None,
1682 }),
1683 };
1684
1685 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1686 let converted_assistant_message: message::Message =
1687 assistant_message.clone().try_into().unwrap();
1688 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1689
1690 match converted_user_message.clone() {
1691 message::Message::User { content } => {
1692 assert_eq!(content.len(), 3);
1693
1694 let mut iter = content.into_iter();
1695
1696 match iter.next().unwrap() {
1697 message::UserContent::Image(message::Image {
1698 data, media_type, ..
1699 }) => {
1700 assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1701 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1702 }
1703 _ => panic!("Expected image content"),
1704 }
1705
1706 match iter.next().unwrap() {
1707 message::UserContent::Text(message::Text { text }) => {
1708 assert_eq!(text, "What is in this image?");
1709 }
1710 _ => panic!("Expected text content"),
1711 }
1712
1713 match iter.next().unwrap() {
1714 message::UserContent::Document(message::Document {
1715 data, media_type, ..
1716 }) => {
1717 assert_eq!(
1718 data,
1719 DocumentSourceKind::String("base64_encoded_pdf_data".into())
1720 );
1721 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1722 }
1723 _ => panic!("Expected document content"),
1724 }
1725
1726 assert_eq!(iter.next(), None);
1727 }
1728 _ => panic!("Expected user message"),
1729 }
1730
1731 match converted_tool_message.clone() {
1732 message::Message::User { content } => {
1733 let message::ToolResult { id, content, .. } = match content.first() {
1734 message::UserContent::ToolResult(tool_result) => tool_result,
1735 _ => panic!("Expected tool result content"),
1736 };
1737 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1738 match content.first() {
1739 message::ToolResultContent::Text(message::Text { text }) => {
1740 assert_eq!(text, "15 degrees");
1741 }
1742 _ => panic!("Expected text content"),
1743 }
1744 }
1745 _ => panic!("Expected tool result content"),
1746 }
1747
1748 match converted_assistant_message.clone() {
1749 message::Message::Assistant { content, .. } => {
1750 assert_eq!(content.len(), 1);
1751
1752 match content.first() {
1753 message::AssistantContent::ToolCall(message::ToolCall {
1754 id, function, ..
1755 }) => {
1756 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1757 assert_eq!(function.name, "get_weather");
1758 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1759 }
1760 _ => panic!("Expected tool call content"),
1761 }
1762 }
1763 _ => panic!("Expected assistant message"),
1764 }
1765
1766 let original_user_message: Message = converted_user_message.try_into().unwrap();
1767 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1768 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1769
1770 assert_eq!(user_message, original_user_message);
1771 assert_eq!(assistant_message, original_assistant_message);
1772 assert_eq!(tool_message, original_tool_message);
1773 }
1774
1775 #[test]
1776 fn test_content_format_conversion() {
1777 use crate::completion::message::ContentFormat;
1778
1779 let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1780 assert_eq!(source_type, SourceType::URL);
1781
1782 let content_format: ContentFormat = SourceType::URL.into();
1783 assert_eq!(content_format, ContentFormat::Url);
1784
1785 let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1786 assert_eq!(source_type, SourceType::BASE64);
1787
1788 let content_format: ContentFormat = SourceType::BASE64.into();
1789 assert_eq!(content_format, ContentFormat::Base64);
1790
1791 let source_type: SourceType = ContentFormat::String.try_into().unwrap();
1792 assert_eq!(source_type, SourceType::TEXT);
1793
1794 let content_format: ContentFormat = SourceType::TEXT.into();
1795 assert_eq!(content_format, ContentFormat::String);
1796 }
1797
1798 #[test]
1799 fn test_cache_control_serialization() {
1800 let system = SystemContent::Text {
1802 text: "You are a helpful assistant.".to_string(),
1803 cache_control: Some(CacheControl::ephemeral()),
1804 };
1805 let json = serde_json::to_string(&system).unwrap();
1806 assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
1807 assert!(json.contains(r#""type":"text""#));
1808
1809 let system_no_cache = SystemContent::Text {
1811 text: "Hello".to_string(),
1812 cache_control: None,
1813 };
1814 let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
1815 assert!(!json_no_cache.contains("cache_control"));
1816
1817 let content = Content::Text {
1819 text: "Test message".to_string(),
1820 cache_control: Some(CacheControl::ephemeral()),
1821 };
1822 let json_content = serde_json::to_string(&content).unwrap();
1823 assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
1824
1825 let mut system_vec = vec![SystemContent::Text {
1827 text: "System prompt".to_string(),
1828 cache_control: None,
1829 }];
1830 let mut messages = vec![
1831 Message {
1832 role: Role::User,
1833 content: OneOrMany::one(Content::Text {
1834 text: "First message".to_string(),
1835 cache_control: None,
1836 }),
1837 },
1838 Message {
1839 role: Role::Assistant,
1840 content: OneOrMany::one(Content::Text {
1841 text: "Response".to_string(),
1842 cache_control: None,
1843 }),
1844 },
1845 ];
1846
1847 apply_cache_control(&mut system_vec, &mut messages);
1848
1849 match &system_vec[0] {
1851 SystemContent::Text { cache_control, .. } => {
1852 assert!(cache_control.is_some());
1853 }
1854 }
1855
1856 for content in messages[0].content.iter() {
1859 if let Content::Text { cache_control, .. } = content {
1860 assert!(cache_control.is_none());
1861 }
1862 }
1863
1864 for content in messages[1].content.iter() {
1866 if let Content::Text { cache_control, .. } = content {
1867 assert!(cache_control.is_some());
1868 }
1869 }
1870 }
1871
1872 #[test]
1873 fn test_plaintext_document_serialization() {
1874 let content = Content::Document {
1875 source: DocumentSource::Text {
1876 data: "Hello, world!".to_string(),
1877 media_type: PlainTextMediaType::Plain,
1878 },
1879 cache_control: None,
1880 };
1881
1882 let json = serde_json::to_value(&content).unwrap();
1883 assert_eq!(json["type"], "document");
1884 assert_eq!(json["source"]["type"], "text");
1885 assert_eq!(json["source"]["media_type"], "text/plain");
1886 assert_eq!(json["source"]["data"], "Hello, world!");
1887 }
1888
1889 #[test]
1890 fn test_plaintext_document_deserialization() {
1891 let json = r#"
1892 {
1893 "type": "document",
1894 "source": {
1895 "type": "text",
1896 "media_type": "text/plain",
1897 "data": "Hello, world!"
1898 }
1899 }
1900 "#;
1901
1902 let content: Content = serde_json::from_str(json).unwrap();
1903 match content {
1904 Content::Document {
1905 source,
1906 cache_control,
1907 } => {
1908 assert_eq!(
1909 source,
1910 DocumentSource::Text {
1911 data: "Hello, world!".to_string(),
1912 media_type: PlainTextMediaType::Plain,
1913 }
1914 );
1915 assert_eq!(cache_control, None);
1916 }
1917 _ => panic!("Expected Document content"),
1918 }
1919 }
1920
1921 #[test]
1922 fn test_base64_pdf_document_serialization() {
1923 let content = Content::Document {
1924 source: DocumentSource::Base64 {
1925 data: "base64data".to_string(),
1926 media_type: DocumentFormat::PDF,
1927 },
1928 cache_control: None,
1929 };
1930
1931 let json = serde_json::to_value(&content).unwrap();
1932 assert_eq!(json["type"], "document");
1933 assert_eq!(json["source"]["type"], "base64");
1934 assert_eq!(json["source"]["media_type"], "application/pdf");
1935 assert_eq!(json["source"]["data"], "base64data");
1936 }
1937
1938 #[test]
1939 fn test_base64_pdf_document_deserialization() {
1940 let json = r#"
1941 {
1942 "type": "document",
1943 "source": {
1944 "type": "base64",
1945 "media_type": "application/pdf",
1946 "data": "base64data"
1947 }
1948 }
1949 "#;
1950
1951 let content: Content = serde_json::from_str(json).unwrap();
1952 match content {
1953 Content::Document { source, .. } => {
1954 assert_eq!(
1955 source,
1956 DocumentSource::Base64 {
1957 data: "base64data".to_string(),
1958 media_type: DocumentFormat::PDF,
1959 }
1960 );
1961 }
1962 _ => panic!("Expected Document content"),
1963 }
1964 }
1965
1966 #[test]
1967 fn test_plaintext_rig_to_anthropic_conversion() {
1968 use crate::completion::message as msg;
1969
1970 let rig_message = msg::Message::User {
1971 content: OneOrMany::one(msg::UserContent::document(
1972 "Some plain text content".to_string(),
1973 Some(msg::DocumentMediaType::TXT),
1974 )),
1975 };
1976
1977 let anthropic_message: Message = rig_message.try_into().unwrap();
1978 assert_eq!(anthropic_message.role, Role::User);
1979
1980 let mut iter = anthropic_message.content.into_iter();
1981 match iter.next().unwrap() {
1982 Content::Document { source, .. } => {
1983 assert_eq!(
1984 source,
1985 DocumentSource::Text {
1986 data: "Some plain text content".to_string(),
1987 media_type: PlainTextMediaType::Plain,
1988 }
1989 );
1990 }
1991 other => panic!("Expected Document content, got: {other:?}"),
1992 }
1993 }
1994
1995 #[test]
1996 fn test_plaintext_anthropic_to_rig_conversion() {
1997 use crate::completion::message as msg;
1998
1999 let anthropic_message = Message {
2000 role: Role::User,
2001 content: OneOrMany::one(Content::Document {
2002 source: DocumentSource::Text {
2003 data: "Some plain text content".to_string(),
2004 media_type: PlainTextMediaType::Plain,
2005 },
2006 cache_control: None,
2007 }),
2008 };
2009
2010 let rig_message: msg::Message = anthropic_message.try_into().unwrap();
2011 match rig_message {
2012 msg::Message::User { content } => {
2013 let mut iter = content.into_iter();
2014 match iter.next().unwrap() {
2015 msg::UserContent::Document(msg::Document {
2016 data, media_type, ..
2017 }) => {
2018 assert_eq!(
2019 data,
2020 DocumentSourceKind::String("Some plain text content".into())
2021 );
2022 assert_eq!(media_type, Some(msg::DocumentMediaType::TXT));
2023 }
2024 other => panic!("Expected Document content, got: {other:?}"),
2025 }
2026 }
2027 _ => panic!("Expected User message"),
2028 }
2029 }
2030
2031 #[test]
2032 fn test_plaintext_roundtrip_rig_to_anthropic_and_back() {
2033 use crate::completion::message as msg;
2034
2035 let original = msg::Message::User {
2036 content: OneOrMany::one(msg::UserContent::document(
2037 "Round trip text".to_string(),
2038 Some(msg::DocumentMediaType::TXT),
2039 )),
2040 };
2041
2042 let anthropic: Message = original.clone().try_into().unwrap();
2043 let back: msg::Message = anthropic.try_into().unwrap();
2044
2045 match (&original, &back) {
2046 (
2047 msg::Message::User {
2048 content: orig_content,
2049 },
2050 msg::Message::User {
2051 content: back_content,
2052 },
2053 ) => match (orig_content.first(), back_content.first()) {
2054 (
2055 msg::UserContent::Document(msg::Document {
2056 media_type: orig_mt,
2057 ..
2058 }),
2059 msg::UserContent::Document(msg::Document {
2060 media_type: back_mt,
2061 ..
2062 }),
2063 ) => {
2064 assert_eq!(orig_mt, back_mt);
2065 }
2066 _ => panic!("Expected Document content in both"),
2067 },
2068 _ => panic!("Expected User messages"),
2069 }
2070 }
2071
2072 #[test]
2073 fn test_unsupported_document_type_returns_error() {
2074 use crate::completion::message as msg;
2075
2076 let rig_message = msg::Message::User {
2077 content: OneOrMany::one(msg::UserContent::Document(msg::Document {
2078 data: DocumentSourceKind::String("data".into()),
2079 media_type: Some(msg::DocumentMediaType::HTML),
2080 additional_params: None,
2081 })),
2082 };
2083
2084 let result: Result<Message, _> = rig_message.try_into();
2085 assert!(result.is_err());
2086 let err = result.unwrap_err().to_string();
2087 assert!(
2088 err.contains("Anthropic only supports PDF and plain text documents"),
2089 "Unexpected error: {err}"
2090 );
2091 }
2092
2093 #[test]
2094 fn test_plaintext_document_url_source_returns_error() {
2095 use crate::completion::message as msg;
2096
2097 let rig_message = msg::Message::User {
2098 content: OneOrMany::one(msg::UserContent::Document(msg::Document {
2099 data: DocumentSourceKind::Url("https://example.com/doc.txt".into()),
2100 media_type: Some(msg::DocumentMediaType::TXT),
2101 additional_params: None,
2102 })),
2103 };
2104
2105 let result: Result<Message, _> = rig_message.try_into();
2106 assert!(result.is_err());
2107 let err = result.unwrap_err().to_string();
2108 assert!(
2109 err.contains("Only string or base64 data is supported for plain text documents"),
2110 "Unexpected error: {err}"
2111 );
2112 }
2113
2114 #[test]
2115 fn test_plaintext_document_with_cache_control() {
2116 let content = Content::Document {
2117 source: DocumentSource::Text {
2118 data: "cached text".to_string(),
2119 media_type: PlainTextMediaType::Plain,
2120 },
2121 cache_control: Some(CacheControl::ephemeral()),
2122 };
2123
2124 let json = serde_json::to_value(&content).unwrap();
2125 assert_eq!(json["source"]["type"], "text");
2126 assert_eq!(json["source"]["media_type"], "text/plain");
2127 assert_eq!(json["cache_control"]["type"], "ephemeral");
2128 }
2129
2130 #[test]
2131 fn test_message_with_plaintext_document_deserialization() {
2132 let json = r#"
2133 {
2134 "role": "user",
2135 "content": [
2136 {
2137 "type": "document",
2138 "source": {
2139 "type": "text",
2140 "media_type": "text/plain",
2141 "data": "Hello from a text file"
2142 }
2143 },
2144 {
2145 "type": "text",
2146 "text": "Summarize this document."
2147 }
2148 ]
2149 }
2150 "#;
2151
2152 let message: Message = serde_json::from_str(json).unwrap();
2153 assert_eq!(message.role, Role::User);
2154 assert_eq!(message.content.len(), 2);
2155
2156 let mut iter = message.content.into_iter();
2157
2158 match iter.next().unwrap() {
2159 Content::Document { source, .. } => {
2160 assert_eq!(
2161 source,
2162 DocumentSource::Text {
2163 data: "Hello from a text file".to_string(),
2164 media_type: PlainTextMediaType::Plain,
2165 }
2166 );
2167 }
2168 _ => panic!("Expected Document content"),
2169 }
2170
2171 match iter.next().unwrap() {
2172 Content::Text { text, .. } => {
2173 assert_eq!(text, "Summarize this document.");
2174 }
2175 _ => panic!("Expected Text content"),
2176 }
2177 }
2178
2179 #[test]
2180 fn test_assistant_reasoning_multiblock_to_anthropic_content() {
2181 let reasoning = message::Reasoning {
2182 id: None,
2183 content: vec![
2184 message::ReasoningContent::Text {
2185 text: "step one".to_string(),
2186 signature: Some("sig-1".to_string()),
2187 },
2188 message::ReasoningContent::Summary("summary".to_string()),
2189 message::ReasoningContent::Text {
2190 text: "step two".to_string(),
2191 signature: Some("sig-2".to_string()),
2192 },
2193 message::ReasoningContent::Redacted {
2194 data: "redacted block".to_string(),
2195 },
2196 ],
2197 };
2198
2199 let msg = message::Message::Assistant {
2200 id: None,
2201 content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2202 };
2203 let converted: Message = msg.try_into().expect("convert assistant message");
2204 let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2205
2206 assert_eq!(converted.role, Role::Assistant);
2207 assert_eq!(converted_content.len(), 4);
2208 assert!(matches!(
2209 converted_content.first(),
2210 Some(Content::Thinking { thinking, signature: Some(signature) })
2211 if thinking == "step one" && signature == "sig-1"
2212 ));
2213 assert!(matches!(
2214 converted_content.get(1),
2215 Some(Content::Thinking { thinking, signature: None }) if thinking == "summary"
2216 ));
2217 assert!(matches!(
2218 converted_content.get(2),
2219 Some(Content::Thinking { thinking, signature: Some(signature) })
2220 if thinking == "step two" && signature == "sig-2"
2221 ));
2222 assert!(matches!(
2223 converted_content.get(3),
2224 Some(Content::RedactedThinking { data }) if data == "redacted block"
2225 ));
2226 }
2227
2228 #[test]
2229 fn test_redacted_thinking_content_to_assistant_reasoning() {
2230 let content = Content::RedactedThinking {
2231 data: "opaque-redacted".to_string(),
2232 };
2233 let converted: message::AssistantContent =
2234 content.try_into().expect("convert redacted thinking");
2235
2236 assert!(matches!(
2237 converted,
2238 message::AssistantContent::Reasoning(message::Reasoning { content, .. })
2239 if matches!(
2240 content.first(),
2241 Some(message::ReasoningContent::Redacted { data }) if data == "opaque-redacted"
2242 )
2243 ));
2244 }
2245
2246 #[test]
2247 fn test_assistant_encrypted_reasoning_maps_to_redacted_thinking() {
2248 let reasoning = message::Reasoning {
2249 id: None,
2250 content: vec![message::ReasoningContent::Encrypted(
2251 "ciphertext".to_string(),
2252 )],
2253 };
2254 let msg = message::Message::Assistant {
2255 id: None,
2256 content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
2257 };
2258
2259 let converted: Message = msg.try_into().expect("convert assistant message");
2260 let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
2261
2262 assert_eq!(converted_content.len(), 1);
2263 assert!(matches!(
2264 converted_content.first(),
2265 Some(Content::RedactedThinking { data }) if data == "ciphertext"
2266 ));
2267 }
2268}