Skip to main content

rig/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::azure;
6//! use rig::client::CompletionClient;
7//!
8//! let client: azure::Client<reqwest::Client> = azure::Client::builder()
9//!     .api_key("test")
10//!     .azure_endpoint("test".to_string()) // add your endpoint here!
11//!     .build()?;
12//!
13//! let gpt4o = client.completion_model(azure::GPT_4O);
14//! ```
15//!
16//! ## Authentication
17//! The authentication type used for the `azure` module is [`AzureOpenAIAuth`].
18//!
19//! By default, using a type that implements `Into<String>` as the input for the client builder will turn the type into a bearer auth token.
20//! If you want to use an API key, you need to use the type specifically.
21
22use std::fmt::Debug;
23
24use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
25use crate::client::{
26    self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
27    ProviderClient,
28};
29use crate::completion::GetTokenUsage;
30use crate::http_client::multipart::Part;
31use crate::http_client::{self, HttpClientExt, MultipartForm, bearer_auth_header};
32use crate::streaming::StreamingCompletionResponse;
33use crate::transcription::TranscriptionError;
34use crate::{
35    completion::{self, CompletionError, CompletionRequest},
36    embeddings::{self, EmbeddingError},
37    json_utils,
38    providers::openai,
39    telemetry::SpanCombinator,
40    transcription::{self},
41};
42use bytes::Bytes;
43use serde::{Deserialize, Serialize};
44use serde_json::json;
45// ================================================================
46// Main Azure OpenAI Client
47// ================================================================
48
49const DEFAULT_API_VERSION: &str = "2024-10-21";
50
51#[derive(Debug, Clone)]
52pub struct AzureExt {
53    endpoint: String,
54    api_version: String,
55}
56
57impl DebugExt for AzureExt {
58    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
59        [
60            ("endpoint", (&self.endpoint as &dyn Debug)),
61            ("api_version", (&self.api_version as &dyn Debug)),
62        ]
63        .into_iter()
64    }
65}
66
67// TODO: @FayCarsons - this should be a type-safe builder,
68// but that would require extending the `ProviderBuilder`
69// to have some notion of complete vs incomplete states in a
70// given extension builder
71#[derive(Debug, Clone)]
72pub struct AzureExtBuilder {
73    endpoint: Option<String>,
74    api_version: String,
75}
76
77impl Default for AzureExtBuilder {
78    fn default() -> Self {
79        Self {
80            endpoint: None,
81            api_version: DEFAULT_API_VERSION.into(),
82        }
83    }
84}
85
86pub type Client<H = reqwest::Client> = client::Client<AzureExt, H>;
87pub type ClientBuilder<H = reqwest::Client> =
88    client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H>;
89
90impl Provider for AzureExt {
91    type Builder = AzureExtBuilder;
92
93    /// Verifying Azure auth without consuming tokens is not supported
94    const VERIFY_PATH: &'static str = "";
95
96    fn build<H>(
97        builder: &client::ClientBuilder<
98            Self::Builder,
99            <Self::Builder as ProviderBuilder>::ApiKey,
100            H,
101        >,
102    ) -> http_client::Result<Self> {
103        let AzureExtBuilder {
104            endpoint,
105            api_version,
106            ..
107        } = builder.ext().clone();
108
109        match endpoint {
110            Some(endpoint) => Ok(Self {
111                endpoint,
112                api_version,
113            }),
114            None => Err(http_client::Error::Instance(
115                "Azure client must be provided an endpoint prior to building".into(),
116            )),
117        }
118    }
119}
120
121impl<H> Capabilities<H> for AzureExt {
122    type Completion = Capable<CompletionModel<H>>;
123    type Embeddings = Capable<EmbeddingModel<H>>;
124    type Transcription = Capable<TranscriptionModel<H>>;
125    type ModelListing = Nothing;
126    #[cfg(feature = "image")]
127    type ImageGeneration = Nothing;
128    #[cfg(feature = "audio")]
129    type AudioGeneration = Capable<AudioGenerationModel<H>>;
130}
131
132impl ProviderBuilder for AzureExtBuilder {
133    type Output = AzureExt;
134    type ApiKey = AzureOpenAIAuth;
135
136    const BASE_URL: &'static str = "";
137
138    fn finish<H>(
139        &self,
140        mut builder: client::ClientBuilder<Self, Self::ApiKey, H>,
141    ) -> http_client::Result<client::ClientBuilder<Self, Self::ApiKey, H>> {
142        use AzureOpenAIAuth::*;
143
144        let auth = builder.get_api_key().clone();
145
146        match auth {
147            Token(token) => bearer_auth_header(builder.headers_mut(), token.as_str())?,
148            ApiKey(key) => {
149                let k = http::HeaderName::from_static("api-key");
150                let v = http::HeaderValue::from_str(key.as_str())?;
151
152                builder.headers_mut().insert(k, v);
153            }
154        }
155
156        Ok(builder)
157    }
158}
159
160impl<H> ClientBuilder<H> {
161    /// API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
162    pub fn api_version(mut self, api_version: &str) -> Self {
163        self.ext_mut().api_version = api_version.into();
164
165        self
166    }
167}
168
169impl<H> client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H> {
170    /// Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
171    pub fn azure_endpoint(self, endpoint: String) -> ClientBuilder<H> {
172        self.over_ext(|AzureExtBuilder { api_version, .. }| AzureExtBuilder {
173            endpoint: Some(endpoint),
174            api_version,
175        })
176    }
177}
178
179/// The authentication type for Azure OpenAI. Can either be an API key or a token.
180/// String types will automatically be coerced to a bearer auth token by default.
181#[derive(Clone)]
182pub enum AzureOpenAIAuth {
183    ApiKey(String),
184    Token(String),
185}
186
187impl ApiKey for AzureOpenAIAuth {}
188
189impl std::fmt::Debug for AzureOpenAIAuth {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        match self {
192            Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
193            Self::Token(_) => write!(f, "Token <REDACTED>"),
194        }
195    }
196}
197
198impl<S> From<S> for AzureOpenAIAuth
199where
200    S: Into<String>,
201{
202    fn from(token: S) -> Self {
203        AzureOpenAIAuth::Token(token.into())
204    }
205}
206
207impl<T> Client<T>
208where
209    T: HttpClientExt,
210{
211    fn endpoint(&self) -> &str {
212        &self.ext().endpoint
213    }
214
215    fn api_version(&self) -> &str {
216        &self.ext().api_version
217    }
218
219    fn post_embedding(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
220        let url = format!(
221            "{}/openai/deployments/{}/embeddings?api-version={}",
222            self.endpoint(),
223            deployment_id.trim_start_matches('/'),
224            self.api_version()
225        );
226
227        self.post(&url)
228    }
229
230    #[cfg(feature = "audio")]
231    fn post_audio_generation(
232        &self,
233        deployment_id: &str,
234    ) -> http_client::Result<http_client::Builder> {
235        let url = format!(
236            "{}/openai/deployments/{}/audio/speech?api-version={}",
237            self.endpoint(),
238            deployment_id.trim_start_matches('/'),
239            self.api_version()
240        );
241
242        self.post(url)
243    }
244
245    fn post_chat_completion(
246        &self,
247        deployment_id: &str,
248    ) -> http_client::Result<http_client::Builder> {
249        let url = format!(
250            "{}/openai/deployments/{}/chat/completions?api-version={}",
251            self.endpoint(),
252            deployment_id.trim_start_matches('/'),
253            self.api_version()
254        );
255
256        self.post(&url)
257    }
258
259    fn post_transcription(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
260        let url = format!(
261            "{}/openai/deployments/{}/audio/translations?api-version={}",
262            self.endpoint(),
263            deployment_id.trim_start_matches('/'),
264            self.api_version()
265        );
266
267        self.post(&url)
268    }
269
270    #[cfg(feature = "image")]
271    fn post_image_generation(
272        &self,
273        deployment_id: &str,
274    ) -> http_client::Result<http_client::Builder> {
275        let url = format!(
276            "{}/openai/deployments/{}/images/generations?api-version={}",
277            self.endpoint(),
278            deployment_id.trim_start_matches('/'),
279            self.api_version()
280        );
281
282        self.post(&url)
283    }
284}
285
286pub struct AzureOpenAIClientParams {
287    api_key: String,
288    version: String,
289    header: String,
290}
291
292impl ProviderClient for Client {
293    type Input = AzureOpenAIClientParams;
294
295    /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
296    fn from_env() -> Self {
297        let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
298            AzureOpenAIAuth::ApiKey(api_key)
299        } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
300            AzureOpenAIAuth::Token(token)
301        } else {
302            panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
303        };
304
305        let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
306        let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
307
308        Self::builder()
309            .api_key(auth)
310            .azure_endpoint(azure_endpoint)
311            .api_version(&api_version)
312            .build()
313            .unwrap()
314    }
315
316    fn from_val(
317        AzureOpenAIClientParams {
318            api_key,
319            version,
320            header,
321        }: Self::Input,
322    ) -> Self {
323        let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
324
325        Self::builder()
326            .api_key(auth)
327            .azure_endpoint(header)
328            .api_version(&version)
329            .build()
330            .unwrap()
331    }
332}
333
334#[derive(Debug, Deserialize)]
335struct ApiErrorResponse {
336    message: String,
337}
338
339#[derive(Debug, Deserialize)]
340#[serde(untagged)]
341enum ApiResponse<T> {
342    Ok(T),
343    Err(ApiErrorResponse),
344}
345
346// ================================================================
347// Azure OpenAI Embedding API
348// ================================================================
349
350/// `text-embedding-3-large` embedding model
351pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
352/// `text-embedding-3-small` embedding model
353pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
354/// `text-embedding-ada-002` embedding model
355pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
356
357fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
358    match identifier {
359        TEXT_EMBEDDING_3_LARGE => Some(3_072),
360        TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
361        _ => None,
362    }
363}
364
365#[derive(Debug, Deserialize)]
366pub struct EmbeddingResponse {
367    pub object: String,
368    pub data: Vec<EmbeddingData>,
369    pub model: String,
370    pub usage: Usage,
371}
372
373impl From<ApiErrorResponse> for EmbeddingError {
374    fn from(err: ApiErrorResponse) -> Self {
375        EmbeddingError::ProviderError(err.message)
376    }
377}
378
379impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
380    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
381        match value {
382            ApiResponse::Ok(response) => Ok(response),
383            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
384        }
385    }
386}
387
388#[derive(Debug, Deserialize)]
389pub struct EmbeddingData {
390    pub object: String,
391    pub embedding: Vec<f64>,
392    pub index: usize,
393}
394
395#[derive(Clone, Debug, Deserialize)]
396pub struct Usage {
397    pub prompt_tokens: usize,
398    pub total_tokens: usize,
399}
400
401impl GetTokenUsage for Usage {
402    fn token_usage(&self) -> Option<crate::completion::Usage> {
403        let mut usage = crate::completion::Usage::new();
404
405        usage.input_tokens = self.prompt_tokens as u64;
406        usage.total_tokens = self.total_tokens as u64;
407        usage.output_tokens = usage.total_tokens - usage.input_tokens;
408
409        Some(usage)
410    }
411}
412
413impl std::fmt::Display for Usage {
414    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415        write!(
416            f,
417            "Prompt tokens: {} Total tokens: {}",
418            self.prompt_tokens, self.total_tokens
419        )
420    }
421}
422
423#[derive(Clone)]
424pub struct EmbeddingModel<T = reqwest::Client> {
425    client: Client<T>,
426    pub model: String,
427    ndims: usize,
428}
429
430impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
431where
432    T: HttpClientExt + Default + Clone + 'static,
433{
434    const MAX_DOCUMENTS: usize = 1024;
435
436    type Client = Client<T>;
437
438    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
439        Self::new(client.clone(), model, dims)
440    }
441
442    fn ndims(&self) -> usize {
443        self.ndims
444    }
445
446    async fn embed_texts(
447        &self,
448        documents: impl IntoIterator<Item = String>,
449    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
450        let documents = documents.into_iter().collect::<Vec<_>>();
451
452        let mut body = json!({
453            "input": documents,
454        });
455
456        if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
457            body["dimensions"] = json!(self.ndims);
458        }
459
460        let body = serde_json::to_vec(&body)?;
461
462        let req = self
463            .client
464            .post_embedding(self.model.as_str())?
465            .body(body)
466            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
467
468        let response = self.client.send(req).await?;
469
470        if response.status().is_success() {
471            let body: Vec<u8> = response.into_body().await?;
472            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
473
474            match body {
475                ApiResponse::Ok(response) => {
476                    tracing::info!(target: "rig",
477                        "Azure embedding token usage: {}",
478                        response.usage
479                    );
480
481                    if response.data.len() != documents.len() {
482                        return Err(EmbeddingError::ResponseError(
483                            "Response data length does not match input length".into(),
484                        ));
485                    }
486
487                    Ok(response
488                        .data
489                        .into_iter()
490                        .zip(documents.into_iter())
491                        .map(|(embedding, document)| embeddings::Embedding {
492                            document,
493                            vec: embedding.embedding,
494                        })
495                        .collect())
496                }
497                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
498            }
499        } else {
500            let text = http_client::text(response).await?;
501            Err(EmbeddingError::ProviderError(text))
502        }
503    }
504}
505
506impl<T> EmbeddingModel<T> {
507    pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
508        let model = model.into();
509        let ndims = ndims
510            .or(model_dimensions_from_identifier(&model))
511            .unwrap_or_default();
512
513        Self {
514            client,
515            model,
516            ndims,
517        }
518    }
519
520    pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
521        let ndims = ndims.unwrap_or_default();
522
523        Self {
524            client,
525            model: model.into(),
526            ndims,
527        }
528    }
529}
530
531// ================================================================
532// Azure OpenAI Completion API
533// ================================================================
534
535/// `o1` completion model
536pub const O1: &str = "o1";
537/// `o1-preview` completion model
538pub const O1_PREVIEW: &str = "o1-preview";
539/// `o1-mini` completion model
540pub const O1_MINI: &str = "o1-mini";
541/// `gpt-4o` completion model
542pub const GPT_4O: &str = "gpt-4o";
543/// `gpt-4o-mini` completion model
544pub const GPT_4O_MINI: &str = "gpt-4o-mini";
545/// `gpt-4o-realtime-preview` completion model
546pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
547/// `gpt-4-turbo` completion model
548pub const GPT_4_TURBO: &str = "gpt-4";
549/// `gpt-4` completion model
550pub const GPT_4: &str = "gpt-4";
551/// `gpt-4-32k` completion model
552pub const GPT_4_32K: &str = "gpt-4-32k";
553/// `gpt-4-32k` completion model
554pub const GPT_4_32K_0613: &str = "gpt-4-32k";
555/// `gpt-3.5-turbo` completion model
556pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
557/// `gpt-3.5-turbo-instruct` completion model
558pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
559/// `gpt-3.5-turbo-16k` completion model
560pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
561
562#[derive(Debug, Serialize, Deserialize)]
563pub(super) struct AzureOpenAICompletionRequest {
564    model: String,
565    pub messages: Vec<openai::Message>,
566    #[serde(skip_serializing_if = "Option::is_none")]
567    temperature: Option<f64>,
568    #[serde(skip_serializing_if = "Vec::is_empty")]
569    tools: Vec<openai::ToolDefinition>,
570    #[serde(skip_serializing_if = "Option::is_none")]
571    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
572    #[serde(flatten, skip_serializing_if = "Option::is_none")]
573    pub additional_params: Option<serde_json::Value>,
574}
575
576impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
577    type Error = CompletionError;
578
579    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
580        let model = req.model.clone().unwrap_or_else(|| model.to_string());
581        //FIXME: Must fix!
582        if req.tool_choice.is_some() {
583            tracing::warn!(
584                "Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25."
585            );
586        }
587
588        let mut full_history: Vec<openai::Message> = match &req.preamble {
589            Some(preamble) => vec![openai::Message::system(preamble)],
590            None => vec![],
591        };
592
593        if let Some(docs) = req.normalized_documents() {
594            let docs: Vec<openai::Message> = docs.try_into()?;
595            full_history.extend(docs);
596        }
597
598        let chat_history: Vec<openai::Message> = req
599            .chat_history
600            .clone()
601            .into_iter()
602            .map(|message| message.try_into())
603            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
604            .into_iter()
605            .flatten()
606            .collect();
607
608        full_history.extend(chat_history);
609
610        let tool_choice = req
611            .tool_choice
612            .clone()
613            .map(crate::providers::openrouter::ToolChoice::try_from)
614            .transpose()?;
615
616        Ok(Self {
617            model: model.to_string(),
618            messages: full_history,
619            temperature: req.temperature,
620            tools: req
621                .tools
622                .clone()
623                .into_iter()
624                .map(openai::ToolDefinition::from)
625                .collect::<Vec<_>>(),
626            tool_choice,
627            additional_params: req.additional_params,
628        })
629    }
630}
631
632#[derive(Clone)]
633pub struct CompletionModel<T = reqwest::Client> {
634    client: Client<T>,
635    /// Name of the model (e.g.: gpt-4o-mini)
636    pub model: String,
637}
638
639impl<T> CompletionModel<T> {
640    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
641        Self {
642            client,
643            model: model.into(),
644        }
645    }
646}
647
648impl<T> completion::CompletionModel for CompletionModel<T>
649where
650    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
651{
652    type Response = openai::CompletionResponse;
653    type StreamingResponse = openai::StreamingCompletionResponse;
654    type Client = Client<T>;
655
656    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
657        Self::new(client.clone(), model.into())
658    }
659
660    async fn completion(
661        &self,
662        completion_request: CompletionRequest,
663    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
664        let span = if tracing::Span::current().is_disabled() {
665            info_span!(
666                target: "rig::completions",
667                "chat",
668                gen_ai.operation.name = "chat",
669                gen_ai.provider.name = "azure.openai",
670                gen_ai.request.model = self.model,
671                gen_ai.system_instructions = &completion_request.preamble,
672                gen_ai.response.id = tracing::field::Empty,
673                gen_ai.response.model = tracing::field::Empty,
674                gen_ai.usage.output_tokens = tracing::field::Empty,
675                gen_ai.usage.input_tokens = tracing::field::Empty,
676            )
677        } else {
678            tracing::Span::current()
679        };
680
681        let request =
682            AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
683
684        if enabled!(Level::TRACE) {
685            tracing::trace!(target: "rig::completions",
686                "Azure OpenAI completion request: {}",
687                serde_json::to_string_pretty(&request)?
688            );
689        }
690
691        let body = serde_json::to_vec(&request)?;
692
693        let req = self
694            .client
695            .post_chat_completion(&self.model)?
696            .body(body)
697            .map_err(http_client::Error::from)?;
698
699        async move {
700            let response = self.client.send::<_, Bytes>(req).await?;
701
702            let status = response.status();
703            let response_body = response.into_body().into_future().await?.to_vec();
704
705            if status.is_success() {
706                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
707                    &response_body,
708                )? {
709                    ApiResponse::Ok(response) => {
710                        let span = tracing::Span::current();
711                        span.record_response_metadata(&response);
712                        span.record_token_usage(&response.usage);
713                        if enabled!(Level::TRACE) {
714                            tracing::trace!(target: "rig::completions",
715                                "Azure OpenAI completion response: {}",
716                                serde_json::to_string_pretty(&response)?
717                            );
718                        }
719                        response.try_into()
720                    }
721                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
722                }
723            } else {
724                Err(CompletionError::ProviderError(
725                    String::from_utf8_lossy(&response_body).to_string(),
726                ))
727            }
728        }
729        .instrument(span)
730        .await
731    }
732
733    async fn stream(
734        &self,
735        completion_request: CompletionRequest,
736    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
737        let preamble = completion_request.preamble.clone();
738        let mut request =
739            AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
740
741        let params = json_utils::merge(
742            request.additional_params.unwrap_or(serde_json::json!({})),
743            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
744        );
745
746        request.additional_params = Some(params);
747
748        if enabled!(Level::TRACE) {
749            tracing::trace!(target: "rig::completions",
750                "Azure OpenAI completion request: {}",
751                serde_json::to_string_pretty(&request)?
752            );
753        }
754
755        let body = serde_json::to_vec(&request)?;
756
757        let req = self
758            .client
759            .post_chat_completion(&self.model)?
760            .body(body)
761            .map_err(http_client::Error::from)?;
762
763        let span = if tracing::Span::current().is_disabled() {
764            info_span!(
765                target: "rig::completions",
766                "chat_streaming",
767                gen_ai.operation.name = "chat_streaming",
768                gen_ai.provider.name = "azure.openai",
769                gen_ai.request.model = self.model,
770                gen_ai.system_instructions = &preamble,
771                gen_ai.response.id = tracing::field::Empty,
772                gen_ai.response.model = tracing::field::Empty,
773                gen_ai.usage.output_tokens = tracing::field::Empty,
774                gen_ai.usage.input_tokens = tracing::field::Empty,
775            )
776        } else {
777            tracing::Span::current()
778        };
779
780        tracing_futures::Instrument::instrument(
781            send_compatible_streaming_request(self.client.clone(), req),
782            span,
783        )
784        .await
785    }
786}
787
788// ================================================================
789// Azure OpenAI Transcription API
790// ================================================================
791
792#[derive(Clone)]
793pub struct TranscriptionModel<T = reqwest::Client> {
794    client: Client<T>,
795    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
796    pub model: String,
797}
798
799impl<T> TranscriptionModel<T> {
800    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
801        Self {
802            client,
803            model: model.into(),
804        }
805    }
806}
807
808impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
809where
810    T: HttpClientExt + Clone + 'static,
811{
812    type Response = TranscriptionResponse;
813    type Client = Client<T>;
814
815    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
816        Self::new(client.clone(), model)
817    }
818
819    async fn transcription(
820        &self,
821        request: transcription::TranscriptionRequest,
822    ) -> Result<
823        transcription::TranscriptionResponse<Self::Response>,
824        transcription::TranscriptionError,
825    > {
826        let data = request.data;
827
828        let mut body =
829            MultipartForm::new().part(Part::bytes("file", data).filename(request.filename.clone()));
830
831        if let Some(prompt) = request.prompt {
832            body = body.text("prompt", prompt.clone());
833        }
834
835        if let Some(ref temperature) = request.temperature {
836            body = body.text("temperature", temperature.to_string());
837        }
838
839        if let Some(ref additional_params) = request.additional_params {
840            for (key, value) in additional_params
841                .as_object()
842                .expect("Additional Parameters to OpenAI Transcription should be a map")
843            {
844                body = body.text(key.to_owned(), value.to_string());
845            }
846        }
847
848        let req = self
849            .client
850            .post_transcription(&self.model)?
851            .body(body)
852            .map_err(|e| TranscriptionError::HttpError(e.into()))?;
853
854        let response = self.client.send_multipart::<Bytes>(req).await?;
855        let status = response.status();
856        let response_body = response.into_body().into_future().await?.to_vec();
857
858        if status.is_success() {
859            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
860                ApiResponse::Ok(response) => response.try_into(),
861                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
862                    api_error_response.message,
863                )),
864            }
865        } else {
866            Err(TranscriptionError::ProviderError(
867                String::from_utf8_lossy(&response_body).to_string(),
868            ))
869        }
870    }
871}
872
873// ================================================================
874// Azure OpenAI Image Generation API
875// ================================================================
876#[cfg(feature = "image")]
877pub use image_generation::*;
878use tracing::{Instrument, Level, enabled, info_span};
879#[cfg(feature = "image")]
880#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
881mod image_generation {
882    use crate::http_client::HttpClientExt;
883    use crate::image_generation;
884    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
885    use crate::providers::azure::{ApiResponse, Client};
886    use crate::providers::openai::ImageGenerationResponse;
887    use bytes::Bytes;
888    use serde_json::json;
889
890    #[derive(Clone)]
891    pub struct ImageGenerationModel<T = reqwest::Client> {
892        client: Client<T>,
893        pub model: String,
894    }
895
896    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
897    where
898        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
899    {
900        type Response = ImageGenerationResponse;
901
902        type Client = Client<T>;
903
904        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
905            Self {
906                client: client.clone(),
907                model: model.into(),
908            }
909        }
910
911        async fn image_generation(
912            &self,
913            generation_request: ImageGenerationRequest,
914        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
915        {
916            let request = json!({
917                "model": self.model,
918                "prompt": generation_request.prompt,
919                "size": format!("{}x{}", generation_request.width, generation_request.height),
920                "response_format": "b64_json"
921            });
922
923            let body = serde_json::to_vec(&request)?;
924
925            let req = self
926                .client
927                .post_image_generation(&self.model)?
928                .body(body)
929                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
930
931            let response = self.client.send::<_, Bytes>(req).await?;
932            let status = response.status();
933            let response_body = response.into_body().into_future().await?.to_vec();
934
935            if !status.is_success() {
936                return Err(ImageGenerationError::ProviderError(format!(
937                    "{status}: {}",
938                    String::from_utf8_lossy(&response_body)
939                )));
940            }
941
942            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
943                ApiResponse::Ok(response) => response.try_into(),
944                ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
945            }
946        }
947    }
948}
949// ================================================================
950// Azure OpenAI Audio Generation API
951// ================================================================
952
953#[cfg(feature = "audio")]
954pub use audio_generation::*;
955
956#[cfg(feature = "audio")]
957#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
958mod audio_generation {
959    use super::Client;
960    use crate::audio_generation::{
961        self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
962    };
963    use crate::http_client::HttpClientExt;
964    use bytes::Bytes;
965    use serde_json::json;
966
967    #[derive(Clone)]
968    pub struct AudioGenerationModel<T = reqwest::Client> {
969        client: Client<T>,
970        model: String,
971    }
972
973    impl<T> AudioGenerationModel<T> {
974        pub fn new(client: Client<T>, deployment_name: impl Into<String>) -> Self {
975            Self {
976                client,
977                model: deployment_name.into(),
978            }
979        }
980    }
981
982    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
983    where
984        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
985    {
986        type Response = Bytes;
987        type Client = Client<T>;
988
989        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
990            Self::new(client.clone(), model)
991        }
992
993        async fn audio_generation(
994            &self,
995            request: AudioGenerationRequest,
996        ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
997            let request = json!({
998                "model": self.model,
999                "input": request.text,
1000                "voice": request.voice,
1001                "speed": request.speed,
1002            });
1003
1004            let body = serde_json::to_vec(&request)?;
1005
1006            let req = self
1007                .client
1008                .post_audio_generation("/audio/speech")?
1009                .header("Content-Type", "application/json")
1010                .body(body)
1011                .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
1012
1013            let response = self.client.send::<_, Bytes>(req).await?;
1014            let status = response.status();
1015            let response_body = response.into_body().into_future().await?;
1016
1017            if !status.is_success() {
1018                return Err(AudioGenerationError::ProviderError(format!(
1019                    "{status}: {}",
1020                    String::from_utf8_lossy(&response_body)
1021                )));
1022            }
1023
1024            Ok(AudioGenerationResponse {
1025                audio: response_body.to_vec(),
1026                response: response_body,
1027            })
1028        }
1029    }
1030}
1031
1032#[cfg(test)]
1033mod azure_tests {
1034    use super::*;
1035
1036    use crate::OneOrMany;
1037    use crate::client::{completion::CompletionClient, embeddings::EmbeddingsClient};
1038    use crate::completion::CompletionModel;
1039    use crate::embeddings::EmbeddingModel;
1040
1041    #[tokio::test]
1042    #[ignore]
1043    async fn test_azure_embedding() {
1044        let _ = tracing_subscriber::fmt::try_init();
1045
1046        let client = Client::<reqwest::Client>::from_env();
1047        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1048        let embeddings = model
1049            .embed_texts(vec!["Hello, world!".to_string()])
1050            .await
1051            .unwrap();
1052
1053        tracing::info!("Azure embedding: {:?}", embeddings);
1054    }
1055
1056    #[tokio::test]
1057    #[ignore]
1058    async fn test_azure_embedding_dimensions() {
1059        let _ = tracing_subscriber::fmt::try_init();
1060
1061        let ndims = 256;
1062        let client = Client::<reqwest::Client>::from_env();
1063        let model = client.embedding_model_with_ndims(TEXT_EMBEDDING_3_SMALL, ndims);
1064        let embedding = model.embed_text("Hello, world!").await.unwrap();
1065
1066        assert!(embedding.vec.len() == ndims);
1067
1068        tracing::info!("Azure dimensions embedding: {:?}", embedding);
1069    }
1070
1071    #[tokio::test]
1072    #[ignore]
1073    async fn test_azure_completion() {
1074        let _ = tracing_subscriber::fmt::try_init();
1075
1076        let client = Client::<reqwest::Client>::from_env();
1077        let model = client.completion_model(GPT_4O_MINI);
1078        let completion = model
1079            .completion(CompletionRequest {
1080                model: None,
1081                preamble: Some("You are a helpful assistant.".to_string()),
1082                chat_history: OneOrMany::one("Hello!".into()),
1083                documents: vec![],
1084                max_tokens: Some(100),
1085                temperature: Some(0.0),
1086                tools: vec![],
1087                tool_choice: None,
1088                additional_params: None,
1089                output_schema: None,
1090            })
1091            .await
1092            .unwrap();
1093
1094        tracing::info!("Azure completion: {:?}", completion);
1095    }
1096    #[tokio::test]
1097    async fn test_client_initialization() {
1098        let _client: crate::providers::azure::Client<reqwest::Client> =
1099            crate::providers::azure::Client::builder()
1100                .api_key("test")
1101                .azure_endpoint("test".to_string()) // add your endpoint here!
1102                .build()
1103                .expect("Client::builder() failed");
1104    }
1105}