rig/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::azure;
6//!
7//! let client = azure::Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
8//!
9//! let gpt4o = client.completion_model(azure::GPT_4O);
10//! ```
11
12use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
13
14use crate::completion::GetTokenUsage;
15use crate::http_client::{self, HttpClientExt};
16use crate::json_utils::merge;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19    completion::{self, CompletionError, CompletionRequest},
20    embeddings::{self, EmbeddingError},
21    json_utils,
22    providers::openai,
23    telemetry::SpanCombinator,
24    transcription::{self, TranscriptionError},
25};
26use bytes::Bytes;
27use reqwest::header::AUTHORIZATION;
28use reqwest::multipart::Part;
29use serde::Deserialize;
30use serde_json::json;
31// ================================================================
32// Main Azure OpenAI Client
33// ================================================================
34
35const DEFAULT_API_VERSION: &str = "2024-10-21";
36
37pub struct ClientBuilder<'a, T = reqwest::Client> {
38    auth: AzureOpenAIAuth,
39    api_version: Option<&'a str>,
40    azure_endpoint: &'a str,
41    http_client: T,
42}
43
44impl<'a, T> ClientBuilder<'a, T>
45where
46    T: Default,
47{
48    pub fn new(auth: impl Into<AzureOpenAIAuth>, endpoint: &'a str) -> Self {
49        Self {
50            auth: auth.into(),
51            api_version: None,
52            azure_endpoint: endpoint,
53            http_client: Default::default(),
54        }
55    }
56}
57
58impl<'a, T> ClientBuilder<'a, T> {
59    /// API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
60    pub fn api_version(mut self, api_version: &'a str) -> Self {
61        self.api_version = Some(api_version);
62        self
63    }
64
65    /// Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
66    pub fn azure_endpoint(mut self, azure_endpoint: &'a str) -> Self {
67        self.azure_endpoint = azure_endpoint;
68        self
69    }
70
71    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
72        ClientBuilder {
73            auth: self.auth,
74            api_version: self.api_version,
75            azure_endpoint: self.azure_endpoint,
76            http_client,
77        }
78    }
79
80    pub fn build(self) -> Client<T> {
81        let api_version = self.api_version.unwrap_or(DEFAULT_API_VERSION);
82
83        Client {
84            api_version: api_version.to_string(),
85            azure_endpoint: self.azure_endpoint.to_string(),
86            auth: self.auth,
87            http_client: self.http_client,
88        }
89    }
90}
91
92#[derive(Clone)]
93pub struct Client<T = reqwest::Client> {
94    api_version: String,
95    azure_endpoint: String,
96    auth: AzureOpenAIAuth,
97    http_client: T,
98}
99
100impl<T> std::fmt::Debug for Client<T>
101where
102    T: std::fmt::Debug,
103{
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        f.debug_struct("Client")
106            .field("azure_endpoint", &self.azure_endpoint)
107            .field("http_client", &self.http_client)
108            .field("auth", &"<REDACTED>")
109            .field("api_version", &self.api_version)
110            .finish()
111    }
112}
113
114#[derive(Clone)]
115pub enum AzureOpenAIAuth {
116    ApiKey(String),
117    Token(String),
118}
119
120impl std::fmt::Debug for AzureOpenAIAuth {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        match self {
123            Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
124            Self::Token(_) => write!(f, "Token <REDACTED>"),
125        }
126    }
127}
128
129impl From<String> for AzureOpenAIAuth {
130    fn from(token: String) -> Self {
131        AzureOpenAIAuth::Token(token)
132    }
133}
134
135impl AzureOpenAIAuth {
136    fn as_header(&self) -> (reqwest::header::HeaderName, reqwest::header::HeaderValue) {
137        match self {
138            AzureOpenAIAuth::ApiKey(api_key) => (
139                "api-key".parse().expect("Header value should parse"),
140                api_key.parse().expect("API key should parse"),
141            ),
142            AzureOpenAIAuth::Token(token) => (
143                AUTHORIZATION,
144                format!("Bearer {token}")
145                    .parse()
146                    .expect("Token should parse"),
147            ),
148        }
149    }
150}
151
152impl<T> Client<T>
153where
154    T: Default,
155{
156    /// Create a new Azure OpenAI client builder.
157    ///
158    /// # Example
159    /// ```
160    /// use rig::providers::azure::{ClientBuilder, self};
161    ///
162    /// // Initialize the Azure OpenAI client
163    /// let azure = Client::builder("your-azure-api-key", "https://{your-resource-name}.openai.azure.com")
164    ///    .build()
165    /// ```
166    pub fn builder(auth: impl Into<AzureOpenAIAuth>, endpoint: &str) -> ClientBuilder<'_, T> {
167        ClientBuilder::new(auth, endpoint)
168    }
169
170    /// Creates a new Azure OpenAI client. For more control, use the `builder` method.
171    pub fn new(auth: impl Into<AzureOpenAIAuth>, endpoint: &str) -> Self {
172        Self::builder(auth, endpoint).build()
173    }
174}
175
176impl<T> Client<T>
177where
178    T: HttpClientExt,
179{
180    fn post(&self, url: String) -> http_client::Builder {
181        let (key, value) = self.auth.as_header();
182
183        http_client::Request::post(url).header(key, value)
184    }
185
186    fn post_embedding(&self, deployment_id: &str) -> http_client::Builder {
187        let url = format!(
188            "{}/openai/deployments/{}/embeddings?api-version={}",
189            self.azure_endpoint,
190            deployment_id.trim_start_matches('/'),
191            self.api_version
192        );
193
194        self.post(url)
195    }
196
197    async fn send<U, R>(
198        &self,
199        req: http_client::Request<U>,
200    ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
201    where
202        U: Into<Bytes> + Send,
203        R: From<Bytes> + Send + 'static,
204    {
205        self.http_client.send(req).await
206    }
207}
208
209impl Client<reqwest::Client> {
210    fn reqwest_post(&self, url: String) -> reqwest::RequestBuilder {
211        let (key, val) = self.auth.as_header();
212
213        self.http_client.post(url).header(key, val)
214    }
215
216    #[cfg(feature = "audio")]
217    fn post_audio_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
218        let url = format!(
219            "{}/openai/deployments/{}/audio/speech?api-version={}",
220            self.azure_endpoint, deployment_id, self.api_version
221        )
222        .replace("//", "/");
223
224        self.reqwest_post(url)
225    }
226
227    fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder {
228        let url = format!(
229            "{}/openai/deployments/{}/chat/completions?api-version={}",
230            self.azure_endpoint, deployment_id, self.api_version
231        )
232        .replace("//", "/");
233
234        self.reqwest_post(url)
235    }
236
237    fn post_transcription(&self, deployment_id: &str) -> reqwest::RequestBuilder {
238        let url = format!(
239            "{}/openai/deployments/{}/audio/translations?api-version={}",
240            self.azure_endpoint, deployment_id, self.api_version
241        )
242        .replace("//", "/");
243
244        self.reqwest_post(url)
245    }
246
247    #[cfg(feature = "image")]
248    fn post_image_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
249        let url = format!(
250            "{}/openai/deployments/{}/images/generations?api-version={}",
251            self.azure_endpoint, deployment_id, self.api_version
252        )
253        .replace("//", "/");
254
255        self.reqwest_post(url)
256    }
257}
258
259impl ProviderClient for Client<reqwest::Client> {
260    /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
261    fn from_env() -> Self {
262        let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
263            AzureOpenAIAuth::ApiKey(api_key)
264        } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
265            AzureOpenAIAuth::Token(token)
266        } else {
267            panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
268        };
269
270        let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
271        let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
272
273        Self::builder(auth, &azure_endpoint)
274            .api_version(&api_version)
275            .build()
276    }
277
278    fn from_val(input: crate::client::ProviderValue) -> Self {
279        let crate::client::ProviderValue::ApiKeyWithVersionAndHeader(api_key, version, header) =
280            input
281        else {
282            panic!("Incorrect provider value type")
283        };
284        let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
285        Self::builder(auth, &header).api_version(&version).build()
286    }
287}
288
289impl CompletionClient for Client<reqwest::Client> {
290    type CompletionModel = CompletionModel<reqwest::Client>;
291
292    /// Create a completion model with the given name.
293    ///
294    /// # Example
295    /// ```
296    /// use rig::providers::azure::{Client, self};
297    ///
298    /// // Initialize the Azure OpenAI client
299    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
300    ///
301    /// let gpt4 = azure.completion_model(azure::GPT_4);
302    /// ```
303    fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
304        CompletionModel::new(self.clone(), model)
305    }
306}
307
308impl EmbeddingsClient for Client<reqwest::Client> {
309    type EmbeddingModel = EmbeddingModel<reqwest::Client>;
310
311    /// Create an embedding model with the given name.
312    /// Note: default embedding dimension of 0 will be used if model is not known.
313    /// If this is the case, it's better to use function `embedding_model_with_ndims`
314    ///
315    /// # Example
316    /// ```
317    /// use rig::providers::azure::{Client, self};
318    ///
319    /// // Initialize the Azure OpenAI client
320    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
321    ///
322    /// let embedding_model = azure.embedding_model(azure::TEXT_EMBEDDING_3_LARGE);
323    /// ```
324    fn embedding_model(&self, model: &str) -> EmbeddingModel<reqwest::Client> {
325        let ndims = match model {
326            TEXT_EMBEDDING_3_LARGE => 3072,
327            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
328            _ => 0,
329        };
330        EmbeddingModel::new(self.clone(), model, ndims)
331    }
332
333    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
334    ///
335    /// # Example
336    /// ```
337    /// use rig::providers::azure::{Client, self};
338    ///
339    /// // Initialize the Azure OpenAI client
340    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
341    ///
342    /// let embedding_model = azure.embedding_model("model-unknown-to-rig", 3072);
343    /// ```
344    fn embedding_model_with_ndims(
345        &self,
346        model: &str,
347        ndims: usize,
348    ) -> EmbeddingModel<reqwest::Client> {
349        EmbeddingModel::new(self.clone(), model, ndims)
350    }
351}
352
353impl TranscriptionClient for Client<reqwest::Client> {
354    type TranscriptionModel = TranscriptionModel<reqwest::Client>;
355
356    /// Create a transcription model with the given name.
357    ///
358    /// # Example
359    /// ```
360    /// use rig::providers::azure::{Client, self};
361    ///
362    /// // Initialize the Azure OpenAI client
363    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
364    ///
365    /// let whisper = azure.transcription_model("model-unknown-to-rig");
366    /// ```
367    fn transcription_model(&self, model: &str) -> TranscriptionModel<reqwest::Client> {
368        TranscriptionModel::new(self.clone(), model)
369    }
370}
371
372#[derive(Debug, Deserialize)]
373struct ApiErrorResponse {
374    message: String,
375}
376
377#[derive(Debug, Deserialize)]
378#[serde(untagged)]
379enum ApiResponse<T> {
380    Ok(T),
381    Err(ApiErrorResponse),
382}
383
384// ================================================================
385// Azure OpenAI Embedding API
386// ================================================================
387/// `text-embedding-3-large` embedding model
388pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
389/// `text-embedding-3-small` embedding model
390pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
391/// `text-embedding-ada-002` embedding model
392pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
393
394#[derive(Debug, Deserialize)]
395pub struct EmbeddingResponse {
396    pub object: String,
397    pub data: Vec<EmbeddingData>,
398    pub model: String,
399    pub usage: Usage,
400}
401
402impl From<ApiErrorResponse> for EmbeddingError {
403    fn from(err: ApiErrorResponse) -> Self {
404        EmbeddingError::ProviderError(err.message)
405    }
406}
407
408impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
409    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
410        match value {
411            ApiResponse::Ok(response) => Ok(response),
412            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
413        }
414    }
415}
416
417#[derive(Debug, Deserialize)]
418pub struct EmbeddingData {
419    pub object: String,
420    pub embedding: Vec<f64>,
421    pub index: usize,
422}
423
424#[derive(Clone, Debug, Deserialize)]
425pub struct Usage {
426    pub prompt_tokens: usize,
427    pub total_tokens: usize,
428}
429
430impl GetTokenUsage for Usage {
431    fn token_usage(&self) -> Option<crate::completion::Usage> {
432        let mut usage = crate::completion::Usage::new();
433
434        usage.input_tokens = self.prompt_tokens as u64;
435        usage.total_tokens = self.total_tokens as u64;
436        usage.output_tokens = usage.total_tokens - usage.input_tokens;
437
438        Some(usage)
439    }
440}
441
442impl std::fmt::Display for Usage {
443    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444        write!(
445            f,
446            "Prompt tokens: {} Total tokens: {}",
447            self.prompt_tokens, self.total_tokens
448        )
449    }
450}
451
452#[derive(Clone)]
453pub struct EmbeddingModel<T = reqwest::Client> {
454    client: Client<T>,
455    pub model: String,
456    ndims: usize,
457}
458
459impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
460where
461    T: HttpClientExt + Default + Clone,
462{
463    const MAX_DOCUMENTS: usize = 1024;
464
465    fn ndims(&self) -> usize {
466        self.ndims
467    }
468
469    #[cfg_attr(feature = "worker", worker::send)]
470    async fn embed_texts(
471        &self,
472        documents: impl IntoIterator<Item = String>,
473    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
474        let documents = documents.into_iter().collect::<Vec<_>>();
475
476        let body = serde_json::to_vec(&json!({
477            "input": documents,
478        }))?;
479
480        let req = self
481            .client
482            .post_embedding(&self.model)
483            .header("Content-Type", "application/json")
484            .body(body)
485            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
486
487        let response = self.client.send(req).await?;
488
489        if response.status().is_success() {
490            let body: Vec<u8> = response.into_body().await?;
491            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
492
493            match body {
494                ApiResponse::Ok(response) => {
495                    tracing::info!(target: "rig",
496                        "Azure embedding token usage: {}",
497                        response.usage
498                    );
499
500                    if response.data.len() != documents.len() {
501                        return Err(EmbeddingError::ResponseError(
502                            "Response data length does not match input length".into(),
503                        ));
504                    }
505
506                    Ok(response
507                        .data
508                        .into_iter()
509                        .zip(documents.into_iter())
510                        .map(|(embedding, document)| embeddings::Embedding {
511                            document,
512                            vec: embedding.embedding,
513                        })
514                        .collect())
515                }
516                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
517            }
518        } else {
519            let text = http_client::text(response).await?;
520            Err(EmbeddingError::ProviderError(text))
521        }
522    }
523}
524
525impl<T> EmbeddingModel<T> {
526    pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
527        Self {
528            client,
529            model: model.to_string(),
530            ndims,
531        }
532    }
533}
534
535// ================================================================
536// Azure OpenAI Completion API
537// ================================================================
538/// `o1` completion model
539pub const O1: &str = "o1";
540/// `o1-preview` completion model
541pub const O1_PREVIEW: &str = "o1-preview";
542/// `o1-mini` completion model
543pub const O1_MINI: &str = "o1-mini";
544/// `gpt-4o` completion model
545pub const GPT_4O: &str = "gpt-4o";
546/// `gpt-4o-mini` completion model
547pub const GPT_4O_MINI: &str = "gpt-4o-mini";
548/// `gpt-4o-realtime-preview` completion model
549pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
550/// `gpt-4-turbo` completion model
551pub const GPT_4_TURBO: &str = "gpt-4";
552/// `gpt-4` completion model
553pub const GPT_4: &str = "gpt-4";
554/// `gpt-4-32k` completion model
555pub const GPT_4_32K: &str = "gpt-4-32k";
556/// `gpt-4-32k` completion model
557pub const GPT_4_32K_0613: &str = "gpt-4-32k";
558/// `gpt-3.5-turbo` completion model
559pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
560/// `gpt-3.5-turbo-instruct` completion model
561pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
562/// `gpt-3.5-turbo-16k` completion model
563pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
564
565#[derive(Clone)]
566pub struct CompletionModel<T = reqwest::Client> {
567    client: Client<T>,
568    /// Name of the model (e.g.: gpt-4o-mini)
569    pub model: String,
570}
571
572impl<T> CompletionModel<T> {
573    pub fn new(client: Client<T>, model: &str) -> Self {
574        Self {
575            client,
576            model: model.to_string(),
577        }
578    }
579
580    fn create_completion_request(
581        &self,
582        completion_request: CompletionRequest,
583    ) -> Result<serde_json::Value, CompletionError> {
584        let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
585            Some(preamble) => vec![openai::Message::system(preamble)],
586            None => vec![],
587        };
588        if let Some(docs) = completion_request.normalized_documents() {
589            let docs: Vec<openai::Message> = docs.try_into()?;
590            full_history.extend(docs);
591        }
592        let chat_history: Vec<openai::Message> = completion_request
593            .chat_history
594            .into_iter()
595            .map(|message| message.try_into())
596            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
597            .into_iter()
598            .flatten()
599            .collect();
600
601        full_history.extend(chat_history);
602
603        let request = if completion_request.tools.is_empty() {
604            json!({
605                "model": self.model,
606                "messages": full_history,
607                "temperature": completion_request.temperature,
608            })
609        } else {
610            json!({
611                "model": self.model,
612                "messages": full_history,
613                "temperature": completion_request.temperature,
614                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
615                "tool_choice": "auto",
616            })
617        };
618
619        let request = if let Some(params) = completion_request.additional_params {
620            json_utils::merge(request, params)
621        } else {
622            request
623        };
624
625        Ok(request)
626    }
627}
628
629impl completion::CompletionModel for CompletionModel<reqwest::Client> {
630    type Response = openai::CompletionResponse;
631    type StreamingResponse = openai::StreamingCompletionResponse;
632
633    #[cfg_attr(feature = "worker", worker::send)]
634    async fn completion(
635        &self,
636        completion_request: CompletionRequest,
637    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
638        let span = if tracing::Span::current().is_disabled() {
639            info_span!(
640                target: "rig::completions",
641                "chat",
642                gen_ai.operation.name = "chat",
643                gen_ai.provider.name = "azure.openai",
644                gen_ai.request.model = self.model,
645                gen_ai.system_instructions = &completion_request.preamble,
646                gen_ai.response.id = tracing::field::Empty,
647                gen_ai.response.model = tracing::field::Empty,
648                gen_ai.usage.output_tokens = tracing::field::Empty,
649                gen_ai.usage.input_tokens = tracing::field::Empty,
650                gen_ai.input.messages = tracing::field::Empty,
651                gen_ai.output.messages = tracing::field::Empty,
652            )
653        } else {
654            tracing::Span::current()
655        };
656        let request = self.create_completion_request(completion_request)?;
657        span.record_model_input(
658            &request
659                .get("messages")
660                .expect("Converting JSON should not fail"),
661        );
662
663        async move {
664            let response = self
665                .client
666                .post_chat_completion(&self.model)
667                .json(&request)
668                .send()
669                .await
670                .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?;
671
672            if response.status().is_success() {
673                let t = response.text().await.map_err(|e| {
674                    CompletionError::HttpError(http_client::Error::Instance(e.into()))
675                })?;
676                tracing::debug!(target: "rig", "Azure completion error: {}", t);
677
678                match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
679                    ApiResponse::Ok(response) => {
680                        let span = tracing::Span::current();
681                        span.record_model_output(&response.choices);
682                        span.record_response_metadata(&response);
683                        span.record_token_usage(&response.usage);
684                        response.try_into()
685                    }
686                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
687                }
688            } else {
689                Err(CompletionError::ProviderError(
690                    response.text().await.map_err(|e| {
691                        CompletionError::HttpError(http_client::Error::Instance(e.into()))
692                    })?,
693                ))
694            }
695        }
696        .instrument(span)
697        .await
698    }
699
700    #[cfg_attr(feature = "worker", worker::send)]
701    async fn stream(
702        &self,
703        request: CompletionRequest,
704    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
705        let preamble = request.preamble.clone();
706        let mut request = self.create_completion_request(request)?;
707
708        request = merge(
709            request,
710            json!({"stream": true, "stream_options": {"include_usage": true}}),
711        );
712
713        let builder = self
714            .client
715            .post_chat_completion(self.model.as_str())
716            .header("Content-Type", "application/json")
717            .json(&request);
718
719        let span = if tracing::Span::current().is_disabled() {
720            info_span!(
721                target: "rig::completions",
722                "chat_streaming",
723                gen_ai.operation.name = "chat_streaming",
724                gen_ai.provider.name = "azure.openai",
725                gen_ai.request.model = self.model,
726                gen_ai.system_instructions = &preamble,
727                gen_ai.response.id = tracing::field::Empty,
728                gen_ai.response.model = tracing::field::Empty,
729                gen_ai.usage.output_tokens = tracing::field::Empty,
730                gen_ai.usage.input_tokens = tracing::field::Empty,
731                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
732                gen_ai.output.messages = tracing::field::Empty,
733            )
734        } else {
735            tracing::Span::current()
736        };
737
738        tracing_futures::Instrument::instrument(send_compatible_streaming_request(builder), span)
739            .await
740    }
741}
742
743// ================================================================
744// Azure OpenAI Transcription API
745// ================================================================
746
747#[derive(Clone)]
748pub struct TranscriptionModel<T = reqwest::Client> {
749    client: Client<T>,
750    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
751    pub model: String,
752}
753
754impl<T> TranscriptionModel<T> {
755    pub fn new(client: Client<T>, model: &str) -> Self {
756        Self {
757            client,
758            model: model.to_string(),
759        }
760    }
761}
762
763impl transcription::TranscriptionModel for TranscriptionModel<reqwest::Client> {
764    type Response = TranscriptionResponse;
765
766    #[cfg_attr(feature = "worker", worker::send)]
767    async fn transcription(
768        &self,
769        request: transcription::TranscriptionRequest,
770    ) -> Result<
771        transcription::TranscriptionResponse<Self::Response>,
772        transcription::TranscriptionError,
773    > {
774        let data = request.data;
775
776        let mut body = reqwest::multipart::Form::new().part(
777            "file",
778            Part::bytes(data).file_name(request.filename.clone()),
779        );
780
781        if let Some(prompt) = request.prompt {
782            body = body.text("prompt", prompt.clone());
783        }
784
785        if let Some(ref temperature) = request.temperature {
786            body = body.text("temperature", temperature.to_string());
787        }
788
789        if let Some(ref additional_params) = request.additional_params {
790            for (key, value) in additional_params
791                .as_object()
792                .expect("Additional Parameters to OpenAI Transcription should be a map")
793            {
794                body = body.text(key.to_owned(), value.to_string());
795            }
796        }
797
798        let response = self
799            .client
800            .post_transcription(&self.model)
801            .header("Content-Type", "application/json")
802            .multipart(body)
803            .send()
804            .await
805            .map_err(|e| TranscriptionError::HttpError(http_client::Error::Instance(e.into())))?;
806
807        if response.status().is_success() {
808            match response
809                .json::<ApiResponse<TranscriptionResponse>>()
810                .await
811                .map_err(|e| {
812                    TranscriptionError::HttpError(http_client::Error::Instance(e.into()))
813                })? {
814                ApiResponse::Ok(response) => response.try_into(),
815                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
816                    api_error_response.message,
817                )),
818            }
819        } else {
820            Err(TranscriptionError::ProviderError(
821                response.text().await.map_err(|e| {
822                    TranscriptionError::HttpError(http_client::Error::Instance(e.into()))
823                })?,
824            ))
825        }
826    }
827}
828
829// ================================================================
830// Azure OpenAI Image Generation API
831// ================================================================
832#[cfg(feature = "image")]
833pub use image_generation::*;
834use tracing::{Instrument, info_span};
835#[cfg(feature = "image")]
836#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
837mod image_generation {
838    use crate::client::ImageGenerationClient;
839    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
840    use crate::providers::azure::{ApiResponse, Client};
841    use crate::providers::openai::ImageGenerationResponse;
842    use crate::{http_client, image_generation};
843    use serde_json::json;
844
845    #[derive(Clone)]
846    pub struct ImageGenerationModel<T = reqwest::Client> {
847        client: Client<T>,
848        pub model: String,
849    }
850    impl image_generation::ImageGenerationModel for ImageGenerationModel<reqwest::Client> {
851        type Response = ImageGenerationResponse;
852
853        #[cfg_attr(feature = "worker", worker::send)]
854        async fn image_generation(
855            &self,
856            generation_request: ImageGenerationRequest,
857        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
858        {
859            let request = json!({
860                "model": self.model,
861                "prompt": generation_request.prompt,
862                "size": format!("{}x{}", generation_request.width, generation_request.height),
863                "response_format": "b64_json"
864            });
865
866            let response = self
867                .client
868                .post_image_generation(&self.model)
869                .header("Content-Type", "application/json")
870                .json(&request)
871                .send()
872                .await
873                .map_err(|e| {
874                    ImageGenerationError::HttpError(http_client::Error::Instance(e.into()))
875                })?;
876
877            if !response.status().is_success() {
878                return Err(ImageGenerationError::ProviderError(format!(
879                    "{}: {}",
880                    response.status(),
881                    response.text().await.map_err(|e| {
882                        ImageGenerationError::HttpError(http_client::Error::Instance(e.into()))
883                    })?
884                )));
885            }
886
887            let t = response.text().await.map_err(|e| {
888                ImageGenerationError::HttpError(http_client::Error::Instance(e.into()))
889            })?;
890
891            match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&t)? {
892                ApiResponse::Ok(response) => response.try_into(),
893                ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
894            }
895        }
896    }
897
898    impl ImageGenerationClient for Client<reqwest::Client> {
899        type ImageGenerationModel = ImageGenerationModel<reqwest::Client>;
900
901        fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
902            ImageGenerationModel {
903                client: self.clone(),
904                model: model.to_string(),
905            }
906        }
907    }
908}
909// ================================================================
910// Azure OpenAI Audio Generation API
911// ================================================================
912
913use crate::client::{
914    CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient, VerifyClient,
915    VerifyError,
916};
917#[cfg(feature = "audio")]
918pub use audio_generation::*;
919
920#[cfg(feature = "audio")]
921#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
922mod audio_generation {
923    use super::Client;
924    use crate::audio_generation::{
925        AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
926    };
927    use crate::client::AudioGenerationClient;
928    use crate::{audio_generation, http_client};
929    use bytes::Bytes;
930    use serde_json::json;
931
932    #[derive(Clone)]
933    pub struct AudioGenerationModel<T = reqwest::Client> {
934        client: Client<T>,
935        model: String,
936    }
937
938    impl audio_generation::AudioGenerationModel for AudioGenerationModel<reqwest::Client> {
939        type Response = Bytes;
940
941        #[cfg_attr(feature = "worker", worker::send)]
942        async fn audio_generation(
943            &self,
944            request: AudioGenerationRequest,
945        ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
946            let request = json!({
947                "model": self.model,
948                "input": request.text,
949                "voice": request.voice,
950                "speed": request.speed,
951            });
952
953            let response = self
954                .client
955                .post_audio_generation("/audio/speech")
956                .header("Content-Type", "application/json")
957                .json(&request)
958                .send()
959                .await
960                .map_err(|e| {
961                    AudioGenerationError::HttpError(http_client::Error::Instance(e.into()))
962                })?;
963
964            if !response.status().is_success() {
965                return Err(AudioGenerationError::ProviderError(format!(
966                    "{}: {}",
967                    response.status(),
968                    response.text().await.map_err(|e| {
969                        AudioGenerationError::HttpError(http_client::Error::Instance(e.into()))
970                    })?
971                )));
972            }
973
974            let bytes = response.bytes().await.map_err(|e| {
975                AudioGenerationError::HttpError(http_client::Error::Instance(e.into()))
976            })?;
977
978            Ok(AudioGenerationResponse {
979                audio: bytes.to_vec(),
980                response: bytes,
981            })
982        }
983    }
984
985    impl AudioGenerationClient for Client<reqwest::Client> {
986        type AudioGenerationModel = AudioGenerationModel<reqwest::Client>;
987
988        fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
989            AudioGenerationModel {
990                client: self.clone(),
991                model: model.to_string(),
992            }
993        }
994    }
995}
996
997impl VerifyClient for Client<reqwest::Client> {
998    #[cfg_attr(feature = "worker", worker::send)]
999    async fn verify(&self) -> Result<(), VerifyError> {
1000        // There is currently no way to verify the Azure OpenAI API key or token without
1001        // consuming tokens
1002        Ok(())
1003    }
1004}
1005
1006#[cfg(test)]
1007mod azure_tests {
1008    use super::*;
1009
1010    use crate::OneOrMany;
1011    use crate::completion::CompletionModel;
1012    use crate::embeddings::EmbeddingModel;
1013
1014    #[tokio::test]
1015    #[ignore]
1016    async fn test_azure_embedding() {
1017        let _ = tracing_subscriber::fmt::try_init();
1018
1019        let client = Client::from_env();
1020        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1021        let embeddings = model
1022            .embed_texts(vec!["Hello, world!".to_string()])
1023            .await
1024            .unwrap();
1025
1026        tracing::info!("Azure embedding: {:?}", embeddings);
1027    }
1028
1029    #[tokio::test]
1030    #[ignore]
1031    async fn test_azure_completion() {
1032        let _ = tracing_subscriber::fmt::try_init();
1033
1034        let client = Client::from_env();
1035        let model = client.completion_model(GPT_4O_MINI);
1036        let completion = model
1037            .completion(CompletionRequest {
1038                preamble: Some("You are a helpful assistant.".to_string()),
1039                chat_history: OneOrMany::one("Hello!".into()),
1040                documents: vec![],
1041                max_tokens: Some(100),
1042                temperature: Some(0.0),
1043                tools: vec![],
1044                tool_choice: None,
1045                additional_params: None,
1046            })
1047            .await
1048            .unwrap();
1049
1050        tracing::info!("Azure completion: {:?}", completion);
1051    }
1052}