Skip to main content

openai_compat/
config.rs

1//! Client configuration and builder, mirroring `_client.py` env-var
2//! fallbacks and `_constants.py` defaults.
3
4use std::time::Duration;
5
6use crate::azure::AzureAuth;
7use crate::client::Client;
8use crate::error::OpenAIError;
9
10pub const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
11/// Default request timeout (600s), from `_constants.py::DEFAULT_TIMEOUT`.
12pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
13/// Default connect timeout (5s), from `_constants.py::DEFAULT_TIMEOUT`.
14pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
15/// Default retry count, from `_constants.py::DEFAULT_MAX_RETRIES`.
16pub const DEFAULT_MAX_RETRIES: u32 = 2;
17
18/// Azure-specific settings resolved at build time.
19#[derive(Clone)]
20pub(crate) struct AzureSettings {
21    pub auth: AzureAuth,
22    /// Pinned deployment. Deployments endpoints are routed to
23    /// `/deployments/{deployment}{path}`; when unset, the deployment is
24    /// derived from the request body's `model`. Non-deployment endpoints
25    /// (e.g. `/models`, `/files`) are never given a deployment segment,
26    /// mirroring `lib/azure.py::_prepare_url`.
27    pub deployment: Option<String>,
28}
29
30/// Resolved client configuration.
31#[derive(Clone)]
32pub struct Config {
33    pub(crate) api_key: String,
34    pub(crate) base_url: String,
35    pub(crate) organization: Option<String>,
36    pub(crate) project: Option<String>,
37    pub(crate) timeout: Duration,
38    pub(crate) connect_timeout: Duration,
39    pub(crate) max_retries: u32,
40    pub(crate) default_headers: Vec<(String, String)>,
41    pub(crate) default_query: Vec<(String, String)>,
42    pub(crate) azure: Option<AzureSettings>,
43}
44
45impl std::fmt::Debug for Config {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        f.debug_struct("Config")
48            .field("api_key", &"[REDACTED]")
49            .field("base_url", &self.base_url)
50            .field("organization", &self.organization)
51            .field("project", &self.project)
52            .field("timeout", &self.timeout)
53            .field("connect_timeout", &self.connect_timeout)
54            .field("max_retries", &self.max_retries)
55            .finish()
56    }
57}
58
59/// Builder for [`Client`]. Unset fields fall back to the same environment
60/// variables the Python SDK uses: `OPENAI_API_KEY`, `OPENAI_BASE_URL`,
61/// `OPENAI_ORG_ID`, `OPENAI_PROJECT_ID`.
62#[derive(Default, Clone)]
63pub struct ClientBuilder {
64    api_key: Option<String>,
65    base_url: Option<String>,
66    organization: Option<String>,
67    project: Option<String>,
68    timeout: Option<Duration>,
69    connect_timeout: Option<Duration>,
70    max_retries: Option<u32>,
71    default_headers: Vec<(String, String)>,
72    azure_endpoint: Option<String>,
73    azure_api_version: Option<String>,
74    azure_deployment: Option<String>,
75    azure_ad_token: Option<String>,
76}
77
78impl std::fmt::Debug for ClientBuilder {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        f.debug_struct("ClientBuilder")
81            .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
82            .field("base_url", &self.base_url)
83            .field("organization", &self.organization)
84            .field("project", &self.project)
85            .field("timeout", &self.timeout)
86            .field("connect_timeout", &self.connect_timeout)
87            .field("max_retries", &self.max_retries)
88            .field("azure_endpoint", &self.azure_endpoint)
89            .field("azure_api_version", &self.azure_api_version)
90            .field("azure_deployment", &self.azure_deployment)
91            .field(
92                "azure_ad_token",
93                &self.azure_ad_token.as_ref().map(|_| "[REDACTED]"),
94            )
95            .finish()
96    }
97}
98
99impl ClientBuilder {
100    pub fn new() -> Self {
101        Self::default()
102    }
103
104    /// API key used in the `Authorization: Bearer` header.
105    pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
106        self.api_key = Some(api_key.into());
107        self
108    }
109
110    /// Base URL of the API, e.g. `https://api.openai.com/v1` or any
111    /// OpenAI-compatible provider endpoint.
112    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
113        self.base_url = Some(base_url.into());
114        self
115    }
116
117    /// Sent as the `OpenAI-Organization` header.
118    pub fn organization(mut self, organization: impl Into<String>) -> Self {
119        self.organization = Some(organization.into());
120        self
121    }
122
123    /// Sent as the `OpenAI-Project` header.
124    pub fn project(mut self, project: impl Into<String>) -> Self {
125        self.project = Some(project.into());
126        self
127    }
128
129    /// Read timeout, applied per read operation (mirroring the Python SDK's
130    /// httpx behavior) so streaming responses are not cut off by a total
131    /// deadline. Defaults to 600 seconds.
132    pub fn timeout(mut self, timeout: Duration) -> Self {
133        self.timeout = Some(timeout);
134        self
135    }
136
137    /// Connection timeout. Defaults to 5 seconds.
138    pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self {
139        self.connect_timeout = Some(connect_timeout);
140        self
141    }
142
143    /// Maximum number of retries for retryable failures. Defaults to 2.
144    pub fn max_retries(mut self, max_retries: u32) -> Self {
145        self.max_retries = Some(max_retries);
146        self
147    }
148
149    /// Add a header sent with every request.
150    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
151        self.default_headers.push((name.into(), value.into()));
152        self
153    }
154
155    /// Target an Azure OpenAI resource, mirroring the Python `AzureOpenAI`
156    /// client: requests go to `{endpoint}/openai[...]` with an `api-version`
157    /// query parameter, authenticated via `api-key` header (or a bearer
158    /// token set with [`ClientBuilder::azure_ad_token`]).
159    ///
160    /// The API key falls back to the `AZURE_OPENAI_API_KEY` environment
161    /// variable; `api_version` may be empty to use `OPENAI_API_VERSION`.
162    pub fn azure(mut self, endpoint: impl Into<String>, api_version: impl Into<String>) -> Self {
163        self.azure_endpoint = Some(endpoint.into());
164        let api_version = api_version.into();
165        if !api_version.is_empty() {
166            self.azure_api_version = Some(api_version);
167        }
168        self
169    }
170
171    /// Pin all requests to a specific Azure deployment
172    /// (`{endpoint}/openai/deployments/{deployment}`). Without this, the
173    /// deployment is derived per request from the body's `model` field.
174    pub fn azure_deployment(mut self, deployment: impl Into<String>) -> Self {
175        self.azure_deployment = Some(deployment.into());
176        self
177    }
178
179    /// Authenticate to Azure with an Entra ID (Azure AD) bearer token
180    /// instead of an API key. Falls back to `AZURE_OPENAI_AD_TOKEN`.
181    pub fn azure_ad_token(mut self, token: impl Into<String>) -> Self {
182        self.azure_ad_token = Some(token.into());
183        self
184    }
185
186    /// Resolve environment fallbacks and construct the [`Client`].
187    pub fn build(self) -> Result<Client, OpenAIError> {
188        let is_azure = self.azure_endpoint.is_some();
189        let api_key = self
190            .api_key
191            .or_else(|| {
192                if is_azure {
193                    std::env::var("AZURE_OPENAI_API_KEY").ok()
194                } else {
195                    None
196                }
197            })
198            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
199            .filter(|k| !k.trim().is_empty());
200
201        let azure_ad_token = self
202            .azure_ad_token
203            .or_else(|| {
204                if is_azure {
205                    std::env::var("AZURE_OPENAI_AD_TOKEN").ok()
206                } else {
207                    None
208                }
209            })
210            .filter(|t| !t.trim().is_empty());
211
212        let (api_key, base_url, default_query, azure) = if let Some(endpoint) =
213            self.azure_endpoint
214        {
215            let api_version = self
216                .azure_api_version
217                .or_else(|| std::env::var("OPENAI_API_VERSION").ok())
218                .filter(|v| !v.trim().is_empty())
219                .ok_or_else(|| {
220                    OpenAIError::Config(
221                        "Azure requires an api_version: pass it to `.azure()` or set OPENAI_API_VERSION"
222                            .into(),
223                    )
224                })?;
225            let auth = match (&azure_ad_token, &api_key) {
226                (Some(token), _) => AzureAuth::BearerToken(token.clone()),
227                (None, Some(key)) => AzureAuth::ApiKey(key.clone()),
228                (None, None) => {
229                    return Err(OpenAIError::Config(
230                        "missing Azure credentials: pass `api_key`/`azure_ad_token` or set AZURE_OPENAI_API_KEY / AZURE_OPENAI_AD_TOKEN"
231                            .into(),
232                    ))
233                }
234            };
235            // Base URL is always `{endpoint}/openai`; any deployment segment
236            // is added per request so non-deployment endpoints stay correct.
237            let base_url = crate::azure::azure_base_url(&endpoint, None);
238            (
239                api_key.unwrap_or_default(),
240                base_url,
241                vec![("api-version".to_string(), api_version)],
242                Some(AzureSettings {
243                    auth,
244                    deployment: self.azure_deployment,
245                }),
246            )
247        } else {
248            let api_key = api_key.ok_or_else(|| {
249                OpenAIError::Config(
250                    "missing API key: pass `api_key` or set the OPENAI_API_KEY environment variable"
251                        .into(),
252                )
253            })?;
254            let base_url = self
255                .base_url
256                .or_else(|| std::env::var("OPENAI_BASE_URL").ok())
257                .filter(|u| !u.trim().is_empty())
258                .unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
259            let base_url = base_url.trim_end_matches('/').to_string();
260            (api_key, base_url, Vec::new(), None)
261        };
262
263        let config = Config {
264            api_key,
265            base_url,
266            organization: self
267                .organization
268                .or_else(|| std::env::var("OPENAI_ORG_ID").ok()),
269            project: self
270                .project
271                .or_else(|| std::env::var("OPENAI_PROJECT_ID").ok()),
272            timeout: self.timeout.unwrap_or(DEFAULT_TIMEOUT),
273            connect_timeout: self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT),
274            max_retries: self.max_retries.unwrap_or(DEFAULT_MAX_RETRIES),
275            default_headers: self.default_headers,
276            default_query,
277            azure,
278        };
279        Client::from_config(config)
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    #[test]
288    fn missing_api_key_is_config_error() {
289        // Only meaningful when the env var is absent; skip otherwise.
290        if std::env::var("OPENAI_API_KEY").is_ok() {
291            return;
292        }
293        let err = ClientBuilder::new().build().unwrap_err();
294        assert!(matches!(err, OpenAIError::Config(_)));
295    }
296
297    #[test]
298    fn base_url_trailing_slash_is_trimmed() {
299        let client = ClientBuilder::new()
300            .api_key("sk-test")
301            .base_url("https://example.com/v1/")
302            .build()
303            .unwrap();
304        assert_eq!(client.base_url(), "https://example.com/v1");
305    }
306
307    #[test]
308    fn config_debug_redacts_api_key() {
309        let client = ClientBuilder::new().api_key("sk-secret").build().unwrap();
310        let debug = format!("{:?}", client.config());
311        assert!(!debug.contains("sk-secret"));
312        assert!(debug.contains("[REDACTED]"));
313    }
314}