llm_kit_azure/
provider.rs

1use llm_kit_openai_compatible::{
2    OpenAICompatibleChatConfig, OpenAICompatibleChatLanguageModel,
3    OpenAICompatibleCompletionConfig, OpenAICompatibleCompletionLanguageModel,
4    OpenAICompatibleEmbeddingConfig, OpenAICompatibleEmbeddingModel, OpenAICompatibleImageModel,
5    OpenAICompatibleImageModelConfig,
6};
7use llm_kit_provider::error::ProviderError;
8use llm_kit_provider::language_model::LanguageModel;
9use llm_kit_provider::provider::Provider;
10use llm_kit_provider::{EmbeddingModel, ImageModel};
11use std::collections::HashMap;
12use std::sync::Arc;
13
14use crate::settings::AzureOpenAIProviderSettings;
15
16/// Azure OpenAI provider implementation.
17///
18/// This provider creates language models, embedding models, and image models
19/// that use Azure OpenAI endpoints. It handles Azure-specific authentication
20/// (api-key header) and URL formatting (deployment-based or v1 API).
21///
22/// # Examples
23///
24/// ## Using Builder Pattern (Recommended)
25///
26/// ```no_run
27/// use llm_kit_azure::AzureClient;
28///
29/// let provider = AzureClient::new()
30///     .resource_name("my-azure-resource")
31///     .api_key("your-api-key")
32///     .build();
33///
34/// let model = provider.chat_model("gpt-4-deployment");
35/// ```
36///
37/// ## Direct Instantiation (Alternative)
38///
39/// ```no_run
40/// use llm_kit_azure::{AzureOpenAIProvider, AzureOpenAIProviderSettings};
41///
42/// let provider = AzureOpenAIProvider::new(
43///     AzureOpenAIProviderSettings::new()
44///         .with_resource_name("my-azure-resource")
45///         .with_api_key("your-api-key")
46/// );
47///
48/// let model = provider.chat_model("gpt-4-deployment");
49/// ```
50pub struct AzureOpenAIProvider {
51    settings: AzureOpenAIProviderSettings,
52}
53
54impl AzureOpenAIProvider {
55    /// Creates a new Azure OpenAI provider.
56    ///
57    /// # Panics
58    ///
59    /// Panics if the settings are invalid (missing both base_url and resource_name).
60    pub fn new(settings: AzureOpenAIProviderSettings) -> Self {
61        // Validate settings on creation
62        if let Err(e) = settings.validate() {
63            panic!("Invalid Azure OpenAI provider settings: {}", e);
64        }
65        Self { settings }
66    }
67
68    /// Helper function to build URLs for Azure OpenAI API calls.
69    ///
70    /// Azure supports two URL formats:
71    /// 1. Deployment-based (legacy): `{base_url}/deployments/{deployment_id}{path}?api-version={version}`
72    /// 2. V1 API (default): `{base_url}/v1{path}?api-version={version}`
73    fn build_url(
74        base_url: &str,
75        deployment_id: &str,
76        path: &str,
77        api_version: &str,
78        use_deployment_based: bool,
79    ) -> String {
80        let full_path = if use_deployment_based {
81            format!("{}/deployments/{}{}", base_url, deployment_id, path)
82        } else {
83            format!("{}/v1{}", base_url, path)
84        };
85
86        // Add api-version query parameter
87        match url::Url::parse(&full_path) {
88            Ok(mut url) => {
89                url.query_pairs_mut()
90                    .append_pair("api-version", api_version);
91                url.to_string()
92            }
93            Err(_) => full_path,
94        }
95    }
96
97    /// Creates a language model with the given deployment ID.
98    ///
99    /// This is an alias for `chat_model()`.
100    pub fn model(&self, deployment_id: impl Into<String>) -> Arc<dyn LanguageModel> {
101        self.chat_model(deployment_id)
102    }
103
104    /// Creates a chat language model with the given deployment ID.
105    ///
106    /// # Arguments
107    ///
108    /// * `deployment_id` - The deployment name/ID in Azure OpenAI
109    pub fn chat_model(&self, deployment_id: impl Into<String>) -> Arc<dyn LanguageModel> {
110        let deployment_id = deployment_id.into();
111        let config = self.create_chat_config();
112        Arc::new(OpenAICompatibleChatLanguageModel::new(
113            deployment_id,
114            config,
115        ))
116    }
117
118    /// Creates a completion language model with the given deployment ID.
119    ///
120    /// # Arguments
121    ///
122    /// * `deployment_id` - The deployment name/ID in Azure OpenAI
123    pub fn completion_model(&self, deployment_id: impl Into<String>) -> Arc<dyn LanguageModel> {
124        let deployment_id = deployment_id.into();
125        let config = self.create_completion_config();
126        Arc::new(OpenAICompatibleCompletionLanguageModel::new(
127            deployment_id,
128            config,
129        ))
130    }
131
132    /// Creates a text embedding model with the given deployment ID.
133    ///
134    /// # Arguments
135    ///
136    /// * `deployment_id` - The deployment name/ID in Azure OpenAI
137    pub fn text_embedding_model(
138        &self,
139        deployment_id: impl Into<String>,
140    ) -> Arc<dyn EmbeddingModel<String>> {
141        let deployment_id = deployment_id.into();
142        let config = self.create_embedding_config();
143        Arc::new(OpenAICompatibleEmbeddingModel::new(deployment_id, config))
144    }
145
146    /// Creates an image model with the given deployment ID.
147    ///
148    /// # Arguments
149    ///
150    /// * `deployment_id` - The deployment name/ID in Azure OpenAI (e.g., dall-e-3 deployment)
151    pub fn image_model(&self, deployment_id: impl Into<String>) -> Arc<dyn ImageModel> {
152        let deployment_id = deployment_id.into();
153        let config = self.create_image_config();
154        Arc::new(OpenAICompatibleImageModel::new(deployment_id, config))
155    }
156
157    /// Creates the configuration for chat models.
158    fn create_chat_config(&self) -> OpenAICompatibleChatConfig {
159        let api_key = self.settings.api_key.clone();
160        let custom_headers = self.settings.headers.clone().unwrap_or_default();
161        let base_url = self
162            .settings
163            .get_base_url()
164            .expect("Base URL should be validated");
165        let api_version = self.settings.api_version.clone();
166        let use_deployment_based = self.settings.use_deployment_based_urls;
167
168        OpenAICompatibleChatConfig {
169            provider: "azure.chat".to_string(),
170            headers: Box::new(move || {
171                let mut headers = HashMap::new();
172
173                // Azure uses 'api-key' header instead of 'Authorization: Bearer'
174                if let Some(ref key) = api_key {
175                    headers.insert("api-key".to_string(), key.clone());
176                }
177
178                // Add custom headers
179                for (key, value) in &custom_headers {
180                    headers.insert(key.clone(), value.clone());
181                }
182
183                headers
184            }),
185            url: Box::new(move |model_id: &str, path: &str| {
186                Self::build_url(
187                    &base_url,
188                    model_id,
189                    path,
190                    &api_version,
191                    use_deployment_based,
192                )
193            }),
194            include_usage: true,
195            supports_structured_outputs: false,
196            supported_urls: None,
197        }
198    }
199
200    /// Creates the configuration for completion models.
201    fn create_completion_config(&self) -> OpenAICompatibleCompletionConfig {
202        let api_key = self.settings.api_key.clone();
203        let custom_headers = self.settings.headers.clone().unwrap_or_default();
204        let base_url = self
205            .settings
206            .get_base_url()
207            .expect("Base URL should be validated");
208        let api_version = self.settings.api_version.clone();
209        let use_deployment_based = self.settings.use_deployment_based_urls;
210
211        OpenAICompatibleCompletionConfig {
212            provider: "azure.completion".to_string(),
213            headers: Box::new(move || {
214                let mut headers = HashMap::new();
215
216                if let Some(ref key) = api_key {
217                    headers.insert("api-key".to_string(), key.clone());
218                }
219
220                for (key, value) in &custom_headers {
221                    headers.insert(key.clone(), value.clone());
222                }
223
224                headers
225            }),
226            url: Box::new(move |model_id: &str, path: &str| {
227                Self::build_url(
228                    &base_url,
229                    model_id,
230                    path,
231                    &api_version,
232                    use_deployment_based,
233                )
234            }),
235            include_usage: true,
236        }
237    }
238
239    /// Creates the configuration for embedding models.
240    fn create_embedding_config(&self) -> OpenAICompatibleEmbeddingConfig {
241        let api_key = self.settings.api_key.clone();
242        let custom_headers = self.settings.headers.clone().unwrap_or_default();
243        let base_url = self
244            .settings
245            .get_base_url()
246            .expect("Base URL should be validated");
247        let api_version = self.settings.api_version.clone();
248        let use_deployment_based = self.settings.use_deployment_based_urls;
249
250        OpenAICompatibleEmbeddingConfig {
251            provider: "azure.embedding".to_string(),
252            headers: Box::new(move || {
253                let mut headers = HashMap::new();
254
255                if let Some(ref key) = api_key {
256                    headers.insert("api-key".to_string(), key.clone());
257                }
258
259                for (key, value) in &custom_headers {
260                    headers.insert(key.clone(), value.clone());
261                }
262
263                headers
264            }),
265            url: Box::new(move |model_id: &str, path: &str| {
266                Self::build_url(
267                    &base_url,
268                    model_id,
269                    path,
270                    &api_version,
271                    use_deployment_based,
272                )
273            }),
274            max_embeddings_per_call: None,
275            supports_parallel_calls: None,
276        }
277    }
278
279    /// Creates the configuration for image models.
280    fn create_image_config(&self) -> OpenAICompatibleImageModelConfig {
281        let api_key = self.settings.api_key.clone();
282        let custom_headers = self.settings.headers.clone().unwrap_or_default();
283        let base_url = self
284            .settings
285            .get_base_url()
286            .expect("Base URL should be validated");
287        let api_version = self.settings.api_version.clone();
288        let use_deployment_based = self.settings.use_deployment_based_urls;
289
290        OpenAICompatibleImageModelConfig {
291            provider: "azure.image".to_string(),
292            headers: Box::new(move || {
293                let mut headers = HashMap::new();
294
295                if let Some(ref key) = api_key {
296                    headers.insert("api-key".to_string(), key.clone());
297                }
298
299                for (key, value) in &custom_headers {
300                    headers.insert(key.clone(), value.clone());
301                }
302
303                headers
304            }),
305            url: Box::new(move |model_id: &str, path: &str| {
306                Self::build_url(
307                    &base_url,
308                    model_id,
309                    path,
310                    &api_version,
311                    use_deployment_based,
312                )
313            }),
314        }
315    }
316
317    /// Gets the provider name.
318    pub fn name(&self) -> &str {
319        "azure"
320    }
321}
322
323// Implement the Provider trait
324impl Provider for AzureOpenAIProvider {
325    fn language_model(&self, deployment_id: &str) -> Result<Arc<dyn LanguageModel>, ProviderError> {
326        Ok(self.chat_model(deployment_id))
327    }
328
329    fn text_embedding_model(
330        &self,
331        deployment_id: &str,
332    ) -> Result<Arc<dyn EmbeddingModel<String>>, ProviderError> {
333        Ok(self.text_embedding_model(deployment_id))
334    }
335
336    fn image_model(&self, deployment_id: &str) -> Result<Arc<dyn ImageModel>, ProviderError> {
337        Ok(self.image_model(deployment_id))
338    }
339
340    fn transcription_model(
341        &self,
342        deployment_id: &str,
343    ) -> Result<Arc<dyn llm_kit_provider::TranscriptionModel>, ProviderError> {
344        Err(ProviderError::no_such_model(
345            deployment_id,
346            "azure.transcription-model-not-supported",
347        ))
348    }
349
350    fn speech_model(
351        &self,
352        deployment_id: &str,
353    ) -> Result<Arc<dyn llm_kit_provider::SpeechModel>, ProviderError> {
354        Err(ProviderError::no_such_model(
355            deployment_id,
356            "azure.speech-model-not-supported",
357        ))
358    }
359
360    fn reranking_model(
361        &self,
362        deployment_id: &str,
363    ) -> Result<Arc<dyn llm_kit_provider::RerankingModel>, ProviderError> {
364        Err(ProviderError::no_such_model(
365            deployment_id,
366            "azure.reranking-model-not-supported",
367        ))
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    fn create_test_provider() -> AzureOpenAIProvider {
376        AzureOpenAIProvider::new(
377            AzureOpenAIProviderSettings::new()
378                .with_resource_name("test-resource")
379                .with_api_key("test-key"),
380        )
381    }
382
383    #[test]
384    fn test_create_azure_provider() {
385        let provider = create_test_provider();
386        assert_eq!(provider.name(), "azure");
387    }
388
389    #[test]
390    fn test_chat_model() {
391        let provider = create_test_provider();
392        let model = provider.chat_model("gpt-4-deployment");
393        assert_eq!(model.provider(), "azure.chat");
394        assert_eq!(model.model_id(), "gpt-4-deployment");
395    }
396
397    #[test]
398    fn test_completion_model() {
399        let provider = create_test_provider();
400        let model = provider.completion_model("gpt-35-turbo-instruct");
401        assert_eq!(model.provider(), "azure.completion");
402        assert_eq!(model.model_id(), "gpt-35-turbo-instruct");
403    }
404
405    #[test]
406    fn test_text_embedding_model() {
407        let provider = create_test_provider();
408        let model = provider.text_embedding_model("text-embedding-ada-002");
409        assert_eq!(model.provider(), "azure.embedding");
410        assert_eq!(model.model_id(), "text-embedding-ada-002");
411    }
412
413    #[test]
414    fn test_image_model() {
415        let provider = create_test_provider();
416        let model = provider.image_model("dall-e-3");
417        assert_eq!(model.provider(), "azure.image");
418        assert_eq!(model.model_id(), "dall-e-3");
419    }
420
421    #[test]
422    fn test_model_alias() {
423        let provider = create_test_provider();
424        let model = provider.model("gpt-4-deployment");
425        assert_eq!(model.provider(), "azure.chat");
426        assert_eq!(model.model_id(), "gpt-4-deployment");
427    }
428
429    #[test]
430    fn test_provider_trait_implementation() {
431        let provider = create_test_provider();
432        let provider_trait: &dyn Provider = &provider;
433
434        // Test language model
435        let model = provider_trait.language_model("gpt-4-deployment").unwrap();
436        assert_eq!(model.provider(), "azure.chat");
437        assert_eq!(model.model_id(), "gpt-4-deployment");
438
439        // Test text embedding model
440        let embedding_model = provider_trait
441            .text_embedding_model("text-embedding-ada-002")
442            .unwrap();
443        assert_eq!(embedding_model.provider(), "azure.embedding");
444        assert_eq!(embedding_model.model_id(), "text-embedding-ada-002");
445
446        // Test image model
447        let image_model = provider_trait.image_model("dall-e-3").unwrap();
448        assert_eq!(image_model.provider(), "azure.image");
449        assert_eq!(image_model.model_id(), "dall-e-3");
450
451        // Test unsupported models
452        assert!(provider_trait.transcription_model("whisper").is_err());
453        assert!(provider_trait.speech_model("tts-1").is_err());
454        assert!(provider_trait.reranking_model("rerank-1").is_err());
455    }
456
457    #[test]
458    fn test_build_url_v1_format() {
459        let url = AzureOpenAIProvider::build_url(
460            "https://test.openai.azure.com/openai",
461            "gpt-4-deployment",
462            "/chat/completions",
463            "2024-02-15-preview",
464            false,
465        );
466
467        assert!(url.contains("/v1/chat/completions"));
468        assert!(url.contains("api-version=2024-02-15-preview"));
469    }
470
471    #[test]
472    fn test_build_url_deployment_based_format() {
473        let url = AzureOpenAIProvider::build_url(
474            "https://test.openai.azure.com/openai",
475            "gpt-4-deployment",
476            "/chat/completions",
477            "2024-02-15-preview",
478            true,
479        );
480
481        assert!(url.contains("/deployments/gpt-4-deployment/chat/completions"));
482        assert!(url.contains("api-version=2024-02-15-preview"));
483    }
484
485    #[test]
486    fn test_with_base_url() {
487        let provider = AzureOpenAIProvider::new(
488            AzureOpenAIProviderSettings::new()
489                .with_base_url("https://custom.endpoint.com/openai")
490                .with_api_key("test-key"),
491        );
492
493        let model = provider.chat_model("gpt-4");
494        assert_eq!(model.provider(), "azure.chat");
495    }
496
497    #[test]
498    fn test_with_custom_api_version() {
499        let provider = AzureOpenAIProvider::new(
500            AzureOpenAIProviderSettings::new()
501                .with_resource_name("test-resource")
502                .with_api_key("test-key")
503                .with_api_version("2023-05-15"),
504        );
505
506        let model = provider.chat_model("gpt-4");
507        assert_eq!(model.provider(), "azure.chat");
508    }
509
510    #[test]
511    fn test_with_deployment_based_urls() {
512        let provider = AzureOpenAIProvider::new(
513            AzureOpenAIProviderSettings::new()
514                .with_resource_name("test-resource")
515                .with_api_key("test-key")
516                .with_use_deployment_based_urls(true),
517        );
518
519        let model = provider.chat_model("gpt-4");
520        assert_eq!(model.provider(), "azure.chat");
521    }
522
523    #[test]
524    #[should_panic(expected = "Invalid Azure OpenAI provider settings")]
525    fn test_provider_without_url_or_resource_panics() {
526        AzureOpenAIProvider::new(AzureOpenAIProviderSettings::new().with_api_key("test-key"));
527    }
528}