rig/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::azure;
6//!
7//! let client = azure::Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
8//!
9//! let gpt4o = client.completion_model(azure::GPT_4O);
10//! ```
11
12use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
13
14use crate::json_utils::merge;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17    completion::{self, CompletionError, CompletionRequest},
18    embeddings::{self, EmbeddingError},
19    json_utils,
20    providers::openai,
21    transcription::{self, TranscriptionError},
22};
23use reqwest::header::AUTHORIZATION;
24use reqwest::multipart::Part;
25use serde::Deserialize;
26use serde_json::json;
27// ================================================================
28// Main Azure OpenAI Client
29// ================================================================
30
31#[derive(Clone)]
32pub struct Client {
33    api_version: String,
34    azure_endpoint: String,
35    auth: AzureOpenAIAuth,
36    http_client: reqwest::Client,
37}
38
39impl std::fmt::Debug for Client {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        f.debug_struct("Client")
42            .field("azure_endpoint", &self.azure_endpoint)
43            .field("http_client", &self.http_client)
44            .field("auth", &"<REDACTED>")
45            .field("api_version", &self.api_version)
46            .finish()
47    }
48}
49
50#[derive(Clone)]
51pub enum AzureOpenAIAuth {
52    ApiKey(String),
53    Token(String),
54}
55
56impl std::fmt::Debug for AzureOpenAIAuth {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        match self {
59            Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
60            Self::Token(_) => write!(f, "Token <REDACTED>"),
61        }
62    }
63}
64
65impl From<String> for AzureOpenAIAuth {
66    fn from(token: String) -> Self {
67        AzureOpenAIAuth::Token(token)
68    }
69}
70
71impl AzureOpenAIAuth {
72    fn as_header(&self) -> (reqwest::header::HeaderName, reqwest::header::HeaderValue) {
73        match self {
74            AzureOpenAIAuth::ApiKey(api_key) => (
75                "api-key".parse().expect("Header value should parse"),
76                api_key.parse().expect("API key should parse"),
77            ),
78            AzureOpenAIAuth::Token(token) => (
79                AUTHORIZATION,
80                format!("Bearer {token}")
81                    .parse()
82                    .expect("Token should parse"),
83            ),
84        }
85    }
86}
87
88impl Client {
89    /// Creates a new Azure OpenAI client.
90    ///
91    /// # Arguments
92    ///
93    /// * `auth` - Azure OpenAI API key or token required for authentication
94    /// * `api_version` - API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
95    /// * `azure_endpoint` - Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
96    pub fn new(auth: impl Into<AzureOpenAIAuth>, api_version: &str, azure_endpoint: &str) -> Self {
97        Self {
98            api_version: api_version.to_string(),
99            auth: auth.into(),
100            azure_endpoint: azure_endpoint.to_string(),
101            http_client: reqwest::Client::builder()
102                .build()
103                .expect("Azure OpenAI reqwest client should build"),
104        }
105    }
106
107    /// Use your own `reqwest::Client`.
108    /// The required headers will be automatically attached upon trying to make a request.
109    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
110        self.http_client = client;
111
112        self
113    }
114
115    /// Creates a new Azure OpenAI client from an API key.
116    ///
117    /// # Arguments
118    ///
119    /// * `api_key` - Azure OpenAI API key required for authentication
120    /// * `api_version` - API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
121    /// * `azure_endpoint` - Azure OpenAI endpoint URL
122    pub fn from_api_key(api_key: &str, api_version: &str, azure_endpoint: &str) -> Self {
123        Self::new(
124            AzureOpenAIAuth::ApiKey(api_key.to_string()),
125            api_version,
126            azure_endpoint,
127        )
128    }
129
130    /// Creates a new Azure OpenAI client from a token.
131    ///
132    /// # Arguments
133    ///
134    /// * `token` - Azure OpenAI token required for authentication
135    /// * `api_version` - API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
136    /// * `azure_endpoint` - Azure OpenAI endpoint URL
137    pub fn from_token(token: &str, api_version: &str, azure_endpoint: &str) -> Self {
138        Self::new(
139            AzureOpenAIAuth::Token(token.to_string()),
140            api_version,
141            azure_endpoint,
142        )
143    }
144
145    fn post_embedding(&self, deployment_id: &str) -> reqwest::RequestBuilder {
146        let url = format!(
147            "{}/openai/deployments/{}/embeddings?api-version={}",
148            self.azure_endpoint, deployment_id, self.api_version
149        )
150        .replace("//", "/");
151
152        let (key, value) = self.auth.as_header();
153        self.http_client.post(url).header(key, value)
154    }
155
156    fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder {
157        let url = format!(
158            "{}/openai/deployments/{}/chat/completions?api-version={}",
159            self.azure_endpoint, deployment_id, self.api_version
160        )
161        .replace("//", "/");
162        let (key, value) = self.auth.as_header();
163        self.http_client.post(url).header(key, value)
164    }
165
166    fn post_transcription(&self, deployment_id: &str) -> reqwest::RequestBuilder {
167        let url = format!(
168            "{}/openai/deployments/{}/audio/translations?api-version={}",
169            self.azure_endpoint, deployment_id, self.api_version
170        )
171        .replace("//", "/");
172        let (key, value) = self.auth.as_header();
173        self.http_client.post(url).header(key, value)
174    }
175
176    #[cfg(feature = "image")]
177    fn post_image_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
178        let url = format!(
179            "{}/openai/deployments/{}/images/generations?api-version={}",
180            self.azure_endpoint, deployment_id, self.api_version
181        )
182        .replace("//", "/");
183        let (key, value) = self.auth.as_header();
184        self.http_client.post(url).header(key, value)
185    }
186
187    #[cfg(feature = "audio")]
188    fn post_audio_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
189        let url = format!(
190            "{}/openai/deployments/{}/audio/speech?api-version={}",
191            self.azure_endpoint, deployment_id, self.api_version
192        )
193        .replace("//", "/");
194        let (key, value) = self.auth.as_header();
195        self.http_client.post(url).header(key, value)
196    }
197}
198
199impl ProviderClient for Client {
200    /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
201    fn from_env() -> Self {
202        let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
203            AzureOpenAIAuth::ApiKey(api_key)
204        } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
205            AzureOpenAIAuth::Token(token)
206        } else {
207            panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
208        };
209
210        let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
211        let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
212
213        Self::new(auth, &api_version, &azure_endpoint)
214    }
215
216    fn from_val(input: crate::client::ProviderValue) -> Self {
217        let crate::client::ProviderValue::ApiKeyWithVersionAndHeader(api_key, version, header) =
218            input
219        else {
220            panic!("Incorrect provider value type")
221        };
222        let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
223        Self::new(auth, &version, &header)
224    }
225}
226
227impl CompletionClient for Client {
228    type CompletionModel = CompletionModel;
229
230    /// Create a completion model with the given name.
231    ///
232    /// # Example
233    /// ```
234    /// use rig::providers::azure::{Client, self};
235    ///
236    /// // Initialize the Azure OpenAI client
237    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
238    ///
239    /// let gpt4 = azure.completion_model(azure::GPT_4);
240    /// ```
241    fn completion_model(&self, model: &str) -> CompletionModel {
242        CompletionModel::new(self.clone(), model)
243    }
244}
245
246impl EmbeddingsClient for Client {
247    type EmbeddingModel = EmbeddingModel;
248
249    /// Create an embedding model with the given name.
250    /// Note: default embedding dimension of 0 will be used if model is not known.
251    /// If this is the case, it's better to use function `embedding_model_with_ndims`
252    ///
253    /// # Example
254    /// ```
255    /// use rig::providers::azure::{Client, self};
256    ///
257    /// // Initialize the Azure OpenAI client
258    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
259    ///
260    /// let embedding_model = azure.embedding_model(azure::TEXT_EMBEDDING_3_LARGE);
261    /// ```
262    fn embedding_model(&self, model: &str) -> EmbeddingModel {
263        let ndims = match model {
264            TEXT_EMBEDDING_3_LARGE => 3072,
265            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
266            _ => 0,
267        };
268        EmbeddingModel::new(self.clone(), model, ndims)
269    }
270
271    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
272    ///
273    /// # Example
274    /// ```
275    /// use rig::providers::azure::{Client, self};
276    ///
277    /// // Initialize the Azure OpenAI client
278    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
279    ///
280    /// let embedding_model = azure.embedding_model("model-unknown-to-rig", 3072);
281    /// ```
282    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
283        EmbeddingModel::new(self.clone(), model, ndims)
284    }
285}
286
287impl TranscriptionClient for Client {
288    type TranscriptionModel = TranscriptionModel;
289
290    /// Create a transcription model with the given name.
291    ///
292    /// # Example
293    /// ```
294    /// use rig::providers::azure::{Client, self};
295    ///
296    /// // Initialize the Azure OpenAI client
297    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
298    ///
299    /// let whisper = azure.transcription_model("model-unknown-to-rig");
300    /// ```
301    fn transcription_model(&self, model: &str) -> TranscriptionModel {
302        TranscriptionModel::new(self.clone(), model)
303    }
304}
305
306#[derive(Debug, Deserialize)]
307struct ApiErrorResponse {
308    message: String,
309}
310
311#[derive(Debug, Deserialize)]
312#[serde(untagged)]
313enum ApiResponse<T> {
314    Ok(T),
315    Err(ApiErrorResponse),
316}
317
318// ================================================================
319// Azure OpenAI Embedding API
320// ================================================================
321/// `text-embedding-3-large` embedding model
322pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
323/// `text-embedding-3-small` embedding model
324pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
325/// `text-embedding-ada-002` embedding model
326pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
327
328#[derive(Debug, Deserialize)]
329pub struct EmbeddingResponse {
330    pub object: String,
331    pub data: Vec<EmbeddingData>,
332    pub model: String,
333    pub usage: Usage,
334}
335
336impl From<ApiErrorResponse> for EmbeddingError {
337    fn from(err: ApiErrorResponse) -> Self {
338        EmbeddingError::ProviderError(err.message)
339    }
340}
341
342impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
343    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
344        match value {
345            ApiResponse::Ok(response) => Ok(response),
346            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
347        }
348    }
349}
350
351#[derive(Debug, Deserialize)]
352pub struct EmbeddingData {
353    pub object: String,
354    pub embedding: Vec<f64>,
355    pub index: usize,
356}
357
358#[derive(Clone, Debug, Deserialize)]
359pub struct Usage {
360    pub prompt_tokens: usize,
361    pub total_tokens: usize,
362}
363
364impl std::fmt::Display for Usage {
365    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366        write!(
367            f,
368            "Prompt tokens: {} Total tokens: {}",
369            self.prompt_tokens, self.total_tokens
370        )
371    }
372}
373
374#[derive(Clone)]
375pub struct EmbeddingModel {
376    client: Client,
377    pub model: String,
378    ndims: usize,
379}
380
381impl embeddings::EmbeddingModel for EmbeddingModel {
382    const MAX_DOCUMENTS: usize = 1024;
383
384    fn ndims(&self) -> usize {
385        self.ndims
386    }
387
388    #[cfg_attr(feature = "worker", worker::send)]
389    async fn embed_texts(
390        &self,
391        documents: impl IntoIterator<Item = String>,
392    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
393        let documents = documents.into_iter().collect::<Vec<_>>();
394
395        let response = self
396            .client
397            .post_embedding(&self.model)
398            .json(&json!({
399                "input": documents,
400            }))
401            .send()
402            .await?;
403
404        if response.status().is_success() {
405            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
406                ApiResponse::Ok(response) => {
407                    tracing::info!(target: "rig",
408                        "Azure embedding token usage: {}",
409                        response.usage
410                    );
411
412                    if response.data.len() != documents.len() {
413                        return Err(EmbeddingError::ResponseError(
414                            "Response data length does not match input length".into(),
415                        ));
416                    }
417
418                    Ok(response
419                        .data
420                        .into_iter()
421                        .zip(documents.into_iter())
422                        .map(|(embedding, document)| embeddings::Embedding {
423                            document,
424                            vec: embedding.embedding,
425                        })
426                        .collect())
427                }
428                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
429            }
430        } else {
431            Err(EmbeddingError::ProviderError(response.text().await?))
432        }
433    }
434}
435
436impl EmbeddingModel {
437    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
438        Self {
439            client,
440            model: model.to_string(),
441            ndims,
442        }
443    }
444}
445
446// ================================================================
447// Azure OpenAI Completion API
448// ================================================================
449/// `o1` completion model
450pub const O1: &str = "o1";
451/// `o1-preview` completion model
452pub const O1_PREVIEW: &str = "o1-preview";
453/// `o1-mini` completion model
454pub const O1_MINI: &str = "o1-mini";
455/// `gpt-4o` completion model
456pub const GPT_4O: &str = "gpt-4o";
457/// `gpt-4o-mini` completion model
458pub const GPT_4O_MINI: &str = "gpt-4o-mini";
459/// `gpt-4o-realtime-preview` completion model
460pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
461/// `gpt-4-turbo` completion model
462pub const GPT_4_TURBO: &str = "gpt-4";
463/// `gpt-4` completion model
464pub const GPT_4: &str = "gpt-4";
465/// `gpt-4-32k` completion model
466pub const GPT_4_32K: &str = "gpt-4-32k";
467/// `gpt-4-32k` completion model
468pub const GPT_4_32K_0613: &str = "gpt-4-32k";
469/// `gpt-3.5-turbo` completion model
470pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
471/// `gpt-3.5-turbo-instruct` completion model
472pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
473/// `gpt-3.5-turbo-16k` completion model
474pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
475
476#[derive(Clone)]
477pub struct CompletionModel {
478    client: Client,
479    /// Name of the model (e.g.: gpt-4o-mini)
480    pub model: String,
481}
482
483impl CompletionModel {
484    pub fn new(client: Client, model: &str) -> Self {
485        Self {
486            client,
487            model: model.to_string(),
488        }
489    }
490
491    fn create_completion_request(
492        &self,
493        completion_request: CompletionRequest,
494    ) -> Result<serde_json::Value, CompletionError> {
495        let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
496            Some(preamble) => vec![openai::Message::system(preamble)],
497            None => vec![],
498        };
499        if let Some(docs) = completion_request.normalized_documents() {
500            let docs: Vec<openai::Message> = docs.try_into()?;
501            full_history.extend(docs);
502        }
503        let chat_history: Vec<openai::Message> = completion_request
504            .chat_history
505            .into_iter()
506            .map(|message| message.try_into())
507            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
508            .into_iter()
509            .flatten()
510            .collect();
511
512        full_history.extend(chat_history);
513
514        let request = if completion_request.tools.is_empty() {
515            json!({
516                "model": self.model,
517                "messages": full_history,
518                "temperature": completion_request.temperature,
519            })
520        } else {
521            json!({
522                "model": self.model,
523                "messages": full_history,
524                "temperature": completion_request.temperature,
525                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
526                "tool_choice": "auto",
527            })
528        };
529
530        let request = if let Some(params) = completion_request.additional_params {
531            json_utils::merge(request, params)
532        } else {
533            request
534        };
535
536        Ok(request)
537    }
538}
539
540impl completion::CompletionModel for CompletionModel {
541    type Response = openai::CompletionResponse;
542    type StreamingResponse = openai::StreamingCompletionResponse;
543
544    #[cfg_attr(feature = "worker", worker::send)]
545    async fn completion(
546        &self,
547        completion_request: CompletionRequest,
548    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
549        let request = self.create_completion_request(completion_request)?;
550
551        let response = self
552            .client
553            .post_chat_completion(&self.model)
554            .json(&request)
555            .send()
556            .await?;
557
558        if response.status().is_success() {
559            let t = response.text().await?;
560            tracing::debug!(target: "rig", "Azure completion error: {}", t);
561
562            match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
563                ApiResponse::Ok(response) => {
564                    tracing::info!(target: "rig",
565                        "Azure completion token usage: {:?}",
566                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
567                    );
568                    response.try_into()
569                }
570                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
571            }
572        } else {
573            Err(CompletionError::ProviderError(response.text().await?))
574        }
575    }
576
577    #[cfg_attr(feature = "worker", worker::send)]
578    async fn stream(
579        &self,
580        request: CompletionRequest,
581    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
582        let mut request = self.create_completion_request(request)?;
583
584        request = merge(
585            request,
586            json!({"stream": true, "stream_options": {"include_usage": true}}),
587        );
588
589        let builder = self
590            .client
591            .post_chat_completion(self.model.as_str())
592            .json(&request);
593
594        send_compatible_streaming_request(builder).await
595    }
596}
597
598// ================================================================
599// Azure OpenAI Transcription API
600// ================================================================
601
602#[derive(Clone)]
603pub struct TranscriptionModel {
604    client: Client,
605    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
606    pub model: String,
607}
608
609impl TranscriptionModel {
610    pub fn new(client: Client, model: &str) -> Self {
611        Self {
612            client,
613            model: model.to_string(),
614        }
615    }
616}
617
618impl transcription::TranscriptionModel for TranscriptionModel {
619    type Response = TranscriptionResponse;
620
621    #[cfg_attr(feature = "worker", worker::send)]
622    async fn transcription(
623        &self,
624        request: transcription::TranscriptionRequest,
625    ) -> Result<
626        transcription::TranscriptionResponse<Self::Response>,
627        transcription::TranscriptionError,
628    > {
629        let data = request.data;
630
631        let mut body = reqwest::multipart::Form::new().part(
632            "file",
633            Part::bytes(data).file_name(request.filename.clone()),
634        );
635
636        if let Some(prompt) = request.prompt {
637            body = body.text("prompt", prompt.clone());
638        }
639
640        if let Some(ref temperature) = request.temperature {
641            body = body.text("temperature", temperature.to_string());
642        }
643
644        if let Some(ref additional_params) = request.additional_params {
645            for (key, value) in additional_params
646                .as_object()
647                .expect("Additional Parameters to OpenAI Transcription should be a map")
648            {
649                body = body.text(key.to_owned(), value.to_string());
650            }
651        }
652
653        let response = self
654            .client
655            .post_transcription(&self.model)
656            .multipart(body)
657            .send()
658            .await?;
659
660        if response.status().is_success() {
661            match response
662                .json::<ApiResponse<TranscriptionResponse>>()
663                .await?
664            {
665                ApiResponse::Ok(response) => response.try_into(),
666                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
667                    api_error_response.message,
668                )),
669            }
670        } else {
671            Err(TranscriptionError::ProviderError(response.text().await?))
672        }
673    }
674}
675
676// ================================================================
677// Azure OpenAI Image Generation API
678// ================================================================
679#[cfg(feature = "image")]
680pub use image_generation::*;
681#[cfg(feature = "image")]
682mod image_generation {
683    use crate::client::ImageGenerationClient;
684    use crate::image_generation;
685    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
686    use crate::providers::azure::{ApiResponse, Client};
687    use crate::providers::openai::ImageGenerationResponse;
688    use serde_json::json;
689
690    #[derive(Clone)]
691    pub struct ImageGenerationModel {
692        client: Client,
693        pub model: String,
694    }
695    impl image_generation::ImageGenerationModel for ImageGenerationModel {
696        type Response = ImageGenerationResponse;
697
698        async fn image_generation(
699            &self,
700            generation_request: ImageGenerationRequest,
701        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
702        {
703            let request = json!({
704                "model": self.model,
705                "prompt": generation_request.prompt,
706                "size": format!("{}x{}", generation_request.width, generation_request.height),
707                "response_format": "b64_json"
708            });
709
710            let response = self
711                .client
712                .post_image_generation(&self.model)
713                .json(&request)
714                .send()
715                .await?;
716
717            if !response.status().is_success() {
718                return Err(ImageGenerationError::ProviderError(format!(
719                    "{}: {}",
720                    response.status(),
721                    response.text().await?
722                )));
723            }
724
725            let t = response.text().await?;
726
727            match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&t)? {
728                ApiResponse::Ok(response) => response.try_into(),
729                ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
730            }
731        }
732    }
733
734    impl ImageGenerationClient for Client {
735        type ImageGenerationModel = ImageGenerationModel;
736
737        fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
738            ImageGenerationModel {
739                client: self.clone(),
740                model: model.to_string(),
741            }
742        }
743    }
744}
745// ================================================================
746// Azure OpenAI Audio Generation API
747// ================================================================
748
749use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient};
750#[cfg(feature = "audio")]
751pub use audio_generation::*;
752
753#[cfg(feature = "audio")]
754mod audio_generation {
755    use super::Client;
756    use crate::audio_generation;
757    use crate::audio_generation::{
758        AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
759    };
760    use crate::client::AudioGenerationClient;
761    use bytes::Bytes;
762    use serde_json::json;
763
764    #[derive(Clone)]
765    pub struct AudioGenerationModel {
766        client: Client,
767        model: String,
768    }
769
770    impl audio_generation::AudioGenerationModel for AudioGenerationModel {
771        type Response = Bytes;
772
773        async fn audio_generation(
774            &self,
775            request: AudioGenerationRequest,
776        ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
777            let request = json!({
778                "model": self.model,
779                "input": request.text,
780                "voice": request.voice,
781                "speed": request.speed,
782            });
783
784            let response = self
785                .client
786                .post_audio_generation("/audio/speech")
787                .json(&request)
788                .send()
789                .await?;
790
791            if !response.status().is_success() {
792                return Err(AudioGenerationError::ProviderError(format!(
793                    "{}: {}",
794                    response.status(),
795                    response.text().await?
796                )));
797            }
798
799            let bytes = response.bytes().await?;
800
801            Ok(AudioGenerationResponse {
802                audio: bytes.to_vec(),
803                response: bytes,
804            })
805        }
806    }
807
808    impl AudioGenerationClient for Client {
809        type AudioGenerationModel = AudioGenerationModel;
810
811        fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
812            AudioGenerationModel {
813                client: self.clone(),
814                model: model.to_string(),
815            }
816        }
817    }
818}
819
820#[cfg(test)]
821mod azure_tests {
822    use super::*;
823
824    use crate::OneOrMany;
825    use crate::completion::CompletionModel;
826    use crate::embeddings::EmbeddingModel;
827
828    #[tokio::test]
829    #[ignore]
830    async fn test_azure_embedding() {
831        let _ = tracing_subscriber::fmt::try_init();
832
833        let client = Client::from_env();
834        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
835        let embeddings = model
836            .embed_texts(vec!["Hello, world!".to_string()])
837            .await
838            .unwrap();
839
840        tracing::info!("Azure embedding: {:?}", embeddings);
841    }
842
843    #[tokio::test]
844    #[ignore]
845    async fn test_azure_completion() {
846        let _ = tracing_subscriber::fmt::try_init();
847
848        let client = Client::from_env();
849        let model = client.completion_model(GPT_4O_MINI);
850        let completion = model
851            .completion(CompletionRequest {
852                preamble: Some("You are a helpful assistant.".to_string()),
853                chat_history: OneOrMany::one("Hello!".into()),
854                documents: vec![],
855                max_tokens: Some(100),
856                temperature: Some(0.0),
857                tools: vec![],
858                additional_params: None,
859            })
860            .await
861            .unwrap();
862
863        tracing::info!("Azure completion: {:?}", completion);
864    }
865}