use super::traits::{
ChatRequest, ChatResponse, FinishReason, LlmProvider, MessageRole, TokenUsage,
};
use anyhow::{Context, Result};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
pub struct AnthropicProvider {
name: Arc<str>,
api_key: Arc<str>,
base_url: Arc<str>,
client: Client,
}
impl AnthropicProvider {
pub fn new(name: Arc<str>, api_key: Arc<str>, base_url: Option<Arc<str>>) -> Self {
Self {
name,
api_key,
base_url: base_url.unwrap_or_else(|| Arc::from("https://api.anthropic.com/v1")),
client: Client::new(),
}
}
}
#[async_trait]
impl LlmProvider for AnthropicProvider {
fn name(&self) -> &str {
&self.name
}
fn provider_type(&self) -> &str {
"anthropic"
}
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
let url = format!("{}/messages", self.base_url);
let system_message = request
.messages
.iter()
.find(|m| m.role == MessageRole::System)
.map(|m| m.content.clone());
let messages: Vec<AnthropicMessage> = request
.messages
.into_iter()
.filter(|m| m.role != MessageRole::System)
.map(|m| AnthropicMessage {
role: match m.role {
MessageRole::User => "user".to_string(),
MessageRole::Assistant => "assistant".to_string(),
_ => "user".to_string(),
},
content: m.content,
})
.collect();
let body = AnthropicRequest {
model: request.model,
messages,
max_tokens: request.max_tokens.unwrap_or(4096),
system: system_message,
};
let response = self
.client
.post(&url)
.header("x-api-key", &*self.api_key)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("Failed to send request to Anthropic")?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("Anthropic API error: {}", error_text);
}
let api_response: AnthropicResponse = response
.json()
.await
.context("Failed to parse Anthropic response")?;
let content = api_response
.content
.into_iter()
.filter(|c| c.content_type == "text")
.map(|c| c.text)
.collect::<Vec<String>>()
.join("");
Ok(ChatResponse {
id: api_response.id,
model: api_response.model,
content,
finish_reason: match api_response.stop_reason.as_deref() {
Some("end_turn") => FinishReason::Stop,
Some("max_tokens") => FinishReason::Length,
Some("tool_use") => FinishReason::ToolCalls,
_ => FinishReason::Stop,
},
usage: TokenUsage {
prompt_tokens: api_response.usage.input_tokens,
completion_tokens: api_response.usage.output_tokens,
total_tokens: api_response.usage.input_tokens + api_response.usage.output_tokens,
},
tool_calls: None,
})
}
async fn is_available(&self) -> bool {
if self.api_key.is_empty() {
return false;
}
self.api_key.starts_with("sk-ant-")
}
}
#[derive(Debug, Serialize)]
struct AnthropicRequest {
model: String,
messages: Vec<AnthropicMessage>,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicMessage {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct AnthropicResponse {
id: String,
model: String,
content: Vec<AnthropicContent>,
stop_reason: Option<String>,
usage: AnthropicUsage,
}
#[derive(Debug, Deserialize)]
struct AnthropicContent {
#[serde(rename = "type")]
content_type: String,
text: String,
}
#[derive(Debug, Deserialize)]
struct AnthropicUsage {
input_tokens: u32,
output_tokens: u32,
}