openai_ergonomic/
config.rs1use crate::{errors::Result, Error};
7use reqwest_middleware::ClientWithMiddleware;
8use std::env;
9
10#[derive(Clone)]
37pub struct Config {
38 api_key: String,
39 api_base: String,
40 organization: Option<String>,
41 project: Option<String>,
42 max_retries: u32,
43 default_model: String,
44 http_client: Option<ClientWithMiddleware>,
45}
46
47impl std::fmt::Debug for Config {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 f.debug_struct("Config")
50 .field("api_key", &"***")
51 .field("api_base", &self.api_base)
52 .field("organization", &self.organization)
53 .field("project", &self.project)
54 .field("max_retries", &self.max_retries)
55 .field("default_model", &self.default_model)
56 .field(
57 "http_client",
58 &self.http_client.as_ref().map(|_| "<ClientWithMiddleware>"),
59 )
60 .finish()
61 }
62}
63
64impl Config {
65 #[must_use]
67 pub fn builder() -> ConfigBuilder {
68 ConfigBuilder::default()
69 }
70
71 pub fn from_env() -> Result<Self> {
73 let api_key = env::var("OPENAI_API_KEY").map_err(|_| {
74 Error::Config("OPENAI_API_KEY environment variable is required".to_string())
75 })?;
76
77 let api_base =
78 env::var("OPENAI_API_BASE").unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
79
80 let organization = env::var("OPENAI_ORGANIZATION").ok();
81 let project = env::var("OPENAI_PROJECT").ok();
82
83 let max_retries = env::var("OPENAI_MAX_RETRIES")
84 .ok()
85 .and_then(|s| s.parse().ok())
86 .unwrap_or(3);
87
88 let default_model =
89 env::var("OPENAI_DEFAULT_MODEL").unwrap_or_else(|_| "gpt-4".to_string());
90
91 Ok(Self {
92 api_key,
93 api_base,
94 organization,
95 project,
96 max_retries,
97 default_model,
98 http_client: None,
99 })
100 }
101
102 pub fn api_key(&self) -> &str {
104 &self.api_key
105 }
106
107 pub fn api_base(&self) -> &str {
109 &self.api_base
110 }
111
112 pub fn organization(&self) -> Option<&str> {
114 self.organization.as_deref()
115 }
116
117 pub fn project(&self) -> Option<&str> {
119 self.project.as_deref()
120 }
121
122 pub fn max_retries(&self) -> u32 {
124 self.max_retries
125 }
126
127 pub fn default_model(&self) -> Option<&str> {
129 if self.default_model.is_empty() {
130 None
131 } else {
132 Some(&self.default_model)
133 }
134 }
135
136 pub fn base_url(&self) -> Option<&str> {
138 if self.api_base == "https://api.openai.com/v1" {
139 None
140 } else {
141 Some(&self.api_base)
142 }
143 }
144
145 pub fn organization_id(&self) -> Option<&str> {
147 self.organization.as_deref()
148 }
149
150 pub fn auth_header(&self) -> String {
152 format!("Bearer {}", self.api_key)
153 }
154
155 pub fn http_client(&self) -> Option<&ClientWithMiddleware> {
157 self.http_client.as_ref()
158 }
159}
160
161impl Default for Config {
162 fn default() -> Self {
163 Self {
164 api_key: String::new(),
165 api_base: "https://api.openai.com/v1".to_string(),
166 organization: None,
167 project: None,
168 max_retries: 3,
169 default_model: "gpt-4".to_string(),
170 http_client: None,
171 }
172 }
173}
174
175#[derive(Clone, Default)]
177pub struct ConfigBuilder {
178 api_key: Option<String>,
179 api_base: Option<String>,
180 organization: Option<String>,
181 project: Option<String>,
182 max_retries: Option<u32>,
183 default_model: Option<String>,
184 http_client: Option<ClientWithMiddleware>,
185}
186
187impl ConfigBuilder {
188 #[must_use]
190 pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
191 self.api_key = Some(api_key.into());
192 self
193 }
194
195 #[must_use]
197 pub fn api_base(mut self, api_base: impl Into<String>) -> Self {
198 self.api_base = Some(api_base.into());
199 self
200 }
201
202 #[must_use]
204 pub fn organization(mut self, organization: impl Into<String>) -> Self {
205 self.organization = Some(organization.into());
206 self
207 }
208
209 #[must_use]
211 pub fn project(mut self, project: impl Into<String>) -> Self {
212 self.project = Some(project.into());
213 self
214 }
215
216 #[must_use]
218 pub fn max_retries(mut self, max_retries: u32) -> Self {
219 self.max_retries = Some(max_retries);
220 self
221 }
222
223 #[must_use]
225 pub fn default_model(mut self, default_model: impl Into<String>) -> Self {
226 self.default_model = Some(default_model.into());
227 self
228 }
229
230 #[must_use]
252 pub fn http_client(mut self, client: ClientWithMiddleware) -> Self {
253 self.http_client = Some(client);
254 self
255 }
256
257 #[must_use]
259 pub fn build(self) -> Config {
260 Config {
261 api_key: self.api_key.unwrap_or_default(),
262 api_base: self
263 .api_base
264 .unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
265 organization: self.organization,
266 project: self.project,
267 max_retries: self.max_retries.unwrap_or(3),
268 default_model: self.default_model.unwrap_or_else(|| "gpt-4".to_string()),
269 http_client: self.http_client,
270 }
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use std::time::Duration;
278
279 #[test]
280 fn test_config_builder() {
281 let config = Config::builder().api_key("test-key").max_retries(5).build();
282
283 assert_eq!(config.api_key(), "test-key");
284 assert_eq!(config.max_retries(), 5);
285 assert_eq!(config.api_base(), "https://api.openai.com/v1");
286 }
287
288 #[test]
289 fn test_auth_header() {
290 let config = Config::builder().api_key("test-key").build();
291
292 assert_eq!(config.auth_header(), "Bearer test-key");
293 }
294
295 #[test]
296 fn test_default_config() {
297 let config = Config::default();
298 assert_eq!(config.max_retries(), 3);
299 assert_eq!(config.default_model(), Some("gpt-4"));
300 }
301
302 #[test]
303 fn test_config_with_custom_http_client() {
304 let http_client = reqwest_middleware::ClientBuilder::new(
305 reqwest::Client::builder()
306 .timeout(Duration::from_secs(30))
307 .build()
308 .unwrap(),
309 )
310 .build();
311
312 let config = Config::builder()
313 .api_key("test-key")
314 .http_client(http_client)
315 .build();
316
317 assert!(config.http_client().is_some());
318 }
319
320 #[test]
321 fn test_config_without_custom_http_client() {
322 let config = Config::builder().api_key("test-key").build();
323
324 assert!(config.http_client().is_none());
325 }
326
327 #[test]
328 fn test_config_debug_hides_sensitive_data() {
329 let config = Config::builder().api_key("secret-key-12345").build();
330
331 let debug_output = format!("{config:?}");
332
333 assert!(!debug_output.contains("secret-key-12345"));
335 assert!(debug_output.contains("***"));
337 }
338
339 #[test]
340 fn test_config_debug_with_http_client() {
341 let http_client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build();
342 let config = Config::builder()
343 .api_key("test-key")
344 .http_client(http_client)
345 .build();
346
347 let debug_output = format!("{config:?}");
348
349 assert!(debug_output.contains("<ClientWithMiddleware>"));
351 }
352}