async_openai/
config.rs

1//! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service.
2use reqwest::header::{HeaderMap, AUTHORIZATION};
3use secrecy::{ExposeSecret, SecretString};
4use serde::Deserialize;
5
6use crate::error::OpenAIError;
7
8/// Default v1 API base url
9pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
10/// Organization header
11pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
12/// Project header
13pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
14
15/// Calls to the Assistants API require that you pass a Beta header
16pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
17
18/// [crate::Client] relies on this for every API call on OpenAI
19/// or Azure OpenAI service
20pub trait Config: Send + Sync {
21    fn headers(&self) -> HeaderMap;
22    fn url(&self, path: &str) -> String;
23    fn query(&self) -> Vec<(&str, &str)>;
24
25    fn api_base(&self) -> &str;
26
27    fn api_key(&self) -> &SecretString;
28}
29
30/// Macro to implement Config trait for pointer types with dyn objects
31macro_rules! impl_config_for_ptr {
32    ($t:ty) => {
33        impl Config for $t {
34            fn headers(&self) -> HeaderMap {
35                self.as_ref().headers()
36            }
37            fn url(&self, path: &str) -> String {
38                self.as_ref().url(path)
39            }
40            fn query(&self) -> Vec<(&str, &str)> {
41                self.as_ref().query()
42            }
43            fn api_base(&self) -> &str {
44                self.as_ref().api_base()
45            }
46            fn api_key(&self) -> &SecretString {
47                self.as_ref().api_key()
48            }
49        }
50    };
51}
52
53impl_config_for_ptr!(Box<dyn Config>);
54impl_config_for_ptr!(std::sync::Arc<dyn Config>);
55
56/// Configuration for OpenAI API
57#[derive(Clone, Debug, Deserialize)]
58#[serde(default)]
59pub struct OpenAIConfig {
60    api_base: String,
61    api_key: SecretString,
62    org_id: String,
63    project_id: String,
64    #[serde(skip)]
65    custom_headers: HeaderMap,
66}
67
68impl Default for OpenAIConfig {
69    fn default() -> Self {
70        Self {
71            api_base: OPENAI_API_BASE.to_string(),
72            api_key: std::env::var("OPENAI_API_KEY")
73                .or_else(|_| {
74                    std::env::var("OPENAI_ADMIN_KEY").map(|admin_key| {
75                        tracing::warn!("Using OPENAI_ADMIN_KEY, OPENAI_API_KEY not set");
76                        admin_key
77                    })
78                })
79                .unwrap_or_default()
80                .into(),
81            org_id: Default::default(),
82            project_id: Default::default(),
83            custom_headers: HeaderMap::new(),
84        }
85    }
86}
87
88impl OpenAIConfig {
89    /// Create client with default [OPENAI_API_BASE] url and default API key from OPENAI_API_KEY env var
90    pub fn new() -> Self {
91        Default::default()
92    }
93
94    /// To use a different organization id other than default
95    pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
96        self.org_id = org_id.into();
97        self
98    }
99
100    /// Non default project id
101    pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
102        self.project_id = project_id.into();
103        self
104    }
105
106    /// To use a different API key different from default OPENAI_API_KEY env var
107    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
108        self.api_key = SecretString::from(api_key.into());
109        self
110    }
111
112    /// To use a API base url different from default [OPENAI_API_BASE]
113    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
114        self.api_base = api_base.into();
115        self
116    }
117
118    /// Add a custom header that will be included in all requests.
119    /// Headers are merged with existing headers, with custom headers taking precedence.
120    pub fn with_header<K, V>(mut self, key: K, value: V) -> Result<Self, OpenAIError>
121    where
122        K: reqwest::header::IntoHeaderName,
123        V: TryInto<reqwest::header::HeaderValue>,
124        V::Error: Into<reqwest::header::InvalidHeaderValue>,
125    {
126        let header_value = value.try_into().map_err(|e| {
127            OpenAIError::InvalidArgument(format!("Invalid header value: {}", e.into()))
128        })?;
129        self.custom_headers.insert(key, header_value);
130        Ok(self)
131    }
132
133    pub fn org_id(&self) -> &str {
134        &self.org_id
135    }
136}
137
138impl Config for OpenAIConfig {
139    fn headers(&self) -> HeaderMap {
140        let mut headers = HeaderMap::new();
141        if !self.org_id.is_empty() {
142            headers.insert(
143                OPENAI_ORGANIZATION_HEADER,
144                self.org_id.as_str().parse().unwrap(),
145            );
146        }
147
148        if !self.project_id.is_empty() {
149            headers.insert(
150                OPENAI_PROJECT_HEADER,
151                self.project_id.as_str().parse().unwrap(),
152            );
153        }
154
155        headers.insert(
156            AUTHORIZATION,
157            format!("Bearer {}", self.api_key.expose_secret())
158                .as_str()
159                .parse()
160                .unwrap(),
161        );
162
163        // Merge custom headers, with custom headers taking precedence
164        for (key, value) in self.custom_headers.iter() {
165            headers.insert(key, value.clone());
166        }
167
168        headers
169    }
170
171    fn url(&self, path: &str) -> String {
172        format!("{}{}", self.api_base, path)
173    }
174
175    fn api_base(&self) -> &str {
176        &self.api_base
177    }
178
179    fn api_key(&self) -> &SecretString {
180        &self.api_key
181    }
182
183    fn query(&self) -> Vec<(&str, &str)> {
184        vec![]
185    }
186}
187
188/// Configuration for Azure OpenAI Service
189#[derive(Clone, Debug, Deserialize)]
190#[serde(default)]
191pub struct AzureConfig {
192    api_version: String,
193    deployment_id: String,
194    api_base: String,
195    api_key: SecretString,
196}
197
198impl Default for AzureConfig {
199    fn default() -> Self {
200        Self {
201            api_base: Default::default(),
202            api_key: std::env::var("OPENAI_API_KEY")
203                .unwrap_or_else(|_| "".to_string())
204                .into(),
205            deployment_id: Default::default(),
206            api_version: Default::default(),
207        }
208    }
209}
210
211impl AzureConfig {
212    pub fn new() -> Self {
213        Default::default()
214    }
215
216    pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
217        self.api_version = api_version.into();
218        self
219    }
220
221    pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
222        self.deployment_id = deployment_id.into();
223        self
224    }
225
226    /// To use a different API key different from default OPENAI_API_KEY env var
227    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
228        self.api_key = SecretString::from(api_key.into());
229        self
230    }
231
232    /// API base url in form of <https://your-resource-name.openai.azure.com>
233    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
234        self.api_base = api_base.into();
235        self
236    }
237}
238
239impl Config for AzureConfig {
240    fn headers(&self) -> HeaderMap {
241        let mut headers = HeaderMap::new();
242
243        headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
244
245        headers
246    }
247
248    fn url(&self, path: &str) -> String {
249        format!(
250            "{}/openai/deployments/{}{}",
251            self.api_base, self.deployment_id, path
252        )
253    }
254
255    fn api_base(&self) -> &str {
256        &self.api_base
257    }
258
259    fn api_key(&self) -> &SecretString {
260        &self.api_key
261    }
262
263    fn query(&self) -> Vec<(&str, &str)> {
264        vec![("api-version", &self.api_version)]
265    }
266}
267
268#[cfg(test)]
269mod test {
270    use super::*;
271    use crate::types::chat::{
272        ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
273    };
274    use crate::Client;
275    use std::sync::Arc;
276    #[test]
277    fn test_client_creation() {
278        unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
279        let openai_config = OpenAIConfig::default();
280        let config = Box::new(openai_config.clone()) as Box<dyn Config>;
281        let client = Client::with_config(config);
282        assert!(client.config().url("").ends_with("/v1"));
283
284        let config = Arc::new(openai_config) as Arc<dyn Config>;
285        let client = Client::with_config(config);
286        assert!(client.config().url("").ends_with("/v1"));
287        let cloned_client = client.clone();
288        assert!(cloned_client.config().url("").ends_with("/v1"));
289    }
290
291    async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
292        let _ = client.chat().create(CreateChatCompletionRequest {
293            model: "gpt-4o".to_string(),
294            messages: vec![ChatCompletionRequestMessage::User(
295                ChatCompletionRequestUserMessage {
296                    content: "Hello, world!".into(),
297                    ..Default::default()
298                },
299            )],
300            ..Default::default()
301        });
302    }
303
304    #[tokio::test]
305    async fn test_dynamic_dispatch() {
306        let openai_config = OpenAIConfig::default();
307        let azure_config = AzureConfig::default();
308
309        let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
310        let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
311
312        let _ = dynamic_dispatch_compiles(&azure_client).await;
313        let _ = dynamic_dispatch_compiles(&oai_client).await;
314
315        let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
316        let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
317    }
318}