rig/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::azure;
6//!
7//! let client = azure::Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
8//!
9//! let gpt4o = client.completion_model(azure::GPT_4O);
10//! ```
11
12use std::fmt::Debug;
13
14use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
15#[cfg(feature = "image")]
16use crate::client::Nothing;
17use crate::client::{
18    self, ApiKey, Capabilities, Capable, DebugExt, Provider, ProviderBuilder, ProviderClient,
19};
20use crate::completion::GetTokenUsage;
21use crate::http_client::{self, HttpClientExt, bearer_auth_header};
22use crate::streaming::StreamingCompletionResponse;
23use crate::transcription::TranscriptionError;
24use crate::{
25    completion::{self, CompletionError, CompletionRequest},
26    embeddings::{self, EmbeddingError},
27    json_utils,
28    providers::openai,
29    telemetry::SpanCombinator,
30    transcription::{self},
31};
32use bytes::Bytes;
33use reqwest::multipart::Part;
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36// ================================================================
37// Main Azure OpenAI Client
38// ================================================================
39
40const DEFAULT_API_VERSION: &str = "2024-10-21";
41
42#[derive(Debug, Clone)]
43pub struct AzureExt {
44    endpoint: String,
45    api_version: String,
46}
47
48impl DebugExt for AzureExt {
49    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
50        [
51            ("endpoint", (&self.endpoint as &dyn Debug)),
52            ("api_version", (&self.api_version as &dyn Debug)),
53        ]
54        .into_iter()
55    }
56}
57
58// TODO: @FayCarsons - this should be a type-safe builder,
59// but that would require extending the `ProviderBuilder`
60// to have some notion of complete vs incomplete states in a
61// given extension builder
62#[derive(Debug, Clone)]
63pub struct AzureExtBuilder {
64    endpoint: Option<String>,
65    api_version: String,
66}
67
68impl Default for AzureExtBuilder {
69    fn default() -> Self {
70        Self {
71            endpoint: None,
72            api_version: DEFAULT_API_VERSION.into(),
73        }
74    }
75}
76
77pub type Client<H = reqwest::Client> = client::Client<AzureExt, H>;
78pub type ClientBuilder<H = reqwest::Client> =
79    client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H>;
80
81impl Provider for AzureExt {
82    type Builder = AzureExtBuilder;
83
84    /// Verifying Azure auth without consuming tokens is not supported
85    const VERIFY_PATH: &'static str = "";
86
87    fn build<H>(
88        builder: &client::ClientBuilder<
89            Self::Builder,
90            <Self::Builder as ProviderBuilder>::ApiKey,
91            H,
92        >,
93    ) -> http_client::Result<Self> {
94        let AzureExtBuilder {
95            endpoint,
96            api_version,
97            ..
98        } = builder.ext().clone();
99
100        match endpoint {
101            Some(endpoint) => Ok(Self {
102                endpoint,
103                api_version,
104            }),
105            None => Err(http_client::Error::Instance(
106                "Azure client must be provided an endpoint prior to building".into(),
107            )),
108        }
109    }
110}
111
112impl<H> Capabilities<H> for AzureExt {
113    type Completion = Capable<CompletionModel<H>>;
114    type Embeddings = Capable<EmbeddingModel<H>>;
115    type Transcription = Capable<TranscriptionModel<H>>;
116    #[cfg(feature = "image")]
117    type ImageGeneration = Nothing;
118    #[cfg(feature = "audio")]
119    type AudioGeneration = Capable<AudioGenerationModel<H>>;
120}
121
122impl ProviderBuilder for AzureExtBuilder {
123    type Output = AzureExt;
124    type ApiKey = AzureOpenAIAuth;
125
126    const BASE_URL: &'static str = "";
127
128    fn finish<H>(
129        &self,
130        mut builder: client::ClientBuilder<Self, Self::ApiKey, H>,
131    ) -> http_client::Result<client::ClientBuilder<Self, Self::ApiKey, H>> {
132        use AzureOpenAIAuth::*;
133
134        let auth = builder.get_api_key().clone();
135
136        match auth {
137            Token(token) => bearer_auth_header(builder.headers_mut(), token.as_str())?,
138            ApiKey(key) => {
139                let k = http::HeaderName::from_static("api-key");
140                let v = http::HeaderValue::from_str(key.as_str())?;
141
142                builder.headers_mut().insert(k, v);
143            }
144        }
145
146        Ok(builder)
147    }
148}
149
150impl<H> ClientBuilder<H> {
151    /// API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
152    pub fn api_version(mut self, api_version: &str) -> Self {
153        self.ext_mut().api_version = api_version.into();
154
155        self
156    }
157}
158
159impl<H> client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H> {
160    /// Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
161    pub fn azure_endpoint(self, endpoint: String) -> ClientBuilder<H> {
162        self.over_ext(|AzureExtBuilder { api_version, .. }| AzureExtBuilder {
163            endpoint: Some(endpoint),
164            api_version,
165        })
166    }
167}
168
169#[derive(Clone)]
170pub enum AzureOpenAIAuth {
171    ApiKey(String),
172    Token(String),
173}
174
175impl ApiKey for AzureOpenAIAuth {}
176
177impl std::fmt::Debug for AzureOpenAIAuth {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        match self {
180            Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
181            Self::Token(_) => write!(f, "Token <REDACTED>"),
182        }
183    }
184}
185
186impl<S> From<S> for AzureOpenAIAuth
187where
188    S: Into<String>,
189{
190    fn from(token: S) -> Self {
191        AzureOpenAIAuth::Token(token.into())
192    }
193}
194
195impl<T> Client<T>
196where
197    T: HttpClientExt,
198{
199    fn endpoint(&self) -> &str {
200        &self.ext().endpoint
201    }
202
203    fn api_version(&self) -> &str {
204        &self.ext().api_version
205    }
206
207    fn post_embedding(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
208        let url = format!(
209            "{}/openai/deployments/{}/embeddings?api-version={}",
210            self.endpoint(),
211            deployment_id.trim_start_matches('/'),
212            self.api_version()
213        );
214
215        self.post(&url)
216    }
217
218    #[cfg(feature = "audio")]
219    fn post_audio_generation(
220        &self,
221        deployment_id: &str,
222    ) -> http_client::Result<http_client::Builder> {
223        let url = format!(
224            "{}/openai/deployments/{}/audio/speech?api-version={}",
225            self.endpoint(),
226            deployment_id.trim_start_matches('/'),
227            self.api_version()
228        );
229
230        self.post(url)
231    }
232
233    fn post_chat_completion(
234        &self,
235        deployment_id: &str,
236    ) -> http_client::Result<http_client::Builder> {
237        let url = format!(
238            "{}/openai/deployments/{}/chat/completions?api-version={}",
239            self.endpoint(),
240            deployment_id.trim_start_matches('/'),
241            self.api_version()
242        );
243
244        self.post(&url)
245    }
246
247    fn post_transcription(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
248        let url = format!(
249            "{}/openai/deployments/{}/audio/translations?api-version={}",
250            self.endpoint(),
251            deployment_id.trim_start_matches('/'),
252            self.api_version()
253        );
254
255        self.post(&url)
256    }
257
258    #[cfg(feature = "image")]
259    fn post_image_generation(
260        &self,
261        deployment_id: &str,
262    ) -> http_client::Result<http_client::Builder> {
263        let url = format!(
264            "{}/openai/deployments/{}/images/generations?api-version={}",
265            self.endpoint(),
266            deployment_id.trim_start_matches('/'),
267            self.api_version()
268        );
269
270        self.post(&url)
271    }
272}
273
274pub struct AzureOpenAIClientParams {
275    api_key: String,
276    version: String,
277    header: String,
278}
279
280impl ProviderClient for Client {
281    type Input = AzureOpenAIClientParams;
282
283    /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
284    fn from_env() -> Self {
285        let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
286            AzureOpenAIAuth::ApiKey(api_key)
287        } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
288            AzureOpenAIAuth::Token(token)
289        } else {
290            panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
291        };
292
293        let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
294        let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
295
296        Self::builder()
297            .api_key(auth)
298            .azure_endpoint(azure_endpoint)
299            .api_version(&api_version)
300            .build()
301            .unwrap()
302    }
303
304    fn from_val(
305        AzureOpenAIClientParams {
306            api_key,
307            version,
308            header,
309        }: Self::Input,
310    ) -> Self {
311        let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
312
313        Self::builder()
314            .api_key(auth)
315            .azure_endpoint(header)
316            .api_version(&version)
317            .build()
318            .unwrap()
319    }
320}
321
322#[derive(Debug, Deserialize)]
323struct ApiErrorResponse {
324    message: String,
325}
326
327#[derive(Debug, Deserialize)]
328#[serde(untagged)]
329enum ApiResponse<T> {
330    Ok(T),
331    Err(ApiErrorResponse),
332}
333
334// ================================================================
335// Azure OpenAI Embedding API
336// ================================================================
337
338/// `text-embedding-3-large` embedding model
339pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
340/// `text-embedding-3-small` embedding model
341pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
342/// `text-embedding-ada-002` embedding model
343pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
344
345fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
346    match identifier {
347        TEXT_EMBEDDING_3_LARGE => Some(3_072),
348        TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
349        _ => None,
350    }
351}
352
353#[derive(Debug, Deserialize)]
354pub struct EmbeddingResponse {
355    pub object: String,
356    pub data: Vec<EmbeddingData>,
357    pub model: String,
358    pub usage: Usage,
359}
360
361impl From<ApiErrorResponse> for EmbeddingError {
362    fn from(err: ApiErrorResponse) -> Self {
363        EmbeddingError::ProviderError(err.message)
364    }
365}
366
367impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
368    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
369        match value {
370            ApiResponse::Ok(response) => Ok(response),
371            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
372        }
373    }
374}
375
376#[derive(Debug, Deserialize)]
377pub struct EmbeddingData {
378    pub object: String,
379    pub embedding: Vec<f64>,
380    pub index: usize,
381}
382
383#[derive(Clone, Debug, Deserialize)]
384pub struct Usage {
385    pub prompt_tokens: usize,
386    pub total_tokens: usize,
387}
388
389impl GetTokenUsage for Usage {
390    fn token_usage(&self) -> Option<crate::completion::Usage> {
391        let mut usage = crate::completion::Usage::new();
392
393        usage.input_tokens = self.prompt_tokens as u64;
394        usage.total_tokens = self.total_tokens as u64;
395        usage.output_tokens = usage.total_tokens - usage.input_tokens;
396
397        Some(usage)
398    }
399}
400
401impl std::fmt::Display for Usage {
402    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403        write!(
404            f,
405            "Prompt tokens: {} Total tokens: {}",
406            self.prompt_tokens, self.total_tokens
407        )
408    }
409}
410
411#[derive(Clone)]
412pub struct EmbeddingModel<T = reqwest::Client> {
413    client: Client<T>,
414    pub model: String,
415    ndims: usize,
416}
417
418impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
419where
420    T: HttpClientExt + Default + Clone + 'static,
421{
422    const MAX_DOCUMENTS: usize = 1024;
423
424    type Client = Client<T>;
425
426    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
427        Self::new(client.clone(), model, dims)
428    }
429
430    fn ndims(&self) -> usize {
431        self.ndims
432    }
433
434    #[cfg_attr(feature = "worker", worker::send)]
435    async fn embed_texts(
436        &self,
437        documents: impl IntoIterator<Item = String>,
438    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
439        let documents = documents.into_iter().collect::<Vec<_>>();
440
441        let body = serde_json::to_vec(&json!({
442            "input": documents,
443        }))?;
444
445        let req = self
446            .client
447            .post_embedding(self.model.as_str())?
448            .body(body)
449            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
450
451        let response = self.client.send(req).await?;
452
453        if response.status().is_success() {
454            let body: Vec<u8> = response.into_body().await?;
455            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
456
457            match body {
458                ApiResponse::Ok(response) => {
459                    tracing::info!(target: "rig",
460                        "Azure embedding token usage: {}",
461                        response.usage
462                    );
463
464                    if response.data.len() != documents.len() {
465                        return Err(EmbeddingError::ResponseError(
466                            "Response data length does not match input length".into(),
467                        ));
468                    }
469
470                    Ok(response
471                        .data
472                        .into_iter()
473                        .zip(documents.into_iter())
474                        .map(|(embedding, document)| embeddings::Embedding {
475                            document,
476                            vec: embedding.embedding,
477                        })
478                        .collect())
479                }
480                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
481            }
482        } else {
483            let text = http_client::text(response).await?;
484            Err(EmbeddingError::ProviderError(text))
485        }
486    }
487}
488
489impl<T> EmbeddingModel<T> {
490    pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
491        let model = model.into();
492        let ndims = ndims
493            .or(model_dimensions_from_identifier(&model))
494            .unwrap_or_default();
495
496        Self {
497            client,
498            model,
499            ndims,
500        }
501    }
502
503    pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
504        let ndims = ndims.unwrap_or_default();
505
506        Self {
507            client,
508            model: model.into(),
509            ndims,
510        }
511    }
512}
513
514// ================================================================
515// Azure OpenAI Completion API
516// ================================================================
517
518/// `o1` completion model
519pub const O1: &str = "o1";
520/// `o1-preview` completion model
521pub const O1_PREVIEW: &str = "o1-preview";
522/// `o1-mini` completion model
523pub const O1_MINI: &str = "o1-mini";
524/// `gpt-4o` completion model
525pub const GPT_4O: &str = "gpt-4o";
526/// `gpt-4o-mini` completion model
527pub const GPT_4O_MINI: &str = "gpt-4o-mini";
528/// `gpt-4o-realtime-preview` completion model
529pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
530/// `gpt-4-turbo` completion model
531pub const GPT_4_TURBO: &str = "gpt-4";
532/// `gpt-4` completion model
533pub const GPT_4: &str = "gpt-4";
534/// `gpt-4-32k` completion model
535pub const GPT_4_32K: &str = "gpt-4-32k";
536/// `gpt-4-32k` completion model
537pub const GPT_4_32K_0613: &str = "gpt-4-32k";
538/// `gpt-3.5-turbo` completion model
539pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
540/// `gpt-3.5-turbo-instruct` completion model
541pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
542/// `gpt-3.5-turbo-16k` completion model
543pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
544
545#[derive(Debug, Serialize, Deserialize)]
546pub(super) struct AzureOpenAICompletionRequest {
547    model: String,
548    pub messages: Vec<openai::Message>,
549    #[serde(flatten, skip_serializing_if = "Option::is_none")]
550    temperature: Option<f64>,
551    #[serde(skip_serializing_if = "Vec::is_empty")]
552    tools: Vec<openai::ToolDefinition>,
553    #[serde(flatten, skip_serializing_if = "Option::is_none")]
554    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
555    #[serde(flatten, skip_serializing_if = "Option::is_none")]
556    pub additional_params: Option<serde_json::Value>,
557}
558
559impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
560    type Error = CompletionError;
561
562    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
563        //FIXME: Must fix!
564        if req.tool_choice.is_some() {
565            tracing::warn!(
566                "Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25."
567            );
568        }
569
570        let mut full_history: Vec<openai::Message> = match &req.preamble {
571            Some(preamble) => vec![openai::Message::system(preamble)],
572            None => vec![],
573        };
574
575        if let Some(docs) = req.normalized_documents() {
576            let docs: Vec<openai::Message> = docs.try_into()?;
577            full_history.extend(docs);
578        }
579
580        let chat_history: Vec<openai::Message> = req
581            .chat_history
582            .clone()
583            .into_iter()
584            .map(|message| message.try_into())
585            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
586            .into_iter()
587            .flatten()
588            .collect();
589
590        full_history.extend(chat_history);
591
592        let tool_choice = req
593            .tool_choice
594            .clone()
595            .map(crate::providers::openrouter::ToolChoice::try_from)
596            .transpose()?;
597
598        Ok(Self {
599            model: model.to_string(),
600            messages: full_history,
601            temperature: req.temperature,
602            tools: req
603                .tools
604                .clone()
605                .into_iter()
606                .map(openai::ToolDefinition::from)
607                .collect::<Vec<_>>(),
608            tool_choice,
609            additional_params: req.additional_params,
610        })
611    }
612}
613
614#[derive(Clone)]
615pub struct CompletionModel<T = reqwest::Client> {
616    client: Client<T>,
617    /// Name of the model (e.g.: gpt-4o-mini)
618    pub model: String,
619}
620
621impl<T> CompletionModel<T> {
622    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
623        Self {
624            client,
625            model: model.into(),
626        }
627    }
628}
629
630impl<T> completion::CompletionModel for CompletionModel<T>
631where
632    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
633{
634    type Response = openai::CompletionResponse;
635    type StreamingResponse = openai::StreamingCompletionResponse;
636    type Client = Client<T>;
637
638    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
639        Self::new(client.clone(), model.into())
640    }
641
642    #[cfg_attr(feature = "worker", worker::send)]
643    async fn completion(
644        &self,
645        completion_request: CompletionRequest,
646    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
647        let span = if tracing::Span::current().is_disabled() {
648            info_span!(
649                target: "rig::completions",
650                "chat",
651                gen_ai.operation.name = "chat",
652                gen_ai.provider.name = "azure.openai",
653                gen_ai.request.model = self.model,
654                gen_ai.system_instructions = &completion_request.preamble,
655                gen_ai.response.id = tracing::field::Empty,
656                gen_ai.response.model = tracing::field::Empty,
657                gen_ai.usage.output_tokens = tracing::field::Empty,
658                gen_ai.usage.input_tokens = tracing::field::Empty,
659            )
660        } else {
661            tracing::Span::current()
662        };
663
664        let request =
665            AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
666
667        if enabled!(Level::TRACE) {
668            tracing::trace!(target: "rig::completions",
669                "Azure OpenAI completion request: {}",
670                serde_json::to_string_pretty(&request)?
671            );
672        }
673
674        let body = serde_json::to_vec(&request)?;
675
676        let req = self
677            .client
678            .post_chat_completion(&self.model)?
679            .body(body)
680            .map_err(http_client::Error::from)?;
681
682        async move {
683            let response = self.client.send::<_, Bytes>(req).await.unwrap();
684
685            let status = response.status();
686            let response_body = response.into_body().into_future().await?.to_vec();
687
688            if status.is_success() {
689                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
690                    &response_body,
691                )? {
692                    ApiResponse::Ok(response) => {
693                        let span = tracing::Span::current();
694                        span.record_response_metadata(&response);
695                        span.record_token_usage(&response.usage);
696                        if enabled!(Level::TRACE) {
697                            tracing::trace!(target: "rig::completions",
698                                "Azure OpenAI completion response: {}",
699                                serde_json::to_string_pretty(&response)?
700                            );
701                        }
702                        response.try_into()
703                    }
704                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
705                }
706            } else {
707                Err(CompletionError::ProviderError(
708                    String::from_utf8_lossy(&response_body).to_string(),
709                ))
710            }
711        }
712        .instrument(span)
713        .await
714    }
715
716    #[cfg_attr(feature = "worker", worker::send)]
717    async fn stream(
718        &self,
719        completion_request: CompletionRequest,
720    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
721        let preamble = completion_request.preamble.clone();
722        let mut request =
723            AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
724
725        let params = json_utils::merge(
726            request.additional_params.unwrap_or(serde_json::json!({})),
727            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
728        );
729
730        request.additional_params = Some(params);
731
732        if enabled!(Level::TRACE) {
733            tracing::trace!(target: "rig::completions",
734                "Azure OpenAI completion request: {}",
735                serde_json::to_string_pretty(&request)?
736            );
737        }
738
739        let body = serde_json::to_vec(&request)?;
740
741        let req = self
742            .client
743            .post_chat_completion(&self.model)?
744            .body(body)
745            .map_err(http_client::Error::from)?;
746
747        let span = if tracing::Span::current().is_disabled() {
748            info_span!(
749                target: "rig::completions",
750                "chat_streaming",
751                gen_ai.operation.name = "chat_streaming",
752                gen_ai.provider.name = "azure.openai",
753                gen_ai.request.model = self.model,
754                gen_ai.system_instructions = &preamble,
755                gen_ai.response.id = tracing::field::Empty,
756                gen_ai.response.model = tracing::field::Empty,
757                gen_ai.usage.output_tokens = tracing::field::Empty,
758                gen_ai.usage.input_tokens = tracing::field::Empty,
759            )
760        } else {
761            tracing::Span::current()
762        };
763
764        tracing_futures::Instrument::instrument(
765            send_compatible_streaming_request(self.client.clone(), req),
766            span,
767        )
768        .await
769    }
770}
771
772// ================================================================
773// Azure OpenAI Transcription API
774// ================================================================
775
776#[derive(Clone)]
777pub struct TranscriptionModel<T = reqwest::Client> {
778    client: Client<T>,
779    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
780    pub model: String,
781}
782
783impl<T> TranscriptionModel<T> {
784    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
785        Self {
786            client,
787            model: model.into(),
788        }
789    }
790}
791
792impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
793where
794    T: HttpClientExt + Clone + 'static,
795{
796    type Response = TranscriptionResponse;
797    type Client = Client<T>;
798
799    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
800        Self::new(client.clone(), model)
801    }
802
803    #[cfg_attr(feature = "worker", worker::send)]
804    async fn transcription(
805        &self,
806        request: transcription::TranscriptionRequest,
807    ) -> Result<
808        transcription::TranscriptionResponse<Self::Response>,
809        transcription::TranscriptionError,
810    > {
811        let data = request.data;
812
813        let mut body = reqwest::multipart::Form::new().part(
814            "file",
815            Part::bytes(data).file_name(request.filename.clone()),
816        );
817
818        if let Some(prompt) = request.prompt {
819            body = body.text("prompt", prompt.clone());
820        }
821
822        if let Some(ref temperature) = request.temperature {
823            body = body.text("temperature", temperature.to_string());
824        }
825
826        if let Some(ref additional_params) = request.additional_params {
827            for (key, value) in additional_params
828                .as_object()
829                .expect("Additional Parameters to OpenAI Transcription should be a map")
830            {
831                body = body.text(key.to_owned(), value.to_string());
832            }
833        }
834
835        let req = self
836            .client
837            .post_transcription(&self.model)?
838            .body(body)
839            .map_err(|e| TranscriptionError::HttpError(e.into()))?;
840
841        let response = self.client.send_multipart::<Bytes>(req).await?;
842        let status = response.status();
843        let response_body = response.into_body().into_future().await?.to_vec();
844
845        if status.is_success() {
846            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
847                ApiResponse::Ok(response) => response.try_into(),
848                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
849                    api_error_response.message,
850                )),
851            }
852        } else {
853            Err(TranscriptionError::ProviderError(
854                String::from_utf8_lossy(&response_body).to_string(),
855            ))
856        }
857    }
858}
859
860// ================================================================
861// Azure OpenAI Image Generation API
862// ================================================================
863#[cfg(feature = "image")]
864pub use image_generation::*;
865use tracing::{Instrument, Level, enabled, info_span};
866#[cfg(feature = "image")]
867#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
868mod image_generation {
869    use crate::http_client::HttpClientExt;
870    use crate::image_generation;
871    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
872    use crate::providers::azure::{ApiResponse, Client};
873    use crate::providers::openai::ImageGenerationResponse;
874    use bytes::Bytes;
875    use serde_json::json;
876
877    #[derive(Clone)]
878    pub struct ImageGenerationModel<T = reqwest::Client> {
879        client: Client<T>,
880        pub model: String,
881    }
882
883    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
884    where
885        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
886    {
887        type Response = ImageGenerationResponse;
888
889        type Client = Client<T>;
890
891        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
892            Self {
893                client: client.clone(),
894                model: model.into(),
895            }
896        }
897
898        #[cfg_attr(feature = "worker", worker::send)]
899        async fn image_generation(
900            &self,
901            generation_request: ImageGenerationRequest,
902        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
903        {
904            let request = json!({
905                "model": self.model,
906                "prompt": generation_request.prompt,
907                "size": format!("{}x{}", generation_request.width, generation_request.height),
908                "response_format": "b64_json"
909            });
910
911            let body = serde_json::to_vec(&request)?;
912
913            let req = self
914                .client
915                .post_image_generation(&self.model)?
916                .body(body)
917                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
918
919            let response = self.client.send::<_, Bytes>(req).await?;
920            let status = response.status();
921            let response_body = response.into_body().into_future().await?.to_vec();
922
923            if !status.is_success() {
924                return Err(ImageGenerationError::ProviderError(format!(
925                    "{status}: {}",
926                    String::from_utf8_lossy(&response_body)
927                )));
928            }
929
930            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
931                ApiResponse::Ok(response) => response.try_into(),
932                ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
933            }
934        }
935    }
936}
937// ================================================================
938// Azure OpenAI Audio Generation API
939// ================================================================
940
941#[cfg(feature = "audio")]
942pub use audio_generation::*;
943
944#[cfg(feature = "audio")]
945#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
946mod audio_generation {
947    use super::Client;
948    use crate::audio_generation::{
949        self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
950    };
951    use crate::http_client::HttpClientExt;
952    use bytes::Bytes;
953    use serde_json::json;
954
955    #[derive(Clone)]
956    pub struct AudioGenerationModel<T = reqwest::Client> {
957        client: Client<T>,
958        model: String,
959    }
960
961    impl<T> AudioGenerationModel<T> {
962        pub fn new(client: Client<T>, deployment_name: impl Into<String>) -> Self {
963            Self {
964                client,
965                model: deployment_name.into(),
966            }
967        }
968    }
969
970    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
971    where
972        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
973    {
974        type Response = Bytes;
975        type Client = Client<T>;
976
977        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
978            Self::new(client.clone(), model)
979        }
980
981        async fn audio_generation(
982            &self,
983            request: AudioGenerationRequest,
984        ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
985            let request = json!({
986                "model": self.model,
987                "input": request.text,
988                "voice": request.voice,
989                "speed": request.speed,
990            });
991
992            let body = serde_json::to_vec(&request)?;
993
994            let req = self
995                .client
996                .post_audio_generation("/audio/speech")?
997                .header("Content-Type", "application/json")
998                .body(body)
999                .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
1000
1001            let response = self.client.send::<_, Bytes>(req).await?;
1002            let status = response.status();
1003            let response_body = response.into_body().into_future().await?;
1004
1005            if !status.is_success() {
1006                return Err(AudioGenerationError::ProviderError(format!(
1007                    "{status}: {}",
1008                    String::from_utf8_lossy(&response_body)
1009                )));
1010            }
1011
1012            Ok(AudioGenerationResponse {
1013                audio: response_body.to_vec(),
1014                response: response_body,
1015            })
1016        }
1017    }
1018}
1019
1020#[cfg(test)]
1021mod azure_tests {
1022    use super::*;
1023
1024    use crate::OneOrMany;
1025    use crate::client::{completion::CompletionClient, embeddings::EmbeddingsClient};
1026    use crate::completion::CompletionModel;
1027    use crate::embeddings::EmbeddingModel;
1028
1029    #[tokio::test]
1030    #[ignore]
1031    async fn test_azure_embedding() {
1032        let _ = tracing_subscriber::fmt::try_init();
1033
1034        let client = Client::<reqwest::Client>::from_env();
1035        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1036        let embeddings = model
1037            .embed_texts(vec!["Hello, world!".to_string()])
1038            .await
1039            .unwrap();
1040
1041        tracing::info!("Azure embedding: {:?}", embeddings);
1042    }
1043
1044    #[tokio::test]
1045    #[ignore]
1046    async fn test_azure_completion() {
1047        let _ = tracing_subscriber::fmt::try_init();
1048
1049        let client = Client::<reqwest::Client>::from_env();
1050        let model = client.completion_model(GPT_4O_MINI);
1051        let completion = model
1052            .completion(CompletionRequest {
1053                preamble: Some("You are a helpful assistant.".to_string()),
1054                chat_history: OneOrMany::one("Hello!".into()),
1055                documents: vec![],
1056                max_tokens: Some(100),
1057                temperature: Some(0.0),
1058                tools: vec![],
1059                tool_choice: None,
1060                additional_params: None,
1061            })
1062            .await
1063            .unwrap();
1064
1065        tracing::info!("Azure completion: {:?}", completion);
1066    }
1067}