openai_ergonomic/
config.rs1use crate::{errors::Result, Error};
7use reqwest_middleware::ClientWithMiddleware;
8use std::env;
9
10#[derive(Clone)]
52pub struct Config {
53 api_key: String,
54 api_base: String,
55 organization: Option<String>,
56 project: Option<String>,
57 max_retries: u32,
58 default_model: String,
59 http_client: Option<ClientWithMiddleware>,
60 azure_deployment: Option<String>,
61 azure_api_version: Option<String>,
62}
63
64impl std::fmt::Debug for Config {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("Config")
67 .field("api_key", &"***")
68 .field("api_base", &self.api_base)
69 .field("organization", &self.organization)
70 .field("project", &self.project)
71 .field("max_retries", &self.max_retries)
72 .field("default_model", &self.default_model)
73 .field(
74 "http_client",
75 &self.http_client.as_ref().map(|_| "<ClientWithMiddleware>"),
76 )
77 .field("azure_deployment", &self.azure_deployment)
78 .field("azure_api_version", &self.azure_api_version)
79 .finish()
80 }
81}
82
83impl Config {
84 #[must_use]
86 pub fn builder() -> ConfigBuilder {
87 ConfigBuilder::default()
88 }
89
90 pub fn from_env() -> Result<Self> {
95 let azure_endpoint = env::var("AZURE_OPENAI_ENDPOINT").ok();
97 let azure_deployment = env::var("AZURE_OPENAI_DEPLOYMENT").ok();
98 let azure_api_version = env::var("AZURE_OPENAI_API_VERSION").ok();
99
100 let (api_key, api_base) = if let Some(endpoint) = azure_endpoint {
101 let key = env::var("AZURE_OPENAI_API_KEY")
103 .or_else(|_| env::var("OPENAI_API_KEY"))
104 .map_err(|_| {
105 Error::Config(
106 "AZURE_OPENAI_API_KEY or OPENAI_API_KEY environment variable is required"
107 .to_string(),
108 )
109 })?;
110 let endpoint = endpoint.trim_end_matches('/').to_string();
112 (key, endpoint)
113 } else {
114 let key = env::var("OPENAI_API_KEY").map_err(|_| {
116 Error::Config("OPENAI_API_KEY environment variable is required".to_string())
117 })?;
118 let base = env::var("OPENAI_API_BASE")
119 .unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
120 (key, base)
121 };
122
123 let organization = env::var("OPENAI_ORGANIZATION").ok();
124 let project = env::var("OPENAI_PROJECT").ok();
125
126 let max_retries = env::var("OPENAI_MAX_RETRIES")
127 .ok()
128 .and_then(|s| s.parse().ok())
129 .unwrap_or(3);
130
131 let default_model =
132 env::var("OPENAI_DEFAULT_MODEL").unwrap_or_else(|_| "gpt-4".to_string());
133
134 Ok(Self {
135 api_key,
136 api_base,
137 organization,
138 project,
139 max_retries,
140 default_model,
141 http_client: None,
142 azure_deployment,
143 azure_api_version,
144 })
145 }
146
147 pub fn api_key(&self) -> &str {
149 &self.api_key
150 }
151
152 pub fn api_base(&self) -> &str {
154 &self.api_base
155 }
156
157 pub fn organization(&self) -> Option<&str> {
159 self.organization.as_deref()
160 }
161
162 pub fn project(&self) -> Option<&str> {
164 self.project.as_deref()
165 }
166
167 pub fn max_retries(&self) -> u32 {
169 self.max_retries
170 }
171
172 pub fn default_model(&self) -> Option<&str> {
174 if self.default_model.is_empty() {
175 None
176 } else {
177 Some(&self.default_model)
178 }
179 }
180
181 pub fn base_url(&self) -> Option<&str> {
183 if self.api_base == "https://api.openai.com/v1" {
184 None
185 } else {
186 Some(&self.api_base)
187 }
188 }
189
190 pub fn organization_id(&self) -> Option<&str> {
192 self.organization.as_deref()
193 }
194
195 pub fn auth_header(&self) -> String {
197 format!("Bearer {}", self.api_key)
198 }
199
200 pub fn http_client(&self) -> Option<&ClientWithMiddleware> {
202 self.http_client.as_ref()
203 }
204
205 pub fn azure_deployment(&self) -> Option<&str> {
207 self.azure_deployment.as_deref()
208 }
209
210 pub fn azure_api_version(&self) -> Option<&str> {
212 self.azure_api_version.as_deref()
213 }
214
215 pub fn is_azure(&self) -> bool {
217 self.azure_deployment.is_some() || self.api_base.contains(".openai.azure.com")
218 }
219}
220
221impl Default for Config {
222 fn default() -> Self {
223 Self {
224 api_key: String::new(),
225 api_base: "https://api.openai.com/v1".to_string(),
226 organization: None,
227 project: None,
228 max_retries: 3,
229 default_model: "gpt-4".to_string(),
230 http_client: None,
231 azure_deployment: None,
232 azure_api_version: None,
233 }
234 }
235}
236
237#[derive(Clone, Default)]
239pub struct ConfigBuilder {
240 api_key: Option<String>,
241 api_base: Option<String>,
242 organization: Option<String>,
243 project: Option<String>,
244 max_retries: Option<u32>,
245 default_model: Option<String>,
246 http_client: Option<ClientWithMiddleware>,
247 azure_deployment: Option<String>,
248 azure_api_version: Option<String>,
249}
250
251impl ConfigBuilder {
252 #[must_use]
254 pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
255 self.api_key = Some(api_key.into());
256 self
257 }
258
259 #[must_use]
261 pub fn api_base(mut self, api_base: impl Into<String>) -> Self {
262 self.api_base = Some(api_base.into());
263 self
264 }
265
266 #[must_use]
268 pub fn organization(mut self, organization: impl Into<String>) -> Self {
269 self.organization = Some(organization.into());
270 self
271 }
272
273 #[must_use]
275 pub fn project(mut self, project: impl Into<String>) -> Self {
276 self.project = Some(project.into());
277 self
278 }
279
280 #[must_use]
282 pub fn max_retries(mut self, max_retries: u32) -> Self {
283 self.max_retries = Some(max_retries);
284 self
285 }
286
287 #[must_use]
289 pub fn default_model(mut self, default_model: impl Into<String>) -> Self {
290 self.default_model = Some(default_model.into());
291 self
292 }
293
294 #[must_use]
316 pub fn http_client(mut self, client: ClientWithMiddleware) -> Self {
317 self.http_client = Some(client);
318 self
319 }
320
321 #[must_use]
325 pub fn azure_deployment(mut self, deployment: impl Into<String>) -> Self {
326 self.azure_deployment = Some(deployment.into());
327 self
328 }
329
330 #[must_use]
334 pub fn azure_api_version(mut self, version: impl Into<String>) -> Self {
335 self.azure_api_version = Some(version.into());
336 self
337 }
338
339 #[must_use]
341 pub fn build(self) -> Config {
342 Config {
343 api_key: self.api_key.unwrap_or_default(),
344 api_base: self
345 .api_base
346 .unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
347 organization: self.organization,
348 project: self.project,
349 max_retries: self.max_retries.unwrap_or(3),
350 default_model: self.default_model.unwrap_or_else(|| "gpt-4".to_string()),
351 http_client: self.http_client,
352 azure_deployment: self.azure_deployment,
353 azure_api_version: self.azure_api_version,
354 }
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use std::time::Duration;
362
363 #[test]
364 fn test_config_builder() {
365 let config = Config::builder().api_key("test-key").max_retries(5).build();
366
367 assert_eq!(config.api_key(), "test-key");
368 assert_eq!(config.max_retries(), 5);
369 assert_eq!(config.api_base(), "https://api.openai.com/v1");
370 }
371
372 #[test]
373 fn test_auth_header() {
374 let config = Config::builder().api_key("test-key").build();
375
376 assert_eq!(config.auth_header(), "Bearer test-key");
377 }
378
379 #[test]
380 fn test_default_config() {
381 let config = Config::default();
382 assert_eq!(config.max_retries(), 3);
383 assert_eq!(config.default_model(), Some("gpt-4"));
384 }
385
386 #[test]
387 fn test_config_with_custom_http_client() {
388 let http_client = reqwest_middleware::ClientBuilder::new(
389 reqwest::Client::builder()
390 .timeout(Duration::from_secs(30))
391 .build()
392 .unwrap(),
393 )
394 .build();
395
396 let config = Config::builder()
397 .api_key("test-key")
398 .http_client(http_client)
399 .build();
400
401 assert!(config.http_client().is_some());
402 }
403
404 #[test]
405 fn test_config_without_custom_http_client() {
406 let config = Config::builder().api_key("test-key").build();
407
408 assert!(config.http_client().is_none());
409 }
410
411 #[test]
412 fn test_config_debug_hides_sensitive_data() {
413 let config = Config::builder().api_key("secret-key-12345").build();
414
415 let debug_output = format!("{config:?}");
416
417 assert!(!debug_output.contains("secret-key-12345"));
419 assert!(debug_output.contains("***"));
421 }
422
423 #[test]
424 fn test_config_debug_with_http_client() {
425 let http_client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build();
426 let config = Config::builder()
427 .api_key("test-key")
428 .http_client(http_client)
429 .build();
430
431 let debug_output = format!("{config:?}");
432
433 assert!(debug_output.contains("<ClientWithMiddleware>"));
435 }
436}