rig/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::azure;
6//!
7//! let client = azure::Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
8//!
9//! let gpt4o = client.completion_model(azure::GPT_4O);
10//! ```
11
12use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
13use crate::json_utils::merge;
14use crate::streaming::{StreamingCompletionModel, StreamingResult};
15use crate::{
16    agent::AgentBuilder,
17    completion::{self, CompletionError, CompletionRequest},
18    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
19    extractor::ExtractorBuilder,
20    json_utils,
21    providers::openai,
22    transcription::{self, TranscriptionError},
23    Embed,
24};
25use reqwest::multipart::Part;
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use serde_json::json;
29
30// ================================================================
31// Main Azure OpenAI Client
32// ================================================================
33
34#[derive(Clone)]
35pub struct Client {
36    api_version: String,
37    azure_endpoint: String,
38    http_client: reqwest::Client,
39}
40
41#[derive(Clone)]
42pub enum AzureOpenAIAuth {
43    ApiKey(String),
44    Token(String),
45}
46
47impl From<String> for AzureOpenAIAuth {
48    fn from(token: String) -> Self {
49        AzureOpenAIAuth::Token(token)
50    }
51}
52
53impl Client {
54    /// Creates a new Azure OpenAI client.
55    ///
56    /// # Arguments
57    ///
58    /// * `auth` - Azure OpenAI API key or token required for authentication
59    /// * `api_version` - API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
60    /// * `azure_endpoint` - Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
61    pub fn new(auth: impl Into<AzureOpenAIAuth>, api_version: &str, azure_endpoint: &str) -> Self {
62        let mut headers = reqwest::header::HeaderMap::new();
63        match auth.into() {
64            AzureOpenAIAuth::ApiKey(api_key) => {
65                headers.insert("api-key", api_key.parse().expect("API key should parse"));
66            }
67            AzureOpenAIAuth::Token(token) => {
68                headers.insert(
69                    "Authorization",
70                    format!("Bearer {}", token)
71                        .parse()
72                        .expect("Token should parse"),
73                );
74            }
75        }
76
77        Self {
78            api_version: api_version.to_string(),
79            azure_endpoint: azure_endpoint.to_string(),
80            http_client: reqwest::Client::builder()
81                .default_headers(headers)
82                .build()
83                .expect("Azure OpenAI reqwest client should build"),
84        }
85    }
86
87    /// Creates a new Azure OpenAI client from an API key.
88    ///
89    /// # Arguments
90    ///
91    /// * `api_key` - Azure OpenAI API key required for authentication
92    /// * `api_version` - API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
93    /// * `azure_endpoint` - Azure OpenAI endpoint URL
94    pub fn from_api_key(api_key: &str, api_version: &str, azure_endpoint: &str) -> Self {
95        Self::new(
96            AzureOpenAIAuth::ApiKey(api_key.to_string()),
97            api_version,
98            azure_endpoint,
99        )
100    }
101
102    /// Creates a new Azure OpenAI client from a token.
103    ///
104    /// # Arguments
105    ///
106    /// * `token` - Azure OpenAI token required for authentication
107    /// * `api_version` - API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
108    /// * `azure_endpoint` - Azure OpenAI endpoint URL
109    pub fn from_token(token: &str, api_version: &str, azure_endpoint: &str) -> Self {
110        Self::new(
111            AzureOpenAIAuth::Token(token.to_string()),
112            api_version,
113            azure_endpoint,
114        )
115    }
116
117    /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
118    pub fn from_env() -> Self {
119        let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
120            AzureOpenAIAuth::ApiKey(api_key)
121        } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
122            AzureOpenAIAuth::Token(token)
123        } else {
124            panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
125        };
126
127        let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
128        let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
129
130        Self::new(auth, &api_version, &azure_endpoint)
131    }
132
133    fn post_embedding(&self, deployment_id: &str) -> reqwest::RequestBuilder {
134        let url = format!(
135            "{}/openai/deployments/{}/embeddings?api-version={}",
136            self.azure_endpoint, deployment_id, self.api_version
137        )
138        .replace("//", "/");
139        self.http_client.post(url)
140    }
141
142    fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder {
143        let url = format!(
144            "{}/openai/deployments/{}/chat/completions?api-version={}",
145            self.azure_endpoint, deployment_id, self.api_version
146        )
147        .replace("//", "/");
148        self.http_client.post(url)
149    }
150
151    fn post_transcription(&self, deployment_id: &str) -> reqwest::RequestBuilder {
152        let url = format!(
153            "{}/openai/deployments/{}/audio/translations?api-version={}",
154            self.azure_endpoint, deployment_id, self.api_version
155        )
156        .replace("//", "/");
157        self.http_client.post(url)
158    }
159
160    /// Create an embedding model with the given name.
161    /// Note: default embedding dimension of 0 will be used if model is not known.
162    /// If this is the case, it's better to use function `embedding_model_with_ndims`
163    ///
164    /// # Example
165    /// ```
166    /// use rig::providers::azure::{Client, self};
167    ///
168    /// // Initialize the Azure OpenAI client
169    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
170    ///
171    /// let embedding_model = azure.embedding_model(azure::TEXT_EMBEDDING_3_LARGE);
172    /// ```
173    pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
174        let ndims = match model {
175            TEXT_EMBEDDING_3_LARGE => 3072,
176            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
177            _ => 0,
178        };
179        EmbeddingModel::new(self.clone(), model, ndims)
180    }
181
182    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
183    ///
184    /// # Example
185    /// ```
186    /// use rig::providers::azure::{Client, self};
187    ///
188    /// // Initialize the Azure OpenAI client
189    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
190    ///
191    /// let embedding_model = azure.embedding_model("model-unknown-to-rig", 3072);
192    /// ```
193    pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
194        EmbeddingModel::new(self.clone(), model, ndims)
195    }
196
197    /// Create an embedding builder with the given embedding model.
198    ///
199    /// # Example
200    /// ```
201    /// use rig::providers::azure::{Client, self};
202    ///
203    /// // Initialize the Azure OpenAI client
204    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
205    ///
206    /// let embeddings = azure.embeddings(azure::TEXT_EMBEDDING_3_LARGE)
207    ///     .simple_document("doc0", "Hello, world!")
208    ///     .simple_document("doc1", "Goodbye, world!")
209    ///     .build()
210    ///     .await
211    ///     .expect("Failed to embed documents");
212    /// ```
213    pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
214        EmbeddingsBuilder::new(self.embedding_model(model))
215    }
216
217    /// Create a completion model with the given name.
218    ///
219    /// # Example
220    /// ```
221    /// use rig::providers::azure::{Client, self};
222    ///
223    /// // Initialize the Azure OpenAI client
224    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
225    ///
226    /// let gpt4 = azure.completion_model(azure::GPT_4);
227    /// ```
228    pub fn completion_model(&self, model: &str) -> CompletionModel {
229        CompletionModel::new(self.clone(), model)
230    }
231
232    /// Create a transcription model with the given name.
233    ///
234    /// # Example
235    /// ```
236    /// use rig::providers::azure::{Client, self};
237    ///
238    /// // Initialize the Azure OpenAI client
239    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
240    ///
241    /// let whisper = azure.transcription_model("model-unknown-to-rig");
242    /// ```
243    pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
244        TranscriptionModel::new(self.clone(), model)
245    }
246
247    /// Create an agent builder with the given completion model.
248    ///
249    /// # Example
250    /// ```
251    /// use rig::providers::azure::{Client, self};
252    ///
253    /// // Initialize the Azure OpenAI client
254    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
255    ///
256    /// let agent = azure.agent(azure::GPT_4)
257    ///    .preamble("You are comedian AI with a mission to make people laugh.")
258    ///    .temperature(0.0)
259    ///    .build();
260    /// ```
261    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
262        AgentBuilder::new(self.completion_model(model))
263    }
264
265    /// Create an extractor builder with the given completion model.
266    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
267        &self,
268        model: &str,
269    ) -> ExtractorBuilder<T, CompletionModel> {
270        ExtractorBuilder::new(self.completion_model(model))
271    }
272}
273
274#[derive(Debug, Deserialize)]
275struct ApiErrorResponse {
276    message: String,
277}
278
279#[derive(Debug, Deserialize)]
280#[serde(untagged)]
281enum ApiResponse<T> {
282    Ok(T),
283    Err(ApiErrorResponse),
284}
285
286// ================================================================
287// Azure OpenAI Embedding API
288// ================================================================
289/// `text-embedding-3-large` embedding model
290pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
291/// `text-embedding-3-small` embedding model
292pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
293/// `text-embedding-ada-002` embedding model
294pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
295
296#[derive(Debug, Deserialize)]
297pub struct EmbeddingResponse {
298    pub object: String,
299    pub data: Vec<EmbeddingData>,
300    pub model: String,
301    pub usage: Usage,
302}
303
304impl From<ApiErrorResponse> for EmbeddingError {
305    fn from(err: ApiErrorResponse) -> Self {
306        EmbeddingError::ProviderError(err.message)
307    }
308}
309
310impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
311    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
312        match value {
313            ApiResponse::Ok(response) => Ok(response),
314            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
315        }
316    }
317}
318
319#[derive(Debug, Deserialize)]
320pub struct EmbeddingData {
321    pub object: String,
322    pub embedding: Vec<f64>,
323    pub index: usize,
324}
325
326#[derive(Clone, Debug, Deserialize)]
327pub struct Usage {
328    pub prompt_tokens: usize,
329    pub total_tokens: usize,
330}
331
332impl std::fmt::Display for Usage {
333    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334        write!(
335            f,
336            "Prompt tokens: {} Total tokens: {}",
337            self.prompt_tokens, self.total_tokens
338        )
339    }
340}
341
342#[derive(Clone)]
343pub struct EmbeddingModel {
344    client: Client,
345    pub model: String,
346    ndims: usize,
347}
348
349impl embeddings::EmbeddingModel for EmbeddingModel {
350    const MAX_DOCUMENTS: usize = 1024;
351
352    fn ndims(&self) -> usize {
353        self.ndims
354    }
355
356    #[cfg_attr(feature = "worker", worker::send)]
357    async fn embed_texts(
358        &self,
359        documents: impl IntoIterator<Item = String>,
360    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
361        let documents = documents.into_iter().collect::<Vec<_>>();
362
363        let response = self
364            .client
365            .post_embedding(&self.model)
366            .json(&json!({
367                "input": documents,
368            }))
369            .send()
370            .await?;
371
372        if response.status().is_success() {
373            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
374                ApiResponse::Ok(response) => {
375                    tracing::info!(target: "rig",
376                        "Azure embedding token usage: {}",
377                        response.usage
378                    );
379
380                    if response.data.len() != documents.len() {
381                        return Err(EmbeddingError::ResponseError(
382                            "Response data length does not match input length".into(),
383                        ));
384                    }
385
386                    Ok(response
387                        .data
388                        .into_iter()
389                        .zip(documents.into_iter())
390                        .map(|(embedding, document)| embeddings::Embedding {
391                            document,
392                            vec: embedding.embedding,
393                        })
394                        .collect())
395                }
396                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
397            }
398        } else {
399            Err(EmbeddingError::ProviderError(response.text().await?))
400        }
401    }
402}
403
404impl EmbeddingModel {
405    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
406        Self {
407            client,
408            model: model.to_string(),
409            ndims,
410        }
411    }
412}
413
414// ================================================================
415// Azure OpenAI Completion API
416// ================================================================
417/// `o1` completion model
418pub const O1: &str = "o1";
419/// `o1-preview` completion model
420pub const O1_PREVIEW: &str = "o1-preview";
421/// `o1-mini` completion model
422pub const O1_MINI: &str = "o1-mini";
423/// `gpt-4o` completion model
424pub const GPT_4O: &str = "gpt-4o";
425/// `gpt-4o-mini` completion model
426pub const GPT_4O_MINI: &str = "gpt-4o-mini";
427/// `gpt-4o-realtime-preview` completion model
428pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
429/// `gpt-4-turbo` completion model
430pub const GPT_4_TURBO: &str = "gpt-4";
431/// `gpt-4` completion model
432pub const GPT_4: &str = "gpt-4";
433/// `gpt-4-32k` completion model
434pub const GPT_4_32K: &str = "gpt-4-32k";
435/// `gpt-4-32k` completion model
436pub const GPT_4_32K_0613: &str = "gpt-4-32k";
437/// `gpt-3.5-turbo` completion model
438pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
439/// `gpt-3.5-turbo-instruct` completion model
440pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
441/// `gpt-3.5-turbo-16k` completion model
442pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
443
444#[derive(Clone)]
445pub struct CompletionModel {
446    client: Client,
447    /// Name of the model (e.g.: gpt-4o-mini)
448    pub model: String,
449}
450
451impl CompletionModel {
452    pub fn new(client: Client, model: &str) -> Self {
453        Self {
454            client,
455            model: model.to_string(),
456        }
457    }
458
459    fn create_completion_request(
460        &self,
461        completion_request: CompletionRequest,
462    ) -> Result<serde_json::Value, CompletionError> {
463        // Add preamble to chat history (if available)
464        let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
465            Some(preamble) => vec![openai::Message::system(preamble)],
466            None => vec![],
467        };
468
469        // Convert prompt to user message
470        let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
471
472        // Convert existing chat history
473        let chat_history: Vec<openai::Message> = completion_request
474            .chat_history
475            .into_iter()
476            .map(|message| message.try_into())
477            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
478            .into_iter()
479            .flatten()
480            .collect();
481
482        // Combine all messages into a single history
483        full_history.extend(chat_history);
484        full_history.extend(prompt);
485
486        let request = if completion_request.tools.is_empty() {
487            json!({
488                "model": self.model,
489                "messages": full_history,
490                "temperature": completion_request.temperature,
491            })
492        } else {
493            json!({
494                "model": self.model,
495                "messages": full_history,
496                "temperature": completion_request.temperature,
497                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
498                "tool_choice": "auto",
499            })
500        };
501
502        let request = if let Some(params) = completion_request.additional_params {
503            json_utils::merge(request, params)
504        } else {
505            request
506        };
507
508        Ok(request)
509    }
510}
511
512impl completion::CompletionModel for CompletionModel {
513    type Response = openai::CompletionResponse;
514
515    #[cfg_attr(feature = "worker", worker::send)]
516    async fn completion(
517        &self,
518        completion_request: CompletionRequest,
519    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
520        let request = self.create_completion_request(completion_request)?;
521
522        let response = self
523            .client
524            .post_chat_completion(&self.model)
525            .json(&request)
526            .send()
527            .await?;
528
529        if response.status().is_success() {
530            let t = response.text().await?;
531            tracing::debug!(target: "rig", "Azure completion error: {}", t);
532
533            match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
534                ApiResponse::Ok(response) => {
535                    tracing::info!(target: "rig",
536                        "Azure completion token usage: {:?}",
537                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
538                    );
539                    response.try_into()
540                }
541                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
542            }
543        } else {
544            Err(CompletionError::ProviderError(response.text().await?))
545        }
546    }
547}
548
549// -----------------------------------------------------
550// Azure OpenAI Streaming API
551// -----------------------------------------------------
552impl StreamingCompletionModel for CompletionModel {
553    async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
554        let mut request = self.create_completion_request(request)?;
555
556        request = merge(request, json!({"stream": true}));
557
558        let builder = self
559            .client
560            .post_chat_completion(self.model.as_str())
561            .json(&request);
562
563        send_compatible_streaming_request(builder).await
564    }
565}
566
567// ================================================================
568// Azure OpenAI Transcription API
569// ================================================================
570
571#[derive(Clone)]
572pub struct TranscriptionModel {
573    client: Client,
574    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
575    pub model: String,
576}
577
578impl TranscriptionModel {
579    pub fn new(client: Client, model: &str) -> Self {
580        Self {
581            client,
582            model: model.to_string(),
583        }
584    }
585}
586
587impl transcription::TranscriptionModel for TranscriptionModel {
588    type Response = TranscriptionResponse;
589
590    #[cfg_attr(feature = "worker", worker::send)]
591    async fn transcription(
592        &self,
593        request: transcription::TranscriptionRequest,
594    ) -> Result<
595        transcription::TranscriptionResponse<Self::Response>,
596        transcription::TranscriptionError,
597    > {
598        let data = request.data;
599
600        let mut body = reqwest::multipart::Form::new().part(
601            "file",
602            Part::bytes(data).file_name(request.filename.clone()),
603        );
604
605        if let Some(prompt) = request.prompt {
606            body = body.text("prompt", prompt.clone());
607        }
608
609        if let Some(ref temperature) = request.temperature {
610            body = body.text("temperature", temperature.to_string());
611        }
612
613        if let Some(ref additional_params) = request.additional_params {
614            for (key, value) in additional_params
615                .as_object()
616                .expect("Additional Parameters to OpenAI Transcription should be a map")
617            {
618                body = body.text(key.to_owned(), value.to_string());
619            }
620        }
621
622        let response = self
623            .client
624            .post_transcription(&self.model)
625            .multipart(body)
626            .send()
627            .await?;
628
629        if response.status().is_success() {
630            match response
631                .json::<ApiResponse<TranscriptionResponse>>()
632                .await?
633            {
634                ApiResponse::Ok(response) => response.try_into(),
635                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
636                    api_error_response.message,
637                )),
638            }
639        } else {
640            Err(TranscriptionError::ProviderError(response.text().await?))
641        }
642    }
643}
644
645#[cfg(test)]
646mod azure_tests {
647    use super::*;
648
649    use crate::completion::CompletionModel;
650    use crate::embeddings::EmbeddingModel;
651
652    #[tokio::test]
653    #[ignore]
654    async fn test_azure_embedding() {
655        let _ = tracing_subscriber::fmt::try_init();
656
657        let client = Client::from_env();
658        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
659        let embeddings = model
660            .embed_texts(vec!["Hello, world!".to_string()])
661            .await
662            .unwrap();
663
664        tracing::info!("Azure embedding: {:?}", embeddings);
665    }
666
667    #[tokio::test]
668    #[ignore]
669    async fn test_azure_completion() {
670        let _ = tracing_subscriber::fmt::try_init();
671
672        let client = Client::from_env();
673        let model = client.completion_model(GPT_4O_MINI);
674        let completion = model
675            .completion(CompletionRequest {
676                preamble: Some("You are a helpful assistant.".to_string()),
677                chat_history: vec![],
678                prompt: "Hello, world!".into(),
679                documents: vec![],
680                max_tokens: Some(100),
681                temperature: Some(0.0),
682                tools: vec![],
683                additional_params: None,
684            })
685            .await
686            .unwrap();
687
688        tracing::info!("Azure completion: {:?}", completion);
689    }
690}