use crate::config::Config;
use crate::error::{Error, Result};
use reqwest::{header, Method, RequestBuilder, StatusCode};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct Client {
http: reqwest::Client,
config: Config,
}
impl Client {
pub fn new(config: Config) -> Result<Self> {
let mut headers = header::HeaderMap::new();
let auth_value = format!("Bearer {}", config.api_key());
headers.insert(
header::AUTHORIZATION,
header::HeaderValue::from_str(&auth_value)
.map_err(|_| Error::Config("Invalid API key format".into()))?,
);
let http = reqwest::Client::builder()
.default_headers(headers)
.timeout(config.timeout())
.build()
.map_err(|e| Error::Config(format!("Failed to create HTTP client: {}", e)))?;
Ok(Self { http, config })
}
pub(crate) async fn get<T>(&self, path: &str) -> Result<T>
where
T: DeserializeOwned + Send + 'static,
{
self.request(Method::GET, path).send().await
}
pub(crate) async fn post<B, T>(&self, path: &str, body: &B) -> Result<T>
where
B: Serialize,
T: DeserializeOwned + Send + 'static,
{
self.request(Method::POST, path).json(body).send().await
}
pub(crate) async fn put<B, T>(&self, path: &str, body: &B) -> Result<T>
where
B: Serialize,
T: DeserializeOwned + Send + 'static,
{
self.request(Method::PUT, path).json(body).send().await
}
fn request(&self, method: Method, path: &str) -> Request<'_> {
let url = self
.config
.base_url()
.join(path.trim_start_matches('/'))
.unwrap();
let builder = self.http.request(method, url);
Request {
client: self,
builder,
attempt: 0,
}
}
fn execute<T>(
&self,
builder: RequestBuilder,
attempt: u32,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send + '_>>
where
T: DeserializeOwned + Send + 'static,
{
Box::pin(async move {
let response = builder
.try_clone()
.ok_or_else(|| Error::Unknown("Failed to clone request".into()))?
.send()
.await?;
let status = response.status();
if status == StatusCode::TOO_MANY_REQUESTS {
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.map(Duration::from_secs);
return Err(Error::RateLimit { retry_after });
}
if status.is_success() {
let text = response.text().await?;
if text.is_empty() {
return serde_json::from_str("{}").map_err(Into::into);
}
return serde_json::from_str(&text).map_err(Error::Serialization);
}
let error_body = response.text().await.unwrap_or_default();
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&error_body) {
let message = json
.get("message")
.or_else(|| json.get("error"))
.and_then(|v| v.as_str())
.unwrap_or("Unknown error");
let code = json
.get("code")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let error = if let Some(code) = code {
Error::api_error_with_code(status.as_u16(), message, code)
} else {
Error::api_error(status.as_u16(), message)
};
if error.is_retryable() && attempt < self.config.max_retries() {
let backoff = Duration::from_millis(100 * 2_u64.pow(attempt));
tokio::time::sleep(backoff).await;
return self.execute(builder, attempt + 1).await;
}
return Err(error);
}
Err(Error::api_error(
status.as_u16(),
format!("Request failed with status {}: {}", status, error_body),
))
})
}
}
struct Request<'a> {
client: &'a Client,
builder: RequestBuilder,
attempt: u32,
}
impl<'a> Request<'a> {
fn json<T: Serialize>(mut self, json: &T) -> Self {
self.builder = self.builder.json(json);
self
}
async fn send<T>(self) -> Result<T>
where
T: DeserializeOwned + Send + 'static,
{
self.client.execute(self.builder, self.attempt).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let config = Config::new("sk_test_123");
let client = Client::new(config);
assert!(client.is_ok());
}
#[test]
fn test_client_invalid_api_key() {
let config = Config::new("invalid\nkey");
let client = Client::new(config);
assert!(client.is_err());
}
}