claudius/
client.rs

1use futures::Stream;
2use reqwest::header::{HeaderMap, HeaderValue};
3use reqwest::{Client as ReqwestClient, Response, header};
4use serde::Deserialize;
5use std::env;
6use std::fs;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::time::sleep;
10
11use crate::backoff::ExponentialBackoff;
12use crate::error::{Error, Result};
13use crate::sse::process_sse;
14use crate::types::{
15    Message, MessageCountTokensParams, MessageCreateParams, MessageStreamEvent, MessageTokensCount,
16    ModelInfo, ModelListParams, ModelListResponse,
17};
18
19const DEFAULT_API_URL: &str = "https://api.anthropic.com/v1/";
20const ANTHROPIC_API_VERSION: &str = "2023-06-01";
21const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60);
22
23/// Client for the Anthropic API with performance optimizations.
24#[derive(Debug, Clone)]
25pub struct Anthropic {
26    api_key: String,
27    client: ReqwestClient,
28    base_url: String,
29    timeout: Duration,
30    max_retries: usize,
31    throughput_ops_sec: f64,
32    reserve_capacity: f64,
33    /// Cached headers for performance - Arc for cheap cloning
34    cached_headers: Arc<HeaderMap>,
35}
36
37impl Anthropic {
38    /// Resolve an API key value, handling file:// URLs
39    fn resolve_api_key(key_value: &str) -> Result<String> {
40        if let Some(stripped) = key_value.strip_prefix("file://") {
41            // Handle file:// URLs
42            let path = if stripped.starts_with('/') {
43                // Absolute path: file:///root/.env -> /root/.env
44                stripped.to_string()
45            } else {
46                // Relative path: file://../foo -> ../foo
47                stripped.to_string()
48            };
49
50            fs::read_to_string(&path)
51                .map(|content| content.trim().to_string())
52                .map_err(|e| {
53                    Error::validation(
54                        format!("Failed to read API key from file '{}': {}", path, e),
55                        Some("api_key".to_string()),
56                    )
57                })
58        } else {
59            // Regular API key value
60            Ok(key_value.to_string())
61        }
62    }
63
64    /// Create a new Anthropic client.
65    ///
66    /// The API key can be provided directly or read from the CLAUDIUS_API_KEY or ANTHROPIC_API_KEY
67    /// environment variables. If an environment variable value starts with "file://", it will be
68    /// treated as a file path and the API key will be read from that file.
69    pub fn new(api_key: Option<String>) -> Result<Self> {
70        let api_key = match api_key {
71            Some(key) => Self::resolve_api_key(&key)?,
72            None => match env::var("CLAUDIUS_API_KEY").ok() {
73                Some(key) => Self::resolve_api_key(&key)?,
74                None => {
75                    let env_key = env::var("ANTHROPIC_API_KEY").map_err(|_| {
76                        Error::authentication(
77                            "API key not provided and ANTHROPIC_API_KEY environment variable not set",
78                        )
79                    })?;
80                    Self::resolve_api_key(&env_key)?
81                }
82            },
83        };
84
85        let timeout = DEFAULT_TIMEOUT;
86        let client = ReqwestClient::builder()
87            .timeout(timeout)
88            .pool_max_idle_per_host(10) // Connection pooling optimization
89            .pool_idle_timeout(Duration::from_secs(90))
90            .tcp_keepalive(Duration::from_secs(60))
91            .build()
92            .map_err(|e| {
93                Error::http_client(
94                    format!("Failed to build HTTP client: {e}"),
95                    Some(Box::new(e)),
96                )
97            })?;
98
99        // Pre-build headers for performance
100        let cached_headers = Arc::new(Self::build_default_headers(&api_key)?);
101
102        Ok(Self {
103            api_key,
104            client,
105            base_url: DEFAULT_API_URL.to_string(),
106            timeout,
107            max_retries: 3,
108            throughput_ops_sec: 1.0 / 60.0,
109            reserve_capacity: 1.0 / 60.0,
110            cached_headers,
111        })
112    }
113
114    /// Set a custom base URL for this client.
115    ///
116    /// This method allows you to specify a different API endpoint for the client.
117    pub fn with_base_url(mut self, base_url: String) -> Self {
118        self.base_url = base_url;
119        self
120    }
121
122    /// Set a custom timeout for this client.
123    ///
124    /// This method allows you to specify a different timeout for API requests.
125    pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
126        self.timeout = timeout;
127
128        // Recreate the client with the new timeout and performance optimizations
129        let client = ReqwestClient::builder()
130            .timeout(timeout)
131            .pool_max_idle_per_host(10)
132            .pool_idle_timeout(Duration::from_secs(90))
133            .tcp_keepalive(Duration::from_secs(60))
134            .build()
135            .map_err(|e| {
136                Error::http_client(
137                    "Failed to build HTTP client with new timeout",
138                    Some(Box::new(e)),
139                )
140            })?;
141
142        self.client = client;
143        Ok(self)
144    }
145
146    /// Set the maximum number of retries for this client.
147    ///
148    /// This method allows you to specify how many times to retry failed requests.
149    pub fn with_max_retries(mut self, max_retries: usize) -> Self {
150        self.max_retries = max_retries;
151        self
152    }
153
154    /// Get the API key being used by this client.
155    pub fn api_key(&self) -> &str {
156        &self.api_key
157    }
158
159    /// Set the backoff parameters for this client.
160    ///
161    /// This method allows you to configure the exponential backoff algorithm.
162    pub fn with_backoff_params(mut self, throughput_ops_sec: f64, reserve_capacity: f64) -> Self {
163        self.throughput_ops_sec = throughput_ops_sec;
164        self.reserve_capacity = reserve_capacity;
165        self
166    }
167
168    /// Set both a custom base URL and timeout for this client.
169    ///
170    /// This is a convenience method that chains with_base_url and with_timeout.
171    pub fn with_base_url_and_timeout(self, base_url: String, timeout: Duration) -> Result<Self> {
172        self.with_base_url(base_url).with_timeout(timeout)
173    }
174
175    /// Build default headers for API requests (static method for initialization).
176    fn build_default_headers(api_key: &str) -> Result<HeaderMap> {
177        let mut headers = HeaderMap::new();
178        headers.insert(
179            header::CONTENT_TYPE,
180            HeaderValue::from_static("application/json"),
181        );
182        headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
183        headers.insert(
184            "x-api-key",
185            HeaderValue::from_str(api_key).map_err(|e| {
186                Error::validation(
187                    format!("Invalid API key format: {e}"),
188                    Some("api_key".to_string()),
189                )
190            })?,
191        );
192        headers.insert(
193            "anthropic-version",
194            HeaderValue::from_static(ANTHROPIC_API_VERSION),
195        );
196        Ok(headers)
197    }
198
199    /// Get cached headers for performance (no allocation needed).
200    fn default_headers(&self) -> HeaderMap {
201        (*self.cached_headers).clone()
202    }
203
204    /// Retry wrapper that implements exponential backoff with header-based retry-after
205    async fn retry_with_backoff<F, Fut, T>(&self, operation: F) -> Result<T>
206    where
207        F: Fn() -> Fut,
208        Fut: std::future::Future<Output = Result<T>>,
209    {
210        let backoff = ExponentialBackoff::new(self.throughput_ops_sec, self.reserve_capacity);
211        let mut last_error = None;
212
213        for attempt in 0..=self.max_retries {
214            match operation().await {
215                Ok(result) => return Ok(result),
216                Err(error) => {
217                    // Check if error is retryable
218                    if !error.is_retryable() {
219                        return Err(error);
220                    }
221
222                    // Don't sleep on the last attempt
223                    if attempt == self.max_retries {
224                        last_error = Some(error);
225                        break;
226                    }
227
228                    // Calculate backoff duration
229                    let exp_backoff_duration = backoff.next();
230
231                    // Get retry-after from error if available
232                    let header_backoff_duration = match &error {
233                        Error::RateLimit {
234                            retry_after: Some(seconds),
235                            ..
236                        } => Some(Duration::from_secs(*seconds)),
237                        Error::ServiceUnavailable {
238                            retry_after: Some(seconds),
239                            ..
240                        } => Some(Duration::from_secs(*seconds)),
241                        _ => None,
242                    };
243
244                    // Take the maximum of exponential backoff and header-based backoff
245                    let sleep_duration = match header_backoff_duration {
246                        Some(header_duration) => exp_backoff_duration.max(header_duration),
247                        None => exp_backoff_duration,
248                    };
249
250                    sleep(sleep_duration).await;
251                    last_error = Some(error);
252                }
253            }
254        }
255
256        Err(last_error
257            .unwrap_or_else(|| Error::unknown("Failed after retries without capturing error")))
258    }
259
260    /// Process API response errors and convert to our Error type
261    async fn process_error_response(response: Response) -> Error {
262        let status = response.status();
263        let status_code = status.as_u16();
264
265        // Get headers we might need for error processing
266        let request_id = response
267            .headers()
268            .get("x-request-id")
269            .and_then(|val| val.to_str().ok())
270            .map(String::from);
271
272        let retry_after = response
273            .headers()
274            .get("retry-after")
275            .and_then(|val| val.to_str().ok())
276            .and_then(|val| val.parse::<u64>().ok());
277
278        // Try to parse error response body
279        #[derive(Deserialize)]
280        struct ErrorResponse {
281            error: Option<ErrorDetail>,
282        }
283
284        #[derive(Deserialize)]
285        struct ErrorDetail {
286            #[serde(rename = "type")]
287            error_type: Option<String>,
288            message: Option<String>,
289            param: Option<String>,
290        }
291
292        let error_body = match response.text().await {
293            Ok(body) => body,
294            Err(e) => {
295                return Error::http_client(
296                    format!("Failed to read error response: {e}"),
297                    Some(Box::new(e)),
298                );
299            }
300        };
301
302        // Try to parse as JSON first
303        let parsed_error = serde_json::from_str::<ErrorResponse>(&error_body).ok();
304        let error_type = parsed_error
305            .as_ref()
306            .and_then(|e| e.error.as_ref())
307            .and_then(|e| e.error_type.clone());
308        let error_message = parsed_error
309            .as_ref()
310            .and_then(|e| e.error.as_ref())
311            .and_then(|e| e.message.clone())
312            .unwrap_or_else(|| error_body.clone());
313        let error_param = parsed_error
314            .as_ref()
315            .and_then(|e| e.error.as_ref())
316            .and_then(|e| e.param.clone());
317
318        // Map HTTP status code to appropriate error type
319        match status_code {
320            400 => Error::bad_request(error_message, error_param),
321            401 => Error::authentication(error_message),
322            403 => Error::permission(error_message),
323            404 => Error::not_found(error_message, None, None),
324            408 => Error::timeout(error_message, None),
325            429 => Error::rate_limit(error_message, retry_after),
326            500 => Error::internal_server(error_message, request_id),
327            502..=504 => Error::service_unavailable(error_message, retry_after),
328            529 => Error::rate_limit(error_message, retry_after),
329            _ => Error::api(status_code, error_type, error_message, request_id),
330        }
331    }
332
333    /// Convert reqwest errors to appropriate Error types
334    fn map_request_error(&self, e: reqwest::Error) -> Error {
335        if e.is_timeout() {
336            Error::timeout(
337                format!("Request timed out: {e}"),
338                Some(self.timeout.as_secs_f64()),
339            )
340        } else if e.is_connect() {
341            Error::connection(format!("Connection error: {e}"), Some(Box::new(e)))
342        } else {
343            Error::http_client(format!("Request failed: {e}"), Some(Box::new(e)))
344        }
345    }
346
347    /// Execute a POST request with error handling
348    async fn execute_post_request<T: serde::de::DeserializeOwned>(
349        &self,
350        url: &str,
351        body: &impl serde::Serialize,
352        headers: Option<HeaderMap>,
353    ) -> Result<T> {
354        let headers = headers.unwrap_or_else(|| self.default_headers());
355
356        let response = self
357            .client
358            .post(url)
359            .headers(headers)
360            .json(body)
361            .send()
362            .await
363            .map_err(|e| self.map_request_error(e))?;
364
365        if !response.status().is_success() {
366            return Err(Self::process_error_response(response).await);
367        }
368
369        response.json::<T>().await.map_err(|e| {
370            Error::serialization(format!("Failed to parse response: {e}"), Some(Box::new(e)))
371        })
372    }
373
374    /// Execute a GET request with error handling
375    async fn execute_get_request<T: serde::de::DeserializeOwned>(
376        &self,
377        url: &str,
378        query_params: Option<&[(String, String)]>,
379    ) -> Result<T> {
380        let mut request = self.client.get(url).headers(self.default_headers());
381
382        if let Some(params) = query_params {
383            for (key, value) in params {
384                request = request.query(&[(key, value)]);
385            }
386        }
387
388        let response = request
389            .send()
390            .await
391            .map_err(|e| self.map_request_error(e))?;
392
393        if !response.status().is_success() {
394            return Err(Self::process_error_response(response).await);
395        }
396
397        response.json::<T>().await.map_err(|e| {
398            Error::serialization(format!("Failed to parse response: {e}"), Some(Box::new(e)))
399        })
400    }
401
402    /// Send a message to the API and get a non-streaming response.
403    pub async fn send(&self, mut params: MessageCreateParams) -> Result<Message> {
404        // Validate parameters first
405        params.validate()?;
406
407        // Ensure stream is disabled
408        params.stream = false;
409
410        self.retry_with_backoff(|| async {
411            let url = format!("{}messages", self.base_url);
412            self.execute_post_request(&url, &params, None).await
413        })
414        .await
415    }
416
417    /// Send a message to the API and get a streaming response.
418    ///
419    /// Returns a stream of MessageStreamEvent objects that can be processed incrementally.
420    pub async fn stream(
421        &self,
422        mut params: MessageCreateParams,
423    ) -> Result<impl Stream<Item = Result<MessageStreamEvent>>> {
424        // Validate parameters first
425        params.validate()?;
426
427        // Ensure stream is enabled
428        params.stream = true;
429
430        let response = self
431            .retry_with_backoff(|| async {
432                let url = format!("{}messages", self.base_url);
433
434                let mut headers = self.default_headers();
435                headers.insert(
436                    header::ACCEPT,
437                    HeaderValue::from_static("text/event-stream"),
438                );
439
440                let response = self
441                    .client
442                    .post(&url)
443                    .headers(headers)
444                    .json(&params)
445                    .send()
446                    .await
447                    .map_err(|e| self.map_request_error(e))?;
448
449                if !response.status().is_success() {
450                    return Err(Self::process_error_response(response).await);
451                }
452
453                Ok(response)
454            })
455            .await?;
456
457        // Get the byte stream from the response
458        let stream = response.bytes_stream();
459
460        // Create an SSE processor
461        Ok(process_sse(stream))
462    }
463
464    /// Count tokens for a message.
465    ///
466    /// This method counts the number of tokens that would be used by a message with the given parameters.
467    /// It's useful for estimating costs or making sure your messages fit within the model's context window.
468    pub async fn count_tokens(
469        &self,
470        params: MessageCountTokensParams,
471    ) -> Result<MessageTokensCount> {
472        self.retry_with_backoff(|| async {
473            let url = format!("{}messages/count_tokens", self.base_url);
474            self.execute_post_request(&url, &params, None).await
475        })
476        .await
477    }
478
479    /// List available models from the API.
480    ///
481    /// Returns a paginated list of all available models. Use the parameters to control
482    /// pagination and filter results.
483    pub async fn list_models(&self, params: Option<ModelListParams>) -> Result<ModelListResponse> {
484        self.retry_with_backoff(|| async {
485            let url = format!("{}models", self.base_url);
486
487            let query_params = params.as_ref().map(|p| {
488                let mut params = Vec::new();
489                if let Some(ref after_id) = p.after_id {
490                    params.push(("after_id".to_string(), after_id.clone()));
491                }
492                if let Some(ref before_id) = p.before_id {
493                    params.push(("before_id".to_string(), before_id.clone()));
494                }
495                if let Some(limit) = p.limit {
496                    params.push(("limit".to_string(), limit.to_string()));
497                }
498                params
499            });
500
501            self.execute_get_request(&url, query_params.as_deref())
502                .await
503        })
504        .await
505    }
506
507    /// Retrieve information about a specific model.
508    ///
509    /// Returns detailed information about the specified model, including its
510    /// ID, creation date, display name, and type.
511    pub async fn get_model(&self, model_id: &str) -> Result<ModelInfo> {
512        self.retry_with_backoff(|| async {
513            let url = format!("{}models/{}", self.base_url, model_id);
514            self.execute_get_request(&url, None).await
515        })
516        .await
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523    use std::sync::Arc;
524    use std::sync::atomic::{AtomicUsize, Ordering};
525
526    #[tokio::test]
527    async fn retry_logic_with_backoff() {
528        let client = Anthropic {
529            api_key: "test".to_string(),
530            client: ReqwestClient::new(),
531            base_url: "http://localhost".to_string(),
532            timeout: Duration::from_secs(1),
533            max_retries: 2,
534            throughput_ops_sec: 1.0 / 60.0,
535            reserve_capacity: 1.0 / 60.0,
536            cached_headers: Arc::new(HeaderMap::new()),
537        };
538
539        let attempt_counter = Arc::new(AtomicUsize::new(0));
540        let counter_clone = attempt_counter.clone();
541
542        let result = client
543            .retry_with_backoff(|| {
544                let counter = counter_clone.clone();
545                async move {
546                    let attempt = counter.fetch_add(1, Ordering::SeqCst);
547                    match attempt {
548                        0 | 1 => Err(Error::rate_limit("Rate limited", Some(1))),
549                        _ => Ok("success".to_string()),
550                    }
551                }
552            })
553            .await;
554
555        assert!(result.is_ok());
556        assert_eq!(result.unwrap(), "success");
557        assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
558    }
559
560    #[tokio::test]
561    async fn retry_logic_with_non_retryable_error() {
562        let client = Anthropic {
563            api_key: "test".to_string(),
564            client: ReqwestClient::new(),
565            base_url: "http://localhost".to_string(),
566            timeout: Duration::from_secs(1),
567            max_retries: 2,
568            throughput_ops_sec: 1.0 / 60.0,
569            reserve_capacity: 1.0 / 60.0,
570            cached_headers: Arc::new(HeaderMap::new()),
571        };
572
573        let attempt_counter = Arc::new(AtomicUsize::new(0));
574        let counter_clone = attempt_counter.clone();
575
576        let result: Result<String> = client
577            .retry_with_backoff(|| {
578                let counter = counter_clone.clone();
579                async move {
580                    counter.fetch_add(1, Ordering::SeqCst);
581                    Err(Error::authentication("Invalid API key"))
582                }
583            })
584            .await;
585
586        assert!(result.is_err());
587        assert!(result.unwrap_err().is_authentication());
588        // Should only attempt once since authentication errors are not retryable
589        assert_eq!(attempt_counter.load(Ordering::SeqCst), 1);
590    }
591
592    #[tokio::test]
593    async fn retry_logic_max_retries_exceeded() {
594        let client = Anthropic {
595            api_key: "test".to_string(),
596            client: ReqwestClient::new(),
597            base_url: "http://localhost".to_string(),
598            timeout: Duration::from_secs(1),
599            max_retries: 2,
600            throughput_ops_sec: 1.0 / 60.0,
601            reserve_capacity: 1.0 / 60.0,
602            cached_headers: Arc::new(HeaderMap::new()),
603        };
604
605        let attempt_counter = Arc::new(AtomicUsize::new(0));
606        let counter_clone = attempt_counter.clone();
607
608        let result: Result<String> = client
609            .retry_with_backoff(|| {
610                let counter = counter_clone.clone();
611                async move {
612                    counter.fetch_add(1, Ordering::SeqCst);
613                    Err(Error::rate_limit("Always rate limited", Some(1)))
614                }
615            })
616            .await;
617
618        assert!(result.is_err());
619        assert!(result.unwrap_err().is_rate_limit());
620        // Should attempt max_retries + 1 times (3 total: initial + 2 retries)
621        assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
622    }
623
624    #[tokio::test]
625    async fn error_529_is_retryable() {
626        // Test that 529 errors are properly mapped to rate_limit and are retryable
627        let client = Anthropic {
628            api_key: "test".to_string(),
629            client: ReqwestClient::new(),
630            base_url: "http://localhost".to_string(),
631            timeout: Duration::from_secs(1),
632            max_retries: 2,
633            throughput_ops_sec: 1.0 / 60.0,
634            reserve_capacity: 1.0 / 60.0,
635            cached_headers: Arc::new(HeaderMap::new()),
636        };
637
638        let attempt_counter = Arc::new(AtomicUsize::new(0));
639        let counter_clone = attempt_counter.clone();
640
641        let result = client
642            .retry_with_backoff(|| {
643                let counter = counter_clone.clone();
644                async move {
645                    let attempt = counter.fetch_add(1, Ordering::SeqCst);
646                    match attempt {
647                        0 | 1 => {
648                            // Simulate a 529 overloaded error
649                            Err(Error::api(
650                                529,
651                                Some("overloaded_error".to_string()),
652                                "Overloaded".to_string(),
653                                None,
654                            ))
655                        }
656                        _ => Ok("success".to_string()),
657                    }
658                }
659            })
660            .await;
661
662        assert!(result.is_ok());
663        assert_eq!(result.unwrap(), "success");
664        // Should retry: initial attempt + 2 retries = 3 total
665        assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
666    }
667
668    #[test]
669    fn error_529_mapped_correctly() {
670        // Test that a 529 API error is correctly identified as retryable
671        let error = Error::api(
672            529,
673            Some("overloaded_error".to_string()),
674            "Overloaded".to_string(),
675            None,
676        );
677        assert!(error.is_retryable());
678
679        // Test that rate_limit error (which 529 now maps to) is also retryable
680        let rate_limit_error = Error::rate_limit("Overloaded", Some(5));
681        assert!(rate_limit_error.is_retryable());
682    }
683
684    #[test]
685    fn resolve_api_key_regular_value() {
686        let result = Anthropic::resolve_api_key("sk-test-key-123");
687        assert!(result.is_ok());
688        assert_eq!(result.unwrap(), "sk-test-key-123");
689    }
690
691    #[test]
692    fn resolve_api_key_file_url_absolute() {
693        let test_dir = std::env::temp_dir().join(format!("claudius_test_{}", std::process::id()));
694        std::fs::create_dir_all(&test_dir).unwrap();
695        let test_file = test_dir.join("test_api_key.txt");
696        std::fs::write(&test_file, "sk-test-from-file-123\n").unwrap();
697
698        let file_url = format!("file://{}", test_file.display());
699        let result = Anthropic::resolve_api_key(&file_url);
700
701        std::fs::remove_dir_all(&test_dir).unwrap();
702
703        assert!(result.is_ok());
704        assert_eq!(result.unwrap(), "sk-test-from-file-123");
705    }
706
707    #[test]
708    fn resolve_api_key_file_url_relative() {
709        let test_file = "test_relative_key.txt";
710        std::fs::write(test_file, "sk-relative-key-456\n").unwrap();
711
712        let file_url = format!("file://{}", test_file);
713        let result = Anthropic::resolve_api_key(&file_url);
714
715        std::fs::remove_file(test_file).unwrap();
716
717        assert!(result.is_ok());
718        assert_eq!(result.unwrap(), "sk-relative-key-456");
719    }
720
721    #[test]
722    fn resolve_api_key_file_url_nonexistent() {
723        let result = Anthropic::resolve_api_key("file:///nonexistent/path/to/key.txt");
724        assert!(result.is_err());
725
726        let error = result.unwrap_err();
727        assert!(error.is_validation());
728        assert!(format!("{}", error).contains("Failed to read API key from file"));
729    }
730
731    #[test]
732    fn resolve_api_key_file_url_with_whitespace() {
733        let test_file = "test_whitespace_key.txt";
734        std::fs::write(test_file, "  sk-whitespace-key-789  \n  ").unwrap();
735
736        let file_url = format!("file://{}", test_file);
737        let result = Anthropic::resolve_api_key(&file_url);
738
739        std::fs::remove_file(test_file).unwrap();
740
741        assert!(result.is_ok());
742        assert_eq!(result.unwrap(), "sk-whitespace-key-789");
743    }
744
745    #[test]
746    fn client_builder_methods() {
747        let client = Anthropic::new(Some("test_key".to_string())).unwrap();
748
749        // Test builder pattern methods
750        let configured_client = client
751            .with_base_url("https://custom.api.com/v1/".to_string())
752            .with_max_retries(5)
753            .with_backoff_params(2.0, 1.0);
754
755        assert_eq!(configured_client.base_url, "https://custom.api.com/v1/");
756        assert_eq!(configured_client.max_retries, 5);
757        assert_eq!(configured_client.throughput_ops_sec, 2.0);
758        assert_eq!(configured_client.reserve_capacity, 1.0);
759    }
760
761    #[test]
762    fn client_timeout_configuration() {
763        let client = Anthropic::new(Some("test_key".to_string())).unwrap();
764        let timeout = Duration::from_secs(30);
765
766        let configured_client = client.with_timeout(timeout).unwrap();
767        assert_eq!(configured_client.timeout, timeout);
768    }
769
770    #[test]
771    fn client_cached_headers_performance() {
772        let client = Anthropic::new(Some("test_key".to_string())).unwrap();
773
774        // Test that headers are cached and cloning is cheap
775        let headers1 = client.default_headers();
776        let headers2 = client.default_headers();
777
778        assert_eq!(headers1.len(), headers2.len());
779        assert!(headers1.contains_key("x-api-key"));
780        assert!(headers1.contains_key("anthropic-version"));
781        assert!(headers1.contains_key("content-type"));
782    }
783
784    #[test]
785    fn request_error_mapping() {
786        let client = Anthropic::new(Some("test_key".to_string())).unwrap();
787
788        // Test different types of reqwest errors are mapped correctly
789        // Note: These are unit tests for the mapping logic, not integration tests
790        let _timeout = Duration::from_secs(30);
791        assert_eq!(client.timeout, DEFAULT_TIMEOUT); // Should use default initially
792    }
793
794    #[tokio::test]
795    async fn concurrent_retry_safety() {
796        use std::sync::atomic::{AtomicUsize, Ordering};
797        use tokio::spawn;
798
799        let client = Anthropic {
800            api_key: "test".to_string(),
801            client: ReqwestClient::new(),
802            base_url: "http://localhost".to_string(),
803            timeout: Duration::from_secs(1),
804            max_retries: 1,
805            throughput_ops_sec: 1.0,
806            reserve_capacity: 1.0,
807            cached_headers: Arc::new(HeaderMap::new()),
808        };
809
810        let attempt_counter = Arc::new(AtomicUsize::new(0));
811        let mut handles = vec![];
812
813        // Spawn multiple concurrent retry operations
814        for _ in 0..3 {
815            let client_clone = client.clone();
816            let counter_clone = attempt_counter.clone();
817
818            let handle = spawn(async move {
819                client_clone
820                    .retry_with_backoff(|| {
821                        let counter = counter_clone.clone();
822                        async move {
823                            counter.fetch_add(1, Ordering::SeqCst);
824                            Ok::<String, Error>("success".to_string())
825                        }
826                    })
827                    .await
828            });
829            handles.push(handle);
830        }
831
832        // Wait for all operations to complete
833        for handle in handles {
834            let result = handle.await.unwrap();
835            assert!(result.is_ok());
836        }
837
838        // Verify all operations executed
839        assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
840    }
841}