Skip to main content

rig_core/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```no_run
5//! use rig_core::providers::azure;
6//! use rig_core::client::CompletionClient;
7//!
8//! # fn run() -> Result<(), Box<dyn std::error::Error>> {
9//! let client = azure::Client::builder()
10//!     .api_key("test")
11//!     .azure_endpoint("test".to_string()) // add your endpoint here!
12//!     .build()?;
13//!
14//! let gpt4o = client.completion_model(azure::GPT_4O);
15//! # Ok(())
16//! # }
17//! ```
18//!
19//! ## Authentication
20//! The authentication type used for the `azure` module is [`AzureOpenAIAuth`].
21//!
22//! By default, using a type that implements `Into<String>` as the input for the client builder will turn the type into a bearer auth token.
23//! If you want to use an API key, you need to use the type specifically.
24
25use std::fmt::Debug;
26
27use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
28use crate::client::{
29    self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
30    ProviderClient,
31};
32use crate::completion::GetTokenUsage;
33use crate::http_client::multipart::Part;
34use crate::http_client::{self, HttpClientExt, MultipartForm, bearer_auth_header};
35use crate::streaming::StreamingCompletionResponse;
36use crate::transcription::TranscriptionError;
37use crate::{
38    completion::{self, CompletionError, CompletionRequest},
39    embeddings::{self, EmbeddingError},
40    json_utils,
41    providers::openai,
42    telemetry::SpanCombinator,
43    transcription::{self},
44};
45use bytes::Bytes;
46use serde::{Deserialize, Serialize};
47use serde_json::json;
48// ================================================================
49// Main Azure OpenAI Client
50// ================================================================
51
52const DEFAULT_API_VERSION: &str = "2024-10-21";
53
54#[derive(Debug, Clone)]
55pub struct AzureExt {
56    endpoint: String,
57    api_version: String,
58}
59
60impl DebugExt for AzureExt {
61    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
62        [
63            ("endpoint", (&self.endpoint as &dyn Debug)),
64            ("api_version", (&self.api_version as &dyn Debug)),
65        ]
66        .into_iter()
67    }
68}
69
70// TODO: @FayCarsons - this should be a type-safe builder,
71// but that would require extending the `ProviderBuilder`
72// to have some notion of complete vs incomplete states in a
73// given extension builder
74#[derive(Debug, Clone)]
75pub struct AzureExtBuilder {
76    endpoint: Option<String>,
77    api_version: String,
78}
79
80impl Default for AzureExtBuilder {
81    fn default() -> Self {
82        Self {
83            endpoint: None,
84            api_version: DEFAULT_API_VERSION.into(),
85        }
86    }
87}
88
89pub type Client<H = reqwest::Client> = client::Client<AzureExt, H>;
90pub type ClientBuilder<H = crate::markers::Missing> =
91    client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H>;
92
93impl Provider for AzureExt {
94    type Builder = AzureExtBuilder;
95
96    /// Verifying Azure auth without consuming tokens is not supported
97    const VERIFY_PATH: &'static str = "";
98}
99
100impl<H> Capabilities<H> for AzureExt {
101    type Completion = Capable<CompletionModel<H>>;
102    type Embeddings = Capable<EmbeddingModel<H>>;
103    type Transcription = Capable<TranscriptionModel<H>>;
104    type ModelListing = Nothing;
105    #[cfg(feature = "image")]
106    type ImageGeneration = Nothing;
107    #[cfg(feature = "audio")]
108    type AudioGeneration = Capable<AudioGenerationModel<H>>;
109}
110
111impl ProviderBuilder for AzureExtBuilder {
112    type Extension<H>
113        = AzureExt
114    where
115        H: HttpClientExt;
116    type ApiKey = AzureOpenAIAuth;
117
118    const BASE_URL: &'static str = "";
119
120    fn build<H>(
121        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
122    ) -> http_client::Result<Self::Extension<H>>
123    where
124        H: HttpClientExt,
125    {
126        let AzureExtBuilder {
127            endpoint,
128            api_version,
129            ..
130        } = builder.ext().clone();
131
132        match endpoint {
133            Some(endpoint) => Ok(AzureExt {
134                endpoint,
135                api_version,
136            }),
137            None => Err(http_client::Error::Instance(
138                "Azure client must be provided an endpoint prior to building".into(),
139            )),
140        }
141    }
142
143    fn finish<H>(
144        &self,
145        mut builder: client::ClientBuilder<Self, Self::ApiKey, H>,
146    ) -> http_client::Result<client::ClientBuilder<Self, Self::ApiKey, H>> {
147        use AzureOpenAIAuth::*;
148
149        let auth = builder.get_api_key().clone();
150
151        match auth {
152            Token(token) => bearer_auth_header(builder.headers_mut(), token.as_str())?,
153            ApiKey(key) => {
154                let k = http::HeaderName::from_static("api-key");
155                let v = http::HeaderValue::from_str(key.as_str())?;
156
157                builder.headers_mut().insert(k, v);
158            }
159        }
160
161        Ok(builder)
162    }
163}
164
165impl<H> ClientBuilder<H> {
166    /// API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
167    pub fn api_version(mut self, api_version: &str) -> Self {
168        self.ext_mut().api_version = api_version.into();
169
170        self
171    }
172}
173
174impl<H> client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H> {
175    /// Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
176    pub fn azure_endpoint(self, endpoint: String) -> ClientBuilder<H> {
177        self.over_ext(|AzureExtBuilder { api_version, .. }| AzureExtBuilder {
178            endpoint: Some(endpoint),
179            api_version,
180        })
181    }
182}
183
184/// The authentication type for Azure OpenAI. Can either be an API key or a token.
185/// String types will automatically be coerced to a bearer auth token by default.
186#[derive(Clone)]
187pub enum AzureOpenAIAuth {
188    ApiKey(String),
189    Token(String),
190}
191
192impl ApiKey for AzureOpenAIAuth {}
193
194impl std::fmt::Debug for AzureOpenAIAuth {
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        match self {
197            Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
198            Self::Token(_) => write!(f, "Token <REDACTED>"),
199        }
200    }
201}
202
203impl<S> From<S> for AzureOpenAIAuth
204where
205    S: Into<String>,
206{
207    fn from(token: S) -> Self {
208        AzureOpenAIAuth::Token(token.into())
209    }
210}
211
212impl<T> Client<T>
213where
214    T: HttpClientExt,
215{
216    fn endpoint(&self) -> &str {
217        &self.ext().endpoint
218    }
219
220    fn api_version(&self) -> &str {
221        &self.ext().api_version
222    }
223
224    fn post_embedding(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
225        let url = format!(
226            "{}/openai/deployments/{}/embeddings?api-version={}",
227            self.endpoint(),
228            deployment_id.trim_start_matches('/'),
229            self.api_version()
230        );
231
232        self.post(&url)
233    }
234
235    #[cfg(feature = "audio")]
236    fn post_audio_generation(
237        &self,
238        deployment_id: &str,
239    ) -> http_client::Result<http_client::Builder> {
240        let url = format!(
241            "{}/openai/deployments/{}/audio/speech?api-version={}",
242            self.endpoint(),
243            deployment_id.trim_start_matches('/'),
244            self.api_version()
245        );
246
247        self.post(url)
248    }
249
250    fn post_chat_completion(
251        &self,
252        deployment_id: &str,
253    ) -> http_client::Result<http_client::Builder> {
254        let url = format!(
255            "{}/openai/deployments/{}/chat/completions?api-version={}",
256            self.endpoint(),
257            deployment_id.trim_start_matches('/'),
258            self.api_version()
259        );
260
261        self.post(&url)
262    }
263
264    fn post_transcription(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
265        let url = format!(
266            "{}/openai/deployments/{}/audio/translations?api-version={}",
267            self.endpoint(),
268            deployment_id.trim_start_matches('/'),
269            self.api_version()
270        );
271
272        self.post(&url)
273    }
274
275    #[cfg(feature = "image")]
276    fn post_image_generation(
277        &self,
278        deployment_id: &str,
279    ) -> http_client::Result<http_client::Builder> {
280        let url = format!(
281            "{}/openai/deployments/{}/images/generations?api-version={}",
282            self.endpoint(),
283            deployment_id.trim_start_matches('/'),
284            self.api_version()
285        );
286
287        self.post(&url)
288    }
289}
290
291pub struct AzureOpenAIClientParams {
292    api_key: String,
293    version: String,
294    header: String,
295}
296
297impl ProviderClient for Client {
298    type Input = AzureOpenAIClientParams;
299    type Error = crate::client::ProviderClientError;
300
301    /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
302    fn from_env() -> Result<Self, Self::Error> {
303        let auth = if let Some(api_key) = crate::client::optional_env_var("AZURE_API_KEY")? {
304            AzureOpenAIAuth::ApiKey(api_key)
305        } else if let Some(token) = crate::client::optional_env_var("AZURE_TOKEN")? {
306            AzureOpenAIAuth::Token(token)
307        } else {
308            return Err(crate::client::ProviderClientError::InvalidConfiguration(
309                "either `AZURE_API_KEY` or `AZURE_TOKEN` must be set",
310            ));
311        };
312
313        let api_version = crate::client::required_env_var("AZURE_API_VERSION")?;
314        let azure_endpoint = crate::client::required_env_var("AZURE_ENDPOINT")?;
315
316        Self::builder()
317            .api_key(auth)
318            .azure_endpoint(azure_endpoint)
319            .api_version(&api_version)
320            .build()
321            .map_err(Into::into)
322    }
323
324    fn from_val(
325        AzureOpenAIClientParams {
326            api_key,
327            version,
328            header,
329        }: Self::Input,
330    ) -> Result<Self, Self::Error> {
331        let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
332
333        Self::builder()
334            .api_key(auth)
335            .azure_endpoint(header)
336            .api_version(&version)
337            .build()
338            .map_err(Into::into)
339    }
340}
341
342#[derive(Debug, Deserialize)]
343struct ApiErrorResponse {
344    message: String,
345}
346
347#[derive(Debug, Deserialize)]
348#[serde(untagged)]
349enum ApiResponse<T> {
350    Ok(T),
351    Err(ApiErrorResponse),
352}
353
354// ================================================================
355// Azure OpenAI Embedding API
356// ================================================================
357
358/// `text-embedding-3-large` embedding model
359pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
360/// `text-embedding-3-small` embedding model
361pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
362/// `text-embedding-ada-002` embedding model
363pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
364
365fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
366    match identifier {
367        TEXT_EMBEDDING_3_LARGE => Some(3_072),
368        TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
369        _ => None,
370    }
371}
372
373#[derive(Debug, Deserialize)]
374pub struct EmbeddingResponse {
375    pub object: String,
376    pub data: Vec<EmbeddingData>,
377    pub model: String,
378    pub usage: Usage,
379}
380
381impl From<ApiErrorResponse> for EmbeddingError {
382    fn from(err: ApiErrorResponse) -> Self {
383        EmbeddingError::ProviderError(err.message)
384    }
385}
386
387impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
388    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
389        match value {
390            ApiResponse::Ok(response) => Ok(response),
391            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
392        }
393    }
394}
395
396#[derive(Debug, Deserialize)]
397pub struct EmbeddingData {
398    pub object: String,
399    pub embedding: Vec<f64>,
400    pub index: usize,
401}
402
403#[derive(Clone, Debug, Deserialize)]
404pub struct Usage {
405    pub prompt_tokens: usize,
406    pub total_tokens: usize,
407}
408
409impl GetTokenUsage for Usage {
410    fn token_usage(&self) -> Option<crate::completion::Usage> {
411        let mut usage = crate::completion::Usage::new();
412
413        usage.input_tokens = self.prompt_tokens as u64;
414        usage.total_tokens = self.total_tokens as u64;
415        usage.output_tokens = usage.total_tokens - usage.input_tokens;
416
417        Some(usage)
418    }
419}
420
421impl std::fmt::Display for Usage {
422    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
423        write!(
424            f,
425            "Prompt tokens: {} Total tokens: {}",
426            self.prompt_tokens, self.total_tokens
427        )
428    }
429}
430
431#[derive(Clone)]
432pub struct EmbeddingModel<T = reqwest::Client> {
433    client: Client<T>,
434    pub model: String,
435    ndims: usize,
436}
437
438impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
439where
440    T: HttpClientExt + Default + Clone + 'static,
441{
442    const MAX_DOCUMENTS: usize = 1024;
443
444    type Client = Client<T>;
445
446    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
447        Self::new(client.clone(), model, dims)
448    }
449
450    fn ndims(&self) -> usize {
451        self.ndims
452    }
453
454    async fn embed_texts(
455        &self,
456        documents: impl IntoIterator<Item = String>,
457    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
458        let documents = documents.into_iter().collect::<Vec<_>>();
459
460        let mut body = json!({
461            "input": documents,
462        });
463
464        let body_object = body.as_object_mut().ok_or_else(|| {
465            EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
466        })?;
467
468        if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
469            body_object.insert("dimensions".to_owned(), json!(self.ndims));
470        }
471
472        let body = serde_json::to_vec(&body)?;
473
474        let req = self
475            .client
476            .post_embedding(self.model.as_str())?
477            .body(body)
478            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
479
480        let response = self.client.send(req).await?;
481
482        if response.status().is_success() {
483            let body: Vec<u8> = response.into_body().await?;
484            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
485
486            match body {
487                ApiResponse::Ok(response) => {
488                    tracing::info!(target: "rig",
489                        "Azure embedding token usage: {}",
490                        response.usage
491                    );
492
493                    if response.data.len() != documents.len() {
494                        return Err(EmbeddingError::ResponseError(
495                            "Response data length does not match input length".into(),
496                        ));
497                    }
498
499                    Ok(response
500                        .data
501                        .into_iter()
502                        .zip(documents.into_iter())
503                        .map(|(embedding, document)| embeddings::Embedding {
504                            document,
505                            vec: embedding.embedding,
506                        })
507                        .collect())
508                }
509                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
510            }
511        } else {
512            let text = http_client::text(response).await?;
513            Err(EmbeddingError::ProviderError(text))
514        }
515    }
516}
517
518impl<T> EmbeddingModel<T> {
519    pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
520        let model = model.into();
521        let ndims = ndims
522            .or(model_dimensions_from_identifier(&model))
523            .unwrap_or_default();
524
525        Self {
526            client,
527            model,
528            ndims,
529        }
530    }
531
532    pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
533        let ndims = ndims.unwrap_or_default();
534
535        Self {
536            client,
537            model: model.into(),
538            ndims,
539        }
540    }
541}
542
543// ================================================================
544// Azure OpenAI Completion API
545// ================================================================
546
547/// `o1` completion model
548pub const O1: &str = "o1";
549/// `o1-preview` completion model
550pub const O1_PREVIEW: &str = "o1-preview";
551/// `o1-mini` completion model
552pub const O1_MINI: &str = "o1-mini";
553/// `gpt-4o` completion model
554pub const GPT_4O: &str = "gpt-4o";
555/// `gpt-4o-mini` completion model
556pub const GPT_4O_MINI: &str = "gpt-4o-mini";
557/// `gpt-4o-realtime-preview` completion model
558pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
559/// `gpt-4-turbo` completion model
560pub const GPT_4_TURBO: &str = "gpt-4";
561/// `gpt-4` completion model
562pub const GPT_4: &str = "gpt-4";
563/// `gpt-4-32k` completion model
564pub const GPT_4_32K: &str = "gpt-4-32k";
565/// `gpt-4-32k` completion model
566pub const GPT_4_32K_0613: &str = "gpt-4-32k";
567/// `gpt-3.5-turbo` completion model
568pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
569/// `gpt-3.5-turbo-instruct` completion model
570pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
571/// `gpt-3.5-turbo-16k` completion model
572pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
573
574#[derive(Debug, Serialize, Deserialize)]
575pub(super) struct AzureOpenAICompletionRequest {
576    model: String,
577    pub messages: Vec<openai::Message>,
578    #[serde(skip_serializing_if = "Option::is_none")]
579    temperature: Option<f64>,
580    #[serde(skip_serializing_if = "Vec::is_empty")]
581    tools: Vec<openai::ToolDefinition>,
582    #[serde(skip_serializing_if = "Option::is_none")]
583    tool_choice: Option<crate::providers::openai::ToolChoice>,
584    #[serde(flatten, skip_serializing_if = "Option::is_none")]
585    pub additional_params: Option<serde_json::Value>,
586}
587
588impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
589    type Error = CompletionError;
590
591    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
592        let model = req.model.clone().unwrap_or_else(|| model.to_string());
593        if req.tool_choice.is_some() {
594            tracing::warn!("Tool choice is currently not supported in Azure OpenAI.");
595        }
596
597        let mut full_history: Vec<openai::Message> = match &req.preamble {
598            Some(preamble) => vec![openai::Message::system(preamble)],
599            None => vec![],
600        };
601
602        if let Some(docs) = req.normalized_documents() {
603            let docs: Vec<openai::Message> = docs.try_into()?;
604            full_history.extend(docs);
605        }
606
607        let chat_history: Vec<openai::Message> = req
608            .chat_history
609            .clone()
610            .into_iter()
611            .map(|message| message.try_into())
612            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
613            .into_iter()
614            .flatten()
615            .collect();
616
617        full_history.extend(chat_history);
618
619        let tool_choice = req
620            .tool_choice
621            .clone()
622            .map(crate::providers::openai::ToolChoice::try_from)
623            .transpose()?;
624
625        let additional_params = if let Some(schema) = req.output_schema {
626            let name = schema
627                .as_object()
628                .and_then(|o| o.get("title"))
629                .and_then(|v| v.as_str())
630                .unwrap_or("response_schema")
631                .to_string();
632            let mut schema_value = schema.to_value();
633            openai::sanitize_schema(&mut schema_value);
634            let response_format = serde_json::json!({
635                "response_format": {
636                    "type": "json_schema",
637                    "json_schema": {
638                        "name": name,
639                        "strict": true,
640                        "schema": schema_value
641                    }
642                }
643            });
644            Some(match req.additional_params {
645                Some(existing) => json_utils::merge(existing, response_format),
646                None => response_format,
647            })
648        } else {
649            req.additional_params
650        };
651
652        Ok(Self {
653            model: model.to_string(),
654            messages: full_history,
655            temperature: req.temperature,
656            tools: req
657                .tools
658                .clone()
659                .into_iter()
660                .map(openai::ToolDefinition::from)
661                .collect::<Vec<_>>(),
662            tool_choice,
663            additional_params,
664        })
665    }
666}
667
668#[derive(Clone)]
669pub struct CompletionModel<T = reqwest::Client> {
670    client: Client<T>,
671    /// Name of the model (e.g.: gpt-4o-mini)
672    pub model: String,
673}
674
675impl<T> CompletionModel<T> {
676    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
677        Self {
678            client,
679            model: model.into(),
680        }
681    }
682}
683
684impl<T> completion::CompletionModel for CompletionModel<T>
685where
686    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
687{
688    type Response = openai::CompletionResponse;
689    type StreamingResponse = openai::StreamingCompletionResponse;
690    type Client = Client<T>;
691
692    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
693        Self::new(client.clone(), model.into())
694    }
695
696    async fn completion(
697        &self,
698        completion_request: CompletionRequest,
699    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
700        let span = if tracing::Span::current().is_disabled() {
701            info_span!(
702                target: "rig::completions",
703                "chat",
704                gen_ai.operation.name = "chat",
705                gen_ai.provider.name = "azure.openai",
706                gen_ai.request.model = self.model,
707                gen_ai.system_instructions = &completion_request.preamble,
708                gen_ai.response.id = tracing::field::Empty,
709                gen_ai.response.model = tracing::field::Empty,
710                gen_ai.usage.output_tokens = tracing::field::Empty,
711                gen_ai.usage.input_tokens = tracing::field::Empty,
712                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
713            )
714        } else {
715            tracing::Span::current()
716        };
717
718        let request =
719            AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
720
721        if enabled!(Level::TRACE) {
722            tracing::trace!(target: "rig::completions",
723                "Azure OpenAI completion request: {}",
724                serde_json::to_string_pretty(&request)?
725            );
726        }
727
728        let body = serde_json::to_vec(&request)?;
729
730        let req = self
731            .client
732            .post_chat_completion(&self.model)?
733            .body(body)
734            .map_err(http_client::Error::from)?;
735
736        async move {
737            let response = self.client.send::<_, Bytes>(req).await?;
738
739            let status = response.status();
740            let response_body = response.into_body().into_future().await?.to_vec();
741
742            if status.is_success() {
743                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
744                    &response_body,
745                )? {
746                    ApiResponse::Ok(response) => {
747                        let span = tracing::Span::current();
748                        span.record_response_metadata(&response);
749                        span.record_token_usage(&response.usage);
750                        if enabled!(Level::TRACE) {
751                            tracing::trace!(target: "rig::completions",
752                                "Azure OpenAI completion response: {}",
753                                serde_json::to_string_pretty(&response)?
754                            );
755                        }
756                        response.try_into()
757                    }
758                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
759                }
760            } else {
761                Err(CompletionError::ProviderError(
762                    String::from_utf8_lossy(&response_body).to_string(),
763                ))
764            }
765        }
766        .instrument(span)
767        .await
768    }
769
770    async fn stream(
771        &self,
772        completion_request: CompletionRequest,
773    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
774        let preamble = completion_request.preamble.clone();
775        let mut request =
776            AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
777
778        let params = json_utils::merge(
779            request.additional_params.unwrap_or(serde_json::json!({})),
780            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
781        );
782
783        request.additional_params = Some(params);
784
785        if enabled!(Level::TRACE) {
786            tracing::trace!(target: "rig::completions",
787                "Azure OpenAI completion request: {}",
788                serde_json::to_string_pretty(&request)?
789            );
790        }
791
792        let body = serde_json::to_vec(&request)?;
793
794        let req = self
795            .client
796            .post_chat_completion(&self.model)?
797            .body(body)
798            .map_err(http_client::Error::from)?;
799
800        let span = if tracing::Span::current().is_disabled() {
801            info_span!(
802                target: "rig::completions",
803                "chat_streaming",
804                gen_ai.operation.name = "chat_streaming",
805                gen_ai.provider.name = "azure.openai",
806                gen_ai.request.model = self.model,
807                gen_ai.system_instructions = &preamble,
808                gen_ai.response.id = tracing::field::Empty,
809                gen_ai.response.model = tracing::field::Empty,
810                gen_ai.usage.output_tokens = tracing::field::Empty,
811                gen_ai.usage.input_tokens = tracing::field::Empty,
812                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
813            )
814        } else {
815            tracing::Span::current()
816        };
817
818        tracing_futures::Instrument::instrument(
819            send_compatible_streaming_request(self.client.clone(), req),
820            span,
821        )
822        .await
823    }
824}
825
826// ================================================================
827// Azure OpenAI Transcription API
828// ================================================================
829
830#[derive(Clone)]
831pub struct TranscriptionModel<T = reqwest::Client> {
832    client: Client<T>,
833    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
834    pub model: String,
835}
836
837impl<T> TranscriptionModel<T> {
838    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
839        Self {
840            client,
841            model: model.into(),
842        }
843    }
844}
845
846impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
847where
848    T: HttpClientExt + Clone + 'static,
849{
850    type Response = TranscriptionResponse;
851    type Client = Client<T>;
852
853    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
854        Self::new(client.clone(), model)
855    }
856
857    async fn transcription(
858        &self,
859        request: transcription::TranscriptionRequest,
860    ) -> Result<
861        transcription::TranscriptionResponse<Self::Response>,
862        transcription::TranscriptionError,
863    > {
864        let data = request.data;
865
866        let mut body =
867            MultipartForm::new().part(Part::bytes("file", data).filename(request.filename.clone()));
868
869        if let Some(prompt) = request.prompt {
870            body = body.text("prompt", prompt.clone());
871        }
872
873        if let Some(ref temperature) = request.temperature {
874            body = body.text("temperature", temperature.to_string());
875        }
876
877        if let Some(ref additional_params) = request.additional_params {
878            let params = additional_params.as_object().ok_or_else(|| {
879                TranscriptionError::RequestError(Box::new(std::io::Error::new(
880                    std::io::ErrorKind::InvalidInput,
881                    "additional transcription parameters must be a JSON object",
882                )))
883            })?;
884
885            for (key, value) in params {
886                body = body.text(key.to_owned(), value.to_string());
887            }
888        }
889
890        let req = self
891            .client
892            .post_transcription(&self.model)?
893            .body(body)
894            .map_err(|e| TranscriptionError::HttpError(e.into()))?;
895
896        let response = self.client.send_multipart::<Bytes>(req).await?;
897        let status = response.status();
898        let response_body = response.into_body().into_future().await?.to_vec();
899
900        if status.is_success() {
901            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
902                ApiResponse::Ok(response) => response.try_into(),
903                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
904                    api_error_response.message,
905                )),
906            }
907        } else {
908            Err(TranscriptionError::ProviderError(
909                String::from_utf8_lossy(&response_body).to_string(),
910            ))
911        }
912    }
913}
914
915// ================================================================
916// Azure OpenAI Image Generation API
917// ================================================================
918#[cfg(feature = "image")]
919pub use image_generation::*;
920use tracing::{Instrument, Level, enabled, info_span};
921#[cfg(feature = "image")]
922#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
923mod image_generation {
924    use crate::http_client::HttpClientExt;
925    use crate::image_generation;
926    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
927    use crate::providers::azure::{ApiResponse, Client};
928    use crate::providers::openai::ImageGenerationResponse;
929    use bytes::Bytes;
930    use serde_json::json;
931
932    #[derive(Clone)]
933    pub struct ImageGenerationModel<T = reqwest::Client> {
934        client: Client<T>,
935        pub model: String,
936    }
937
938    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
939    where
940        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
941    {
942        type Response = ImageGenerationResponse;
943
944        type Client = Client<T>;
945
946        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
947            Self {
948                client: client.clone(),
949                model: model.into(),
950            }
951        }
952
953        async fn image_generation(
954            &self,
955            generation_request: ImageGenerationRequest,
956        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
957        {
958            let request = json!({
959                "model": self.model,
960                "prompt": generation_request.prompt,
961                "size": format!("{}x{}", generation_request.width, generation_request.height),
962                "response_format": "b64_json"
963            });
964
965            let body = serde_json::to_vec(&request)?;
966
967            let req = self
968                .client
969                .post_image_generation(&self.model)?
970                .body(body)
971                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
972
973            let response = self.client.send::<_, Bytes>(req).await?;
974            let status = response.status();
975            let response_body = response.into_body().into_future().await?.to_vec();
976
977            if !status.is_success() {
978                return Err(ImageGenerationError::ProviderError(format!(
979                    "{status}: {}",
980                    String::from_utf8_lossy(&response_body)
981                )));
982            }
983
984            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
985                ApiResponse::Ok(response) => response.try_into(),
986                ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
987            }
988        }
989    }
990}
991// ================================================================
992// Azure OpenAI Audio Generation API
993// ================================================================
994
995#[cfg(feature = "audio")]
996pub use audio_generation::*;
997
998#[cfg(feature = "audio")]
999#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
1000mod audio_generation {
1001    use super::Client;
1002    use crate::audio_generation::{
1003        self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
1004    };
1005    use crate::http_client::HttpClientExt;
1006    use bytes::Bytes;
1007    use serde_json::json;
1008
1009    #[derive(Clone)]
1010    pub struct AudioGenerationModel<T = reqwest::Client> {
1011        client: Client<T>,
1012        model: String,
1013    }
1014
1015    impl<T> AudioGenerationModel<T> {
1016        pub fn new(client: Client<T>, deployment_name: impl Into<String>) -> Self {
1017            Self {
1018                client,
1019                model: deployment_name.into(),
1020            }
1021        }
1022    }
1023
1024    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
1025    where
1026        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
1027    {
1028        type Response = Bytes;
1029        type Client = Client<T>;
1030
1031        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1032            Self::new(client.clone(), model)
1033        }
1034
1035        async fn audio_generation(
1036            &self,
1037            request: AudioGenerationRequest,
1038        ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
1039            let request = json!({
1040                "model": self.model,
1041                "input": request.text,
1042                "voice": request.voice,
1043                "speed": request.speed,
1044            });
1045
1046            let body = serde_json::to_vec(&request)?;
1047
1048            let req = self
1049                .client
1050                .post_audio_generation("/audio/speech")?
1051                .header("Content-Type", "application/json")
1052                .body(body)
1053                .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
1054
1055            let response = self.client.send::<_, Bytes>(req).await?;
1056            let status = response.status();
1057            let response_body = response.into_body().into_future().await?;
1058
1059            if !status.is_success() {
1060                return Err(AudioGenerationError::ProviderError(format!(
1061                    "{status}: {}",
1062                    String::from_utf8_lossy(&response_body)
1063                )));
1064            }
1065
1066            Ok(AudioGenerationResponse {
1067                audio: response_body.to_vec(),
1068                response: response_body,
1069            })
1070        }
1071    }
1072}
1073
1074#[cfg(test)]
1075mod azure_tests {
1076    use schemars::JsonSchema;
1077
1078    use super::*;
1079
1080    use crate::OneOrMany;
1081    use crate::client::{completion::CompletionClient, embeddings::EmbeddingsClient};
1082    use crate::completion::CompletionModel;
1083    use crate::embeddings::EmbeddingModel;
1084    use crate::prelude::TypedPrompt;
1085    use crate::providers::openai::GPT_5_MINI;
1086
1087    #[tokio::test]
1088    #[ignore]
1089    async fn test_azure_embedding() -> anyhow::Result<()> {
1090        let _ = tracing_subscriber::fmt::try_init();
1091
1092        let client = Client::from_env()?;
1093        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1094        let embeddings = model.embed_texts(vec!["Hello, world!".to_string()]).await?;
1095
1096        tracing::info!("Azure embedding: {:?}", embeddings);
1097        Ok(())
1098    }
1099
1100    #[tokio::test]
1101    #[ignore]
1102    async fn test_azure_embedding_dimensions() -> anyhow::Result<()> {
1103        let _ = tracing_subscriber::fmt::try_init();
1104
1105        let ndims = 256;
1106        let client = Client::from_env()?;
1107        let model = client.embedding_model_with_ndims(TEXT_EMBEDDING_3_SMALL, ndims);
1108        let embedding = model.embed_text("Hello, world!").await?;
1109
1110        anyhow::ensure!(
1111            embedding.vec.len() == ndims,
1112            "expected embedding dimensions {ndims}, got {}",
1113            embedding.vec.len()
1114        );
1115
1116        tracing::info!("Azure dimensions embedding: {:?}", embedding);
1117        Ok(())
1118    }
1119
1120    #[tokio::test]
1121    #[ignore]
1122    async fn test_azure_completion() -> anyhow::Result<()> {
1123        let _ = tracing_subscriber::fmt::try_init();
1124
1125        let client = Client::from_env()?;
1126        let model = client.completion_model(GPT_4O_MINI);
1127        let completion = model
1128            .completion(CompletionRequest {
1129                model: None,
1130                preamble: Some("You are a helpful assistant.".to_string()),
1131                chat_history: OneOrMany::one("Hello!".into()),
1132                documents: vec![],
1133                max_tokens: Some(100),
1134                temperature: Some(0.0),
1135                tools: vec![],
1136                tool_choice: None,
1137                additional_params: None,
1138                output_schema: None,
1139            })
1140            .await?;
1141
1142        tracing::info!("Azure completion: {:?}", completion);
1143        Ok(())
1144    }
1145
1146    #[tokio::test]
1147    #[ignore]
1148    async fn test_azure_structured_output() -> anyhow::Result<()> {
1149        let _ = tracing_subscriber::fmt::try_init();
1150
1151        #[derive(Debug, Deserialize, JsonSchema)]
1152        struct Person {
1153            name: String,
1154            age: u32,
1155        }
1156
1157        let client = Client::from_env()?;
1158        let agent = client
1159            .agent(GPT_5_MINI)
1160            .preamble("You are a helpful assistant that extracts personal details.")
1161            .max_tokens(100)
1162            .output_schema::<Person>()
1163            .build();
1164
1165        let result: Person = agent
1166            .prompt_typed("Hello! My name is John Doe and I'm 54 years old.")
1167            .await?;
1168
1169        anyhow::ensure!(
1170            result.name == "John Doe",
1171            "expected name John Doe, got {}",
1172            result.name
1173        );
1174        anyhow::ensure!(result.age == 54, "expected age 54, got {}", result.age);
1175
1176        tracing::info!("Extracted person: {:?}", result);
1177        Ok(())
1178    }
1179
1180    #[tokio::test]
1181    async fn test_client_initialization() {
1182        let _client = crate::providers::azure::Client::builder()
1183            .api_key("test")
1184            .azure_endpoint("test".to_string()) // add your endpoint here!
1185            .build()
1186            .expect("Client::builder() failed");
1187    }
1188}