rig/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::azure;
6//!
7//! let client = azure::Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
8//!
9//! let gpt4o = client.completion_model(azure::GPT_4O);
10//! ```
11
12use std::fmt::Debug;
13
14use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
15#[cfg(feature = "image")]
16use crate::client::Nothing;
17use crate::client::{
18    self, ApiKey, Capabilities, Capable, DebugExt, Provider, ProviderBuilder, ProviderClient,
19};
20use crate::completion::GetTokenUsage;
21use crate::http_client::multipart::Part;
22use crate::http_client::{self, HttpClientExt, MultipartForm, bearer_auth_header};
23use crate::streaming::StreamingCompletionResponse;
24use crate::transcription::TranscriptionError;
25use crate::{
26    completion::{self, CompletionError, CompletionRequest},
27    embeddings::{self, EmbeddingError},
28    json_utils,
29    providers::openai,
30    telemetry::SpanCombinator,
31    transcription::{self},
32};
33use bytes::Bytes;
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36// ================================================================
37// Main Azure OpenAI Client
38// ================================================================
39
40const DEFAULT_API_VERSION: &str = "2024-10-21";
41
42#[derive(Debug, Clone)]
43pub struct AzureExt {
44    endpoint: String,
45    api_version: String,
46}
47
48impl DebugExt for AzureExt {
49    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
50        [
51            ("endpoint", (&self.endpoint as &dyn Debug)),
52            ("api_version", (&self.api_version as &dyn Debug)),
53        ]
54        .into_iter()
55    }
56}
57
58// TODO: @FayCarsons - this should be a type-safe builder,
59// but that would require extending the `ProviderBuilder`
60// to have some notion of complete vs incomplete states in a
61// given extension builder
62#[derive(Debug, Clone)]
63pub struct AzureExtBuilder {
64    endpoint: Option<String>,
65    api_version: String,
66}
67
68impl Default for AzureExtBuilder {
69    fn default() -> Self {
70        Self {
71            endpoint: None,
72            api_version: DEFAULT_API_VERSION.into(),
73        }
74    }
75}
76
77pub type Client<H = reqwest::Client> = client::Client<AzureExt, H>;
78pub type ClientBuilder<H = reqwest::Client> =
79    client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H>;
80
81impl Provider for AzureExt {
82    type Builder = AzureExtBuilder;
83
84    /// Verifying Azure auth without consuming tokens is not supported
85    const VERIFY_PATH: &'static str = "";
86
87    fn build<H>(
88        builder: &client::ClientBuilder<
89            Self::Builder,
90            <Self::Builder as ProviderBuilder>::ApiKey,
91            H,
92        >,
93    ) -> http_client::Result<Self> {
94        let AzureExtBuilder {
95            endpoint,
96            api_version,
97            ..
98        } = builder.ext().clone();
99
100        match endpoint {
101            Some(endpoint) => Ok(Self {
102                endpoint,
103                api_version,
104            }),
105            None => Err(http_client::Error::Instance(
106                "Azure client must be provided an endpoint prior to building".into(),
107            )),
108        }
109    }
110}
111
112impl<H> Capabilities<H> for AzureExt {
113    type Completion = Capable<CompletionModel<H>>;
114    type Embeddings = Capable<EmbeddingModel<H>>;
115    type Transcription = Capable<TranscriptionModel<H>>;
116    #[cfg(feature = "image")]
117    type ImageGeneration = Nothing;
118    #[cfg(feature = "audio")]
119    type AudioGeneration = Capable<AudioGenerationModel<H>>;
120}
121
122impl ProviderBuilder for AzureExtBuilder {
123    type Output = AzureExt;
124    type ApiKey = AzureOpenAIAuth;
125
126    const BASE_URL: &'static str = "";
127
128    fn finish<H>(
129        &self,
130        mut builder: client::ClientBuilder<Self, Self::ApiKey, H>,
131    ) -> http_client::Result<client::ClientBuilder<Self, Self::ApiKey, H>> {
132        use AzureOpenAIAuth::*;
133
134        let auth = builder.get_api_key().clone();
135
136        match auth {
137            Token(token) => bearer_auth_header(builder.headers_mut(), token.as_str())?,
138            ApiKey(key) => {
139                let k = http::HeaderName::from_static("api-key");
140                let v = http::HeaderValue::from_str(key.as_str())?;
141
142                builder.headers_mut().insert(k, v);
143            }
144        }
145
146        Ok(builder)
147    }
148}
149
150impl<H> ClientBuilder<H> {
151    /// API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
152    pub fn api_version(mut self, api_version: &str) -> Self {
153        self.ext_mut().api_version = api_version.into();
154
155        self
156    }
157}
158
159impl<H> client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H> {
160    /// Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
161    pub fn azure_endpoint(self, endpoint: String) -> ClientBuilder<H> {
162        self.over_ext(|AzureExtBuilder { api_version, .. }| AzureExtBuilder {
163            endpoint: Some(endpoint),
164            api_version,
165        })
166    }
167}
168
169#[derive(Clone)]
170pub enum AzureOpenAIAuth {
171    ApiKey(String),
172    Token(String),
173}
174
175impl ApiKey for AzureOpenAIAuth {}
176
177impl std::fmt::Debug for AzureOpenAIAuth {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        match self {
180            Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
181            Self::Token(_) => write!(f, "Token <REDACTED>"),
182        }
183    }
184}
185
186impl<S> From<S> for AzureOpenAIAuth
187where
188    S: Into<String>,
189{
190    fn from(token: S) -> Self {
191        AzureOpenAIAuth::Token(token.into())
192    }
193}
194
195impl<T> Client<T>
196where
197    T: HttpClientExt,
198{
199    fn endpoint(&self) -> &str {
200        &self.ext().endpoint
201    }
202
203    fn api_version(&self) -> &str {
204        &self.ext().api_version
205    }
206
207    fn post_embedding(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
208        let url = format!(
209            "{}/openai/deployments/{}/embeddings?api-version={}",
210            self.endpoint(),
211            deployment_id.trim_start_matches('/'),
212            self.api_version()
213        );
214
215        self.post(&url)
216    }
217
218    #[cfg(feature = "audio")]
219    fn post_audio_generation(
220        &self,
221        deployment_id: &str,
222    ) -> http_client::Result<http_client::Builder> {
223        let url = format!(
224            "{}/openai/deployments/{}/audio/speech?api-version={}",
225            self.endpoint(),
226            deployment_id.trim_start_matches('/'),
227            self.api_version()
228        );
229
230        self.post(url)
231    }
232
233    fn post_chat_completion(
234        &self,
235        deployment_id: &str,
236    ) -> http_client::Result<http_client::Builder> {
237        let url = format!(
238            "{}/openai/deployments/{}/chat/completions?api-version={}",
239            self.endpoint(),
240            deployment_id.trim_start_matches('/'),
241            self.api_version()
242        );
243
244        self.post(&url)
245    }
246
247    fn post_transcription(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
248        let url = format!(
249            "{}/openai/deployments/{}/audio/translations?api-version={}",
250            self.endpoint(),
251            deployment_id.trim_start_matches('/'),
252            self.api_version()
253        );
254
255        self.post(&url)
256    }
257
258    #[cfg(feature = "image")]
259    fn post_image_generation(
260        &self,
261        deployment_id: &str,
262    ) -> http_client::Result<http_client::Builder> {
263        let url = format!(
264            "{}/openai/deployments/{}/images/generations?api-version={}",
265            self.endpoint(),
266            deployment_id.trim_start_matches('/'),
267            self.api_version()
268        );
269
270        self.post(&url)
271    }
272}
273
274pub struct AzureOpenAIClientParams {
275    api_key: String,
276    version: String,
277    header: String,
278}
279
280impl ProviderClient for Client {
281    type Input = AzureOpenAIClientParams;
282
283    /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
284    fn from_env() -> Self {
285        let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
286            AzureOpenAIAuth::ApiKey(api_key)
287        } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
288            AzureOpenAIAuth::Token(token)
289        } else {
290            panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
291        };
292
293        let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
294        let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
295
296        Self::builder()
297            .api_key(auth)
298            .azure_endpoint(azure_endpoint)
299            .api_version(&api_version)
300            .build()
301            .unwrap()
302    }
303
304    fn from_val(
305        AzureOpenAIClientParams {
306            api_key,
307            version,
308            header,
309        }: Self::Input,
310    ) -> Self {
311        let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
312
313        Self::builder()
314            .api_key(auth)
315            .azure_endpoint(header)
316            .api_version(&version)
317            .build()
318            .unwrap()
319    }
320}
321
322#[derive(Debug, Deserialize)]
323struct ApiErrorResponse {
324    message: String,
325}
326
327#[derive(Debug, Deserialize)]
328#[serde(untagged)]
329enum ApiResponse<T> {
330    Ok(T),
331    Err(ApiErrorResponse),
332}
333
334// ================================================================
335// Azure OpenAI Embedding API
336// ================================================================
337
338/// `text-embedding-3-large` embedding model
339pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
340/// `text-embedding-3-small` embedding model
341pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
342/// `text-embedding-ada-002` embedding model
343pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
344
345fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
346    match identifier {
347        TEXT_EMBEDDING_3_LARGE => Some(3_072),
348        TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
349        _ => None,
350    }
351}
352
353#[derive(Debug, Deserialize)]
354pub struct EmbeddingResponse {
355    pub object: String,
356    pub data: Vec<EmbeddingData>,
357    pub model: String,
358    pub usage: Usage,
359}
360
361impl From<ApiErrorResponse> for EmbeddingError {
362    fn from(err: ApiErrorResponse) -> Self {
363        EmbeddingError::ProviderError(err.message)
364    }
365}
366
367impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
368    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
369        match value {
370            ApiResponse::Ok(response) => Ok(response),
371            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
372        }
373    }
374}
375
376#[derive(Debug, Deserialize)]
377pub struct EmbeddingData {
378    pub object: String,
379    pub embedding: Vec<f64>,
380    pub index: usize,
381}
382
383#[derive(Clone, Debug, Deserialize)]
384pub struct Usage {
385    pub prompt_tokens: usize,
386    pub total_tokens: usize,
387}
388
389impl GetTokenUsage for Usage {
390    fn token_usage(&self) -> Option<crate::completion::Usage> {
391        let mut usage = crate::completion::Usage::new();
392
393        usage.input_tokens = self.prompt_tokens as u64;
394        usage.total_tokens = self.total_tokens as u64;
395        usage.output_tokens = usage.total_tokens - usage.input_tokens;
396
397        Some(usage)
398    }
399}
400
401impl std::fmt::Display for Usage {
402    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403        write!(
404            f,
405            "Prompt tokens: {} Total tokens: {}",
406            self.prompt_tokens, self.total_tokens
407        )
408    }
409}
410
411#[derive(Clone)]
412pub struct EmbeddingModel<T = reqwest::Client> {
413    client: Client<T>,
414    pub model: String,
415    ndims: usize,
416}
417
418impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
419where
420    T: HttpClientExt + Default + Clone + 'static,
421{
422    const MAX_DOCUMENTS: usize = 1024;
423
424    type Client = Client<T>;
425
426    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
427        Self::new(client.clone(), model, dims)
428    }
429
430    fn ndims(&self) -> usize {
431        self.ndims
432    }
433
434    async fn embed_texts(
435        &self,
436        documents: impl IntoIterator<Item = String>,
437    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
438        let documents = documents.into_iter().collect::<Vec<_>>();
439
440        let body = serde_json::to_vec(&json!({
441            "input": documents,
442        }))?;
443
444        let req = self
445            .client
446            .post_embedding(self.model.as_str())?
447            .body(body)
448            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
449
450        let response = self.client.send(req).await?;
451
452        if response.status().is_success() {
453            let body: Vec<u8> = response.into_body().await?;
454            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
455
456            match body {
457                ApiResponse::Ok(response) => {
458                    tracing::info!(target: "rig",
459                        "Azure embedding token usage: {}",
460                        response.usage
461                    );
462
463                    if response.data.len() != documents.len() {
464                        return Err(EmbeddingError::ResponseError(
465                            "Response data length does not match input length".into(),
466                        ));
467                    }
468
469                    Ok(response
470                        .data
471                        .into_iter()
472                        .zip(documents.into_iter())
473                        .map(|(embedding, document)| embeddings::Embedding {
474                            document,
475                            vec: embedding.embedding,
476                        })
477                        .collect())
478                }
479                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
480            }
481        } else {
482            let text = http_client::text(response).await?;
483            Err(EmbeddingError::ProviderError(text))
484        }
485    }
486}
487
488impl<T> EmbeddingModel<T> {
489    pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
490        let model = model.into();
491        let ndims = ndims
492            .or(model_dimensions_from_identifier(&model))
493            .unwrap_or_default();
494
495        Self {
496            client,
497            model,
498            ndims,
499        }
500    }
501
502    pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
503        let ndims = ndims.unwrap_or_default();
504
505        Self {
506            client,
507            model: model.into(),
508            ndims,
509        }
510    }
511}
512
513// ================================================================
514// Azure OpenAI Completion API
515// ================================================================
516
517/// `o1` completion model
518pub const O1: &str = "o1";
519/// `o1-preview` completion model
520pub const O1_PREVIEW: &str = "o1-preview";
521/// `o1-mini` completion model
522pub const O1_MINI: &str = "o1-mini";
523/// `gpt-4o` completion model
524pub const GPT_4O: &str = "gpt-4o";
525/// `gpt-4o-mini` completion model
526pub const GPT_4O_MINI: &str = "gpt-4o-mini";
527/// `gpt-4o-realtime-preview` completion model
528pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
529/// `gpt-4-turbo` completion model
530pub const GPT_4_TURBO: &str = "gpt-4";
531/// `gpt-4` completion model
532pub const GPT_4: &str = "gpt-4";
533/// `gpt-4-32k` completion model
534pub const GPT_4_32K: &str = "gpt-4-32k";
535/// `gpt-4-32k` completion model
536pub const GPT_4_32K_0613: &str = "gpt-4-32k";
537/// `gpt-3.5-turbo` completion model
538pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
539/// `gpt-3.5-turbo-instruct` completion model
540pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
541/// `gpt-3.5-turbo-16k` completion model
542pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
543
544#[derive(Debug, Serialize, Deserialize)]
545pub(super) struct AzureOpenAICompletionRequest {
546    model: String,
547    pub messages: Vec<openai::Message>,
548    #[serde(skip_serializing_if = "Option::is_none")]
549    temperature: Option<f64>,
550    #[serde(skip_serializing_if = "Vec::is_empty")]
551    tools: Vec<openai::ToolDefinition>,
552    #[serde(skip_serializing_if = "Option::is_none")]
553    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
554    #[serde(flatten, skip_serializing_if = "Option::is_none")]
555    pub additional_params: Option<serde_json::Value>,
556}
557
558impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
559    type Error = CompletionError;
560
561    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
562        //FIXME: Must fix!
563        if req.tool_choice.is_some() {
564            tracing::warn!(
565                "Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25."
566            );
567        }
568
569        let mut full_history: Vec<openai::Message> = match &req.preamble {
570            Some(preamble) => vec![openai::Message::system(preamble)],
571            None => vec![],
572        };
573
574        if let Some(docs) = req.normalized_documents() {
575            let docs: Vec<openai::Message> = docs.try_into()?;
576            full_history.extend(docs);
577        }
578
579        let chat_history: Vec<openai::Message> = req
580            .chat_history
581            .clone()
582            .into_iter()
583            .map(|message| message.try_into())
584            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
585            .into_iter()
586            .flatten()
587            .collect();
588
589        full_history.extend(chat_history);
590
591        let tool_choice = req
592            .tool_choice
593            .clone()
594            .map(crate::providers::openrouter::ToolChoice::try_from)
595            .transpose()?;
596
597        Ok(Self {
598            model: model.to_string(),
599            messages: full_history,
600            temperature: req.temperature,
601            tools: req
602                .tools
603                .clone()
604                .into_iter()
605                .map(openai::ToolDefinition::from)
606                .collect::<Vec<_>>(),
607            tool_choice,
608            additional_params: req.additional_params,
609        })
610    }
611}
612
613#[derive(Clone)]
614pub struct CompletionModel<T = reqwest::Client> {
615    client: Client<T>,
616    /// Name of the model (e.g.: gpt-4o-mini)
617    pub model: String,
618}
619
620impl<T> CompletionModel<T> {
621    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
622        Self {
623            client,
624            model: model.into(),
625        }
626    }
627}
628
629impl<T> completion::CompletionModel for CompletionModel<T>
630where
631    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
632{
633    type Response = openai::CompletionResponse;
634    type StreamingResponse = openai::StreamingCompletionResponse;
635    type Client = Client<T>;
636
637    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
638        Self::new(client.clone(), model.into())
639    }
640
641    async fn completion(
642        &self,
643        completion_request: CompletionRequest,
644    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
645        let span = if tracing::Span::current().is_disabled() {
646            info_span!(
647                target: "rig::completions",
648                "chat",
649                gen_ai.operation.name = "chat",
650                gen_ai.provider.name = "azure.openai",
651                gen_ai.request.model = self.model,
652                gen_ai.system_instructions = &completion_request.preamble,
653                gen_ai.response.id = tracing::field::Empty,
654                gen_ai.response.model = tracing::field::Empty,
655                gen_ai.usage.output_tokens = tracing::field::Empty,
656                gen_ai.usage.input_tokens = tracing::field::Empty,
657            )
658        } else {
659            tracing::Span::current()
660        };
661
662        let request =
663            AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
664
665        if enabled!(Level::TRACE) {
666            tracing::trace!(target: "rig::completions",
667                "Azure OpenAI completion request: {}",
668                serde_json::to_string_pretty(&request)?
669            );
670        }
671
672        let body = serde_json::to_vec(&request)?;
673
674        let req = self
675            .client
676            .post_chat_completion(&self.model)?
677            .body(body)
678            .map_err(http_client::Error::from)?;
679
680        async move {
681            let response = self.client.send::<_, Bytes>(req).await.unwrap();
682
683            let status = response.status();
684            let response_body = response.into_body().into_future().await?.to_vec();
685
686            if status.is_success() {
687                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
688                    &response_body,
689                )? {
690                    ApiResponse::Ok(response) => {
691                        let span = tracing::Span::current();
692                        span.record_response_metadata(&response);
693                        span.record_token_usage(&response.usage);
694                        if enabled!(Level::TRACE) {
695                            tracing::trace!(target: "rig::completions",
696                                "Azure OpenAI completion response: {}",
697                                serde_json::to_string_pretty(&response)?
698                            );
699                        }
700                        response.try_into()
701                    }
702                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
703                }
704            } else {
705                Err(CompletionError::ProviderError(
706                    String::from_utf8_lossy(&response_body).to_string(),
707                ))
708            }
709        }
710        .instrument(span)
711        .await
712    }
713
714    async fn stream(
715        &self,
716        completion_request: CompletionRequest,
717    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
718        let preamble = completion_request.preamble.clone();
719        let mut request =
720            AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
721
722        let params = json_utils::merge(
723            request.additional_params.unwrap_or(serde_json::json!({})),
724            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
725        );
726
727        request.additional_params = Some(params);
728
729        if enabled!(Level::TRACE) {
730            tracing::trace!(target: "rig::completions",
731                "Azure OpenAI completion request: {}",
732                serde_json::to_string_pretty(&request)?
733            );
734        }
735
736        let body = serde_json::to_vec(&request)?;
737
738        let req = self
739            .client
740            .post_chat_completion(&self.model)?
741            .body(body)
742            .map_err(http_client::Error::from)?;
743
744        let span = if tracing::Span::current().is_disabled() {
745            info_span!(
746                target: "rig::completions",
747                "chat_streaming",
748                gen_ai.operation.name = "chat_streaming",
749                gen_ai.provider.name = "azure.openai",
750                gen_ai.request.model = self.model,
751                gen_ai.system_instructions = &preamble,
752                gen_ai.response.id = tracing::field::Empty,
753                gen_ai.response.model = tracing::field::Empty,
754                gen_ai.usage.output_tokens = tracing::field::Empty,
755                gen_ai.usage.input_tokens = tracing::field::Empty,
756            )
757        } else {
758            tracing::Span::current()
759        };
760
761        tracing_futures::Instrument::instrument(
762            send_compatible_streaming_request(self.client.clone(), req),
763            span,
764        )
765        .await
766    }
767}
768
769// ================================================================
770// Azure OpenAI Transcription API
771// ================================================================
772
773#[derive(Clone)]
774pub struct TranscriptionModel<T = reqwest::Client> {
775    client: Client<T>,
776    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
777    pub model: String,
778}
779
780impl<T> TranscriptionModel<T> {
781    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
782        Self {
783            client,
784            model: model.into(),
785        }
786    }
787}
788
789impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
790where
791    T: HttpClientExt + Clone + 'static,
792{
793    type Response = TranscriptionResponse;
794    type Client = Client<T>;
795
796    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
797        Self::new(client.clone(), model)
798    }
799
800    async fn transcription(
801        &self,
802        request: transcription::TranscriptionRequest,
803    ) -> Result<
804        transcription::TranscriptionResponse<Self::Response>,
805        transcription::TranscriptionError,
806    > {
807        let data = request.data;
808
809        let mut body =
810            MultipartForm::new().part(Part::bytes("file", data).filename(request.filename.clone()));
811
812        if let Some(prompt) = request.prompt {
813            body = body.text("prompt", prompt.clone());
814        }
815
816        if let Some(ref temperature) = request.temperature {
817            body = body.text("temperature", temperature.to_string());
818        }
819
820        if let Some(ref additional_params) = request.additional_params {
821            for (key, value) in additional_params
822                .as_object()
823                .expect("Additional Parameters to OpenAI Transcription should be a map")
824            {
825                body = body.text(key.to_owned(), value.to_string());
826            }
827        }
828
829        let req = self
830            .client
831            .post_transcription(&self.model)?
832            .body(body)
833            .map_err(|e| TranscriptionError::HttpError(e.into()))?;
834
835        let response = self.client.send_multipart::<Bytes>(req).await?;
836        let status = response.status();
837        let response_body = response.into_body().into_future().await?.to_vec();
838
839        if status.is_success() {
840            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
841                ApiResponse::Ok(response) => response.try_into(),
842                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
843                    api_error_response.message,
844                )),
845            }
846        } else {
847            Err(TranscriptionError::ProviderError(
848                String::from_utf8_lossy(&response_body).to_string(),
849            ))
850        }
851    }
852}
853
854// ================================================================
855// Azure OpenAI Image Generation API
856// ================================================================
857#[cfg(feature = "image")]
858pub use image_generation::*;
859use tracing::{Instrument, Level, enabled, info_span};
860#[cfg(feature = "image")]
861#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
862mod image_generation {
863    use crate::http_client::HttpClientExt;
864    use crate::image_generation;
865    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
866    use crate::providers::azure::{ApiResponse, Client};
867    use crate::providers::openai::ImageGenerationResponse;
868    use bytes::Bytes;
869    use serde_json::json;
870
871    #[derive(Clone)]
872    pub struct ImageGenerationModel<T = reqwest::Client> {
873        client: Client<T>,
874        pub model: String,
875    }
876
877    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
878    where
879        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
880    {
881        type Response = ImageGenerationResponse;
882
883        type Client = Client<T>;
884
885        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
886            Self {
887                client: client.clone(),
888                model: model.into(),
889            }
890        }
891
892        async fn image_generation(
893            &self,
894            generation_request: ImageGenerationRequest,
895        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
896        {
897            let request = json!({
898                "model": self.model,
899                "prompt": generation_request.prompt,
900                "size": format!("{}x{}", generation_request.width, generation_request.height),
901                "response_format": "b64_json"
902            });
903
904            let body = serde_json::to_vec(&request)?;
905
906            let req = self
907                .client
908                .post_image_generation(&self.model)?
909                .body(body)
910                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
911
912            let response = self.client.send::<_, Bytes>(req).await?;
913            let status = response.status();
914            let response_body = response.into_body().into_future().await?.to_vec();
915
916            if !status.is_success() {
917                return Err(ImageGenerationError::ProviderError(format!(
918                    "{status}: {}",
919                    String::from_utf8_lossy(&response_body)
920                )));
921            }
922
923            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
924                ApiResponse::Ok(response) => response.try_into(),
925                ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
926            }
927        }
928    }
929}
930// ================================================================
931// Azure OpenAI Audio Generation API
932// ================================================================
933
934#[cfg(feature = "audio")]
935pub use audio_generation::*;
936
937#[cfg(feature = "audio")]
938#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
939mod audio_generation {
940    use super::Client;
941    use crate::audio_generation::{
942        self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
943    };
944    use crate::http_client::HttpClientExt;
945    use bytes::Bytes;
946    use serde_json::json;
947
948    #[derive(Clone)]
949    pub struct AudioGenerationModel<T = reqwest::Client> {
950        client: Client<T>,
951        model: String,
952    }
953
954    impl<T> AudioGenerationModel<T> {
955        pub fn new(client: Client<T>, deployment_name: impl Into<String>) -> Self {
956            Self {
957                client,
958                model: deployment_name.into(),
959            }
960        }
961    }
962
963    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
964    where
965        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
966    {
967        type Response = Bytes;
968        type Client = Client<T>;
969
970        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
971            Self::new(client.clone(), model)
972        }
973
974        async fn audio_generation(
975            &self,
976            request: AudioGenerationRequest,
977        ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
978            let request = json!({
979                "model": self.model,
980                "input": request.text,
981                "voice": request.voice,
982                "speed": request.speed,
983            });
984
985            let body = serde_json::to_vec(&request)?;
986
987            let req = self
988                .client
989                .post_audio_generation("/audio/speech")?
990                .header("Content-Type", "application/json")
991                .body(body)
992                .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
993
994            let response = self.client.send::<_, Bytes>(req).await?;
995            let status = response.status();
996            let response_body = response.into_body().into_future().await?;
997
998            if !status.is_success() {
999                return Err(AudioGenerationError::ProviderError(format!(
1000                    "{status}: {}",
1001                    String::from_utf8_lossy(&response_body)
1002                )));
1003            }
1004
1005            Ok(AudioGenerationResponse {
1006                audio: response_body.to_vec(),
1007                response: response_body,
1008            })
1009        }
1010    }
1011}
1012
1013#[cfg(test)]
1014mod azure_tests {
1015    use super::*;
1016
1017    use crate::OneOrMany;
1018    use crate::client::{completion::CompletionClient, embeddings::EmbeddingsClient};
1019    use crate::completion::CompletionModel;
1020    use crate::embeddings::EmbeddingModel;
1021
1022    #[tokio::test]
1023    #[ignore]
1024    async fn test_azure_embedding() {
1025        let _ = tracing_subscriber::fmt::try_init();
1026
1027        let client = Client::<reqwest::Client>::from_env();
1028        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1029        let embeddings = model
1030            .embed_texts(vec!["Hello, world!".to_string()])
1031            .await
1032            .unwrap();
1033
1034        tracing::info!("Azure embedding: {:?}", embeddings);
1035    }
1036
1037    #[tokio::test]
1038    #[ignore]
1039    async fn test_azure_completion() {
1040        let _ = tracing_subscriber::fmt::try_init();
1041
1042        let client = Client::<reqwest::Client>::from_env();
1043        let model = client.completion_model(GPT_4O_MINI);
1044        let completion = model
1045            .completion(CompletionRequest {
1046                preamble: Some("You are a helpful assistant.".to_string()),
1047                chat_history: OneOrMany::one("Hello!".into()),
1048                documents: vec![],
1049                max_tokens: Some(100),
1050                temperature: Some(0.0),
1051                tools: vec![],
1052                tool_choice: None,
1053                additional_params: None,
1054            })
1055            .await
1056            .unwrap();
1057
1058        tracing::info!("Azure completion: {:?}", completion);
1059    }
1060}