1use reqwest::header::{HeaderMap, AUTHORIZATION};
3use secrecy::{ExposeSecret, SecretString};
4use serde::Deserialize;
5
6use crate::error::OpenAIError;
7
8pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
10pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
12pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
14
15pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
17
18pub trait Config: Send + Sync {
21 fn headers(&self) -> HeaderMap;
22 fn url(&self, path: &str) -> String;
23 fn query(&self) -> Vec<(&str, &str)>;
24
25 fn api_base(&self) -> &str;
26
27 fn api_key(&self) -> &SecretString;
28}
29
30macro_rules! impl_config_for_ptr {
32 ($t:ty) => {
33 impl Config for $t {
34 fn headers(&self) -> HeaderMap {
35 self.as_ref().headers()
36 }
37 fn url(&self, path: &str) -> String {
38 self.as_ref().url(path)
39 }
40 fn query(&self) -> Vec<(&str, &str)> {
41 self.as_ref().query()
42 }
43 fn api_base(&self) -> &str {
44 self.as_ref().api_base()
45 }
46 fn api_key(&self) -> &SecretString {
47 self.as_ref().api_key()
48 }
49 }
50 };
51}
52
53impl_config_for_ptr!(Box<dyn Config>);
54impl_config_for_ptr!(std::sync::Arc<dyn Config>);
55
56#[derive(Clone, Debug, Deserialize)]
58#[serde(default)]
59pub struct OpenAIConfig {
60 api_base: String,
61 api_key: SecretString,
62 org_id: String,
63 project_id: String,
64 #[serde(skip)]
65 custom_headers: HeaderMap,
66}
67
68impl Default for OpenAIConfig {
69 fn default() -> Self {
70 Self {
71 api_base: OPENAI_API_BASE.to_string(),
72 api_key: std::env::var("OPENAI_API_KEY")
73 .or_else(|_| {
74 std::env::var("OPENAI_ADMIN_KEY").map(|admin_key| {
75 tracing::warn!("Using OPENAI_ADMIN_KEY, OPENAI_API_KEY not set");
76 admin_key
77 })
78 })
79 .unwrap_or_default()
80 .into(),
81 org_id: Default::default(),
82 project_id: Default::default(),
83 custom_headers: HeaderMap::new(),
84 }
85 }
86}
87
88impl OpenAIConfig {
89 pub fn new() -> Self {
91 Default::default()
92 }
93
94 pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
96 self.org_id = org_id.into();
97 self
98 }
99
100 pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
102 self.project_id = project_id.into();
103 self
104 }
105
106 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
108 self.api_key = SecretString::from(api_key.into());
109 self
110 }
111
112 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
114 self.api_base = api_base.into();
115 self
116 }
117
118 pub fn with_header<K, V>(mut self, key: K, value: V) -> Result<Self, OpenAIError>
121 where
122 K: reqwest::header::IntoHeaderName,
123 V: TryInto<reqwest::header::HeaderValue>,
124 V::Error: Into<reqwest::header::InvalidHeaderValue>,
125 {
126 let header_value = value.try_into().map_err(|e| {
127 OpenAIError::InvalidArgument(format!("Invalid header value: {}", e.into()))
128 })?;
129 self.custom_headers.insert(key, header_value);
130 Ok(self)
131 }
132
133 pub fn org_id(&self) -> &str {
134 &self.org_id
135 }
136}
137
138impl Config for OpenAIConfig {
139 fn headers(&self) -> HeaderMap {
140 let mut headers = HeaderMap::new();
141 if !self.org_id.is_empty() {
142 headers.insert(
143 OPENAI_ORGANIZATION_HEADER,
144 self.org_id.as_str().parse().unwrap(),
145 );
146 }
147
148 if !self.project_id.is_empty() {
149 headers.insert(
150 OPENAI_PROJECT_HEADER,
151 self.project_id.as_str().parse().unwrap(),
152 );
153 }
154
155 headers.insert(
156 AUTHORIZATION,
157 format!("Bearer {}", self.api_key.expose_secret())
158 .as_str()
159 .parse()
160 .unwrap(),
161 );
162
163 for (key, value) in self.custom_headers.iter() {
165 headers.insert(key, value.clone());
166 }
167
168 headers
169 }
170
171 fn url(&self, path: &str) -> String {
172 format!("{}{}", self.api_base, path)
173 }
174
175 fn api_base(&self) -> &str {
176 &self.api_base
177 }
178
179 fn api_key(&self) -> &SecretString {
180 &self.api_key
181 }
182
183 fn query(&self) -> Vec<(&str, &str)> {
184 vec![]
185 }
186}
187
188#[derive(Clone, Debug, Deserialize)]
190#[serde(default)]
191pub struct AzureConfig {
192 api_version: String,
193 deployment_id: String,
194 api_base: String,
195 api_key: SecretString,
196}
197
198impl Default for AzureConfig {
199 fn default() -> Self {
200 Self {
201 api_base: Default::default(),
202 api_key: std::env::var("OPENAI_API_KEY")
203 .unwrap_or_else(|_| "".to_string())
204 .into(),
205 deployment_id: Default::default(),
206 api_version: Default::default(),
207 }
208 }
209}
210
211impl AzureConfig {
212 pub fn new() -> Self {
213 Default::default()
214 }
215
216 pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
217 self.api_version = api_version.into();
218 self
219 }
220
221 pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
222 self.deployment_id = deployment_id.into();
223 self
224 }
225
226 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
228 self.api_key = SecretString::from(api_key.into());
229 self
230 }
231
232 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
234 self.api_base = api_base.into();
235 self
236 }
237}
238
239impl Config for AzureConfig {
240 fn headers(&self) -> HeaderMap {
241 let mut headers = HeaderMap::new();
242
243 headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
244
245 headers
246 }
247
248 fn url(&self, path: &str) -> String {
249 format!(
250 "{}/openai/deployments/{}{}",
251 self.api_base, self.deployment_id, path
252 )
253 }
254
255 fn api_base(&self) -> &str {
256 &self.api_base
257 }
258
259 fn api_key(&self) -> &SecretString {
260 &self.api_key
261 }
262
263 fn query(&self) -> Vec<(&str, &str)> {
264 vec![("api-version", &self.api_version)]
265 }
266}
267
268#[cfg(test)]
269mod test {
270 use super::*;
271 use crate::types::chat::{
272 ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
273 };
274 use crate::Client;
275 use std::sync::Arc;
276 #[test]
277 fn test_client_creation() {
278 unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
279 let openai_config = OpenAIConfig::default();
280 let config = Box::new(openai_config.clone()) as Box<dyn Config>;
281 let client = Client::with_config(config);
282 assert!(client.config().url("").ends_with("/v1"));
283
284 let config = Arc::new(openai_config) as Arc<dyn Config>;
285 let client = Client::with_config(config);
286 assert!(client.config().url("").ends_with("/v1"));
287 let cloned_client = client.clone();
288 assert!(cloned_client.config().url("").ends_with("/v1"));
289 }
290
291 async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
292 let _ = client.chat().create(CreateChatCompletionRequest {
293 model: "gpt-4o".to_string(),
294 messages: vec![ChatCompletionRequestMessage::User(
295 ChatCompletionRequestUserMessage {
296 content: "Hello, world!".into(),
297 ..Default::default()
298 },
299 )],
300 ..Default::default()
301 });
302 }
303
304 #[tokio::test]
305 async fn test_dynamic_dispatch() {
306 let openai_config = OpenAIConfig::default();
307 let azure_config = AzureConfig::default();
308
309 let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
310 let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
311
312 let _ = dynamic_dispatch_compiles(&azure_client).await;
313 let _ = dynamic_dispatch_compiles(&oai_client).await;
314
315 let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
316 let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
317 }
318}