async_openai_wasm/
config.rs

1//! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service.
2use reqwest::header::{AUTHORIZATION, HeaderMap};
3use secrecy::{ExposeSecret, SecretString};
4use serde::Deserialize;
5
6/// Default v1 API base url
7pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
8/// Organization header
9pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
10/// Project header
11pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
12
13/// Calls to the Assistants API require that you pass a Beta header
14pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
15
16/// [crate::Client] relies on this for every API call on OpenAI
17/// or Azure OpenAI service
18pub trait Config: Send + Sync {
19    fn headers(&self) -> HeaderMap;
20    fn url(&self, path: &str) -> String;
21    fn query(&self) -> Vec<(&str, &str)>;
22
23    fn api_base(&self) -> &str;
24
25    fn api_key(&self) -> &SecretString;
26}
27
28/// Macro to implement Config trait for pointer types with dyn objects
29macro_rules! impl_config_for_ptr {
30    ($t:ty) => {
31        impl Config for $t {
32            fn headers(&self) -> HeaderMap {
33                self.as_ref().headers()
34            }
35            fn url(&self, path: &str) -> String {
36                self.as_ref().url(path)
37            }
38            fn query(&self) -> Vec<(&str, &str)> {
39                self.as_ref().query()
40            }
41            fn api_base(&self) -> &str {
42                self.as_ref().api_base()
43            }
44            fn api_key(&self) -> &SecretString {
45                self.as_ref().api_key()
46            }
47        }
48    };
49}
50
51impl_config_for_ptr!(Box<dyn Config>);
52impl_config_for_ptr!(std::sync::Arc<dyn Config>);
53
54/// Configuration for OpenAI API
55#[derive(Clone, Debug, Deserialize)]
56#[serde(default)]
57pub struct OpenAIConfig {
58    api_base: String,
59    api_key: SecretString,
60    org_id: String,
61    project_id: String,
62}
63
64impl Default for OpenAIConfig {
65    fn default() -> Self {
66        Self {
67            api_base: OPENAI_API_BASE.to_string(),
68            api_key: std::env::var("OPENAI_API_KEY")
69                .unwrap_or_else(|_| "".to_string())
70                .into(),
71            org_id: Default::default(),
72            project_id: Default::default(),
73        }
74    }
75}
76
77impl OpenAIConfig {
78    /// Create client with default [OPENAI_API_BASE] url and default API key from OPENAI_API_KEY env var
79    pub fn new() -> Self {
80        Default::default()
81    }
82
83    /// To use a different organization id other than default
84    pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
85        self.org_id = org_id.into();
86        self
87    }
88
89    /// Non default project id
90    pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
91        self.project_id = project_id.into();
92        self
93    }
94
95    /// To use a different API key different from default OPENAI_API_KEY env var
96    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
97        self.api_key = SecretString::from(api_key.into());
98        self
99    }
100
101    /// To use a API base url different from default [OPENAI_API_BASE]
102    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
103        self.api_base = api_base.into();
104        self
105    }
106
107    pub fn org_id(&self) -> &str {
108        &self.org_id
109    }
110}
111
112impl Config for OpenAIConfig {
113    fn headers(&self) -> HeaderMap {
114        let mut headers = HeaderMap::new();
115        if !self.org_id.is_empty() {
116            headers.insert(
117                OPENAI_ORGANIZATION_HEADER,
118                self.org_id.as_str().parse().unwrap(),
119            );
120        }
121
122        if !self.project_id.is_empty() {
123            headers.insert(
124                OPENAI_PROJECT_HEADER,
125                self.project_id.as_str().parse().unwrap(),
126            );
127        }
128
129        headers.insert(
130            AUTHORIZATION,
131            format!("Bearer {}", self.api_key.expose_secret())
132                .as_str()
133                .parse()
134                .unwrap(),
135        );
136
137        // hack for Assistants APIs
138        // Calls to the Assistants API require that you pass a Beta header
139        headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap());
140
141        headers
142    }
143
144    fn url(&self, path: &str) -> String {
145        format!("{}{}", self.api_base, path)
146    }
147
148    fn api_base(&self) -> &str {
149        &self.api_base
150    }
151
152    fn api_key(&self) -> &SecretString {
153        &self.api_key
154    }
155
156    fn query(&self) -> Vec<(&str, &str)> {
157        vec![]
158    }
159}
160
161/// Configuration for Azure OpenAI Service
162#[derive(Clone, Debug, Deserialize)]
163#[serde(default)]
164pub struct AzureConfig {
165    api_version: String,
166    deployment_id: String,
167    api_base: String,
168    api_key: SecretString,
169}
170
171impl Default for AzureConfig {
172    fn default() -> Self {
173        Self {
174            api_base: Default::default(),
175            api_key: std::env::var("OPENAI_API_KEY")
176                .unwrap_or_else(|_| "".to_string())
177                .into(),
178            deployment_id: Default::default(),
179            api_version: Default::default(),
180        }
181    }
182}
183
184impl AzureConfig {
185    pub fn new() -> Self {
186        Default::default()
187    }
188
189    pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
190        self.api_version = api_version.into();
191        self
192    }
193
194    pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
195        self.deployment_id = deployment_id.into();
196        self
197    }
198
199    /// To use a different API key different from default OPENAI_API_KEY env var
200    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
201        self.api_key = SecretString::from(api_key.into());
202        self
203    }
204
205    /// API base url in form of <https://your-resource-name.openai.azure.com>
206    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
207        self.api_base = api_base.into();
208        self
209    }
210}
211
212impl Config for AzureConfig {
213    fn headers(&self) -> HeaderMap {
214        let mut headers = HeaderMap::new();
215
216        headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
217
218        headers
219    }
220
221    fn url(&self, path: &str) -> String {
222        format!(
223            "{}/openai/deployments/{}{}",
224            self.api_base, self.deployment_id, path
225        )
226    }
227
228    fn api_base(&self) -> &str {
229        &self.api_base
230    }
231
232    fn api_key(&self) -> &SecretString {
233        &self.api_key
234    }
235
236    fn query(&self) -> Vec<(&str, &str)> {
237        vec![("api-version", &self.api_version)]
238    }
239}
240
241#[cfg(test)]
242mod test {
243    use super::*;
244    use crate::types::{
245        ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
246    };
247    use crate::Client;
248    use std::sync::Arc;
249    #[test]
250    fn test_client_creation() {
251        unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
252        let openai_config = OpenAIConfig::default();
253        let config = Box::new(openai_config.clone()) as Box<dyn Config>;
254        let client = Client::with_config(config);
255        assert!(client.config().url("").ends_with("/v1"));
256
257        let config = Arc::new(openai_config) as Arc<dyn Config>;
258        let client = Client::with_config(config);
259        assert!(client.config().url("").ends_with("/v1"));
260        let cloned_client = client.clone();
261        assert!(cloned_client.config().url("").ends_with("/v1"));
262    }
263
264    async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
265        let _ = client.chat().create(CreateChatCompletionRequest {
266            model: "gpt-4o".to_string(),
267            messages: vec![ChatCompletionRequestMessage::User(
268                ChatCompletionRequestUserMessage {
269                    content: "Hello, world!".into(),
270                    ..Default::default()
271                },
272            )],
273            ..Default::default()
274        });
275    }
276
277    #[tokio::test]
278    async fn test_dynamic_dispatch() {
279        let openai_config = OpenAIConfig::default();
280        let azure_config = AzureConfig::default();
281
282        let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
283        let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
284
285        let _ = dynamic_dispatch_compiles(&azure_client).await;
286        let _ = dynamic_dispatch_compiles(&oai_client).await;
287
288        let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
289        let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
290    }
291}