1use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
13
14use crate::json_utils::merge;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17 completion::{self, CompletionError, CompletionRequest},
18 embeddings::{self, EmbeddingError},
19 json_utils,
20 providers::openai,
21 transcription::{self, TranscriptionError},
22};
23use reqwest::multipart::Part;
24use serde::Deserialize;
25use serde_json::json;
26#[derive(Debug, Clone)]
31pub struct Client {
32 api_version: String,
33 azure_endpoint: String,
34 http_client: reqwest::Client,
35}
36
37#[derive(Clone)]
38pub enum AzureOpenAIAuth {
39 ApiKey(String),
40 Token(String),
41}
42
43impl From<String> for AzureOpenAIAuth {
44 fn from(token: String) -> Self {
45 AzureOpenAIAuth::Token(token)
46 }
47}
48
49impl Client {
50 pub fn new(auth: impl Into<AzureOpenAIAuth>, api_version: &str, azure_endpoint: &str) -> Self {
58 let mut headers = reqwest::header::HeaderMap::new();
59 match auth.into() {
60 AzureOpenAIAuth::ApiKey(api_key) => {
61 headers.insert("api-key", api_key.parse().expect("API key should parse"));
62 }
63 AzureOpenAIAuth::Token(token) => {
64 headers.insert(
65 "Authorization",
66 format!("Bearer {token}")
67 .parse()
68 .expect("Token should parse"),
69 );
70 }
71 }
72
73 Self {
74 api_version: api_version.to_string(),
75 azure_endpoint: azure_endpoint.to_string(),
76 http_client: reqwest::Client::builder()
77 .default_headers(headers)
78 .build()
79 .expect("Azure OpenAI reqwest client should build"),
80 }
81 }
82
83 pub fn from_api_key(api_key: &str, api_version: &str, azure_endpoint: &str) -> Self {
91 Self::new(
92 AzureOpenAIAuth::ApiKey(api_key.to_string()),
93 api_version,
94 azure_endpoint,
95 )
96 }
97
98 pub fn from_token(token: &str, api_version: &str, azure_endpoint: &str) -> Self {
106 Self::new(
107 AzureOpenAIAuth::Token(token.to_string()),
108 api_version,
109 azure_endpoint,
110 )
111 }
112
113 fn post_embedding(&self, deployment_id: &str) -> reqwest::RequestBuilder {
114 let url = format!(
115 "{}/openai/deployments/{}/embeddings?api-version={}",
116 self.azure_endpoint, deployment_id, self.api_version
117 )
118 .replace("//", "/");
119 self.http_client.post(url)
120 }
121
122 fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder {
123 let url = format!(
124 "{}/openai/deployments/{}/chat/completions?api-version={}",
125 self.azure_endpoint, deployment_id, self.api_version
126 )
127 .replace("//", "/");
128 self.http_client.post(url)
129 }
130
131 fn post_transcription(&self, deployment_id: &str) -> reqwest::RequestBuilder {
132 let url = format!(
133 "{}/openai/deployments/{}/audio/translations?api-version={}",
134 self.azure_endpoint, deployment_id, self.api_version
135 )
136 .replace("//", "/");
137 self.http_client.post(url)
138 }
139
140 #[cfg(feature = "image")]
141 fn post_image_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
142 let url = format!(
143 "{}/openai/deployments/{}/images/generations?api-version={}",
144 self.azure_endpoint, deployment_id, self.api_version
145 )
146 .replace("//", "/");
147 self.http_client.post(url)
148 }
149
150 #[cfg(feature = "audio")]
151 fn post_audio_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
152 let url = format!(
153 "{}/openai/deployments/{}/audio/speech?api-version={}",
154 self.azure_endpoint, deployment_id, self.api_version
155 )
156 .replace("//", "/");
157 self.http_client.post(url)
158 }
159}
160
161impl ProviderClient for Client {
162 fn from_env() -> Self {
164 let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
165 AzureOpenAIAuth::ApiKey(api_key)
166 } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
167 AzureOpenAIAuth::Token(token)
168 } else {
169 panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
170 };
171
172 let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
173 let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
174
175 Self::new(auth, &api_version, &azure_endpoint)
176 }
177}
178
179impl CompletionClient for Client {
180 type CompletionModel = CompletionModel;
181
182 fn completion_model(&self, model: &str) -> CompletionModel {
194 CompletionModel::new(self.clone(), model)
195 }
196}
197
198impl EmbeddingsClient for Client {
199 type EmbeddingModel = EmbeddingModel;
200
201 fn embedding_model(&self, model: &str) -> EmbeddingModel {
215 let ndims = match model {
216 TEXT_EMBEDDING_3_LARGE => 3072,
217 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
218 _ => 0,
219 };
220 EmbeddingModel::new(self.clone(), model, ndims)
221 }
222
223 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
235 EmbeddingModel::new(self.clone(), model, ndims)
236 }
237}
238
239impl TranscriptionClient for Client {
240 type TranscriptionModel = TranscriptionModel;
241
242 fn transcription_model(&self, model: &str) -> TranscriptionModel {
254 TranscriptionModel::new(self.clone(), model)
255 }
256}
257
258#[derive(Debug, Deserialize)]
259struct ApiErrorResponse {
260 message: String,
261}
262
263#[derive(Debug, Deserialize)]
264#[serde(untagged)]
265enum ApiResponse<T> {
266 Ok(T),
267 Err(ApiErrorResponse),
268}
269
270pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
275pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
277pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
279
280#[derive(Debug, Deserialize)]
281pub struct EmbeddingResponse {
282 pub object: String,
283 pub data: Vec<EmbeddingData>,
284 pub model: String,
285 pub usage: Usage,
286}
287
288impl From<ApiErrorResponse> for EmbeddingError {
289 fn from(err: ApiErrorResponse) -> Self {
290 EmbeddingError::ProviderError(err.message)
291 }
292}
293
294impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
295 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
296 match value {
297 ApiResponse::Ok(response) => Ok(response),
298 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
299 }
300 }
301}
302
303#[derive(Debug, Deserialize)]
304pub struct EmbeddingData {
305 pub object: String,
306 pub embedding: Vec<f64>,
307 pub index: usize,
308}
309
310#[derive(Clone, Debug, Deserialize)]
311pub struct Usage {
312 pub prompt_tokens: usize,
313 pub total_tokens: usize,
314}
315
316impl std::fmt::Display for Usage {
317 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318 write!(
319 f,
320 "Prompt tokens: {} Total tokens: {}",
321 self.prompt_tokens, self.total_tokens
322 )
323 }
324}
325
326#[derive(Clone)]
327pub struct EmbeddingModel {
328 client: Client,
329 pub model: String,
330 ndims: usize,
331}
332
333impl embeddings::EmbeddingModel for EmbeddingModel {
334 const MAX_DOCUMENTS: usize = 1024;
335
336 fn ndims(&self) -> usize {
337 self.ndims
338 }
339
340 #[cfg_attr(feature = "worker", worker::send)]
341 async fn embed_texts(
342 &self,
343 documents: impl IntoIterator<Item = String>,
344 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
345 let documents = documents.into_iter().collect::<Vec<_>>();
346
347 let response = self
348 .client
349 .post_embedding(&self.model)
350 .json(&json!({
351 "input": documents,
352 }))
353 .send()
354 .await?;
355
356 if response.status().is_success() {
357 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
358 ApiResponse::Ok(response) => {
359 tracing::info!(target: "rig",
360 "Azure embedding token usage: {}",
361 response.usage
362 );
363
364 if response.data.len() != documents.len() {
365 return Err(EmbeddingError::ResponseError(
366 "Response data length does not match input length".into(),
367 ));
368 }
369
370 Ok(response
371 .data
372 .into_iter()
373 .zip(documents.into_iter())
374 .map(|(embedding, document)| embeddings::Embedding {
375 document,
376 vec: embedding.embedding,
377 })
378 .collect())
379 }
380 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
381 }
382 } else {
383 Err(EmbeddingError::ProviderError(response.text().await?))
384 }
385 }
386}
387
388impl EmbeddingModel {
389 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
390 Self {
391 client,
392 model: model.to_string(),
393 ndims,
394 }
395 }
396}
397
398pub const O1: &str = "o1";
403pub const O1_PREVIEW: &str = "o1-preview";
405pub const O1_MINI: &str = "o1-mini";
407pub const GPT_4O: &str = "gpt-4o";
409pub const GPT_4O_MINI: &str = "gpt-4o-mini";
411pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
413pub const GPT_4_TURBO: &str = "gpt-4";
415pub const GPT_4: &str = "gpt-4";
417pub const GPT_4_32K: &str = "gpt-4-32k";
419pub const GPT_4_32K_0613: &str = "gpt-4-32k";
421pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
423pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
425pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
427
428#[derive(Clone)]
429pub struct CompletionModel {
430 client: Client,
431 pub model: String,
433}
434
435impl CompletionModel {
436 pub fn new(client: Client, model: &str) -> Self {
437 Self {
438 client,
439 model: model.to_string(),
440 }
441 }
442
443 fn create_completion_request(
444 &self,
445 completion_request: CompletionRequest,
446 ) -> Result<serde_json::Value, CompletionError> {
447 let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
448 Some(preamble) => vec![openai::Message::system(preamble)],
449 None => vec![],
450 };
451 if let Some(docs) = completion_request.normalized_documents() {
452 let docs: Vec<openai::Message> = docs.try_into()?;
453 full_history.extend(docs);
454 }
455 let chat_history: Vec<openai::Message> = completion_request
456 .chat_history
457 .into_iter()
458 .map(|message| message.try_into())
459 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
460 .into_iter()
461 .flatten()
462 .collect();
463
464 full_history.extend(chat_history);
465
466 let request = if completion_request.tools.is_empty() {
467 json!({
468 "model": self.model,
469 "messages": full_history,
470 "temperature": completion_request.temperature,
471 })
472 } else {
473 json!({
474 "model": self.model,
475 "messages": full_history,
476 "temperature": completion_request.temperature,
477 "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
478 "tool_choice": "auto",
479 })
480 };
481
482 let request = if let Some(params) = completion_request.additional_params {
483 json_utils::merge(request, params)
484 } else {
485 request
486 };
487
488 Ok(request)
489 }
490}
491
492impl completion::CompletionModel for CompletionModel {
493 type Response = openai::CompletionResponse;
494 type StreamingResponse = openai::StreamingCompletionResponse;
495
496 #[cfg_attr(feature = "worker", worker::send)]
497 async fn completion(
498 &self,
499 completion_request: CompletionRequest,
500 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
501 let request = self.create_completion_request(completion_request)?;
502
503 let response = self
504 .client
505 .post_chat_completion(&self.model)
506 .json(&request)
507 .send()
508 .await?;
509
510 if response.status().is_success() {
511 let t = response.text().await?;
512 tracing::debug!(target: "rig", "Azure completion error: {}", t);
513
514 match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
515 ApiResponse::Ok(response) => {
516 tracing::info!(target: "rig",
517 "Azure completion token usage: {:?}",
518 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
519 );
520 response.try_into()
521 }
522 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
523 }
524 } else {
525 Err(CompletionError::ProviderError(response.text().await?))
526 }
527 }
528
529 #[cfg_attr(feature = "worker", worker::send)]
530 async fn stream(
531 &self,
532 request: CompletionRequest,
533 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
534 let mut request = self.create_completion_request(request)?;
535
536 request = merge(
537 request,
538 json!({"stream": true, "stream_options": {"include_usage": true}}),
539 );
540
541 let builder = self
542 .client
543 .post_chat_completion(self.model.as_str())
544 .json(&request);
545
546 send_compatible_streaming_request(builder).await
547 }
548}
549
550#[derive(Clone)]
555pub struct TranscriptionModel {
556 client: Client,
557 pub model: String,
559}
560
561impl TranscriptionModel {
562 pub fn new(client: Client, model: &str) -> Self {
563 Self {
564 client,
565 model: model.to_string(),
566 }
567 }
568}
569
570impl transcription::TranscriptionModel for TranscriptionModel {
571 type Response = TranscriptionResponse;
572
573 #[cfg_attr(feature = "worker", worker::send)]
574 async fn transcription(
575 &self,
576 request: transcription::TranscriptionRequest,
577 ) -> Result<
578 transcription::TranscriptionResponse<Self::Response>,
579 transcription::TranscriptionError,
580 > {
581 let data = request.data;
582
583 let mut body = reqwest::multipart::Form::new().part(
584 "file",
585 Part::bytes(data).file_name(request.filename.clone()),
586 );
587
588 if let Some(prompt) = request.prompt {
589 body = body.text("prompt", prompt.clone());
590 }
591
592 if let Some(ref temperature) = request.temperature {
593 body = body.text("temperature", temperature.to_string());
594 }
595
596 if let Some(ref additional_params) = request.additional_params {
597 for (key, value) in additional_params
598 .as_object()
599 .expect("Additional Parameters to OpenAI Transcription should be a map")
600 {
601 body = body.text(key.to_owned(), value.to_string());
602 }
603 }
604
605 let response = self
606 .client
607 .post_transcription(&self.model)
608 .multipart(body)
609 .send()
610 .await?;
611
612 if response.status().is_success() {
613 match response
614 .json::<ApiResponse<TranscriptionResponse>>()
615 .await?
616 {
617 ApiResponse::Ok(response) => response.try_into(),
618 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
619 api_error_response.message,
620 )),
621 }
622 } else {
623 Err(TranscriptionError::ProviderError(response.text().await?))
624 }
625 }
626}
627
628#[cfg(feature = "image")]
632pub use image_generation::*;
633#[cfg(feature = "image")]
634mod image_generation {
635 use crate::client::ImageGenerationClient;
636 use crate::image_generation;
637 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
638 use crate::providers::azure::{ApiResponse, Client};
639 use crate::providers::openai::ImageGenerationResponse;
640 use serde_json::json;
641
642 #[derive(Clone)]
643 pub struct ImageGenerationModel {
644 client: Client,
645 pub model: String,
646 }
647 impl image_generation::ImageGenerationModel for ImageGenerationModel {
648 type Response = ImageGenerationResponse;
649
650 async fn image_generation(
651 &self,
652 generation_request: ImageGenerationRequest,
653 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
654 {
655 let request = json!({
656 "model": self.model,
657 "prompt": generation_request.prompt,
658 "size": format!("{}x{}", generation_request.width, generation_request.height),
659 "response_format": "b64_json"
660 });
661
662 let response = self
663 .client
664 .post_image_generation(&self.model)
665 .json(&request)
666 .send()
667 .await?;
668
669 if !response.status().is_success() {
670 return Err(ImageGenerationError::ProviderError(format!(
671 "{}: {}",
672 response.status(),
673 response.text().await?
674 )));
675 }
676
677 let t = response.text().await?;
678
679 match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&t)? {
680 ApiResponse::Ok(response) => response.try_into(),
681 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
682 }
683 }
684 }
685
686 impl ImageGenerationClient for Client {
687 type ImageGenerationModel = ImageGenerationModel;
688
689 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
690 ImageGenerationModel {
691 client: self.clone(),
692 model: model.to_string(),
693 }
694 }
695 }
696}
697use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient};
702#[cfg(feature = "audio")]
703pub use audio_generation::*;
704
705#[cfg(feature = "audio")]
706mod audio_generation {
707 use super::Client;
708 use crate::audio_generation;
709 use crate::audio_generation::{
710 AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
711 };
712 use crate::client::AudioGenerationClient;
713 use bytes::Bytes;
714 use serde_json::json;
715
716 #[derive(Clone)]
717 pub struct AudioGenerationModel {
718 client: Client,
719 model: String,
720 }
721
722 impl audio_generation::AudioGenerationModel for AudioGenerationModel {
723 type Response = Bytes;
724
725 async fn audio_generation(
726 &self,
727 request: AudioGenerationRequest,
728 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
729 let request = json!({
730 "model": self.model,
731 "input": request.text,
732 "voice": request.voice,
733 "speed": request.speed,
734 });
735
736 let response = self
737 .client
738 .post_audio_generation("/audio/speech")
739 .json(&request)
740 .send()
741 .await?;
742
743 if !response.status().is_success() {
744 return Err(AudioGenerationError::ProviderError(format!(
745 "{}: {}",
746 response.status(),
747 response.text().await?
748 )));
749 }
750
751 let bytes = response.bytes().await?;
752
753 Ok(AudioGenerationResponse {
754 audio: bytes.to_vec(),
755 response: bytes,
756 })
757 }
758 }
759
760 impl AudioGenerationClient for Client {
761 type AudioGenerationModel = AudioGenerationModel;
762
763 fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
764 AudioGenerationModel {
765 client: self.clone(),
766 model: model.to_string(),
767 }
768 }
769 }
770}
771
772#[cfg(test)]
773mod azure_tests {
774 use super::*;
775
776 use crate::completion::CompletionModel;
777 use crate::embeddings::EmbeddingModel;
778 use crate::OneOrMany;
779
780 #[tokio::test]
781 #[ignore]
782 async fn test_azure_embedding() {
783 let _ = tracing_subscriber::fmt::try_init();
784
785 let client = Client::from_env();
786 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
787 let embeddings = model
788 .embed_texts(vec!["Hello, world!".to_string()])
789 .await
790 .unwrap();
791
792 tracing::info!("Azure embedding: {:?}", embeddings);
793 }
794
795 #[tokio::test]
796 #[ignore]
797 async fn test_azure_completion() {
798 let _ = tracing_subscriber::fmt::try_init();
799
800 let client = Client::from_env();
801 let model = client.completion_model(GPT_4O_MINI);
802 let completion = model
803 .completion(CompletionRequest {
804 preamble: Some("You are a helpful assistant.".to_string()),
805 chat_history: OneOrMany::one("Hello!".into()),
806 documents: vec![],
807 max_tokens: Some(100),
808 temperature: Some(0.0),
809 tools: vec![],
810 additional_params: None,
811 })
812 .await
813 .unwrap();
814
815 tracing::info!("Azure completion: {:?}", completion);
816 }
817}