rig/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::azure;
6//!
7//! let client = azure::Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
8//!
9//! let gpt4o = client.completion_model(azure::GPT_4O);
10//! ```
11
12use std::fmt::Debug;
13
14use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
15#[cfg(feature = "image")]
16use crate::client::Nothing;
17use crate::client::{
18    self, ApiKey, Capabilities, Capable, DebugExt, Provider, ProviderBuilder, ProviderClient,
19};
20use crate::completion::GetTokenUsage;
21use crate::http_client::{self, HttpClientExt, bearer_auth_header};
22use crate::streaming::StreamingCompletionResponse;
23use crate::transcription::TranscriptionError;
24use crate::{
25    completion::{self, CompletionError, CompletionRequest},
26    embeddings::{self, EmbeddingError},
27    json_utils,
28    providers::openai,
29    telemetry::SpanCombinator,
30    transcription::{self},
31};
32use bytes::Bytes;
33use reqwest::multipart::Part;
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36// ================================================================
37// Main Azure OpenAI Client
38// ================================================================
39
40const DEFAULT_API_VERSION: &str = "2024-10-21";
41
42#[derive(Debug, Clone)]
43pub struct AzureExt {
44    endpoint: String,
45    api_version: String,
46}
47
48impl DebugExt for AzureExt {
49    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
50        [
51            ("endpoint", (&self.endpoint as &dyn Debug)),
52            ("api_version", (&self.api_version as &dyn Debug)),
53        ]
54        .into_iter()
55    }
56}
57
58// TODO: @FayCarsons - this should be a type-safe builder,
59// but that would require extending the `ProviderBuilder`
60// to have some notion of complete vs incomplete states in a
61// given extension builder
62#[derive(Debug, Clone)]
63pub struct AzureExtBuilder {
64    endpoint: Option<String>,
65    api_version: String,
66}
67
68impl Default for AzureExtBuilder {
69    fn default() -> Self {
70        Self {
71            endpoint: None,
72            api_version: DEFAULT_API_VERSION.into(),
73        }
74    }
75}
76
77pub type Client<H = reqwest::Client> = client::Client<AzureExt, H>;
78pub type ClientBuilder<H = reqwest::Client> =
79    client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H>;
80
81impl Provider for AzureExt {
82    type Builder = AzureExtBuilder;
83
84    /// Verifying Azure auth without consuming tokens is not supported
85    const VERIFY_PATH: &'static str = "";
86
87    fn build<H>(
88        builder: &client::ClientBuilder<
89            Self::Builder,
90            <Self::Builder as ProviderBuilder>::ApiKey,
91            H,
92        >,
93    ) -> http_client::Result<Self> {
94        let AzureExtBuilder {
95            endpoint,
96            api_version,
97            ..
98        } = builder.ext().clone();
99
100        match endpoint {
101            Some(endpoint) => Ok(Self {
102                endpoint,
103                api_version,
104            }),
105            None => Err(http_client::Error::Instance(
106                "Azure client must be provided an endpoint prior to building".into(),
107            )),
108        }
109    }
110}
111
112impl<H> Capabilities<H> for AzureExt {
113    type Completion = Capable<CompletionModel<H>>;
114    type Embeddings = Capable<EmbeddingModel<H>>;
115    type Transcription = Capable<TranscriptionModel<H>>;
116    #[cfg(feature = "image")]
117    type ImageGeneration = Nothing;
118    #[cfg(feature = "audio")]
119    type AudioGeneration = Capable<AudioGenerationModel<H>>;
120}
121
122impl ProviderBuilder for AzureExtBuilder {
123    type Output = AzureExt;
124    type ApiKey = AzureOpenAIAuth;
125
126    const BASE_URL: &'static str = "";
127
128    fn finish<H>(
129        &self,
130        mut builder: client::ClientBuilder<Self, Self::ApiKey, H>,
131    ) -> http_client::Result<client::ClientBuilder<Self, Self::ApiKey, H>> {
132        use AzureOpenAIAuth::*;
133
134        let auth = builder.get_api_key().clone();
135
136        match auth {
137            Token(token) => bearer_auth_header(builder.headers_mut(), token.as_str())?,
138            ApiKey(key) => {
139                let k = http::HeaderName::from_static("api-key");
140                let v = http::HeaderValue::from_str(key.as_str())?;
141
142                builder.headers_mut().insert(k, v);
143            }
144        }
145
146        Ok(builder)
147    }
148}
149
150impl<H> ClientBuilder<H> {
151    /// API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
152    pub fn api_version(mut self, api_version: &str) -> Self {
153        self.ext_mut().api_version = api_version.into();
154
155        self
156    }
157}
158
159impl<H> client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H> {
160    /// Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
161    pub fn azure_endpoint(self, endpoint: String) -> ClientBuilder<H> {
162        self.over_ext(|AzureExtBuilder { api_version, .. }| AzureExtBuilder {
163            endpoint: Some(endpoint),
164            api_version,
165        })
166    }
167}
168
169#[derive(Clone)]
170pub enum AzureOpenAIAuth {
171    ApiKey(String),
172    Token(String),
173}
174
175impl ApiKey for AzureOpenAIAuth {}
176
177impl std::fmt::Debug for AzureOpenAIAuth {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        match self {
180            Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
181            Self::Token(_) => write!(f, "Token <REDACTED>"),
182        }
183    }
184}
185
186impl<S> From<S> for AzureOpenAIAuth
187where
188    S: Into<String>,
189{
190    fn from(token: S) -> Self {
191        AzureOpenAIAuth::Token(token.into())
192    }
193}
194
195impl<T> Client<T>
196where
197    T: HttpClientExt,
198{
199    fn endpoint(&self) -> &str {
200        &self.ext().endpoint
201    }
202
203    fn api_version(&self) -> &str {
204        &self.ext().api_version
205    }
206
207    fn post_embedding(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
208        let url = format!(
209            "{}/openai/deployments/{}/embeddings?api-version={}",
210            self.endpoint(),
211            deployment_id.trim_start_matches('/'),
212            self.api_version()
213        );
214
215        self.post(&url)
216    }
217
218    #[cfg(feature = "audio")]
219    fn post_audio_generation(
220        &self,
221        deployment_id: &str,
222    ) -> http_client::Result<http_client::Builder> {
223        let url = format!(
224            "{}/openai/deployments/{}/audio/speech?api-version={}",
225            self.endpoint(),
226            deployment_id.trim_start_matches('/'),
227            self.api_version()
228        );
229
230        self.post(url)
231    }
232
233    fn post_chat_completion(
234        &self,
235        deployment_id: &str,
236    ) -> http_client::Result<http_client::Builder> {
237        let url = format!(
238            "{}/openai/deployments/{}/chat/completions?api-version={}",
239            self.endpoint(),
240            deployment_id.trim_start_matches('/'),
241            self.api_version()
242        );
243
244        self.post(&url)
245    }
246
247    fn post_transcription(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
248        let url = format!(
249            "{}/openai/deployments/{}/audio/translations?api-version={}",
250            self.endpoint(),
251            deployment_id.trim_start_matches('/'),
252            self.api_version()
253        );
254
255        self.post(&url)
256    }
257
258    #[cfg(feature = "image")]
259    fn post_image_generation(
260        &self,
261        deployment_id: &str,
262    ) -> http_client::Result<http_client::Builder> {
263        let url = format!(
264            "{}/openai/deployments/{}/images/generations?api-version={}",
265            self.endpoint(),
266            deployment_id.trim_start_matches('/'),
267            self.api_version()
268        );
269
270        self.post(&url)
271    }
272}
273
274pub struct AzureOpenAIClientParams {
275    api_key: String,
276    version: String,
277    header: String,
278}
279
280impl ProviderClient for Client {
281    type Input = AzureOpenAIClientParams;
282
283    /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
284    fn from_env() -> Self {
285        let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
286            AzureOpenAIAuth::ApiKey(api_key)
287        } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
288            AzureOpenAIAuth::Token(token)
289        } else {
290            panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
291        };
292
293        let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
294        let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
295
296        Self::builder()
297            .api_key(auth)
298            .azure_endpoint(azure_endpoint)
299            .api_version(&api_version)
300            .build()
301            .unwrap()
302    }
303
304    fn from_val(
305        AzureOpenAIClientParams {
306            api_key,
307            version,
308            header,
309        }: Self::Input,
310    ) -> Self {
311        let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
312
313        Self::builder()
314            .api_key(auth)
315            .azure_endpoint(header)
316            .api_version(&version)
317            .build()
318            .unwrap()
319    }
320}
321
322#[derive(Debug, Deserialize)]
323struct ApiErrorResponse {
324    message: String,
325}
326
327#[derive(Debug, Deserialize)]
328#[serde(untagged)]
329enum ApiResponse<T> {
330    Ok(T),
331    Err(ApiErrorResponse),
332}
333
334// ================================================================
335// Azure OpenAI Embedding API
336// ================================================================
337
338/// `text-embedding-3-large` embedding model
339pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
340/// `text-embedding-3-small` embedding model
341pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
342/// `text-embedding-ada-002` embedding model
343pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
344
345fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
346    match identifier {
347        TEXT_EMBEDDING_3_LARGE => Some(3_072),
348        TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
349        _ => None,
350    }
351}
352
353#[derive(Debug, Deserialize)]
354pub struct EmbeddingResponse {
355    pub object: String,
356    pub data: Vec<EmbeddingData>,
357    pub model: String,
358    pub usage: Usage,
359}
360
361impl From<ApiErrorResponse> for EmbeddingError {
362    fn from(err: ApiErrorResponse) -> Self {
363        EmbeddingError::ProviderError(err.message)
364    }
365}
366
367impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
368    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
369        match value {
370            ApiResponse::Ok(response) => Ok(response),
371            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
372        }
373    }
374}
375
376#[derive(Debug, Deserialize)]
377pub struct EmbeddingData {
378    pub object: String,
379    pub embedding: Vec<f64>,
380    pub index: usize,
381}
382
383#[derive(Clone, Debug, Deserialize)]
384pub struct Usage {
385    pub prompt_tokens: usize,
386    pub total_tokens: usize,
387}
388
389impl GetTokenUsage for Usage {
390    fn token_usage(&self) -> Option<crate::completion::Usage> {
391        let mut usage = crate::completion::Usage::new();
392
393        usage.input_tokens = self.prompt_tokens as u64;
394        usage.total_tokens = self.total_tokens as u64;
395        usage.output_tokens = usage.total_tokens - usage.input_tokens;
396
397        Some(usage)
398    }
399}
400
401impl std::fmt::Display for Usage {
402    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403        write!(
404            f,
405            "Prompt tokens: {} Total tokens: {}",
406            self.prompt_tokens, self.total_tokens
407        )
408    }
409}
410
411#[derive(Clone)]
412pub struct EmbeddingModel<T = reqwest::Client> {
413    client: Client<T>,
414    pub model: String,
415    ndims: usize,
416}
417
418impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
419where
420    T: HttpClientExt + Default + Clone,
421{
422    const MAX_DOCUMENTS: usize = 1024;
423
424    type Client = Client<T>;
425
426    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
427        Self::new(client.clone(), model, dims)
428    }
429
430    fn ndims(&self) -> usize {
431        self.ndims
432    }
433
434    #[cfg_attr(feature = "worker", worker::send)]
435    async fn embed_texts(
436        &self,
437        documents: impl IntoIterator<Item = String>,
438    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
439        let documents = documents.into_iter().collect::<Vec<_>>();
440
441        let body = serde_json::to_vec(&json!({
442            "input": documents,
443        }))?;
444
445        let req = self
446            .client
447            .post_embedding(self.model.as_str())?
448            .body(body)
449            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
450
451        let response = self.client.send(req).await?;
452
453        if response.status().is_success() {
454            let body: Vec<u8> = response.into_body().await?;
455            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
456
457            match body {
458                ApiResponse::Ok(response) => {
459                    tracing::info!(target: "rig",
460                        "Azure embedding token usage: {}",
461                        response.usage
462                    );
463
464                    if response.data.len() != documents.len() {
465                        return Err(EmbeddingError::ResponseError(
466                            "Response data length does not match input length".into(),
467                        ));
468                    }
469
470                    Ok(response
471                        .data
472                        .into_iter()
473                        .zip(documents.into_iter())
474                        .map(|(embedding, document)| embeddings::Embedding {
475                            document,
476                            vec: embedding.embedding,
477                        })
478                        .collect())
479                }
480                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
481            }
482        } else {
483            let text = http_client::text(response).await?;
484            Err(EmbeddingError::ProviderError(text))
485        }
486    }
487}
488
489impl<T> EmbeddingModel<T> {
490    pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
491        let model = model.into();
492        let ndims = ndims
493            .or(model_dimensions_from_identifier(&model))
494            .unwrap_or_default();
495
496        Self {
497            client,
498            model,
499            ndims,
500        }
501    }
502
503    pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
504        let ndims = ndims.unwrap_or_default();
505
506        Self {
507            client,
508            model: model.into(),
509            ndims,
510        }
511    }
512}
513
514// ================================================================
515// Azure OpenAI Completion API
516// ================================================================
517
518/// `o1` completion model
519pub const O1: &str = "o1";
520/// `o1-preview` completion model
521pub const O1_PREVIEW: &str = "o1-preview";
522/// `o1-mini` completion model
523pub const O1_MINI: &str = "o1-mini";
524/// `gpt-4o` completion model
525pub const GPT_4O: &str = "gpt-4o";
526/// `gpt-4o-mini` completion model
527pub const GPT_4O_MINI: &str = "gpt-4o-mini";
528/// `gpt-4o-realtime-preview` completion model
529pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
530/// `gpt-4-turbo` completion model
531pub const GPT_4_TURBO: &str = "gpt-4";
532/// `gpt-4` completion model
533pub const GPT_4: &str = "gpt-4";
534/// `gpt-4-32k` completion model
535pub const GPT_4_32K: &str = "gpt-4-32k";
536/// `gpt-4-32k` completion model
537pub const GPT_4_32K_0613: &str = "gpt-4-32k";
538/// `gpt-3.5-turbo` completion model
539pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
540/// `gpt-3.5-turbo-instruct` completion model
541pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
542/// `gpt-3.5-turbo-16k` completion model
543pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
544
545#[derive(Debug, Serialize, Deserialize)]
546pub(super) struct AzureOpenAICompletionRequest {
547    model: String,
548    pub messages: Vec<openai::Message>,
549    #[serde(flatten, skip_serializing_if = "Option::is_none")]
550    temperature: Option<f64>,
551    #[serde(skip_serializing_if = "Vec::is_empty")]
552    tools: Vec<openai::ToolDefinition>,
553    #[serde(flatten, skip_serializing_if = "Option::is_none")]
554    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
555    #[serde(flatten, skip_serializing_if = "Option::is_none")]
556    pub additional_params: Option<serde_json::Value>,
557}
558
559impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
560    type Error = CompletionError;
561
562    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
563        //FIXME: Must fix!
564        if req.tool_choice.is_some() {
565            tracing::warn!(
566                "Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25."
567            );
568        }
569
570        let mut full_history: Vec<openai::Message> = match &req.preamble {
571            Some(preamble) => vec![openai::Message::system(preamble)],
572            None => vec![],
573        };
574
575        if let Some(docs) = req.normalized_documents() {
576            let docs: Vec<openai::Message> = docs.try_into()?;
577            full_history.extend(docs);
578        }
579
580        let chat_history: Vec<openai::Message> = req
581            .chat_history
582            .clone()
583            .into_iter()
584            .map(|message| message.try_into())
585            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
586            .into_iter()
587            .flatten()
588            .collect();
589
590        full_history.extend(chat_history);
591
592        let tool_choice = req
593            .tool_choice
594            .clone()
595            .map(crate::providers::openrouter::ToolChoice::try_from)
596            .transpose()?;
597
598        Ok(Self {
599            model: model.to_string(),
600            messages: full_history,
601            temperature: req.temperature,
602            tools: req
603                .tools
604                .clone()
605                .into_iter()
606                .map(openai::ToolDefinition::from)
607                .collect::<Vec<_>>(),
608            tool_choice,
609            additional_params: req.additional_params,
610        })
611    }
612}
613
614#[derive(Clone)]
615pub struct CompletionModel<T = reqwest::Client> {
616    client: Client<T>,
617    /// Name of the model (e.g.: gpt-4o-mini)
618    pub model: String,
619}
620
621impl<T> CompletionModel<T> {
622    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
623        Self {
624            client,
625            model: model.into(),
626        }
627    }
628}
629
630impl<T> completion::CompletionModel for CompletionModel<T>
631where
632    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
633{
634    type Response = openai::CompletionResponse;
635    type StreamingResponse = openai::StreamingCompletionResponse;
636    type Client = Client<T>;
637
638    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
639        Self::new(client.clone(), model.into())
640    }
641
642    #[cfg_attr(feature = "worker", worker::send)]
643    async fn completion(
644        &self,
645        completion_request: CompletionRequest,
646    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
647        let span = if tracing::Span::current().is_disabled() {
648            info_span!(
649                target: "rig::completions",
650                "chat",
651                gen_ai.operation.name = "chat",
652                gen_ai.provider.name = "azure.openai",
653                gen_ai.request.model = self.model,
654                gen_ai.system_instructions = &completion_request.preamble,
655                gen_ai.response.id = tracing::field::Empty,
656                gen_ai.response.model = tracing::field::Empty,
657                gen_ai.usage.output_tokens = tracing::field::Empty,
658                gen_ai.usage.input_tokens = tracing::field::Empty,
659                gen_ai.input.messages = tracing::field::Empty,
660                gen_ai.output.messages = tracing::field::Empty,
661            )
662        } else {
663            tracing::Span::current()
664        };
665
666        let request =
667            AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
668
669        span.record_model_input(&request.messages);
670        let body = serde_json::to_vec(&request)?;
671
672        let req = self
673            .client
674            .post_chat_completion(&self.model)?
675            .body(body)
676            .map_err(http_client::Error::from)?;
677
678        async move {
679            let response = self.client.send::<_, Bytes>(req).await.unwrap();
680
681            let status = response.status();
682            let response_body = response.into_body().into_future().await?.to_vec();
683
684            if status.is_success() {
685                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
686                    &response_body,
687                )? {
688                    ApiResponse::Ok(response) => {
689                        let span = tracing::Span::current();
690                        span.record_model_output(&response.choices);
691                        span.record_response_metadata(&response);
692                        span.record_token_usage(&response.usage);
693                        tracing::trace!(
694                            target: "rig::completions",
695                            "Azure completion response: {}",
696                            serde_json::to_string_pretty(&response)?
697                        );
698                        response.try_into()
699                    }
700                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
701                }
702            } else {
703                Err(CompletionError::ProviderError(
704                    String::from_utf8_lossy(&response_body).to_string(),
705                ))
706            }
707        }
708        .instrument(span)
709        .await
710    }
711
712    #[cfg_attr(feature = "worker", worker::send)]
713    async fn stream(
714        &self,
715        completion_request: CompletionRequest,
716    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
717        let preamble = completion_request.preamble.clone();
718        let mut request =
719            AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
720
721        let params = json_utils::merge(
722            request.additional_params.unwrap_or(serde_json::json!({})),
723            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
724        );
725
726        request.additional_params = Some(params);
727
728        let body = serde_json::to_vec(&request)?;
729
730        let req = self
731            .client
732            .post_chat_completion(&self.model)?
733            .body(body)
734            .map_err(http_client::Error::from)?;
735
736        let span = if tracing::Span::current().is_disabled() {
737            info_span!(
738                target: "rig::completions",
739                "chat_streaming",
740                gen_ai.operation.name = "chat_streaming",
741                gen_ai.provider.name = "azure.openai",
742                gen_ai.request.model = self.model,
743                gen_ai.system_instructions = &preamble,
744                gen_ai.response.id = tracing::field::Empty,
745                gen_ai.response.model = tracing::field::Empty,
746                gen_ai.usage.output_tokens = tracing::field::Empty,
747                gen_ai.usage.input_tokens = tracing::field::Empty,
748                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
749                gen_ai.output.messages = tracing::field::Empty,
750            )
751        } else {
752            tracing::Span::current()
753        };
754
755        tracing_futures::Instrument::instrument(
756            send_compatible_streaming_request(self.client.http_client().clone(), req),
757            span,
758        )
759        .await
760    }
761}
762
763// ================================================================
764// Azure OpenAI Transcription API
765// ================================================================
766
767#[derive(Clone)]
768pub struct TranscriptionModel<T = reqwest::Client> {
769    client: Client<T>,
770    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
771    pub model: String,
772}
773
774impl<T> TranscriptionModel<T> {
775    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
776        Self {
777            client,
778            model: model.into(),
779        }
780    }
781}
782
783impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
784where
785    T: HttpClientExt + Clone + 'static,
786{
787    type Response = TranscriptionResponse;
788    type Client = Client<T>;
789
790    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
791        Self::new(client.clone(), model)
792    }
793
794    #[cfg_attr(feature = "worker", worker::send)]
795    async fn transcription(
796        &self,
797        request: transcription::TranscriptionRequest,
798    ) -> Result<
799        transcription::TranscriptionResponse<Self::Response>,
800        transcription::TranscriptionError,
801    > {
802        let data = request.data;
803
804        let mut body = reqwest::multipart::Form::new().part(
805            "file",
806            Part::bytes(data).file_name(request.filename.clone()),
807        );
808
809        if let Some(prompt) = request.prompt {
810            body = body.text("prompt", prompt.clone());
811        }
812
813        if let Some(ref temperature) = request.temperature {
814            body = body.text("temperature", temperature.to_string());
815        }
816
817        if let Some(ref additional_params) = request.additional_params {
818            for (key, value) in additional_params
819                .as_object()
820                .expect("Additional Parameters to OpenAI Transcription should be a map")
821            {
822                body = body.text(key.to_owned(), value.to_string());
823            }
824        }
825
826        let req = self
827            .client
828            .post_transcription(&self.model)?
829            .body(body)
830            .map_err(|e| TranscriptionError::HttpError(e.into()))?;
831
832        let response = self
833            .client
834            .http_client()
835            .send_multipart::<Bytes>(req)
836            .await?;
837        let status = response.status();
838        let response_body = response.into_body().into_future().await?.to_vec();
839
840        if status.is_success() {
841            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
842                ApiResponse::Ok(response) => response.try_into(),
843                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
844                    api_error_response.message,
845                )),
846            }
847        } else {
848            Err(TranscriptionError::ProviderError(
849                String::from_utf8_lossy(&response_body).to_string(),
850            ))
851        }
852    }
853}
854
855// ================================================================
856// Azure OpenAI Image Generation API
857// ================================================================
858#[cfg(feature = "image")]
859pub use image_generation::*;
860use tracing::{Instrument, info_span};
861#[cfg(feature = "image")]
862#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
863mod image_generation {
864    use crate::http_client::HttpClientExt;
865    use crate::image_generation;
866    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
867    use crate::providers::azure::{ApiResponse, Client};
868    use crate::providers::openai::ImageGenerationResponse;
869    use bytes::Bytes;
870    use serde_json::json;
871
872    #[derive(Clone)]
873    pub struct ImageGenerationModel<T = reqwest::Client> {
874        client: Client<T>,
875        pub model: String,
876    }
877
878    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
879    where
880        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
881    {
882        type Response = ImageGenerationResponse;
883
884        type Client = Client<T>;
885
886        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
887            Self {
888                client: client.clone(),
889                model: model.into(),
890            }
891        }
892
893        #[cfg_attr(feature = "worker", worker::send)]
894        async fn image_generation(
895            &self,
896            generation_request: ImageGenerationRequest,
897        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
898        {
899            let request = json!({
900                "model": self.model,
901                "prompt": generation_request.prompt,
902                "size": format!("{}x{}", generation_request.width, generation_request.height),
903                "response_format": "b64_json"
904            });
905
906            let body = serde_json::to_vec(&request)?;
907
908            let req = self
909                .client
910                .post_image_generation(&self.model)?
911                .body(body)
912                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
913
914            let response = self.client.send::<_, Bytes>(req).await?;
915            let status = response.status();
916            let response_body = response.into_body().into_future().await?.to_vec();
917
918            if !status.is_success() {
919                return Err(ImageGenerationError::ProviderError(format!(
920                    "{status}: {}",
921                    String::from_utf8_lossy(&response_body)
922                )));
923            }
924
925            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
926                ApiResponse::Ok(response) => response.try_into(),
927                ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
928            }
929        }
930    }
931}
932// ================================================================
933// Azure OpenAI Audio Generation API
934// ================================================================
935
936#[cfg(feature = "audio")]
937pub use audio_generation::*;
938
939#[cfg(feature = "audio")]
940#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
941mod audio_generation {
942    use super::Client;
943    use crate::audio_generation::{
944        self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
945    };
946    use crate::http_client::HttpClientExt;
947    use bytes::Bytes;
948    use serde_json::json;
949
950    #[derive(Clone)]
951    pub struct AudioGenerationModel<T = reqwest::Client> {
952        client: Client<T>,
953        model: String,
954    }
955
956    impl<T> AudioGenerationModel<T> {
957        pub fn new(client: Client<T>, deployment_name: impl Into<String>) -> Self {
958            Self {
959                client,
960                model: deployment_name.into(),
961            }
962        }
963    }
964
965    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
966    where
967        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
968    {
969        type Response = Bytes;
970        type Client = Client<T>;
971
972        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
973            Self::new(client.clone(), model)
974        }
975
976        async fn audio_generation(
977            &self,
978            request: AudioGenerationRequest,
979        ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
980            let request = json!({
981                "model": self.model,
982                "input": request.text,
983                "voice": request.voice,
984                "speed": request.speed,
985            });
986
987            let body = serde_json::to_vec(&request)?;
988
989            let req = self
990                .client
991                .post_audio_generation("/audio/speech")?
992                .header("Content-Type", "application/json")
993                .body(body)
994                .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
995
996            let response = self.client.send::<_, Bytes>(req).await?;
997            let status = response.status();
998            let response_body = response.into_body().into_future().await?;
999
1000            if !status.is_success() {
1001                return Err(AudioGenerationError::ProviderError(format!(
1002                    "{status}: {}",
1003                    String::from_utf8_lossy(&response_body)
1004                )));
1005            }
1006
1007            Ok(AudioGenerationResponse {
1008                audio: response_body.to_vec(),
1009                response: response_body,
1010            })
1011        }
1012    }
1013}
1014
1015#[cfg(test)]
1016mod azure_tests {
1017    use super::*;
1018
1019    use crate::OneOrMany;
1020    use crate::client::{completion::CompletionClient, embeddings::EmbeddingsClient};
1021    use crate::completion::CompletionModel;
1022    use crate::embeddings::EmbeddingModel;
1023
1024    #[tokio::test]
1025    #[ignore]
1026    async fn test_azure_embedding() {
1027        let _ = tracing_subscriber::fmt::try_init();
1028
1029        let client = Client::<reqwest::Client>::from_env();
1030        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1031        let embeddings = model
1032            .embed_texts(vec!["Hello, world!".to_string()])
1033            .await
1034            .unwrap();
1035
1036        tracing::info!("Azure embedding: {:?}", embeddings);
1037    }
1038
1039    #[tokio::test]
1040    #[ignore]
1041    async fn test_azure_completion() {
1042        let _ = tracing_subscriber::fmt::try_init();
1043
1044        let client = Client::<reqwest::Client>::from_env();
1045        let model = client.completion_model(GPT_4O_MINI);
1046        let completion = model
1047            .completion(CompletionRequest {
1048                preamble: Some("You are a helpful assistant.".to_string()),
1049                chat_history: OneOrMany::one("Hello!".into()),
1050                documents: vec![],
1051                max_tokens: Some(100),
1052                temperature: Some(0.0),
1053                tools: vec![],
1054                tool_choice: None,
1055                additional_params: None,
1056            })
1057            .await
1058            .unwrap();
1059
1060        tracing::info!("Azure completion: {:?}", completion);
1061    }
1062}