use crate::error::TokenError;
use reqwest;
use serde::{Deserialize, Serialize};
use std::time::Duration;
pub struct ClaudeApiClient {
client: reqwest::Client,
api_key: String,
}
impl ClaudeApiClient {
const API_ENDPOINT: &'static str = "https://api.anthropic.com/v1/messages/count_tokens";
const API_VERSION: &'static str = "2023-06-01";
const TIMEOUT_SECS: u64 = 30;
const MAX_RETRIES: u32 = 3;
pub fn new(api_key: String) -> Result<Self, TokenError> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(Self::TIMEOUT_SECS))
.user_agent(format!("token-count/{}", env!("CARGO_PKG_VERSION")))
.build()
.map_err(|e| TokenError::ApiError(format!("Failed to create HTTP client: {}", e)))?;
Ok(Self { client, api_key })
}
pub async fn count_tokens(&self, model: &str, text: &str) -> Result<usize, TokenError> {
let mut attempts = 0;
loop {
match self.try_count_tokens(model, text).await {
Ok(count) => return Ok(count),
Err(e) if Self::is_retryable(&e) && attempts < Self::MAX_RETRIES - 1 => {
let backoff_ms = 2u64.pow(attempts) * 1000; tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
attempts += 1;
}
Err(e) => return Err(e),
}
}
}
fn is_retryable(error: &TokenError) -> bool {
matches!(error, TokenError::RateLimited | TokenError::ApiServerError(_))
}
async fn try_count_tokens(&self, model: &str, text: &str) -> Result<usize, TokenError> {
let request = CountTokensRequest {
model: model.to_string(),
messages: vec![Message { role: "user".to_string(), content: text.to_string() }],
};
let response = self
.client
.post(Self::API_ENDPOINT)
.header("x-api-key", &self.api_key)
.header("anthropic-version", Self::API_VERSION)
.json(&request)
.send()
.await
.map_err(|e| TokenError::ApiError(format!("Network error: {}", e)))?;
let status = response.status();
if !status.is_success() {
return Err(Self::parse_api_error(status.as_u16(), response).await);
}
let body: CountTokensResponse = response
.json()
.await
.map_err(|e| TokenError::ApiError(format!("Failed to parse response: {}", e)))?;
Ok(body.input_tokens)
}
async fn parse_api_error(status_code: u16, response: reqwest::Response) -> TokenError {
let error_msg = match response.json::<ApiErrorResponse>().await {
Ok(err_resp) => format!("{}: {}", err_resp.error.error_type, err_resp.error.message),
Err(_) => format!("HTTP {}", status_code),
};
match status_code {
401 => TokenError::InvalidApiKey,
429 => TokenError::RateLimited,
500..=599 => TokenError::ApiServerError(status_code),
_ => TokenError::ApiError(error_msg),
}
}
}
#[derive(Debug, Serialize)]
struct CountTokensRequest {
model: String,
messages: Vec<Message>,
}
#[derive(Debug, Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct CountTokensResponse {
input_tokens: usize,
}
#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
error: ApiError,
}
#[derive(Debug, Deserialize)]
struct ApiError {
#[serde(rename = "type")]
error_type: String,
message: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_success_response() {
let json = r#"{"input_tokens": 42}"#;
let response: CountTokensResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.input_tokens, 42);
}
#[test]
fn test_parse_error_response() {
let json = r#"{"error": {"type": "authentication_error", "message": "invalid key"}}"#;
let response: ApiErrorResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.error.error_type, "authentication_error");
assert_eq!(response.error.message, "invalid key");
}
#[test]
fn test_serialize_request() {
let request = CountTokensRequest {
model: "claude-sonnet-4-6".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: "Hello, Claude!".to_string(),
}],
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("claude-sonnet-4-6"));
assert!(json.contains("Hello, Claude!"));
assert!(json.contains("user"));
}
#[test]
fn test_is_retryable() {
assert!(ClaudeApiClient::is_retryable(&TokenError::RateLimited));
assert!(ClaudeApiClient::is_retryable(&TokenError::ApiServerError(500)));
assert!(!ClaudeApiClient::is_retryable(&TokenError::InvalidApiKey));
assert!(!ClaudeApiClient::is_retryable(&TokenError::ApiError("test".to_string())));
}
#[test]
fn test_client_creation() {
let client = ClaudeApiClient::new("test-key".to_string());
assert!(client.is_ok());
}
}