use crate::modules::{
chat::Chat, completions::Completions, embeddings::Embeddings, models::Models,
};
use crate::{config::Config, service::client::HttpClient};
use http::HeaderValue;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{RwLock, RwLockReadGuard};
pub struct OpenAI {
config: Arc<RwLock<Config>>,
http_client: HttpClient,
chat: Chat,
completions: Completions,
models: Models,
embeddings: Embeddings,
}
impl OpenAI {
#[must_use]
pub fn new(api_key: &str, base_url: &str) -> OpenAI {
let config = Config::new(api_key.to_string(), base_url.to_string());
let http_client = HttpClient::new(config);
OpenAI {
chat: Chat::new(http_client.clone()),
completions: Completions::new(http_client.clone()),
models: Models::new(http_client.clone()),
embeddings: Embeddings::new(http_client.clone()),
config: http_client.config(),
http_client,
}
}
#[must_use]
pub fn with_config(config: Config) -> OpenAI {
let http_client = HttpClient::new(config);
OpenAI {
chat: Chat::new(http_client.clone()),
completions: Completions::new(http_client.clone()),
models: Models::new(http_client.clone()),
embeddings: Embeddings::new(http_client.clone()),
config: http_client.config(),
http_client,
}
}
#[doc = include_str!("../docs/from_env.md")]
#[must_use]
pub fn from_env() -> Result<Self, String> {
let api_key = std::env::var("OPENAI_API_KEY")
.map_err(|_| "The `OPENAI_API_KEY` environment variable is not set.")?;
let base_url =
std::env::var("OPENAI_BASE_URL").unwrap_or("https://api.openai.com/v1".to_string());
let mut config = Config::new(api_key, base_url);
if let Ok(timeout) = std::env::var("OPENAI_TIMEOUT") {
if let Ok(timeout) = timeout.parse::<u64>() {
config.with_timeout(Duration::from_secs(timeout));
}
}
if let Ok(connect_timeout) = std::env::var("OPENAI_CONNECT_TIMEOUT") {
if let Ok(connect_timeout) = connect_timeout.parse::<u64>() {
config.with_connect_timeout(Duration::from_secs(connect_timeout));
}
}
if let Ok(retry_count) = std::env::var("OPENAI_RETRY_COUNT") {
if let Ok(retry_count) = retry_count.parse::<usize>() {
config.with_retry_count(retry_count);
}
}
if let Ok(proxy) = std::env::var("OPENAI_PROXY") {
config.with_proxy(proxy);
}
if let Ok(user_agent) = std::env::var("OPENAI_USER_AGENT") {
config.with_user_agent(HeaderValue::from_str(&user_agent).unwrap_or_else(|_| {
panic!("Cannot convert the value `{user_agent}` of environment variable `OPENAI_USER_AGENT` to HeaderValue, please check if the value is valid.")
}));
}
Ok(Self::with_config(config))
}
}
impl OpenAI {
pub async fn update_config<F>(&self, update_fn: F)
where
F: FnOnce(&mut Config),
{
{
let mut config_guard = self.config.write().await;
update_fn(&mut config_guard);
}
self.http_client.refresh_client().await;
}
#[doc = include_str!("../docs/chat.md")]
#[inline]
pub fn chat(&self) -> &Chat {
&self.chat
}
#[doc = include_str!("../docs/completions.md")]
#[inline]
pub fn completions(&self) -> &Completions {
&self.completions
}
#[doc = include_str!("../docs/models.md")]
#[inline]
pub fn models(&self) -> &Models {
&self.models
}
#[doc = include_str!("../docs/embeddings.md")]
#[inline]
pub fn embeddings(&self) -> &Embeddings {
&self.embeddings
}
pub async fn base_url(&self) -> String {
self.config.read().await.base_url().to_string()
}
pub async fn api_key(&self) -> String {
self.config.read().await.api_key().to_string()
}
pub async fn config(&self) -> RwLockReadGuard<'_, Config> {
self.config.read().await
}
pub async fn with_base_url(&self, base_url: impl Into<String>) {
self.config.write().await.with_base_url(base_url);
}
pub async fn with_api_key(&self, api_key: impl Into<String>) {
self.config.write().await.with_api_key(api_key);
}
pub async fn with_timeout(&self, timeout: Duration) {
self.update_config(|config| {
config.with_timeout(timeout);
})
.await;
}
pub async fn with_connect_timeout(&self, connect_timeout: Duration) {
self.update_config(|config| {
config.with_connect_timeout(connect_timeout);
})
.await;
}
pub async fn with_retry_count(&self, retry_count: usize) {
self.config.write().await.with_retry_count(retry_count);
}
pub async fn with_proxy(&self, proxy: impl Into<String>) {
self.update_config(|config| {
config.with_proxy(proxy);
})
.await;
}
pub async fn with_user_agent(&self, user_agent: HeaderValue) {
self.update_config(|config| {
config.with_user_agent(user_agent);
})
.await;
}
}