async_openai/
config.rs

1//! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service.
2use reqwest::header::{HeaderMap, AUTHORIZATION};
3use secrecy::{ExposeSecret, SecretString};
4use serde::Deserialize;
5
6use crate::error::OpenAIError;
7
8/// Default v1 API base url
9pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
10/// Organization header
11pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
12/// Project header
13pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
14
15/// Calls to the Assistants API require that you pass a Beta header
16pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
17
18/// [crate::Client] relies on this for every API call on OpenAI
19/// or Azure OpenAI service
20pub trait Config: Send + Sync {
21    fn headers(&self) -> HeaderMap;
22    fn url(&self, path: &str) -> String;
23    fn query(&self) -> Vec<(&str, &str)>;
24
25    fn api_base(&self) -> &str;
26
27    fn api_key(&self) -> &SecretString;
28}
29
30/// Macro to implement Config trait for pointer types with dyn objects
31macro_rules! impl_config_for_ptr {
32    ($t:ty) => {
33        impl Config for $t {
34            fn headers(&self) -> HeaderMap {
35                self.as_ref().headers()
36            }
37            fn url(&self, path: &str) -> String {
38                self.as_ref().url(path)
39            }
40            fn query(&self) -> Vec<(&str, &str)> {
41                self.as_ref().query()
42            }
43            fn api_base(&self) -> &str {
44                self.as_ref().api_base()
45            }
46            fn api_key(&self) -> &SecretString {
47                self.as_ref().api_key()
48            }
49        }
50    };
51}
52
53impl_config_for_ptr!(Box<dyn Config>);
54impl_config_for_ptr!(std::sync::Arc<dyn Config>);
55
56/// Configuration for OpenAI API
57#[derive(Clone, Debug, Deserialize)]
58#[serde(default)]
59pub struct OpenAIConfig {
60    api_base: String,
61    api_key: SecretString,
62    org_id: String,
63    project_id: String,
64    #[serde(skip)]
65    custom_headers: HeaderMap,
66}
67
68impl Default for OpenAIConfig {
69    fn default() -> Self {
70        Self {
71            api_base: OPENAI_API_BASE.to_string(),
72            api_key: std::env::var("OPENAI_API_KEY")
73                .unwrap_or_else(|_| "".to_string())
74                .into(),
75            org_id: Default::default(),
76            project_id: Default::default(),
77            custom_headers: HeaderMap::new(),
78        }
79    }
80}
81
82impl OpenAIConfig {
83    /// Create client with default [OPENAI_API_BASE] url and default API key from OPENAI_API_KEY env var
84    pub fn new() -> Self {
85        Default::default()
86    }
87
88    /// To use a different organization id other than default
89    pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
90        self.org_id = org_id.into();
91        self
92    }
93
94    /// Non default project id
95    pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
96        self.project_id = project_id.into();
97        self
98    }
99
100    /// To use a different API key different from default OPENAI_API_KEY env var
101    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
102        self.api_key = SecretString::from(api_key.into());
103        self
104    }
105
106    /// To use a API base url different from default [OPENAI_API_BASE]
107    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
108        self.api_base = api_base.into();
109        self
110    }
111
112    /// Add a custom header that will be included in all requests.
113    /// Headers are merged with existing headers, with custom headers taking precedence.
114    pub fn with_header<K, V>(mut self, key: K, value: V) -> Result<Self, OpenAIError>
115    where
116        K: reqwest::header::IntoHeaderName,
117        V: TryInto<reqwest::header::HeaderValue>,
118        V::Error: Into<reqwest::header::InvalidHeaderValue>,
119    {
120        let header_value = value.try_into().map_err(|e| {
121            OpenAIError::InvalidArgument(format!("Invalid header value: {}", e.into()))
122        })?;
123        self.custom_headers.insert(key, header_value);
124        Ok(self)
125    }
126
127    pub fn org_id(&self) -> &str {
128        &self.org_id
129    }
130}
131
132impl Config for OpenAIConfig {
133    fn headers(&self) -> HeaderMap {
134        let mut headers = HeaderMap::new();
135        if !self.org_id.is_empty() {
136            headers.insert(
137                OPENAI_ORGANIZATION_HEADER,
138                self.org_id.as_str().parse().unwrap(),
139            );
140        }
141
142        if !self.project_id.is_empty() {
143            headers.insert(
144                OPENAI_PROJECT_HEADER,
145                self.project_id.as_str().parse().unwrap(),
146            );
147        }
148
149        headers.insert(
150            AUTHORIZATION,
151            format!("Bearer {}", self.api_key.expose_secret())
152                .as_str()
153                .parse()
154                .unwrap(),
155        );
156
157        // Merge custom headers, with custom headers taking precedence
158        for (key, value) in self.custom_headers.iter() {
159            headers.insert(key, value.clone());
160        }
161
162        headers
163    }
164
165    fn url(&self, path: &str) -> String {
166        format!("{}{}", self.api_base, path)
167    }
168
169    fn api_base(&self) -> &str {
170        &self.api_base
171    }
172
173    fn api_key(&self) -> &SecretString {
174        &self.api_key
175    }
176
177    fn query(&self) -> Vec<(&str, &str)> {
178        vec![]
179    }
180}
181
182/// Configuration for Azure OpenAI Service
183#[derive(Clone, Debug, Deserialize)]
184#[serde(default)]
185pub struct AzureConfig {
186    api_version: String,
187    deployment_id: String,
188    api_base: String,
189    api_key: SecretString,
190}
191
192impl Default for AzureConfig {
193    fn default() -> Self {
194        Self {
195            api_base: Default::default(),
196            api_key: std::env::var("OPENAI_API_KEY")
197                .unwrap_or_else(|_| "".to_string())
198                .into(),
199            deployment_id: Default::default(),
200            api_version: Default::default(),
201        }
202    }
203}
204
205impl AzureConfig {
206    pub fn new() -> Self {
207        Default::default()
208    }
209
210    pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
211        self.api_version = api_version.into();
212        self
213    }
214
215    pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
216        self.deployment_id = deployment_id.into();
217        self
218    }
219
220    /// To use a different API key different from default OPENAI_API_KEY env var
221    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
222        self.api_key = SecretString::from(api_key.into());
223        self
224    }
225
226    /// API base url in form of <https://your-resource-name.openai.azure.com>
227    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
228        self.api_base = api_base.into();
229        self
230    }
231}
232
233impl Config for AzureConfig {
234    fn headers(&self) -> HeaderMap {
235        let mut headers = HeaderMap::new();
236
237        headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
238
239        headers
240    }
241
242    fn url(&self, path: &str) -> String {
243        format!(
244            "{}/openai/deployments/{}{}",
245            self.api_base, self.deployment_id, path
246        )
247    }
248
249    fn api_base(&self) -> &str {
250        &self.api_base
251    }
252
253    fn api_key(&self) -> &SecretString {
254        &self.api_key
255    }
256
257    fn query(&self) -> Vec<(&str, &str)> {
258        vec![("api-version", &self.api_version)]
259    }
260}
261
262#[cfg(test)]
263mod test {
264    use super::*;
265    use crate::types::chat::{
266        ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
267    };
268    use crate::Client;
269    use std::sync::Arc;
270    #[test]
271    fn test_client_creation() {
272        unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
273        let openai_config = OpenAIConfig::default();
274        let config = Box::new(openai_config.clone()) as Box<dyn Config>;
275        let client = Client::with_config(config);
276        assert!(client.config().url("").ends_with("/v1"));
277
278        let config = Arc::new(openai_config) as Arc<dyn Config>;
279        let client = Client::with_config(config);
280        assert!(client.config().url("").ends_with("/v1"));
281        let cloned_client = client.clone();
282        assert!(cloned_client.config().url("").ends_with("/v1"));
283    }
284
285    async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
286        let _ = client.chat().create(CreateChatCompletionRequest {
287            model: "gpt-4o".to_string(),
288            messages: vec![ChatCompletionRequestMessage::User(
289                ChatCompletionRequestUserMessage {
290                    content: "Hello, world!".into(),
291                    ..Default::default()
292                },
293            )],
294            ..Default::default()
295        });
296    }
297
298    #[tokio::test]
299    async fn test_dynamic_dispatch() {
300        let openai_config = OpenAIConfig::default();
301        let azure_config = AzureConfig::default();
302
303        let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
304        let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
305
306        let _ = dynamic_dispatch_compiles(&azure_client).await;
307        let _ = dynamic_dispatch_compiles(&oai_client).await;
308
309        let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
310        let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
311    }
312}