use crate::auth::AuthStrategy;
use crate::error::{OpenAiError, Result};
use crate::types::{
CreateChatCompletionRequest, CreateChatCompletionResponse, ErrorResponse, Model, ModelList,
};
use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
use std::sync::Arc;
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
pub struct Client {
http: reqwest::Client,
auth: Arc<dyn AuthStrategy>,
base_url: String,
}
impl Client {
pub fn builder() -> ClientBuilder<()> {
ClientBuilder::new()
}
pub async fn create_chat_completion(
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse> {
let url = format!("{}/chat/completions", self.base_url);
self.post(&url, &request).await
}
pub async fn list_models(&self) -> Result<ModelList> {
let url = format!("{}/models", self.base_url);
self.get(&url).await
}
pub async fn get_model(&self, model_id: &str) -> Result<Model> {
let url = format!("{}/models/{}", self.base_url, model_id);
self.get(&url).await
}
async fn get<T>(&self, url: &str) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
let mut headers = HeaderMap::new();
self.auth.apply(&mut headers).await?;
tracing::debug!(url = %url, "GET request");
let response = self.http.get(url).headers(headers).send().await?;
self.handle_response(response).await
}
async fn post<T, B>(&self, url: &str, body: &B) -> Result<T>
where
T: serde::de::DeserializeOwned,
B: serde::Serialize,
{
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
self.auth.apply(&mut headers).await?;
tracing::debug!(url = %url, "POST request");
let response = self
.http
.post(url)
.headers(headers)
.json(body)
.send()
.await?;
self.handle_response(response).await
}
async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
let status = response.status();
let status_code = status.as_u16();
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse().ok());
if status.is_success() {
let body = response.text().await?;
tracing::debug!(status = %status_code, "Response received");
serde_json::from_str(&body).map_err(OpenAiError::from)
} else {
let body = response.text().await?;
tracing::warn!(status = %status_code, body = %body, "API error");
if let Ok(error_response) = serde_json::from_str::<ErrorResponse>(&body) {
let message = error_response.error.message;
let code = error_response.error.code.as_deref();
return Err(match status_code {
401 => OpenAiError::Unauthorized,
403 => OpenAiError::Forbidden(message),
404 => OpenAiError::NotFound(message),
429 => OpenAiError::RateLimited {
retry_after: retry_after.unwrap_or(60),
},
500..=599 => OpenAiError::ServerError(message),
_ => match code {
Some("context_length_exceeded") => {
OpenAiError::ContextLengthExceeded(message)
}
Some("invalid_request_error") => OpenAiError::InvalidRequest(message),
_ => OpenAiError::Api {
status: status_code,
message,
},
},
});
}
Err(OpenAiError::Api {
status: status_code,
message: body,
})
}
}
}
pub struct ClientBuilder<A> {
auth: A,
base_url: String,
}
impl ClientBuilder<()> {
pub fn new() -> Self {
Self {
auth: (),
base_url: DEFAULT_BASE_URL.to_string(),
}
}
pub fn auth<S: AuthStrategy + 'static>(self, strategy: S) -> ClientBuilder<S> {
ClientBuilder {
auth: strategy,
base_url: self.base_url,
}
}
}
impl Default for ClientBuilder<()> {
fn default() -> Self {
Self::new()
}
}
impl<A: AuthStrategy + 'static> ClientBuilder<A> {
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
pub fn build(self) -> Client {
Client {
http: reqwest::Client::new(),
auth: Arc::new(self.auth),
base_url: self.base_url,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::ApiKeyAuth;
use crate::types::Message;
#[test]
fn test_builder() {
let client = Client::builder()
.auth(ApiKeyAuth::new("test-key"))
.base_url("https://custom.api.com")
.build();
assert_eq!(client.base_url, "https://custom.api.com");
}
#[test]
fn test_create_chat_completion_request() {
let request = CreateChatCompletionRequest::new("gpt-4o", vec![Message::user("Hello")])
.with_max_tokens(1024)
.with_temperature(0.7);
assert_eq!(request.model, "gpt-4o");
assert_eq!(request.max_tokens, Some(1024));
assert_eq!(request.temperature, Some(0.7));
}
}