rig/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::azure;
6//!
7//! let client = azure::Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
8//!
9//! let gpt4o = client.completion_model(azure::GPT_4O);
10//! ```
11
12use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
13
14use crate::json_utils::merge;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17    completion::{self, CompletionError, CompletionRequest},
18    embeddings::{self, EmbeddingError},
19    json_utils,
20    providers::openai,
21    transcription::{self, TranscriptionError},
22};
23use reqwest::multipart::Part;
24use serde::Deserialize;
25use serde_json::json;
26// ================================================================
27// Main Azure OpenAI Client
28// ================================================================
29
30#[derive(Debug, Clone)]
31pub struct Client {
32    api_version: String,
33    azure_endpoint: String,
34    http_client: reqwest::Client,
35}
36
37#[derive(Clone)]
38pub enum AzureOpenAIAuth {
39    ApiKey(String),
40    Token(String),
41}
42
43impl From<String> for AzureOpenAIAuth {
44    fn from(token: String) -> Self {
45        AzureOpenAIAuth::Token(token)
46    }
47}
48
49impl Client {
50    /// Creates a new Azure OpenAI client.
51    ///
52    /// # Arguments
53    ///
54    /// * `auth` - Azure OpenAI API key or token required for authentication
55    /// * `api_version` - API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
56    /// * `azure_endpoint` - Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
57    pub fn new(auth: impl Into<AzureOpenAIAuth>, api_version: &str, azure_endpoint: &str) -> Self {
58        let mut headers = reqwest::header::HeaderMap::new();
59        match auth.into() {
60            AzureOpenAIAuth::ApiKey(api_key) => {
61                headers.insert("api-key", api_key.parse().expect("API key should parse"));
62            }
63            AzureOpenAIAuth::Token(token) => {
64                headers.insert(
65                    "Authorization",
66                    format!("Bearer {token}")
67                        .parse()
68                        .expect("Token should parse"),
69                );
70            }
71        }
72
73        Self {
74            api_version: api_version.to_string(),
75            azure_endpoint: azure_endpoint.to_string(),
76            http_client: reqwest::Client::builder()
77                .default_headers(headers)
78                .build()
79                .expect("Azure OpenAI reqwest client should build"),
80        }
81    }
82
83    /// Creates a new Azure OpenAI client from an API key.
84    ///
85    /// # Arguments
86    ///
87    /// * `api_key` - Azure OpenAI API key required for authentication
88    /// * `api_version` - API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
89    /// * `azure_endpoint` - Azure OpenAI endpoint URL
90    pub fn from_api_key(api_key: &str, api_version: &str, azure_endpoint: &str) -> Self {
91        Self::new(
92            AzureOpenAIAuth::ApiKey(api_key.to_string()),
93            api_version,
94            azure_endpoint,
95        )
96    }
97
98    /// Creates a new Azure OpenAI client from a token.
99    ///
100    /// # Arguments
101    ///
102    /// * `token` - Azure OpenAI token required for authentication
103    /// * `api_version` - API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
104    /// * `azure_endpoint` - Azure OpenAI endpoint URL
105    pub fn from_token(token: &str, api_version: &str, azure_endpoint: &str) -> Self {
106        Self::new(
107            AzureOpenAIAuth::Token(token.to_string()),
108            api_version,
109            azure_endpoint,
110        )
111    }
112
113    fn post_embedding(&self, deployment_id: &str) -> reqwest::RequestBuilder {
114        let url = format!(
115            "{}/openai/deployments/{}/embeddings?api-version={}",
116            self.azure_endpoint, deployment_id, self.api_version
117        )
118        .replace("//", "/");
119        self.http_client.post(url)
120    }
121
122    fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder {
123        let url = format!(
124            "{}/openai/deployments/{}/chat/completions?api-version={}",
125            self.azure_endpoint, deployment_id, self.api_version
126        )
127        .replace("//", "/");
128        self.http_client.post(url)
129    }
130
131    fn post_transcription(&self, deployment_id: &str) -> reqwest::RequestBuilder {
132        let url = format!(
133            "{}/openai/deployments/{}/audio/translations?api-version={}",
134            self.azure_endpoint, deployment_id, self.api_version
135        )
136        .replace("//", "/");
137        self.http_client.post(url)
138    }
139
140    #[cfg(feature = "image")]
141    fn post_image_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
142        let url = format!(
143            "{}/openai/deployments/{}/images/generations?api-version={}",
144            self.azure_endpoint, deployment_id, self.api_version
145        )
146        .replace("//", "/");
147        self.http_client.post(url)
148    }
149
150    #[cfg(feature = "audio")]
151    fn post_audio_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
152        let url = format!(
153            "{}/openai/deployments/{}/audio/speech?api-version={}",
154            self.azure_endpoint, deployment_id, self.api_version
155        )
156        .replace("//", "/");
157        self.http_client.post(url)
158    }
159}
160
161impl ProviderClient for Client {
162    /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
163    fn from_env() -> Self {
164        let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
165            AzureOpenAIAuth::ApiKey(api_key)
166        } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
167            AzureOpenAIAuth::Token(token)
168        } else {
169            panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
170        };
171
172        let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
173        let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
174
175        Self::new(auth, &api_version, &azure_endpoint)
176    }
177}
178
179impl CompletionClient for Client {
180    type CompletionModel = CompletionModel;
181
182    /// Create a completion model with the given name.
183    ///
184    /// # Example
185    /// ```
186    /// use rig::providers::azure::{Client, self};
187    ///
188    /// // Initialize the Azure OpenAI client
189    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
190    ///
191    /// let gpt4 = azure.completion_model(azure::GPT_4);
192    /// ```
193    fn completion_model(&self, model: &str) -> CompletionModel {
194        CompletionModel::new(self.clone(), model)
195    }
196}
197
198impl EmbeddingsClient for Client {
199    type EmbeddingModel = EmbeddingModel;
200
201    /// Create an embedding model with the given name.
202    /// Note: default embedding dimension of 0 will be used if model is not known.
203    /// If this is the case, it's better to use function `embedding_model_with_ndims`
204    ///
205    /// # Example
206    /// ```
207    /// use rig::providers::azure::{Client, self};
208    ///
209    /// // Initialize the Azure OpenAI client
210    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
211    ///
212    /// let embedding_model = azure.embedding_model(azure::TEXT_EMBEDDING_3_LARGE);
213    /// ```
214    fn embedding_model(&self, model: &str) -> EmbeddingModel {
215        let ndims = match model {
216            TEXT_EMBEDDING_3_LARGE => 3072,
217            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
218            _ => 0,
219        };
220        EmbeddingModel::new(self.clone(), model, ndims)
221    }
222
223    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
224    ///
225    /// # Example
226    /// ```
227    /// use rig::providers::azure::{Client, self};
228    ///
229    /// // Initialize the Azure OpenAI client
230    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
231    ///
232    /// let embedding_model = azure.embedding_model("model-unknown-to-rig", 3072);
233    /// ```
234    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
235        EmbeddingModel::new(self.clone(), model, ndims)
236    }
237}
238
239impl TranscriptionClient for Client {
240    type TranscriptionModel = TranscriptionModel;
241
242    /// Create a transcription model with the given name.
243    ///
244    /// # Example
245    /// ```
246    /// use rig::providers::azure::{Client, self};
247    ///
248    /// // Initialize the Azure OpenAI client
249    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
250    ///
251    /// let whisper = azure.transcription_model("model-unknown-to-rig");
252    /// ```
253    fn transcription_model(&self, model: &str) -> TranscriptionModel {
254        TranscriptionModel::new(self.clone(), model)
255    }
256}
257
258#[derive(Debug, Deserialize)]
259struct ApiErrorResponse {
260    message: String,
261}
262
263#[derive(Debug, Deserialize)]
264#[serde(untagged)]
265enum ApiResponse<T> {
266    Ok(T),
267    Err(ApiErrorResponse),
268}
269
270// ================================================================
271// Azure OpenAI Embedding API
272// ================================================================
273/// `text-embedding-3-large` embedding model
274pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
275/// `text-embedding-3-small` embedding model
276pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
277/// `text-embedding-ada-002` embedding model
278pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
279
280#[derive(Debug, Deserialize)]
281pub struct EmbeddingResponse {
282    pub object: String,
283    pub data: Vec<EmbeddingData>,
284    pub model: String,
285    pub usage: Usage,
286}
287
288impl From<ApiErrorResponse> for EmbeddingError {
289    fn from(err: ApiErrorResponse) -> Self {
290        EmbeddingError::ProviderError(err.message)
291    }
292}
293
294impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
295    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
296        match value {
297            ApiResponse::Ok(response) => Ok(response),
298            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
299        }
300    }
301}
302
303#[derive(Debug, Deserialize)]
304pub struct EmbeddingData {
305    pub object: String,
306    pub embedding: Vec<f64>,
307    pub index: usize,
308}
309
310#[derive(Clone, Debug, Deserialize)]
311pub struct Usage {
312    pub prompt_tokens: usize,
313    pub total_tokens: usize,
314}
315
316impl std::fmt::Display for Usage {
317    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318        write!(
319            f,
320            "Prompt tokens: {} Total tokens: {}",
321            self.prompt_tokens, self.total_tokens
322        )
323    }
324}
325
326#[derive(Clone)]
327pub struct EmbeddingModel {
328    client: Client,
329    pub model: String,
330    ndims: usize,
331}
332
333impl embeddings::EmbeddingModel for EmbeddingModel {
334    const MAX_DOCUMENTS: usize = 1024;
335
336    fn ndims(&self) -> usize {
337        self.ndims
338    }
339
340    #[cfg_attr(feature = "worker", worker::send)]
341    async fn embed_texts(
342        &self,
343        documents: impl IntoIterator<Item = String>,
344    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
345        let documents = documents.into_iter().collect::<Vec<_>>();
346
347        let response = self
348            .client
349            .post_embedding(&self.model)
350            .json(&json!({
351                "input": documents,
352            }))
353            .send()
354            .await?;
355
356        if response.status().is_success() {
357            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
358                ApiResponse::Ok(response) => {
359                    tracing::info!(target: "rig",
360                        "Azure embedding token usage: {}",
361                        response.usage
362                    );
363
364                    if response.data.len() != documents.len() {
365                        return Err(EmbeddingError::ResponseError(
366                            "Response data length does not match input length".into(),
367                        ));
368                    }
369
370                    Ok(response
371                        .data
372                        .into_iter()
373                        .zip(documents.into_iter())
374                        .map(|(embedding, document)| embeddings::Embedding {
375                            document,
376                            vec: embedding.embedding,
377                        })
378                        .collect())
379                }
380                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
381            }
382        } else {
383            Err(EmbeddingError::ProviderError(response.text().await?))
384        }
385    }
386}
387
388impl EmbeddingModel {
389    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
390        Self {
391            client,
392            model: model.to_string(),
393            ndims,
394        }
395    }
396}
397
398// ================================================================
399// Azure OpenAI Completion API
400// ================================================================
401/// `o1` completion model
402pub const O1: &str = "o1";
403/// `o1-preview` completion model
404pub const O1_PREVIEW: &str = "o1-preview";
405/// `o1-mini` completion model
406pub const O1_MINI: &str = "o1-mini";
407/// `gpt-4o` completion model
408pub const GPT_4O: &str = "gpt-4o";
409/// `gpt-4o-mini` completion model
410pub const GPT_4O_MINI: &str = "gpt-4o-mini";
411/// `gpt-4o-realtime-preview` completion model
412pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
413/// `gpt-4-turbo` completion model
414pub const GPT_4_TURBO: &str = "gpt-4";
415/// `gpt-4` completion model
416pub const GPT_4: &str = "gpt-4";
417/// `gpt-4-32k` completion model
418pub const GPT_4_32K: &str = "gpt-4-32k";
419/// `gpt-4-32k` completion model
420pub const GPT_4_32K_0613: &str = "gpt-4-32k";
421/// `gpt-3.5-turbo` completion model
422pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
423/// `gpt-3.5-turbo-instruct` completion model
424pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
425/// `gpt-3.5-turbo-16k` completion model
426pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
427
428#[derive(Clone)]
429pub struct CompletionModel {
430    client: Client,
431    /// Name of the model (e.g.: gpt-4o-mini)
432    pub model: String,
433}
434
435impl CompletionModel {
436    pub fn new(client: Client, model: &str) -> Self {
437        Self {
438            client,
439            model: model.to_string(),
440        }
441    }
442
443    fn create_completion_request(
444        &self,
445        completion_request: CompletionRequest,
446    ) -> Result<serde_json::Value, CompletionError> {
447        let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
448            Some(preamble) => vec![openai::Message::system(preamble)],
449            None => vec![],
450        };
451        if let Some(docs) = completion_request.normalized_documents() {
452            let docs: Vec<openai::Message> = docs.try_into()?;
453            full_history.extend(docs);
454        }
455        let chat_history: Vec<openai::Message> = completion_request
456            .chat_history
457            .into_iter()
458            .map(|message| message.try_into())
459            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
460            .into_iter()
461            .flatten()
462            .collect();
463
464        full_history.extend(chat_history);
465
466        let request = if completion_request.tools.is_empty() {
467            json!({
468                "model": self.model,
469                "messages": full_history,
470                "temperature": completion_request.temperature,
471            })
472        } else {
473            json!({
474                "model": self.model,
475                "messages": full_history,
476                "temperature": completion_request.temperature,
477                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
478                "tool_choice": "auto",
479            })
480        };
481
482        let request = if let Some(params) = completion_request.additional_params {
483            json_utils::merge(request, params)
484        } else {
485            request
486        };
487
488        Ok(request)
489    }
490}
491
492impl completion::CompletionModel for CompletionModel {
493    type Response = openai::CompletionResponse;
494    type StreamingResponse = openai::StreamingCompletionResponse;
495
496    #[cfg_attr(feature = "worker", worker::send)]
497    async fn completion(
498        &self,
499        completion_request: CompletionRequest,
500    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
501        let request = self.create_completion_request(completion_request)?;
502
503        let response = self
504            .client
505            .post_chat_completion(&self.model)
506            .json(&request)
507            .send()
508            .await?;
509
510        if response.status().is_success() {
511            let t = response.text().await?;
512            tracing::debug!(target: "rig", "Azure completion error: {}", t);
513
514            match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
515                ApiResponse::Ok(response) => {
516                    tracing::info!(target: "rig",
517                        "Azure completion token usage: {:?}",
518                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
519                    );
520                    response.try_into()
521                }
522                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
523            }
524        } else {
525            Err(CompletionError::ProviderError(response.text().await?))
526        }
527    }
528
529    #[cfg_attr(feature = "worker", worker::send)]
530    async fn stream(
531        &self,
532        request: CompletionRequest,
533    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
534        let mut request = self.create_completion_request(request)?;
535
536        request = merge(
537            request,
538            json!({"stream": true, "stream_options": {"include_usage": true}}),
539        );
540
541        let builder = self
542            .client
543            .post_chat_completion(self.model.as_str())
544            .json(&request);
545
546        send_compatible_streaming_request(builder).await
547    }
548}
549
550// ================================================================
551// Azure OpenAI Transcription API
552// ================================================================
553
554#[derive(Clone)]
555pub struct TranscriptionModel {
556    client: Client,
557    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
558    pub model: String,
559}
560
561impl TranscriptionModel {
562    pub fn new(client: Client, model: &str) -> Self {
563        Self {
564            client,
565            model: model.to_string(),
566        }
567    }
568}
569
570impl transcription::TranscriptionModel for TranscriptionModel {
571    type Response = TranscriptionResponse;
572
573    #[cfg_attr(feature = "worker", worker::send)]
574    async fn transcription(
575        &self,
576        request: transcription::TranscriptionRequest,
577    ) -> Result<
578        transcription::TranscriptionResponse<Self::Response>,
579        transcription::TranscriptionError,
580    > {
581        let data = request.data;
582
583        let mut body = reqwest::multipart::Form::new().part(
584            "file",
585            Part::bytes(data).file_name(request.filename.clone()),
586        );
587
588        if let Some(prompt) = request.prompt {
589            body = body.text("prompt", prompt.clone());
590        }
591
592        if let Some(ref temperature) = request.temperature {
593            body = body.text("temperature", temperature.to_string());
594        }
595
596        if let Some(ref additional_params) = request.additional_params {
597            for (key, value) in additional_params
598                .as_object()
599                .expect("Additional Parameters to OpenAI Transcription should be a map")
600            {
601                body = body.text(key.to_owned(), value.to_string());
602            }
603        }
604
605        let response = self
606            .client
607            .post_transcription(&self.model)
608            .multipart(body)
609            .send()
610            .await?;
611
612        if response.status().is_success() {
613            match response
614                .json::<ApiResponse<TranscriptionResponse>>()
615                .await?
616            {
617                ApiResponse::Ok(response) => response.try_into(),
618                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
619                    api_error_response.message,
620                )),
621            }
622        } else {
623            Err(TranscriptionError::ProviderError(response.text().await?))
624        }
625    }
626}
627
628// ================================================================
629// Azure OpenAI Image Generation API
630// ================================================================
631#[cfg(feature = "image")]
632pub use image_generation::*;
633#[cfg(feature = "image")]
634mod image_generation {
635    use crate::client::ImageGenerationClient;
636    use crate::image_generation;
637    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
638    use crate::providers::azure::{ApiResponse, Client};
639    use crate::providers::openai::ImageGenerationResponse;
640    use serde_json::json;
641
642    #[derive(Clone)]
643    pub struct ImageGenerationModel {
644        client: Client,
645        pub model: String,
646    }
647    impl image_generation::ImageGenerationModel for ImageGenerationModel {
648        type Response = ImageGenerationResponse;
649
650        async fn image_generation(
651            &self,
652            generation_request: ImageGenerationRequest,
653        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
654        {
655            let request = json!({
656                "model": self.model,
657                "prompt": generation_request.prompt,
658                "size": format!("{}x{}", generation_request.width, generation_request.height),
659                "response_format": "b64_json"
660            });
661
662            let response = self
663                .client
664                .post_image_generation(&self.model)
665                .json(&request)
666                .send()
667                .await?;
668
669            if !response.status().is_success() {
670                return Err(ImageGenerationError::ProviderError(format!(
671                    "{}: {}",
672                    response.status(),
673                    response.text().await?
674                )));
675            }
676
677            let t = response.text().await?;
678
679            match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&t)? {
680                ApiResponse::Ok(response) => response.try_into(),
681                ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
682            }
683        }
684    }
685
686    impl ImageGenerationClient for Client {
687        type ImageGenerationModel = ImageGenerationModel;
688
689        fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
690            ImageGenerationModel {
691                client: self.clone(),
692                model: model.to_string(),
693            }
694        }
695    }
696}
697// ================================================================
698// Azure OpenAI Audio Generation API
699// ================================================================
700
701use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient};
702#[cfg(feature = "audio")]
703pub use audio_generation::*;
704
705#[cfg(feature = "audio")]
706mod audio_generation {
707    use super::Client;
708    use crate::audio_generation;
709    use crate::audio_generation::{
710        AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
711    };
712    use crate::client::AudioGenerationClient;
713    use bytes::Bytes;
714    use serde_json::json;
715
716    #[derive(Clone)]
717    pub struct AudioGenerationModel {
718        client: Client,
719        model: String,
720    }
721
722    impl audio_generation::AudioGenerationModel for AudioGenerationModel {
723        type Response = Bytes;
724
725        async fn audio_generation(
726            &self,
727            request: AudioGenerationRequest,
728        ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
729            let request = json!({
730                "model": self.model,
731                "input": request.text,
732                "voice": request.voice,
733                "speed": request.speed,
734            });
735
736            let response = self
737                .client
738                .post_audio_generation("/audio/speech")
739                .json(&request)
740                .send()
741                .await?;
742
743            if !response.status().is_success() {
744                return Err(AudioGenerationError::ProviderError(format!(
745                    "{}: {}",
746                    response.status(),
747                    response.text().await?
748                )));
749            }
750
751            let bytes = response.bytes().await?;
752
753            Ok(AudioGenerationResponse {
754                audio: bytes.to_vec(),
755                response: bytes,
756            })
757        }
758    }
759
760    impl AudioGenerationClient for Client {
761        type AudioGenerationModel = AudioGenerationModel;
762
763        fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
764            AudioGenerationModel {
765                client: self.clone(),
766                model: model.to_string(),
767            }
768        }
769    }
770}
771
772#[cfg(test)]
773mod azure_tests {
774    use super::*;
775
776    use crate::completion::CompletionModel;
777    use crate::embeddings::EmbeddingModel;
778    use crate::OneOrMany;
779
780    #[tokio::test]
781    #[ignore]
782    async fn test_azure_embedding() {
783        let _ = tracing_subscriber::fmt::try_init();
784
785        let client = Client::from_env();
786        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
787        let embeddings = model
788            .embed_texts(vec!["Hello, world!".to_string()])
789            .await
790            .unwrap();
791
792        tracing::info!("Azure embedding: {:?}", embeddings);
793    }
794
795    #[tokio::test]
796    #[ignore]
797    async fn test_azure_completion() {
798        let _ = tracing_subscriber::fmt::try_init();
799
800        let client = Client::from_env();
801        let model = client.completion_model(GPT_4O_MINI);
802        let completion = model
803            .completion(CompletionRequest {
804                preamble: Some("You are a helpful assistant.".to_string()),
805                chat_history: OneOrMany::one("Hello!".into()),
806                documents: vec![],
807                max_tokens: Some(100),
808                temperature: Some(0.0),
809                tools: vec![],
810                additional_params: None,
811            })
812            .await
813            .unwrap();
814
815        tracing::info!("Azure completion: {:?}", completion);
816    }
817}