async_openai_wasm/
config.rs1use reqwest::header::{AUTHORIZATION, HeaderMap};
3use secrecy::{ExposeSecret, SecretString};
4use serde::Deserialize;
5
6pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
8pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
10pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
12
13pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
15
16pub trait Config {
19 fn headers(&self) -> HeaderMap;
20 fn url(&self, path: &str) -> String;
21 fn query(&self) -> Vec<(&str, &str)>;
22
23 fn api_base(&self) -> &str;
24
25 fn api_key(&self) -> &SecretString;
26}
27
28macro_rules! impl_config_for_ptr {
30 ($t:ty) => {
31 impl Config for $t {
32 fn headers(&self) -> HeaderMap {
33 self.as_ref().headers()
34 }
35 fn url(&self, path: &str) -> String {
36 self.as_ref().url(path)
37 }
38 fn query(&self) -> Vec<(&str, &str)> {
39 self.as_ref().query()
40 }
41 fn api_base(&self) -> &str {
42 self.as_ref().api_base()
43 }
44 fn api_key(&self) -> &SecretString {
45 self.as_ref().api_key()
46 }
47 }
48 };
49}
50
51impl_config_for_ptr!(Box<dyn Config>);
52impl_config_for_ptr!(std::rc::Rc<dyn Config>);
53impl_config_for_ptr!(std::sync::Arc<dyn Config>);
54
55#[derive(Clone, Debug, Deserialize)]
57#[serde(default)]
58pub struct OpenAIConfig {
59 api_base: String,
60 api_key: SecretString,
61 org_id: String,
62 project_id: String,
63}
64
65impl Default for OpenAIConfig {
66 fn default() -> Self {
67 Self {
68 api_base: OPENAI_API_BASE.to_string(),
69 api_key: std::env::var("OPENAI_API_KEY")
70 .unwrap_or_else(|_| "".to_string())
71 .into(),
72 org_id: Default::default(),
73 project_id: Default::default(),
74 }
75 }
76}
77
78impl OpenAIConfig {
79 pub fn new() -> Self {
81 Default::default()
82 }
83
84 pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
86 self.org_id = org_id.into();
87 self
88 }
89
90 pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
92 self.project_id = project_id.into();
93 self
94 }
95
96 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
98 self.api_key = SecretString::from(api_key.into());
99 self
100 }
101
102 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
104 self.api_base = api_base.into();
105 self
106 }
107
108 pub fn org_id(&self) -> &str {
109 &self.org_id
110 }
111}
112
113impl Config for OpenAIConfig {
114 fn headers(&self) -> HeaderMap {
115 let mut headers = HeaderMap::new();
116 if !self.org_id.is_empty() {
117 headers.insert(
118 OPENAI_ORGANIZATION_HEADER,
119 self.org_id.as_str().parse().unwrap(),
120 );
121 }
122
123 if !self.project_id.is_empty() {
124 headers.insert(
125 OPENAI_PROJECT_HEADER,
126 self.project_id.as_str().parse().unwrap(),
127 );
128 }
129
130 headers.insert(
131 AUTHORIZATION,
132 format!("Bearer {}", self.api_key.expose_secret())
133 .as_str()
134 .parse()
135 .unwrap(),
136 );
137
138 headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap());
141
142 headers
143 }
144
145 fn url(&self, path: &str) -> String {
146 format!("{}{}", self.api_base, path)
147 }
148
149 fn api_base(&self) -> &str {
150 &self.api_base
151 }
152
153 fn api_key(&self) -> &SecretString {
154 &self.api_key
155 }
156
157 fn query(&self) -> Vec<(&str, &str)> {
158 vec![]
159 }
160}
161
162#[derive(Clone, Debug, Deserialize)]
164#[serde(default)]
165pub struct AzureConfig {
166 api_version: String,
167 deployment_id: String,
168 api_base: String,
169 api_key: SecretString,
170}
171
172impl Default for AzureConfig {
173 fn default() -> Self {
174 Self {
175 api_base: Default::default(),
176 api_key: std::env::var("OPENAI_API_KEY")
177 .unwrap_or_else(|_| "".to_string())
178 .into(),
179 deployment_id: Default::default(),
180 api_version: Default::default(),
181 }
182 }
183}
184
185impl AzureConfig {
186 pub fn new() -> Self {
187 Default::default()
188 }
189
190 pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
191 self.api_version = api_version.into();
192 self
193 }
194
195 pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
196 self.deployment_id = deployment_id.into();
197 self
198 }
199
200 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
202 self.api_key = SecretString::from(api_key.into());
203 self
204 }
205
206 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
208 self.api_base = api_base.into();
209 self
210 }
211}
212
213impl Config for AzureConfig {
214 fn headers(&self) -> HeaderMap {
215 let mut headers = HeaderMap::new();
216
217 headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
218
219 headers
220 }
221
222 fn url(&self, path: &str) -> String {
223 format!(
224 "{}/openai/deployments/{}{}",
225 self.api_base, self.deployment_id, path
226 )
227 }
228
229 fn api_base(&self) -> &str {
230 &self.api_base
231 }
232
233 fn api_key(&self) -> &SecretString {
234 &self.api_key
235 }
236
237 fn query(&self) -> Vec<(&str, &str)> {
238 vec![("api-version", &self.api_version)]
239 }
240}
241
242#[cfg(test)]
243mod test {
244 use super::*;
245 use crate::Client;
246 use std::rc::Rc;
247 use std::sync::Arc;
248 #[test]
249 fn test_client_creation() {
250 unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
251 let openai_config = OpenAIConfig::default();
252 let config = Box::new(openai_config.clone()) as Box<dyn Config>;
253 let client = Client::with_config(config);
254 assert!(client.config().url("").ends_with("/v1"));
255 let config = Rc::new(openai_config.clone()) as Rc<dyn Config>;
256 let client = Client::with_config(config);
257 assert!(client.config().url("").ends_with("/v1"));
258 let cloned_client = client.clone();
259 assert!(cloned_client.config().url("").ends_with("/v1"));
260 let config = Arc::new(openai_config) as Arc<dyn Config>;
261 let client = Client::with_config(config);
262 assert!(client.config().url("").ends_with("/v1"));
263 let cloned_client = client.clone();
264 assert!(cloned_client.config().url("").ends_with("/v1"));
265 }
266}