rig/providers/
azure.rs

1//! Azure OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::azure;
6//!
7//! let client = azure::Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
8//!
9//! let gpt4o = client.completion_model(azure::GPT_4O);
10//! ```
11use crate::{
12    agent::AgentBuilder,
13    completion::{self, CompletionError, CompletionRequest},
14    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
15    extractor::ExtractorBuilder,
16    json_utils,
17    providers::openai,
18    Embed,
19};
20use schemars::JsonSchema;
21use serde::{Deserialize, Serialize};
22use serde_json::json;
23
24// ================================================================
25// Main Azure OpenAI Client
26// ================================================================
27
28#[derive(Clone)]
29pub struct Client {
30    api_version: String,
31    azure_endpoint: String,
32    http_client: reqwest::Client,
33}
34
35impl Client {
36    /// Creates a new Azure OpenAI client.
37    ///
38    /// # Arguments
39    ///
40    /// * `api_key` - Azure OpenAI API key required for authentication
41    /// * `api_version` - API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview)
42    /// * `azure_endpoint` - Azure OpenAI endpoint URL, for example: https://{your-resource-name}.openai.azure.com
43    pub fn new(api_key: &str, api_version: &str, azure_endpoint: &str) -> Self {
44        Self {
45            api_version: api_version.to_string(),
46            azure_endpoint: azure_endpoint.to_string(),
47            http_client: reqwest::Client::builder()
48                .default_headers({
49                    let mut headers = reqwest::header::HeaderMap::new();
50                    headers.insert("api-key", api_key.parse().expect("API key should parse"));
51                    headers
52                })
53                .build()
54                .expect("Azure OpenAI reqwest client should build"),
55        }
56    }
57
58    /// Create a new Azure OpenAI client from the `AZURE_API_KEY`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables.
59    /// Panics if these environment variables are not set.
60    pub fn from_env() -> Self {
61        let api_key = std::env::var("AZURE_API_KEY").expect("AZURE_API_KEY not set");
62        let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
63        let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
64        Self::new(&api_key, &api_version, &azure_endpoint)
65    }
66
67    fn post_embedding(&self, deployment_id: &str) -> reqwest::RequestBuilder {
68        let url = format!(
69            "{}/openai/deployments/{}/embeddings?api-version={}",
70            self.azure_endpoint, deployment_id, self.api_version
71        )
72        .replace("//", "/");
73        self.http_client.post(url)
74    }
75
76    fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder {
77        let url = format!(
78            "{}/openai/deployments/{}/chat/completions?api-version={}",
79            self.azure_endpoint, deployment_id, self.api_version
80        )
81        .replace("//", "/");
82        self.http_client.post(url)
83    }
84
85    /// Create an embedding model with the given name.
86    /// Note: default embedding dimension of 0 will be used if model is not known.
87    /// If this is the case, it's better to use function `embedding_model_with_ndims`
88    ///
89    /// # Example
90    /// ```
91    /// use rig::providers::azure::{Client, self};
92    ///
93    /// // Initialize the Azure OpenAI client
94    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
95    ///
96    /// let embedding_model = azure.embedding_model(azure::TEXT_EMBEDDING_3_LARGE);
97    /// ```
98    pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
99        let ndims = match model {
100            TEXT_EMBEDDING_3_LARGE => 3072,
101            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
102            _ => 0,
103        };
104        EmbeddingModel::new(self.clone(), model, ndims)
105    }
106
107    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
108    ///
109    /// # Example
110    /// ```
111    /// use rig::providers::azure::{Client, self};
112    ///
113    /// // Initialize the Azure OpenAI client
114    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
115    ///
116    /// let embedding_model = azure.embedding_model("model-unknown-to-rig", 3072);
117    /// ```
118    pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
119        EmbeddingModel::new(self.clone(), model, ndims)
120    }
121
122    /// Create an embedding builder with the given embedding model.
123    ///
124    /// # Example
125    /// ```
126    /// use rig::providers::azure::{Client, self};
127    ///
128    /// // Initialize the Azure OpenAI client
129    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
130    ///
131    /// let embeddings = azure.embeddings(azure::TEXT_EMBEDDING_3_LARGE)
132    ///     .simple_document("doc0", "Hello, world!")
133    ///     .simple_document("doc1", "Goodbye, world!")
134    ///     .build()
135    ///     .await
136    ///     .expect("Failed to embed documents");
137    /// ```
138    pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
139        EmbeddingsBuilder::new(self.embedding_model(model))
140    }
141
142    /// Create a completion model with the given name.
143    ///
144    /// # Example
145    /// ```
146    /// use rig::providers::azure::{Client, self};
147    ///
148    /// // Initialize the Azure OpenAI client
149    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
150    ///
151    /// let gpt4 = azure.completion_model(azure::GPT_4);
152    /// ```
153    pub fn completion_model(&self, model: &str) -> CompletionModel {
154        CompletionModel::new(self.clone(), model)
155    }
156
157    /// Create an agent builder with the given completion model.
158    ///
159    /// # Example
160    /// ```
161    /// use rig::providers::azure::{Client, self};
162    ///
163    /// // Initialize the Azure OpenAI client
164    /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT");
165    ///
166    /// let agent = azure.agent(azure::GPT_4)
167    ///    .preamble("You are comedian AI with a mission to make people laugh.")
168    ///    .temperature(0.0)
169    ///    .build();
170    /// ```
171    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
172        AgentBuilder::new(self.completion_model(model))
173    }
174
175    /// Create an extractor builder with the given completion model.
176    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
177        &self,
178        model: &str,
179    ) -> ExtractorBuilder<T, CompletionModel> {
180        ExtractorBuilder::new(self.completion_model(model))
181    }
182}
183
184#[derive(Debug, Deserialize)]
185struct ApiErrorResponse {
186    message: String,
187}
188
189#[derive(Debug, Deserialize)]
190#[serde(untagged)]
191enum ApiResponse<T> {
192    Ok(T),
193    Err(ApiErrorResponse),
194}
195
196// ================================================================
197// Azure OpenAI Embedding API
198// ================================================================
199/// `text-embedding-3-large` embedding model
200pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
201/// `text-embedding-3-small` embedding model
202pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
203/// `text-embedding-ada-002` embedding model
204pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
205
206#[derive(Debug, Deserialize)]
207pub struct EmbeddingResponse {
208    pub object: String,
209    pub data: Vec<EmbeddingData>,
210    pub model: String,
211    pub usage: Usage,
212}
213
214impl From<ApiErrorResponse> for EmbeddingError {
215    fn from(err: ApiErrorResponse) -> Self {
216        EmbeddingError::ProviderError(err.message)
217    }
218}
219
220impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
221    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
222        match value {
223            ApiResponse::Ok(response) => Ok(response),
224            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
225        }
226    }
227}
228
229#[derive(Debug, Deserialize)]
230pub struct EmbeddingData {
231    pub object: String,
232    pub embedding: Vec<f64>,
233    pub index: usize,
234}
235
236#[derive(Clone, Debug, Deserialize)]
237pub struct Usage {
238    pub prompt_tokens: usize,
239    pub total_tokens: usize,
240}
241
242impl std::fmt::Display for Usage {
243    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        write!(
245            f,
246            "Prompt tokens: {} Total tokens: {}",
247            self.prompt_tokens, self.total_tokens
248        )
249    }
250}
251
252#[derive(Clone)]
253pub struct EmbeddingModel {
254    client: Client,
255    pub model: String,
256    ndims: usize,
257}
258
259impl embeddings::EmbeddingModel for EmbeddingModel {
260    const MAX_DOCUMENTS: usize = 1024;
261
262    fn ndims(&self) -> usize {
263        self.ndims
264    }
265
266    #[cfg_attr(feature = "worker", worker::send)]
267    async fn embed_texts(
268        &self,
269        documents: impl IntoIterator<Item = String>,
270    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
271        let documents = documents.into_iter().collect::<Vec<_>>();
272
273        let response = self
274            .client
275            .post_embedding(&self.model)
276            .json(&json!({
277                "input": documents,
278            }))
279            .send()
280            .await?;
281
282        if response.status().is_success() {
283            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
284                ApiResponse::Ok(response) => {
285                    tracing::info!(target: "rig",
286                        "Azure embedding token usage: {}",
287                        response.usage
288                    );
289
290                    if response.data.len() != documents.len() {
291                        return Err(EmbeddingError::ResponseError(
292                            "Response data length does not match input length".into(),
293                        ));
294                    }
295
296                    Ok(response
297                        .data
298                        .into_iter()
299                        .zip(documents.into_iter())
300                        .map(|(embedding, document)| embeddings::Embedding {
301                            document,
302                            vec: embedding.embedding,
303                        })
304                        .collect())
305                }
306                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
307            }
308        } else {
309            Err(EmbeddingError::ProviderError(response.text().await?))
310        }
311    }
312}
313
314impl EmbeddingModel {
315    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
316        Self {
317            client,
318            model: model.to_string(),
319            ndims,
320        }
321    }
322}
323
324// ================================================================
325// Azure OpenAI Completion API
326// ================================================================
327/// `o1` completion model
328pub const O1: &str = "o1";
329/// `o1-preview` completion model
330pub const O1_PREVIEW: &str = "o1-preview";
331/// `o1-mini` completion model
332pub const O1_MINI: &str = "o1-mini";
333/// `gpt-4o` completion model
334pub const GPT_4O: &str = "gpt-4o";
335/// `gpt-4o-mini` completion model
336pub const GPT_4O_MINI: &str = "gpt-4o-mini";
337/// `gpt-4o-realtime-preview` completion model
338pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
339/// `gpt-4-turbo` completion model
340pub const GPT_4_TURBO: &str = "gpt-4";
341/// `gpt-4` completion model
342pub const GPT_4: &str = "gpt-4";
343/// `gpt-4-32k` completion model
344pub const GPT_4_32K: &str = "gpt-4-32k";
345/// `gpt-4-32k` completion model
346pub const GPT_4_32K_0613: &str = "gpt-4-32k";
347/// `gpt-3.5-turbo` completion model
348pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
349/// `gpt-3.5-turbo-instruct` completion model
350pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
351/// `gpt-3.5-turbo-16k` completion model
352pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
353
354#[derive(Clone)]
355pub struct CompletionModel {
356    client: Client,
357    /// Name of the model (e.g.: gpt-4o-mini)
358    pub model: String,
359}
360
361impl CompletionModel {
362    pub fn new(client: Client, model: &str) -> Self {
363        Self {
364            client,
365            model: model.to_string(),
366        }
367    }
368}
369
370impl completion::CompletionModel for CompletionModel {
371    type Response = openai::CompletionResponse;
372
373    #[cfg_attr(feature = "worker", worker::send)]
374    async fn completion(
375        &self,
376        completion_request: CompletionRequest,
377    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
378        // Add preamble to chat history (if available)
379        let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
380            Some(preamble) => vec![openai::Message::system(preamble)],
381            None => vec![],
382        };
383
384        // Convert prompt to user message
385        let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
386
387        // Convert existing chat history
388        let chat_history: Vec<openai::Message> = completion_request
389            .chat_history
390            .into_iter()
391            .map(|message| message.try_into())
392            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
393            .into_iter()
394            .flatten()
395            .collect();
396
397        // Combine all messages into a single history
398        full_history.extend(chat_history);
399        full_history.extend(prompt);
400
401        let request = if completion_request.tools.is_empty() {
402            json!({
403                "model": self.model,
404                "messages": full_history,
405                "temperature": completion_request.temperature,
406            })
407        } else {
408            json!({
409                "model": self.model,
410                "messages": full_history,
411                "temperature": completion_request.temperature,
412                "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
413                "tool_choice": "auto",
414            })
415        };
416
417        let response = self
418            .client
419            .post_chat_completion(&self.model)
420            .json(
421                &if let Some(params) = completion_request.additional_params {
422                    json_utils::merge(request, params)
423                } else {
424                    request
425                },
426            )
427            .send()
428            .await?;
429
430        if response.status().is_success() {
431            let t = response.text().await?;
432            tracing::debug!(target: "rig", "Azure completion error: {}", t);
433
434            match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
435                ApiResponse::Ok(response) => {
436                    tracing::info!(target: "rig",
437                        "Azure completion token usage: {:?}",
438                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
439                    );
440                    response.try_into()
441                }
442                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
443            }
444        } else {
445            Err(CompletionError::ProviderError(response.text().await?))
446        }
447    }
448}
449
450#[cfg(test)]
451mod azure_tests {
452    use super::*;
453
454    use crate::completion::CompletionModel;
455    use crate::embeddings::EmbeddingModel;
456
457    #[tokio::test]
458    #[ignore]
459    async fn test_azure_embedding() {
460        let _ = tracing_subscriber::fmt::try_init();
461
462        let client = Client::from_env();
463        let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
464        let embeddings = model
465            .embed_texts(vec!["Hello, world!".to_string()])
466            .await
467            .unwrap();
468
469        tracing::info!("Azure embedding: {:?}", embeddings);
470    }
471
472    #[tokio::test]
473    #[ignore]
474    async fn test_azure_completion() {
475        let _ = tracing_subscriber::fmt::try_init();
476
477        let client = Client::from_env();
478        let model = client.completion_model(GPT_4O_MINI);
479        let completion = model
480            .completion(CompletionRequest {
481                preamble: Some("You are a helpful assistant.".to_string()),
482                chat_history: vec![],
483                prompt: "Hello, world!".into(),
484                documents: vec![],
485                max_tokens: Some(100),
486                temperature: Some(0.0),
487                tools: vec![],
488                additional_params: None,
489            })
490            .await
491            .unwrap();
492
493        tracing::info!("Azure completion: {:?}", completion);
494    }
495}