1use crate::{
12 agent::AgentBuilder,
13 completion::{self, CompletionError, CompletionRequest},
14 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
15 extractor::ExtractorBuilder,
16 json_utils,
17 providers::openai,
18 Embed,
19};
20use schemars::JsonSchema;
21use serde::{Deserialize, Serialize};
22use serde_json::json;
23
24#[derive(Clone)]
29pub struct Client {
30 api_version: String,
31 azure_endpoint: String,
32 http_client: reqwest::Client,
33}
34
35impl Client {
36 pub fn new(api_key: &str, api_version: &str, azure_endpoint: &str) -> Self {
44 Self {
45 api_version: api_version.to_string(),
46 azure_endpoint: azure_endpoint.to_string(),
47 http_client: reqwest::Client::builder()
48 .default_headers({
49 let mut headers = reqwest::header::HeaderMap::new();
50 headers.insert("api-key", api_key.parse().expect("API key should parse"));
51 headers
52 })
53 .build()
54 .expect("Azure OpenAI reqwest client should build"),
55 }
56 }
57
58 pub fn from_env() -> Self {
61 let api_key = std::env::var("AZURE_API_KEY").expect("AZURE_API_KEY not set");
62 let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
63 let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
64 Self::new(&api_key, &api_version, &azure_endpoint)
65 }
66
67 fn post_embedding(&self, deployment_id: &str) -> reqwest::RequestBuilder {
68 let url = format!(
69 "{}/openai/deployments/{}/embeddings?api-version={}",
70 self.azure_endpoint, deployment_id, self.api_version
71 )
72 .replace("//", "/");
73 self.http_client.post(url)
74 }
75
76 fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder {
77 let url = format!(
78 "{}/openai/deployments/{}/chat/completions?api-version={}",
79 self.azure_endpoint, deployment_id, self.api_version
80 )
81 .replace("//", "/");
82 self.http_client.post(url)
83 }
84
85 pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
99 let ndims = match model {
100 TEXT_EMBEDDING_3_LARGE => 3072,
101 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
102 _ => 0,
103 };
104 EmbeddingModel::new(self.clone(), model, ndims)
105 }
106
107 pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
119 EmbeddingModel::new(self.clone(), model, ndims)
120 }
121
122 pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
139 EmbeddingsBuilder::new(self.embedding_model(model))
140 }
141
142 pub fn completion_model(&self, model: &str) -> CompletionModel {
154 CompletionModel::new(self.clone(), model)
155 }
156
157 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
172 AgentBuilder::new(self.completion_model(model))
173 }
174
175 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
177 &self,
178 model: &str,
179 ) -> ExtractorBuilder<T, CompletionModel> {
180 ExtractorBuilder::new(self.completion_model(model))
181 }
182}
183
184#[derive(Debug, Deserialize)]
185struct ApiErrorResponse {
186 message: String,
187}
188
189#[derive(Debug, Deserialize)]
190#[serde(untagged)]
191enum ApiResponse<T> {
192 Ok(T),
193 Err(ApiErrorResponse),
194}
195
196pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
201pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
203pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
205
206#[derive(Debug, Deserialize)]
207pub struct EmbeddingResponse {
208 pub object: String,
209 pub data: Vec<EmbeddingData>,
210 pub model: String,
211 pub usage: Usage,
212}
213
214impl From<ApiErrorResponse> for EmbeddingError {
215 fn from(err: ApiErrorResponse) -> Self {
216 EmbeddingError::ProviderError(err.message)
217 }
218}
219
220impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
221 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
222 match value {
223 ApiResponse::Ok(response) => Ok(response),
224 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
225 }
226 }
227}
228
229#[derive(Debug, Deserialize)]
230pub struct EmbeddingData {
231 pub object: String,
232 pub embedding: Vec<f64>,
233 pub index: usize,
234}
235
236#[derive(Clone, Debug, Deserialize)]
237pub struct Usage {
238 pub prompt_tokens: usize,
239 pub total_tokens: usize,
240}
241
242impl std::fmt::Display for Usage {
243 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 write!(
245 f,
246 "Prompt tokens: {} Total tokens: {}",
247 self.prompt_tokens, self.total_tokens
248 )
249 }
250}
251
252#[derive(Clone)]
253pub struct EmbeddingModel {
254 client: Client,
255 pub model: String,
256 ndims: usize,
257}
258
259impl embeddings::EmbeddingModel for EmbeddingModel {
260 const MAX_DOCUMENTS: usize = 1024;
261
262 fn ndims(&self) -> usize {
263 self.ndims
264 }
265
266 #[cfg_attr(feature = "worker", worker::send)]
267 async fn embed_texts(
268 &self,
269 documents: impl IntoIterator<Item = String>,
270 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
271 let documents = documents.into_iter().collect::<Vec<_>>();
272
273 let response = self
274 .client
275 .post_embedding(&self.model)
276 .json(&json!({
277 "input": documents,
278 }))
279 .send()
280 .await?;
281
282 if response.status().is_success() {
283 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
284 ApiResponse::Ok(response) => {
285 tracing::info!(target: "rig",
286 "Azure embedding token usage: {}",
287 response.usage
288 );
289
290 if response.data.len() != documents.len() {
291 return Err(EmbeddingError::ResponseError(
292 "Response data length does not match input length".into(),
293 ));
294 }
295
296 Ok(response
297 .data
298 .into_iter()
299 .zip(documents.into_iter())
300 .map(|(embedding, document)| embeddings::Embedding {
301 document,
302 vec: embedding.embedding,
303 })
304 .collect())
305 }
306 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
307 }
308 } else {
309 Err(EmbeddingError::ProviderError(response.text().await?))
310 }
311 }
312}
313
314impl EmbeddingModel {
315 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
316 Self {
317 client,
318 model: model.to_string(),
319 ndims,
320 }
321 }
322}
323
324pub const O1: &str = "o1";
329pub const O1_PREVIEW: &str = "o1-preview";
331pub const O1_MINI: &str = "o1-mini";
333pub const GPT_4O: &str = "gpt-4o";
335pub const GPT_4O_MINI: &str = "gpt-4o-mini";
337pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
339pub const GPT_4_TURBO: &str = "gpt-4";
341pub const GPT_4: &str = "gpt-4";
343pub const GPT_4_32K: &str = "gpt-4-32k";
345pub const GPT_4_32K_0613: &str = "gpt-4-32k";
347pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
349pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
351pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
353
354#[derive(Clone)]
355pub struct CompletionModel {
356 client: Client,
357 pub model: String,
359}
360
361impl CompletionModel {
362 pub fn new(client: Client, model: &str) -> Self {
363 Self {
364 client,
365 model: model.to_string(),
366 }
367 }
368}
369
370impl completion::CompletionModel for CompletionModel {
371 type Response = openai::CompletionResponse;
372
373 #[cfg_attr(feature = "worker", worker::send)]
374 async fn completion(
375 &self,
376 completion_request: CompletionRequest,
377 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
378 let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
380 Some(preamble) => vec![openai::Message::system(preamble)],
381 None => vec![],
382 };
383
384 let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
386
387 let chat_history: Vec<openai::Message> = completion_request
389 .chat_history
390 .into_iter()
391 .map(|message| message.try_into())
392 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
393 .into_iter()
394 .flatten()
395 .collect();
396
397 full_history.extend(chat_history);
399 full_history.extend(prompt);
400
401 let request = if completion_request.tools.is_empty() {
402 json!({
403 "model": self.model,
404 "messages": full_history,
405 "temperature": completion_request.temperature,
406 })
407 } else {
408 json!({
409 "model": self.model,
410 "messages": full_history,
411 "temperature": completion_request.temperature,
412 "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
413 "tool_choice": "auto",
414 })
415 };
416
417 let response = self
418 .client
419 .post_chat_completion(&self.model)
420 .json(
421 &if let Some(params) = completion_request.additional_params {
422 json_utils::merge(request, params)
423 } else {
424 request
425 },
426 )
427 .send()
428 .await?;
429
430 if response.status().is_success() {
431 let t = response.text().await?;
432 tracing::debug!(target: "rig", "Azure completion error: {}", t);
433
434 match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
435 ApiResponse::Ok(response) => {
436 tracing::info!(target: "rig",
437 "Azure completion token usage: {:?}",
438 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
439 );
440 response.try_into()
441 }
442 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
443 }
444 } else {
445 Err(CompletionError::ProviderError(response.text().await?))
446 }
447 }
448}
449
450#[cfg(test)]
451mod azure_tests {
452 use super::*;
453
454 use crate::completion::CompletionModel;
455 use crate::embeddings::EmbeddingModel;
456
457 #[tokio::test]
458 #[ignore]
459 async fn test_azure_embedding() {
460 let _ = tracing_subscriber::fmt::try_init();
461
462 let client = Client::from_env();
463 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
464 let embeddings = model
465 .embed_texts(vec!["Hello, world!".to_string()])
466 .await
467 .unwrap();
468
469 tracing::info!("Azure embedding: {:?}", embeddings);
470 }
471
472 #[tokio::test]
473 #[ignore]
474 async fn test_azure_completion() {
475 let _ = tracing_subscriber::fmt::try_init();
476
477 let client = Client::from_env();
478 let model = client.completion_model(GPT_4O_MINI);
479 let completion = model
480 .completion(CompletionRequest {
481 preamble: Some("You are a helpful assistant.".to_string()),
482 chat_history: vec![],
483 prompt: "Hello, world!".into(),
484 documents: vec![],
485 max_tokens: Some(100),
486 temperature: Some(0.0),
487 tools: vec![],
488 additional_params: None,
489 })
490 .await
491 .unwrap();
492
493 tracing::info!("Azure completion: {:?}", completion);
494 }
495}