use reqwest::header::{HeaderMap, AUTHORIZATION};
use secrecy::{ExposeSecret, SecretString};
use serde::Deserialize;
use crate::error::OpenAIError;
pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
pub trait Config: Send + Sync {
fn headers(&self) -> HeaderMap;
fn url(&self, path: &str) -> String;
fn query(&self) -> Vec<(&str, &str)>;
fn api_base(&self) -> &str;
fn api_key(&self) -> &SecretString;
}
macro_rules! impl_config_for_ptr {
($t:ty) => {
impl Config for $t {
fn headers(&self) -> HeaderMap {
self.as_ref().headers()
}
fn url(&self, path: &str) -> String {
self.as_ref().url(path)
}
fn query(&self) -> Vec<(&str, &str)> {
self.as_ref().query()
}
fn api_base(&self) -> &str {
self.as_ref().api_base()
}
fn api_key(&self) -> &SecretString {
self.as_ref().api_key()
}
}
};
}
impl_config_for_ptr!(Box<dyn Config>);
impl_config_for_ptr!(std::sync::Arc<dyn Config>);
fn default_api_base() -> String {
#[cfg(not(target_family = "wasm"))]
{
std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| OPENAI_API_BASE.to_string())
}
#[cfg(target_family = "wasm")]
{
OPENAI_API_BASE.to_string()
}
}
fn default_api_key() -> String {
#[cfg(not(target_family = "wasm"))]
{
std::env::var("OPENAI_API_KEY")
.or_else(|_| {
std::env::var("OPENAI_ADMIN_KEY").map(|admin_key| {
tracing::warn!("Using OPENAI_ADMIN_KEY, OPENAI_API_KEY not set");
admin_key
})
})
.unwrap_or_default()
}
#[cfg(target_family = "wasm")]
{
String::new()
}
}
fn default_org_id() -> String {
#[cfg(not(target_family = "wasm"))]
{
std::env::var("OPENAI_ORG_ID").unwrap_or_default()
}
#[cfg(target_family = "wasm")]
{
String::new()
}
}
fn default_project_id() -> String {
#[cfg(not(target_family = "wasm"))]
{
std::env::var("OPENAI_PROJECT_ID").unwrap_or_default()
}
#[cfg(target_family = "wasm")]
{
String::new()
}
}
#[derive(Clone, Debug, Deserialize)]
#[serde(default)]
pub struct OpenAIConfig {
api_base: String,
api_key: SecretString,
org_id: String,
project_id: String,
#[serde(skip)]
custom_headers: HeaderMap,
}
impl Default for OpenAIConfig {
fn default() -> Self {
Self {
api_base: default_api_base(),
api_key: default_api_key().into(),
org_id: default_org_id(),
project_id: default_project_id(),
custom_headers: HeaderMap::new(),
}
}
}
impl OpenAIConfig {
pub fn new() -> Self {
Default::default()
}
pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
self.org_id = org_id.into();
self
}
pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
self.project_id = project_id.into();
self
}
pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
self.api_key = SecretString::from(api_key.into());
self
}
pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
self.api_base = api_base.into();
self
}
pub fn with_header<K, V>(mut self, key: K, value: V) -> Result<Self, OpenAIError>
where
K: reqwest::header::IntoHeaderName,
V: TryInto<reqwest::header::HeaderValue>,
V::Error: Into<reqwest::header::InvalidHeaderValue>,
{
let header_value = value.try_into().map_err(|e| {
OpenAIError::InvalidArgument(format!("Invalid header value: {}", e.into()))
})?;
self.custom_headers.insert(key, header_value);
Ok(self)
}
pub fn org_id(&self) -> &str {
&self.org_id
}
}
impl Config for OpenAIConfig {
fn headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
if !self.org_id.is_empty() {
headers.insert(
OPENAI_ORGANIZATION_HEADER,
self.org_id.as_str().parse().unwrap(),
);
}
if !self.project_id.is_empty() {
headers.insert(
OPENAI_PROJECT_HEADER,
self.project_id.as_str().parse().unwrap(),
);
}
headers.insert(
AUTHORIZATION,
format!("Bearer {}", self.api_key.expose_secret())
.as_str()
.parse()
.unwrap(),
);
for (key, value) in self.custom_headers.iter() {
headers.insert(key, value.clone());
}
headers
}
fn url(&self, path: &str) -> String {
format!("{}{}", self.api_base, path)
}
fn api_base(&self) -> &str {
&self.api_base
}
fn api_key(&self) -> &SecretString {
&self.api_key
}
fn query(&self) -> Vec<(&str, &str)> {
vec![]
}
}
#[derive(Clone, Debug, Deserialize)]
#[serde(default)]
pub struct AzureConfig {
api_version: String,
deployment_id: String,
api_base: String,
api_key: SecretString,
}
impl Default for AzureConfig {
fn default() -> Self {
Self {
api_base: Default::default(),
api_key: default_api_key().into(),
deployment_id: Default::default(),
api_version: Default::default(),
}
}
}
impl AzureConfig {
pub fn new() -> Self {
Default::default()
}
pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
self.api_version = api_version.into();
self
}
pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
self.deployment_id = deployment_id.into();
self
}
pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
self.api_key = SecretString::from(api_key.into());
self
}
pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
self.api_base = api_base.into();
self
}
}
impl Config for AzureConfig {
fn headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
headers
}
fn url(&self, path: &str) -> String {
format!(
"{}/openai/deployments/{}{}",
self.api_base, self.deployment_id, path
)
}
fn api_base(&self) -> &str {
&self.api_base
}
fn api_key(&self) -> &SecretString {
&self.api_key
}
fn query(&self) -> Vec<(&str, &str)> {
vec![("api-version", &self.api_version)]
}
}
#[cfg(all(test, feature = "chat-completion"))]
mod test {
use super::*;
use crate::types::chat::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
};
use crate::Client;
use std::sync::Arc;
#[test]
fn test_client_creation() {
unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
let openai_config = OpenAIConfig::default();
let config = Box::new(openai_config.clone()) as Box<dyn Config>;
let client = Client::with_config(config);
assert!(client.config().url("").ends_with("/v1"));
let config = Arc::new(openai_config) as Arc<dyn Config>;
let client = Client::with_config(config);
assert!(client.config().url("").ends_with("/v1"));
let cloned_client = client.clone();
assert!(cloned_client.config().url("").ends_with("/v1"));
}
async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
let _ = client.chat().create(CreateChatCompletionRequest {
model: "gpt-4o".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: "Hello, world!".into(),
..Default::default()
},
)],
..Default::default()
});
}
#[tokio::test]
async fn test_dynamic_dispatch() {
let openai_config = OpenAIConfig::default();
let azure_config = AzureConfig::default();
let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
let _ = dynamic_dispatch_compiles(&azure_client).await;
let _ = dynamic_dispatch_compiles(&oai_client).await;
let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
}
}