1use std::time::Duration;
5
6use crate::azure::AzureAuth;
7use crate::client::Client;
8use crate::error::OpenAIError;
9
10pub const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
11pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
13pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
15pub const DEFAULT_MAX_RETRIES: u32 = 2;
17
18#[derive(Clone)]
20pub(crate) struct AzureSettings {
21 pub auth: AzureAuth,
22 pub deployment: Option<String>,
28}
29
30#[derive(Clone)]
32pub struct Config {
33 pub(crate) api_key: String,
34 pub(crate) base_url: String,
35 pub(crate) organization: Option<String>,
36 pub(crate) project: Option<String>,
37 pub(crate) timeout: Duration,
38 pub(crate) connect_timeout: Duration,
39 pub(crate) max_retries: u32,
40 pub(crate) default_headers: Vec<(String, String)>,
41 pub(crate) default_query: Vec<(String, String)>,
42 pub(crate) azure: Option<AzureSettings>,
43}
44
45impl std::fmt::Debug for Config {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.debug_struct("Config")
48 .field("api_key", &"[REDACTED]")
49 .field("base_url", &self.base_url)
50 .field("organization", &self.organization)
51 .field("project", &self.project)
52 .field("timeout", &self.timeout)
53 .field("connect_timeout", &self.connect_timeout)
54 .field("max_retries", &self.max_retries)
55 .finish()
56 }
57}
58
59#[derive(Default, Clone)]
63pub struct ClientBuilder {
64 api_key: Option<String>,
65 base_url: Option<String>,
66 organization: Option<String>,
67 project: Option<String>,
68 timeout: Option<Duration>,
69 connect_timeout: Option<Duration>,
70 max_retries: Option<u32>,
71 default_headers: Vec<(String, String)>,
72 azure_endpoint: Option<String>,
73 azure_api_version: Option<String>,
74 azure_deployment: Option<String>,
75 azure_ad_token: Option<String>,
76}
77
78impl std::fmt::Debug for ClientBuilder {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 f.debug_struct("ClientBuilder")
81 .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
82 .field("base_url", &self.base_url)
83 .field("organization", &self.organization)
84 .field("project", &self.project)
85 .field("timeout", &self.timeout)
86 .field("connect_timeout", &self.connect_timeout)
87 .field("max_retries", &self.max_retries)
88 .field("azure_endpoint", &self.azure_endpoint)
89 .field("azure_api_version", &self.azure_api_version)
90 .field("azure_deployment", &self.azure_deployment)
91 .field(
92 "azure_ad_token",
93 &self.azure_ad_token.as_ref().map(|_| "[REDACTED]"),
94 )
95 .finish()
96 }
97}
98
99impl ClientBuilder {
100 pub fn new() -> Self {
101 Self::default()
102 }
103
104 pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
106 self.api_key = Some(api_key.into());
107 self
108 }
109
110 pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
113 self.base_url = Some(base_url.into());
114 self
115 }
116
117 pub fn organization(mut self, organization: impl Into<String>) -> Self {
119 self.organization = Some(organization.into());
120 self
121 }
122
123 pub fn project(mut self, project: impl Into<String>) -> Self {
125 self.project = Some(project.into());
126 self
127 }
128
129 pub fn timeout(mut self, timeout: Duration) -> Self {
133 self.timeout = Some(timeout);
134 self
135 }
136
137 pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self {
139 self.connect_timeout = Some(connect_timeout);
140 self
141 }
142
143 pub fn max_retries(mut self, max_retries: u32) -> Self {
145 self.max_retries = Some(max_retries);
146 self
147 }
148
149 pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
151 self.default_headers.push((name.into(), value.into()));
152 self
153 }
154
155 pub fn azure(mut self, endpoint: impl Into<String>, api_version: impl Into<String>) -> Self {
163 self.azure_endpoint = Some(endpoint.into());
164 let api_version = api_version.into();
165 if !api_version.is_empty() {
166 self.azure_api_version = Some(api_version);
167 }
168 self
169 }
170
171 pub fn azure_deployment(mut self, deployment: impl Into<String>) -> Self {
175 self.azure_deployment = Some(deployment.into());
176 self
177 }
178
179 pub fn azure_ad_token(mut self, token: impl Into<String>) -> Self {
182 self.azure_ad_token = Some(token.into());
183 self
184 }
185
186 pub fn build(self) -> Result<Client, OpenAIError> {
188 let is_azure = self.azure_endpoint.is_some();
189 let api_key = self
190 .api_key
191 .or_else(|| {
192 if is_azure {
193 std::env::var("AZURE_OPENAI_API_KEY").ok()
194 } else {
195 None
196 }
197 })
198 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
199 .filter(|k| !k.trim().is_empty());
200
201 let azure_ad_token = self
202 .azure_ad_token
203 .or_else(|| {
204 if is_azure {
205 std::env::var("AZURE_OPENAI_AD_TOKEN").ok()
206 } else {
207 None
208 }
209 })
210 .filter(|t| !t.trim().is_empty());
211
212 let (api_key, base_url, default_query, azure) = if let Some(endpoint) =
213 self.azure_endpoint
214 {
215 let api_version = self
216 .azure_api_version
217 .or_else(|| std::env::var("OPENAI_API_VERSION").ok())
218 .filter(|v| !v.trim().is_empty())
219 .ok_or_else(|| {
220 OpenAIError::Config(
221 "Azure requires an api_version: pass it to `.azure()` or set OPENAI_API_VERSION"
222 .into(),
223 )
224 })?;
225 let auth = match (&azure_ad_token, &api_key) {
226 (Some(token), _) => AzureAuth::BearerToken(token.clone()),
227 (None, Some(key)) => AzureAuth::ApiKey(key.clone()),
228 (None, None) => {
229 return Err(OpenAIError::Config(
230 "missing Azure credentials: pass `api_key`/`azure_ad_token` or set AZURE_OPENAI_API_KEY / AZURE_OPENAI_AD_TOKEN"
231 .into(),
232 ))
233 }
234 };
235 let base_url = crate::azure::azure_base_url(&endpoint, None);
238 (
239 api_key.unwrap_or_default(),
240 base_url,
241 vec![("api-version".to_string(), api_version)],
242 Some(AzureSettings {
243 auth,
244 deployment: self.azure_deployment,
245 }),
246 )
247 } else {
248 let api_key = api_key.ok_or_else(|| {
249 OpenAIError::Config(
250 "missing API key: pass `api_key` or set the OPENAI_API_KEY environment variable"
251 .into(),
252 )
253 })?;
254 let base_url = self
255 .base_url
256 .or_else(|| std::env::var("OPENAI_BASE_URL").ok())
257 .filter(|u| !u.trim().is_empty())
258 .unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
259 let base_url = base_url.trim_end_matches('/').to_string();
260 (api_key, base_url, Vec::new(), None)
261 };
262
263 let config = Config {
264 api_key,
265 base_url,
266 organization: self
267 .organization
268 .or_else(|| std::env::var("OPENAI_ORG_ID").ok()),
269 project: self
270 .project
271 .or_else(|| std::env::var("OPENAI_PROJECT_ID").ok()),
272 timeout: self.timeout.unwrap_or(DEFAULT_TIMEOUT),
273 connect_timeout: self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT),
274 max_retries: self.max_retries.unwrap_or(DEFAULT_MAX_RETRIES),
275 default_headers: self.default_headers,
276 default_query,
277 azure,
278 };
279 Client::from_config(config)
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
288 fn missing_api_key_is_config_error() {
289 if std::env::var("OPENAI_API_KEY").is_ok() {
291 return;
292 }
293 let err = ClientBuilder::new().build().unwrap_err();
294 assert!(matches!(err, OpenAIError::Config(_)));
295 }
296
297 #[test]
298 fn base_url_trailing_slash_is_trimmed() {
299 let client = ClientBuilder::new()
300 .api_key("sk-test")
301 .base_url("https://example.com/v1/")
302 .build()
303 .unwrap();
304 assert_eq!(client.base_url(), "https://example.com/v1");
305 }
306
307 #[test]
308 fn config_debug_redacts_api_key() {
309 let client = ClientBuilder::new().api_key("sk-secret").build().unwrap();
310 let debug = format!("{:?}", client.config());
311 assert!(!debug.contains("sk-secret"));
312 assert!(debug.contains("[REDACTED]"));
313 }
314}