use futures::Stream;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Client as ReqwestClient, Response, header};
use serde::Deserialize;
use std::env;
use std::fs;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
use crate::backoff::ExponentialBackoff;
use crate::error::{Error, Result};
use crate::sse::process_sse;
use crate::types::{
Message, MessageCountTokensParams, MessageCreateParams, MessageStreamEvent, MessageTokensCount,
ModelInfo, ModelListParams, ModelListResponse,
};
const DEFAULT_API_URL: &str = "https://api.anthropic.com/v1/";
const ANTHROPIC_API_VERSION: &str = "2023-06-01";
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60);
#[derive(Debug, Clone)]
pub struct Anthropic {
api_key: String,
client: ReqwestClient,
base_url: String,
timeout: Duration,
max_retries: usize,
throughput_ops_sec: f64,
reserve_capacity: f64,
cached_headers: Arc<HeaderMap>,
}
impl Anthropic {
fn resolve_api_key(key_value: &str) -> Result<String> {
if let Some(stripped) = key_value.strip_prefix("file://") {
let path = if stripped.starts_with('/') {
stripped.to_string()
} else {
stripped.to_string()
};
fs::read_to_string(&path)
.map(|content| content.trim().to_string())
.map_err(|e| {
Error::validation(
format!("Failed to read API key from file '{}': {}", path, e),
Some("api_key".to_string()),
)
})
} else {
Ok(key_value.to_string())
}
}
pub fn new(api_key: Option<String>) -> Result<Self> {
let api_key = match api_key {
Some(key) => Self::resolve_api_key(&key)?,
None => match env::var("CLAUDIUS_API_KEY").ok() {
Some(key) => Self::resolve_api_key(&key)?,
None => {
let env_key = env::var("ANTHROPIC_API_KEY").map_err(|_| {
Error::authentication(
"API key not provided and ANTHROPIC_API_KEY environment variable not set",
)
})?;
Self::resolve_api_key(&env_key)?
}
},
};
let timeout = DEFAULT_TIMEOUT;
let client = ReqwestClient::builder()
.timeout(timeout)
.pool_max_idle_per_host(10) .pool_idle_timeout(Duration::from_secs(90))
.tcp_keepalive(Duration::from_secs(60))
.build()
.map_err(|e| {
Error::http_client(
format!("Failed to build HTTP client: {e}"),
Some(Box::new(e)),
)
})?;
let cached_headers = Arc::new(Self::build_default_headers(&api_key)?);
Ok(Self {
api_key,
client,
base_url: DEFAULT_API_URL.to_string(),
timeout,
max_retries: 3,
throughput_ops_sec: 1.0 / 60.0,
reserve_capacity: 1.0 / 60.0,
cached_headers,
})
}
pub fn with_base_url(mut self, base_url: String) -> Self {
self.base_url = base_url;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.timeout = timeout;
let client = ReqwestClient::builder()
.timeout(timeout)
.pool_max_idle_per_host(10)
.pool_idle_timeout(Duration::from_secs(90))
.tcp_keepalive(Duration::from_secs(60))
.build()
.map_err(|e| {
Error::http_client(
"Failed to build HTTP client with new timeout",
Some(Box::new(e)),
)
})?;
self.client = client;
Ok(self)
}
pub fn with_max_retries(mut self, max_retries: usize) -> Self {
self.max_retries = max_retries;
self
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn with_backoff_params(mut self, throughput_ops_sec: f64, reserve_capacity: f64) -> Self {
self.throughput_ops_sec = throughput_ops_sec;
self.reserve_capacity = reserve_capacity;
self
}
pub fn with_base_url_and_timeout(self, base_url: String, timeout: Duration) -> Result<Self> {
self.with_base_url(base_url).with_timeout(timeout)
}
fn build_default_headers(api_key: &str) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
headers.insert(
"x-api-key",
HeaderValue::from_str(api_key).map_err(|e| {
Error::validation(
format!("Invalid API key format: {e}"),
Some("api_key".to_string()),
)
})?,
);
headers.insert(
"anthropic-version",
HeaderValue::from_static(ANTHROPIC_API_VERSION),
);
Ok(headers)
}
fn default_headers(&self) -> HeaderMap {
(*self.cached_headers).clone()
}
async fn retry_with_backoff<F, Fut, T>(&self, operation: F) -> Result<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let backoff = ExponentialBackoff::new(self.throughput_ops_sec, self.reserve_capacity);
let mut last_error = None;
for attempt in 0..=self.max_retries {
match operation().await {
Ok(result) => return Ok(result),
Err(error) => {
if !error.is_retryable() {
return Err(error);
}
if attempt == self.max_retries {
last_error = Some(error);
break;
}
let exp_backoff_duration = backoff.next();
let header_backoff_duration = match &error {
Error::RateLimit {
retry_after: Some(seconds),
..
} => Some(Duration::from_secs(*seconds)),
Error::ServiceUnavailable {
retry_after: Some(seconds),
..
} => Some(Duration::from_secs(*seconds)),
_ => None,
};
let sleep_duration = match header_backoff_duration {
Some(header_duration) => exp_backoff_duration.max(header_duration),
None => exp_backoff_duration,
};
sleep(sleep_duration).await;
last_error = Some(error);
}
}
}
Err(last_error
.unwrap_or_else(|| Error::unknown("Failed after retries without capturing error")))
}
async fn process_error_response(response: Response) -> Error {
let status = response.status();
let status_code = status.as_u16();
let request_id = response
.headers()
.get("x-request-id")
.and_then(|val| val.to_str().ok())
.map(String::from);
let retry_after = response
.headers()
.get("retry-after")
.and_then(|val| val.to_str().ok())
.and_then(|val| val.parse::<u64>().ok());
#[derive(Deserialize)]
struct ErrorResponse {
error: Option<ErrorDetail>,
}
#[derive(Deserialize)]
struct ErrorDetail {
#[serde(rename = "type")]
error_type: Option<String>,
message: Option<String>,
param: Option<String>,
}
let error_body = match response.text().await {
Ok(body) => body,
Err(e) => {
return Error::http_client(
format!("Failed to read error response: {e}"),
Some(Box::new(e)),
);
}
};
let parsed_error = serde_json::from_str::<ErrorResponse>(&error_body).ok();
let error_type = parsed_error
.as_ref()
.and_then(|e| e.error.as_ref())
.and_then(|e| e.error_type.clone());
let error_message = parsed_error
.as_ref()
.and_then(|e| e.error.as_ref())
.and_then(|e| e.message.clone())
.unwrap_or_else(|| error_body.clone());
let error_param = parsed_error
.as_ref()
.and_then(|e| e.error.as_ref())
.and_then(|e| e.param.clone());
match status_code {
400 => Error::bad_request(error_message, error_param),
401 => Error::authentication(error_message),
403 => Error::permission(error_message),
404 => Error::not_found(error_message, None, None),
408 => Error::timeout(error_message, None),
429 => Error::rate_limit(error_message, retry_after),
500 => Error::internal_server(error_message, request_id),
502..=504 => Error::service_unavailable(error_message, retry_after),
529 => Error::rate_limit(error_message, retry_after),
_ => Error::api(status_code, error_type, error_message, request_id),
}
}
fn map_request_error(&self, e: reqwest::Error) -> Error {
if e.is_timeout() {
Error::timeout(
format!("Request timed out: {e}"),
Some(self.timeout.as_secs_f64()),
)
} else if e.is_connect() {
Error::connection(format!("Connection error: {e}"), Some(Box::new(e)))
} else {
Error::http_client(format!("Request failed: {e}"), Some(Box::new(e)))
}
}
async fn execute_post_request<T: serde::de::DeserializeOwned>(
&self,
url: &str,
body: &impl serde::Serialize,
headers: Option<HeaderMap>,
) -> Result<T> {
let headers = headers.unwrap_or_else(|| self.default_headers());
let response = self
.client
.post(url)
.headers(headers)
.json(body)
.send()
.await
.map_err(|e| self.map_request_error(e))?;
if !response.status().is_success() {
return Err(Self::process_error_response(response).await);
}
response.json::<T>().await.map_err(|e| {
Error::serialization(format!("Failed to parse response: {e}"), Some(Box::new(e)))
})
}
async fn execute_get_request<T: serde::de::DeserializeOwned>(
&self,
url: &str,
query_params: Option<&[(String, String)]>,
) -> Result<T> {
let mut request = self.client.get(url).headers(self.default_headers());
if let Some(params) = query_params {
for (key, value) in params {
request = request.query(&[(key, value)]);
}
}
let response = request
.send()
.await
.map_err(|e| self.map_request_error(e))?;
if !response.status().is_success() {
return Err(Self::process_error_response(response).await);
}
response.json::<T>().await.map_err(|e| {
Error::serialization(format!("Failed to parse response: {e}"), Some(Box::new(e)))
})
}
pub async fn send(&self, mut params: MessageCreateParams) -> Result<Message> {
params.validate()?;
params.stream = false;
self.retry_with_backoff(|| async {
let url = format!("{}messages", self.base_url);
self.execute_post_request(&url, ¶ms, None).await
})
.await
}
pub async fn stream(
&self,
mut params: MessageCreateParams,
) -> Result<impl Stream<Item = Result<MessageStreamEvent>>> {
params.validate()?;
params.stream = true;
let response = self
.retry_with_backoff(|| async {
let url = format!("{}messages", self.base_url);
let mut headers = self.default_headers();
headers.insert(
header::ACCEPT,
HeaderValue::from_static("text/event-stream"),
);
let response = self
.client
.post(&url)
.headers(headers)
.json(¶ms)
.send()
.await
.map_err(|e| self.map_request_error(e))?;
if !response.status().is_success() {
return Err(Self::process_error_response(response).await);
}
Ok(response)
})
.await?;
let stream = response.bytes_stream();
Ok(process_sse(stream))
}
pub async fn count_tokens(
&self,
params: MessageCountTokensParams,
) -> Result<MessageTokensCount> {
self.retry_with_backoff(|| async {
let url = format!("{}messages/count_tokens", self.base_url);
self.execute_post_request(&url, ¶ms, None).await
})
.await
}
pub async fn list_models(&self, params: Option<ModelListParams>) -> Result<ModelListResponse> {
self.retry_with_backoff(|| async {
let url = format!("{}models", self.base_url);
let query_params = params.as_ref().map(|p| {
let mut params = Vec::new();
if let Some(ref after_id) = p.after_id {
params.push(("after_id".to_string(), after_id.clone()));
}
if let Some(ref before_id) = p.before_id {
params.push(("before_id".to_string(), before_id.clone()));
}
if let Some(limit) = p.limit {
params.push(("limit".to_string(), limit.to_string()));
}
params
});
self.execute_get_request(&url, query_params.as_deref())
.await
})
.await
}
pub async fn get_model(&self, model_id: &str) -> Result<ModelInfo> {
self.retry_with_backoff(|| async {
let url = format!("{}models/{}", self.base_url, model_id);
self.execute_get_request(&url, None).await
})
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[tokio::test]
async fn retry_logic_with_backoff() {
let client = Anthropic {
api_key: "test".to_string(),
client: ReqwestClient::new(),
base_url: "http://localhost".to_string(),
timeout: Duration::from_secs(1),
max_retries: 2,
throughput_ops_sec: 1.0 / 60.0,
reserve_capacity: 1.0 / 60.0,
cached_headers: Arc::new(HeaderMap::new()),
};
let attempt_counter = Arc::new(AtomicUsize::new(0));
let counter_clone = attempt_counter.clone();
let result = client
.retry_with_backoff(|| {
let counter = counter_clone.clone();
async move {
let attempt = counter.fetch_add(1, Ordering::SeqCst);
match attempt {
0 | 1 => Err(Error::rate_limit("Rate limited", Some(1))),
_ => Ok("success".to_string()),
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn retry_logic_with_non_retryable_error() {
let client = Anthropic {
api_key: "test".to_string(),
client: ReqwestClient::new(),
base_url: "http://localhost".to_string(),
timeout: Duration::from_secs(1),
max_retries: 2,
throughput_ops_sec: 1.0 / 60.0,
reserve_capacity: 1.0 / 60.0,
cached_headers: Arc::new(HeaderMap::new()),
};
let attempt_counter = Arc::new(AtomicUsize::new(0));
let counter_clone = attempt_counter.clone();
let result: Result<String> = client
.retry_with_backoff(|| {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err(Error::authentication("Invalid API key"))
}
})
.await;
assert!(result.is_err());
assert!(result.unwrap_err().is_authentication());
assert_eq!(attempt_counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn retry_logic_max_retries_exceeded() {
let client = Anthropic {
api_key: "test".to_string(),
client: ReqwestClient::new(),
base_url: "http://localhost".to_string(),
timeout: Duration::from_secs(1),
max_retries: 2,
throughput_ops_sec: 1.0 / 60.0,
reserve_capacity: 1.0 / 60.0,
cached_headers: Arc::new(HeaderMap::new()),
};
let attempt_counter = Arc::new(AtomicUsize::new(0));
let counter_clone = attempt_counter.clone();
let result: Result<String> = client
.retry_with_backoff(|| {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err(Error::rate_limit("Always rate limited", Some(1)))
}
})
.await;
assert!(result.is_err());
assert!(result.unwrap_err().is_rate_limit());
assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn error_529_is_retryable() {
let client = Anthropic {
api_key: "test".to_string(),
client: ReqwestClient::new(),
base_url: "http://localhost".to_string(),
timeout: Duration::from_secs(1),
max_retries: 2,
throughput_ops_sec: 1.0 / 60.0,
reserve_capacity: 1.0 / 60.0,
cached_headers: Arc::new(HeaderMap::new()),
};
let attempt_counter = Arc::new(AtomicUsize::new(0));
let counter_clone = attempt_counter.clone();
let result = client
.retry_with_backoff(|| {
let counter = counter_clone.clone();
async move {
let attempt = counter.fetch_add(1, Ordering::SeqCst);
match attempt {
0 | 1 => {
Err(Error::api(
529,
Some("overloaded_error".to_string()),
"Overloaded".to_string(),
None,
))
}
_ => Ok("success".to_string()),
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
}
#[test]
fn error_529_mapped_correctly() {
let error = Error::api(
529,
Some("overloaded_error".to_string()),
"Overloaded".to_string(),
None,
);
assert!(error.is_retryable());
let rate_limit_error = Error::rate_limit("Overloaded", Some(5));
assert!(rate_limit_error.is_retryable());
}
#[test]
fn resolve_api_key_regular_value() {
let result = Anthropic::resolve_api_key("sk-test-key-123");
assert!(result.is_ok());
assert_eq!(result.unwrap(), "sk-test-key-123");
}
#[test]
fn resolve_api_key_file_url_absolute() {
let test_dir = std::env::temp_dir().join(format!("claudius_test_{}", std::process::id()));
std::fs::create_dir_all(&test_dir).unwrap();
let test_file = test_dir.join("test_api_key.txt");
std::fs::write(&test_file, "sk-test-from-file-123\n").unwrap();
let file_url = format!("file://{}", test_file.display());
let result = Anthropic::resolve_api_key(&file_url);
std::fs::remove_dir_all(&test_dir).unwrap();
assert!(result.is_ok());
assert_eq!(result.unwrap(), "sk-test-from-file-123");
}
#[test]
fn resolve_api_key_file_url_relative() {
let test_file = "test_relative_key.txt";
std::fs::write(test_file, "sk-relative-key-456\n").unwrap();
let file_url = format!("file://{}", test_file);
let result = Anthropic::resolve_api_key(&file_url);
std::fs::remove_file(test_file).unwrap();
assert!(result.is_ok());
assert_eq!(result.unwrap(), "sk-relative-key-456");
}
#[test]
fn resolve_api_key_file_url_nonexistent() {
let result = Anthropic::resolve_api_key("file:///nonexistent/path/to/key.txt");
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.is_validation());
assert!(format!("{}", error).contains("Failed to read API key from file"));
}
#[test]
fn resolve_api_key_file_url_with_whitespace() {
let test_file = "test_whitespace_key.txt";
std::fs::write(test_file, " sk-whitespace-key-789 \n ").unwrap();
let file_url = format!("file://{}", test_file);
let result = Anthropic::resolve_api_key(&file_url);
std::fs::remove_file(test_file).unwrap();
assert!(result.is_ok());
assert_eq!(result.unwrap(), "sk-whitespace-key-789");
}
#[test]
fn client_builder_methods() {
let client = Anthropic::new(Some("test_key".to_string())).unwrap();
let configured_client = client
.with_base_url("https://custom.api.com/v1/".to_string())
.with_max_retries(5)
.with_backoff_params(2.0, 1.0);
assert_eq!(configured_client.base_url, "https://custom.api.com/v1/");
assert_eq!(configured_client.max_retries, 5);
assert_eq!(configured_client.throughput_ops_sec, 2.0);
assert_eq!(configured_client.reserve_capacity, 1.0);
}
#[test]
fn client_timeout_configuration() {
let client = Anthropic::new(Some("test_key".to_string())).unwrap();
let timeout = Duration::from_secs(30);
let configured_client = client.with_timeout(timeout).unwrap();
assert_eq!(configured_client.timeout, timeout);
}
#[test]
fn client_cached_headers_performance() {
let client = Anthropic::new(Some("test_key".to_string())).unwrap();
let headers1 = client.default_headers();
let headers2 = client.default_headers();
assert_eq!(headers1.len(), headers2.len());
assert!(headers1.contains_key("x-api-key"));
assert!(headers1.contains_key("anthropic-version"));
assert!(headers1.contains_key("content-type"));
}
#[test]
fn request_error_mapping() {
let client = Anthropic::new(Some("test_key".to_string())).unwrap();
let _timeout = Duration::from_secs(30);
assert_eq!(client.timeout, DEFAULT_TIMEOUT); }
#[tokio::test]
async fn concurrent_retry_safety() {
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::spawn;
let client = Anthropic {
api_key: "test".to_string(),
client: ReqwestClient::new(),
base_url: "http://localhost".to_string(),
timeout: Duration::from_secs(1),
max_retries: 1,
throughput_ops_sec: 1.0,
reserve_capacity: 1.0,
cached_headers: Arc::new(HeaderMap::new()),
};
let attempt_counter = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..3 {
let client_clone = client.clone();
let counter_clone = attempt_counter.clone();
let handle = spawn(async move {
client_clone
.retry_with_backoff(|| {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok::<String, Error>("success".to_string())
}
})
.await
});
handles.push(handle);
}
for handle in handles {
let result = handle.await.unwrap();
assert!(result.is_ok());
}
assert_eq!(attempt_counter.load(Ordering::SeqCst), 3);
}
}