use crate::llm::retry;
use crate::llm::traits::AiProvider;
use crate::llm::types::{
ChatCompletionParams, Message, ProviderExchange, ProviderResponse, TokenUsage, ToolCall,
};
use crate::llm::utils::normalize_model_name;
use anyhow::Result;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::env;
const PRICING: &[(&str, f64, f64)] = &[
("claude-opus-4-5", 5.00, 25.00), ("claude-haiku-4-5", 1.00, 5.00),
("claude-sonnet-4-5", 3.00, 15.00), ("claude-opus-4-1", 15.00, 75.00),
("claude-opus-4-0", 15.00, 75.00),
("claude-opus-4", 15.00, 75.00),
("claude-sonnet-4-0", 3.00, 15.00),
("claude-sonnet-4", 3.00, 15.00),
("claude-3-7-sonnet", 3.00, 15.00),
("claude-3-5-sonnet", 3.00, 15.00),
("claude-3-5-haiku", 0.80, 4.00),
("claude-3-opus", 15.00, 75.00),
("claude-3-sonnet", 3.00, 15.00),
("claude-3-haiku", 0.25, 1.25),
];
struct CacheTokenUsage {
regular_input_tokens: u64,
cache_creation_tokens: u64,
cache_creation_tokens_1h: u64, cache_read_tokens: u64,
output_tokens: u64,
}
fn supports_temperature_and_top_p(model: &str) -> bool {
let unsupported_prefixes = [
"opus-4-1",
"sonnet-4-5",
"claude-haiku-4-5",
"claude-opus-4-5",
];
!unsupported_prefixes
.iter()
.any(|prefix| model.contains(prefix))
}
fn calculate_cost_with_cache(model: &str, usage: CacheTokenUsage) -> Option<f64> {
for (pricing_model, input_price, output_price) in PRICING {
if model.contains(pricing_model) {
let regular_input_cost =
(usage.regular_input_tokens as f64 / 1_000_000.0) * input_price;
let cache_creation_cost =
(usage.cache_creation_tokens as f64 / 1_000_000.0) * input_price * 1.25;
let cache_creation_cost_1h =
(usage.cache_creation_tokens_1h as f64 / 1_000_000.0) * input_price * 2.0;
let cache_read_cost =
(usage.cache_read_tokens as f64 / 1_000_000.0) * input_price * 0.1;
let output_cost = (usage.output_tokens as f64 / 1_000_000.0) * output_price;
let total_cost = regular_input_cost
+ cache_creation_cost
+ cache_creation_cost_1h
+ cache_read_cost
+ output_cost;
return Some(total_cost);
}
}
None
}
fn calculate_anthropic_cost(
model: &str,
input_tokens: u32,
output_tokens: u32,
cache_creation_input_tokens: u32,
cache_read_input_tokens: u32,
) -> Option<f64> {
let cache_creation_1h_tokens = if cache_creation_input_tokens > 0 {
cache_creation_input_tokens
} else {
0
};
let regular_input_tokens =
input_tokens.saturating_sub(cache_creation_input_tokens + cache_read_input_tokens);
let usage = CacheTokenUsage {
regular_input_tokens: regular_input_tokens as u64,
cache_creation_tokens: 0, cache_creation_tokens_1h: cache_creation_1h_tokens as u64,
cache_read_tokens: cache_read_input_tokens as u64,
output_tokens: output_tokens as u64,
};
calculate_cost_with_cache(model, usage)
}
#[derive(Debug, Clone)]
pub struct AnthropicProvider;
impl Default for AnthropicProvider {
fn default() -> Self {
Self::new()
}
}
impl AnthropicProvider {
pub fn new() -> Self {
Self
}
}
const ANTHROPIC_API_KEY_ENV: &str = "ANTHROPIC_API_KEY";
const ANTHROPIC_OAUTH_TOKEN_ENV: &str = "ANTHROPIC_OAUTH_ACCESS_TOKEN";
const ANTHROPIC_API_URL_ENV: &str = "ANTHROPIC_API_URL";
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
#[async_trait::async_trait]
impl AiProvider for AnthropicProvider {
fn name(&self) -> &str {
"anthropic"
}
fn supports_model(&self, model: &str) -> bool {
let normalized = normalize_model_name(model);
normalized.starts_with("claude-") || normalized.contains("claude")
}
fn get_api_key(&self) -> Result<String> {
if env::var(ANTHROPIC_OAUTH_TOKEN_ENV).is_ok() {
return Err(anyhow::anyhow!(
"Using OAuth authentication. API key not available when {} is set.",
ANTHROPIC_OAUTH_TOKEN_ENV
));
}
match env::var(ANTHROPIC_API_KEY_ENV) {
Ok(key) => Ok(key),
Err(_) => Err(anyhow::anyhow!(
"Anthropic API key not found in environment variable: {}. Set either {} for API key auth or {} for OAuth.",
ANTHROPIC_API_KEY_ENV,
ANTHROPIC_API_KEY_ENV,
ANTHROPIC_OAUTH_TOKEN_ENV
)),
}
}
fn supports_caching(&self, _model: &str) -> bool {
true
}
fn supports_vision(&self, model: &str) -> bool {
let model_lower = normalize_model_name(model);
model_lower.contains("claude-3")
|| model_lower.contains("claude-4")
|| model_lower.contains("claude-3.5")
|| model_lower.contains("claude-3.7")
}
fn get_max_input_tokens(&self, model: &str) -> usize {
let model_lower = normalize_model_name(model);
if model_lower.contains("claude-opus-4")
|| model_lower.contains("claude-sonnet-4")
|| model_lower.contains("claude-haiku-4")
{
200_000
} else if model_lower.contains("claude-3-7") {
200_000
} else if model_lower.contains("claude-3-5") {
200_000
} else if model_lower.contains("claude-3") {
200_000
} else {
100_000
}
}
async fn chat_completion(&self, params: ChatCompletionParams) -> Result<ProviderResponse> {
let (auth_header_name, auth_header_value) =
if let Ok(oauth_token) = env::var(ANTHROPIC_OAUTH_TOKEN_ENV) {
(
"Authorization".to_string(),
format!("Bearer {}", oauth_token),
)
} else {
let api_key = self.get_api_key()?;
("x-api-key".to_string(), api_key)
};
let anthropic_messages = convert_messages(¶ms.messages);
let system_message = params
.messages
.iter()
.find(|m| m.role == "system")
.map(|m| m.content.clone())
.unwrap_or_else(|| "You are a helpful assistant.".to_string());
let system_cached = params
.messages
.iter()
.any(|m| m.role == "system" && m.cached);
let mut request_body = serde_json::json!({
"model": params.model,
"messages": anthropic_messages,
});
request_body["temperature"] = serde_json::json!(params.temperature);
if supports_temperature_and_top_p(¶ms.model) {
request_body["top_p"] = serde_json::json!(params.top_p);
}
request_body["top_k"] = serde_json::json!(params.top_k);
if params.max_tokens > 0 {
request_body["max_tokens"] = serde_json::json!(params.max_tokens);
}
if system_cached {
let cache_ttl = crate::llm::config::CacheTTL::short();
request_body["system"] = serde_json::json!([{
"type": "text",
"text": system_message,
"cache_control": {
"type": "ephemeral",
"ttl": cache_ttl.to_string()
}
}]);
} else {
request_body["system"] = serde_json::json!(system_message);
}
if let Some(tools) = ¶ms.tools {
if !tools.is_empty() {
let mut sorted_tools = tools.clone();
sorted_tools.sort_by(|a, b| a.name.cmp(&b.name));
let anthropic_tools = sorted_tools
.iter()
.map(|f| {
let mut tool = serde_json::json!({
"name": f.name,
"description": f.description,
"input_schema": f.parameters
});
if let Some(ref cache_control) = f.cache_control {
tool["cache_control"] = cache_control.clone();
}
tool
})
.collect::<Vec<_>>();
request_body["tools"] = serde_json::json!(anthropic_tools);
}
}
let api_url =
env::var(ANTHROPIC_API_URL_ENV).unwrap_or_else(|_| ANTHROPIC_API_URL.to_string());
let response = execute_anthropic_request(
auth_header_name,
auth_header_value,
api_url,
request_body,
params.max_retries,
params.retry_timeout,
params.cancellation_token.as_ref(),
)
.await?;
Ok(response)
}
}
#[derive(Serialize, Deserialize, Debug)]
struct AnthropicMessage {
role: String,
content: Vec<AnthropicContent>,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
enum AnthropicContent {
#[serde(rename = "text")]
Text {
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<serde_json::Value>,
},
#[serde(rename = "image")]
Image {
source: ImageSource,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<serde_json::Value>,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<serde_json::Value>,
},
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
}
#[derive(Serialize, Deserialize, Debug)]
struct ImageSource {
#[serde(rename = "type")]
source_type: String,
media_type: String,
data: String,
}
#[derive(Deserialize, Debug)]
struct AnthropicResponse {
content: Vec<AnthropicResponseContent>,
usage: AnthropicUsage,
#[serde(default)]
stop_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type")]
enum AnthropicResponseContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
}
#[derive(Deserialize, Debug)]
struct AnthropicUsage {
input_tokens: u64,
output_tokens: u64,
#[serde(default)]
cache_creation_input_tokens: Option<u64>,
#[serde(default)]
cache_read_input_tokens: Option<u64>,
}
fn convert_messages(messages: &[Message]) -> Vec<AnthropicMessage> {
let mut result = Vec::new();
for message in messages {
if message.role == "system" {
continue;
}
match message.role.as_str() {
"tool" => {
let tool_call_id = message.tool_call_id.as_deref().unwrap_or("");
let content = vec![AnthropicContent::ToolResult {
tool_use_id: tool_call_id.to_string(),
content: message.content.clone(),
cache_control: if message.cached {
Some(serde_json::json!({"type": "ephemeral"}))
} else {
None
},
}];
result.push(AnthropicMessage {
role: "user".to_string(), content,
});
}
_ => {
if message.role == "assistant" && message.tool_calls.is_some() {
let mut content = Vec::new();
if !message.content.trim().is_empty() {
content.push(AnthropicContent::Text {
text: message.content.clone(),
cache_control: if message.cached {
Some(serde_json::json!({"type": "ephemeral"}))
} else {
None
},
});
}
if let Some(ref tool_calls_data) = message.tool_calls {
if let Ok(generic_calls) = serde_json::from_value::<
Vec<crate::llm::tool_calls::GenericToolCall>,
>(tool_calls_data.clone())
{
for call in generic_calls {
content.push(AnthropicContent::ToolUse {
id: call.id,
name: call.name,
input: call.arguments,
});
}
}
}
result.push(AnthropicMessage {
role: message.role.clone(),
content,
});
} else {
let mut content = vec![AnthropicContent::Text {
text: message.content.clone(),
cache_control: if message.cached {
Some(serde_json::json!({"type": "ephemeral"}))
} else {
None
},
}];
if let Some(images) = &message.images {
for image in images {
if let crate::llm::types::ImageData::Base64(data) = &image.data {
content.push(AnthropicContent::Image {
source: ImageSource {
source_type: "base64".to_string(),
media_type: image.media_type.clone(),
data: data.clone(),
},
cache_control: None,
});
}
}
}
result.push(AnthropicMessage {
role: message.role.clone(),
content,
});
}
}
}
}
result
}
async fn execute_anthropic_request(
auth_header_name: String,
auth_header_value: String,
api_url: String,
request_body: serde_json::Value,
max_retries: u32,
base_timeout: std::time::Duration,
cancellation_token: Option<&tokio::sync::watch::Receiver<bool>>,
) -> Result<ProviderResponse> {
let client = Client::new();
let start_time = std::time::Instant::now();
let response = retry::retry_with_exponential_backoff(
|| {
let client = client.clone();
let auth_header_name = auth_header_name.clone();
let auth_header_value = auth_header_value.clone();
let api_url = api_url.clone();
let request_body = request_body.clone();
Box::pin(async move {
client
.post(&api_url)
.header("Content-Type", "application/json")
.header(&auth_header_name, &auth_header_value)
.header("anthropic-version", "2023-06-01")
.header("anthropic-beta", "prompt-caching-2024-07-31")
.json(&request_body)
.send()
.await
})
},
max_retries,
base_timeout,
cancellation_token,
)
.await?;
let request_time_ms = start_time.elapsed().as_millis() as u64;
let mut rate_limit_headers = std::collections::HashMap::new();
let headers = response.headers();
if let Some(tokens_limit) = headers
.get("anthropic-ratelimit-tokens-limit")
.and_then(|h| h.to_str().ok())
{
rate_limit_headers.insert("tokens_limit".to_string(), tokens_limit.to_string());
}
if let Some(tokens_remaining) = headers
.get("anthropic-ratelimit-tokens-remaining")
.and_then(|h| h.to_str().ok())
{
rate_limit_headers.insert("tokens_remaining".to_string(), tokens_remaining.to_string());
}
if let Some(input_tokens_limit) = headers
.get("anthropic-ratelimit-input-tokens-limit")
.and_then(|h| h.to_str().ok())
{
rate_limit_headers.insert(
"input_tokens_limit".to_string(),
input_tokens_limit.to_string(),
);
}
if let Some(input_tokens_remaining) = headers
.get("anthropic-ratelimit-input-tokens-remaining")
.and_then(|h| h.to_str().ok())
{
rate_limit_headers.insert(
"input_tokens_remaining".to_string(),
input_tokens_remaining.to_string(),
);
}
if let Some(output_tokens_limit) = headers
.get("anthropic-ratelimit-output-tokens-limit")
.and_then(|h| h.to_str().ok())
{
rate_limit_headers.insert(
"output_tokens_limit".to_string(),
output_tokens_limit.to_string(),
);
}
if let Some(output_tokens_remaining) = headers
.get("anthropic-ratelimit-output-tokens-remaining")
.and_then(|h| h.to_str().ok())
{
rate_limit_headers.insert(
"output_tokens_remaining".to_string(),
output_tokens_remaining.to_string(),
);
}
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(anyhow::anyhow!(
"Anthropic API error {}: {}",
status,
error_text
));
}
let response_text = response.text().await?;
let anthropic_response: AnthropicResponse = serde_json::from_str(&response_text)?;
let mut content_parts = Vec::new();
let mut tool_calls = Vec::new();
for content in anthropic_response.content {
match content {
AnthropicResponseContent::Text { text } => {
content_parts.push(text);
}
AnthropicResponseContent::ToolUse { id, name, input } => {
tool_calls.push(ToolCall {
id: id.clone(),
name: name.clone(),
arguments: input,
});
}
}
}
let content = content_parts.join("\n");
let cached_tokens = anthropic_response
.usage
.cache_read_input_tokens
.unwrap_or(0);
let cache_creation_tokens = anthropic_response
.usage
.cache_creation_input_tokens
.unwrap_or(0);
let cost = calculate_anthropic_cost(
request_body["model"].as_str().unwrap_or(""),
anthropic_response.usage.input_tokens as u32,
anthropic_response.usage.output_tokens as u32,
cache_creation_tokens as u32,
cached_tokens as u32,
);
let usage = TokenUsage {
prompt_tokens: anthropic_response.usage.input_tokens,
output_tokens: anthropic_response.usage.output_tokens,
reasoning_tokens: 0, total_tokens: anthropic_response.usage.input_tokens
+ anthropic_response.usage.output_tokens,
cached_tokens,
cost,
request_time_ms: Some(request_time_ms),
};
let mut response_json: serde_json::Value = serde_json::from_str(&response_text)?;
if !tool_calls.is_empty() {
let generic_calls: Vec<crate::llm::tool_calls::GenericToolCall> = tool_calls
.iter()
.map(|tc| crate::llm::tool_calls::GenericToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
meta: None, })
.collect();
response_json["tool_calls"] = serde_json::to_value(&generic_calls).unwrap_or_default();
}
let exchange = if rate_limit_headers.is_empty() {
ProviderExchange::new(request_body, response_json, Some(usage), "anthropic")
} else {
ProviderExchange::with_rate_limit_headers(
request_body,
response_json,
Some(usage),
"anthropic",
rate_limit_headers,
)
};
Ok(ProviderResponse {
content,
thinking: None, exchange,
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
finish_reason: anthropic_response.stop_reason,
structured_output: None, })
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[test]
#[serial]
fn test_oauth_token_priority() {
let provider = AnthropicProvider::new();
env::set_var(ANTHROPIC_OAUTH_TOKEN_ENV, "test-oauth-token");
let result = provider.get_api_key();
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("OAuth authentication"));
env::remove_var(ANTHROPIC_OAUTH_TOKEN_ENV);
}
#[test]
#[serial]
fn test_api_key_fallback() {
let provider = AnthropicProvider::new();
env::remove_var(ANTHROPIC_OAUTH_TOKEN_ENV);
env::set_var(ANTHROPIC_API_KEY_ENV, "test-api-key");
let result = provider.get_api_key();
assert!(result.is_ok());
assert_eq!(result.unwrap(), "test-api-key");
env::remove_var(ANTHROPIC_API_KEY_ENV);
}
#[test]
#[serial]
fn test_no_auth_error() {
let provider = AnthropicProvider::new();
env::remove_var(ANTHROPIC_OAUTH_TOKEN_ENV);
env::remove_var(ANTHROPIC_API_KEY_ENV);
let result = provider.get_api_key();
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(
error_msg.contains("ANTHROPIC_API_KEY") || error_msg.contains("ANTHROPIC_OAUTH_TOKEN")
);
}
#[test]
fn test_supports_model_case_insensitive() {
let provider = AnthropicProvider::new();
assert!(provider.supports_model("claude-3-haiku"));
assert!(provider.supports_model("claude-3-5-sonnet"));
assert!(provider.supports_model("CLAUDE-3-HAIKU"));
assert!(provider.supports_model("CLAUDE-3-5-SONNET"));
assert!(provider.supports_model("ClaUde-3-Haiku"));
assert!(provider.supports_model("CLAUDE-3-7-sonnet"));
}
#[test]
fn test_supports_vision_case_insensitive() {
let provider = AnthropicProvider::new();
assert!(provider.supports_vision("claude-3-haiku"));
assert!(provider.supports_vision("claude-3-5-sonnet"));
assert!(provider.supports_vision("CLAUDE-3-HAIKU"));
assert!(provider.supports_vision("CLAUDE-3-5-SONNET"));
assert!(provider.supports_vision("ClaUde-3-7"));
}
}