1use llm_kit_openai_compatible::{
2 OpenAICompatibleChatConfig, OpenAICompatibleChatLanguageModel,
3 OpenAICompatibleCompletionConfig, OpenAICompatibleCompletionLanguageModel,
4 OpenAICompatibleEmbeddingConfig, OpenAICompatibleEmbeddingModel, OpenAICompatibleImageModel,
5 OpenAICompatibleImageModelConfig,
6};
7use llm_kit_provider::error::ProviderError;
8use llm_kit_provider::language_model::LanguageModel;
9use llm_kit_provider::provider::Provider;
10use llm_kit_provider::{EmbeddingModel, ImageModel};
11use std::collections::HashMap;
12use std::sync::Arc;
13
14use crate::settings::AzureOpenAIProviderSettings;
15
16pub struct AzureOpenAIProvider {
51 settings: AzureOpenAIProviderSettings,
52}
53
54impl AzureOpenAIProvider {
55 pub fn new(settings: AzureOpenAIProviderSettings) -> Self {
61 if let Err(e) = settings.validate() {
63 panic!("Invalid Azure OpenAI provider settings: {}", e);
64 }
65 Self { settings }
66 }
67
68 fn build_url(
74 base_url: &str,
75 deployment_id: &str,
76 path: &str,
77 api_version: &str,
78 use_deployment_based: bool,
79 ) -> String {
80 let full_path = if use_deployment_based {
81 format!("{}/deployments/{}{}", base_url, deployment_id, path)
82 } else {
83 format!("{}/v1{}", base_url, path)
84 };
85
86 match url::Url::parse(&full_path) {
88 Ok(mut url) => {
89 url.query_pairs_mut()
90 .append_pair("api-version", api_version);
91 url.to_string()
92 }
93 Err(_) => full_path,
94 }
95 }
96
97 pub fn model(&self, deployment_id: impl Into<String>) -> Arc<dyn LanguageModel> {
101 self.chat_model(deployment_id)
102 }
103
104 pub fn chat_model(&self, deployment_id: impl Into<String>) -> Arc<dyn LanguageModel> {
110 let deployment_id = deployment_id.into();
111 let config = self.create_chat_config();
112 Arc::new(OpenAICompatibleChatLanguageModel::new(
113 deployment_id,
114 config,
115 ))
116 }
117
118 pub fn completion_model(&self, deployment_id: impl Into<String>) -> Arc<dyn LanguageModel> {
124 let deployment_id = deployment_id.into();
125 let config = self.create_completion_config();
126 Arc::new(OpenAICompatibleCompletionLanguageModel::new(
127 deployment_id,
128 config,
129 ))
130 }
131
132 pub fn text_embedding_model(
138 &self,
139 deployment_id: impl Into<String>,
140 ) -> Arc<dyn EmbeddingModel<String>> {
141 let deployment_id = deployment_id.into();
142 let config = self.create_embedding_config();
143 Arc::new(OpenAICompatibleEmbeddingModel::new(deployment_id, config))
144 }
145
146 pub fn image_model(&self, deployment_id: impl Into<String>) -> Arc<dyn ImageModel> {
152 let deployment_id = deployment_id.into();
153 let config = self.create_image_config();
154 Arc::new(OpenAICompatibleImageModel::new(deployment_id, config))
155 }
156
157 fn create_chat_config(&self) -> OpenAICompatibleChatConfig {
159 let api_key = self.settings.api_key.clone();
160 let custom_headers = self.settings.headers.clone().unwrap_or_default();
161 let base_url = self
162 .settings
163 .get_base_url()
164 .expect("Base URL should be validated");
165 let api_version = self.settings.api_version.clone();
166 let use_deployment_based = self.settings.use_deployment_based_urls;
167
168 OpenAICompatibleChatConfig {
169 provider: "azure.chat".to_string(),
170 headers: Box::new(move || {
171 let mut headers = HashMap::new();
172
173 if let Some(ref key) = api_key {
175 headers.insert("api-key".to_string(), key.clone());
176 }
177
178 for (key, value) in &custom_headers {
180 headers.insert(key.clone(), value.clone());
181 }
182
183 headers
184 }),
185 url: Box::new(move |model_id: &str, path: &str| {
186 Self::build_url(
187 &base_url,
188 model_id,
189 path,
190 &api_version,
191 use_deployment_based,
192 )
193 }),
194 include_usage: true,
195 supports_structured_outputs: false,
196 supported_urls: None,
197 }
198 }
199
200 fn create_completion_config(&self) -> OpenAICompatibleCompletionConfig {
202 let api_key = self.settings.api_key.clone();
203 let custom_headers = self.settings.headers.clone().unwrap_or_default();
204 let base_url = self
205 .settings
206 .get_base_url()
207 .expect("Base URL should be validated");
208 let api_version = self.settings.api_version.clone();
209 let use_deployment_based = self.settings.use_deployment_based_urls;
210
211 OpenAICompatibleCompletionConfig {
212 provider: "azure.completion".to_string(),
213 headers: Box::new(move || {
214 let mut headers = HashMap::new();
215
216 if let Some(ref key) = api_key {
217 headers.insert("api-key".to_string(), key.clone());
218 }
219
220 for (key, value) in &custom_headers {
221 headers.insert(key.clone(), value.clone());
222 }
223
224 headers
225 }),
226 url: Box::new(move |model_id: &str, path: &str| {
227 Self::build_url(
228 &base_url,
229 model_id,
230 path,
231 &api_version,
232 use_deployment_based,
233 )
234 }),
235 include_usage: true,
236 }
237 }
238
239 fn create_embedding_config(&self) -> OpenAICompatibleEmbeddingConfig {
241 let api_key = self.settings.api_key.clone();
242 let custom_headers = self.settings.headers.clone().unwrap_or_default();
243 let base_url = self
244 .settings
245 .get_base_url()
246 .expect("Base URL should be validated");
247 let api_version = self.settings.api_version.clone();
248 let use_deployment_based = self.settings.use_deployment_based_urls;
249
250 OpenAICompatibleEmbeddingConfig {
251 provider: "azure.embedding".to_string(),
252 headers: Box::new(move || {
253 let mut headers = HashMap::new();
254
255 if let Some(ref key) = api_key {
256 headers.insert("api-key".to_string(), key.clone());
257 }
258
259 for (key, value) in &custom_headers {
260 headers.insert(key.clone(), value.clone());
261 }
262
263 headers
264 }),
265 url: Box::new(move |model_id: &str, path: &str| {
266 Self::build_url(
267 &base_url,
268 model_id,
269 path,
270 &api_version,
271 use_deployment_based,
272 )
273 }),
274 max_embeddings_per_call: None,
275 supports_parallel_calls: None,
276 }
277 }
278
279 fn create_image_config(&self) -> OpenAICompatibleImageModelConfig {
281 let api_key = self.settings.api_key.clone();
282 let custom_headers = self.settings.headers.clone().unwrap_or_default();
283 let base_url = self
284 .settings
285 .get_base_url()
286 .expect("Base URL should be validated");
287 let api_version = self.settings.api_version.clone();
288 let use_deployment_based = self.settings.use_deployment_based_urls;
289
290 OpenAICompatibleImageModelConfig {
291 provider: "azure.image".to_string(),
292 headers: Box::new(move || {
293 let mut headers = HashMap::new();
294
295 if let Some(ref key) = api_key {
296 headers.insert("api-key".to_string(), key.clone());
297 }
298
299 for (key, value) in &custom_headers {
300 headers.insert(key.clone(), value.clone());
301 }
302
303 headers
304 }),
305 url: Box::new(move |model_id: &str, path: &str| {
306 Self::build_url(
307 &base_url,
308 model_id,
309 path,
310 &api_version,
311 use_deployment_based,
312 )
313 }),
314 }
315 }
316
317 pub fn name(&self) -> &str {
319 "azure"
320 }
321}
322
323impl Provider for AzureOpenAIProvider {
325 fn language_model(&self, deployment_id: &str) -> Result<Arc<dyn LanguageModel>, ProviderError> {
326 Ok(self.chat_model(deployment_id))
327 }
328
329 fn text_embedding_model(
330 &self,
331 deployment_id: &str,
332 ) -> Result<Arc<dyn EmbeddingModel<String>>, ProviderError> {
333 Ok(self.text_embedding_model(deployment_id))
334 }
335
336 fn image_model(&self, deployment_id: &str) -> Result<Arc<dyn ImageModel>, ProviderError> {
337 Ok(self.image_model(deployment_id))
338 }
339
340 fn transcription_model(
341 &self,
342 deployment_id: &str,
343 ) -> Result<Arc<dyn llm_kit_provider::TranscriptionModel>, ProviderError> {
344 Err(ProviderError::no_such_model(
345 deployment_id,
346 "azure.transcription-model-not-supported",
347 ))
348 }
349
350 fn speech_model(
351 &self,
352 deployment_id: &str,
353 ) -> Result<Arc<dyn llm_kit_provider::SpeechModel>, ProviderError> {
354 Err(ProviderError::no_such_model(
355 deployment_id,
356 "azure.speech-model-not-supported",
357 ))
358 }
359
360 fn reranking_model(
361 &self,
362 deployment_id: &str,
363 ) -> Result<Arc<dyn llm_kit_provider::RerankingModel>, ProviderError> {
364 Err(ProviderError::no_such_model(
365 deployment_id,
366 "azure.reranking-model-not-supported",
367 ))
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 fn create_test_provider() -> AzureOpenAIProvider {
376 AzureOpenAIProvider::new(
377 AzureOpenAIProviderSettings::new()
378 .with_resource_name("test-resource")
379 .with_api_key("test-key"),
380 )
381 }
382
383 #[test]
384 fn test_create_azure_provider() {
385 let provider = create_test_provider();
386 assert_eq!(provider.name(), "azure");
387 }
388
389 #[test]
390 fn test_chat_model() {
391 let provider = create_test_provider();
392 let model = provider.chat_model("gpt-4-deployment");
393 assert_eq!(model.provider(), "azure.chat");
394 assert_eq!(model.model_id(), "gpt-4-deployment");
395 }
396
397 #[test]
398 fn test_completion_model() {
399 let provider = create_test_provider();
400 let model = provider.completion_model("gpt-35-turbo-instruct");
401 assert_eq!(model.provider(), "azure.completion");
402 assert_eq!(model.model_id(), "gpt-35-turbo-instruct");
403 }
404
405 #[test]
406 fn test_text_embedding_model() {
407 let provider = create_test_provider();
408 let model = provider.text_embedding_model("text-embedding-ada-002");
409 assert_eq!(model.provider(), "azure.embedding");
410 assert_eq!(model.model_id(), "text-embedding-ada-002");
411 }
412
413 #[test]
414 fn test_image_model() {
415 let provider = create_test_provider();
416 let model = provider.image_model("dall-e-3");
417 assert_eq!(model.provider(), "azure.image");
418 assert_eq!(model.model_id(), "dall-e-3");
419 }
420
421 #[test]
422 fn test_model_alias() {
423 let provider = create_test_provider();
424 let model = provider.model("gpt-4-deployment");
425 assert_eq!(model.provider(), "azure.chat");
426 assert_eq!(model.model_id(), "gpt-4-deployment");
427 }
428
429 #[test]
430 fn test_provider_trait_implementation() {
431 let provider = create_test_provider();
432 let provider_trait: &dyn Provider = &provider;
433
434 let model = provider_trait.language_model("gpt-4-deployment").unwrap();
436 assert_eq!(model.provider(), "azure.chat");
437 assert_eq!(model.model_id(), "gpt-4-deployment");
438
439 let embedding_model = provider_trait
441 .text_embedding_model("text-embedding-ada-002")
442 .unwrap();
443 assert_eq!(embedding_model.provider(), "azure.embedding");
444 assert_eq!(embedding_model.model_id(), "text-embedding-ada-002");
445
446 let image_model = provider_trait.image_model("dall-e-3").unwrap();
448 assert_eq!(image_model.provider(), "azure.image");
449 assert_eq!(image_model.model_id(), "dall-e-3");
450
451 assert!(provider_trait.transcription_model("whisper").is_err());
453 assert!(provider_trait.speech_model("tts-1").is_err());
454 assert!(provider_trait.reranking_model("rerank-1").is_err());
455 }
456
457 #[test]
458 fn test_build_url_v1_format() {
459 let url = AzureOpenAIProvider::build_url(
460 "https://test.openai.azure.com/openai",
461 "gpt-4-deployment",
462 "/chat/completions",
463 "2024-02-15-preview",
464 false,
465 );
466
467 assert!(url.contains("/v1/chat/completions"));
468 assert!(url.contains("api-version=2024-02-15-preview"));
469 }
470
471 #[test]
472 fn test_build_url_deployment_based_format() {
473 let url = AzureOpenAIProvider::build_url(
474 "https://test.openai.azure.com/openai",
475 "gpt-4-deployment",
476 "/chat/completions",
477 "2024-02-15-preview",
478 true,
479 );
480
481 assert!(url.contains("/deployments/gpt-4-deployment/chat/completions"));
482 assert!(url.contains("api-version=2024-02-15-preview"));
483 }
484
485 #[test]
486 fn test_with_base_url() {
487 let provider = AzureOpenAIProvider::new(
488 AzureOpenAIProviderSettings::new()
489 .with_base_url("https://custom.endpoint.com/openai")
490 .with_api_key("test-key"),
491 );
492
493 let model = provider.chat_model("gpt-4");
494 assert_eq!(model.provider(), "azure.chat");
495 }
496
497 #[test]
498 fn test_with_custom_api_version() {
499 let provider = AzureOpenAIProvider::new(
500 AzureOpenAIProviderSettings::new()
501 .with_resource_name("test-resource")
502 .with_api_key("test-key")
503 .with_api_version("2023-05-15"),
504 );
505
506 let model = provider.chat_model("gpt-4");
507 assert_eq!(model.provider(), "azure.chat");
508 }
509
510 #[test]
511 fn test_with_deployment_based_urls() {
512 let provider = AzureOpenAIProvider::new(
513 AzureOpenAIProviderSettings::new()
514 .with_resource_name("test-resource")
515 .with_api_key("test-key")
516 .with_use_deployment_based_urls(true),
517 );
518
519 let model = provider.chat_model("gpt-4");
520 assert_eq!(model.provider(), "azure.chat");
521 }
522
523 #[test]
524 #[should_panic(expected = "Invalid Azure OpenAI provider settings")]
525 fn test_provider_without_url_or_resource_panics() {
526 AzureOpenAIProvider::new(AzureOpenAIProviderSettings::new().with_api_key("test-key"));
527 }
528}