use super::{BaseProvider, ModelPricing, Provider, ProviderError, ProviderType};
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::Result;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, info, warn};
#[derive(Debug, Clone)]
pub struct AnthropicProvider {
base: BaseProvider,
api_version: String,
pricing_cache: HashMap<String, ModelPricing>,
}
impl AnthropicProvider {
pub async fn new(config: &ProviderConfig) -> Result<Self> {
let base = BaseProvider::new(config)?;
let base_url = config
.base_url
.clone()
.unwrap_or_else(|| "https://api.anthropic.com".to_string());
let provider = Self {
base: BaseProvider { base_url, ..base },
api_version: "2023-06-01".to_string(), pricing_cache: Self::initialize_pricing_cache(),
};
provider.validate_config().await?;
info!(
"Anthropic provider '{}' initialized successfully",
config.name
);
Ok(provider)
}
async fn validate_config(&self) -> Result<()> {
if self.base.api_key.is_empty() {
return Err(
ProviderError::Authentication("Anthropic API key is required".to_string()).into(),
);
}
debug!("Anthropic provider configuration validated successfully");
Ok(())
}
fn initialize_pricing_cache() -> HashMap<String, ModelPricing> {
let mut cache = HashMap::new();
cache.insert(
"claude-3-opus-20240229".to_string(),
ModelPricing {
model: "claude-3-opus-20240229".to_string(),
input_cost_per_1k: 0.015,
output_cost_per_1k: 0.075,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"claude-3-sonnet-20240229".to_string(),
ModelPricing {
model: "claude-3-sonnet-20240229".to_string(),
input_cost_per_1k: 0.003,
output_cost_per_1k: 0.015,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"claude-3-haiku-20240307".to_string(),
ModelPricing {
model: "claude-3-haiku-20240307".to_string(),
input_cost_per_1k: 0.00025,
output_cost_per_1k: 0.00125,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"claude-2.1".to_string(),
ModelPricing {
model: "claude-2.1".to_string(),
input_cost_per_1k: 0.008,
output_cost_per_1k: 0.024,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"claude-2.0".to_string(),
ModelPricing {
model: "claude-2.0".to_string(),
input_cost_per_1k: 0.008,
output_cost_per_1k: 0.024,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache
}
fn create_headers(&self) -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("x-api-key", self.base.api_key.parse().unwrap());
headers.insert("anthropic-version", self.api_version.parse().unwrap());
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers
}
fn convert_messages_to_anthropic(
&self,
messages: &[ChatMessage],
) -> (Option<String>, Vec<serde_json::Value>) {
let mut system_message = None;
let mut anthropic_messages = Vec::new();
for message in messages {
match message.role {
MessageRole::System => {
if let Some(MessageContent::Text(text)) = &message.content {
system_message = Some(text.clone());
}
}
MessageRole::User => {
anthropic_messages.push(json!({
"role": "user",
"content": self.convert_message_content(message.content.as_ref())
}));
}
MessageRole::Assistant => {
anthropic_messages.push(json!({
"role": "assistant",
"content": self.convert_message_content(message.content.as_ref())
}));
}
_ => {
warn!("Unsupported message role for Anthropic: {:?}", message.role);
}
}
}
(system_message, anthropic_messages)
}
fn convert_message_content(&self, content: Option<&MessageContent>) -> serde_json::Value {
match content {
Some(MessageContent::Text(text)) => json!(text),
Some(MessageContent::Parts(parts)) => {
let mut anthropic_content = Vec::new();
for part in parts {
match part {
ContentPart::Text { text } => {
anthropic_content.push(json!({
"type": "text",
"text": text
}));
}
ContentPart::ImageUrl { image_url } => {
anthropic_content.push(json!({
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg", "data": image_url.url.trim_start_matches("data:image/jpeg;base64,")
}
}));
}
ContentPart::Audio { .. } => {
}
}
}
json!(anthropic_content)
}
None => json!(""),
}
}
fn convert_anthropic_response_to_openai(
&self,
anthropic_response: serde_json::Value,
model: &str,
) -> Result<ChatCompletionResponse> {
let id = anthropic_response
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("chatcmpl-anthropic")
.to_string();
let content = anthropic_response
.get("content")
.and_then(|v| v.as_array())
.and_then(|arr| arr.first())
.and_then(|item| item.get("text"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let usage = anthropic_response.get("usage").map(|u| Usage {
prompt_tokens: u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
completion_tokens: u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
total_tokens: 0, prompt_tokens_details: None,
completion_tokens_details: None,
});
let mut usage = usage.unwrap_or_default();
usage.total_tokens = usage.prompt_tokens + usage.completion_tokens;
Ok(ChatCompletionResponse {
id,
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: model.to_string(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: MessageRole::Assistant,
content: Some(MessageContent::Text(content)),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
},
finish_reason: Some("stop".to_string()),
logprobs: None,
}],
usage: Some(usage),
system_fingerprint: None,
})
}
async fn make_anthropic_request(
&self,
endpoint: &str,
body: serde_json::Value,
) -> Result<reqwest::Response> {
let url = format!(
"{}/{}",
self.base.base_url.trim_end_matches('/'),
endpoint.trim_start_matches('/')
);
let response = self
.base
.client
.post(&url)
.headers(self.create_headers())
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(match status.as_u16() {
401 => ProviderError::Authentication(error_text),
429 => ProviderError::RateLimit(error_text),
404 => ProviderError::ModelNotFound(error_text),
400 => ProviderError::InvalidRequest(error_text),
503 => ProviderError::Unavailable(error_text),
_ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
}
.into());
}
Ok(response)
}
}
#[async_trait]
impl Provider for AnthropicProvider {
fn name(&self) -> &str {
&self.base.name
}
fn provider_type(&self) -> ProviderType {
ProviderType::Anthropic
}
async fn supports_model(&self, model: &str) -> bool {
self.base.is_model_supported(model) || model.starts_with("claude-")
}
async fn supports_images(&self) -> bool {
true }
async fn supports_embeddings(&self) -> bool {
false }
async fn supports_streaming(&self) -> bool {
true
}
async fn list_models(&self) -> Result<Vec<Model>> {
let known_models = vec![
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
"claude-2.1",
"claude-2.0",
];
let models = known_models
.into_iter()
.map(|model| Model {
id: model.to_string(),
object: "model".to_string(),
created: chrono::Utc::now().timestamp() as u64,
owned_by: "anthropic".to_string(),
})
.collect();
Ok(models)
}
async fn health_check(&self) -> Result<()> {
debug!("Performing Anthropic health check");
let body = json!({
"model": "claude-3-haiku-20240307",
"max_tokens": 1,
"messages": [
{
"role": "user",
"content": "Hi"
}
]
});
let response = self.make_anthropic_request("v1/messages", body).await?;
if response.status().is_success() {
Ok(())
} else {
Err(ProviderError::Unavailable(format!(
"Health check failed with status: {}",
response.status()
))
.into())
}
}
async fn chat_completion(
&self,
request: ChatCompletionRequest,
_context: RequestContext,
) -> Result<ChatCompletionResponse> {
debug!("Anthropic chat completion for model: {}", request.model);
if request.stream.unwrap_or(false) {
return Err(ProviderError::InvalidRequest(
"Streaming requests should use chat_completion_stream method".to_string(),
)
.into());
}
let (system_message, messages) = self.convert_messages_to_anthropic(&request.messages);
let mut body = json!({
"model": request.model,
"messages": messages,
"max_tokens": request.max_tokens.unwrap_or(4096),
});
if let Some(system) = system_message {
body["system"] = json!(system);
}
if let Some(temp) = request.temperature {
body["temperature"] = json!(temp);
}
if let Some(top_p) = request.top_p {
body["top_p"] = json!(top_p);
}
if let Some(stop) = request.stop {
body["stop_sequences"] = json!(stop);
}
let response = self.make_anthropic_request("v1/messages", body).await?;
let anthropic_response: serde_json::Value = self.base.parse_json_response(response).await?;
self.convert_anthropic_response_to_openai(anthropic_response, &request.model)
}
async fn chat_completion_stream(
&self,
request: ChatCompletionRequest,
_context: RequestContext,
) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin + 'static>> {
debug!(
"Anthropic streaming chat completion for model: {}",
request.model
);
let (system_message, messages) = self.convert_messages_to_anthropic(&request.messages);
let mut body = json!({
"model": request.model,
"messages": messages,
"max_tokens": request.max_tokens.unwrap_or(4096),
"stream": true
});
if let Some(system) = system_message {
body["system"] = json!(system);
}
if let Some(temp) = request.temperature {
body["temperature"] = json!(temp);
}
if let Some(top_p) = request.top_p {
body["top_p"] = json!(top_p);
}
if let Some(stop) = request.stop {
body["stop_sequences"] = json!(stop);
}
let url = format!(
"{}/{}",
self.base.base_url.trim_end_matches('/'),
"v1/messages"
);
let mut headers = self.create_headers();
headers.insert(
reqwest::header::ACCEPT,
"text/event-stream".parse().unwrap(),
);
let response = self
.base
.client
.post(&url)
.headers(headers)
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(match status.as_u16() {
401 => ProviderError::Authentication(error_text),
429 => ProviderError::RateLimit(error_text),
404 => ProviderError::ModelNotFound(error_text),
400 => ProviderError::InvalidRequest(error_text),
503 => ProviderError::Unavailable(error_text),
_ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
}
.into());
}
let stream = crate::core::streaming::providers::AnthropicStreaming::create_stream(response);
Ok(Box::new(stream))
}
async fn completion(
&self,
_request: CompletionRequest,
_context: RequestContext,
) -> Result<CompletionResponse> {
Err(ProviderError::InvalidRequest(
"Anthropic does not support legacy completion endpoint".to_string(),
)
.into())
}
async fn embedding(
&self,
_request: EmbeddingRequest,
_context: RequestContext,
) -> Result<EmbeddingResponse> {
Err(
ProviderError::InvalidRequest("Anthropic does not support embeddings".to_string())
.into(),
)
}
async fn image_generation(
&self,
_request: ImageGenerationRequest,
_context: RequestContext,
) -> Result<ImageGenerationResponse> {
Err(ProviderError::InvalidRequest(
"Anthropic does not support image generation".to_string(),
)
.into())
}
async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing> {
if let Some(pricing) = self.pricing_cache.get(model) {
Ok(pricing.clone())
} else {
Ok(ModelPricing {
model: model.to_string(),
input_cost_per_1k: 0.008, output_cost_per_1k: 0.024,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
})
}
}
async fn calculate_cost(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<f64> {
let pricing = self.get_model_pricing(model).await?;
let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_cost_per_1k;
let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;
Ok(input_cost + output_cost)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ProviderConfig;
fn create_test_config() -> ProviderConfig {
ProviderConfig {
name: "test-anthropic".to_string(),
provider_type: "anthropic".to_string(),
api_key: "test-key".to_string(),
base_url: Some("https://api.anthropic.com".to_string()),
models: vec!["claude-3-sonnet-20240229".to_string()],
timeout: 30,
max_retries: 3,
organization: None,
api_version: None,
project: None,
weight: 1.0,
rpm: 1000,
tpm: 10000,
enabled: true,
max_concurrent_requests: 10,
retry: crate::config::RetryConfig::default(),
health_check: crate::config::HealthCheckConfig::default(),
settings: std::collections::HashMap::new(),
tags: vec![],
}
}
#[tokio::test]
async fn test_anthropic_provider_creation() {
let config = create_test_config();
let provider = AnthropicProvider::new(&config).await;
assert!(provider.is_ok());
}
#[tokio::test]
async fn test_model_support() {
let config = create_test_config();
if let Ok(provider) = AnthropicProvider::new(&config).await {
assert!(provider.supports_model("claude-3-sonnet-20240229").await);
assert!(provider.supports_model("claude-2.1").await);
assert!(!provider.supports_model("gpt-4").await);
}
}
#[tokio::test]
async fn test_message_conversion() {
let config = create_test_config();
if let Ok(provider) = AnthropicProvider::new(&config).await {
let messages = vec![
ChatMessage {
role: MessageRole::System,
content: Some(MessageContent::Text(
"You are a helpful assistant".to_string(),
)),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
},
ChatMessage {
role: MessageRole::User,
content: Some(MessageContent::Text("Hello".to_string())),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
},
];
let (system, anthropic_messages) = provider.convert_messages_to_anthropic(&messages);
assert_eq!(system, Some("You are a helpful assistant".to_string()));
assert_eq!(anthropic_messages.len(), 1);
}
}
#[test]
fn test_pricing_cache() {
let cache = AnthropicProvider::initialize_pricing_cache();
assert!(cache.contains_key("claude-3-opus-20240229"));
assert!(cache.contains_key("claude-3-sonnet-20240229"));
assert!(cache.contains_key("claude-3-haiku-20240307"));
}
}