1use crate::core::error::{ApiError, ConnectionError, ConnectionTimeoutError, Error};
4use crate::internal::backoff::{default_retry_timeout_ms, retry_after_ms};
5use crate::internal::headers::{default_headers, merge_headers};
6use crate::resources::beta::Beta;
7use crate::resources::completions::Completions;
8use crate::resources::messages::Messages;
9use crate::resources::models::Models;
10use http::HeaderMap;
11use reqwest::{Client as HttpClient, Method, Response};
12use serde::de::DeserializeOwned;
13use serde::Serialize;
14use std::collections::HashMap;
15use std::time::Duration;
16
17pub const HUMAN_PROMPT: &str = "\n\nHuman:";
19pub const AI_PROMPT: &str = "\n\nAssistant:";
20
21const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
22const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
23const DEFAULT_MAX_RETRIES: u32 = 2;
24
25#[derive(Debug, Clone)]
27pub struct ClientOptions {
28 pub api_key: Option<String>,
29 pub auth_token: Option<String>,
30 pub base_url: Option<String>,
31 pub timeout: Option<Duration>,
32 pub max_retries: Option<u32>,
33 pub default_headers: HashMap<String, String>,
34 pub default_query: HashMap<String, String>,
35}
36
37impl Default for ClientOptions {
38 fn default() -> Self {
39 Self {
40 api_key: std::env::var("ANTHROPIC_API_KEY").ok(),
41 auth_token: std::env::var("ANTHROPIC_AUTH_TOKEN").ok(),
42 base_url: std::env::var("ANTHROPIC_BASE_URL").ok(),
43 timeout: None,
44 max_retries: None,
45 default_headers: HashMap::new(),
46 default_query: HashMap::new(),
47 }
48 }
49}
50
51#[derive(Clone)]
53pub struct Anthropic {
54 http: HttpClient,
55 api_key: String,
56 auth_token: Option<String>,
57 base_url: String,
58 #[allow(dead_code)]
59 timeout: Duration,
60 max_retries: u32,
61 default_headers: HashMap<String, String>,
62 #[allow(dead_code)]
63 default_query: HashMap<String, String>,
64 #[allow(dead_code)]
65 middleware: Vec<std::sync::Arc<dyn crate::core::middleware::Middleware>>,
66}
67
68impl Anthropic {
69 pub fn new() -> Result<Self, Error> {
70 Self::with_options(ClientOptions::default())
71 }
72
73 pub fn with_options(options: ClientOptions) -> Result<Self, Error> {
74 let api_key = options
75 .api_key
76 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
77 .ok_or_else(|| {
78 Error::Anthropic(crate::core::error::AnthropicError(
79 "Missing API key: set ANTHROPIC_API_KEY or pass api_key in ClientOptions".into(),
80 ))
81 })?;
82
83 let base_url = options
84 .base_url
85 .or_else(|| std::env::var("ANTHROPIC_BASE_URL").ok())
86 .unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
87
88 let timeout = options.timeout.unwrap_or(DEFAULT_TIMEOUT);
89 let max_retries = options.max_retries.unwrap_or(DEFAULT_MAX_RETRIES);
90
91 let http = HttpClient::builder()
92 .timeout(timeout)
93 .build()
94 .map_err(|e| {
95 Error::Connection(ConnectionError {
96 message: e.to_string(),
97 source: Some(Box::new(e)),
98 })
99 })?;
100
101 Ok(Self {
102 http,
103 api_key,
104 auth_token: options.auth_token,
105 base_url: base_url.trim_end_matches('/').to_string(),
106 timeout,
107 max_retries,
108 default_headers: options.default_headers,
109 default_query: options.default_query,
110 middleware: Vec::new(),
111 })
112 }
113
114 pub fn with_api_key(api_key: impl Into<String>) -> Result<Self, Error> {
115 Self::with_options(ClientOptions {
116 api_key: Some(api_key.into()),
117 ..Default::default()
118 })
119 }
120
121 pub fn base_url(&self) -> &str {
122 &self.base_url
123 }
124
125 pub fn max_retries(&self) -> u32 {
126 self.max_retries
127 }
128
129 pub fn messages(&self) -> Messages<'_> {
130 Messages::new(self)
131 }
132
133 pub fn models(&self) -> Models<'_> {
134 Models::new(self)
135 }
136
137 pub fn completions(&self) -> Completions<'_> {
138 Completions::new(self)
139 }
140
141 pub fn beta(&self) -> Beta<'_> {
142 Beta::new(self)
143 }
144
145 pub(crate) async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T, Error> {
146 self.request(Method::GET, path, None::<&()>, false).await
147 }
148
149 pub(crate) async fn get_with_query<T: DeserializeOwned>(
150 &self,
151 path: &str,
152 query: Option<&[(&str, &str)]>,
153 ) -> Result<T, Error> {
154 let mut url = self.build_url(path);
155 if let Some(q) = query {
156 let qs: Vec<String> = q
157 .iter()
158 .map(|(k, v)| format!("{}={}", k, urlencoding_encode(v)))
159 .collect();
160 if !qs.is_empty() {
161 url.push('?');
162 url.push_str(&qs.join("&"));
163 }
164 }
165 self.request_url(Method::GET, &url, None::<&()>, false)
166 .await
167 }
168
169 pub(crate) async fn post<T, B>(&self, path: &str, body: &B) -> Result<T, Error>
170 where
171 T: DeserializeOwned,
172 B: Serialize + ?Sized,
173 {
174 self.request(Method::POST, path, Some(body), false).await
175 }
176
177 pub(crate) async fn post_streaming<B>(
178 &self,
179 path: &str,
180 body: &B,
181 ) -> Result<Response, Error>
182 where
183 B: Serialize + ?Sized,
184 {
185 self.request_raw(Method::POST, path, Some(body), true)
186 .await
187 }
188
189 pub(crate) async fn post_empty<T>(&self, path: &str) -> Result<T, Error>
190 where
191 T: DeserializeOwned,
192 {
193 self.request(Method::POST, path, None::<&()>, false).await
194 }
195
196 #[allow(dead_code)]
197 pub(crate) async fn delete<T>(&self, path: &str) -> Result<T, Error>
198 where
199 T: DeserializeOwned,
200 {
201 self.request(Method::DELETE, path, None::<&()>, false).await
202 }
203
204 #[allow(dead_code)]
205 pub(crate) async fn patch<T, B>(&self, path: &str, body: &B) -> Result<T, Error>
206 where
207 T: DeserializeOwned,
208 B: Serialize + ?Sized,
209 {
210 self.request(Method::PATCH, path, Some(body), false).await
211 }
212
213 pub(crate) async fn get_beta<T>(
214 &self,
215 path: &str,
216 beta_headers: &[String],
217 query: Option<&[(&str, &str)]>,
218 ) -> Result<T, Error>
219 where
220 T: DeserializeOwned,
221 {
222 self.request_beta(Method::GET, path, None::<&()>, beta_headers, query, false)
223 .await
224 }
225
226 pub(crate) async fn post_beta<T, B>(
227 &self,
228 path: &str,
229 body: &B,
230 beta_headers: &[String],
231 ) -> Result<T, Error>
232 where
233 T: DeserializeOwned,
234 B: Serialize + ?Sized,
235 {
236 self.request_beta(Method::POST, path, Some(body), beta_headers, None, false)
237 .await
238 }
239
240 pub(crate) async fn delete_beta<T>(
241 &self,
242 path: &str,
243 beta_headers: &[String],
244 ) -> Result<T, Error>
245 where
246 T: DeserializeOwned,
247 {
248 self.request_beta(
249 Method::DELETE,
250 path,
251 None::<&()>,
252 beta_headers,
253 None,
254 false,
255 )
256 .await
257 }
258
259 async fn request_beta<T, B>(
260 &self,
261 method: Method,
262 path: &str,
263 body: Option<&B>,
264 beta_headers: &[String],
265 query: Option<&[(&str, &str)]>,
266 stream: bool,
267 ) -> Result<T, Error>
268 where
269 T: DeserializeOwned,
270 B: Serialize + ?Sized,
271 {
272 let mut extra_headers = self.default_headers.clone();
273 if !beta_headers.is_empty() {
274 extra_headers.insert("anthropic-beta".to_string(), beta_headers.join(","));
275 }
276 let url = self.build_url(path);
277 let mut full_url = url;
278 if let Some(q) = query {
279 let qs: Vec<String> = q
280 .iter()
281 .map(|(k, v)| format!("{}={}", k, urlencoding_encode(v)))
282 .collect();
283 if !qs.is_empty() {
284 full_url.push('?');
285 full_url.push_str(&qs.join("&"));
286 }
287 }
288
289 let response = self
290 .make_request_with_retries_beta(method, &full_url, body, stream, self.max_retries, &extra_headers)
291 .await?;
292 let status = response.status().as_u16();
293 let headers = response.headers().clone();
294 let bytes = response.bytes().await.map_err(|e| {
295 Error::Connection(ConnectionError {
296 message: e.to_string(),
297 source: Some(Box::new(e)),
298 })
299 })?;
300
301 if !(200..300).contains(&status) {
302 let body_json = serde_json::from_slice(&bytes).unwrap_or(serde_json::Value::Null);
303 return Err(ApiError::generate(
304 Some(status),
305 body_json,
306 None,
307 header_map_from_reqwest(&headers),
308 ));
309 }
310
311 serde_json::from_slice(&bytes).map_err(|e| {
312 Error::Anthropic(crate::core::error::AnthropicError(format!(
313 "failed to parse response JSON: {e}"
314 )))
315 })
316 }
317
318 async fn make_request_with_retries_beta<B>(
319 &self,
320 method: Method,
321 url: &str,
322 body: Option<&B>,
323 stream: bool,
324 mut retries_remaining: u32,
325 extra_headers: &HashMap<String, String>,
326 ) -> Result<Response, Error>
327 where
328 B: Serialize + ?Sized,
329 {
330 loop {
331 let mut headers = self.build_headers(stream)?;
332 merge_headers(&mut headers, extra_headers);
333 let mut req = self.http.request(method.clone(), url).headers(headers);
334
335 if let Some(b) = body {
336 req = req.json(b);
337 }
338
339 let response = match req.send().await {
340 Ok(r) => r,
341 Err(e) => {
342 if e.is_timeout() {
343 return Err(Error::ConnectionTimeout(ConnectionTimeoutError(
344 e.to_string(),
345 )));
346 }
347 if retries_remaining == 0 {
348 return Err(Error::Connection(ConnectionError {
349 message: e.to_string(),
350 source: Some(Box::new(e)),
351 }));
352 }
353 retries_remaining -= 1;
354 tokio::time::sleep(Duration::from_millis(default_retry_timeout_ms(
355 retries_remaining,
356 self.max_retries,
357 )))
358 .await;
359 continue;
360 }
361 };
362
363 let status = response.status().as_u16();
364 if (200..300).contains(&status) {
365 return Ok(response);
366 }
367
368 if retries_remaining == 0 || !should_retry(status, response.headers()) {
369 return Ok(response);
370 }
371
372 let wait = retry_after_ms(response.headers()).unwrap_or_else(|| {
373 default_retry_timeout_ms(retries_remaining - 1, self.max_retries)
374 });
375 retries_remaining -= 1;
376 tokio::time::sleep(Duration::from_millis(wait)).await;
377 }
378 }
379
380 async fn request<T, B>(
381 &self,
382 method: Method,
383 path: &str,
384 body: Option<&B>,
385 stream: bool,
386 ) -> Result<T, Error>
387 where
388 T: DeserializeOwned,
389 B: Serialize + ?Sized,
390 {
391 let response = self
392 .request_raw(method.clone(), path, body, stream)
393 .await?;
394 let status = response.status().as_u16();
395 let headers = response.headers().clone();
396 let bytes = response.bytes().await.map_err(|e| {
397 Error::Connection(ConnectionError {
398 message: e.to_string(),
399 source: Some(Box::new(e)),
400 })
401 })?;
402
403 if !(200..300).contains(&status) {
404 let body_json = serde_json::from_slice(&bytes).unwrap_or(serde_json::Value::Null);
405 return Err(ApiError::generate(
406 Some(status),
407 body_json,
408 None,
409 header_map_from_reqwest(&headers),
410 ));
411 }
412
413 serde_json::from_slice(&bytes).map_err(|e| {
414 Error::Anthropic(crate::core::error::AnthropicError(format!(
415 "failed to parse response JSON: {e}"
416 )))
417 })
418 }
419
420 async fn request_url<T, B>(
421 &self,
422 method: Method,
423 url: &str,
424 body: Option<&B>,
425 stream: bool,
426 ) -> Result<T, Error>
427 where
428 T: DeserializeOwned,
429 B: Serialize + ?Sized,
430 {
431 let response = self
432 .make_request_with_retries(method, url, body, stream, self.max_retries)
433 .await?;
434 let status = response.status().as_u16();
435 let headers = response.headers().clone();
436 let bytes = response.bytes().await.map_err(|e| {
437 Error::Connection(ConnectionError {
438 message: e.to_string(),
439 source: Some(Box::new(e)),
440 })
441 })?;
442
443 if !(200..300).contains(&status) {
444 let body_json = serde_json::from_slice(&bytes).unwrap_or(serde_json::Value::Null);
445 return Err(ApiError::generate(
446 Some(status),
447 body_json,
448 None,
449 header_map_from_reqwest(&headers),
450 ));
451 }
452
453 serde_json::from_slice(&bytes).map_err(|e| {
454 Error::Anthropic(crate::core::error::AnthropicError(format!(
455 "failed to parse response JSON: {e}"
456 )))
457 })
458 }
459
460 async fn request_raw<B>(
461 &self,
462 method: Method,
463 path: &str,
464 body: Option<&B>,
465 stream: bool,
466 ) -> Result<Response, Error>
467 where
468 B: Serialize + ?Sized,
469 {
470 let url = self.build_url(path);
471 self.make_request_with_retries(method, &url, body, stream, self.max_retries)
472 .await
473 }
474
475 fn build_url(&self, path: &str) -> String {
476 format!("{}{}", self.base_url, path)
477 }
478
479 fn build_headers(&self, stream: bool) -> Result<HeaderMap, Error> {
480 let mut headers = default_headers(&self.api_key);
481 if let Some(token) = &self.auth_token {
482 headers.insert(
483 "authorization",
484 format!("Bearer {token}").parse().unwrap(),
485 );
486 }
487 if stream {
488 headers.insert("accept", "text/event-stream".parse().unwrap());
489 } else {
490 headers.insert("accept", "application/json".parse().unwrap());
491 }
492 headers.insert("content-type", "application/json".parse().unwrap());
493 merge_headers(&mut headers, &self.default_headers);
494 Ok(headers)
495 }
496
497 async fn make_request_with_retries<B>(
498 &self,
499 method: Method,
500 url: &str,
501 body: Option<&B>,
502 stream: bool,
503 mut retries_remaining: u32,
504 ) -> Result<Response, Error>
505 where
506 B: Serialize + ?Sized,
507 {
508 loop {
509 let headers = self.build_headers(stream)?;
510 let mut req = self.http.request(method.clone(), url).headers(headers);
511
512 if let Some(b) = body {
513 req = req.json(b);
514 }
515
516 let response = match req.send().await {
517 Ok(r) => r,
518 Err(e) => {
519 if e.is_timeout() {
520 return Err(Error::ConnectionTimeout(ConnectionTimeoutError(
521 e.to_string(),
522 )));
523 }
524 if retries_remaining == 0 {
525 return Err(Error::Connection(ConnectionError {
526 message: e.to_string(),
527 source: Some(Box::new(e)),
528 }));
529 }
530 retries_remaining -= 1;
531 tokio::time::sleep(Duration::from_millis(default_retry_timeout_ms(
532 retries_remaining,
533 self.max_retries,
534 )))
535 .await;
536 continue;
537 }
538 };
539
540 let status = response.status().as_u16();
541 if (200..300).contains(&status) {
542 return Ok(response);
543 }
544
545 if retries_remaining == 0 || !should_retry(status, response.headers()) {
546 return Ok(response);
547 }
548
549 let wait = retry_after_ms(response.headers()).unwrap_or_else(|| {
550 default_retry_timeout_ms(retries_remaining - 1, self.max_retries)
551 });
552 retries_remaining -= 1;
553 tokio::time::sleep(Duration::from_millis(wait)).await;
554 }
555 }
556}
557
558impl Default for Anthropic {
559 fn default() -> Self {
560 Self::new().expect("ANTHROPIC_API_KEY must be set for default client")
561 }
562}
563
564fn should_retry(status: u16, headers: &reqwest::header::HeaderMap) -> bool {
565 if let Some(v) = headers.get("x-should-retry") {
566 if let Ok(s) = v.to_str() {
567 if s == "true" {
568 return true;
569 }
570 if s == "false" {
571 return false;
572 }
573 }
574 }
575
576 matches!(status, 408 | 409 | 429) || (500..600).contains(&status)
577}
578
579fn header_map_from_reqwest(headers: &reqwest::header::HeaderMap) -> HeaderMap {
580 let mut map = HeaderMap::new();
581 for (k, v) in headers.iter() {
582 if let Ok(val) = http::HeaderValue::from_bytes(v.as_bytes()) {
583 map.insert(k.clone(), val);
584 }
585 }
586 map
587}
588
589fn urlencoding_encode(s: &str) -> String {
590 s.chars()
591 .map(|c| match c {
592 'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' => c.to_string(),
593 _ => format!("%{:02X}", c as u8),
594 })
595 .collect()
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601
602 #[test]
603 fn should_retry_rate_limit() {
604 let headers = reqwest::header::HeaderMap::new();
605 assert!(should_retry(429, &headers));
606 }
607
608 #[test]
609 fn should_not_retry_bad_request() {
610 let headers = reqwest::header::HeaderMap::new();
611 assert!(!should_retry(400, &headers));
612 }
613}