async_openai/
config.rs

1//! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service.
2use reqwest::header::{HeaderMap, AUTHORIZATION};
3use secrecy::{ExposeSecret, SecretString};
4use serde::Deserialize;
5
6use crate::error::OpenAIError;
7
8/// Default v1 API base url
9pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
10/// Organization header
11pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
12/// Project header
13pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
14
15/// Calls to the Assistants API require that you pass a Beta header
16pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
17
18/// [crate::Client] relies on this for every API call on OpenAI
19/// or Azure OpenAI service
20pub trait Config: Send + Sync {
21    fn headers(&self) -> HeaderMap;
22    fn url(&self, path: &str) -> String;
23    fn query(&self) -> Vec<(&str, &str)>;
24
25    fn api_base(&self) -> &str;
26
27    fn api_key(&self) -> &SecretString;
28}
29
30/// Macro to implement Config trait for pointer types with dyn objects
31macro_rules! impl_config_for_ptr {
32    ($t:ty) => {
33        impl Config for $t {
34            fn headers(&self) -> HeaderMap {
35                self.as_ref().headers()
36            }
37            fn url(&self, path: &str) -> String {
38                self.as_ref().url(path)
39            }
40            fn query(&self) -> Vec<(&str, &str)> {
41                self.as_ref().query()
42            }
43            fn api_base(&self) -> &str {
44                self.as_ref().api_base()
45            }
46            fn api_key(&self) -> &SecretString {
47                self.as_ref().api_key()
48            }
49        }
50    };
51}
52
53impl_config_for_ptr!(Box<dyn Config>);
54impl_config_for_ptr!(std::sync::Arc<dyn Config>);
55
56/// Configuration for OpenAI API
57#[derive(Clone, Debug, Deserialize)]
58#[serde(default)]
59pub struct OpenAIConfig {
60    api_base: String,
61    api_key: SecretString,
62    org_id: String,
63    project_id: String,
64    #[serde(skip)]
65    custom_headers: HeaderMap,
66}
67
68impl Default for OpenAIConfig {
69    fn default() -> Self {
70        Self {
71            api_base: std::env::var("OPENAI_BASE_URL")
72                .unwrap_or_else(|_| OPENAI_API_BASE.to_string()),
73            api_key: std::env::var("OPENAI_API_KEY")
74                .or_else(|_| {
75                    std::env::var("OPENAI_ADMIN_KEY").map(|admin_key| {
76                        tracing::warn!("Using OPENAI_ADMIN_KEY, OPENAI_API_KEY not set");
77                        admin_key
78                    })
79                })
80                .unwrap_or_default()
81                .into(),
82            org_id: std::env::var("OPENAI_ORG_ID").unwrap_or_default(),
83            project_id: std::env::var("OPENAI_PROJECT_ID").unwrap_or_default(),
84            custom_headers: HeaderMap::new(),
85        }
86    }
87}
88
89impl OpenAIConfig {
90    /// Create client with default [OPENAI_API_BASE] url (can also be changed with OPENAI_BASE_URL env var) and default API key from OPENAI_API_KEY env var
91    pub fn new() -> Self {
92        Default::default()
93    }
94
95    /// To use a different organization id other than default
96    pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
97        self.org_id = org_id.into();
98        self
99    }
100
101    /// Non default project id
102    pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
103        self.project_id = project_id.into();
104        self
105    }
106
107    /// To use a different API key different from default OPENAI_API_KEY env var
108    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
109        self.api_key = SecretString::from(api_key.into());
110        self
111    }
112
113    /// To use a API base url different from default [OPENAI_API_BASE]
114    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
115        self.api_base = api_base.into();
116        self
117    }
118
119    /// Add a custom header that will be included in all requests.
120    /// Headers are merged with existing headers, with custom headers taking precedence.
121    pub fn with_header<K, V>(mut self, key: K, value: V) -> Result<Self, OpenAIError>
122    where
123        K: reqwest::header::IntoHeaderName,
124        V: TryInto<reqwest::header::HeaderValue>,
125        V::Error: Into<reqwest::header::InvalidHeaderValue>,
126    {
127        let header_value = value.try_into().map_err(|e| {
128            OpenAIError::InvalidArgument(format!("Invalid header value: {}", e.into()))
129        })?;
130        self.custom_headers.insert(key, header_value);
131        Ok(self)
132    }
133
134    pub fn org_id(&self) -> &str {
135        &self.org_id
136    }
137}
138
139impl Config for OpenAIConfig {
140    fn headers(&self) -> HeaderMap {
141        let mut headers = HeaderMap::new();
142        if !self.org_id.is_empty() {
143            headers.insert(
144                OPENAI_ORGANIZATION_HEADER,
145                self.org_id.as_str().parse().unwrap(),
146            );
147        }
148
149        if !self.project_id.is_empty() {
150            headers.insert(
151                OPENAI_PROJECT_HEADER,
152                self.project_id.as_str().parse().unwrap(),
153            );
154        }
155
156        headers.insert(
157            AUTHORIZATION,
158            format!("Bearer {}", self.api_key.expose_secret())
159                .as_str()
160                .parse()
161                .unwrap(),
162        );
163
164        // Merge custom headers, with custom headers taking precedence
165        for (key, value) in self.custom_headers.iter() {
166            headers.insert(key, value.clone());
167        }
168
169        headers
170    }
171
172    fn url(&self, path: &str) -> String {
173        format!("{}{}", self.api_base, path)
174    }
175
176    fn api_base(&self) -> &str {
177        &self.api_base
178    }
179
180    fn api_key(&self) -> &SecretString {
181        &self.api_key
182    }
183
184    fn query(&self) -> Vec<(&str, &str)> {
185        vec![]
186    }
187}
188
189/// Configuration for Azure OpenAI Service
190#[derive(Clone, Debug, Deserialize)]
191#[serde(default)]
192pub struct AzureConfig {
193    api_version: String,
194    deployment_id: String,
195    api_base: String,
196    api_key: SecretString,
197}
198
199impl Default for AzureConfig {
200    fn default() -> Self {
201        Self {
202            api_base: Default::default(),
203            api_key: std::env::var("OPENAI_API_KEY")
204                .unwrap_or_else(|_| "".to_string())
205                .into(),
206            deployment_id: Default::default(),
207            api_version: Default::default(),
208        }
209    }
210}
211
212impl AzureConfig {
213    pub fn new() -> Self {
214        Default::default()
215    }
216
217    pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
218        self.api_version = api_version.into();
219        self
220    }
221
222    pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
223        self.deployment_id = deployment_id.into();
224        self
225    }
226
227    /// To use a different API key different from default OPENAI_API_KEY env var
228    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
229        self.api_key = SecretString::from(api_key.into());
230        self
231    }
232
233    /// API base url in form of <https://your-resource-name.openai.azure.com>
234    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
235        self.api_base = api_base.into();
236        self
237    }
238}
239
240impl Config for AzureConfig {
241    fn headers(&self) -> HeaderMap {
242        let mut headers = HeaderMap::new();
243
244        headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
245
246        headers
247    }
248
249    fn url(&self, path: &str) -> String {
250        format!(
251            "{}/openai/deployments/{}{}",
252            self.api_base, self.deployment_id, path
253        )
254    }
255
256    fn api_base(&self) -> &str {
257        &self.api_base
258    }
259
260    fn api_key(&self) -> &SecretString {
261        &self.api_key
262    }
263
264    fn query(&self) -> Vec<(&str, &str)> {
265        vec![("api-version", &self.api_version)]
266    }
267}
268
269#[cfg(all(test, feature = "chat-completion"))]
270mod test {
271    use super::*;
272    use crate::types::chat::{
273        ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
274    };
275    use crate::Client;
276    use std::sync::Arc;
277    #[test]
278    fn test_client_creation() {
279        unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
280        let openai_config = OpenAIConfig::default();
281        let config = Box::new(openai_config.clone()) as Box<dyn Config>;
282        let client = Client::with_config(config);
283        assert!(client.config().url("").ends_with("/v1"));
284
285        let config = Arc::new(openai_config) as Arc<dyn Config>;
286        let client = Client::with_config(config);
287        assert!(client.config().url("").ends_with("/v1"));
288        let cloned_client = client.clone();
289        assert!(cloned_client.config().url("").ends_with("/v1"));
290    }
291
292    async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
293        let _ = client.chat().create(CreateChatCompletionRequest {
294            model: "gpt-4o".to_string(),
295            messages: vec![ChatCompletionRequestMessage::User(
296                ChatCompletionRequestUserMessage {
297                    content: "Hello, world!".into(),
298                    ..Default::default()
299                },
300            )],
301            ..Default::default()
302        });
303    }
304
305    #[tokio::test]
306    async fn test_dynamic_dispatch() {
307        let openai_config = OpenAIConfig::default();
308        let azure_config = AzureConfig::default();
309
310        let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
311        let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
312
313        let _ = dynamic_dispatch_compiles(&azure_client).await;
314        let _ = dynamic_dispatch_compiles(&oai_client).await;
315
316        let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
317        let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
318    }
319}