llm_stack_anthropic/
provider.rs1use std::collections::HashSet;
4
5use llm_stack::ChatResponse;
6use llm_stack::error::LlmError;
7use llm_stack::provider::{Capability, ChatParams, Provider, ProviderMetadata};
8use llm_stack::stream::ChatStream;
9use reqwest::header::{HeaderMap, HeaderValue};
10use tracing::instrument;
11
12use crate::config::AnthropicConfig;
13use crate::convert;
14
15#[derive(Debug)]
40pub struct AnthropicProvider {
41 config: AnthropicConfig,
42 client: reqwest::Client,
43}
44
45impl AnthropicProvider {
46 pub fn new(config: AnthropicConfig) -> Self {
51 let client = config.client.clone().unwrap_or_else(|| {
52 let mut builder = reqwest::Client::builder();
53 if let Some(timeout) = config.timeout {
54 builder = builder.timeout(timeout);
55 }
56 builder.build().expect("failed to build HTTP client")
57 });
58 Self { config, client }
59 }
60
61 fn default_headers(&self) -> Result<HeaderMap, LlmError> {
63 let mut headers = HeaderMap::new();
64 headers.insert(
65 "x-api-key",
66 HeaderValue::from_str(&self.config.api_key)
67 .map_err(|_| LlmError::Auth("API key contains invalid header characters".into()))?,
68 );
69 headers.insert(
70 "anthropic-version",
71 HeaderValue::from_str(&self.config.api_version).map_err(|_| {
72 LlmError::InvalidRequest("API version contains invalid header characters".into())
73 })?,
74 );
75 headers.insert("content-type", HeaderValue::from_static("application/json"));
76 Ok(headers)
77 }
78
79 fn messages_url(&self) -> String {
81 let base = self.config.base_url.trim_end_matches('/');
82 format!("{base}/v1/messages")
83 }
84
85 async fn send_request(
88 &self,
89 params: &ChatParams,
90 stream: bool,
91 ) -> Result<reqwest::Response, LlmError> {
92 let request_body = convert::build_request(params, &self.config, stream)?;
93
94 let mut headers = self.default_headers()?;
95 if let Some(extra) = ¶ms.extra_headers {
96 headers.extend(extra.iter().map(|(k, v)| (k.clone(), v.clone())));
97 }
98
99 let mut req = self
100 .client
101 .post(self.messages_url())
102 .headers(headers)
103 .json(&request_body);
104
105 if let Some(timeout) = params.timeout {
106 req = req.timeout(timeout);
107 }
108
109 let response = req.send().await.map_err(|e| {
110 if e.is_timeout() {
111 LlmError::Timeout {
112 elapsed_ms: params
113 .timeout
114 .or(self.config.timeout)
115 .map_or(0, |d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX)),
116 }
117 } else {
118 LlmError::Http {
119 status: e.status().map(|s| {
120 http::StatusCode::from_u16(s.as_u16())
121 .unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR)
122 }),
123 message: e.to_string(),
124 retryable: e.is_connect() || e.is_timeout(),
125 }
126 }
127 })?;
128
129 let status = response.status();
130 if !status.is_success() {
131 let body = response.text().await.unwrap_or_default();
132 let http_status = http::StatusCode::from_u16(status.as_u16())
133 .unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR);
134 return Err(convert::convert_error(http_status, &body));
135 }
136
137 Ok(response)
138 }
139}
140
141impl Provider for AnthropicProvider {
142 #[instrument(skip_all, fields(model = %self.config.model))]
143 async fn generate(&self, params: &ChatParams) -> Result<ChatResponse, LlmError> {
144 let response = self.send_request(params, false).await?;
145
146 let api_response: crate::types::Response =
147 response
148 .json()
149 .await
150 .map_err(|e| LlmError::ResponseFormat {
151 message: format!("Failed to parse Anthropic response: {e}"),
152 raw: String::new(),
153 })?;
154
155 Ok(convert::convert_response(api_response))
156 }
157
158 #[instrument(skip_all, fields(model = %self.config.model))]
159 async fn stream(&self, params: &ChatParams) -> Result<ChatStream, LlmError> {
160 let response = self.send_request(params, true).await?;
161 Ok(crate::stream::into_stream(response))
162 }
163
164 fn metadata(&self) -> ProviderMetadata {
165 let mut capabilities = HashSet::new();
166 capabilities.insert(Capability::Tools);
167 capabilities.insert(Capability::Vision);
168 capabilities.insert(Capability::Reasoning);
169 capabilities.insert(Capability::Caching);
170 capabilities.insert(Capability::StructuredOutput);
171
172 ProviderMetadata {
173 name: "anthropic".into(),
174 model: self.config.model.clone(),
175 context_window: context_window_for_model(&self.config.model),
176 capabilities,
177 }
178 }
179}
180
181fn context_window_for_model(model: &str) -> u64 {
183 if model.contains("claude") {
184 200_000
185 } else {
186 100_000
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use std::time::Duration;
194
195 use super::*;
196
197 #[test]
198 fn test_metadata() {
199 let provider = AnthropicProvider::new(AnthropicConfig {
200 model: "claude-sonnet-4-20250514".into(),
201 ..Default::default()
202 });
203 let meta = provider.metadata();
204
205 assert_eq!(meta.name, "anthropic");
206 assert_eq!(meta.model, "claude-sonnet-4-20250514");
207 assert_eq!(meta.context_window, 200_000);
208 assert!(meta.capabilities.contains(&Capability::Tools));
209 assert!(meta.capabilities.contains(&Capability::Vision));
210 assert!(meta.capabilities.contains(&Capability::Reasoning));
211 assert!(meta.capabilities.contains(&Capability::Caching));
212 }
213
214 #[test]
215 fn test_context_window_claude_3_5() {
216 assert_eq!(
217 context_window_for_model("claude-3-5-haiku-20241022"),
218 200_000
219 );
220 assert_eq!(
221 context_window_for_model("claude-3-5-sonnet-20241022"),
222 200_000
223 );
224 }
225
226 #[test]
227 fn test_context_window_claude_4() {
228 assert_eq!(
229 context_window_for_model("claude-sonnet-4-20250514"),
230 200_000
231 );
232 assert_eq!(context_window_for_model("claude-opus-4-20250514"), 200_000);
233 }
234
235 #[test]
236 fn test_context_window_unknown() {
237 assert_eq!(context_window_for_model("some-future-model"), 100_000);
238 }
239
240 #[test]
241 fn test_messages_url() {
242 let provider = AnthropicProvider::new(AnthropicConfig {
243 base_url: "https://api.anthropic.com".into(),
244 ..Default::default()
245 });
246 assert_eq!(
247 provider.messages_url(),
248 "https://api.anthropic.com/v1/messages"
249 );
250 }
251
252 #[test]
253 fn test_messages_url_custom_base() {
254 let provider = AnthropicProvider::new(AnthropicConfig {
255 base_url: "http://localhost:8080".into(),
256 ..Default::default()
257 });
258 assert_eq!(provider.messages_url(), "http://localhost:8080/v1/messages");
259 }
260
261 #[test]
262 fn test_messages_url_trailing_slash() {
263 let provider = AnthropicProvider::new(AnthropicConfig {
264 base_url: "https://proxy.example.com/".into(),
265 ..Default::default()
266 });
267 assert_eq!(
268 provider.messages_url(),
269 "https://proxy.example.com/v1/messages"
270 );
271 }
272
273 #[test]
274 fn test_default_headers() {
275 let provider = AnthropicProvider::new(AnthropicConfig {
276 api_key: "sk-ant-test123".into(),
277 api_version: "2023-06-01".into(),
278 ..Default::default()
279 });
280 let headers = provider.default_headers().unwrap();
281
282 assert_eq!(headers.get("x-api-key").unwrap(), "sk-ant-test123");
283 assert_eq!(headers.get("anthropic-version").unwrap(), "2023-06-01");
284 assert_eq!(headers.get("content-type").unwrap(), "application/json");
285 }
286
287 #[test]
288 fn test_default_headers_invalid_api_key() {
289 let provider = AnthropicProvider::new(AnthropicConfig {
290 api_key: "invalid\nkey".into(),
291 ..Default::default()
292 });
293 let err = provider.default_headers().unwrap_err();
294 assert!(matches!(err, llm_stack::LlmError::Auth(_)));
295 }
296
297 #[test]
298 fn test_new_with_custom_client() {
299 let custom_client = reqwest::Client::builder()
300 .timeout(Duration::from_secs(10))
301 .build()
302 .unwrap();
303
304 let provider = AnthropicProvider::new(AnthropicConfig {
305 client: Some(custom_client),
306 ..Default::default()
307 });
308
309 assert_eq!(provider.metadata().name, "anthropic");
312 }
313
314 #[test]
315 fn test_new_with_timeout() {
316 let provider = AnthropicProvider::new(AnthropicConfig {
317 timeout: Some(Duration::from_secs(30)),
318 ..Default::default()
319 });
320 assert_eq!(provider.metadata().name, "anthropic");
321 }
322}