async_openai_wasm/
config.rs1use reqwest::header::{AUTHORIZATION, HeaderMap};
3use secrecy::{ExposeSecret, SecretString};
4use serde::Deserialize;
5
6pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
8pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
10pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
12
13pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
15
16pub trait Config: Send + Sync {
19 fn headers(&self) -> HeaderMap;
20 fn url(&self, path: &str) -> String;
21 fn query(&self) -> Vec<(&str, &str)>;
22
23 fn api_base(&self) -> &str;
24
25 fn api_key(&self) -> &SecretString;
26}
27
28macro_rules! impl_config_for_ptr {
30 ($t:ty) => {
31 impl Config for $t {
32 fn headers(&self) -> HeaderMap {
33 self.as_ref().headers()
34 }
35 fn url(&self, path: &str) -> String {
36 self.as_ref().url(path)
37 }
38 fn query(&self) -> Vec<(&str, &str)> {
39 self.as_ref().query()
40 }
41 fn api_base(&self) -> &str {
42 self.as_ref().api_base()
43 }
44 fn api_key(&self) -> &SecretString {
45 self.as_ref().api_key()
46 }
47 }
48 };
49}
50
51impl_config_for_ptr!(Box<dyn Config>);
52impl_config_for_ptr!(std::sync::Arc<dyn Config>);
53
54#[derive(Clone, Debug, Deserialize)]
56#[serde(default)]
57pub struct OpenAIConfig {
58 api_base: String,
59 api_key: SecretString,
60 org_id: String,
61 project_id: String,
62}
63
64impl Default for OpenAIConfig {
65 fn default() -> Self {
66 Self {
67 api_base: OPENAI_API_BASE.to_string(),
68 api_key: std::env::var("OPENAI_API_KEY")
69 .unwrap_or_else(|_| "".to_string())
70 .into(),
71 org_id: Default::default(),
72 project_id: Default::default(),
73 }
74 }
75}
76
77impl OpenAIConfig {
78 pub fn new() -> Self {
80 Default::default()
81 }
82
83 pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
85 self.org_id = org_id.into();
86 self
87 }
88
89 pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
91 self.project_id = project_id.into();
92 self
93 }
94
95 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
97 self.api_key = SecretString::from(api_key.into());
98 self
99 }
100
101 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
103 self.api_base = api_base.into();
104 self
105 }
106
107 pub fn org_id(&self) -> &str {
108 &self.org_id
109 }
110}
111
112impl Config for OpenAIConfig {
113 fn headers(&self) -> HeaderMap {
114 let mut headers = HeaderMap::new();
115 if !self.org_id.is_empty() {
116 headers.insert(
117 OPENAI_ORGANIZATION_HEADER,
118 self.org_id.as_str().parse().unwrap(),
119 );
120 }
121
122 if !self.project_id.is_empty() {
123 headers.insert(
124 OPENAI_PROJECT_HEADER,
125 self.project_id.as_str().parse().unwrap(),
126 );
127 }
128
129 headers.insert(
130 AUTHORIZATION,
131 format!("Bearer {}", self.api_key.expose_secret())
132 .as_str()
133 .parse()
134 .unwrap(),
135 );
136
137 headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap());
140
141 headers
142 }
143
144 fn url(&self, path: &str) -> String {
145 format!("{}{}", self.api_base, path)
146 }
147
148 fn api_base(&self) -> &str {
149 &self.api_base
150 }
151
152 fn api_key(&self) -> &SecretString {
153 &self.api_key
154 }
155
156 fn query(&self) -> Vec<(&str, &str)> {
157 vec![]
158 }
159}
160
161#[derive(Clone, Debug, Deserialize)]
163#[serde(default)]
164pub struct AzureConfig {
165 api_version: String,
166 deployment_id: String,
167 api_base: String,
168 api_key: SecretString,
169}
170
171impl Default for AzureConfig {
172 fn default() -> Self {
173 Self {
174 api_base: Default::default(),
175 api_key: std::env::var("OPENAI_API_KEY")
176 .unwrap_or_else(|_| "".to_string())
177 .into(),
178 deployment_id: Default::default(),
179 api_version: Default::default(),
180 }
181 }
182}
183
184impl AzureConfig {
185 pub fn new() -> Self {
186 Default::default()
187 }
188
189 pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
190 self.api_version = api_version.into();
191 self
192 }
193
194 pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
195 self.deployment_id = deployment_id.into();
196 self
197 }
198
199 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
201 self.api_key = SecretString::from(api_key.into());
202 self
203 }
204
205 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
207 self.api_base = api_base.into();
208 self
209 }
210}
211
212impl Config for AzureConfig {
213 fn headers(&self) -> HeaderMap {
214 let mut headers = HeaderMap::new();
215
216 headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
217
218 headers
219 }
220
221 fn url(&self, path: &str) -> String {
222 format!(
223 "{}/openai/deployments/{}{}",
224 self.api_base, self.deployment_id, path
225 )
226 }
227
228 fn api_base(&self) -> &str {
229 &self.api_base
230 }
231
232 fn api_key(&self) -> &SecretString {
233 &self.api_key
234 }
235
236 fn query(&self) -> Vec<(&str, &str)> {
237 vec![("api-version", &self.api_version)]
238 }
239}
240
241#[cfg(test)]
242mod test {
243 use super::*;
244 use crate::types::{
245 ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
246 };
247 use crate::Client;
248 use std::sync::Arc;
249 #[test]
250 fn test_client_creation() {
251 unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
252 let openai_config = OpenAIConfig::default();
253 let config = Box::new(openai_config.clone()) as Box<dyn Config>;
254 let client = Client::with_config(config);
255 assert!(client.config().url("").ends_with("/v1"));
256
257 let config = Arc::new(openai_config) as Arc<dyn Config>;
258 let client = Client::with_config(config);
259 assert!(client.config().url("").ends_with("/v1"));
260 let cloned_client = client.clone();
261 assert!(cloned_client.config().url("").ends_with("/v1"));
262 }
263
264 async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
265 let _ = client.chat().create(CreateChatCompletionRequest {
266 model: "gpt-4o".to_string(),
267 messages: vec![ChatCompletionRequestMessage::User(
268 ChatCompletionRequestUserMessage {
269 content: "Hello, world!".into(),
270 ..Default::default()
271 },
272 )],
273 ..Default::default()
274 });
275 }
276
277 #[tokio::test]
278 async fn test_dynamic_dispatch() {
279 let openai_config = OpenAIConfig::default();
280 let azure_config = AzureConfig::default();
281
282 let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
283 let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
284
285 let _ = dynamic_dispatch_compiles(&azure_client).await;
286 let _ = dynamic_dispatch_compiles(&oai_client).await;
287
288 let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
289 let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
290 }
291}