use crate::completion::CompletionRequest;
use crate::providers::anthropic::streaming::StreamingCompletionResponse;
use crate::{
OneOrMany,
client::Provider,
completion::{self, CompletionError, GetTokenUsage},
http_client::HttpClientExt,
message::{self, DocumentMediaType, DocumentSourceKind, MessageError, MimeType, Reasoning},
one_or_many::string_or_one_or_many,
telemetry::{ProviderResponseExt, SpanCombinator},
wasm_compat::*,
};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::{convert::Infallible, str::FromStr};
use tracing::{Instrument, Level, enabled, info_span};
pub const CLAUDE_OPUS_4_6: &str = "claude-opus-4-6";
pub const CLAUDE_OPUS_4_7: &str = "claude-opus-4-7";
pub const CLAUDE_SONNET_4_6: &str = "claude-sonnet-4-6";
pub const CLAUDE_HAIKU_4_5: &str = "claude-haiku-4-5";
pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
const EMPTY_RESPONSE_ERROR: &str = "Response contained no message or tool call (empty)";
pub trait AnthropicCompatibleProvider: Provider {
const PROVIDER_NAME: &'static str;
fn default_max_tokens(model: &str) -> Option<u64> {
let _ = model;
None
}
}
impl AnthropicCompatibleProvider for super::client::AnthropicExt {
const PROVIDER_NAME: &'static str = "anthropic";
fn default_max_tokens(model: &str) -> Option<u64> {
default_max_tokens_for_model(model)
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct CompletionResponse {
pub content: Vec<Content>,
pub id: String,
pub model: String,
pub role: String,
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
pub usage: Usage,
}
impl ProviderResponseExt for CompletionResponse {
type OutputMessage = Content;
type Usage = Usage;
fn get_response_id(&self) -> Option<String> {
Some(self.id.to_owned())
}
fn get_response_model_name(&self) -> Option<String> {
Some(self.model.to_owned())
}
fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
self.content.clone()
}
fn get_text_response(&self) -> Option<String> {
let res = self
.content
.iter()
.filter_map(|x| {
if let Content::Text { text, .. } = x {
Some(text.to_owned())
} else {
None
}
})
.collect::<Vec<String>>()
.join("\n");
if res.is_empty() { None } else { Some(res) }
}
fn get_usage(&self) -> Option<Self::Usage> {
Some(self.usage.clone())
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Usage {
pub input_tokens: u64,
pub cache_read_input_tokens: Option<u64>,
pub cache_creation_input_tokens: Option<u64>,
pub output_tokens: u64,
}
impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
self.input_tokens,
match self.cache_read_input_tokens {
Some(token) => token.to_string(),
None => "n/a".to_string(),
},
match self.cache_creation_input_tokens {
Some(token) => token.to_string(),
None => "n/a".to_string(),
},
self.output_tokens
)
}
}
impl GetTokenUsage for Usage {
fn token_usage(&self) -> Option<crate::completion::Usage> {
let mut usage = crate::completion::Usage::new();
usage.input_tokens = self.input_tokens;
usage.output_tokens = self.output_tokens;
usage.cached_input_tokens = self.cache_read_input_tokens.unwrap_or_default();
usage.cache_creation_input_tokens = self.cache_creation_input_tokens.unwrap_or_default();
usage.total_tokens = self.input_tokens
+ self.cache_read_input_tokens.unwrap_or_default()
+ self.cache_creation_input_tokens.unwrap_or_default()
+ self.output_tokens;
Some(usage)
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ToolDefinition {
pub name: String,
pub description: Option<String>,
pub input_schema: serde_json::Value,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Default)]
pub enum CacheTtl {
#[default]
#[serde(rename = "5m")]
FiveMinutes,
#[serde(rename = "1h")]
OneHour,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum CacheControl {
Ephemeral {
#[serde(skip_serializing_if = "Option::is_none")]
ttl: Option<CacheTtl>,
},
}
impl CacheControl {
pub fn ephemeral() -> Self {
Self::Ephemeral { ttl: None }
}
pub fn ephemeral_1h() -> Self {
Self::Ephemeral {
ttl: Some(CacheTtl::OneHour),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SystemContent {
Text {
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
}
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
type Error = CompletionError;
fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
let content = response
.content
.iter()
.map(|content| content.clone().try_into())
.collect::<Result<Vec<_>, _>>()?;
let choice = if content.is_empty() {
if response.stop_reason.as_deref() == Some("end_turn") {
OneOrMany::one(completion::AssistantContent::text(""))
} else {
return Err(CompletionError::ResponseError(
EMPTY_RESPONSE_ERROR.to_owned(),
));
}
} else {
OneOrMany::many(content)
.map_err(|_| CompletionError::ResponseError(EMPTY_RESPONSE_ERROR.to_owned()))?
};
let usage = completion::Usage {
input_tokens: response.usage.input_tokens,
output_tokens: response.usage.output_tokens,
total_tokens: response.usage.input_tokens
+ response.usage.cache_read_input_tokens.unwrap_or(0)
+ response.usage.cache_creation_input_tokens.unwrap_or(0)
+ response.usage.output_tokens,
cached_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
cache_creation_input_tokens: response.usage.cache_creation_input_tokens.unwrap_or(0),
};
Ok(completion::CompletionResponse {
choice,
usage,
raw_response: response,
message_id: None,
})
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct Message {
pub role: Role,
#[serde(deserialize_with = "string_or_one_or_many")]
pub content: OneOrMany<Content>,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Content {
Text {
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
Image {
source: ImageSource,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
ToolResult {
tool_use_id: String,
#[serde(deserialize_with = "string_or_one_or_many")]
content: OneOrMany<ToolResultContent>,
#[serde(skip_serializing_if = "Option::is_none")]
is_error: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
Document {
source: DocumentSource,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
Thinking {
thinking: String,
#[serde(skip_serializing_if = "Option::is_none")]
signature: Option<String>,
},
RedactedThinking {
data: String,
},
}
impl FromStr for Content {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Content::Text {
text: s.to_owned(),
cache_control: None,
})
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolResultContent {
Text { text: String },
Image(ImageSource),
}
impl FromStr for ToolResultContent {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(ToolResultContent::Text { text: s.to_owned() })
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ImageSource {
#[serde(rename = "base64")]
Base64 {
data: String,
media_type: ImageFormat,
},
#[serde(rename = "url")]
Url { url: String },
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DocumentSource {
Base64 {
data: String,
media_type: DocumentFormat,
},
Text {
data: String,
media_type: PlainTextMediaType,
},
Url {
url: String,
},
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ImageFormat {
#[serde(rename = "image/jpeg")]
JPEG,
#[serde(rename = "image/png")]
PNG,
#[serde(rename = "image/gif")]
GIF,
#[serde(rename = "image/webp")]
WEBP,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum DocumentFormat {
#[serde(rename = "application/pdf")]
PDF,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub enum PlainTextMediaType {
#[serde(rename = "text/plain")]
Plain,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum SourceType {
BASE64,
URL,
TEXT,
}
impl From<String> for Content {
fn from(text: String) -> Self {
Content::Text {
text,
cache_control: None,
}
}
}
impl From<String> for ToolResultContent {
fn from(text: String) -> Self {
ToolResultContent::Text { text }
}
}
impl TryFrom<message::ContentFormat> for SourceType {
type Error = MessageError;
fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
match format {
message::ContentFormat::Base64 => Ok(SourceType::BASE64),
message::ContentFormat::Url => Ok(SourceType::URL),
message::ContentFormat::String => Ok(SourceType::TEXT),
}
}
}
impl From<SourceType> for message::ContentFormat {
fn from(source_type: SourceType) -> Self {
match source_type {
SourceType::BASE64 => message::ContentFormat::Base64,
SourceType::URL => message::ContentFormat::Url,
SourceType::TEXT => message::ContentFormat::String,
}
}
}
impl TryFrom<message::ImageMediaType> for ImageFormat {
type Error = MessageError;
fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
Ok(match media_type {
message::ImageMediaType::JPEG => ImageFormat::JPEG,
message::ImageMediaType::PNG => ImageFormat::PNG,
message::ImageMediaType::GIF => ImageFormat::GIF,
message::ImageMediaType::WEBP => ImageFormat::WEBP,
_ => {
return Err(MessageError::ConversionError(
format!("Unsupported image media type: {media_type:?}").to_owned(),
));
}
})
}
}
impl From<ImageFormat> for message::ImageMediaType {
fn from(format: ImageFormat) -> Self {
match format {
ImageFormat::JPEG => message::ImageMediaType::JPEG,
ImageFormat::PNG => message::ImageMediaType::PNG,
ImageFormat::GIF => message::ImageMediaType::GIF,
ImageFormat::WEBP => message::ImageMediaType::WEBP,
}
}
}
impl TryFrom<DocumentMediaType> for DocumentFormat {
type Error = MessageError;
fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
match value {
DocumentMediaType::PDF => Ok(DocumentFormat::PDF),
other => Err(MessageError::ConversionError(format!(
"DocumentFormat only supports PDF for base64 sources, got: {}",
other.to_mime_type()
))),
}
}
}
impl TryFrom<message::AssistantContent> for Content {
type Error = MessageError;
fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
match text {
message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text {
text,
cache_control: None,
}),
message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
"Anthropic currently doesn't support images.".to_string(),
)),
message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
Ok(Content::ToolUse {
id,
name: function.name,
input: function.arguments,
})
}
message::AssistantContent::Reasoning(reasoning) => Ok(Content::Thinking {
thinking: reasoning.display_text(),
signature: reasoning.first_signature().map(str::to_owned),
}),
}
}
}
fn anthropic_content_from_assistant_content(
content: message::AssistantContent,
) -> Result<Vec<Content>, MessageError> {
match content {
message::AssistantContent::Text(message::Text { text }) => Ok(vec![Content::Text {
text,
cache_control: None,
}]),
message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
"Anthropic currently doesn't support images.".to_string(),
)),
message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
Ok(vec![Content::ToolUse {
id,
name: function.name,
input: function.arguments,
}])
}
message::AssistantContent::Reasoning(reasoning) => {
let mut converted = Vec::new();
for block in reasoning.content {
match block {
message::ReasoningContent::Text { text, signature } => {
converted.push(Content::Thinking {
thinking: text,
signature,
});
}
message::ReasoningContent::Summary(summary) => {
converted.push(Content::Thinking {
thinking: summary,
signature: None,
});
}
message::ReasoningContent::Redacted { data }
| message::ReasoningContent::Encrypted(data) => {
converted.push(Content::RedactedThinking { data });
}
}
}
if converted.is_empty() {
return Err(MessageError::ConversionError(
"Cannot convert empty reasoning content to Anthropic format".to_string(),
));
}
Ok(converted)
}
}
}
impl TryFrom<message::Message> for Message {
type Error = MessageError;
fn try_from(message: message::Message) -> Result<Self, Self::Error> {
Ok(match message {
message::Message::User { content } => Message {
role: Role::User,
content: content.try_map(|content| match content {
message::UserContent::Text(message::Text { text }) => Ok(Content::Text {
text,
cache_control: None,
}),
message::UserContent::ToolResult(message::ToolResult {
id, content, ..
}) => Ok(Content::ToolResult {
tool_use_id: id,
content: content.try_map(|content| match content {
message::ToolResultContent::Text(message::Text { text }) => {
Ok(ToolResultContent::Text { text })
}
message::ToolResultContent::Image(image) => {
let DocumentSourceKind::Base64(data) = image.data else {
return Err(MessageError::ConversionError(
"Only base64 strings can be used with the Anthropic API"
.to_string(),
));
};
let media_type =
image.media_type.ok_or(MessageError::ConversionError(
"Image media type is required".to_owned(),
))?;
Ok(ToolResultContent::Image(ImageSource::Base64 {
data,
media_type: media_type.try_into()?,
}))
}
})?,
is_error: None,
cache_control: None,
}),
message::UserContent::Image(message::Image {
data, media_type, ..
}) => {
let source = match data {
DocumentSourceKind::Base64(data) => {
let media_type =
media_type.ok_or(MessageError::ConversionError(
"Image media type is required for Claude API".to_string(),
))?;
ImageSource::Base64 {
data,
media_type: ImageFormat::try_from(media_type)?,
}
}
DocumentSourceKind::Url(url) => ImageSource::Url { url },
DocumentSourceKind::Unknown => {
return Err(MessageError::ConversionError(
"Image content has no body".into(),
));
}
doc => {
return Err(MessageError::ConversionError(format!(
"Unsupported document type: {doc:?}"
)));
}
};
Ok(Content::Image {
source,
cache_control: None,
})
}
message::UserContent::Document(message::Document {
data, media_type, ..
}) => {
let media_type = media_type.ok_or(MessageError::ConversionError(
"Document media type is required".to_string(),
))?;
let source = match media_type {
DocumentMediaType::PDF => {
let data = match data {
DocumentSourceKind::Base64(data)
| DocumentSourceKind::String(data) => data,
_ => {
return Err(MessageError::ConversionError(
"Only base64 encoded data is supported for PDF documents".into(),
));
}
};
DocumentSource::Base64 {
data,
media_type: DocumentFormat::PDF,
}
}
DocumentMediaType::TXT => {
let data = match data {
DocumentSourceKind::String(data)
| DocumentSourceKind::Base64(data) => data,
_ => {
return Err(MessageError::ConversionError(
"Only string or base64 data is supported for plain text documents".into(),
));
}
};
DocumentSource::Text {
data,
media_type: PlainTextMediaType::Plain,
}
}
other => {
return Err(MessageError::ConversionError(format!(
"Anthropic only supports PDF and plain text documents, got: {}",
other.to_mime_type()
)));
}
};
Ok(Content::Document {
source,
cache_control: None,
})
}
message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
"Audio is not supported in Anthropic".to_owned(),
)),
message::UserContent::Video { .. } => Err(MessageError::ConversionError(
"Video is not supported in Anthropic".to_owned(),
)),
})?,
},
message::Message::System { content } => Message {
role: Role::User,
content: OneOrMany::one(Content::Text {
text: content,
cache_control: None,
}),
},
message::Message::Assistant { content, .. } => {
let converted_content = content.into_iter().try_fold(
Vec::new(),
|mut accumulated, assistant_content| {
accumulated
.extend(anthropic_content_from_assistant_content(assistant_content)?);
Ok::<Vec<Content>, MessageError>(accumulated)
},
)?;
Message {
content: OneOrMany::many(converted_content).map_err(|_| {
MessageError::ConversionError(
"Assistant message did not contain Anthropic-compatible content"
.to_owned(),
)
})?,
role: Role::Assistant,
}
}
})
}
}
impl TryFrom<Content> for message::AssistantContent {
type Error = MessageError;
fn try_from(content: Content) -> Result<Self, Self::Error> {
Ok(match content {
Content::Text { text, .. } => message::AssistantContent::text(text),
Content::ToolUse { id, name, input } => {
message::AssistantContent::tool_call(id, name, input)
}
Content::Thinking {
thinking,
signature,
} => message::AssistantContent::Reasoning(Reasoning::new_with_signature(
&thinking, signature,
)),
Content::RedactedThinking { data } => {
message::AssistantContent::Reasoning(Reasoning::redacted(data))
}
_ => {
return Err(MessageError::ConversionError(
"Content did not contain a message, tool call, or reasoning".to_owned(),
));
}
})
}
}
impl From<ToolResultContent> for message::ToolResultContent {
fn from(content: ToolResultContent) -> Self {
match content {
ToolResultContent::Text { text } => message::ToolResultContent::text(text),
ToolResultContent::Image(source) => match source {
ImageSource::Base64 { data, media_type } => {
message::ToolResultContent::image_base64(data, Some(media_type.into()), None)
}
ImageSource::Url { url } => message::ToolResultContent::image_url(url, None, None),
},
}
}
}
impl TryFrom<Message> for message::Message {
type Error = MessageError;
fn try_from(message: Message) -> Result<Self, Self::Error> {
Ok(match message.role {
Role::User => message::Message::User {
content: message.content.try_map(|content| {
Ok(match content {
Content::Text { text, .. } => message::UserContent::text(text),
Content::ToolResult {
tool_use_id,
content,
..
} => message::UserContent::tool_result(
tool_use_id,
content.map(|content| content.into()),
),
Content::Image { source, .. } => match source {
ImageSource::Base64 { data, media_type } => {
message::UserContent::Image(message::Image {
data: DocumentSourceKind::Base64(data),
media_type: Some(media_type.into()),
detail: None,
additional_params: None,
})
}
ImageSource::Url { url } => {
message::UserContent::Image(message::Image {
data: DocumentSourceKind::Url(url),
media_type: None,
detail: None,
additional_params: None,
})
}
},
Content::Document { source, .. } => match source {
DocumentSource::Base64 { data, media_type } => {
let rig_media_type = match media_type {
DocumentFormat::PDF => message::DocumentMediaType::PDF,
};
message::UserContent::document(data, Some(rig_media_type))
}
DocumentSource::Text { data, .. } => message::UserContent::document(
data,
Some(message::DocumentMediaType::TXT),
),
DocumentSource::Url { url } => {
message::UserContent::document_url(url, None)
}
},
_ => {
return Err(MessageError::ConversionError(
"Unsupported content type for User role".to_owned(),
));
}
})
})?,
},
Role::Assistant => message::Message::Assistant {
id: None,
content: message.content.try_map(|content| content.try_into())?,
},
})
}
}
#[doc(hidden)]
#[derive(Clone)]
pub struct GenericCompletionModel<Ext = super::client::AnthropicExt, T = reqwest::Client> {
pub(crate) client: crate::client::Client<Ext, T>,
pub model: String,
pub default_max_tokens: Option<u64>,
pub prompt_caching: bool,
pub automatic_caching: bool,
pub automatic_caching_ttl: Option<CacheTtl>,
}
pub type CompletionModel<T = reqwest::Client> =
GenericCompletionModel<super::client::AnthropicExt, T>;
impl<Ext, T> GenericCompletionModel<Ext, T>
where
T: HttpClientExt,
Ext: AnthropicCompatibleProvider + Clone + 'static,
{
pub fn new(client: crate::client::Client<Ext, T>, model: impl Into<String>) -> Self {
let model = model.into();
let default_max_tokens = Ext::default_max_tokens(&model);
Self {
client,
model,
default_max_tokens,
prompt_caching: false,
automatic_caching: false,
automatic_caching_ttl: None,
}
}
pub fn with_model(client: crate::client::Client<Ext, T>, model: &str) -> Self {
Self {
client,
model: model.to_string(),
default_max_tokens: Ext::default_max_tokens(model)
.or_else(|| Some(default_max_tokens_with_fallback(model))),
prompt_caching: false,
automatic_caching: false,
automatic_caching_ttl: None,
}
}
pub fn with_prompt_caching(mut self) -> Self {
self.prompt_caching = true;
self
}
pub fn with_automatic_caching(mut self) -> Self {
self.automatic_caching = true;
self
}
pub fn with_automatic_caching_1h(mut self) -> Self {
self.automatic_caching = true;
self.automatic_caching_ttl = Some(CacheTtl::OneHour);
self
}
}
fn default_max_tokens_for_model(model: &str) -> Option<u64> {
if model.starts_with("claude-opus-4-7") || model.starts_with("claude-opus-4-6") {
Some(128_000)
} else if model.starts_with("claude-opus-4")
|| model.starts_with("claude-sonnet-4")
|| model.starts_with("claude-haiku-4-5")
{
Some(64_000)
} else {
None
}
}
fn default_max_tokens_with_fallback(model: &str) -> u64 {
default_max_tokens_for_model(model).unwrap_or(2_048)
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Metadata {
user_id: Option<String>,
}
#[derive(Default, Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolChoice {
#[default]
Auto,
Any,
None,
Tool {
name: String,
},
}
impl TryFrom<message::ToolChoice> for ToolChoice {
type Error = CompletionError;
fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
let res = match value {
message::ToolChoice::Auto => Self::Auto,
message::ToolChoice::None => Self::None,
message::ToolChoice::Required => Self::Any,
message::ToolChoice::Specific { function_names } => {
if function_names.len() != 1 {
return Err(CompletionError::ProviderError(
"Only one tool may be specified to be used by Claude".into(),
));
}
let Some(name) = function_names.into_iter().next() else {
return Err(CompletionError::ProviderError(
"Only one tool may be specified to be used by Claude".into(),
));
};
Self::Tool { name }
}
};
Ok(res)
}
}
fn sanitize_schema(schema: &mut serde_json::Value) {
use serde_json::Value;
if let Value::Object(obj) = schema {
let is_object_schema = obj.get("type") == Some(&Value::String("object".to_string()))
|| obj.contains_key("properties");
if is_object_schema && !obj.contains_key("additionalProperties") {
obj.insert("additionalProperties".to_string(), Value::Bool(false));
}
if let Some(Value::Object(properties)) = obj.get("properties") {
let prop_keys = properties.keys().cloned().map(Value::String).collect();
obj.insert("required".to_string(), Value::Array(prop_keys));
}
let is_numeric_schema = obj.get("type") == Some(&Value::String("integer".to_string()))
|| obj.get("type") == Some(&Value::String("number".to_string()));
if is_numeric_schema {
for key in [
"minimum",
"maximum",
"exclusiveMinimum",
"exclusiveMaximum",
"multipleOf",
] {
obj.remove(key);
}
}
if let Some(defs) = obj.get_mut("$defs")
&& let Value::Object(defs_obj) = defs
{
for (_, def_schema) in defs_obj.iter_mut() {
sanitize_schema(def_schema);
}
}
if let Some(properties) = obj.get_mut("properties")
&& let Value::Object(props) = properties
{
for (_, prop_value) in props.iter_mut() {
sanitize_schema(prop_value);
}
}
if let Some(items) = obj.get_mut("items") {
sanitize_schema(items);
}
if let Some(one_of) = obj.remove("oneOf") {
match obj.get_mut("anyOf") {
Some(Value::Array(existing)) => {
if let Value::Array(mut incoming) = one_of {
existing.append(&mut incoming);
}
}
_ => {
obj.insert("anyOf".to_string(), one_of);
}
}
}
for key in ["anyOf", "allOf"] {
if let Some(variants) = obj.get_mut(key)
&& let Value::Array(variants_array) = variants
{
for variant in variants_array.iter_mut() {
sanitize_schema(variant);
}
}
}
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum OutputFormat {
JsonSchema { schema: serde_json::Value },
}
#[derive(Debug, Deserialize, Serialize)]
struct OutputConfig {
format: OutputFormat,
}
#[derive(Debug, Deserialize, Serialize)]
struct AnthropicCompletionRequest {
model: String,
messages: Vec<Message>,
max_tokens: u64,
#[serde(skip_serializing_if = "Vec::is_empty")]
system: Vec<SystemContent>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
output_config: Option<OutputConfig>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
additional_params: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
}
fn set_content_cache_control(content: &mut Content, value: Option<CacheControl>) {
match content {
Content::Text { cache_control, .. } => *cache_control = value,
Content::Image { cache_control, .. } => *cache_control = value,
Content::ToolResult { cache_control, .. } => *cache_control = value,
Content::Document { cache_control, .. } => *cache_control = value,
_ => {}
}
}
pub fn apply_cache_control(system: &mut [SystemContent], messages: &mut [Message]) {
if let Some(SystemContent::Text { cache_control, .. }) = system.last_mut() {
*cache_control = Some(CacheControl::ephemeral());
}
for msg in messages.iter_mut() {
for content in msg.content.iter_mut() {
set_content_cache_control(content, None);
}
}
if let Some(last_msg) = messages.last_mut() {
set_content_cache_control(last_msg.content.last_mut(), Some(CacheControl::ephemeral()));
}
}
pub(super) fn split_system_messages_from_history(
history: Vec<message::Message>,
) -> (Vec<SystemContent>, Vec<message::Message>) {
let mut system = Vec::new();
let mut remaining = Vec::new();
for message in history {
match message {
message::Message::System { content } => {
if !content.is_empty() {
system.push(SystemContent::Text {
text: content,
cache_control: None,
});
}
}
other => remaining.push(other),
}
}
(system, remaining)
}
pub struct AnthropicRequestParams<'a> {
pub model: &'a str,
pub request: CompletionRequest,
pub prompt_caching: bool,
pub automatic_caching: bool,
pub automatic_caching_ttl: Option<CacheTtl>,
}
impl TryFrom<AnthropicRequestParams<'_>> for AnthropicCompletionRequest {
type Error = CompletionError;
fn try_from(params: AnthropicRequestParams<'_>) -> Result<Self, Self::Error> {
let AnthropicRequestParams {
model,
request: mut req,
prompt_caching,
automatic_caching,
automatic_caching_ttl,
} = params;
let Some(max_tokens) = req.max_tokens else {
return Err(CompletionError::RequestError(
"`max_tokens` must be set for Anthropic".into(),
));
};
let mut full_history = vec![];
if let Some(docs) = req.normalized_documents() {
full_history.push(docs);
}
full_history.extend(req.chat_history);
let (history_system, full_history) = split_system_messages_from_history(full_history);
let mut messages = full_history
.into_iter()
.map(Message::try_from)
.collect::<Result<Vec<Message>, _>>()?;
let mut additional_params_payload = req
.additional_params
.take()
.unwrap_or(serde_json::Value::Null);
let mut additional_tools =
extract_tools_from_additional_params(&mut additional_params_payload)?;
let mut tools = req
.tools
.into_iter()
.map(|tool| ToolDefinition {
name: tool.name,
description: Some(tool.description),
input_schema: tool.parameters,
})
.map(serde_json::to_value)
.collect::<Result<Vec<_>, _>>()?;
tools.append(&mut additional_tools);
let mut system = if let Some(preamble) = req.preamble {
if preamble.is_empty() {
vec![]
} else {
vec![SystemContent::Text {
text: preamble,
cache_control: None,
}]
}
} else {
vec![]
};
system.extend(history_system);
if prompt_caching {
apply_cache_control(&mut system, &mut messages);
}
let output_config = if let Some(schema) = req.output_schema {
let mut schema_value = schema.to_value();
sanitize_schema(&mut schema_value);
Some(OutputConfig {
format: OutputFormat::JsonSchema {
schema: schema_value,
},
})
} else {
None
};
Ok(Self {
model: model.to_string(),
messages,
max_tokens,
system,
temperature: req.temperature,
tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
tools,
output_config,
cache_control: if automatic_caching {
Some(CacheControl::Ephemeral {
ttl: automatic_caching_ttl,
})
} else {
None
},
additional_params: if additional_params_payload.is_null() {
None
} else {
Some(additional_params_payload)
},
})
}
}
fn extract_tools_from_additional_params(
additional_params: &mut serde_json::Value,
) -> Result<Vec<serde_json::Value>, CompletionError> {
if let Some(map) = additional_params.as_object_mut()
&& let Some(raw_tools) = map.remove("tools")
{
return serde_json::from_value::<Vec<serde_json::Value>>(raw_tools).map_err(|err| {
CompletionError::RequestError(
format!("Invalid Anthropic `additional_params.tools` payload: {err}").into(),
)
});
}
Ok(Vec::new())
}
impl<Ext, T> completion::CompletionModel for GenericCompletionModel<Ext, T>
where
T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
Ext: AnthropicCompatibleProvider + Clone + WasmCompatSend + WasmCompatSync + 'static,
{
type Response = CompletionResponse;
type StreamingResponse = StreamingCompletionResponse;
type Client = crate::client::Client<Ext, T>;
fn make(client: &Self::Client, model: impl Into<String>) -> Self {
Self::new(client.clone(), model.into())
}
async fn completion(
&self,
mut completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let request_model = completion_request
.model
.clone()
.unwrap_or_else(|| self.model.clone());
let span = if tracing::Span::current().is_disabled() {
info_span!(
target: "rig::completions",
"chat",
gen_ai.operation.name = "chat",
gen_ai.provider.name = Ext::PROVIDER_NAME,
gen_ai.request.model = &request_model,
gen_ai.system_instructions = &completion_request.preamble,
gen_ai.response.id = tracing::field::Empty,
gen_ai.response.model = tracing::field::Empty,
gen_ai.usage.output_tokens = tracing::field::Empty,
gen_ai.usage.input_tokens = tracing::field::Empty,
gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
)
} else {
tracing::Span::current()
};
if completion_request.max_tokens.is_none() {
if let Some(tokens) = self.default_max_tokens {
completion_request.max_tokens = Some(tokens);
} else {
return Err(CompletionError::RequestError(
"`max_tokens` must be set for Anthropic".into(),
));
}
}
let request = AnthropicCompletionRequest::try_from(AnthropicRequestParams {
model: &request_model,
request: completion_request,
prompt_caching: self.prompt_caching,
automatic_caching: self.automatic_caching,
automatic_caching_ttl: self.automatic_caching_ttl.clone(),
})?;
if enabled!(Level::TRACE) {
tracing::trace!(
target: "rig::completions",
"Anthropic completion request: {}",
serde_json::to_string_pretty(&request)?
);
}
async move {
let request: Vec<u8> = serde_json::to_vec(&request)?;
let req = self
.client
.post("/v1/messages")?
.body(request)
.map_err(|e| CompletionError::HttpError(e.into()))?;
let response = self
.client
.send::<_, Bytes>(req)
.await
.map_err(CompletionError::HttpError)?;
if response.status().is_success() {
match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
response
.into_body()
.await
.map_err(CompletionError::HttpError)?
.to_vec()
.as_slice(),
)? {
ApiResponse::Message(completion) => {
let span = tracing::Span::current();
span.record_response_metadata(&completion);
span.record_token_usage(&completion.usage);
if enabled!(Level::TRACE) {
tracing::trace!(
target: "rig::completions",
"Anthropic completion response: {}",
serde_json::to_string_pretty(&completion)?
);
}
completion.try_into()
}
ApiResponse::Error(ApiErrorResponse { message }) => {
Err(CompletionError::ResponseError(message))
}
}
} else {
let text: String = String::from_utf8_lossy(
&response
.into_body()
.await
.map_err(CompletionError::HttpError)?,
)
.into();
Err(CompletionError::ProviderError(text))
}
}
.instrument(span)
.await
}
async fn stream(
&self,
request: CompletionRequest,
) -> Result<
crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
CompletionError,
> {
GenericCompletionModel::stream(self, request).await
}
}
#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
message: String,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ApiResponse<T> {
Message(T),
Error(ApiErrorResponse),
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use serde_path_to_error::deserialize;
#[test]
fn current_model_default_max_tokens_match_anthropic_limits() {
assert_eq!(default_max_tokens_for_model(CLAUDE_OPUS_4_7), Some(128_000));
assert_eq!(default_max_tokens_for_model(CLAUDE_OPUS_4_6), Some(128_000));
assert_eq!(
default_max_tokens_for_model(CLAUDE_SONNET_4_6),
Some(64_000)
);
assert_eq!(default_max_tokens_for_model(CLAUDE_HAIKU_4_5), Some(64_000));
}
#[test]
fn unknown_model_uses_conservative_default_max_tokens_fallback() {
assert_eq!(default_max_tokens_for_model("claude-unknown"), None);
assert_eq!(default_max_tokens_with_fallback("claude-unknown"), 2_048);
}
#[test]
fn test_deserialize_message() {
let assistant_message_json = r#"
{
"role": "assistant",
"content": "\n\nHello there, how may I assist you today?"
}
"#;
let assistant_message_json2 = r#"
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "\n\nHello there, how may I assist you today?"
},
{
"type": "tool_use",
"id": "toolu_01A09q90qw90lq917835lq9",
"name": "get_weather",
"input": {"location": "San Francisco, CA"}
}
]
}
"#;
let user_message_json = r#"
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "/9j/4AAQSkZJRg..."
}
},
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "tool_result",
"tool_use_id": "toolu_01A09q90qw90lq917835lq9",
"content": "15 degrees"
}
]
}
"#;
let assistant_message: Message = {
let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
deserialize(jd).unwrap_or_else(|err| {
panic!("Deserialization error at {}: {}", err.path(), err);
})
};
let assistant_message2: Message = {
let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
deserialize(jd).unwrap_or_else(|err| {
panic!("Deserialization error at {}: {}", err.path(), err);
})
};
let user_message: Message = {
let jd = &mut serde_json::Deserializer::from_str(user_message_json);
deserialize(jd).unwrap_or_else(|err| {
panic!("Deserialization error at {}: {}", err.path(), err);
})
};
let Message { role, content } = assistant_message;
assert_eq!(role, Role::Assistant);
assert_eq!(
content.first(),
Content::Text {
text: "\n\nHello there, how may I assist you today?".to_owned(),
cache_control: None,
}
);
let Message { role, content } = assistant_message2;
{
assert_eq!(role, Role::Assistant);
assert_eq!(content.len(), 2);
let mut iter = content.into_iter();
match iter.next().unwrap() {
Content::Text { text, .. } => {
assert_eq!(text, "\n\nHello there, how may I assist you today?");
}
_ => panic!("Expected text content"),
}
match iter.next().unwrap() {
Content::ToolUse { id, name, input } => {
assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
assert_eq!(name, "get_weather");
assert_eq!(input, json!({"location": "San Francisco, CA"}));
}
_ => panic!("Expected tool use content"),
}
assert_eq!(iter.next(), None);
}
let Message { role, content } = user_message;
{
assert_eq!(role, Role::User);
assert_eq!(content.len(), 3);
let mut iter = content.into_iter();
match iter.next().unwrap() {
Content::Image { source, .. } => {
assert_eq!(
source,
ImageSource::Base64 {
data: "/9j/4AAQSkZJRg...".to_owned(),
media_type: ImageFormat::JPEG,
}
);
}
_ => panic!("Expected image content"),
}
match iter.next().unwrap() {
Content::Text { text, .. } => {
assert_eq!(text, "What is in this image?");
}
_ => panic!("Expected text content"),
}
match iter.next().unwrap() {
Content::ToolResult {
tool_use_id,
content,
is_error,
..
} => {
assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
assert_eq!(
content.first(),
ToolResultContent::Text {
text: "15 degrees".to_owned()
}
);
assert_eq!(is_error, None);
}
_ => panic!("Expected tool result content"),
}
assert_eq!(iter.next(), None);
}
}
#[test]
fn test_message_to_message_conversion() {
let user_message: Message = serde_json::from_str(
r#"
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "/9j/4AAQSkZJRg..."
}
},
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "document",
"source": {
"type": "base64",
"data": "base64_encoded_pdf_data",
"media_type": "application/pdf"
}
}
]
}
"#,
)
.unwrap();
let assistant_message = Message {
role: Role::Assistant,
content: OneOrMany::one(Content::ToolUse {
id: "toolu_01A09q90qw90lq917835lq9".to_string(),
name: "get_weather".to_string(),
input: json!({"location": "San Francisco, CA"}),
}),
};
let tool_message = Message {
role: Role::User,
content: OneOrMany::one(Content::ToolResult {
tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
content: OneOrMany::one(ToolResultContent::Text {
text: "15 degrees".to_string(),
}),
is_error: None,
cache_control: None,
}),
};
let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
let converted_assistant_message: message::Message =
assistant_message.clone().try_into().unwrap();
let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
match converted_user_message.clone() {
message::Message::User { content } => {
assert_eq!(content.len(), 3);
let mut iter = content.into_iter();
match iter.next().unwrap() {
message::UserContent::Image(message::Image {
data, media_type, ..
}) => {
assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
}
_ => panic!("Expected image content"),
}
match iter.next().unwrap() {
message::UserContent::Text(message::Text { text }) => {
assert_eq!(text, "What is in this image?");
}
_ => panic!("Expected text content"),
}
match iter.next().unwrap() {
message::UserContent::Document(message::Document {
data, media_type, ..
}) => {
assert_eq!(
data,
DocumentSourceKind::String("base64_encoded_pdf_data".into())
);
assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
}
_ => panic!("Expected document content"),
}
assert_eq!(iter.next(), None);
}
_ => panic!("Expected user message"),
}
match converted_tool_message.clone() {
message::Message::User { content } => {
let message::ToolResult { id, content, .. } = match content.first() {
message::UserContent::ToolResult(tool_result) => tool_result,
_ => panic!("Expected tool result content"),
};
assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
match content.first() {
message::ToolResultContent::Text(message::Text { text }) => {
assert_eq!(text, "15 degrees");
}
_ => panic!("Expected text content"),
}
}
_ => panic!("Expected tool result content"),
}
match converted_assistant_message.clone() {
message::Message::Assistant { content, .. } => {
assert_eq!(content.len(), 1);
match content.first() {
message::AssistantContent::ToolCall(message::ToolCall {
id, function, ..
}) => {
assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
assert_eq!(function.name, "get_weather");
assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
}
_ => panic!("Expected tool call content"),
}
}
_ => panic!("Expected assistant message"),
}
let original_user_message: Message = converted_user_message.try_into().unwrap();
let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
let original_tool_message: Message = converted_tool_message.try_into().unwrap();
assert_eq!(user_message, original_user_message);
assert_eq!(assistant_message, original_assistant_message);
assert_eq!(tool_message, original_tool_message);
}
#[test]
fn test_content_format_conversion() {
use crate::completion::message::ContentFormat;
let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
assert_eq!(source_type, SourceType::URL);
let content_format: ContentFormat = SourceType::URL.into();
assert_eq!(content_format, ContentFormat::Url);
let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
assert_eq!(source_type, SourceType::BASE64);
let content_format: ContentFormat = SourceType::BASE64.into();
assert_eq!(content_format, ContentFormat::Base64);
let source_type: SourceType = ContentFormat::String.try_into().unwrap();
assert_eq!(source_type, SourceType::TEXT);
let content_format: ContentFormat = SourceType::TEXT.into();
assert_eq!(content_format, ContentFormat::String);
}
#[test]
fn test_cache_control_serialization() {
let system = SystemContent::Text {
text: "You are a helpful assistant.".to_string(),
cache_control: Some(CacheControl::ephemeral()),
};
let json = serde_json::to_string(&system).unwrap();
assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
assert!(json.contains(r#""type":"text""#));
let system_no_cache = SystemContent::Text {
text: "Hello".to_string(),
cache_control: None,
};
let json_no_cache = serde_json::to_string(&system_no_cache).unwrap();
assert!(!json_no_cache.contains("cache_control"));
let content = Content::Text {
text: "Test message".to_string(),
cache_control: Some(CacheControl::ephemeral()),
};
let json_content = serde_json::to_string(&content).unwrap();
assert!(json_content.contains(r#""cache_control":{"type":"ephemeral"}"#));
let mut system_vec = vec![SystemContent::Text {
text: "System prompt".to_string(),
cache_control: None,
}];
let mut messages = vec![
Message {
role: Role::User,
content: OneOrMany::one(Content::Text {
text: "First message".to_string(),
cache_control: None,
}),
},
Message {
role: Role::Assistant,
content: OneOrMany::one(Content::Text {
text: "Response".to_string(),
cache_control: None,
}),
},
];
apply_cache_control(&mut system_vec, &mut messages);
match &system_vec[0] {
SystemContent::Text { cache_control, .. } => {
assert!(cache_control.is_some());
}
}
for content in messages[0].content.iter() {
if let Content::Text { cache_control, .. } = content {
assert!(cache_control.is_none());
}
}
for content in messages[1].content.iter() {
if let Content::Text { cache_control, .. } = content {
assert!(cache_control.is_some());
}
}
}
#[test]
fn test_plaintext_document_serialization() {
let content = Content::Document {
source: DocumentSource::Text {
data: "Hello, world!".to_string(),
media_type: PlainTextMediaType::Plain,
},
cache_control: None,
};
let json = serde_json::to_value(&content).unwrap();
assert_eq!(json["type"], "document");
assert_eq!(json["source"]["type"], "text");
assert_eq!(json["source"]["media_type"], "text/plain");
assert_eq!(json["source"]["data"], "Hello, world!");
}
#[test]
fn test_plaintext_document_deserialization() {
let json = r#"
{
"type": "document",
"source": {
"type": "text",
"media_type": "text/plain",
"data": "Hello, world!"
}
}
"#;
let content: Content = serde_json::from_str(json).unwrap();
match content {
Content::Document {
source,
cache_control,
} => {
assert_eq!(
source,
DocumentSource::Text {
data: "Hello, world!".to_string(),
media_type: PlainTextMediaType::Plain,
}
);
assert_eq!(cache_control, None);
}
_ => panic!("Expected Document content"),
}
}
#[test]
fn test_base64_pdf_document_serialization() {
let content = Content::Document {
source: DocumentSource::Base64 {
data: "base64data".to_string(),
media_type: DocumentFormat::PDF,
},
cache_control: None,
};
let json = serde_json::to_value(&content).unwrap();
assert_eq!(json["type"], "document");
assert_eq!(json["source"]["type"], "base64");
assert_eq!(json["source"]["media_type"], "application/pdf");
assert_eq!(json["source"]["data"], "base64data");
}
#[test]
fn test_base64_pdf_document_deserialization() {
let json = r#"
{
"type": "document",
"source": {
"type": "base64",
"media_type": "application/pdf",
"data": "base64data"
}
}
"#;
let content: Content = serde_json::from_str(json).unwrap();
match content {
Content::Document { source, .. } => {
assert_eq!(
source,
DocumentSource::Base64 {
data: "base64data".to_string(),
media_type: DocumentFormat::PDF,
}
);
}
_ => panic!("Expected Document content"),
}
}
#[test]
fn test_plaintext_rig_to_anthropic_conversion() {
use crate::completion::message as msg;
let rig_message = msg::Message::User {
content: OneOrMany::one(msg::UserContent::document(
"Some plain text content".to_string(),
Some(msg::DocumentMediaType::TXT),
)),
};
let anthropic_message: Message = rig_message.try_into().unwrap();
assert_eq!(anthropic_message.role, Role::User);
let mut iter = anthropic_message.content.into_iter();
match iter.next().unwrap() {
Content::Document { source, .. } => {
assert_eq!(
source,
DocumentSource::Text {
data: "Some plain text content".to_string(),
media_type: PlainTextMediaType::Plain,
}
);
}
other => panic!("Expected Document content, got: {other:?}"),
}
}
#[test]
fn test_plaintext_anthropic_to_rig_conversion() {
use crate::completion::message as msg;
let anthropic_message = Message {
role: Role::User,
content: OneOrMany::one(Content::Document {
source: DocumentSource::Text {
data: "Some plain text content".to_string(),
media_type: PlainTextMediaType::Plain,
},
cache_control: None,
}),
};
let rig_message: msg::Message = anthropic_message.try_into().unwrap();
match rig_message {
msg::Message::User { content } => {
let mut iter = content.into_iter();
match iter.next().unwrap() {
msg::UserContent::Document(msg::Document {
data, media_type, ..
}) => {
assert_eq!(
data,
DocumentSourceKind::String("Some plain text content".into())
);
assert_eq!(media_type, Some(msg::DocumentMediaType::TXT));
}
other => panic!("Expected Document content, got: {other:?}"),
}
}
_ => panic!("Expected User message"),
}
}
#[test]
fn test_plaintext_roundtrip_rig_to_anthropic_and_back() {
use crate::completion::message as msg;
let original = msg::Message::User {
content: OneOrMany::one(msg::UserContent::document(
"Round trip text".to_string(),
Some(msg::DocumentMediaType::TXT),
)),
};
let anthropic: Message = original.clone().try_into().unwrap();
let back: msg::Message = anthropic.try_into().unwrap();
match (&original, &back) {
(
msg::Message::User {
content: orig_content,
},
msg::Message::User {
content: back_content,
},
) => match (orig_content.first(), back_content.first()) {
(
msg::UserContent::Document(msg::Document {
media_type: orig_mt,
..
}),
msg::UserContent::Document(msg::Document {
media_type: back_mt,
..
}),
) => {
assert_eq!(orig_mt, back_mt);
}
_ => panic!("Expected Document content in both"),
},
_ => panic!("Expected User messages"),
}
}
#[test]
fn test_unsupported_document_type_returns_error() {
use crate::completion::message as msg;
let rig_message = msg::Message::User {
content: OneOrMany::one(msg::UserContent::Document(msg::Document {
data: DocumentSourceKind::String("data".into()),
media_type: Some(msg::DocumentMediaType::HTML),
additional_params: None,
})),
};
let result: Result<Message, _> = rig_message.try_into();
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Anthropic only supports PDF and plain text documents"),
"Unexpected error: {err}"
);
}
#[test]
fn test_plaintext_document_url_source_returns_error() {
use crate::completion::message as msg;
let rig_message = msg::Message::User {
content: OneOrMany::one(msg::UserContent::Document(msg::Document {
data: DocumentSourceKind::Url("https://example.com/doc.txt".into()),
media_type: Some(msg::DocumentMediaType::TXT),
additional_params: None,
})),
};
let result: Result<Message, _> = rig_message.try_into();
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Only string or base64 data is supported for plain text documents"),
"Unexpected error: {err}"
);
}
#[test]
fn test_plaintext_document_with_cache_control() {
let content = Content::Document {
source: DocumentSource::Text {
data: "cached text".to_string(),
media_type: PlainTextMediaType::Plain,
},
cache_control: Some(CacheControl::ephemeral()),
};
let json = serde_json::to_value(&content).unwrap();
assert_eq!(json["source"]["type"], "text");
assert_eq!(json["source"]["media_type"], "text/plain");
assert_eq!(json["cache_control"]["type"], "ephemeral");
}
#[test]
fn test_message_with_plaintext_document_deserialization() {
let json = r#"
{
"role": "user",
"content": [
{
"type": "document",
"source": {
"type": "text",
"media_type": "text/plain",
"data": "Hello from a text file"
}
},
{
"type": "text",
"text": "Summarize this document."
}
]
}
"#;
let message: Message = serde_json::from_str(json).unwrap();
assert_eq!(message.role, Role::User);
assert_eq!(message.content.len(), 2);
let mut iter = message.content.into_iter();
match iter.next().unwrap() {
Content::Document { source, .. } => {
assert_eq!(
source,
DocumentSource::Text {
data: "Hello from a text file".to_string(),
media_type: PlainTextMediaType::Plain,
}
);
}
_ => panic!("Expected Document content"),
}
match iter.next().unwrap() {
Content::Text { text, .. } => {
assert_eq!(text, "Summarize this document.");
}
_ => panic!("Expected Text content"),
}
}
#[test]
fn test_assistant_reasoning_multiblock_to_anthropic_content() {
let reasoning = message::Reasoning {
id: None,
content: vec![
message::ReasoningContent::Text {
text: "step one".to_string(),
signature: Some("sig-1".to_string()),
},
message::ReasoningContent::Summary("summary".to_string()),
message::ReasoningContent::Text {
text: "step two".to_string(),
signature: Some("sig-2".to_string()),
},
message::ReasoningContent::Redacted {
data: "redacted block".to_string(),
},
],
};
let msg = message::Message::Assistant {
id: None,
content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
};
let converted: Message = msg.try_into().expect("convert assistant message");
let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
assert_eq!(converted.role, Role::Assistant);
assert_eq!(converted_content.len(), 4);
assert!(matches!(
converted_content.first(),
Some(Content::Thinking { thinking, signature: Some(signature) })
if thinking == "step one" && signature == "sig-1"
));
assert!(matches!(
converted_content.get(1),
Some(Content::Thinking { thinking, signature: None }) if thinking == "summary"
));
assert!(matches!(
converted_content.get(2),
Some(Content::Thinking { thinking, signature: Some(signature) })
if thinking == "step two" && signature == "sig-2"
));
assert!(matches!(
converted_content.get(3),
Some(Content::RedactedThinking { data }) if data == "redacted block"
));
}
#[test]
fn test_redacted_thinking_content_to_assistant_reasoning() {
let content = Content::RedactedThinking {
data: "opaque-redacted".to_string(),
};
let converted: message::AssistantContent =
content.try_into().expect("convert redacted thinking");
assert!(matches!(
converted,
message::AssistantContent::Reasoning(message::Reasoning { content, .. })
if matches!(
content.first(),
Some(message::ReasoningContent::Redacted { data }) if data == "opaque-redacted"
)
));
}
#[test]
fn test_assistant_encrypted_reasoning_maps_to_redacted_thinking() {
let reasoning = message::Reasoning {
id: None,
content: vec![message::ReasoningContent::Encrypted(
"ciphertext".to_string(),
)],
};
let msg = message::Message::Assistant {
id: None,
content: OneOrMany::one(message::AssistantContent::Reasoning(reasoning)),
};
let converted: Message = msg.try_into().expect("convert assistant message");
let converted_content = converted.content.iter().cloned().collect::<Vec<_>>();
assert_eq!(converted_content.len(), 1);
assert!(matches!(
converted_content.first(),
Some(Content::RedactedThinking { data }) if data == "ciphertext"
));
}
#[test]
fn empty_end_turn_response_normalizes_to_empty_text_choice() {
let response = CompletionResponse {
content: vec![],
id: "msg_123".to_string(),
model: CLAUDE_SONNET_4_6.to_string(),
role: "assistant".to_string(),
stop_reason: Some("end_turn".to_string()),
stop_sequence: None,
usage: Usage {
input_tokens: 7,
cache_read_input_tokens: None,
cache_creation_input_tokens: None,
output_tokens: 2,
},
};
let parsed: completion::CompletionResponse<CompletionResponse> = response
.try_into()
.expect("empty end_turn should not error");
assert_eq!(parsed.choice.len(), 1);
assert!(matches!(
parsed.choice.first(),
completion::AssistantContent::Text(text) if text.text.is_empty()
));
}
#[test]
fn empty_non_end_turn_response_still_errors() {
let response = CompletionResponse {
content: vec![],
id: "msg_123".to_string(),
model: CLAUDE_SONNET_4_6.to_string(),
role: "assistant".to_string(),
stop_reason: Some("tool_use".to_string()),
stop_sequence: None,
usage: Usage {
input_tokens: 7,
cache_read_input_tokens: None,
cache_creation_input_tokens: None,
output_tokens: 2,
},
};
let err = completion::CompletionResponse::<CompletionResponse>::try_from(response)
.expect_err("empty non-end_turn should remain an error");
assert!(matches!(
err,
CompletionError::ResponseError(message) if message == EMPTY_RESPONSE_ERROR
));
}
}