1use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
13
14use crate::json_utils::merge;
15use crate::streaming::StreamingCompletionResponse;
16use crate::{
17 completion::{self, CompletionError, CompletionRequest},
18 embeddings::{self, EmbeddingError},
19 json_utils,
20 providers::openai,
21 transcription::{self, TranscriptionError},
22};
23use reqwest::header::AUTHORIZATION;
24use reqwest::multipart::Part;
25use serde::Deserialize;
26use serde_json::json;
27#[derive(Clone)]
32pub struct Client {
33 api_version: String,
34 azure_endpoint: String,
35 auth: AzureOpenAIAuth,
36 http_client: reqwest::Client,
37}
38
39impl std::fmt::Debug for Client {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("Client")
42 .field("azure_endpoint", &self.azure_endpoint)
43 .field("http_client", &self.http_client)
44 .field("auth", &"<REDACTED>")
45 .field("api_version", &self.api_version)
46 .finish()
47 }
48}
49
50#[derive(Clone)]
51pub enum AzureOpenAIAuth {
52 ApiKey(String),
53 Token(String),
54}
55
56impl std::fmt::Debug for AzureOpenAIAuth {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 match self {
59 Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
60 Self::Token(_) => write!(f, "Token <REDACTED>"),
61 }
62 }
63}
64
65impl From<String> for AzureOpenAIAuth {
66 fn from(token: String) -> Self {
67 AzureOpenAIAuth::Token(token)
68 }
69}
70
71impl AzureOpenAIAuth {
72 fn as_header(&self) -> (reqwest::header::HeaderName, reqwest::header::HeaderValue) {
73 match self {
74 AzureOpenAIAuth::ApiKey(api_key) => (
75 "api-key".parse().expect("Header value should parse"),
76 api_key.parse().expect("API key should parse"),
77 ),
78 AzureOpenAIAuth::Token(token) => (
79 AUTHORIZATION,
80 format!("Bearer {token}")
81 .parse()
82 .expect("Token should parse"),
83 ),
84 }
85 }
86}
87
88impl Client {
89 pub fn new(auth: impl Into<AzureOpenAIAuth>, api_version: &str, azure_endpoint: &str) -> Self {
97 Self {
98 api_version: api_version.to_string(),
99 auth: auth.into(),
100 azure_endpoint: azure_endpoint.to_string(),
101 http_client: reqwest::Client::builder()
102 .build()
103 .expect("Azure OpenAI reqwest client should build"),
104 }
105 }
106
107 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
110 self.http_client = client;
111
112 self
113 }
114
115 pub fn from_api_key(api_key: &str, api_version: &str, azure_endpoint: &str) -> Self {
123 Self::new(
124 AzureOpenAIAuth::ApiKey(api_key.to_string()),
125 api_version,
126 azure_endpoint,
127 )
128 }
129
130 pub fn from_token(token: &str, api_version: &str, azure_endpoint: &str) -> Self {
138 Self::new(
139 AzureOpenAIAuth::Token(token.to_string()),
140 api_version,
141 azure_endpoint,
142 )
143 }
144
145 fn post_embedding(&self, deployment_id: &str) -> reqwest::RequestBuilder {
146 let url = format!(
147 "{}/openai/deployments/{}/embeddings?api-version={}",
148 self.azure_endpoint, deployment_id, self.api_version
149 )
150 .replace("//", "/");
151
152 let (key, value) = self.auth.as_header();
153 self.http_client.post(url).header(key, value)
154 }
155
156 fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder {
157 let url = format!(
158 "{}/openai/deployments/{}/chat/completions?api-version={}",
159 self.azure_endpoint, deployment_id, self.api_version
160 )
161 .replace("//", "/");
162 let (key, value) = self.auth.as_header();
163 self.http_client.post(url).header(key, value)
164 }
165
166 fn post_transcription(&self, deployment_id: &str) -> reqwest::RequestBuilder {
167 let url = format!(
168 "{}/openai/deployments/{}/audio/translations?api-version={}",
169 self.azure_endpoint, deployment_id, self.api_version
170 )
171 .replace("//", "/");
172 let (key, value) = self.auth.as_header();
173 self.http_client.post(url).header(key, value)
174 }
175
176 #[cfg(feature = "image")]
177 fn post_image_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
178 let url = format!(
179 "{}/openai/deployments/{}/images/generations?api-version={}",
180 self.azure_endpoint, deployment_id, self.api_version
181 )
182 .replace("//", "/");
183 let (key, value) = self.auth.as_header();
184 self.http_client.post(url).header(key, value)
185 }
186
187 #[cfg(feature = "audio")]
188 fn post_audio_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
189 let url = format!(
190 "{}/openai/deployments/{}/audio/speech?api-version={}",
191 self.azure_endpoint, deployment_id, self.api_version
192 )
193 .replace("//", "/");
194 let (key, value) = self.auth.as_header();
195 self.http_client.post(url).header(key, value)
196 }
197}
198
199impl ProviderClient for Client {
200 fn from_env() -> Self {
202 let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
203 AzureOpenAIAuth::ApiKey(api_key)
204 } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
205 AzureOpenAIAuth::Token(token)
206 } else {
207 panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
208 };
209
210 let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
211 let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
212
213 Self::new(auth, &api_version, &azure_endpoint)
214 }
215}
216
217impl CompletionClient for Client {
218 type CompletionModel = CompletionModel;
219
220 fn completion_model(&self, model: &str) -> CompletionModel {
232 CompletionModel::new(self.clone(), model)
233 }
234}
235
236impl EmbeddingsClient for Client {
237 type EmbeddingModel = EmbeddingModel;
238
239 fn embedding_model(&self, model: &str) -> EmbeddingModel {
253 let ndims = match model {
254 TEXT_EMBEDDING_3_LARGE => 3072,
255 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
256 _ => 0,
257 };
258 EmbeddingModel::new(self.clone(), model, ndims)
259 }
260
261 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
273 EmbeddingModel::new(self.clone(), model, ndims)
274 }
275}
276
277impl TranscriptionClient for Client {
278 type TranscriptionModel = TranscriptionModel;
279
280 fn transcription_model(&self, model: &str) -> TranscriptionModel {
292 TranscriptionModel::new(self.clone(), model)
293 }
294}
295
296#[derive(Debug, Deserialize)]
297struct ApiErrorResponse {
298 message: String,
299}
300
301#[derive(Debug, Deserialize)]
302#[serde(untagged)]
303enum ApiResponse<T> {
304 Ok(T),
305 Err(ApiErrorResponse),
306}
307
308pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
313pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
315pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
317
318#[derive(Debug, Deserialize)]
319pub struct EmbeddingResponse {
320 pub object: String,
321 pub data: Vec<EmbeddingData>,
322 pub model: String,
323 pub usage: Usage,
324}
325
326impl From<ApiErrorResponse> for EmbeddingError {
327 fn from(err: ApiErrorResponse) -> Self {
328 EmbeddingError::ProviderError(err.message)
329 }
330}
331
332impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
333 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
334 match value {
335 ApiResponse::Ok(response) => Ok(response),
336 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
337 }
338 }
339}
340
341#[derive(Debug, Deserialize)]
342pub struct EmbeddingData {
343 pub object: String,
344 pub embedding: Vec<f64>,
345 pub index: usize,
346}
347
348#[derive(Clone, Debug, Deserialize)]
349pub struct Usage {
350 pub prompt_tokens: usize,
351 pub total_tokens: usize,
352}
353
354impl std::fmt::Display for Usage {
355 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356 write!(
357 f,
358 "Prompt tokens: {} Total tokens: {}",
359 self.prompt_tokens, self.total_tokens
360 )
361 }
362}
363
364#[derive(Clone)]
365pub struct EmbeddingModel {
366 client: Client,
367 pub model: String,
368 ndims: usize,
369}
370
371impl embeddings::EmbeddingModel for EmbeddingModel {
372 const MAX_DOCUMENTS: usize = 1024;
373
374 fn ndims(&self) -> usize {
375 self.ndims
376 }
377
378 #[cfg_attr(feature = "worker", worker::send)]
379 async fn embed_texts(
380 &self,
381 documents: impl IntoIterator<Item = String>,
382 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
383 let documents = documents.into_iter().collect::<Vec<_>>();
384
385 let response = self
386 .client
387 .post_embedding(&self.model)
388 .json(&json!({
389 "input": documents,
390 }))
391 .send()
392 .await?;
393
394 if response.status().is_success() {
395 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
396 ApiResponse::Ok(response) => {
397 tracing::info!(target: "rig",
398 "Azure embedding token usage: {}",
399 response.usage
400 );
401
402 if response.data.len() != documents.len() {
403 return Err(EmbeddingError::ResponseError(
404 "Response data length does not match input length".into(),
405 ));
406 }
407
408 Ok(response
409 .data
410 .into_iter()
411 .zip(documents.into_iter())
412 .map(|(embedding, document)| embeddings::Embedding {
413 document,
414 vec: embedding.embedding,
415 })
416 .collect())
417 }
418 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
419 }
420 } else {
421 Err(EmbeddingError::ProviderError(response.text().await?))
422 }
423 }
424}
425
426impl EmbeddingModel {
427 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
428 Self {
429 client,
430 model: model.to_string(),
431 ndims,
432 }
433 }
434}
435
436pub const O1: &str = "o1";
441pub const O1_PREVIEW: &str = "o1-preview";
443pub const O1_MINI: &str = "o1-mini";
445pub const GPT_4O: &str = "gpt-4o";
447pub const GPT_4O_MINI: &str = "gpt-4o-mini";
449pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
451pub const GPT_4_TURBO: &str = "gpt-4";
453pub const GPT_4: &str = "gpt-4";
455pub const GPT_4_32K: &str = "gpt-4-32k";
457pub const GPT_4_32K_0613: &str = "gpt-4-32k";
459pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
461pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
463pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
465
466#[derive(Clone)]
467pub struct CompletionModel {
468 client: Client,
469 pub model: String,
471}
472
473impl CompletionModel {
474 pub fn new(client: Client, model: &str) -> Self {
475 Self {
476 client,
477 model: model.to_string(),
478 }
479 }
480
481 fn create_completion_request(
482 &self,
483 completion_request: CompletionRequest,
484 ) -> Result<serde_json::Value, CompletionError> {
485 let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
486 Some(preamble) => vec![openai::Message::system(preamble)],
487 None => vec![],
488 };
489 if let Some(docs) = completion_request.normalized_documents() {
490 let docs: Vec<openai::Message> = docs.try_into()?;
491 full_history.extend(docs);
492 }
493 let chat_history: Vec<openai::Message> = completion_request
494 .chat_history
495 .into_iter()
496 .map(|message| message.try_into())
497 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
498 .into_iter()
499 .flatten()
500 .collect();
501
502 full_history.extend(chat_history);
503
504 let request = if completion_request.tools.is_empty() {
505 json!({
506 "model": self.model,
507 "messages": full_history,
508 "temperature": completion_request.temperature,
509 })
510 } else {
511 json!({
512 "model": self.model,
513 "messages": full_history,
514 "temperature": completion_request.temperature,
515 "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
516 "tool_choice": "auto",
517 })
518 };
519
520 let request = if let Some(params) = completion_request.additional_params {
521 json_utils::merge(request, params)
522 } else {
523 request
524 };
525
526 Ok(request)
527 }
528}
529
530impl completion::CompletionModel for CompletionModel {
531 type Response = openai::CompletionResponse;
532 type StreamingResponse = openai::StreamingCompletionResponse;
533
534 #[cfg_attr(feature = "worker", worker::send)]
535 async fn completion(
536 &self,
537 completion_request: CompletionRequest,
538 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
539 let request = self.create_completion_request(completion_request)?;
540
541 let response = self
542 .client
543 .post_chat_completion(&self.model)
544 .json(&request)
545 .send()
546 .await?;
547
548 if response.status().is_success() {
549 let t = response.text().await?;
550 tracing::debug!(target: "rig", "Azure completion error: {}", t);
551
552 match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
553 ApiResponse::Ok(response) => {
554 tracing::info!(target: "rig",
555 "Azure completion token usage: {:?}",
556 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
557 );
558 response.try_into()
559 }
560 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
561 }
562 } else {
563 Err(CompletionError::ProviderError(response.text().await?))
564 }
565 }
566
567 #[cfg_attr(feature = "worker", worker::send)]
568 async fn stream(
569 &self,
570 request: CompletionRequest,
571 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
572 let mut request = self.create_completion_request(request)?;
573
574 request = merge(
575 request,
576 json!({"stream": true, "stream_options": {"include_usage": true}}),
577 );
578
579 let builder = self
580 .client
581 .post_chat_completion(self.model.as_str())
582 .json(&request);
583
584 send_compatible_streaming_request(builder).await
585 }
586}
587
588#[derive(Clone)]
593pub struct TranscriptionModel {
594 client: Client,
595 pub model: String,
597}
598
599impl TranscriptionModel {
600 pub fn new(client: Client, model: &str) -> Self {
601 Self {
602 client,
603 model: model.to_string(),
604 }
605 }
606}
607
608impl transcription::TranscriptionModel for TranscriptionModel {
609 type Response = TranscriptionResponse;
610
611 #[cfg_attr(feature = "worker", worker::send)]
612 async fn transcription(
613 &self,
614 request: transcription::TranscriptionRequest,
615 ) -> Result<
616 transcription::TranscriptionResponse<Self::Response>,
617 transcription::TranscriptionError,
618 > {
619 let data = request.data;
620
621 let mut body = reqwest::multipart::Form::new().part(
622 "file",
623 Part::bytes(data).file_name(request.filename.clone()),
624 );
625
626 if let Some(prompt) = request.prompt {
627 body = body.text("prompt", prompt.clone());
628 }
629
630 if let Some(ref temperature) = request.temperature {
631 body = body.text("temperature", temperature.to_string());
632 }
633
634 if let Some(ref additional_params) = request.additional_params {
635 for (key, value) in additional_params
636 .as_object()
637 .expect("Additional Parameters to OpenAI Transcription should be a map")
638 {
639 body = body.text(key.to_owned(), value.to_string());
640 }
641 }
642
643 let response = self
644 .client
645 .post_transcription(&self.model)
646 .multipart(body)
647 .send()
648 .await?;
649
650 if response.status().is_success() {
651 match response
652 .json::<ApiResponse<TranscriptionResponse>>()
653 .await?
654 {
655 ApiResponse::Ok(response) => response.try_into(),
656 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
657 api_error_response.message,
658 )),
659 }
660 } else {
661 Err(TranscriptionError::ProviderError(response.text().await?))
662 }
663 }
664}
665
666#[cfg(feature = "image")]
670pub use image_generation::*;
671#[cfg(feature = "image")]
672mod image_generation {
673 use crate::client::ImageGenerationClient;
674 use crate::image_generation;
675 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
676 use crate::providers::azure::{ApiResponse, Client};
677 use crate::providers::openai::ImageGenerationResponse;
678 use serde_json::json;
679
680 #[derive(Clone)]
681 pub struct ImageGenerationModel {
682 client: Client,
683 pub model: String,
684 }
685 impl image_generation::ImageGenerationModel for ImageGenerationModel {
686 type Response = ImageGenerationResponse;
687
688 async fn image_generation(
689 &self,
690 generation_request: ImageGenerationRequest,
691 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
692 {
693 let request = json!({
694 "model": self.model,
695 "prompt": generation_request.prompt,
696 "size": format!("{}x{}", generation_request.width, generation_request.height),
697 "response_format": "b64_json"
698 });
699
700 let response = self
701 .client
702 .post_image_generation(&self.model)
703 .json(&request)
704 .send()
705 .await?;
706
707 if !response.status().is_success() {
708 return Err(ImageGenerationError::ProviderError(format!(
709 "{}: {}",
710 response.status(),
711 response.text().await?
712 )));
713 }
714
715 let t = response.text().await?;
716
717 match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&t)? {
718 ApiResponse::Ok(response) => response.try_into(),
719 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
720 }
721 }
722 }
723
724 impl ImageGenerationClient for Client {
725 type ImageGenerationModel = ImageGenerationModel;
726
727 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
728 ImageGenerationModel {
729 client: self.clone(),
730 model: model.to_string(),
731 }
732 }
733 }
734}
735use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient};
740#[cfg(feature = "audio")]
741pub use audio_generation::*;
742
743#[cfg(feature = "audio")]
744mod audio_generation {
745 use super::Client;
746 use crate::audio_generation;
747 use crate::audio_generation::{
748 AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
749 };
750 use crate::client::AudioGenerationClient;
751 use bytes::Bytes;
752 use serde_json::json;
753
754 #[derive(Clone)]
755 pub struct AudioGenerationModel {
756 client: Client,
757 model: String,
758 }
759
760 impl audio_generation::AudioGenerationModel for AudioGenerationModel {
761 type Response = Bytes;
762
763 async fn audio_generation(
764 &self,
765 request: AudioGenerationRequest,
766 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
767 let request = json!({
768 "model": self.model,
769 "input": request.text,
770 "voice": request.voice,
771 "speed": request.speed,
772 });
773
774 let response = self
775 .client
776 .post_audio_generation("/audio/speech")
777 .json(&request)
778 .send()
779 .await?;
780
781 if !response.status().is_success() {
782 return Err(AudioGenerationError::ProviderError(format!(
783 "{}: {}",
784 response.status(),
785 response.text().await?
786 )));
787 }
788
789 let bytes = response.bytes().await?;
790
791 Ok(AudioGenerationResponse {
792 audio: bytes.to_vec(),
793 response: bytes,
794 })
795 }
796 }
797
798 impl AudioGenerationClient for Client {
799 type AudioGenerationModel = AudioGenerationModel;
800
801 fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
802 AudioGenerationModel {
803 client: self.clone(),
804 model: model.to_string(),
805 }
806 }
807 }
808}
809
810#[cfg(test)]
811mod azure_tests {
812 use super::*;
813
814 use crate::OneOrMany;
815 use crate::completion::CompletionModel;
816 use crate::embeddings::EmbeddingModel;
817
818 #[tokio::test]
819 #[ignore]
820 async fn test_azure_embedding() {
821 let _ = tracing_subscriber::fmt::try_init();
822
823 let client = Client::from_env();
824 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
825 let embeddings = model
826 .embed_texts(vec!["Hello, world!".to_string()])
827 .await
828 .unwrap();
829
830 tracing::info!("Azure embedding: {:?}", embeddings);
831 }
832
833 #[tokio::test]
834 #[ignore]
835 async fn test_azure_completion() {
836 let _ = tracing_subscriber::fmt::try_init();
837
838 let client = Client::from_env();
839 let model = client.completion_model(GPT_4O_MINI);
840 let completion = model
841 .completion(CompletionRequest {
842 preamble: Some("You are a helpful assistant.".to_string()),
843 chat_history: OneOrMany::one("Hello!".into()),
844 documents: vec![],
845 max_tokens: Some(100),
846 temperature: Some(0.0),
847 tools: vec![],
848 additional_params: None,
849 })
850 .await
851 .unwrap();
852
853 tracing::info!("Azure completion: {:?}", completion);
854 }
855}