async_openai/
config.rs

1//! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service.
2use reqwest::header::{HeaderMap, AUTHORIZATION};
3use secrecy::{ExposeSecret, SecretString};
4use serde::Deserialize;
5
6use crate::error::OpenAIError;
7
8/// Default v1 API base url
9pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
10/// Organization header
11pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
12/// Project header
13pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
14
15/// Calls to the Assistants API require that you pass a Beta header
16pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
17
18/// [crate::Client] relies on this for every API call on OpenAI
19/// or Azure OpenAI service
20pub trait Config: Send + Sync {
21    fn headers(&self) -> HeaderMap;
22    fn url(&self, path: &str) -> String;
23    fn query(&self) -> Vec<(&str, &str)>;
24
25    fn api_base(&self) -> &str;
26
27    fn api_key(&self) -> &SecretString;
28}
29
30/// Macro to implement Config trait for pointer types with dyn objects
31macro_rules! impl_config_for_ptr {
32    ($t:ty) => {
33        impl Config for $t {
34            fn headers(&self) -> HeaderMap {
35                self.as_ref().headers()
36            }
37            fn url(&self, path: &str) -> String {
38                self.as_ref().url(path)
39            }
40            fn query(&self) -> Vec<(&str, &str)> {
41                self.as_ref().query()
42            }
43            fn api_base(&self) -> &str {
44                self.as_ref().api_base()
45            }
46            fn api_key(&self) -> &SecretString {
47                self.as_ref().api_key()
48            }
49        }
50    };
51}
52
53impl_config_for_ptr!(Box<dyn Config>);
54impl_config_for_ptr!(std::sync::Arc<dyn Config>);
55
56// Private helper functions for default values
57fn default_api_base() -> String {
58    #[cfg(not(target_family = "wasm"))]
59    {
60        std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| OPENAI_API_BASE.to_string())
61    }
62
63    #[cfg(target_family = "wasm")]
64    {
65        OPENAI_API_BASE.to_string()
66    }
67}
68
69fn default_api_key() -> String {
70    #[cfg(not(target_family = "wasm"))]
71    {
72        std::env::var("OPENAI_API_KEY")
73            .or_else(|_| {
74                std::env::var("OPENAI_ADMIN_KEY").map(|admin_key| {
75                    tracing::warn!("Using OPENAI_ADMIN_KEY, OPENAI_API_KEY not set");
76                    admin_key
77                })
78            })
79            .unwrap_or_default()
80    }
81
82    #[cfg(target_family = "wasm")]
83    {
84        String::new()
85    }
86}
87
88fn default_org_id() -> String {
89    #[cfg(not(target_family = "wasm"))]
90    {
91        std::env::var("OPENAI_ORG_ID").unwrap_or_default()
92    }
93
94    #[cfg(target_family = "wasm")]
95    {
96        String::new()
97    }
98}
99
100fn default_project_id() -> String {
101    #[cfg(not(target_family = "wasm"))]
102    {
103        std::env::var("OPENAI_PROJECT_ID").unwrap_or_default()
104    }
105
106    #[cfg(target_family = "wasm")]
107    {
108        String::new()
109    }
110}
111
112/// Configuration for OpenAI API
113#[derive(Clone, Debug, Deserialize)]
114#[serde(default)]
115pub struct OpenAIConfig {
116    api_base: String,
117    api_key: SecretString,
118    org_id: String,
119    project_id: String,
120    #[serde(skip)]
121    custom_headers: HeaderMap,
122}
123
124impl Default for OpenAIConfig {
125    fn default() -> Self {
126        Self {
127            api_base: default_api_base(),
128            api_key: default_api_key().into(),
129            org_id: default_org_id(),
130            project_id: default_project_id(),
131            custom_headers: HeaderMap::new(),
132        }
133    }
134}
135
136impl OpenAIConfig {
137    /// Create client with default [OPENAI_API_BASE] url (can also be changed with OPENAI_BASE_URL env var) and default API key from OPENAI_API_KEY env var
138    pub fn new() -> Self {
139        Default::default()
140    }
141
142    /// To use a different organization id other than default
143    pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
144        self.org_id = org_id.into();
145        self
146    }
147
148    /// Non default project id
149    pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
150        self.project_id = project_id.into();
151        self
152    }
153
154    /// To use a different API key different from default OPENAI_API_KEY env var
155    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
156        self.api_key = SecretString::from(api_key.into());
157        self
158    }
159
160    /// To use a API base url different from default [OPENAI_API_BASE]
161    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
162        self.api_base = api_base.into();
163        self
164    }
165
166    /// Add a custom header that will be included in all requests.
167    /// Headers are merged with existing headers, with custom headers taking precedence.
168    pub fn with_header<K, V>(mut self, key: K, value: V) -> Result<Self, OpenAIError>
169    where
170        K: reqwest::header::IntoHeaderName,
171        V: TryInto<reqwest::header::HeaderValue>,
172        V::Error: Into<reqwest::header::InvalidHeaderValue>,
173    {
174        let header_value = value.try_into().map_err(|e| {
175            OpenAIError::InvalidArgument(format!("Invalid header value: {}", e.into()))
176        })?;
177        self.custom_headers.insert(key, header_value);
178        Ok(self)
179    }
180
181    pub fn org_id(&self) -> &str {
182        &self.org_id
183    }
184}
185
186impl Config for OpenAIConfig {
187    fn headers(&self) -> HeaderMap {
188        let mut headers = HeaderMap::new();
189        if !self.org_id.is_empty() {
190            headers.insert(
191                OPENAI_ORGANIZATION_HEADER,
192                self.org_id.as_str().parse().unwrap(),
193            );
194        }
195
196        if !self.project_id.is_empty() {
197            headers.insert(
198                OPENAI_PROJECT_HEADER,
199                self.project_id.as_str().parse().unwrap(),
200            );
201        }
202
203        headers.insert(
204            AUTHORIZATION,
205            format!("Bearer {}", self.api_key.expose_secret())
206                .as_str()
207                .parse()
208                .unwrap(),
209        );
210
211        // Merge custom headers, with custom headers taking precedence
212        for (key, value) in self.custom_headers.iter() {
213            headers.insert(key, value.clone());
214        }
215
216        headers
217    }
218
219    fn url(&self, path: &str) -> String {
220        format!("{}{}", self.api_base, path)
221    }
222
223    fn api_base(&self) -> &str {
224        &self.api_base
225    }
226
227    fn api_key(&self) -> &SecretString {
228        &self.api_key
229    }
230
231    fn query(&self) -> Vec<(&str, &str)> {
232        vec![]
233    }
234}
235
236/// Configuration for Azure OpenAI Service
237#[derive(Clone, Debug, Deserialize)]
238#[serde(default)]
239pub struct AzureConfig {
240    api_version: String,
241    deployment_id: String,
242    api_base: String,
243    api_key: SecretString,
244}
245
246impl Default for AzureConfig {
247    fn default() -> Self {
248        Self {
249            api_base: Default::default(),
250            api_key: default_api_key().into(),
251            deployment_id: Default::default(),
252            api_version: Default::default(),
253        }
254    }
255}
256
257impl AzureConfig {
258    pub fn new() -> Self {
259        Default::default()
260    }
261
262    pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
263        self.api_version = api_version.into();
264        self
265    }
266
267    pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
268        self.deployment_id = deployment_id.into();
269        self
270    }
271
272    /// To use a different API key different from default OPENAI_API_KEY env var
273    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
274        self.api_key = SecretString::from(api_key.into());
275        self
276    }
277
278    /// API base url in form of <https://your-resource-name.openai.azure.com>
279    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
280        self.api_base = api_base.into();
281        self
282    }
283}
284
285impl Config for AzureConfig {
286    fn headers(&self) -> HeaderMap {
287        let mut headers = HeaderMap::new();
288
289        headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
290
291        headers
292    }
293
294    fn url(&self, path: &str) -> String {
295        format!(
296            "{}/openai/deployments/{}{}",
297            self.api_base, self.deployment_id, path
298        )
299    }
300
301    fn api_base(&self) -> &str {
302        &self.api_base
303    }
304
305    fn api_key(&self) -> &SecretString {
306        &self.api_key
307    }
308
309    fn query(&self) -> Vec<(&str, &str)> {
310        vec![("api-version", &self.api_version)]
311    }
312}
313
314#[cfg(all(test, feature = "chat-completion"))]
315mod test {
316    use super::*;
317    use crate::types::chat::{
318        ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
319    };
320    use crate::Client;
321    use std::sync::Arc;
322    #[test]
323    fn test_client_creation() {
324        unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
325        let openai_config = OpenAIConfig::default();
326        let config = Box::new(openai_config.clone()) as Box<dyn Config>;
327        let client = Client::with_config(config);
328        assert!(client.config().url("").ends_with("/v1"));
329
330        let config = Arc::new(openai_config) as Arc<dyn Config>;
331        let client = Client::with_config(config);
332        assert!(client.config().url("").ends_with("/v1"));
333        let cloned_client = client.clone();
334        assert!(cloned_client.config().url("").ends_with("/v1"));
335    }
336
337    async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
338        let _ = client.chat().create(CreateChatCompletionRequest {
339            model: "gpt-4o".to_string(),
340            messages: vec![ChatCompletionRequestMessage::User(
341                ChatCompletionRequestUserMessage {
342                    content: "Hello, world!".into(),
343                    ..Default::default()
344                },
345            )],
346            ..Default::default()
347        });
348    }
349
350    #[tokio::test]
351    async fn test_dynamic_dispatch() {
352        let openai_config = OpenAIConfig::default();
353        let azure_config = AzureConfig::default();
354
355        let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
356        let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
357
358        let _ = dynamic_dispatch_compiles(&azure_client).await;
359        let _ = dynamic_dispatch_compiles(&oai_client).await;
360
361        let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
362        let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
363    }
364}