rig/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::azure;
6//!
7//! let client = azure::Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
8//!
9//! let gpt4o = client.completion_model(azure::GPT_4O);
10//! ```
11
12use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
13
14use crate::json_utils::merge;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17    client::ClientBuilderError,
18    completion::{self, CompletionError, CompletionRequest},
19    embeddings::{self, EmbeddingError},
20    json_utils,
21    providers::openai,
22    transcription::{self, TranscriptionError},
23};
24use reqwest::header::AUTHORIZATION;
25use reqwest::multipart::Part;
26use serde::Deserialize;
27use serde_json::json;
28// ================================================================
29// Main Azure OpenAI Client
30// ================================================================
31
32const DEFAULT_API_VERSION: &str = "2024-10-21";
33
34pub struct ClientBuilder<'a> {
35    auth: AzureOpenAIAuth,
36    api_version: Option<&'a str>,
37    azure_endpoint: &'a str,
38    http_client: Option<reqwest::Client>,
39}
40
41impl<'a> ClientBuilder<'a> {
42    pub fn new(auth: impl Into<AzureOpenAIAuth>, endpoint: &'a str) -> Self {
43        Self {
44            auth: auth.into(),
45            api_version: None,
46            azure_endpoint: endpoint,
47            http_client: None,
48        }
49    }
50
51    /// API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
52    pub fn api_version(mut self, api_version: &'a str) -> Self {
53        self.api_version = Some(api_version);
54        self
55    }
56
57    /// Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
58    pub fn azure_endpoint(mut self, azure_endpoint: &'a str) -> Self {
59        self.azure_endpoint = azure_endpoint;
60        self
61    }
62
63    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
64        self.http_client = Some(client);
65        self
66    }
67
68    pub fn build(self) -> Result<Client, ClientBuilderError> {
69        let http_client = if let Some(http_client) = self.http_client {
70            http_client
71        } else {
72            reqwest::Client::builder().build()?
73        };
74
75        let api_version = self.api_version.unwrap_or(DEFAULT_API_VERSION);
76
77        Ok(Client {
78            api_version: api_version.to_string(),
79            azure_endpoint: self.azure_endpoint.to_string(),
80            auth: self.auth,
81            http_client,
82        })
83    }
84}
85
86#[derive(Clone)]
87pub struct Client {
88    api_version: String,
89    azure_endpoint: String,
90    auth: AzureOpenAIAuth,
91    http_client: reqwest::Client,
92}
93
94impl std::fmt::Debug for Client {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        f.debug_struct("Client")
97            .field("azure_endpoint", &self.azure_endpoint)
98            .field("http_client", &self.http_client)
99            .field("auth", &"<REDACTED>")
100            .field("api_version", &self.api_version)
101            .finish()
102    }
103}
104
105#[derive(Clone)]
106pub enum AzureOpenAIAuth {
107    ApiKey(String),
108    Token(String),
109}
110
111impl std::fmt::Debug for AzureOpenAIAuth {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        match self {
114            Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
115            Self::Token(_) => write!(f, "Token <REDACTED>"),
116        }
117    }
118}
119
120impl From<String> for AzureOpenAIAuth {
121    fn from(token: String) -> Self {
122        AzureOpenAIAuth::Token(token)
123    }
124}
125
126impl AzureOpenAIAuth {
127    fn as_header(&self) -> (reqwest::header::HeaderName, reqwest::header::HeaderValue) {
128        match self {
129            AzureOpenAIAuth::ApiKey(api_key) => (
130                "api-key".parse().expect("Header value should parse"),
131                api_key.parse().expect("API key should parse"),
132            ),
133            AzureOpenAIAuth::Token(token) => (
134                AUTHORIZATION,
135                format!("Bearer {token}")
136                    .parse()
137                    .expect("Token should parse"),
138            ),
139        }
140    }
141}
142
143impl Client {
144    /// Create a new Azure OpenAI client builder.
145    ///
146    /// # Example
147    /// ```
148    /// use rig::providers::azure::{ClientBuilder, self};
149    ///
150    /// // Initialize the Azure OpenAI client
151    /// let azure = Client::builder("your-azure-api-key", "https://{your-resource-name}.openai.azure.com")
152    ///    .build()
153    /// ```
154    pub fn builder(auth: impl Into<AzureOpenAIAuth>, endpoint: &str) -> ClientBuilder<'_> {
155        ClientBuilder::new(auth, endpoint)
156    }
157
158    /// Creates a new Azure OpenAI client. For more control, use the `builder` method.
159    ///
160    /// # Panics
161    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
162    pub fn new(auth: impl Into<AzureOpenAIAuth>, endpoint: &str) -> Self {
163        Self::builder(auth, endpoint)
164            .build()
165            .expect("Azure OpenAI client should build")
166    }
167
168    fn post_embedding(&self, deployment_id: &str) -> reqwest::RequestBuilder {
169        let url = format!(
170            "{}/openai/deployments/{}/embeddings?api-version={}",
171            self.azure_endpoint, deployment_id, self.api_version
172        )
173        .replace("//", "/");
174
175        let (key, value) = self.auth.as_header();
176        self.http_client.post(url).header(key, value)
177    }
178
179    fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder {
180        let url = format!(
181            "{}/openai/deployments/{}/chat/completions?api-version={}",
182            self.azure_endpoint, deployment_id, self.api_version
183        )
184        .replace("//", "/");
185        let (key, value) = self.auth.as_header();
186        self.http_client.post(url).header(key, value)
187    }
188
189    fn post_transcription(&self, deployment_id: &str) -> reqwest::RequestBuilder {
190        let url = format!(
191            "{}/openai/deployments/{}/audio/translations?api-version={}",
192            self.azure_endpoint, deployment_id, self.api_version
193        )
194        .replace("//", "/");
195        let (key, value) = self.auth.as_header();
196        self.http_client.post(url).header(key, value)
197    }
198
199    #[cfg(feature = "image")]
200    fn post_image_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
201        let url = format!(
202            "{}/openai/deployments/{}/images/generations?api-version={}",
203            self.azure_endpoint, deployment_id, self.api_version
204        )
205        .replace("//", "/");
206        let (key, value) = self.auth.as_header();
207        self.http_client.post(url).header(key, value)
208    }
209
210    #[cfg(feature = "audio")]
211    fn post_audio_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
212        let url = format!(
213            "{}/openai/deployments/{}/audio/speech?api-version={}",
214            self.azure_endpoint, deployment_id, self.api_version
215        )
216        .replace("//", "/");
217        let (key, value) = self.auth.as_header();
218        self.http_client.post(url).header(key, value)
219    }
220}
221
222impl ProviderClient for Client {
223    /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
224    fn from_env() -> Self {
225        let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
226            AzureOpenAIAuth::ApiKey(api_key)
227        } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
228            AzureOpenAIAuth::Token(token)
229        } else {
230            panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
231        };
232
233        let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
234        let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
235
236        Self::builder(auth, &azure_endpoint)
237            .api_version(&api_version)
238            .build()
239            .expect("Azure OpenAI client should build")
240    }
241
242    fn from_val(input: crate::client::ProviderValue) -> Self {
243        let crate::client::ProviderValue::ApiKeyWithVersionAndHeader(api_key, version, header) =
244            input
245        else {
246            panic!("Incorrect provider value type")
247        };
248        let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
249        Self::builder(auth, &header)
250            .api_version(&version)
251            .build()
252            .expect("Azure OpenAI client should build")
253    }
254}
255
256impl CompletionClient for Client {
257    type CompletionModel = CompletionModel;
258
259    /// Create a completion model with the given name.
260    ///
261    /// # Example
262    /// ```
263    /// use rig::providers::azure::{Client, self};
264    ///
265    /// // Initialize the Azure OpenAI client
266    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
267    ///
268    /// let gpt4 = azure.completion_model(azure::GPT_4);
269    /// ```
270    fn completion_model(&self, model: &str) -> CompletionModel {
271        CompletionModel::new(self.clone(), model)
272    }
273}
274
275impl EmbeddingsClient for Client {
276    type EmbeddingModel = EmbeddingModel;
277
278    /// Create an embedding model with the given name.
279    /// Note: default embedding dimension of 0 will be used if model is not known.
280    /// If this is the case, it's better to use function `embedding_model_with_ndims`
281    ///
282    /// # Example
283    /// ```
284    /// use rig::providers::azure::{Client, self};
285    ///
286    /// // Initialize the Azure OpenAI client
287    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
288    ///
289    /// let embedding_model = azure.embedding_model(azure::TEXT_EMBEDDING_3_LARGE);
290    /// ```
291    fn embedding_model(&self, model: &str) -> EmbeddingModel {
292        let ndims = match model {
293            TEXT_EMBEDDING_3_LARGE => 3072,
294            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
295            _ => 0,
296        };
297        EmbeddingModel::new(self.clone(), model, ndims)
298    }
299
300    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
301    ///
302    /// # Example
303    /// ```
304    /// use rig::providers::azure::{Client, self};
305    ///
306    /// // Initialize the Azure OpenAI client
307    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
308    ///
309    /// let embedding_model = azure.embedding_model("model-unknown-to-rig", 3072);
310    /// ```
311    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
312        EmbeddingModel::new(self.clone(), model, ndims)
313    }
314}
315
316impl TranscriptionClient for Client {
317    type TranscriptionModel = TranscriptionModel;
318
319    /// Create a transcription model with the given name.
320    ///
321    /// # Example
322    /// ```
323    /// use rig::providers::azure::{Client, self};
324    ///
325    /// // Initialize the Azure OpenAI client
326    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
327    ///
328    /// let whisper = azure.transcription_model("model-unknown-to-rig");
329    /// ```
330    fn transcription_model(&self, model: &str) -> TranscriptionModel {
331        TranscriptionModel::new(self.clone(), model)
332    }
333}
334
335#[derive(Debug, Deserialize)]
336struct ApiErrorResponse {
337    message: String,
338}
339
340#[derive(Debug, Deserialize)]
341#[serde(untagged)]
342enum ApiResponse<T> {
343    Ok(T),
344    Err(ApiErrorResponse),
345}
346
347// ================================================================
348// Azure OpenAI Embedding API
349// ================================================================
350/// `text-embedding-3-large` embedding model
351pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
352/// `text-embedding-3-small` embedding model
353pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
354/// `text-embedding-ada-002` embedding model
355pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
356
357#[derive(Debug, Deserialize)]
358pub struct EmbeddingResponse {
359    pub object: String,
360    pub data: Vec<EmbeddingData>,
361    pub model: String,
362    pub usage: Usage,
363}
364
365impl From<ApiErrorResponse> for EmbeddingError {
366    fn from(err: ApiErrorResponse) -> Self {
367        EmbeddingError::ProviderError(err.message)
368    }
369}
370
371impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
372    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
373        match value {
374            ApiResponse::Ok(response) => Ok(response),
375            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
376        }
377    }
378}
379
380#[derive(Debug, Deserialize)]
381pub struct EmbeddingData {
382    pub object: String,
383    pub embedding: Vec<f64>,
384    pub index: usize,
385}
386
387#[derive(Clone, Debug, Deserialize)]
388pub struct Usage {
389    pub prompt_tokens: usize,
390    pub total_tokens: usize,
391}
392
393impl std::fmt::Display for Usage {
394    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395        write!(
396            f,
397            "Prompt tokens: {} Total tokens: {}",
398            self.prompt_tokens, self.total_tokens
399        )
400    }
401}
402
403#[derive(Clone)]
404pub struct EmbeddingModel {
405    client: Client,
406    pub model: String,
407    ndims: usize,
408}
409
410impl embeddings::EmbeddingModel for EmbeddingModel {
411    const MAX_DOCUMENTS: usize = 1024;
412
413    fn ndims(&self) -> usize {
414        self.ndims
415    }
416
417    #[cfg_attr(feature = "worker", worker::send)]
418    async fn embed_texts(
419        &self,
420        documents: impl IntoIterator<Item = String>,
421    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
422        let documents = documents.into_iter().collect::<Vec<_>>();
423
424        let response = self
425            .client
426            .post_embedding(&self.model)
427            .json(&json!({
428                "input": documents,
429            }))
430            .send()
431            .await?;
432
433        if response.status().is_success() {
434            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
435                ApiResponse::Ok(response) => {
436                    tracing::info!(target: "rig",
437                        "Azure embedding token usage: {}",
438                        response.usage
439                    );
440
441                    if response.data.len() != documents.len() {
442                        return Err(EmbeddingError::ResponseError(
443                            "Response data length does not match input length".into(),
444                        ));
445                    }
446
447                    Ok(response
448                        .data
449                        .into_iter()
450                        .zip(documents.into_iter())
451                        .map(|(embedding, document)| embeddings::Embedding {
452                            document,
453                            vec: embedding.embedding,
454                        })
455                        .collect())
456                }
457                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
458            }
459        } else {
460            Err(EmbeddingError::ProviderError(response.text().await?))
461        }
462    }
463}
464
465impl EmbeddingModel {
466    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
467        Self {
468            client,
469            model: model.to_string(),
470            ndims,
471        }
472    }
473}
474
475// ================================================================
476// Azure OpenAI Completion API
477// ================================================================
478/// `o1` completion model
479pub const O1: &str = "o1";
480/// `o1-preview` completion model
481pub const O1_PREVIEW: &str = "o1-preview";
482/// `o1-mini` completion model
483pub const O1_MINI: &str = "o1-mini";
484/// `gpt-4o` completion model
485pub const GPT_4O: &str = "gpt-4o";
486/// `gpt-4o-mini` completion model
487pub const GPT_4O_MINI: &str = "gpt-4o-mini";
488/// `gpt-4o-realtime-preview` completion model
489pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
490/// `gpt-4-turbo` completion model
491pub const GPT_4_TURBO: &str = "gpt-4";
492/// `gpt-4` completion model
493pub const GPT_4: &str = "gpt-4";
494/// `gpt-4-32k` completion model
495pub const GPT_4_32K: &str = "gpt-4-32k";
496/// `gpt-4-32k` completion model
497pub const GPT_4_32K_0613: &str = "gpt-4-32k";
498/// `gpt-3.5-turbo` completion model
499pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
500/// `gpt-3.5-turbo-instruct` completion model
501pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
502/// `gpt-3.5-turbo-16k` completion model
503pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
504
505#[derive(Clone)]
506pub struct CompletionModel {
507    client: Client,
508    /// Name of the model (e.g.: gpt-4o-mini)
509    pub model: String,
510}
511
512impl CompletionModel {
513    pub fn new(client: Client, model: &str) -> Self {
514        Self {
515            client,
516            model: model.to_string(),
517        }
518    }
519
520    fn create_completion_request(
521        &self,
522        completion_request: CompletionRequest,
523    ) -> Result<serde_json::Value, CompletionError> {
524        let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
525            Some(preamble) => vec![openai::Message::system(preamble)],
526            None => vec![],
527        };
528        if let Some(docs) = completion_request.normalized_documents() {
529            let docs: Vec<openai::Message> = docs.try_into()?;
530            full_history.extend(docs);
531        }
532        let chat_history: Vec<openai::Message> = completion_request
533            .chat_history
534            .into_iter()
535            .map(|message| message.try_into())
536            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
537            .into_iter()
538            .flatten()
539            .collect();
540
541        full_history.extend(chat_history);
542
543        let request = if completion_request.tools.is_empty() {
544            json!({
545                "model": self.model,
546                "messages": full_history,
547                "temperature": completion_request.temperature,
548            })
549        } else {
550            json!({
551                "model": self.model,
552                "messages": full_history,
553                "temperature": completion_request.temperature,
554                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
555                "tool_choice": "auto",
556            })
557        };
558
559        let request = if let Some(params) = completion_request.additional_params {
560            json_utils::merge(request, params)
561        } else {
562            request
563        };
564
565        Ok(request)
566    }
567}
568
569impl completion::CompletionModel for CompletionModel {
570    type Response = openai::CompletionResponse;
571    type StreamingResponse = openai::StreamingCompletionResponse;
572
573    #[cfg_attr(feature = "worker", worker::send)]
574    async fn completion(
575        &self,
576        completion_request: CompletionRequest,
577    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
578        let request = self.create_completion_request(completion_request)?;
579
580        let response = self
581            .client
582            .post_chat_completion(&self.model)
583            .json(&request)
584            .send()
585            .await?;
586
587        if response.status().is_success() {
588            let t = response.text().await?;
589            tracing::debug!(target: "rig", "Azure completion error: {}", t);
590
591            match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
592                ApiResponse::Ok(response) => {
593                    tracing::info!(target: "rig",
594                        "Azure completion token usage: {:?}",
595                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
596                    );
597                    response.try_into()
598                }
599                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
600            }
601        } else {
602            Err(CompletionError::ProviderError(response.text().await?))
603        }
604    }
605
606    #[cfg_attr(feature = "worker", worker::send)]
607    async fn stream(
608        &self,
609        request: CompletionRequest,
610    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
611        let mut request = self.create_completion_request(request)?;
612
613        request = merge(
614            request,
615            json!({"stream": true, "stream_options": {"include_usage": true}}),
616        );
617
618        let builder = self
619            .client
620            .post_chat_completion(self.model.as_str())
621            .json(&request);
622
623        send_compatible_streaming_request(builder).await
624    }
625}
626
627// ================================================================
628// Azure OpenAI Transcription API
629// ================================================================
630
631#[derive(Clone)]
632pub struct TranscriptionModel {
633    client: Client,
634    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
635    pub model: String,
636}
637
638impl TranscriptionModel {
639    pub fn new(client: Client, model: &str) -> Self {
640        Self {
641            client,
642            model: model.to_string(),
643        }
644    }
645}
646
647impl transcription::TranscriptionModel for TranscriptionModel {
648    type Response = TranscriptionResponse;
649
650    #[cfg_attr(feature = "worker", worker::send)]
651    async fn transcription(
652        &self,
653        request: transcription::TranscriptionRequest,
654    ) -> Result<
655        transcription::TranscriptionResponse<Self::Response>,
656        transcription::TranscriptionError,
657    > {
658        let data = request.data;
659
660        let mut body = reqwest::multipart::Form::new().part(
661            "file",
662            Part::bytes(data).file_name(request.filename.clone()),
663        );
664
665        if let Some(prompt) = request.prompt {
666            body = body.text("prompt", prompt.clone());
667        }
668
669        if let Some(ref temperature) = request.temperature {
670            body = body.text("temperature", temperature.to_string());
671        }
672
673        if let Some(ref additional_params) = request.additional_params {
674            for (key, value) in additional_params
675                .as_object()
676                .expect("Additional Parameters to OpenAI Transcription should be a map")
677            {
678                body = body.text(key.to_owned(), value.to_string());
679            }
680        }
681
682        let response = self
683            .client
684            .post_transcription(&self.model)
685            .multipart(body)
686            .send()
687            .await?;
688
689        if response.status().is_success() {
690            match response
691                .json::<ApiResponse<TranscriptionResponse>>()
692                .await?
693            {
694                ApiResponse::Ok(response) => response.try_into(),
695                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
696                    api_error_response.message,
697                )),
698            }
699        } else {
700            Err(TranscriptionError::ProviderError(response.text().await?))
701        }
702    }
703}
704
705// ================================================================
706// Azure OpenAI Image Generation API
707// ================================================================
708#[cfg(feature = "image")]
709pub use image_generation::*;
710#[cfg(feature = "image")]
711#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
712mod image_generation {
713    use crate::client::ImageGenerationClient;
714    use crate::image_generation;
715    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
716    use crate::providers::azure::{ApiResponse, Client};
717    use crate::providers::openai::ImageGenerationResponse;
718    use serde_json::json;
719
720    #[derive(Clone)]
721    pub struct ImageGenerationModel {
722        client: Client,
723        pub model: String,
724    }
725    impl image_generation::ImageGenerationModel for ImageGenerationModel {
726        type Response = ImageGenerationResponse;
727
728        #[cfg_attr(feature = "worker", worker::send)]
729        async fn image_generation(
730            &self,
731            generation_request: ImageGenerationRequest,
732        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
733        {
734            let request = json!({
735                "model": self.model,
736                "prompt": generation_request.prompt,
737                "size": format!("{}x{}", generation_request.width, generation_request.height),
738                "response_format": "b64_json"
739            });
740
741            let response = self
742                .client
743                .post_image_generation(&self.model)
744                .json(&request)
745                .send()
746                .await?;
747
748            if !response.status().is_success() {
749                return Err(ImageGenerationError::ProviderError(format!(
750                    "{}: {}",
751                    response.status(),
752                    response.text().await?
753                )));
754            }
755
756            let t = response.text().await?;
757
758            match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&t)? {
759                ApiResponse::Ok(response) => response.try_into(),
760                ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
761            }
762        }
763    }
764
765    impl ImageGenerationClient for Client {
766        type ImageGenerationModel = ImageGenerationModel;
767
768        fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
769            ImageGenerationModel {
770                client: self.clone(),
771                model: model.to_string(),
772            }
773        }
774    }
775}
776// ================================================================
777// Azure OpenAI Audio Generation API
778// ================================================================
779
780use crate::client::{
781    CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient, VerifyClient,
782    VerifyError,
783};
784#[cfg(feature = "audio")]
785pub use audio_generation::*;
786
787#[cfg(feature = "audio")]
788#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
789mod audio_generation {
790    use super::Client;
791    use crate::audio_generation;
792    use crate::audio_generation::{
793        AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
794    };
795    use crate::client::AudioGenerationClient;
796    use bytes::Bytes;
797    use serde_json::json;
798
799    #[derive(Clone)]
800    pub struct AudioGenerationModel {
801        client: Client,
802        model: String,
803    }
804
805    impl audio_generation::AudioGenerationModel for AudioGenerationModel {
806        type Response = Bytes;
807
808        #[cfg_attr(feature = "worker", worker::send)]
809        async fn audio_generation(
810            &self,
811            request: AudioGenerationRequest,
812        ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
813            let request = json!({
814                "model": self.model,
815                "input": request.text,
816                "voice": request.voice,
817                "speed": request.speed,
818            });
819
820            let response = self
821                .client
822                .post_audio_generation("/audio/speech")
823                .json(&request)
824                .send()
825                .await?;
826
827            if !response.status().is_success() {
828                return Err(AudioGenerationError::ProviderError(format!(
829                    "{}: {}",
830                    response.status(),
831                    response.text().await?
832                )));
833            }
834
835            let bytes = response.bytes().await?;
836
837            Ok(AudioGenerationResponse {
838                audio: bytes.to_vec(),
839                response: bytes,
840            })
841        }
842    }
843
844    impl AudioGenerationClient for Client {
845        type AudioGenerationModel = AudioGenerationModel;
846
847        fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
848            AudioGenerationModel {
849                client: self.clone(),
850                model: model.to_string(),
851            }
852        }
853    }
854}
855
856impl VerifyClient for Client {
857    #[cfg_attr(feature = "worker", worker::send)]
858    async fn verify(&self) -> Result<(), VerifyError> {
859        // There is currently no way to verify the Azure OpenAI API key or token without
860        // consuming tokens
861        Ok(())
862    }
863}
864
865#[cfg(test)]
866mod azure_tests {
867    use super::*;
868
869    use crate::OneOrMany;
870    use crate::completion::CompletionModel;
871    use crate::embeddings::EmbeddingModel;
872
873    #[tokio::test]
874    #[ignore]
875    async fn test_azure_embedding() {
876        let _ = tracing_subscriber::fmt::try_init();
877
878        let client = Client::from_env();
879        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
880        let embeddings = model
881            .embed_texts(vec!["Hello, world!".to_string()])
882            .await
883            .unwrap();
884
885        tracing::info!("Azure embedding: {:?}", embeddings);
886    }
887
888    #[tokio::test]
889    #[ignore]
890    async fn test_azure_completion() {
891        let _ = tracing_subscriber::fmt::try_init();
892
893        let client = Client::from_env();
894        let model = client.completion_model(GPT_4O_MINI);
895        let completion = model
896            .completion(CompletionRequest {
897                preamble: Some("You are a helpful assistant.".to_string()),
898                chat_history: OneOrMany::one("Hello!".into()),
899                documents: vec![],
900                max_tokens: Some(100),
901                temperature: Some(0.0),
902                tools: vec![],
903                additional_params: None,
904            })
905            .await
906            .unwrap();
907
908        tracing::info!("Azure completion: {:?}", completion);
909    }
910}