openai_ergonomic/
config.rs

1//! Configuration for the `OpenAI` ergonomic client.
2//!
3//! This module provides configuration options for the `OpenAI` client,
4//! including API key management, base URLs, timeouts, and retry settings.
5
6use crate::{errors::Result, Error};
7use reqwest_middleware::ClientWithMiddleware;
8use std::env;
9
10/// Configuration for the `OpenAI` client.
11///
12/// The configuration can be created from environment variables or
13/// manually constructed with the builder pattern.
14///
15/// # Environment Variables
16///
17/// - `OPENAI_API_KEY`: The `OpenAI` API key (required)
18/// - `OPENAI_API_BASE`: Custom base URL for the API (optional)
19/// - `OPENAI_ORGANIZATION`: Organization ID (optional)
20/// - `OPENAI_PROJECT`: Project ID (optional)
21/// - `OPENAI_MAX_RETRIES`: Maximum number of retries (optional, default: 3)
22///
23/// # Example
24///
25/// ```rust,ignore
26/// # use openai_ergonomic::Config;
27/// // From environment variables
28/// let config = Config::from_env().unwrap();
29///
30/// // Manual configuration
31/// let config = Config::builder()
32///     .api_key("your-api-key")
33///     .max_retries(5)
34///     .build();
35/// ```
36#[derive(Clone)]
37pub struct Config {
38    api_key: String,
39    api_base: String,
40    organization: Option<String>,
41    project: Option<String>,
42    max_retries: u32,
43    default_model: String,
44    http_client: Option<ClientWithMiddleware>,
45}
46
47impl std::fmt::Debug for Config {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        f.debug_struct("Config")
50            .field("api_key", &"***")
51            .field("api_base", &self.api_base)
52            .field("organization", &self.organization)
53            .field("project", &self.project)
54            .field("max_retries", &self.max_retries)
55            .field("default_model", &self.default_model)
56            .field(
57                "http_client",
58                &self.http_client.as_ref().map(|_| "<ClientWithMiddleware>"),
59            )
60            .finish()
61    }
62}
63
64impl Config {
65    /// Create a new configuration builder.
66    #[must_use]
67    pub fn builder() -> ConfigBuilder {
68        ConfigBuilder::default()
69    }
70
71    /// Create configuration from environment variables.
72    pub fn from_env() -> Result<Self> {
73        let api_key = env::var("OPENAI_API_KEY").map_err(|_| {
74            Error::Config("OPENAI_API_KEY environment variable is required".to_string())
75        })?;
76
77        let api_base =
78            env::var("OPENAI_API_BASE").unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
79
80        let organization = env::var("OPENAI_ORGANIZATION").ok();
81        let project = env::var("OPENAI_PROJECT").ok();
82
83        let max_retries = env::var("OPENAI_MAX_RETRIES")
84            .ok()
85            .and_then(|s| s.parse().ok())
86            .unwrap_or(3);
87
88        let default_model =
89            env::var("OPENAI_DEFAULT_MODEL").unwrap_or_else(|_| "gpt-4".to_string());
90
91        Ok(Self {
92            api_key,
93            api_base,
94            organization,
95            project,
96            max_retries,
97            default_model,
98            http_client: None,
99        })
100    }
101
102    /// Get the API key.
103    pub fn api_key(&self) -> &str {
104        &self.api_key
105    }
106
107    /// Get the API base URL.
108    pub fn api_base(&self) -> &str {
109        &self.api_base
110    }
111
112    /// Get the organization ID, if set.
113    pub fn organization(&self) -> Option<&str> {
114        self.organization.as_deref()
115    }
116
117    /// Get the project ID, if set.
118    pub fn project(&self) -> Option<&str> {
119        self.project.as_deref()
120    }
121
122    /// Get the maximum number of retries.
123    pub fn max_retries(&self) -> u32 {
124        self.max_retries
125    }
126
127    /// Get the default model to use.
128    pub fn default_model(&self) -> Option<&str> {
129        if self.default_model.is_empty() {
130            None
131        } else {
132            Some(&self.default_model)
133        }
134    }
135
136    /// Get the base URL, if different from default.
137    pub fn base_url(&self) -> Option<&str> {
138        if self.api_base == "https://api.openai.com/v1" {
139            None
140        } else {
141            Some(&self.api_base)
142        }
143    }
144
145    /// Get the organization ID, if set.
146    pub fn organization_id(&self) -> Option<&str> {
147        self.organization.as_deref()
148    }
149
150    /// Create an authorization header value.
151    pub fn auth_header(&self) -> String {
152        format!("Bearer {}", self.api_key)
153    }
154
155    /// Get the custom HTTP client, if set.
156    pub fn http_client(&self) -> Option<&ClientWithMiddleware> {
157        self.http_client.as_ref()
158    }
159}
160
161impl Default for Config {
162    fn default() -> Self {
163        Self {
164            api_key: String::new(),
165            api_base: "https://api.openai.com/v1".to_string(),
166            organization: None,
167            project: None,
168            max_retries: 3,
169            default_model: "gpt-4".to_string(),
170            http_client: None,
171        }
172    }
173}
174
175/// Builder for creating `OpenAI` client configuration.
176#[derive(Clone, Default)]
177pub struct ConfigBuilder {
178    api_key: Option<String>,
179    api_base: Option<String>,
180    organization: Option<String>,
181    project: Option<String>,
182    max_retries: Option<u32>,
183    default_model: Option<String>,
184    http_client: Option<ClientWithMiddleware>,
185}
186
187impl ConfigBuilder {
188    /// Set the API key.
189    #[must_use]
190    pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
191        self.api_key = Some(api_key.into());
192        self
193    }
194
195    /// Set the API base URL.
196    #[must_use]
197    pub fn api_base(mut self, api_base: impl Into<String>) -> Self {
198        self.api_base = Some(api_base.into());
199        self
200    }
201
202    /// Set the organization ID.
203    #[must_use]
204    pub fn organization(mut self, organization: impl Into<String>) -> Self {
205        self.organization = Some(organization.into());
206        self
207    }
208
209    /// Set the project ID.
210    #[must_use]
211    pub fn project(mut self, project: impl Into<String>) -> Self {
212        self.project = Some(project.into());
213        self
214    }
215
216    /// Set the maximum number of retries.
217    #[must_use]
218    pub fn max_retries(mut self, max_retries: u32) -> Self {
219        self.max_retries = Some(max_retries);
220        self
221    }
222
223    /// Set the default model to use.
224    #[must_use]
225    pub fn default_model(mut self, default_model: impl Into<String>) -> Self {
226        self.default_model = Some(default_model.into());
227        self
228    }
229
230    /// Set a custom HTTP client.
231    ///
232    /// This allows you to provide a pre-configured `ClientWithMiddleware` with
233    /// custom settings like retry policies, connection pooling, proxies, etc.
234    ///
235    /// # Example
236    ///
237    /// ```rust,ignore
238    /// use reqwest_middleware::ClientBuilder;
239    /// use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
240    ///
241    /// let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
242    /// let client = ClientBuilder::new(reqwest::Client::new())
243    ///     .with(RetryTransientMiddleware::new_with_policy(retry_policy))
244    ///     .build();
245    ///
246    /// let config = Config::builder()
247    ///     .api_key("sk-...")
248    ///     .http_client(client)
249    ///     .build();
250    /// ```
251    #[must_use]
252    pub fn http_client(mut self, client: ClientWithMiddleware) -> Self {
253        self.http_client = Some(client);
254        self
255    }
256
257    /// Build the configuration.
258    #[must_use]
259    pub fn build(self) -> Config {
260        Config {
261            api_key: self.api_key.unwrap_or_default(),
262            api_base: self
263                .api_base
264                .unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
265            organization: self.organization,
266            project: self.project,
267            max_retries: self.max_retries.unwrap_or(3),
268            default_model: self.default_model.unwrap_or_else(|| "gpt-4".to_string()),
269            http_client: self.http_client,
270        }
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use std::time::Duration;
278
279    #[test]
280    fn test_config_builder() {
281        let config = Config::builder().api_key("test-key").max_retries(5).build();
282
283        assert_eq!(config.api_key(), "test-key");
284        assert_eq!(config.max_retries(), 5);
285        assert_eq!(config.api_base(), "https://api.openai.com/v1");
286    }
287
288    #[test]
289    fn test_auth_header() {
290        let config = Config::builder().api_key("test-key").build();
291
292        assert_eq!(config.auth_header(), "Bearer test-key");
293    }
294
295    #[test]
296    fn test_default_config() {
297        let config = Config::default();
298        assert_eq!(config.max_retries(), 3);
299        assert_eq!(config.default_model(), Some("gpt-4"));
300    }
301
302    #[test]
303    fn test_config_with_custom_http_client() {
304        let http_client = reqwest_middleware::ClientBuilder::new(
305            reqwest::Client::builder()
306                .timeout(Duration::from_secs(30))
307                .build()
308                .unwrap(),
309        )
310        .build();
311
312        let config = Config::builder()
313            .api_key("test-key")
314            .http_client(http_client)
315            .build();
316
317        assert!(config.http_client().is_some());
318    }
319
320    #[test]
321    fn test_config_without_custom_http_client() {
322        let config = Config::builder().api_key("test-key").build();
323
324        assert!(config.http_client().is_none());
325    }
326
327    #[test]
328    fn test_config_debug_hides_sensitive_data() {
329        let config = Config::builder().api_key("secret-key-12345").build();
330
331        let debug_output = format!("{config:?}");
332
333        // Should not contain the actual API key
334        assert!(!debug_output.contains("secret-key-12345"));
335        // Should contain the masked version
336        assert!(debug_output.contains("***"));
337    }
338
339    #[test]
340    fn test_config_debug_with_http_client() {
341        let http_client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build();
342        let config = Config::builder()
343            .api_key("test-key")
344            .http_client(http_client)
345            .build();
346
347        let debug_output = format!("{config:?}");
348
349        // Should show placeholder for HTTP client
350        assert!(debug_output.contains("<ClientWithMiddleware>"));
351    }
352}