llmkit_anthropic/
provider.rs1use std::time::{Duration, Instant};
4
5use async_trait::async_trait;
6use llmkit_core::{
7 pricing, ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse,
8 LlmError, LlmProvider, LlmResult,
9};
10
11use crate::types::{ApiError, MessagesResponse};
12use crate::{chat, stream};
13
14const DEFAULT_BASE_URL: &str = "https://api.anthropic.com/v1";
15const DEFAULT_MODEL: &str = "claude-opus-4-8";
16const ANTHROPIC_VERSION: &str = "2023-06-01";
17
18#[derive(Clone)]
20pub struct AnthropicProvider {
21 http: reqwest::Client,
22 api_key: String,
23 base_url: String,
24 model: String,
25 version: String,
26}
27
28impl AnthropicProvider {
29 pub fn new(api_key: impl Into<String>) -> Self {
31 Self {
32 http: reqwest::Client::new(),
33 api_key: api_key.into(),
34 base_url: DEFAULT_BASE_URL.to_string(),
35 model: DEFAULT_MODEL.to_string(),
36 version: ANTHROPIC_VERSION.to_string(),
37 }
38 }
39
40 pub fn from_env() -> LlmResult<Self> {
42 let key = std::env::var("ANTHROPIC_API_KEY")
43 .map_err(|_| LlmError::Auth("ANTHROPIC_API_KEY not set".into()))?;
44 Ok(Self::new(key))
45 }
46
47 pub fn model(mut self, model: impl Into<String>) -> Self {
49 self.model = model.into();
50 self
51 }
52
53 pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
55 self.base_url = base_url.into();
56 self
57 }
58
59 pub fn version(mut self, version: impl Into<String>) -> Self {
61 self.version = version.into();
62 self
63 }
64
65 pub fn with_client(mut self, client: reqwest::Client) -> Self {
67 self.http = client;
68 self
69 }
70
71 fn resolved_model(&self, req: &ChatRequest) -> String {
72 req.model.clone().unwrap_or_else(|| self.model.clone())
73 }
74
75 fn request(&self, body: &impl serde::Serialize) -> reqwest::RequestBuilder {
76 self.http
77 .post(format!("{}/messages", self.base_url))
78 .header("x-api-key", &self.api_key)
79 .header("anthropic-version", &self.version)
80 .json(body)
81 }
82}
83
84#[async_trait]
85impl LlmProvider for AnthropicProvider {
86 async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
87 let model = self.resolved_model(&req);
88 let body = chat::build_request(&req, model, false);
89
90 let start = Instant::now();
91 let resp = self.request(&body).send().await.map_err(map_reqwest_err)?;
92 let resp = check_status(resp).await?;
93 let parsed: MessagesResponse = resp.json().await.map_err(map_reqwest_err)?;
94
95 let mut out = chat::map_response(parsed, start.elapsed().as_millis() as u64)?;
96 out.cost = pricing::pricing_for(&out.model).map(|p| p.cost_for(out.usage));
97 Ok(out)
98 }
99
100 async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
101 let model = self.resolved_model(&req);
102 let body = chat::build_request(&req, model, true);
103
104 let resp = self.request(&body).send().await.map_err(map_reqwest_err)?;
105 let resp = check_status(resp).await?;
106 Ok(stream::parse(resp))
107 }
108
109 async fn embed(&self, _req: EmbedRequest) -> LlmResult<EmbedResponse> {
110 Err(LlmError::Unsupported(
111 "Anthropic does not provide an embeddings endpoint".into(),
112 ))
113 }
114
115 fn name(&self) -> &'static str {
116 "anthropic"
117 }
118
119 fn model(&self) -> &str {
120 &self.model
121 }
122
123 fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
124 let model = self.resolved_model(req);
125 let pricing = pricing::pricing_for(&model)?;
126 let prompt_chars: usize = req
127 .messages
128 .iter()
129 .filter_map(|m| m.content.as_text())
130 .map(|t| t.len())
131 .sum::<usize>()
132 + req.system.as_deref().map(str::len).unwrap_or(0);
133 let prompt_tokens = (prompt_chars / 4) as u32;
134 let completion_tokens = req.max_tokens.unwrap_or(256);
135 Some(pricing.cost_for(llmkit_core::TokenUsage::new(prompt_tokens, completion_tokens)))
136 }
137}
138
139fn map_reqwest_err(e: reqwest::Error) -> LlmError {
140 if e.is_timeout() {
141 LlmError::Timeout
142 } else if e.is_decode() {
143 LlmError::Serialization(e.to_string())
144 } else {
145 LlmError::Transport(e.to_string())
146 }
147}
148
149async fn check_status(resp: reqwest::Response) -> LlmResult<reqwest::Response> {
150 let status = resp.status();
151 if status.is_success() {
152 return Ok(resp);
153 }
154
155 let retry_after = resp
156 .headers()
157 .get(reqwest::header::RETRY_AFTER)
158 .and_then(|v| v.to_str().ok())
159 .and_then(|s| s.parse::<u64>().ok())
160 .map(Duration::from_secs);
161
162 let body = resp.text().await.unwrap_or_default();
163 let message = serde_json::from_str::<ApiError>(&body)
164 .map(|e| e.error.message)
165 .unwrap_or(body);
166
167 Err(match status.as_u16() {
168 401 | 403 => LlmError::Auth(message),
169 429 => LlmError::RateLimited { retry_after, message },
170 400 | 404 | 422 => LlmError::InvalidRequest(message),
171 code => LlmError::Provider { status: code, message },
172 })
173}