Skip to main content

llmkit_anthropic/
provider.rs

1//! [`AnthropicProvider`] — implements [`LlmProvider`] against the Messages API.
2
3use std::time::{Duration, Instant};
4
5use async_trait::async_trait;
6use llmkit_core::{
7    pricing, ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse,
8    LlmError, LlmProvider, LlmResult,
9};
10
11use crate::types::{ApiError, MessagesResponse};
12use crate::{chat, stream};
13
14const DEFAULT_BASE_URL: &str = "https://api.anthropic.com/v1";
15const DEFAULT_MODEL: &str = "claude-opus-4-8";
16const ANTHROPIC_VERSION: &str = "2023-06-01";
17
18/// Anthropic (Claude) provider over the `/v1/messages` API.
19#[derive(Clone)]
20pub struct AnthropicProvider {
21    http: reqwest::Client,
22    api_key: String,
23    base_url: String,
24    model: String,
25    version: String,
26}
27
28impl AnthropicProvider {
29    /// Construct with an explicit API key.
30    pub fn new(api_key: impl Into<String>) -> Self {
31        Self {
32            http: reqwest::Client::new(),
33            api_key: api_key.into(),
34            base_url: DEFAULT_BASE_URL.to_string(),
35            model: DEFAULT_MODEL.to_string(),
36            version: ANTHROPIC_VERSION.to_string(),
37        }
38    }
39
40    /// Construct from the `ANTHROPIC_API_KEY` environment variable.
41    pub fn from_env() -> LlmResult<Self> {
42        let key = std::env::var("ANTHROPIC_API_KEY")
43            .map_err(|_| LlmError::Auth("ANTHROPIC_API_KEY not set".into()))?;
44        Ok(Self::new(key))
45    }
46
47    /// Set the default model.
48    pub fn model(mut self, model: impl Into<String>) -> Self {
49        self.model = model.into();
50        self
51    }
52
53    /// Override the base URL.
54    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
55        self.base_url = base_url.into();
56        self
57    }
58
59    /// Override the `anthropic-version` header.
60    pub fn version(mut self, version: impl Into<String>) -> Self {
61        self.version = version.into();
62        self
63    }
64
65    /// Provide a custom [`reqwest::Client`].
66    pub fn with_client(mut self, client: reqwest::Client) -> Self {
67        self.http = client;
68        self
69    }
70
71    fn resolved_model(&self, req: &ChatRequest) -> String {
72        req.model.clone().unwrap_or_else(|| self.model.clone())
73    }
74
75    fn request(&self, body: &impl serde::Serialize) -> reqwest::RequestBuilder {
76        self.http
77            .post(format!("{}/messages", self.base_url))
78            .header("x-api-key", &self.api_key)
79            .header("anthropic-version", &self.version)
80            .json(body)
81    }
82}
83
84#[async_trait]
85impl LlmProvider for AnthropicProvider {
86    async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
87        let model = self.resolved_model(&req);
88        let body = chat::build_request(&req, model, false);
89
90        let start = Instant::now();
91        let resp = self.request(&body).send().await.map_err(map_reqwest_err)?;
92        let resp = check_status(resp).await?;
93        let parsed: MessagesResponse = resp.json().await.map_err(map_reqwest_err)?;
94
95        let mut out = chat::map_response(parsed, start.elapsed().as_millis() as u64)?;
96        out.cost = pricing::pricing_for(&out.model).map(|p| p.cost_for(out.usage));
97        Ok(out)
98    }
99
100    async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
101        let model = self.resolved_model(&req);
102        let body = chat::build_request(&req, model, true);
103
104        let resp = self.request(&body).send().await.map_err(map_reqwest_err)?;
105        let resp = check_status(resp).await?;
106        Ok(stream::parse(resp))
107    }
108
109    async fn embed(&self, _req: EmbedRequest) -> LlmResult<EmbedResponse> {
110        Err(LlmError::Unsupported(
111            "Anthropic does not provide an embeddings endpoint".into(),
112        ))
113    }
114
115    fn name(&self) -> &'static str {
116        "anthropic"
117    }
118
119    fn model(&self) -> &str {
120        &self.model
121    }
122
123    fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
124        let model = self.resolved_model(req);
125        let pricing = pricing::pricing_for(&model)?;
126        let prompt_chars: usize = req
127            .messages
128            .iter()
129            .filter_map(|m| m.content.as_text())
130            .map(|t| t.len())
131            .sum::<usize>()
132            + req.system.as_deref().map(str::len).unwrap_or(0);
133        let prompt_tokens = (prompt_chars / 4) as u32;
134        let completion_tokens = req.max_tokens.unwrap_or(256);
135        Some(pricing.cost_for(llmkit_core::TokenUsage::new(prompt_tokens, completion_tokens)))
136    }
137}
138
139fn map_reqwest_err(e: reqwest::Error) -> LlmError {
140    if e.is_timeout() {
141        LlmError::Timeout
142    } else if e.is_decode() {
143        LlmError::Serialization(e.to_string())
144    } else {
145        LlmError::Transport(e.to_string())
146    }
147}
148
149async fn check_status(resp: reqwest::Response) -> LlmResult<reqwest::Response> {
150    let status = resp.status();
151    if status.is_success() {
152        return Ok(resp);
153    }
154
155    let retry_after = resp
156        .headers()
157        .get(reqwest::header::RETRY_AFTER)
158        .and_then(|v| v.to_str().ok())
159        .and_then(|s| s.parse::<u64>().ok())
160        .map(Duration::from_secs);
161
162    let body = resp.text().await.unwrap_or_default();
163    let message = serde_json::from_str::<ApiError>(&body)
164        .map(|e| e.error.message)
165        .unwrap_or(body);
166
167    Err(match status.as_u16() {
168        401 | 403 => LlmError::Auth(message),
169        429 => LlmError::RateLimited { retry_after, message },
170        400 | 404 | 422 => LlmError::InvalidRequest(message),
171        code => LlmError::Provider { status: code, message },
172    })
173}