1use futures::Stream;
2use reqwest::header::{HeaderMap, HeaderValue};
3use reqwest::{Client as ReqwestClient, Response, header};
4use serde::Deserialize;
5use std::env;
6use std::fs;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::time::sleep;
10
11use crate::backoff::ExponentialBackoff;
12use crate::error::{Error, Result};
13use crate::sse::process_sse;
14use crate::types::{
15 Message, MessageCountTokensParams, MessageCreateParams, MessageStreamEvent, MessageTokensCount,
16 ModelInfo, ModelListParams, ModelListResponse,
17};
18
19const DEFAULT_API_URL: &str = "https://api.anthropic.com/v1/";
20const ANTHROPIC_API_VERSION: &str = "2023-06-01";
21const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60);
22
23#[derive(Debug, Clone)]
25pub struct Anthropic {
26 api_key: String,
27 client: ReqwestClient,
28 base_url: String,
29 timeout: Duration,
30 max_retries: usize,
31 throughput_ops_sec: f64,
32 reserve_capacity: f64,
33 cached_headers: Arc<HeaderMap>,
35}
36
37impl Anthropic {
38 fn resolve_api_key(key_value: &str) -> Result<String> {
40 if let Some(stripped) = key_value.strip_prefix("file://") {
41 let path = if stripped.starts_with('/') {
43 stripped.to_string()
45 } else {
46 stripped.to_string()
48 };
49
50 fs::read_to_string(&path)
51 .map(|content| content.trim().to_string())
52 .map_err(|e| {
53 Error::validation(
54 format!("Failed to read API key from file '{}': {}", path, e),
55 Some("api_key".to_string()),
56 )
57 })
58 } else {
59 Ok(key_value.to_string())
61 }
62 }
63
64 pub fn new(api_key: Option<String>) -> Result<Self> {
70 let api_key = match api_key {
71 Some(key) => Self::resolve_api_key(&key)?,
72 None => match env::var("CLAUDIUS_API_KEY").ok() {
73 Some(key) => Self::resolve_api_key(&key)?,
74 None => {
75 let env_key = env::var("ANTHROPIC_API_KEY").map_err(|_| {
76 Error::authentication(
77 "API key not provided and ANTHROPIC_API_KEY environment variable not set",
78 )
79 })?;
80 Self::resolve_api_key(&env_key)?
81 }
82 },
83 };
84
85 let timeout = DEFAULT_TIMEOUT;
86 let client = ReqwestClient::builder()
87 .timeout(timeout)
88 .pool_max_idle_per_host(10) .pool_idle_timeout(Duration::from_secs(90))
90 .tcp_keepalive(Duration::from_secs(60))
91 .build()
92 .map_err(|e| {
93 Error::http_client(
94 format!("Failed to build HTTP client: {e}"),
95 Some(Box::new(e)),
96 )
97 })?;
98
99 let cached_headers = Arc::new(Self::build_default_headers(&api_key)?);
101
102 Ok(Self {
103 api_key,
104 client,
105 base_url: DEFAULT_API_URL.to_string(),
106 timeout,
107 max_retries: 3,
108 throughput_ops_sec: 1.0 / 60.0,
109 reserve_capacity: 1.0 / 60.0,
110 cached_headers,
111 })
112 }
113
114 pub fn with_base_url(mut self, base_url: String) -> Self {
118 self.base_url = base_url;
119 self
120 }
121
122 pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
126 self.timeout = timeout;
127
128 let client = ReqwestClient::builder()
130 .timeout(timeout)
131 .pool_max_idle_per_host(10)
132 .pool_idle_timeout(Duration::from_secs(90))
133 .tcp_keepalive(Duration::from_secs(60))
134 .build()
135 .map_err(|e| {
136 Error::http_client(
137 "Failed to build HTTP client with new timeout",
138 Some(Box::new(e)),
139 )
140 })?;
141
142 self.client = client;
143 Ok(self)
144 }
145
146 pub fn with_max_retries(mut self, max_retries: usize) -> Self {
150 self.max_retries = max_retries;
151 self
152 }
153
154 pub fn api_key(&self) -> &str {
156 &self.api_key
157 }
158
159 pub fn with_backoff_params(mut self, throughput_ops_sec: f64, reserve_capacity: f64) -> Self {
163 self.throughput_ops_sec = throughput_ops_sec;
164 self.reserve_capacity = reserve_capacity;
165 self
166 }
167
168 pub fn with_base_url_and_timeout(self, base_url: String, timeout: Duration) -> Result<Self> {
172 self.with_base_url(base_url).with_timeout(timeout)
173 }
174
175 fn build_default_headers(api_key: &str) -> Result<HeaderMap> {
177 let mut headers = HeaderMap::new();
178 headers.insert(
179 header::CONTENT_TYPE,
180 HeaderValue::from_static("application/json"),
181 );
182 headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
183 headers.insert(
184 "x-api-key",
185 HeaderValue::from_str(api_key).map_err(|e| {
186 Error::validation(
187 format!("Invalid API key format: {e}"),
188 Some("api_key".to_string()),
189 )
190 })?,
191 );
192 headers.insert(
193 "anthropic-version",
194 HeaderValue::from_static(ANTHROPIC_API_VERSION),
195 );
196 Ok(headers)
197 }
198
199 fn default_headers(&self) -> HeaderMap {
201 (*self.cached_headers).clone()
202 }
203
204 async fn retry_with_backoff<F, Fut, T>(&self, operation: F) -> Result<T>
206 where
207 F: Fn() -> Fut,
208 Fut: std::future::Future<Output = Result<T>>,
209 {
210 let backoff = ExponentialBackoff::new(self.throughput_ops_sec, self.reserve_capacity);
211 let mut last_error = None;
212
213 for attempt in 0..=self.max_retries {
214 match operation().await {
215 Ok(result) => return Ok(result),
216 Err(error) => {
217 if !error.is_retryable() {
219 return Err(error);
220 }
221
222 if attempt == self.max_retries {
224 last_error = Some(error);
225 break;
226 }
227
228 let exp_backoff_duration = backoff.next();
230
231 let header_backoff_duration = match &error {
233 Error::RateLimit {
234 retry_after: Some(seconds),
235 ..
236 } => Some(Duration::from_secs(*seconds)),
237 Error::ServiceUnavailable {
238 retry_after: Some(seconds),
239 ..
240 } => Some(Duration::from_secs(*seconds)),
241 _ => None,
242 };
243
244 let sleep_duration = match header_backoff_duration {
246 Some(header_duration) => exp_backoff_duration.max(header_duration),
247 None => exp_backoff_duration,
248 };
249
250 sleep(sleep_duration).await;
251 last_error = Some(error);
252 }
253 }
254 }
255
256 Err(last_error
257 .unwrap_or_else(|| Error::unknown("Failed after retries without capturing error")))
258 }
259
260 async fn process_error_response(response: Response) -> Error {
262 let status = response.status();
263 let status_code = status.as_u16();
264
265 let request_id = response
267 .headers()
268 .get("x-request-id")
269 .and_then(|val| val.to_str().ok())
270 .map(String::from);
271
272 let retry_after = response
273 .headers()
274 .get("retry-after")
275 .and_then(|val| val.to_str().ok())
276 .and_then(|val| val.parse::<u64>().ok());
277
278 #[derive(Deserialize)]
280 struct ErrorResponse {
281 error: Option<ErrorDetail>,
282 }
283
284 #[derive(Deserialize)]
285 struct ErrorDetail {
286 #[serde(rename = "type")]
287 error_type: Option<String>,
288 message: Option<String>,
289 param: Option<String>,
290 }
291
292 let error_body = match response.text().await {
293 Ok(body) => body,
294 Err(e) => {
295 return Error::http_client(
296 format!("Failed to read error response: {e}"),
297 Some(Box::new(e)),
298 );
299 }
300 };
301
302 let parsed_error = serde_json::from_str::<ErrorResponse>(&error_body).ok();
304 let error_type = parsed_error
305 .as_ref()
306 .and_then(|e| e.error.as_ref())
307 .and_then(|e| e.error_type.clone());
308 let error_message = parsed_error
309 .as_ref()
310 .and_then(|e| e.error.as_ref())
311 .and_then(|e| e.message.clone())
312 .unwrap_or_else(|| error_body.clone());
313 let error_param = parsed_error
314 .as_ref()
315 .and_then(|e| e.error.as_ref())
316 .and_then(|e| e.param.clone());
317
318 match status_code {
320 400 => Error::bad_request(error_message, error_param),
321 401 => Error::authentication(error_message),
322 403 => Error::permission(error_message),
323 404 => Error::not_found(error_message, None, None),
324 408 => Error::timeout(error_message, None),
325 429 => Error::rate_limit(error_message, retry_after),
326 500 => Error::internal_server(error_message, request_id),
327 502..=504 => Error::service_unavailable(error_message, retry_after),
328 529 => Error::rate_limit(error_message, retry_after),
329 _ => Error::api(status_code, error_type, error_message, request_id),
330 }
331 }
332
333 fn map_request_error(&self, e: reqwest::Error) -> Error {
335 if e.is_timeout() {
336 Error::timeout(
337 format!("Request timed out: {e}"),
338 Some(self.timeout.as_secs_f64()),
339 )
340 } else if e.is_connect() {
341 Error::connection(format!("Connection error: {e}"), Some(Box::new(e)))
342 } else {
343 Error::http_client(format!("Request failed: {e}"), Some(Box::new(e)))
344 }
345 }
346
347 async fn execute_post_request<T: serde::de::DeserializeOwned>(
349 &self,
350 url: &str,
351 body: &impl serde::Serialize,
352 headers: Option<HeaderMap>,
353 ) -> Result<T> {
354 let headers = headers.unwrap_or_else(|| self.default_headers());
355
356 let response = self
357 .client
358 .post(url)
359 .headers(headers)
360 .json(body)
361 .send()
362 .await
363 .map_err(|e| self.map_request_error(e))?;
364
365 if !response.status().is_success() {
366 return Err(Self::process_error_response(response).await);
367 }
368
369 response.json::<T>().await.map_err(|e| {
370 Error::serialization(format!("Failed to parse response: {e}"), Some(Box::new(e)))
371 })
372 }
373
374 async fn execute_get_request<T: serde::de::DeserializeOwned>(
376 &self,
377 url: &str,
378 query_params: Option<&[(String, String)]>,
379 ) -> Result<T> {
380 let mut request = self.client.get(url).headers(self.default_headers());
381
382 if let Some(params) = query_params {
383 for (key, value) in params {
384 request = request.query(&[(key, value)]);
385 }
386 }
387
388 let response = request
389 .send()
390 .await
391 .map_err(|e| self.map_request_error(e))?;
392
393 if !response.status().is_success() {
394 return Err(Self::process_error_response(response).await);
395 }
396
397 response.json::<T>().await.map_err(|e| {
398 Error::serialization(format!("Failed to parse response: {e}"), Some(Box::new(e)))
399 })
400 }
401
402 pub async fn send(&self, mut params: MessageCreateParams) -> Result<Message> {
404 params.validate()?;
406
407 params.stream = false;
409
410 self.retry_with_backoff(|| async {
411 let url = format!("{}messages", self.base_url);
412 self.execute_post_request(&url, ¶ms, None).await
413 })
414 .await
415 }
416
417 pub async fn stream(
421 &self,
422 mut params: MessageCreateParams,
423 ) -> Result<impl Stream<Item = Result<MessageStreamEvent>>> {
424 params.validate()?;
426
427 params.stream = true;
429
430 let response = self
431 .retry_with_backoff(|| async {
432 let url = format!("{}messages", self.base_url);
433
434 let mut headers = self.default_headers();
435 headers.insert(
436 header::ACCEPT,
437 HeaderValue::from_static("text/event-stream"),
438 );
439
440 let response = self
441 .client
442 .post(&url)
443 .headers(headers)
444 .json(¶ms)
445 .send()
446 .await
447 .map_err(|e| self.map_request_error(e))?;
448
449 if !response.status().is_success() {
450 return Err(Self::process_error_response(response).await);
451 }
452
453 Ok(response)
454 })
455 .await?;
456
457 let stream = response.bytes_stream();
459
460 Ok(process_sse(stream))
462 }
463
464 pub async fn count_tokens(
469 &self,
470 params: MessageCountTokensParams,
471 ) -> Result<MessageTokensCount> {
472 self.retry_with_backoff(|| async {
473 let url = format!("{}messages/count_tokens", self.base_url);
474 self.execute_post_request(&url, ¶ms, None).await
475 })
476 .await
477 }
478
479 pub async fn list_models(&self, params: Option<ModelListParams>) -> Result<ModelListResponse> {
484 self.retry_with_backoff(|| async {
485 let url = format!("{}models", self.base_url);
486
487 let query_params = params.as_ref().map(|p| {
488 let mut params = Vec::new();
489 if let Some(ref after_id) = p.after_id {
490 params.push(("after_id".to_string(), after_id.clone()));
491 }
492 if let Some(ref before_id) = p.before_id {
493 params.push(("before_id".to_string(), before_id.clone()));
494 }
495 if let Some(limit) = p.limit {
496 params.push(("limit".to_string(), limit.to_string()));
497 }
498 params
499 });
500
501 self.execute_get_request(&url, query_params.as_deref())
502 .await
503 })
504 .await
505 }
506
507 pub async fn get_model(&self, model_id: &str) -> Result<ModelInfo> {
512 self.retry_with_backoff(|| async {
513 let url = format!("{}models/{}", self.base_url, model_id);
514 self.execute_get_request(&url, None).await
515 })
516 .await
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use std::sync::Arc;
524 use std::sync::atomic::{AtomicUsize, Ordering};
525
526 #[tokio::test]
527 async fn retry_logic_with_backoff() {
528 let client = Anthropic {
529 api_key: "test".to_string(),
530 client: ReqwestClient::new(),
531 base_url: "http://localhost".to_string(),
532 timeout: Duration::from_secs(1),
533 max_retries: 2,
534 throughput_ops_sec: 1.0 / 60.0,
535 reserve_capacity: 1.0 / 60.0,
536 cached_headers: Arc::new(HeaderMap::new()),
537 };
538
539 let attempt_counter = Arc::new(AtomicUsize::new(0));
540 let counter_clone = attempt_counter.clone();
541
542 let result = client
543 .retry_with_backoff(|| {
544 let counter = counter_clone.clone();
545 async move {
546 let attempt = counter.fetch_add(1, Ordering::SeqCst);
547 match attempt {
548 0 | 1 => Err(Error::rate_limit("Rate limited", Some(1))),
549 _ => Ok("success".to_string()),
550 }
551 }
552 })
553 .await;
554
555 assert!(result.is_ok());
556 assert_eq!(result.unwrap(), "success");
557 assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
558 }
559
560 #[tokio::test]
561 async fn retry_logic_with_non_retryable_error() {
562 let client = Anthropic {
563 api_key: "test".to_string(),
564 client: ReqwestClient::new(),
565 base_url: "http://localhost".to_string(),
566 timeout: Duration::from_secs(1),
567 max_retries: 2,
568 throughput_ops_sec: 1.0 / 60.0,
569 reserve_capacity: 1.0 / 60.0,
570 cached_headers: Arc::new(HeaderMap::new()),
571 };
572
573 let attempt_counter = Arc::new(AtomicUsize::new(0));
574 let counter_clone = attempt_counter.clone();
575
576 let result: Result<String> = client
577 .retry_with_backoff(|| {
578 let counter = counter_clone.clone();
579 async move {
580 counter.fetch_add(1, Ordering::SeqCst);
581 Err(Error::authentication("Invalid API key"))
582 }
583 })
584 .await;
585
586 assert!(result.is_err());
587 assert!(result.unwrap_err().is_authentication());
588 assert_eq!(attempt_counter.load(Ordering::SeqCst), 1);
590 }
591
592 #[tokio::test]
593 async fn retry_logic_max_retries_exceeded() {
594 let client = Anthropic {
595 api_key: "test".to_string(),
596 client: ReqwestClient::new(),
597 base_url: "http://localhost".to_string(),
598 timeout: Duration::from_secs(1),
599 max_retries: 2,
600 throughput_ops_sec: 1.0 / 60.0,
601 reserve_capacity: 1.0 / 60.0,
602 cached_headers: Arc::new(HeaderMap::new()),
603 };
604
605 let attempt_counter = Arc::new(AtomicUsize::new(0));
606 let counter_clone = attempt_counter.clone();
607
608 let result: Result<String> = client
609 .retry_with_backoff(|| {
610 let counter = counter_clone.clone();
611 async move {
612 counter.fetch_add(1, Ordering::SeqCst);
613 Err(Error::rate_limit("Always rate limited", Some(1)))
614 }
615 })
616 .await;
617
618 assert!(result.is_err());
619 assert!(result.unwrap_err().is_rate_limit());
620 assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
622 }
623
624 #[tokio::test]
625 async fn error_529_is_retryable() {
626 let client = Anthropic {
628 api_key: "test".to_string(),
629 client: ReqwestClient::new(),
630 base_url: "http://localhost".to_string(),
631 timeout: Duration::from_secs(1),
632 max_retries: 2,
633 throughput_ops_sec: 1.0 / 60.0,
634 reserve_capacity: 1.0 / 60.0,
635 cached_headers: Arc::new(HeaderMap::new()),
636 };
637
638 let attempt_counter = Arc::new(AtomicUsize::new(0));
639 let counter_clone = attempt_counter.clone();
640
641 let result = client
642 .retry_with_backoff(|| {
643 let counter = counter_clone.clone();
644 async move {
645 let attempt = counter.fetch_add(1, Ordering::SeqCst);
646 match attempt {
647 0 | 1 => {
648 Err(Error::api(
650 529,
651 Some("overloaded_error".to_string()),
652 "Overloaded".to_string(),
653 None,
654 ))
655 }
656 _ => Ok("success".to_string()),
657 }
658 }
659 })
660 .await;
661
662 assert!(result.is_ok());
663 assert_eq!(result.unwrap(), "success");
664 assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
666 }
667
668 #[test]
669 fn error_529_mapped_correctly() {
670 let error = Error::api(
672 529,
673 Some("overloaded_error".to_string()),
674 "Overloaded".to_string(),
675 None,
676 );
677 assert!(error.is_retryable());
678
679 let rate_limit_error = Error::rate_limit("Overloaded", Some(5));
681 assert!(rate_limit_error.is_retryable());
682 }
683
684 #[test]
685 fn resolve_api_key_regular_value() {
686 let result = Anthropic::resolve_api_key("sk-test-key-123");
687 assert!(result.is_ok());
688 assert_eq!(result.unwrap(), "sk-test-key-123");
689 }
690
691 #[test]
692 fn resolve_api_key_file_url_absolute() {
693 let test_dir = std::env::temp_dir().join(format!("claudius_test_{}", std::process::id()));
694 std::fs::create_dir_all(&test_dir).unwrap();
695 let test_file = test_dir.join("test_api_key.txt");
696 std::fs::write(&test_file, "sk-test-from-file-123\n").unwrap();
697
698 let file_url = format!("file://{}", test_file.display());
699 let result = Anthropic::resolve_api_key(&file_url);
700
701 std::fs::remove_dir_all(&test_dir).unwrap();
702
703 assert!(result.is_ok());
704 assert_eq!(result.unwrap(), "sk-test-from-file-123");
705 }
706
707 #[test]
708 fn resolve_api_key_file_url_relative() {
709 let test_file = "test_relative_key.txt";
710 std::fs::write(test_file, "sk-relative-key-456\n").unwrap();
711
712 let file_url = format!("file://{}", test_file);
713 let result = Anthropic::resolve_api_key(&file_url);
714
715 std::fs::remove_file(test_file).unwrap();
716
717 assert!(result.is_ok());
718 assert_eq!(result.unwrap(), "sk-relative-key-456");
719 }
720
721 #[test]
722 fn resolve_api_key_file_url_nonexistent() {
723 let result = Anthropic::resolve_api_key("file:///nonexistent/path/to/key.txt");
724 assert!(result.is_err());
725
726 let error = result.unwrap_err();
727 assert!(error.is_validation());
728 assert!(format!("{}", error).contains("Failed to read API key from file"));
729 }
730
731 #[test]
732 fn resolve_api_key_file_url_with_whitespace() {
733 let test_file = "test_whitespace_key.txt";
734 std::fs::write(test_file, " sk-whitespace-key-789 \n ").unwrap();
735
736 let file_url = format!("file://{}", test_file);
737 let result = Anthropic::resolve_api_key(&file_url);
738
739 std::fs::remove_file(test_file).unwrap();
740
741 assert!(result.is_ok());
742 assert_eq!(result.unwrap(), "sk-whitespace-key-789");
743 }
744
745 #[test]
746 fn client_builder_methods() {
747 let client = Anthropic::new(Some("test_key".to_string())).unwrap();
748
749 let configured_client = client
751 .with_base_url("https://custom.api.com/v1/".to_string())
752 .with_max_retries(5)
753 .with_backoff_params(2.0, 1.0);
754
755 assert_eq!(configured_client.base_url, "https://custom.api.com/v1/");
756 assert_eq!(configured_client.max_retries, 5);
757 assert_eq!(configured_client.throughput_ops_sec, 2.0);
758 assert_eq!(configured_client.reserve_capacity, 1.0);
759 }
760
761 #[test]
762 fn client_timeout_configuration() {
763 let client = Anthropic::new(Some("test_key".to_string())).unwrap();
764 let timeout = Duration::from_secs(30);
765
766 let configured_client = client.with_timeout(timeout).unwrap();
767 assert_eq!(configured_client.timeout, timeout);
768 }
769
770 #[test]
771 fn client_cached_headers_performance() {
772 let client = Anthropic::new(Some("test_key".to_string())).unwrap();
773
774 let headers1 = client.default_headers();
776 let headers2 = client.default_headers();
777
778 assert_eq!(headers1.len(), headers2.len());
779 assert!(headers1.contains_key("x-api-key"));
780 assert!(headers1.contains_key("anthropic-version"));
781 assert!(headers1.contains_key("content-type"));
782 }
783
784 #[test]
785 fn request_error_mapping() {
786 let client = Anthropic::new(Some("test_key".to_string())).unwrap();
787
788 let _timeout = Duration::from_secs(30);
791 assert_eq!(client.timeout, DEFAULT_TIMEOUT); }
793
794 #[tokio::test]
795 async fn concurrent_retry_safety() {
796 use std::sync::atomic::{AtomicUsize, Ordering};
797 use tokio::spawn;
798
799 let client = Anthropic {
800 api_key: "test".to_string(),
801 client: ReqwestClient::new(),
802 base_url: "http://localhost".to_string(),
803 timeout: Duration::from_secs(1),
804 max_retries: 1,
805 throughput_ops_sec: 1.0,
806 reserve_capacity: 1.0,
807 cached_headers: Arc::new(HeaderMap::new()),
808 };
809
810 let attempt_counter = Arc::new(AtomicUsize::new(0));
811 let mut handles = vec![];
812
813 for _ in 0..3 {
815 let client_clone = client.clone();
816 let counter_clone = attempt_counter.clone();
817
818 let handle = spawn(async move {
819 client_clone
820 .retry_with_backoff(|| {
821 let counter = counter_clone.clone();
822 async move {
823 counter.fetch_add(1, Ordering::SeqCst);
824 Ok::<String, Error>("success".to_string())
825 }
826 })
827 .await
828 });
829 handles.push(handle);
830 }
831
832 for handle in handles {
834 let result = handle.await.unwrap();
835 assert!(result.is_ok());
836 }
837
838 assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
840 }
841}