async_openai_wasm/
config.rs

1//! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service.
2use reqwest::header::{AUTHORIZATION, HeaderMap};
3use secrecy::{ExposeSecret, SecretString};
4use serde::Deserialize;
5
6/// Default v1 API base url
7pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
8/// Organization header
9pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
10/// Project header
11pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
12
13/// Calls to the Assistants API require that you pass a Beta header
14pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
15
16/// [crate::Client] relies on this for every API call on OpenAI
17/// or Azure OpenAI service
18pub trait Config {
19    fn headers(&self) -> HeaderMap;
20    fn url(&self, path: &str) -> String;
21    fn query(&self) -> Vec<(&str, &str)>;
22
23    fn api_base(&self) -> &str;
24
25    fn api_key(&self) -> &SecretString;
26}
27
28/// Macro to implement Config trait for pointer types with dyn objects
29macro_rules! impl_config_for_ptr {
30    ($t:ty) => {
31        impl Config for $t {
32            fn headers(&self) -> HeaderMap {
33                self.as_ref().headers()
34            }
35            fn url(&self, path: &str) -> String {
36                self.as_ref().url(path)
37            }
38            fn query(&self) -> Vec<(&str, &str)> {
39                self.as_ref().query()
40            }
41            fn api_base(&self) -> &str {
42                self.as_ref().api_base()
43            }
44            fn api_key(&self) -> &SecretString {
45                self.as_ref().api_key()
46            }
47        }
48    };
49}
50
51impl_config_for_ptr!(Box<dyn Config>);
52impl_config_for_ptr!(std::rc::Rc<dyn Config>);
53impl_config_for_ptr!(std::sync::Arc<dyn Config>);
54
55/// Configuration for OpenAI API
56#[derive(Clone, Debug, Deserialize)]
57#[serde(default)]
58pub struct OpenAIConfig {
59    api_base: String,
60    api_key: SecretString,
61    org_id: String,
62    project_id: String,
63}
64
65impl Default for OpenAIConfig {
66    fn default() -> Self {
67        Self {
68            api_base: OPENAI_API_BASE.to_string(),
69            api_key: std::env::var("OPENAI_API_KEY")
70                .unwrap_or_else(|_| "".to_string())
71                .into(),
72            org_id: Default::default(),
73            project_id: Default::default(),
74        }
75    }
76}
77
78impl OpenAIConfig {
79    /// Create client with default [OPENAI_API_BASE] url and default API key from OPENAI_API_KEY env var
80    pub fn new() -> Self {
81        Default::default()
82    }
83
84    /// To use a different organization id other than default
85    pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
86        self.org_id = org_id.into();
87        self
88    }
89
90    /// Non default project id
91    pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
92        self.project_id = project_id.into();
93        self
94    }
95
96    /// To use a different API key different from default OPENAI_API_KEY env var
97    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
98        self.api_key = SecretString::from(api_key.into());
99        self
100    }
101
102    /// To use a API base url different from default [OPENAI_API_BASE]
103    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
104        self.api_base = api_base.into();
105        self
106    }
107
108    pub fn org_id(&self) -> &str {
109        &self.org_id
110    }
111}
112
113impl Config for OpenAIConfig {
114    fn headers(&self) -> HeaderMap {
115        let mut headers = HeaderMap::new();
116        if !self.org_id.is_empty() {
117            headers.insert(
118                OPENAI_ORGANIZATION_HEADER,
119                self.org_id.as_str().parse().unwrap(),
120            );
121        }
122
123        if !self.project_id.is_empty() {
124            headers.insert(
125                OPENAI_PROJECT_HEADER,
126                self.project_id.as_str().parse().unwrap(),
127            );
128        }
129
130        headers.insert(
131            AUTHORIZATION,
132            format!("Bearer {}", self.api_key.expose_secret())
133                .as_str()
134                .parse()
135                .unwrap(),
136        );
137
138        // hack for Assistants APIs
139        // Calls to the Assistants API require that you pass a Beta header
140        headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap());
141
142        headers
143    }
144
145    fn url(&self, path: &str) -> String {
146        format!("{}{}", self.api_base, path)
147    }
148
149    fn api_base(&self) -> &str {
150        &self.api_base
151    }
152
153    fn api_key(&self) -> &SecretString {
154        &self.api_key
155    }
156
157    fn query(&self) -> Vec<(&str, &str)> {
158        vec![]
159    }
160}
161
162/// Configuration for Azure OpenAI Service
163#[derive(Clone, Debug, Deserialize)]
164#[serde(default)]
165pub struct AzureConfig {
166    api_version: String,
167    deployment_id: String,
168    api_base: String,
169    api_key: SecretString,
170}
171
172impl Default for AzureConfig {
173    fn default() -> Self {
174        Self {
175            api_base: Default::default(),
176            api_key: std::env::var("OPENAI_API_KEY")
177                .unwrap_or_else(|_| "".to_string())
178                .into(),
179            deployment_id: Default::default(),
180            api_version: Default::default(),
181        }
182    }
183}
184
185impl AzureConfig {
186    pub fn new() -> Self {
187        Default::default()
188    }
189
190    pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
191        self.api_version = api_version.into();
192        self
193    }
194
195    pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
196        self.deployment_id = deployment_id.into();
197        self
198    }
199
200    /// To use a different API key different from default OPENAI_API_KEY env var
201    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
202        self.api_key = SecretString::from(api_key.into());
203        self
204    }
205
206    /// API base url in form of <https://your-resource-name.openai.azure.com>
207    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
208        self.api_base = api_base.into();
209        self
210    }
211}
212
213impl Config for AzureConfig {
214    fn headers(&self) -> HeaderMap {
215        let mut headers = HeaderMap::new();
216
217        headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
218
219        headers
220    }
221
222    fn url(&self, path: &str) -> String {
223        format!(
224            "{}/openai/deployments/{}{}",
225            self.api_base, self.deployment_id, path
226        )
227    }
228
229    fn api_base(&self) -> &str {
230        &self.api_base
231    }
232
233    fn api_key(&self) -> &SecretString {
234        &self.api_key
235    }
236
237    fn query(&self) -> Vec<(&str, &str)> {
238        vec![("api-version", &self.api_version)]
239    }
240}
241
242#[cfg(test)]
243mod test {
244    use super::*;
245    use crate::Client;
246    use std::rc::Rc;
247    use std::sync::Arc;
248    #[test]
249    fn test_client_creation() {
250        unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
251        let openai_config = OpenAIConfig::default();
252        let config = Box::new(openai_config.clone()) as Box<dyn Config>;
253        let client = Client::with_config(config);
254        assert!(client.config().url("").ends_with("/v1"));
255        let config = Rc::new(openai_config.clone()) as Rc<dyn Config>;
256        let client = Client::with_config(config);
257        assert!(client.config().url("").ends_with("/v1"));
258        let cloned_client = client.clone();
259        assert!(cloned_client.config().url("").ends_with("/v1"));
260        let config = Arc::new(openai_config) as Arc<dyn Config>;
261        let client = Client::with_config(config);
262        assert!(client.config().url("").ends_with("/v1"));
263        let cloned_client = client.clone();
264        assert!(cloned_client.config().url("").ends_with("/v1"));
265    }
266}