use super::error::{ProviderError, Result};
use super::r#trait::{Provider, ProviderStream};
use super::types::*;
use async_trait::async_trait;
use futures::stream::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
const ANTHROPIC_MODELS_URL: &str = "https://api.anthropic.com/v1/models";
const ANTHROPIC_VERSION: &str = "2023-06-01";
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300); const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); const DEFAULT_POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(90);
#[derive(Clone)]
pub struct AnthropicProvider {
api_key: String,
client: Client,
custom_default_model: Option<String>,
}
impl AnthropicProvider {
pub fn new(api_key: String) -> Self {
let client = Client::builder()
.timeout(DEFAULT_TIMEOUT) .connect_timeout(DEFAULT_CONNECT_TIMEOUT) .pool_idle_timeout(DEFAULT_POOL_IDLE_TIMEOUT) .pool_max_idle_per_host(2) .build()
.expect("Failed to create HTTP client");
Self {
api_key,
client,
custom_default_model: None,
}
}
pub fn with_client(api_key: String, client: Client) -> Self {
Self {
api_key,
client,
custom_default_model: None,
}
}
pub fn with_default_model(mut self, model: String) -> Self {
self.custom_default_model = Some(model);
self
}
fn headers(&self) -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"x-api-key",
self.api_key.parse().expect("Invalid API key format"),
);
headers.insert(
"anthropic-beta",
"prompt-caching-2024-07-31"
.parse()
.expect("Invalid beta header"),
);
headers.insert(
"anthropic-version",
ANTHROPIC_VERSION.parse().expect("Invalid version"),
);
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().expect("valid content-type"),
);
headers
}
fn request_headers(&self, _request: &LLMRequest) -> reqwest::header::HeaderMap {
self.headers()
}
fn to_anthropic_request(&self, request: LLMRequest) -> AnthropicRequest {
let cache = AnthropicCacheControl {
cache_type: "ephemeral".to_string(),
};
let system = request.system.map(|s| {
AnthropicSystem::Blocks(vec![AnthropicSystemBlock {
block_type: "text".to_string(),
text: s,
cache_control: Some(cache.clone()),
}])
});
let tools = request.tools.map(|tools| {
let len = tools.len();
tools
.into_iter()
.enumerate()
.map(|(i, t)| AnthropicTool {
name: t.name,
description: t.description,
input_schema: t.input_schema,
cache_control: if i == len - 1 {
Some(cache.clone())
} else {
None
},
})
.collect()
});
AnthropicRequest {
model: request.model,
messages: request.messages,
system,
max_tokens: request.max_tokens.unwrap_or(16384),
temperature: request.temperature,
tools,
stream: Some(request.stream),
metadata: request.metadata,
}
}
#[allow(clippy::wrong_self_convention)]
fn from_anthropic_response(&self, response: AnthropicResponse) -> LLMResponse {
LLMResponse {
id: response.id,
model: response.model,
content: response.content,
stop_reason: response.stop_reason,
usage: response.usage,
}
}
async fn handle_error(&self, response: reqwest::Response) -> ProviderError {
let status = response.status().as_u16();
let retry_after = response.headers().get("retry-after").and_then(|v| {
v.to_str().ok().and_then(|s| {
s.parse::<u64>().ok()
})
});
let body_bytes = response.bytes().await.unwrap_or_default();
tracing::debug!(
"Anthropic error response ({}): {}",
status,
String::from_utf8_lossy(&body_bytes)
.chars()
.take(500)
.collect::<String>()
);
if let Ok(error_body) = serde_json::from_slice::<AnthropicError>(&body_bytes) {
let message = if status == 429 {
if let Some(secs) = retry_after {
format!(
"{} (retry after {} seconds)",
error_body.error.message, secs
)
} else {
format!(
"{} (rate limited, please retry later)",
error_body.error.message
)
}
} else {
error_body.error.message
};
return if status == 429 {
ProviderError::RateLimitExceeded(message)
} else {
ProviderError::ApiError {
status,
message,
error_type: Some(error_body.error.error_type),
}
};
}
if status == 429 {
let message = if let Some(secs) = retry_after {
format!("Rate limit exceeded (retry after {} seconds)", secs)
} else {
"Rate limit exceeded, please retry later".to_string()
};
ProviderError::RateLimitExceeded(message)
} else {
ProviderError::ApiError {
status,
message: "Unknown error".to_string(),
error_type: None,
}
}
}
}
#[async_trait]
impl Provider for AnthropicProvider {
async fn complete(&self, request: LLMRequest) -> Result<LLMResponse> {
use super::retry::{RetryConfig, retry_with_backoff};
let model = request.model.clone();
let message_count = request.messages.len();
tracing::info!(
"Anthropic API request: model={}, messages={}, max_tokens={}",
model,
message_count,
request.max_tokens.unwrap_or(16384)
);
let req_headers = self.request_headers(&request);
let anthropic_request = self.to_anthropic_request(request);
let retry_config = RetryConfig::default();
let result = retry_with_backoff(
|| async {
tracing::debug!("Sending request to Anthropic API");
let response = self
.client
.post(ANTHROPIC_API_URL)
.headers(req_headers.clone())
.json(&anthropic_request)
.send()
.await?;
let status = response.status();
tracing::debug!("Anthropic API response status: {}", status);
if !status.is_success() {
return Err(self.handle_error(response).await);
}
let anthropic_response: AnthropicResponse = response.json().await?;
let llm_response = self.from_anthropic_response(anthropic_response);
tracing::info!(
"Anthropic API response: input_tokens={}, output_tokens={}, stop_reason={:?}",
llm_response.usage.input_tokens,
llm_response.usage.output_tokens,
llm_response.stop_reason
);
Ok(llm_response)
},
&retry_config,
)
.await;
if let Err(ref e) = result {
tracing::error!("Anthropic API request failed: {}", e);
}
result
}
async fn stream(&self, request: LLMRequest) -> Result<ProviderStream> {
use super::retry::{RetryConfig, retry_with_backoff};
let model = request.model.clone();
let message_count = request.messages.len();
tracing::info!(
"Anthropic streaming request: model={}, messages={}",
model,
message_count
);
let req_headers = self.request_headers(&request);
let mut anthropic_request = self.to_anthropic_request(request);
anthropic_request.stream = Some(true);
let retry_config = RetryConfig::default();
let response = retry_with_backoff(
|| async {
let response = self
.client
.post(ANTHROPIC_API_URL)
.headers(req_headers.clone())
.json(&anthropic_request)
.send()
.await?;
if !response.status().is_success() {
return Err(self.handle_error(response).await);
}
Ok(response)
},
&retry_config,
)
.await?;
let byte_stream = response.bytes_stream();
let buffer = std::sync::Arc::new(std::sync::Mutex::new(String::new()));
let event_stream = byte_stream
.map(
move |chunk_result| -> Vec<std::result::Result<StreamEvent, ProviderError>> {
match chunk_result {
Err(e) => vec![Err(ProviderError::StreamError(e.to_string()))],
Ok(chunk) => {
let text = String::from_utf8_lossy(&chunk);
let mut buf = buffer.lock().expect("SSE buffer lock poisoned");
buf.push_str(&text);
let mut events = Vec::new();
while let Some(newline_pos) = buf.find('\n') {
let line = buf[..newline_pos].trim().to_string();
buf.drain(..=newline_pos);
if let Some(json_str) = line.strip_prefix("data: ") {
if json_str == "[DONE]" {
continue;
}
match serde_json::from_str::<StreamEvent>(json_str) {
Ok(event) => events.push(Ok(event)),
Err(e) => {
tracing::warn!(
"Failed to parse SSE event JSON: {}. Data: {}",
e,
json_str.chars().take(200).collect::<String>()
);
}
}
}
}
if events.is_empty() {
vec![Ok(StreamEvent::Ping)]
} else {
events
}
}
}
},
)
.flat_map(futures::stream::iter);
Ok(Box::pin(event_stream))
}
fn supports_streaming(&self) -> bool {
true
}
fn supports_tools(&self) -> bool {
true
}
fn supports_vision(&self) -> bool {
true
}
fn name(&self) -> &str {
"anthropic"
}
fn default_model(&self) -> &str {
self.custom_default_model
.as_deref()
.unwrap_or("claude-sonnet-4-5")
}
fn supported_models(&self) -> Vec<String> {
vec![
"claude-opus-4-6".to_string(),
"claude-sonnet-4-5-20250929".to_string(),
"claude-haiku-4-5-20251001".to_string(),
"claude-3-5-sonnet-20241022".to_string(),
"claude-3-5-haiku-20241022".to_string(),
"claude-3-opus-20240229".to_string(),
"claude-3-sonnet-20240229".to_string(),
"claude-3-5-sonnet-20240620".to_string(),
"claude-3-haiku-20240307".to_string(),
]
}
async fn fetch_models(&self) -> Vec<String> {
#[derive(Deserialize)]
struct ModelEntry {
id: String,
}
#[derive(Deserialize)]
struct ModelsResponse {
data: Vec<ModelEntry>,
}
let req = self
.client
.get(ANTHROPIC_MODELS_URL)
.header("anthropic-version", ANTHROPIC_VERSION)
.header("x-api-key", &self.api_key)
.header("anthropic-beta", "prompt-caching-2024-07-31");
match req.send().await {
Ok(resp) if resp.status().is_success() => match resp.json::<ModelsResponse>().await {
Ok(body) => {
let mut models: Vec<String> = body.data.into_iter().map(|m| m.id).collect();
models.sort();
if models.is_empty() {
return self.supported_models();
}
models
}
Err(_) => self.supported_models(),
},
_ => self.supported_models(),
}
}
fn context_window(&self, model: &str) -> Option<u32> {
match model {
"claude-opus-4-6" => Some(200_000),
"claude-sonnet-4-5-20250929" => Some(200_000),
"claude-haiku-4-5-20251001" => Some(200_000),
"claude-3-5-sonnet-20241022" => Some(200_000),
"claude-3-5-haiku-20241022" => Some(200_000),
"claude-3-opus-20240229" => Some(200_000),
"claude-3-sonnet-20240229" => Some(200_000),
"claude-3-5-sonnet-20240620" => Some(200_000),
"claude-3-haiku-20240307" => Some(200_000),
_ => None,
}
}
fn calculate_cost(&self, model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
crate::usage::pricing::PricingConfig::load()
.map(|cfg| cfg.calculate_cost(model, input_tokens, output_tokens))
.unwrap_or(0.0)
}
}
#[derive(Debug, Serialize)]
struct AnthropicRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<AnthropicSystem>,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<AnthropicTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
metadata: Option<std::collections::HashMap<String, String>>,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
#[allow(dead_code)]
enum AnthropicSystem {
String(String),
Blocks(Vec<AnthropicSystemBlock>),
}
#[derive(Debug, Serialize)]
struct AnthropicSystemBlock {
#[serde(rename = "type")]
block_type: String,
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<AnthropicCacheControl>,
}
#[derive(Debug, Serialize)]
struct AnthropicTool {
name: String,
description: String,
input_schema: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<AnthropicCacheControl>,
}
#[derive(Debug, Serialize, Clone)]
struct AnthropicCacheControl {
#[serde(rename = "type")]
cache_type: String,
}
#[derive(Debug, Deserialize)]
struct AnthropicResponse {
id: String,
model: String,
content: Vec<ContentBlock>,
stop_reason: Option<StopReason>,
usage: TokenUsage,
}
#[derive(Debug, Deserialize)]
struct AnthropicError {
error: AnthropicErrorDetail,
}
#[derive(Debug, Deserialize)]
struct AnthropicErrorDetail {
#[serde(rename = "type")]
error_type: String,
message: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_anthropic_provider_creation() {
let provider = AnthropicProvider::new("test-key".to_string());
assert_eq!(provider.name(), "anthropic");
assert_eq!(provider.default_model(), "claude-sonnet-4-5");
}
#[test]
fn test_custom_default_model() {
let provider = AnthropicProvider::new("test-key".to_string())
.with_default_model("claude-opus-4-6".to_string());
assert_eq!(provider.default_model(), "claude-opus-4-6");
}
#[test]
fn test_supported_models() {
let provider = AnthropicProvider::new("test-key".to_string());
let models = provider.supported_models();
assert!(models.contains(&"claude-opus-4-6".to_string()));
assert!(models.contains(&"claude-sonnet-4-5-20250929".to_string()));
assert!(models.contains(&"claude-haiku-4-5-20251001".to_string()));
assert!(models.contains(&"claude-3-opus-20240229".to_string()));
}
#[test]
fn test_context_window() {
let provider = AnthropicProvider::new("test-key".to_string());
assert_eq!(provider.context_window("claude-opus-4-6"), Some(200_000));
assert_eq!(
provider.context_window("claude-3-opus-20240229"),
Some(200_000)
);
assert_eq!(provider.context_window("unknown-model"), None);
}
#[test]
fn test_cost_calculation() {
let provider = AnthropicProvider::new("test-key".to_string());
let cost = provider.calculate_cost("claude-opus-4-6", 1_000_000, 1_000_000);
assert_eq!(cost, 30.0);
let cost = provider.calculate_cost("claude-sonnet-4-6", 1_000_000, 1_000_000);
assert_eq!(cost, 18.0);
let cost = provider.calculate_cost("claude-3-opus-20240229", 1_000_000, 1_000_000);
assert_eq!(cost, 90.0);
let cost = provider.calculate_cost("claude-haiku-4-5-20251001", 1_000_000, 1_000_000);
assert_eq!(cost, 6.0);
let cost = provider.calculate_cost("claude-3-5-haiku-20241022", 1_000_000, 1_000_000);
assert_eq!(cost, 4.8);
let cost = provider.calculate_cost("claude-3-haiku-20240307", 1_000_000, 1_000_000);
assert_eq!(cost, 1.5); }
#[test]
fn test_standard_headers() {
let provider = AnthropicProvider::new("sk-ant-api-key".to_string());
let headers = provider.headers();
assert!(headers.contains_key("x-api-key"));
assert!(!headers.contains_key(reqwest::header::AUTHORIZATION));
assert!(headers.contains_key("anthropic-beta"));
}
#[test]
fn test_capabilities() {
let provider = AnthropicProvider::new("test-key".to_string());
assert!(provider.supports_streaming());
assert!(provider.supports_tools());
assert!(provider.supports_vision());
}
}