1use reqwest::header::{HeaderMap, AUTHORIZATION};
3use secrecy::{ExposeSecret, SecretString};
4use serde::Deserialize;
5
6use crate::error::OpenAIError;
7
8pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
10pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
12pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
14
15pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
17
18pub trait Config: Send + Sync {
21 fn headers(&self) -> HeaderMap;
22 fn url(&self, path: &str) -> String;
23 fn query(&self) -> Vec<(&str, &str)>;
24
25 fn api_base(&self) -> &str;
26
27 fn api_key(&self) -> &SecretString;
28}
29
30macro_rules! impl_config_for_ptr {
32 ($t:ty) => {
33 impl Config for $t {
34 fn headers(&self) -> HeaderMap {
35 self.as_ref().headers()
36 }
37 fn url(&self, path: &str) -> String {
38 self.as_ref().url(path)
39 }
40 fn query(&self) -> Vec<(&str, &str)> {
41 self.as_ref().query()
42 }
43 fn api_base(&self) -> &str {
44 self.as_ref().api_base()
45 }
46 fn api_key(&self) -> &SecretString {
47 self.as_ref().api_key()
48 }
49 }
50 };
51}
52
53impl_config_for_ptr!(Box<dyn Config>);
54impl_config_for_ptr!(std::sync::Arc<dyn Config>);
55
56fn default_api_base() -> String {
58 #[cfg(not(target_family = "wasm"))]
59 {
60 std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| OPENAI_API_BASE.to_string())
61 }
62
63 #[cfg(target_family = "wasm")]
64 {
65 OPENAI_API_BASE.to_string()
66 }
67}
68
69fn default_api_key() -> String {
70 #[cfg(not(target_family = "wasm"))]
71 {
72 std::env::var("OPENAI_API_KEY")
73 .or_else(|_| {
74 std::env::var("OPENAI_ADMIN_KEY").map(|admin_key| {
75 tracing::warn!("Using OPENAI_ADMIN_KEY, OPENAI_API_KEY not set");
76 admin_key
77 })
78 })
79 .unwrap_or_default()
80 }
81
82 #[cfg(target_family = "wasm")]
83 {
84 String::new()
85 }
86}
87
88fn default_org_id() -> String {
89 #[cfg(not(target_family = "wasm"))]
90 {
91 std::env::var("OPENAI_ORG_ID").unwrap_or_default()
92 }
93
94 #[cfg(target_family = "wasm")]
95 {
96 String::new()
97 }
98}
99
100fn default_project_id() -> String {
101 #[cfg(not(target_family = "wasm"))]
102 {
103 std::env::var("OPENAI_PROJECT_ID").unwrap_or_default()
104 }
105
106 #[cfg(target_family = "wasm")]
107 {
108 String::new()
109 }
110}
111
112#[derive(Clone, Debug, Deserialize)]
114#[serde(default)]
115pub struct OpenAIConfig {
116 api_base: String,
117 api_key: SecretString,
118 org_id: String,
119 project_id: String,
120 #[serde(skip)]
121 custom_headers: HeaderMap,
122}
123
124impl Default for OpenAIConfig {
125 fn default() -> Self {
126 Self {
127 api_base: default_api_base(),
128 api_key: default_api_key().into(),
129 org_id: default_org_id(),
130 project_id: default_project_id(),
131 custom_headers: HeaderMap::new(),
132 }
133 }
134}
135
136impl OpenAIConfig {
137 pub fn new() -> Self {
139 Default::default()
140 }
141
142 pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
144 self.org_id = org_id.into();
145 self
146 }
147
148 pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
150 self.project_id = project_id.into();
151 self
152 }
153
154 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
156 self.api_key = SecretString::from(api_key.into());
157 self
158 }
159
160 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
162 self.api_base = api_base.into();
163 self
164 }
165
166 pub fn with_header<K, V>(mut self, key: K, value: V) -> Result<Self, OpenAIError>
169 where
170 K: reqwest::header::IntoHeaderName,
171 V: TryInto<reqwest::header::HeaderValue>,
172 V::Error: Into<reqwest::header::InvalidHeaderValue>,
173 {
174 let header_value = value.try_into().map_err(|e| {
175 OpenAIError::InvalidArgument(format!("Invalid header value: {}", e.into()))
176 })?;
177 self.custom_headers.insert(key, header_value);
178 Ok(self)
179 }
180
181 pub fn org_id(&self) -> &str {
182 &self.org_id
183 }
184}
185
186impl Config for OpenAIConfig {
187 fn headers(&self) -> HeaderMap {
188 let mut headers = HeaderMap::new();
189 if !self.org_id.is_empty() {
190 headers.insert(
191 OPENAI_ORGANIZATION_HEADER,
192 self.org_id.as_str().parse().unwrap(),
193 );
194 }
195
196 if !self.project_id.is_empty() {
197 headers.insert(
198 OPENAI_PROJECT_HEADER,
199 self.project_id.as_str().parse().unwrap(),
200 );
201 }
202
203 headers.insert(
204 AUTHORIZATION,
205 format!("Bearer {}", self.api_key.expose_secret())
206 .as_str()
207 .parse()
208 .unwrap(),
209 );
210
211 for (key, value) in self.custom_headers.iter() {
213 headers.insert(key, value.clone());
214 }
215
216 headers
217 }
218
219 fn url(&self, path: &str) -> String {
220 format!("{}{}", self.api_base, path)
221 }
222
223 fn api_base(&self) -> &str {
224 &self.api_base
225 }
226
227 fn api_key(&self) -> &SecretString {
228 &self.api_key
229 }
230
231 fn query(&self) -> Vec<(&str, &str)> {
232 vec![]
233 }
234}
235
236#[derive(Clone, Debug, Deserialize)]
238#[serde(default)]
239pub struct AzureConfig {
240 api_version: String,
241 deployment_id: String,
242 api_base: String,
243 api_key: SecretString,
244}
245
246impl Default for AzureConfig {
247 fn default() -> Self {
248 Self {
249 api_base: Default::default(),
250 api_key: default_api_key().into(),
251 deployment_id: Default::default(),
252 api_version: Default::default(),
253 }
254 }
255}
256
257impl AzureConfig {
258 pub fn new() -> Self {
259 Default::default()
260 }
261
262 pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
263 self.api_version = api_version.into();
264 self
265 }
266
267 pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
268 self.deployment_id = deployment_id.into();
269 self
270 }
271
272 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
274 self.api_key = SecretString::from(api_key.into());
275 self
276 }
277
278 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
280 self.api_base = api_base.into();
281 self
282 }
283}
284
285impl Config for AzureConfig {
286 fn headers(&self) -> HeaderMap {
287 let mut headers = HeaderMap::new();
288
289 headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
290
291 headers
292 }
293
294 fn url(&self, path: &str) -> String {
295 format!(
296 "{}/openai/deployments/{}{}",
297 self.api_base, self.deployment_id, path
298 )
299 }
300
301 fn api_base(&self) -> &str {
302 &self.api_base
303 }
304
305 fn api_key(&self) -> &SecretString {
306 &self.api_key
307 }
308
309 fn query(&self) -> Vec<(&str, &str)> {
310 vec![("api-version", &self.api_version)]
311 }
312}
313
314#[cfg(all(test, feature = "chat-completion"))]
315mod test {
316 use super::*;
317 use crate::types::chat::{
318 ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
319 };
320 use crate::Client;
321 use std::sync::Arc;
322 #[test]
323 fn test_client_creation() {
324 unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
325 let openai_config = OpenAIConfig::default();
326 let config = Box::new(openai_config.clone()) as Box<dyn Config>;
327 let client = Client::with_config(config);
328 assert!(client.config().url("").ends_with("/v1"));
329
330 let config = Arc::new(openai_config) as Arc<dyn Config>;
331 let client = Client::with_config(config);
332 assert!(client.config().url("").ends_with("/v1"));
333 let cloned_client = client.clone();
334 assert!(cloned_client.config().url("").ends_with("/v1"));
335 }
336
337 async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
338 let _ = client.chat().create(CreateChatCompletionRequest {
339 model: "gpt-4o".to_string(),
340 messages: vec![ChatCompletionRequestMessage::User(
341 ChatCompletionRequestUserMessage {
342 content: "Hello, world!".into(),
343 ..Default::default()
344 },
345 )],
346 ..Default::default()
347 });
348 }
349
350 #[tokio::test]
351 async fn test_dynamic_dispatch() {
352 let openai_config = OpenAIConfig::default();
353 let azure_config = AzureConfig::default();
354
355 let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
356 let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
357
358 let _ = dynamic_dispatch_compiles(&azure_client).await;
359 let _ = dynamic_dispatch_compiles(&oai_client).await;
360
361 let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
362 let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
363 }
364}