1use std::env;
2use std::fs;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use futures::Stream;
9use reqwest::header::{HeaderMap, HeaderValue};
10use reqwest::{Client as ReqwestClient, Response, header};
11use serde::Deserialize;
12use tokio::time::sleep;
13
14use crate::AccumulatingStream;
15use crate::backoff::ExponentialBackoff;
16use crate::client_logger::ClientLogger;
17use crate::error::{Error, Result};
18use crate::observability::{
19 CLIENT_REQUEST_DURATION, CLIENT_REQUEST_ERRORS, CLIENT_REQUEST_RETRIES, CLIENT_REQUESTS,
20 CLIENT_RETRY_BACKOFF,
21};
22use crate::sse::process_sse;
23use crate::types::{
24 BatchRequest, BatchResultItem, FileObject, Message, MessageBatch, MessageCountTokensParams,
25 MessageCreateParams, MessageStreamEvent, MessageTokensCount, ModelInfo, ModelListParams,
26 ModelListResponse, PaginatedList, SkillObject, ThinkingConfig,
27};
28
29use base64::Engine as _;
30
31fn base64_encode(data: &[u8]) -> String {
33 base64::engine::general_purpose::STANDARD.encode(data)
34}
35
36pub struct LoggingStream<'a> {
42 inner: AccumulatingStream,
43 logger: &'a dyn ClientLogger,
44 receiver: Option<tokio::sync::oneshot::Receiver<Result<Message>>>,
45}
46
47impl<'a> LoggingStream<'a> {
48 fn new(
50 inner: AccumulatingStream,
51 receiver: tokio::sync::oneshot::Receiver<Result<Message>>,
52 logger: &'a dyn ClientLogger,
53 ) -> Self {
54 Self { inner, logger, receiver: Some(receiver) }
55 }
56}
57
58impl Stream for LoggingStream<'_> {
59 type Item = Result<MessageStreamEvent>;
60
61 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
62 let inner = Pin::new(&mut self.inner);
63 match inner.poll_next(cx) {
64 Poll::Ready(Some(Ok(event))) => {
65 self.logger.log_stream_event(&event);
66 Poll::Ready(Some(Ok(event)))
67 }
68 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
69 Poll::Ready(None) => {
70 if let Some(mut receiver) = self.receiver.take()
72 && let Ok(Ok(ref message)) = receiver.try_recv()
73 {
74 self.logger.log_stream_message(message);
75 }
76 Poll::Ready(None)
77 }
78 Poll::Pending => Poll::Pending,
79 }
80 }
81}
82
83const DEFAULT_API_URL: &str = "https://api.anthropic.com";
84const ANTHROPIC_API_VERSION: &str = "2023-06-01";
85const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60);
86const STRUCTURED_OUTPUTS_BETA: &str = "structured-outputs-2025-11-13";
87
88#[derive(Debug, Clone)]
90pub struct Anthropic {
91 api_key: String,
92 client: ReqwestClient,
93 base_url: String,
94 timeout: Duration,
95 max_retries: usize,
96 throughput_ops_sec: f64,
97 reserve_capacity: f64,
98 cached_headers: Arc<HeaderMap>,
100}
101
102impl Anthropic {
103 fn resolve_api_key(key_value: &str) -> Result<String> {
105 if let Some(stripped) = key_value.strip_prefix("file://") {
106 let path = if stripped.starts_with('/') {
108 stripped.to_string()
110 } else {
111 stripped.to_string()
113 };
114
115 fs::read_to_string(&path).map(|content| content.trim().to_string()).map_err(|e| {
116 Error::validation(
117 format!("Failed to read API key from file '{}': {}", path, e),
118 Some("api_key".to_string()),
119 )
120 })
121 } else {
122 Ok(key_value.to_string())
124 }
125 }
126
127 pub fn new(api_key: Option<String>) -> Result<Self> {
136 let api_key = match api_key {
137 Some(key) => Self::resolve_api_key(&key)?,
138 None => {
139 let env_key = env::var("ANTHROPIC_API_KEY").map_err(|_| {
140 Error::authentication(
141 "API key not provided and ANTHROPIC_API_KEY environment variable not set",
142 )
143 })?;
144 Self::resolve_api_key(&env_key)?
145 }
146 };
147
148 let timeout = DEFAULT_TIMEOUT;
149 let client = ReqwestClient::builder()
150 .timeout(timeout)
151 .pool_max_idle_per_host(10) .pool_idle_timeout(Duration::from_secs(90))
153 .tcp_keepalive(Duration::from_secs(60))
154 .build()
155 .map_err(|e| {
156 Error::http_client(format!("Failed to build HTTP client: {e}"), Some(Box::new(e)))
157 })?;
158
159 let cached_headers = Arc::new(Self::build_default_headers(&api_key)?);
161
162 let base_url =
164 env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_API_URL.to_string());
165
166 Ok(Self {
167 api_key,
168 client,
169 base_url,
170 timeout,
171 max_retries: 3,
172 throughput_ops_sec: 1.0 / 60.0,
173 reserve_capacity: 1.0 / 60.0,
174 cached_headers,
175 })
176 }
177
178 pub fn with_base_url(mut self, base_url: String) -> Self {
202 self.base_url = base_url;
203 self
204 }
205
206 pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
210 self.timeout = timeout;
211
212 let client = ReqwestClient::builder()
214 .timeout(timeout)
215 .pool_max_idle_per_host(10)
216 .pool_idle_timeout(Duration::from_secs(90))
217 .tcp_keepalive(Duration::from_secs(60))
218 .build()
219 .map_err(|e| {
220 Error::http_client(
221 "Failed to build HTTP client with new timeout",
222 Some(Box::new(e)),
223 )
224 })?;
225
226 self.client = client;
227 Ok(self)
228 }
229
230 pub fn with_max_retries(mut self, max_retries: usize) -> Self {
234 self.max_retries = max_retries;
235 self
236 }
237
238 pub fn api_key(&self) -> &str {
240 &self.api_key
241 }
242
243 pub fn with_backoff_params(mut self, throughput_ops_sec: f64, reserve_capacity: f64) -> Self {
247 self.throughput_ops_sec = throughput_ops_sec;
248 self.reserve_capacity = reserve_capacity;
249 self
250 }
251
252 pub fn with_base_url_and_timeout(self, base_url: String, timeout: Duration) -> Result<Self> {
256 self.with_base_url(base_url).with_timeout(timeout)
257 }
258
259 fn build_default_headers(api_key: &str) -> Result<HeaderMap> {
261 let mut headers = HeaderMap::new();
262 headers.insert(header::CONTENT_TYPE, HeaderValue::from_static("application/json"));
263 headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
264 headers.insert(
265 "x-api-key",
266 HeaderValue::from_str(api_key).map_err(|e| {
267 Error::validation(
268 format!("Invalid API key format: {e}"),
269 Some("api_key".to_string()),
270 )
271 })?,
272 );
273 headers.insert("anthropic-version", HeaderValue::from_static(ANTHROPIC_API_VERSION));
274 Ok(headers)
275 }
276
277 fn default_headers(&self) -> HeaderMap {
279 (*self.cached_headers).clone()
280 }
281
282 fn build_url(&self, endpoint: &str) -> String {
295 let base = self.base_url.trim_end_matches('/');
296 format!("{}/v1/{}", base, endpoint)
297 }
298
299 async fn retry_with_backoff<F, Fut, T>(&self, operation: F) -> Result<T>
301 where
302 F: Fn() -> Fut,
303 Fut: std::future::Future<Output = Result<T>>,
304 {
305 let backoff = ExponentialBackoff::new(self.throughput_ops_sec, self.reserve_capacity);
306 let mut last_error = None;
307
308 for attempt in 0..=self.max_retries {
309 match operation().await {
310 Ok(result) => return Ok(result),
311 Err(error) => {
312 if !error.is_retryable() {
314 return Err(error);
315 }
316
317 if attempt == self.max_retries {
319 last_error = Some(error);
320 break;
321 }
322
323 let exp_backoff_duration = backoff.next();
325
326 let header_backoff_duration = match &error {
328 Error::RateLimit { retry_after: Some(seconds), .. } => {
329 Some(Duration::from_secs(*seconds))
330 }
331 Error::ServiceUnavailable { retry_after: Some(seconds), .. } => {
332 Some(Duration::from_secs(*seconds))
333 }
334 _ => None,
335 };
336
337 let sleep_duration = match header_backoff_duration {
339 Some(header_duration) => exp_backoff_duration.max(header_duration),
340 None => exp_backoff_duration,
341 };
342
343 CLIENT_REQUEST_RETRIES.click();
344 CLIENT_RETRY_BACKOFF.add(sleep_duration.as_secs_f64());
345 sleep(sleep_duration).await;
346 last_error = Some(error);
347 }
348 }
349 }
350
351 Err(last_error
352 .unwrap_or_else(|| Error::unknown("Failed after retries without capturing error")))
353 }
354
355 async fn process_error_response(response: Response) -> Error {
357 let status = response.status();
358 let status_code = status.as_u16();
359
360 let request_id = response
362 .headers()
363 .get("x-request-id")
364 .and_then(|val| val.to_str().ok())
365 .map(String::from);
366
367 let retry_after = response
368 .headers()
369 .get("retry-after")
370 .and_then(|val| val.to_str().ok())
371 .and_then(|val| val.parse::<u64>().ok());
372
373 #[derive(Deserialize)]
375 struct ErrorResponse {
376 error: Option<ErrorDetail>,
377 }
378
379 #[derive(Deserialize)]
380 struct ErrorDetail {
381 #[serde(rename = "type")]
382 error_type: Option<String>,
383 message: Option<String>,
384 param: Option<String>,
385 }
386
387 let error_body = match response.text().await {
388 Ok(body) => body,
389 Err(e) => {
390 return Error::http_client(
391 format!("Failed to read error response: {e}"),
392 Some(Box::new(e)),
393 );
394 }
395 };
396
397 let parsed_error = serde_json::from_str::<ErrorResponse>(&error_body).ok();
399 let error_type =
400 parsed_error.as_ref().and_then(|e| e.error.as_ref()).and_then(|e| e.error_type.clone());
401 let error_message = parsed_error
402 .as_ref()
403 .and_then(|e| e.error.as_ref())
404 .and_then(|e| e.message.clone())
405 .unwrap_or_else(|| error_body.clone());
406 let error_param =
407 parsed_error.as_ref().and_then(|e| e.error.as_ref()).and_then(|e| e.param.clone());
408
409 match status_code {
411 400 => Error::bad_request(error_message, error_param),
412 401 => Error::authentication(error_message),
413 403 => Error::permission(error_message),
414 404 => Error::not_found(error_message, None, None),
415 408 => Error::timeout(error_message, None),
416 429 => Error::rate_limit(error_message, retry_after),
417 500 => Error::internal_server(error_message, request_id),
418 502..=504 => Error::service_unavailable(error_message, retry_after),
419 529 => Error::rate_limit(error_message, retry_after),
420 _ => Error::api(status_code, error_type, error_message, request_id),
421 }
422 }
423
424 fn map_request_error(&self, e: reqwest::Error) -> Error {
426 if e.is_timeout() {
427 Error::timeout(format!("Request timed out: {e}"), Some(self.timeout.as_secs_f64()))
428 } else if e.is_connect() {
429 Error::connection(format!("Connection error: {e}"), Some(Box::new(e)))
430 } else {
431 Error::http_client(format!("Request failed: {e}"), Some(Box::new(e)))
432 }
433 }
434
435 async fn execute_post_request<T: serde::de::DeserializeOwned>(
437 &self,
438 url: &str,
439 body: &impl serde::Serialize,
440 headers: Option<HeaderMap>,
441 ) -> Result<T> {
442 let headers = headers.unwrap_or_else(|| self.default_headers());
443
444 let response = self
445 .client
446 .post(url)
447 .headers(headers)
448 .json(body)
449 .send()
450 .await
451 .map_err(|e| self.map_request_error(e))?;
452
453 if !response.status().is_success() {
454 return Err(Self::process_error_response(response).await);
455 }
456
457 response.json::<T>().await.map_err(|e| {
458 Error::serialization(format!("Failed to parse response: {e}"), Some(Box::new(e)))
459 })
460 }
461
462 async fn execute_get_request<T: serde::de::DeserializeOwned>(
464 &self,
465 url: &str,
466 query_params: Option<&[(String, String)]>,
467 ) -> Result<T> {
468 let mut request = self.client.get(url).headers(self.default_headers());
469
470 if let Some(params) = query_params {
471 for (key, value) in params {
472 request = request.query(&[(key, value)]);
473 }
474 }
475
476 let response = request.send().await.map_err(|e| self.map_request_error(e))?;
477
478 if !response.status().is_success() {
479 return Err(Self::process_error_response(response).await);
480 }
481
482 response.json::<T>().await.map_err(|e| {
483 Error::serialization(format!("Failed to parse response: {e}"), Some(Box::new(e)))
484 })
485 }
486
487 pub async fn send(&self, mut params: MessageCreateParams) -> Result<Message> {
489 let start = Instant::now();
490 CLIENT_REQUESTS.click();
491
492 if let Err(err) = params.validate() {
494 CLIENT_REQUEST_ERRORS.click();
495 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
496 return Err(err);
497 }
498
499 params.stream = false;
501
502 if matches!(params.thinking, Some(ThinkingConfig::Enabled { .. })) {
504 params.temperature = Some(1.0);
505 }
506
507 let mut headers = self.default_headers();
509
510 if params.requires_structured_outputs_beta() {
512 headers.insert("anthropic-beta", HeaderValue::from_static(STRUCTURED_OUTPUTS_BETA));
513 }
514
515 if params.context_management.is_some() {
517 let existing =
518 headers.get("anthropic-beta").and_then(|v| v.to_str().ok()).unwrap_or("");
519 let new_val = if existing.is_empty() {
520 "context-management-2025-06-27".to_string()
521 } else {
522 format!("{existing},context-management-2025-06-27")
523 };
524 headers.insert(
525 "anthropic-beta",
526 HeaderValue::from_str(&new_val)
527 .unwrap_or_else(|_| HeaderValue::from_static("context-management-2025-06-27")),
528 );
529 }
530
531 if params.speed.is_some() {
533 let existing =
534 headers.get("anthropic-beta").and_then(|v| v.to_str().ok()).unwrap_or("");
535 let new_val = if existing.is_empty() {
536 "fast-mode-2026-02-01".to_string()
537 } else {
538 format!("{existing},fast-mode-2026-02-01")
539 };
540 headers.insert(
541 "anthropic-beta",
542 HeaderValue::from_str(&new_val)
543 .unwrap_or_else(|_| HeaderValue::from_static("fast-mode-2026-02-01")),
544 );
545 }
546
547 let result = self
548 .retry_with_backoff(|| async {
549 let url = self.build_url("messages");
550 self.execute_post_request(&url, ¶ms, Some(headers.clone())).await
551 })
552 .await;
553
554 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
555 if result.is_err() {
556 CLIENT_REQUEST_ERRORS.click();
557 }
558 result
559 }
560
561 pub async fn send_with_logger(
566 &self,
567 params: MessageCreateParams,
568 logger: &dyn ClientLogger,
569 ) -> Result<Message> {
570 let result = self.send(params).await;
571 if let Ok(ref message) = result {
572 logger.log_response(message);
573 }
574 result
575 }
576
577 pub async fn stream(
581 &self,
582 params: &MessageCreateParams,
583 ) -> Result<impl Stream<Item = Result<MessageStreamEvent>> + use<>> {
584 let start = Instant::now();
585 CLIENT_REQUESTS.click();
586
587 if let Err(err) = params.validate() {
589 CLIENT_REQUEST_ERRORS.click();
590 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
591 return Err(err);
592 }
593
594 let mut params = params.clone();
596 params.stream = true;
597
598 if matches!(params.thinking, Some(ThinkingConfig::Enabled { .. })) {
600 params.temperature = Some(1.0);
601 }
602
603 let needs_beta = params.requires_structured_outputs_beta();
605
606 let needs_context_mgmt = params.context_management.is_some();
608
609 let needs_fast_mode = params.speed.is_some();
611
612 let response = self
613 .retry_with_backoff(|| async {
614 let url = self.build_url("messages");
615
616 let mut headers = self.default_headers();
617 headers.insert(header::ACCEPT, HeaderValue::from_static("text/event-stream"));
618
619 let mut betas = Vec::new();
621 if needs_beta {
622 betas.push(STRUCTURED_OUTPUTS_BETA);
623 }
624 if needs_context_mgmt {
625 betas.push("context-management-2025-06-27");
626 }
627 if needs_fast_mode {
628 betas.push("fast-mode-2026-02-01");
629 }
630 if !betas.is_empty() {
631 let beta_val = betas.join(",");
632 headers.insert(
633 "anthropic-beta",
634 HeaderValue::from_str(&beta_val)
635 .unwrap_or_else(|_| HeaderValue::from_static(STRUCTURED_OUTPUTS_BETA)),
636 );
637 }
638
639 let response = self
640 .client
641 .post(&url)
642 .headers(headers)
643 .json(¶ms)
644 .send()
645 .await
646 .map_err(|e| self.map_request_error(e))?;
647
648 if !response.status().is_success() {
649 return Err(Self::process_error_response(response).await);
650 }
651
652 Ok(response)
653 })
654 .await;
655
656 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
657 let response = match response {
658 Ok(response) => response,
659 Err(err) => {
660 CLIENT_REQUEST_ERRORS.click();
661 return Err(err);
662 }
663 };
664
665 let stream = response.bytes_stream();
667
668 Ok(process_sse(stream))
670 }
671
672 pub async fn stream_with_logger<'a>(
682 &self,
683 params: &MessageCreateParams,
684 logger: &'a dyn ClientLogger,
685 ) -> Result<LoggingStream<'a>> {
686 let raw_stream = self.stream(params).await?;
687 let (accumulating_stream, receiver) = AccumulatingStream::new(raw_stream);
688 Ok(LoggingStream::new(accumulating_stream, receiver, logger))
689 }
690
691 pub async fn count_tokens(
696 &self,
697 params: MessageCountTokensParams,
698 ) -> Result<MessageTokensCount> {
699 let start = Instant::now();
700 CLIENT_REQUESTS.click();
701 let result = self
702 .retry_with_backoff(|| async {
703 let url = self.build_url("messages/count_tokens");
704 self.execute_post_request(&url, ¶ms, None).await
705 })
706 .await;
707
708 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
709 if result.is_err() {
710 CLIENT_REQUEST_ERRORS.click();
711 }
712 result
713 }
714
715 pub async fn list_models(&self, params: Option<ModelListParams>) -> Result<ModelListResponse> {
720 let start = Instant::now();
721 CLIENT_REQUESTS.click();
722 let result = self
723 .retry_with_backoff(|| async {
724 let url = self.build_url("models");
725
726 let query_params = params.as_ref().map(|p| {
727 let mut params = Vec::new();
728 if let Some(ref after_id) = p.after_id {
729 params.push(("after_id".to_string(), after_id.clone()));
730 }
731 if let Some(ref before_id) = p.before_id {
732 params.push(("before_id".to_string(), before_id.clone()));
733 }
734 if let Some(limit) = p.limit {
735 params.push(("limit".to_string(), limit.to_string()));
736 }
737 params
738 });
739
740 self.execute_get_request(&url, query_params.as_deref()).await
741 })
742 .await;
743
744 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
745 if result.is_err() {
746 CLIENT_REQUEST_ERRORS.click();
747 }
748 result
749 }
750
751 pub async fn get_model(&self, model_id: &str) -> Result<ModelInfo> {
756 let start = Instant::now();
757 CLIENT_REQUESTS.click();
758 let result = self
759 .retry_with_backoff(|| async {
760 let url = self.build_url(&format!("models/{}", model_id));
761 self.execute_get_request(&url, None).await
762 })
763 .await;
764
765 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
766 if result.is_err() {
767 CLIENT_REQUEST_ERRORS.click();
768 }
769 result
770 }
771
772 async fn execute_delete_request(&self, url: &str) -> Result<()> {
776 let response = self
777 .client
778 .delete(url)
779 .headers(self.default_headers())
780 .send()
781 .await
782 .map_err(|e| self.map_request_error(e))?;
783
784 if !response.status().is_success() {
785 return Err(Self::process_error_response(response).await);
786 }
787
788 Ok(())
789 }
790
791 fn pagination_params(
793 before_id: Option<&str>,
794 after_id: Option<&str>,
795 limit: Option<u32>,
796 ) -> Option<Vec<(String, String)>> {
797 let mut params = Vec::new();
798 if let Some(before) = before_id {
799 params.push(("before_id".to_string(), before.to_string()));
800 }
801 if let Some(after) = after_id {
802 params.push(("after_id".to_string(), after.to_string()));
803 }
804 if let Some(lim) = limit {
805 params.push(("limit".to_string(), lim.to_string()));
806 }
807 if params.is_empty() { None } else { Some(params) }
808 }
809
810 pub async fn create_batch(&self, requests: Vec<BatchRequest>) -> Result<MessageBatch> {
814 let start = Instant::now();
815 CLIENT_REQUESTS.click();
816 let body = serde_json::json!({ "requests": requests });
817 let result = self
818 .retry_with_backoff(|| async {
819 let url = self.build_url("messages/batches");
820 self.execute_post_request(&url, &body, None).await
821 })
822 .await;
823 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
824 if result.is_err() {
825 CLIENT_REQUEST_ERRORS.click();
826 }
827 result
828 }
829
830 pub async fn get_batch(&self, batch_id: &str) -> Result<MessageBatch> {
832 let start = Instant::now();
833 CLIENT_REQUESTS.click();
834 let result = self
835 .retry_with_backoff(|| async {
836 let url = self.build_url(&format!("messages/batches/{batch_id}"));
837 self.execute_get_request(&url, None).await
838 })
839 .await;
840 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
841 if result.is_err() {
842 CLIENT_REQUEST_ERRORS.click();
843 }
844 result
845 }
846
847 pub async fn batch_results(&self, batch_id: &str) -> Result<Vec<BatchResultItem>> {
849 let start = Instant::now();
850 CLIENT_REQUESTS.click();
851 let result = self
852 .retry_with_backoff(|| async {
853 let url = self.build_url(&format!("messages/batches/{batch_id}/results"));
854 let response = self
855 .client
856 .get(&url)
857 .headers(self.default_headers())
858 .send()
859 .await
860 .map_err(|e| self.map_request_error(e))?;
861
862 if !response.status().is_success() {
863 return Err(Self::process_error_response(response).await);
864 }
865
866 let text = response.text().await.map_err(|e| {
867 Error::serialization(
868 format!("Failed to read batch results: {e}"),
869 Some(Box::new(e)),
870 )
871 })?;
872
873 let mut items = Vec::new();
874 for line in text.lines() {
875 let trimmed = line.trim();
876 if trimmed.is_empty() {
877 continue;
878 }
879 let item: BatchResultItem = serde_json::from_str(trimmed)?;
880 items.push(item);
881 }
882 Ok(items)
883 })
884 .await;
885 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
886 if result.is_err() {
887 CLIENT_REQUEST_ERRORS.click();
888 }
889 result
890 }
891
892 pub async fn cancel_batch(&self, batch_id: &str) -> Result<MessageBatch> {
894 let start = Instant::now();
895 CLIENT_REQUESTS.click();
896 let result = self
897 .retry_with_backoff(|| async {
898 let url = self.build_url(&format!("messages/batches/{batch_id}/cancel"));
899 self.execute_post_request(&url, &serde_json::json!({}), None).await
900 })
901 .await;
902 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
903 if result.is_err() {
904 CLIENT_REQUEST_ERRORS.click();
905 }
906 result
907 }
908
909 pub async fn delete_batch(&self, batch_id: &str) -> Result<()> {
911 let start = Instant::now();
912 CLIENT_REQUESTS.click();
913 let result = self
914 .retry_with_backoff(|| async {
915 let url = self.build_url(&format!("messages/batches/{batch_id}"));
916 self.execute_delete_request(&url).await
917 })
918 .await;
919 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
920 if result.is_err() {
921 CLIENT_REQUEST_ERRORS.click();
922 }
923 result
924 }
925
926 pub async fn list_batches(
928 &self,
929 before_id: Option<&str>,
930 after_id: Option<&str>,
931 limit: Option<u32>,
932 ) -> Result<PaginatedList<MessageBatch>> {
933 let start = Instant::now();
934 CLIENT_REQUESTS.click();
935 let result = self
936 .retry_with_backoff(|| async {
937 let url = self.build_url("messages/batches");
938 let query = Self::pagination_params(before_id, after_id, limit);
939 self.execute_get_request(&url, query.as_deref()).await
940 })
941 .await;
942 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
943 if result.is_err() {
944 CLIENT_REQUEST_ERRORS.click();
945 }
946 result
947 }
948
949 pub async fn upload_file(
953 &self,
954 data: Vec<u8>,
955 mime_type: &str,
956 filename: &str,
957 purpose: &str,
958 ) -> Result<FileObject> {
959 let start = Instant::now();
960 CLIENT_REQUESTS.click();
961
962 let mime_type = mime_type.to_string();
963 let filename = filename.to_string();
964 let purpose = purpose.to_string();
965
966 let result = self
967 .retry_with_backoff(|| {
968 let data = data.clone();
969 let mime_type = mime_type.clone();
970 let filename = filename.clone();
971 let purpose = purpose.clone();
972 async move {
973 let url = self.build_url("files");
974 let part = reqwest::multipart::Part::bytes(data)
975 .file_name(filename)
976 .mime_str(&mime_type)
977 .map_err(|e| {
978 Error::validation(
979 format!("Invalid MIME type: {e}"),
980 Some("mime_type".to_string()),
981 )
982 })?;
983 let form =
984 reqwest::multipart::Form::new().text("purpose", purpose).part("file", part);
985
986 let response = self
987 .client
988 .post(&url)
989 .headers(self.default_headers())
990 .multipart(form)
991 .send()
992 .await
993 .map_err(|e| self.map_request_error(e))?;
994
995 if !response.status().is_success() {
996 return Err(Self::process_error_response(response).await);
997 }
998
999 response.json::<FileObject>().await.map_err(|e| {
1000 Error::serialization(
1001 format!("Failed to parse file response: {e}"),
1002 Some(Box::new(e)),
1003 )
1004 })
1005 }
1006 })
1007 .await;
1008 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
1009 if result.is_err() {
1010 CLIENT_REQUEST_ERRORS.click();
1011 }
1012 result
1013 }
1014
1015 pub async fn get_file(&self, file_id: &str) -> Result<FileObject> {
1017 let start = Instant::now();
1018 CLIENT_REQUESTS.click();
1019 let result = self
1020 .retry_with_backoff(|| async {
1021 let url = self.build_url(&format!("files/{file_id}"));
1022 self.execute_get_request(&url, None).await
1023 })
1024 .await;
1025 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
1026 if result.is_err() {
1027 CLIENT_REQUEST_ERRORS.click();
1028 }
1029 result
1030 }
1031
1032 pub async fn delete_file(&self, file_id: &str) -> Result<()> {
1034 let start = Instant::now();
1035 CLIENT_REQUESTS.click();
1036 let result = self
1037 .retry_with_backoff(|| async {
1038 let url = self.build_url(&format!("files/{file_id}"));
1039 self.execute_delete_request(&url).await
1040 })
1041 .await;
1042 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
1043 if result.is_err() {
1044 CLIENT_REQUEST_ERRORS.click();
1045 }
1046 result
1047 }
1048
1049 pub async fn list_files(
1051 &self,
1052 before_id: Option<&str>,
1053 after_id: Option<&str>,
1054 limit: Option<u32>,
1055 ) -> Result<PaginatedList<FileObject>> {
1056 let start = Instant::now();
1057 CLIENT_REQUESTS.click();
1058 let result = self
1059 .retry_with_backoff(|| async {
1060 let url = self.build_url("files");
1061 let query = Self::pagination_params(before_id, after_id, limit);
1062 self.execute_get_request(&url, query.as_deref()).await
1063 })
1064 .await;
1065 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
1066 if result.is_err() {
1067 CLIENT_REQUEST_ERRORS.click();
1068 }
1069 result
1070 }
1071
1072 pub async fn create_skill(
1076 &self,
1077 name: &str,
1078 description: &str,
1079 content: Vec<u8>,
1080 ) -> Result<SkillObject> {
1081 let start = Instant::now();
1082 CLIENT_REQUESTS.click();
1083 let body = serde_json::json!({
1084 "name": name,
1085 "description": description,
1086 "content": base64_encode(&content),
1087 });
1088 let result = self
1089 .retry_with_backoff(|| async {
1090 let url = self.build_url("skills");
1091 self.execute_post_request(&url, &body, None).await
1092 })
1093 .await;
1094 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
1095 if result.is_err() {
1096 CLIENT_REQUEST_ERRORS.click();
1097 }
1098 result
1099 }
1100
1101 pub async fn get_skill(&self, skill_id: &str) -> Result<SkillObject> {
1103 let start = Instant::now();
1104 CLIENT_REQUESTS.click();
1105 let result = self
1106 .retry_with_backoff(|| async {
1107 let url = self.build_url(&format!("skills/{skill_id}"));
1108 self.execute_get_request(&url, None).await
1109 })
1110 .await;
1111 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
1112 if result.is_err() {
1113 CLIENT_REQUEST_ERRORS.click();
1114 }
1115 result
1116 }
1117
1118 pub async fn update_skill(&self, skill_id: &str, content: Vec<u8>) -> Result<SkillObject> {
1120 let start = Instant::now();
1121 CLIENT_REQUESTS.click();
1122 let body = serde_json::json!({
1123 "content": base64_encode(&content),
1124 });
1125 let result = self
1126 .retry_with_backoff(|| async {
1127 let url = self.build_url(&format!("skills/{skill_id}"));
1128 self.execute_post_request(&url, &body, None).await
1129 })
1130 .await;
1131 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
1132 if result.is_err() {
1133 CLIENT_REQUEST_ERRORS.click();
1134 }
1135 result
1136 }
1137
1138 pub async fn delete_skill(&self, skill_id: &str) -> Result<()> {
1140 let start = Instant::now();
1141 CLIENT_REQUESTS.click();
1142 let result = self
1143 .retry_with_backoff(|| async {
1144 let url = self.build_url(&format!("skills/{skill_id}"));
1145 self.execute_delete_request(&url).await
1146 })
1147 .await;
1148 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
1149 if result.is_err() {
1150 CLIENT_REQUEST_ERRORS.click();
1151 }
1152 result
1153 }
1154
1155 pub async fn list_skills(
1157 &self,
1158 before_id: Option<&str>,
1159 after_id: Option<&str>,
1160 limit: Option<u32>,
1161 ) -> Result<PaginatedList<SkillObject>> {
1162 let start = Instant::now();
1163 CLIENT_REQUESTS.click();
1164 let result = self
1165 .retry_with_backoff(|| async {
1166 let url = self.build_url("skills");
1167 let query = Self::pagination_params(before_id, after_id, limit);
1168 self.execute_get_request(&url, query.as_deref()).await
1169 })
1170 .await;
1171 CLIENT_REQUEST_DURATION.add(start.elapsed().as_secs_f64());
1172 if result.is_err() {
1173 CLIENT_REQUEST_ERRORS.click();
1174 }
1175 result
1176 }
1177}
1178
1179#[cfg(test)]
1180mod tests {
1181 use super::*;
1182 use std::sync::Arc;
1183 use std::sync::atomic::{AtomicUsize, Ordering};
1184
1185 #[tokio::test]
1186 async fn retry_logic_with_backoff() {
1187 let client = Anthropic {
1188 api_key: "test".to_string(),
1189 client: ReqwestClient::new(),
1190 base_url: "http://localhost".to_string(),
1191 timeout: Duration::from_secs(1),
1192 max_retries: 2,
1193 throughput_ops_sec: 1.0 / 60.0,
1194 reserve_capacity: 1.0 / 60.0,
1195 cached_headers: Arc::new(HeaderMap::new()),
1196 };
1197
1198 let attempt_counter = Arc::new(AtomicUsize::new(0));
1199 let counter_clone = attempt_counter.clone();
1200
1201 let result = client
1202 .retry_with_backoff(|| {
1203 let counter = counter_clone.clone();
1204 async move {
1205 let attempt = counter.fetch_add(1, Ordering::SeqCst);
1206 match attempt {
1207 0 | 1 => Err(Error::rate_limit("Rate limited", Some(1))),
1208 _ => Ok("success".to_string()),
1209 }
1210 }
1211 })
1212 .await;
1213
1214 assert!(result.is_ok());
1215 assert_eq!(result.unwrap(), "success");
1216 assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
1217 }
1218
1219 #[tokio::test]
1220 async fn retry_logic_with_non_retryable_error() {
1221 let client = Anthropic {
1222 api_key: "test".to_string(),
1223 client: ReqwestClient::new(),
1224 base_url: "http://localhost".to_string(),
1225 timeout: Duration::from_secs(1),
1226 max_retries: 2,
1227 throughput_ops_sec: 1.0 / 60.0,
1228 reserve_capacity: 1.0 / 60.0,
1229 cached_headers: Arc::new(HeaderMap::new()),
1230 };
1231
1232 let attempt_counter = Arc::new(AtomicUsize::new(0));
1233 let counter_clone = attempt_counter.clone();
1234
1235 let result: Result<String> = client
1236 .retry_with_backoff(|| {
1237 let counter = counter_clone.clone();
1238 async move {
1239 counter.fetch_add(1, Ordering::SeqCst);
1240 Err(Error::authentication("Invalid API key"))
1241 }
1242 })
1243 .await;
1244
1245 assert!(result.is_err());
1246 assert!(result.unwrap_err().is_authentication());
1247 assert_eq!(attempt_counter.load(Ordering::SeqCst), 1);
1249 }
1250
1251 #[tokio::test]
1252 async fn retry_logic_max_retries_exceeded() {
1253 let client = Anthropic {
1254 api_key: "test".to_string(),
1255 client: ReqwestClient::new(),
1256 base_url: "http://localhost".to_string(),
1257 timeout: Duration::from_secs(1),
1258 max_retries: 2,
1259 throughput_ops_sec: 1.0 / 60.0,
1260 reserve_capacity: 1.0 / 60.0,
1261 cached_headers: Arc::new(HeaderMap::new()),
1262 };
1263
1264 let attempt_counter = Arc::new(AtomicUsize::new(0));
1265 let counter_clone = attempt_counter.clone();
1266
1267 let result: Result<String> = client
1268 .retry_with_backoff(|| {
1269 let counter = counter_clone.clone();
1270 async move {
1271 counter.fetch_add(1, Ordering::SeqCst);
1272 Err(Error::rate_limit("Always rate limited", Some(1)))
1273 }
1274 })
1275 .await;
1276
1277 assert!(result.is_err());
1278 assert!(result.unwrap_err().is_rate_limit());
1279 assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
1281 }
1282
1283 #[tokio::test]
1284 async fn error_529_is_retryable() {
1285 let client = Anthropic {
1287 api_key: "test".to_string(),
1288 client: ReqwestClient::new(),
1289 base_url: "http://localhost".to_string(),
1290 timeout: Duration::from_secs(1),
1291 max_retries: 2,
1292 throughput_ops_sec: 1.0 / 60.0,
1293 reserve_capacity: 1.0 / 60.0,
1294 cached_headers: Arc::new(HeaderMap::new()),
1295 };
1296
1297 let attempt_counter = Arc::new(AtomicUsize::new(0));
1298 let counter_clone = attempt_counter.clone();
1299
1300 let result = client
1301 .retry_with_backoff(|| {
1302 let counter = counter_clone.clone();
1303 async move {
1304 let attempt = counter.fetch_add(1, Ordering::SeqCst);
1305 match attempt {
1306 0 | 1 => {
1307 Err(Error::api(
1309 529,
1310 Some("overloaded_error".to_string()),
1311 "Overloaded".to_string(),
1312 None,
1313 ))
1314 }
1315 _ => Ok("success".to_string()),
1316 }
1317 }
1318 })
1319 .await;
1320
1321 assert!(result.is_ok());
1322 assert_eq!(result.unwrap(), "success");
1323 assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
1325 }
1326
1327 #[test]
1328 fn error_529_mapped_correctly() {
1329 let error =
1331 Error::api(529, Some("overloaded_error".to_string()), "Overloaded".to_string(), None);
1332 assert!(error.is_retryable());
1333
1334 let rate_limit_error = Error::rate_limit("Overloaded", Some(5));
1336 assert!(rate_limit_error.is_retryable());
1337 }
1338
1339 #[test]
1340 fn resolve_api_key_regular_value() {
1341 let result = Anthropic::resolve_api_key("sk-test-key-123");
1342 assert!(result.is_ok());
1343 assert_eq!(result.unwrap(), "sk-test-key-123");
1344 }
1345
1346 #[test]
1347 fn resolve_api_key_file_url_absolute() {
1348 let test_dir =
1349 std::env::temp_dir().join(format!("adk_anthropic_test_{}", std::process::id()));
1350 std::fs::create_dir_all(&test_dir).unwrap();
1351 let test_file = test_dir.join("test_api_key.txt");
1352 std::fs::write(&test_file, "sk-test-from-file-123\n").unwrap();
1353
1354 let file_url = format!("file://{}", test_file.display());
1355 let result = Anthropic::resolve_api_key(&file_url);
1356
1357 std::fs::remove_dir_all(&test_dir).unwrap();
1358
1359 assert!(result.is_ok());
1360 assert_eq!(result.unwrap(), "sk-test-from-file-123");
1361 }
1362
1363 #[test]
1364 fn resolve_api_key_file_url_relative() {
1365 let test_file = "test_relative_key.txt";
1366 std::fs::write(test_file, "sk-relative-key-456\n").unwrap();
1367
1368 let file_url = format!("file://{}", test_file);
1369 let result = Anthropic::resolve_api_key(&file_url);
1370
1371 std::fs::remove_file(test_file).unwrap();
1372
1373 assert!(result.is_ok());
1374 assert_eq!(result.unwrap(), "sk-relative-key-456");
1375 }
1376
1377 #[test]
1378 fn resolve_api_key_file_url_nonexistent() {
1379 let result = Anthropic::resolve_api_key("file:///nonexistent/path/to/key.txt");
1380 assert!(result.is_err());
1381
1382 let error = result.unwrap_err();
1383 assert!(error.is_validation());
1384 assert!(format!("{}", error).contains("Failed to read API key from file"));
1385 }
1386
1387 #[test]
1388 fn resolve_api_key_file_url_with_whitespace() {
1389 let test_file = "test_whitespace_key.txt";
1390 std::fs::write(test_file, " sk-whitespace-key-789 \n ").unwrap();
1391
1392 let file_url = format!("file://{}", test_file);
1393 let result = Anthropic::resolve_api_key(&file_url);
1394
1395 std::fs::remove_file(test_file).unwrap();
1396
1397 assert!(result.is_ok());
1398 assert_eq!(result.unwrap(), "sk-whitespace-key-789");
1399 }
1400
1401 #[test]
1402 fn client_builder_methods() {
1403 let client = Anthropic::new(Some("test_key".to_string())).unwrap();
1404
1405 let configured_client = client
1407 .with_base_url("https://custom.api.com".to_string())
1408 .with_max_retries(5)
1409 .with_backoff_params(2.0, 1.0);
1410
1411 assert_eq!(configured_client.base_url, "https://custom.api.com");
1412 assert_eq!(configured_client.max_retries, 5);
1413 assert_eq!(configured_client.throughput_ops_sec, 2.0);
1414 assert_eq!(configured_client.reserve_capacity, 1.0);
1415 }
1416
1417 #[test]
1418 fn build_url_default_base() {
1419 let client = Anthropic::new(Some("test_key".to_string())).unwrap();
1420 assert_eq!(client.build_url("messages"), "https://api.anthropic.com/v1/messages");
1422 assert_eq!(
1423 client.build_url("messages/count_tokens"),
1424 "https://api.anthropic.com/v1/messages/count_tokens"
1425 );
1426 assert_eq!(client.build_url("models"), "https://api.anthropic.com/v1/models");
1427 }
1428
1429 #[test]
1430 fn build_url_custom_base_without_trailing_slash() {
1431 let client = Anthropic::new(Some("test_key".to_string()))
1432 .unwrap()
1433 .with_base_url("https://api.minimax.io/anthropic".to_string());
1434 assert_eq!(client.build_url("messages"), "https://api.minimax.io/anthropic/v1/messages");
1435 }
1436
1437 #[test]
1438 fn build_url_custom_base_with_trailing_slash() {
1439 let client = Anthropic::new(Some("test_key".to_string()))
1440 .unwrap()
1441 .with_base_url("https://api.minimax.io/anthropic/".to_string());
1442 assert_eq!(client.build_url("messages"), "https://api.minimax.io/anthropic/v1/messages");
1443 }
1444
1445 #[test]
1446 fn build_url_minimax_china() {
1447 let client = Anthropic::new(Some("test_key".to_string()))
1448 .unwrap()
1449 .with_base_url("https://api.minimaxi.com/anthropic".to_string());
1450 assert_eq!(client.build_url("messages"), "https://api.minimaxi.com/anthropic/v1/messages");
1451 assert_eq!(
1452 client.build_url(&format!("models/{}", "claude-3-opus")),
1453 "https://api.minimaxi.com/anthropic/v1/models/claude-3-opus"
1454 );
1455 }
1456
1457 #[test]
1458 fn client_timeout_configuration() {
1459 let client = Anthropic::new(Some("test_key".to_string())).unwrap();
1460 let timeout = Duration::from_secs(30);
1461
1462 let configured_client = client.with_timeout(timeout).unwrap();
1463 assert_eq!(configured_client.timeout, timeout);
1464 }
1465
1466 #[test]
1467 fn client_cached_headers_performance() {
1468 let client = Anthropic::new(Some("test_key".to_string())).unwrap();
1469
1470 let headers1 = client.default_headers();
1472 let headers2 = client.default_headers();
1473
1474 assert_eq!(headers1.len(), headers2.len());
1475 assert!(headers1.contains_key("x-api-key"));
1476 assert!(headers1.contains_key("anthropic-version"));
1477 assert!(headers1.contains_key("content-type"));
1478 }
1479
1480 #[test]
1481 fn request_error_mapping() {
1482 let client = Anthropic::new(Some("test_key".to_string())).unwrap();
1483
1484 let _timeout = Duration::from_secs(30);
1487 assert_eq!(client.timeout, DEFAULT_TIMEOUT); }
1489
1490 #[tokio::test]
1491 async fn concurrent_retry_safety() {
1492 use std::sync::atomic::{AtomicUsize, Ordering};
1493 use tokio::spawn;
1494
1495 let client = Anthropic {
1496 api_key: "test".to_string(),
1497 client: ReqwestClient::new(),
1498 base_url: "http://localhost".to_string(),
1499 timeout: Duration::from_secs(1),
1500 max_retries: 1,
1501 throughput_ops_sec: 1.0,
1502 reserve_capacity: 1.0,
1503 cached_headers: Arc::new(HeaderMap::new()),
1504 };
1505
1506 let attempt_counter = Arc::new(AtomicUsize::new(0));
1507 let mut handles = vec![];
1508
1509 for _ in 0..3 {
1511 let client_clone = client.clone();
1512 let counter_clone = attempt_counter.clone();
1513
1514 let handle = spawn(async move {
1515 client_clone
1516 .retry_with_backoff(|| {
1517 let counter = counter_clone.clone();
1518 async move {
1519 counter.fetch_add(1, Ordering::SeqCst);
1520 Ok::<String, Error>("success".to_string())
1521 }
1522 })
1523 .await
1524 });
1525 handles.push(handle);
1526 }
1527
1528 for handle in handles {
1530 let result = handle.await.unwrap();
1531 assert!(result.is_ok());
1532 }
1533
1534 assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
1536 }
1537}