1use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
13
14use crate::json_utils::merge;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17 completion::{self, CompletionError, CompletionRequest},
18 embeddings::{self, EmbeddingError},
19 json_utils,
20 providers::openai,
21 transcription::{self, TranscriptionError},
22};
23use reqwest::header::AUTHORIZATION;
24use reqwest::multipart::Part;
25use serde::Deserialize;
26use serde_json::json;
27#[derive(Clone)]
32pub struct Client {
33 api_version: String,
34 azure_endpoint: String,
35 auth: AzureOpenAIAuth,
36 http_client: reqwest::Client,
37}
38
39impl std::fmt::Debug for Client {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("Client")
42 .field("azure_endpoint", &self.azure_endpoint)
43 .field("http_client", &self.http_client)
44 .field("auth", &"<REDACTED>")
45 .field("api_version", &self.api_version)
46 .finish()
47 }
48}
49
50#[derive(Clone)]
51pub enum AzureOpenAIAuth {
52 ApiKey(String),
53 Token(String),
54}
55
56impl std::fmt::Debug for AzureOpenAIAuth {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 match self {
59 Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
60 Self::Token(_) => write!(f, "Token <REDACTED>"),
61 }
62 }
63}
64
65impl From<String> for AzureOpenAIAuth {
66 fn from(token: String) -> Self {
67 AzureOpenAIAuth::Token(token)
68 }
69}
70
71impl AzureOpenAIAuth {
72 fn as_header(&self) -> (reqwest::header::HeaderName, reqwest::header::HeaderValue) {
73 match self {
74 AzureOpenAIAuth::ApiKey(api_key) => (
75 "api-key".parse().expect("Header value should parse"),
76 api_key.parse().expect("API key should parse"),
77 ),
78 AzureOpenAIAuth::Token(token) => (
79 AUTHORIZATION,
80 format!("Bearer {token}")
81 .parse()
82 .expect("Token should parse"),
83 ),
84 }
85 }
86}
87
88impl Client {
89 pub fn new(auth: impl Into<AzureOpenAIAuth>, api_version: &str, azure_endpoint: &str) -> Self {
97 Self {
98 api_version: api_version.to_string(),
99 auth: auth.into(),
100 azure_endpoint: azure_endpoint.to_string(),
101 http_client: reqwest::Client::builder()
102 .build()
103 .expect("Azure OpenAI reqwest client should build"),
104 }
105 }
106
107 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
110 self.http_client = client;
111
112 self
113 }
114
115 pub fn from_api_key(api_key: &str, api_version: &str, azure_endpoint: &str) -> Self {
123 Self::new(
124 AzureOpenAIAuth::ApiKey(api_key.to_string()),
125 api_version,
126 azure_endpoint,
127 )
128 }
129
130 pub fn from_token(token: &str, api_version: &str, azure_endpoint: &str) -> Self {
138 Self::new(
139 AzureOpenAIAuth::Token(token.to_string()),
140 api_version,
141 azure_endpoint,
142 )
143 }
144
145 fn post_embedding(&self, deployment_id: &str) -> reqwest::RequestBuilder {
146 let url = format!(
147 "{}/openai/deployments/{}/embeddings?api-version={}",
148 self.azure_endpoint, deployment_id, self.api_version
149 )
150 .replace("//", "/");
151
152 let (key, value) = self.auth.as_header();
153 self.http_client.post(url).header(key, value)
154 }
155
156 fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder {
157 let url = format!(
158 "{}/openai/deployments/{}/chat/completions?api-version={}",
159 self.azure_endpoint, deployment_id, self.api_version
160 )
161 .replace("//", "/");
162 let (key, value) = self.auth.as_header();
163 self.http_client.post(url).header(key, value)
164 }
165
166 fn post_transcription(&self, deployment_id: &str) -> reqwest::RequestBuilder {
167 let url = format!(
168 "{}/openai/deployments/{}/audio/translations?api-version={}",
169 self.azure_endpoint, deployment_id, self.api_version
170 )
171 .replace("//", "/");
172 let (key, value) = self.auth.as_header();
173 self.http_client.post(url).header(key, value)
174 }
175
176 #[cfg(feature = "image")]
177 fn post_image_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
178 let url = format!(
179 "{}/openai/deployments/{}/images/generations?api-version={}",
180 self.azure_endpoint, deployment_id, self.api_version
181 )
182 .replace("//", "/");
183 let (key, value) = self.auth.as_header();
184 self.http_client.post(url).header(key, value)
185 }
186
187 #[cfg(feature = "audio")]
188 fn post_audio_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
189 let url = format!(
190 "{}/openai/deployments/{}/audio/speech?api-version={}",
191 self.azure_endpoint, deployment_id, self.api_version
192 )
193 .replace("//", "/");
194 let (key, value) = self.auth.as_header();
195 self.http_client.post(url).header(key, value)
196 }
197}
198
199impl ProviderClient for Client {
200 fn from_env() -> Self {
202 let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
203 AzureOpenAIAuth::ApiKey(api_key)
204 } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
205 AzureOpenAIAuth::Token(token)
206 } else {
207 panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
208 };
209
210 let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
211 let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
212
213 Self::new(auth, &api_version, &azure_endpoint)
214 }
215
216 fn from_val(input: crate::client::ProviderValue) -> Self {
217 let crate::client::ProviderValue::ApiKeyWithVersionAndHeader(api_key, version, header) =
218 input
219 else {
220 panic!("Incorrect provider value type")
221 };
222 let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
223 Self::new(auth, &version, &header)
224 }
225}
226
227impl CompletionClient for Client {
228 type CompletionModel = CompletionModel;
229
230 fn completion_model(&self, model: &str) -> CompletionModel {
242 CompletionModel::new(self.clone(), model)
243 }
244}
245
246impl EmbeddingsClient for Client {
247 type EmbeddingModel = EmbeddingModel;
248
249 fn embedding_model(&self, model: &str) -> EmbeddingModel {
263 let ndims = match model {
264 TEXT_EMBEDDING_3_LARGE => 3072,
265 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
266 _ => 0,
267 };
268 EmbeddingModel::new(self.clone(), model, ndims)
269 }
270
271 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
283 EmbeddingModel::new(self.clone(), model, ndims)
284 }
285}
286
287impl TranscriptionClient for Client {
288 type TranscriptionModel = TranscriptionModel;
289
290 fn transcription_model(&self, model: &str) -> TranscriptionModel {
302 TranscriptionModel::new(self.clone(), model)
303 }
304}
305
306#[derive(Debug, Deserialize)]
307struct ApiErrorResponse {
308 message: String,
309}
310
311#[derive(Debug, Deserialize)]
312#[serde(untagged)]
313enum ApiResponse<T> {
314 Ok(T),
315 Err(ApiErrorResponse),
316}
317
318pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
323pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
325pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
327
328#[derive(Debug, Deserialize)]
329pub struct EmbeddingResponse {
330 pub object: String,
331 pub data: Vec<EmbeddingData>,
332 pub model: String,
333 pub usage: Usage,
334}
335
336impl From<ApiErrorResponse> for EmbeddingError {
337 fn from(err: ApiErrorResponse) -> Self {
338 EmbeddingError::ProviderError(err.message)
339 }
340}
341
342impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
343 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
344 match value {
345 ApiResponse::Ok(response) => Ok(response),
346 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
347 }
348 }
349}
350
351#[derive(Debug, Deserialize)]
352pub struct EmbeddingData {
353 pub object: String,
354 pub embedding: Vec<f64>,
355 pub index: usize,
356}
357
358#[derive(Clone, Debug, Deserialize)]
359pub struct Usage {
360 pub prompt_tokens: usize,
361 pub total_tokens: usize,
362}
363
364impl std::fmt::Display for Usage {
365 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366 write!(
367 f,
368 "Prompt tokens: {} Total tokens: {}",
369 self.prompt_tokens, self.total_tokens
370 )
371 }
372}
373
374#[derive(Clone)]
375pub struct EmbeddingModel {
376 client: Client,
377 pub model: String,
378 ndims: usize,
379}
380
381impl embeddings::EmbeddingModel for EmbeddingModel {
382 const MAX_DOCUMENTS: usize = 1024;
383
384 fn ndims(&self) -> usize {
385 self.ndims
386 }
387
388 #[cfg_attr(feature = "worker", worker::send)]
389 async fn embed_texts(
390 &self,
391 documents: impl IntoIterator<Item = String>,
392 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
393 let documents = documents.into_iter().collect::<Vec<_>>();
394
395 let response = self
396 .client
397 .post_embedding(&self.model)
398 .json(&json!({
399 "input": documents,
400 }))
401 .send()
402 .await?;
403
404 if response.status().is_success() {
405 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
406 ApiResponse::Ok(response) => {
407 tracing::info!(target: "rig",
408 "Azure embedding token usage: {}",
409 response.usage
410 );
411
412 if response.data.len() != documents.len() {
413 return Err(EmbeddingError::ResponseError(
414 "Response data length does not match input length".into(),
415 ));
416 }
417
418 Ok(response
419 .data
420 .into_iter()
421 .zip(documents.into_iter())
422 .map(|(embedding, document)| embeddings::Embedding {
423 document,
424 vec: embedding.embedding,
425 })
426 .collect())
427 }
428 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
429 }
430 } else {
431 Err(EmbeddingError::ProviderError(response.text().await?))
432 }
433 }
434}
435
436impl EmbeddingModel {
437 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
438 Self {
439 client,
440 model: model.to_string(),
441 ndims,
442 }
443 }
444}
445
446pub const O1: &str = "o1";
451pub const O1_PREVIEW: &str = "o1-preview";
453pub const O1_MINI: &str = "o1-mini";
455pub const GPT_4O: &str = "gpt-4o";
457pub const GPT_4O_MINI: &str = "gpt-4o-mini";
459pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
461pub const GPT_4_TURBO: &str = "gpt-4";
463pub const GPT_4: &str = "gpt-4";
465pub const GPT_4_32K: &str = "gpt-4-32k";
467pub const GPT_4_32K_0613: &str = "gpt-4-32k";
469pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
471pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
473pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
475
476#[derive(Clone)]
477pub struct CompletionModel {
478 client: Client,
479 pub model: String,
481}
482
483impl CompletionModel {
484 pub fn new(client: Client, model: &str) -> Self {
485 Self {
486 client,
487 model: model.to_string(),
488 }
489 }
490
491 fn create_completion_request(
492 &self,
493 completion_request: CompletionRequest,
494 ) -> Result<serde_json::Value, CompletionError> {
495 let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
496 Some(preamble) => vec![openai::Message::system(preamble)],
497 None => vec![],
498 };
499 if let Some(docs) = completion_request.normalized_documents() {
500 let docs: Vec<openai::Message> = docs.try_into()?;
501 full_history.extend(docs);
502 }
503 let chat_history: Vec<openai::Message> = completion_request
504 .chat_history
505 .into_iter()
506 .map(|message| message.try_into())
507 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
508 .into_iter()
509 .flatten()
510 .collect();
511
512 full_history.extend(chat_history);
513
514 let request = if completion_request.tools.is_empty() {
515 json!({
516 "model": self.model,
517 "messages": full_history,
518 "temperature": completion_request.temperature,
519 })
520 } else {
521 json!({
522 "model": self.model,
523 "messages": full_history,
524 "temperature": completion_request.temperature,
525 "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
526 "tool_choice": "auto",
527 })
528 };
529
530 let request = if let Some(params) = completion_request.additional_params {
531 json_utils::merge(request, params)
532 } else {
533 request
534 };
535
536 Ok(request)
537 }
538}
539
540impl completion::CompletionModel for CompletionModel {
541 type Response = openai::CompletionResponse;
542 type StreamingResponse = openai::StreamingCompletionResponse;
543
544 #[cfg_attr(feature = "worker", worker::send)]
545 async fn completion(
546 &self,
547 completion_request: CompletionRequest,
548 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
549 let request = self.create_completion_request(completion_request)?;
550
551 let response = self
552 .client
553 .post_chat_completion(&self.model)
554 .json(&request)
555 .send()
556 .await?;
557
558 if response.status().is_success() {
559 let t = response.text().await?;
560 tracing::debug!(target: "rig", "Azure completion error: {}", t);
561
562 match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
563 ApiResponse::Ok(response) => {
564 tracing::info!(target: "rig",
565 "Azure completion token usage: {:?}",
566 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
567 );
568 response.try_into()
569 }
570 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
571 }
572 } else {
573 Err(CompletionError::ProviderError(response.text().await?))
574 }
575 }
576
577 #[cfg_attr(feature = "worker", worker::send)]
578 async fn stream(
579 &self,
580 request: CompletionRequest,
581 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
582 let mut request = self.create_completion_request(request)?;
583
584 request = merge(
585 request,
586 json!({"stream": true, "stream_options": {"include_usage": true}}),
587 );
588
589 let builder = self
590 .client
591 .post_chat_completion(self.model.as_str())
592 .json(&request);
593
594 send_compatible_streaming_request(builder).await
595 }
596}
597
598#[derive(Clone)]
603pub struct TranscriptionModel {
604 client: Client,
605 pub model: String,
607}
608
609impl TranscriptionModel {
610 pub fn new(client: Client, model: &str) -> Self {
611 Self {
612 client,
613 model: model.to_string(),
614 }
615 }
616}
617
618impl transcription::TranscriptionModel for TranscriptionModel {
619 type Response = TranscriptionResponse;
620
621 #[cfg_attr(feature = "worker", worker::send)]
622 async fn transcription(
623 &self,
624 request: transcription::TranscriptionRequest,
625 ) -> Result<
626 transcription::TranscriptionResponse<Self::Response>,
627 transcription::TranscriptionError,
628 > {
629 let data = request.data;
630
631 let mut body = reqwest::multipart::Form::new().part(
632 "file",
633 Part::bytes(data).file_name(request.filename.clone()),
634 );
635
636 if let Some(prompt) = request.prompt {
637 body = body.text("prompt", prompt.clone());
638 }
639
640 if let Some(ref temperature) = request.temperature {
641 body = body.text("temperature", temperature.to_string());
642 }
643
644 if let Some(ref additional_params) = request.additional_params {
645 for (key, value) in additional_params
646 .as_object()
647 .expect("Additional Parameters to OpenAI Transcription should be a map")
648 {
649 body = body.text(key.to_owned(), value.to_string());
650 }
651 }
652
653 let response = self
654 .client
655 .post_transcription(&self.model)
656 .multipart(body)
657 .send()
658 .await?;
659
660 if response.status().is_success() {
661 match response
662 .json::<ApiResponse<TranscriptionResponse>>()
663 .await?
664 {
665 ApiResponse::Ok(response) => response.try_into(),
666 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
667 api_error_response.message,
668 )),
669 }
670 } else {
671 Err(TranscriptionError::ProviderError(response.text().await?))
672 }
673 }
674}
675
676#[cfg(feature = "image")]
680pub use image_generation::*;
681#[cfg(feature = "image")]
682mod image_generation {
683 use crate::client::ImageGenerationClient;
684 use crate::image_generation;
685 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
686 use crate::providers::azure::{ApiResponse, Client};
687 use crate::providers::openai::ImageGenerationResponse;
688 use serde_json::json;
689
690 #[derive(Clone)]
691 pub struct ImageGenerationModel {
692 client: Client,
693 pub model: String,
694 }
695 impl image_generation::ImageGenerationModel for ImageGenerationModel {
696 type Response = ImageGenerationResponse;
697
698 async fn image_generation(
699 &self,
700 generation_request: ImageGenerationRequest,
701 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
702 {
703 let request = json!({
704 "model": self.model,
705 "prompt": generation_request.prompt,
706 "size": format!("{}x{}", generation_request.width, generation_request.height),
707 "response_format": "b64_json"
708 });
709
710 let response = self
711 .client
712 .post_image_generation(&self.model)
713 .json(&request)
714 .send()
715 .await?;
716
717 if !response.status().is_success() {
718 return Err(ImageGenerationError::ProviderError(format!(
719 "{}: {}",
720 response.status(),
721 response.text().await?
722 )));
723 }
724
725 let t = response.text().await?;
726
727 match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&t)? {
728 ApiResponse::Ok(response) => response.try_into(),
729 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
730 }
731 }
732 }
733
734 impl ImageGenerationClient for Client {
735 type ImageGenerationModel = ImageGenerationModel;
736
737 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
738 ImageGenerationModel {
739 client: self.clone(),
740 model: model.to_string(),
741 }
742 }
743 }
744}
745use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient};
750#[cfg(feature = "audio")]
751pub use audio_generation::*;
752
753#[cfg(feature = "audio")]
754mod audio_generation {
755 use super::Client;
756 use crate::audio_generation;
757 use crate::audio_generation::{
758 AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
759 };
760 use crate::client::AudioGenerationClient;
761 use bytes::Bytes;
762 use serde_json::json;
763
764 #[derive(Clone)]
765 pub struct AudioGenerationModel {
766 client: Client,
767 model: String,
768 }
769
770 impl audio_generation::AudioGenerationModel for AudioGenerationModel {
771 type Response = Bytes;
772
773 async fn audio_generation(
774 &self,
775 request: AudioGenerationRequest,
776 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
777 let request = json!({
778 "model": self.model,
779 "input": request.text,
780 "voice": request.voice,
781 "speed": request.speed,
782 });
783
784 let response = self
785 .client
786 .post_audio_generation("/audio/speech")
787 .json(&request)
788 .send()
789 .await?;
790
791 if !response.status().is_success() {
792 return Err(AudioGenerationError::ProviderError(format!(
793 "{}: {}",
794 response.status(),
795 response.text().await?
796 )));
797 }
798
799 let bytes = response.bytes().await?;
800
801 Ok(AudioGenerationResponse {
802 audio: bytes.to_vec(),
803 response: bytes,
804 })
805 }
806 }
807
808 impl AudioGenerationClient for Client {
809 type AudioGenerationModel = AudioGenerationModel;
810
811 fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
812 AudioGenerationModel {
813 client: self.clone(),
814 model: model.to_string(),
815 }
816 }
817 }
818}
819
820#[cfg(test)]
821mod azure_tests {
822 use super::*;
823
824 use crate::OneOrMany;
825 use crate::completion::CompletionModel;
826 use crate::embeddings::EmbeddingModel;
827
828 #[tokio::test]
829 #[ignore]
830 async fn test_azure_embedding() {
831 let _ = tracing_subscriber::fmt::try_init();
832
833 let client = Client::from_env();
834 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
835 let embeddings = model
836 .embed_texts(vec!["Hello, world!".to_string()])
837 .await
838 .unwrap();
839
840 tracing::info!("Azure embedding: {:?}", embeddings);
841 }
842
843 #[tokio::test]
844 #[ignore]
845 async fn test_azure_completion() {
846 let _ = tracing_subscriber::fmt::try_init();
847
848 let client = Client::from_env();
849 let model = client.completion_model(GPT_4O_MINI);
850 let completion = model
851 .completion(CompletionRequest {
852 preamble: Some("You are a helpful assistant.".to_string()),
853 chat_history: OneOrMany::one("Hello!".into()),
854 documents: vec![],
855 max_tokens: Some(100),
856 temperature: Some(0.0),
857 tools: vec![],
858 additional_params: None,
859 })
860 .await
861 .unwrap();
862
863 tracing::info!("Azure completion: {:?}", completion);
864 }
865}