use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use std::time::Duration;
use crate::error::Error;
use crate::models::Models;
use crate::types::{ChatCompletionRequest, ChatCompletionResponse};
use crate::{DEFAULT_BASE_URL, DEFAULT_MAX_RETRIES, DEFAULT_TIMEOUT};
pub struct Client {
api_key: String,
base_url: String,
http_client: reqwest::Client,
max_retries: u32,
}
impl Client {
pub fn new(api_key: impl Into<String>) -> Self {
Self::with_config(api_key, ClientConfig::default())
}
pub fn with_config(api_key: impl Into<String>, config: ClientConfig) -> Self {
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(config.timeout))
.default_headers(headers)
.build()
.expect("Failed to create HTTP client");
Self {
api_key: api_key.into(),
base_url: config.base_url,
http_client,
max_retries: config.max_retries,
}
}
pub fn chat(&self) -> Chat<'_> {
Chat { client: self }
}
pub fn models(&self) -> Models {
Models
}
async fn request<T: serde::de::DeserializeOwned>(
&self,
method: reqwest::Method,
path: &str,
body: Option<&impl serde::Serialize>,
) -> Result<T, Error> {
let url = format!("{}{}", self.base_url, path);
let mut retries = 0;
loop {
let mut request = self
.http_client
.request(method.clone(), &url)
.header(AUTHORIZATION, format!("Bearer {}", self.api_key));
if let Some(body) = body {
request = request.json(body);
}
let response = request.send().await?;
let status = response.status();
if status.is_success() {
return Ok(response.json().await?);
}
let error_text = response.text().await.unwrap_or_default();
let error_message = extract_error_message(&error_text);
let error = match status.as_u16() {
401 => Error::Authentication(error_message),
402 => Error::InsufficientBalance(error_message),
429 => Error::RateLimit(error_message),
400 => Error::BadRequest(error_message),
500..=599 => Error::InternalServer(error_message),
_ => Error::Unknown(error_message),
};
if error.is_retryable() && retries < self.max_retries {
retries += 1;
tokio::time::sleep(Duration::from_secs(2u64.pow(retries))).await;
continue;
}
return Err(error);
}
}
}
pub struct ClientConfig {
pub base_url: String,
pub timeout: u64,
pub max_retries: u32,
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
base_url: DEFAULT_BASE_URL.to_string(),
timeout: DEFAULT_TIMEOUT,
max_retries: DEFAULT_MAX_RETRIES,
}
}
}
impl ClientConfig {
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
pub fn timeout(mut self, timeout: u64) -> Self {
self.timeout = timeout;
self
}
pub fn max_retries(mut self, retries: u32) -> Self {
self.max_retries = retries;
self
}
}
pub struct Chat<'a> {
client: &'a Client,
}
impl<'a> Chat<'a> {
pub fn completions(&self) -> Completions<'a> {
Completions { client: self.client }
}
}
pub struct Completions<'a> {
client: &'a Client,
}
impl<'a> Completions<'a> {
pub async fn create(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, Error> {
self.client
.request(reqwest::Method::POST, "/v1/chat/completions", Some(&request))
.await
}
}
fn extract_error_message(text: &str) -> String {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
if let Some(message) = json.get("error").and_then(|e| e.get("message")).and_then(|m| m.as_str()) {
return message.to_string();
}
}
text.to_string()
}