Skip to main content

rig_core/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```no_run
5//! use rig_core::providers::azure;
6//! use rig_core::client::CompletionClient;
7//!
8//! # fn run() -> Result<(), Box<dyn std::error::Error>> {
9//! let client = azure::Client::builder()
10//!     .api_key("test")
11//!     .azure_endpoint("test".to_string()) // add your endpoint here!
12//!     .build()?;
13//!
14//! let gpt4o = client.completion_model(azure::GPT_4O);
15//! # Ok(())
16//! # }
17//! ```
18//!
19//! ## Authentication
20//! The authentication type used for the `azure` module is [`AzureOpenAIAuth`].
21//!
22//! By default, using a type that implements `Into<String>` as the input for the client builder will turn the type into a bearer auth token.
23//! If you want to use an API key, you need to use the type specifically.
24
25use std::fmt::Debug;
26
27use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
28use crate::client::{
29    self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
30    ProviderClient,
31};
32use crate::completion::GetTokenUsage;
33use crate::http_client::multipart::Part;
34use crate::http_client::{self, HttpClientExt, MultipartForm, bearer_auth_header};
35use crate::streaming::StreamingCompletionResponse;
36use crate::transcription::TranscriptionError;
37use crate::{
38    completion::{self, CompletionError, CompletionRequest},
39    embeddings::{self, EmbeddingError},
40    json_utils,
41    providers::openai,
42    telemetry::SpanCombinator,
43    transcription::{self},
44};
45use bytes::Bytes;
46use serde::{Deserialize, Serialize};
47use serde_json::json;
48// ================================================================
49// Main Azure OpenAI Client
50// ================================================================
51
52const DEFAULT_API_VERSION: &str = "2024-10-21";
53
54#[derive(Debug, Clone)]
55pub struct AzureExt {
56    endpoint: String,
57    api_version: String,
58}
59
60impl DebugExt for AzureExt {
61    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
62        [
63            ("endpoint", (&self.endpoint as &dyn Debug)),
64            ("api_version", (&self.api_version as &dyn Debug)),
65        ]
66        .into_iter()
67    }
68}
69
70// TODO: @FayCarsons - this should be a type-safe builder,
71// but that would require extending the `ProviderBuilder`
72// to have some notion of complete vs incomplete states in a
73// given extension builder
74#[derive(Debug, Clone)]
75pub struct AzureExtBuilder {
76    endpoint: Option<String>,
77    api_version: String,
78}
79
80impl Default for AzureExtBuilder {
81    fn default() -> Self {
82        Self {
83            endpoint: None,
84            api_version: DEFAULT_API_VERSION.into(),
85        }
86    }
87}
88
89pub type Client<H = reqwest::Client> = client::Client<AzureExt, H>;
90pub type ClientBuilder<H = crate::markers::Missing> =
91    client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H>;
92
93impl Provider for AzureExt {
94    type Builder = AzureExtBuilder;
95
96    /// Verifying Azure auth without consuming tokens is not supported
97    const VERIFY_PATH: &'static str = "";
98}
99
100impl<H> Capabilities<H> for AzureExt {
101    type Completion = Capable<CompletionModel<H>>;
102    type Embeddings = Capable<EmbeddingModel<H>>;
103    type Transcription = Capable<TranscriptionModel<H>>;
104    type ModelListing = Nothing;
105    #[cfg(feature = "image")]
106    type ImageGeneration = Nothing;
107    #[cfg(feature = "audio")]
108    type AudioGeneration = Capable<AudioGenerationModel<H>>;
109    type Rerank = Nothing;
110}
111
112impl ProviderBuilder for AzureExtBuilder {
113    type Extension<H>
114        = AzureExt
115    where
116        H: HttpClientExt;
117    type ApiKey = AzureOpenAIAuth;
118
119    const BASE_URL: &'static str = "";
120
121    fn build<H>(
122        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
123    ) -> http_client::Result<Self::Extension<H>>
124    where
125        H: HttpClientExt,
126    {
127        let AzureExtBuilder {
128            endpoint,
129            api_version,
130            ..
131        } = builder.ext().clone();
132
133        match endpoint {
134            Some(endpoint) => Ok(AzureExt {
135                endpoint,
136                api_version,
137            }),
138            None => Err(http_client::Error::Instance(
139                "Azure client must be provided an endpoint prior to building".into(),
140            )),
141        }
142    }
143
144    fn finish<H>(
145        &self,
146        mut builder: client::ClientBuilder<Self, Self::ApiKey, H>,
147    ) -> http_client::Result<client::ClientBuilder<Self, Self::ApiKey, H>> {
148        use AzureOpenAIAuth::*;
149
150        let auth = builder.get_api_key().clone();
151
152        match auth {
153            Token(token) => bearer_auth_header(builder.headers_mut(), token.as_str())?,
154            ApiKey(key) => {
155                let k = http::HeaderName::from_static("api-key");
156                let v = http::HeaderValue::from_str(key.as_str())?;
157
158                builder.headers_mut().insert(k, v);
159            }
160        }
161
162        Ok(builder)
163    }
164}
165
166impl<H> ClientBuilder<H> {
167    /// API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
168    pub fn api_version(mut self, api_version: &str) -> Self {
169        self.ext_mut().api_version = api_version.into();
170
171        self
172    }
173}
174
175impl<H> client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H> {
176    /// Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
177    pub fn azure_endpoint(self, endpoint: String) -> ClientBuilder<H> {
178        self.over_ext(|AzureExtBuilder { api_version, .. }| AzureExtBuilder {
179            endpoint: Some(endpoint),
180            api_version,
181        })
182    }
183}
184
185/// The authentication type for Azure OpenAI. Can either be an API key or a token.
186/// String types will automatically be coerced to a bearer auth token by default.
187#[derive(Clone)]
188pub enum AzureOpenAIAuth {
189    ApiKey(String),
190    Token(String),
191}
192
193impl ApiKey for AzureOpenAIAuth {}
194
195impl std::fmt::Debug for AzureOpenAIAuth {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        match self {
198            Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
199            Self::Token(_) => write!(f, "Token <REDACTED>"),
200        }
201    }
202}
203
204impl<S> From<S> for AzureOpenAIAuth
205where
206    S: Into<String>,
207{
208    fn from(token: S) -> Self {
209        AzureOpenAIAuth::Token(token.into())
210    }
211}
212
213impl<T> Client<T>
214where
215    T: HttpClientExt,
216{
217    fn endpoint(&self) -> &str {
218        &self.ext().endpoint
219    }
220
221    fn api_version(&self) -> &str {
222        &self.ext().api_version
223    }
224
225    fn post_embedding(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
226        let url = format!(
227            "{}/openai/deployments/{}/embeddings?api-version={}",
228            self.endpoint(),
229            deployment_id.trim_start_matches('/'),
230            self.api_version()
231        );
232
233        self.post(&url)
234    }
235
236    #[cfg(feature = "audio")]
237    fn post_audio_generation(
238        &self,
239        deployment_id: &str,
240    ) -> http_client::Result<http_client::Builder> {
241        let url = format!(
242            "{}/openai/deployments/{}/audio/speech?api-version={}",
243            self.endpoint(),
244            deployment_id.trim_start_matches('/'),
245            self.api_version()
246        );
247
248        self.post(url)
249    }
250
251    fn post_chat_completion(
252        &self,
253        deployment_id: &str,
254    ) -> http_client::Result<http_client::Builder> {
255        let url = format!(
256            "{}/openai/deployments/{}/chat/completions?api-version={}",
257            self.endpoint(),
258            deployment_id.trim_start_matches('/'),
259            self.api_version()
260        );
261
262        self.post(&url)
263    }
264
265    fn post_transcription(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
266        let url = format!(
267            "{}/openai/deployments/{}/audio/translations?api-version={}",
268            self.endpoint(),
269            deployment_id.trim_start_matches('/'),
270            self.api_version()
271        );
272
273        self.post(&url)
274    }
275
276    #[cfg(feature = "image")]
277    fn post_image_generation(
278        &self,
279        deployment_id: &str,
280    ) -> http_client::Result<http_client::Builder> {
281        let url = format!(
282            "{}/openai/deployments/{}/images/generations?api-version={}",
283            self.endpoint(),
284            deployment_id.trim_start_matches('/'),
285            self.api_version()
286        );
287
288        self.post(&url)
289    }
290}
291
292pub struct AzureOpenAIClientParams {
293    api_key: String,
294    version: String,
295    header: String,
296}
297
298impl ProviderClient for Client {
299    type Input = AzureOpenAIClientParams;
300    type Error = crate::client::ProviderClientError;
301
302    /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
303    fn from_env() -> Result<Self, Self::Error> {
304        let auth = if let Some(api_key) = crate::client::optional_env_var("AZURE_API_KEY")? {
305            AzureOpenAIAuth::ApiKey(api_key)
306        } else if let Some(token) = crate::client::optional_env_var("AZURE_TOKEN")? {
307            AzureOpenAIAuth::Token(token)
308        } else {
309            return Err(crate::client::ProviderClientError::InvalidConfiguration(
310                "either `AZURE_API_KEY` or `AZURE_TOKEN` must be set",
311            ));
312        };
313
314        let api_version = crate::client::required_env_var("AZURE_API_VERSION")?;
315        let azure_endpoint = crate::client::required_env_var("AZURE_ENDPOINT")?;
316
317        Self::builder()
318            .api_key(auth)
319            .azure_endpoint(azure_endpoint)
320            .api_version(&api_version)
321            .build()
322            .map_err(Into::into)
323    }
324
325    fn from_val(
326        AzureOpenAIClientParams {
327            api_key,
328            version,
329            header,
330        }: Self::Input,
331    ) -> Result<Self, Self::Error> {
332        let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
333
334        Self::builder()
335            .api_key(auth)
336            .azure_endpoint(header)
337            .api_version(&version)
338            .build()
339            .map_err(Into::into)
340    }
341}
342
343#[derive(Debug, Deserialize)]
344struct ApiErrorResponse {
345    message: String,
346}
347
348#[derive(Debug, Deserialize)]
349#[serde(untagged)]
350enum ApiResponse<T> {
351    Ok(T),
352    Err(ApiErrorResponse),
353}
354
355// ================================================================
356// Azure OpenAI Embedding API
357// ================================================================
358
359/// `text-embedding-3-large` embedding model
360pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
361/// `text-embedding-3-small` embedding model
362pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
363/// `text-embedding-ada-002` embedding model
364pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
365
366fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
367    match identifier {
368        TEXT_EMBEDDING_3_LARGE => Some(3_072),
369        TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
370        _ => None,
371    }
372}
373
374#[derive(Debug, Deserialize)]
375pub struct EmbeddingResponse {
376    pub object: String,
377    pub data: Vec<EmbeddingData>,
378    pub model: String,
379    pub usage: Usage,
380}
381
382impl From<ApiErrorResponse> for EmbeddingError {
383    fn from(err: ApiErrorResponse) -> Self {
384        EmbeddingError::ProviderError(err.message)
385    }
386}
387
388impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
389    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
390        match value {
391            ApiResponse::Ok(response) => Ok(response),
392            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
393        }
394    }
395}
396
397#[derive(Debug, Deserialize)]
398pub struct EmbeddingData {
399    pub object: String,
400    pub embedding: Vec<f64>,
401    pub index: usize,
402}
403
404#[derive(Clone, Debug, Deserialize)]
405pub struct Usage {
406    pub prompt_tokens: usize,
407    pub total_tokens: usize,
408}
409
410impl GetTokenUsage for Usage {
411    fn token_usage(&self) -> crate::completion::Usage {
412        let mut usage = crate::completion::Usage::new();
413
414        usage.input_tokens = self.prompt_tokens as u64;
415        usage.total_tokens = self.total_tokens as u64;
416        usage.output_tokens = usage.total_tokens - usage.input_tokens;
417
418        usage
419    }
420}
421
422impl std::fmt::Display for Usage {
423    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424        write!(
425            f,
426            "Prompt tokens: {} Total tokens: {}",
427            self.prompt_tokens, self.total_tokens
428        )
429    }
430}
431
432#[derive(Clone)]
433pub struct EmbeddingModel<T = reqwest::Client> {
434    client: Client<T>,
435    pub model: String,
436    ndims: usize,
437}
438
439impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
440where
441    T: HttpClientExt + Default + Clone + 'static,
442{
443    const MAX_DOCUMENTS: usize = 1024;
444
445    type Client = Client<T>;
446
447    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
448        Self::new(client.clone(), model, dims)
449    }
450
451    fn ndims(&self) -> usize {
452        self.ndims
453    }
454
455    async fn embed_texts(
456        &self,
457        documents: impl IntoIterator<Item = String>,
458    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
459        let documents = documents.into_iter().collect::<Vec<_>>();
460
461        let mut body = json!({
462            "input": documents,
463        });
464
465        let body_object = body.as_object_mut().ok_or_else(|| {
466            EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
467        })?;
468
469        if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
470            body_object.insert("dimensions".to_owned(), json!(self.ndims));
471        }
472
473        let body = serde_json::to_vec(&body)?;
474
475        let req = self
476            .client
477            .post_embedding(self.model.as_str())?
478            .body(body)
479            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
480
481        let response = self.client.send(req).await?;
482
483        if response.status().is_success() {
484            let body: Vec<u8> = response.into_body().await?;
485            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
486
487            match body {
488                ApiResponse::Ok(response) => {
489                    tracing::info!(target: "rig",
490                        "Azure embedding token usage: {}",
491                        response.usage
492                    );
493
494                    if response.data.len() != documents.len() {
495                        return Err(EmbeddingError::ResponseError(
496                            "Response data length does not match input length".into(),
497                        ));
498                    }
499
500                    Ok(response
501                        .data
502                        .into_iter()
503                        .zip(documents.into_iter())
504                        .map(|(embedding, document)| embeddings::Embedding {
505                            document,
506                            vec: embedding.embedding,
507                        })
508                        .collect())
509                }
510                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
511            }
512        } else {
513            let text = http_client::text(response).await?;
514            Err(EmbeddingError::ProviderError(text))
515        }
516    }
517}
518
519impl<T> EmbeddingModel<T> {
520    pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
521        let model = model.into();
522        let ndims = ndims
523            .or(model_dimensions_from_identifier(&model))
524            .unwrap_or_default();
525
526        Self {
527            client,
528            model,
529            ndims,
530        }
531    }
532
533    pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
534        let ndims = ndims.unwrap_or_default();
535
536        Self {
537            client,
538            model: model.into(),
539            ndims,
540        }
541    }
542}
543
544// ================================================================
545// Azure OpenAI Completion API
546// ================================================================
547
548/// `o1` completion model
549pub const O1: &str = "o1";
550/// `o1-preview` completion model
551pub const O1_PREVIEW: &str = "o1-preview";
552/// `o1-mini` completion model
553pub const O1_MINI: &str = "o1-mini";
554/// `gpt-4o` completion model
555pub const GPT_4O: &str = "gpt-4o";
556/// `gpt-4o-mini` completion model
557pub const GPT_4O_MINI: &str = "gpt-4o-mini";
558/// `gpt-4o-realtime-preview` completion model
559pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
560/// `gpt-4-turbo` completion model
561pub const GPT_4_TURBO: &str = "gpt-4";
562/// `gpt-4` completion model
563pub const GPT_4: &str = "gpt-4";
564/// `gpt-4-32k` completion model
565pub const GPT_4_32K: &str = "gpt-4-32k";
566/// `gpt-4-32k` completion model
567pub const GPT_4_32K_0613: &str = "gpt-4-32k";
568/// `gpt-3.5-turbo` completion model
569pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
570/// `gpt-3.5-turbo-instruct` completion model
571pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
572/// `gpt-3.5-turbo-16k` completion model
573pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
574
575#[derive(Debug, Serialize, Deserialize)]
576pub(super) struct AzureOpenAICompletionRequest {
577    model: String,
578    pub messages: Vec<openai::Message>,
579    #[serde(skip_serializing_if = "Option::is_none")]
580    temperature: Option<f64>,
581    #[serde(skip_serializing_if = "Vec::is_empty")]
582    tools: Vec<openai::ToolDefinition>,
583    #[serde(skip_serializing_if = "Option::is_none")]
584    tool_choice: Option<crate::providers::openai::ToolChoice>,
585    #[serde(flatten, skip_serializing_if = "Option::is_none")]
586    pub additional_params: Option<serde_json::Value>,
587}
588
589impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
590    type Error = CompletionError;
591
592    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
593        let chat_history = req.chat_history_with_documents();
594        let model = req.model.clone().unwrap_or_else(|| model.to_string());
595        if req.tool_choice.is_some() {
596            tracing::warn!("Tool choice is currently not supported in Azure OpenAI.");
597        }
598
599        let mut full_history: Vec<openai::Message> = match &req.preamble {
600            Some(preamble) => vec![openai::Message::system(preamble)],
601            None => vec![],
602        };
603
604        let chat_history: Vec<openai::Message> = chat_history
605            .into_iter()
606            .map(|message| message.try_into())
607            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
608            .into_iter()
609            .flatten()
610            .collect();
611
612        full_history.extend(chat_history);
613
614        let tool_choice = req
615            .tool_choice
616            .clone()
617            .map(crate::providers::openai::ToolChoice::try_from)
618            .transpose()?;
619
620        let additional_params = if let Some(schema) = req.output_schema {
621            let name = schema
622                .as_object()
623                .and_then(|o| o.get("title"))
624                .and_then(|v| v.as_str())
625                .unwrap_or("response_schema")
626                .to_string();
627            let mut schema_value = schema.to_value();
628            openai::sanitize_schema(&mut schema_value);
629            let response_format = serde_json::json!({
630                "response_format": {
631                    "type": "json_schema",
632                    "json_schema": {
633                        "name": name,
634                        "strict": true,
635                        "schema": schema_value
636                    }
637                }
638            });
639            Some(match req.additional_params {
640                Some(existing) => json_utils::merge(existing, response_format),
641                None => response_format,
642            })
643        } else {
644            req.additional_params
645        };
646
647        Ok(Self {
648            model: model.to_string(),
649            messages: full_history,
650            temperature: req.temperature,
651            tools: req
652                .tools
653                .clone()
654                .into_iter()
655                .map(openai::ToolDefinition::from)
656                .collect::<Vec<_>>(),
657            tool_choice,
658            additional_params,
659        })
660    }
661}
662
663#[derive(Clone)]
664pub struct CompletionModel<T = reqwest::Client> {
665    client: Client<T>,
666    /// Name of the model (e.g.: gpt-4o-mini)
667    pub model: String,
668}
669
670impl<T> CompletionModel<T> {
671    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
672        Self {
673            client,
674            model: model.into(),
675        }
676    }
677}
678
679impl<T> completion::CompletionModel for CompletionModel<T>
680where
681    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
682{
683    type Response = openai::CompletionResponse;
684    type StreamingResponse = openai::StreamingCompletionResponse;
685    type Client = Client<T>;
686
687    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
688        Self::new(client.clone(), model.into())
689    }
690
691    async fn completion(
692        &self,
693        completion_request: CompletionRequest,
694    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
695        let span = if tracing::Span::current().is_disabled() {
696            info_span!(
697                target: "rig::completions",
698                "chat",
699                gen_ai.operation.name = "chat",
700                gen_ai.provider.name = "azure.openai",
701                gen_ai.request.model = self.model,
702                gen_ai.system_instructions = &completion_request.preamble,
703                gen_ai.response.id = tracing::field::Empty,
704                gen_ai.response.model = tracing::field::Empty,
705                gen_ai.usage.output_tokens = tracing::field::Empty,
706                gen_ai.usage.input_tokens = tracing::field::Empty,
707                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
708            )
709        } else {
710            tracing::Span::current()
711        };
712
713        let request =
714            AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
715
716        if enabled!(Level::TRACE) {
717            tracing::trace!(target: "rig::completions",
718                "Azure OpenAI completion request: {}",
719                serde_json::to_string_pretty(&request)?
720            );
721        }
722
723        let body = serde_json::to_vec(&request)?;
724
725        let req = self
726            .client
727            .post_chat_completion(&self.model)?
728            .body(body)
729            .map_err(http_client::Error::from)?;
730
731        async move {
732            let response = self.client.send::<_, Bytes>(req).await?;
733
734            let status = response.status();
735            let response_body = response.into_body().into_future().await?.to_vec();
736
737            if status.is_success() {
738                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
739                    &response_body,
740                )? {
741                    ApiResponse::Ok(response) => {
742                        let span = tracing::Span::current();
743                        span.record_response_metadata(&response);
744                        span.record_token_usage(&response.usage);
745                        if enabled!(Level::TRACE) {
746                            tracing::trace!(target: "rig::completions",
747                                "Azure OpenAI completion response: {}",
748                                serde_json::to_string_pretty(&response)?
749                            );
750                        }
751                        response.try_into()
752                    }
753                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
754                }
755            } else {
756                Err(CompletionError::ProviderError(
757                    String::from_utf8_lossy(&response_body).to_string(),
758                ))
759            }
760        }
761        .instrument(span)
762        .await
763    }
764
765    async fn stream(
766        &self,
767        completion_request: CompletionRequest,
768    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
769        let preamble = completion_request.preamble.clone();
770        let mut request =
771            AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
772
773        let params = json_utils::merge(
774            request.additional_params.unwrap_or(serde_json::json!({})),
775            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
776        );
777
778        request.additional_params = Some(params);
779
780        if enabled!(Level::TRACE) {
781            tracing::trace!(target: "rig::completions",
782                "Azure OpenAI completion request: {}",
783                serde_json::to_string_pretty(&request)?
784            );
785        }
786
787        let body = serde_json::to_vec(&request)?;
788
789        let req = self
790            .client
791            .post_chat_completion(&self.model)?
792            .body(body)
793            .map_err(http_client::Error::from)?;
794
795        let span = if tracing::Span::current().is_disabled() {
796            info_span!(
797                target: "rig::completions",
798                "chat_streaming",
799                gen_ai.operation.name = "chat_streaming",
800                gen_ai.provider.name = "azure.openai",
801                gen_ai.request.model = self.model,
802                gen_ai.system_instructions = &preamble,
803                gen_ai.response.id = tracing::field::Empty,
804                gen_ai.response.model = tracing::field::Empty,
805                gen_ai.usage.output_tokens = tracing::field::Empty,
806                gen_ai.usage.input_tokens = tracing::field::Empty,
807                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
808            )
809        } else {
810            tracing::Span::current()
811        };
812
813        tracing_futures::Instrument::instrument(
814            send_compatible_streaming_request(self.client.clone(), req),
815            span,
816        )
817        .await
818    }
819}
820
821// ================================================================
822// Azure OpenAI Transcription API
823// ================================================================
824
825#[derive(Clone)]
826pub struct TranscriptionModel<T = reqwest::Client> {
827    client: Client<T>,
828    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
829    pub model: String,
830}
831
832impl<T> TranscriptionModel<T> {
833    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
834        Self {
835            client,
836            model: model.into(),
837        }
838    }
839}
840
841impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
842where
843    T: HttpClientExt + Clone + 'static,
844{
845    type Response = TranscriptionResponse;
846    type Client = Client<T>;
847
848    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
849        Self::new(client.clone(), model)
850    }
851
852    async fn transcription(
853        &self,
854        request: transcription::TranscriptionRequest,
855    ) -> Result<
856        transcription::TranscriptionResponse<Self::Response>,
857        transcription::TranscriptionError,
858    > {
859        let data = request.data;
860
861        let mut body =
862            MultipartForm::new().part(Part::bytes("file", data).filename(request.filename.clone()));
863
864        if let Some(prompt) = request.prompt {
865            body = body.text("prompt", prompt.clone());
866        }
867
868        if let Some(ref temperature) = request.temperature {
869            body = body.text("temperature", temperature.to_string());
870        }
871
872        if let Some(ref additional_params) = request.additional_params {
873            let params = additional_params.as_object().ok_or_else(|| {
874                TranscriptionError::RequestError(Box::new(std::io::Error::new(
875                    std::io::ErrorKind::InvalidInput,
876                    "additional transcription parameters must be a JSON object",
877                )))
878            })?;
879
880            for (key, value) in params {
881                body = body.text(key.to_owned(), value.to_string());
882            }
883        }
884
885        let req = self
886            .client
887            .post_transcription(&self.model)?
888            .body(body)
889            .map_err(|e| TranscriptionError::HttpError(e.into()))?;
890
891        let response = self.client.send_multipart::<Bytes>(req).await?;
892        let status = response.status();
893        let response_body = response.into_body().into_future().await?.to_vec();
894
895        if status.is_success() {
896            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
897                ApiResponse::Ok(response) => response.try_into(),
898                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
899                    api_error_response.message,
900                )),
901            }
902        } else {
903            Err(TranscriptionError::ProviderError(
904                String::from_utf8_lossy(&response_body).to_string(),
905            ))
906        }
907    }
908}
909
910// ================================================================
911// Azure OpenAI Image Generation API
912// ================================================================
913#[cfg(feature = "image")]
914pub use image_generation::*;
915use tracing::{Instrument, Level, enabled, info_span};
916#[cfg(feature = "image")]
917#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
918mod image_generation {
919    use crate::http_client::HttpClientExt;
920    use crate::image_generation;
921    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
922    use crate::providers::azure::{ApiResponse, Client};
923    use crate::providers::openai::ImageGenerationResponse;
924    use bytes::Bytes;
925    use serde_json::json;
926
927    #[derive(Clone)]
928    pub struct ImageGenerationModel<T = reqwest::Client> {
929        client: Client<T>,
930        pub model: String,
931    }
932
933    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
934    where
935        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
936    {
937        type Response = ImageGenerationResponse;
938
939        type Client = Client<T>;
940
941        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
942            Self {
943                client: client.clone(),
944                model: model.into(),
945            }
946        }
947
948        async fn image_generation(
949            &self,
950            generation_request: ImageGenerationRequest,
951        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
952        {
953            let request = json!({
954                "model": self.model,
955                "prompt": generation_request.prompt,
956                "size": format!("{}x{}", generation_request.width, generation_request.height),
957                "response_format": "b64_json"
958            });
959
960            let body = serde_json::to_vec(&request)?;
961
962            let req = self
963                .client
964                .post_image_generation(&self.model)?
965                .body(body)
966                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
967
968            let response = self.client.send::<_, Bytes>(req).await?;
969            let status = response.status();
970            let response_body = response.into_body().into_future().await?.to_vec();
971
972            if !status.is_success() {
973                return Err(ImageGenerationError::ProviderError(format!(
974                    "{status}: {}",
975                    String::from_utf8_lossy(&response_body)
976                )));
977            }
978
979            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
980                ApiResponse::Ok(response) => response.try_into(),
981                ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
982            }
983        }
984    }
985}
986// ================================================================
987// Azure OpenAI Audio Generation API
988// ================================================================
989
990#[cfg(feature = "audio")]
991pub use audio_generation::*;
992
993#[cfg(feature = "audio")]
994#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
995mod audio_generation {
996    use super::Client;
997    use crate::audio_generation::{
998        self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
999    };
1000    use crate::http_client::HttpClientExt;
1001    use bytes::Bytes;
1002    use serde_json::json;
1003
1004    #[derive(Clone)]
1005    pub struct AudioGenerationModel<T = reqwest::Client> {
1006        client: Client<T>,
1007        model: String,
1008    }
1009
1010    impl<T> AudioGenerationModel<T> {
1011        pub fn new(client: Client<T>, deployment_name: impl Into<String>) -> Self {
1012            Self {
1013                client,
1014                model: deployment_name.into(),
1015            }
1016        }
1017    }
1018
1019    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
1020    where
1021        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
1022    {
1023        type Response = Bytes;
1024        type Client = Client<T>;
1025
1026        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1027            Self::new(client.clone(), model)
1028        }
1029
1030        async fn audio_generation(
1031            &self,
1032            request: AudioGenerationRequest,
1033        ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
1034            let request = json!({
1035                "model": self.model,
1036                "input": request.text,
1037                "voice": request.voice,
1038                "speed": request.speed,
1039            });
1040
1041            let body = serde_json::to_vec(&request)?;
1042
1043            let req = self
1044                .client
1045                .post_audio_generation("/audio/speech")?
1046                .header("Content-Type", "application/json")
1047                .body(body)
1048                .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
1049
1050            let response = self.client.send::<_, Bytes>(req).await?;
1051            let status = response.status();
1052            let response_body = response.into_body().into_future().await?;
1053
1054            if !status.is_success() {
1055                return Err(AudioGenerationError::ProviderError(format!(
1056                    "{status}: {}",
1057                    String::from_utf8_lossy(&response_body)
1058                )));
1059            }
1060
1061            Ok(AudioGenerationResponse {
1062                audio: response_body.to_vec(),
1063                response: response_body,
1064            })
1065        }
1066    }
1067}
1068
1069#[cfg(test)]
1070mod azure_tests {
1071    use schemars::JsonSchema;
1072
1073    use super::*;
1074
1075    use crate::OneOrMany;
1076    use crate::client::{completion::CompletionClient, embeddings::EmbeddingsClient};
1077    use crate::completion::CompletionModel;
1078    use crate::embeddings::EmbeddingModel;
1079    use crate::prelude::TypedPrompt;
1080    use crate::providers::openai::GPT_5_MINI;
1081
1082    #[tokio::test]
1083    #[ignore]
1084    async fn test_azure_embedding() -> anyhow::Result<()> {
1085        let _ = tracing_subscriber::fmt::try_init();
1086
1087        let client = Client::from_env()?;
1088        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1089        let embeddings = model.embed_texts(vec!["Hello, world!".to_string()]).await?;
1090
1091        tracing::info!("Azure embedding: {:?}", embeddings);
1092        Ok(())
1093    }
1094
1095    #[tokio::test]
1096    #[ignore]
1097    async fn test_azure_embedding_dimensions() -> anyhow::Result<()> {
1098        let _ = tracing_subscriber::fmt::try_init();
1099
1100        let ndims = 256;
1101        let client = Client::from_env()?;
1102        let model = client.embedding_model_with_ndims(TEXT_EMBEDDING_3_SMALL, ndims);
1103        let embedding = model.embed_text("Hello, world!").await?;
1104
1105        anyhow::ensure!(
1106            embedding.vec.len() == ndims,
1107            "expected embedding dimensions {ndims}, got {}",
1108            embedding.vec.len()
1109        );
1110
1111        tracing::info!("Azure dimensions embedding: {:?}", embedding);
1112        Ok(())
1113    }
1114
1115    #[tokio::test]
1116    #[ignore]
1117    async fn test_azure_completion() -> anyhow::Result<()> {
1118        let _ = tracing_subscriber::fmt::try_init();
1119
1120        let client = Client::from_env()?;
1121        let model = client.completion_model(GPT_4O_MINI);
1122        let completion = model
1123            .completion(CompletionRequest {
1124                model: None,
1125                preamble: Some("You are a helpful assistant.".to_string()),
1126                chat_history: OneOrMany::one("Hello!".into()),
1127                documents: vec![],
1128                max_tokens: Some(100),
1129                temperature: Some(0.0),
1130                tools: vec![],
1131                tool_choice: None,
1132                additional_params: None,
1133                output_schema: None,
1134            })
1135            .await?;
1136
1137        tracing::info!("Azure completion: {:?}", completion);
1138        Ok(())
1139    }
1140
1141    #[tokio::test]
1142    #[ignore]
1143    async fn test_azure_structured_output() -> anyhow::Result<()> {
1144        let _ = tracing_subscriber::fmt::try_init();
1145
1146        #[derive(Debug, Deserialize, JsonSchema)]
1147        struct Person {
1148            name: String,
1149            age: u32,
1150        }
1151
1152        let client = Client::from_env()?;
1153        let agent = client
1154            .agent(GPT_5_MINI)
1155            .preamble("You are a helpful assistant that extracts personal details.")
1156            .max_tokens(100)
1157            .output_schema::<Person>()
1158            .build();
1159
1160        let result: Person = agent
1161            .prompt_typed("Hello! My name is John Doe and I'm 54 years old.")
1162            .await?;
1163
1164        anyhow::ensure!(
1165            result.name == "John Doe",
1166            "expected name John Doe, got {}",
1167            result.name
1168        );
1169        anyhow::ensure!(result.age == 54, "expected age 54, got {}", result.age);
1170
1171        tracing::info!("Extracted person: {:?}", result);
1172        Ok(())
1173    }
1174
1175    #[tokio::test]
1176    async fn test_client_initialization() {
1177        let _client = crate::providers::azure::Client::builder()
1178            .api_key("test")
1179            .azure_endpoint("test".to_string()) // add your endpoint here!
1180            .build()
1181            .expect("Client::builder() failed");
1182    }
1183}