rig/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::azure;
6//!
7//! let client = azure::Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
8//!
9//! let gpt4o = client.completion_model(azure::GPT_4O);
10//! ```
11
12use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
13
14use crate::json_utils::merge;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17    completion::{self, CompletionError, CompletionRequest},
18    embeddings::{self, EmbeddingError},
19    json_utils,
20    providers::openai,
21    transcription::{self, TranscriptionError},
22};
23use reqwest::header::AUTHORIZATION;
24use reqwest::multipart::Part;
25use serde::Deserialize;
26use serde_json::json;
27// ================================================================
28// Main Azure OpenAI Client
29// ================================================================
30
31#[derive(Clone)]
32pub struct Client {
33    api_version: String,
34    azure_endpoint: String,
35    auth: AzureOpenAIAuth,
36    http_client: reqwest::Client,
37}
38
39impl std::fmt::Debug for Client {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        f.debug_struct("Client")
42            .field("azure_endpoint", &self.azure_endpoint)
43            .field("http_client", &self.http_client)
44            .field("auth", &"<REDACTED>")
45            .field("api_version", &self.api_version)
46            .finish()
47    }
48}
49
50#[derive(Clone)]
51pub enum AzureOpenAIAuth {
52    ApiKey(String),
53    Token(String),
54}
55
56impl std::fmt::Debug for AzureOpenAIAuth {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        match self {
59            Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
60            Self::Token(_) => write!(f, "Token <REDACTED>"),
61        }
62    }
63}
64
65impl From<String> for AzureOpenAIAuth {
66    fn from(token: String) -> Self {
67        AzureOpenAIAuth::Token(token)
68    }
69}
70
71impl AzureOpenAIAuth {
72    fn as_header(&self) -> (reqwest::header::HeaderName, reqwest::header::HeaderValue) {
73        match self {
74            AzureOpenAIAuth::ApiKey(api_key) => (
75                "api-key".parse().expect("Header value should parse"),
76                api_key.parse().expect("API key should parse"),
77            ),
78            AzureOpenAIAuth::Token(token) => (
79                AUTHORIZATION,
80                format!("Bearer {token}")
81                    .parse()
82                    .expect("Token should parse"),
83            ),
84        }
85    }
86}
87
88impl Client {
89    /// Creates a new Azure OpenAI client.
90    ///
91    /// # Arguments
92    ///
93    /// * `auth` - Azure OpenAI API key or token required for authentication
94    /// * `api_version` - API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
95    /// * `azure_endpoint` - Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
96    pub fn new(auth: impl Into<AzureOpenAIAuth>, api_version: &str, azure_endpoint: &str) -> Self {
97        Self {
98            api_version: api_version.to_string(),
99            auth: auth.into(),
100            azure_endpoint: azure_endpoint.to_string(),
101            http_client: reqwest::Client::builder()
102                .build()
103                .expect("Azure OpenAI reqwest client should build"),
104        }
105    }
106
107    /// Use your own `reqwest::Client`.
108    /// The required headers will be automatically attached upon trying to make a request.
109    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
110        self.http_client = client;
111
112        self
113    }
114
115    /// Creates a new Azure OpenAI client from an API key.
116    ///
117    /// # Arguments
118    ///
119    /// * `api_key` - Azure OpenAI API key required for authentication
120    /// * `api_version` - API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
121    /// * `azure_endpoint` - Azure OpenAI endpoint URL
122    pub fn from_api_key(api_key: &str, api_version: &str, azure_endpoint: &str) -> Self {
123        Self::new(
124            AzureOpenAIAuth::ApiKey(api_key.to_string()),
125            api_version,
126            azure_endpoint,
127        )
128    }
129
130    /// Creates a new Azure OpenAI client from a token.
131    ///
132    /// # Arguments
133    ///
134    /// * `token` - Azure OpenAI token required for authentication
135    /// * `api_version` - API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
136    /// * `azure_endpoint` - Azure OpenAI endpoint URL
137    pub fn from_token(token: &str, api_version: &str, azure_endpoint: &str) -> Self {
138        Self::new(
139            AzureOpenAIAuth::Token(token.to_string()),
140            api_version,
141            azure_endpoint,
142        )
143    }
144
145    fn post_embedding(&self, deployment_id: &str) -> reqwest::RequestBuilder {
146        let url = format!(
147            "{}/openai/deployments/{}/embeddings?api-version={}",
148            self.azure_endpoint, deployment_id, self.api_version
149        )
150        .replace("//", "/");
151
152        let (key, value) = self.auth.as_header();
153        self.http_client.post(url).header(key, value)
154    }
155
156    fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder {
157        let url = format!(
158            "{}/openai/deployments/{}/chat/completions?api-version={}",
159            self.azure_endpoint, deployment_id, self.api_version
160        )
161        .replace("//", "/");
162        let (key, value) = self.auth.as_header();
163        self.http_client.post(url).header(key, value)
164    }
165
166    fn post_transcription(&self, deployment_id: &str) -> reqwest::RequestBuilder {
167        let url = format!(
168            "{}/openai/deployments/{}/audio/translations?api-version={}",
169            self.azure_endpoint, deployment_id, self.api_version
170        )
171        .replace("//", "/");
172        let (key, value) = self.auth.as_header();
173        self.http_client.post(url).header(key, value)
174    }
175
176    #[cfg(feature = "image")]
177    fn post_image_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
178        let url = format!(
179            "{}/openai/deployments/{}/images/generations?api-version={}",
180            self.azure_endpoint, deployment_id, self.api_version
181        )
182        .replace("//", "/");
183        let (key, value) = self.auth.as_header();
184        self.http_client.post(url).header(key, value)
185    }
186
187    #[cfg(feature = "audio")]
188    fn post_audio_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
189        let url = format!(
190            "{}/openai/deployments/{}/audio/speech?api-version={}",
191            self.azure_endpoint, deployment_id, self.api_version
192        )
193        .replace("//", "/");
194        let (key, value) = self.auth.as_header();
195        self.http_client.post(url).header(key, value)
196    }
197}
198
199impl ProviderClient for Client {
200    /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
201    fn from_env() -> Self {
202        let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
203            AzureOpenAIAuth::ApiKey(api_key)
204        } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
205            AzureOpenAIAuth::Token(token)
206        } else {
207            panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
208        };
209
210        let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
211        let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
212
213        Self::new(auth, &api_version, &azure_endpoint)
214    }
215}
216
217impl CompletionClient for Client {
218    type CompletionModel = CompletionModel;
219
220    /// Create a completion model with the given name.
221    ///
222    /// # Example
223    /// ```
224    /// use rig::providers::azure::{Client, self};
225    ///
226    /// // Initialize the Azure OpenAI client
227    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
228    ///
229    /// let gpt4 = azure.completion_model(azure::GPT_4);
230    /// ```
231    fn completion_model(&self, model: &str) -> CompletionModel {
232        CompletionModel::new(self.clone(), model)
233    }
234}
235
236impl EmbeddingsClient for Client {
237    type EmbeddingModel = EmbeddingModel;
238
239    /// Create an embedding model with the given name.
240    /// Note: default embedding dimension of 0 will be used if model is not known.
241    /// If this is the case, it's better to use function `embedding_model_with_ndims`
242    ///
243    /// # Example
244    /// ```
245    /// use rig::providers::azure::{Client, self};
246    ///
247    /// // Initialize the Azure OpenAI client
248    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
249    ///
250    /// let embedding_model = azure.embedding_model(azure::TEXT_EMBEDDING_3_LARGE);
251    /// ```
252    fn embedding_model(&self, model: &str) -> EmbeddingModel {
253        let ndims = match model {
254            TEXT_EMBEDDING_3_LARGE => 3072,
255            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
256            _ => 0,
257        };
258        EmbeddingModel::new(self.clone(), model, ndims)
259    }
260
261    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
262    ///
263    /// # Example
264    /// ```
265    /// use rig::providers::azure::{Client, self};
266    ///
267    /// // Initialize the Azure OpenAI client
268    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
269    ///
270    /// let embedding_model = azure.embedding_model("model-unknown-to-rig", 3072);
271    /// ```
272    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
273        EmbeddingModel::new(self.clone(), model, ndims)
274    }
275}
276
277impl TranscriptionClient for Client {
278    type TranscriptionModel = TranscriptionModel;
279
280    /// Create a transcription model with the given name.
281    ///
282    /// # Example
283    /// ```
284    /// use rig::providers::azure::{Client, self};
285    ///
286    /// // Initialize the Azure OpenAI client
287    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
288    ///
289    /// let whisper = azure.transcription_model("model-unknown-to-rig");
290    /// ```
291    fn transcription_model(&self, model: &str) -> TranscriptionModel {
292        TranscriptionModel::new(self.clone(), model)
293    }
294}
295
296#[derive(Debug, Deserialize)]
297struct ApiErrorResponse {
298    message: String,
299}
300
301#[derive(Debug, Deserialize)]
302#[serde(untagged)]
303enum ApiResponse<T> {
304    Ok(T),
305    Err(ApiErrorResponse),
306}
307
308// ================================================================
309// Azure OpenAI Embedding API
310// ================================================================
311/// `text-embedding-3-large` embedding model
312pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
313/// `text-embedding-3-small` embedding model
314pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
315/// `text-embedding-ada-002` embedding model
316pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
317
318#[derive(Debug, Deserialize)]
319pub struct EmbeddingResponse {
320    pub object: String,
321    pub data: Vec<EmbeddingData>,
322    pub model: String,
323    pub usage: Usage,
324}
325
326impl From<ApiErrorResponse> for EmbeddingError {
327    fn from(err: ApiErrorResponse) -> Self {
328        EmbeddingError::ProviderError(err.message)
329    }
330}
331
332impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
333    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
334        match value {
335            ApiResponse::Ok(response) => Ok(response),
336            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
337        }
338    }
339}
340
341#[derive(Debug, Deserialize)]
342pub struct EmbeddingData {
343    pub object: String,
344    pub embedding: Vec<f64>,
345    pub index: usize,
346}
347
348#[derive(Clone, Debug, Deserialize)]
349pub struct Usage {
350    pub prompt_tokens: usize,
351    pub total_tokens: usize,
352}
353
354impl std::fmt::Display for Usage {
355    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356        write!(
357            f,
358            "Prompt tokens: {} Total tokens: {}",
359            self.prompt_tokens, self.total_tokens
360        )
361    }
362}
363
364#[derive(Clone)]
365pub struct EmbeddingModel {
366    client: Client,
367    pub model: String,
368    ndims: usize,
369}
370
371impl embeddings::EmbeddingModel for EmbeddingModel {
372    const MAX_DOCUMENTS: usize = 1024;
373
374    fn ndims(&self) -> usize {
375        self.ndims
376    }
377
378    #[cfg_attr(feature = "worker", worker::send)]
379    async fn embed_texts(
380        &self,
381        documents: impl IntoIterator<Item = String>,
382    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
383        let documents = documents.into_iter().collect::<Vec<_>>();
384
385        let response = self
386            .client
387            .post_embedding(&self.model)
388            .json(&json!({
389                "input": documents,
390            }))
391            .send()
392            .await?;
393
394        if response.status().is_success() {
395            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
396                ApiResponse::Ok(response) => {
397                    tracing::info!(target: "rig",
398                        "Azure embedding token usage: {}",
399                        response.usage
400                    );
401
402                    if response.data.len() != documents.len() {
403                        return Err(EmbeddingError::ResponseError(
404                            "Response data length does not match input length".into(),
405                        ));
406                    }
407
408                    Ok(response
409                        .data
410                        .into_iter()
411                        .zip(documents.into_iter())
412                        .map(|(embedding, document)| embeddings::Embedding {
413                            document,
414                            vec: embedding.embedding,
415                        })
416                        .collect())
417                }
418                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
419            }
420        } else {
421            Err(EmbeddingError::ProviderError(response.text().await?))
422        }
423    }
424}
425
426impl EmbeddingModel {
427    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
428        Self {
429            client,
430            model: model.to_string(),
431            ndims,
432        }
433    }
434}
435
436// ================================================================
437// Azure OpenAI Completion API
438// ================================================================
439/// `o1` completion model
440pub const O1: &str = "o1";
441/// `o1-preview` completion model
442pub const O1_PREVIEW: &str = "o1-preview";
443/// `o1-mini` completion model
444pub const O1_MINI: &str = "o1-mini";
445/// `gpt-4o` completion model
446pub const GPT_4O: &str = "gpt-4o";
447/// `gpt-4o-mini` completion model
448pub const GPT_4O_MINI: &str = "gpt-4o-mini";
449/// `gpt-4o-realtime-preview` completion model
450pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
451/// `gpt-4-turbo` completion model
452pub const GPT_4_TURBO: &str = "gpt-4";
453/// `gpt-4` completion model
454pub const GPT_4: &str = "gpt-4";
455/// `gpt-4-32k` completion model
456pub const GPT_4_32K: &str = "gpt-4-32k";
457/// `gpt-4-32k` completion model
458pub const GPT_4_32K_0613: &str = "gpt-4-32k";
459/// `gpt-3.5-turbo` completion model
460pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
461/// `gpt-3.5-turbo-instruct` completion model
462pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
463/// `gpt-3.5-turbo-16k` completion model
464pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
465
466#[derive(Clone)]
467pub struct CompletionModel {
468    client: Client,
469    /// Name of the model (e.g.: gpt-4o-mini)
470    pub model: String,
471}
472
473impl CompletionModel {
474    pub fn new(client: Client, model: &str) -> Self {
475        Self {
476            client,
477            model: model.to_string(),
478        }
479    }
480
481    fn create_completion_request(
482        &self,
483        completion_request: CompletionRequest,
484    ) -> Result<serde_json::Value, CompletionError> {
485        let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
486            Some(preamble) => vec![openai::Message::system(preamble)],
487            None => vec![],
488        };
489        if let Some(docs) = completion_request.normalized_documents() {
490            let docs: Vec<openai::Message> = docs.try_into()?;
491            full_history.extend(docs);
492        }
493        let chat_history: Vec<openai::Message> = completion_request
494            .chat_history
495            .into_iter()
496            .map(|message| message.try_into())
497            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
498            .into_iter()
499            .flatten()
500            .collect();
501
502        full_history.extend(chat_history);
503
504        let request = if completion_request.tools.is_empty() {
505            json!({
506                "model": self.model,
507                "messages": full_history,
508                "temperature": completion_request.temperature,
509            })
510        } else {
511            json!({
512                "model": self.model,
513                "messages": full_history,
514                "temperature": completion_request.temperature,
515                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
516                "tool_choice": "auto",
517            })
518        };
519
520        let request = if let Some(params) = completion_request.additional_params {
521            json_utils::merge(request, params)
522        } else {
523            request
524        };
525
526        Ok(request)
527    }
528}
529
530impl completion::CompletionModel for CompletionModel {
531    type Response = openai::CompletionResponse;
532    type StreamingResponse = openai::StreamingCompletionResponse;
533
534    #[cfg_attr(feature = "worker", worker::send)]
535    async fn completion(
536        &self,
537        completion_request: CompletionRequest,
538    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
539        let request = self.create_completion_request(completion_request)?;
540
541        let response = self
542            .client
543            .post_chat_completion(&self.model)
544            .json(&request)
545            .send()
546            .await?;
547
548        if response.status().is_success() {
549            let t = response.text().await?;
550            tracing::debug!(target: "rig", "Azure completion error: {}", t);
551
552            match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
553                ApiResponse::Ok(response) => {
554                    tracing::info!(target: "rig",
555                        "Azure completion token usage: {:?}",
556                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
557                    );
558                    response.try_into()
559                }
560                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
561            }
562        } else {
563            Err(CompletionError::ProviderError(response.text().await?))
564        }
565    }
566
567    #[cfg_attr(feature = "worker", worker::send)]
568    async fn stream(
569        &self,
570        request: CompletionRequest,
571    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
572        let mut request = self.create_completion_request(request)?;
573
574        request = merge(
575            request,
576            json!({"stream": true, "stream_options": {"include_usage": true}}),
577        );
578
579        let builder = self
580            .client
581            .post_chat_completion(self.model.as_str())
582            .json(&request);
583
584        send_compatible_streaming_request(builder).await
585    }
586}
587
588// ================================================================
589// Azure OpenAI Transcription API
590// ================================================================
591
592#[derive(Clone)]
593pub struct TranscriptionModel {
594    client: Client,
595    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
596    pub model: String,
597}
598
599impl TranscriptionModel {
600    pub fn new(client: Client, model: &str) -> Self {
601        Self {
602            client,
603            model: model.to_string(),
604        }
605    }
606}
607
608impl transcription::TranscriptionModel for TranscriptionModel {
609    type Response = TranscriptionResponse;
610
611    #[cfg_attr(feature = "worker", worker::send)]
612    async fn transcription(
613        &self,
614        request: transcription::TranscriptionRequest,
615    ) -> Result<
616        transcription::TranscriptionResponse<Self::Response>,
617        transcription::TranscriptionError,
618    > {
619        let data = request.data;
620
621        let mut body = reqwest::multipart::Form::new().part(
622            "file",
623            Part::bytes(data).file_name(request.filename.clone()),
624        );
625
626        if let Some(prompt) = request.prompt {
627            body = body.text("prompt", prompt.clone());
628        }
629
630        if let Some(ref temperature) = request.temperature {
631            body = body.text("temperature", temperature.to_string());
632        }
633
634        if let Some(ref additional_params) = request.additional_params {
635            for (key, value) in additional_params
636                .as_object()
637                .expect("Additional Parameters to OpenAI Transcription should be a map")
638            {
639                body = body.text(key.to_owned(), value.to_string());
640            }
641        }
642
643        let response = self
644            .client
645            .post_transcription(&self.model)
646            .multipart(body)
647            .send()
648            .await?;
649
650        if response.status().is_success() {
651            match response
652                .json::<ApiResponse<TranscriptionResponse>>()
653                .await?
654            {
655                ApiResponse::Ok(response) => response.try_into(),
656                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
657                    api_error_response.message,
658                )),
659            }
660        } else {
661            Err(TranscriptionError::ProviderError(response.text().await?))
662        }
663    }
664}
665
666// ================================================================
667// Azure OpenAI Image Generation API
668// ================================================================
669#[cfg(feature = "image")]
670pub use image_generation::*;
671#[cfg(feature = "image")]
672mod image_generation {
673    use crate::client::ImageGenerationClient;
674    use crate::image_generation;
675    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
676    use crate::providers::azure::{ApiResponse, Client};
677    use crate::providers::openai::ImageGenerationResponse;
678    use serde_json::json;
679
680    #[derive(Clone)]
681    pub struct ImageGenerationModel {
682        client: Client,
683        pub model: String,
684    }
685    impl image_generation::ImageGenerationModel for ImageGenerationModel {
686        type Response = ImageGenerationResponse;
687
688        async fn image_generation(
689            &self,
690            generation_request: ImageGenerationRequest,
691        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
692        {
693            let request = json!({
694                "model": self.model,
695                "prompt": generation_request.prompt,
696                "size": format!("{}x{}", generation_request.width, generation_request.height),
697                "response_format": "b64_json"
698            });
699
700            let response = self
701                .client
702                .post_image_generation(&self.model)
703                .json(&request)
704                .send()
705                .await?;
706
707            if !response.status().is_success() {
708                return Err(ImageGenerationError::ProviderError(format!(
709                    "{}: {}",
710                    response.status(),
711                    response.text().await?
712                )));
713            }
714
715            let t = response.text().await?;
716
717            match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&t)? {
718                ApiResponse::Ok(response) => response.try_into(),
719                ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
720            }
721        }
722    }
723
724    impl ImageGenerationClient for Client {
725        type ImageGenerationModel = ImageGenerationModel;
726
727        fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
728            ImageGenerationModel {
729                client: self.clone(),
730                model: model.to_string(),
731            }
732        }
733    }
734}
735// ================================================================
736// Azure OpenAI Audio Generation API
737// ================================================================
738
739use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient};
740#[cfg(feature = "audio")]
741pub use audio_generation::*;
742
743#[cfg(feature = "audio")]
744mod audio_generation {
745    use super::Client;
746    use crate::audio_generation;
747    use crate::audio_generation::{
748        AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
749    };
750    use crate::client::AudioGenerationClient;
751    use bytes::Bytes;
752    use serde_json::json;
753
754    #[derive(Clone)]
755    pub struct AudioGenerationModel {
756        client: Client,
757        model: String,
758    }
759
760    impl audio_generation::AudioGenerationModel for AudioGenerationModel {
761        type Response = Bytes;
762
763        async fn audio_generation(
764            &self,
765            request: AudioGenerationRequest,
766        ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
767            let request = json!({
768                "model": self.model,
769                "input": request.text,
770                "voice": request.voice,
771                "speed": request.speed,
772            });
773
774            let response = self
775                .client
776                .post_audio_generation("/audio/speech")
777                .json(&request)
778                .send()
779                .await?;
780
781            if !response.status().is_success() {
782                return Err(AudioGenerationError::ProviderError(format!(
783                    "{}: {}",
784                    response.status(),
785                    response.text().await?
786                )));
787            }
788
789            let bytes = response.bytes().await?;
790
791            Ok(AudioGenerationResponse {
792                audio: bytes.to_vec(),
793                response: bytes,
794            })
795        }
796    }
797
798    impl AudioGenerationClient for Client {
799        type AudioGenerationModel = AudioGenerationModel;
800
801        fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
802            AudioGenerationModel {
803                client: self.clone(),
804                model: model.to_string(),
805            }
806        }
807    }
808}
809
810#[cfg(test)]
811mod azure_tests {
812    use super::*;
813
814    use crate::OneOrMany;
815    use crate::completion::CompletionModel;
816    use crate::embeddings::EmbeddingModel;
817
818    #[tokio::test]
819    #[ignore]
820    async fn test_azure_embedding() {
821        let _ = tracing_subscriber::fmt::try_init();
822
823        let client = Client::from_env();
824        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
825        let embeddings = model
826            .embed_texts(vec!["Hello, world!".to_string()])
827            .await
828            .unwrap();
829
830        tracing::info!("Azure embedding: {:?}", embeddings);
831    }
832
833    #[tokio::test]
834    #[ignore]
835    async fn test_azure_completion() {
836        let _ = tracing_subscriber::fmt::try_init();
837
838        let client = Client::from_env();
839        let model = client.completion_model(GPT_4O_MINI);
840        let completion = model
841            .completion(CompletionRequest {
842                preamble: Some("You are a helpful assistant.".to_string()),
843                chat_history: OneOrMany::one("Hello!".into()),
844                documents: vec![],
845                max_tokens: Some(100),
846                temperature: Some(0.0),
847                tools: vec![],
848                additional_params: None,
849            })
850            .await
851            .unwrap();
852
853        tracing::info!("Azure completion: {:?}", completion);
854    }
855}