Skip to main content

rig/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::azure;
6//! use rig::client::CompletionClient;
7//!
8//! let client = azure::Client::builder()
9//!     .api_key("test")
10//!     .azure_endpoint("test".to_string()) // add your endpoint here!
11//!     .build()?;
12//!
13//! let gpt4o = client.completion_model(azure::GPT_4O);
14//! ```
15//!
16//! ## Authentication
17//! The authentication type used for the `azure` module is [`AzureOpenAIAuth`].
18//!
19//! By default, using a type that implements `Into<String>` as the input for the client builder will turn the type into a bearer auth token.
20//! If you want to use an API key, you need to use the type specifically.
21
22use std::fmt::Debug;
23
24use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
25use crate::client::{
26    self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
27    ProviderClient,
28};
29use crate::completion::GetTokenUsage;
30use crate::http_client::multipart::Part;
31use crate::http_client::{self, HttpClientExt, MultipartForm, bearer_auth_header};
32use crate::streaming::StreamingCompletionResponse;
33use crate::transcription::TranscriptionError;
34use crate::{
35    completion::{self, CompletionError, CompletionRequest},
36    embeddings::{self, EmbeddingError},
37    json_utils,
38    providers::openai,
39    telemetry::SpanCombinator,
40    transcription::{self},
41};
42use bytes::Bytes;
43use serde::{Deserialize, Serialize};
44use serde_json::json;
45// ================================================================
46// Main Azure OpenAI Client
47// ================================================================
48
49const DEFAULT_API_VERSION: &str = "2024-10-21";
50
51#[derive(Debug, Clone)]
52pub struct AzureExt {
53    endpoint: String,
54    api_version: String,
55}
56
57impl DebugExt for AzureExt {
58    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
59        [
60            ("endpoint", (&self.endpoint as &dyn Debug)),
61            ("api_version", (&self.api_version as &dyn Debug)),
62        ]
63        .into_iter()
64    }
65}
66
67// TODO: @FayCarsons - this should be a type-safe builder,
68// but that would require extending the `ProviderBuilder`
69// to have some notion of complete vs incomplete states in a
70// given extension builder
71#[derive(Debug, Clone)]
72pub struct AzureExtBuilder {
73    endpoint: Option<String>,
74    api_version: String,
75}
76
77impl Default for AzureExtBuilder {
78    fn default() -> Self {
79        Self {
80            endpoint: None,
81            api_version: DEFAULT_API_VERSION.into(),
82        }
83    }
84}
85
86pub type Client<H = reqwest::Client> = client::Client<AzureExt, H>;
87pub type ClientBuilder<H = reqwest::Client> =
88    client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H>;
89
90impl Provider for AzureExt {
91    type Builder = AzureExtBuilder;
92
93    /// Verifying Azure auth without consuming tokens is not supported
94    const VERIFY_PATH: &'static str = "";
95}
96
97impl<H> Capabilities<H> for AzureExt {
98    type Completion = Capable<CompletionModel<H>>;
99    type Embeddings = Capable<EmbeddingModel<H>>;
100    type Transcription = Capable<TranscriptionModel<H>>;
101    type ModelListing = Nothing;
102    #[cfg(feature = "image")]
103    type ImageGeneration = Nothing;
104    #[cfg(feature = "audio")]
105    type AudioGeneration = Capable<AudioGenerationModel<H>>;
106}
107
108impl ProviderBuilder for AzureExtBuilder {
109    type Extension<H>
110        = AzureExt
111    where
112        H: HttpClientExt;
113    type ApiKey = AzureOpenAIAuth;
114
115    const BASE_URL: &'static str = "";
116
117    fn build<H>(
118        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
119    ) -> http_client::Result<Self::Extension<H>>
120    where
121        H: HttpClientExt,
122    {
123        let AzureExtBuilder {
124            endpoint,
125            api_version,
126            ..
127        } = builder.ext().clone();
128
129        match endpoint {
130            Some(endpoint) => Ok(AzureExt {
131                endpoint,
132                api_version,
133            }),
134            None => Err(http_client::Error::Instance(
135                "Azure client must be provided an endpoint prior to building".into(),
136            )),
137        }
138    }
139
140    fn finish<H>(
141        &self,
142        mut builder: client::ClientBuilder<Self, Self::ApiKey, H>,
143    ) -> http_client::Result<client::ClientBuilder<Self, Self::ApiKey, H>> {
144        use AzureOpenAIAuth::*;
145
146        let auth = builder.get_api_key().clone();
147
148        match auth {
149            Token(token) => bearer_auth_header(builder.headers_mut(), token.as_str())?,
150            ApiKey(key) => {
151                let k = http::HeaderName::from_static("api-key");
152                let v = http::HeaderValue::from_str(key.as_str())?;
153
154                builder.headers_mut().insert(k, v);
155            }
156        }
157
158        Ok(builder)
159    }
160}
161
162impl<H> ClientBuilder<H> {
163    /// API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
164    pub fn api_version(mut self, api_version: &str) -> Self {
165        self.ext_mut().api_version = api_version.into();
166
167        self
168    }
169}
170
171impl<H> client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H> {
172    /// Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
173    pub fn azure_endpoint(self, endpoint: String) -> ClientBuilder<H> {
174        self.over_ext(|AzureExtBuilder { api_version, .. }| AzureExtBuilder {
175            endpoint: Some(endpoint),
176            api_version,
177        })
178    }
179}
180
181/// The authentication type for Azure OpenAI. Can either be an API key or a token.
182/// String types will automatically be coerced to a bearer auth token by default.
183#[derive(Clone)]
184pub enum AzureOpenAIAuth {
185    ApiKey(String),
186    Token(String),
187}
188
189impl ApiKey for AzureOpenAIAuth {}
190
191impl std::fmt::Debug for AzureOpenAIAuth {
192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193        match self {
194            Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
195            Self::Token(_) => write!(f, "Token <REDACTED>"),
196        }
197    }
198}
199
200impl<S> From<S> for AzureOpenAIAuth
201where
202    S: Into<String>,
203{
204    fn from(token: S) -> Self {
205        AzureOpenAIAuth::Token(token.into())
206    }
207}
208
209impl<T> Client<T>
210where
211    T: HttpClientExt,
212{
213    fn endpoint(&self) -> &str {
214        &self.ext().endpoint
215    }
216
217    fn api_version(&self) -> &str {
218        &self.ext().api_version
219    }
220
221    fn post_embedding(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
222        let url = format!(
223            "{}/openai/deployments/{}/embeddings?api-version={}",
224            self.endpoint(),
225            deployment_id.trim_start_matches('/'),
226            self.api_version()
227        );
228
229        self.post(&url)
230    }
231
232    #[cfg(feature = "audio")]
233    fn post_audio_generation(
234        &self,
235        deployment_id: &str,
236    ) -> http_client::Result<http_client::Builder> {
237        let url = format!(
238            "{}/openai/deployments/{}/audio/speech?api-version={}",
239            self.endpoint(),
240            deployment_id.trim_start_matches('/'),
241            self.api_version()
242        );
243
244        self.post(url)
245    }
246
247    fn post_chat_completion(
248        &self,
249        deployment_id: &str,
250    ) -> http_client::Result<http_client::Builder> {
251        let url = format!(
252            "{}/openai/deployments/{}/chat/completions?api-version={}",
253            self.endpoint(),
254            deployment_id.trim_start_matches('/'),
255            self.api_version()
256        );
257
258        self.post(&url)
259    }
260
261    fn post_transcription(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
262        let url = format!(
263            "{}/openai/deployments/{}/audio/translations?api-version={}",
264            self.endpoint(),
265            deployment_id.trim_start_matches('/'),
266            self.api_version()
267        );
268
269        self.post(&url)
270    }
271
272    #[cfg(feature = "image")]
273    fn post_image_generation(
274        &self,
275        deployment_id: &str,
276    ) -> http_client::Result<http_client::Builder> {
277        let url = format!(
278            "{}/openai/deployments/{}/images/generations?api-version={}",
279            self.endpoint(),
280            deployment_id.trim_start_matches('/'),
281            self.api_version()
282        );
283
284        self.post(&url)
285    }
286}
287
288pub struct AzureOpenAIClientParams {
289    api_key: String,
290    version: String,
291    header: String,
292}
293
294impl ProviderClient for Client {
295    type Input = AzureOpenAIClientParams;
296
297    /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
298    fn from_env() -> Self {
299        let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
300            AzureOpenAIAuth::ApiKey(api_key)
301        } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
302            AzureOpenAIAuth::Token(token)
303        } else {
304            panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
305        };
306
307        let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
308        let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
309
310        Self::builder()
311            .api_key(auth)
312            .azure_endpoint(azure_endpoint)
313            .api_version(&api_version)
314            .build()
315            .unwrap()
316    }
317
318    fn from_val(
319        AzureOpenAIClientParams {
320            api_key,
321            version,
322            header,
323        }: Self::Input,
324    ) -> Self {
325        let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
326
327        Self::builder()
328            .api_key(auth)
329            .azure_endpoint(header)
330            .api_version(&version)
331            .build()
332            .unwrap()
333    }
334}
335
336#[derive(Debug, Deserialize)]
337struct ApiErrorResponse {
338    message: String,
339}
340
341#[derive(Debug, Deserialize)]
342#[serde(untagged)]
343enum ApiResponse<T> {
344    Ok(T),
345    Err(ApiErrorResponse),
346}
347
348// ================================================================
349// Azure OpenAI Embedding API
350// ================================================================
351
352/// `text-embedding-3-large` embedding model
353pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
354/// `text-embedding-3-small` embedding model
355pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
356/// `text-embedding-ada-002` embedding model
357pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
358
359fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
360    match identifier {
361        TEXT_EMBEDDING_3_LARGE => Some(3_072),
362        TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
363        _ => None,
364    }
365}
366
367#[derive(Debug, Deserialize)]
368pub struct EmbeddingResponse {
369    pub object: String,
370    pub data: Vec<EmbeddingData>,
371    pub model: String,
372    pub usage: Usage,
373}
374
375impl From<ApiErrorResponse> for EmbeddingError {
376    fn from(err: ApiErrorResponse) -> Self {
377        EmbeddingError::ProviderError(err.message)
378    }
379}
380
381impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
382    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
383        match value {
384            ApiResponse::Ok(response) => Ok(response),
385            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
386        }
387    }
388}
389
390#[derive(Debug, Deserialize)]
391pub struct EmbeddingData {
392    pub object: String,
393    pub embedding: Vec<f64>,
394    pub index: usize,
395}
396
397#[derive(Clone, Debug, Deserialize)]
398pub struct Usage {
399    pub prompt_tokens: usize,
400    pub total_tokens: usize,
401}
402
403impl GetTokenUsage for Usage {
404    fn token_usage(&self) -> Option<crate::completion::Usage> {
405        let mut usage = crate::completion::Usage::new();
406
407        usage.input_tokens = self.prompt_tokens as u64;
408        usage.total_tokens = self.total_tokens as u64;
409        usage.output_tokens = usage.total_tokens - usage.input_tokens;
410
411        Some(usage)
412    }
413}
414
415impl std::fmt::Display for Usage {
416    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417        write!(
418            f,
419            "Prompt tokens: {} Total tokens: {}",
420            self.prompt_tokens, self.total_tokens
421        )
422    }
423}
424
425#[derive(Clone)]
426pub struct EmbeddingModel<T = reqwest::Client> {
427    client: Client<T>,
428    pub model: String,
429    ndims: usize,
430}
431
432impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
433where
434    T: HttpClientExt + Default + Clone + 'static,
435{
436    const MAX_DOCUMENTS: usize = 1024;
437
438    type Client = Client<T>;
439
440    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
441        Self::new(client.clone(), model, dims)
442    }
443
444    fn ndims(&self) -> usize {
445        self.ndims
446    }
447
448    async fn embed_texts(
449        &self,
450        documents: impl IntoIterator<Item = String>,
451    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
452        let documents = documents.into_iter().collect::<Vec<_>>();
453
454        let mut body = json!({
455            "input": documents,
456        });
457
458        if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
459            body["dimensions"] = json!(self.ndims);
460        }
461
462        let body = serde_json::to_vec(&body)?;
463
464        let req = self
465            .client
466            .post_embedding(self.model.as_str())?
467            .body(body)
468            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
469
470        let response = self.client.send(req).await?;
471
472        if response.status().is_success() {
473            let body: Vec<u8> = response.into_body().await?;
474            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
475
476            match body {
477                ApiResponse::Ok(response) => {
478                    tracing::info!(target: "rig",
479                        "Azure embedding token usage: {}",
480                        response.usage
481                    );
482
483                    if response.data.len() != documents.len() {
484                        return Err(EmbeddingError::ResponseError(
485                            "Response data length does not match input length".into(),
486                        ));
487                    }
488
489                    Ok(response
490                        .data
491                        .into_iter()
492                        .zip(documents.into_iter())
493                        .map(|(embedding, document)| embeddings::Embedding {
494                            document,
495                            vec: embedding.embedding,
496                        })
497                        .collect())
498                }
499                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
500            }
501        } else {
502            let text = http_client::text(response).await?;
503            Err(EmbeddingError::ProviderError(text))
504        }
505    }
506}
507
508impl<T> EmbeddingModel<T> {
509    pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
510        let model = model.into();
511        let ndims = ndims
512            .or(model_dimensions_from_identifier(&model))
513            .unwrap_or_default();
514
515        Self {
516            client,
517            model,
518            ndims,
519        }
520    }
521
522    pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
523        let ndims = ndims.unwrap_or_default();
524
525        Self {
526            client,
527            model: model.into(),
528            ndims,
529        }
530    }
531}
532
533// ================================================================
534// Azure OpenAI Completion API
535// ================================================================
536
537/// `o1` completion model
538pub const O1: &str = "o1";
539/// `o1-preview` completion model
540pub const O1_PREVIEW: &str = "o1-preview";
541/// `o1-mini` completion model
542pub const O1_MINI: &str = "o1-mini";
543/// `gpt-4o` completion model
544pub const GPT_4O: &str = "gpt-4o";
545/// `gpt-4o-mini` completion model
546pub const GPT_4O_MINI: &str = "gpt-4o-mini";
547/// `gpt-4o-realtime-preview` completion model
548pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
549/// `gpt-4-turbo` completion model
550pub const GPT_4_TURBO: &str = "gpt-4";
551/// `gpt-4` completion model
552pub const GPT_4: &str = "gpt-4";
553/// `gpt-4-32k` completion model
554pub const GPT_4_32K: &str = "gpt-4-32k";
555/// `gpt-4-32k` completion model
556pub const GPT_4_32K_0613: &str = "gpt-4-32k";
557/// `gpt-3.5-turbo` completion model
558pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
559/// `gpt-3.5-turbo-instruct` completion model
560pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
561/// `gpt-3.5-turbo-16k` completion model
562pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
563
564#[derive(Debug, Serialize, Deserialize)]
565pub(super) struct AzureOpenAICompletionRequest {
566    model: String,
567    pub messages: Vec<openai::Message>,
568    #[serde(skip_serializing_if = "Option::is_none")]
569    temperature: Option<f64>,
570    #[serde(skip_serializing_if = "Vec::is_empty")]
571    tools: Vec<openai::ToolDefinition>,
572    #[serde(skip_serializing_if = "Option::is_none")]
573    tool_choice: Option<crate::providers::openai::ToolChoice>,
574    #[serde(flatten, skip_serializing_if = "Option::is_none")]
575    pub additional_params: Option<serde_json::Value>,
576}
577
578impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
579    type Error = CompletionError;
580
581    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
582        let model = req.model.clone().unwrap_or_else(|| model.to_string());
583        //FIXME: Must fix!
584        if req.tool_choice.is_some() {
585            tracing::warn!(
586                "Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25."
587            );
588        }
589
590        let mut full_history: Vec<openai::Message> = match &req.preamble {
591            Some(preamble) => vec![openai::Message::system(preamble)],
592            None => vec![],
593        };
594
595        if let Some(docs) = req.normalized_documents() {
596            let docs: Vec<openai::Message> = docs.try_into()?;
597            full_history.extend(docs);
598        }
599
600        let chat_history: Vec<openai::Message> = req
601            .chat_history
602            .clone()
603            .into_iter()
604            .map(|message| message.try_into())
605            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
606            .into_iter()
607            .flatten()
608            .collect();
609
610        full_history.extend(chat_history);
611
612        let tool_choice = req
613            .tool_choice
614            .clone()
615            .map(crate::providers::openai::ToolChoice::try_from)
616            .transpose()?;
617
618        let additional_params = if let Some(schema) = req.output_schema {
619            let name = schema
620                .as_object()
621                .and_then(|o| o.get("title"))
622                .and_then(|v| v.as_str())
623                .unwrap_or("response_schema")
624                .to_string();
625            let mut schema_value = schema.to_value();
626            openai::sanitize_schema(&mut schema_value);
627            let response_format = serde_json::json!({
628                "response_format": {
629                    "type": "json_schema",
630                    "json_schema": {
631                        "name": name,
632                        "strict": true,
633                        "schema": schema_value
634                    }
635                }
636            });
637            Some(match req.additional_params {
638                Some(existing) => json_utils::merge(existing, response_format),
639                None => response_format,
640            })
641        } else {
642            req.additional_params
643        };
644
645        Ok(Self {
646            model: model.to_string(),
647            messages: full_history,
648            temperature: req.temperature,
649            tools: req
650                .tools
651                .clone()
652                .into_iter()
653                .map(openai::ToolDefinition::from)
654                .collect::<Vec<_>>(),
655            tool_choice,
656            additional_params,
657        })
658    }
659}
660
661#[derive(Clone)]
662pub struct CompletionModel<T = reqwest::Client> {
663    client: Client<T>,
664    /// Name of the model (e.g.: gpt-4o-mini)
665    pub model: String,
666}
667
668impl<T> CompletionModel<T> {
669    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
670        Self {
671            client,
672            model: model.into(),
673        }
674    }
675}
676
677impl<T> completion::CompletionModel for CompletionModel<T>
678where
679    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
680{
681    type Response = openai::CompletionResponse;
682    type StreamingResponse = openai::StreamingCompletionResponse;
683    type Client = Client<T>;
684
685    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
686        Self::new(client.clone(), model.into())
687    }
688
689    async fn completion(
690        &self,
691        completion_request: CompletionRequest,
692    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
693        let span = if tracing::Span::current().is_disabled() {
694            info_span!(
695                target: "rig::completions",
696                "chat",
697                gen_ai.operation.name = "chat",
698                gen_ai.provider.name = "azure.openai",
699                gen_ai.request.model = self.model,
700                gen_ai.system_instructions = &completion_request.preamble,
701                gen_ai.response.id = tracing::field::Empty,
702                gen_ai.response.model = tracing::field::Empty,
703                gen_ai.usage.output_tokens = tracing::field::Empty,
704                gen_ai.usage.input_tokens = tracing::field::Empty,
705                gen_ai.usage.cached_tokens = tracing::field::Empty,
706            )
707        } else {
708            tracing::Span::current()
709        };
710
711        let request =
712            AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
713
714        if enabled!(Level::TRACE) {
715            tracing::trace!(target: "rig::completions",
716                "Azure OpenAI completion request: {}",
717                serde_json::to_string_pretty(&request)?
718            );
719        }
720
721        let body = serde_json::to_vec(&request)?;
722
723        let req = self
724            .client
725            .post_chat_completion(&self.model)?
726            .body(body)
727            .map_err(http_client::Error::from)?;
728
729        async move {
730            let response = self.client.send::<_, Bytes>(req).await?;
731
732            let status = response.status();
733            let response_body = response.into_body().into_future().await?.to_vec();
734
735            if status.is_success() {
736                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
737                    &response_body,
738                )? {
739                    ApiResponse::Ok(response) => {
740                        let span = tracing::Span::current();
741                        span.record_response_metadata(&response);
742                        span.record_token_usage(&response.usage);
743                        if enabled!(Level::TRACE) {
744                            tracing::trace!(target: "rig::completions",
745                                "Azure OpenAI completion response: {}",
746                                serde_json::to_string_pretty(&response)?
747                            );
748                        }
749                        response.try_into()
750                    }
751                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
752                }
753            } else {
754                Err(CompletionError::ProviderError(
755                    String::from_utf8_lossy(&response_body).to_string(),
756                ))
757            }
758        }
759        .instrument(span)
760        .await
761    }
762
763    async fn stream(
764        &self,
765        completion_request: CompletionRequest,
766    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
767        let preamble = completion_request.preamble.clone();
768        let mut request =
769            AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
770
771        let params = json_utils::merge(
772            request.additional_params.unwrap_or(serde_json::json!({})),
773            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
774        );
775
776        request.additional_params = Some(params);
777
778        if enabled!(Level::TRACE) {
779            tracing::trace!(target: "rig::completions",
780                "Azure OpenAI completion request: {}",
781                serde_json::to_string_pretty(&request)?
782            );
783        }
784
785        let body = serde_json::to_vec(&request)?;
786
787        let req = self
788            .client
789            .post_chat_completion(&self.model)?
790            .body(body)
791            .map_err(http_client::Error::from)?;
792
793        let span = if tracing::Span::current().is_disabled() {
794            info_span!(
795                target: "rig::completions",
796                "chat_streaming",
797                gen_ai.operation.name = "chat_streaming",
798                gen_ai.provider.name = "azure.openai",
799                gen_ai.request.model = self.model,
800                gen_ai.system_instructions = &preamble,
801                gen_ai.response.id = tracing::field::Empty,
802                gen_ai.response.model = tracing::field::Empty,
803                gen_ai.usage.output_tokens = tracing::field::Empty,
804                gen_ai.usage.input_tokens = tracing::field::Empty,
805                gen_ai.usage.cached_tokens = tracing::field::Empty,
806            )
807        } else {
808            tracing::Span::current()
809        };
810
811        tracing_futures::Instrument::instrument(
812            send_compatible_streaming_request(self.client.clone(), req),
813            span,
814        )
815        .await
816    }
817}
818
819// ================================================================
820// Azure OpenAI Transcription API
821// ================================================================
822
823#[derive(Clone)]
824pub struct TranscriptionModel<T = reqwest::Client> {
825    client: Client<T>,
826    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
827    pub model: String,
828}
829
830impl<T> TranscriptionModel<T> {
831    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
832        Self {
833            client,
834            model: model.into(),
835        }
836    }
837}
838
839impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
840where
841    T: HttpClientExt + Clone + 'static,
842{
843    type Response = TranscriptionResponse;
844    type Client = Client<T>;
845
846    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
847        Self::new(client.clone(), model)
848    }
849
850    async fn transcription(
851        &self,
852        request: transcription::TranscriptionRequest,
853    ) -> Result<
854        transcription::TranscriptionResponse<Self::Response>,
855        transcription::TranscriptionError,
856    > {
857        let data = request.data;
858
859        let mut body =
860            MultipartForm::new().part(Part::bytes("file", data).filename(request.filename.clone()));
861
862        if let Some(prompt) = request.prompt {
863            body = body.text("prompt", prompt.clone());
864        }
865
866        if let Some(ref temperature) = request.temperature {
867            body = body.text("temperature", temperature.to_string());
868        }
869
870        if let Some(ref additional_params) = request.additional_params {
871            for (key, value) in additional_params
872                .as_object()
873                .expect("Additional Parameters to OpenAI Transcription should be a map")
874            {
875                body = body.text(key.to_owned(), value.to_string());
876            }
877        }
878
879        let req = self
880            .client
881            .post_transcription(&self.model)?
882            .body(body)
883            .map_err(|e| TranscriptionError::HttpError(e.into()))?;
884
885        let response = self.client.send_multipart::<Bytes>(req).await?;
886        let status = response.status();
887        let response_body = response.into_body().into_future().await?.to_vec();
888
889        if status.is_success() {
890            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
891                ApiResponse::Ok(response) => response.try_into(),
892                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
893                    api_error_response.message,
894                )),
895            }
896        } else {
897            Err(TranscriptionError::ProviderError(
898                String::from_utf8_lossy(&response_body).to_string(),
899            ))
900        }
901    }
902}
903
904// ================================================================
905// Azure OpenAI Image Generation API
906// ================================================================
907#[cfg(feature = "image")]
908pub use image_generation::*;
909use tracing::{Instrument, Level, enabled, info_span};
910#[cfg(feature = "image")]
911#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
912mod image_generation {
913    use crate::http_client::HttpClientExt;
914    use crate::image_generation;
915    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
916    use crate::providers::azure::{ApiResponse, Client};
917    use crate::providers::openai::ImageGenerationResponse;
918    use bytes::Bytes;
919    use serde_json::json;
920
921    #[derive(Clone)]
922    pub struct ImageGenerationModel<T = reqwest::Client> {
923        client: Client<T>,
924        pub model: String,
925    }
926
927    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
928    where
929        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
930    {
931        type Response = ImageGenerationResponse;
932
933        type Client = Client<T>;
934
935        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
936            Self {
937                client: client.clone(),
938                model: model.into(),
939            }
940        }
941
942        async fn image_generation(
943            &self,
944            generation_request: ImageGenerationRequest,
945        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
946        {
947            let request = json!({
948                "model": self.model,
949                "prompt": generation_request.prompt,
950                "size": format!("{}x{}", generation_request.width, generation_request.height),
951                "response_format": "b64_json"
952            });
953
954            let body = serde_json::to_vec(&request)?;
955
956            let req = self
957                .client
958                .post_image_generation(&self.model)?
959                .body(body)
960                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
961
962            let response = self.client.send::<_, Bytes>(req).await?;
963            let status = response.status();
964            let response_body = response.into_body().into_future().await?.to_vec();
965
966            if !status.is_success() {
967                return Err(ImageGenerationError::ProviderError(format!(
968                    "{status}: {}",
969                    String::from_utf8_lossy(&response_body)
970                )));
971            }
972
973            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
974                ApiResponse::Ok(response) => response.try_into(),
975                ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
976            }
977        }
978    }
979}
980// ================================================================
981// Azure OpenAI Audio Generation API
982// ================================================================
983
984#[cfg(feature = "audio")]
985pub use audio_generation::*;
986
987#[cfg(feature = "audio")]
988#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
989mod audio_generation {
990    use super::Client;
991    use crate::audio_generation::{
992        self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
993    };
994    use crate::http_client::HttpClientExt;
995    use bytes::Bytes;
996    use serde_json::json;
997
998    #[derive(Clone)]
999    pub struct AudioGenerationModel<T = reqwest::Client> {
1000        client: Client<T>,
1001        model: String,
1002    }
1003
1004    impl<T> AudioGenerationModel<T> {
1005        pub fn new(client: Client<T>, deployment_name: impl Into<String>) -> Self {
1006            Self {
1007                client,
1008                model: deployment_name.into(),
1009            }
1010        }
1011    }
1012
1013    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
1014    where
1015        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
1016    {
1017        type Response = Bytes;
1018        type Client = Client<T>;
1019
1020        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1021            Self::new(client.clone(), model)
1022        }
1023
1024        async fn audio_generation(
1025            &self,
1026            request: AudioGenerationRequest,
1027        ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
1028            let request = json!({
1029                "model": self.model,
1030                "input": request.text,
1031                "voice": request.voice,
1032                "speed": request.speed,
1033            });
1034
1035            let body = serde_json::to_vec(&request)?;
1036
1037            let req = self
1038                .client
1039                .post_audio_generation("/audio/speech")?
1040                .header("Content-Type", "application/json")
1041                .body(body)
1042                .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
1043
1044            let response = self.client.send::<_, Bytes>(req).await?;
1045            let status = response.status();
1046            let response_body = response.into_body().into_future().await?;
1047
1048            if !status.is_success() {
1049                return Err(AudioGenerationError::ProviderError(format!(
1050                    "{status}: {}",
1051                    String::from_utf8_lossy(&response_body)
1052                )));
1053            }
1054
1055            Ok(AudioGenerationResponse {
1056                audio: response_body.to_vec(),
1057                response: response_body,
1058            })
1059        }
1060    }
1061}
1062
1063#[cfg(test)]
1064mod azure_tests {
1065    use schemars::JsonSchema;
1066
1067    use super::*;
1068
1069    use crate::OneOrMany;
1070    use crate::client::{completion::CompletionClient, embeddings::EmbeddingsClient};
1071    use crate::completion::CompletionModel;
1072    use crate::embeddings::EmbeddingModel;
1073    use crate::prelude::TypedPrompt;
1074    use crate::providers::openai::GPT_5_MINI;
1075
1076    #[tokio::test]
1077    #[ignore]
1078    async fn test_azure_embedding() {
1079        let _ = tracing_subscriber::fmt::try_init();
1080
1081        let client = Client::from_env();
1082        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1083        let embeddings = model
1084            .embed_texts(vec!["Hello, world!".to_string()])
1085            .await
1086            .unwrap();
1087
1088        tracing::info!("Azure embedding: {:?}", embeddings);
1089    }
1090
1091    #[tokio::test]
1092    #[ignore]
1093    async fn test_azure_embedding_dimensions() {
1094        let _ = tracing_subscriber::fmt::try_init();
1095
1096        let ndims = 256;
1097        let client = Client::from_env();
1098        let model = client.embedding_model_with_ndims(TEXT_EMBEDDING_3_SMALL, ndims);
1099        let embedding = model.embed_text("Hello, world!").await.unwrap();
1100
1101        assert!(embedding.vec.len() == ndims);
1102
1103        tracing::info!("Azure dimensions embedding: {:?}", embedding);
1104    }
1105
1106    #[tokio::test]
1107    #[ignore]
1108    async fn test_azure_completion() {
1109        let _ = tracing_subscriber::fmt::try_init();
1110
1111        let client = Client::from_env();
1112        let model = client.completion_model(GPT_4O_MINI);
1113        let completion = model
1114            .completion(CompletionRequest {
1115                model: None,
1116                preamble: Some("You are a helpful assistant.".to_string()),
1117                chat_history: OneOrMany::one("Hello!".into()),
1118                documents: vec![],
1119                max_tokens: Some(100),
1120                temperature: Some(0.0),
1121                tools: vec![],
1122                tool_choice: None,
1123                additional_params: None,
1124                output_schema: None,
1125            })
1126            .await
1127            .unwrap();
1128
1129        tracing::info!("Azure completion: {:?}", completion);
1130    }
1131
1132    #[tokio::test]
1133    #[ignore]
1134    async fn test_azure_structured_output() {
1135        let _ = tracing_subscriber::fmt::try_init();
1136
1137        #[derive(Debug, Deserialize, JsonSchema)]
1138        struct Person {
1139            name: String,
1140            age: u32,
1141        }
1142
1143        let client = Client::from_env();
1144        let agent = client
1145            .agent(GPT_5_MINI)
1146            .preamble("You are a helpful assistant that extracts personal details.")
1147            .max_tokens(100)
1148            .output_schema::<Person>()
1149            .build();
1150
1151        let result: Person = agent
1152            .prompt_typed("Hello! My name is John Doe and I'm 54 years old.")
1153            .await
1154            .expect("failed to extract person");
1155
1156        assert!(result.name == "John Doe");
1157        assert!(result.age == 54);
1158
1159        tracing::info!("Extracted person: {:?}", result);
1160    }
1161
1162    #[tokio::test]
1163    async fn test_client_initialization() {
1164        let _client = crate::providers::azure::Client::builder()
1165            .api_key("test")
1166            .azure_endpoint("test".to_string()) // add your endpoint here!
1167            .build()
1168            .expect("Client::builder() failed");
1169    }
1170}