llm_kit_azure/
client.rs

1use std::collections::HashMap;
2
3use crate::provider::AzureOpenAIProvider;
4use crate::settings::AzureOpenAIProviderSettings;
5
6/// Builder for creating an Azure OpenAI provider.
7///
8/// Provides a fluent API for constructing an `AzureOpenAIProvider` with various configuration options.
9///
10/// # Examples
11///
12/// ## Basic Usage with Resource Name
13///
14/// ```no_run
15/// use llm_kit_azure::AzureClient;
16///
17/// let provider = AzureClient::new()
18///     .resource_name("my-azure-resource")
19///     .api_key("your-api-key")
20///     .build();
21///
22/// let model = provider.chat_model("gpt-4-deployment");
23/// ```
24///
25/// ## Custom Base URL
26///
27/// ```no_run
28/// use llm_kit_azure::AzureClient;
29///
30/// let provider = AzureClient::new()
31///     .base_url("https://my-resource.openai.azure.com/openai")
32///     .api_key("your-api-key")
33///     .build();
34///
35/// let model = provider.chat_model("gpt-4-deployment");
36/// ```
37///
38/// ## With Custom Headers and API Version
39///
40/// ```no_run
41/// use llm_kit_azure::AzureClient;
42///
43/// let provider = AzureClient::new()
44///     .resource_name("my-resource")
45///     .api_key("your-api-key")
46///     .api_version("2024-02-15-preview")
47///     .header("X-Custom-Header", "value")
48///     .build();
49///
50/// let model = provider.chat_model("gpt-4-deployment");
51/// ```
52///
53/// ## Deployment-Based URLs (Legacy Format)
54///
55/// ```no_run
56/// use llm_kit_azure::AzureClient;
57///
58/// let provider = AzureClient::new()
59///     .resource_name("my-resource")
60///     .api_key("your-api-key")
61///     .use_deployment_based_urls(true)
62///     .build();
63///
64/// let model = provider.chat_model("gpt-4-deployment");
65/// ```
66#[derive(Debug, Clone, Default)]
67pub struct AzureClient {
68    resource_name: Option<String>,
69    base_url: Option<String>,
70    api_key: Option<String>,
71    headers: HashMap<String, String>,
72    api_version: Option<String>,
73    use_deployment_based_urls: bool,
74}
75
76impl AzureClient {
77    /// Creates a new client builder with default settings.
78    ///
79    /// The default API version is "v1" and uses the v1 API URL format.
80    /// If no API key is provided, it will attempt to use the `AZURE_API_KEY` environment variable.
81    /// If no resource name or base URL is provided, it will attempt to use the `AZURE_RESOURCE_NAME`
82    /// or `AZURE_BASE_URL` environment variables.
83    pub fn new() -> Self {
84        Self::default()
85    }
86
87    /// Sets the Azure OpenAI resource name.
88    ///
89    /// The resource name is used to construct the base URL:
90    /// `https://{resource_name}.openai.azure.com/openai`
91    ///
92    /// # Arguments
93    ///
94    /// * `resource_name` - The Azure OpenAI resource name
95    ///
96    /// # Examples
97    ///
98    /// ```no_run
99    /// use llm_kit_azure::AzureClient;
100    ///
101    /// let client = AzureClient::new()
102    ///     .resource_name("my-azure-resource");
103    /// ```
104    pub fn resource_name(mut self, resource_name: impl Into<String>) -> Self {
105        self.resource_name = Some(resource_name.into());
106        self
107    }
108
109    /// Sets a custom base URL for API calls.
110    ///
111    /// When set, this takes precedence over `resource_name`.
112    ///
113    /// # Arguments
114    ///
115    /// * `base_url` - The base URL (e.g., "<https://my-resource.openai.azure.com/openai>")
116    ///
117    /// # Examples
118    ///
119    /// ```no_run
120    /// use llm_kit_azure::AzureClient;
121    ///
122    /// let client = AzureClient::new()
123    ///     .base_url("https://custom.endpoint.com/openai");
124    /// ```
125    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
126        self.base_url = Some(base_url.into());
127        self
128    }
129
130    /// Sets the API key for authentication.
131    ///
132    /// If not specified, the client will attempt to use the `AZURE_API_KEY` environment variable.
133    ///
134    /// # Arguments
135    ///
136    /// * `api_key` - The API key
137    ///
138    /// # Examples
139    ///
140    /// ```no_run
141    /// use llm_kit_azure::AzureClient;
142    ///
143    /// let client = AzureClient::new()
144    ///     .api_key("your-api-key");
145    /// ```
146    pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
147        self.api_key = Some(api_key.into());
148        self
149    }
150
151    /// Adds a custom header to include in requests.
152    ///
153    /// # Arguments
154    ///
155    /// * `key` - The header name
156    /// * `value` - The header value
157    ///
158    /// # Examples
159    ///
160    /// ```no_run
161    /// use llm_kit_azure::AzureClient;
162    ///
163    /// let client = AzureClient::new()
164    ///     .header("X-Custom-Header", "value");
165    /// ```
166    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
167        self.headers.insert(key.into(), value.into());
168        self
169    }
170
171    /// Sets multiple custom headers at once.
172    ///
173    /// # Arguments
174    ///
175    /// * `headers` - A HashMap of header names to values
176    ///
177    /// # Examples
178    ///
179    /// ```no_run
180    /// use llm_kit_azure::AzureClient;
181    /// use std::collections::HashMap;
182    ///
183    /// let mut headers = HashMap::new();
184    /// headers.insert("X-Custom-1".to_string(), "value1".to_string());
185    /// headers.insert("X-Custom-2".to_string(), "value2".to_string());
186    ///
187    /// let client = AzureClient::new()
188    ///     .headers(headers);
189    /// ```
190    pub fn headers(mut self, headers: HashMap<String, String>) -> Self {
191        self.headers.extend(headers);
192        self
193    }
194
195    /// Sets the API version to use in requests.
196    ///
197    /// Azure OpenAI requires an API version parameter in all requests.
198    /// Defaults to "v1" if not specified.
199    ///
200    /// Common versions include:
201    /// - "2023-05-15"
202    /// - "2024-02-15-preview"
203    /// - "v1" (default)
204    ///
205    /// # Arguments
206    ///
207    /// * `api_version` - The API version string
208    ///
209    /// # Examples
210    ///
211    /// ```no_run
212    /// use llm_kit_azure::AzureClient;
213    ///
214    /// let client = AzureClient::new()
215    ///     .api_version("2024-02-15-preview");
216    /// ```
217    pub fn api_version(mut self, api_version: impl Into<String>) -> Self {
218        self.api_version = Some(api_version.into());
219        self
220    }
221
222    /// Sets whether to use deployment-based URLs.
223    ///
224    /// When `true`, uses legacy format:
225    /// `{base_url}/deployments/{deployment_id}{path}?api-version={version}`
226    ///
227    /// When `false` (default), uses v1 API format:
228    /// `{base_url}/v1{path}?api-version={version}`
229    ///
230    /// # Arguments
231    ///
232    /// * `use_deployment_based_urls` - Whether to use deployment-based URLs
233    ///
234    /// # Examples
235    ///
236    /// ```no_run
237    /// use llm_kit_azure::AzureClient;
238    ///
239    /// let client = AzureClient::new()
240    ///     .use_deployment_based_urls(true);
241    /// ```
242    pub fn use_deployment_based_urls(mut self, use_deployment_based_urls: bool) -> Self {
243        self.use_deployment_based_urls = use_deployment_based_urls;
244        self
245    }
246
247    /// Builds the Azure OpenAI provider.
248    ///
249    /// This method constructs the provider settings from the builder configuration
250    /// and creates an `AzureOpenAIProvider` instance.
251    ///
252    /// # Panics
253    ///
254    /// Panics if neither resource name nor base URL is provided (either via builder or environment variables).
255    ///
256    /// # Examples
257    ///
258    /// ```no_run
259    /// use llm_kit_azure::AzureClient;
260    ///
261    /// let provider = AzureClient::new()
262    ///     .resource_name("my-resource")
263    ///     .api_key("your-api-key")
264    ///     .build();
265    /// ```
266    pub fn build(self) -> AzureOpenAIProvider {
267        let mut settings = AzureOpenAIProviderSettings::new();
268
269        // Set resource name or base URL (prefer explicit values, fall back to env vars)
270        if let Some(resource_name) = self.resource_name {
271            settings = settings.with_resource_name(resource_name);
272        } else if let Some(base_url) = self.base_url {
273            settings = settings.with_base_url(base_url);
274        } else {
275            // Try environment variables
276            if let Ok(resource_name) = std::env::var("AZURE_RESOURCE_NAME") {
277                settings = settings.with_resource_name(resource_name);
278            } else if let Ok(base_url) = std::env::var("AZURE_BASE_URL") {
279                settings = settings.with_base_url(base_url);
280            }
281        }
282
283        // Set API key (prefer explicit value, fall back to env var)
284        if let Some(api_key) = self.api_key {
285            settings = settings.with_api_key(api_key);
286        } else if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
287            settings = settings.with_api_key(api_key);
288        }
289
290        // Set custom headers if provided
291        if !self.headers.is_empty() {
292            settings = settings.with_headers(self.headers);
293        }
294
295        // Set API version if provided
296        if let Some(api_version) = self.api_version {
297            settings = settings.with_api_version(api_version);
298        }
299
300        // Set deployment-based URLs flag
301        settings = settings.with_use_deployment_based_urls(self.use_deployment_based_urls);
302
303        AzureOpenAIProvider::new(settings)
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_builder_with_resource_name() {
313        let provider = AzureClient::new()
314            .resource_name("test-resource")
315            .api_key("test-key")
316            .build();
317
318        assert_eq!(provider.name(), "azure");
319    }
320
321    #[test]
322    fn test_builder_with_base_url() {
323        let provider = AzureClient::new()
324            .base_url("https://custom.endpoint.com/openai")
325            .api_key("test-key")
326            .build();
327
328        assert_eq!(provider.name(), "azure");
329    }
330
331    #[test]
332    fn test_builder_with_api_version() {
333        let provider = AzureClient::new()
334            .resource_name("test-resource")
335            .api_key("test-key")
336            .api_version("2024-02-15-preview")
337            .build();
338
339        assert_eq!(provider.name(), "azure");
340    }
341
342    #[test]
343    fn test_builder_with_headers() {
344        let provider = AzureClient::new()
345            .resource_name("test-resource")
346            .api_key("test-key")
347            .header("X-Custom-1", "value1")
348            .header("X-Custom-2", "value2")
349            .build();
350
351        assert_eq!(provider.name(), "azure");
352    }
353
354    #[test]
355    fn test_builder_with_multiple_headers() {
356        let mut headers = HashMap::new();
357        headers.insert("X-Custom-1".to_string(), "value1".to_string());
358        headers.insert("X-Custom-2".to_string(), "value2".to_string());
359
360        let provider = AzureClient::new()
361            .resource_name("test-resource")
362            .api_key("test-key")
363            .headers(headers)
364            .build();
365
366        assert_eq!(provider.name(), "azure");
367    }
368
369    #[test]
370    fn test_builder_with_deployment_based_urls() {
371        let provider = AzureClient::new()
372            .resource_name("test-resource")
373            .api_key("test-key")
374            .use_deployment_based_urls(true)
375            .build();
376
377        assert_eq!(provider.name(), "azure");
378    }
379
380    #[test]
381    fn test_builder_full_configuration() {
382        let provider = AzureClient::new()
383            .resource_name("test-resource")
384            .api_key("test-key")
385            .api_version("2024-02-15-preview")
386            .header("X-Custom", "value")
387            .use_deployment_based_urls(true)
388            .build();
389
390        assert_eq!(provider.name(), "azure");
391    }
392
393    #[test]
394    fn test_builder_creates_working_models() {
395        let provider = AzureClient::new()
396            .resource_name("test-resource")
397            .api_key("test-key")
398            .build();
399
400        // Test chat model
401        let chat_model = provider.chat_model("gpt-4");
402        assert_eq!(chat_model.provider(), "azure.chat");
403        assert_eq!(chat_model.model_id(), "gpt-4");
404
405        // Test completion model
406        let completion_model = provider.completion_model("gpt-35-turbo-instruct");
407        assert_eq!(completion_model.provider(), "azure.completion");
408        assert_eq!(completion_model.model_id(), "gpt-35-turbo-instruct");
409
410        // Test embedding model
411        let embedding_model = provider.text_embedding_model("text-embedding-ada-002");
412        assert_eq!(embedding_model.provider(), "azure.embedding");
413        assert_eq!(embedding_model.model_id(), "text-embedding-ada-002");
414
415        // Test image model
416        let image_model = provider.image_model("dall-e-3");
417        assert_eq!(image_model.provider(), "azure.image");
418        assert_eq!(image_model.model_id(), "dall-e-3");
419    }
420
421    #[test]
422    #[should_panic(expected = "Invalid Azure OpenAI provider settings")]
423    fn test_builder_without_url_or_resource_panics() {
424        AzureClient::new().api_key("test-key").build();
425    }
426}