1use std::env;
8
9use crate::client::OpenAI;
10use crate::config::ClientConfig;
11use crate::error::OpenAIError;
12
13const DEFAULT_API_VERSION: &str = "2024-10-21";
15
16#[derive(Debug, Clone, Default)]
41pub struct AzureConfig {
42 pub azure_endpoint: Option<String>,
44
45 pub azure_deployment: Option<String>,
47
48 pub api_version: Option<String>,
50
51 pub api_key: Option<String>,
53
54 pub azure_ad_token: Option<String>,
56}
57
58impl AzureConfig {
59 #[must_use]
61 pub fn new() -> Self {
62 Self::default()
63 }
64
65 #[must_use]
69 pub fn azure_endpoint(mut self, endpoint: impl Into<String>) -> Self {
70 self.azure_endpoint = Some(endpoint.into());
71 self
72 }
73
74 #[must_use]
79 pub fn azure_deployment(mut self, deployment: impl Into<String>) -> Self {
80 self.azure_deployment = Some(deployment.into());
81 self
82 }
83
84 #[must_use]
88 pub fn api_version(mut self, version: impl Into<String>) -> Self {
89 self.api_version = Some(version.into());
90 self
91 }
92
93 #[must_use]
97 pub fn api_key(mut self, key: impl Into<String>) -> Self {
98 self.api_key = Some(key.into());
99 self
100 }
101
102 #[must_use]
107 pub fn azure_ad_token(mut self, token: impl Into<String>) -> Self {
108 self.azure_ad_token = Some(token.into());
109 self
110 }
111
112 pub fn build(self) -> Result<OpenAI, OpenAIError> {
121 let endpoint = self.azure_endpoint.ok_or_else(|| {
122 OpenAIError::InvalidArgument(
123 "Azure endpoint is required. Set azure_endpoint() or AZURE_OPENAI_ENDPOINT env var"
124 .to_string(),
125 )
126 })?;
127
128 let api_version = self
129 .api_version
130 .unwrap_or_else(|| DEFAULT_API_VERSION.to_string());
131
132 if self.api_key.is_some() && self.azure_ad_token.is_some() {
134 return Err(OpenAIError::InvalidArgument(
135 "api_key and azure_ad_token are mutually exclusive; only one can be set"
136 .to_string(),
137 ));
138 }
139
140 let (auth_key, use_azure_api_key_header) = match (&self.api_key, &self.azure_ad_token) {
142 (Some(key), None) => (key.clone(), true),
143 (None, Some(token)) => (token.clone(), false),
144 (None, None) => {
145 return Err(OpenAIError::InvalidArgument(
146 "Azure credentials required. Set api_key() or azure_ad_token()".to_string(),
147 ));
148 }
149 _ => unreachable!(), };
151
152 let base_url = {
154 let trimmed = endpoint.trim_end_matches('/');
155 match &self.azure_deployment {
156 Some(deployment) => format!("{trimmed}/openai/deployments/{deployment}"),
157 None => format!("{trimmed}/openai"),
158 }
159 };
160
161 let config = ClientConfig::new(auth_key)
163 .base_url(base_url)
164 .default_query(vec![("api-version".to_string(), api_version)])
165 .use_azure_api_key_header(use_azure_api_key_header);
166
167 Ok(OpenAI::with_config(config))
168 }
169
170 pub fn from_env() -> Result<OpenAI, OpenAIError> {
178 let mut config = Self::new();
179
180 if let Ok(endpoint) = env::var("AZURE_OPENAI_ENDPOINT") {
181 config = config.azure_endpoint(endpoint);
182 }
183
184 if let Ok(key) = env::var("AZURE_OPENAI_API_KEY") {
185 config = config.api_key(key);
186 }
187
188 if let Ok(token) = env::var("AZURE_OPENAI_AD_TOKEN") {
189 config = config.azure_ad_token(token);
190 }
191
192 if let Ok(version) = env::var("OPENAI_API_VERSION") {
193 config = config.api_version(version);
194 }
195
196 config.build()
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
207 fn test_azure_url_with_deployment() {
208 let client = AzureConfig::new()
209 .azure_endpoint("https://my-resource.openai.azure.com")
210 .azure_deployment("gpt-4")
211 .api_key("test-key")
212 .build()
213 .unwrap();
214
215 assert_eq!(
216 client.config.base_url(),
217 "https://my-resource.openai.azure.com/openai/deployments/gpt-4"
218 );
219 }
220
221 #[test]
222 fn test_azure_url_without_deployment() {
223 let client = AzureConfig::new()
224 .azure_endpoint("https://my-resource.openai.azure.com")
225 .api_key("test-key")
226 .build()
227 .unwrap();
228
229 assert_eq!(
230 client.config.base_url(),
231 "https://my-resource.openai.azure.com/openai"
232 );
233 }
234
235 #[test]
236 fn test_azure_url_trailing_slash_stripped() {
237 let client = AzureConfig::new()
238 .azure_endpoint("https://my-resource.openai.azure.com/")
239 .azure_deployment("gpt-4")
240 .api_key("test-key")
241 .build()
242 .unwrap();
243
244 assert_eq!(
245 client.config.base_url(),
246 "https://my-resource.openai.azure.com/openai/deployments/gpt-4"
247 );
248 }
249
250 #[test]
251 fn test_azure_default_api_version() {
252 let client = AzureConfig::new()
253 .azure_endpoint("https://example.openai.azure.com")
254 .api_key("test-key")
255 .build()
256 .unwrap();
257
258 let query = client.options.query.as_ref().unwrap();
259 assert!(
260 query
261 .iter()
262 .any(|(k, v)| k == "api-version" && v == "2024-10-21")
263 );
264 }
265
266 #[test]
267 fn test_azure_custom_api_version() {
268 let client = AzureConfig::new()
269 .azure_endpoint("https://example.openai.azure.com")
270 .api_key("test-key")
271 .api_version("2024-06-01")
272 .build()
273 .unwrap();
274
275 let query = client.options.query.as_ref().unwrap();
276 assert!(
277 query
278 .iter()
279 .any(|(k, v)| k == "api-version" && v == "2024-06-01")
280 );
281 }
282
283 #[tokio::test]
286 async fn test_azure_sends_api_version_query_param() {
287 let mut server = mockito::Server::new_async().await;
288 let mock = server
289 .mock("GET", "/openai/models")
290 .match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
291 "api-version".into(),
292 "2024-10-21".into(),
293 )]))
294 .with_status(200)
295 .with_body(r#"{"data":[],"object":"list"}"#)
296 .create_async()
297 .await;
298
299 let client = AzureConfig::new()
300 .azure_endpoint(&server.url())
301 .api_key("test-key")
302 .build()
303 .unwrap();
304
305 #[derive(serde::Deserialize)]
306 struct ListResp {
307 object: String,
308 }
309
310 let resp: ListResp = client.get("/models").await.unwrap();
311 assert_eq!(resp.object, "list");
312 mock.assert_async().await;
313 }
314
315 #[tokio::test]
318 async fn test_azure_sends_api_key_header() {
319 let mut server = mockito::Server::new_async().await;
320 let mock = server
321 .mock("GET", "/openai/test")
322 .match_header("api-key", "my-azure-key")
323 .match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
324 "api-version".into(),
325 "2024-10-21".into(),
326 )]))
327 .with_status(200)
328 .with_body(r#"{"ok":true}"#)
329 .create_async()
330 .await;
331
332 let client = AzureConfig::new()
333 .azure_endpoint(&server.url())
334 .api_key("my-azure-key")
335 .build()
336 .unwrap();
337
338 #[derive(serde::Deserialize)]
339 struct Resp {
340 ok: bool,
341 }
342
343 let resp: Resp = client.get("/test").await.unwrap();
344 assert!(resp.ok);
345 mock.assert_async().await;
346 }
347
348 #[tokio::test]
349 async fn test_azure_does_not_send_bearer_auth() {
350 let mut server = mockito::Server::new_async().await;
351 let mock = server
353 .mock("GET", "/openai/test")
354 .match_header("api-key", "my-azure-key")
355 .match_header("authorization", mockito::Matcher::Missing)
356 .match_query(mockito::Matcher::Any)
357 .with_status(200)
358 .with_body(r#"{"ok":true}"#)
359 .create_async()
360 .await;
361
362 let client = AzureConfig::new()
363 .azure_endpoint(&server.url())
364 .api_key("my-azure-key")
365 .build()
366 .unwrap();
367
368 #[derive(serde::Deserialize)]
369 struct Resp {
370 ok: bool,
371 }
372
373 let resp: Resp = client.get("/test").await.unwrap();
374 assert!(resp.ok);
375 mock.assert_async().await;
376 }
377
378 #[tokio::test]
381 async fn test_azure_ad_token_sends_bearer() {
382 let mut server = mockito::Server::new_async().await;
383 let mock = server
384 .mock("GET", "/openai/test")
385 .match_header("authorization", "Bearer my-ad-token")
386 .match_query(mockito::Matcher::Any)
387 .with_status(200)
388 .with_body(r#"{"ok":true}"#)
389 .create_async()
390 .await;
391
392 let client = AzureConfig::new()
393 .azure_endpoint(&server.url())
394 .azure_ad_token("my-ad-token")
395 .build()
396 .unwrap();
397
398 #[derive(serde::Deserialize)]
399 struct Resp {
400 ok: bool,
401 }
402
403 let resp: Resp = client.get("/test").await.unwrap();
404 assert!(resp.ok);
405 mock.assert_async().await;
406 }
407
408 #[test]
411 fn test_mutual_exclusivity_error() {
412 let result = AzureConfig::new()
413 .azure_endpoint("https://example.openai.azure.com")
414 .api_key("key")
415 .azure_ad_token("token")
416 .build();
417
418 assert!(result.is_err());
419 let err = result.unwrap_err();
420 assert!(
421 err.to_string().contains("mutually exclusive"),
422 "unexpected error: {err}"
423 );
424 }
425
426 #[test]
427 fn test_no_credentials_error() {
428 let result = AzureConfig::new()
429 .azure_endpoint("https://example.openai.azure.com")
430 .build();
431
432 assert!(result.is_err());
433 let err = result.unwrap_err();
434 assert!(
435 err.to_string().contains("credentials required"),
436 "unexpected error: {err}"
437 );
438 }
439
440 #[test]
441 fn test_no_endpoint_error() {
442 let result = AzureConfig::new().api_key("key").build();
443
444 assert!(result.is_err());
445 let err = result.unwrap_err();
446 assert!(
447 err.to_string().contains("endpoint is required"),
448 "unexpected error: {err}"
449 );
450 }
451
452 #[test]
455 fn test_from_env_reads_variables() {
456 unsafe {
458 env::set_var("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com");
459 env::set_var("AZURE_OPENAI_API_KEY", "env-key");
460 env::set_var("OPENAI_API_VERSION", "2024-06-01");
461 env::remove_var("AZURE_OPENAI_AD_TOKEN");
462 }
463
464 let client = AzureConfig::from_env().unwrap();
465
466 assert_eq!(
467 client.config.base_url(),
468 "https://test.openai.azure.com/openai"
469 );
470 assert_eq!(client.config.api_key(), "env-key");
471
472 let query = client.options.query.as_ref().unwrap();
473 assert!(
474 query
475 .iter()
476 .any(|(k, v)| k == "api-version" && v == "2024-06-01")
477 );
478
479 unsafe {
481 env::remove_var("AZURE_OPENAI_ENDPOINT");
482 env::remove_var("AZURE_OPENAI_API_KEY");
483 env::remove_var("OPENAI_API_VERSION");
484 }
485 }
486
487 #[tokio::test]
490 async fn test_azure_chat_completion_e2e() {
491 let mut server = mockito::Server::new_async().await;
492 let mock = server
493 .mock("POST", "/openai/deployments/gpt-4/chat/completions")
494 .match_header("api-key", "azure-key")
495 .match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
496 "api-version".into(),
497 "2024-10-21".into(),
498 )]))
499 .with_status(200)
500 .with_header("content-type", "application/json")
501 .with_body(
502 r#"{
503 "id": "chatcmpl-azure-123",
504 "object": "chat.completion",
505 "created": 1700000000,
506 "model": "gpt-4",
507 "choices": [{
508 "index": 0,
509 "message": {
510 "role": "assistant",
511 "content": "Hello from Azure!"
512 },
513 "finish_reason": "stop"
514 }],
515 "usage": {
516 "prompt_tokens": 10,
517 "completion_tokens": 5,
518 "total_tokens": 15
519 }
520 }"#,
521 )
522 .create_async()
523 .await;
524
525 let client = AzureConfig::new()
526 .azure_endpoint(&server.url())
527 .azure_deployment("gpt-4")
528 .api_key("azure-key")
529 .build()
530 .unwrap();
531
532 use crate::types::chat::{ChatCompletionMessageParam, ChatCompletionRequest, UserContent};
533
534 let request = ChatCompletionRequest::new(
535 "gpt-4",
536 vec![ChatCompletionMessageParam::User {
537 content: UserContent::Text("Hello!".into()),
538 name: None,
539 }],
540 );
541
542 let response = client.chat().completions().create(request).await.unwrap();
543 assert_eq!(response.id, "chatcmpl-azure-123");
544 assert_eq!(
545 response.choices[0].message.content.as_deref().unwrap_or(""),
546 "Hello from Azure!"
547 );
548 mock.assert_async().await;
549 }
550}