1use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
13use crate::json_utils::merge;
14use crate::streaming::{StreamingCompletionModel, StreamingResult};
15use crate::{
16 agent::AgentBuilder,
17 completion::{self, CompletionError, CompletionRequest},
18 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
19 extractor::ExtractorBuilder,
20 json_utils,
21 providers::openai,
22 transcription::{self, TranscriptionError},
23 Embed,
24};
25use reqwest::multipart::Part;
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use serde_json::json;
29
30#[derive(Clone)]
35pub struct Client {
36 api_version: String,
37 azure_endpoint: String,
38 http_client: reqwest::Client,
39}
40
41#[derive(Clone)]
42pub enum AzureOpenAIAuth {
43 ApiKey(String),
44 Token(String),
45}
46
47impl From<String> for AzureOpenAIAuth {
48 fn from(token: String) -> Self {
49 AzureOpenAIAuth::Token(token)
50 }
51}
52
53impl Client {
54 pub fn new(auth: impl Into<AzureOpenAIAuth>, api_version: &str, azure_endpoint: &str) -> Self {
62 let mut headers = reqwest::header::HeaderMap::new();
63 match auth.into() {
64 AzureOpenAIAuth::ApiKey(api_key) => {
65 headers.insert("api-key", api_key.parse().expect("API key should parse"));
66 }
67 AzureOpenAIAuth::Token(token) => {
68 headers.insert(
69 "Authorization",
70 format!("Bearer {}", token)
71 .parse()
72 .expect("Token should parse"),
73 );
74 }
75 }
76
77 Self {
78 api_version: api_version.to_string(),
79 azure_endpoint: azure_endpoint.to_string(),
80 http_client: reqwest::Client::builder()
81 .default_headers(headers)
82 .build()
83 .expect("Azure OpenAI reqwest client should build"),
84 }
85 }
86
87 pub fn from_api_key(api_key: &str, api_version: &str, azure_endpoint: &str) -> Self {
95 Self::new(
96 AzureOpenAIAuth::ApiKey(api_key.to_string()),
97 api_version,
98 azure_endpoint,
99 )
100 }
101
102 pub fn from_token(token: &str, api_version: &str, azure_endpoint: &str) -> Self {
110 Self::new(
111 AzureOpenAIAuth::Token(token.to_string()),
112 api_version,
113 azure_endpoint,
114 )
115 }
116
117 pub fn from_env() -> Self {
119 let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
120 AzureOpenAIAuth::ApiKey(api_key)
121 } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
122 AzureOpenAIAuth::Token(token)
123 } else {
124 panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
125 };
126
127 let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
128 let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
129
130 Self::new(auth, &api_version, &azure_endpoint)
131 }
132
133 fn post_embedding(&self, deployment_id: &str) -> reqwest::RequestBuilder {
134 let url = format!(
135 "{}/openai/deployments/{}/embeddings?api-version={}",
136 self.azure_endpoint, deployment_id, self.api_version
137 )
138 .replace("//", "/");
139 self.http_client.post(url)
140 }
141
142 fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder {
143 let url = format!(
144 "{}/openai/deployments/{}/chat/completions?api-version={}",
145 self.azure_endpoint, deployment_id, self.api_version
146 )
147 .replace("//", "/");
148 self.http_client.post(url)
149 }
150
151 fn post_transcription(&self, deployment_id: &str) -> reqwest::RequestBuilder {
152 let url = format!(
153 "{}/openai/deployments/{}/audio/translations?api-version={}",
154 self.azure_endpoint, deployment_id, self.api_version
155 )
156 .replace("//", "/");
157 self.http_client.post(url)
158 }
159
160 pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
174 let ndims = match model {
175 TEXT_EMBEDDING_3_LARGE => 3072,
176 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
177 _ => 0,
178 };
179 EmbeddingModel::new(self.clone(), model, ndims)
180 }
181
182 pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
194 EmbeddingModel::new(self.clone(), model, ndims)
195 }
196
197 pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
214 EmbeddingsBuilder::new(self.embedding_model(model))
215 }
216
217 pub fn completion_model(&self, model: &str) -> CompletionModel {
229 CompletionModel::new(self.clone(), model)
230 }
231
232 pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
244 TranscriptionModel::new(self.clone(), model)
245 }
246
247 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
262 AgentBuilder::new(self.completion_model(model))
263 }
264
265 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
267 &self,
268 model: &str,
269 ) -> ExtractorBuilder<T, CompletionModel> {
270 ExtractorBuilder::new(self.completion_model(model))
271 }
272}
273
274#[derive(Debug, Deserialize)]
275struct ApiErrorResponse {
276 message: String,
277}
278
279#[derive(Debug, Deserialize)]
280#[serde(untagged)]
281enum ApiResponse<T> {
282 Ok(T),
283 Err(ApiErrorResponse),
284}
285
286pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
291pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
293pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
295
296#[derive(Debug, Deserialize)]
297pub struct EmbeddingResponse {
298 pub object: String,
299 pub data: Vec<EmbeddingData>,
300 pub model: String,
301 pub usage: Usage,
302}
303
304impl From<ApiErrorResponse> for EmbeddingError {
305 fn from(err: ApiErrorResponse) -> Self {
306 EmbeddingError::ProviderError(err.message)
307 }
308}
309
310impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
311 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
312 match value {
313 ApiResponse::Ok(response) => Ok(response),
314 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
315 }
316 }
317}
318
319#[derive(Debug, Deserialize)]
320pub struct EmbeddingData {
321 pub object: String,
322 pub embedding: Vec<f64>,
323 pub index: usize,
324}
325
326#[derive(Clone, Debug, Deserialize)]
327pub struct Usage {
328 pub prompt_tokens: usize,
329 pub total_tokens: usize,
330}
331
332impl std::fmt::Display for Usage {
333 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334 write!(
335 f,
336 "Prompt tokens: {} Total tokens: {}",
337 self.prompt_tokens, self.total_tokens
338 )
339 }
340}
341
342#[derive(Clone)]
343pub struct EmbeddingModel {
344 client: Client,
345 pub model: String,
346 ndims: usize,
347}
348
349impl embeddings::EmbeddingModel for EmbeddingModel {
350 const MAX_DOCUMENTS: usize = 1024;
351
352 fn ndims(&self) -> usize {
353 self.ndims
354 }
355
356 #[cfg_attr(feature = "worker", worker::send)]
357 async fn embed_texts(
358 &self,
359 documents: impl IntoIterator<Item = String>,
360 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
361 let documents = documents.into_iter().collect::<Vec<_>>();
362
363 let response = self
364 .client
365 .post_embedding(&self.model)
366 .json(&json!({
367 "input": documents,
368 }))
369 .send()
370 .await?;
371
372 if response.status().is_success() {
373 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
374 ApiResponse::Ok(response) => {
375 tracing::info!(target: "rig",
376 "Azure embedding token usage: {}",
377 response.usage
378 );
379
380 if response.data.len() != documents.len() {
381 return Err(EmbeddingError::ResponseError(
382 "Response data length does not match input length".into(),
383 ));
384 }
385
386 Ok(response
387 .data
388 .into_iter()
389 .zip(documents.into_iter())
390 .map(|(embedding, document)| embeddings::Embedding {
391 document,
392 vec: embedding.embedding,
393 })
394 .collect())
395 }
396 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
397 }
398 } else {
399 Err(EmbeddingError::ProviderError(response.text().await?))
400 }
401 }
402}
403
404impl EmbeddingModel {
405 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
406 Self {
407 client,
408 model: model.to_string(),
409 ndims,
410 }
411 }
412}
413
414pub const O1: &str = "o1";
419pub const O1_PREVIEW: &str = "o1-preview";
421pub const O1_MINI: &str = "o1-mini";
423pub const GPT_4O: &str = "gpt-4o";
425pub const GPT_4O_MINI: &str = "gpt-4o-mini";
427pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
429pub const GPT_4_TURBO: &str = "gpt-4";
431pub const GPT_4: &str = "gpt-4";
433pub const GPT_4_32K: &str = "gpt-4-32k";
435pub const GPT_4_32K_0613: &str = "gpt-4-32k";
437pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
439pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
441pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
443
444#[derive(Clone)]
445pub struct CompletionModel {
446 client: Client,
447 pub model: String,
449}
450
451impl CompletionModel {
452 pub fn new(client: Client, model: &str) -> Self {
453 Self {
454 client,
455 model: model.to_string(),
456 }
457 }
458
459 fn create_completion_request(
460 &self,
461 completion_request: CompletionRequest,
462 ) -> Result<serde_json::Value, CompletionError> {
463 let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
465 Some(preamble) => vec![openai::Message::system(preamble)],
466 None => vec![],
467 };
468
469 let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;
471
472 let chat_history: Vec<openai::Message> = completion_request
474 .chat_history
475 .into_iter()
476 .map(|message| message.try_into())
477 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
478 .into_iter()
479 .flatten()
480 .collect();
481
482 full_history.extend(chat_history);
484 full_history.extend(prompt);
485
486 let request = if completion_request.tools.is_empty() {
487 json!({
488 "model": self.model,
489 "messages": full_history,
490 "temperature": completion_request.temperature,
491 })
492 } else {
493 json!({
494 "model": self.model,
495 "messages": full_history,
496 "temperature": completion_request.temperature,
497 "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
498 "tool_choice": "auto",
499 })
500 };
501
502 let request = if let Some(params) = completion_request.additional_params {
503 json_utils::merge(request, params)
504 } else {
505 request
506 };
507
508 Ok(request)
509 }
510}
511
512impl completion::CompletionModel for CompletionModel {
513 type Response = openai::CompletionResponse;
514
515 #[cfg_attr(feature = "worker", worker::send)]
516 async fn completion(
517 &self,
518 completion_request: CompletionRequest,
519 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
520 let request = self.create_completion_request(completion_request)?;
521
522 let response = self
523 .client
524 .post_chat_completion(&self.model)
525 .json(&request)
526 .send()
527 .await?;
528
529 if response.status().is_success() {
530 let t = response.text().await?;
531 tracing::debug!(target: "rig", "Azure completion error: {}", t);
532
533 match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
534 ApiResponse::Ok(response) => {
535 tracing::info!(target: "rig",
536 "Azure completion token usage: {:?}",
537 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
538 );
539 response.try_into()
540 }
541 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
542 }
543 } else {
544 Err(CompletionError::ProviderError(response.text().await?))
545 }
546 }
547}
548
549impl StreamingCompletionModel for CompletionModel {
553 async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
554 let mut request = self.create_completion_request(request)?;
555
556 request = merge(request, json!({"stream": true}));
557
558 let builder = self
559 .client
560 .post_chat_completion(self.model.as_str())
561 .json(&request);
562
563 send_compatible_streaming_request(builder).await
564 }
565}
566
567#[derive(Clone)]
572pub struct TranscriptionModel {
573 client: Client,
574 pub model: String,
576}
577
578impl TranscriptionModel {
579 pub fn new(client: Client, model: &str) -> Self {
580 Self {
581 client,
582 model: model.to_string(),
583 }
584 }
585}
586
587impl transcription::TranscriptionModel for TranscriptionModel {
588 type Response = TranscriptionResponse;
589
590 #[cfg_attr(feature = "worker", worker::send)]
591 async fn transcription(
592 &self,
593 request: transcription::TranscriptionRequest,
594 ) -> Result<
595 transcription::TranscriptionResponse<Self::Response>,
596 transcription::TranscriptionError,
597 > {
598 let data = request.data;
599
600 let mut body = reqwest::multipart::Form::new().part(
601 "file",
602 Part::bytes(data).file_name(request.filename.clone()),
603 );
604
605 if let Some(prompt) = request.prompt {
606 body = body.text("prompt", prompt.clone());
607 }
608
609 if let Some(ref temperature) = request.temperature {
610 body = body.text("temperature", temperature.to_string());
611 }
612
613 if let Some(ref additional_params) = request.additional_params {
614 for (key, value) in additional_params
615 .as_object()
616 .expect("Additional Parameters to OpenAI Transcription should be a map")
617 {
618 body = body.text(key.to_owned(), value.to_string());
619 }
620 }
621
622 let response = self
623 .client
624 .post_transcription(&self.model)
625 .multipart(body)
626 .send()
627 .await?;
628
629 if response.status().is_success() {
630 match response
631 .json::<ApiResponse<TranscriptionResponse>>()
632 .await?
633 {
634 ApiResponse::Ok(response) => response.try_into(),
635 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
636 api_error_response.message,
637 )),
638 }
639 } else {
640 Err(TranscriptionError::ProviderError(response.text().await?))
641 }
642 }
643}
644
645#[cfg(test)]
646mod azure_tests {
647 use super::*;
648
649 use crate::completion::CompletionModel;
650 use crate::embeddings::EmbeddingModel;
651
652 #[tokio::test]
653 #[ignore]
654 async fn test_azure_embedding() {
655 let _ = tracing_subscriber::fmt::try_init();
656
657 let client = Client::from_env();
658 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
659 let embeddings = model
660 .embed_texts(vec!["Hello, world!".to_string()])
661 .await
662 .unwrap();
663
664 tracing::info!("Azure embedding: {:?}", embeddings);
665 }
666
667 #[tokio::test]
668 #[ignore]
669 async fn test_azure_completion() {
670 let _ = tracing_subscriber::fmt::try_init();
671
672 let client = Client::from_env();
673 let model = client.completion_model(GPT_4O_MINI);
674 let completion = model
675 .completion(CompletionRequest {
676 preamble: Some("You are a helpful assistant.".to_string()),
677 chat_history: vec![],
678 prompt: "Hello, world!".into(),
679 documents: vec![],
680 max_tokens: Some(100),
681 temperature: Some(0.0),
682 tools: vec![],
683 additional_params: None,
684 })
685 .await
686 .unwrap();
687
688 tracing::info!("Azure completion: {:?}", completion);
689 }
690}