use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use super::{CompletionRequest, LLMProvider, ProviderError};
use crate::constants::{LLM_PROMPT_BYTES_MAX, LLM_RESPONSE_BYTES_MAX};
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
const DEFAULT_MODEL: &str = "claude-sonnet-4-20250514";
const DEFAULT_MAX_TOKENS: usize = 4096;
const ANTHROPIC_VERSION: &str = "2023-06-01";
#[derive(Debug, Serialize)]
struct AnthropicRequest {
model: String,
max_tokens: usize,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
}
#[derive(Debug, Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct AnthropicResponse {
content: Vec<ContentBlock>,
#[serde(default)]
stop_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ContentBlock {
#[serde(rename = "type")]
content_type: String,
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct AnthropicError {
error: ErrorDetail,
}
#[derive(Debug, Deserialize)]
struct ErrorDetail {
#[serde(rename = "type")]
error_type: String,
message: String,
}
#[derive(Debug, Clone)]
pub struct AnthropicProvider {
client: reqwest::Client,
api_key: String,
model: String,
api_url: String,
}
impl AnthropicProvider {
#[must_use]
pub fn new(api_key: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
api_key: api_key.into(),
model: DEFAULT_MODEL.to_string(),
api_url: ANTHROPIC_API_URL.to_string(),
}
}
#[must_use]
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
#[must_use]
pub fn with_api_url(mut self, url: impl Into<String>) -> Self {
self.api_url = url.into();
self
}
#[must_use]
pub fn model(&self) -> &str {
&self.model
}
fn build_request(&self, request: &CompletionRequest) -> AnthropicRequest {
AnthropicRequest {
model: self.model.clone(),
max_tokens: request.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
messages: vec![Message {
role: "user".to_string(),
content: request.prompt.clone(),
}],
system: request.system.clone(),
temperature: request.temperature,
}
}
fn parse_error(status: reqwest::StatusCode, body: &str) -> ProviderError {
if let Ok(err) = serde_json::from_str::<AnthropicError>(body) {
return match err.error.error_type.as_str() {
"authentication_error" => ProviderError::AuthenticationFailed,
"rate_limit_error" => ProviderError::rate_limit(None),
"overloaded_error" => {
ProviderError::service_unavailable("Anthropic API overloaded")
}
"invalid_request_error" => ProviderError::invalid_request(err.error.message),
_ => ProviderError::invalid_response(err.error.message),
};
}
match status {
reqwest::StatusCode::UNAUTHORIZED => ProviderError::AuthenticationFailed,
reqwest::StatusCode::TOO_MANY_REQUESTS => ProviderError::rate_limit(None),
reqwest::StatusCode::SERVICE_UNAVAILABLE | reqwest::StatusCode::BAD_GATEWAY => {
ProviderError::service_unavailable("Anthropic API unavailable")
}
reqwest::StatusCode::REQUEST_TIMEOUT | reqwest::StatusCode::GATEWAY_TIMEOUT => {
ProviderError::Timeout
}
_ => ProviderError::invalid_response(format!("HTTP {}: {}", status, body)),
}
}
}
#[async_trait]
impl LLMProvider for AnthropicProvider {
async fn complete(&self, request: &CompletionRequest) -> Result<String, ProviderError> {
debug_assert!(
request.prompt.len() <= LLM_PROMPT_BYTES_MAX,
"prompt exceeds limit"
);
let body = self.build_request(request);
let response = self
.client
.post(&self.api_url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", ANTHROPIC_VERSION)
.header("content-type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
ProviderError::Timeout
} else if e.is_connect() {
ProviderError::network_error("Connection failed")
} else {
ProviderError::network_error(e.to_string())
}
})?;
let status = response.status();
let response_body = response
.text()
.await
.map_err(|e| ProviderError::network_error(e.to_string()))?;
if !status.is_success() {
return Err(Self::parse_error(status, &response_body));
}
let parsed: AnthropicResponse = serde_json::from_str(&response_body)
.map_err(|e| ProviderError::json_error(format!("Failed to parse response: {}", e)))?;
let text = parsed
.content
.into_iter()
.filter_map(|block| {
if block.content_type == "text" {
block.text
} else {
None
}
})
.collect::<Vec<_>>()
.join("");
debug_assert!(
text.len() <= LLM_RESPONSE_BYTES_MAX,
"response exceeds limit"
);
Ok(text)
}
fn name(&self) -> &'static str {
"anthropic"
}
fn is_simulation(&self) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let provider = AnthropicProvider::new("test-key");
assert_eq!(provider.model(), DEFAULT_MODEL);
assert!(!provider.is_simulation());
assert_eq!(provider.name(), "anthropic");
}
#[test]
fn test_with_model() {
let provider = AnthropicProvider::new("test-key").with_model("claude-opus-4-20250514");
assert_eq!(provider.model(), "claude-opus-4-20250514");
}
#[test]
fn test_with_api_url() {
let provider =
AnthropicProvider::new("test-key").with_api_url("http://localhost:8080/v1/messages");
assert_eq!(provider.api_url, "http://localhost:8080/v1/messages");
}
#[test]
fn test_build_request() {
let provider = AnthropicProvider::new("test-key");
let request = CompletionRequest::new("Hello")
.with_system("You are helpful")
.with_max_tokens(100)
.with_temperature(0.5);
let built = provider.build_request(&request);
assert_eq!(built.model, DEFAULT_MODEL);
assert_eq!(built.max_tokens, 100);
assert_eq!(built.messages.len(), 1);
assert_eq!(built.messages[0].role, "user");
assert_eq!(built.messages[0].content, "Hello");
assert_eq!(built.system, Some("You are helpful".to_string()));
assert_eq!(built.temperature, Some(0.5));
}
#[test]
fn test_parse_error_auth() {
let body = r#"{"error":{"type":"authentication_error","message":"Invalid API key"}}"#;
let err = AnthropicProvider::parse_error(reqwest::StatusCode::UNAUTHORIZED, body);
assert!(matches!(err, ProviderError::AuthenticationFailed));
}
#[test]
fn test_parse_error_rate_limit() {
let body = r#"{"error":{"type":"rate_limit_error","message":"Rate limit exceeded"}}"#;
let err = AnthropicProvider::parse_error(reqwest::StatusCode::TOO_MANY_REQUESTS, body);
assert!(matches!(err, ProviderError::RateLimit { .. }));
}
#[test]
fn test_parse_error_overloaded() {
let body = r#"{"error":{"type":"overloaded_error","message":"API overloaded"}}"#;
let err = AnthropicProvider::parse_error(reqwest::StatusCode::SERVICE_UNAVAILABLE, body);
assert!(matches!(err, ProviderError::ServiceUnavailable { .. }));
}
#[test]
fn test_parse_error_invalid_request() {
let body = r#"{"error":{"type":"invalid_request_error","message":"Bad prompt"}}"#;
let err = AnthropicProvider::parse_error(reqwest::StatusCode::BAD_REQUEST, body);
assert!(matches!(err, ProviderError::InvalidRequest { .. }));
}
#[test]
fn test_parse_error_fallback() {
let body = "Internal server error";
let err = AnthropicProvider::parse_error(reqwest::StatusCode::INTERNAL_SERVER_ERROR, body);
assert!(matches!(err, ProviderError::InvalidResponse { .. }));
}
}