dynamo_async_openai/
config.rs1use reqwest::header::{AUTHORIZATION, HeaderMap};
13use secrecy::{ExposeSecret, SecretString};
14use serde::Deserialize;
15
16pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
18pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
20pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
22
23pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
25
26pub trait Config: Send + Sync {
29 fn headers(&self) -> HeaderMap;
30 fn url(&self, path: &str) -> String;
31 fn query(&self) -> Vec<(&str, &str)>;
32
33 fn api_base(&self) -> &str;
34
35 fn api_key(&self) -> &SecretString;
36}
37
38macro_rules! impl_config_for_ptr {
40 ($t:ty) => {
41 impl Config for $t {
42 fn headers(&self) -> HeaderMap {
43 self.as_ref().headers()
44 }
45 fn url(&self, path: &str) -> String {
46 self.as_ref().url(path)
47 }
48 fn query(&self) -> Vec<(&str, &str)> {
49 self.as_ref().query()
50 }
51 fn api_base(&self) -> &str {
52 self.as_ref().api_base()
53 }
54 fn api_key(&self) -> &SecretString {
55 self.as_ref().api_key()
56 }
57 }
58 };
59}
60
61impl_config_for_ptr!(Box<dyn Config>);
62impl_config_for_ptr!(std::sync::Arc<dyn Config>);
63
64#[derive(Clone, Debug, Deserialize)]
66#[serde(default)]
67pub struct OpenAIConfig {
68 api_base: String,
69 api_key: SecretString,
70 org_id: String,
71 project_id: String,
72}
73
74impl Default for OpenAIConfig {
75 fn default() -> Self {
76 Self {
77 api_base: OPENAI_API_BASE.to_string(),
78 api_key: std::env::var("OPENAI_API_KEY")
79 .unwrap_or_else(|_| "".to_string())
80 .into(),
81 org_id: Default::default(),
82 project_id: Default::default(),
83 }
84 }
85}
86
87impl OpenAIConfig {
88 pub fn new() -> Self {
90 Default::default()
91 }
92
93 pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
95 self.org_id = org_id.into();
96 self
97 }
98
99 pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
101 self.project_id = project_id.into();
102 self
103 }
104
105 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
107 self.api_key = SecretString::from(api_key.into());
108 self
109 }
110
111 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
113 self.api_base = api_base.into();
114 self
115 }
116
117 pub fn org_id(&self) -> &str {
118 &self.org_id
119 }
120}
121
122impl Config for OpenAIConfig {
123 fn headers(&self) -> HeaderMap {
124 let mut headers = HeaderMap::new();
125 if !self.org_id.is_empty() {
126 headers.insert(
127 OPENAI_ORGANIZATION_HEADER,
128 self.org_id.as_str().parse().unwrap(),
129 );
130 }
131
132 if !self.project_id.is_empty() {
133 headers.insert(
134 OPENAI_PROJECT_HEADER,
135 self.project_id.as_str().parse().unwrap(),
136 );
137 }
138
139 headers.insert(
140 AUTHORIZATION,
141 format!("Bearer {}", self.api_key.expose_secret())
142 .as_str()
143 .parse()
144 .unwrap(),
145 );
146
147 headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap());
150
151 headers
152 }
153
154 fn url(&self, path: &str) -> String {
155 format!("{}{}", self.api_base, path)
156 }
157
158 fn api_base(&self) -> &str {
159 &self.api_base
160 }
161
162 fn api_key(&self) -> &SecretString {
163 &self.api_key
164 }
165
166 fn query(&self) -> Vec<(&str, &str)> {
167 vec![]
168 }
169}
170
171#[derive(Clone, Debug, Deserialize)]
173#[serde(default)]
174pub struct AzureConfig {
175 api_version: String,
176 deployment_id: String,
177 api_base: String,
178 api_key: SecretString,
179}
180
181impl Default for AzureConfig {
182 fn default() -> Self {
183 Self {
184 api_base: Default::default(),
185 api_key: std::env::var("OPENAI_API_KEY")
186 .unwrap_or_else(|_| "".to_string())
187 .into(),
188 deployment_id: Default::default(),
189 api_version: Default::default(),
190 }
191 }
192}
193
194impl AzureConfig {
195 pub fn new() -> Self {
196 Default::default()
197 }
198
199 pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
200 self.api_version = api_version.into();
201 self
202 }
203
204 pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
205 self.deployment_id = deployment_id.into();
206 self
207 }
208
209 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
211 self.api_key = SecretString::from(api_key.into());
212 self
213 }
214
215 pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
217 self.api_base = api_base.into();
218 self
219 }
220}
221
222impl Config for AzureConfig {
223 fn headers(&self) -> HeaderMap {
224 let mut headers = HeaderMap::new();
225
226 headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
227
228 headers
229 }
230
231 fn url(&self, path: &str) -> String {
232 format!(
233 "{}/openai/deployments/{}{}",
234 self.api_base, self.deployment_id, path
235 )
236 }
237
238 fn api_base(&self) -> &str {
239 &self.api_base
240 }
241
242 fn api_key(&self) -> &SecretString {
243 &self.api_key
244 }
245
246 fn query(&self) -> Vec<(&str, &str)> {
247 vec![("api-version", &self.api_version)]
248 }
249}
250
251#[cfg(test)]
252mod test {
253 use super::*;
254 use crate::Client;
255 use crate::types::{
256 ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
257 };
258 use std::sync::Arc;
259 #[test]
260 fn test_client_creation() {
261 unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
262 let openai_config = OpenAIConfig::default();
263 let config = Box::new(openai_config.clone()) as Box<dyn Config>;
264 let client = Client::with_config(config);
265 assert!(client.config().url("").ends_with("/v1"));
266
267 let config = Arc::new(openai_config) as Arc<dyn Config>;
268 let client = Client::with_config(config);
269 assert!(client.config().url("").ends_with("/v1"));
270 let cloned_client = client.clone();
271 assert!(cloned_client.config().url("").ends_with("/v1"));
272 }
273
274 async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
275 let _ = client.chat().create(CreateChatCompletionRequest {
276 model: "gpt-4o".to_string(),
277 messages: vec![ChatCompletionRequestMessage::User(
278 ChatCompletionRequestUserMessage {
279 content: "Hello, world!".into(),
280 ..Default::default()
281 },
282 )],
283 ..Default::default()
284 });
285 }
286
287 #[tokio::test]
288 async fn test_dynamic_dispatch() {
289 let openai_config = OpenAIConfig::default();
290 let azure_config = AzureConfig::default();
291
292 let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
293 let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
294
295 let _ = dynamic_dispatch_compiles(&azure_client).await;
296 let _ = dynamic_dispatch_compiles(&oai_client).await;
297
298 let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
299 let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
300 }
301}