rig/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::azure;
6//!
7//! let client = azure::Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
8//!
9//! let gpt4o = client.completion_model(azure::GPT_4O);
10//! ```
11
12use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
13
14use crate::completion::GetTokenUsage;
15use crate::http_client::{self, HttpClientExt};
16use crate::json_utils::merge;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19    completion::{self, CompletionError, CompletionRequest},
20    embeddings::{self, EmbeddingError},
21    json_utils,
22    providers::openai,
23    telemetry::SpanCombinator,
24    transcription::{self, TranscriptionError},
25};
26use bytes::Bytes;
27use reqwest::header::AUTHORIZATION;
28use reqwest::multipart::Part;
29use serde::Deserialize;
30use serde_json::json;
31// ================================================================
32// Main Azure OpenAI Client
33// ================================================================
34
35const DEFAULT_API_VERSION: &str = "2024-10-21";
36
37pub struct ClientBuilder<'a, T = reqwest::Client> {
38    auth: AzureOpenAIAuth,
39    api_version: Option<&'a str>,
40    azure_endpoint: &'a str,
41    http_client: T,
42}
43
44impl<'a, T> ClientBuilder<'a, T>
45where
46    T: Default,
47{
48    pub fn new(auth: impl Into<AzureOpenAIAuth>, endpoint: &'a str) -> Self {
49        Self {
50            auth: auth.into(),
51            api_version: None,
52            azure_endpoint: endpoint,
53            http_client: Default::default(),
54        }
55    }
56}
57
58impl<'a, T> ClientBuilder<'a, T> {
59    pub fn new_with_client(
60        auth: impl Into<AzureOpenAIAuth>,
61        endpoint: &'a str,
62        http_client: T,
63    ) -> Self {
64        Self {
65            auth: auth.into(),
66            api_version: None,
67            azure_endpoint: endpoint,
68            http_client,
69        }
70    }
71
72    /// API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
73    pub fn api_version(mut self, api_version: &'a str) -> Self {
74        self.api_version = Some(api_version);
75        self
76    }
77
78    /// Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
79    pub fn azure_endpoint(mut self, azure_endpoint: &'a str) -> Self {
80        self.azure_endpoint = azure_endpoint;
81        self
82    }
83
84    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
85        ClientBuilder {
86            auth: self.auth,
87            api_version: self.api_version,
88            azure_endpoint: self.azure_endpoint,
89            http_client,
90        }
91    }
92
93    pub fn build(self) -> Client<T> {
94        let api_version = self.api_version.unwrap_or(DEFAULT_API_VERSION);
95
96        Client {
97            api_version: api_version.to_string(),
98            azure_endpoint: self.azure_endpoint.to_string(),
99            auth: self.auth,
100            http_client: self.http_client,
101        }
102    }
103}
104
105#[derive(Clone)]
106pub struct Client<T = reqwest::Client> {
107    api_version: String,
108    azure_endpoint: String,
109    auth: AzureOpenAIAuth,
110    http_client: T,
111}
112
113impl<T> std::fmt::Debug for Client<T>
114where
115    T: std::fmt::Debug,
116{
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        f.debug_struct("Client")
119            .field("azure_endpoint", &self.azure_endpoint)
120            .field("http_client", &self.http_client)
121            .field("auth", &"<REDACTED>")
122            .field("api_version", &self.api_version)
123            .finish()
124    }
125}
126
127#[derive(Clone)]
128pub enum AzureOpenAIAuth {
129    ApiKey(String),
130    Token(String),
131}
132
133impl std::fmt::Debug for AzureOpenAIAuth {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        match self {
136            Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
137            Self::Token(_) => write!(f, "Token <REDACTED>"),
138        }
139    }
140}
141
142impl From<String> for AzureOpenAIAuth {
143    fn from(token: String) -> Self {
144        AzureOpenAIAuth::Token(token)
145    }
146}
147
148impl AzureOpenAIAuth {
149    fn as_header(&self) -> (reqwest::header::HeaderName, reqwest::header::HeaderValue) {
150        match self {
151            AzureOpenAIAuth::ApiKey(api_key) => (
152                "api-key".parse().expect("Header value should parse"),
153                api_key.parse().expect("API key should parse"),
154            ),
155            AzureOpenAIAuth::Token(token) => (
156                AUTHORIZATION,
157                format!("Bearer {token}")
158                    .parse()
159                    .expect("Token should parse"),
160            ),
161        }
162    }
163}
164
165impl Client<reqwest::Client> {
166    /// Create a new Azure OpenAI client builder.
167    ///
168    /// # Example
169    /// ```
170    /// use rig::providers::azure::{ClientBuilder, self};
171    ///
172    /// // Initialize the Azure OpenAI client
173    /// let azure = Client::builder("your-azure-api-key", "https://{your-resource-name}.openai.azure.com")
174    ///    .build()
175    /// ```
176    pub fn builder(
177        auth: impl Into<AzureOpenAIAuth>,
178        endpoint: &str,
179    ) -> ClientBuilder<'_, reqwest::Client> {
180        ClientBuilder::new(auth, endpoint)
181    }
182
183    /// Creates a new Azure OpenAI client. For more control, use the `builder` method.
184    pub fn new(auth: impl Into<AzureOpenAIAuth>, endpoint: &str) -> Self {
185        Self::builder(auth, endpoint).build()
186    }
187
188    pub fn from_env() -> Self {
189        <Self as ProviderClient>::from_env()
190    }
191}
192
193impl<T> Client<T>
194where
195    T: HttpClientExt,
196{
197    fn post(&self, url: String) -> http_client::Builder {
198        let (key, value) = self.auth.as_header();
199
200        http_client::Request::post(url).header(key, value)
201    }
202
203    fn post_embedding(&self, deployment_id: &str) -> http_client::Builder {
204        let url = format!(
205            "{}/openai/deployments/{}/embeddings?api-version={}",
206            self.azure_endpoint,
207            deployment_id.trim_start_matches('/'),
208            self.api_version
209        );
210
211        self.post(url)
212    }
213
214    #[cfg(feature = "audio")]
215    fn post_audio_generation(&self, deployment_id: &str) -> http_client::Builder {
216        let url = format!(
217            "{}/openai/deployments/{}/audio/speech?api-version={}",
218            self.azure_endpoint,
219            deployment_id.trim_start_matches('/'),
220            self.api_version
221        );
222
223        self.post(url)
224    }
225
226    fn post_chat_completion(&self, deployment_id: &str) -> http_client::Builder {
227        let url = format!(
228            "{}/openai/deployments/{}/chat/completions?api-version={}",
229            self.azure_endpoint,
230            deployment_id.trim_start_matches('/'),
231            self.api_version
232        );
233
234        self.post(url)
235    }
236
237    fn post_transcription(&self, deployment_id: &str) -> http_client::Builder {
238        let url = format!(
239            "{}/openai/deployments/{}/audio/translations?api-version={}",
240            self.azure_endpoint,
241            deployment_id.trim_start_matches('/'),
242            self.api_version
243        );
244
245        self.post(url)
246    }
247
248    #[cfg(feature = "image")]
249    fn post_image_generation(&self, deployment_id: &str) -> http_client::Builder {
250        let url = format!(
251            "{}/openai/deployments/{}/images/generations?api-version={}",
252            self.azure_endpoint,
253            deployment_id.trim_start_matches('/'),
254            self.api_version
255        );
256
257        self.post(url)
258    }
259
260    async fn send<U, R>(
261        &self,
262        req: http_client::Request<U>,
263    ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
264    where
265        U: Into<Bytes> + Send,
266        R: From<Bytes> + Send + 'static,
267    {
268        self.http_client.send(req).await
269    }
270}
271
272impl<T> ProviderClient for Client<T>
273where
274    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
275{
276    /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
277    fn from_env() -> Self {
278        let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
279            AzureOpenAIAuth::ApiKey(api_key)
280        } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
281            AzureOpenAIAuth::Token(token)
282        } else {
283            panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
284        };
285
286        let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
287        let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
288
289        ClientBuilder::<T>::new(auth, &azure_endpoint)
290            .api_version(&api_version)
291            .build()
292    }
293
294    fn from_val(input: crate::client::ProviderValue) -> Self {
295        let crate::client::ProviderValue::ApiKeyWithVersionAndHeader(api_key, version, header) =
296            input
297        else {
298            panic!("Incorrect provider value type")
299        };
300        let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
301        ClientBuilder::<T>::new(auth, &header)
302            .api_version(&version)
303            .build()
304    }
305}
306
307impl<T> CompletionClient for Client<T>
308where
309    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
310{
311    type CompletionModel = CompletionModel<T>;
312
313    /// Create a completion model with the given name.
314    ///
315    /// # Example
316    /// ```
317    /// use rig::providers::azure::{Client, self};
318    ///
319    /// // Initialize the Azure OpenAI client
320    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
321    ///
322    /// let gpt4 = azure.completion_model(azure::GPT_4);
323    /// ```
324    fn completion_model(&self, model: &str) -> Self::CompletionModel {
325        CompletionModel::new(self.clone(), model)
326    }
327}
328
329impl<T> EmbeddingsClient for Client<T>
330where
331    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
332{
333    type EmbeddingModel = EmbeddingModel<T>;
334
335    /// Create an embedding model with the given name.
336    /// Note: default embedding dimension of 0 will be used if model is not known.
337    /// If this is the case, it's better to use function `embedding_model_with_ndims`
338    ///
339    /// # Example
340    /// ```
341    /// use rig::providers::azure::{Client, self};
342    ///
343    /// // Initialize the Azure OpenAI client
344    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
345    ///
346    /// let embedding_model = azure.embedding_model(azure::TEXT_EMBEDDING_3_LARGE);
347    /// ```
348    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
349        let ndims = match model {
350            TEXT_EMBEDDING_3_LARGE => 3072,
351            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
352            _ => 0,
353        };
354        EmbeddingModel::new(self.clone(), model, ndims)
355    }
356
357    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
358    ///
359    /// # Example
360    /// ```
361    /// use rig::providers::azure::{Client, self};
362    ///
363    /// // Initialize the Azure OpenAI client
364    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
365    ///
366    /// let embedding_model = azure.embedding_model("model-unknown-to-rig", 3072);
367    /// ```
368    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
369        EmbeddingModel::new(self.clone(), model, ndims)
370    }
371}
372
373impl<T> TranscriptionClient for Client<T>
374where
375    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
376{
377    type TranscriptionModel = TranscriptionModel<T>;
378
379    /// Create a transcription model with the given name.
380    ///
381    /// # Example
382    /// ```
383    /// use rig::providers::azure::{Client, self};
384    ///
385    /// // Initialize the Azure OpenAI client
386    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
387    ///
388    /// let whisper = azure.transcription_model("model-unknown-to-rig");
389    /// ```
390    fn transcription_model(&self, model: &str) -> Self::TranscriptionModel {
391        TranscriptionModel::new(self.clone(), model)
392    }
393}
394
395#[derive(Debug, Deserialize)]
396struct ApiErrorResponse {
397    message: String,
398}
399
400#[derive(Debug, Deserialize)]
401#[serde(untagged)]
402enum ApiResponse<T> {
403    Ok(T),
404    Err(ApiErrorResponse),
405}
406
407// ================================================================
408// Azure OpenAI Embedding API
409// ================================================================
410/// `text-embedding-3-large` embedding model
411pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
412/// `text-embedding-3-small` embedding model
413pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
414/// `text-embedding-ada-002` embedding model
415pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
416
417#[derive(Debug, Deserialize)]
418pub struct EmbeddingResponse {
419    pub object: String,
420    pub data: Vec<EmbeddingData>,
421    pub model: String,
422    pub usage: Usage,
423}
424
425impl From<ApiErrorResponse> for EmbeddingError {
426    fn from(err: ApiErrorResponse) -> Self {
427        EmbeddingError::ProviderError(err.message)
428    }
429}
430
431impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
432    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
433        match value {
434            ApiResponse::Ok(response) => Ok(response),
435            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
436        }
437    }
438}
439
440#[derive(Debug, Deserialize)]
441pub struct EmbeddingData {
442    pub object: String,
443    pub embedding: Vec<f64>,
444    pub index: usize,
445}
446
447#[derive(Clone, Debug, Deserialize)]
448pub struct Usage {
449    pub prompt_tokens: usize,
450    pub total_tokens: usize,
451}
452
453impl GetTokenUsage for Usage {
454    fn token_usage(&self) -> Option<crate::completion::Usage> {
455        let mut usage = crate::completion::Usage::new();
456
457        usage.input_tokens = self.prompt_tokens as u64;
458        usage.total_tokens = self.total_tokens as u64;
459        usage.output_tokens = usage.total_tokens - usage.input_tokens;
460
461        Some(usage)
462    }
463}
464
465impl std::fmt::Display for Usage {
466    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467        write!(
468            f,
469            "Prompt tokens: {} Total tokens: {}",
470            self.prompt_tokens, self.total_tokens
471        )
472    }
473}
474
475#[derive(Clone)]
476pub struct EmbeddingModel<T = reqwest::Client> {
477    client: Client<T>,
478    pub model: String,
479    ndims: usize,
480}
481
482impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
483where
484    T: HttpClientExt + Default + Clone,
485{
486    const MAX_DOCUMENTS: usize = 1024;
487
488    fn ndims(&self) -> usize {
489        self.ndims
490    }
491
492    #[cfg_attr(feature = "worker", worker::send)]
493    async fn embed_texts(
494        &self,
495        documents: impl IntoIterator<Item = String>,
496    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
497        let documents = documents.into_iter().collect::<Vec<_>>();
498
499        let body = serde_json::to_vec(&json!({
500            "input": documents,
501        }))?;
502
503        let req = self
504            .client
505            .post_embedding(&self.model)
506            .header("Content-Type", "application/json")
507            .body(body)
508            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
509
510        let response = self.client.send(req).await?;
511
512        if response.status().is_success() {
513            let body: Vec<u8> = response.into_body().await?;
514            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
515
516            match body {
517                ApiResponse::Ok(response) => {
518                    tracing::info!(target: "rig",
519                        "Azure embedding token usage: {}",
520                        response.usage
521                    );
522
523                    if response.data.len() != documents.len() {
524                        return Err(EmbeddingError::ResponseError(
525                            "Response data length does not match input length".into(),
526                        ));
527                    }
528
529                    Ok(response
530                        .data
531                        .into_iter()
532                        .zip(documents.into_iter())
533                        .map(|(embedding, document)| embeddings::Embedding {
534                            document,
535                            vec: embedding.embedding,
536                        })
537                        .collect())
538                }
539                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
540            }
541        } else {
542            let text = http_client::text(response).await?;
543            Err(EmbeddingError::ProviderError(text))
544        }
545    }
546}
547
548impl<T> EmbeddingModel<T> {
549    pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
550        Self {
551            client,
552            model: model.to_string(),
553            ndims,
554        }
555    }
556}
557
558// ================================================================
559// Azure OpenAI Completion API
560// ================================================================
561/// `o1` completion model
562pub const O1: &str = "o1";
563/// `o1-preview` completion model
564pub const O1_PREVIEW: &str = "o1-preview";
565/// `o1-mini` completion model
566pub const O1_MINI: &str = "o1-mini";
567/// `gpt-4o` completion model
568pub const GPT_4O: &str = "gpt-4o";
569/// `gpt-4o-mini` completion model
570pub const GPT_4O_MINI: &str = "gpt-4o-mini";
571/// `gpt-4o-realtime-preview` completion model
572pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
573/// `gpt-4-turbo` completion model
574pub const GPT_4_TURBO: &str = "gpt-4";
575/// `gpt-4` completion model
576pub const GPT_4: &str = "gpt-4";
577/// `gpt-4-32k` completion model
578pub const GPT_4_32K: &str = "gpt-4-32k";
579/// `gpt-4-32k` completion model
580pub const GPT_4_32K_0613: &str = "gpt-4-32k";
581/// `gpt-3.5-turbo` completion model
582pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
583/// `gpt-3.5-turbo-instruct` completion model
584pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
585/// `gpt-3.5-turbo-16k` completion model
586pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
587
588#[derive(Clone)]
589pub struct CompletionModel<T = reqwest::Client> {
590    client: Client<T>,
591    /// Name of the model (e.g.: gpt-4o-mini)
592    pub model: String,
593}
594
595impl<T> CompletionModel<T> {
596    pub fn new(client: Client<T>, model: &str) -> Self {
597        Self {
598            client,
599            model: model.to_string(),
600        }
601    }
602
603    fn create_completion_request(
604        &self,
605        completion_request: CompletionRequest,
606    ) -> Result<serde_json::Value, CompletionError> {
607        let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
608            Some(preamble) => vec![openai::Message::system(preamble)],
609            None => vec![],
610        };
611        if let Some(docs) = completion_request.normalized_documents() {
612            let docs: Vec<openai::Message> = docs.try_into()?;
613            full_history.extend(docs);
614        }
615        let chat_history: Vec<openai::Message> = completion_request
616            .chat_history
617            .into_iter()
618            .map(|message| message.try_into())
619            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
620            .into_iter()
621            .flatten()
622            .collect();
623
624        full_history.extend(chat_history);
625
626        let request = if completion_request.tools.is_empty() {
627            json!({
628                "model": self.model,
629                "messages": full_history,
630                "temperature": completion_request.temperature,
631            })
632        } else {
633            json!({
634                "model": self.model,
635                "messages": full_history,
636                "temperature": completion_request.temperature,
637                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
638                "tool_choice": "auto",
639            })
640        };
641
642        let request = if let Some(params) = completion_request.additional_params {
643            json_utils::merge(request, params)
644        } else {
645            request
646        };
647
648        Ok(request)
649    }
650}
651
652impl<T> completion::CompletionModel for CompletionModel<T>
653where
654    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
655{
656    type Response = openai::CompletionResponse;
657    type StreamingResponse = openai::StreamingCompletionResponse;
658
659    #[cfg_attr(feature = "worker", worker::send)]
660    async fn completion(
661        &self,
662        completion_request: CompletionRequest,
663    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
664        let span = if tracing::Span::current().is_disabled() {
665            info_span!(
666                target: "rig::completions",
667                "chat",
668                gen_ai.operation.name = "chat",
669                gen_ai.provider.name = "azure.openai",
670                gen_ai.request.model = self.model,
671                gen_ai.system_instructions = &completion_request.preamble,
672                gen_ai.response.id = tracing::field::Empty,
673                gen_ai.response.model = tracing::field::Empty,
674                gen_ai.usage.output_tokens = tracing::field::Empty,
675                gen_ai.usage.input_tokens = tracing::field::Empty,
676                gen_ai.input.messages = tracing::field::Empty,
677                gen_ai.output.messages = tracing::field::Empty,
678            )
679        } else {
680            tracing::Span::current()
681        };
682        let request = self.create_completion_request(completion_request)?;
683        span.record_model_input(
684            &request
685                .get("messages")
686                .expect("Converting JSON should not fail"),
687        );
688        let body = serde_json::to_vec(&request)?;
689
690        let req = self
691            .client
692            .post_chat_completion(&self.model)
693            .header("Content-Type", "application/json")
694            .body(body)
695            .map_err(http_client::Error::from)?;
696
697        async move {
698            let response = self.client.http_client.send::<_, Bytes>(req).await.unwrap();
699
700            let status = response.status();
701            let response_body = response.into_body().into_future().await?.to_vec();
702
703            if status.is_success() {
704                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(&response_body)? {
705                    ApiResponse::Ok(response) => {
706                        let span = tracing::Span::current();
707                        span.record_model_output(&response.choices);
708                        span.record_response_metadata(&response);
709                        span.record_token_usage(&response.usage);
710                        tracing::debug!(target: "rig", "Azure completion output: {}", serde_json::to_string_pretty(&response)?);
711                        response.try_into()
712                    }
713                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
714                }
715            } else {
716                Err(CompletionError::ProviderError(
717                    String::from_utf8_lossy(&response_body).to_string()
718                ))
719            }
720        }
721        .instrument(span)
722        .await
723    }
724
725    #[cfg_attr(feature = "worker", worker::send)]
726    async fn stream(
727        &self,
728        request: CompletionRequest,
729    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
730        let preamble = request.preamble.clone();
731        let mut request = self.create_completion_request(request)?;
732
733        request = merge(
734            request,
735            json!({"stream": true, "stream_options": {"include_usage": true}}),
736        );
737
738        let body = serde_json::to_vec(&request)?;
739
740        let req = self
741            .client
742            .post_chat_completion(&self.model)
743            .header("Content-Type", "application/json")
744            .body(body)
745            .map_err(http_client::Error::from)?;
746
747        let span = if tracing::Span::current().is_disabled() {
748            info_span!(
749                target: "rig::completions",
750                "chat_streaming",
751                gen_ai.operation.name = "chat_streaming",
752                gen_ai.provider.name = "azure.openai",
753                gen_ai.request.model = self.model,
754                gen_ai.system_instructions = &preamble,
755                gen_ai.response.id = tracing::field::Empty,
756                gen_ai.response.model = tracing::field::Empty,
757                gen_ai.usage.output_tokens = tracing::field::Empty,
758                gen_ai.usage.input_tokens = tracing::field::Empty,
759                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
760                gen_ai.output.messages = tracing::field::Empty,
761            )
762        } else {
763            tracing::Span::current()
764        };
765
766        tracing_futures::Instrument::instrument(
767            send_compatible_streaming_request(self.client.http_client.clone(), req),
768            span,
769        )
770        .await
771    }
772}
773
774// ================================================================
775// Azure OpenAI Transcription API
776// ================================================================
777
778#[derive(Clone)]
779pub struct TranscriptionModel<T = reqwest::Client> {
780    client: Client<T>,
781    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
782    pub model: String,
783}
784
785impl<T> TranscriptionModel<T> {
786    pub fn new(client: Client<T>, model: &str) -> Self {
787        Self {
788            client,
789            model: model.to_string(),
790        }
791    }
792}
793
794impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
795where
796    T: HttpClientExt + Clone + 'static,
797{
798    type Response = TranscriptionResponse;
799
800    #[cfg_attr(feature = "worker", worker::send)]
801    async fn transcription(
802        &self,
803        request: transcription::TranscriptionRequest,
804    ) -> Result<
805        transcription::TranscriptionResponse<Self::Response>,
806        transcription::TranscriptionError,
807    > {
808        let data = request.data;
809
810        let mut body = reqwest::multipart::Form::new().part(
811            "file",
812            Part::bytes(data).file_name(request.filename.clone()),
813        );
814
815        if let Some(prompt) = request.prompt {
816            body = body.text("prompt", prompt.clone());
817        }
818
819        if let Some(ref temperature) = request.temperature {
820            body = body.text("temperature", temperature.to_string());
821        }
822
823        if let Some(ref additional_params) = request.additional_params {
824            for (key, value) in additional_params
825                .as_object()
826                .expect("Additional Parameters to OpenAI Transcription should be a map")
827            {
828                body = body.text(key.to_owned(), value.to_string());
829            }
830        }
831
832        let req = self
833            .client
834            .post_transcription(&self.model)
835            .body(body)
836            .map_err(|e| TranscriptionError::HttpError(e.into()))?;
837
838        let response = self.client.http_client.send_multipart::<Bytes>(req).await?;
839        let status = response.status();
840        let response_body = response.into_body().into_future().await?.to_vec();
841
842        if status.is_success() {
843            match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
844                ApiResponse::Ok(response) => response.try_into(),
845                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
846                    api_error_response.message,
847                )),
848            }
849        } else {
850            Err(TranscriptionError::ProviderError(
851                String::from_utf8_lossy(&response_body).to_string(),
852            ))
853        }
854    }
855}
856
857// ================================================================
858// Azure OpenAI Image Generation API
859// ================================================================
860#[cfg(feature = "image")]
861pub use image_generation::*;
862use tracing::{Instrument, info_span};
863#[cfg(feature = "image")]
864#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
865mod image_generation {
866    use crate::client::ImageGenerationClient;
867    use crate::http_client::HttpClientExt;
868    use crate::image_generation;
869    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
870    use crate::providers::azure::{ApiResponse, Client};
871    use crate::providers::openai::ImageGenerationResponse;
872    use bytes::Bytes;
873    use serde_json::json;
874
875    #[derive(Clone)]
876    pub struct ImageGenerationModel<T = reqwest::Client> {
877        client: Client<T>,
878        pub model: String,
879    }
880
881    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
882    where
883        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
884    {
885        type Response = ImageGenerationResponse;
886
887        #[cfg_attr(feature = "worker", worker::send)]
888        async fn image_generation(
889            &self,
890            generation_request: ImageGenerationRequest,
891        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
892        {
893            let request = json!({
894                "model": self.model,
895                "prompt": generation_request.prompt,
896                "size": format!("{}x{}", generation_request.width, generation_request.height),
897                "response_format": "b64_json"
898            });
899
900            let body = serde_json::to_vec(&request)?;
901
902            let req = self
903                .client
904                .post_image_generation(&self.model)
905                .header("Content-Type", "application/json")
906                .body(body)
907                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
908
909            let response = self.client.http_client.send::<_, Bytes>(req).await?;
910            let status = response.status();
911            let response_body = response.into_body().into_future().await?.to_vec();
912
913            if !status.is_success() {
914                return Err(ImageGenerationError::ProviderError(format!(
915                    "{status}: {}",
916                    String::from_utf8_lossy(&response_body)
917                )));
918            }
919
920            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
921                ApiResponse::Ok(response) => response.try_into(),
922                ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
923            }
924        }
925    }
926
927    impl<T> ImageGenerationClient for Client<T>
928    where
929        T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
930    {
931        type ImageGenerationModel = ImageGenerationModel<T>;
932
933        fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
934            ImageGenerationModel {
935                client: self.clone(),
936                model: model.to_string(),
937            }
938        }
939    }
940}
941// ================================================================
942// Azure OpenAI Audio Generation API
943// ================================================================
944
945use crate::client::{
946    CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient, VerifyClient,
947    VerifyError,
948};
949#[cfg(feature = "audio")]
950pub use audio_generation::*;
951
952#[cfg(feature = "audio")]
953#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
954mod audio_generation {
955    use super::Client;
956    use crate::audio_generation::{
957        self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
958    };
959    use crate::client::AudioGenerationClient;
960    use crate::http_client::HttpClientExt;
961    use bytes::Bytes;
962    use serde_json::json;
963
964    #[derive(Clone)]
965    pub struct AudioGenerationModel<T = reqwest::Client> {
966        client: Client<T>,
967        model: String,
968    }
969
970    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
971    where
972        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
973    {
974        type Response = Bytes;
975
976        #[cfg_attr(feature = "worker", worker::send)]
977        async fn audio_generation(
978            &self,
979            request: AudioGenerationRequest,
980        ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
981            let request = json!({
982                "model": self.model,
983                "input": request.text,
984                "voice": request.voice,
985                "speed": request.speed,
986            });
987
988            let body = serde_json::to_vec(&request)?;
989
990            let req = self
991                .client
992                .post_audio_generation("/audio/speech")
993                .header("Content-Type", "application/json")
994                .body(body)
995                .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
996
997            let response = self.client.http_client.send::<_, Bytes>(req).await?;
998            let status = response.status();
999            let response_body = response.into_body().into_future().await?;
1000
1001            if !status.is_success() {
1002                return Err(AudioGenerationError::ProviderError(format!(
1003                    "{status}: {}",
1004                    String::from_utf8_lossy(&response_body)
1005                )));
1006            }
1007
1008            Ok(AudioGenerationResponse {
1009                audio: response_body.to_vec(),
1010                response: response_body,
1011            })
1012        }
1013    }
1014
1015    impl<T> AudioGenerationClient for Client<T>
1016    where
1017        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
1018    {
1019        type AudioGenerationModel = AudioGenerationModel<T>;
1020
1021        fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
1022            AudioGenerationModel {
1023                client: self.clone(),
1024                model: model.to_string(),
1025            }
1026        }
1027    }
1028}
1029
1030impl<T> VerifyClient for Client<T>
1031where
1032    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
1033{
1034    #[cfg_attr(feature = "worker", worker::send)]
1035    async fn verify(&self) -> Result<(), VerifyError> {
1036        // There is currently no way to verify the Azure OpenAI API key or token without
1037        // consuming tokens
1038        Ok(())
1039    }
1040}
1041
1042#[cfg(test)]
1043mod azure_tests {
1044    use super::*;
1045
1046    use crate::OneOrMany;
1047    use crate::completion::CompletionModel;
1048    use crate::embeddings::EmbeddingModel;
1049
1050    #[tokio::test]
1051    #[ignore]
1052    async fn test_azure_embedding() {
1053        let _ = tracing_subscriber::fmt::try_init();
1054
1055        let client = Client::from_env();
1056        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1057        let embeddings = model
1058            .embed_texts(vec!["Hello, world!".to_string()])
1059            .await
1060            .unwrap();
1061
1062        tracing::info!("Azure embedding: {:?}", embeddings);
1063    }
1064
1065    #[tokio::test]
1066    #[ignore]
1067    async fn test_azure_completion() {
1068        let _ = tracing_subscriber::fmt::try_init();
1069
1070        let client = Client::from_env();
1071        let model = client.completion_model(GPT_4O_MINI);
1072        let completion = model
1073            .completion(CompletionRequest {
1074                preamble: Some("You are a helpful assistant.".to_string()),
1075                chat_history: OneOrMany::one("Hello!".into()),
1076                documents: vec![],
1077                max_tokens: Some(100),
1078                temperature: Some(0.0),
1079                tools: vec![],
1080                tool_choice: None,
1081                additional_params: None,
1082            })
1083            .await
1084            .unwrap();
1085
1086        tracing::info!("Azure completion: {:?}", completion);
1087    }
1088}