1use reqwest::header::{HeaderMap, AUTHORIZATION};
3use secrecy::{ExposeSecret, SecretString};
4use serde::Deserialize;
5
6use crate::error::OpenAIError;
7
8pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
10pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
12pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
14
15pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
17
18pub trait Config: Send + Sync {
21 fn headers(&self) -> HeaderMap;
22 fn url(&self, path: &str) -> String;
23 fn query(&self) -> Vec<(&str, &str)>;
24
25 fn api_base(&self) -> &str;
26
27 fn api_key(&self) -> &SecretString;
28}
29
30macro_rules! impl_config_for_ptr {
32 ($t:ty) => {
33 impl Config for $t {
34 fn headers(&self) -> HeaderMap {
35 self.as_ref().headers()
36 }
37 fn url(&self, path: &str) -> String {
38 self.as_ref().url(path)
39 }
40 fn query(&self) -> Vec<(&str, &str)> {
41 self.as_ref().query()
42 }
43 fn api_base(&self) -> &str {
44 self.as_ref().api_base()
45 }
46 fn api_key(&self) -> &SecretString {
47 self.as_ref().api_key()
48 }
49 }
50 };
51}
52
53impl_config_for_ptr!(Box<dyn Config>);
54impl_config_for_ptr!(std::sync::Arc<dyn Config>);
55
56#[derive(Clone, Debug, Deserialize)]
58#[serde(default)]
59pub struct OpenAIConfig {
60 api_base: String,
61 api_key: SecretString,
62 org_id: String,
63 project_id: String,
64 #[serde(skip)]
65 custom_headers: HeaderMap,
66}
67
68impl Default for OpenAIConfig {
69 fn default() -> Self {
70 Self {
71 api_base: OPENAI_API_BASE.to_string(),
72 api_key: std::env::var("OPENAI_API_KEY")
73 .unwrap_or_else(|_| "".to_string())
74 .into(),
75 org_id: Default::default(),
76 project_id: Default::default(),
77 custom_headers: HeaderMap::new(),
78 }
79 }
80}
81
82impl OpenAIConfig {
83 pub fn new() -> Self {
85 Default::default()
86 }
87
88 pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
90 self.org_id = org_id.into();
91 self
92 }
93
94 pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
96 self.project_id = project_id.into();
97 self
98 }
99
100 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
102 self.api_key = SecretString::from(api_key.into());
103 self
104 }
105
106 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
108 self.api_base = api_base.into();
109 self
110 }
111
112 pub fn with_header<K, V>(mut self, key: K, value: V) -> Result<Self, OpenAIError>
115 where
116 K: reqwest::header::IntoHeaderName,
117 V: TryInto<reqwest::header::HeaderValue>,
118 V::Error: Into<reqwest::header::InvalidHeaderValue>,
119 {
120 let header_value = value.try_into().map_err(|e| {
121 OpenAIError::InvalidArgument(format!("Invalid header value: {}", e.into()))
122 })?;
123 self.custom_headers.insert(key, header_value);
124 Ok(self)
125 }
126
127 pub fn org_id(&self) -> &str {
128 &self.org_id
129 }
130}
131
132impl Config for OpenAIConfig {
133 fn headers(&self) -> HeaderMap {
134 let mut headers = HeaderMap::new();
135 if !self.org_id.is_empty() {
136 headers.insert(
137 OPENAI_ORGANIZATION_HEADER,
138 self.org_id.as_str().parse().unwrap(),
139 );
140 }
141
142 if !self.project_id.is_empty() {
143 headers.insert(
144 OPENAI_PROJECT_HEADER,
145 self.project_id.as_str().parse().unwrap(),
146 );
147 }
148
149 headers.insert(
150 AUTHORIZATION,
151 format!("Bearer {}", self.api_key.expose_secret())
152 .as_str()
153 .parse()
154 .unwrap(),
155 );
156
157 for (key, value) in self.custom_headers.iter() {
159 headers.insert(key, value.clone());
160 }
161
162 headers
163 }
164
165 fn url(&self, path: &str) -> String {
166 format!("{}{}", self.api_base, path)
167 }
168
169 fn api_base(&self) -> &str {
170 &self.api_base
171 }
172
173 fn api_key(&self) -> &SecretString {
174 &self.api_key
175 }
176
177 fn query(&self) -> Vec<(&str, &str)> {
178 vec![]
179 }
180}
181
182#[derive(Clone, Debug, Deserialize)]
184#[serde(default)]
185pub struct AzureConfig {
186 api_version: String,
187 deployment_id: String,
188 api_base: String,
189 api_key: SecretString,
190}
191
192impl Default for AzureConfig {
193 fn default() -> Self {
194 Self {
195 api_base: Default::default(),
196 api_key: std::env::var("OPENAI_API_KEY")
197 .unwrap_or_else(|_| "".to_string())
198 .into(),
199 deployment_id: Default::default(),
200 api_version: Default::default(),
201 }
202 }
203}
204
205impl AzureConfig {
206 pub fn new() -> Self {
207 Default::default()
208 }
209
210 pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
211 self.api_version = api_version.into();
212 self
213 }
214
215 pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
216 self.deployment_id = deployment_id.into();
217 self
218 }
219
220 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
222 self.api_key = SecretString::from(api_key.into());
223 self
224 }
225
226 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
228 self.api_base = api_base.into();
229 self
230 }
231}
232
233impl Config for AzureConfig {
234 fn headers(&self) -> HeaderMap {
235 let mut headers = HeaderMap::new();
236
237 headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
238
239 headers
240 }
241
242 fn url(&self, path: &str) -> String {
243 format!(
244 "{}/openai/deployments/{}{}",
245 self.api_base, self.deployment_id, path
246 )
247 }
248
249 fn api_base(&self) -> &str {
250 &self.api_base
251 }
252
253 fn api_key(&self) -> &SecretString {
254 &self.api_key
255 }
256
257 fn query(&self) -> Vec<(&str, &str)> {
258 vec![("api-version", &self.api_version)]
259 }
260}
261
262#[cfg(test)]
263mod test {
264 use super::*;
265 use crate::types::chat::{
266 ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
267 };
268 use crate::Client;
269 use std::sync::Arc;
270 #[test]
271 fn test_client_creation() {
272 unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
273 let openai_config = OpenAIConfig::default();
274 let config = Box::new(openai_config.clone()) as Box<dyn Config>;
275 let client = Client::with_config(config);
276 assert!(client.config().url("").ends_with("/v1"));
277
278 let config = Arc::new(openai_config) as Arc<dyn Config>;
279 let client = Client::with_config(config);
280 assert!(client.config().url("").ends_with("/v1"));
281 let cloned_client = client.clone();
282 assert!(cloned_client.config().url("").ends_with("/v1"));
283 }
284
285 async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
286 let _ = client.chat().create(CreateChatCompletionRequest {
287 model: "gpt-4o".to_string(),
288 messages: vec![ChatCompletionRequestMessage::User(
289 ChatCompletionRequestUserMessage {
290 content: "Hello, world!".into(),
291 ..Default::default()
292 },
293 )],
294 ..Default::default()
295 });
296 }
297
298 #[tokio::test]
299 async fn test_dynamic_dispatch() {
300 let openai_config = OpenAIConfig::default();
301 let azure_config = AzureConfig::default();
302
303 let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
304 let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
305
306 let _ = dynamic_dispatch_compiles(&azure_client).await;
307 let _ = dynamic_dispatch_compiles(&oai_client).await;
308
309 let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
310 let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
311 }
312}