1use reqwest::header::{HeaderMap, AUTHORIZATION};
3use secrecy::{ExposeSecret, SecretString};
4use serde::Deserialize;
5
6use crate::error::OpenAIError;
7
8pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
10pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
12pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
14
15pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
17
18pub trait Config: Send + Sync {
21 fn headers(&self) -> HeaderMap;
22 fn url(&self, path: &str) -> String;
23 fn query(&self) -> Vec<(&str, &str)>;
24
25 fn api_base(&self) -> &str;
26
27 fn api_key(&self) -> &SecretString;
28}
29
30macro_rules! impl_config_for_ptr {
32 ($t:ty) => {
33 impl Config for $t {
34 fn headers(&self) -> HeaderMap {
35 self.as_ref().headers()
36 }
37 fn url(&self, path: &str) -> String {
38 self.as_ref().url(path)
39 }
40 fn query(&self) -> Vec<(&str, &str)> {
41 self.as_ref().query()
42 }
43 fn api_base(&self) -> &str {
44 self.as_ref().api_base()
45 }
46 fn api_key(&self) -> &SecretString {
47 self.as_ref().api_key()
48 }
49 }
50 };
51}
52
53impl_config_for_ptr!(Box<dyn Config>);
54impl_config_for_ptr!(std::sync::Arc<dyn Config>);
55
56#[derive(Clone, Debug, Deserialize)]
58#[serde(default)]
59pub struct OpenAIConfig {
60 api_base: String,
61 api_key: SecretString,
62 org_id: String,
63 project_id: String,
64 #[serde(skip)]
65 custom_headers: HeaderMap,
66}
67
68impl Default for OpenAIConfig {
69 fn default() -> Self {
70 Self {
71 api_base: std::env::var("OPENAI_BASE_URL")
72 .unwrap_or_else(|_| OPENAI_API_BASE.to_string()),
73 api_key: std::env::var("OPENAI_API_KEY")
74 .or_else(|_| {
75 std::env::var("OPENAI_ADMIN_KEY").map(|admin_key| {
76 tracing::warn!("Using OPENAI_ADMIN_KEY, OPENAI_API_KEY not set");
77 admin_key
78 })
79 })
80 .unwrap_or_default()
81 .into(),
82 org_id: std::env::var("OPENAI_ORG_ID").unwrap_or_default(),
83 project_id: std::env::var("OPENAI_PROJECT_ID").unwrap_or_default(),
84 custom_headers: HeaderMap::new(),
85 }
86 }
87}
88
89impl OpenAIConfig {
90 pub fn new() -> Self {
92 Default::default()
93 }
94
95 pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
97 self.org_id = org_id.into();
98 self
99 }
100
101 pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
103 self.project_id = project_id.into();
104 self
105 }
106
107 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
109 self.api_key = SecretString::from(api_key.into());
110 self
111 }
112
113 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
115 self.api_base = api_base.into();
116 self
117 }
118
119 pub fn with_header<K, V>(mut self, key: K, value: V) -> Result<Self, OpenAIError>
122 where
123 K: reqwest::header::IntoHeaderName,
124 V: TryInto<reqwest::header::HeaderValue>,
125 V::Error: Into<reqwest::header::InvalidHeaderValue>,
126 {
127 let header_value = value.try_into().map_err(|e| {
128 OpenAIError::InvalidArgument(format!("Invalid header value: {}", e.into()))
129 })?;
130 self.custom_headers.insert(key, header_value);
131 Ok(self)
132 }
133
134 pub fn org_id(&self) -> &str {
135 &self.org_id
136 }
137}
138
139impl Config for OpenAIConfig {
140 fn headers(&self) -> HeaderMap {
141 let mut headers = HeaderMap::new();
142 if !self.org_id.is_empty() {
143 headers.insert(
144 OPENAI_ORGANIZATION_HEADER,
145 self.org_id.as_str().parse().unwrap(),
146 );
147 }
148
149 if !self.project_id.is_empty() {
150 headers.insert(
151 OPENAI_PROJECT_HEADER,
152 self.project_id.as_str().parse().unwrap(),
153 );
154 }
155
156 headers.insert(
157 AUTHORIZATION,
158 format!("Bearer {}", self.api_key.expose_secret())
159 .as_str()
160 .parse()
161 .unwrap(),
162 );
163
164 for (key, value) in self.custom_headers.iter() {
166 headers.insert(key, value.clone());
167 }
168
169 headers
170 }
171
172 fn url(&self, path: &str) -> String {
173 format!("{}{}", self.api_base, path)
174 }
175
176 fn api_base(&self) -> &str {
177 &self.api_base
178 }
179
180 fn api_key(&self) -> &SecretString {
181 &self.api_key
182 }
183
184 fn query(&self) -> Vec<(&str, &str)> {
185 vec![]
186 }
187}
188
189#[derive(Clone, Debug, Deserialize)]
191#[serde(default)]
192pub struct AzureConfig {
193 api_version: String,
194 deployment_id: String,
195 api_base: String,
196 api_key: SecretString,
197}
198
199impl Default for AzureConfig {
200 fn default() -> Self {
201 Self {
202 api_base: Default::default(),
203 api_key: std::env::var("OPENAI_API_KEY")
204 .unwrap_or_else(|_| "".to_string())
205 .into(),
206 deployment_id: Default::default(),
207 api_version: Default::default(),
208 }
209 }
210}
211
212impl AzureConfig {
213 pub fn new() -> Self {
214 Default::default()
215 }
216
217 pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
218 self.api_version = api_version.into();
219 self
220 }
221
222 pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
223 self.deployment_id = deployment_id.into();
224 self
225 }
226
227 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
229 self.api_key = SecretString::from(api_key.into());
230 self
231 }
232
233 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
235 self.api_base = api_base.into();
236 self
237 }
238}
239
240impl Config for AzureConfig {
241 fn headers(&self) -> HeaderMap {
242 let mut headers = HeaderMap::new();
243
244 headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
245
246 headers
247 }
248
249 fn url(&self, path: &str) -> String {
250 format!(
251 "{}/openai/deployments/{}{}",
252 self.api_base, self.deployment_id, path
253 )
254 }
255
256 fn api_base(&self) -> &str {
257 &self.api_base
258 }
259
260 fn api_key(&self) -> &SecretString {
261 &self.api_key
262 }
263
264 fn query(&self) -> Vec<(&str, &str)> {
265 vec![("api-version", &self.api_version)]
266 }
267}
268
269#[cfg(all(test, feature = "chat-completion"))]
270mod test {
271 use super::*;
272 use crate::types::chat::{
273 ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
274 };
275 use crate::Client;
276 use std::sync::Arc;
277 #[test]
278 fn test_client_creation() {
279 unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
280 let openai_config = OpenAIConfig::default();
281 let config = Box::new(openai_config.clone()) as Box<dyn Config>;
282 let client = Client::with_config(config);
283 assert!(client.config().url("").ends_with("/v1"));
284
285 let config = Arc::new(openai_config) as Arc<dyn Config>;
286 let client = Client::with_config(config);
287 assert!(client.config().url("").ends_with("/v1"));
288 let cloned_client = client.clone();
289 assert!(cloned_client.config().url("").ends_with("/v1"));
290 }
291
292 async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
293 let _ = client.chat().create(CreateChatCompletionRequest {
294 model: "gpt-4o".to_string(),
295 messages: vec![ChatCompletionRequestMessage::User(
296 ChatCompletionRequestUserMessage {
297 content: "Hello, world!".into(),
298 ..Default::default()
299 },
300 )],
301 ..Default::default()
302 });
303 }
304
305 #[tokio::test]
306 async fn test_dynamic_dispatch() {
307 let openai_config = OpenAIConfig::default();
308 let azure_config = AzureConfig::default();
309
310 let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
311 let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
312
313 let _ = dynamic_dispatch_compiles(&azure_client).await;
314 let _ = dynamic_dispatch_compiles(&oai_client).await;
315
316 let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
317 let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
318 }
319}