openai_ergonomic/
config.rs

1//! Configuration for the `OpenAI` ergonomic client.
2//!
3//! This module provides configuration options for the `OpenAI` client,
4//! including API key management, base URLs, timeouts, and retry settings.
5
6use crate::{errors::Result, Error};
7use reqwest_middleware::ClientWithMiddleware;
8use std::env;
9
10/// Configuration for the `OpenAI` client.
11///
12/// The configuration can be created from environment variables or
13/// manually constructed with the builder pattern.
14///
15/// # Environment Variables
16///
17/// ## Standard `OpenAI`
18/// - `OPENAI_API_KEY`: The `OpenAI` API key (required)
19/// - `OPENAI_API_BASE`: Custom base URL for the API (optional)
20/// - `OPENAI_ORGANIZATION`: Organization ID (optional)
21/// - `OPENAI_PROJECT`: Project ID (optional)
22/// - `OPENAI_MAX_RETRIES`: Maximum number of retries (optional, default: 3)
23///
24/// ## Azure `OpenAI`
25/// - `AZURE_OPENAI_API_KEY`: The Azure `OpenAI` API key (alternative to `OPENAI_API_KEY`)
26/// - `AZURE_OPENAI_ENDPOINT`: Azure `OpenAI` endpoint (e.g., `<https://my-resource.openai.azure.com>`)
27/// - `AZURE_OPENAI_DEPLOYMENT`: Deployment name (required for Azure)
28/// - `AZURE_OPENAI_API_VERSION`: API version (optional, default: 2024-02-01)
29///
30/// # Example
31///
32/// ```rust,ignore
33/// # use openai_ergonomic::Config;
34/// // From environment variables
35/// let config = Config::from_env().unwrap();
36///
37/// // Manual configuration for OpenAI
38/// let config = Config::builder()
39///     .api_key("your-api-key")
40///     .max_retries(5)
41///     .build();
42///
43/// // Manual configuration for Azure OpenAI
44/// let config = Config::builder()
45///     .api_key("your-azure-api-key")
46///     .api_base("https://my-resource.openai.azure.com")
47///     .azure_deployment("my-deployment")
48///     .azure_api_version("2024-02-01")
49///     .build();
50/// ```
51#[derive(Clone)]
52pub struct Config {
53    api_key: String,
54    api_base: String,
55    organization: Option<String>,
56    project: Option<String>,
57    max_retries: u32,
58    default_model: String,
59    http_client: Option<ClientWithMiddleware>,
60    azure_deployment: Option<String>,
61    azure_api_version: Option<String>,
62}
63
64impl std::fmt::Debug for Config {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("Config")
67            .field("api_key", &"***")
68            .field("api_base", &self.api_base)
69            .field("organization", &self.organization)
70            .field("project", &self.project)
71            .field("max_retries", &self.max_retries)
72            .field("default_model", &self.default_model)
73            .field(
74                "http_client",
75                &self.http_client.as_ref().map(|_| "<ClientWithMiddleware>"),
76            )
77            .field("azure_deployment", &self.azure_deployment)
78            .field("azure_api_version", &self.azure_api_version)
79            .finish()
80    }
81}
82
83impl Config {
84    /// Create a new configuration builder.
85    #[must_use]
86    pub fn builder() -> ConfigBuilder {
87        ConfigBuilder::default()
88    }
89
90    /// Create configuration from environment variables.
91    ///
92    /// Supports both standard `OpenAI` and Azure `OpenAI` configurations.
93    /// For Azure `OpenAI`, set `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_API_KEY`, and `AZURE_OPENAI_DEPLOYMENT`.
94    pub fn from_env() -> Result<Self> {
95        // Check for Azure OpenAI configuration first
96        let azure_endpoint = env::var("AZURE_OPENAI_ENDPOINT").ok();
97        let azure_deployment = env::var("AZURE_OPENAI_DEPLOYMENT").ok();
98        let azure_api_version = env::var("AZURE_OPENAI_API_VERSION").ok();
99
100        let (api_key, api_base) = if let Some(endpoint) = azure_endpoint {
101            // Azure OpenAI configuration
102            let key = env::var("AZURE_OPENAI_API_KEY")
103                .or_else(|_| env::var("OPENAI_API_KEY"))
104                .map_err(|_| {
105                    Error::Config(
106                        "AZURE_OPENAI_API_KEY or OPENAI_API_KEY environment variable is required"
107                            .to_string(),
108                    )
109                })?;
110            // Trim trailing slash from Azure endpoint
111            let endpoint = endpoint.trim_end_matches('/').to_string();
112            (key, endpoint)
113        } else {
114            // Standard OpenAI configuration
115            let key = env::var("OPENAI_API_KEY").map_err(|_| {
116                Error::Config("OPENAI_API_KEY environment variable is required".to_string())
117            })?;
118            let base = env::var("OPENAI_API_BASE")
119                .unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
120            (key, base)
121        };
122
123        let organization = env::var("OPENAI_ORGANIZATION").ok();
124        let project = env::var("OPENAI_PROJECT").ok();
125
126        let max_retries = env::var("OPENAI_MAX_RETRIES")
127            .ok()
128            .and_then(|s| s.parse().ok())
129            .unwrap_or(3);
130
131        let default_model =
132            env::var("OPENAI_DEFAULT_MODEL").unwrap_or_else(|_| "gpt-4".to_string());
133
134        Ok(Self {
135            api_key,
136            api_base,
137            organization,
138            project,
139            max_retries,
140            default_model,
141            http_client: None,
142            azure_deployment,
143            azure_api_version,
144        })
145    }
146
147    /// Get the API key.
148    pub fn api_key(&self) -> &str {
149        &self.api_key
150    }
151
152    /// Get the API base URL.
153    pub fn api_base(&self) -> &str {
154        &self.api_base
155    }
156
157    /// Get the organization ID, if set.
158    pub fn organization(&self) -> Option<&str> {
159        self.organization.as_deref()
160    }
161
162    /// Get the project ID, if set.
163    pub fn project(&self) -> Option<&str> {
164        self.project.as_deref()
165    }
166
167    /// Get the maximum number of retries.
168    pub fn max_retries(&self) -> u32 {
169        self.max_retries
170    }
171
172    /// Get the default model to use.
173    pub fn default_model(&self) -> Option<&str> {
174        if self.default_model.is_empty() {
175            None
176        } else {
177            Some(&self.default_model)
178        }
179    }
180
181    /// Get the base URL, if different from default.
182    pub fn base_url(&self) -> Option<&str> {
183        if self.api_base == "https://api.openai.com/v1" {
184            None
185        } else {
186            Some(&self.api_base)
187        }
188    }
189
190    /// Get the organization ID, if set.
191    pub fn organization_id(&self) -> Option<&str> {
192        self.organization.as_deref()
193    }
194
195    /// Create an authorization header value.
196    pub fn auth_header(&self) -> String {
197        format!("Bearer {}", self.api_key)
198    }
199
200    /// Get the custom HTTP client, if set.
201    pub fn http_client(&self) -> Option<&ClientWithMiddleware> {
202        self.http_client.as_ref()
203    }
204
205    /// Get the Azure deployment name, if set.
206    pub fn azure_deployment(&self) -> Option<&str> {
207        self.azure_deployment.as_deref()
208    }
209
210    /// Get the Azure API version, if set.
211    pub fn azure_api_version(&self) -> Option<&str> {
212        self.azure_api_version.as_deref()
213    }
214
215    /// Check if this configuration is for Azure `OpenAI`.
216    pub fn is_azure(&self) -> bool {
217        self.azure_deployment.is_some() || self.api_base.contains(".openai.azure.com")
218    }
219}
220
221impl Default for Config {
222    fn default() -> Self {
223        Self {
224            api_key: String::new(),
225            api_base: "https://api.openai.com/v1".to_string(),
226            organization: None,
227            project: None,
228            max_retries: 3,
229            default_model: "gpt-4".to_string(),
230            http_client: None,
231            azure_deployment: None,
232            azure_api_version: None,
233        }
234    }
235}
236
237/// Builder for creating `OpenAI` client configuration.
238#[derive(Clone, Default)]
239pub struct ConfigBuilder {
240    api_key: Option<String>,
241    api_base: Option<String>,
242    organization: Option<String>,
243    project: Option<String>,
244    max_retries: Option<u32>,
245    default_model: Option<String>,
246    http_client: Option<ClientWithMiddleware>,
247    azure_deployment: Option<String>,
248    azure_api_version: Option<String>,
249}
250
251impl ConfigBuilder {
252    /// Set the API key.
253    #[must_use]
254    pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
255        self.api_key = Some(api_key.into());
256        self
257    }
258
259    /// Set the API base URL.
260    #[must_use]
261    pub fn api_base(mut self, api_base: impl Into<String>) -> Self {
262        self.api_base = Some(api_base.into());
263        self
264    }
265
266    /// Set the organization ID.
267    #[must_use]
268    pub fn organization(mut self, organization: impl Into<String>) -> Self {
269        self.organization = Some(organization.into());
270        self
271    }
272
273    /// Set the project ID.
274    #[must_use]
275    pub fn project(mut self, project: impl Into<String>) -> Self {
276        self.project = Some(project.into());
277        self
278    }
279
280    /// Set the maximum number of retries.
281    #[must_use]
282    pub fn max_retries(mut self, max_retries: u32) -> Self {
283        self.max_retries = Some(max_retries);
284        self
285    }
286
287    /// Set the default model to use.
288    #[must_use]
289    pub fn default_model(mut self, default_model: impl Into<String>) -> Self {
290        self.default_model = Some(default_model.into());
291        self
292    }
293
294    /// Set a custom HTTP client.
295    ///
296    /// This allows you to provide a pre-configured `ClientWithMiddleware` with
297    /// custom settings like retry policies, connection pooling, proxies, etc.
298    ///
299    /// # Example
300    ///
301    /// ```rust,ignore
302    /// use reqwest_middleware::ClientBuilder;
303    /// use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
304    ///
305    /// let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
306    /// let client = ClientBuilder::new(reqwest::Client::new())
307    ///     .with(RetryTransientMiddleware::new_with_policy(retry_policy))
308    ///     .build();
309    ///
310    /// let config = Config::builder()
311    ///     .api_key("sk-...")
312    ///     .http_client(client)
313    ///     .build();
314    /// ```
315    #[must_use]
316    pub fn http_client(mut self, client: ClientWithMiddleware) -> Self {
317        self.http_client = Some(client);
318        self
319    }
320
321    /// Set the Azure deployment name.
322    ///
323    /// Required when using Azure `OpenAI`.
324    #[must_use]
325    pub fn azure_deployment(mut self, deployment: impl Into<String>) -> Self {
326        self.azure_deployment = Some(deployment.into());
327        self
328    }
329
330    /// Set the Azure API version.
331    ///
332    /// Defaults to "2024-02-01" if not specified.
333    #[must_use]
334    pub fn azure_api_version(mut self, version: impl Into<String>) -> Self {
335        self.azure_api_version = Some(version.into());
336        self
337    }
338
339    /// Build the configuration.
340    #[must_use]
341    pub fn build(self) -> Config {
342        Config {
343            api_key: self.api_key.unwrap_or_default(),
344            api_base: self
345                .api_base
346                .unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
347            organization: self.organization,
348            project: self.project,
349            max_retries: self.max_retries.unwrap_or(3),
350            default_model: self.default_model.unwrap_or_else(|| "gpt-4".to_string()),
351            http_client: self.http_client,
352            azure_deployment: self.azure_deployment,
353            azure_api_version: self.azure_api_version,
354        }
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use std::time::Duration;
362
363    #[test]
364    fn test_config_builder() {
365        let config = Config::builder().api_key("test-key").max_retries(5).build();
366
367        assert_eq!(config.api_key(), "test-key");
368        assert_eq!(config.max_retries(), 5);
369        assert_eq!(config.api_base(), "https://api.openai.com/v1");
370    }
371
372    #[test]
373    fn test_auth_header() {
374        let config = Config::builder().api_key("test-key").build();
375
376        assert_eq!(config.auth_header(), "Bearer test-key");
377    }
378
379    #[test]
380    fn test_default_config() {
381        let config = Config::default();
382        assert_eq!(config.max_retries(), 3);
383        assert_eq!(config.default_model(), Some("gpt-4"));
384    }
385
386    #[test]
387    fn test_config_with_custom_http_client() {
388        let http_client = reqwest_middleware::ClientBuilder::new(
389            reqwest::Client::builder()
390                .timeout(Duration::from_secs(30))
391                .build()
392                .unwrap(),
393        )
394        .build();
395
396        let config = Config::builder()
397            .api_key("test-key")
398            .http_client(http_client)
399            .build();
400
401        assert!(config.http_client().is_some());
402    }
403
404    #[test]
405    fn test_config_without_custom_http_client() {
406        let config = Config::builder().api_key("test-key").build();
407
408        assert!(config.http_client().is_none());
409    }
410
411    #[test]
412    fn test_config_debug_hides_sensitive_data() {
413        let config = Config::builder().api_key("secret-key-12345").build();
414
415        let debug_output = format!("{config:?}");
416
417        // Should not contain the actual API key
418        assert!(!debug_output.contains("secret-key-12345"));
419        // Should contain the masked version
420        assert!(debug_output.contains("***"));
421    }
422
423    #[test]
424    fn test_config_debug_with_http_client() {
425        let http_client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build();
426        let config = Config::builder()
427            .api_key("test-key")
428            .http_client(http_client)
429            .build();
430
431        let debug_output = format!("{config:?}");
432
433        // Should show placeholder for HTTP client
434        assert!(debug_output.contains("<ClientWithMiddleware>"));
435    }
436}