Skip to main content

openai_oxide/
azure.rs

1// Azure OpenAI client configuration.
2//
3// Provides `AzureConfig` builder for constructing an `OpenAI` client that
4// targets Azure OpenAI deployments. Matches the Python SDK's `AzureOpenAI`
5// constructor pattern.
6
7use std::env;
8
9use crate::client::OpenAI;
10use crate::config::ClientConfig;
11use crate::error::OpenAIError;
12
13/// Default Azure API version.
14const DEFAULT_API_VERSION: &str = "2024-10-21";
15
16/// Configuration builder for Azure OpenAI deployments.
17///
18/// Azure OpenAI uses different URL construction and authentication compared
19/// to the standard OpenAI API:
20/// - URL: `{endpoint}/openai/deployments/{deployment}` or `{endpoint}/openai`
21/// - Auth: `api-key` header (not `Authorization: Bearer`)
22/// - Query: `api-version` parameter on every request
23///
24/// # Examples
25///
26/// ```ignore
27/// use openai_oxide::{OpenAI, AzureConfig};
28///
29/// let client = OpenAI::azure(
30///     AzureConfig::new()
31///         .azure_endpoint("https://my-resource.openai.azure.com")
32///         .azure_deployment("gpt-4")
33///         .api_version("2024-10-21")
34///         .api_key("my-azure-api-key")
35/// )?;
36///
37/// // All resources work the same as with the standard client
38/// let response = client.chat().completions().create(request).await?;
39/// ```
40#[derive(Debug, Clone, Default)]
41pub struct AzureConfig {
42    /// Azure endpoint URL, e.g. `https://my-resource.openai.azure.com`.
43    pub azure_endpoint: Option<String>,
44
45    /// Azure deployment name, e.g. `gpt-4`.
46    pub azure_deployment: Option<String>,
47
48    /// Azure API version, e.g. `2024-10-21`.
49    pub api_version: Option<String>,
50
51    /// Azure API key (mutually exclusive with `azure_ad_token`).
52    pub api_key: Option<String>,
53
54    /// Azure AD token for authentication (mutually exclusive with `api_key`).
55    pub azure_ad_token: Option<String>,
56}
57
58impl AzureConfig {
59    /// Create a new empty Azure configuration.
60    #[must_use]
61    pub fn new() -> Self {
62        Self::default()
63    }
64
65    /// Set the Azure endpoint URL.
66    ///
67    /// Example: `https://my-resource.openai.azure.com`
68    #[must_use]
69    pub fn azure_endpoint(mut self, endpoint: impl Into<String>) -> Self {
70        self.azure_endpoint = Some(endpoint.into());
71        self
72    }
73
74    /// Set the Azure deployment name.
75    ///
76    /// When set, the base URL becomes `{endpoint}/openai/deployments/{deployment}`.
77    /// When not set, the base URL is `{endpoint}/openai`.
78    #[must_use]
79    pub fn azure_deployment(mut self, deployment: impl Into<String>) -> Self {
80        self.azure_deployment = Some(deployment.into());
81        self
82    }
83
84    /// Set the Azure API version.
85    ///
86    /// Defaults to `2024-10-21` if not set and not in environment.
87    #[must_use]
88    pub fn api_version(mut self, version: impl Into<String>) -> Self {
89        self.api_version = Some(version.into());
90        self
91    }
92
93    /// Set the Azure API key.
94    ///
95    /// Mutually exclusive with `azure_ad_token`.
96    #[must_use]
97    pub fn api_key(mut self, key: impl Into<String>) -> Self {
98        self.api_key = Some(key.into());
99        self
100    }
101
102    /// Set the Azure AD token for authentication.
103    ///
104    /// Mutually exclusive with `api_key`. When using AD token auth,
105    /// requests use `Authorization: Bearer {token}` instead of `api-key` header.
106    #[must_use]
107    pub fn azure_ad_token(mut self, token: impl Into<String>) -> Self {
108        self.azure_ad_token = Some(token.into());
109        self
110    }
111
112    /// Build an `OpenAI` client from this Azure configuration.
113    ///
114    /// # Errors
115    ///
116    /// Returns `OpenAIError::InvalidArgument` if:
117    /// - No endpoint is provided (and `AZURE_OPENAI_ENDPOINT` is not set)
118    /// - No credentials are provided (neither API key nor AD token)
119    /// - Both `api_key` and `azure_ad_token` are set (mutually exclusive)
120    pub fn build(self) -> Result<OpenAI, OpenAIError> {
121        let endpoint = self.azure_endpoint.ok_or_else(|| {
122            OpenAIError::InvalidArgument(
123                "Azure endpoint is required. Set azure_endpoint() or AZURE_OPENAI_ENDPOINT env var"
124                    .to_string(),
125            )
126        })?;
127
128        let api_version = self
129            .api_version
130            .unwrap_or_else(|| DEFAULT_API_VERSION.to_string());
131
132        // Validate mutual exclusivity
133        if self.api_key.is_some() && self.azure_ad_token.is_some() {
134            return Err(OpenAIError::InvalidArgument(
135                "api_key and azure_ad_token are mutually exclusive; only one can be set"
136                    .to_string(),
137            ));
138        }
139
140        // Determine auth mode
141        let (auth_key, use_azure_api_key_header) = match (&self.api_key, &self.azure_ad_token) {
142            (Some(key), None) => (key.clone(), true),
143            (None, Some(token)) => (token.clone(), false),
144            (None, None) => {
145                return Err(OpenAIError::InvalidArgument(
146                    "Azure credentials required. Set api_key() or azure_ad_token()".to_string(),
147                ));
148            }
149            _ => unreachable!(), // already checked above
150        };
151
152        // Build base URL
153        let base_url = {
154            let trimmed = endpoint.trim_end_matches('/');
155            match &self.azure_deployment {
156                Some(deployment) => format!("{trimmed}/openai/deployments/{deployment}"),
157                None => format!("{trimmed}/openai"),
158            }
159        };
160
161        // Build config with api-version as default query
162        let config = ClientConfig::new(auth_key)
163            .base_url(base_url)
164            .default_query(vec![("api-version".to_string(), api_version)])
165            .use_azure_api_key_header(use_azure_api_key_header);
166
167        Ok(OpenAI::with_config(config))
168    }
169
170    /// Build an `OpenAI` client from environment variables.
171    ///
172    /// Reads:
173    /// - `AZURE_OPENAI_API_KEY` — API key
174    /// - `AZURE_OPENAI_ENDPOINT` — Azure endpoint URL
175    /// - `OPENAI_API_VERSION` — API version (defaults to `2024-10-21`)
176    /// - `AZURE_OPENAI_AD_TOKEN` — Azure AD token (alternative to API key)
177    pub fn from_env() -> Result<OpenAI, OpenAIError> {
178        let mut config = Self::new();
179
180        if let Ok(endpoint) = env::var("AZURE_OPENAI_ENDPOINT") {
181            config = config.azure_endpoint(endpoint);
182        }
183
184        if let Ok(key) = env::var("AZURE_OPENAI_API_KEY") {
185            config = config.api_key(key);
186        }
187
188        if let Ok(token) = env::var("AZURE_OPENAI_AD_TOKEN") {
189            config = config.azure_ad_token(token);
190        }
191
192        if let Ok(version) = env::var("OPENAI_API_VERSION") {
193            config = config.api_version(version);
194        }
195
196        config.build()
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    // --- Task 2.1: AzureConfig builder URL construction ---
205
206    #[test]
207    fn test_azure_url_with_deployment() {
208        let client = AzureConfig::new()
209            .azure_endpoint("https://my-resource.openai.azure.com")
210            .azure_deployment("gpt-4")
211            .api_key("test-key")
212            .build()
213            .unwrap();
214
215        assert_eq!(
216            client.config.base_url(),
217            "https://my-resource.openai.azure.com/openai/deployments/gpt-4"
218        );
219    }
220
221    #[test]
222    fn test_azure_url_without_deployment() {
223        let client = AzureConfig::new()
224            .azure_endpoint("https://my-resource.openai.azure.com")
225            .api_key("test-key")
226            .build()
227            .unwrap();
228
229        assert_eq!(
230            client.config.base_url(),
231            "https://my-resource.openai.azure.com/openai"
232        );
233    }
234
235    #[test]
236    fn test_azure_url_trailing_slash_stripped() {
237        let client = AzureConfig::new()
238            .azure_endpoint("https://my-resource.openai.azure.com/")
239            .azure_deployment("gpt-4")
240            .api_key("test-key")
241            .build()
242            .unwrap();
243
244        assert_eq!(
245            client.config.base_url(),
246            "https://my-resource.openai.azure.com/openai/deployments/gpt-4"
247        );
248    }
249
250    #[test]
251    fn test_azure_default_api_version() {
252        let client = AzureConfig::new()
253            .azure_endpoint("https://example.openai.azure.com")
254            .api_key("test-key")
255            .build()
256            .unwrap();
257
258        let query = client.options.query.as_ref().unwrap();
259        assert!(
260            query
261                .iter()
262                .any(|(k, v)| k == "api-version" && v == "2024-10-21")
263        );
264    }
265
266    #[test]
267    fn test_azure_custom_api_version() {
268        let client = AzureConfig::new()
269            .azure_endpoint("https://example.openai.azure.com")
270            .api_key("test-key")
271            .api_version("2024-06-01")
272            .build()
273            .unwrap();
274
275        let query = client.options.query.as_ref().unwrap();
276        assert!(
277            query
278                .iter()
279                .any(|(k, v)| k == "api-version" && v == "2024-06-01")
280        );
281    }
282
283    // --- Task 2.2: api-version query param on requests ---
284
285    #[tokio::test]
286    async fn test_azure_sends_api_version_query_param() {
287        let mut server = mockito::Server::new_async().await;
288        let mock = server
289            .mock("GET", "/openai/models")
290            .match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
291                "api-version".into(),
292                "2024-10-21".into(),
293            )]))
294            .with_status(200)
295            .with_body(r#"{"data":[],"object":"list"}"#)
296            .create_async()
297            .await;
298
299        let client = AzureConfig::new()
300            .azure_endpoint(&server.url())
301            .api_key("test-key")
302            .build()
303            .unwrap();
304
305        #[derive(serde::Deserialize)]
306        struct ListResp {
307            object: String,
308        }
309
310        let resp: ListResp = client.get("/models").await.unwrap();
311        assert_eq!(resp.object, "list");
312        mock.assert_async().await;
313    }
314
315    // --- Task 2.3: Azure api-key header ---
316
317    #[tokio::test]
318    async fn test_azure_sends_api_key_header() {
319        let mut server = mockito::Server::new_async().await;
320        let mock = server
321            .mock("GET", "/openai/test")
322            .match_header("api-key", "my-azure-key")
323            .match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
324                "api-version".into(),
325                "2024-10-21".into(),
326            )]))
327            .with_status(200)
328            .with_body(r#"{"ok":true}"#)
329            .create_async()
330            .await;
331
332        let client = AzureConfig::new()
333            .azure_endpoint(&server.url())
334            .api_key("my-azure-key")
335            .build()
336            .unwrap();
337
338        #[derive(serde::Deserialize)]
339        struct Resp {
340            ok: bool,
341        }
342
343        let resp: Resp = client.get("/test").await.unwrap();
344        assert!(resp.ok);
345        mock.assert_async().await;
346    }
347
348    #[tokio::test]
349    async fn test_azure_does_not_send_bearer_auth() {
350        let mut server = mockito::Server::new_async().await;
351        // Ensure no Authorization header is sent for api-key mode
352        let mock = server
353            .mock("GET", "/openai/test")
354            .match_header("api-key", "my-azure-key")
355            .match_header("authorization", mockito::Matcher::Missing)
356            .match_query(mockito::Matcher::Any)
357            .with_status(200)
358            .with_body(r#"{"ok":true}"#)
359            .create_async()
360            .await;
361
362        let client = AzureConfig::new()
363            .azure_endpoint(&server.url())
364            .api_key("my-azure-key")
365            .build()
366            .unwrap();
367
368        #[derive(serde::Deserialize)]
369        struct Resp {
370            ok: bool,
371        }
372
373        let resp: Resp = client.get("/test").await.unwrap();
374        assert!(resp.ok);
375        mock.assert_async().await;
376    }
377
378    // --- Task 2.4: Azure AD token auth ---
379
380    #[tokio::test]
381    async fn test_azure_ad_token_sends_bearer() {
382        let mut server = mockito::Server::new_async().await;
383        let mock = server
384            .mock("GET", "/openai/test")
385            .match_header("authorization", "Bearer my-ad-token")
386            .match_query(mockito::Matcher::Any)
387            .with_status(200)
388            .with_body(r#"{"ok":true}"#)
389            .create_async()
390            .await;
391
392        let client = AzureConfig::new()
393            .azure_endpoint(&server.url())
394            .azure_ad_token("my-ad-token")
395            .build()
396            .unwrap();
397
398        #[derive(serde::Deserialize)]
399        struct Resp {
400            ok: bool,
401        }
402
403        let resp: Resp = client.get("/test").await.unwrap();
404        assert!(resp.ok);
405        mock.assert_async().await;
406    }
407
408    // --- Task 2.5: Mutual exclusivity validation ---
409
410    #[test]
411    fn test_mutual_exclusivity_error() {
412        let result = AzureConfig::new()
413            .azure_endpoint("https://example.openai.azure.com")
414            .api_key("key")
415            .azure_ad_token("token")
416            .build();
417
418        assert!(result.is_err());
419        let err = result.unwrap_err();
420        assert!(
421            err.to_string().contains("mutually exclusive"),
422            "unexpected error: {err}"
423        );
424    }
425
426    #[test]
427    fn test_no_credentials_error() {
428        let result = AzureConfig::new()
429            .azure_endpoint("https://example.openai.azure.com")
430            .build();
431
432        assert!(result.is_err());
433        let err = result.unwrap_err();
434        assert!(
435            err.to_string().contains("credentials required"),
436            "unexpected error: {err}"
437        );
438    }
439
440    #[test]
441    fn test_no_endpoint_error() {
442        let result = AzureConfig::new().api_key("key").build();
443
444        assert!(result.is_err());
445        let err = result.unwrap_err();
446        assert!(
447            err.to_string().contains("endpoint is required"),
448            "unexpected error: {err}"
449        );
450    }
451
452    // --- Task 2.6: from_env() ---
453
454    #[test]
455    fn test_from_env_reads_variables() {
456        // SAFETY: test runs in a single thread context; no concurrent env access.
457        unsafe {
458            env::set_var("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com");
459            env::set_var("AZURE_OPENAI_API_KEY", "env-key");
460            env::set_var("OPENAI_API_VERSION", "2024-06-01");
461            env::remove_var("AZURE_OPENAI_AD_TOKEN");
462        }
463
464        let client = AzureConfig::from_env().unwrap();
465
466        assert_eq!(
467            client.config.base_url(),
468            "https://test.openai.azure.com/openai"
469        );
470        assert_eq!(client.config.api_key(), "env-key");
471
472        let query = client.options.query.as_ref().unwrap();
473        assert!(
474            query
475                .iter()
476                .any(|(k, v)| k == "api-version" && v == "2024-06-01")
477        );
478
479        // Clean up
480        unsafe {
481            env::remove_var("AZURE_OPENAI_ENDPOINT");
482            env::remove_var("AZURE_OPENAI_API_KEY");
483            env::remove_var("OPENAI_API_VERSION");
484        }
485    }
486
487    // --- Task 2.7: End-to-end chat completion through Azure client ---
488
489    #[tokio::test]
490    async fn test_azure_chat_completion_e2e() {
491        let mut server = mockito::Server::new_async().await;
492        let mock = server
493            .mock("POST", "/openai/deployments/gpt-4/chat/completions")
494            .match_header("api-key", "azure-key")
495            .match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
496                "api-version".into(),
497                "2024-10-21".into(),
498            )]))
499            .with_status(200)
500            .with_header("content-type", "application/json")
501            .with_body(
502                r#"{
503                    "id": "chatcmpl-azure-123",
504                    "object": "chat.completion",
505                    "created": 1700000000,
506                    "model": "gpt-4",
507                    "choices": [{
508                        "index": 0,
509                        "message": {
510                            "role": "assistant",
511                            "content": "Hello from Azure!"
512                        },
513                        "finish_reason": "stop"
514                    }],
515                    "usage": {
516                        "prompt_tokens": 10,
517                        "completion_tokens": 5,
518                        "total_tokens": 15
519                    }
520                }"#,
521            )
522            .create_async()
523            .await;
524
525        let client = AzureConfig::new()
526            .azure_endpoint(&server.url())
527            .azure_deployment("gpt-4")
528            .api_key("azure-key")
529            .build()
530            .unwrap();
531
532        use crate::types::chat::{ChatCompletionMessageParam, ChatCompletionRequest, UserContent};
533
534        let request = ChatCompletionRequest::new(
535            "gpt-4",
536            vec![ChatCompletionMessageParam::User {
537                content: UserContent::Text("Hello!".into()),
538                name: None,
539            }],
540        );
541
542        let response = client.chat().completions().create(request).await.unwrap();
543        assert_eq!(response.id, "chatcmpl-azure-123");
544        assert_eq!(
545            response.choices[0].message.content.as_deref().unwrap_or(""),
546            "Hello from Azure!"
547        );
548        mock.assert_async().await;
549    }
550}